diff --git a/.gitignore b/.gitignore index 057169ec42..f714acdefa 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,20 @@ mindspore/lib output *.ir +# flatbuffer +mindspore/lite/tools/converter/parser/tflite/schema_generated.h +mindspore/lite/tools/converter/parser/caffe/caffe.pb.cc +mindspore/lite/tools/converter/parser/caffe/caffe.pb.h +mindspore/lite/tools/converter/parser/onnx/onnx.pb.h +mindspore/lite/tools/converter/parser/onnx/onnx.pb.h +mindspore/lite/tools/converter/schema/*.h +mindspore/lite/tools/converter/schema/inner +mindspore/lite/schema/*.h +mindspore/lite/schema/inner + +mindspore/lite/src/runtime/kernel/opencl/cl/fp16/*.inc +mindspore/lite/src/runtime/kernel/opencl/cl/fp32/*.inc + # Cmake files CMakeFiles/ cmake_install.cmake @@ -27,6 +41,7 @@ cmake-build-debug *.pb.h *.pb.cc *.pb +*_grpc.py # Object files *.o @@ -71,5 +86,6 @@ test_temp_summary_event_file/ mindspore/version.py mindspore/default_config.py mindspore/.commit_id -onnx.proto -mindspore/ccsrc/onnx.proto + +# lite test file +mindspore/lite/test/do_test/ diff --git a/.gitmodules b/.gitmodules index 9eb6c53c34..fe94639740 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,7 @@ [submodule "third_party/flatbuffers"] path = third_party/flatbuffers url = https://github.com/google/flatbuffers.git + ignore = all [submodule "third_party/googletest"] path = third_party/googletest url = https://github.com/google/googletest.git @@ -10,9 +11,26 @@ [submodule "third_party/protobuf"] path = third_party/protobuf url = https://github.com/protocolbuffers/protobuf.git + ignore = all [submodule "akg"] path = akg url = https://gitee.com/mindspore/akg.git [submodule "graphengine"] path = graphengine url = https://gitee.com/ms-incubator/graphengine.git +[submodule "third_party/OpenCL-CLHPP"] + path = third_party/OpenCL-CLHPP + url = https://github.com/KhronosGroup/OpenCL-CLHPP.git +[submodule "third_party/OpenCL-Headers"] + path = third_party/OpenCL-Headers + url = https://github.com/KhronosGroup/OpenCL-Headers.git +[submodule "third_party/opencv"] + path = third_party/opencv + url = https://github.com/opencv/opencv.git +[submodule "third_party/eigen"] + path = third_party/eigen + url = https://gitlab.com/libeigen/eigen.git +[submodule "third_party/libjpeg-turbo"] + path = third_party/libjpeg-turbo + url = https://github.com/libjpeg-turbo/libjpeg-turbo.git + ignore = dirty diff --git a/CMakeLists.txt b/CMakeLists.txt index c4da105cac..3cebff4083 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,7 @@ if (ENABLE_PYTHON) add_compile_definitions(ENABLE_PYTHON) endif() -set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)=' -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") +set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -42,7 +42,7 @@ if (NOT Patch_FOUND) endif () message(PATCH_EXECUTABLE = ${Patch_EXECUTABLE}) -if (ENABLE_AKG AND ENABLE_D) +if (ENABLE_AKG AND (ENABLE_D OR ENABLE_GPU)) add_subdirectory("${CMAKE_SOURCE_DIR}/akg") endif() @@ -51,6 +51,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include/flatbuffers) +if (NOT ENABLE_ACL) + include(${CMAKE_SOURCE_DIR}/cmake/dependency_utils.cmake) find_package(Python3 3.7 COMPONENTS Interpreter Development) if(Python3_FOUND) @@ -100,8 +102,13 @@ if (ENABLE_TESTCASES) add_subdirectory(tests) endif() +endif() # NOT ENABLE_ACL + if (ENABLE_SERVING) add_subdirectory(serving) + add_subdirectory(serving/example/cpp_client) endif() +if (NOT ENABLE_ACL) include(cmake/package.cmake) +endif() # NOT ENABLE_ACL diff --git a/README.md b/README.md index 25abdd6fcb..50e21fae70 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ For installation using `pip`, take `CPU` and `Ubuntu-x86` build version as an ex 1. Download whl from [MindSpore download page](https://www.mindspore.cn/versions/en), and install the package. ``` - pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.5.0-beta/MindSpore/cpu/ubuntu_x86/mindspore-0.5.0-cp37-cp37m-linux_x86_64.whl + pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.6.0-beta/MindSpore/cpu/ubuntu_x86/mindspore-0.6.0-cp37-cp37m-linux_x86_64.whl ``` 2. Run the following command to verify the install. @@ -132,8 +132,8 @@ currently the containerized build options are supported as follows: For `CPU` backend, you can directly pull and run the latest stable image using the below command: ``` - docker pull mindspore/mindspore-cpu:0.5.0-beta - docker run -it mindspore/mindspore-cpu:0.5.0-beta /bin/bash + docker pull mindspore/mindspore-cpu:0.6.0-beta + docker run -it mindspore/mindspore-cpu:0.6.0-beta /bin/bash ``` * GPU @@ -150,8 +150,8 @@ currently the containerized build options are supported as follows: Then you can pull and run the latest stable image using the below command: ``` - docker pull mindspore/mindspore-gpu:0.5.0-beta - docker run -it --runtime=nvidia --privileged=true mindspore/mindspore-gpu:0.5.0-beta /bin/bash + docker pull mindspore/mindspore-gpu:0.6.0-beta + docker run -it --runtime=nvidia --privileged=true mindspore/mindspore-gpu:0.6.0-beta /bin/bash ``` To test if the docker image works, please execute the python code below and check the output: @@ -202,10 +202,10 @@ Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/comm ### Communication -- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers. +- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/zt-dgk65rli-3ex4xvS4wHX7UDmsQmfu8w) - Communication platform for developers. - IRC channel at `#mindspore` (only for meeting minutes logging purpose) -- Video Conferencing: https://meet.jit.si -- Mailing-list: https://mailweb.mindspore.cn/postorius/lists +- Video Conferencing: TBD +- Mailing-list: ## Contributing diff --git a/RELEASE.md b/RELEASE.md index def72cbb20..f954102731 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,74 @@ +# Release 0.6.0-beta + +## Major Features and Improvements +### Ascend 910 Training and Inference Framework +* New models + * There are official, research and community under modelzoo. + * Official is maintained with the newest APIs by MindSpore team, MaskRCNN are added. + * Research is uploaded by researchers for official review, and APIs may not be updated in time. + * Community reprints the relevant links of partner research results. + * Hub added on the same level as modelzoo, synchronous storage of materials needed for official hub web pages which will be launched soon. + * Support pre-trained models, few lines of code can be used to download and load pre-trained models, supporting inference or transfer learning. +* Frontend and user interface + * Supports user side operator compilation and graph execution error rendering. + * Uniform definition dynamic learning rate behavior in optimizers. + * Support IndexSlice in sparse expression. + * Support use parent construct method during construct. + * Support asynchronous execution save checkpoint file. + * Support implicit type conversion in pynative mode. + * User interfaces change log + * unform learning rate behavior in optimizers([!2755](https://gitee.com/mindspore/mindspore/pulls/2755)) + * rename operator of sparse optimizer([!3217](https://gitee.com/mindspore/mindspore/pulls/3217)) + * move profiler module from mindinsight to mindspore([!3075](https://gitee.com/mindspore/mindspore/pulls/3075)) + * VOCDataset output change to multi-columns([!3093](https://gitee.com/mindspore/mindspore/pulls/3093)) + * GetDatasize feature([!3212](https://gitee.com/mindspore/mindspore/pulls/3212)) + * dataset: modify config api([!2936](https://gitee.com/mindspore/mindspore/pulls/2936)) +* Executor and performance optimization + * Decouple C++ and python, so make the architecture more extensible. + * Parameter Server for distributed deep learning supported. + * Serving:a flexible service deployment framework for deep learning models. + * Memory reuse is enhanced, and the batch size of Bert large model is increased from 96 to 160 on a single server. +* Data processing, augmentation, and save format + * Support MindRecord save operator after date processing + * Support automatic fusion operator, such as decode/resize/crop + * Support CSV dataset loading +### Other Hardware Support +* GPU platform + * New model supported: ResNext50, WarpCTC and GoogLeNet. + * Support hyperparametric search and data enhanced automl on GPU. + * Support Resnet50 automatic parallel in GPU backend. + +## Bugfixes +* Models + * Improved the performance and accuracy on ResNet50([!3456](https://gitee.com/mindspore/mindspore/pulls/3456)) + * Fixed the performance test case of bert([!3486](https://gitee.com/mindspore/mindspore/pulls/3486)) +* Python API + * Fix assign used in while loop([!2720](https://gitee.com/mindspore/mindspore/pulls/2720)) + * Revert optimize the graph output of all nop node.([!2857](https://gitee.com/mindspore/mindspore/pulls/2857)) + * Print tensor as numpy.([!2859](https://gitee.com/mindspore/mindspore/pulls/2859)) + * Support weight decay for sparse optimizer([!2668](https://gitee.com/mindspore/mindspore/pulls/2668)) + * Fix BatchToSpaceND([!2741](https://gitee.com/mindspore/mindspore/pulls/2741)) + * Fixing type check mistakes of InplaceAdd and Inplace Sub ops([!2744](https://gitee.com/mindspore/mindspore/pulls/2744])) + * Change order param only equal to group param([!2748](https://gitee.com/mindspore/mindspore/pulls/2748)) +* Executor + * The performance of graph whith control flow is optimized([!2931](https://gitee.com/mindspore/mindspore/pulls/2931)) + * Fix bug of wrong number of tuple layers([!3390](https://gitee.com/mindspore/mindspore/pulls/3390)) + * Fix cpu multi graph memory exception([!3631](https://gitee.com/mindspore/mindspore/pulls/3631)) + * Enable data sync when calling operator without defining a cell([!3081](https://gitee.com/mindspore/mindspore/pulls/3081)) + * Fix argmaxwith value error in pynative mode on GPU([!3082](https://gitee.com/mindspore/mindspore/pulls/3082)) + * Fix precision error with fp16 input on pynative mode([!3196](https://gitee.com/mindspore/mindspore/pulls/3196)) +* Data processing + * Fix bug of RandomColor and RandomSharpness default parameter checking ([!2833](https://gitee.com/mindspore/mindspore/pulls/2833)) + * Fix process hung when training and eval ([!3469](https://gitee.com/mindspore/mindspore/pulls/3469)) + +## Contributors +Thanks goes to these wonderful people: + +Alexey Shevlyakov, avakh, baihuawei, BowenK, buxue, caifubi, caojian05, Cathy Wong, changzherui, chenfei, chengxianbin, chenhaozhe, chenjianping, chentingting, chenzomi, chujinjin, Danish Farid, dayschan, dengwentao, dinghao, etone-chan, fangzehua, fary86, geekun, Giancarlo Colmenares, gong chen, gukecai, guohongzilong, hangangqiang, heleiwang, hesham, He Wei, hexia, hongxing, huangdongrun, huanghui, islam_amin, Jamie Nisbet, Jesse Lee, jiangjinsheng, jiangzhiwen, jinyaohui, jjfeing, jojobugfree, Jonathan Yan, jonyguo, Junhan Hu, Kang, kingfo, kouzhenzhong, kpy, kswang, laiyongqiang, leopz, liangzelang, lichenever, lihongkang, Li Hongzhang, lilei, limingqi107, lirongzhen1, liubuyu, liuchongming74, liuwenhao4, liuxiao, Lixia Chen, liyanliu, liyong, lizhenyu, lvliang, Mahdi, Margaret_wangrui, meixiaowei, ms_yan, nhussain, ougongchang, panfengfeng, panyifeng, peilinwang, Peilin Wang, pkuliuliu, qianlong, rick_sanchez, shibeiji, Shida He, shijianning, simson, sunsuodong, suteng, Tinazhang, Tron Zhang, unknown, VectorSL, wandongdong, wangcong, wangdongxu, wangdongxu6, wanghua, wangnan39, Wei Luning, wenchunjiang, wenkai, wilfChen, WilliamLian, wukesong, Xian Weizhao, Xiaoda Zhang, xiefangqi, xulei2020, xunxue, xutianchun, Yang, yanghaitao, yanghaitao1, yanghaoran, yangjie, yangjie159, YangLuo, Yanjun Peng, yankai, yanzhenxiang2020, yao_yf, Yi Huaijie, yoonlee666, yuchaojie, yujianfeng, zhangzhongpeng, zhangdengcheng, Zhang Qinghua, zhangyinxia, zhangz0911gm, zhaojichen, zhaoting, zhaozhenlong, zhoufeng, zhouneng, zhousiyi, Zirui Wu, Ziyan, zjun, ZPaC, lihongzhang, wangdongxu + +Contributions of any kind are welcome! + + # Release 0.5.0-beta ## Major Features and Improvements diff --git a/akg b/akg index df57a6cf94..5fe7e5c837 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35 +Subproject commit 5fe7e5c8377dccfd35c9f661e10ed3dc136208c5 diff --git a/build.sh b/build.sh index cfa657ff3e..23d26e4eab 100755 --- a/build.sh +++ b/build.sh @@ -24,8 +24,9 @@ usage() { echo "Usage:" echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" - echo " [-a on|off] [-Q on|off] [-S on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" - echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off] [-E] [-l on|off]" + echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" + echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\" + echo " [-B on|off] [-w on|off] [-E] [-l on|off] [-n]" echo "" echo "Options:" echo " -d Debug mode" @@ -48,13 +49,14 @@ usage() echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" echo " -Q Enable dump memory, default off" echo " -D Enable dumping of function graph ir, default on" - echo " -S Enable async data dump, default off" echo " -z Compile dataset & mindrecord, default on" + echo " -n Compile minddata lite" echo " -M Enable MPI and NCCL for GPU training, gpu default on" echo " -V Specify the minimum required cuda version, default CUDA 10.1" - echo " -I Compile predict, default off" + echo " -I Compile lite" echo " -K Compile with AKG, default on" echo " -s Enable serving module, default off" + echo " -w Enable acl module, default off" echo " -B Enable debugger, default off" echo " -E Enable IBVERBS for parameter server, default off" echo " -l Compile with python dependency, default on" @@ -89,28 +91,34 @@ checkopts() ENABLE_TIMELINE="off" ENABLE_DUMP2PROTO="on" ENABLE_DUMPE2E="off" - ENABLE_DATA_DUMP="off" ENABLE_DUMP_IR="on" COMPILE_MINDDATA="on" + COMPILE_MINDDATA_LITE="off" ENABLE_MPI="off" CUDA_VERSION="10.1" - COMPILE_PREDICT="off" + COMPILE_LITE="off" + LITE_PLATFORM="" + SUPPORT_TRAIN="off" USE_GLOG="on" - PREDICT_PLATFORM="" ENABLE_AKG="on" ENABLE_SERVING="off" + ENABLE_ACL="off" ENABLE_DEBUGGER="off" ENABLE_IBVERBS="off" ENABLE_PYTHON="on" + ENABLE_GPU="off" # Process the options - while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:S:D:zM:V:K:sB:E' opt + while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:En' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in d) DEBUG_MODE="on" ;; + n) + COMPILE_MINDDATA_LITE="on" + ;; r) DEBUG_MODE="off" ;; @@ -186,6 +194,7 @@ checkopts() elif [[ "X$OPTARG" == "Xd" || "X$OPTARG" == "Xascend" ]]; then ENABLE_D="on" ENABLE_CPU="on" + ENABLE_SERVING="on" elif [[ "X$OPTARG" == "Xcpu" ]]; then ENABLE_CPU="on" else @@ -220,11 +229,6 @@ checkopts() ENABLE_DUMPE2E="$OPTARG" echo "enable dump end to end" ;; - S) - check_on_off $OPTARG S - ENABLE_DATA_DUMP="$OPTARG" - echo "enable data dump" - ;; D) check_on_off $OPTARG D ENABLE_DUMP_IR="$OPTARG" @@ -244,13 +248,16 @@ checkopts() fi ;; I) - COMPILE_PREDICT="on" + COMPILE_LITE="on" if [[ "$OPTARG" == "arm64" ]]; then - PREDICT_PLATFORM="arm64" + LITE_PLATFORM="arm64" + elif [[ "$OPTARG" == "arm32" ]]; then + LITE_PLATFORM="arm32" elif [[ "$OPTARG" == "x86_64" ]]; then - PREDICT_PLATFORM="x86_64" + ENABLE_CONVERTER="on" + LITE_PLATFORM="x86_64" else - echo "-I parameter must be arm64 or x86_64" + echo "-I parameter must be arm64、arm32 or x86_64" exit 1 fi ;; @@ -262,6 +269,10 @@ checkopts() ENABLE_SERVING="on" echo "enable serving" ;; + w) + ENABLE_ACL="on" + echo "enable acl" + ;; B) check_on_off $OPTARG B ENABLE_DEBUGGER="on" @@ -279,10 +290,13 @@ checkopts() done } checkopts "$@" +if [[ "X$ENABLE_GPU" = "Xon" ]] && [[ "X$ENABLE_DUMPE2E" = "Xon" ]]; then + ENABLE_DEBUGGER="on" +fi echo "---------------- MindSpore: build start ----------------" mkdir -pv "${BUILD_PATH}/package/mindspore/lib" git submodule update --init graphengine -if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then +if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = "Xon" ]]; then git submodule update --init --recursive akg fi @@ -328,9 +342,6 @@ build_mindspore() if [[ "X$ENABLE_DUMPE2E" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_E2E=ON" fi - if [[ "X$ENABLE_DATA_DUMP" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DATA_DUMP=ON" - fi CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_IR=${ENABLE_DUMP_IR}" CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PYTHON=${ENABLE_PYTHON}" if [[ "X$ENABLE_MPI" = "Xon" ]]; then @@ -340,7 +351,7 @@ build_mindspore() CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_D=ON" fi if [[ "X$ENABLE_GPU" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DCUDA_PATH=$CUDA_PATH -DCUDNN_PATH=$CUDNN_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}" + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DUSE_CUDA=ON -DCUDA_PATH=$CUDA_PATH -DCUDNN_PATH=$CUDNN_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}" fi if [[ "X$ENABLE_CPU" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON" @@ -351,12 +362,15 @@ build_mindspore() if [[ "X$USE_GLOG" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" fi - if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then + if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" fi if [[ "X$ENABLE_SERVING" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON" fi + if [[ "X$ENABLE_ACL" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ACL=ON" + fi if [[ "X$ENABLE_DEBUGGER" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DEBUGGER=ON" fi @@ -371,132 +385,319 @@ build_mindspore() if [[ -n "$VERBOSE" ]]; then CMAKE_VERBOSE="--verbose" fi + if [[ "X$ENABLE_ACL" = "Xon" ]]; then + cmake --build . ${CMAKE_VERBOSE} -j$THREAD_NUM + else cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM + fi echo "success to build mindspore project!" } -build_predict() -{ - git submodule update --init --recursive third_party/incubator-tvm - echo "start build predict project" - - git submodule update --init --recursive third_party/flatbuffers - git submodule update --init --recursive third_party/googletest - git submodule update --init --recursive third_party/protobuf - - rm -rf "${BASEPATH}/predict/build" - mkdir -pv "${BASEPATH}/predict/build" - rm -rf "${BASEPATH}/predict/output" - mkdir -pv "${BASEPATH}/predict/output" - - if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - if [ "${ANDROID_NDK}" ]; then - echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" - else - echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r16b/ \e[0m" - exit 1 - fi +checkndk() { + if [ "${ANDROID_NDK}" ]; then + echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" + else + echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m" + exit 1 fi +} - #build flatbuf - cd "${BASEPATH}/third_party/flatbuffers" - rm -rf build && mkdir -p build && cd build && cmake .. && make -j$THREAD_NUM - FLATC="${BASEPATH}"/third_party/flatbuffers/build/flatc - cd "${BASEPATH}"/predict/schema && mkdir -p "${BASEPATH}"/predict/schema/inner +gene_flatbuffer() { + FLAT_DIR="${BASEPATH}/mindspore/lite/schema" + cd ${FLAT_DIR} && rm -rf "${FLAT_DIR}/inner" && mkdir -p "${FLAT_DIR}/inner" find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b - find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o ${BASEPATH}/predict/schema/inner - - # check LLVM_PATH - if [ "${LLVM_PATH}" == "" ]; then - echo "Please set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config" - exit - fi - - #build tvm - tvm_open_source="${BASEPATH}/third_party/incubator-tvm" - tvm_kernel_build="${BASEPATH}/predict/module/tvm_kernel" - if [ ! -f "${tvm_kernel_build}"/incubator-tvm/build/libtvm.so ]; then - rm -fr "${tvm_kernel_build}"/incubator-tvm - cp -fr "${tvm_open_source}" "${tvm_kernel_build}" - mkdir -p "${tvm_kernel_build}"/incubator-tvm/build - patch -d "${tvm_kernel_build}"/incubator-tvm -p1 < "${BASEPATH}"/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch - cp "${tvm_kernel_build}"/lite/src/codegen/llvm/lite_rtfunc_reset.cc "${tvm_kernel_build}"/incubator-tvm/src/codegen/llvm/ - cp "${tvm_open_source}"/cmake/config.cmake "${tvm_kernel_build}"/incubator-tvm - if [ "${LLVM_PATH}" ]; then - sed -i "s#set(USE_LLVM .*)#set(USE_LLVM \"${LLVM_PATH}\")#g" "${tvm_kernel_build}"/incubator-tvm/config.cmake - else - echo "need set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config" + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o "${FLAT_DIR}/inner" + + FLAT_DIR="${BASEPATH}/mindspore/lite/tools/converter/parser/tflite" + cd ${FLAT_DIR} + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o "${FLAT_DIR}/" +} + +build_flatbuffer() { + cd ${BASEPATH} + FLATC="${BASEPATH}"/third_party/flatbuffers/build/flatc + if [[ ! -f "${FLATC}" ]]; then + git submodule update --init --recursive third_party/flatbuffers + cd ${BASEPATH}/third_party/flatbuffers + rm -rf build && mkdir -pv build && cd build && cmake .. && make -j$THREAD_NUM + gene_flatbuffer + fi + if [[ "${INC_BUILD}" == "off" ]]; then + gene_flatbuffer + fi +} + +gene_protobuf() { + PROTO_SRC_DIR="${BASEPATH}/mindspore/lite/tools/converter/parser/caffe" + find ${PROTO_SRC_DIR} -name "*.proto" -print0 | xargs -0 "${PROTOC}" -I"${PROTO_SRC_DIR}" --cpp_out="${PROTO_SRC_DIR}" + PROTO_SRC_DIR="${BASEPATH}/mindspore/lite/tools/converter/parser/onnx" + find ${PROTO_SRC_DIR} -name "*.proto" -print0 | xargs -0 "${PROTOC}" -I"${PROTO_SRC_DIR}" --cpp_out="${PROTO_SRC_DIR}" +} + +build_protobuf() { + cd ${BASEPATH} + PROTOC="${BASEPATH}"/third_party/protobuf/build/bin/protoc + if [[ ! -f "${PROTOC}" ]]; then + git submodule update --init --recursive third_party/protobuf + cd ${BASEPATH}/third_party/protobuf + rm -rf build && mkdir -pv build && ./autogen.sh + ./configure --prefix=${BASEPATH}/third_party/protobuf/build + make clean && make -j$THREAD_NUM && make install + gene_protobuf + fi + if [[ "${INC_BUILD}" == "off" ]]; then + gene_protobuf + fi +} + +build_gtest() { + cd ${BASEPATH} + git submodule update --init --recursive third_party/googletest +} + +gene_clhpp() { + CL_SRC_DIR="${BASEPATH}/mindspore/lite/src/runtime/kernel/opencl/cl" + for sub_dir in "${CL_SRC_DIR}"/* + do + data_type="$(basename ${sub_dir})" + if [ ! -d ${CL_SRC_DIR}/${data_type} ]; then + continue fi - cd "${tvm_kernel_build}"/incubator-tvm/build - cmake .. - make -j$THREAD_NUM - else - cd "${tvm_kernel_build}"/incubator-tvm/build - make -j$THREAD_NUM - fi - - #gen op - predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_x86" - predict_platform="x86" - if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_arm64" - predict_platform="arm64" - fi - - need_get_libs=true - if [ -d "${predict_tvm_op_lib_path}" ]; then - file_list=$(ls "${predict_tvm_op_lib_path}") - if [ -n "${file_list}" ]; then - libstime=$(stat -c %Y "${predict_tvm_op_lib_path}"/* | sort -u | tail -n1) - pythontime=$(find "${BASEPATH}"/predict/module/tvm_kernel/lite/python/ -name "*.py" -exec stat -c %Y {} \; | - sort -u | tail -n1) - if [ "${libstime}" -ge "${pythontime}" ]; then - need_get_libs=false - else - rm -fr "${predict_tvm_op_lib_path}" + cd ${CL_SRC_DIR}/${data_type} + rm -rf *.inc + echo "$(cd "$(dirname $0)"; pwd)" + for file_path in "${CL_SRC_DIR}/${data_type}"/* + do + file="$(basename ${file_path})" + inc_file=`echo ${CL_SRC_DIR}/${data_type}/${file} | sed 's/$/.inc/'` + sed 's/^/\"/;s/$/ \\n\" \\/' ${CL_SRC_DIR}/${data_type}/${file} > ${inc_file} + kernel_name=`echo ${file} | sed s'/.\{3\}$//'` + sed -i "1i\static const char *${kernel_name}_source_${data_type} =\"\\n\" \\" ${inc_file} + sed -i '$a\;' ${inc_file} + done + done +} + +gene_ocl_program() { + CL_SRC_DIR="${BASEPATH}/mindspore/lite/src/runtime/kernel/opencl/cl" + SPIRV_DIR=build/spirv + rm -rf ${SPIRV_DIR} + mkdir -pv ${SPIRV_DIR} + for sub_dir in "${CL_SRC_DIR}"/* + do + data_type="$(basename ${sub_dir})" + if [ ! -d ${CL_SRC_DIR}/${data_type} ]; then + continue fi - fi + #echo $(cd "$(dirname $0)"; pwd) + for file_path in "${CL_SRC_DIR}/${data_type}"/* + do + file="$(basename ${file_path})" + if [ "${file##*.}" != "cl" ]; then + continue + fi + clang -Xclang -finclude-default-header -cl-std=CL2.0 --target=spir64-unknown-unknown -emit-llvm \ + -c -O0 -o ${SPIRV_DIR}/${file%.*}.bc ${CL_SRC_DIR}/${data_type}/${file} + done + done + + bcs=`ls ${SPIRV_DIR}/*.bc` + llvm-link ${bcs} -o ${SPIRV_DIR}/program.bc + llvm-spirv -o ${SPIRV_DIR}/program.spv ${SPIRV_DIR}/program.bc + + CL_PROGRAM_PATH="${BASEPATH}/mindspore/lite/src/runtime/kernel/opencl/cl/program.inc" + echo "#include " > ${CL_PROGRAM_PATH} + echo "std::vector g_program_binary = {" >> ${CL_PROGRAM_PATH} + #hexdump -v -e '16/1 "0x%02x, " "\n"' ${SPIRV_DIR}/program.spv >> ${CL_PROGRAM_PATH} + hexdump -v -e '1/1 "0x%02x, "' ${SPIRV_DIR}/program.spv >> ${CL_PROGRAM_PATH} + echo "};" >> ${CL_PROGRAM_PATH} + echo "Compile SPIRV done" +} + +build_opencl() { + cd ${BASEPATH} + git submodule update --init third_party/OpenCL-Headers + git submodule update --init third_party/OpenCL-CLHPP + if [[ "${OPENCL_OFFLINE_COMPILE}" == "on" ]]; then + gene_ocl_program + else + gene_clhpp + fi +} + +build_opencv() { + cd ${BASEPATH} + if [[ "${INC_BUILD}" == "off" ]]; then + git submodule update --init --recursive third_party/opencv + cd ${BASEPATH}/third_party/opencv + rm -rf build && mkdir -p build && cd build && cmake ${CMAKE_MINDDATA_ARGS} -DBUILD_SHARED_LIBS=ON -DBUILD_ANDROID_PROJECTS=OFF \ + -DBUILD_LIST=core,imgcodecs,imgproc -DBUILD_ZLIB=ON .. && make -j$THREAD_NUM + fi +} + +build_jpeg_turbo() { + cd ${BASEPATH} + if [[ "${INC_BUILD}" == "off" ]]; then + git submodule update --init --recursive third_party/libjpeg-turbo + cd ${BASEPATH}/third_party/libjpeg-turbo + rm -rf build && mkdir -p build && cd build && cmake ${CMAKE_MINDDATA_ARGS} -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="${BASEPATH}/third_party/libjpeg-turbo" .. && make -j$THREAD_NUM && make install fi +} + +build_eigen() { + cd ${BASEPATH} + git submodule update --init --recursive third_party/eigen +} + +build_minddata_lite_deps() +{ + echo "start build minddata lite project" + if [[ "${LITE_PLATFORM}" == "arm64" ]]; then + CMAKE_MINDDATA_ARGS="-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake -DANDROID_NATIVE_API_LEVEL=19 \ + -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang \ + -DANDROID_STL=c++_shared -DCMAKE_BUILD_TYPE=${BUILD_TYPE}" + elif [[ "${LITE_PLATFORM}" == "arm32" ]]; then + CMAKE_MINDDATA_ARGS="-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake -DANDROID_NATIVE_API_LEVEL=19 \ + -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=armeabi-v7a -DANDROID_TOOLCHAIN_NAME=clang \ + -DANDROID_STL=c++_shared -DCMAKE_BUILD_TYPE=${BUILD_TYPE}" + else + CMAKE_MINDDATA_ARGS="-DCMAKE_BUILD_TYPE=${BUILD_TYPE}" + fi + build_opencv + build_eigen + build_jpeg_turbo +} - if $need_get_libs; then - PYTHONPATH_OLD=${PYTHONPATH} - export PYTHONPATH="${tvm_kernel_build}/incubator-tvm/python:${tvm_kernel_build}/incubator-tvm/topi/python:${tvm_kernel_build}/incubator-tvm/nnvm/python:${tvm_kernel_build}/lite/python:" - cd "${BASEPATH}"/predict/module/tvm_kernel/lite/python/at_ops - python3 at_gen_strip.py ${predict_platform} - export PYTHONPATH=${PYTHONPATH_OLD} +build_lite() +{ + echo "start build mindspore lite project" + + if [ "${ENABLE_GPU}" == "on" ] && [ "${LITE_PLATFORM}" == "arm64" ]; then + echo "start build opencl" + build_opencl fi + if [[ "${LITE_PLATFORM}" == "x86_64" ]]; then + build_protobuf + fi + build_flatbuffer + build_gtest - cd "${BASEPATH}/predict/build" - if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ - -DANDROID_NATIVE_API_LEVEL=android-19 -DANDROID_NDK="${ANDROID_NDK}" \ - -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" -DANDROID_STL="c++_shared" \ - -DANDROID_ABI="arm64-v8a" -DENABLE_PREDICT_ARM64=ON -DANDROID_ALLOW_UNDEFINED_SYMBOLS=TRUE .. - elif [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then - cmake .. + if [ "${COMPILE_MINDDATA_LITE}" == "on" ]; then + build_minddata_lite_deps fi - make ${VERBOSE} -j$THREAD_NUM - if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then - cd "${BASEPATH}/predict/build/test" && ./run_tests.sh + cd "${BASEPATH}/mindspore/lite" + if [[ "${INC_BUILD}" == "off" ]]; then + rm -rf build + fi + mkdir -pv build + cd build + BUILD_TYPE="Release" + if [[ "${DEBUG_MODE}" == "on" ]]; then + BUILD_TYPE="Debug" fi - # copy securec include files - mkdir -p "${BASEPATH}/predict/output/include/securec/include" - cp "${BASEPATH}"/third_party/securec/include/* "${BASEPATH}"/predict/output/include/securec/include + if [[ "${LITE_PLATFORM}" == "arm64" ]]; then + checkndk + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ + -DANDROID_STL="c++_shared" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DBUILD_DEVICE=on -DPLATFORM_ARM64=on -DBUILD_CONVERTER=off -DENABLE_NEON=on -DENABLE_FP16="off" \ + -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ + "${BASEPATH}/mindspore/lite" + elif [[ "${LITE_PLATFORM}" == "arm32" ]]; then + checkndk + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ + -DANDROID_STL="c++_shared" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DBUILD_DEVICE=on -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DBUILD_CONVERTER=off \ + -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ + "${BASEPATH}/mindspore/lite" + else + cmake -DBUILD_DEVICE=on -DPLATFORM_ARM64=off -DBUILD_CONVERTER=${ENABLE_CONVERTER} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ + -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} "${BASEPATH}/mindspore/lite" + fi + VERBOSE=2 make -j$THREAD_NUM + COMPILE_RET=$? - cd "${BASEPATH}/predict/output/" - if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then - tar -cf MSPredict-0.5.0-linux_x86_64.tar.gz include/ lib/ --warning=no-file-changed - elif [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - tar -cf MSPredict-0.5.0-linux_aarch64.tar.gz include/ lib/ --warning=no-file-changed + if [[ "${COMPILE_RET}" -ne 0 ]]; then + echo "---------------- mindspore lite: build failed ----------------" + else + mkdir -pv ${BASEPATH}/mindspore/lite/output/ + if [[ "$LITE_PLATFORM" == "x86_64" ]]; then + OUTPUT_DIR=${BASEPATH}/output/MSLite-0.6.0-linux_x86_64 + rm -rf ${OUTPUT_DIR} && mkdir -p ${OUTPUT_DIR} && cd ${OUTPUT_DIR} + mkdir -p ${OUTPUT_DIR}/converter && mkdir -p ${OUTPUT_DIR}/time_profile + mkdir -p ${OUTPUT_DIR}/benchmark && mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib + mkdir -p ${OUTPUT_DIR}/third_party + cp ${BASEPATH}/mindspore/lite/build/tools/converter/converter_lite ${OUTPUT_DIR}/converter/ + cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/ + cp ${BASEPATH}/mindspore/lite/build/tools/time_profile/timeprofile ${OUTPUT_DIR}/time_profile/ + cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/ + mkdir -p ${OUTPUT_DIR}/include/ir/dtype/ + cp ${BASEPATH}/mindspore/core/ir/dtype/type_id.h ${OUTPUT_DIR}/include/ir/dtype/ + mkdir -p ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/schema/*.h ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/protobuf/lib + cp -r ${BASEPATH}/third_party/protobuf/build/include/ ${OUTPUT_DIR}/third_party/protobuf/ + cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19 ${OUTPUT_DIR}/third_party/protobuf/lib/ + cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 ${OUTPUT_DIR}/third_party/protobuf/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers + cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/ + cd .. + tar -czf MSLite-0.6.0-linux_x86_64.tar.gz MSLite-0.6.0-linux_x86_64/ --warning=no-file-changed + sha256sum MSLite-0.6.0-linux_x86_64.tar.gz > MSLite-0.6.0-linux_x86_64.tar.gz.sha256 + rm -rf MSLite-0.6.0-linux_x86_64/ + elif [[ "$LITE_PLATFORM" == "arm64" ]]; then + OUTPUT_DIR=${BASEPATH}/output/MSLite-0.6.0-linux_arm64 + rm -rf ${OUTPUT_DIR} && mkdir -p ${OUTPUT_DIR} && cd ${OUTPUT_DIR} + mkdir -p ${OUTPUT_DIR}/time_profile && mkdir -p ${OUTPUT_DIR}/benchmark + mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib + mkdir -p ${OUTPUT_DIR}/third_party + cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/ + cp ${BASEPATH}/mindspore/lite/build/tools/time_profile/timeprofile ${OUTPUT_DIR}/time_profile/ + cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/ + mkdir -p ${OUTPUT_DIR}/include/ir/dtype/ + cp ${BASEPATH}/mindspore/core/ir/dtype/type_id.h ${OUTPUT_DIR}/include/ir/dtype/ + mkdir -p ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/schema/*.h ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers + cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/ + cd .. + tar -czf MSLite-0.6.0-linux_arm64.tar.gz MSLite-0.6.0-linux_arm64/ --warning=no-file-changed + sha256sum MSLite-0.6.0-linux_arm64.tar.gz > MSLite-0.6.0-linux_arm64.tar.gz.sha256 + rm -rf MSLite-0.6.0-linux_arm64/ + elif [[ "$LITE_PLATFORM" == "arm32" ]]; then + OUTPUT_DIR=${BASEPATH}/output/MSLite-0.6.0-linux_arm32 + rm -rf ${OUTPUT_DIR} && mkdir -p ${OUTPUT_DIR} && cd ${OUTPUT_DIR} + mkdir -p ${OUTPUT_DIR}/time_profile && mkdir -p ${OUTPUT_DIR}/benchmark + mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib + mkdir -p ${OUTPUT_DIR}/third_party + cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/ + cp ${BASEPATH}/mindspore/lite/build/tools/time_profile/timeprofile ${OUTPUT_DIR}/time_profile/ + cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/ + mkdir -p ${OUTPUT_DIR}/include/ir/dtype/ + cp ${BASEPATH}/mindspore/core/ir/dtype/type_id.h ${OUTPUT_DIR}/include/ir/dtype/ + mkdir -p ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/schema/*.h ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers + cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/ + cd .. + tar -czf MSLite-0.6.0-linux_arm32.tar.gz MSLite-0.6.0-linux_arm32/ --warning=no-file-changed + sha256sum MSLite-0.6.0-linux_arm32.tar.gz > MSLite-0.6.0-linux_arm32.tar.gz.sha256 + rm -rf MSLite-0.6.0-linux_arm32/ + fi + echo "---------------- mindspore lite: build success ----------------" fi - echo "success to build predict project!" } -if [[ "X$COMPILE_PREDICT" = "Xon" ]]; then - build_predict - echo "---------------- mindspore: build end ----------------" +if [[ "X$COMPILE_LITE" = "Xon" ]]; then + build_lite exit else build_mindspore diff --git a/cmake/external_libs/dlpack.cmake b/cmake/external_libs/dlpack.cmake deleted file mode 100644 index a2375c7d35..0000000000 --- a/cmake/external_libs/dlpack.cmake +++ /dev/null @@ -1,7 +0,0 @@ -mindspore_add_pkg(dlpack - VER 0.2 - HEAD_ONLY ./ - URL https://github.com/dmlc/dlpack/archive/0acb731e0e43d15deee27b66f10e4c5b4e667913.zip - MD5 6b8093f17ad4e830d3c63eb3171c4b45) - - diff --git a/cmake/external_libs/dmlc_core.cmake b/cmake/external_libs/dmlc_core.cmake deleted file mode 100644 index e07df83fd6..0000000000 --- a/cmake/external_libs/dmlc_core.cmake +++ /dev/null @@ -1,7 +0,0 @@ -mindspore_add_pkg(dmlc-core - VER 0.3 - HEAD_ONLY ./ - URL https://github.com/dmlc/dmlc-core/archive/808f485387f9a03f78fa9f1159f387d0d91b7a28.zip - MD5 ea36f94c57752bf40fb02dfc362f1ed9) - - diff --git a/cmake/external_libs/glog.cmake b/cmake/external_libs/glog.cmake index d7942a4efd..f372c8e3c2 100644 --- a/cmake/external_libs/glog.cmake +++ b/cmake/external_libs/glog.cmake @@ -1,4 +1,4 @@ -set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS}") +set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") mindspore_add_pkg(glog VER 0.4.0 diff --git a/cmake/external_libs/jpeg_turbo.cmake b/cmake/external_libs/jpeg_turbo.cmake index 6c2c70c709..b37089bfa1 100644 --- a/cmake/external_libs/jpeg_turbo.cmake +++ b/cmake/external_libs/jpeg_turbo.cmake @@ -12,6 +12,7 @@ mindspore_add_pkg(jpeg_turbo URL https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.4.tar.gz MD5 44c43e4a9fb352f47090804529317c88 CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DCMAKE_SKIP_RPATH=TRUE + PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/jpeg_turbo/jpeg_turbo.patch001 ) include_directories(${jpeg_turbo_INC}) add_library(mindspore::jpeg_turbo ALIAS jpeg_turbo::jpeg) diff --git a/cmake/external_libs/mkl_dnn.cmake b/cmake/external_libs/mkl_dnn.cmake index 85a3132ba1..2c8a17c4e7 100644 --- a/cmake/external_libs/mkl_dnn.cmake +++ b/cmake/external_libs/mkl_dnn.cmake @@ -2,18 +2,18 @@ set(onednn_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(onednn_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") if (CMAKE_SYSTEM_NAME MATCHES "Windows") mindspore_add_pkg(onednn - VER 1.1.1 + VER 1.5 LIBS dnnl mkldnn HEAD_ONLY ./include RELEASE on - URL https://github.com/oneapi-src/oneDNN/releases/download/v1.1.1/dnnl_win_1.1.1_cpu_vcomp.zip - MD5 ecaab9ed549643067699c80e5cea1c23) + URL https://github.com/oneapi-src/oneDNN/releases/download/v1.5/dnnl_win_1.5.0_cpu_vcomp.zip + MD5 17757c84f49edd42d34ae8c9288110a1) else() mindspore_add_pkg(onednn - VER 1.1.2 + VER 1.5 LIBS dnnl mkldnn - URL https://github.com/oneapi-src/oneDNN/archive/v1.1.2.tar.gz - MD5 ab40d52230f3ad1d7a6f06ce0f6bc17a + URL https://github.com/oneapi-src/oneDNN/archive/v1.5.tar.gz + MD5 5d97e0e8f4c0b37da5f524533b7a644b CMAKE_OPTION -DDNNL_ARCH_OPT_FLAGS='' -DDNNL_CPU_RUNTIME='SEQ' -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF) endif() diff --git a/cmake/external_libs/rang.cmake b/cmake/external_libs/rang.cmake deleted file mode 100644 index 45ea375cb5..0000000000 --- a/cmake/external_libs/rang.cmake +++ /dev/null @@ -1,7 +0,0 @@ -mindspore_add_pkg(rang - VER 3.1.0 - HEAD_ONLY ./ - URL https://github.com/agauniyal/rang/archive/cabe04d6d6b05356fa8f9741704924788f0dd762.zip - MD5 0c5c9b251fea9ee7ce32f188655be0ea) - - diff --git a/cmake/external_libs/sentencepiece.cmake b/cmake/external_libs/sentencepiece.cmake new file mode 100644 index 0000000000..5d74e2015f --- /dev/null +++ b/cmake/external_libs/sentencepiece.cmake @@ -0,0 +1,25 @@ +if (WIN32) + set(sentencepiece_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -Wno-unused-result -Wno-stringop-overflow -Wno-format-extra-args -Wno-format") + set(sentencepiece_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") + mindspore_add_pkg(sentencepiece + VER 0.1.92 + LIBS sentencepiece sentencepiece_train + URL https://github.com/google/sentencepiece/archive/v0.1.92.tar.gz + CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DSPM_USE_BUILTIN_PROTOBUF=ON + MD5 5dfd2241914b5598a68b2a8542ed8e91 + ) +else () + set(sentencepiece_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -Wno-unused-result -Wno-sign-compare") + set(sentencepiece_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") + mindspore_add_pkg(sentencepiece + VER 0.1.92 + LIBS sentencepiece sentencepiece_train + URL https://github.com/google/sentencepiece/archive/v0.1.92.tar.gz + CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DSPM_USE_BUILTIN_PROTOBUF=OFF -DSPM_ENABLE_SHARED=OFF -DPROTOBUF_INC=${protobuf_INC} + MD5 5dfd2241914b5598a68b2a8542ed8e91 + PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sentencepiece/sentencepiece.patch001 + ) +endif () +include_directories(${sentencepiece_INC}) +add_library(mindspore::sentencepiece ALIAS sentencepiece::sentencepiece) +add_library(mindspore::sentencepiece_train ALIAS sentencepiece::sentencepiece_train) \ No newline at end of file diff --git a/cmake/external_libs/tvm_gpu.cmake b/cmake/external_libs/tvm_gpu.cmake deleted file mode 100644 index 834e2d159d..0000000000 --- a/cmake/external_libs/tvm_gpu.cmake +++ /dev/null @@ -1,15 +0,0 @@ -set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") -mindspore_add_pkg(incubator_tvm_gpu - VER 0.6.0 - LIBS tvm - URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz - MD5 9cbbd32545a776023acabbba270449fe - CUSTOM_CMAKE ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/ - SUBMODULES ${dlpack_DIRPATH} ${dmlc-core_DIRPATH} ${rang_DIRPATH} - SOURCEMODULES topi/python/topi python/tvm - PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/find_library.patch - ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch - ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch - CMAKE_OPTION " ") -add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) \ No newline at end of file diff --git a/cmake/mind_expression.cmake b/cmake/mind_expression.cmake index 9002c23976..8e1e9ce553 100644 --- a/cmake/mind_expression.cmake +++ b/cmake/mind_expression.cmake @@ -15,7 +15,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) -if (ENABLE_DEBUGGER OR ENABLE_SERVING) +if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES) # build dependencies of gRPC include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) @@ -30,7 +30,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake) if(USE_GLOG) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake) endif() -if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows" AND NOT ENABLE_GE) +if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zeromq.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pslite.cmake) endif() @@ -38,19 +38,15 @@ endif() find_package(Python3) include_directories(${Python3_INCLUDE_DIRS}) include_directories(${CMAKE_SOURCE_DIR}/third_party) +if (ENABLE_MPI) + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake) +endif() + if (ENABLE_CPU) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake) - if (ENABLE_MPI) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake) - endif() endif() if (ENABLE_GPU) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/dlpack.cmake) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/dmlc_core.cmake) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/rang.cmake) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tvm_gpu.cmake) - if (ENABLE_MPI) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake) endif() @@ -69,12 +65,16 @@ endif() if (ENABLE_MINDDATA) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/icu4c.cmake) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cppjieba.cmake) + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sentencepiece.cmake) +endif() + +if (ENABLE_MINDDATA OR ENABLE_SERVING) + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake) endif() include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) diff --git a/cmake/options.cmake b/cmake/options.cmake index 2470c25a90..7971d9e342 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -70,6 +70,10 @@ if (ENABLE_GPU) add_compile_definitions(ENABLE_GPU_COLLECTIVE) endif() +if (ENABLE_CPU) + add_compile_definitions(ENABLE_CPU) +endif() + if (ENABLE_GE) add_compile_definitions(ENABLE_GE) add_compile_definitions(CUSTOM_OP) @@ -116,10 +120,6 @@ if(ENABLE_DUMP_E2E) add_compile_definitions(ENABLE_DUMP_E2E) endif() -if(ENABLE_DATA_DUMP) - add_compile_definitions(ENABLE_DATA_DUMP) -endif() - if(ENABLE_DEBUGGER) add_compile_definitions(ENABLE_DEBUGGER) endif() diff --git a/cmake/package.cmake b/cmake/package.cmake index 7b3c2f7bb2..edc88c2b34 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -40,6 +40,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows") set(jpeg_turbo_LIBPATH ${jpeg_turbo_LIBPATH}/../bin/) set(sqlite_LIBPATH ${sqlite_LIBPATH}/../bin/) set(tinyxml2_LIBPATH ${tinyxml2_LIBPATH}/../bin/) + set(sentencepiece_LIBPATH ${sentencepiece_LIBPATH}/../bin/) else () set(INSTALL_LIB_DIR "lib") endif () @@ -91,6 +92,14 @@ if (ENABLE_MINDDATA) DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) + file(GLOB_RECURSE SENTENCEPIECE_LIB_LIST + ${sentencepiece_LIBPATH}/libsentencepiece* + ) + install( + FILES ${SENTENCEPIECE_LIB_LIST} + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) if (CMAKE_SYSTEM_NAME MATCHES "Windows") message("icu4c does not support windows system temporarily") else() @@ -123,24 +132,39 @@ if (ENABLE_CPU) endif () if (ENABLE_MPI) - install( - TARGETS _ms_mpi - DESTINATION ${INSTALL_BASE_DIR} - COMPONENT mindspore + if (ENABLE_GPU) + install( + TARGETS _ms_mpi + DESTINATION ${INSTALL_BASE_DIR} + COMPONENT mindspore + ) + endif () + if (ENABLE_CPU) + install( + TARGETS mpi_adapter + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) + endif () + file(GLOB_RECURSE MPI_LIB_LIST + ${ompi_LIBPATH}/libmpi${CMAKE_SHARED_LIBRARY_SUFFIX}* + ${ompi_LIBPATH}/libopen*${CMAKE_SHARED_LIBRARY_SUFFIX}* ) install( - TARGETS mpi_adapter + FILES ${MPI_LIB_LIST} DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) endif () if (ENABLE_GPU) + if (ENABLE_MPI) install( TARGETS gpu_collective DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) + endif () install( TARGETS gpu_queue DESTINATION ${INSTALL_LIB_DIR} @@ -216,34 +240,12 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/common ${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/communication + ${CMAKE_SOURCE_DIR}/mindspore/profiler DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore ) -if (ENABLE_GPU) - install( - DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/_akg - DESTINATION ${INSTALL_PY_DIR}/../ - COMPONENT mindspore - ) - if (EXISTS ${incubator_tvm_gpu_ROOT}) - file(GLOB_RECURSE GLOG_LIB_LIST ${incubator_tvm_gpu_LIBPATH}/lib*) - install( - FILES ${GLOG_LIB_LIST} - DESTINATION ${INSTALL_LIB_DIR} - COMPONENT mindspore - ) - install( - DIRECTORY - ${incubator_tvm_gpu_ROOT}/topi/python/topi - ${incubator_tvm_gpu_ROOT}/python/tvm - DESTINATION ${INSTALL_PY_DIR}/../_akg - COMPONENT mindspore - ) - endif () -endif () - -if (ENABLE_D AND ENABLE_AKG) +if ((ENABLE_D OR ENABLE_GPU) AND ENABLE_AKG) set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg) install( DIRECTORY @@ -268,6 +270,13 @@ if (ENABLE_SERVING) COMPONENT mindspore ) + install( + FILES ${CMAKE_SOURCE_DIR}/build/mindspore/serving/ms_service_pb2.py + ${CMAKE_SOURCE_DIR}/build/mindspore/serving/ms_service_pb2_grpc.py + DESTINATION ${INSTALL_PY_DIR} + COMPONENT mindspore + ) + install( TARGETS inference DESTINATION ${INSTALL_LIB_DIR} diff --git a/cmake/package_script.cmake b/cmake/package_script.cmake index 94ffc71b49..0ade0af696 100644 --- a/cmake/package_script.cmake +++ b/cmake/package_script.cmake @@ -1,13 +1,16 @@ # find exec find_package(Python3 3.7 COMPONENTS Interpreter Development) if (NOT Python3_FOUND) - message("No python3 found.") - return () + message(FATAL_ERROR "No python3 found.") endif () set(PYTHON ${Python3_EXECUTABLE}) set(PYTHON_VERSION ${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}) +if (NOT PYTHON_VERSION MATCHES "3.7") + message(FATAL_ERROR "FIND PYTHON VERSION ${PYTHON_VERSION} BUT CAN NOT MATCH PYTHON VERSION 3.7") +endif () + find_package(Git) if (NOT GIT_FOUND) message("No git found.") diff --git a/config/data_dump.json b/config/data_dump.json index fc08f78590..f1e8177eaf 100644 --- a/config/data_dump.json +++ b/config/data_dump.json @@ -1,14 +1,16 @@ { "DumpSettings": { "net_name": "ResNet50", - "mode": 1, + "dump_mode": 1, + "op_debug_mode": 3, "iteration": 0, "kernels": ["Default/Conv2D-op2", "Default/TensorAdd-op10"] }, "DumpSettingsSpec": { "net_name": "net name eg:ResNet50", - "mode": "0: dump all kernels, 1: dump kernels in kernels list", + "dump_mode": "0: dump all kernels, 1: dump kernels in kernels list", + "op_debug_mode": "0: close debug, 1: debug ai-core overflow, 2: debug atomic overflow, 3: debug all overflow", "iteration": "specified iteration ", "kernels": "op's full scope name which need to be dump" } diff --git a/docker/mindspore-cpu/0.6.0-beta/Dockerfile b/docker/mindspore-cpu/0.6.0-beta/Dockerfile new file mode 100644 index 0000000000..c203cf1721 --- /dev/null +++ b/docker/mindspore-cpu/0.6.0-beta/Dockerfile @@ -0,0 +1,67 @@ +FROM ubuntu:18.04 + +MAINTAINER leonwanghui + +# Set env +ENV PYTHON_ROOT_PATH /usr/local/python-3.7.5 +ENV PATH /usr/local/bin:$PATH + +# Install base tools +RUN apt update \ + && DEBIAN_FRONTEND=noninteractive apt install -y \ + vim \ + wget \ + curl \ + xz-utils \ + net-tools \ + openssh-client \ + git \ + ntpdate \ + tzdata \ + tcl \ + sudo \ + bash-completion + +# Install compile tools +RUN DEBIAN_FRONTEND=noninteractive apt install -y \ + gcc \ + g++ \ + zlibc \ + make \ + libgmp-dev \ + patch \ + autoconf \ + libtool \ + automake \ + flex + +# Set bash +RUN echo "dash dash/sh boolean false" | debconf-set-selections +RUN DEBIAN_FRONTEND=noninteractive dpkg-reconfigure dash + +# Install python (v3.7.5) +RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \ + libgdbm-dev libgdbm-compat-dev liblzma-dev libreadline-dev libsqlite3-dev \ + && cd /tmp \ + && wget https://github.com/python/cpython/archive/v3.7.5.tar.gz \ + && tar -xvf v3.7.5.tar.gz \ + && cd /tmp/cpython-3.7.5 \ + && mkdir -p ${PYTHON_ROOT_PATH} \ + && ./configure --prefix=${PYTHON_ROOT_PATH} \ + && make -j4 \ + && make install -j4 \ + && rm -f /usr/local/bin/python \ + && rm -f /usr/local/bin/pip \ + && ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \ + && ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \ + && rm -rf /tmp/cpython-3.7.5 \ + && rm -f /tmp/v3.7.5.tar.gz + +# Set pip source +RUN mkdir -pv /root/.pip \ + && echo "[global]" > /root/.pip/pip.conf \ + && echo "trusted-host=mirrors.aliyun.com" >> /root/.pip/pip.conf \ + && echo "index-url=http://mirrors.aliyun.com/pypi/simple/" >> /root/.pip/pip.conf + +# Install MindSpore cpu whl package +RUN pip install --no-cache-dir https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.6.0-beta/MindSpore/cpu/ubuntu_x86/mindspore-0.6.0-cp37-cp37m-linux_x86_64.whl diff --git a/docker/mindspore-gpu/0.6.0-beta/Dockerfile b/docker/mindspore-gpu/0.6.0-beta/Dockerfile new file mode 100644 index 0000000000..90d90a4e17 --- /dev/null +++ b/docker/mindspore-gpu/0.6.0-beta/Dockerfile @@ -0,0 +1,83 @@ +FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 + +MAINTAINER leonwanghui + +# Set env +ENV PYTHON_ROOT_PATH /usr/local/python-3.7.5 +ENV OMPI_ROOT_PATH /usr/local/openmpi-3.1.5 +ENV PATH ${OMPI_ROOT_PATH}/bin:/usr/local/bin:$PATH +ENV LD_LIBRARY_PATH ${OMPI_ROOT_PATH}/lib:$LD_LIBRARY_PATH + +# Install base tools +RUN apt update \ + && DEBIAN_FRONTEND=noninteractive apt install -y \ + vim \ + wget \ + curl \ + xz-utils \ + net-tools \ + openssh-client \ + git \ + ntpdate \ + tzdata \ + tcl \ + sudo \ + bash-completion + +# Install compile tools +RUN DEBIAN_FRONTEND=noninteractive apt install -y \ + gcc \ + g++ \ + zlibc \ + make \ + libgmp-dev \ + patch \ + autoconf \ + libtool \ + automake \ + flex \ + libnccl2=2.4.8-1+cuda10.1 \ + libnccl-dev=2.4.8-1+cuda10.1 + +# Set bash +RUN echo "dash dash/sh boolean false" | debconf-set-selections +RUN DEBIAN_FRONTEND=noninteractive dpkg-reconfigure dash + +# Install python (v3.7.5) +RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \ + libgdbm-dev libgdbm-compat-dev liblzma-dev libreadline-dev libsqlite3-dev \ + && cd /tmp \ + && wget https://github.com/python/cpython/archive/v3.7.5.tar.gz \ + && tar -xvf v3.7.5.tar.gz \ + && cd /tmp/cpython-3.7.5 \ + && mkdir -p ${PYTHON_ROOT_PATH} \ + && ./configure --prefix=${PYTHON_ROOT_PATH} \ + && make -j4 \ + && make install -j4 \ + && rm -f /usr/local/bin/python \ + && rm -f /usr/local/bin/pip \ + && ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \ + && ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \ + && rm -rf /tmp/cpython-3.7.5 \ + && rm -f /tmp/v3.7.5.tar.gz + +# Set pip source +RUN mkdir -pv /root/.pip \ + && echo "[global]" > /root/.pip/pip.conf \ + && echo "trusted-host=mirrors.aliyun.com" >> /root/.pip/pip.conf \ + && echo "index-url=http://mirrors.aliyun.com/pypi/simple/" >> /root/.pip/pip.conf + +# Install openmpi (v3.1.5) +RUN cd /tmp \ + && wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.5.tar.gz \ + && tar -xvf openmpi-3.1.5.tar.gz \ + && cd /tmp/openmpi-3.1.5 \ + && mkdir -p ${OMPI_ROOT_PATH} \ + && ./configure --prefix=${OMPI_ROOT_PATH} \ + && make -j4 \ + && make install -j4 \ + && rm -rf /tmp/openmpi-3.1.5 \ + && rm -f /tmp/openmpi-3.1.5.tar.gz + +# Install MindSpore cuda-10.1 whl package +RUN pip install --no-cache-dir https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.6.0-beta/MindSpore/gpu/ubuntu_x86/cuda-10.1/mindspore_gpu-0.6.0-cp37-cp37m-linux_x86_64.whl diff --git a/graphengine b/graphengine index eee707935c..e64a1cfc04 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit eee707935c066c16e9b9cd207f8125871b6b97cf +Subproject commit e64a1cfc0457c96859bc9be1693443aa14f2e9df diff --git a/mindspore/_akg/ops/__init__.py b/hub/docs/.gitkeep similarity index 100% rename from mindspore/_akg/ops/__init__.py rename to hub/docs/.gitkeep diff --git a/mindspore/_akg/ops/array/__init__.py b/hub/images/.gitkeep similarity index 100% rename from mindspore/_akg/ops/array/__init__.py rename to hub/images/.gitkeep diff --git a/mindspore/_akg/ops/math/__init__.py b/hub/scripts/.gitkeep similarity index 100% rename from mindspore/_akg/ops/math/__init__.py rename to hub/scripts/.gitkeep diff --git a/include/infer_log.h b/include/infer_log.h new file mode 100644 index 0000000000..869588bda3 --- /dev/null +++ b/include/infer_log.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_LOG_H_ +#define MINDSPORE_INFERENCE_LOG_H_ + +#include +#include +#include +#include +#include +#include + +#ifndef ENABLE_ACL +#include "mindspore/core/utils/log_adapter.h" +#else // ENABLE_ACL +#include "acl/acl.h" +#endif + +namespace mindspore::inference { + +class LogStream { + public: + LogStream() { sstream_ = std::make_shared(); } + ~LogStream() = default; + + template + LogStream &operator<<(const T &val) noexcept { + (*sstream_) << val; + return *this; + } + + LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept { + (*sstream_) << func; + return *this; + } + + friend class LogWriter; + friend class Status; + + private: + std::shared_ptr sstream_; +}; + +#ifndef ENABLE_ACL +#define MSI_LOG(level) MS_LOG(level) + +#define MSI_LOG_DEBUG MSI_LOG(DEBUG) +#define MSI_LOG_INFO MSI_LOG(INFO) +#define MSI_LOG_WARNING MSI_LOG(WARNING) +#define MSI_LOG_ERROR MSI_LOG(ERROR) + +#define MSI_ASSERT(item) MS_ASSERT(item) + +#else // ENABLE_ACL + +class LogWriter { + public: + LogWriter(const char *file, int line, const char *func, aclLogLevel log_level) + : file_(file), line_(line), func_(func), log_level_(log_level) {} + ~LogWriter() = default; + + void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))) { + std::ostringstream msg; + msg << stream.sstream_->rdbuf(); + OutputLog(msg); + } + + private: + void OutputLog(const std::ostringstream &msg) const { aclAppLog(log_level_, func_, file_, line_, msg.str().c_str()); } + + const char *file_; + int line_; + const char *func_; + aclLogLevel log_level_; +}; + +#define MSILOG_IF(level) inference::LogWriter(__FILE__, __LINE__, __FUNCTION__, ACL_##level) < inference::LogStream() + +#define MSI_LOG(level) MSI_LOG_##level + +#define MSI_LOG_DEBUG MSILOG_IF(DEBUG) +#define MSI_LOG_INFO MSILOG_IF(INFO) +#define MSI_LOG_WARNING MSILOG_IF(WARNING) +#define MSI_LOG_ERROR MSILOG_IF(ERROR) + +#define MSI_ASSERT(item) + +#endif // ENABLE_ACL + +#define INFER_STATUS(code) inference::Status(code) < inference::LogStream() + +} // namespace mindspore::inference + +#endif // MINDSPORE_INFERENCE_LOG_H_ diff --git a/include/infer_tensor.h b/include/infer_tensor.h new file mode 100644 index 0000000000..7aed3e9a47 --- /dev/null +++ b/include/infer_tensor.h @@ -0,0 +1,204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INCLUDE_INFER_TENSOR_H_ +#define MINDSPORE_INCLUDE_INFER_TENSOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "securec/include/securec.h" +#include "include/infer_log.h" + +namespace mindspore { +#define MS_API __attribute__((visibility("default"))) +namespace inference { + +enum DataType { + kMSI_Unknown = 0, + kMSI_Bool = 1, + kMSI_Int8 = 2, + kMSI_Int16 = 3, + kMSI_Int32 = 4, + kMSI_Int64 = 5, + kMSI_Uint8 = 6, + kMSI_Uint16 = 7, + kMSI_Uint32 = 8, + kMSI_Uint64 = 9, + kMSI_Float16 = 10, + kMSI_Float32 = 11, + kMSI_Float64 = 12, +}; + +class InferTensorBase { + public: + InferTensorBase() = default; + virtual ~InferTensorBase() = default; + + virtual DataType data_type() const = 0; + virtual void set_data_type(DataType type) = 0; + virtual std::vector shape() const = 0; + virtual void set_shape(const std::vector &shape) = 0; + virtual const void *data() const = 0; + virtual size_t data_size() const = 0; + virtual bool resize_data(size_t data_len) = 0; + virtual void *mutable_data() = 0; + + bool set_data(const void *data, size_t data_len) { + resize_data(data_len); + if (mutable_data() == nullptr) { + MSI_LOG_ERROR << "set data failed, data len " << data_len; + return false; + } + if (data_size() != data_len) { + MSI_LOG_ERROR << "set data failed, tensor current data size " << data_size() << " not match data len " + << data_len; + return false; + } + if (data_len == 0) { + return true; + } + memcpy_s(mutable_data(), data_size(), data, data_len); + return true; + } + + int64_t ElementNum() const { + std::vector shapex = shape(); + return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies()); + } + + int GetTypeSize(DataType type) const { + const std::map type_size_map{ + {kMSI_Bool, sizeof(bool)}, {kMSI_Float64, sizeof(double)}, {kMSI_Int8, sizeof(int8_t)}, + {kMSI_Uint8, sizeof(uint8_t)}, {kMSI_Int16, sizeof(int16_t)}, {kMSI_Uint16, sizeof(uint16_t)}, + {kMSI_Int32, sizeof(int32_t)}, {kMSI_Uint32, sizeof(uint32_t)}, {kMSI_Int64, sizeof(int64_t)}, + {kMSI_Uint64, sizeof(uint64_t)}, {kMSI_Float16, sizeof(uint16_t)}, {kMSI_Float32, sizeof(float)}, + }; + auto it = type_size_map.find(type); + if (it != type_size_map.end()) { + return it->second; + } + return 0; + } +}; + +class InferTensor : public InferTensorBase { + public: + DataType type_; + std::vector shape_; + std::vector data_; + + public: + InferTensor() = default; + InferTensor(DataType type, std::vector shape, const void *data, size_t data_len) { + set_data_type(type); + set_shape(shape); + set_data(data, data_len); + } + + void set_data_type(DataType type) override { type_ = type; } + DataType data_type() const override { return type_; } + + void set_shape(const std::vector &shape) override { shape_ = shape; } + std::vector shape() const override { return shape_; } + + const void *data() const override { return data_.data(); } + size_t data_size() const override { return data_.size(); } + + bool resize_data(size_t data_len) override { + data_.resize(data_len); + return true; + } + void *mutable_data() override { return data_.data(); } +}; + +class InferImagesBase { + public: + virtual size_t batch_size() const = 0; + virtual bool get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const = 0; + virtual size_t input_index() const = 0; // the index of images as input in model +}; + +class RequestBase { + public: + virtual size_t size() const = 0; + virtual const InferTensorBase *operator[](size_t index) const = 0; +}; + +class ImagesRequestBase { + public: + virtual size_t size() const = 0; + virtual const InferImagesBase *operator[](size_t index) const = 0; +}; + +class ReplyBase { + public: + virtual size_t size() const = 0; + virtual InferTensorBase *operator[](size_t index) = 0; + virtual const InferTensorBase *operator[](size_t index) const = 0; + virtual InferTensorBase *add() = 0; + virtual void clear() = 0; +}; + +class VectorInferTensorWrapReply : public ReplyBase { + public: + explicit VectorInferTensorWrapReply(std::vector &tensor_list) : tensor_list_(tensor_list) {} + + size_t size() const { return tensor_list_.size(); } + InferTensorBase *operator[](size_t index) { + if (index >= tensor_list_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size(); + return nullptr; + } + return &(tensor_list_[index]); + } + const InferTensorBase *operator[](size_t index) const { + if (index >= tensor_list_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size(); + return nullptr; + } + return &(tensor_list_[index]); + } + InferTensorBase *add() { + tensor_list_.push_back(InferTensor()); + return &(tensor_list_.back()); + } + void clear() { tensor_list_.clear(); } + std::vector &tensor_list_; +}; + +class VectorInferTensorWrapRequest : public RequestBase { + public: + explicit VectorInferTensorWrapRequest(const std::vector &tensor_list) : tensor_list_(tensor_list) {} + + size_t size() const { return tensor_list_.size(); } + const InferTensorBase *operator[](size_t index) const { + if (index >= tensor_list_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size(); + return nullptr; + } + return &(tensor_list_[index]); + } + const std::vector &tensor_list_; +}; + +} // namespace inference +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_INFER_TENSOR_H_ diff --git a/include/inference.h b/include/inference.h index 7e5ee27d49..8598401b75 100644 --- a/include/inference.h +++ b/include/inference.h @@ -20,25 +20,63 @@ #include #include #include -#include "include/ms_tensor.h" +#include "include/infer_tensor.h" +#include "include/infer_log.h" namespace mindspore { -class FuncGraph; namespace inference { -class MS_API MSSession { - public: - MSSession() = default; - static std::shared_ptr CreateSession(const std::string &device, uint32_t device_id); +enum StatusCode { SUCCESS = 0, FAILED, INVALID_INPUTS }; - virtual uint32_t CompileGraph(std::shared_ptr funcGraphPtr) = 0; +class Status { + public: + Status() : status_code_(FAILED) {} + Status(enum StatusCode status_code, const std::string &status_msg = "") + : status_code_(status_code), status_msg_(status_msg) {} + bool IsSuccess() const { return status_code_ == SUCCESS; } + enum StatusCode StatusCode() const { return status_code_; } + std::string StatusMessage() const { return status_msg_; } + bool operator==(const Status &other) const { return status_code_ == other.status_code_; } + bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } + bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } + bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } + operator bool() const = delete; + Status &operator<(const LogStream &stream) noexcept __attribute__((visibility("default"))) { + status_msg_ = stream.sstream_->str(); + return *this; + } - virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) = 0; + private: + enum StatusCode status_code_; + std::string status_msg_; }; -std::shared_ptr MS_API LoadModel(const char *model_buf, size_t size, const std::string &device); +class MS_API InferSession { + public: + InferSession() = default; + virtual ~InferSession() = default; + virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0; + virtual Status FinalizeEnv() = 0; + virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0; + virtual Status UnloadModel(uint32_t model_id) = 0; + // override this method to avoid request/reply data copy + virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0; + + virtual Status ExecuteModel(uint32_t model_id, const std::vector &inputs, + std::vector &outputs) { + VectorInferTensorWrapRequest request(inputs); + VectorInferTensorWrapReply reply(outputs); + return ExecuteModel(model_id, request, reply); + } + // default not support input data preprocess(decode, resize, crop, crop&paste, etc.) + virtual Status ExecuteModel(uint32_t /*model_id*/, + const ImagesRequestBase & /*images_inputs*/, // images for preprocess + const RequestBase & /*request*/, ReplyBase & /*reply*/) { + return FAILED; + } + static std::shared_ptr CreateSession(const std::string &device, uint32_t device_id); +}; -void MS_API ExitInference(); } // namespace inference } // namespace mindspore #endif // MINDSPORE_INCLUDE_MS_SESSION_H diff --git a/include/ms_tensor.h b/include/ms_tensor.h deleted file mode 100644 index fc59e12328..0000000000 --- a/include/ms_tensor.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_ -#define MINDSPORE_INCLUDE_MS_TENSOR_H_ - -#include -#include -#include -#include "mindspore/core/ir/dtype/type_id.h" - -namespace mindspore { -#define MS_API __attribute__((visibility("default"))) -namespace inference { -class MS_API MSTensor { - public: - MSTensor() = default; - // brief Create a MSTensor pointer. - // - // param data_type DataTypeId of tensor to be created. - // param shape Shape of tensor to be created. - // return MSTensor pointer. - static MSTensor *CreateTensor(TypeId data_type, const std::vector &shape); - - ~MSTensor() = default; - - virtual TypeId data_type() const = 0; - - virtual TypeId set_data_type(const TypeId data_type) = 0; - - virtual std::vector shape() const = 0; - - virtual size_t set_shape(const std::vector &shape) = 0; - - virtual int DimensionSize(size_t index) const = 0; - // brief Get number of element in MSTensor. - // - // return Number of element in MSTensor. - virtual int ElementsNum() const = 0; - - virtual std::size_t hash() const = 0; - // brief Get byte size of data in MSTensor. - // - // return Byte size of data in MSTensor. - virtual size_t Size() const = 0; - // brief Get pointer of data in MSTensor. - // - // The data pointer can be used to both write or read data in MSTensor. - // - // return A pointer points to data in MSTensor. - virtual void *MutableData() const = 0; -}; -using MultiTensor = std::vector>; -} // namespace inference -} // namespace mindspore -#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_ diff --git a/mindspore/_akg/__init__.py b/mindspore/_akg/__init__.py deleted file mode 100644 index d0c1f0ffe4..0000000000 --- a/mindspore/_akg/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""__init__""" -from . import add_path -from .op_build import op_build -from .message import compilewithjson diff --git a/mindspore/_akg/add_path.py b/mindspore/_akg/add_path.py deleted file mode 100644 index d1e50f8177..0000000000 --- a/mindspore/_akg/add_path.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""add tvm path""" -import sys -import os - - -def AKGAddPath(): - """_akg add path.""" - pwd = os.path.dirname(os.path.realpath(__file__)) - tvm_path = os.path.realpath(pwd) - if tvm_path not in sys.path: - sys.path.insert(0, tvm_path) - else: - sys.path.remove(tvm_path) - sys.path.insert(0, tvm_path) - - -class AKGMetaPathFinder: - """class AKGMetaPath finder.""" - - def find_module(self, fullname, path=None): - """method _akg find module.""" - _ = path - if fullname.startswith("_akg.tvm"): - rname = fullname[5:] - return AKGMetaPathLoader(rname) - if fullname.startswith("_akg.topi"): - rname = fullname[5:] - return AKGMetaPathLoader(rname) - return None - - -class AKGMetaPathLoader: - """class AKGMetaPathLoader loader.""" - - def __init__(self, rname): - self.__rname = rname - - def load_module(self, fullname): - if self.__rname in sys.modules: - sys.modules.pop(self.__rname) - AKGAddPath() - __import__(self.__rname, globals(), locals()) - self.__target_module = sys.modules[self.__rname] - sys.modules[fullname] = self.__target_module - return self.__target_module - - -sys.meta_path.insert(0, AKGMetaPathFinder()) diff --git a/mindspore/_akg/gpu/__init__.py b/mindspore/_akg/gpu/__init__.py deleted file mode 100644 index 4c11499594..0000000000 --- a/mindspore/_akg/gpu/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""__init__""" -from .equal import Equal -from .equal import gpu_schedule_Equal -from .tile import Tile -from .tile import gpu_schedule_Tile -from .cast import Cast -from .cast import gpu_schedule_Cast -from .relu6 import ReLU6, gpu_schedule_ReLU6 -from .relu6_grad import ReLU6Grad, gpu_schedule_ReLU6Grad -from .squeeze import Squeeze, gpu_schedule_Squeeze -from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad -from .mean import SimpleMean, gpu_schedule_SimpleMean -from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad -from .mul import Mul, gpu_schedule_Mul -from .hsigmoid import HSigmoid, gpu_schedule_HSigmoid -from .hsigmoid_grad import HSigmoidGrad, gpu_schedule_HSigmoidGrad -from .hswish import HSwish, gpu_schedule_HSwish -from .hswish_grad import HSwishGrad, gpu_schedule_HSwishGrad -from .logical_or import LogicalOr, gpu_schedule_LogicalOr -from .logical_not import LogicalNot, gpu_schedule_LogicalNot -from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd -from .sub import Sub, gpu_schedule_Sub -from .less_equal import LessEqual, gpu_schedule_LessEqual -from .notequal import NotEqual, gpu_schedule_NotEqual -from .greater_equal import GreaterEqual, gpu_schedule_GreaterEqual diff --git a/mindspore/_akg/gpu/cast.py b/mindspore/_akg/gpu/cast.py deleted file mode 100644 index d6b38b6e9b..0000000000 --- a/mindspore/_akg/gpu/cast.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""cast""" -import logging -import _akg.tvm -from _akg.ops.math import cast -from _akg.topi.generic import schedule_elemwise - -def Cast(x, dst_type): - """cast.""" - return cast.cast(x, dst_type) - - -def gpu_schedule_Cast(outs): - """ - gpu schedule for cast. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - logging.info("Skip because %s is not enabled", device) - return None - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/default_schedule.py b/mindspore/_akg/gpu/default_schedule.py deleted file mode 100644 index 811cc2d710..0000000000 --- a/mindspore/_akg/gpu/default_schedule.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""default schedule function for GPU""" -from queue import Queue - -import _akg.tvm as tvm - -DEFAULT_GPU_THREAD = 1024 - - -def default_schedule(outs): - """ - default schedule function. - - Args: - outs (Union[tvm.tensor.Tensor, list[tvm.tensor.Tensor]]): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - if not isinstance(outs, tvm.tensor.Tensor) and not isinstance(outs, list): - raise ValueError("outs should be list of _akg.tvm.tensor.Tensor or _akg.tvm.tensor.Tensor") - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - outs_list = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - with tvm.target.create(device): - sch = tvm.create_schedule(outs_list[0].op) - outputs_tensor = Queue() - outputs_tensor.put(outs_list[0]) - op_list = [] - while not outputs_tensor.empty(): - out = outputs_tensor.get() - if out.op not in op_list and isinstance(out.op, tvm.tensor.ComputeOp): - op_list.append(out.op) - for input_tensor in out.op.input_tensors: - outputs_tensor.put(input_tensor) - for op in op_list: - stage = sch[op.output(0)] - bx, tx = stage.split(op.axis[0], factor=DEFAULT_GPU_THREAD) - stage.bind(bx, tvm.thread_axis("blockIdx.x")) - stage.bind(tx, tvm.thread_axis("threadIdx.x")) - return sch diff --git a/mindspore/_akg/gpu/equal.py b/mindspore/_akg/gpu/equal.py deleted file mode 100644 index 3321c10b2c..0000000000 --- a/mindspore/_akg/gpu/equal.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""equal""" -import _akg.tvm -from _akg.ops.math import equal -from _akg.topi.generic import schedule_elemwise - -def Equal(x, y): - """equal.""" - return equal.equal(x, y) - - -def gpu_schedule_Equal(outs): - """ - gpu schedule for Equal. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/greater_equal.py b/mindspore/_akg/gpu/greater_equal.py deleted file mode 100644 index 0212cac03c..0000000000 --- a/mindspore/_akg/gpu/greater_equal.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""greater_equal""" -import _akg.tvm -from _akg.ops.math import greater_equal -from _akg.topi.generic import schedule_elemwise - -def GreaterEqual(x, y): - """GreaterEqual.""" - return greater_equal.greater_equal(x, y) - - -def gpu_schedule_GreaterEqual(outs): - """ - GPU schedule for GreaterEqual. - - Args: - outs (tvm.tensor.Tensor): Outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/hsigmoid.py b/mindspore/_akg/gpu/hsigmoid.py deleted file mode 100644 index b313c2fd5a..0000000000 --- a/mindspore/_akg/gpu/hsigmoid.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""hsigmoid""" -import _akg.topi as topi -import _akg.tvm as tvm -from _akg.topi import tag - - -@tvm.tag_scope(tag=tag.ELEMWISE) -def topi_nn_hsigmoid(x): - """ - topi hsigmoid - Args: - x: - - Returns: - - """ - return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, - tvm.if_then_else(x(*i) >= 3, 1, - (x(*i) + 3) / 6))) - - -def HSigmoid(x): - """ - HSigmoid - Args: - x: - - Returns: - - """ - return topi_nn_hsigmoid(x) - - -def gpu_schedule_HSigmoid(outs): - """ - gpu schedule HSigmoid - Args: - outs: - - Returns: - - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with tvm.target.create(device): - sch = topi.cuda.schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/hsigmoid_grad.py b/mindspore/_akg/gpu/hsigmoid_grad.py deleted file mode 100644 index bdde4ed3ca..0000000000 --- a/mindspore/_akg/gpu/hsigmoid_grad.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""HSigmoid grad""" -import _akg.topi as topi -import _akg.tvm as tvm - - -def HSigmoidGrad(y_grad, x): - """ - HSigmoidGrad - Args: - y_grad: - x: - - Returns: - - """ - return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, - tvm.if_then_else(x(*i) >= 3, 0, - y_grad(*i) / 6))) - - -def gpu_schedule_HSigmoidGrad(outs): - """ - gpu schedule ReLU6Grad - Args: - outs: - - Returns: - - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - - with tvm.target.create(device): - sch = topi.cuda.schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/hswish.py b/mindspore/_akg/gpu/hswish.py deleted file mode 100644 index 3def3c4b35..0000000000 --- a/mindspore/_akg/gpu/hswish.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""HSwish""" -import _akg.topi as topi -import _akg.tvm as tvm -from _akg.topi import tag - - -@tvm.tag_scope(tag=tag.ELEMWISE) -def topi_nn_HSwish(x): - """ - topi HSwish - Args: - x: - - Returns: - - """ - return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, - tvm.if_then_else(x(*i) >= 3, x(*i), - x(*i) * (x(*i) + 3) / 6))) - - -def HSwish(x): - """ - HSwish - Args: - x: - - Returns: - - """ - return topi_nn_HSwish(x) - - -def gpu_schedule_HSwish(outs): - """ - gpu schedule HSwish - Args: - outs: - - Returns: - - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with tvm.target.create(device): - sch = topi.cuda.schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/hswish_grad.py b/mindspore/_akg/gpu/hswish_grad.py deleted file mode 100644 index cadbf0f663..0000000000 --- a/mindspore/_akg/gpu/hswish_grad.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""HSwishGrad""" -import _akg.topi as topi -import _akg.tvm as tvm - - -def HSwishGrad(y_grad, x): - """ - HSwishGrad - Args: - y_grad: - x: - - Returns: - - """ - shape = x.shape - - res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, y_grad(*i) * (2 * x(*i) + 3) / 6)) - res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= 3, y_grad(*i), res0(*i))) - return res6 - - -def gpu_schedule_HSwishGrad(outs): - """ - gpu schedule HSwishGrad - Args: - outs: - - Returns: - - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - - with tvm.target.create(device): - sch = topi.cuda.schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/less_equal.py b/mindspore/_akg/gpu/less_equal.py deleted file mode 100644 index c58346e929..0000000000 --- a/mindspore/_akg/gpu/less_equal.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""less_equal""" -import _akg.tvm -from _akg.ops.math import less_equal -from _akg.topi.generic import schedule_elemwise - -def LessEqual(x, y): - """LessEqual.""" - return less_equal.less_equal(x, y) - - -def gpu_schedule_LessEqual(outs): - """ - GPU schedule for LessEqual. - - Args: - outs (tvm.tensor.Tensor): Outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/logical_and.py b/mindspore/_akg/gpu/logical_and.py deleted file mode 100644 index 6453901458..0000000000 --- a/mindspore/_akg/gpu/logical_and.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""logical_and""" -import _akg.tvm -from _akg.ops.math import logical_and -from _akg.topi.generic import schedule_elemwise - -def LogicalAnd(x, y): - """LogicalAnd.""" - return logical_and.logical_and(x, y) - - -def gpu_schedule_LogicalAnd(outs): - """ - GPU schedule for LogicalAnd. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/logical_not.py b/mindspore/_akg/gpu/logical_not.py deleted file mode 100644 index 0a38107187..0000000000 --- a/mindspore/_akg/gpu/logical_not.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""logical_not""" -import _akg.tvm -from _akg.ops.math import logical_not -from _akg.topi.generic import schedule_elemwise - -def LogicalNot(x): - """LogicalNot.""" - return logical_not.logical_not(x) - - -def gpu_schedule_LogicalNot(outs): - """ - GPU schedule for LogicalNot. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/logical_or.py b/mindspore/_akg/gpu/logical_or.py deleted file mode 100644 index 1bd49bedbc..0000000000 --- a/mindspore/_akg/gpu/logical_or.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""logical_or""" -import _akg.tvm -from _akg.ops.math import logical_or -from _akg.topi.generic import schedule_elemwise - -def LogicalOr(x, y): - """LogicalOr.""" - return logical_or.logical_or(x, y) - - -def gpu_schedule_LogicalOr(outs): - """ - GPU schedule for LogicalOr. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/mean.py b/mindspore/_akg/gpu/mean.py deleted file mode 100644 index e9cdb6d551..0000000000 --- a/mindspore/_akg/gpu/mean.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""mean op compute and schedule""" -import _akg.tvm as tvm -from _akg.ops.math.mean import mean -from .default_schedule import DEFAULT_GPU_THREAD - -def Mean(x, axis=None, keepdims=True): - """mean.""" - outs = mean(x, axis, keepdims) - - # remove useless mean_output - if isinstance(outs, tuple): - outs = outs[0] - if outs.op.name == "mean_output": - outs = outs.op.input_tensors[0] - return outs - - -def gpu_schedule_Mean(outs): - """ - gpu schedule function for mean. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - out = outs[0] if isinstance(outs, list) else outs - - device = "cuda" - with tvm.target.create(device): - sch = tvm.create_schedule(out.op) - if out.op.name == "T_divide": - tensor_c = out - else: # squeeze - tensor_c = out.op.input_tensors[0] - - tensor_b = tensor_c.op.input_tensors[0] - if len(tensor_c.op.axis) >= 2: - sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1]) - else: - sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0]) - - bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD) - sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x")) - sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x")) - return sch - -def SimpleMean(x): - """ - SimpleMean compute the mean of the input 4D Tensor over last two axises and keep reduced dimensions. - - Args: - x (tvm.tensor.Tensor): Tensor of type float16, float32. - - Returns: - tvm.tensor.Tensor, has the same type as x, output shape will be (a, b, 1, 1) if input Tensor x is (a, b, c, d). - """ - axis = (2, 3) - keepdims = True - return Mean(x, axis, keepdims) - - -def gpu_schedule_SimpleMean(outs): - """gpu schedule function for SimpleMean.""" - return gpu_schedule_Mean(outs) diff --git a/mindspore/_akg/gpu/mean_grad.py b/mindspore/_akg/gpu/mean_grad.py deleted file mode 100644 index 9d91ee3f40..0000000000 --- a/mindspore/_akg/gpu/mean_grad.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""mean_grad""" -import _akg.tvm as tvm -import _akg -from _akg.ops.math import mean -from .default_schedule import DEFAULT_GPU_THREAD - - -def mean_ad(head, input_shape, axis, keepdims): - """mean autodiff.""" - tensor_a = tvm.placeholder(input_shape, head.dtype, "A") - tensor_b = mean.mean(tensor_a, axis, keepdims) - - # remove useless mean_output - if isinstance(tensor_b, tuple): - tensor_b = tensor_b[0] - if tensor_b.op.name == "mean_output": - tensor_b = tensor_b.op.input_tensors[0] - - jacs = list(_akg.differentiate(tensor_b, [tensor_a], head)) - return jacs[0] - - -def MeanGrad(y_grad, input_shape, axis=None, keepdims=True): - """Mean Grad.""" - if axis is None and not keepdims: - raise ValueError("Mean not support (axis=None && keepdims=False) now") - return mean_ad(y_grad, input_shape, axis, keepdims) - - -def gpu_schedule_MeanGrad(outs): - """gpu schedule MeanGrad.""" - out = outs[0] if isinstance(outs, list) else outs - - device = "cuda" - with tvm.target.create(device): - sch = tvm.create_schedule(out.op) - tensor_c = out - tensor_b = tensor_c.op.input_tensors[0] - if len(tensor_c.op.axis) >= 2: - sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1]) - else: - sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0]) - - bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD) - sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x")) - sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x")) - - return sch - -def SimpleMeanGrad(HEAD, input_shape): - """ - Compute Simple Mean Grad. - - Args: - HEAD (tvm.tensor.Tensor): output gradient, dy, defined in Primitive. - input_shape (Union[list[int], tuple[int]]): shape of mean input, x.shape. - - Returns: - tvm.tensor.Tensor, gradient of mean input. - """ - axis = (2, 3) - keepdims = True - return MeanGrad(HEAD, input_shape, axis, keepdims) - - -def gpu_schedule_SimpleMeanGrad(outs): - """ - gpu schedule SimpleMeanGrad. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - return gpu_schedule_MeanGrad(outs) diff --git a/mindspore/_akg/gpu/mul.py b/mindspore/_akg/gpu/mul.py deleted file mode 100644 index 5c289a62a6..0000000000 --- a/mindspore/_akg/gpu/mul.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""mul""" -import _akg.topi as topi -import _akg.tvm as tvm -from _akg.ops.math import mul - -def Mul(x, y): - """mul.""" - return mul.mul(x, y) - - -def gpu_schedule_Mul(outs): - """ - gpu schedule for mul. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with tvm.target.create(device): - sch = topi.cuda.schedule_broadcast(outs) - return sch diff --git a/mindspore/_akg/gpu/notequal.py b/mindspore/_akg/gpu/notequal.py deleted file mode 100644 index 3e3a6561a1..0000000000 --- a/mindspore/_akg/gpu/notequal.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""notequal""" -import _akg.tvm -from _akg.ops.math import notequal -from _akg.topi.generic import schedule_elemwise - -def NotEqual(x, y): - """notequal.""" - return notequal.notequal(x, y) - - -def gpu_schedule_NotEqual(outs): - """ - gpu schedule for NotEqual. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/relu6.py b/mindspore/_akg/gpu/relu6.py deleted file mode 100644 index 9a0a3d7a45..0000000000 --- a/mindspore/_akg/gpu/relu6.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""relu6""" -import _akg.topi as topi -import _akg.tvm as tvm -from _akg.topi import tag - -@tvm.tag_scope(tag=tag.ELEMWISE) -def topi_nn_relu6(x): - """topi nn relu6.""" - return tvm.compute(x.shape, lambda *i: tvm.min(tvm.max(x(*i), tvm.const(0, x.dtype)), tvm.const(6, x.dtype))) - -def ReLU6(x): - """ - Compute elementwise with function: min(max(x, 0), 6). - - Args: - x (tvm.tensor.Tensor): Tensor of type float16, float32. - - Returns: - tvm.tensor.Tensor, has same type and shape as input. - """ - return topi_nn_relu6(x) - - -def gpu_schedule_ReLU6(outs): - """ - gpu schedule ReLU6. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with tvm.target.create(device): - sch = topi.cuda.schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/relu6_grad.py b/mindspore/_akg/gpu/relu6_grad.py deleted file mode 100644 index 62aeabb4c0..0000000000 --- a/mindspore/_akg/gpu/relu6_grad.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""relu6 grad""" -import _akg.topi as topi -import _akg.tvm as tvm - -def ReLU6Grad(y_grad, x): - """ - Computes Gradients of Rectified Linear 6. - - Args: - y_grad (tvm.tensor.Tensor): Tensor of type float16, float32, gradients backpropagated to the ReLU6 op. - x (tvm.tensor.Tensor): Tensor of type float16/float32, inputs that where passed to the ReLU6 op, or its outputs. - - Returns: - tvm.tensor.Tensor, has same type and shape as x. - """ - shape = x.shape - dtype = x.dtype - - zero = tvm.const(0, dtype) - six = tvm.const(6, dtype) - - res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= zero, x(*i), zero)) - res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= six, zero, res0(*i))) - res = tvm.compute(shape, lambda *i: tvm.if_then_else(res6(*i) == zero, zero, y_grad(*i))) - return res - - -def gpu_schedule_ReLU6Grad(outs): - """ - gpu schedule ReLU6Grad. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - - with tvm.target.create(device): - sch = topi.cuda.schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/squeeze.py b/mindspore/_akg/gpu/squeeze.py deleted file mode 100644 index b5f55facaa..0000000000 --- a/mindspore/_akg/gpu/squeeze.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""squeeze""" -import _akg.topi as topi -import _akg.tvm as tvm - -def Squeeze(x, axis=None): - """ - Remove the dimensions which have shape size 1. - - Args: - x (tvm.tensor.Tensor): Tensor, input whose shape is to be squeeze. - axis (Union[list, tuple, int, None]): specify which size 1 dimension to be removed. - - Returns: - tvm.tensor.Tensor, has the same type and element as x, but some size 1 dimensions are removed. - """ - return topi.squeeze(x, axis) - - -def gpu_schedule_Squeeze(outs): - """ - gpu schedule Squeeze. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - - with tvm.target.create(device): - sch = topi.cuda.schedule_injective(outs) - return sch diff --git a/mindspore/_akg/gpu/squeeze_grad.py b/mindspore/_akg/gpu/squeeze_grad.py deleted file mode 100644 index ae31de8e84..0000000000 --- a/mindspore/_akg/gpu/squeeze_grad.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""squeeze grad""" -import _akg.topi as topi - - -def SqueezeGrad(y_grad, x_shape): - """ - Computes gradients for squeeze op. - - Args: - y_grad (tvm.tensor.Tensor): the gradient needed to be propagation. - x_shape (Union[list, tuple]): output Tensor shape. - - Returns: - tvm.tensor.Tensor: output gradient. - """ - return topi.reshape(y_grad, x_shape) - - -def gpu_schedule_SqueezeGrad(outs): - """ - gpu schedule SqueezeGrad. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - from .default_schedule import default_schedule - return default_schedule(outs) diff --git a/mindspore/_akg/gpu/sub.py b/mindspore/_akg/gpu/sub.py deleted file mode 100644 index 611e4228fd..0000000000 --- a/mindspore/_akg/gpu/sub.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""sub""" -import _akg.tvm -from _akg.ops.math import sub -from _akg.topi.generic import schedule_elemwise - -def Sub(x, y): - """Sub.""" - return sub.sub(x, y) - - -def gpu_schedule_Sub(outs): - """ - GPU schedule for Sub. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - sch = schedule_elemwise(outs) - return sch diff --git a/mindspore/_akg/gpu/tile.py b/mindspore/_akg/gpu/tile.py deleted file mode 100644 index 1eb6979b09..0000000000 --- a/mindspore/_akg/gpu/tile.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""tile""" -import _akg.tvm -from _akg.ops.array import tile -from _akg.topi.generic import schedule_elemwise - -def Tile(x, multiples): - """tile.""" - return tile.tile(x, multiples) - -def gpu_schedule_Tile(outs): - """ - gpu schedule for tile. - - Args: - outs (tvm.tensor.Tensor): outputs of compute. - - Returns: - sch (schedule.Schedule): The created schedule. - """ - device = 'cuda' - ctx = _akg.tvm.context(device, 0) - if not ctx.exist: - raise SystemError("Skip because %s is not enabled" % device) - with _akg.tvm.target.create(device): - s = schedule_elemwise(outs) - return s diff --git a/mindspore/_akg/message.py b/mindspore/_akg/message.py deleted file mode 100644 index 3d1f81f914..0000000000 --- a/mindspore/_akg/message.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""message""" -import importlib.util -import json -import json.decoder as jd -import logging -import traceback -import os.path -from pathlib import Path -import _akg.tvm -from _akg.utils import validation_check as vc_util -from _akg.utils.dsl_create import TensorUtils -from . import gpu -from . import op_build - - -@vc_util.check_input_type(str) -def compilewithjson(json_str): - """compile with json.""" - try: - kernel_info = json.loads(json_str) - except jd.JSONDecodeError: - logging.error(traceback.format_exc()) - return False - - op_name = kernel_info['name'] - op_func = None - processor = 'aicore' - if 'process' in kernel_info: - processor = kernel_info['process'] - # get custom ops implementation first. - if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: - impl_path = os.path.realpath(kernel_info['impl_path']) - if os.path.isfile(impl_path): - custom_mod_name = Path(impl_path).resolve().stem - mod_spec = importlib.util.spec_from_file_location( - custom_mod_name, impl_path) - custom_mod = importlib.util.module_from_spec(mod_spec) - mod_spec.loader.exec_module(custom_mod) - op_func = getattr(custom_mod, op_name, None) - - # get built-in ops. - if op_func is None: - if processor == 'cuda': - op_func = getattr(gpu, op_name, None) - - if op_func is None: - logging.error( - "this op not supported, please check op name %s", str(op_name)) - return False - - args = {} - tsr = [] - for input_desc in kernel_info['input_desc']: - if len(input_desc) == 1: - tensor_shape = input_desc[0]['shape'] - tensor_shape = (1,) if not tensor_shape else tensor_shape - vc_util.shape_dtype_max_size_check(tensor_shape) - args[input_desc[0]['name']] = _akg.tvm.placeholder( - shape=tensor_shape, name=input_desc[0]['tensor_name'], dtype=input_desc[0]['data_type']) - tsr.append(args[input_desc[0]['name']]) - else: - tmp_input = [] - for tmp_desc in input_desc: - tensor_shape = tmp_desc['shape'] - tensor_shape = (1,) if not tensor_shape else tensor_shape - vc_util.shape_dtype_max_size_check(tensor_shape) - tmp_input.append(_akg.tvm.placeholder( - shape=tensor_shape, name=tmp_desc['tensor_name'], dtype=tmp_desc['data_type'])) - args[input_desc[0]['name']] = tmp_input - tsr = tsr + tmp_input - - if kernel_info['attr']: - for ext_arg in kernel_info['attr']: - args[ext_arg['name']] = ext_arg['value'] - - output = op_func(**args) - - if isinstance(output, (list, tuple)): - from inspect import isfunction - tmp_outputs = [] - for elem in output: - if not isfunction(elem) or isinstance(elem, dict): - tmp_outputs.append(elem) - - output = tmp_outputs - else: - output = [output] - - tsr = tsr + [i for i in output if TensorUtils.is_output_value(i)] - return op_build([op_name], output, tsr, processor, kernel_info['op']) diff --git a/mindspore/_akg/op_build.py b/mindspore/_akg/op_build.py deleted file mode 100644 index 92101f657e..0000000000 --- a/mindspore/_akg/op_build.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""op_build""" -import os -import fcntl -import types -import typing -import logging -import traceback -import _akg.tvm -import _akg -from _akg import save_gpu_param as gpu_utils -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(list, (list, tuple), (list, tuple), str, str) -def op_build(opnames, computes, args, device, kernel_name): - """op_build""" - kernel_meta_path = "./cuda_meta_" + str(os.getpid()) + "/" - if device == "cuda": - cuda_path = os.path.realpath(kernel_meta_path) - if not os.path.isdir(cuda_path): - os.makedirs(cuda_path) - if not opnames: - logging.error("no opname given.") - return None - - schedule_name = 'gpu_schedule_' + opnames[0] - schedule_func = getattr(_akg.gpu, schedule_name) - if not isinstance(schedule_func, (types.FunctionType, typing.Callable)): - logging.error("no schedule func found %s", str(schedule_name)) - return None - - ptx_file = os.path.realpath(kernel_meta_path + kernel_name + ".ptx") - if os.path.exists(ptx_file): - os.chmod(ptx_file, 0o600) - try: - with open(ptx_file, 'at') as file: - fcntl.flock(file.fileno(), fcntl.LOCK_EX) - file.seek(0, 2) - if file.tell() == 0: - s = schedule_func(computes) - foo = _akg.tvm.build(s, args, device, name=kernel_name) - ptx_code = foo.imported_modules[0].get_source("ptx") - file.write(ptx_code) - json_file = os.path.realpath( - kernel_meta_path + kernel_name + ".json") - kernel_info = (ptx_code, json_file, kernel_name) - gpu_utils.save_gpu_params(s, args, kernel_info) - os.chmod(ptx_file, 0o400) - except IOError: - logging.error(traceback.format_exc()) - return None - return True - - logging.error("Not support device %s.", device) - return None diff --git a/mindspore/_akg/ops/array/tile.py b/mindspore/_akg/ops/array/tile.py deleted file mode 100644 index 2fa485ea36..0000000000 --- a/mindspore/_akg/ops/array/tile.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: tile""" -import _akg.tvm -import _akg.topi -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, (list, tuple)) -def tile(data, multiples): - """ - Repeats the data in the specified dimensions according to the multiples. - - Args: - data (tvm.tensor.Tensor): Tensor. - multiples (Union[list, tuple]): Elements must be int. The number of repetitions. - - Returns: - tvm.tensor.Tensor, has the same dtype as data. - """ - vc_util.check_shape(data.shape) - vc_util.check_int_list(multiples, "multiples") - output = _akg.topi.tile(data, multiples) - return output diff --git a/mindspore/_akg/ops/math/cast.py b/mindspore/_akg/ops/math/cast.py deleted file mode 100644 index 78140bfe27..0000000000 --- a/mindspore/_akg/ops/math/cast.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: cast""" -import _akg.tvm -import _akg.topi -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, str) -def cast(data, dst_type): - """ - cast data to target type. - - Args: - data (tvm.tensor.Tensor): Tensor to be casted. - dst_type (str): target cast type. - - Returns: - tvm.tensor.Tensor, type is dst_type. - """ - vc_util.check_shape(data.shape) - out = _akg.topi.cast(data, dst_type) - - return out diff --git a/mindspore/_akg/ops/math/equal.py b/mindspore/_akg/ops/math/equal.py deleted file mode 100644 index 2dbb1ba733..0000000000 --- a/mindspore/_akg/ops/math/equal.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: equal""" -import _akg.tvm -import _akg.topi -from _akg.utils.dsl_create import produce_shapes -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def equal(input1, input2): - """ - check whether input1 equals to input2. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - input2 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. If input1 equal to input2 return True, else return False. - """ - shape1 = [x.value for x in input1.shape] - shape2 = [x.value for x in input2.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - - shape1, shape2, shape = produce_shapes(shape1, shape2) - - vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) - dtype = input1.dtype - - # get equal compute - t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") - f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") - - input1_bro = _akg.topi.broadcast_to(input1, shape) - input2_bro = _akg.topi.broadcast_to(input2, shape) - c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice], - t_value[indice], f_value[indice]), name="C") - res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") - - return res diff --git a/mindspore/_akg/ops/math/greater_equal.py b/mindspore/_akg/ops/math/greater_equal.py deleted file mode 100644 index 00ad016643..0000000000 --- a/mindspore/_akg/ops/math/greater_equal.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: greaterequal""" -import _akg.tvm -import _akg.topi -from _akg.utils.dsl_create import produce_shapes -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def greater_equal(input1, input2): - """ - Check whether input1 greaterquals to input2. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - input2 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. If input1 greaterquals to input2 return True, else return False. - """ - shape1 = [x.value for x in input1.shape] - shape2 = [x.value for x in input2.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - - shape1, shape2, shape = produce_shapes(shape1, shape2) - - vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) - dtype = input1.dtype - - # get greaterquals compute - t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") - f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") - - input1_bro = _akg.topi.broadcast_to(input1, shape) - input2_bro = _akg.topi.broadcast_to(input2, shape) - c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] >= input2_bro[indice], - t_value[indice], f_value[indice]), name="C") - res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") - - return res diff --git a/mindspore/_akg/ops/math/less_equal.py b/mindspore/_akg/ops/math/less_equal.py deleted file mode 100644 index 5a566fbbca..0000000000 --- a/mindspore/_akg/ops/math/less_equal.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: lessequal""" -import _akg.tvm -import _akg.topi -from _akg.utils.dsl_create import produce_shapes -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def less_equal(input1, input2): - """ - Check whether input1 lessequals to input2. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - input2 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False. - """ - shape1 = [x.value for x in input1.shape] - shape2 = [x.value for x in input2.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - - shape1, shape2, shape = produce_shapes(shape1, shape2) - - vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) - dtype = input1.dtype - - # get lessequal compute - t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") - f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") - - input1_bro = _akg.topi.broadcast_to(input1, shape) - input2_bro = _akg.topi.broadcast_to(input2, shape) - c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice], - t_value[indice], f_value[indice]), name="C") - res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") - - return res diff --git a/mindspore/_akg/ops/math/logical_and.py b/mindspore/_akg/ops/math/logical_and.py deleted file mode 100644 index 480d4e1741..0000000000 --- a/mindspore/_akg/ops/math/logical_and.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: logical_and""" -import _akg.tvm -import _akg.topi -from _akg.utils import validation_check as vc_util - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def logical_and(input1, input2): - """ - Compute logical_and of input1 and input2. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - input2 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. LogicalAnd of input1 and input2. - """ - - vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) - - shape1 = [x.value for x in input1.shape] - shape2 = [x.value for x in input2.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - - res = _akg.topi.logical_and(input1, input2) - return res diff --git a/mindspore/_akg/ops/math/logical_not.py b/mindspore/_akg/ops/math/logical_not.py deleted file mode 100644 index 9befe7e816..0000000000 --- a/mindspore/_akg/ops/math/logical_not.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: logical_not""" -import _akg.tvm -import _akg.topi -from _akg.utils import validation_check as vc_util - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor) -def logical_not(input1): - """ - Compute logical_not of input1. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. - """ - res = _akg.topi.logical_not(input1) - return res diff --git a/mindspore/_akg/ops/math/logical_or.py b/mindspore/_akg/ops/math/logical_or.py deleted file mode 100644 index 8fb0b80567..0000000000 --- a/mindspore/_akg/ops/math/logical_or.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: logical_or""" -import _akg.tvm -import _akg.topi -from _akg.utils import validation_check as vc_util - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def logical_or(input1, input2): - """ - Compute logical_or of input1 and input2. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - input2 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. LogicalOr of input1 and input2. - """ - - vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) - - shape1 = [x.value for x in input1.shape] - shape2 = [x.value for x in input2.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - - res = _akg.topi.logical_or(input1, input2) - return res diff --git a/mindspore/_akg/ops/math/mean.py b/mindspore/_akg/ops/math/mean.py deleted file mode 100644 index e8300f22fc..0000000000 --- a/mindspore/_akg/ops/math/mean.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: mean""" -import _akg.topi -import _akg.tvm -from _akg.utils import format_transform as ft_util -from _akg.utils import validation_check as vc_util -from _akg.ops.math import sum_value - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None))) -def mean(data, axis=None, keepdims=False): - """ - Computes the mean of the values of a Tensor over the whole dataset. - - Args: - data (tvm.tensor.Tensor): Tensor. - axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None. - keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length. - - Returns: - tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are - retained with length 1. else these reduced axis will be eliminate. - """ - shape = [x.value for x in data.shape] - vc_util.reduce_axis_check(shape, axis) - axis = ft_util.refine_reduce_axis(data, axis) - - count = 1 - for i in axis: - count *= shape[i] - output, _ = sum_value.sum_value(data, axis, keepdims) - res = _akg.topi.divide(output, count) - - return res diff --git a/mindspore/_akg/ops/math/mul.py b/mindspore/_akg/ops/math/mul.py deleted file mode 100644 index a690089da2..0000000000 --- a/mindspore/_akg/ops/math/mul.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: mul""" -import _akg.topi -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def mul(l_input, r_input): - """ - Calculate x * y element-wise. - - Note: - mul supports broadcasting. - - Args: - l_input (tvm.tensor.Tensor): Tensor. - r_input (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor, has the same type as l_input and r_input. - """ - shape1 = [x.value for x in l_input.shape] - shape2 = [x.value for x in r_input.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - vc_util.auto_broadcast_check(shape1, shape2) - vc_util.elemwise_dtype_check(l_input.dtype, r_input.dtype) - output = _akg.topi.multiply(l_input, r_input) - - return output diff --git a/mindspore/_akg/ops/math/notequal.py b/mindspore/_akg/ops/math/notequal.py deleted file mode 100644 index 16d5e4a0f4..0000000000 --- a/mindspore/_akg/ops/math/notequal.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: notequal""" -import _akg.tvm -import _akg.topi -from _akg.utils.dsl_create import produce_shapes -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def notequal(input1, input2): - """ - check whether input1 notequals to input2. - - Args: - input1 (tvm.tensor.Tensor): Tensor. - input2 (tvm.tensor.Tensor): Tensor. - - Returns: - tvm.tensor.Tensor. If input1 notequal to input2 return True, else return False. - """ - shape1 = [x.value for x in input1.shape] - shape2 = [x.value for x in input2.shape] - vc_util.check_shape(shape1) - vc_util.check_shape(shape2) - - shape1, shape2, shape = produce_shapes(shape1, shape2) - - vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) - dtype = input1.dtype - - # get notequal compute - t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") - f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") - - input1_bro = _akg.topi.broadcast_to(input1, shape) - input2_bro = _akg.topi.broadcast_to(input2, shape) - c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] != input2_bro[indice], - t_value[indice], f_value[indice]), name="C") - res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") - - return res diff --git a/mindspore/_akg/ops/math/sub.py b/mindspore/_akg/ops/math/sub.py deleted file mode 100644 index 6ae2ee51ef..0000000000 --- a/mindspore/_akg/ops/math/sub.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: sub""" -import _akg.topi -import _akg.tvm -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) -def sub(data1, data2): - """ - Computes data1 - data2 elementwise, broadcast is supported. - - Args: - data1 (tvm.tensor.Tensor): Tensor. - data2 (tvm.tensor.Tensor): Tensor of same type as data1, if shape(data2) != shape(data1), broadcast will happen. - - Returns: - tvm.tensor.Tensor, subtracted result, with same type as input tensors and broadcasted shape of data1 and data2. - """ - vc_util.elemwise_dtype_check(data1.dtype, data2.dtype) - vc_util.check_shape(data1.shape) - vc_util.check_shape(data2.shape) - vc_util.auto_broadcast_check(data1.shape, data2.shape) - - res = _akg.topi.subtract(data1, data2) - - return res diff --git a/mindspore/_akg/ops/math/sum_value.py b/mindspore/_akg/ops/math/sum_value.py deleted file mode 100644 index b9720469a6..0000000000 --- a/mindspore/_akg/ops/math/sum_value.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""operator dsl function: sum""" - -import _akg.topi -import _akg.tvm -from _akg.utils import format_transform as ft_util -from _akg.utils import validation_check as vc_util - - -@vc_util.check_input_type(_akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None))) -def sum_value(inputs, axis=None, keepdims=False): - """ - Compute the sum of elements across dimensions of a tensor. - - Args: - inputs (tvm.tensor.Tensor): Tensor. - axis (Union[list, tuple, int, None]): If the list or tuple is empty, the axis equal to None. - keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length. - - Returns: - tvm.tensor.Tensor, has same type as input. If keepdims is True, all reduced dimensions are retained - with length 1, else these reduced axis will be eliminate. - """ - axis = ft_util.refine_reduce_axis(inputs, axis) - vc_util.check_shape(inputs.shape) - - if not axis: - output = _akg.topi.identity(inputs) - else: - output = _akg.topi.sum(inputs, axis=axis, keepdims=keepdims) - - return output diff --git a/mindspore/_akg/save_gpu_param.py b/mindspore/_akg/save_gpu_param.py deleted file mode 100644 index ed2c9fe23a..0000000000 --- a/mindspore/_akg/save_gpu_param.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""save gpu param""" -import os -import hashlib -import _akg.tvm -from _akg.tvm import schedule -from _akg.utils import validation_check as vc_util - - -def get_dim(dim, axis=True): - """get dim info""" - dims_str = { - "grid_dim0": "// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = ", - "grid_dim1": "// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = ", - "grid_dim2": "// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = ", - "block_dim0": "// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = ", - "block_dim1": "// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = ", - "block_dim2": "// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = " - } - dim_to_axis = { - "grid_dim0": '"blockIdx.x" : ', - "grid_dim1": '"blockIdx.y" : ', - "grid_dim2": '"blockIdx.z" : ', - "block_dim0": '"threadIdx.x" : ', - "block_dim1": '"threadIdx.y" : ', - "block_dim2": '"threadIdx.z" : ' - } - if axis: - return dim_to_axis.get(dim) - return dims_str.get(dim) - - -def parse_params(file, dim, ir): - """parse parameters""" - dim_str = get_dim(dim, axis=False) - pos = ir.find(dim_str) - if pos != -1: - index = pos + len(dim_str) - param_temp = get_dim(dim) - - while ir[index].isdigit(): - param_temp += ir[index] - index += 1 - file.write(param_temp + ",\n") - else: - param_temp = get_dim(dim) + '1' - file.write(param_temp + ",\n") - - -@vc_util.check_input_type(schedule.Schedule, (list, tuple), tuple) -def save_gpu_params(s, args, kernel_info): - """save gpu parameters""" - ptx_code = kernel_info[0] - file_name = kernel_info[1] - kernel_name = kernel_info[2] - ir = str(_akg.tvm.lower(s, args, simple_mode=True)) - file_path = os.path.realpath(file_name) - if os.path.exists(file_path): - os.remove(file_path) - - sha256 = hashlib.sha256() - sha256.update(ptx_code.encode("utf-8")) - hash_str = sha256.hexdigest() - with os.fdopen(os.open(file_path, os.O_WRONLY | os.O_CREAT, 0o400), 'w') as fo: - fo.write("{\n") - fo.write('"kernelName" : ' + '"' + kernel_name + "_kernel0" + '",\n') - parse_params(fo, "grid_dim0", ir) - parse_params(fo, "grid_dim1", ir) - parse_params(fo, "grid_dim2", ir) - parse_params(fo, "block_dim0", ir) - parse_params(fo, "block_dim1", ir) - parse_params(fo, "block_dim2", ir) - fo.write('"sha256" : ' + '"' + hash_str + '"\n') - fo.write("}\n") diff --git a/mindspore/_akg/utils/dsl_create.py b/mindspore/_akg/utils/dsl_create.py deleted file mode 100644 index 9d27039b28..0000000000 --- a/mindspore/_akg/utils/dsl_create.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""dsl create helping function""" -import _akg -from _akg.utils import format_transform as ft_util - -class TensorUtils: - """Class for creating tensor.""" - CREATE_SCH_ONLY = 'create_sch_only' - - @classmethod - def get_tensor_attrs(cls, tensor): - """get tensor attrs.""" - tensor_attrs = dict() - if "attrs" in dir(tensor.op): - tensor_attrs = dict(tensor.op.attrs.items()) - return tensor_attrs - - @classmethod - def update_tensor_attrs(cls, tensor, attrs): - """update tensor attrs.""" - tensor_attrs = cls.get_tensor_attrs(tensor) - tensor_attrs.update(attrs) - tensor = _akg.tvm.compute(tensor.shape, - lambda *indice: tensor[indice], - name=tensor.op.name, - tag=tensor.op.tag, - attrs=tensor_attrs) - return tensor - - @classmethod - def is_create_sch_only(cls, tensor): - tensor_attrs = cls.get_tensor_attrs(tensor) - if cls.CREATE_SCH_ONLY in tensor_attrs.keys(): - return True - return False - - @classmethod - def is_output_value(cls, tensor): - """check output value.""" - return not cls.is_create_sch_only(tensor) - - @classmethod - def inplace_set(cls, input_tensor, output_tensor, buffer_name="data_buf"): - """inplace set.""" - input_tensor_shape = ft_util.get_shape(input_tensor) - output_tensor_shape = ft_util.get_shape(output_tensor) - if not input_tensor_shape == output_tensor_shape: - raise RuntimeError("Shape of the input_tensor and the output_tensor should be equal, " - "but got %s and %s"%(input_tensor_shape, output_tensor_shape)) - output_tensor = cls.update_tensor_attrs(output_tensor, {cls.CREATE_SCH_ONLY: 1}) - data_buf = _akg.tvm.decl_buffer(input_tensor.shape, input_tensor.dtype, name=buffer_name) - binds_info = {input_tensor: data_buf, output_tensor: data_buf} - return output_tensor, binds_info - - @classmethod - def inplace_set_tensors(cls, input_tensors, output_tensors, buffer_names=None): - """ - inplace set for tensors - - Args: - in_tensors (Union[list, tuple]): Origin input tensors. - out_tensors (Union[list, tuple]): Origin output tensors. - buffer_names (Union[list, tuple] or None): Buffer names used to bind. - - Return: - inplace_tensors (list): Output tensors with the inplace info. - binds_infos (dict): Dictionary that maps the input tensor and the output - tensor to buffer. - """ - if not buffer_names: - buffer_names = ["data_buf_%s" % i for i in range(len(input_tensors))] - for arg in (input_tensors, output_tensors, buffer_names): - if not isinstance(arg, (tuple, list)): - raise RuntimeError("arg must be tuple or list!") - if len(input_tensors) != len(output_tensors) or len(input_tensors) != len(buffer_names): - raise RuntimeError("length of the input_tensors, output_tensors and buffer_names must be equal!") - - inplace_tensors = [] - binds_infos = dict() - for input_tensor, output_tensor, buffer_name in zip(input_tensors, output_tensors, buffer_names): - inplace_tensor, binds_info = cls.inplace_set(input_tensor, output_tensor, buffer_name) - inplace_tensors.append(inplace_tensor) - binds_infos.update(binds_info) - return inplace_tensors, binds_infos - -def produce_shapes(shape1, shape2): - """two input shapes produce three output shape.""" - shape1 = list(shape1) - shape2 = list(shape2) - flag = 0 - if len(shape1) < len(shape2): - shape1, shape2 = shape2, shape1 - flag = 1 - - output_shape_len = len(shape1) - dec = output_shape_len - len(shape2) - for i in range(dec): - shape2 = [1] + shape2 - - out_shape = [] - for i in range(output_shape_len): - if (shape1[i] != shape2[i]) and (shape1[i] != 1) and (shape2[i] != 1): - raise RuntimeError("input shapes not match!") - out_shape.append(shape1[i] if shape1[i] > shape2[i] else shape2[i]) - - if flag == 1: - shape1, shape2 = shape2, shape1 - - return shape1, shape2, out_shape diff --git a/mindspore/_akg/utils/format_transform.py b/mindspore/_akg/utils/format_transform.py deleted file mode 100644 index c7a69b26cc..0000000000 --- a/mindspore/_akg/utils/format_transform.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""format transform function""" -import _akg - -def refine_reduce_axis(input_content, axis): - """make reduce axis legal.""" - shape = get_shape(input_content) - if axis is None: - axis = [i for i in range(len(shape))] - elif isinstance(axis, int): - axis = [axis] - elif not isinstance(axis, (tuple, list)): - raise TypeError("axis must be one of the type int,tuple,list or None") - - if len(axis) > len(shape): - raise ValueError("axis size must not larger than shape size") - - axis = list(axis) - - for i, _ in enumerate(axis): - if axis[i] < 0: - axis[i] += len(shape) - - if axis[i] >= len(shape): - raise ValueError(("axis value-{} exceeds len(axis) which is invalid".format(axis[i]))) - - axis.sort(reverse=True) - - return axis - - -def get_shape_from_tensor(data): - """translate _akg.tvm.shape to list type in python.""" - tvm_shape = data.shape - py_shape = [] - for i in tvm_shape: - if isinstance(i, _akg.tvm.expr.Var): - py_shape.append(i) - else: - py_shape.append(i.value) - return py_shape - - -def tvm_shape_to_list(tvm_shape): - """translate _akg.tvm.shape to list type in python.""" - py_shape = [] - for i in tvm_shape: - if isinstance(i, _akg.tvm.expr.Var): - py_shape.append(i) - else: - py_shape.append(i.value) - return py_shape - - -def get_shape(data): - """get shape and save it as list.""" - if isinstance(data, _akg.tvm.tensor.Tensor): - shape = get_shape_from_tensor(data) - elif isinstance(data, _akg.tvm.container.Array): - shape = tvm_shape_to_list(data) - elif isinstance(data, int): - shape = [data] - elif isinstance(data, (tuple, list)): - shape = list(data) - else: - raise TypeError("Refine axis does not support type {} for now.".format(type(data))) - return shape diff --git a/mindspore/_akg/utils/validation_check.py b/mindspore/_akg/utils/validation_check.py deleted file mode 100644 index 1231b3110e..0000000000 --- a/mindspore/_akg/utils/validation_check.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""validation check functions""" -from functools import wraps, reduce -from _akg.utils.format_transform import get_shape - -MAX_DATA_SIZE = 2 ** 31 - -def check_input_type_dict(input_dict, input_key, input_name): - """ - check input parameter type for new type: dict. - - Note: - rule1: key of input_dict should be in the input_key - rule2: type of input_dict[shape] should be in (list, tuple), if have shape - rule3: type of input_dict[dtype] should be in (str), if have dtype - - Args: - input_dict (dict): input_dict - input_key (list or tuple): all input key list, the key of input must in input_key - input_name (str): input param name, only used for error print - - Returns: - None - """ - def _check_input_type(input_key, input_type): - if not isinstance(input_dict[input_key], input_type): - raise RuntimeError( - "the input parameter %s[%s] must be %s, while type of input is %s" % - (input_name, input_key, input_type, type(input_dict[input_key]))) - - for key in input_dict.keys(): - if key not in input_key: - raise RuntimeError( - "the input parameter %s must have arrt <%s>" % - (input_name, key)) - - # check shape's type of input_dict, if have shape - if key == "shape": - _check_input_type(key, (list, tuple)) - - # check dtype's type of input_dict, if have dtype - if key == "dtype": - _check_input_type(key, (str,)) - - -def check_input_type_list_tuple(inputs, expect): - """check inputs by a list or tuple of expected types.""" - if not isinstance(inputs, expect[1][0]): - raise RuntimeError("the input parameter %s must be (list, tuple), while" - " type of input is %s" % (expect[0], type(inputs))) - for inp in inputs: - if not isinstance(inp, expect[1][1]): - raise RuntimeError("The element in parameter %s must be %s, while " - "type of input is %s" % ( - expect[0], expect[1][1], type(inp))) - - -def check_input_type(*type_args, **_type_kwargs): - """check input parameter type.""" - def out_wrapper(func): - """outer wrapper function.""" - formal_parameter = func.__code__.co_varnames - formal_parameter_list = list(zip(formal_parameter, type_args)) - - @wraps(func) - def in_wrapper(*args, **kwargs): - """inner wrapper function.""" - for i, arg_v in enumerate(args): - # add for new input dict, if dict, will check shape and dtype - if isinstance(arg_v, dict): - check_input_type_dict(arg_v, arg_v.keys(), - formal_parameter_list[i][0]) - - if isinstance(formal_parameter_list[i][1], tuple): - if isinstance(formal_parameter_list[i][1][0], tuple) \ - and len(formal_parameter_list[i][1]) == 2: - check_input_type_list_tuple(arg_v, formal_parameter_list[i]) - continue - - if not isinstance(arg_v, formal_parameter_list[i][1]): - raise RuntimeError("the %sth input parameter %s must be %s, " - "while type of input is %s" % (str(i), formal_parameter_list[i][0], - formal_parameter_list[i][1], - type(arg_v))) - for i in kwargs: - for j in formal_parameter_list: - if i in j: - if not isinstance(kwargs[i], j[1]): - raise RuntimeError("the input parameter %s must be " - "%s, while type of input is %s" - "" % (i, j[1], type(kwargs[i]))) - break - return func(*args, **kwargs) - - return in_wrapper - - return out_wrapper - - -def shape_dtype_max_size_check(shape): - """check validation of tensor's shape.""" - if shape: - mul = int(reduce(lambda x, y: int(x) * int(y), shape)) - if mul > MAX_DATA_SIZE: - error_msg = "*".join([str(sh) for sh in shape]) - raise RuntimeError("Invalid shape, data is {} bytes ({}), which " - "exceed max data size {} bytes" - .format(mul, error_msg, MAX_DATA_SIZE)) - - -def check_shape(tensor, length=None, tensor_name=""): - """The common check rule for placeholder data.""" - shape = get_shape(tensor) - if not shape: - raise RuntimeError("The ndim of input tensor {} must more than 0, " - "actual input is {}".format(tensor_name, len(shape))) - - for shape_v in shape: - if not isinstance(shape_v, int) or shape_v <= 0: - raise RuntimeError("The type of tensor {} axis value must be " - "positive int and value more than 0," - "actual input is ({}) {}". - format(tensor_name, type(shape_v), shape_v)) - - if length and len(shape) != length: - raise ValueError('The length of {} should be {}, while actual length is {}'. - format(tensor_name, length, len(shape))) - - -def ops_dtype_check(dtype, args): - """check validation of op's dtype.""" - expected_dtype = list() - - def _get_expect_dtype(expected_dtype, arg): - if isinstance(arg, str): - expected_dtype.append(arg) - elif isinstance(arg, (list, tuple)): - for t in arg: - _get_expect_dtype(expected_dtype, t) - else: - raise TypeError("arg should be either a string, " - "or a list/tuple of string, " - "while current is {}".format(type(arg))) - - _get_expect_dtype(expected_dtype, args) - - if isinstance(dtype, (list, tuple)): - checking_dtype = [d.lower() for d in dtype] - elif isinstance(dtype, str): - checking_dtype = [dtype.lower()] - else: - raise TypeError("dtype should be either a string or a tuple/list of string") - error_msg = "Supported dtype: {}, while received dtype: {}" - if not set(checking_dtype).issubset(set(expected_dtype)): - raise RuntimeError(error_msg.format(expected_dtype, checking_dtype)) - - -def reduce_axis_check(reduce_shape, reduce_axis): - """check validation of reduce axis for certain reduce shape.""" - dim = len(reduce_shape) - if dim == 1 and int(reduce_shape[0]) == 1: - raise RuntimeError("Error, reduce shape is 1. Scalar is not supported " - "for reduction, please input a vector.") - if isinstance(reduce_axis, int): - if reduce_axis not in range(-dim, dim): - raise RuntimeError("Reduce axis should be in range [%d. %d)" - "" % (-dim, dim)) - elif isinstance(reduce_axis, (tuple, list)): - if len(reduce_axis) > len(reduce_shape): - raise RuntimeError("Reduce axis list exceed reduce shape length: " - "%d vs %d, error" % (len(reduce_axis), len(reduce_shape))) - processed_axis = [] - for axis in reduce_axis: - processed_axis.append(int(axis + dim) if axis < 0 else int(axis)) - if len(set(processed_axis)) < len(processed_axis): - raise RuntimeError("Reduce axis list contains %d duplicated element, please check" - % (len(processed_axis) - len(set(processed_axis)))) - for axis in processed_axis: - if axis >= dim: - raise RuntimeError("Invalid reduce axis, axis should less than %d" % dim) - elif reduce_axis is not None: - raise RuntimeError("axis should be a list, tuple or int.") - - -def elemwise_dtype_check(dtype_a, dtype_b, supported_type=None): - """check validation of tensor's dtype for element-wise op.""" - if supported_type: - ops_dtype_check(dtype_a, supported_type) - ops_dtype_check(dtype_b, supported_type) - if dtype_a.lower() != dtype_b.lower(): - raise RuntimeError("Element-wise operation needs same data type, while " - "current is %s vs %s" % (dtype_a.lower(), dtype_b.lower())) - - -def auto_broadcast_check(shape_a, shape_b): - """automatic broadcast check.""" - shape_l = get_shape(shape_a) - shape_r = get_shape(shape_b) - - if len(shape_l) <= len(shape_r): - shape_short = shape_l - shape_long = shape_r - else: - shape_short = shape_r - shape_long = shape_l - - dim_diff = len(shape_long) - len(shape_short) - for i in range(dim_diff): - shape_short.insert(0, 1) - for i, shp in enumerate(shape_short): - if int(shp) != int(shape_long[i]) and 1 not in [int(shp), int(shape_long[i])]: - raise RuntimeError("Invalid auto broadcast, dim %d should be 1 or equal, " - "while now is %d vs %d" % (i, shp, shape_long[i])) - - -def check_int_list(array, array_name): - """check whether all the elements are integers.""" - for num in array: - if not isinstance(num, int): - raise RuntimeError("Type of value in %s should be int, but got type %s" % (array_name, type(num))) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index d5ac7c3e33..801e3ac554 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -672,7 +672,7 @@ def check_input_data(*data, data_class): def check_output_data(data): """Output data check.""" - if not data: + if data is None: raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.') diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 6bd382c1b6..1eade2d86d 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -14,7 +14,10 @@ # ============================================================================ """builtin_operations""" import numpy as np +from mindspore.ops import functional as F +from mindspore.ops import composite as C from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype @@ -113,6 +116,7 @@ def bool_or(x, y): """Implement `bool_or`.""" return x or y + def vm_compare(*args): """Implement `vm_compare` for tensor.""" obj_str = args[-1] @@ -141,10 +145,12 @@ def list_len(x): """Implement `list_len`.""" return len(x) + def Depend(value, expr): """Implement `Depend`.""" return value + # only used in PyNative mode def make_ref(key, value, ref): return value @@ -171,3 +177,16 @@ def tuple_to_array(x): def stop_gradient(x): """Implement `stop_gradient`.""" return x + + +hyper_map = C.HyperMap() + + +def mixed_precision_cast(dst_type, x): + """Implement `mixed_precision_cast`.""" + def cast_inner(data): + if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): + return F.cast(data, dst_type) + return data + + return hyper_map(cast_inner, x) diff --git a/mindspore/_extends/parallel_compile/akg_compiler/__init__.py b/mindspore/_extends/parallel_compile/akg_compiler/__init__.py index e30774307c..c336f0dafc 100644 --- a/mindspore/_extends/parallel_compile/akg_compiler/__init__.py +++ b/mindspore/_extends/parallel_compile/akg_compiler/__init__.py @@ -12,3 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +""" +Extension functions. + +Python functions that will be called in the c++ parts of MindSpore. +""" diff --git a/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py b/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py new file mode 100644 index 0000000000..757008a022 --- /dev/null +++ b/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""akg process""" +import os +import subprocess +import sys +from multiprocessing import Pool, cpu_count + +def _compile_akg_task(*json_strs): + """ + compile func called in single process + + Parameters: + json_strs: list. List contains multiple kernel infos, suitable for json compile api. + """ + akg_compiler = os.path.join(os.path.split( + os.path.realpath(__file__))[0], "compiler.py") + for json_str in json_strs: + res = subprocess.run( + [sys.executable, akg_compiler, json_str], text=True) + if res.returncode != 0: + raise ValueError("Failed, args: {}!".format(json_str)) + +def create_akg_parallel_process(process_num, wait_time): + """ + create AkgParallelCompiler object + + Returns: + AkgParallelCompiler + """ + return AkgProcess(process_num, wait_time) + +class AkgProcess: + """akg kernel parallel process""" + + def __init__(self, process_num, wait_time): + """ + Args: + process_num: int. processes number + waittime: int. max time the function blocked + """ + if not isinstance(process_num, int): + raise ValueError("process number must be a num") + if not isinstance(wait_time, int): + raise ValueError("wait time must be a num") + if process_num == 0: + process_num = 1 + max_proc_num = 16 + self.process_num = min([cpu_count(), max_proc_num, process_num]) + self.args = [[] for _ in range(self.process_num)] + self.wait_time = wait_time + self.argc = 0 + + def compile(self): + """ + compile kernel by multi processes + Return: + True for all compile success, False for some failed. + """ + if self.argc == 0: + raise ValueError("json must be not null") + with Pool(processes=self.process_num) as pool: + res = pool.starmap_async(_compile_akg_task, self.args) + res.get(timeout=self.wait_time) + return True + + def accept_json(self, json): + """ + accept json data before compile + Args: + json: str. kernel info. + """ + if not isinstance(json, str): + raise ValueError("json must be a str") + self.args[self.argc % self.process_num].append(json) + self.argc += 1 diff --git a/mindspore/_extends/parallel_compile/akg_compiler/multi_process_compiler.py b/mindspore/_extends/parallel_compile/akg_compiler/multi_process_compiler.py deleted file mode 100644 index ffe9c85dc3..0000000000 --- a/mindspore/_extends/parallel_compile/akg_compiler/multi_process_compiler.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Providing multi process compile with json""" -import os -import subprocess -import sys -from multiprocessing import Pool, cpu_count - - -def _compile_akg_task(*json_strs): - """ - compile func called in single process - - Parameters: - json_strs: list. List contains multiple kernel infos, suitable for json compile api. - """ - akg_compiler = os.path.join(os.path.split( - os.path.realpath(__file__))[0], "compiler.py") - for json_str in json_strs: - res = subprocess.run( - [sys.executable, akg_compiler, json_str], text=True) - if res.returncode != 0: - raise ValueError("Failed, args: {}!".format(json_str)) - - -def compile_akg_kernel_parallel(json_infos, process, waitime): - """ - compile kernel use multi processes - - Parameters: - json_infos: list. list contain kernel info(task id and json str) - process: int. processes num - waittime: int. max time the function blocked - - Returns: - True for all compile success, False for some failed. - """ - if not isinstance(json_infos, list): - raise ValueError("json_infos must be a list") - if not isinstance(process, int): - raise ValueError("process must be a num") - if not isinstance(waitime, int): - raise ValueError("waittime must be a num") - - if process == 0 and json_infos: - process = 1 - - cpu_proc_num = cpu_count() - max_proc_num = 16 - process = min([cpu_proc_num, max_proc_num, process]) - - args = [[] for _ in range(process)] - for p, info in enumerate(json_infos): - args[p % process].append(info) - - with Pool(processes=process) as pool: - res = pool.starmap_async(_compile_akg_task, args) - res.get(timeout=waitime) - return True diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py index 80b50c45a9..12bdd8ea38 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py @@ -22,14 +22,14 @@ import json from .common import check_kernel_info, TBEException from .helper import _op_select_format, _check_supported -def create_tbe_parallel_compiler(): +def create_tbe_parallel_process(): """ create TBEParallelCompiler object Returns: TBEParallelCompiler """ - return compile_pool + return tbe_process def op_select_format(op_json: str): """ @@ -98,8 +98,8 @@ def run_compiler(op_json): except subprocess.CalledProcessError as e: return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json -class CompilerPool: - """compiler pool""" +class TbeProcess: + """tbe process""" def __init__(self): self.__processe_num = multiprocessing.cpu_count() @@ -168,5 +168,4 @@ class CompilerPool: if self.__running_tasks: self.__running_tasks.clear() - -compile_pool = CompilerPool() +tbe_process = TbeProcess() diff --git a/mindspore/_extends/parse/__init__.py b/mindspore/_extends/parse/__init__.py index 323932560a..10a991c1ed 100644 --- a/mindspore/_extends/parse/__init__.py +++ b/mindspore/_extends/parse/__init__.py @@ -21,12 +21,12 @@ from .parser import (Parser, create_obj_instance, generate_scope, get_class_member_namespace_symbol, create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_module_namespace, get_obj_type, get_object_key, - get_default_input, get_parse_method_of_class, get_scope_name, + get_parse_method_of_class, get_scope_name, is_class_member, parse_cb, resolve_symbol) from .serialize import * __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', - 'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member', + 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', diff --git a/mindspore/_extends/parse/namespace.py b/mindspore/_extends/parse/namespace.py index 8d8b6fd30e..f32abed284 100644 --- a/mindspore/_extends/parse/namespace.py +++ b/mindspore/_extends/parse/namespace.py @@ -99,12 +99,19 @@ class ClassMemberNamespace(Namespace): obj (Object): A python class object. """ def __init__(self, obj): + self.__class_member_namespace__ = True label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' super().__init__(label, obj) def __getitem__(self, name): d, = self.dicts + if name == "self": + return d + if name == "namespace": + return self try: - return getattr(d, name) + if hasattr(d, name): + return getattr(d, name) + return d.__dict__[name] except ValueError: raise UnboundLocalError(name) diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 9d715fdf53..695f449f83 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -70,6 +70,7 @@ parse_expr_statement_white_list = ( "append", ) + def create_slice_obj(start, end, step): """Create slice object""" return slice(start, end, step) @@ -201,17 +202,9 @@ def get_object_key(obj): if isinstance(obj, types.MethodType): method_instance = obj.__self__ instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance)) - obj_id = instance_id + obj_id + obj_id = instance_id + obj_id + str(obj.__hash__()) return obj_id, obj_key -def get_default_input(obj): - if hasattr(obj, '__parameter__'): - return obj.default_input - if isinstance(obj, tuple): - convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x - args = tuple(convert(x) for x in obj) - return args - return obj def is_class_member(node): """Check the attr is class member variable.""" @@ -224,10 +217,12 @@ def is_class_member(node): return True return False + def get_obj_id(obj): """Get the obj id.""" return str(id(obj)) + def get_obj_type(obj): """Get the obj type.""" obj_type = RESOLVE_TYPE_INVALID @@ -320,6 +315,7 @@ def get_dataclass_methods(cls): if isinstance(getattr(cls, name), (types.FunctionType,))} return methods + class Parser: """ Parser python code to ast tree. @@ -453,6 +449,28 @@ class Parser: logger.debug("ops info = %r", ops_info) return ops_info + def analyze_super(self, class_type_node, subclass_instance): + """Analyze super and return a class instance.""" + sub_class = type(subclass_instance) + if class_type_node is None: + return super(sub_class, subclass_instance) + if isinstance(class_type_node, ast.Name): + class_name = getattr(class_type_node, 'id') + elif isinstance(class_type_node, ast.Attribute): + class_name = getattr(class_type_node, 'attr') + else: + raise ValueError(f"When call 'super', the first arg should be a class type, " + f"but got {class_type_node.__class__.__name__}.") + + target_father_class = None + for class_element in sub_class.mro(): + if class_element.__name__ == class_name: + target_father_class = class_element + break + if target_father_class is None: + raise ValueError("When call 'super', the second arg should be an instance of first arg.") + return super(target_father_class, subclass_instance) + def get_location(self, node): """ Get location of node start and end line no. diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index e60b70efac..e39a1c266c 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -17,7 +17,7 @@ """Resources for ast tree parse.""" import ast import math -from mindspore import IndexedSlices +from mindspore import RowTensor, SparseTensor from mindspore.ops.composite import multitype_ops from mindspore.ops import functional as F, composite as C from . import standard_method as M @@ -117,6 +117,7 @@ convert_object_map = { T.zip: C.zip_operation, T.print: F.print_, T.enumerate: M.enumerate_, + T.isinstance: M.isinstance_, # custom define operation T.iter: M.ms_iter, @@ -139,5 +140,6 @@ convert_object_map = { math.tan: NO_IMPLEMENT, # user defined - IndexedSlices: F.make_indexed_slices, + RowTensor: F.make_row_tensor, + SparseTensor: F.make_sparse_tensor, } diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index d70c6edcf4..763a4da780 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -27,6 +27,42 @@ from ...ops.composite.base import _append __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] trans = P.Transpose() +shape_ = P.Shape() +dtype_ = P.DType() + + +def all_(x, axis=(), keep_dims=False): + """ + Check all array elements along a given axis evaluate to True. + + Args: + x (Tensor): A Tensor to be reduced. + axis (Union[None, int, tuple(int)): Dimensions of reduction. + keep_dims (bool): Whether to keep the reduced dimensions. + + Returns: + Tensor, has the same data type as x. + """ + + reduce_all = P.ReduceAll(keep_dims) + return reduce_all(x, axis) + + +def any_(x, axis=(), keep_dims=False): + """ + Check any array element along a given axis evaluate to True. + + Args: + x (Tensor): A Tensor to be reduced. + axis (Union[None, int, tuple(int)): Dimensions of reduction. + keep_dims (bool): Whether to keep the reduced dimensions. + + Returns: + Tensor, has the same data type as x. + """ + + reduce_any = P.ReduceAny(keep_dims) + return reduce_any(x, axis) def transpose(x): @@ -114,6 +150,12 @@ def enumerate_(x, start=0): return ret +def isinstance_(x, base_type): + """Determine whether x is an instance of base_type.""" + x_type = F.typeof(x) + return check_type_same(x_type, base_type) + + def while_cond(x): """For while condtion, if the condition is a tensor, the loop will not be unrolled""" if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): @@ -123,6 +165,14 @@ def while_cond(x): return x +@constexpr +def check_type_same(x_type, base_type): + """Check x_type is same as base_type.""" + if mstype.issubclass_(x_type, base_type): + return True + raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.") + + @constexpr def check_is_tuple_or_list(x, op_name, arg_name): """check whether x is list or tuple.""" @@ -146,7 +196,8 @@ def check_is_tensor_bool_cond(shp): """check if tensor is a bool condition""" if shp in ((), (1,)): return True - raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp) + raise ValueError("The truth value of an array with several elements is ambiguous.") + @constexpr def const_tensor_to_bool(x): @@ -154,13 +205,12 @@ def const_tensor_to_bool(x): if x is None: raise ValueError("Only constant tensor bool can be converted to bool") x = x.asnumpy() - if x.shape not in ((), (1,)): - raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape) if x.shape == (): - value = bool(x) - else: - value = bool(x[0]) - return value + return bool(x) + if x.shape == (1,): + return bool(x[0]) + raise ValueError("The truth value of an array with several elements is ambiguous.") + def tensor_bool(x): """tensor as conditon, if is constant, return immediate bool value""" diff --git a/mindspore/_extends/parse/trope.py b/mindspore/_extends/parse/trope.py index 28f3196975..674715ef59 100644 --- a/mindspore/_extends/parse/trope.py +++ b/mindspore/_extends/parse/trope.py @@ -27,7 +27,7 @@ from operator import ( # noqa # support system function call from builtins import ( # noqa - bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate + bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate, isinstance ) # support functools @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'matmul', 'getitem', 'setitem', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', - 'partial', 'print', 'enumerate', + 'partial', 'print', 'enumerate', 'isinstance', 'exp', 'log', 'sin', 'cos', 'tan'] diff --git a/mindspore/_extends/remote/__init__.py b/mindspore/_extends/remote/__init__.py new file mode 100644 index 0000000000..b786fcce9d --- /dev/null +++ b/mindspore/_extends/remote/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Server functions. + +Python functions that will be called in the c++ client part of MindSpore. +""" diff --git a/mindspore/_extends/remote/kernel_build_server.py b/mindspore/_extends/remote/kernel_build_server.py new file mode 100644 index 0000000000..48ad7b3e96 --- /dev/null +++ b/mindspore/_extends/remote/kernel_build_server.py @@ -0,0 +1,174 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""kernel build server""" +import os +import time + +class Messager: + '''Messager''' + + def __init__(self, fdin, fdout): + self.fdin = fdin + self.fdout = fdout + self.fin = os.fdopen(fdin, "r") + self.fout = os.fdopen(fdout, "w") + self.message = '' + + def __del__(self): + os.close(self.fdin) + os.close(self.fdout) + + def get_message(self): + """ + Get message from remote + + Returns: + message + """ + try: + # Not read by input() anymore + res = self.fin.readline() + if not res: + logger.debug('[TRACE]', "read nothing...") + self.exit() + if res[len(res) - 1] == '\n': + res = res[0:len(res)-1] + self.message = res + logger.debug('[IN]', self.message) + except EOFError: + self.exit() + finally: + pass + if self.message == '' or self.message == 'FINISH': + self.send_ack() + self.exit() + return self.message + + def send_res(self, res, keep_format=True): + """ + Send result to remote + + Args: + keep_format: True or False + """ + logger.debug('[OUT]', str(res)) + if keep_format: + res_str = str(res).replace('\n', '[LF]').replace('\r', '[CR]').replace(' ', '[SP]') + else: + res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '') + tag = '[~]' # The same as client kTAG + + # Not write by print(tag + res_str, flush=True) any more + try: + self.fout.write(tag + res_str + "\n") + self.fout.flush() + except BrokenPipeError as err: + logger.info('[TRACE]', 'Write, ' + str(err)) + self.exit() + finally: + pass + + def send_ack(self, success=True): + """ + Send ack to remote + + Args: + success: True or False + """ + if success: + self.send_res('ACK') + else: + self.send_res('ERR') + + def loop(self): + """ + Messaging loop + """ + while True: + self.handle() + + def run(self): + self.loop() + + def handle(self): + """ + A interface communicates with remote. + + Note: + All subclasses should override this interface. + """ + raise NotImplementedError + + def exit(self): + """ + A interface handles the procedure before exit. + + Note: + All subclasses should override this interface. + """ + raise NotImplementedError + +class Logger: + """ + Replace dummy 'logger' to output log as below: + logger = Logger(0, True, "remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log") + """ + def __init__(self, level=1, dumpfile=False, filename='Logger.log'): + """ + Args: + level: 0 for debug and info, 1 for info + dumpfile: if dump log into file + """ + self.level = level + self.dumpfile = dumpfile + if self.dumpfile: + self.log = open(filename, "a") + + def write(self, msg): + self.log.write(msg) + self.flush() + + def writeline(self, tag, msg): + prefix = tag + ' REMOTE(' + str(os.getpid()) + ',python)' + line = prefix + '\t' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ':\t' + msg + print(line, flush=True) + if self.dumpfile: + self.write(line + '\n') + + def debug(self, tag, msg): + if self.level == 0: + self.writeline('[DEBUG]' + tag, msg) + + def info(self, tag, msg): + self.writeline('[INFO]' + tag, msg) + + def flush(self): + self.log.flush() + +class DummyLogger: + """DummyLogger""" + def __init__(self): + pass + + def debug(self, tag, msg): + pass + + def info(self, tag, msg): + pass + +logger = DummyLogger() + +def get_logger(): + return logger diff --git a/mindspore/_extends/remote/kernel_build_server_ascend.py b/mindspore/_extends/remote/kernel_build_server_ascend.py new file mode 100644 index 0000000000..e77beda2f3 --- /dev/null +++ b/mindspore/_extends/remote/kernel_build_server_ascend.py @@ -0,0 +1,148 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""kernel build server for ascend""" +import sys +from mindspore._extends.remote.kernel_build_server import Messager, get_logger +from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported +from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process + +class TbeBuilder: + """Tbe building wrapper""" + + def __init__(self): + self.tbe_builder = create_tbe_parallel_process() + + def start(self, json): + return self.tbe_builder.start_compile_op(json) + + def wait(self): + return self.tbe_builder.wait_one() + + def reset(self): + self.tbe_builder.reset_task_info() + + def exit(self): + self.tbe_builder.exit() + +class AkgBuilder: + """Akg building wrapper""" + + def __init__(self): + pass + + def create(self, process_num, waitime): + self.akg_builder = create_akg_parallel_process(process_num, waitime) + + def accept_json(self, json): + return self.akg_builder.accept_json(json) + + def compile(self): + return self.akg_builder.compile() + +class AscendMessager(Messager): + ''' + Ascend Messager + It works as a server, communicating with c++ client. + ''' + + def __init__(self, fdin, fdout): + super().__init__(fdin, fdout) + get_logger().info('[TRACE]', 'Ascend Messager init...') + self.tbe_builder = TbeBuilder() + self.akg_builder = AkgBuilder() + + def handle(self): + """ + Communicate with remote client. + Reference protocol between them at PR#3821 and PR#3935 + """ + arg = self.get_message() + if arg == 'TBE/START': + self.send_ack() + json = self.get_message() + res = self.tbe_builder.start(json) + self.send_res(res) + elif arg == 'TBE/WAIT': + self.send_ack() + task_id, res, pre = self.tbe_builder.wait() + get_logger().debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre)) + if self.get_message() != 'CONTINUE': + self.send_ack(False) + self.exit() + self.send_res(task_id) + if self.get_message() != 'CONTINUE': + self.send_ack(False) + self.exit() + self.send_res(res) + if self.get_message() != 'CONTINUE': + self.send_ack(False) + self.exit() + self.send_res(pre) + elif arg == 'TBE/RESET': + self.tbe_builder.reset() + self.send_ack() + elif arg == 'AKG/START': + self.send_ack() + process_num_str = self.get_message() + self.send_ack() + wait_time_str = self.get_message() + self.akg_builder.create(int(process_num_str), int(wait_time_str)) + self.send_ack() + elif arg == 'AKG/DATA': + self.send_ack() + while True: + req = self.get_message() + if req.startswith('{'): + self.akg_builder.accept_json(req) + self.send_ack() + elif req == 'AKG/WAIT': + res = self.akg_builder.compile() + self.send_res(res) + break + else: + self.send_ack(False) + break + elif arg == 'FORMAT': + self.send_ack() + json = self.get_message() + self.send_res(op_select_format(json)) + elif arg == 'SUPPORT': + self.send_ack() + json = self.get_message() + get_logger().debug('[SUPPORT]', json) + try: + res = check_supported(json) + except json.decoder.JSONDecodeError: + self.send_ack(False) + self.exit() + finally: + pass + self.send_res(res) + else: + self.send_ack(False) + self.exit() + + def exit(self): + self.tbe_builder.reset() + self.tbe_builder.exit() + get_logger().info('[TRACE]', 'Ascend Messager Exit...') + exit() + +if __name__ == '__main__': + if len(sys.argv) != 3: + raise Exception('Incorrect argv: {}'.format(sys.argv)) + get_logger().debug('[TRACE]', 'argv: ' + str(sys.argv)) + messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2])) + messager.run() diff --git a/mindspore/_extends/remote/kernel_build_server_gpu.py b/mindspore/_extends/remote/kernel_build_server_gpu.py new file mode 100644 index 0000000000..8bdf5805af --- /dev/null +++ b/mindspore/_extends/remote/kernel_build_server_gpu.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""kernel build server for gpu""" +import os +import sys +from mindspore._extends.remote.kernel_build_server import Messager, get_logger +from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single + +class GpuMessager(Messager): + ''' + GPU Messager + It works as a server, communicating with c++ client. + ''' + + def __init__(self, fdin, fdout): + super().__init__(fdin, fdout) + get_logger().info('[TRACE]', 'GPU Messager init...') + + def handle(self): + """ + Communicate with remote client. + Reference protocol between them at PR#4063 + """ + arg = self.get_message() + if arg == 'AKG/PID': + self.send_res(os.getpid()) + elif arg == 'AKG/COMPILE': + self.send_ack() + json = self.get_message() + try: + akg_compile_single(json) + except ValueError: + self.send_ack(False) + self.exit() + finally: + pass + self.send_ack() + else: + self.send_ack(False) + self.exit() + + def exit(self): + get_logger().info('[TRACE]', 'GPU Messager Exit...') + exit() + +if __name__ == '__main__': + if len(sys.argv) != 3: + raise Exception('Incorrect argv: {}'.format(sys.argv)) + get_logger().debug('[TRACE]', 'argv: ' + str(sys.argv)) + messager = GpuMessager(int(sys.argv[1]), int(sys.argv[2])) + messager.run() diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index bb02f338f6..41d2c7e726 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -44,7 +44,7 @@ if(ENABLE_GPU) "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" ) - list(APPEND CUDA_NVCC_FLAGS -arch=sm_53) + list(APPEND CUDA_NVCC_FLAGS -arch=sm_53 --expt-relaxed-constexpr) list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc") list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc" "runtime/device/gpu/distribution/collective_wrapper.cc" @@ -60,11 +60,6 @@ if(ENABLE_GPU) add_compile_definitions(ENABLE_GPU) endif () -## make flatuffer files -include_directories("${CMAKE_BINARY_DIR}/predict/schema/inner") -file(GLOB_RECURSE FLATBUFFER_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "predict/schema/*.fbs") -set(FLATBUFFER_OU "${CMAKE_BINARY_DIR}/predict/schema/inner") -ms_build_flatbuffers("${FLATBUFFER_IN}" "${FLATBUFFER_IN}" flat_input "${FLATBUFFER_OU}") ## make protobuf files file(COPY "${ms_onnx_INC}/onnx/onnx.proto" DESTINATION ${CMAKE_BINARY_DIR}/proto) @@ -104,13 +99,9 @@ endif () if (ENABLE_D) include_directories("${CMAKE_BINARY_DIR}/backend/kernel_compiler/aicpu") - include_directories("${CMAKE_BINARY_DIR}/predict/generator/ir") file(GLOB_RECURSE PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "backend/kernel_compiler/aicpu/proto/*.proto") ms_protobuf_generate(PROTOSRCS PROTOHDRS ${PROTO_IN}) - file(GLOB_RECURSE PROTO_INNER RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "predict/proto/*.proto") - ms_protobuf_generate(PREDICT_PROTOSRCS PREDICT_PROTOHDRS ${PROTO_INNER}) - file(GLOB_RECURSE PROTO_DUMP RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "runtime/device/ascend/dump/proto/*.proto") ms_protobuf_generate(DUMP_PROTOSRCS PROTOHDRS ${PROTO_DUMP}) @@ -139,7 +130,7 @@ set(SUB_COMP frontend/operator pipeline/jit pipeline/pynative - common debug gvar predict pybind_api utils vm + common debug gvar pybind_api utils vm ) foreach (_comp ${SUB_COMP}) @@ -147,16 +138,18 @@ foreach (_comp ${SUB_COMP}) string(REPLACE "/" "_" sub ${_comp}) if (TARGET _mindspore_${sub}_obj) list(APPEND SUB_OBJECTS_SRC $) - add_dependencies(_mindspore_${sub}_obj proto_input flat_input) + add_dependencies(_mindspore_${sub}_obj proto_input ) endif () endforeach () add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/base base) list(APPEND SUB_OBJECTS_SRC $) add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/abstract abstract) list(APPEND SUB_OBJECTS_SRC $) +add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/utils util) +list(APPEND SUB_OBJECTS_SRC $) add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir) list(APPEND SUB_OBJECTS_SRC $) -add_dependencies(_mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input flat_input) +add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input ) set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) @@ -169,7 +162,7 @@ if (ENABLE_DEBUGGER) endif() target_link_libraries(mindspore proto_input) -if (ENABLE_MPI) +if (ENABLE_MPI AND ENABLE_CPU) target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter) else () target_link_libraries(mindspore securec mindspore::flatbuffers) @@ -252,15 +245,15 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows") target_link_libraries(mindspore mindspore_gvar) target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) else () - target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) - target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) - target_link_libraries(_c_expression PRIVATE mindspore_gvar) - if (NOT ENABLE_GE) - target_link_libraries(_c_expression PRIVATE mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) + if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) + target_link_libraries(mindspore mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) if (${ENABLE_IBVERBS} STREQUAL "ON") - target_link_libraries(_c_expression PRIVATE ibverbs rdmacm) + target_link_libraries(mindspore ibverbs rdmacm) endif() endif() + target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) + target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) + target_link_libraries(_c_expression PRIVATE mindspore_gvar) endif () if (USE_GLOG) @@ -278,7 +271,11 @@ if (ENABLE_GPU) ${CUDA_PATH}/lib64/libcurand.so ${CUDNN_PATH}/lib64/libcudnn.so ${CUDA_PATH}/lib64/libcudart.so - ${CUDA_PATH}/lib64/stubs/libcuda.so) + ${CUDA_PATH}/lib64/stubs/libcuda.so + ${CUDA_PATH}/lib64/libcusolver.so) + if (ENABLE_MPI) + set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH}) + endif() endif () if (ENABLE_CPU) @@ -296,7 +293,7 @@ set(LOAD_ONNX_SRC ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc ) add_library(inference SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc ${LOAD_ONNX_SRC} ) target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index b412d83d11..4de32578a5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -26,14 +26,6 @@ if (ENABLE_CPU) "cpu/*.cc" ) - list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc" - "cpu/ps/pull_kernel.cc" - "cpu/ps/embedding_look_up_ps_kernel.cc" - "cpu/ps/embedding_look_up_proxy_kernel.cc" - "cpu/ps/apply_momentum_ps_kernel.cc" - "cpu/ps/sparse_apply_adam_ps_kernel.cc" - "cpu/ps/sparse_apply_ftrl_ps_kernel.cc") - if (NOT ENABLE_MPI) list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc") @@ -41,6 +33,18 @@ if (ENABLE_CPU) endif () endif () +if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/pserver_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/pull_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_adam_ps_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_ftrl_ps_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc") +endif() + if (ENABLE_GPU) file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cu" diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc index 7e7fd20f39..1e855f8fc0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc @@ -30,7 +30,7 @@ #include "proto/attr.pb.h" #include "proto/node_def.pb.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/aicpu/aicpu_util.h" #include "backend/session/kernel_graph.h" #include "backend/kernel_compiler/common_utils.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h index 6e2ee3959b..a1b5b4b84e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_BUILD_H_ #include #include "backend/kernel_compiler/kernel.h" @@ -24,4 +24,4 @@ KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h index e21f4eace4..d82fa6b02b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_META_DATA_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_META_DATA_H_ #include #include @@ -27,4 +27,4 @@ namespace kernel { void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc index e18b3169f3..b2f992fc82 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -26,7 +26,7 @@ #include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" #include "utils/convert_utils.h" #include "backend/kernel_compiler/aicpu/aicpu_util.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" using AicpuTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h index 82260010ea..9bc75d1110 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_MOD_H_ #include #include #include @@ -72,4 +72,4 @@ using AicputOpKernelModPtrList = std::vector; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index fd4495afeb..d68aef3f86 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_UTIL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_UTIL_H_ #include #include @@ -61,4 +61,4 @@ class AicpuOpUtil { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_UTIL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc index 73fdb5c11b..e4d0a6c00a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc @@ -18,6 +18,7 @@ #include #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -75,15 +76,7 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { std::string dst_type; TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - if (output_type == kFloat32->type_id()) { - dst_type = "float32"; - } else if (output_type == kFloat16->type_id()) { - dst_type = "float16"; - } else if (output_type == kInt32->type_id()) { - dst_type = "int32"; - } else { - MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); - } + dst_type = TypeId2String(output_type); AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h index 9ba724db42..2db84631ef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H -#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H #include #include @@ -55,4 +55,4 @@ const std::unordered_map #include #include #include @@ -31,19 +30,16 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/convert_utils.h" #include "utils/any.h" #include "utils/utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" +#include "backend/session/kernel_build_client.h" namespace mindspore { namespace kernel { -constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; -constexpr int32_t ARGS_SIZE = 1; -constexpr auto kCompileWithJsonFunc = "compilewithjson"; - // json key constexpr auto kOpDesc = "op_desc"; constexpr auto kInputDesc = "input_desc"; @@ -70,25 +66,6 @@ std::string Vector2Str(const std::vector &inputs) { } } // namespace -std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, const std::pair &position) { if (node_json.count(tag) == 0) { @@ -528,32 +505,11 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod return cached_kernel_pack; } - PyObject *pModule = nullptr; - PyObject *pFunc = nullptr; - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - - pModule = PyImport_ImportModule(kAkgModule); - if (pModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "]."; - return nullptr; - } - - pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc); - pArg = PyTuple_New(ARGS_SIZE); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str())); - (void)alarm(AUTODIFF_COMPILE_OVERTIME); - pRes = PyEval_CallObject(pFunc, pArg); + auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(node_json); (void)alarm(0); - if (pRes == nullptr) { - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; - return nullptr; - } - if (PyObject_IsTrue(pRes) != 1) { - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; + if (!res) { + MS_LOG(ERROR) << "Akg compile failed, json: " << node_json; return nullptr; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h index 7b6a2f0b86..41073007bc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_ #include #include #include @@ -24,6 +24,7 @@ #include #include "backend/kernel_compiler/kernel.h" #include "ir/dtype.h" +#include "ir/primitive.h" #include #include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/oplib/oplib.h" @@ -73,4 +74,4 @@ std::string GetTensorName(const nlohmann::json &node_json, const std::string &ta } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h index 02785c6cdb..b8b7b3885d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_METADATA_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_METADATA_H_ #include #include @@ -28,4 +28,4 @@ namespace kernel { void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_METADATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc index d698c89bc9..57217c66b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc @@ -23,7 +23,6 @@ #include #include #include -#include #include "ir/dtype.h" #include "ir/func_graph.h" #include "backend/kernel_compiler/kernel.h" @@ -32,10 +31,10 @@ #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" #include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_build_client.h" namespace mindspore { namespace kernel { -constexpr int32_t PARALLEL_ARGS_SIZE = 3; constexpr int32_t PROCESS_NUM = 16; constexpr int32_t TIME_OUT = 300; @@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type"; constexpr auto kInputDesc = "input_desc"; constexpr auto kOutputDesc = "output_desc"; constexpr auto kTensorName = "tensor_name"; -constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; -constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; + namespace { void UpdateTensorNameInJson(const std::vector &anf_nodes, std::map *node_json_map) { @@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector &anf return true; } -void GenParallelCompileFuncArgs(const std::vector &kernel_jsons, PyObject **p_args) { - MS_EXCEPTION_IF_NULL(p_args); - *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); - - PyObject *arg1 = PyList_New(kernel_jsons.size()); - for (int i = 0; i < PyList_Size(arg1); ++i) { - PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); - } - PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); - PyObject *arg3 = Py_BuildValue("i", TIME_OUT); - - (void)PyTuple_SetItem(*p_args, 0, arg1); - (void)PyTuple_SetItem(*p_args, 1, arg2); - (void)PyTuple_SetItem(*p_args, 2, arg3); -} - bool AkgOpParallelBuild(const std::vector> &build_args) { auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); if (jsons.empty()) { return true; } - // Try to call python method to compile nodes parallely. - PyObject *p_module = nullptr; - PyObject *p_func = nullptr; - PyObject *p_arg = nullptr; - PyObject *p_res = nullptr; - - p_module = PyImport_ImportModule(kMultiProcModule); - if (p_module == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; + // Start building in AKG + if (!AscendKernelBuildClient::Instance().AkgStart(PROCESS_NUM, TIME_OUT)) { + MS_LOG(ERROR) << "Akg start failed."; return false; } - - p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); - GenParallelCompileFuncArgs(jsons, &p_arg); - MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() - << " Akg kernels parallelly."; - p_res = PyEval_CallObject(p_func, p_arg); - if (p_res == nullptr) { - PyErr_Print(); - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + if (!AscendKernelBuildClient::Instance().AkgSendData(jsons)) { + MS_LOG(ERROR) << "Akg send data failed."; return false; } - if (PyObject_IsTrue(p_res) != 1) { - PyErr_Print(); - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + if (!AscendKernelBuildClient::Instance().AkgWait()) { + MS_LOG(ERROR) << "Akg compile failed."; return false; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h index 713b65a451..de301c378e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ #include #include @@ -53,4 +53,4 @@ bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc index 8bb4940778..8affa12c32 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc @@ -26,7 +26,7 @@ #include "runtime/rt.h" #include "utils/log_adapter.h" #include "utils/convert_utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h index 3ea36f1a23..147349d747 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ #include #include #include @@ -51,4 +51,4 @@ using AkgKernelModPtr = std::shared_ptr; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc index 96fcd1869e..e2cae2873a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc @@ -20,7 +20,7 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/akg/akg_kernel_build.h" #include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h index abb6d1f030..685fad862d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ #include "backend/kernel_compiler/kernel.h" #include "base/base.h" @@ -25,4 +25,4 @@ KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc index d527f8ec76..65ba8b0b74 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc @@ -18,7 +18,7 @@ #include #include #include "nlohmann/json.hpp" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h index a6a17d033f..b87d223f7f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ #include #include #include @@ -79,4 +79,4 @@ using GpuKernelModPtr = std::shared_ptr; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h index c6398eda9e..a7e747fca1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h @@ -14,16 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_KERNEL_MOD_H_ #include #include #include "framework/ge_runtime/task_info.h" #include "backend/kernel_compiler/kernel.h" -#ifdef ENABLE_DATA_DUMP #include "debug/data_dump_parser.h" -#endif using TaskInfoPtr = std::shared_ptr; namespace mindspore { @@ -34,13 +32,7 @@ class AscendKernelMod : public KernelMod { const std::vector &, uint32_t) = 0; uint32_t block_dim() { return block_dim_; } uint32_t stream_id() { return stream_id_; } - virtual bool NeedDump() { -#ifdef ENABLE_DATA_DUMP - return DataDumpParser::GetInstance().NeedDump(kernel_name_); -#else - return false; -#endif - } + virtual bool NeedDump() { return DataDumpParser::GetInstance().NeedDump(kernel_name_); } protected: uint32_t block_dim_{1}; @@ -49,4 +41,4 @@ class AscendKernelMod : public KernelMod { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index f4495cdb9d..bf383ff9d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -20,15 +20,16 @@ #include #include #include +#include #include #include "nlohmann/json.hpp" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "ir/manager.h" #include "ir/meta_tensor.h" #include "ir/func_graph.h" #include "frontend/operator/ops.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" namespace mindspore { namespace kernel { @@ -73,8 +74,12 @@ const std::unordered_map fusion_type_maps = { {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, }; -void KernelMeta::Initialize() { - kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; +void KernelMeta::Initialize(int pid) { + if (pid == -1) { + kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; + } else { + kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/"; + } // remove old kernel cache RemoveKernelCache(); @@ -499,235 +504,329 @@ int Sign(float x) { return 0; } -void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim) { - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - std::unordered_map index_map; - size_t unique_indices_size = 0; - for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { - int index = origin_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= first_dim) { - continue; - } - auto iter = index_map.find(index); - if (iter == index_map.end()) { - index_map[index] = unique_indices_size; - unique_grad->indices_[unique_indices_size] = index; - size_t start_index = unique_indices_size * outer_dim; - size_t end_index = start_index + outer_dim; - for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { - unique_grad->value_[j] = origin_sparse_grad.value_[k]; - } - unique_indices_size++; - } else { - size_t first_index = iter->second; - size_t start_index = first_index * outer_dim; - size_t end_index = start_index + outer_dim; - for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { - unique_grad->value_[j] += origin_sparse_grad.value_[k]; - } +namespace { +struct BucketSparseGradient { + float *value_; + int *indices_; + int *global_indices_; + size_t indices_size_; +}; + +struct MultiThreadReduceSparseGradientParam { + SparseGradient *input_grad_{nullptr}; + SparseGradient *workspace_grad_{nullptr}; + SparseGradient *output_grad_{nullptr}; + size_t max_index_{0}; + size_t value_stride_{0}; + size_t thread_num_{0}; + bool use_sort_reduce_{false}; +}; + +void CalculateEachBucketSize(const std::shared_ptr &sparse_grad, size_t max_index, + std::vector *each_bucket_size) { + MS_LOG(DEBUG) << "Start"; + MS_EXCEPTION_IF_NULL(sparse_grad); + MS_EXCEPTION_IF_NULL(sparse_grad->indices_); + MS_EXCEPTION_IF_NULL(each_bucket_size); + size_t bucket_num = each_bucket_size->size(); + for (size_t i = 0; i < sparse_grad->indices_size_; ++i) { + int index = sparse_grad->indices_[i]; + if (index >= 0 && IntToSize(index) < max_index) { + auto bucket_id = index % bucket_num; + each_bucket_size->at(bucket_id)++; } } - unique_grad->indices_size_ = unique_indices_size; + MS_LOG(DEBUG) << "End"; } -struct WorkerParamsForReduceSparseGradient { - size_t slice_start_{0}; - size_t slice_end_{0}; - size_t max_length_{0}; - size_t outer_dim_{0}; - std::vector> *sorted_indices_{nullptr}; - std::vector *slice_positions_{nullptr}; - float *src_value_{nullptr}; - SparseGradient *unique_grad_{nullptr}; -}; +void SplitAndCalculateSegmentBucketSize(const MultiThreadReduceSparseGradientParam ¶m, + std::vector> *segments_ptr, + std::vector>> *segment_bucket_sizes_ptr) { + MS_EXCEPTION_IF_NULL(param.input_grad_); + MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr); + MS_EXCEPTION_IF_NULL(segments_ptr); + auto &segments = *segments_ptr; + auto &segment_bucket_sizes = *segment_bucket_sizes_ptr; + auto input_grad = param.input_grad_; + if (param.thread_num_ < 1) { + MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!"; + } + size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_; + size_t left_indices_size = input_grad->indices_size_ % param.thread_num_; + std::vector threads; + threads.reserve(param.thread_num_); + segments.reserve(param.thread_num_); -void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) { - MS_EXCEPTION_IF_NULL(param.sorted_indices_); - MS_EXCEPTION_IF_NULL(param.slice_positions_); - MS_EXCEPTION_IF_NULL(param.src_value_); - MS_EXCEPTION_IF_NULL(param.unique_grad_); - auto outer_dim = param.outer_dim_; - auto &sorted_indices = *(param.sorted_indices_); - auto &slice_positions = *(param.slice_positions_); - auto unique_grad = param.unique_grad_; - for (size_t slice_id = param.slice_start_; slice_id < param.slice_end_; ++slice_id) { - size_t cur_pos = slice_positions[slice_id]; - int index = sorted_indices[cur_pos].first; - unique_grad->indices_[slice_id] = index; - size_t start_index = slice_id * outer_dim; - auto ret_code = memcpy_s(unique_grad->value_ + start_index, (param.max_length_ - start_index) * sizeof(float), - param.src_value_ + sorted_indices[cur_pos].second, outer_dim * sizeof(float)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; - } - cur_pos++; - size_t end_pos; - if (slice_id + 1 < slice_positions.size()) { - end_pos = slice_positions[slice_id + 1]; - } else { - end_pos = sorted_indices.size(); - } - while (cur_pos < end_pos) { - for (size_t i = 0; i < outer_dim; ++i) { - unique_grad->value_[start_index + i] += param.src_value_[sorted_indices[cur_pos].second + i]; - } - cur_pos++; + size_t current_indices_offset = 0; + for (size_t i = 0; i < param.thread_num_; ++i) { + segment_bucket_sizes.emplace_back(std::make_shared>(param.thread_num_, 0)); + size_t indices_size = thread_indices_size; + if (i < left_indices_size) { + indices_size += 1; } + segments.emplace_back(std::make_shared()); + segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_; + segments[i]->indices_ = input_grad->indices_ + current_indices_offset; + segments[i]->indices_size_ = indices_size; + threads.emplace_back( + std::thread(CalculateEachBucketSize, segments[i], param.max_index_, segment_bucket_sizes[i].get())); + current_indices_offset += indices_size; + } + + for (size_t i = 0; i < param.thread_num_; ++i) { + threads[i].join(); } } -void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, - size_t outer_dim, std::vector> *sorted_indices, - std::vector *slice_positions) { +void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam ¶m, + const std::shared_ptr &segment, size_t bucket_offset, + const std::vector> &buckets) { MS_LOG(DEBUG) << "Start"; - size_t thread_num = 24; - if (slice_positions->size() < thread_num) { - thread_num = slice_positions->size(); + MS_EXCEPTION_IF_NULL(segment); + MS_EXCEPTION_IF_NULL(segment->indices_); + std::vector bucket_data_num(param.thread_num_, 0); + for (size_t i = 0; i < segment->indices_size_; ++i) { + int index = segment->indices_[i]; + if (index >= 0 && IntToSize(index) < param.max_index_) { + auto bucket_id = index % param.thread_num_; + auto bucket_index = bucket_data_num[bucket_id]; + buckets[bucket_id]->indices_[bucket_index] = index; + buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i; + bucket_data_num[bucket_id]++; + } } - size_t stride = (slice_positions->size() + thread_num - 1) / thread_num; - thread_num = (slice_positions->size() + stride - 1) / stride; - std::vector threads; - size_t max_length = sorted_indices->size() * outer_dim; + MS_LOG(DEBUG) << "End"; +} + +void GatherSegmentIndicesToOutputBucket(const MultiThreadReduceSparseGradientParam ¶m, + const std::vector> &segments, + const std::vector>> &segment_bucket_sizes, + std::vector> *buckets_ptr) { + MS_EXCEPTION_IF_NULL(param.output_grad_); + MS_EXCEPTION_IF_NULL(param.output_grad_->value_); + MS_EXCEPTION_IF_NULL(param.output_grad_->indices_); + MS_EXCEPTION_IF_NULL(buckets_ptr); + auto &buckets = *buckets_ptr; + size_t thread_num = param.thread_num_; + if (thread_num != segment_bucket_sizes.size()) { + MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!"; + } + std::vector bucket_data_size(thread_num, 0); for (size_t i = 0; i < thread_num; ++i) { - size_t slice_start = i * stride; - size_t slice_end = 0; - if (i == thread_num - 1) { - slice_end = slice_positions->size(); - } else { - slice_end = slice_start + stride; + for (size_t j = 0; j < thread_num; ++j) { + bucket_data_size[j] += segment_bucket_sizes[i]->at(j); } - WorkerParamsForReduceSparseGradient params{ - slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_, - unique_grad}; - threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params)); + } + size_t current_indices_offset = 0; + for (size_t i = 0; i < thread_num; ++i) { + buckets.emplace_back(std::make_shared()); + buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_; + buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset; + buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset; + buckets[i]->indices_size_ = bucket_data_size[i]; + current_indices_offset += bucket_data_size[i]; + } + std::vector tmp_bucket_data_size(thread_num, 0); + std::vector>> each_thread_buckets; + for (size_t i = 0; i < thread_num; ++i) { + std::vector> thread_buckets; + for (size_t j = 0; j < thread_num; ++j) { + thread_buckets.emplace_back(std::make_shared()); + thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j]; + thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j]; + thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_; + thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j); + tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j); + } + each_thread_buckets.emplace_back(thread_buckets); + } + std::vector threads; + threads.reserve(thread_num); + current_indices_offset = 0; + for (size_t i = 0; i < thread_num; ++i) { + threads.emplace_back( + std::thread(CopySegmentIndicesToBucket, param, segments[i], current_indices_offset, each_thread_buckets[i])); + current_indices_offset += segments[i]->indices_size_; } for (size_t i = 0; i < thread_num; ++i) { threads[i].join(); } - MS_LOG(DEBUG) << "End"; } -void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim, bool use_multi_threads) { +void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam ¶m, + const std::shared_ptr &bucket, + const std::shared_ptr &reduced_bucket) { MS_LOG(DEBUG) << "Start"; - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - std::vector> sorted_indices; - sorted_indices.reserve(origin_sparse_grad.indices_size_); - for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { - int index = origin_sparse_grad.indices_[i]; - if (index >= 0 && IntToSize(index) < first_dim) { - sorted_indices.emplace_back(std::pair(index, i * outer_dim)); - } - } - std::sort( - sorted_indices.begin(), sorted_indices.end(), - [](const std::pair &left, const std::pair &right) { return left.first < right.first; }); - int last_index = 0; - std::vector slice_positions; - slice_positions.reserve(sorted_indices.size()); + MS_EXCEPTION_IF_NULL(bucket); + MS_EXCEPTION_IF_NULL(bucket->value_); + MS_EXCEPTION_IF_NULL(bucket->indices_); + MS_EXCEPTION_IF_NULL(reduced_bucket); + MS_EXCEPTION_IF_NULL(reduced_bucket->value_); + MS_EXCEPTION_IF_NULL(reduced_bucket->indices_); + std::vector> sorted_indices; + sorted_indices.reserve(bucket->indices_size_); + for (size_t i = 0; i < bucket->indices_size_; ++i) { + int index = bucket->indices_[i]; + int global_index = bucket->global_indices_[i]; + sorted_indices.emplace_back(std::pair(index, global_index)); + } + std::sort(sorted_indices.begin(), sorted_indices.end()); + + float *global_value = param.input_grad_->value_; + size_t unique_indices_size = 0; + size_t max_length = reduced_bucket->indices_size_ * param.value_stride_; + int last_index{0}; + size_t value_offset{0}; for (size_t i = 0; i < sorted_indices.size(); ++i) { - if (i == 0 || last_index != sorted_indices[i].first) { - slice_positions.emplace_back(i); + int index = sorted_indices[i].first; + int global_index = sorted_indices[i].second; + int global_value_offset = global_index * param.value_stride_; + if (i == 0 || index != last_index) { + if (i != 0) { + unique_indices_size++; + } + reduced_bucket->indices_[unique_indices_size] = index; + value_offset = unique_indices_size * param.value_stride_; + auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float), + global_value + global_value_offset, param.value_stride_ * sizeof(float)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + } else { + for (size_t j = 0; j < param.value_stride_; ++j) { + reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j]; + } } - last_index = sorted_indices[i].first; + last_index = index; } - if (use_multi_threads) { - RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions); - } else { - size_t max_length = sorted_indices.size() * outer_dim; - WorkerParamsForReduceSparseGradient params{0, - slice_positions.size(), - max_length, - outer_dim, - &sorted_indices, - &slice_positions, - origin_sparse_grad.value_, - unique_grad}; - WorkerForReduceSparseGradient(params); - } - unique_grad->indices_size_ = slice_positions.size(); + reduced_bucket->indices_size_ = unique_indices_size; MS_LOG(DEBUG) << "End"; } -void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, - SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim) { +void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam ¶m, + const std::shared_ptr &bucket, + const std::shared_ptr &reduced_bucket) { MS_LOG(DEBUG) << "Start"; - if (unique_slice_grads.empty()) { - return; - } - size_t index_data_size = outer_dim * sizeof(float); + MS_EXCEPTION_IF_NULL(bucket); + MS_EXCEPTION_IF_NULL(bucket->value_); + MS_EXCEPTION_IF_NULL(bucket->indices_); + MS_EXCEPTION_IF_NULL(reduced_bucket); + MS_EXCEPTION_IF_NULL(reduced_bucket->value_); + MS_EXCEPTION_IF_NULL(reduced_bucket->indices_); + + float *global_value = param.input_grad_->value_; + std::unordered_map index_map; size_t unique_indices_size = 0; - for (size_t i = 0; i < unique_slice_grads.size(); ++i) { - auto &slice_grad = unique_slice_grads[i]; - auto ret_code = memcpy_s(tmp_grad->value_ + unique_indices_size * outer_dim, - (tmp_grad->indices_size_ - unique_indices_size) * index_data_size, slice_grad->value_, - slice_grad->indices_size_ * index_data_size); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; - } - ret_code = - memcpy_s(tmp_grad->indices_ + unique_indices_size, (tmp_grad->indices_size_ - unique_indices_size) * sizeof(int), - slice_grad->indices_, slice_grad->indices_size_ * sizeof(int)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; + size_t max_length = reduced_bucket->indices_size_ * param.value_stride_; + for (size_t i = 0; i < bucket->indices_size_; ++i) { + int index = bucket->indices_[i]; + int global_index = bucket->global_indices_[i]; + auto iter = index_map.find(index); + if (iter == index_map.end()) { + reduced_bucket->indices_[unique_indices_size] = index; + size_t start_index = unique_indices_size * param.value_stride_; + index_map[index] = start_index; + auto ret_code = memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float), + global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + unique_indices_size++; + } else { + size_t start_index = iter->second; + size_t end_index = start_index + param.value_stride_; + for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) { + reduced_bucket->value_[j] += global_value[k]; + } } - unique_indices_size += slice_grad->indices_size_; } - tmp_grad->indices_size_ = unique_indices_size; - ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim); + reduced_bucket->indices_size_ = unique_indices_size; MS_LOG(DEBUG) << "End"; } -void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, - SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) { - MS_LOG(DEBUG) << "Start"; - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - MS_EXCEPTION_IF_NULL(tmp_grad); - MS_EXCEPTION_IF_NULL(tmp_grad->value_); - MS_EXCEPTION_IF_NULL(tmp_grad->indices_); - size_t thread_num = 24; - if (origin_sparse_grad.indices_size_ < thread_num) { - thread_num = origin_sparse_grad.indices_size_; - } - size_t thread_indices_size = origin_sparse_grad.indices_size_ / thread_num; - size_t left_indices_size = origin_sparse_grad.indices_size_ % thread_num; +void ReduceBucketSparseGradientToWorkspace(const MultiThreadReduceSparseGradientParam ¶m, + const std::vector> &buckets, + std::vector> *reduced_buckets_ptr) { + MS_EXCEPTION_IF_NULL(param.workspace_grad_); + MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_); + MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_); + MS_EXCEPTION_IF_NULL(reduced_buckets_ptr); + auto &reduced_buckets = *reduced_buckets_ptr; + size_t thread_num = buckets.size(); std::vector threads; threads.reserve(thread_num); - std::vector> unique_slice_grads; + + size_t current_indices_offset = 0; for (size_t i = 0; i < thread_num; ++i) { - size_t indices_size = thread_indices_size; - if (i == thread_num - 1) { - indices_size = thread_indices_size + left_indices_size; + reduced_buckets.emplace_back(std::make_shared()); + reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_; + reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset; + reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_; + if (param.use_sort_reduce_) { + threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i])); + } else { + threads.emplace_back(std::thread(ReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i])); } - size_t value_offset = i * thread_indices_size * outer_dim; - size_t indices_offset = i * thread_indices_size; - auto slice_grad = SparseGradient( - {origin_sparse_grad.value_ + value_offset, origin_sparse_grad.indices_ + indices_offset, indices_size}); - unique_slice_grads.emplace_back(std::make_shared()); - unique_slice_grads[i]->value_ = unique_grad->value_ + value_offset; - unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset; - unique_slice_grads[i]->indices_size_ = indices_size; - threads.emplace_back( - std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false)); + current_indices_offset += buckets[i]->indices_size_; } for (size_t i = 0; i < thread_num; ++i) { threads[i].join(); } - ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim); +} + +void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam ¶m, + const std::vector> &reduced_buckets) { + MS_EXCEPTION_IF_NULL(param.output_grad_); + auto output_grad = param.output_grad_; + MS_EXCEPTION_IF_NULL(output_grad->value_); + MS_EXCEPTION_IF_NULL(output_grad->indices_); + size_t stride_data_size = param.value_stride_ * sizeof(float); + size_t unique_indices_size = 0; + for (size_t i = 0; i < reduced_buckets.size(); ++i) { + auto &bucket = reduced_buckets[i]; + MS_EXCEPTION_IF_NULL(bucket); + if (bucket->indices_size_ == 0) { + continue; + } + auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_, + (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_, + bucket->indices_size_ * stride_data_size); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + ret_code = memcpy_s(output_grad->indices_ + unique_indices_size, + (output_grad->indices_size_ - unique_indices_size) * sizeof(int), bucket->indices_, + bucket->indices_size_ * sizeof(int)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + unique_indices_size += bucket->indices_size_; + } + output_grad->indices_size_ = unique_indices_size; +} +} // namespace + +void BucketReduceSparseGradient(const ReduceSparseGradientParam ¶m) { + MS_LOG(DEBUG) << "Start"; + MS_EXCEPTION_IF_NULL(param.input_grad_); + size_t thread_num = 23; + if (param.input_grad_->indices_size_ < thread_num) { + thread_num = param.input_grad_->indices_size_; + } + MultiThreadReduceSparseGradientParam multi_thread_param({param.input_grad_, param.workspace_grad_, param.output_grad_, + param.max_index_, param.value_stride_, thread_num, + param.use_sort_reduce_}); + std::vector> segments; + std::vector>> segment_bucket_sizes; + SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes); + + std::vector> buckets; + GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets); + + std::vector> reduced_buckets; + ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets); + + MergeReduceSparseGradient(multi_thread_param, reduced_buckets); MS_LOG(DEBUG) << "End"; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index 8c9ea84b34..d9ebfe3c4c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_ #include #include @@ -40,7 +40,6 @@ constexpr auto kProcessorCuda = "cuda"; constexpr auto kJsonSuffix = ".json"; constexpr auto kInfoSuffix = ".info"; constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; -constexpr auto kAkgModule = "_akg"; constexpr auto kArgDataformat = "data_format"; const std::vector support_devices = {"aicore", "aicpu", "cuda"}; @@ -54,7 +53,7 @@ using KernelMetaPtr = std::shared_ptr; class KernelMeta { public: KernelMeta() = default; - void Initialize(); + void Initialize(int pid); void RemoveKernelCache(); std::string Search(const std::string &kernel_name) const; bool Insert(const std::string &kernel_name, const std::string &kernel_json); @@ -73,9 +72,18 @@ class KernelMeta { }; struct SparseGradient { - float *value_; - int *indices_; - size_t indices_size_; + float *value_{nullptr}; + int *indices_{nullptr}; + size_t indices_size_{0}; +}; + +struct ReduceSparseGradientParam { + SparseGradient *input_grad_{nullptr}; + SparseGradient *workspace_grad_{nullptr}; + SparseGradient *output_grad_{nullptr}; + size_t max_index_{0}; + size_t value_stride_{0}; + bool use_sort_reduce_{false}; }; struct MultiThreadComputeParams { @@ -112,10 +120,6 @@ void SaveJsonInfo(const std::string &json_name, const std::string &info); std::string GetProcessor(const AnfNodePtr &anf_node); bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); int Sign(float x); -void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim); -void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim, bool use_multi_threads = true); std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index); std::vector>> GetInputIndex(const std::vector &node_list, const std::vector &input_list); @@ -130,16 +134,9 @@ void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *sorted_indices, - std::vector *slice_positions); -void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, - SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim); -void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, - SparseGradient *unique_grad, size_t first_dim, size_t outer_dim); +void BucketReduceSparseGradient(const ReduceSparseGradientParam ¶m); std::vector GetReduceAttrAxis(const CNodePtr &cnode); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h index 925f0fab50..bf73b812a7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -45,4 +45,4 @@ MS_REG_CPU_KERNEL(AddN, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h index 42c83ccf0b..ed9c372671 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -41,4 +41,4 @@ MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32). } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc index c1ff8d54bd..76acf852d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h index 23e8488890..2d60bf3b75 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ #include #include @@ -55,4 +55,4 @@ MS_REG_CPU_KERNEL(ApplyMomentum, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h index 3883344f96..dc2cfcefac 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARGMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARGMAX_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -42,4 +42,4 @@ MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutpu } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARGMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h index c572f68230..60c1b5a6a2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_ #include #include @@ -43,4 +43,4 @@ MS_REG_CPU_KERNEL( BiasAddCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h index a5743879a7..f86805455e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_ #include #include @@ -40,4 +40,4 @@ MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).Add BiasAddGradCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h index 94e4ad40f3..207e04b7ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -47,4 +47,4 @@ MS_REG_CPU_KERNEL(Concat, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index f2aa292c6e..1e6ab7b182 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_ #include #include @@ -84,4 +84,4 @@ class CPUKernelUtils { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h index 80f9a342ac..90ab6f3ff0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_ #include #include @@ -23,7 +23,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "runtime/device/cpu/kernel_select_cpu.h" @@ -76,4 +76,4 @@ class CPUKernelRegistrar { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc index 344f03cc53..6bbf6c8a5d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "backend/kernel_compiler/cpu/debug_cpu_kernel.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #ifdef ENABLE_DEBUGGER #include "debug/debugger/debugger.h" #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h index 18302e8992..1e9d72b6ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEBUG_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEBUG_CPU_KERNEL_H_ #include #include @@ -38,4 +38,4 @@ MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEBUG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h index 3e3807f58e..e97a96780a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -43,4 +43,4 @@ MS_REG_CPU_KERNEL(EmbeddingLookupCommGrad, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc index b2feb9204f..1021bb1b6c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -17,160 +17,58 @@ #include #include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "runtime/device/cpu/mpi/mpi_adapter.h" #include "ir/primitive.h" namespace mindspore { namespace kernel { +namespace { +void LookUpTableTask(const float *input_addr, const int *indices_addr, float *output_addr, size_t indices_lens, + size_t outer_dim_size, int offset, size_t first_dim_size) { + size_t lens = outer_dim_size * sizeof(float); + for (size_t i = 0; i < indices_lens; ++i) { + int index = indices_addr[i] - offset; + if (index >= 0 && index < SizeToInt(first_dim_size)) { + size_t pos = index * outer_dim_size; + auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + output_addr += outer_dim_size; + } +} +} // namespace + void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_lens_ = 1; - for (auto shape : input_shape_) { - input_lens_ = input_lens_ * shape; - } - indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - indices_lens_ = 1; - for (auto shape : indices_shape_) { - indices_lens_ = indices_lens_ * shape; + std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.empty()) { + MS_LOG(EXCEPTION) << "param must be at least 1D"; } - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - axis_ = 4 - input_shape_.size(); - if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) { - reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrReduceScatterFlag); + first_dim_size_ = input_shape[0]; + for (size_t i = 1; i < input_shape.size(); ++i) { + outer_dim_size_ *= input_shape[i]; } -#ifdef ENABLE_MPI - if (reduce_scatter_flag_) { - size_t gatherv2_out_lens = 1; - for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { - if (i == 0) { - for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { - gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; - } - } else { - gatherv2_out_lens = gatherv2_out_lens * input_shape_[i]; - } - } - gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); - gather_v2_out_ = malloc(gatherv2_out_lens_); - if (gather_v2_out_ == nullptr) { - MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; - } - auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); - if (ret != 0) { - MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; - } - split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (const auto &shape : indices_shape) { + indices_lens_ *= shape; } -#else - if (reduce_scatter_flag_) { - MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; - } -#endif if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { offset_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrOffset); } - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); } bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast(gather_v2_out_) : output_addr; - size_t dim0 = input_shape_[0]; - size_t dim1 = input_shape_[1]; - size_t dim2 = input_shape_[2]; - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - LookUpTable(inputs, i, j, k, &gather_out_addr); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - LookUpTable(inputs, i, j, 0, &gather_out_addr); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - LookUpTable(inputs, i, 0, 0, &gather_out_addr); - } - } else if (axis_ == 0) { - LookUpTable(inputs, 0, 0, 0, &gather_out_addr); - } -#ifdef ENABLE_MPI - if (reduce_scatter_flag_) { - size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); - size_t reduce_scatter_out_lens = one_split_lens / 8; - const std::vector &group = {0, 1, 2, 3, 4, 5, 6, 7}; - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - for (int i = 0; i < split_num_; i++) { - mpi_instance->ReduceScatter(reinterpret_cast(gather_v2_out_) + i * one_split_lens, - output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum"); - } - } -#endif - return true; -} - -void LookUpTable_task(const float *input_addr, float *output_addr, const int *indices_addr, size_t indices_lens, - size_t num, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, - std::vector input_shape, size_t input_lens) { - size_t lens = num * sizeof(float); - for (size_t i = 0; i < indices_lens; ++i) { - int indices = indices_addr[i] - offset; - if (indices >= 0) { - size_t index = IntToSize(indices); - if (index < input_shape[axis]) { - size_t pos = 0; - if (axis == 3) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index); - } else if (axis == 2) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0); - } else if (axis == 1) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0); - } else if (axis == 0) { - pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0); - } - if (pos + num <= input_lens) { - auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - output_addr += num; - } -} - -void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr) { auto input_addr = reinterpret_cast(inputs[0]->addr); auto indices_addr = reinterpret_cast(inputs[1]->addr); - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); - float *task_out_addr = *output_addr; + auto output_addr = reinterpret_cast(outputs[0]->addr); const size_t thread_num = 8; std::thread threads[8]; size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; @@ -183,8 +81,8 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector } MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; threads[i] = - std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset, - task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_); + std::thread(LookUpTableTask, input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_, + task_proc_lens, outer_dim_size_, offset_, first_dim_size_); task_offset += task_proc_lens; if (task_offset + task_proc_lens > indices_lens_) { task_proc_lens = indices_lens_ - task_offset; @@ -193,14 +91,14 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector for (size_t j = 0; j < i; j++) { threads[j].join(); } - *output_addr += num * indices_lens_; + return true; } void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (input_shape.size() > 4) { MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() - << ", but EmbeddingLookUpCPUKernel olny support 4d or lower."; + << ", but EmbeddingLookUpCPUKernel only support 4d or lower."; } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h index 6c61ee346c..bbde7157cd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -24,44 +24,20 @@ namespace mindspore { namespace kernel { class EmbeddingLookUpCPUKernel : public CPUKernel { public: - EmbeddingLookUpCPUKernel() { - axis_ = 0; - offset_ = 0; - split_num_ = 0; - input_lens_ = 0; - indices_lens_ = 0; - gatherv2_out_lens_ = 0; - reduce_scatter_flag_ = false; - gather_v2_out_ = nullptr; - } - ~EmbeddingLookUpCPUKernel() override { - if (gather_v2_out_ != nullptr) { - free(gather_v2_out_); - gather_v2_out_ = nullptr; - } - } + EmbeddingLookUpCPUKernel() {} + ~EmbeddingLookUpCPUKernel() override {} void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - private: - void LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr); + protected: void CheckParam(const CNodePtr &kernel_node); - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; - int axis_; - int offset_; - int split_num_; - size_t input_lens_; - size_t indices_lens_; - size_t gatherv2_out_lens_; - bool reduce_scatter_flag_; - - void *gather_v2_out_; + int offset_{0}; + size_t indices_lens_{1}; + size_t first_dim_size_{1}; + size_t outer_dim_size_{1}; }; MS_REG_CPU_KERNEL( @@ -71,4 +47,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h index 6e4ed6d5f1..3b977e48e5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EQUAL_COUNT_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -40,4 +40,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EQUAL_COUNT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h index 8fdac0dfde..9d6ec0de68 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -49,4 +49,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc index e58b1d319c..25a01fdcac 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h" #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -32,8 +32,6 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) { @@ -57,6 +55,7 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector int_padding_r; const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + std::vector kernel_size({weight_shape[2], weight_shape[3]}); GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { MS_LOG(EXCEPTION) << "get padding failed"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h index c0c64ba4da..2861325f2d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_ #include #include @@ -40,4 +40,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc index 3fa6a91405..077cf9a619 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h" #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -32,8 +32,6 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { @@ -53,6 +51,7 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); std::vector int_padding_l; std::vector int_padding_r; + std::vector kernel_size({weight_shape[2], weight_shape[3]}); GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { MS_LOG(EXCEPTION) << "get padding failed"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h index ae8269c142..efddccfa9d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ #include #include @@ -40,4 +40,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc index 1f02d70f86..2073c5bf20 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc @@ -17,7 +17,7 @@ #include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { @@ -33,7 +33,6 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - int kernel_size = SizeToInt(weight_shape[3]); auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { @@ -52,6 +51,7 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector int_padding_l; std::vector int_padding_r; const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + std::vector kernel_size({weight_shape[2], weight_shape[3]}); GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { MS_LOG(EXCEPTION) << "conv2d grad get padding failed"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h index 6f699130a8..c3618dcb6c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ #include #include @@ -40,4 +40,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc index 626fd1934e..72ce1fd9c1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h" #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -29,36 +29,7 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); using tag = dnnl::memory::format_tag; using dim = dnnl::memory::dims; - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); - input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); - hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); - num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); - has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); - batch_size_ = SizeToInt(src_shape[1]); - seq_len_ = SizeToInt(src_shape[0]); - num_directions_ = 1; - if (bidirectional_) { - num_directions_ = 2; - } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { - MS_LOG(EXCEPTION) << "error iteration shape!"; - } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } - if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { - MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; - } - const int gate_size = 4 * hidden_size_; - for (int i = 0; i < num_layers_; ++i) { - weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); - weight_h_size_ += gate_size * hidden_size_; - } - weight_size_ = weight_size_ * num_directions_; - weight_h_size_ = weight_h_size_ * num_directions_; + CheckParam(kernel_node); auto eng = MKLKernelEngine::Get().engine(); dnnl::stream s(eng); dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; @@ -99,6 +70,39 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); } +void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) { + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); + input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); + hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); + num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); + batch_size_ = SizeToInt(src_shape[1]); + seq_len_ = SizeToInt(src_shape[0]); + num_directions_ = 1; + if (bidirectional_) { + num_directions_ = 2; + } + const int gate_size = 4 * hidden_size_; + for (int i = 0; i < num_layers_; ++i) { + weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); + weight_h_size_ += gate_size * hidden_size_; + } + weight_size_ = weight_size_ * num_directions_; + weight_h_size_ = weight_h_size_ * num_directions_; + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } + if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { + MS_LOG(EXCEPTION) << "lstm only support 3-D input!"; + } +} + bool LstmCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h index 761494a931..3f2aa916d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_ #if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) #define PLATFORM_86 #endif @@ -37,6 +37,7 @@ class LstmCPUKernel : public MKLCPUKernel { const std::vector &outputs) override; private: + void CheckParam(const CNodePtr &kernel_node); int weight_size_ = 0; int weight_h_size_ = 0; int input_size_; @@ -67,4 +68,4 @@ MS_REG_CPU_KERNEL(LSTM, LstmCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc index 56da8ec808..ea2ea9824d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -18,7 +18,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -28,37 +28,8 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); using tag = dnnl::memory::format_tag; using dim = dnnl::memory::dims; + CheckParam(kernel_node); auto eng = MKLKernelEngine::Get().engine(); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); - input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); - hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); - num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); - has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); - batch_size_ = SizeToInt(src_shape[1]); - seq_len_ = SizeToInt(src_shape[0]); - num_directions_ = 1; - if (bidirectional_) { - num_directions_ = 2; - } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { - MS_LOG(EXCEPTION) << "error iteration shape!"; - } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } - if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { - MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; - } - const int gate_size = 4 * hidden_size_; - for (int i = 0; i < num_layers_; ++i) { - weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); - weight_h_size_ += gate_size * hidden_size_; - } - weight_size_ = weight_size_ * num_directions_; - weight_h_size_ = weight_h_size_ * num_directions_; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; if (bidirectional_) { direction = dnnl::rnn_direction::bidirectional_concat; @@ -91,7 +62,14 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { dst_h_desc, dst_c_desc); prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); primitive_ = std::make_shared(prim_backward_desc_); + AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); + AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); +} +void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc, + const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc, + const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc, + const dnnl::memory::desc &dst_c_desc) { AddArgument(DNNL_ARG_SRC_LAYER, src_desc); AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); @@ -101,7 +79,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { AddArgument(DNNL_ARG_DST_LAYER, dst_desc); AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); @@ -113,6 +90,72 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); } +void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); + input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); + hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); + num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); + batch_size_ = SizeToInt(src_shape[1]); + seq_len_ = SizeToInt(src_shape[0]); + num_directions_ = 1; + if (bidirectional_) { + num_directions_ = 2; + } + const int gate_size = 4 * hidden_size_; + for (int i = 0; i < num_layers_; ++i) { + weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); + weight_h_size_ += gate_size * hidden_size_; + } + weight_size_ = weight_size_ * num_directions_; + weight_h_size_ = weight_h_size_ * num_directions_; + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } + if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { + MS_LOG(EXCEPTION) << "lstm only support 3-D input!"; + } +} + +void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector &inputs, + const std::vector &outputs, + const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory, + const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory, + const dnnl::memory &diff_weights_h_memory, + const dnnl::memory &diff_bias_memory) { + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); +} + +void LSTMGradCPUKernel::Memset_op(const dnnl::memory &mem, string name) { + if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) { + MS_LOG(EXCEPTION) << name << " memset error"; + } +} + bool LSTMGradCPUKernel::Launch(const std::vector &inputs, const std::vector &workspace /*workspace*/, const std::vector &outputs) { @@ -145,14 +188,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); user_diff_weights_memory.set_data_handle(outputs[3]->addr); user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0, - user_diff_weights_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "user weights grad memset error"; - } - if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0, - user_diff_weights_h_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "user weights iter grad memset error"; - } + Memset_op(user_diff_weights_memory, "user weights grad"); + Memset_op(user_diff_weights_h_memory, "user weights iter grad"); + Memset_op(diff_weights_memory, "weights grad"); + Memset_op(diff_weights_h_memory, "weights iter grad"); if (has_bias_) { diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); } @@ -160,33 +199,8 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, prim_backward_desc_.diff_bias_desc().get_size())) { MS_LOG(EXCEPTION) << "bias grad memset error"; } - if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0, - diff_weights_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "weights grad memset error"; - } - if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0, - diff_weights_h_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "weights iter grad memset error"; - } - SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); - SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); + SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory, + diff_weights_h_memory, diff_bias_memory); ExecutePrimitive(); Reorder(&diff_weights_memory, &user_diff_weights_memory); Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h index b95b5ba792..700bc67bea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_ +#include #include #include #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" @@ -32,6 +33,17 @@ class LSTMGradCPUKernel : public MKLCPUKernel { const std::vector &outputs) override; private: + void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc, + const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc, + const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc, + const dnnl::memory::desc &dst_c_desc); + void SetArgumentHandleOp(const std::vector &inputs, + const std::vector &outputs, const dnnl::memory &weights_memory, + const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory, + const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory, + const dnnl::memory &diff_bias_memory); + void Memset_op(const dnnl::memory &mem, string name); + void CheckParam(const CNodePtr &kernel_node); int weight_size_ = 0; int weight_h_size_ = 0; int input_size_; @@ -68,4 +80,4 @@ MS_REG_CPU_KERNEL(LSTMGrad, LSTMGradCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc index 4bbaa6459f..ee9c4eb300 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc @@ -17,7 +17,7 @@ #include #include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h index ef52f652d0..23d3666567 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATMUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATMUL_CPU_KERNEL_H_ #include #include @@ -47,4 +47,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATMUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc index c71abe809d..7f66b81b82 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc @@ -17,13 +17,13 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" namespace mindspore { namespace kernel { void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, - const std::vector &src_shape, int kernel_size, int stride, + const std::vector &src_shape, const std::vector &kernel_size, int stride, std::vector *padding_l, std::vector *padding_r) { MS_EXCEPTION_IF_NULL(kernel_node); if (src_shape.size() < 2) { @@ -32,11 +32,13 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa std::vector weight_height; weight_height.emplace_back(src_shape[src_shape.size() - 2]); weight_height.emplace_back(src_shape[src_shape.size() - 1]); - int rad = kernel_size / 2; - int need_pad = kernel_size - 1; + MS_LOG(INFO) << "pad mode " << pad_mode; if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { - for (auto wh : weight_height) { + for (size_t i = 0; i < weight_height.size(); ++i) { + auto wh = weight_height[i]; + int rad = kernel_size[i] / 2; + int need_pad = kernel_size[i] - 1; int re = (wh - 1) % stride; int pad = std::max(rad - (re / 2), 0); padding_r->emplace_back(pad); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h index fc7128b10e..7f145c7116 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_CPU_KERNEL_H_ #include #include @@ -33,7 +33,8 @@ class MKLCPUKernel : public CPUKernel { protected: void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector &src_shape, - int kernel_size, int stride, std::vector *padding_l, std::vector *padding_r); + const std::vector &kernel_size, int stride, std::vector *padding_l, + std::vector *padding_r); void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); void SetArgumentHandle(int arg_key, void *ptr); dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; @@ -49,4 +50,4 @@ class MKLCPUKernel : public CPUKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h index 99e7ecdfe0..7e14b6681a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h @@ -23,7 +23,7 @@ #include #include #include "dnnl.hpp" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc index fddd769047..a59cc3c4b6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h index 182679f59d..d67626deca 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MUL_CPU_KERNEL_H_ #include #include @@ -39,4 +39,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc index e4bedf23b9..ad0ff274c1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h" #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -28,17 +28,18 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); + std::vector origin_kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - if (kernel_sizes.size() != 4 || strides.size() != 4) { - MS_LOG(EXCEPTION) << "invalid kernel size " << kernel_sizes.size() << " or stride size " << strides.size(); + if (origin_kernel_sizes.size() != 4 || strides.size() != 4) { + MS_LOG(EXCEPTION) << "invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size(); } dnnl::memory::dims strides_dims{strides[2], strides[3]}; - dnnl::memory::dims kernels_dims{kernel_sizes[2], kernel_sizes[3]}; + dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]}; const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); std::vector int_padding_l; std::vector int_padding_r; - GetPadding(kernel_node, pad_mode, src_shape, kernel_sizes[3], strides[3], &int_padding_l, &int_padding_r); + std::vector kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])}); + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, strides[3], &int_padding_l, &int_padding_r); if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { MS_LOG(EXCEPTION) << "pooling get padding failed"; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h index 8187eaffda..f9791eb6e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_CPU_KERNEL_H_ #include #include @@ -38,4 +38,4 @@ MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc index 8189df07ff..bd9bbe7b11 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc @@ -17,7 +17,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -34,7 +34,7 @@ void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { } std::vector padding_r; const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); - kernel_size_ = kernel_sizes[3]; + kernel_size_ = {IntToSize(kernel_sizes[2]), IntToSize(kernel_sizes[3])}; stride_ = strides[3]; GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); } @@ -77,7 +77,7 @@ void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *d size_t diff_index = 0; for (size_t h = 0; h < dst_shape_[2]; ++h) { box[0].first = IntToSize(std::max(h_start, 0)); - box[0].second = IntToSize(std::min(h_start + kernel_size_, src_height)); + box[0].second = IntToSize(std::min(h_start + SizeToInt(kernel_size_[1]), src_height)); for (size_t w = 0; w < src_shape_[3]; ++w) { row_max_pair[w].first = 0; row_max_pair[w].second = 0; @@ -85,7 +85,7 @@ void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *d int w_start = -padding_l_[1]; for (size_t w = 0; w < dst_shape_[3]; ++w) { box[1].first = IntToSize(std::max(w_start, 0)); - box[1].second = IntToSize(std::min(w_start + kernel_size_, src_width)); + box[1].second = IntToSize(std::min(w_start + SizeToInt(kernel_size_[0]), src_width)); RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); diff_index += 1; w_start += stride_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h index 95a7bb3f66..6c3f6a4ef1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_GRAD_CPU_KERNEL_H_ #include #include @@ -37,7 +37,8 @@ class PoolingGradCPUKernel : public MKLCPUKernel { void RowPoolingGrad(const float *input, float *output, float diff, const std::vector> &box, std::vector> *row_max_pair); void ChannelPoolingGrad(const float *input, const float *diff, float *output); - int stride_{0}, kernel_size_{0}; + int stride_{0}; + std::vector kernel_size_; std::vector padding_l_; std::vector src_shape_; std::vector dst_shape_; @@ -53,4 +54,4 @@ MS_REG_CPU_KERNEL(MaxPoolGrad, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc index 29ac9a1062..558b2d7065 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h index a2da2480e2..902885ee59 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_CPU_KERNEL_H_ #include #include @@ -37,4 +37,4 @@ MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc index 9139aa7862..3188b0622c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h index c895ab2756..ff418ae316 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_GRAD_CPU_KERNEL_H_ #include #include @@ -40,4 +40,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc index 94271b8a69..de9a8890df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h index 2812dd31af..fbe3607640 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CPU_KERNEL_H_ #include #include @@ -38,4 +38,4 @@ MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc index 889e2abdec..d6aadad436 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -19,7 +19,7 @@ #include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h index d05cb49b7b..cc9346fe1b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ #include #include @@ -50,4 +50,4 @@ MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc index b8bf7b318a..b1299f1ae0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -19,7 +19,7 @@ #include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h index 0d79b0514b..5bb5d65dcc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ #include #include @@ -50,4 +50,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h index 393b0e8c41..52a94c448d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ONE_HOT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ONE_HOT_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -48,4 +48,4 @@ MS_REG_CPU_KERNEL(OneHot, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ONE_HOT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h index a78f40d04b..18b040cc86 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ #include #include @@ -40,4 +40,4 @@ class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKerne } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index 59ab65014b..ea22397a3e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -22,8 +22,11 @@ namespace kernel { namespace ps { void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { EmbeddingLookUpCPUKernel::InitKernel(kernel_node); - - for (auto dim : input_shape_) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + size_t axis = kShape2dDims - input_shape.size(); + for (auto dim : input_shape) { input_dims_ *= dim; } @@ -32,14 +35,15 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { } std::vector keys{key_, key_, key_}; std::vector values; - values.insert(values.end(), input_shape_.begin(), input_shape_.end()); - values.insert(values.end(), indices_shape_.begin(), indices_shape_.end()); - values.insert(values.end(), output_shape_.begin(), output_shape_.end()); - std::vector lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()), - SizeToInt(output_shape_.size())}; + values.insert(values.end(), input_shape.begin(), input_shape.end()); + values.insert(values.end(), indices_shape.begin(), indices_shape.end()); + values.insert(values.end(), output_shape.begin(), output_shape.end()); + MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape + << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; + std::vector lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())}; const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { - parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]); + parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape[axis]); parallel::ps::Worker::GetInstance().InitPSEmbeddingTable(keys, values, lens); } } @@ -53,15 +57,15 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector &i size_t output_size = outputs[0]->size; size_t size = input_size / sizeof(float); - ::ps::SArray lookup_ids(size, 0); + ::ps::SArray lookup_ids(size, 0); ::ps::SArray lengths{size}; - ::ps::SArray lookup_result; + ::ps::SArray lookup_result(output_size / sizeof(float), 0); auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size); if (ret != EOK) { MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; } - parallel::ps::Worker::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result, + parallel::ps::Worker::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result, parallel::ps::kEmbeddingLookupCmd); auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h index 45e0a23fcb..dabe2c3ed8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ #include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" #include @@ -46,4 +46,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc index bcb3ca8ae8..4a36628dc7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -25,47 +25,43 @@ namespace mindspore { namespace kernel { namespace ps { using mindspore::parallel::ps::Util; +constexpr int kAxis = 0; void EmbeddingLookUpPSKernel::InitKernel( const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; input_shape_ = *(shape_vec[0]); - input_lens_ = 1; - for (auto shape : input_shape_) { - input_lens_ = input_lens_ * shape; + first_dim_size_ = input_shape_[0]; + for (size_t i = 1; i < input_shape_.size(); ++i) { + outer_dim_size_ *= input_shape_[i]; } - indices_shape_ = *(shape_vec[1]); + auto indices_shape = *(shape_vec[1]); indices_lens_ = 1; - for (auto shape : indices_shape_) { + for (auto shape : indices_shape) { indices_lens_ = indices_lens_ * shape; } - output_shape_ = *(shape_vec[2]); - axis_ = 2; - reduce_scatter_flag_ = false; + auto output_shape = *(shape_vec[2]); size_t offset = 0; for (size_t i = 0; i < rank_id_; i++) { - offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_); + offset += Util::LocalShard(input_shape_[kAxis], i, pserver_num_); } offset_ = offset; - split_num_ = pserver_num_; // input shape should be sharded after computing offset_; - Shard(input_shape_, axis_); + Shard(&input_shape_, kAxis); size_t output_size = - std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies()); + std::accumulate(output_shape.begin(), output_shape.end(), sizeof(float), std::multiplies()); output_size_list_.emplace_back(output_size); - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); } void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; - const auto &indices_shape_ = *(shape_vec[0]); - indices_lens_ = indices_shape_[0]; + const auto &indices_shape = *(shape_vec[0]); + indices_lens_ = indices_shape[0]; size_t output_size = sizeof(float) * indices_lens_; - for (size_t i = axis_ + 1; i < input_shape_.size(); i++) { + for (size_t i = kAxis + 1; i < input_shape_.size(); i++) { output_size *= input_shape_[i]; } output_size_list_.clear(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h index e23a90a11c..987de740d8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ #include #include @@ -38,9 +38,12 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK const std::vector &input_sizes() const override; const std::vector &output_sizes() const override; const std::vector &workspace_sizes() const override; + + private: + std::vector input_shape_; }; } // namespace ps } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h index a2b6c4fa61..158b890929 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_ #include #include @@ -31,8 +31,9 @@ class PServerKernel { ~PServerKernel() = default; PServerKernel(const PServerKernel &) = delete; PServerKernel &operator=(const PServerKernel &) = delete; - virtual void InitKernel(const std::shared_ptr>>> &) {} + virtual void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) {} virtual void ReInit(const std::shared_ptr>>> &) {} virtual bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) = 0; @@ -54,4 +55,4 @@ class PServerKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h index 84dd9b819e..350b503d8b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PULL_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PULL_KERNEL_H_ #include #include @@ -33,8 +33,9 @@ class PullKernel : public CPUKernel { ~PullKernel() override = default; bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { - // If the paramter is embedding table, don't Pull from PServer. - if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) { + bool init_in_server = mindspore::parallel::ps::Worker::GetInstance().GetParamInitInServer(param_name_); + // If init_in_server, forward kernel should run in server too. + if (!init_in_server) { parallel::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); } return true; @@ -82,4 +83,4 @@ class PullKernel : public CPUKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PULL_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc index 96c1f15bda..2322d4ee3a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc @@ -34,5 +34,13 @@ MS_REG_CPU_KERNEL_T(Push, MS_REG_CPU_KERNEL_T( Push, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), PushKernel, float); + +MS_REG_CPU_KERNEL_T(Push, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeUInt64), + PushKernel, float); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h index 938792f3bf..800315e5f3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PUSH_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PUSH_KERNEL_H_ #include #include @@ -43,7 +43,10 @@ class PushKernel : public CPUKernel { sizes.push_back(SizeToInt(input->size) / sizeof(T)); } parallel::ps::Worker::GetInstance().Push(keys, addrs, sizes); - memcpy(outputs[0]->addr, &key_, sizeof(size_t)); + auto ret = memcpy_s(outputs[0]->addr, sizeof(size_t), &key_, sizeof(size_t)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } return true; } @@ -77,4 +80,4 @@ class PushKernel : public CPUKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PUSH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc index c7283954f8..222a980fc6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace kernel { namespace ps { void SparseApplyAdamPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { + const CNodePtr &cnode, const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; std::vector &var_shape = *(shape_vec[0]); std::vector &m_shape = *(shape_vec[1]); @@ -55,11 +55,11 @@ void SparseApplyAdamPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; } - /* - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(cnode, "use_nesterov"); } - */ + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); @@ -75,7 +75,7 @@ void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr &inputs) { const auto &indices_addr = inputs[10]; - indices_size_ = indices_addr->size; + indices_size_ = indices_addr->size / sizeof(int); workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); workspace_size_list_[1] = indices_size_ * sizeof(int); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h index 337fcb3bf0..bd3e021a69 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ #include #include @@ -30,7 +30,8 @@ class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerK SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} ~SparseApplyAdamPSKernel() override = default; - void InitKernel(const std::shared_ptr>>> &) override; + void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) override; void ReInit(const std::shared_ptr>>> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; @@ -46,4 +47,4 @@ class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerK } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index 0392bd5a69..afd676382f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { namespace ps { void SparseApplyFtrlPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { + const CNodePtr &cnode, const std::shared_ptr>>> &shapes) { const std::vector>> &shape_vec = *shapes; std::vector var_shape = *(shape_vec[0]); std::vector accum_shape = *(shape_vec[1]); @@ -46,10 +46,24 @@ void SparseApplyFtrlPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } - lr_ = 0.01; - l1_ = 1e-8; - l2_ = 1e-8; - lr_power_ = -0.5; + lr_ = AnfAlgo::GetNodeAttr(cnode, "lr"); + if (lr_ <= 0) { + MS_LOG(EXCEPTION) << "lr should be a positive scalar"; + } + l1_ = AnfAlgo::GetNodeAttr(cnode, "l1"); + if (l1_ < 0) { + MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; + } + l2_ = AnfAlgo::GetNodeAttr(cnode, "l2"); + if (l2_ < 0) { + MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; + } + lr_power_ = AnfAlgo::GetNodeAttr(cnode, "lr_power"); + if (lr_power_ > 0) { + MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; + } + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); } @@ -64,7 +78,7 @@ void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr &inputs) { const auto &indices_addr = inputs[4]; - indices_size_ = indices_addr->size; + indices_size_ = indices_addr->size / sizeof(int); workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); workspace_size_list_[1] = indices_size_ * sizeof(int); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h index d97f19d349..3a5dfc738e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ #include #include @@ -30,7 +30,8 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} ~SparseApplyFtrlPSKernel() override = default; - void InitKernel(const std::shared_ptr>>> &) override; + void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) override; void ReInit(const std::shared_ptr>>> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, @@ -47,4 +48,4 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc new file mode 100644 index 0000000000..03949b3685 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" +#include +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void SparseApplyLazyAdamPSKernel::InitKernel( + const CNodePtr &cnode, const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector &var_shape = *(shape_vec[0]); + std::vector &m_shape = *(shape_vec[1]); + std::vector &v_shape = *(shape_vec[2]); + const std::vector &grad_shape = *(shape_vec[9]); + const std::vector &indices_shape = *(shape_vec[10]); + + Shard(&var_shape, 0); + Shard(&m_shape, 0); + Shard(&v_shape, 0); + + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + } + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(cnode, "use_nesterov"); + } + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); +} + +void SparseApplyLazyAdamPSKernel::ReInit( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + const std::vector &indices_shape = *(shape_vec[0]); + indices_size_ = indices_shape[0]; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +void SparseApplyLazyAdamPSKernel::ReInit(const std::vector &inputs) { + const auto &indices_addr = inputs[10]; + indices_size_ = indices_addr->size / sizeof(int); + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +bool SparseApplyLazyAdamPSKernel::Execute(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + ReInit(inputs); + int *indices = reinterpret_cast(inputs[10]->addr); + for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) { + indices[i] -= rank_id_ * var_first_dim_size_; + } + return Launch(inputs, workspace, outputs); +} + +const std::vector &SparseApplyLazyAdamPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &SparseApplyLazyAdamPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &SparseApplyLazyAdamPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h new file mode 100644 index 0000000000..595f2ab6a3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::kernel::SparseApplyLazyAdamCPUKernel; +class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public PServerKernel { + public: + SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~SparseApplyLazyAdamPSKernel() override = default; + + void InitKernel(const CNodePtr &cnode, + const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; + + protected: + void ReInit(const std::vector &) override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc index 0dddf1d3c4..c2075b76c5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc @@ -79,9 +79,6 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool ReduceCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspaces*/, const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } size_t out_float_size = left_dims_ * sizeof(float); size_t in_float_size = stride_ * out_float_size; if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) { @@ -106,6 +103,11 @@ bool ReduceCPUKernel::Launch(const std::vector &inputs, } (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]); + ConvertDataToOutput(&new_input[0], output); + return true; +} + +void ReduceCPUKernel::ConvertDataToOutput(const float *new_input, float *output) { if (reduce_type_ == kReduceTypeMax) { for (size_t i = 0; i < left_dims_; ++i) { float value = new_input[i * stride_]; @@ -129,7 +131,6 @@ bool ReduceCPUKernel::Launch(const std::vector &inputs, } } } - return true; } void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector &input_shape, const std::vector &input_axis, const int shape_size, float *output) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h index a9696bad49..b0df49395a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_ #include #include #include @@ -34,7 +34,8 @@ class ReduceCPUKernel : public CPUKernel { private: void Transpose(const int size, const float *input, const std::vector &input_shape, const std::vector &input_axis, const int shape_size, float *output); - size_t reduce_type_; + void ConvertDataToOutput(const float *input, float *output); + size_t reduce_type_ = 0; std::vector axis_; std::vector shape_; size_t left_dims_ = 1; @@ -48,4 +49,4 @@ MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu ReduceCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h index 317d7df443..6af5f4c117 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -42,4 +42,4 @@ MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h index 04f1db3304..915e1e8616 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESHAPE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESHAPE_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -50,4 +50,4 @@ MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOut } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESHAPE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h index 03b7ecdc17..8facbb957d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -54,4 +54,4 @@ MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).Ad } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h index ec480d7e80..e9ba06cc67 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_GRAD_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -56,4 +56,4 @@ MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32 } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc index 2ff8e77fcd..bd57a022fd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc @@ -27,12 +27,12 @@ void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t en auto m = input_params->m_; auto m_t = input_params->m_t_; auto v = input_params->v_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - auto use_nesterov = input_params->use_nesterov_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; + const auto beta1 = input_params->beta1_; + const auto beta2 = input_params->beta2_; + const auto use_nesterov = input_params->use_nesterov_; + const auto unique_sparse_grad = input_params->sparse_grad_; + const auto var_first_dim_size = input_params->var_first_dim_size_; + const auto var_outer_dim_size = input_params->var_outer_dim_size_; for (size_t i = start; i < end; ++i) { int index = unique_sparse_grad.indices_[i]; if (index < 0 || IntToSize(index) >= var_first_dim_size) { @@ -55,8 +55,8 @@ void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_ MS_EXCEPTION_IF_NULL(input_params); auto m = input_params->m_; auto v = input_params->v_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; + const auto beta1 = input_params->beta1_; + const auto beta2 = input_params->beta2_; for (size_t i = start; i < end; ++i) { m[i] *= beta1; v[i] *= beta2; @@ -66,10 +66,10 @@ void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_ void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) { MS_EXCEPTION_IF_NULL(input_params); auto var = input_params->var_; - auto m = input_params->m_; - auto v = input_params->v_; - auto lr = input_params->lr_; - auto epsilon = input_params->epsilon_; + const auto *m = input_params->m_; + const auto *v = input_params->v_; + const auto lr = input_params->lr_; + const auto epsilon = input_params->epsilon_; for (size_t i = start; i < end; ++i) { var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); } @@ -81,6 +81,8 @@ void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) MS_EXCEPTION_IF_NULL(kernel_node); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); } @@ -142,11 +144,21 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector &inp auto indices = reinterpret_cast(inputs[10]->addr); auto new_grad = reinterpret_cast(workspace[0]->addr); auto new_indices = reinterpret_cast(workspace[1]->addr); - auto m_t = reinterpret_cast(workspace[2]->addr); + auto workspace_grad = reinterpret_cast(workspace[2]->addr); + auto workspace_indices = reinterpret_cast(workspace[3]->addr); + auto m_t = reinterpret_cast(workspace[4]->addr); SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); + SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_}); + SparseGradient input_sparse_grad({grad, indices, indices_size_}); + ReduceSparseGradientParam param; + param.input_grad_ = &input_sparse_grad; + param.workspace_grad_ = &workspace_sparse_grad; + param.output_grad_ = &unique_sparse_grad; + param.max_index_ = var_first_dim_size_; + param.value_stride_ = var_outer_dim_size_; + BucketReduceSparseGradient(param); + size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h index 3a7a449246..6cf716839c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ #include #include @@ -60,4 +60,4 @@ MS_REG_CPU_KERNEL(FusedSparseAdam, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc index 2662604e19..e375310229 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc @@ -27,13 +27,13 @@ void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t en auto var = input_params->var_; auto accum = input_params->accum_; auto linear = input_params->linear_; - auto lr = input_params->lr_; - auto l1 = input_params->l1_; - auto l2_plus = 2 * input_params->l2_; - auto lr_power = input_params->lr_power_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; + const auto lr = input_params->lr_; + const auto l1 = input_params->l1_; + const auto l2_plus = 2 * input_params->l2_; + const auto lr_power = input_params->lr_power_; + const auto unique_sparse_grad = input_params->sparse_grad_; + const auto var_first_dim_size = input_params->var_first_dim_size_; + const auto var_outer_dim_size = input_params->var_outer_dim_size_; for (size_t i = start; i < end; ++i) { int index = unique_sparse_grad.indices_[i]; if (index < 0 || IntToSize(index) >= var_first_dim_size) { @@ -132,12 +132,19 @@ bool SparseApplyFtrlCPUKernel::Launch(const std::vector &inp auto indices = reinterpret_cast(inputs[4]->addr); auto new_grad = reinterpret_cast(workspace[0]->addr); auto new_indices = reinterpret_cast(workspace[1]->addr); - auto tmp_grad = reinterpret_cast(workspace[2]->addr); - auto tmp_indices = reinterpret_cast(workspace[3]->addr); + auto workspace_grad = reinterpret_cast(workspace[2]->addr); + auto workspace_indices = reinterpret_cast(workspace[3]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); - TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, - var_first_dim_size_, var_outer_dim_size_); + SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_}); + SparseGradient input_sparse_grad({grad, indices, indices_size_}); + ReduceSparseGradientParam param; + param.input_grad_ = &input_sparse_grad; + param.workspace_grad_ = &workspace_sparse_grad; + param.output_grad_ = &unique_sparse_grad; + param.max_index_ = var_first_dim_size_; + param.value_stride_ = var_outer_dim_size_; + BucketReduceSparseGradient(param); MultiThreadComputeParams input_params; input_params.var_ = var; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h index c24ce8c703..a4523e8530 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -56,4 +56,4 @@ MS_REG_CPU_KERNEL(FusedSparseFtrl, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc index 636d92dcbb..24a48e2d7b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc @@ -27,14 +27,14 @@ void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_ auto var = input_params->var_; auto m = input_params->m_; auto v = input_params->v_; - auto lr = input_params->lr_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - auto epsilon = input_params->epsilon_; - auto use_nesterov = input_params->use_nesterov_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; + const auto lr = input_params->lr_; + const auto beta1 = input_params->beta1_; + const auto beta2 = input_params->beta2_; + const auto epsilon = input_params->epsilon_; + const auto use_nesterov = input_params->use_nesterov_; + const auto unique_sparse_grad = input_params->sparse_grad_; + const auto var_first_dim_size = input_params->var_first_dim_size_; + const auto var_outer_dim_size = input_params->var_outer_dim_size_; for (size_t i = start; i < end; ++i) { int index = unique_sparse_grad.indices_[i]; if (index < 0 || IntToSize(index) >= var_first_dim_size) { @@ -123,13 +123,19 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector auto indices = reinterpret_cast(inputs[10]->addr); auto new_grad = reinterpret_cast(workspace[0]->addr); auto new_indices = reinterpret_cast(workspace[1]->addr); - auto tmp_grad = reinterpret_cast(workspace[2]->addr); - auto tmp_indices = reinterpret_cast(workspace[3]->addr); + auto workspace_grad = reinterpret_cast(workspace[2]->addr); + auto workspace_indices = reinterpret_cast(workspace[3]->addr); SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); - TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, - var_first_dim_size_, var_outer_dim_size_); + SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_}); + SparseGradient input_sparse_grad({grad, indices, indices_size_}); + ReduceSparseGradientParam param; + param.input_grad_ = &input_sparse_grad; + param.workspace_grad_ = &workspace_sparse_grad; + param.output_grad_ = &unique_sparse_grad; + param.max_index_ = var_first_dim_size_; + param.value_stride_ = var_outer_dim_size_; + BucketReduceSparseGradient(param); lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); MultiThreadComputeParams input_params; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h index e588702aea..2235c22ea5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ #include #include @@ -33,7 +33,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - private: + protected: size_t indices_size_{0}; size_t var_first_dim_size_{0}; size_t var_outer_dim_size_{1}; @@ -60,4 +60,4 @@ MS_REG_CPU_KERNEL(FusedSparseLazyAdam, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc index efba35ad8c..9e066c5879 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc @@ -26,12 +26,12 @@ void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start MS_EXCEPTION_IF_NULL(input_params); auto var = input_params->var_; auto accum = input_params->accum_; - auto lr = input_params->lr_; - auto l1 = input_params->l1_; - auto l2 = input_params->l2_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; + const auto lr = input_params->lr_; + const auto l1 = input_params->l1_; + const auto l2 = input_params->l2_; + const auto unique_sparse_grad = input_params->sparse_grad_; + const auto var_first_dim_size = input_params->var_first_dim_size_; + const auto var_outer_dim_size = input_params->var_outer_dim_size_; for (size_t i = start; i < end; ++i) { int index = unique_sparse_grad.indices_[i]; if (index < 0 || IntToSize(index) >= var_first_dim_size) { @@ -61,6 +61,8 @@ void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &ke MS_EXCEPTION_IF_NULL(kernel_node); workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); } void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { @@ -119,9 +121,19 @@ bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector(inputs[6]->addr); auto new_grad = reinterpret_cast(workspace[0]->addr); auto new_indices = reinterpret_cast(workspace[1]->addr); + auto workspace_grad = reinterpret_cast(workspace[2]->addr); + auto workspace_indices = reinterpret_cast(workspace[3]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); + SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_}); + SparseGradient input_sparse_grad({grad, indices, indices_size_}); + ReduceSparseGradientParam param; + param.input_grad_ = &input_sparse_grad; + param.workspace_grad_ = &workspace_sparse_grad; + param.output_grad_ = &unique_sparse_grad; + param.max_index_ = var_first_dim_size_; + param.value_stride_ = var_outer_dim_size_; + BucketReduceSparseGradient(param); MultiThreadComputeParams input_params; input_params.var_ = var; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h index 616fb9b954..7bd38f5560 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ #include #include @@ -54,4 +54,4 @@ MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h index d1b55ded90..53bef5f7dc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SUB_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SUB_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -42,4 +42,4 @@ MS_REG_CPU_KERNEL( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SUB_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h index 15796f9f3c..b80007e016 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ #include #include #include @@ -41,4 +41,4 @@ MS_REG_CPU_KERNEL(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu TransposeCPUFwdKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h index 61a53c5b40..13a7d4380b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXGPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXGPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -103,4 +103,4 @@ class ArgmaxGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h index d2369023fb..a6cb342268 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -93,4 +93,4 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc index 5d34a1c9c2..3e7cb788ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc @@ -30,5 +30,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).A ArrayReduceGpuKernel, float) MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index b96f63670d..0f32519959 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYREDUCE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYREDUCE_GPU_KERNEL_H_ #include #include @@ -29,6 +29,7 @@ const std::map kReduceTypeMap = { {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, + {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, }; template class ArrayReduceGpuKernel : public GpuKernel { @@ -234,4 +235,4 @@ class ArrayReduceGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYREDUCE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc new file mode 100644 index 0000000000..96e82bc5f3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastToGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastToGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h new file mode 100644 index 0000000000..280879b81c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_TO_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_TO_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BroadcastToGpuKernel : public GpuKernel { + public: + BroadcastToGpuKernel() {} + ~BroadcastToGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[0], output_shape_[1], + output_shape_[2], output_shape_[3], input_addr, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (input_shapes.size() > 4 || output_shapes.size() > 4) { + MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4"; + } + + size_t offset = output_shapes.size() - input_shapes.size(); + for (size_t i = 0; i < input_shapes.size(); i++) { + input_shape_[i + offset] = input_shapes[i]; + } + + for (size_t j = 0; j < output_shapes.size(); j++) { + output_shape_[j] = output_shapes[j]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T)); + output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T)); + } + + private: + int input_shape_[4] = {1, 1, 1, 1}; + int output_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_TO_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h index 15ccedcaec..c87082c9f2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h @@ -14,10 +14,11 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H #include +#include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" @@ -27,40 +28,35 @@ namespace kernel { template class ConcatV2GpuFwdKernel : public GpuKernel { public: - ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} + ConcatV2GpuFwdKernel() + : axis_(0), + input_num_(1), + output_size_(0), + all_size_before_axis_(1), + all_size_axis_(1), + inputs_host_(nullptr), + len_axis_(nullptr) {} ~ConcatV2GpuFwdKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - if (inputs.size() == 2) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, - reinterpret_cast(stream_ptr)); - } - - if (inputs.size() == 3) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *input_2 = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, - reinterpret_cast(stream_ptr)); - } - - if (inputs.size() == 4) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *input_2 = GetDeviceAddress(inputs, 2); - T *input_3 = GetDeviceAddress(inputs, 3); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, - reinterpret_cast(stream_ptr)); + T *output = GetDeviceAddress(outputs, 0); + T **inputs_device = GetDeviceAddress(workspace, 0); + int *len_axis_device = GetDeviceAddress(workspace, 1); + for (size_t i = 0; i < inputs.size(); i++) { + inputs_host_[i] = GetDeviceAddress(inputs, i); } + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_device, inputs_host_.get(), sizeof(T *) * input_num_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "ConcatV2 opt cudaMemcpyAsync inputs failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(len_axis_device, len_axis_.get(), sizeof(int) * input_num_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "ConcatV2 opt cudaMemcpyAsync length on axis failed"); + ConcatKernel(output_size_, input_num_, all_size_before_axis_, all_size_axis_, len_axis_device, inputs_device, + output, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -74,25 +70,34 @@ class ConcatV2GpuFwdKernel : public GpuKernel { axis_ += SizeToInt(input_shape.size()); } - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_size = sizeof(T); + input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node)); + inputs_host_ = std::make_unique(input_num_); + len_axis_ = std::make_unique(input_num_); + for (int i = 0; i < input_num_; i++) { + size_t input_size = 1; auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); for (size_t j = 0; j < input_shape.size(); j++) { - input_size *= SizeToInt(input_shape[j]); - if (j >= IntToSize(axis_)) { - w_[i] *= SizeToInt(input_shape[j]); - } - input_size_list_.push_back(input_size); + input_size *= input_shape[j]; } + input_size_list_.push_back(input_size * sizeof(T)); + len_axis_[i] = SizeToInt(input_shape[axis_]); } + workspace_size_list_.push_back(sizeof(T *) * input_num_); + workspace_size_list_.push_back(sizeof(int) * input_num_); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - output_size_ = sizeof(T); - for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ = 1; + for (int i = 0; i < SizeToInt(output_shape.size()); i++) { output_size_ *= output_shape[i]; + if (i > axis_) { + all_size_before_axis_ *= output_shape[i]; + all_size_axis_ *= output_shape[i]; + } + if (i == axis_) { + all_size_before_axis_ *= output_shape[i]; + } } - output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_ * sizeof(T)); InitSizeLists(); return true; @@ -103,11 +108,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel { private: bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num < 2 || input_num > 4) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; - return false; - } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; @@ -115,9 +115,13 @@ class ConcatV2GpuFwdKernel : public GpuKernel { } return true; } - int w_[4] = {1, 1, 1, 1}; int axis_; + int input_num_; size_t output_size_; + int all_size_before_axis_; + int all_size_axis_; + std::unique_ptr inputs_host_; + std::unique_ptr len_axis_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; @@ -125,4 +129,4 @@ class ConcatV2GpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc new file mode 100644 index 0000000000..38f168a9b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherNdGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherNdGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherNdGpuFwdKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h new file mode 100644 index 0000000000..d4e8d3d8ad --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h @@ -0,0 +1,171 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHERND_GPU_KERNEL_H +#define MINDSPORE_GATHERND_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" + +namespace mindspore { +namespace kernel { +template +class GatherNdGpuFwdKernel : public GpuKernel { + public: + GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr), memcpy_flag_(false) {} + ~GatherNdGpuFwdKernel() { + if (dev_batch_strides_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(dev_batch_strides_)); + } + if (dev_batch_indices_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(dev_batch_indices_)); + } + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + S *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + if (!memcpy_flag_) { + const size_t strides_len = sizeof(S) * batch_strides_.size(); + const size_t indices_len = sizeof(S) * batch_indices_.size(); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_batch_strides_, &batch_strides_[0], strides_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in GatherNdGpuFwdKernel::Launch."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_batch_indices_, &batch_indices_[0], indices_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in GatherNdGpuFwdKernel::Launch."); + memcpy_flag_ = true; + } + + GatherNd(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], dev_batch_strides_, + dev_batch_indices_, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + memcpy_flag_ = false; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherNdGpuFwdKernel needs 2."; + } + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + Reshape(); + + size_t dim_indices_last = dims_[dims_.size() - 1]; + batch_strides_.resize(dim_indices_last, 0); + batch_indices_.resize(dim_indices_last, 0); + + if (dim_indices_last > 0) { + batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1]; + batch_indices_[dim_indices_last - 1] = dims_[1]; + } + for (size_t i = dim_indices_last - 1; i > 0; --i) { + batch_strides_[i - 1] = input_shapes_[i - 1]; + batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; + } + + const size_t strides_len = sizeof(S) * batch_strides_.size(); + void *dev_batch_strides_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(strides_len); + if (dev_batch_strides_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_strides_work, size: " << strides_len; + } + dev_batch_strides_ = static_cast(dev_batch_strides_work); + + const size_t indices_len = sizeof(S) * batch_indices_.size(); + void *dev_batch_indices_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); + if (dev_batch_indices_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_indices_work, size: " << indices_len; + } + dev_batch_indices_ = static_cast(dev_batch_indices_work); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t size = GetSize(input_shapes_); + input_size_list_.push_back(size); + + size = GetSize(indices_shapes_); + input_size_list_.push_back(size); + + size = GetSize(output_shapes_); + output_size_list_.push_back(size); + } + + private: + void Reshape() { + size_t dim_of_indices = 1; + for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); i++) { + dim_of_indices *= indices_shapes_[i]; + } + + size_t dim_after_indices = 1; + size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)]; + for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) { + dim_after_indices *= input_shapes_[i]; + } + dims_.emplace_back(dim_of_indices); + dims_.emplace_back(dim_after_indices); + dims_.emplace_back(dim_indices_last); + return; + } + size_t GetSize(const std::vector &shape) const { + if (shape.size() == 0) { + return 0; + } + size_t result = sizeof(T); + for (size_t i = 0; i < shape.size(); i++) { + result *= shape[i]; + } + return result; + } + + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + + std::vector dims_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector batch_strides_; + std::vector batch_indices_; + + S *dev_batch_strides_; + S *dev_batch_indices_; + bool memcpy_flag_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_GATHERND_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h index 6c46a63e69..13d2eb32f2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ONEHOT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ONEHOT_GPU_KERNEL_H #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -102,4 +102,4 @@ class OneHotGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ONEHOT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc new file mode 100644 index 0000000000..1400f623b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + OnesLikeGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + OnesLikeGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + OnesLikeGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h new file mode 100644 index 0000000000..11f972f07e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ONESLIKE_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ONESLIKE_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh" +namespace mindspore { +namespace kernel { +template +class OnesLikeGpuKernel : public GpuKernel { + public: + OnesLikeGpuKernel() : input_size_(0), output_size_(0) {} + ~OnesLikeGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int size = SizeToInt(input_size_ / sizeof(T)); + CalOnesLike(size, input, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but oneslike needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but oneslike needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + size_t shape_size = input_shape.size(); + + input_size_ = 1; + for (size_t i = 0; i < shape_size; i++) { + input_size_ *= input_shape[i]; + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ONESLIKE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc new file mode 100644 index 0000000000..3e38ca599e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeNearestNeighborGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeNearestNeighborGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ResizeNearestNeighborGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h new file mode 100644 index 0000000000..ac0f6b402e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ResizeNearestNeighborGpuKernel : public GpuKernel { + public: + ResizeNearestNeighborGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {} + ~ResizeNearestNeighborGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int size = SizeToInt(output_size_ / sizeof(T)); + float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); + float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); + CalResizeNearestNeighbor(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output, + output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], align_corners_, + h_scale, w_scale, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) { + MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " + << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; + return false; + } + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + output_size_ *= output_shape[i]; + output_shape_.push_back(output_shape[i]); + } + output_size_ *= sizeof(T); + align_corners_ = GetAttr(kernel_node, "align_corners"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + float Scaling(const int in_size, const int out_size, bool align_corners) { + return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); + } + + bool align_corners_; + size_t shape_size_; + std::vector input_shape_; + std::vector output_shape_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc new file mode 100644 index 0000000000..14a886a4b8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeNearestNeighborGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeNearestNeighborGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ResizeNearestNeighborGradGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h new file mode 100644 index 0000000000..6d32d8da73 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ResizeNearestNeighborGradGpuKernel : public GpuKernel { + public: + ResizeNearestNeighborGradGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {} + ~ResizeNearestNeighborGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int input_size = SizeToInt(input_size_ / sizeof(T)); + float h_scale = Scaling(output_shape_[2], input_shape_[2], align_corners_); + float w_scale = Scaling(output_shape_[3], input_shape_[3], align_corners_); + CalResizeNearestNeighborGrad(input_size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], + align_corners_, h_scale, w_scale, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) { + MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " + << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; + return false; + } + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + output_size_ *= output_shape[i]; + output_shape_.push_back(output_shape[i]); + } + output_size_ *= sizeof(T); + align_corners_ = GetAttr(kernel_node, "align_corners"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + float Scaling(const int in_size, const int out_size, bool align_corners) { + return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); + } + + bool align_corners_; + size_t shape_size_; + std::vector input_shape_; + std::vector output_shape_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc new file mode 100644 index 0000000000..3a9aa6e075 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ScatterNd, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ScatterNdGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ScatterNdGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ScatterNdGpuFwdKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h new file mode 100644 index 0000000000..7cc0d1f858 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class ScatterNdGpuFwdKernel : public GpuKernel { + public: + ScatterNdGpuFwdKernel() + : input_size_(1), + indices_size_(1), + output_size_(1), + block_size_(1), + indices_stride_(nullptr), + work_shape_(nullptr), + indices_dim_0_(0), + indices_dim_1_(0), + memcpy_flag_(false) {} + ~ScatterNdGpuFwdKernel() { + if (indices_stride_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(indices_stride_)); + } + if (work_shape_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(work_shape_)); + } + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + S *indices = GetDeviceAddress(inputs, 0); + T *update = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + + if (!memcpy_flag_) { + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(indices_stride_, &vec_indices_stride_[0], indices_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpy failed in ScatterNdGpuFwdKernel::Launch."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpy failed in ScatterNdGpuFwdKernel::Launch."); + memcpy_flag_ = true; + } + + const size_t input_size = input_size_ / sizeof(T); + const size_t output_size = output_size_ / sizeof(T); + + ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_, + indices_stride_, work_shape_, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + memcpy_flag_ = false; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 2 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; + return false; + } + + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + vec_work_shape_ = GetAttr>(kernel_node, "shape"); + + GetSize(); + + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); + if (indices_stride_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; + } + indices_stride_ = static_cast(indices_stride_work); + + const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); + if (work_shape_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; + } + work_shape_ = static_cast(work_shape_work); + + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(indices_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + void GetSize() { + indices_size_ = sizeof(S); + for (size_t i = 0; i < indices_shapes_.size(); i++) { + indices_size_ *= indices_shapes_[i]; + } + input_size_ = sizeof(T); + for (size_t i = 0; i < input_shapes_.size(); i++) { + input_size_ *= input_shapes_[i]; + } + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shapes_.size(); i++) { + output_size_ *= output_shapes_[i]; + } + + // calculate indices dim 0/1 + indices_dim_0_ = indices_shapes_[0]; + indices_dim_1_ = indices_shapes_[indices_shapes_.size() - 1]; + + // calculate block_size + for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { + block_size_ *= output_shapes_[i]; + } + + // calculate indices_stride + vec_indices_stride_.resize(indices_dim_1_, 0); + vec_indices_stride_[indices_dim_1_ - 1] = block_size_; + + for (size_t i = indices_dim_1_ - 1; i > 0; --i) { + vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i]; + } + } + + private: + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + std::vector vec_indices_stride_; + std::vector vec_work_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t indices_size_; + size_t output_size_; + size_t block_size_; + + S *indices_stride_; + S *work_shape_; + size_t indices_dim_0_; + size_t indices_dim_1_; + bool memcpy_flag_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h index 73e60c44bd..6dbb72cc85 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SELECT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SELECT_GPU_KERNEL_H #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -92,4 +92,4 @@ class SelectGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SELECT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc index 4c9ff2b7f4..37c1fc6d4d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc @@ -24,11 +24,5 @@ MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutp SliceGpuFwdKernel, int) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGpuFwdKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h index f8ecb9ccf0..8fc129ea43 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -41,14 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel { } T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - if (is_strided_slice_) { - CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, - reinterpret_cast(stream_ptr)); - } else { - Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], - input_shape_[1], input_shape_[2], input_shape_[3], input, output, - reinterpret_cast(stream_ptr)); - } + Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], + input_shape_[1], input_shape_[2], input_shape_[3], input, output, + reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -159,4 +154,4 @@ class SliceGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc index 2eeb3acf73..da3353b823 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc @@ -29,11 +29,5 @@ MS_REG_GPU_KERNEL_ONE( SliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), SliceGradGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h index 006cbf0266..45566c9d69 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -38,13 +38,8 @@ class SliceGradGpuKernel : public GpuKernel { T *dy = GetDeviceAddress(inputs, 0); T *dx = GetDeviceAddress(outputs, 0); FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast(stream_ptr)); - if (is_strided_slice_) { - CalStridedSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, strides_, dx, - reinterpret_cast(stream_ptr)); - } else { - CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, - reinterpret_cast(stream_ptr)); - } + CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, + reinterpret_cast(stream_ptr)); return true; } @@ -140,8 +135,8 @@ class SliceGradGpuKernel : public GpuKernel { size_t input_size_; size_t output_size_; size_t workspace_size_; -}; +}; // namespace kernel } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc new file mode 100644 index 0000000000..0101f65001 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SplitGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Split, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SplitGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SplitGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h new file mode 100644 index 0000000000..2a7cb6c5dc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h @@ -0,0 +1,153 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPLIT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPLIT_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SplitGpuFwdKernel : public GpuKernel { + public: + SplitGpuFwdKernel() + : axis_(0), + output_num_(1), + input_size_(1), + axis_step_(1), + all_size_before_axis_(1), + all_size_axis_(1), + outputs_host_(nullptr) {} + ~SplitGpuFwdKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T **outputs_device = GetDeviceAddress(workspace, 0); + for (size_t i = 0; i < outputs.size(); i++) { + outputs_host_[i] = GetDeviceAddress(outputs, i); + } + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_device, outputs_host_.get(), sizeof(T *) * output_num_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "Split opt cudaMemcpyAsync outputs failed"); + SplitKernel(input_size_, axis_step_, all_size_before_axis_, all_size_axis_, input, outputs_device, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + axis_ += SizeToInt(input_shape.size()); + } + output_num_ = GetAttr(kernel_node, "output_num"); + + if (!CheckParam(kernel_node)) { + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = 1; + all_size_before_axis_ = 1; + all_size_axis_ = 1; + + for (int i = 0; i < SizeToInt(input_shape.size()); i++) { + input_size_ *= input_shape[i]; + if (i > axis_) { + all_size_before_axis_ *= input_shape[i]; + all_size_axis_ *= input_shape[i]; + } + if (i == axis_) { + all_size_before_axis_ *= input_shape[i]; + } + } + input_size_list_.push_back(input_size_ * sizeof(T)); + axis_step_ = input_shape[axis_] / output_num_; + + for (int i = 0; i < output_num_; i++) { + size_t output_size = 1; + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i); + for (size_t j = 0; j < output_shape.size(); j++) { + output_size *= output_shape[j]; + } + output_size_list_.push_back(output_size * sizeof(T)); + } + workspace_size_list_.push_back(sizeof(T *) * output_num_); + InitSizeLists(); + outputs_host_ = std::make_unique(output_num_); + return true; + } + + protected: + void InitSizeLists() override {} + + private: + bool CheckParam(const CNodePtr &kernel_node) { + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + int dims = SizeToInt(input_shape.size()); + int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node)); + + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input."; + return false; + } + if (dims == 0) { + MS_LOG(ERROR) << "Input dims is " << dims << ", scalar is not supported."; + return false; + } + if (axis_ < -dims || axis_ >= dims) { + MS_LOG(ERROR) << "Attr axis " << axis_ << " must be in " << -dims << "~" << dims; + return false; + } + if (output_num_ > SizeToInt(input_shape[axis_])) { + MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; + return false; + } + if (input_shape[axis_] % output_num_ != 0) { + MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_]; + return false; + } + if (output_num_ != output_num) { + MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; + return false; + } + return true; + } + int axis_; + int output_num_; + size_t input_size_; + int axis_step_; + int all_size_before_axis_; + int all_size_axis_; + std::unique_ptr outputs_host_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPLIT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc new file mode 100644 index 0000000000..5ecb9d2a55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + StridedSliceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + StridedSliceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + StridedSliceGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h new file mode 100644 index 0000000000..aa37e8e6f9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr int MAX_DIMS = 7; +template +class StridedSliceGpuKernel : public GpuKernel { + public: + StridedSliceGpuKernel() : null_output_(false) {} + ~StridedSliceGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (null_output_) { + return true; + } + + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + StridedSlice(input_shape_, begin_, strides_, output_shape_, input, output, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape_.size() > MAX_DIMS) { + MS_LOG(ERROR) << "StridedSlice support support dims less than " << input_shape_.size(); + return false; + } + + FillEmptyDims(kernel_node); + ParseMasks(kernel_node); + FillOutputDim(); + null_output_ = IsNullOutput(); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t size = sizeof(T); + for (size_t i = 0; i < MAX_DIMS; i++) { + size *= input_shape_[i]; + } + input_size_list_.push_back(size); + + int size1 = sizeof(T); + for (size_t i = 0; i < MAX_DIMS; i++) { + size1 *= output_shape_[i]; + } + output_size_list_.push_back(size1); + } + + private: + void FillEmptyDims(const CNodePtr &kernel_node) { + begin_ = GetAttr>(kernel_node, "begin"); + end_ = GetAttr>(kernel_node, "end"); + strides_ = GetAttr>(kernel_node, "strides"); + + for (size_t i = 0; i < MAX_DIMS; i++) { + if (i < begin_.size()) { + begin_[i] = + std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1)); + } else { + begin_.push_back(0); + } + + if (i < end_.size()) { + end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1); + } else { + end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); + } + + if (i >= strides_.size()) { + strides_.push_back(1); + } + + if (i >= input_shape_.size()) { + input_shape_.push_back(1); + } + } + } + + void ParseMasks(const CNodePtr &kernel_node) { + auto begin_mask_int = GetAttr(kernel_node, "begin_mask"); + auto begin_mask = Dec2Bin(begin_mask_int); + for (size_t i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) { + begin_[i] = 0; + } + } + + auto end_mask_int = GetAttr(kernel_node, "end_mask"); + auto end_mask = Dec2Bin(end_mask_int); + for (size_t j = 0; j < end_mask.size(); j++) { + if (end_mask[j]) { + end_[j] = input_shape_[j]; + } + } + + auto ellipsis_mask_int = GetAttr(kernel_node, "ellipsis_mask"); + auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); + for (size_t k = 0; k < ellipsis_mask.size(); k++) { + if (ellipsis_mask[k]) { + begin_[k] = 0; + end_[k] = input_shape_[k]; + strides_[k] = 1; + } + } + + auto shrink_axis_mask_str = GetAttr(kernel_node, "shrink_axis_mask"); + auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); + for (size_t l = 0; l < shrink_axis_mask.size(); l++) { + if (shrink_axis_mask[l]) { + end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; + strides_[l] = end_[l] > begin_[l] ? 1 : -1; + } + } + } + + std::vector Dec2Bin(const int &mask) { + auto mask_str = std::bitset(mask).to_string(); + int dim_idx = 0; + std::vector result = {false, false, false, false}; + for (int i = mask_str.size() - 1; i >= 0; i--) { + if (mask_str[i] == '1') { + result[dim_idx] = true; + } + dim_idx++; + } + return result; + } + + void FillOutputDim() { + for (int i = 0; i < MAX_DIMS; i++) { + if (begin_[i] <= end_[i] && strides_[i] > 0) { + output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1); + } else if (begin_[i] > end_[i] && strides_[i] < 0) { + output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1); + } else { + output_shape_.push_back(0); + } + } + } + + bool IsNullOutput() { + for (int i = 0; i < MAX_DIMS; i++) { + if (begin_[i] >= end_[i] && strides_[i] > 0) { + return true; + } + if (begin_[i] < end_[i] && strides_[i] < 0) { + return true; + } + } + return false; + } + + std::vector begin_; + std::vector end_; + std::vector strides_; + std::vector input_shape_; + std::vector output_shape_; + int null_output_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc new file mode 100644 index 0000000000..bbcce07a09 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + StridedSliceGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + StridedSliceGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + StridedSliceGradGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h new file mode 100644 index 0000000000..f9cc3bcbfd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h @@ -0,0 +1,200 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr int MAX_DIMS = 7; +template +class StridedSliceGradGpuKernel : public GpuKernel { + public: + StridedSliceGradGpuKernel() : null_output_(false) {} + ~StridedSliceGradGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *dy = GetDeviceAddress(inputs, 0); + T *dx = GetDeviceAddress(outputs, 0); + + FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast(stream_ptr)); + if (null_output_) { + return true; + } + + StridedSliceGrad(output_shape_, begin_, strides_, input_shape_, dy, dx, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + input_shape_ = GetAttr>(kernel_node, "shapex"); + if (input_shape_.size() > MAX_DIMS) { + MS_LOG(ERROR) << "StridedSliceGrad support support dims less than " << input_shape_.size(); + return false; + } + + FillEmptyDims(kernel_node); + ParseMasks(kernel_node); + FillOutputDim(); + null_output_ = IsNullOutput(); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + int size = sizeof(T); + for (size_t i = 0; i < MAX_DIMS; i++) { + size *= output_shape_[i]; + } + input_size_list_.push_back(size); + + int size1 = sizeof(T); + for (size_t i = 0; i < MAX_DIMS; i++) { + size1 *= input_shape_[i]; + } + output_size_list_.push_back(size1); + } + + private: + void FillEmptyDims(const CNodePtr &kernel_node) { + begin_ = GetAttr>(kernel_node, "begin"); + end_ = GetAttr>(kernel_node, "end"); + strides_ = GetAttr>(kernel_node, "strides"); + + for (size_t i = 0; i < MAX_DIMS; i++) { + if (i < begin_.size()) { + begin_[i] = + std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1)); + } else { + begin_.push_back(0); + } + + if (i < end_.size()) { + end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1); + } else { + end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); + } + + if (i >= strides_.size()) { + strides_.push_back(1); + } + + if (i >= input_shape_.size()) { + input_shape_.push_back(1); + } + } + } + + void ParseMasks(const CNodePtr &kernel_node) { + auto begin_mask_int = GetAttr(kernel_node, "begin_mask"); + auto begin_mask = Dec2Bin(begin_mask_int); + for (size_t i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) { + begin_[i] = 0; + } + } + + auto end_mask_int = GetAttr(kernel_node, "end_mask"); + auto end_mask = Dec2Bin(end_mask_int); + for (size_t j = 0; j < end_mask.size(); j++) { + if (end_mask[j]) { + end_[j] = input_shape_[j]; + } + } + + auto ellipsis_mask_int = GetAttr(kernel_node, "ellipsis_mask"); + auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); + for (size_t k = 0; k < ellipsis_mask.size(); k++) { + if (ellipsis_mask[k]) { + begin_[k] = 0; + end_[k] = input_shape_[k]; + strides_[k] = 1; + } + } + + auto shrink_axis_mask_str = GetAttr(kernel_node, "shrink_axis_mask"); + auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); + for (size_t l = 0; l < shrink_axis_mask.size(); l++) { + if (shrink_axis_mask[l]) { + end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; + strides_[l] = end_[l] > begin_[l] ? 1 : -1; + } + } + } + + std::vector Dec2Bin(const int &mask) { + auto mask_str = std::bitset(mask).to_string(); + int dim_idx = 0; + std::vector result = {false, false, false, false}; + for (int i = mask_str.size() - 1; i >= 0; i--) { + if (mask_str[i] == '1') { + result[dim_idx] = true; + } + dim_idx++; + } + return result; + } + + void FillOutputDim() { + for (int i = 0; i < MAX_DIMS; i++) { + if (begin_[i] <= end_[i] && strides_[i] > 0) { + output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1); + } else if (begin_[i] > end_[i] && strides_[i] < 0) { + output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1); + } else { + output_shape_.push_back(0); + } + } + } + + bool IsNullOutput() { + for (int i = 0; i < MAX_DIMS; i++) { + if (begin_[i] >= end_[i] && strides_[i] > 0) { + return true; + } + if (begin_[i] < end_[i] && strides_[i] < 0) { + return true; + } + } + return false; + } + + std::vector begin_; + std::vector end_; + std::vector strides_; + std::vector input_shape_; + std::vector output_shape_; + int null_output_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc new file mode 100644 index 0000000000..59503128e9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(TopK, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + TopKGpuKernel, float, int) +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h new file mode 100644 index 0000000000..093527b5d9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class TopKGpuKernel : public GpuKernel { + public: + TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {} + ~TopKGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + S *k = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + S *indices = GetDeviceAddress(outputs, 1); + T *data_buff = nullptr; + S *index_buff = nullptr; + if (use_share_mem_ == false) { + data_buff = GetDeviceAddress(workspaces, 0); + index_buff = GetDeviceAddress(workspaces, 1); + } + + TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff, + reinterpret_cast(stream_ptr)); + + if (sorted_ == false) { + BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shapes.size() - 1; i++) { + outer_size_ *= input_shapes[i]; + } + inner_size_ = input_shapes[input_shapes.size() - 1]; + k_ = output_shapes[output_shapes.size() - 1]; + + sorted_ = GetAttr(kernel_node, "sorted"); + + ceil_power2_ = RoundUpPower2(inner_size_); + size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S)); + if (buffer_size > SHARED_MEM_PER_BLOCK) { + use_share_mem_ = false; + MS_LOG(WARNING) << "CUDA share memory not enough, sort with RAM"; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(outer_size_ * inner_size_ * sizeof(T)); + input_size_list_.push_back(sizeof(S)); + output_size_list_.push_back(outer_size_ * k_ * sizeof(T)); + output_size_list_.push_back(outer_size_ * k_ * sizeof(S)); + if (use_share_mem_ == false) { + workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T)); + workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S)); + } + } + + private: + bool sorted_; + int outer_size_; + int inner_size_; + int k_; + bool use_share_mem_; + int ceil_power2_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // TopKpuKernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h index 0f9c710e3e..ce0093cf34 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -108,4 +108,4 @@ class TransposeGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h index 1f7884c650..80dd18a8cb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORT_SEGMENT_SUM_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORT_SEGMENT_SUM_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -91,4 +91,4 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h index 7de32ade4f..9d4d5c8cd0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONTROL_RECV_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONTROL_RECV_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -63,4 +63,4 @@ class RecvGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONTROL_RECV_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h index beea19a435..d9b70e5629 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONTROL_SEND_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONTROL_SEND_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -63,4 +63,4 @@ class SendGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONTROL_SEND_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu new file mode 100644 index 0000000000..38f84dc618 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" + +template +__global__ void BoundingBoxDecodeKernel(const size_t size, const T *rois, const T *deltas, T *bboxes, const float m1, + const float m2, const float m3, const float m4, const float s1, const float s2, + const float s3, const float s4, const int max_height, const int max_width, + const float ratio_clip) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + T dx = deltas[left_x] * s1 + m1; + T dy = deltas[left_y] * s2 + m2; + T dw = deltas[right_x] * s3 + m3; + T dh = deltas[right_y] * s4 + m4; + + T max_ratio = abs(log(ratio_clip)); + + dw = dw > max_ratio ? max_ratio : (dw < (-max_ratio) ? (-max_ratio) : dw); + dh = dh > max_ratio ? max_ratio : (dh < (-max_ratio) ? (-max_ratio) : dh); + + T px = (rois[left_x] + rois[right_x]) * 0.5f; + T py = (rois[left_y] + rois[right_y]) * 0.5f; + T pw = rois[right_x] - rois[left_x] + 1.0f; + T ph = rois[right_y] - rois[left_y] + 1.0f; + + T gx = px + pw * dx; + T gy = py + ph * dy; + T gw = pw * exp(dw); + T gh = ph * exp(dh); + + T x1 = gx - gw * 0.5f + 0.5f; + T y1 = gy - gh * 0.5f + 0.5f; + T x2 = gx + gw * 0.5f - 0.5f; + T y2 = gy + gh * 0.5f - 0.5f; + + x1 = x1 > max_width ? max_width : (x1 < 0 ? 0 : x1); + y1 = y1 > max_height ? max_height : (y1 < 0 ? 0 : y1); + x2 = x2 > max_width ? max_width : (x2 < 0 ? 0 : x2); + y2 = y2 > max_height ? max_height : (y2 < 0 ? 0 : y2); + + bboxes[left_x] = x1; + bboxes[left_y] = y1; + bboxes[right_x] = x2; + bboxes[right_y] = y2; + } +} + +template +void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, + const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, + const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, + cudaStream_t cuda_stream) { + BoundingBoxDecodeKernel<<>>(size, rois, deltas, bboxes, m1, m2, m3, m4, + s1, s2, s3, s4, max_height, max_width, + ratio_clip); +} + +template void BoundingBoxDecode(const size_t size, const float *rois, const float *deltas, float *bboxes, + const float &m1, const float &m2, const float &m3, const float &m4, + const float &s1, const float &s2, const float &s3, const float &s4, + const int &max_height, const int &max_width, const float &ratio_clip, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh new file mode 100644 index 0000000000..ccd3914a1c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, + const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, + const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu new file mode 100644 index 0000000000..cf0ee68ae0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" + +template +__global__ void BoundingBoxEncodeKernel(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, + const float m1, const float m2, const float m3, const float m4, const float s1, + const float s2, const float s3, const float s4) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + T px = (anchor_box[left_x] + anchor_box[right_x]) * 0.5f; + T py = (anchor_box[left_y] + anchor_box[right_y]) * 0.5f; + T pw = anchor_box[right_x] - anchor_box[left_x] + 1.0f; + T ph = anchor_box[right_y] - anchor_box[left_y] + 1.0f; + + T gx = (groundtruth_box[left_x] + groundtruth_box[right_x]) * 0.5f; + T gy = (groundtruth_box[left_y] + groundtruth_box[right_y]) * 0.5f; + T gw = groundtruth_box[right_x] - groundtruth_box[left_x] + 1.0f; + T gh = groundtruth_box[right_y] - groundtruth_box[left_y] + 1.0f; + + T dx = (gx - px) / pw; + T dy = (gy - py) / ph; + T dw = log(gw / pw); + T dh = log(gh / ph); + + deltas[left_x] = (dx - m1) / s1; + deltas[left_y] = (dy - m2) / s2; + deltas[right_x] = (dw - m3) / s3; + deltas[right_y] = (dh - m4) / s4; + } +} + +template +void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, + const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, + const float &s3, const float &s4, cudaStream_t cuda_stream) { + BoundingBoxEncodeKernel<<>>(size, anchor_box, groundtruth_box, deltas, + m1, m2, m3, m4, s1, s2, s3, s4); +} + +template void BoundingBoxEncode(const size_t size, const float *anchor_box, const float *groundtruth_box, + float *deltas, const float &m1, const float &m2, const float &m3, + const float &m4, const float &s1, const float &s2, const float &s3, + const float &s4, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh new file mode 100644 index 0000000000..8ab810d7b9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, + const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, + const float &s3, const float &s4, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index a72daa4234..827bec11f9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -69,6 +69,33 @@ struct AddFunc { __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } }; +template +struct FloorDivFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return floor(static_cast(lhs / rhs)); } +}; + +template <> +struct FloorDivFunc { + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + return __float2half(floor(__half2float(lhs)/ __half2float(rhs))); + } +}; + +template <> +struct FloorDivFunc { + // invalid branch + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +}; + +template +struct AbsGradFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { + T zero = 0.0; + return lhs < zero ? -rhs : rhs; + } +}; + + template <> struct PowerFunc { // invalid branch @@ -77,6 +104,7 @@ struct PowerFunc { __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + template __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, @@ -126,6 +154,12 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const case BROADCAST_TYPE_ADD: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, output); + case BROADCAST_TYPE_FLOORDIV: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_ABSGRAD: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); } } @@ -167,6 +201,10 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const return NoBroadcastOperator>(nums, input0, input1, output); case BROADCAST_TYPE_ADD: return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_FLOORDIV: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_ABSGRAD: + return NoBroadcastOperator>(nums, input0, input1, output); } } @@ -176,6 +214,28 @@ void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, cons NoBroadcastKernel<<>>(nums, op, input0, input1, output); } +template +__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, + const int o1, const int o2, const int o3, const T *input_addr, T *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) { + int i = pos / (o1 * o2 * o3) % o0; + int j = pos / (o2 * o3) % o1; + int k = pos / o3 % o2; + int l = pos % o3; + + int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); + output_addr[pos] = input_addr[input_idx]; + } +} + +template +void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { + int nums = o0 * o1 * o2 * o3; + BroadcastToKernel<<>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, + output_addr); +} + template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, const float *input0, const float *input1, bool *output, @@ -196,6 +256,10 @@ template void Broadcast(const int &l0, const int &l1, const int &l2, const int & const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, const int *input0, const int *input1, int *output, cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const int *input0, const int *input1, bool *output, + cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, bool *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, @@ -206,3 +270,10 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half * half *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, + bool *output, cudaStream_t stream); +template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const float *input_addr, float *output_addr, + cudaStream_t stream); +template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index dfc4c75c93..7d762c34d9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -29,6 +29,8 @@ enum BroadcastOpType { BROADCAST_TYPE_MUL = 6, BROADCAST_TYPE_SUB = 7, BROADCAST_TYPE_ADD = 8, + BROADCAST_TYPE_FLOORDIV = 9, + BROADCAST_TYPE_ABSGRAD = 10, BROADCAST_TYPE_INVALID = 0xffffffff, }; @@ -41,4 +43,8 @@ template void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, cudaStream_t stream); +template +void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream); + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu new file mode 100644 index 0000000000..588f8c60e2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh" + +template +__global__ void CheckValidKernel(const size_t size, const T *box, const T *img_metas, S *valid) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + S valid_flag = false; + valid_flag |= !(box[left_x] >= 0.f); + valid_flag |= !(box[left_y] >= 0.f); + valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]); + valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]); + + valid[i] = !valid_flag; + } + + return; +} + +template +void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream) { + CheckValidKernel<<>>(size, box, img_metas, valid); +} + +template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh new file mode 100644 index 0000000000..fa82f10960 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu index 147782591a..4866d61dd9 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu @@ -19,90 +19,51 @@ #include #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" template -__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2); - int m = pos % (w1 + w2); - output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; +__global__ void Concat(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, T** inputs, T* output) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int num = pos % all_size_before_axis / all_size_axis; + int block = -1; + int axis_inc = 0; + int block_len = 0; + for (int i = 0; i < input_num; i++) { + if (axis_inc <= num) { + block++; + axis_inc += len_axis[i]; + } else { + break; + } + } + block_len = len_axis[block]; + axis_inc -= len_axis[block]; + int block_pos = pos / all_size_before_axis * block_len * all_size_axis + + (num - axis_inc) * all_size_axis + pos % all_size_axis;; + output[pos] = inputs[block][block_pos]; } return; } template -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3); - int m = pos % (w1 + w2 + w3); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1] : - input_3[n * w3 + m - w1 - w2]; - } - return; -} - -template -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3 + w4); - int m = pos % (w1 + w2 + w3 + w4); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1]: - m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: - input_4[n * w4 + m - w1 - w2 - w3]; - } - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, input_1, input_2, output); - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, +void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, T** inputs, T* output, cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, w3, input_1, input_2, input_3, output); + Concat<<>>(size, input_num, + all_size_before_axis, all_size_axis, + len_axis, inputs, output); return; } -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, w3, w4, input_1, - input_2, input_3, input_4, output); - return; -} - -template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const float* input_1, const float* input_2, const float* input_3, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const int* input_1, const int* input_2, const int* input_3, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const half* input_1, const half* input_2, const half* input_3, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const float* input_1, const float* input_2, const float* input_3, const float* input_4, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const int* input_1, const int* input_2, const int* input_3, const int* input_4, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const half* input_1, const half* input_2, const half* input_3, const half* input_4, - half* output, cudaStream_t cuda_stream); - +template void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, float** inputs, float* output, + cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, int** inputs, int* output, + cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, half** inputs, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh index 7bd32c140f..6e469e8028 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh @@ -19,13 +19,8 @@ #include "runtime/device/gpu/cuda_common.h" template -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream); -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, +void ConcatKernel(const size_t size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, T** inputs, T* output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cu new file mode 100644 index 0000000000..c4bba2863c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cu @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "convert_gradient_impl.cuh" + +template +__global__ void ConvertGradientKernel(const size_t size, const size_t height_h, const size_t height_w, + const size_t batchwidth, const size_t width, T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t dst_batchIdx = pointIdx / (height_h * height_w); + size_t dst_batchIdxX = dst_batchIdx / batchwidth; + size_t dst_batchIdxY = dst_batchIdx % batchwidth; + size_t dst_x = (pointIdx - dst_batchIdx * height_h * height_w) / height_w; + size_t dst_y = (pointIdx - dst_batchIdx * height_h * height_w) % height_w; + size_t src_coordinate = dst_batchIdxX * height_h * width + dst_x * width + dst_batchIdxY * height_w + dst_y; + output_addr[pointIdx] = input_addr[src_coordinate]; + } +} + +template +__global__ void ConvertGradientBackKernel(const size_t size, const size_t height_h, const size_t height_w, + const size_t batchwidth, const size_t width, T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t dst_batchIdx = pointIdx / (height_h * height_w); + size_t dst_batchIdxX = dst_batchIdx / batchwidth; + size_t dst_batchIdxY = dst_batchIdx % batchwidth; + size_t dst_x = (pointIdx - dst_batchIdx * height_h * height_w) / height_w; + size_t dst_y = (pointIdx - dst_batchIdx * height_h * height_w) % height_w; + size_t src_coordinate = dst_batchIdxX * height_h * width + dst_x * width + dst_batchIdxY * height_w + dst_y; + output_addr[src_coordinate] = input_addr[pointIdx]; + } +} + +template +__global__ void ConvertGradientBackKernel(const size_t size, const size_t height_h, const size_t height_w, + const size_t ori_h, const size_t ori_w, const size_t batchwidth, + const size_t width, T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t dst_batchIdx = pointIdx / (height_h * height_w); + size_t dst_batchIdxX = dst_batchIdx / batchwidth; + size_t dst_batchIdxY = dst_batchIdx % batchwidth; + size_t dst_x = (pointIdx - dst_batchIdx * height_h * height_w) / height_w; + size_t dst_y = (pointIdx - dst_batchIdx * height_h * height_w) % height_w; + size_t src_x = dst_batchIdxX * height_h + dst_x; + size_t src_y = dst_batchIdxY * height_w + dst_y; + if (src_x < ori_h && src_y < ori_w) { + size_t src_coordinate = src_x * ori_w + src_y; + output_addr[src_coordinate] = input_addr[pointIdx]; + } + } +} + +template +void ConvertGradient(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth, + const size_t width, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { + ConvertGradientKernel<<>>(size, height_h, height_w, batchwidth, width, + input_addr, output_addr); +} + +template +void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth, + const size_t width, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { + ConvertGradientBackKernel<<>>(size, height_h, height_w, batchwidth, + width, input_addr, output_addr); +} + +template +void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t ori_h, + const size_t ori_w, const size_t batchwidth, const size_t width, T *input_addr, T *output_addr, + cudaStream_t cuda_stream) { + ConvertGradientBackKernel<<>>( + size, height_h, height_w, ori_h, ori_w, batchwidth, width, input_addr, output_addr); +} + +template void ConvertGradient(const size_t size, const size_t height_h, const size_t height_w, + const size_t batchwidth, const size_t width, float *input_addr, float *output_addr, + cudaStream_t cuda_stream); + +template void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, + const size_t batchwidth, const size_t width, float *input_addr, + float *output_addr, cudaStream_t cuda_stream); + +template void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, + const size_t ori_h, const size_t ori_w, const size_t batchwidth, + const size_t width, float *input_addr, float *output_addr, + cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cuh new file mode 100644 index 0000000000..354ffe3965 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cuh @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CONVERTGRADIENT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CONVERTGRADIENT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ConvertGradient(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth, + const size_t width, T *input_addr, T *outt_addr, cudaStream_t cuda_stream); + +template +void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth, + const size_t width, T *input_addr, T *output_addr, cudaStream_t cuda_stream); + +template +void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t ori_h, + const size_t ori_w, const size_t batchwidth, const size_t width, T *input_addr, T *output_addr, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CONVERTGRADIENT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu new file mode 100644 index 0000000000..dfc62147b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cumsum_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (size_t j = 0; j < dim1; ++j) { + size_t read_index = j * stride2 + offset; + if (j == 0) { + output[read_index] = input[read_index]; + } else { + size_t read_index2 = (j - 1) * stride2 + offset; + output[read_index] = output[read_index2] + input[read_index]; + } + } + } +} +template +void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, + cudaStream_t stream) { + int size = dim0 * dim2; + CumSumKernel<<>>(input, output, dim0, dim1, dim2, stride, stride2); + return; +} + +template void CumSum(float *input, float *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh new file mode 100644 index 0000000000..85ca551643 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ +template +void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, + cudaStream_t stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu new file mode 100644 index 0000000000..3d02723218 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void GatherNdKernel(T *input, S *indices, T *output, const size_t output_dim0, const size_t output_dim1, + const size_t indices_dim1, S *batch_indices, S *batch_strides) { + int num = output_dim0 * output_dim1; + int i, j; + for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / output_dim1 % output_dim0; + j = write_index % output_dim1; + + bool out_of_bound = false; + int read_index = 0; + int indices_i = 0; + for (size_t k = 0; k < indices_dim1; k++) { + size_t ind = indices_dim1 * i + k; + indices_i = indices[ind]; + out_of_bound |= !(indices_i < batch_indices[k]); + read_index += indices_i * batch_strides[k]; + } + read_index += j; + + if (!out_of_bound) { + output[write_index] = input[read_index]; + } else { + output[write_index] = 0; + } + } + return; +} +template +void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, + const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream) { + int size = output_dim0 * output_dim1; + GatherNdKernel<<>>(input, indices, output, output_dim0, output_dim1, + indices_dim1, batch_indices, batch_strides); + return; +} + +template void GatherNd(float *input, int *indices, float *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); +template void GatherNd(half *input, int *indices, half *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); +template void GatherNd(int *input, int *indices, int *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh new file mode 100644 index 0000000000..c6cbbf7603 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHERND_GPU_CU_H +#define MINDSPORE_GATHERND_GPU_CU_H + +#include "runtime/device/gpu/cuda_common.h" + +template +void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, + const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream); + +#endif // MINDSPORE_GATHERND_GPU_CU_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu new file mode 100644 index 0000000000..ecb44c45ac --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "identity_impl.cuh" +#include +template +__global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (dim * dim); + size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim; + size_t dst_y = (pointIdx - batchIdx * dim * dim) % dim; + if (dst_x == dst_y) { + output_addr[pointIdx] = 1; + } else { + output_addr[pointIdx] = 0; + } + } +} + +template +void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) { + IdentityKernel<<>>(size, dim, output_addr); + return; +} + +template void Identity(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh new file mode 100644 index 0000000000..b8fd4a0be3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu new file mode 100644 index 0000000000..a3cdd7e131 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh" + +template +__device__ T CoordinateMax(const T a, const T b) { + return (a > b ? a : b); +} + +template +__device__ T CoordinateMin(const T a, const T b) { + return (a < b ? a : b); +} + +template +__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode, + const size_t input_len_0) { + T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION]; + T overlaps_coordinate[IOU_DIMENSION]; + const T epsilon = 1e-10; + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + for (size_t j = 0; j < IOU_DIMENSION; j++) { + location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j]; + location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j]; + } + + overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]); + overlaps_coordinate[1] = CoordinateMax(location_coordinate[0][1], location_coordinate[1][1]); + overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]); + overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]); + + T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1); + T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1); + T overlaps = overlaps_w * overlaps_h; + + T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] - + location_coordinate[0][1] + 1); + T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] - + location_coordinate[1][1] + 1); + if (mode == 0) { + iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon); + } else { + iou_results[i] = overlaps / (area2 + epsilon); + } + } + + return; +} + +template +void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream) { + IOUKernel<<>>(size, box1, box2, iou_results, mode, input_len_0); +} + +template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh new file mode 100644 index 0000000000..f8e7d98a24 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +#define IOU_LOCATION_NUM 2 +#define IOU_DIMENSION 4 + +template +void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu index fcb7418952..35d200b92a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -34,9 +34,9 @@ inline __device__ half my_pow(half a, double b) { } template -inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, - const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, - T* dg, T* db) { +inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, + const int &mean_dim, const T &epsilon, const T *dy, const T *x, + const T *mean, const T *var, T *dg, T *db) { int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -46,14 +46,15 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d } int pos = row * col_dim + col; - dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); + int mean_offset = pos / mean_dim; + dg[0] += dy[pos] * my_pow(var[mean_offset] + epsilon, -0.5) * (x[pos] - mean[mean_offset]); db[0] += dy[pos]; } } } template -inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { +inline __device__ void GammaAndBetaWarpReduce(T *dg, T *db) { for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); db[0] += __shfl_down_sync(0xffffffff, db[0], delta); @@ -61,12 +62,8 @@ inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { } template -inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, - T* db_addr) { - if (threadIdx.x >= row_dim) { - return; - } - +inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr, + T *db_addr) { // load data to share memory // thread(0, 32, 64, 96, ...) keep the data DynamicSharedMem share_mem; @@ -93,8 +90,9 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di } template -__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, - const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon, + const T *dy, const T *x, const T *mean_addr, const T *var_addr, T *dg_addr, + T *db_addr) { // row: [0:param_axis] // col: [param_axis:] // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) @@ -102,16 +100,16 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { T dg = 0; T db = 0; - GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); + GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); GammaAndBetaWarpReduce(&dg, &db); GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); } } template -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, - const T* var, const T* gamma) { +inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, + T *sum1, T *sum2, T *sum3, const T *dy, const T *x, const T *mean, + const T *var, const T *gamma) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -133,9 +131,9 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con } template <> -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - half* sum1, half* sum2, half* sum3, const half* dy, const half* x, - const half* mean, const half* var, const half* gamma) { +inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, + half *sum1, half *sum2, half *sum3, const half *dy, const half *x, + const half *mean, const half *var, const half *gamma) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -157,7 +155,7 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con } template -inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { +inline __device__ void InputWarpReduce(T *sum1, T *sum2, T *sum3) { for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); @@ -166,11 +164,7 @@ inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { } template -inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - +inline __device__ void InputBlockReduce(const int &col_dim, T *sum1, T *sum2, T *sum3, T *share_mem) { // load data to share memory // thread(0, 32, 64, 96, ...) keep the data if (threadIdx.x % WARP_SIZE == 0) { @@ -193,9 +187,9 @@ inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* } template -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, - const T* share_mem) { +inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, + const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx, + const T *share_mem) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = (row * col_dim + col); int gamma_offset = pos % param_dim; @@ -208,9 +202,9 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& } template <> -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, const half* share_mem) { +inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *dy, const half *x, const half *mean, const half *var, const half *gamma, + half *dx, const half *share_mem) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = (row * col_dim + col); int gamma_offset = pos % param_dim; @@ -218,14 +212,14 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& half v2 = x[pos] - mean[row]; half v3 = my_pow(var[row] + epsilon, -0.5); dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + - (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ - * __float2half(1.0 / col_dim); + (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2]) * + __float2half(1.0 / col_dim); } } template -__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx) { +__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *dy, + const T *x, const T *mean, const T *var, const T *gamma, T *dx) { for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { T sum1 = 0; T sum2 = 0; @@ -239,21 +233,25 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int } template -void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, - gamma, dx); +void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *dy, + const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) { + const int thread_per_block = 256; + int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, + mean, var, gamma, dx); - share_mem_size = - ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); - GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); + share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T); + // GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, + // mean, + // var, dg, db); + int param_reduce_dim = row_dim * col_dim / param_dim; + GammaAndBetaPropKernel<<>>(param_reduce_dim, param_dim, col_dim, + epsilon, dy, x, mean, var, dg, db); } -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, - const float* dy, const float* x, const float* mean, const float* var, const float* gamma, - float* dx, float* dg, float* db, cudaStream_t stream); -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, half* dg, half* db, cudaStream_t stream); +template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, + const float *dy, const float *x, const float *mean, const float *var, const float *gamma, + float *dx, float *dg, float *db, cudaStream_t stream); +template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *dy, const half *x, const half *mean, const half *var, const half *gamma, + half *dx, half *dg, half *db, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu index 138300b303..5797a3d711 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu @@ -73,10 +73,6 @@ inline __device__ void WarpReduce(T *mean, T *var, T *num) { template inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, T *share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - // load data to share memory // thread(0, 32, 64, 96, ...) keep the data if (threadIdx.x % WARP_SIZE == 0) { @@ -146,13 +142,11 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { - const dim3 block(row_dim); - const dim3 thread(256); + const int thread_per_block = 256; // keep the mean/var/num after warp reduce - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, - mean, var); + int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T); + LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, + beta, y, mean, var); } template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu new file mode 100644 index 0000000000..9a2c560bc7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "loss_with_reduction_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void Copy(T *loss, T *tmp_loss, int reduction, int input_size) { + loss[0] += tmp_loss[0]; + if (reduction == 1) { + loss[0] /= input_size; + } +} + +template +__global__ void AddTile(T *tmp_loss, int index) { + tmp_loss[0] += tmp_loss[index]; +} +template +__global__ void PartialSum(T *tmp_loss, int stride) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < stride; i += blockDim.x * gridDim.x) { + tmp_loss[i] += tmp_loss[i + stride]; + } +} + +template +__global__ void LossInitKernel(T *loss) { + loss[0] = static_cast(0.); +} + +template +__global__ void KLDivLossKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, T *loss, + T *tmp_loss) { + T epsilon = 1e-6; + if (reduction == 0) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = max(input_y[i], epsilon); + T value = input_y[i] * (logf(denominator) - input_x[i]); + loss[i] = value; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = max(input_y[i], epsilon); + T value = input_y[i] * (logf(denominator) - input_x[i]); + tmp_loss[i] = value; + } + } +} + +template +void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, + cudaStream_t stream) { + LossInitKernel<<<1, 1, 0, stream>>>(loss); + T *tmp_loss; + if (reduction != 0) { + cudaMalloc(reinterpret_cast(&tmp_loss), input_size * sizeof(T)); + } + KLDivLossKernel<<>>(input_size, reduction, input_x, input_y, loss, + tmp_loss); + if (reduction != 0) { + if (input_size % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); + } + for (int stride = input_size / 2; stride > 0; stride >>= 1) { + PartialSum<<>>(tmp_loss, stride); + if (stride > 2 && stride % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, stride - 1); + } + } + Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); + } + cudaFree(tmp_loss); +} + +template +__global__ void KLDivLossGradKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, + const T *dloss, T *dx, T *dy) { + T epsilon = 1e-6; + if (reduction == 0) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = max(input_y[i], epsilon); + dx[i] = -input_y[i] * dloss[i]; + dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss[i]; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = max(input_y[i], epsilon); + dx[i] = -input_y[i] * dloss[0]; + dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss[0]; + } + } +} + +template +void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, + T *dx, T *dy, cudaStream_t stream) { + KLDivLossGradKernel<<>>(input_size, reduction, input_x, input_y, + dloss, dx, dy); +} + +template +__global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x, + const T *input_y, const T *weight, T *loss, T *tmp_loss) { + T epsilon = 1e-6; + if (reduction == 0) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T antilogarithm = max(input_x[i], epsilon); + T antilogarithm2 = min(1 - input_x[i], 1 - epsilon); + T value = -weight[i] * (input_y[i] * logf(antilogarithm) + (1 - input_y[i]) * logf(antilogarithm2)); + loss[i] = value; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T antilogarithm = max(input_x[i], epsilon); + T antilogarithm2 = min(1 - input_x[i], 1 - epsilon); + T value = -weight[i] * (input_y[i] * logf(antilogarithm) + (1 - input_y[i]) * logf(antilogarithm2)); + tmp_loss[i] = value; + } + } +} + +template +void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, + const T *weight, T *loss, cudaStream_t stream) { + LossInitKernel<<<1, 1, 0, stream>>>(loss); + T *tmp_loss; + if (reduction != 0) { + cudaMalloc(reinterpret_cast(&tmp_loss), input_size * sizeof(T)); + } + BinaryCrossEntropyLossKernel<<>>(input_size, reduction, input_x, + input_y, weight, loss, tmp_loss); + if (reduction != 0) { + if (input_size % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); + } + for (int stride = input_size / 2; stride > 0; stride >>= 1) { + PartialSum<<>>(tmp_loss, stride); + if (stride > 2 && stride % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, stride - 1); + } + } + Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); + } + cudaFree(tmp_loss); +} + +template +__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x, + const T *input_y, const T *weight, const T *dloss, T *dx) { + T epsilon = 1e-6; + if (reduction == 0) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = max(input_x[i] * (1 - input_x[i]), epsilon); + T value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = max(input_x[i] * (1 - input_x[i]), epsilon); + T value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[0]; + } + } +} + +template +void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, + const T *weight, const T *dloss, T *dx, cudaStream_t stream) { + BinaryCrossEntropyLossGradKernel<<>>(input_size, reduction, input_x, + input_y, weight, dloss, dx); +} + +template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, + float *loss, cudaStream_t stream); + +template void KLDivLossGrad(const int &input_size, const int &reduction, const float *input_x, const float *input_y, + const float *dloss, float *dx, float *dy, cudaStream_t stream); + +template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const float *input_x, + const float *input_y, const float *weight, float *loss, cudaStream_t stream); + +template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const float *input_x, + const float *input_y, const float *weight, const float *dloss, float *dx, + cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh new file mode 100644 index 0000000000..a01ca830f7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH +template +void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, + const T *weight, T *loss, cudaStream_t stream); +template +void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, + const T *weight, const T *dloss, T *dx, cudaStream_t stream); +template +void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, + cudaStream_t stream); +template +void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, + T *dx, T *dy, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu new file mode 100644 index 0000000000..b1bd5fdb69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "matrix_combine_impl.cuh" +#include +template +__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width, + const size_t dst_width, T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (src_height * src_width); + size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width; + size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width; + size_t dst_h = src_height * batchIdx + src_h; + size_t dst_w = src_width * batchIdx + src_w; + output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx]; + } +} + +template +__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width, + const size_t dst_width, const size_t res_width, const size_t batch, T *input_addr, + T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (src_height * src_width); + if (batchIdx != (batch - 1)) { + size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width; + size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width; + size_t dst_h = src_height * batchIdx + src_h; + size_t dst_w = src_width * batchIdx + src_w; + output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx]; + } else { + size_t src_h = (pointIdx - (batch - 1) * src_height * src_width) / res_width; + size_t src_w = (pointIdx - (batch - 1) * src_height * src_width) % res_width; + size_t src_coordinate = (batch - 1) * src_height * src_width + src_h * src_width + src_w; + size_t dst_h = src_height * (batch - 1) + src_h; + size_t dst_w = src_width * (batch - 1) + src_w; + output_addr[dst_h * dst_width + dst_w] = input_addr[src_coordinate]; + } + } +} + +template +void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width, + const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr, + cudaStream_t cuda_stream) { + if (residual == 0) { + MatrixCombineKernel<<>>(size, src_height, src_width, dst_width, + input_addr, output_addr); + } else { + MatrixCombineKernel<<>>(size, src_height, src_width, dst_width, + res_width, batch, input_addr, output_addr); + } + return; +} + +template void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, + const size_t dst_width, const size_t residual, const size_t res_width, + const size_t batch, float *input_addr, float *output_addr, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh new file mode 100644 index 0000000000..737ec13383 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width, + const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu new file mode 100644 index 0000000000..15013377fe --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "matrix_split_impl.cuh" +#include +template +__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, + T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (split_dim * split_dim); + size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim; + size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim; + size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y; + output_addr[pointIdx] = input_addr[src_coordinate]; + } +} + +template +__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, const size_t res_dim, + T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (split_dim * split_dim); + size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim; + size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim; + size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y; + size_t batch_lower = dim / split_dim; + if (batchIdx < batch_lower) { + output_addr[pointIdx] = input_addr[src_coordinate]; + } else { + if (dst_x < res_dim && dst_y < res_dim) { + output_addr[pointIdx] = input_addr[src_coordinate]; + } else if (dst_x == dst_y) { + output_addr[pointIdx] = 1; + } else { + output_addr[pointIdx] = 0; + } + } + } +} + +template +void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr, + cudaStream_t cuda_stream) { + size_t batch = dim / split_dim; + size_t res_dim = dim - batch * split_dim; + if (res_dim == 0) { + MatrixSplitKernel<<>>(size, split_dim, dim, input_addr, output_addr); + } else { + MatrixSplitKernel<<>>(size, split_dim, dim, res_dim, input_addr, + output_addr); + } + return; +} + +template void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, float *input_addr, + float *output_addr, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh new file mode 100644 index 0000000000..edae55c14d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu new file mode 100644 index 0000000000..863f3a7a85 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "maxpool_with_argmax_grad_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" + +template +__global__ void MaxPoolWithArgmaxGrad(const T* x, + const T* dy, + const S* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int xNCHW, + const int xCHW, + const int xHW, + const int dyCHW, + const int dyHW, + T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (xNCHW); + pos += blockDim.x * gridDim.x) { + const int posn = pos / xCHW; + const int posc = pos / xHW % c; + const int posh = pos / xHeight % xHeight; + const int posw = pos % xWidth; + const S posIdx = posh*xWidth + posw; + int hstart = posh+padTop; + if (hstart < windowHeight) { + hstart = 0; + } else { + hstart = (hstart-windowHeight)/strideHeight + 1; + } + int wstart = posw+padLeft; + if (wstart < windowWidth) { + wstart = 0; + } else { + wstart = (wstart-windowWidth)/strideWidth + 1; + } + const int hend = min((posh+padTop)/strideHeight +1, dyHeight); + const int wend = min((posw+padLeft)/strideWidth +1, dyWidth); + const int channelStart = posn*dyCHW + posc*dyHW; + T dySum = static_cast(0.0); + for (int hcur = hstart; hcur < hend; ++hcur) { + for (int wcur = wstart; wcur < wend; ++wcur) { + const int curIdx = hcur*dyWidth + wcur; + S maxIdx = index[channelStart+curIdx]; + if (maxIdx == posIdx) { + dySum += dy[channelStart+curIdx]; + } + } + } + dx[pos] = dySum; + } + return; +} + +template <> +__global__ void MaxPoolWithArgmaxGrad(const half* x, + const half* dy, + const int* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int xNCHW, + const int xCHW, + const int xHW, + const int dyCHW, + const int dyHW, + half* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (xNCHW); + pos += blockDim.x * gridDim.x) { + const int posn = pos / xCHW; + const int posc = pos / xHW % c; + const int posh = pos / xHeight % xHeight; + const int posw = pos % xWidth; + const int posIdx = posh*xWidth + posw; + int hstart = posh+padTop; + if (hstart < windowHeight) { + hstart = 0; + } else { + hstart = (hstart-windowHeight)/strideHeight + 1; + } + int wstart = posw+padLeft; + if (wstart < windowWidth) { + wstart = 0; + } else { + wstart = (wstart-windowWidth)/strideWidth + 1; + } + const int hend = min((posh+padTop)/strideHeight +1, dyHeight); + const int wend = min((posw+padLeft)/strideWidth +1, dyWidth); + const int channelStart = posn*dyCHW + posc*dyHW; + float dySum = 0.0f; + for (int hcur = hstart; hcur < hend; ++hcur) { + for (int wcur = wstart; wcur < wend; ++wcur) { + const int curIdx = hcur*dyWidth + wcur; + int maxIdx = index[channelStart+curIdx]; + if (maxIdx == posIdx) { + dySum += __half2float(dy[channelStart+curIdx]); + } + } + } + dx[pos] = __float2half(dySum); + } + return; +} + +template +void CalMaxPoolWithArgmaxGrad(const T* x, + const T* dy, + const S* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + T* dx, + cudaStream_t cuda_stream) { + const int xHW = xHeight*xWidth; + const int xCHW = c*xHW; + const int xNCHW = n*xCHW; + const int dyHW = dyHeight*dyWidth; + const int dyCHW = c*dyHW; + MaxPoolWithArgmaxGrad<<>>( + x, + dy, + index, + n, + c, + xHeight, + xWidth, + dyHeight, + dyWidth, + windowHeight, + windowWidth, + strideHeight, + strideWidth, + padTop, + padLeft, + xNCHW, + xCHW, + xHW, + dyCHW, + dyHW, + dx); + return; +} + +template void CalMaxPoolWithArgmaxGrad(const float* x, + const float* dy, + const int* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + float* dx, + cudaStream_t cuda_stream); +template void CalMaxPoolWithArgmaxGrad(const half* x, + const half* dy, + const int* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + half* dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh new file mode 100644 index 0000000000..fe378acec6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ +template +void CalMaxPoolWithArgmaxGrad(const T* x, const T* dy, const S* index, const int n, const int c, const int xHeight, + const int xWidth, const int dyHeight, const int dyWidth, const int windowHeight, + const int windowWidth, const int strideHeight, const int strideWidth, const int padTop, + const int padLeft, T* dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu new file mode 100644 index 0000000000..7126a3feda --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "maxpool_with_argmax_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void MaxPoolWithArgmax(const T* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + const int outputNCHW, + const int outputCHW, + const int outputHW, + T* output, + S *index) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (outputNCHW); + pos += blockDim.x * gridDim.x) { + const int posn = pos / outputCHW; + const int posc = pos / outputHW % c; + const int posh = pos / outputHeight % outputHeight; + const int posw = pos % outputWidth; + int hstart = posh * strideHeight - padTop; + int wstart = posw * strideWidth - padLeft; + const int hend = min(hstart + windowHeight, h); + const int wend = min(wstart + windowWidth, w); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + S inputStart = posn*c*h*w + posc*h*w; + S maxIdx = hstart*w + wstart; + T maxData = input[inputStart+maxIdx]; + for (int hcur = hstart; hcur < hend; ++hcur) { + for (int wcur = wstart; wcur < wend; ++wcur) { + S inputIdx = hcur*w + wcur; + T inputData = input[inputStart+inputIdx]; + if (inputData > maxData) { + maxIdx = inputIdx; + maxData = inputData; + } + } + } + output[pos] = maxData; + index[pos] = maxIdx; + } + return; +} + +template +void CalMaxPoolWithArgmax(const T* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + T* output, + S *index, + cudaStream_t cuda_stream) { + const int outputNCHW = n*c*outputHeight*outputWidth; + const int outputCHW = c*outputHeight*outputWidth; + const int outputHW = outputHeight*outputWidth; + MaxPoolWithArgmax<<>>( + input, + n, + c, + h, + w, + windowHeight, + windowWidth, + strideHeight, + strideWidth, + padTop, + padLeft, + outputHeight, + outputWidth, + outputNCHW, + outputCHW, + outputHW, + output, + index); + return; +} + +template void CalMaxPoolWithArgmax(const float* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + float* output, + int* index, + cudaStream_t cuda_stream); + +template void CalMaxPoolWithArgmax(const half* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + half* output, + int* index, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh new file mode 100644 index 0000000000..8b088067ed --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_ +template +void CalMaxPoolWithArgmax(const T* input, const int n, const int c, const int h, const int w, const int windowHeight, + const int windowWidth, const int strideHeight, const int strideWidth, const int padTop, + const int padLeft, const int outputHeight, const int outputWidth, T* output, S *index, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu new file mode 100755 index 0000000000..c1117b85ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh" + +__inline__ __device__ bool range_check(int x, int y, int padded_width, int padded_height) { + // check for existence in current padded array + if (((x >= 0) && (x <= padded_width - 1)) && ((y >= 0) && (y <= padded_height - 1))) { + return true; + } + return false; +} + +template +__global__ void MirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int padd_dim, + const int *paddings, int mode, T *output) { + int padd_offset = 4 * (padd_dim - 2); + int pad_left_ = paddings[padd_offset + 4]; + int pad_top_ = paddings[padd_offset + 0]; + + // Create anchor points for old tensor positions inside new tensor + int ap1_x = pad_left_; + int ap1_y = pad_top_; + int ap2_x = pad_left_ + old_width - 1; + int ap2_y = pad_top_ + old_height - 1; + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / padded_width) / padded_height; + const int padded_x = pos % padded_width; + const int padded_y = (pos / padded_width) % padded_height; + + // distance to move from anchor point + int x_dist = 0; + int y_dist = 0; + + // x,y value to mirror in new tenspr + int matchval_x_index = padded_x; + int matchval_y_index = padded_y; + + if (padded_y - pad_top_ < 0 || padded_x - pad_left_ < 0 || padded_y - pad_top_ >= old_height || + padded_x - pad_left_ >= old_width) { + if ((padded_x < ap1_x) || (padded_x > ap2_x)) { + x_dist = (padded_x < ap1_x) ? (ap1_x - padded_x) : (padded_x - ap2_x); // GEN DIST + matchval_x_index = (padded_x < ap1_x) ? (ap1_x + x_dist - mode) : (ap2_x - x_dist + mode); + } + if ((padded_y < ap1_y) || (padded_y > ap2_y)) { + y_dist = (padded_y < ap1_y) ? (ap1_y - padded_y) : (padded_y - ap2_y); + matchval_y_index = (padded_y < ap1_y) ? (ap1_y + y_dist - mode) : (ap2_y - y_dist + mode); + } + output[pos] = + input[(block_num * old_height + matchval_y_index - pad_top_) * old_width + matchval_x_index - pad_left_]; + } else { + // existing values remain the same + output[pos] = input[(block_num * old_height + padded_y - pad_top_) * old_width + padded_x - pad_left_]; + } + } + return; +} + +template +__global__ void MirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, + const int padded_height, const int padded_width, const int old_height, + const int old_width, const int padd_dim, const int *paddings, int mode, T *dx) { + int padd_offset = 4 * (padd_dim - 2); + int pad_left_ = paddings[padd_offset + 4]; + int pad_top_ = paddings[padd_offset + 0]; + + // Create anchor points for positions in the dy array + int ap1_x = pad_left_; + int ap1_y = pad_top_; + int ap2_x = pad_left_ + old_width - 1; + int ap2_y = pad_top_ + old_height - 1; + + int adjust = 0; // adjust dist from reflection axis for symmetric padding + if (mode == 1) { + adjust = 1; + } + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / old_width) / old_height; + + // refer to indices of original values inside padded array + const int padded_x = (pos % old_width) + pad_left_; + const int padded_y = ((pos / old_width) % old_height) + pad_top_; + + // copy positions own value into output + dx[pos] = dx[pos] + dy[(block_num * padded_height + padded_y) * padded_width + padded_x]; + + int x_dist_1 = (ap1_x - padded_x - adjust); + int y_dist_1 = (ap1_y - padded_y - adjust); + int x_dist_2 = (ap2_x - padded_x + adjust); + int y_dist_2 = (ap2_y - padded_y + adjust); + + int axis_dist[] = {x_dist_1, x_dist_2, y_dist_1, y_dist_2}; + int anch_point[] = {ap1_x, ap2_x, ap1_y, ap2_y}; + bool x_axis_check[] = {true, true, false, false}; // true - update X , false - update Y + + int temp_x = 0; + int temp_y = 0; + + // mirroring in axis lines + for (int x = 0; x < 4; x++) { + if (axis_dist[x] != 0) { + if (x_axis_check[x]) { + temp_y = padded_y; + temp_x = anch_point[x] + axis_dist[x]; + } else { + temp_x = padded_x; + temp_y = anch_point[x] + axis_dist[x]; + } + if (range_check(temp_x, temp_y, padded_width, padded_height)) { + dx[pos] = dx[pos] + dy[(block_num * padded_height + temp_y) * padded_width + temp_x]; + } + } + } + + // mirroring at corners + for (int x = 0; x < 2; x++) { + for (int y = 2; y < 4; y++) { + if ((axis_dist[x] != 0) && (axis_dist[y] != 0)) { + temp_x = anch_point[x] + axis_dist[x]; + temp_y = anch_point[y] + axis_dist[y]; + if (range_check(temp_x, temp_y, padded_width, padded_height)) { + dx[pos] = dx[pos] + dy[(block_num * padded_height + temp_y) * padded_width + temp_x]; + } + } + } + } + } + return; +} + +template +void CalMirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, int padd_num, + const int *paddings, const int mode, T *output, cudaStream_t cuda_stream) { + MirrorPad<<>>( + size, input, num, channels, old_height, old_width, padded_height, padded_width, padd_num, paddings, mode, output); + return; +} + +template +void CalMirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, const int padded_height, + const int padded_width, const int old_height, const int old_width, const int padd_dim, + const int *paddings, int mode, T *dx, cudaStream_t cuda_stream) { + MirrorPadGrad<<>>(size, dy, num, channels, padded_height, padded_width, + old_height, old_width, padd_dim, paddings, mode, dx); + return; +} + +template void CalMirrorPad(const size_t size, const float *input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, int padd_num, const int *paddings, int mode, float *output, + cudaStream_t cuda_stream); +template void CalMirrorPadGrad(const size_t size, const float *dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int padd_dim, const int *paddings, int mode, + float *dx, cudaStream_t cuda_stream); +template void CalMirrorPad(const size_t size, const half *input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, int padd_num, const int *paddings, int mode, half *output, + cudaStream_t cuda_stream); +template void CalMirrorPadGrad(const size_t size, const half *dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int padd_dim, const int *paddings, int mode, + half *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh new file mode 100755 index 0000000000..d2bf4dff11 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_ +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalMirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, int padd_num, + const int *paddings, int mode, T *output, cudaStream_t cuda_stream); +template +void CalMirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, const int padded_height, + const int padded_width, const int old_height, const int old_width, const int padd_dim, + const int *paddings, int mode, T *dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu index 5a1c9eb687..03a4ccb617 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu @@ -15,9 +15,9 @@ */ #include "momentum_impl.cuh" -template +template __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, - const T *gradient, const S *momentum) { + const G *gradient, const S *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; variable[i] -= learning_rate[0] * accumulation[i]; @@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, } return; } -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, +template <> +__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, + const float *learning_rate, const half *gradient, + const float *momentum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); + variable[i] -= learning_rate[0] * accumulation[i]; + } + return; +} +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, const S *momentum, cudaStream_t cuda_stream) { MomentumUpdateVariableKernel<<>>(size, variable, accumulation, learning_rate, gradient, momentum); return; } -template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, +template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, const float *learning_rate, const float *gradient, const float *momentum, cudaStream_t cuda_stream); -template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, +template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, const half *learning_rate, const half *gradient, const half *momentum, cudaStream_t cuda_stream); -template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, +template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, + const float *learning_rate, const half *gradient, + const float *momentum, cudaStream_t cuda_stream); +template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, const float *learning_rate, const half *gradient, const float *momentum, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh index 62708663ad..e5a22e4791 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -18,8 +18,8 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ #include "runtime/device/gpu/cuda_common.h" -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, const S *momentum, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu new file mode 100644 index 0000000000..36c1b2ee48 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, softwareg + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nms_with_mask_impl.cuh" +#include +#include + +int RoundUpPower2M(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__inline__ __device__ void SwapM(T *lhs, T *rhs) { + T tmp = lhs[0]; + lhs[0] = rhs[0]; + rhs[0] = tmp; +} + +template +__inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area, + float IOU_value) { + T x_1 = max(output[box_A_start + 0], output[box_B_start + 0]); + T y_1 = max(output[box_A_start + 1], output[box_B_start + 1]); + T x_2 = min(output[box_A_start + 2], output[box_B_start + 2]); + T y_2 = min(output[box_A_start + 3], output[box_B_start + 3]); + T width = max(x_2 - x_1, T(0)); // in case of no overlap + T height = max(y_2 - y_1, T(0)); + T combined_area = area[box_A_ix] + area[box_B_ix]; + // return decision to keep or remove box + return !(((width * height) / (combined_area - (width * height))) > IOU_value); +} + +template +__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) { + for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { + sel_idx[box_num] = box_num; + area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) * + (output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]); + } +} + +template +__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, + int box_size_) { + for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { + // represents highest score box in that GPU block + if (threadIdx.x == 0) { + sel_boxes[box_num] = true; + continue; + } + int box_start_index = box_num * box_size_; // start index adjustment + int block_max_box_num = ((blockIdx.x * blockDim.x) + 0); + int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment + sel_boxes[box_num] = + IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area, + IOU_value); // update mask + } +} + +template +__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) { + int box_i, box_j; // access all shared mem meta data with these + int box_i_start_index, box_j_start_index; // actual input data indexing + for (int i = 0; i < num - 1; i++) { + box_i = i; + box_i_start_index = box_i * box_size_; // adjust starting index + if (sel_boxes[box_i]) { + for (int j = i + 1; j < num; j++) { + box_j = j; + box_j_start_index = box_j * box_size_; + if (sel_boxes[box_j]) { + sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value); + } + } + } + } +} + +template +__global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in, + S *data_out, T *index_buff, S *data_buff, int box_size_) { + // default: sort with share memory + extern __shared__ T share_mem_NMS[]; + T *index_arr = share_mem_NMS; + S *data_arr = reinterpret_cast(index_arr + ceil_power2); + // sort with RAM + if (index_buff != nullptr && data_buff != nullptr) { + index_arr = index_buff + blockIdx.x * ceil_power2; + data_arr = data_buff + blockIdx.x * ceil_power2; + } + for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { + index_arr[i] = (i < inner) ? T(i) : std::numeric_limits::max(); + // populated directly from input data + data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits::max(); + } + __syncthreads(); + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (data_arr[tid] > data_arr[tid_comp]) { + SwapM(&index_arr[tid], &index_arr[tid_comp]); + SwapM(&data_arr[tid], &data_arr[tid_comp]); + } + } else { + if (data_arr[tid] < data_arr[tid_comp]) { + SwapM(&index_arr[tid], &index_arr[tid_comp]); + SwapM(&data_arr[tid], &data_arr[tid_comp]); + } + } + } + } + __syncthreads(); + } + } + T correct_index; + for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { + correct_index = index_arr[(inner - 1) - tid]; + // moved data from input to output, correct ordering using sorted index array + for (auto i : {0, 1, 2, 3, 4}) { + data_out[(blockIdx.x * inner + tid) * box_size_ + i] = + data_in[(blockIdx.x * inner + correct_index) * box_size_ + i]; + } + } +} + +template +void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) { + Preprocess<<>>(num, sel_idx, area, output, box_size_); +} + +template +void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, + int box_size_, cudaStream_t stream) { + int ceil_power2 = RoundUpPower2M(inner); + size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); + if (share_mem > SHARED_MEM_PER_BLOCK) { + share_mem = 0; + } else { + data_buff = nullptr; + index_buff = nullptr; + } + int thread = std::min(ceil_power2, GET_THREADS); + BitonicSortByKeyKernelM<<>>(outer, inner, ceil_power2, data_in, data_out, + index_buff, data_buff, box_size_); +} + +template +void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream) { + NMSWithMaskKernel<<>>(num, IOU_value, output, area, sel_boxes, + box_size_); +} + +template +void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream) { + FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_); +} + +template void CalPreprocess(const int num, int *sel_idx, float *area, float *output, int box_size_, + cudaStream_t cuda_stream); + +template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff, + float *data_buff, int box_size_, cudaStream_t stream); + +template void CalNMSWithMask(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, + int box_size_, cudaStream_t cuda_stream); + +template void CalFinalPass(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, + int box_size_, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh new file mode 100644 index 0000000000..0eafd51389 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream); + +template +void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream); + +template +void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, + int box_size_, cudaStream_t stream); + +template +void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu new file mode 100644 index 0000000000..dc1d9cd206 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "oneslike_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void OnesLike(const int size, const T* input, T* output) { + int one = 1; + T val = static_cast(one); + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = val; + } + return; +} +template +void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream) { + OnesLike<<>>(size, input, output); + return; +} + +template void CalOnesLike(const int size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalOnesLike(const int size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalOnesLike(const int size, const int* input, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh new file mode 100644 index 0000000000..81e92c1d09 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ + +template +void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu new file mode 100644 index 0000000000..6ce1fda22b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu @@ -0,0 +1,265 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" +#include + +int RcwmRoundUpPower2(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__inline__ __device__ void Swap(T *lhs, T *rhs) { + T tmp = lhs[0]; + lhs[0] = rhs[0]; + rhs[0] = tmp; +} + +template +__global__ void InitArray(const int input_size, const int ceil_power2, const T *input, S *mask_buff, S *rank_buff) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ceil_power2; pos += blockDim.x * gridDim.x) { + mask_buff[pos] = (pos < input_size) ? static_cast(input[pos]) : 0; + rank_buff[pos] = (pos < input_size && input[pos] != false) ? pos : (ceil_power2 + 1); + } +} + +template +__device__ void WarpReduce(volatile T *sdata, size_t tid) { + if (blockSize >= 64) sdata[tid] += sdata[tid + 32]; + if (blockSize >= 32) sdata[tid] += sdata[tid + 16]; + if (blockSize >= 16) sdata[tid] += sdata[tid + 8]; + if (blockSize >= 8) sdata[tid] += sdata[tid + 4]; + if (blockSize >= 4) sdata[tid] += sdata[tid + 2]; + if (blockSize >= 2) sdata[tid] += sdata[tid + 1]; +} + +template +__global__ void ReductionSum(T *g_idata, T *g_odata, size_t n) { + __shared__ T sdata[blockSize]; + + size_t tid = threadIdx.x; + size_t i = blockIdx.x * (blockSize) + tid; + size_t gridSize = blockSize * gridDim.x; + sdata[tid] = 0; + + while (i < n) { + sdata[tid] += g_idata[i]; + i += gridSize; + } + + __syncthreads(); + + if (blockSize >= 1024) { + if (tid < 512) { + sdata[tid] += sdata[tid + 512]; + } + __syncthreads(); + } + if (blockSize >= 512) { + if (tid < 256) { + sdata[tid] += sdata[tid + 256]; + } + __syncthreads(); + } + if (blockSize >= 256) { + if (tid < 128) { + sdata[tid] += sdata[tid + 128]; + } + __syncthreads(); + } + if (blockSize >= 128) { + if (tid < 64) { + sdata[tid] += sdata[tid + 64]; + } + __syncthreads(); + } + + if (tid < 32) WarpReduce(sdata, tid); + if (tid == 0) g_odata[blockIdx.x] = sdata[0]; +} + +template +__global__ void Reshape2Index(const int input_size, const int input_shape_size, const int d1, const int d2, + const int d3, const int d4, const int d5, const T *input, S *output_index) { + int pos_array[MAX_DIMENSION]; + int index_pos; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { + pos_array[0] = pos / (d2 * d3 * d4 * d5) % d1; + pos_array[1] = pos / (d3 * d4 * d5) % d2; + pos_array[2] = pos / (d4 * d5) % d3; + pos_array[3] = pos / (d5) % d4; + pos_array[4] = pos % d5; + + index_pos = pos * input_shape_size; + if (input[pos] == false) { + for (int i = 0; i < input_shape_size; i++) { + output_index[index_pos++] = 0; + } + } else { + for (int i = MAX_DIMENSION - input_shape_size; i < MAX_DIMENSION; i++) { + output_index[index_pos++] = pos_array[i]; + } + } + } +} + +template +__global__ void Copy(const T *src, T *dst, const int n) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n; pos += blockDim.x * gridDim.x) { + dst[pos] = src[pos]; + } +} + +template +__global__ void Sort(const int ceil_power2, T *rank_buff) { + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (rank_buff[tid] > rank_buff[tid_comp]) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); + } + } else { + if (rank_buff[tid] < rank_buff[tid_comp]) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ void SrandInit(const int ceil_power2, curandState *globalState, const int seedc) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < ceil_power2; i += blockDim.x * gridDim.x) { + curand_init(seedc, i, 0, &globalState[i]); + } +} + +template +__global__ void Shuffle(const int ceil_power2, curandState *globalState, T *rank_buff) { + int limit = ceil_power2 + 1; + int value; + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + value = static_cast(curand(&globalState[tid])); + if (value & 1) { + if (rank_buff[tid] != limit && rank_buff[tid_comp] != limit) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void MoveToOutput(const int input_shape_size, const int count, const T *input, S *output_index, + T *output_mask, S *index_buff, S *rank_buff, S *Tnum_buff) { + int Tnum = static_cast(Tnum_buff[0]); + int idx = 0; + int pos; + if (count <= Tnum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + idx = rank_buff[i]; + pos = i; + output_mask[pos] = input[idx]; + pos *= input_shape_size; + idx *= input_shape_size; + for (size_t j = 0; j < input_shape_size; j++) { + output_index[pos] = index_buff[idx]; + pos++; + idx++; + } + } + } else { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + if (i < Tnum) { + idx = rank_buff[i]; + pos = i; + output_mask[pos] = input[idx]; + pos *= input_shape_size; + idx *= input_shape_size; + for (size_t j = 0; j < input_shape_size; j++) { + output_index[pos] = index_buff[idx]; + pos++; + idx++; + } + } else { + pos = i; + output_mask[pos] = static_cast(0); + pos *= input_shape_size; + for (size_t j = 0; j < input_shape_size; j++) { + output_index[pos] = static_cast(0); + pos++; + } + } + } + } +} + +template +void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, + const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, + const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff, + S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream) { + int ceil_power2 = RcwmRoundUpPower2(input_size); + + InitArray<<>>(input_size, ceil_power2, input, mask_buff, rank_buff); + + size_t BLOCKNUM; + size_t n = ceil_power2; + Copy<<>>(mask_buff, tmp_buff, ceil_power2); + do { + BLOCKNUM = std::ceil(static_cast(n) / BLOCKSIZE); + ReductionSum<<>>(tmp_buff, Tnum_buff, n); + Copy<<>>(Tnum_buff, tmp_buff, BLOCKNUM); + n = BLOCKNUM; + } while (n > BLOCKSIZE); + if (n > 1) ReductionSum<<<1, BLOCKSIZE, 0, stream>>>(Tnum_buff, Tnum_buff, n); + + Reshape2Index<<>>(input_size, input_shape_size, d1, d2, d3, d4, d5, + input, index_buff); + + Sort<<>>(ceil_power2, rank_buff); + + SrandInit<<>>(ceil_power2, globalState, seedc); + Shuffle<<>>(ceil_power2, globalState, rank_buff); + + MoveToOutput<<>>(input_shape_size, count, input, output_index, output_mask, + index_buff, rank_buff, Tnum_buff); +} + +template void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, + const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, + const bool *input, int *output_index, bool *output_mask, int *index_buff, + int *mask_buff, int *rank_buff, int *Tnum_buff, int *tmp_buff, + curandState *globalState, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh new file mode 100644 index 0000000000..bb654e4b58 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ + +#include +#include +#include "runtime/device/gpu/cuda_common.h" +#define BLOCKSIZE 256 +#define MAX_DIMENSION 5 + +template +void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, + const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, + const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff, + S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream); + +int RcwmRoundUpPower2(int v); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu index 6f99394562..19a1273cb3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu @@ -24,6 +24,18 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size return; } +template +__global__ void UniformKernel(int seed, curandState *globalState, T *input1, size_t input_size_1, + T *input2, size_t input_size_2, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + input1[i] = (input_size_1 == 1 ? input1[0] : input1[i]); + input2[i] = (input_size_2 == 1 ? input2[0] : input2[i]); + curand_init(seed, i, 0, &globalState[i]); + output[i] = curand_uniform(&globalState[i]) * (input2[i] - input1[i]) + input1[i]; + } + return; +} + template void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) { int RNG_seed = 0; @@ -38,5 +50,17 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si return; } +template +void UniformReal(int seed, curandState *globalState, T *input1, size_t input_size_1, + T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) { + seed = (seed == 0 ? time(NULL):seed); + UniformKernel<<>> + (seed, globalState, input1, input_size_1, input2, input_size_2, output, count); + return; +} + template void StandardNormal(int seed, int seed2, curandState *globalState, float *output, size_t count, cudaStream_t cuda_stream); +template void UniformReal(int seed, curandState *globalState, float *input1, size_t input_size_1, + float *input2, size_t input_size_2, float *output, size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh index b099ead9bf..f5699cee0a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh @@ -23,4 +23,8 @@ template void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream); +template +void UniformReal(int seed, curandState *globalState, + T *input1, size_t input_size_1, T *input2, size_t input_size_2, + T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu new file mode 100644 index 0000000000..edb509a38d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" + +template +__global__ void InitZero(T *output, const int output_size) { + for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (output_size); pos += gridDim.x * blockDim.x) { + output[pos] = static_cast(0); + } +} + +template +__global__ void ResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, + const int s3, const int s4, T *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale) { + // initialization + // HalfPixelCenters false + int output_pos; + int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION]; + int out_height = d3; + int out_width = d4; + // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + + // pos_array[1] * output_shape[2] * output_shape[3] + + // pos_array[2] * output_shape[3] + + // pos_array[3] + int in_h; + int in_w; + + for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (input_size); pos += gridDim.x * blockDim.x) { + pos_array[0] = pos / (s2 * s3 * s4) % s1; + pos_array[1] = pos / (s3 * s4) % s2; + pos_array[2] = pos / (s4) % s3; + pos_array[3] = pos % s4; + in_h = pos_array[2]; + in_w = pos_array[3]; + const int out_y = + min((align_corners) ? static_cast(roundf(in_h * h_scale)) : static_cast(floorf(in_h * h_scale)), + out_height - 1); + const int out_x = + min((align_corners) ? static_cast(roundf(in_w * w_scale)) : static_cast(floorf(in_w * w_scale)), + out_width - 1); + // pos_array[0] N, pos_array[1] C, out_y H, out_x W + output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x; + ms_atomic_add(&output[output_pos], input[pos]); + } +} + +template +void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, const int d4, + bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) { + int output_size = d1 * d2 * d3 * d4; + InitZero<<>>(output, output_size); + ResizeNearestNeighborGrad<<>>( + input_size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); + return; +} + +template void CalResizeNearestNeighborGrad(const int input_size, const float *input, const int s1, const int s2, + const int s3, const int s4, float *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighborGrad(const int input_size, const half *input, const int s1, const int s2, + const int s3, const int s4, half *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighborGrad(const int input_size, const int *input, const int s1, const int s2, + const int s3, const int s4, int *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh new file mode 100644 index 0000000000..c7f85e694a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ +#include +#include "runtime/device/gpu/cuda_common.h" +#define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4 + +template +void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, const int d4, + bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu new file mode 100644 index 0000000000..2cca9bd7a8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh" + +template +__global__ void ResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, const int d4, + bool align_corners, float h_scale, float w_scale) { + // initialization + // HalfPixelCenters false + int input_pos; + int pos_array[RESIZENEARESTNEIGHBOR_DIMENSION]; + int in_height = s3; + int in_width = s4; + // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + + // pos_array[1] * output_shape[2] * output_shape[3] + + // pos_array[2] * output_shape[3] + + // pos_array[3] + int out_h; + int out_w; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + pos_array[0] = pos / (d2 * d3 * d4) % d1; + pos_array[1] = pos / (d3 * d4) % d2; + pos_array[2] = pos / (d4) % d3; + pos_array[3] = pos % d4; + out_h = pos_array[2]; + out_w = pos_array[3]; + const int in_y = + min((align_corners) ? static_cast(roundf(out_h * h_scale)) : static_cast(floorf(out_h * h_scale)), + in_height - 1); + const int in_x = + min((align_corners) ? static_cast(roundf(out_w * w_scale)) : static_cast(floorf(out_w * w_scale)), + in_width - 1); + // pos_array[0] N, pos_array[1] C, in_y H, in_x W + input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; + output[pos] = input[input_pos]; + } + return; +} + +template +void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4, + T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, + float h_scale, float w_scale, cudaStream_t cuda_stream) { + ResizeNearestNeighbor<<>>(size, input, s1, s2, s3, s4, output, d1, d2, + d3, d4, align_corners, h_scale, w_scale); + return; +} + +template void CalResizeNearestNeighbor(const int size, const float *input, const int s1, const int s2, + const int s3, const int s4, float *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighbor(const int size, const half *input, const int s1, const int s2, + const int s3, const int s4, half *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighbor(const int size, const int *input, const int s1, const int s2, const int s3, + const int s4, int *output, const int d1, const int d2, const int d3, + const int d4, bool align_corners, float h_scale, float w_scale, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh new file mode 100644 index 0000000000..a9eafe36ce --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ +#include +#include "runtime/device/gpu/cuda_common.h" +#define RESIZENEARESTNEIGHBOR_DIMENSION 4 + +template +void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4, + T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, + float h_scale, float w_scale, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu new file mode 100644 index 0000000000..5706aa15fc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "roi_align_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high, + int *y_high, T *w1, T *w2, T *w3, T *w4) { + // return 0 if out of map boundary + if (y <= static_cast(-1.0) || y >= static_cast(height) || x <= static_cast(-1.0) || + x >= static_cast(width)) { + *w1 = *w2 = *w3 = *w4 = 0; + *x_low = *x_high = *y_low = *y_high = -1; + return; + } + + // low bounder is at least zero + y = y <= static_cast(.0) ? static_cast(.0) : y; + x = x <= static_cast(.0) ? static_cast(.0) : x; + + // top left point + *y_low = static_cast(y); + *x_low = static_cast(x); + + // bottom right point + if (*y_low >= height - 1) { + *y_high = *y_low = height - 1; + y = static_cast(*y_low); + } else { + *y_high = *y_low + 1; + } + + if (*x_low >= width - 1) { + *x_high = *x_low = width - 1; + x = static_cast(*x_low); + } else { + *x_high = *x_low + 1; + } + + // distance to nearest points + T lx, ly, hx, hy; + ly = y - static_cast(*y_low), lx = x - static_cast(*x_low); + hy = static_cast(1.) - ly, hx = static_cast(1.) - lx; + + // weight is evaluated by the distance to point away. + // the closer to point home, the more weight, the farther to point away. + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; + return; +} + +template +__device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_scale, const int sample_num, + int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, int *offset, int *n, int *c, int *ph, int *pw, + int *roi_bin_grid_h, int *roi_bin_grid_w, T *bin_size_h, T *bin_size_w, T *roi_start_h, + T *roi_start_w) { + // (n, c, ph, pw) is the base param of pooled map + *pw = thread_idx % pooled_width; + *ph = (thread_idx / pooled_width) % pooled_height; + *c = (thread_idx / pooled_width / pooled_height) % channels; + *n = thread_idx / pooled_width / pooled_height / channels; + + // Roi has + // 1. 4 points, or + // 2. indicator + 4 points (1 + 4) + const T *roi_box = roi_boxes + (*n) * roi_cols; + int roi_batch_ind = 0; + if (roi_cols == 5) { + roi_batch_ind = roi_box[0]; + roi_box++; + } + + // Scale and shift ROI + T roi_offset = roi_end_mode == 1 ? static_cast(0.5) : static_cast(.0); + *roi_start_w = roi_box[0] * spatial_scale - roi_offset; + *roi_start_h = roi_box[1] * spatial_scale - roi_offset; + T roi_end_w = roi_box[2] * spatial_scale - roi_offset; + T roi_end_h = roi_box[3] * spatial_scale - roi_offset; + + // New ROI height/width + T roi_width = roi_end_w - (*roi_start_w); + T roi_height = roi_end_h - (*roi_start_h); + + // ratio of roi / pooled + *bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + *bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + *offset = (roi_batch_ind * channels + (*c)) * height * width; + + // grid (int) by Sample ratio if defined, otherwise by pooled H/W + *roi_bin_grid_h = (sample_num > 0) ? sample_num : static_cast(roi_height / static_cast(pooled_height)); + *roi_bin_grid_w = (sample_num > 0) ? sample_num : static_cast(roi_width / static_cast(pooled_width)); + return; +} + +template +__global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes, int roi_cols, T *out_data, + const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width) { + for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; + thread_idx += blockDim.x * gridDim.x) { + int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; + T bin_size_h, bin_size_w, roi_start_h, roi_start_w; + + bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, + pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h, + &bin_size_w, &roi_start_h, &roi_start_w); + + // (n, c, ph, pw) is the base param of pooled map + const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w; + + T accumulate_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + // Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT + const T y = roi_start_h + static_cast(ph) * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + static_cast(pw) * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + // bilinear interpolate by shifted y / x + // calculate bilinear interpolation + int x_low, y_low, x_high, y_high; + T w1, w2, w3, w4; + bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4); + T v1 = input[y_low * width + x_low + offset]; + T v2 = input[y_low * width + x_high + offset]; + T v3 = input[y_high * width + x_low + offset]; + T v4 = input[y_high * width + x_high + offset]; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + accumulate_val += val; + } + } + accumulate_val /= count_points_in_grid_cell; + + out_data[thread_idx] = accumulate_val; + } +} + +template +void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale, + const int sample_num, int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) { + size_t size = roi_rows * channels * pooled_height * pooled_width; + ROIAlignKernel<<>>(size, x, roi_boxes, roi_cols, out_data, + spatial_scale, sample_num, roi_end_mode, channels, + height, width, pooled_height, pooled_width); + return; +} + +template void ROIAlign(const float *x, const float *roi_boxes, int roi_rows, int roi_cols, float *out_data, + const float spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width, + cudaStream_t cuda_stream); + +template void ROIAlign(const half *x, const half *roi_boxes, int roi_rows, int roi_cols, half *out_data, + const half spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width, + cudaStream_t cuda_stream); + +template +__global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, int roi_cols, T *dx, + const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width) { + for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; + thread_idx += blockDim.x * gridDim.x) { + int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; + T bin_size_h, bin_size_w, roi_start_h, roi_start_w; + + bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, + pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h, + &bin_size_w, &roi_start_h, &roi_start_w); + + // (n, c, ph, pw) is the base param of pooled map + const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T *offset_top_diff = dy + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + // Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT + const T y = roi_start_h + static_cast(ph) * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + static_cast(pw) * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + // bilinear interpolate by shifted y / x + // calculate bilinear interpolation + int x_low, y_low, x_high, y_high; + T w1, w2, w3, w4; + bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4); + T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell; + T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell; + T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell; + T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(dx + offset + y_low * width + x_low, static_cast(g1)); + atomicAdd(dx + offset + y_low * width + x_high, static_cast(g2)); + atomicAdd(dx + offset + y_high * width + x_low, static_cast(g3)); + atomicAdd(dx + offset + y_high * width + x_high, static_cast(g4)); + } + } + } + } +} + +template +void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T *dx, const T spatial_scale, + const int sample_num, int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) { + size_t size = roi_rows * channels * pooled_height * pooled_width; + ROIAlignGradKernel<<>>( + size, dy, roi_boxes, roi_cols, dx, spatial_scale, sample_num, roi_end_mode, channels, height, width, pooled_height, + pooled_width); + return; +} + +template void ROIAlignGrad(const float *dy, const float *roi_boxes, int roi_rows, int roi_cols, float *dx, + const float spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh new file mode 100644 index 0000000000..aad65e7ba3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ +template +void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale, + const int sample_num, int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, cudaStream_t cuda_stream); + +template +void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T *dx, const T spatial_scale, + const int sample_num, int roi_end_mode, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu new file mode 100644 index 0000000000..80258e718d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, + const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, + S *indices_stride, S *work_shape) { + int i, j; + for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; + read_index += blockDim.x * gridDim.x) { + int write_index = 0; + bool out_bound = false; + + i = read_index / block_size; + j = read_index % block_size; + + for (size_t k = 0; k < indices_dim_1; k++) { + S indices_i = indices[i * indices_dim_1 + k]; + out_bound |= indices_i >= work_shape[k]; + write_index += indices_i * indices_stride[k]; + } + + write_index += j; + out_bound |= write_index >= output_size; + + if (!out_bound) { + ms_atomic_add(&output[write_index], update[read_index]); + } + } +} + +template +void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream) { + ScatterNdKernel<<>>(indices, update, output, block_size, input_size, + output_size, indices_dim_0, indices_dim_1, + indices_stride, work_shape); + return; +} + +template void ScatterNd(int *indices, float *update, float *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, half *update, half *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, int *update, int *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh new file mode 100644 index 0000000000..7573239743 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SCATTER_ND_GPU_CU_H +#define MINDSPORE_SCATTER_ND_GPU_CU_H + +#include "runtime/device/gpu/cuda_common.h" + +template +void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream); +#endif // MINDSPORE_SCATTER_ND_GPU_CU_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu new file mode 100644 index 0000000000..66451c2390 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" + +template +__global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad, + const T *momentum, const T *lr, T *param, T *accum, T *stat) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + T grad_new = grad[i]; + if (weight_decay > static_cast(0)) { + grad_new += param[i] * weight_decay; + } + + if (momentum[0] > static_cast(0)) { + if (stat[i] > static_cast(0)) { + accum[i] = grad_new; + stat[i] = 0; + } else { + accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new; + } + + if (nesterov) { + grad_new += accum[i] * momentum[0]; + } else { + grad_new = accum[i]; + } + } + + param[i] -= lr[0] * grad_new; + } +} + +template +void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, + const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) { + SGDKernel<<>>(size, dampening, weight_decay, nesterov, grad, momentum, + lr, param, accum, stat); +} + +template void SGD(const int size, const float dampening, const float weight_decay, const bool nesterov, const float *lr, + const float *momentum, const float *grad, float *param, float *accum, float *stat, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh new file mode 100644 index 0000000000..bc2fa3304d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, + const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu old mode 100755 new mode 100644 index dd4effc174..6e73e29b5a --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -21,48 +21,29 @@ #include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" template -__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output) { +__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, + const T *input, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { int i = pos / (l2 * l3 * l4) % l1; int j = pos / (l3 * l4) % l2; int k = pos / l4 % l3; int o = pos % l4; - int offset = (i + s1) * (d2 * d3 * d4) + - (j + s2) * (d3 * d4) + - (k + s3) * d4 + - (o + s4); + int offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4); output[pos] = input[offset]; } } template -__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { +__global__ void SliceGrad(const T *dy, int p, int start, int length, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { output[start + pos] = dy[p + pos]; } return; } + template -__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); - pos += blockDim.x * gridDim.x) { - output[p + pos] = input[start + pos * stride]; - } - return; -} -template -__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); - pos += blockDim.x * gridDim.x) { - dx[start + pos * stride] = dy[p + pos]; - } - return; -} -template -__global__ void FillArray(T* addr, const size_t len, const float value) { +__global__ void FillArray(T *addr, const size_t len, const float value) { T value_ = static_cast(value); for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < len; pos += blockDim.x * gridDim.x) { addr[pos] = value_; @@ -70,21 +51,20 @@ __global__ void FillArray(T* addr, const size_t len, const float value) { return; } template -void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream) { +void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream) { FillArray<<>>(addr, input_size, value); return; } template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output, cudaStream_t stream) { - Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, - d1, d2, d3, d4, input, output); +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, + const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, + cudaStream_t stream) { + Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4, + input, output); } template -void CalSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, const std::vector begin, - const std::vector size, T* output, cudaStream_t cuda_stream) { +void CalSliceGrad(const size_t input_size, const T *dy, const std::vector in_shape, const std::vector begin, + const std::vector size, T *output, cudaStream_t cuda_stream) { int block = in_shape[1] * in_shape[2] * in_shape[3]; int map = in_shape[2] * in_shape[3]; int w = in_shape[3]; @@ -100,92 +80,120 @@ void CalSliceGrad(const size_t input_size, const T* dy, const std::vector i } } } + template -void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int ended = end[3]; - int p = 0; - int start = 0; - for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0])); i += std::abs(strides[0])) { - for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1])); j += std::abs(strides[1])) { - for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2])); k += std::abs(strides[2])) { - start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + - (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; - StridedSlice<<>>(input, p, start, begin[3], strides[3], - ended, output); - p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); - } - } +__global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int b4, + const int b5, const int b6, const int s0, const int s1, const int s2, + const int s3, const int s4, const int s5, const int s6, const int i0, + const int i1, const int i2, const int i3, const int i4, const int i5, + const int i6, const int o0, const int o1, const int o2, const int o3, + const int o4, const int o5, const int o6, const T *input_addr, T *output_addr) { + int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { + int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; + int j = pos / (o2 * o3 * o4 * o5 * o6) % o1; + int k = pos / (o3 * o4 * o5 * o6) % o2; + int l = pos / (o4 * o5 * o6) % o3; + int m = pos / (o5 * o6) % o4; + int n = pos / (o6) % o5; + int o = pos % o6; + + int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ + + (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \ + + (n * s5 + b5) * i6 + (o * s6 + b6); + output_addr[pos] = input_addr[input_idx]; } } + template -void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* dx, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int ended = end[3]; - int p = 0; - int start = 0; - for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0] + 1)); i += std::abs(strides[0])) { - for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1] + 1)); - j += std::abs(strides[1])) { - for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2] + 1)); - k += std::abs(strides[2])) { - start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + - (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; - StridedSliceGrad<<>>(dy, p, start, begin[3], strides[3], - ended, dx); - p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); - } - } +void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const T *input, T *output, + cudaStream_t cuda_stream) { + int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \ + * output_shape[4] * output_shape[5] * output_shape[6]; + StridedSliceKernel<<>>( + begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], + strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6], + input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4], input_shape[5], input_shape[6], + output_shape[0], output_shape[1], output_shape[2], output_shape[3], output_shape[4], output_shape[5], + output_shape[6], input, output); +} + +template +__global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int b4, + const int b5, const int b6, const int s0, const int s1, const int s2, + const int s3, const int s4, const int s5, const int s6, const int i0, + const int i1, const int i2, const int i3, const int i4, const int i5, + const int i6, const int o0, const int o1, const int o2, const int o3, + const int o4, const int o5, const int o6, const T *dy, T *dx) { + int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { + int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; + int j = pos / (o2 * o3 * o4 * o5 * o6) % o1; + int k = pos / (o3 * o4 * o5 * o6) % o2; + int l = pos / (o4 * o5 * o6) % o3; + int m = pos / (o5 * o6) % o4; + int n = pos / (o6) % o5; + int o = pos % o6; + + int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ + + (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \ + + (n * s5 + b5) * i6 + (o * s6 + b6); + dx[input_idx] = dy[pos]; } + return; +} + +template +void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, + const std::vector &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream) { + int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6]; + StridedSliceGradKernel<<>>( + begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], + strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6], + dx_shape[0], dx_shape[1], dx_shape[2], dx_shape[3], dx_shape[4], dx_shape[5], dx_shape[6], + dy_shape[0], dy_shape[1], dy_shape[2], dy_shape[3], dy_shape[4], dy_shape[5], dy_shape[6], + dy, dx); } -template void FillDeviceArray(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, +template void FillDeviceArray(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, const float *input, float *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, float* output, +template void CalSliceGrad(const size_t input_size, const float *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, float *output, cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const float* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, float* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, float* dx, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, +template void FillDeviceArray(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, const half *input, half *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, half* output, +template void CalSliceGrad(const size_t input_size, const half *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, half *output, cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const half* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, half* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, half* dx, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, +template void FillDeviceArray(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, + const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, const int *input, int *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, int* output, +template void CalSliceGrad(const size_t input_size, const int *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, int *output, cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const int* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, int* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, int* dx, cudaStream_t cuda_stream); + +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const float *input, + float *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const half *input, + half *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const int *input, + int *output, cudaStream_t cuda_stream); + +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const float *dy, + float *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const half *dy, + half *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, const int *dy, + int *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh old mode 100755 new mode 100644 index e04f277c3d..70b013174e --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh @@ -21,23 +21,20 @@ #include #include "runtime/device/gpu/cuda_common.h" - template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output, cudaStream_t stream); +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, + const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, + cudaStream_t stream); template -void CalSliceGrad(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector size, T* output, cudaStream_t cuda_stream); +void CalSliceGrad(const size_t input_size, const T *input, const std::vector in_shape, + const std::vector begin, const std::vector size, T *output, cudaStream_t cuda_stream); template -void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* output, cudaStream_t cuda_stream); +void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, const T *input, T *output, + cudaStream_t cuda_stream); template -void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* dx, cudaStream_t cuda_stream); +void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, + const std::vector &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream); template -void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream); +void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu new file mode 100755 index 0000000000..e892a3b47d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" +template +__global__ void Split(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const T* input, T** outputs) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int num = pos % all_size_before_axis / all_size_axis; + int block = num / axis_step; + int block_pos = pos / all_size_before_axis * axis_step * all_size_axis + + num % axis_step * all_size_axis + pos % all_size_axis; + outputs[block][block_pos] = input[pos]; + } + return; +} + +template +void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) { + Split<<>>(size, axis_step, all_size_before_axis, + all_size_axis, input, outputs); + return; +} + +template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const float* input, float** outputs, + cudaStream_t cuda_stream); +template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const int* input, int** outputs, + cudaStream_t cuda_stream); +template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const half* input, half** outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh new file mode 100755 index 0000000000..b8abce290d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu new file mode 100644 index 0000000000..6e5ac52903 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" +#include +#include + +int RoundUpPower2(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__inline__ __device__ void Swap(T *lhs, T *rhs) { + T tmp = lhs[0]; + lhs[0] = rhs[0]; + rhs[0] = tmp; +} + +template +__global__ void TopkKernel(const int outer, const int inner, const int ceil_power2, const T *input, const S *k, + T *output, S *indices, T *data_buff, S *index_buff) { + // default: sort with share memory + extern __shared__ T share_mem[]; + T *data_arr = share_mem; + S *index_arr = reinterpret_cast(data_arr + ceil_power2); + // sort with RAM + if (data_buff != nullptr && index_buff != nullptr) { + data_arr = data_buff + blockIdx.x * ceil_power2; + index_arr = index_buff + blockIdx.x * ceil_power2; + } + + for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { + data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits::max(); + index_arr[i] = i; + } + __syncthreads(); + + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (data_arr[tid] > data_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } else { + if (data_arr[tid] < data_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } + } + } + __syncthreads(); + } + } + + for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) { + output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1]; + indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1]; + } +} + +template +void TopK(const int &outer, const int &inner, const T *input, const S *k, T *output, S *indices, T *data_buff, + S *index_buff, cudaStream_t stream) { + int ceil_power2 = RoundUpPower2(inner); + int share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0; + int thread = std::min(ceil_power2, GET_THREADS); + TopkKernel<<>>(outer, inner, ceil_power2, input, k, output, indices, data_buff, + index_buff); +} + +template +__global__ void BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input, + S *indices, T *data_buff, S *index_buff) { + // default: sort with share memory + extern __shared__ T share_mem[]; + T *data_arr = share_mem; + S *index_arr = reinterpret_cast(data_arr + ceil_power2); + // sort with RAM + if (data_buff != nullptr && index_buff != nullptr) { + data_arr = data_buff + blockIdx.x * ceil_power2; + index_arr = index_buff + blockIdx.x * ceil_power2; + } + + for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { + data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits::max(); + index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits::max();; + } + __syncthreads(); + + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (index_arr[tid] > index_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } else { + if (index_arr[tid] < index_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } + } + } + __syncthreads(); + } + } + + for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { + input[blockIdx.x * inner + tid] = data_arr[tid]; + indices[blockIdx.x * inner + tid] = index_arr[tid]; + } +} + +template +void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff, + cudaStream_t stream) { + int ceil_power2 = RoundUpPower2(inner); + size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); + if (share_mem > SHARED_MEM_PER_BLOCK) { + share_mem = 0; + } else { + data_buff = nullptr; + index_buff = nullptr; + } + int thread = std::min(ceil_power2, GET_THREADS); + BitonicSortByKeyKernel<<>>(outer, inner, ceil_power2, input, indices, data_buff, + index_buff); +} + +template void TopK(const int &outer, const int &inner, const float *input_addr, const int *k, float *output, + int *indices, float *data_buff, int *index_buff, cudaStream_t stream); +template void BitonicSortByKey(const int &outer, const int &inner, float *input, int *indices, float *data_buff, + int *index_buff, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh new file mode 100644 index 0000000000..014044296a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void TopK(const int &outer, const int &inner, const T *input_addr, const S *k, T *output, S *indices, T *data_buff, + S *index_buff, cudaStream_t stream); + +template +void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff, + cudaStream_t stream); +int RoundUpPower2(int v); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 09b347e3d5..629c4c29dc 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -103,6 +103,35 @@ __global__ void ZeroslikeKernel(T *output, size_t count) { return; } template +__global__ void AbsKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = abs(input[i]); + } + return; +} +template <> +__global__ void AbsKernel(half *input, half *output, size_t count) { + half zero = 0.0; + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = input[i] < zero ? -input[i] : input[i]; + } + return; +} +template +__global__ void FloorKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = floor(input[i]); + } + return; +} +template <> +__global__ void FloorKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hfloor(input[i]); + } + return; +} +template void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -147,6 +176,16 @@ void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { ZeroslikeKernel<<>>(output, count); return; } +template +void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + AbsKernel<<>>(input, output, count); + return; +} +template +void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + FloorKernel<<>>(input, output, count); + return; +} template void Exponential(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(float *input, float *output, size_t count, cudaStream_t cuda_stream); @@ -156,6 +195,8 @@ template void Square(float *input, float *output, size_t count, cudaStrea template void Sqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); +template void Abs(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Floor(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); @@ -164,3 +205,5 @@ template void Square(half *input, half *output, size_t count, cudaStream_t template void Sqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); +template void Abs(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Floor(half *input, half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index cf8b30866e..4020f93df2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -34,5 +34,9 @@ template void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); +template +void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh new file mode 100644 index 0000000000..9da273a661 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); } + +inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); } + +inline __device__ half ms_atomic_add(half *address, half val) { + unsigned int *aligned = + reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *aligned; + unsigned int assumed; + unsigned short old_as_us; //NOLINT + do { + assumed = old; + old_as_us = static_cast(reinterpret_cast(address) & 2 ? old >> 16 : old & 0xffff); //NOLINT + half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast(val)); + unsigned short sum_as_us = __half_as_ushort(sum); //NOLINT + unsigned int sum_as_ui = + reinterpret_cast(address) & 2 ? (sum_as_us << 16) | (old & 0xffff) : (old & 0xffff0000) | sum_as_us; + old = atomicCAS(aligned, assumed, sum_as_ui); + } while (assumed != old); + __half_raw raw = {old_as_us}; + return half(raw); +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h index b52f79d6f3..df5a6d36ac 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_DATASET_UTILS_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_DATASET_UTILS_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DATASET_UTILS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DATASET_UTILS_KERNEL_H_ #include #include "ir/dtype/type.h" @@ -25,4 +25,4 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t); int ElementNums(const std::vector &shape); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_DATASET_UTILS_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DATASET_UTILS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index 4c179f2173..a6a25096fb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNEL_H_ #include #include @@ -103,4 +103,4 @@ class GpuKernel : public KernelMod { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc index 4a0191abd7..8cd5a3c4ee 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -19,7 +19,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "runtime/device/kernel_info.h" #include "runtime/device/gpu/cuda_common.h" #include "backend/kernel_compiler/common_utils.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h index 8834fa0f1a..f6ea0f0efb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ #include #include @@ -88,6 +88,12 @@ class GpuKernelRegister { static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain three typename +#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \ + #OPNAME, ATTR, []() { return new OPCLASS(); }); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h index 7c9f209a9e..67f4a9b9a9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_KERNEL_CONSTANTS_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_KERNEL_CONSTANTS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_KERNEL_CONSTANTS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_KERNEL_CONSTANTS_H_ #include #include @@ -52,4 +52,4 @@ static std::map kCudaDtypeMap = {{"kNumberTypeFloat } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_KERNEL_CONSTANTS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_KERNEL_CONSTANTS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h index b69bd20216..49b6471e28 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ #include #include @@ -140,4 +140,4 @@ class AddNGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h index 04a74b3412..f0e689f423 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ASSIGNADD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ASSIGNADD_GPU_KERNEL_H #include #include @@ -92,4 +92,4 @@ class AssignAddGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ASSIGNADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 41e7147328..ccccd767a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -51,6 +51,14 @@ MS_REG_GPU_KERNEL_TWO( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + FloorDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + AbsGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) // fp16 MS_REG_GPU_KERNEL_TWO( @@ -85,8 +93,19 @@ MS_REG_GPU_KERNEL_TWO( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + FloorDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + AbsGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) // int32 +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int, bool) MS_REG_GPU_KERNEL_TWO( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int, int) @@ -99,5 +118,14 @@ MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO( Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index aaf827723a..b6ac5a3688 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_GPU_KERNEL_H_ #include #include @@ -96,9 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, + {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, }; auto iter = kBroadcastTypeMap.find(kernel_name); @@ -137,4 +138,4 @@ class BroadcastOpGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h index 6258c5c4e2..831d04aad7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_GPU_KERNEL_H_ #include #include @@ -144,4 +144,4 @@ class BroadcastOpGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc new file mode 100644 index 0000000000..9ef1429568 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CholeskyGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h new file mode 100644 index 0000000000..abbbe049d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h @@ -0,0 +1,254 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H +#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class CholeskyGpuKernel : public GpuKernel { + public: + CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {} + ~CholeskyGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + if (!use_split_matrix) { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = input1_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + } + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + float alpha = 1; + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, + d_array_addr, lda_, d_identity_addr, ldb_, batch_), + "cublas trsm batched Fail"); + } else { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + auto d_batch_input_addr = GetDeviceAddress(workspace, 3); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = d_batch_input_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + } + Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); + MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + float alpha = 1; + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, m_, m_, &alpha, + d_array_addr, lda_, d_identity_addr, ldb_, batch_), + "cublas trsm batched Fail"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); + blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + split_dim = GetAttr(kernel_node, "split_dim"); + if (split_dim == 0) { + use_split_matrix = false; + if (in_shape.size() == 2) { + batch_ = 1; + if (in_shape[0] != in_shape[1]) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + } + } else if (in_shape.size() == 3) { + batch_ = SizeToInt(in_shape[0]); + if (in_shape[1] != in_shape[2]) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + } + } else { + MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; + } + + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } else { + if (in_shape.size() != 2) { + MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2."; + } + height = in_shape[0]; + width = in_shape[1]; + if (height != width) { + MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input."; + } + if (SizeToInt(height) <= split_dim) { + use_split_matrix = false; + batch_ = 1; + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } else { + use_split_matrix = true; + int batch = SizeToInt(in_shape[1]) / split_dim; + res_dim = in_shape[1] - batch * split_dim; + if (res_dim == 0) { + batch_ = batch; + } else { + batch_ = batch + 1; + } + m_ = split_dim; + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } + } + return true; + } + + protected: + void InitSizeLists() override { + if (!use_split_matrix) { + size_t unit_size = sizeof(T); + size_t input_size = batch_ * m_ * lda_ * unit_size; + input_size_list_.push_back(input_size); + size_t output_size = batch_ * m_ * lda_ * unit_size; + output_size_list_.push_back(output_size); + size_t workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(int); + workspace_size_list_.push_back(workspace_size); + } else { + size_t unit_size = sizeof(T); + size_t input_size = height * width * unit_size; + input_size_list_.push_back(input_size); + size_t output_size = batch_ * m_ * lda_ * unit_size; + output_size_list_.push_back(output_size); + size_t workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(T *); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * sizeof(int); + workspace_size_list_.push_back(workspace_size); + workspace_size = batch_ * m_ * lda_ * unit_size; + workspace_size_list_.push_back(workspace_size); + } + } + + private: + size_t batch_; + size_t m_; + size_t lda_; + size_t ldb_; + int res_dim; + int split_dim; + bool is_null_input_; + bool use_split_matrix; + size_t height; + size_t width; + cusolverDnHandle_t handle_; + cublasHandle_t blas_handle_; + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + std::vector h_array; + std::vector h_identity; + std::vector h_identity_data; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc new file mode 100644 index 0000000000..deb5e39ff7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CumSumGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h new file mode 100644 index 0000000000..92e9232416 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class CumSumGpuKernel : public GpuKernel { + public: + CumSumGpuKernel() : axis_(0), input_size_0_(0), stride_(0), stride2_(0) {} + ~CumSumGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + CumSum(input_addr, output_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1."; + return false; + } + input_size_0_ = sizeof(T); + shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + axis_ = GetAttr(kernel_node, "axis"); + int input_dim_length = SizeToInt(shape_.size()); + if (axis_ >= input_dim_length) { + MS_LOG(EXCEPTION) << "Axis out of bounds."; + } + while (axis_ < 0) { + axis_ += input_dim_length; + } + for (size_t i = 0; i < shape_.size(); i++) { + input_size_0_ *= shape_[i]; + } + Reshape(); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_0_); + output_size_list_.push_back(input_size_0_); + } + + private: + void Reshape() { + dims_[0] = 1; + dims_[1] = shape_[IntToSize(axis_)]; + dims_[2] = 1; + for (size_t i = 0; i < IntToSize(axis_); i++) { + dims_[0] *= shape_[i]; + } + for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) { + dims_[2] *= shape_[i]; + } + stride_ = dims_[1] * dims_[2]; + stride2_ = dims_[2]; + return; + } + int axis_; + size_t input_size_0_; + size_t stride_; + size_t stride2_; + size_t dims_[3] = {}; + std::vector shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h index be74f2e9dc..1d3cfcc8ce 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FLOAT_STATUS_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FLOAT_STATUS_GPU_KERNEL_H #include #include @@ -127,4 +127,4 @@ class FloatStatusGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc new file mode 100644 index 0000000000..5def6b3af4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(NMSWithMask, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + NMSWithMaskGpuFwdKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h new file mode 100644 index 0000000000..a5e0464cb9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class NMSWithMaskGpuFwdKernel : public GpuKernel { + public: + NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {} + ~NMSWithMaskGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *data_buff = GetDeviceAddress(workspace, 0); // sort buffer + int *index_buff = GetDeviceAddress(workspace, 1); + T *area = GetDeviceAddress(workspace, 2); // store area values for all boxes + T *output = GetDeviceAddress(outputs, 0); + int *sel_idx = GetDeviceAddress(outputs, 1); + bool *sel_boxes = GetDeviceAddress(outputs, 2); + + BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_, + reinterpret_cast(stream_ptr)); + CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast(stream_ptr)); + CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_, + reinterpret_cast(stream_ptr)); + CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + iou_value_ = GetAttr(kernel_node, "iou_threshold"); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but NMSWithMask needs 1 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 3) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but NMSWithMask needs 3 output."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (CHECK_NULL_INPUT(input_shape)) { + MS_LOG(WARNING) << "NMSWithMask input is null"; + InitSizeLists(); + return true; + } + + num_input_ = input_shape[0]; // Get N value in [N,5] data + + input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox + output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool)); + workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int)); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // N sized input/output data + input_size_list_.push_back(num_input_ * sizeof(T) * box_size_); + output_size_list_.push_back(num_input_ * sizeof(T) * box_size_); + output_size_list_.push_back(num_input_ * sizeof(int)); + output_size_list_.push_back(num_input_ * sizeof(bool)); + + // N sized workspace arrs + workspace_size_list_.push_back(num_input_ * sizeof(T)); + workspace_size_list_.push_back(num_input_ * sizeof(int)); + workspace_size_list_.push_back(num_input_ * sizeof(T)); + } + + private: + int num_input_; + float iou_value_; + static const int box_size_ = 5; // pre_defined box width + // int box_size__ = 5; // current size of bboxes + // default values + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc index c72c271c52..8dfd4eef08 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc @@ -20,5 +20,12 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RandomOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(UniformReal, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RandomOpGpuKernel, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index 785ac02ee5..98a421c922 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOMOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOMOP_GPU_KERNEL_H_ #include #include @@ -28,17 +28,22 @@ namespace mindspore { namespace kernel { -enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; +enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; -const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; +const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}, + {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; template class RandomOpGpuKernel : public GpuKernel { public: RandomOpGpuKernel() : random_op_type_(RANDOM_OP_INVALID_TYPE), - input_size_0_(0), + input_size_0_(sizeof(int)), + input_size_1_(sizeof(T)), + input_size_2_(sizeof(T)), output_size_(sizeof(T)), - workspace_size_(sizeof(curandState)) {} + workspace_size_(sizeof(curandState)), + seed_(0), + seed2_(0) {} ~RandomOpGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -57,12 +62,21 @@ class RandomOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } + case RANDOM_OP_UNIFORM_REAL: { + T *input_addr_1 = GetDeviceAddress(inputs, 1); + T *input_addr_2 = GetDeviceAddress(inputs, 2); + UniformReal(seed_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, + inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; } } return true; } + bool Init(const CNodePtr &kernel_node) override { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); auto iter = kRandomOpTypeMap.find(kernel_name); @@ -72,10 +86,14 @@ class RandomOpGpuKernel : public GpuKernel { random_op_type_ = iter->second; } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { + if (random_op_type_ == RANDOM_OP_NORMAL && input_num != 1) { MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; return false; } + if (random_op_type_ == RANDOM_OP_UNIFORM_REAL && input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 3 inputs."; + return false; + } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; @@ -86,13 +104,25 @@ class RandomOpGpuKernel : public GpuKernel { input_size_0_ += input_shape_0[i]; } input_size_0_ *= sizeof(int); + if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { + auto input_shape_1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < input_shape_1.size(); i++) { + input_size_1_ *= input_shape_1[i]; + } + auto input_shape_2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < input_shape_2.size(); i++) { + input_size_2_ *= input_shape_2[i]; + } + } auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); for (size_t i = 0; i < output_shape.size(); i++) { output_size_ *= output_shape[i]; workspace_size_ *= output_shape[i]; } seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); - seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + if (random_op_type_ == RANDOM_OP_NORMAL) { + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + } InitSizeLists(); return true; } @@ -100,6 +130,10 @@ class RandomOpGpuKernel : public GpuKernel { protected: void InitSizeLists() override { input_size_list_.push_back(input_size_0_); + if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { + input_size_list_.push_back(input_size_1_); + input_size_list_.push_back(input_size_2_); + } output_size_list_.push_back(output_size_); workspace_size_list_.push_back(workspace_size_); } @@ -107,6 +141,8 @@ class RandomOpGpuKernel : public GpuKernel { private: RandomOptype random_op_type_; size_t input_size_0_; + size_t input_size_1_; + size_t input_size_2_; size_t output_size_; size_t workspace_size_; int seed_; @@ -118,4 +154,4 @@ class RandomOpGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOMOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index ae8e7bbd0b..d646ef417c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -46,5 +46,13 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 26993bc3bd..7e3f2c862e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GPU_KERNEL_H_ #include #include @@ -36,6 +36,8 @@ enum UnaryOptype { UNARY_OP_SQUARE, UNARY_OP_SQRT, UNARY_OP_RSQRT, + UNARY_OP_ABS, + UNARY_OP_FLOOR, UNARY_OP_INVALID_TYPE = 255 }; static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, @@ -45,7 +47,9 @@ static const std::map kUnaryOpTypeMap = {{"Exp", UNARY {"ZerosLike", UNARY_OP_ZEROSLIKE}, {"Square", UNARY_OP_SQUARE}, {"Sqrt", UNARY_OP_SQRT}, - {"Rsqrt", UNARY_OP_RSQRT}}; + {"Rsqrt", UNARY_OP_RSQRT}, + {"Abs", UNARY_OP_ABS}, + {"Floor", UNARY_OP_FLOOR}}; template class UnaryOpGpuKernel : public GpuKernel { public: @@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel { Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; } + case UNARY_OP_ABS: { + Abs(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_FLOOR: { + Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; } @@ -158,4 +170,4 @@ class UnaryOpGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.cc new file mode 100644 index 0000000000..0746a70262 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/update_thor_gradient.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(UpdateThorGradient, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + UpdateThorGradientGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h new file mode 100644 index 0000000000..92d777d8d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h @@ -0,0 +1,241 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_UPDATE_THOR_GRADIENT_GPU_KERNEL_H +#define MINDSPORE_UPDATE_THOR_GRADIENT_GPU_KERNEL_H +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +struct GradientSize { + size_t batch_h; + size_t batch_w; + size_t h; + size_t w; + size_t ori_h; + size_t ori_w; + size_t pad_h; + size_t pad_w; + bool need_convert; + cudaDataType_t dtype; +}; +template +class UpdateThorGradientGpuKernel : public GpuKernel { + public: + UpdateThorGradientGpuKernel() : split_dim(128) {} + ~UpdateThorGradientGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto input1_addr = GetDeviceAddress(inputs, 0); + auto input2_addr = GetDeviceAddress(inputs, 1); + auto input3_addr = GetDeviceAddress(inputs, 2); + auto workspace1_addr = GetDeviceAddress(workspace, 0); + T *workspace2_addr = nullptr; + T *workspace3_addr = nullptr; + if (gradient_size.need_convert) { + workspace2_addr = GetDeviceAddress(workspace, 1); + workspace3_addr = GetDeviceAddress(workspace, 2); + } + T *workspace4_addr = nullptr; + auto output_addr = GetDeviceAddress(outputs, 0); + if (gradient_size.pad_h != 0 || gradient_size.pad_w != 0) { + workspace4_addr = GetDeviceAddress(workspace, 3); + const size_t size = (gradient_size.ori_h + gradient_size.pad_h) * (gradient_size.ori_w + gradient_size.pad_w); + CalPad(size, input2_addr, 1, 1, gradient_size.ori_h, gradient_size.ori_w, + gradient_size.ori_h + gradient_size.pad_h, gradient_size.ori_w + gradient_size.pad_w, 0, 0, 0.0, + workspace4_addr, reinterpret_cast(stream_ptr)); + cudaMemsetAsync(workspace1_addr, 0, + gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * sizeof(T), + reinterpret_cast(stream_ptr)); + input2_addr = workspace4_addr; + } + const float alpha = 1; + const float beta = 0; + const int lda = SizeToInt(gradient_size.h); + const int ldb = SizeToInt(gradient_size.ori_w + gradient_size.pad_w); + const int ldc = SizeToInt(gradient_size.ori_w + gradient_size.pad_w); + + auto stride_a = SizeToInt(gradient_size.h * gradient_size.h); + auto stride_b = SizeToInt(gradient_size.h * (gradient_size.ori_w + gradient_size.pad_w)); + auto stride_c = SizeToInt(gradient_size.h * (gradient_size.ori_w + gradient_size.pad_w)); + + try { + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(gradient_size.ori_w), + SizeToInt(gradient_size.h), SizeToInt(gradient_size.h), &alpha, input2_addr, + gradient_size.dtype, ldb, stride_b, input1_addr, gradient_size.dtype, lda, stride_a, + &beta, workspace1_addr, gradient_size.dtype, ldc, stride_c, gradient_size.batch_h, + CUDA_R_32F, algo_), + "cublasSgemm Call Fail"); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << "when invoke cubals cublasGemmStridedBatchedEx"; + } + + auto r_input_addr = workspace1_addr; + if (gradient_size.need_convert) { + size_t size = gradient_size.batch_w * gradient_size.batch_h * gradient_size.w * gradient_size.h; + ConvertGradient(size, gradient_size.h, gradient_size.w, gradient_size.batch_w, + gradient_size.batch_w * gradient_size.w, workspace1_addr, workspace2_addr, + reinterpret_cast(stream_ptr)); + r_input_addr = workspace2_addr; + } + + const int lda_r = SizeToInt(gradient_size.w); + const int ldb_r = SizeToInt(gradient_size.w); + const int ldc_r = SizeToInt(gradient_size.w); + + stride_a = SizeToInt(gradient_size.h * gradient_size.w); + stride_b = SizeToInt(gradient_size.w * gradient_size.w); + stride_c = SizeToInt(gradient_size.h * gradient_size.w); + auto r_output_addr = output_addr; + if (gradient_size.need_convert) { + r_output_addr = workspace3_addr; + } + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(gradient_size.w), + SizeToInt(gradient_size.h), SizeToInt(gradient_size.w), &alpha, input3_addr, + gradient_size.dtype, ldb_r, stride_b, r_input_addr, gradient_size.dtype, lda_r, + stride_a, &beta, r_output_addr, gradient_size.dtype, ldc_r, stride_c, + gradient_size.batch_h * gradient_size.batch_w, CUDA_R_32F, algo_), + "cublasSgemm Call Fail"); + if (gradient_size.need_convert) { + size_t size = gradient_size.batch_w * gradient_size.batch_h * gradient_size.w * gradient_size.h; + if (gradient_size.pad_h == 0 && gradient_size.pad_w == 0) { + ConvertGradientBack(size, gradient_size.h, gradient_size.w, gradient_size.batch_w, + gradient_size.batch_w * gradient_size.w, r_output_addr, output_addr, + reinterpret_cast(stream_ptr)); + } else { + ConvertGradientBack(size, gradient_size.h, gradient_size.w, gradient_size.ori_h, gradient_size.ori_w, + gradient_size.batch_w, gradient_size.ori_w, r_output_addr, output_addr, + reinterpret_cast(stream_ptr)); + } + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + SetProperty(kernel_node); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t unit_size = sizeof(T); + size_t input_size_ = gradient_size.h * gradient_size.h * gradient_size.batch_h * unit_size; + input_size_list_.push_back(input_size_); + + input_size_ = gradient_size.ori_h * gradient_size.ori_w * unit_size; + input_size_list_.push_back(input_size_); + + input_size_ = gradient_size.w * gradient_size.w * gradient_size.batch_w * unit_size; + input_size_list_.push_back(input_size_); + + size_t output_size = gradient_size.ori_h * gradient_size.ori_w * unit_size; + output_size_list_.push_back(output_size); + + size_t workspace_size_ = 0; + workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; + workspace_size_list_.push_back(workspace_size_); + + if (gradient_size.need_convert) { + workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; + workspace_size_list_.push_back(workspace_size_); + workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; + workspace_size_list_.push_back(workspace_size_); + } + + if (gradient_size.pad_h != 0 || gradient_size.pad_w != 0) { + workspace_size_ = + (gradient_size.ori_w + gradient_size.pad_w) * (gradient_size.ori_h + gradient_size.pad_h) * unit_size; + workspace_size_list_.push_back(workspace_size_); + } + } + + private: + void SetProperty(const CNodePtr &kernel_node) { + auto matrix_a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto matrix_g_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + + split_dim = size_t(GetAttr(kernel_node, "split_dim")); + + gradient_size.batch_h = gradient_shape[0] / split_dim; + gradient_size.batch_w = gradient_shape[1] / split_dim; + if (gradient_size.batch_h * split_dim != gradient_shape[0]) { + gradient_size.batch_h += 1; + if (gradient_shape[0] > split_dim) { + gradient_size.h = split_dim; + gradient_size.pad_h = gradient_size.batch_h * split_dim - gradient_shape[0]; + } else { + gradient_size.h = gradient_shape[0]; + gradient_size.pad_h = 0; + } + } else { + gradient_size.h = split_dim; + gradient_size.pad_h = 0; + } + + if (gradient_size.batch_w * split_dim != gradient_shape[1]) { + gradient_size.batch_w += 1; + if (gradient_shape[1] > split_dim) { + gradient_size.w = split_dim; + gradient_size.pad_w = gradient_size.batch_w * split_dim - gradient_shape[1]; + } else { + gradient_size.w = gradient_shape[1]; + gradient_size.pad_w = 0; + } + } else { + gradient_size.w = split_dim; + gradient_size.pad_w = 0; + } + + if (gradient_size.batch_w * gradient_size.w <= split_dim) { + gradient_size.need_convert = false; + } else { + gradient_size.need_convert = true; + } + + gradient_size.ori_w = gradient_shape[1]; + gradient_size.ori_h = gradient_shape[0]; + gradient_size.dtype = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); + } + + size_t split_dim; + struct GradientSize gradient_size; + cublasHandle_t handle_; + cublasGemmAlgo_t algo_ = CUBLAS_GEMM_DEFAULT; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc index c6e3c4c043..8374914dd5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc @@ -24,17 +24,28 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AllReduce, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) + MS_REG_GPU_KERNEL_ONE( AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), NcclGpuKernel, float) MS_REG_GPU_KERNEL_ONE( AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AllGather, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) + MS_REG_GPU_KERNEL_ONE( ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), NcclGpuKernel, float) MS_REG_GPU_KERNEL_ONE( ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceScatter, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h index 9701738bfc..18caa149f6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_GPU_KERNEL_H_ #include #include @@ -185,4 +185,4 @@ class NcclGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index d651da75e0..b434ddadd5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ #include #include @@ -139,4 +139,4 @@ class ActivationGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index ffdb618098..2d7b2012f3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ #include #include @@ -143,4 +143,4 @@ class ActivationGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h index e2fc87ed51..8ac9839dcf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -139,4 +139,4 @@ class AdamGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h index 3e15b818be..dd44682e83 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ #include #include @@ -155,4 +155,4 @@ class BiasAddGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.cc new file mode 100644 index 0000000000..4f6fd5a2c1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropy, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BinaryCrossEntropyGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h new file mode 100644 index 0000000000..8ccbc22d68 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BinaryCrossEntropyGpuKernel : public GpuKernel { + public: + BinaryCrossEntropyGpuKernel() : input_size_(1), reduction_(1) {} + ~BinaryCrossEntropyGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *weight = GetDeviceAddress(inputs, 2); + T *loss = GetDeviceAddress(outputs, 0); + + BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + string reduction = GetAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + if (reduction_ == 0) { + output_size_list_.push_back(input_size_ * sizeof(T)); + } else { + output_size_list_.push_back(sizeof(T)); + } + } + + private: + size_t input_size_; + int reduction_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BINARY_CROSS_ENTROPY_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.cc new file mode 100644 index 0000000000..f55644173c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropyGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BinaryCrossEntropyGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h new file mode 100644 index 0000000000..326d1c82c2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BinaryCrossEntropyGradGpuKernel : public GpuKernel { + public: + BinaryCrossEntropyGradGpuKernel() : input_size_(1), reduction_(1) {} + ~BinaryCrossEntropyGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *dloss = GetDeviceAddress(inputs, 2); + T *weight = GetDeviceAddress(inputs, 3); + T *dx = GetDeviceAddress(outputs, 0); + BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + string reduction = GetAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + if (reduction_ == 0) { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } else { + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(sizeof(T)); + } + } + + private: + size_t input_size_; + int reduction_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index 6072614e22..c5e8a26801 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_ #include #include @@ -109,12 +109,14 @@ class Conv2dGpuFwdKernel : public GpuKernel { Set4DDesc(in_shape, filter_shape, output_shape); group_ = GetAttr(kernel_node, "group"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; + auto pad_list = GetAttr>(kernel_node, "pad_list"); + pad_height_ = pad_list[0]; + pad_width_ = pad_list[2]; + auto symmetry_pad = (pad_height_ == pad_list[1]) && (pad_width_ == pad_list[3]); pad_mode_ = GetAttr(kernel_node, "pad_mode"); SetStrideAndDilation(kernel_node); cudnnTensorDescriptor_t input_descriptor_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { SetPad(in_shape, kernel_node); input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; } else { @@ -317,4 +319,4 @@ class Conv2dGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h index 638da4a99f..ac4d127e43 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ #include #include @@ -113,12 +113,14 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { group_ = GetAttr(kernel_node, "group"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; + auto pad_list = GetAttr>(kernel_node, "pad_list"); + pad_height_ = pad_list[0]; + pad_width_ = pad_list[2]; + auto symmetry_pad = (pad_height_ == pad_list[1]) && (pad_width_ == pad_list[3]); pad_mode_ = GetAttr(kernel_node, "pad_mode"); SetStrideAndDilation(kernel_node); cudnnTensorDescriptor_t x_desc_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { SetPad(in_shape, kernel_node); x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; } else { @@ -317,4 +319,4 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h index a9a1e5c0cc..e40bd6898f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ #include #include @@ -114,12 +114,14 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { group_ = GetAttr(kernel_node, "group"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; + auto pad_list = GetAttr>(kernel_node, "pad_list"); + pad_height_ = pad_list[0]; + pad_width_ = pad_list[2]; + auto symmetry_pad = (pad_height_ == pad_list[1]) && (pad_width_ == pad_list[3]); pad_mode_ = GetAttr(kernel_node, "pad_mode"); SetStrideAndDilation(kernel_node); cudnnTensorDescriptor_t dx_desc_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { SetPad(input_shape, kernel_node); dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; } else { @@ -312,4 +314,4 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h index 8b02354516..cf72a32b65 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_ #include #include @@ -51,10 +51,12 @@ class CtcLossGpuKernel : public GpuKernel { float *grads = GetDeviceAddress(outputs, 1); // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires - void *labels_host = nullptr; + int *labels_host = nullptr; + int *no_blank_labels_host = nullptr; void *input_lengths_host = nullptr; void *label_lengths_host = nullptr; CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); cudaStream_t stream = reinterpret_cast(stream_ptr); @@ -68,12 +70,21 @@ class CtcLossGpuKernel : public GpuKernel { "cudaMemcpyAsync failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + + size_t j = 0; + for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) { + if (labels_host[i] != 0) { + no_blank_labels_host[j] = labels_host[i]; + j++; + } + } + size_t workspace_size = 0; CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(labels_host), - reinterpret_cast(label_lengths_host), - reinterpret_cast(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, - ctcloss_desc_, &workspace_size), + cudnnGetCTCLossWorkspaceSize( + cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(no_blank_labels_host), + reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), + CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size), "cudnnGetCTCLossWorkspaceSize failed."); void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); if (workspace == nullptr) { @@ -81,7 +92,7 @@ class CtcLossGpuKernel : public GpuKernel { } CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(labels_host), + cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(no_blank_labels_host), reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), costs, probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), "cudnnCtcLoss failed."); @@ -91,6 +102,7 @@ class CtcLossGpuKernel : public GpuKernel { CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed."); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -163,4 +175,4 @@ class CtcLossGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h index 2104d7af35..632caef9ec 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -115,4 +115,4 @@ class DropoutGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h index a3a7250c9b..626692516c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GRAD_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -97,4 +97,4 @@ class DropoutGradGpuBwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_DROPOUT_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h index a140579a3c..baf6e35f2e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FLATTEN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FLATTEN_GPU_KERNEL_H_ #include #include @@ -75,4 +75,4 @@ class FlattenGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FLATTEN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h index b21327bc3b..700fe5884a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ #include #include @@ -86,4 +86,4 @@ class FlattenGardGpuBkwKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h index ea08741dba..1fbe2aa332 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FTRL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FTRL_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -127,4 +127,4 @@ class FtrlGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h index c4fd31a737..10e90aef8d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -100,4 +100,4 @@ class FusedAdamWeightDecayGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc index 2ce39b63a0..ddd9c2f8d0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, KernelAttr() .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, half) MS_REG_GPU_KERNEL_ONE(BatchNorm, KernelAttr() @@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, MS_REG_GPU_KERNEL_ONE(BatchNorm, KernelAttr() .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h index 774428dc40..b029929b02 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel { return true; } auto x = GetDeviceAddress(inputs, 0); - auto scale = GetDeviceAddress(inputs, 1); - auto bias = GetDeviceAddress(inputs, 2); - auto runing_mean = GetDeviceAddress(inputs, 3); - auto runnig_variance = GetDeviceAddress(inputs, 4); + auto scale = GetDeviceAddress(inputs, 1); + auto bias = GetDeviceAddress(inputs, 2); + auto runing_mean = GetDeviceAddress(inputs, 3); + auto runnig_variance = GetDeviceAddress(inputs, 4); auto y = GetDeviceAddress(outputs, 0); const float alpha = 1; const float beta = 0; if (is_train_) { - auto save_mean = GetDeviceAddress(outputs, 3); - auto save_variance = GetDeviceAddress(outputs, 4); + auto save_mean = GetDeviceAddress(outputs, 3); + auto save_variance = GetDeviceAddress(outputs, 4); CHECK_CUDNN_RET_WITH_EXCEPT( cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, @@ -187,4 +187,4 @@ class FusedBatchNormGpuKernel : public GpuKernel { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc index 546e034f6b..7cd993d0f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc @@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h index a2d0d741b1..b22cc2f03f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { } auto dy = GetDeviceAddress(inputs, 0); auto x = GetDeviceAddress(inputs, 1); - auto scale = GetDeviceAddress(inputs, 2); - auto save_mean = GetDeviceAddress(inputs, 3); - auto save_variance = GetDeviceAddress(inputs, 4); + auto scale = GetDeviceAddress(inputs, 2); + auto save_mean = GetDeviceAddress(inputs, 3); + auto save_variance = GetDeviceAddress(inputs, 4); auto dx = GetDeviceAddress(outputs, 0); - auto bn_scale = GetDeviceAddress(outputs, 1); - auto bn_bias = GetDeviceAddress(outputs, 2); + auto bn_scale = GetDeviceAddress(outputs, 1); + auto bn_bias = GetDeviceAddress(outputs, 2); const float alpha_data_diff = 1; const float beta_data_diff = 0; @@ -175,4 +175,4 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h index 823da1fe9f..27afbda57e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GELU_GRAD_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -72,4 +72,4 @@ class GeLUGpuGradKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h index 76d3861d55..41cf065a83 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GELU_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -69,4 +69,4 @@ class GeluGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc new file mode 100644 index 0000000000..b7a71308ce --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Im2ColGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Im2ColGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h new file mode 100644 index 0000000000..53a711775a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h @@ -0,0 +1,269 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class Im2ColGpuFwdKernel : public GpuKernel { + public: + Im2ColGpuFwdKernel() + : cudnn_handle_(nullptr), + input_desc_(nullptr), + output_desc_(nullptr), + filter_desc_(nullptr), + conv_desc_(nullptr), + padded_desc_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + is_null_input_(false), + input_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~Im2ColGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded_addr = GetDeviceAddress(workspace, 0); + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnIm2Col(cudnn_handle_, padded_desc_, padded_addr, filter_desc_, conv_desc_, output_addr), + "cudnnIm2ColForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnIm2Col(cudnn_handle_, input_desc_, input_addr, filter_desc_, conv_desc_, output_addr), + "cudnnIm2ColForward failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto filter_shape = GetAttr>(kernel_node, "kernel_size"); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "cudnnIm2ColForward input is null."; + InitSizeLists(); + return true; + } + Set4DDesc(in_shape, filter_shape, output_shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, 1), "cudnnSetConvGroupCount failed"); + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(in_shape, kernel_node); + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + workspace_size_list_.push_back(padded_size_); + } + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Im2Col needs 1 inputs."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Im2Col needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_height_ = SizeToInt(in_shape[2]); + old_width_ = SizeToInt(in_shape[3]); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + + // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, + old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + + void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, + const std::vector &output_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), + SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 1, + SizeToInt(in_shape[1]), filter_shape[0], filter_shape[1]), + "cudnnSetFilter4dDescriptor failed"); + + auto out_H = output_shape[0] * output_shape[1] * output_shape[2]; + auto out_W = output_shape[3] * output_shape[4] * output_shape[5]; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + SizeToInt(out_H), SizeToInt(out_W), 1, 1), + "cudnnSetTensor4dDescriptor failed"); + } + + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 4) { + MS_LOG(EXCEPTION) << "Im2Col's stride must be 4d!"; + } + if (stride_[0] != 1 || stride_[1] != 1) { + MS_LOG(EXCEPTION) << "Im2Col's stride only support 1 in N axis and C axis!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "Im2Col's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "Im2Col's dilation only support 1 in N axis and C axis!"; + } + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionFwdAlgo_t conv_algorithm_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t padded_desc_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc new file mode 100644 index 0000000000..20cb060ea0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + KLDivLoss, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KLDivLossGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h new file mode 100644 index 0000000000..43aced9494 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class KLDivLossGpuKernel : public GpuKernel { + public: + KLDivLossGpuKernel() : input_size_(1), reduction_(1) {} + ~KLDivLossGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *loss = GetDeviceAddress(outputs, 0); + + KLDivLoss(input_size_, reduction_, input_x, input_y, loss, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + string reduction = GetAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + if (reduction_ == 0) { + output_size_list_.push_back(input_size_ * sizeof(T)); + } else { + output_size_list_.push_back(sizeof(T)); + } + } + + private: + size_t input_size_; + int reduction_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc new file mode 100644 index 0000000000..83371f580b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(KLDivLossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KLDivLossGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h new file mode 100644 index 0000000000..37a0c76a8c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class KLDivLossGradGpuKernel : public GpuKernel { + public: + KLDivLossGradGpuKernel() : input_size_(1), reduction_(1) {} + ~KLDivLossGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *dloss = GetDeviceAddress(inputs, 2); + T *dx = GetDeviceAddress(outputs, 0); + T *dy = GetDeviceAddress(outputs, 1); + KLDivLossGrad(input_size_, reduction_, input_x, input_y, dloss, dx, dy, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + string reduction = GetAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + if (reduction_ == 0) { + input_size_list_.push_back(input_size_ * sizeof(T)); + } else { + input_size_list_.push_back(sizeof(T)); + } + } + + private: + size_t input_size_; + int reduction_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_KL_DIV_LOSS_GRAD_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h index 74669e03de..7f65573a96 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -100,4 +100,4 @@ class LayerNormGpuKernel : public GpuKernel { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h index 93967adad3..8e4b12af2d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -104,4 +104,4 @@ class LayerNormGradGpuKernel : public GpuKernel { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h index ad3e588f00..2db0b15200 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GPU_KERNEL_H_ #include #include @@ -244,4 +244,4 @@ class LstmGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h index 6d6bed5555..29d1de9a89 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ #include #include @@ -281,4 +281,4 @@ class LstmGradDataGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h index 445d2ce199..8a73de52a3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ #include #include @@ -228,4 +228,4 @@ class LstmGradWeightGpuKernel : public GpuKernel { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc new file mode 100644 index 0000000000..1866c83466 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + MaxPoolWithArgmax, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + MaxPoolWithArgmaxGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + MaxPoolWithArgmax, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + MaxPoolWithArgmaxGpuFwdKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h new file mode 100644 index 0000000000..aef408c403 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class MaxPoolWithArgmaxGpuFwdKernel : public GpuKernel { + public: + MaxPoolWithArgmaxGpuFwdKernel() + : n_(0), + c_(0), + input_height_(0), + input_width_(0), + window_height_(0), + window_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + stride_height_(0), + stride_width_(0), + output_height_(0), + output_width_(0), + input_size_(0), + output_size_(0) {} + ~MaxPoolWithArgmaxGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + S *index_addr = GetDeviceAddress(outputs, 1); + CalMaxPoolWithArgmax(input_addr, n_, c_, input_height_, input_width_, window_height_, window_width_, stride_height_, + stride_width_, pad_top_, pad_left_, output_height_, output_width_, output_addr, index_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MaxPoolWithArgmax needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but MaxPoolWithArgmax needs 2 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (auto x : input_shape) { + input_size_ *= x; + } + output_size_ = sizeof(T); + for (auto x : output_shape) { + output_size_ *= x; + } + n_ = SizeToInt(input_shape[0]); + c_ = SizeToInt(input_shape[1]); + input_height_ = SizeToInt(input_shape[2]); + input_width_ = SizeToInt(input_shape[3]); + output_height_ = SizeToInt(output_shape[2]); + output_width_ = SizeToInt(output_shape[3]); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + window_height_ = window[1]; + window_width_ = window[2]; + auto stride = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + stride_height_ = stride[1]; + stride_width_ = stride[2]; + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + pad_top_ = 0; + pad_left_ = 0; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_ / sizeof(T) * sizeof(S)); + } + + private: + void SetPad() { + pad_height_ = std::max( + 0, (((input_height_ / stride_height_) * stride_height_ == input_height_ ? (input_height_ / stride_height_) + : (input_height_ / stride_height_) + 1) - + 1) * + stride_height_ + + window_height_ - input_height_); + pad_width_ = std::max( + 0, (((input_width_ / stride_width_) * stride_width_ == input_width_ ? (input_width_ / stride_width_) + : (input_width_ / stride_width_) + 1) - + 1) * + stride_width_ + + window_width_ - input_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + } + + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int n_; + int c_; + int input_height_; + int input_width_; + int window_height_; + int window_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int stride_height_; + int stride_width_; + int output_height_; + int output_width_; + + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc new file mode 100644 index 0000000000..954a5cfbf9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(MaxPoolGradWithArgmax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + MaxPoolWithArgmaxGradGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO(MaxPoolGradWithArgmax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + MaxPoolWithArgmaxGradGpuKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h new file mode 100644 index 0000000000..9d90e2d9f4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel { + public: + MaxPoolWithArgmaxGradGpuKernel() + : n_(0), + c_(0), + x_height_(0), + x_width_(0), + dy_height_(0), + dy_width_(0), + x_size_(0), + dy_size_(0), + index_size_(0), + dx_size_(0) {} + ~MaxPoolWithArgmaxGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *x_addr = GetDeviceAddress(inputs, 0); + T *dy_addr = GetDeviceAddress(inputs, 1); + S *index_addr = GetDeviceAddress(inputs, 2); + T *dx_addr = GetDeviceAddress(outputs, 0); + CalMaxPoolWithArgmaxGrad(x_addr, dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_, + window_height_, window_width_, stride_height_, stride_width_, pad_top_, pad_left_, dx_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MaxPoolGradWithArgmax needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but MaxPoolGradWithArgmax needs 1 output."; + return false; + } + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto index_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + auto dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + x_size_ = sizeof(T); + for (auto x : x_shape) { + x_size_ *= x; + } + dy_size_ = sizeof(T); + for (auto x : dy_shape) { + dy_size_ *= x; + } + index_size_ = sizeof(S); + for (auto x : index_shape) { + index_size_ *= x; + } + dx_size_ = sizeof(T); + for (auto x : dx_shape) { + dx_size_ *= x; + } + n_ = SizeToInt(x_shape[0]); + c_ = SizeToInt(x_shape[1]); + x_height_ = SizeToInt(x_shape[2]); + x_width_ = SizeToInt(x_shape[3]); + dy_height_ = SizeToInt(dy_shape[2]); + dy_width_ = SizeToInt(dy_shape[3]); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + window_height_ = window[1]; + window_width_ = window[2]; + auto stride = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + stride_height_ = stride[1]; + stride_width_ = stride[2]; + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + pad_top_ = 0; + pad_left_ = 0; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(x_size_); + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(index_size_); + output_size_list_.push_back(dx_size_); + } + + private: + void SetPad() { + pad_height_ = std::max( + 0, (((x_height_ / stride_height_) * stride_height_ == x_height_ ? (x_height_ / stride_height_) + : (x_height_ / stride_height_) + 1) - + 1) * + stride_height_ + + window_height_ - x_height_); + pad_width_ = + std::max(0, (((x_width_ / stride_width_) * stride_width_ == x_width_ ? (x_width_ / stride_width_) + : (x_width_ / stride_width_) + 1) - + 1) * + stride_width_ + + window_width_ - x_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + } + + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int n_; + int c_; + int x_height_; + int x_width_; + int dy_height_; + int dy_width_; + int window_height_; + int window_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int stride_height_; + int stride_width_; + + size_t x_size_; + size_t dy_size_; + size_t index_size_; + size_t dx_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc new file mode 100644 index 0000000000..306a7b09ab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + MirrorPad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + MirrorPadGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + MirrorPad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + MirrorPadGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h new file mode 100644 index 0000000000..c7c4aefb76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class MirrorPadGpuFwdKernel : public GpuKernel { + public: + MirrorPadGpuFwdKernel() + : num_input_(0), num_paddings_(0), mode_(0), input_size_(1), output_size_(1), workspace_size_(0) {} + ~MirrorPadGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + int *paddings = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = output_size_ / sizeof(T); + int dim_offset = output_shape_.size() - 2; + + CalMirrorPad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + output_shape_[dim_offset + 0], output_shape_[dim_offset + 1], num_paddings_, paddings, mode_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MirrorPad needs 2 input."; + return false; + } + // check number of output -> should be 1 + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; + return false; + } + + string mode = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("mode")); + + if (mode == "REFLECT") { + mode_ = 0; // reflected mirroring + } else { + mode_ = 1; // symmetric mirroring + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + // shape adjustement -> from 2d/3d to 4d to standardize + if (input_shape.size() == 4) { + } else if (input_shape.size() == 3) { + auto it = input_shape.begin(); + input_shape.insert(it, 1); // batch padding + } else if (input_shape.size() == 2) { + auto it = input_shape.begin(); + input_shape.insert(it, 2, 1); // channel padding + } + + for (auto in_shape : input_shape) { + input_size_ *= in_shape; + input_shape_.push_back(in_shape); + } + num_input_ = input_size_; + input_size_ *= sizeof(T); + + auto padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + num_paddings_ = padding_shape[0]; + input_size_ += 2 * num_paddings_ * sizeof(int); + + output_size_ = sizeof(T); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto x : output_shape) { + output_size_ *= x; + output_shape_.push_back(x); + } + + int max_width = input_shape_[3]; + int max_height = input_shape_[2]; + + // basic error check for padding value + if (mode_ == 1) { // symmetric + max_width = max_width + (2 * max_width); + max_height = max_height + (2 * max_height); + } else { // reflect + max_width = max_width + (2 * (max_width - 1)); + max_height = max_height + (2 * (max_height - 1)); + } + + if (output_shape_[(output_shape_.size() - 2) + 0] > max_width || + output_shape_[(output_shape_.size() - 2) + 1] > max_width) { + MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more dims"; + return false; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(num_input_ * sizeof(T)); + input_size_list_.push_back(2 * num_paddings_ * sizeof(int)); + output_size_list_.push_back(output_size_); + } + + private: + size_t num_input_; + int num_paddings_; + int mode_; + std::vector input_shape_; // dims of the input data + std::vector output_shape_; // dims of the output data + // default + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc new file mode 100644 index 0000000000..599e51272b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + MirrorPadGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + MirrorPadGpuBackKernel, float) +MS_REG_GPU_KERNEL_ONE( + MirrorPadGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + MirrorPadGpuBackKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h new file mode 100644 index 0000000000..2f793aba77 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class MirrorPadGpuBackKernel : public GpuKernel { + public: + MirrorPadGpuBackKernel() + : num_input_(0), num_paddings_(0), mode_(0), input_size_(1), output_size_(1), workspace_size_(0) {} + ~MirrorPadGpuBackKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + int *paddings = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = output_size_ / sizeof(T); + int dim_offset = output_shape_.size() - 2; + + CalMirrorPadGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + output_shape_[dim_offset + 0], output_shape_[dim_offset + 1], num_paddings_, paddings, mode_, + output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MirrorPadGrad needs 2 input."; + return false; + } + // check number of output -> should be 1 + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but MirrorPadGrad needs 1 output."; + return false; + } + + string mode = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("mode")); + + if (mode == "REFLECT") { + mode_ = 0; // reflected mirroring + } else { + mode_ = 1; // symmetric mirroring + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + // shape adjustement -> from 2d/3d to 4d to standardize + if (input_shape.size() == 4) { + } else if (input_shape.size() == 3) { + auto it = input_shape.begin(); + input_shape.insert(it, 1); // batch padding + } else if (input_shape.size() == 2) { + auto it = input_shape.begin(); + input_shape.insert(it, 2, 1); // channel padding + } + + for (auto in_shape : input_shape) { + input_size_ *= in_shape; + input_shape_.push_back(in_shape); + } + num_input_ = input_size_; + input_size_ *= sizeof(T); + + auto padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + num_paddings_ = padding_shape[0]; + input_size_ += +(2 * num_paddings_ * sizeof(int)); + + output_size_ = sizeof(T); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto x : output_shape) { + output_size_ *= x; + output_shape_.push_back(x); + } + + int max_width = input_shape_[3]; + int max_height = input_shape_[2]; + + // basic error check for padding value + if (mode_ == 1) { // symmetric + max_width = max_width + (2 * max_width); + max_height = max_height + (2 * max_height); + } else { // reflect + max_width = max_width + (2 * (max_width - 1)); + max_height = max_height + (2 * (max_height - 1)); + } + + if (output_shape_[(output_shape_.size() - 2) + 0] > max_width || + output_shape_[(output_shape_.size() - 2) + 1] > max_width) { + MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more DIMS"; + return false; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(num_input_ * sizeof(T)); + input_size_list_.push_back(2 * num_paddings_ * sizeof(int)); + output_size_list_.push_back(output_size_); + } + + private: + size_t num_input_; + int num_paddings_; + int mode_; + std::vector input_shape_; // dims of the input data + std::vector output_shape_; // dims of the output data + // default + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc index 99ae2affe8..96411e9bbf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc @@ -18,32 +18,41 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - MomentumGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, float) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MomentumGpuKernel, float, float, float) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, half, half) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, float, half) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MomentumGpuKernel, float, float, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h index 32d3fbb079..091e2d6e9f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MOMENTUM_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -23,7 +23,7 @@ #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" namespace mindspore { namespace kernel { -template +template class MomentumGpuKernel : public GpuKernel { public: MomentumGpuKernel() @@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel { T *variable = GetDeviceAddress(inputs, 0); T *accumulation = GetDeviceAddress(inputs, 1); S *learning_rate = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); + G *gradient = GetDeviceAddress(inputs, 3); S *momentum = GetDeviceAddress(inputs, 4); MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, reinterpret_cast(stream_ptr)); @@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel { variable_size_ = sizeof(T); accumulation_size_ = sizeof(T); learning_rate_size_ = sizeof(S); - gradient_size_ = sizeof(T); + gradient_size_ = sizeof(G); momentum_size_ = sizeof(S); auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -97,4 +97,4 @@ class MomentumGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.cc new file mode 100644 index 0000000000..3b43451f4c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PadGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + PadGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h new file mode 100644 index 0000000000..64595d38ea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PAD_GPU_FWD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PAD_GPU_FWD_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class PadGpuFwdKernel : public GpuKernel { + public: + PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(0), output_size_(0), workspace_size_(0) {} + ~PadGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + size_t size = output_size_ / sizeof(T); + int pad_left = paddings[3][0]; + int pad_top = paddings[2][0]; + T pad_value = 0.0; + CalPad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[2], + output_shape_[3], pad_top, pad_left, pad_value, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + // check number of inputs -> should be 1 + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Pad needs 1 input."; + return false; + } + // check number of output -> should be 1 + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + // shape adjustement -> from 2d/3d to 4d to standardize + if (shape_size_ == 4) { + } else if (shape_size_ == 3) { + auto it = input_shape.begin(); + input_shape.insert(it, 1); // batch padding + shape_size_ = 4; + } else if (shape_size_ == 2) { + auto it = input_shape.begin(); + input_shape.insert(it, 2, 1); // channel padding + shape_size_ = 4; + } + paddings = GetValue>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("paddings")); + // shape adjustement -> from 2d/3d to 4d to standardize + if (paddings.size() == 4) { + } else if (paddings.size() == 3) { + auto it = paddings.begin(); + paddings.insert(it, 1, {0, 0}); // batch padding + } else if (paddings.size() == 2) { + auto it = paddings.begin(); + paddings.insert(it, 2, {0, 0}); // channel padding + } + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + temp = input_shape[i] + (paddings[i][0] + paddings[i][1]); // compute new dim size + output_size_ *= temp; + output_shape_.push_back(temp); // correct new dimension size + } + output_size_ *= sizeof(T); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t shape_size_; + size_t temp; + std::vector> paddings; // list of paddings (tuple of tuple in python) + std::vector input_shape_; // dims of the input data + std::vector output_shape_; // dims of the output data + // default + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PAD_GPU_FWD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h index 908a4e9b99..e9cf05d0dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GPU_KERNEL_H_ #include #include @@ -249,4 +249,4 @@ class PoolingGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index a066eacfa0..0d16fc48a2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ #include #include @@ -293,4 +293,4 @@ class PoolingGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h index 9811c71094..cf31abea02 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RMSPROP_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RMSPROP_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc new file mode 100644 index 0000000000..c79e3af080 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ROIAlign, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ROIAlignGpuFwdKernel, float) + +MS_REG_GPU_KERNEL_ONE( + ROIAlign, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ROIAlignGpuFwdKernel, half) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h new file mode 100644 index 0000000000..78bc1d0b61 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ROIAlignGpuFwdKernel : public GpuKernel { + public: + ROIAlignGpuFwdKernel() : x_size_(0), rois_size_(0), output_size_(0) {} + ~ROIAlignGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + const T *x = GetDeviceAddress(inputs, 0); + const T *rois = GetDeviceAddress(inputs, 1); + + T *out_data = GetDeviceAddress(outputs, 0); + + ROIAlign(x, rois, roi_rows_, roi_cols_, out_data, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_, + width_, pooled_height_, pooled_width_, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + // Get the number of input args + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ROIAlign needs 2 input."; + return false; + } + + // Get the number of output args + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ROIAlign needs 1 output."; + return false; + } + + // Get the input shapes + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + + auto x_shape_size = x_shape.size(); + if (x_shape_size != 4) { + MS_LOG(ERROR) << "x shape size is " << x_shape_size << ", but shoud be 4."; + return false; + } + + // Get channels, height & width + int batch_N = x_shape[0]; + channels_ = x_shape[1]; + height_ = x_shape[2]; + width_ = x_shape[3]; + x_shape_ = {batch_N, channels_, height_, width_}; + x_size_ = batch_N * channels_ * height_ * width_ * sizeof(T); + + // Get rois rows and cols + roi_rows_ = rois_shape[0]; + roi_cols_ = rois_shape[1]; + rois_size_ = roi_rows_ * roi_cols_ * sizeof(T); + rois_shape_ = {roi_rows_, roi_cols_}; + + // Get primitive args + pooled_height_ = GetAttr(kernel_node, "pooled_height"); + pooled_width_ = GetAttr(kernel_node, "pooled_width"); + spatial_scale_ = static_cast(GetAttr(kernel_node, "spatial_scale")); + sample_num_ = GetAttr(kernel_node, "sample_num"); + roi_end_mode_ = GetAttr(kernel_node, "roi_end_mode"); + + // Get output_shape + output_shape_ = {roi_rows_, channels_, pooled_height_, pooled_width_}; + output_size_ = 1; + for (size_t i = 0; i < 4; i++) { + output_size_ *= output_shape_[i]; + } + output_size_ *= sizeof(T); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(x_size_); + input_size_list_.push_back(rois_size_); + output_size_list_.push_back(output_size_); + } + + private: + int pooled_height_; + int pooled_width_; + T spatial_scale_; + int sample_num_; + int roi_end_mode_; + + int roi_rows_; + int roi_cols_; + int channels_; + int height_; + int width_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector x_shape_; + std::vector rois_shape_; + std::vector output_shape_; + + size_t x_size_; + size_t rois_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc new file mode 100644 index 0000000000..5d08e3d470 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ROIAlignGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ROIAlignGradGpuFwdKernel, float) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h new file mode 100644 index 0000000000..5d63083e03 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GRAD_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ROIAlignGradGpuFwdKernel : public GpuKernel { + public: + ROIAlignGradGpuFwdKernel() : dy_size_(0), rois_size_(0), output_size_(0) {} + ~ROIAlignGradGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + const T *dy = GetDeviceAddress(inputs, 0); + const T *rois = GetDeviceAddress(inputs, 1); + + T *dx = GetDeviceAddress(outputs, 0); + + ROIAlignGrad(dy, rois, roi_rows_, roi_cols_, dx, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_, + width_, pooled_height_, pooled_width_, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + // Get the number of input args + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ROIAlignGrad needs 2 input."; + return false; + } + + // Get the number of output args + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ROIAlignGrad needs 1 output."; + return false; + } + + // Get the input shapes + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + + auto dy_shape_size = dy_shape.size(); + if (dy_shape_size != 4) { + MS_LOG(ERROR) << "dy shape size is " << dy_shape_size << ", but shoud be 4."; + return false; + } + + // Parse y diff + dy_shape_ = {static_cast(dy_shape[0]), static_cast(dy_shape[1]), static_cast(dy_shape[2]), + static_cast(dy_shape[3])}; + dy_size_ = dy_shape_[0] * dy_shape_[1] * dy_shape_[2] * dy_shape_[3] * sizeof(T); + + // Get rois rows and cols + roi_rows_ = rois_shape[0]; + roi_cols_ = rois_shape[1]; + rois_shape_ = {roi_rows_, roi_cols_}; + rois_size_ = roi_rows_ * roi_cols_ * sizeof(T); + + // Get primitive args + xdiff_shape_ = GetAttr>(kernel_node, "xdiff_shape"); + pooled_height_ = GetAttr(kernel_node, "pooled_height"); + pooled_width_ = GetAttr(kernel_node, "pooled_width"); + spatial_scale_ = static_cast(GetAttr(kernel_node, "spatial_scale")); + sample_num_ = GetAttr(kernel_node, "sample_num"); + roi_end_mode_ = 1; + + // Get channels, height & width + channels_ = xdiff_shape_[1]; + height_ = xdiff_shape_[2]; + width_ = xdiff_shape_[3]; + + // Get output_shape + output_shape_ = {roi_rows_, channels_, height_, width_}; + output_size_ = roi_rows_ * channels_ * height_ * width_ * sizeof(T); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(rois_size_); + output_size_list_.push_back(output_size_); + } + + private: + std::vector xdiff_shape_; + int pooled_height_; + int pooled_width_; + T spatial_scale_; + int sample_num_; + int roi_end_mode_; + + int roi_rows_; + int roi_cols_; + int channels_; + int height_; + int width_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector dy_shape_; + std::vector rois_shape_; + std::vector output_shape_; + + size_t dy_size_; + size_t rois_size_; + size_t output_size_; +}; // namespace kernel +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc new file mode 100644 index 0000000000..7b022699f7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SGD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SGDGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h new file mode 100644 index 0000000000..70a57cded0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class SGDGpuKernel : public GpuKernel { + public: + SGDGpuKernel() : size_(1), dampening_(0.0), weight_decay_(0.0), nesterov_(false) {} + ~SGDGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream) override { + T *param = GetDeviceAddress(inputs, 0); + T *grad = GetDeviceAddress(inputs, 1); + T *lr = GetDeviceAddress(inputs, 2); + T *accum = GetDeviceAddress(inputs, 3); + T *momentum = GetDeviceAddress(inputs, 4); + T *stat = GetDeviceAddress(inputs, 5); + + SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, + reinterpret_cast(stream)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + dampening_ = GetAttr(kernel_node, "dampening"); + weight_decay_ = GetAttr(kernel_node, "weight_decay"); + nesterov_ = GetAttr(kernel_node, "nesterov"); + + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto &dim : input_shape) { + size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = size_ * sizeof(T); + input_size_list_.push_back(input_size); // parameter + input_size_list_.push_back(input_size); // gradient + input_size_list_.push_back(sizeof(T)); // lr + input_size_list_.push_back(input_size); // accum + input_size_list_.push_back(sizeof(T)); // momentum + input_size_list_.push_back(input_size); // stat + output_size_list_.push_back(input_size); + } + + private: + size_t size_; + float dampening_; + float weight_decay_; + bool nesterov_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h index a2d3aabb68..bf133e2bb7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -94,4 +94,4 @@ class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h index 88ab46a6ba..873f9c5be1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -93,4 +93,4 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h index dc20f75077..1ebd56874b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -72,4 +72,4 @@ class SmoothL1LossGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h index 02be336932..88e8bbd30e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -73,4 +73,4 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h index e56cb96fd7..8a93c4c455 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ #include #include @@ -202,4 +202,4 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h index 279bac3aa9..9369ba0a55 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -249,4 +249,4 @@ class SoftmaxGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h index b814be9969..7bea9b3569 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -216,4 +216,4 @@ class SoftmaxGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h index bcb8a6b333..88ec631017 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ #include #include @@ -203,4 +203,4 @@ class SparseSoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h index 76e863393c..2be341f50a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ASSIGN_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ASSIGN_GPU_KERNEL_H #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -90,4 +90,4 @@ class AssignGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ASSIGN_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc new file mode 100644 index 0000000000..d08b671241 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + BoundingBoxDecode, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BoundingBoxDecodeGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h new file mode 100644 index 0000000000..0f1d9ac917 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class BoundingBoxDecodeGpuKernel : public GpuKernel { + public: + BoundingBoxDecodeGpuKernel() : rois_size_(0), deltas_size_(0), bboxes_size_(0), wh_ratio_clip_(0.016) {} + + ~BoundingBoxDecodeGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *rois_addr = GetDeviceAddress(inputs, 0); + T *deltas_addr = GetDeviceAddress(inputs, 1); + T *bboxes_addr = GetDeviceAddress(outputs, 0); + + if (inputs[0]->size != inputs[1]->size) { + MS_LOG(ERROR) << "Rois box size must equal with deltas box size -" << inputs[1]->size << ", but got" + << inputs[0]->size; + return false; + } + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2], + means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs."; + return false; + } + rois_size_ = sizeof(T); + deltas_size_ = sizeof(T); + bboxes_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + rois_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + deltas_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + bboxes_size_ *= output_shape[i]; + } + + InitSizeLists(); + + const size_t coordinate_size = 4; + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + means_ = GetAttr>(kernel_node, "means"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + float mean = GetAttr(kernel_node, "means"); + for (size_t i = 0; i < coordinate_size; i++) { + means_.emplace_back(mean); + } + } else { + MS_LOG(EXCEPTION) << "Attribute means type is invalid."; + } + + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + stds_ = GetAttr>(kernel_node, "stds"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + float std = GetAttr(kernel_node, "stds"); + for (size_t i = 0; i < coordinate_size; i++) { + stds_.emplace_back(std); + } + } else { + MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; + } + + max_shape_ = GetAttr>(kernel_node, "max_shape"); + wh_ratio_clip_ = GetAttr(kernel_node, "wh_ratio_clip"); + + if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { + MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; + } + + if (max_shape_.size() < 2) { + MS_LOG(EXCEPTION) << "The size of max_shape is less than 2."; + } + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(rois_size_); + input_size_list_.push_back(deltas_size_); + output_size_list_.push_back(bboxes_size_); + } + + private: + size_t rois_size_; + size_t deltas_size_; + size_t bboxes_size_; + std::vector means_; + std::vector stds_; + std::vector max_shape_; + float wh_ratio_clip_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc new file mode 100644 index 0000000000..98ee8104e0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + BoundingBoxEncode, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BoundingBoxEncodeGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h new file mode 100644 index 0000000000..564751cda4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class BoundingBoxEncodeGpuKernel : public GpuKernel { + public: + BoundingBoxEncodeGpuKernel() : anchor_size_(0), groundtruth_size_(0), deltas_size_(0) {} + + ~BoundingBoxEncodeGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *anchor_addr = GetDeviceAddress(inputs, 0); + T *groundtruth_addr = GetDeviceAddress(inputs, 1); + T *deltas_addr = GetDeviceAddress(outputs, 0); + + if (inputs[0]->size != inputs[1]->size) { + MS_LOG(ERROR) << "Anchor box size must equal with groundtruth box size -" << inputs[1]->size << ", but got" + << inputs[0]->size; + return false; + } + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], means_[1], + means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3], + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs."; + return false; + } + anchor_size_ = sizeof(T); + groundtruth_size_ = sizeof(T); + deltas_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + anchor_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + groundtruth_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + deltas_size_ *= output_shape[i]; + } + + InitSizeLists(); + + const size_t coordinate_size = 4; + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + means_ = GetAttr>(kernel_node, "means"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + float mean = GetAttr(kernel_node, "means"); + for (size_t i = 0; i < coordinate_size; i++) { + means_.emplace_back(mean); + } + } else { + MS_LOG(EXCEPTION) << "Attribute means type is invalid."; + } + + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + stds_ = GetAttr>(kernel_node, "stds"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + float std = GetAttr(kernel_node, "stds"); + for (size_t i = 0; i < coordinate_size; i++) { + stds_.emplace_back(std); + } + } else { + MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; + } + + if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { + MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; + } + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(anchor_size_); + input_size_list_.push_back(groundtruth_size_); + output_size_list_.push_back(deltas_size_); + } + + private: + size_t anchor_size_; + size_t groundtruth_size_; + size_t deltas_size_; + std::vector means_; + std::vector stds_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc new file mode 100644 index 0000000000..208e217e1d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + CheckValid, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, float, bool) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h new file mode 100644 index 0000000000..36a69c28b4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class CheckValidGpuKernel : public GpuKernel { + public: + CheckValidGpuKernel() : anchor_boxes_size_(0), img_metas_size_(0), valid_size_(0) {} + + ~CheckValidGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *anchor_boxes_addr = GetDeviceAddress(inputs, 0); + T *img_metas_addr = GetDeviceAddress(inputs, 1); + S *valid_addr = GetDeviceAddress(outputs, 0); + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + const size_t size = block_size / coordinate; + CheckValid(size, anchor_boxes_addr, img_metas_addr, valid_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs."; + return false; + } + anchor_boxes_size_ = sizeof(T); + img_metas_size_ = sizeof(T); + valid_size_ = sizeof(S); + + auto anchor_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < anchor_boxes_shape.size(); i++) { + anchor_boxes_size_ *= anchor_boxes_shape[i]; + } + + auto img_metas_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < img_metas_shape.size(); i++) { + img_metas_size_ *= img_metas_shape[i]; + } + + auto valid_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < valid_shape.size(); i++) { + valid_size_ *= valid_shape[i]; + } + + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(anchor_boxes_size_); + input_size_list_.push_back(img_metas_size_); + output_size_list_.push_back(valid_size_); + } + + private: + size_t anchor_boxes_size_; + size_t img_metas_size_; + size_t valid_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc new file mode 100644 index 0000000000..5d3f0f202b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/iou_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + IOUGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h new file mode 100644 index 0000000000..c28e4f91ec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class IOUGpuKernel : public GpuKernel { + public: + IOUGpuKernel() : gt_boxes_size_(0), anchor_boxes_size_(0), iou_size_(0), mode_(0) {} + + ~IOUGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *gt_boxes_addr = GetDeviceAddress(inputs, 0); + T *anchor_boxes_addr = GetDeviceAddress(inputs, 1); + T *iou_addr = GetDeviceAddress(outputs, 0); + + const size_t coordinate = 4; + const size_t block_size_0 = inputs[0]->size / sizeof(T); + const size_t block_size_1 = inputs[1]->size / sizeof(T); + if ((block_size_0 % coordinate) != 0 || (block_size_1 % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + const size_t input_len_0 = block_size_0 / coordinate; + const size_t input_len_1 = block_size_1 / coordinate; + IOU(input_len_0 * input_len_1, gt_boxes_addr, anchor_boxes_addr, iou_addr, mode_, input_len_0, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs."; + return false; + } + gt_boxes_size_ = sizeof(T); + anchor_boxes_size_ = sizeof(T); + iou_size_ = sizeof(T); + + auto gt_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < gt_boxes_shape.size(); i++) { + gt_boxes_size_ *= gt_boxes_shape[i]; + } + + auto anchor_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < anchor_boxes_shape.size(); i++) { + anchor_boxes_size_ *= anchor_boxes_shape[i]; + } + + auto iou_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < iou_shape.size(); i++) { + iou_size_ *= iou_shape[i]; + } + + InitSizeLists(); + + std::string mode = GetAttr(kernel_node, "mode"); + + if (mode == "iou") { + mode_ = 0; + } else if (mode == "iof") { + mode_ = 1; + } else { + MS_LOG(ERROR) << "Mode only support 'iou' or 'iof'."; + return false; + } + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(gt_boxes_size_); + input_size_list_.push_back(anchor_boxes_size_); + output_size_list_.push_back(iou_size_); + } + + private: + size_t gt_boxes_size_; + size_t anchor_boxes_size_; + size_t iou_size_; + size_t mode_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h index 83600e20df..32d587a90d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -129,4 +129,4 @@ class BatchNormFold2GpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h index 3335210925..61abfaf113 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -165,4 +165,4 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h index 11b150686c..18c7e3fb91 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -206,4 +206,4 @@ class BatchNormFoldGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h index 93a3cbf46e..45d18602dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -163,4 +163,4 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h index 4ba6285e4b..20b413da71 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CORRECTIONMUL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CORRECTIONMUL_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -94,4 +94,4 @@ class CorrectionMulGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CORRECTIONMUL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h index b9fcbf0787..533266e185 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -102,4 +102,4 @@ class CorrectionMulGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h index 8e2c9524b2..cd5102bcd6 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -60,4 +60,4 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h index c2611ab8a2..94e0c96a0f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -56,4 +56,4 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h index 6df4da3104..0bfeb8c202 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -57,4 +57,4 @@ class FakeQuantPerLayerGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h index 475723f684..1c90b09abc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -57,4 +57,4 @@ class FakeQuantPerLayerGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h index 9a0fe23e6a..671d2cb40f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -52,4 +52,4 @@ class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h index 80ce6185c0..b4811466ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -51,4 +51,4 @@ class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc new file mode 100644 index 0000000000..9d810878b0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + RandomChoiceWithMask, + KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + RandomChoiceWithMaskGpuKernel, bool, int) +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h new file mode 100644 index 0000000000..c4c3380723 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class RandomChoiceWithMaskGpuKernel : public GpuKernel { + public: + RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {} + ~RandomChoiceWithMaskGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + S *output_index = GetDeviceAddress(outputs, 0); + T *output_mask = GetDeviceAddress(outputs, 1); + S *index_buff = GetDeviceAddress(workspaces, 0); + S *mask_buff = GetDeviceAddress(workspaces, 1); + S *rank_buff = GetDeviceAddress(workspaces, 2); + S *Tnum_buff = GetDeviceAddress(workspaces, 3); + S *tmp_buff = GetDeviceAddress(workspaces, 4); + void *States = GetDeviceAddress(workspaces, 5); + curandState *devStates = reinterpret_cast(States); + CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2], + input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask, + index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_shape_size_ = input_shape.size(); + if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) { + MS_LOG(ERROR) << "Input is " << input_shape_size_ + << "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs."; + return false; + } + // convert size_t to int + for (auto i = 0; i < input_shape_size_; i++) { + input_shape_5D_.push_back(input_shape[i]); + } + // convert shape to 5D + while (input_shape_5D_.size() != MAX_DIMENSION) { + input_shape_5D_.insert(input_shape_5D_.begin(), 1); + } + // init seedc_ + int seed = GetAttr(kernel_node, "seed"); + int seed2 = GetAttr(kernel_node, "seed2"); + if (seed2 != 0) + seedc_ = seed2; + else if (seed != 0) + seedc_ = seed; + else + seedc_ = time(NULL); + // init memory + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + count_ = GetAttr(kernel_node, "count"); + // upper ceiling for input for ceil_power2 + ceil_power2_ = RcwmRoundUpPower2(input_size_); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S)); + output_size_list_.push_back(count_ * sizeof(T)); + workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + int blocknum = std::ceil(static_cast(ceil_power2_) / BLOCKSIZE); + workspace_size_list_.push_back(blocknum * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); + } + + private: + int input_shape_size_; + int seedc_; + int input_size_; + int count_; + int ceil_power2_; + std::vector input_shape_5D_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 5ec4f52574..bf948498aa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -18,7 +18,7 @@ #include "runtime/device/ascend/tasksink/runtime_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" using HcclTaskInfoPtr = std::shared_ptr; using ge::model_runner::HcclTaskInfo; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h index db7a0fbf7c..330692e461 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_KERNEL_H_ #include #include @@ -26,7 +26,7 @@ #include "backend/kernel_compiler/ascend_kernel_mod.h" #include "backend/kernel_compiler/hccl/hcom_util.h" #include "hccl/hcom.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h index 21b34d6522..8d8436af86 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_KERNEL_BUILD_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc index 55742d383c..4b8654ed84 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -20,6 +20,7 @@ #include "utils/utils.h" #include "backend/kernel_compiler/hccl/hcom_util.h" #include "backend/session/anf_runtime_algorithm.h" +#include "frontend/parallel/context.h" namespace mindspore { namespace kernel { @@ -27,6 +28,11 @@ namespace { std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; auto op_name = AnfAlgo::GetCNodeName(kernel_node); + auto parallel_context_instance = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context_instance); + if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { + return kOpFormat_DEFAULT; + } auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); if (op_name != kReduceScatter && op_name != kAllGatherOpName) { return format; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h index 25891fdaf6..1ff5ecf769 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ #include #include #include @@ -26,4 +26,4 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector #include -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h index 6434b5fb9c..d6489642d6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_BROADCAST_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_BROADCAST_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc index 201071dcb5..2502ec799b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc @@ -20,7 +20,7 @@ #include #include -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h index 21d8ffa484..36a11d70c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_GATHER_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_GATHER_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc index 533ce1b087..70857d6a6c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -20,7 +20,7 @@ #include #include -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h index 39641f7448..99ea18f700 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_REDUCE_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_REDUCE_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc index 32c6dacb01..ca38f3e73b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc @@ -20,7 +20,7 @@ #include #include -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h index 2f4ace5aea..987982a73c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h index dc9596cf5c..2979fc5ed8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_UTILS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_UTILS_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h index 2d240338f3..5725bc80ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_ #include #include #include @@ -138,4 +138,4 @@ using KernelModPtr = std::shared_ptr; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc index 68392d1871..cf7c1af9cc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -52,19 +52,22 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { return outputs_device_type_[output_index]; } -std::vector KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } +const std::vector &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } -std::vector KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } +const std::vector &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } -std::vector KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } +const std::vector &KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } -std::vector KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } +const std::vector &KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { + if (input_reshape_type_.empty()) { + return {}; + } if (input_index >= input_reshape_type_.size()) { MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " << input_reshape_type_.size(); @@ -73,6 +76,9 @@ std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const } std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { + if (output_reshape_type_.empty()) { + return {}; + } if (output_index >= output_reshape_type_.size()) { MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " << output_reshape_type_.size(); @@ -158,13 +164,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType( const std::vector> &input_reshape_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->input_reshape_type_ = input_reshape_type; } -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType( const std::vector> &output_reshape_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->output_reshape_type_ = output_reshape_type; @@ -189,5 +195,36 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string } kernel_build_info_->outputs_format_[index] = format; } +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector &input_reshape_type, + size_t index) { + if (index >= kernel_build_info_->input_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + std::copy(input_reshape_type.begin(), input_reshape_type.end(), + std::back_inserter(kernel_build_info_->input_reshape_type_[index])); +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector &output_reshape_type, + size_t index) { + if (index >= kernel_build_info_->output_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + std::copy(output_reshape_type.begin(), output_reshape_type.end(), + std::back_inserter(kernel_build_info_->output_reshape_type_[index])); +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) { + if (index >= kernel_build_info_->outputs_device_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->outputs_device_type_[index] = output_device_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) { + if (index >= kernel_build_info_->inputs_device_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->inputs_device_type_[index] = input_device_type; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h index be243c9ae0..f45a1b4887 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ #include #include #include @@ -63,13 +63,17 @@ class KernelBuildInfo { std::vector GetOutputReshapeType(size_t input_index) const; - std::vector GetAllInputFormats() const; + const std::vector &GetAllInputFormats() const; - std::vector GetAllOutputFormats() const; + const std::vector &GetAllOutputFormats() const; - std::vector GetAllInputDeviceTypes() const; + const std::vector &GetAllInputDeviceTypes() const; - std::vector GetAllOutputDeviceTypes() const; + const std::vector &GetAllOutputDeviceTypes() const; + + std::vector> GetAllOutputReshapeType() const; + + std::vector> GetAllInputReshapeType() const; OpPattern op_pattern() const { return op_pattern_; } @@ -109,7 +113,22 @@ class KernelBuildInfo::KernelBuildInfoBuilder { KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared(); } explicit KernelBuildInfoBuilder(std::shared_ptr kernel_build_info) - : kernel_build_info_(std::move(kernel_build_info)) {} + : kernel_build_info_(std::make_shared()) { + SetKernelType(kernel_build_info->kernel_type()); + SetFusionType(kernel_build_info->fusion_type()); + SetProcessor(kernel_build_info->processor()); + OpPattern(kernel_build_info->op_pattern()); + for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) { + kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index)); + kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index)); + kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index)); + } + for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) { + kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index)); + kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index)); + kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index)); + } + } ~KernelBuildInfoBuilder() = default; @@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOutputsDeviceType(const std::vector &outputs_device_type); - void SetInputReshapeType(const std::vector> &input_reshape_type); + void SetInputsReshapeType(const std::vector> &input_reshape_type); - void SetOutputReshapeType(const std::vector> &output_reshape_type); + void SetOutputsReshapeType(const std::vector> &output_reshape_type); void SetFusionType(FusionType fusion_type); @@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOutputFormat(const std::string &format, size_t index); + void SetInputReshapeType(const std::vector &input_reshape_type, size_t index); + + void SetOutputReshapeType(const std::vector &output_reshape_type, size_t index); + + void SetInputDeviceType(const TypeId &input_device_type, size_t index); + + void SetOutputDeviceType(const TypeId &output_device_type, size_t index); + std::shared_ptr Build(); private: @@ -144,4 +171,4 @@ class KernelBuildInfo::KernelBuildInfoBuilder { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc index 0045e49bef..d4ef905729 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" @@ -101,14 +101,14 @@ std::map KernelFusion(const std::vector int build_failed_num = 0; while (!build_manger->IsAllTaskFinish()) { int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; + std::string task_result; + std::string pre_build_result; auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); if (!ret) { MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; } - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + if (task_result != "Success") { MS_LOG(INFO) << "Fusion warning: Fuison op build failed, err log: " << task_result << " change to single op build."; build_failed_num++; diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h index 2fb3a05b4b..089f41f2b8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNELFUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNELFUSION_H_ #include #include #include "backend/kernel_compiler/kernel.h" @@ -25,6 +25,10 @@ namespace kernel { * @brief fuse op and return a callable mod */ struct FusionScopeInfo { + FusionScopeInfo() {} + FusionScopeInfo(int32_t id, const std::vector &in, const std::vector &comp, + const std::vector &out) + : scope_id(id), input_nodes(in), compute_nodes(comp), output_nodes(out) {} int32_t scope_id; std::vector input_nodes; std::vector compute_nodes; @@ -35,4 +39,4 @@ std::map KernelFusion(const std::vector } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNELFUSION_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc index 81b5d0f996..343fdd4896 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc @@ -23,7 +23,7 @@ #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "backend/kernel_compiler/akg/akg_kernel_metadata.h" #include "backend/session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { @@ -31,12 +31,16 @@ namespace { void FilterInvalidKernelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); + MS_EXCEPTION_IF_NULL(kernel_node); + size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node); + size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node); std::vector> filtered_list; - (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), - [&kernel_node](const std::shared_ptr &kernel_build_info) { - return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && - AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); - }); + (void)std::copy_if( + kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), + [output_tensor_num, input_tensor_num](const std::shared_ptr &kernel_build_info) { + return kernel_build_info->GetOutputNum() == output_tensor_num && + kernel_build_info->GetInputNum() == input_tensor_num; + }); if (!filtered_list.empty()) { kernel_info_list->clear(); (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); @@ -44,21 +48,20 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, MS_LOG(INFO) << "All kernel Info list does not match any kernel info "; for (size_t index = 0; index < kernel_info_list->size(); ++index) { std::ostringstream buffer; - auto kernel_info = kernel_info_list->at(index); + auto &kernel_info = kernel_info_list->at(index); MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) { - buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + if (kernel_info->GetOutputNum() != output_tensor_num) { + buffer << "Kernel node's output size [" << output_tensor_num << "]" << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]"; } else { - buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" + buffer << "Kernel node's output size [" << input_tensor_num << "]" << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]"; } MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str(); } kernel_info_list->clear(); - MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" - << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" - << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; + MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]" + << "input size : [" << input_tensor_num << "] cannot match any kernelInfo !"; } } } // namespace diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h index 20458f48d0..9b4b49ac06 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_QUERY_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_QUERY_H_ #include #include @@ -32,4 +32,4 @@ bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h index 64ae1009d1..cf566a7d16 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ -#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ #include #include #include @@ -59,13 +59,13 @@ class OpIOInfo { ~OpIOInfo() = default; int index() const { return index_; } - std::string name() const { return name_; } + const std::string &name() const { return name_; } bool need_compile() const { return need_compile_; } - std::string param_type() const { return param_type_; } - std::string reshape_type() const { return reshape_type_; } - std::string shape() const { return shape_; } - std::vector dtypes() const { return dtypes_; } - std::vector formats() const { return formats_; } + const std::string ¶m_type() const { return param_type_; } + const std::string &reshape_type() const { return reshape_type_; } + const std::string &shape() const { return shape_; } + const std::vector &dtypes() const { return dtypes_; } + const std::vector &formats() const { return formats_; } void set_index(const int index) { index_ = index; } void set_name(const std::string &name) { name_ = name; } @@ -172,4 +172,4 @@ class OpInfo { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc index 69c4ca7db1..9f3099c415 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc @@ -22,7 +22,7 @@ #include #include "utils/log_adapter.h" #include "utils/overload.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { @@ -60,7 +60,7 @@ constexpr auto kFormat = "format"; constexpr auto kNeedCompile = "need_compile"; constexpr auto kShape = "shape"; constexpr auto kProcessor = "processor"; -std::vector> OpLib::op_info_; +std::multimap> OpLib::op_info_; static std::string ImplTypeToStr(OpImplyType impl_type) { switch (impl_type) { @@ -133,11 +133,11 @@ void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_p } bool OpLib::RegOpFromLocalInfo() { - MS_LOG(INFO) << "Start"; static bool has_load = false; if (has_load) { return true; } + MS_LOG(INFO) << "Start"; has_load = true; std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH"); if (dir.empty()) { @@ -224,7 +224,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI MS_LOG(ERROR) << "GetRefInfo Failed"; return false; } - op_info_.push_back(op_info); + op_info_.emplace(op_info->op_name(), op_info); return true; } @@ -336,16 +336,17 @@ std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType im << ", current op num: " << op_info_.size(); return nullptr; } - for (const auto &op_info : op_info_) { + std::string target_processor = is_gpu ? kCUDA : kAiCore; + for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) { + auto &op_info = iter->second; MS_EXCEPTION_IF_NULL(op_info); - if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { - auto akg_processor_match = [&]() { - return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore; - }; - if (imply_type != kAKG || akg_processor_match()) { - return op_info; - } + if (op_info->imply_type() != imply_type) { + continue; + } + if (imply_type == kAKG && op_info->processor() != target_processor) { + continue; } + return op_info; } MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) << ", current op num: " << op_info_.size(); @@ -378,7 +379,8 @@ bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - for (const auto &exist_op_info : op_info_) { + for (auto [iter, end] = op_info_.equal_range(op_info->op_name()); iter != end; ++iter) { + auto &exist_op_info = iter->second; MS_EXCEPTION_IF_NULL(exist_op_info); if (exist_op_info->equals_to(op_info)) { return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h index 845edbfc2a..808fa14413 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h @@ -14,12 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ -#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPLIB_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPLIB_H_ #include #include #include +#include #include +#include "utils/ms_utils.h" #include "backend/kernel_compiler/oplib/opinfo.h" namespace mindspore { @@ -29,12 +31,12 @@ class OpLib { OpLib() = default; virtual ~OpLib() = default; static bool RegOp(const std::string &json_string, const std::string &impl_path); - static void RegOpInfo(const std::shared_ptr &opinfo) { op_info_.emplace_back(opinfo); } + static void RegOpInfo(const std::shared_ptr &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); } static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); - static const std::vector> &GetAllOpsInfo() { return op_info_; } + static const std::multimap> &GetAllOpsInfo() { return op_info_; } protected: - static std::vector> op_info_; + static std::multimap> op_info_; private: static bool RegOpFromLocalInfo(); @@ -52,4 +54,4 @@ class OpLib { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPLIB_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h index 6b2981e5b3..fbdf69c495 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h @@ -32,7 +32,7 @@ class OpInfoLoaderPy { auto ops = OpLib::GetAllOpsInfo(); auto op_infos = new std::vector(); for (auto op_info : ops) { - auto new_op_info = new OpInfo(*op_info); + auto new_op_info = new OpInfo(*op_info.second); op_infos->emplace_back(new_op_info); } return (size_t)op_infos; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc index 552468bb71..49666293b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc @@ -19,7 +19,7 @@ #include #include "runtime/mem.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using ge::model_runner::MemcpyAsyncTaskInfo; using MemcpyAsyncTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h index cff946cc36..b60969c220 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H -#define MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_ASSIGN_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_ASSIGN_H #include #include "backend/kernel_compiler/rts/rt_kernel.h" @@ -38,4 +38,4 @@ MS_REG_RTKERNEL(assign, AssignKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_ASSIGN_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc index 8ec460fe0b..81c85bd370 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc @@ -20,7 +20,7 @@ #include "runtime/stream.h" #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using ge::model_runner::LabelGotoTaskInfo; using LabelGotoTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h index 2680d916a5..a073555241 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_GOTO_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_GOTO_H #include #include @@ -44,4 +44,4 @@ MS_REG_RTKERNEL(labelgoto, LabelGotoKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_GOTO_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc index 909885ff17..5945dc9f52 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc @@ -20,7 +20,7 @@ #include "runtime/stream.h" #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using ge::model_runner::LabelSetTaskInfo; using LabelSetTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h index 8d0cfdfb20..721bc39cff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_SET_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_SET_H #include #include @@ -44,4 +44,4 @@ MS_REG_RTKERNEL(labelset, LabelSetKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_SET_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc index ccb49d9497..8e06569370 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc @@ -21,7 +21,7 @@ #include "runtime/stream.h" #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using ge::model_runner::LabelSwitchTaskInfo; using LabelSwitchTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h index 1860d38d74..b573782f0c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_SWITCH_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_SWITCH_H #include #include @@ -54,4 +54,4 @@ MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_LABEL_SWITCH_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc index ca1114a83f..9546f38e6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -20,10 +20,10 @@ #include #include "runtime/mem.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "common/trans.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" using ge::model_runner::MemcpyAsyncTaskInfo; using MemcpyAsyncTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h index 07a782be50..4e66a212b2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H -#define MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_MEMCPY_ASYNC_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_MEMCPY_ASYNC_H #include #include @@ -53,4 +53,4 @@ MS_REG_RTKERNEL(memcpy_async, MemCpyAsyncKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_MEMCPY_ASYNC_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h index cdb43afb3e..239cf8e222 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_PROFILING_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_PROFILING_KERNEL_MOD_H_ #include #include "backend/kernel_compiler/rts/rt_kernel.h" namespace mindspore { @@ -37,4 +37,4 @@ class ProfilingKernelMod : public RtKernel { MS_REG_RTKERNEL(profiling, ProfilingKernelMod); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_PROFILING_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc index cee0ef2fdc..d51663f354 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc @@ -17,11 +17,11 @@ #include "backend/kernel_compiler/rts/recv.h" #include #include "runtime/stream.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "runtime/device/ascend/ascend_stream_assign.h" #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h index 73e0214eae..13dd91d55e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RECV_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RECV_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RECV_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RECV_H #include #include @@ -43,4 +43,4 @@ MS_REG_RTKERNEL(recv, RecvKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RECV_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RECV_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h index dc0aa3e283..0c41015383 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_H #include #include @@ -74,4 +74,4 @@ class _RtKernelRegister { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h index ccfb8d923b..66c11b0a34 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_BUILD_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_BUILD_H #include #include @@ -26,4 +26,4 @@ KernelModPtr RtOpBuild(const AnfNodePtr &anf_node); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_BUILD_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc index 9501aed5f2..59ac61fd81 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc @@ -19,7 +19,7 @@ #include #include "utils/convert_utils.h" #include "utils/utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/session/anf_runtime_algorithm.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h index 6048fb3779..28489bbdc0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_INFO_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_INFO_H #include #include @@ -72,4 +72,4 @@ void GetRtKelInfo(const CNodePtr &kernel_node, std::vector; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.h b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h index dbadb1ef44..6550a3b11a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/send.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_SEND_H -#define MINDSPORE_CCSRC_KERNEL_RTS_SEND_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_SEND_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_SEND_H #include #include #include "backend/kernel_compiler/rts/rt_kernel.h" @@ -41,4 +41,4 @@ MS_REG_RTKERNEL(send, SendKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_SEND_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_SEND_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc index e33549973d..4e48366f45 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc @@ -20,7 +20,7 @@ #include "runtime/stream.h" #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using ge::model_runner::StreamActiveTaskInfo; using StreamActiveTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h index 409c3437dc..7164b7d9fc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H -#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_STREAM_ACTIVE_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_STREAM_ACTIVE_H #include #include #include "backend/kernel_compiler/rts/rt_kernel.h" @@ -43,4 +43,4 @@ MS_REG_RTKERNEL(streamactive, StreamActiveKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_STREAM_ACTIVE_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc index 5fe03b1960..aecd7f69e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc @@ -22,7 +22,7 @@ #include "runtime/stream.h" #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using ge::model_runner::StreamSwitchTaskInfo; using StreamSwitchTaskInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h index 64a51f68bf..78e9f80e4d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H -#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_STREAM_SWITCH_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_STREAM_SWITCH_H #include #include @@ -46,4 +46,4 @@ MS_REG_RTKERNEL(streamswitch, StreamSwitchKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_STREAM_SWITCH_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/task_stream.h b/mindspore/ccsrc/backend/kernel_compiler/task_stream.h index babfe7eecd..ccc930fa70 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/task_stream.h +++ b/mindspore/ccsrc/backend/kernel_compiler/task_stream.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TASK_STREAM_H_ -#define MINDSPORE_CCSRC_KERNEL_TASK_STREAM_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TASK_STREAM_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TASK_STREAM_H_ #include #include @@ -56,4 +56,4 @@ class TaskStream { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TASK_STREAM_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TASK_STREAM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc index 449a9f4556..f3ef4e24f4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc @@ -37,6 +37,7 @@ static std::map tbe_func_adapter_map = { {"re_lu6", "relu6"}, {"re_lu6_grad", "relu6_grad"}, {"re_lu", "relu"}, + {"reverse_v2", "reverse_v2_d"}, {"re_luv2", "relu_v2"}, {"p_re_lu", "prelu"}, {"p_re_lu_grad", "prelu_grad"}, @@ -45,6 +46,7 @@ static std::map tbe_func_adapter_map = { {"reduce_max", "reduce_max_d"}, {"reduce_min", "reduce_min_d"}, {"avg_pool_grad", "avg_pool_grad_d"}, + {"avg_pool_grad_vm", "avg_pool_grad_d"}, {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, {"conv2d_backprop_input", "conv2d_backprop_input_d"}, {"depthwise_conv2d_native", "depthwise_conv2d"}, @@ -63,6 +65,7 @@ static std::map tbe_func_adapter_map = { {"b_n_training_update_grad", "bn_training_update_grad"}, {"b_n_infer", "bn_infer"}, {"b_n_infer_grad", "bn_infer_grad"}, + {"b_n_inference", "bninference_d"}, {"n_pu_clear_float_status", "n_p_u_clear_float_status"}, {"n_pu_get_float_status", "n_p_u_get_float_status"}, {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"}, @@ -81,6 +84,7 @@ static std::map tbe_func_adapter_map = { {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, {"apply_add_sign", "apply_add_sign_d"}, {"apply_power_sign", "apply_power_sign_d"}, + {"apply_centered_rms_prop", "apply_centered_rms_prop_d"}, {"transpose", "transpose_d"}, {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, @@ -124,6 +128,7 @@ static std::map tbe_func_adapter_map = { {"apply_rms_prop", "apply_rms_prop_d"}, {"cum_prod", "cumprod_d"}, {"reduce_all", "reduce_all_d"}, + {"reduce_any", "reduce_any_d"}, {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, {"unsorted_segment_min", "unsorted_segment_min_d"}, {"reduce_prod", "reduce_prod_d"}, diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h index aa09efc11f..f72de02e8f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_ADAPTER_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_ADAPTER_H #include #include @@ -65,4 +65,4 @@ class TbeAdapter { } // namespace tbe } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_ADAPTER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc index e7fd94ef84..34165c4799 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc @@ -21,7 +21,7 @@ #include #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h index dea058cd56..78b6f0c7e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_COMMON_UTILS_H_ #include #include "backend/kernel_compiler/kernel.h" @@ -39,4 +39,4 @@ std::string GetProcessor(const AnfNodePtr &anf_node); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index 73642b291a..c6b62a1766 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -22,7 +22,6 @@ #include "frontend/parallel/ops_info/ops_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h" -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" @@ -43,6 +42,7 @@ constexpr auto kJInputs = "inputs"; constexpr auto kJOutputs = "outputs"; constexpr auto kJAttrs = "attrs"; constexpr auto kJKernelName = "kernel_name"; +constexpr auto kJFullName = "full_name"; constexpr auto kJOpInfo = "op_info"; constexpr auto kJDtype = "dtype"; constexpr auto kJtype = "type"; @@ -125,6 +125,7 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptrfullname_with_scope(); if (creater_type_ == SINGLE_BUILD) { TbeUtils::SaveJsonInfo(json_name_, json_info_); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h index 768f811055..3a00169632 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_BUILD_H_ #include #include @@ -26,7 +26,6 @@ #include #include "ir/dtype.h" #include "backend/kernel_compiler/kernel.h" -#include "pybind11/stl.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h" @@ -119,4 +118,4 @@ class TbeKernelJsonCreator { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc index e6cb4cf30d..933fcb1566 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc @@ -17,7 +17,7 @@ #include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" #include #include "runtime/rt.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "graphengine/inc/framework/ge_runtime/task_info.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h index de48c83d9b..70d17a02c6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_MOD_H_ #include #include @@ -54,4 +54,4 @@ class TbeKernelMod : public AscendKernelMod { using TbeKernelModPtr = std::shared_ptr; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc index 48223f40c6..7a625268d3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -23,24 +23,18 @@ #include #include -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" #include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" #include "backend/session/anf_runtime_algorithm.h" #include "./common.h" -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" namespace mindspore { namespace kernel { using mindspore::kernel::tbe::TbeUtils; -constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; -constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler"; -constexpr auto kStartCompileOp = "start_compile_op"; -constexpr auto kWaitOne = "wait_one"; -constexpr auto kResetTaskInfo = "reset_task_info"; bool TbeOpParallelPreBuild(const std::vector &anf_nodes) { auto build_manger = std::make_shared(); @@ -61,14 +55,14 @@ bool TbeOpParallelPreBuild(const std::vector &anf_nodes) { } while (!build_manger->IsAllPreTaskFinish()) { int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; + std::string task_result; + std::string pre_build_result; auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); if (!ret) { MS_EXCEPTION(ArgumentError) << "Pre Build Failed. wait one ret:" << ret << ", task id:" << task_id; } - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + if (task_result != "Success") { MS_EXCEPTION(ArgumentError) << "task pre compile Failed, task id:" << task_id << ", cause:" << task_result; } @@ -116,14 +110,14 @@ bool TbeOpParallelBuild(const std::vector &anf_nodes) { } while (!build_manger->IsAllTaskFinish()) { int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; + std::string task_result; + std::string pre_build_result; auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); if (!ret) { MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; } - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + if (task_result != "Success") { MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; } (void)build_manger->TaskFinishProcess(task_id); @@ -131,43 +125,10 @@ bool TbeOpParallelBuild(const std::vector &anf_nodes) { return build_manger->GenSameOpKernelMod(); } -ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); } +ParallelBuildManager::ParallelBuildManager() {} ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); } -int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const { - PyObject *pRes = nullptr; - PyObject *pArgs = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); - (void)PyTuple_SetItem(pArgs, 0, arg1); - pRes = PyObject_CallMethod(tbe_parallel_compiler_, kStartCompileOp, "O", pArgs); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function start_compile_op"; - } - int task_id; - (void)PyArg_Parse(pRes, "i", &task_id); - MS_LOG(INFO) << "start compile , task id:" << task_id; - return task_id; -} - -bool ParallelBuildManager::WaitOne(int *task_id, char **task_result, char **pre_build_result) const { - MS_LOG(INFO) << "wait task start."; - MS_EXCEPTION_IF_NULL(task_id); - MS_EXCEPTION_IF_NULL(task_result); - PyObject *pRes = nullptr; - PyObject *pArg = Py_BuildValue("()"); - pRes = PyObject_CallMethod(tbe_parallel_compiler_, kWaitOne, "O", pArg); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function wait_one"; - return false; - } - (void)PyArg_ParseTuple(pRes, "iss", task_id, task_result, pre_build_result); - return true; -} - void ParallelBuildManager::SavePreTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node) { MS_LOG(INFO) << "SavePreTaskInfo, task id: " << task_id; pre_task_map_[task_id] = anf_node; @@ -310,6 +271,15 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s return kernel_mod_ptr; } +int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) { + return AscendKernelBuildClient::Instance().TbeStart(kernel_json.dump()); +} + +bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) { + MS_EXCEPTION_IF_NULL(task_id); + return AscendKernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result); +} + void ParallelBuildManager::ResetTaskInfo() { if (task_map_.empty()) { MS_LOG(INFO) << "All tasks are compiled success."; @@ -317,10 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() { } task_map_.clear(); same_op_list_.clear(); - if (tbe_parallel_compiler_ != nullptr) { - PyObject *pArg = Py_BuildValue("()"); - (void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg); - } + AscendKernelBuildClient::Instance().TbeReset(); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h index a29469b47c..a026f186c0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h @@ -14,16 +14,18 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ #include #include #include #include -#include "backend/kernel_compiler/kernel.h" -#include "pybind11/stl.h" #include + +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_build_client.h" + namespace mindspore { namespace kernel { bool TbeOpParallelPreBuild(const std::vector &anf_nodes); @@ -42,7 +44,6 @@ class ParallelBuildManager { public: ParallelBuildManager(); ~ParallelBuildManager(); - int32_t StartCompileOp(const nlohmann::json &kernel_json) const; void SavePreTaskInfo(int32_t task_id, const AnfNodePtr &anf_node); void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name, const std::vector &input_size_list, const std::vector &output_size_list, @@ -54,7 +55,6 @@ class ParallelBuildManager { const std::vector &input_size_list, const std::vector &output_size_list, AnfNode *node) const; - bool WaitOne(int *task_id, char **task_result, char **pre_build_result) const; bool IsAllPreTaskFinish() const; bool IsAllTaskFinish() const; void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result); @@ -62,10 +62,13 @@ class ParallelBuildManager { KernelModPtr GenKernelMod(const string &json_name, const string &processor, const std::vector &input_size_list, const std::vector &output_size_list, const KernelPackPtr &kernel_pack) const; + + // Interactive with real backend, who could be implemented by Python. + int StartCompileOp(const nlohmann::json &kernel_json); + bool WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result); void ResetTaskInfo(); private: - PyObject *tbe_parallel_compiler_; std::map pre_task_map_; std::map task_map_; std::vector same_op_list_; @@ -73,4 +76,4 @@ class ParallelBuildManager { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h index c07197610e..b0002a7033 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_ #include #include namespace mindspore { @@ -27,4 +27,4 @@ using SupportFormatItem = std::vector; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h index 4685df6724..fb5f1f554b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_BROADCAST_SELECTER_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_BROADCAST_SELECTER_H_ #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h index 196bb7b06a..48b3e3a3f1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_REDUCE_SELECTER_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_REDUCE_SELECTER_H_ #include #include #include diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index 21f2347629..5635811425 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -23,14 +23,15 @@ #include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" #include "nlohmann/json.hpp" -#include "utils/context/ms_context.h" -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/helper.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "frontend/parallel/ops_info/ops_utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h" +#include "backend/session/kernel_build_client.h" namespace mindspore { namespace kernel { @@ -59,6 +60,10 @@ void TbeKernelSelect::TbeMetadataInfoEx() { MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; return; } + if (!TbePropertyChecker::CheckTbeProperties(cnode_ptr_)) { + MS_LOG(INFO) << "Warning: node(" << cnode_ptr_->fullname_with_scope() << ") not support tbe aicore."; + return; + } MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ << ", node name: " << cnode_ptr_->fullname_with_scope(); OpPattern pattern = op_info_ptr->op_pattern(); @@ -77,7 +82,6 @@ void TbeKernelSelect::TbeMetadataInfoEx() { } // check support FilterInVaildKernelInfo(); - MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; } void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { @@ -113,7 +117,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { } builder.SetInputsDeviceType(inputs_device_type); builder.SetInputsFormat(inputs_format); - builder.SetInputReshapeType(inputs_reshape_type); + builder.SetInputsReshapeType(inputs_reshape_type); // output std::vector outputs_format; std::vector outputs_device_type; @@ -124,7 +128,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { } builder.SetOutputsDeviceType(outputs_device_type); builder.SetOutputsFormat(outputs_format); - builder.SetOutputReshapeType(outputs_reshape_type); + builder.SetOutputsReshapeType(outputs_reshape_type); kernel_info_list_->emplace_back(builder.Build()); } MS_LOG(INFO) << "end."; @@ -216,38 +220,37 @@ void TbeKernelSelect::FilterInVaildKernelInfo() { MS_LOG(INFO) << "Warning: get kernel build info failed."; return; } - auto kernel_build_info_iter = kernel_info_list_->begin(); - while (kernel_build_info_iter != kernel_info_list_->end()) { - if (!FilterInVaildShape(kernel_build_info_iter)) { - MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); - kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + std::vector> new_kernel_info_list; + for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) { + if (!FilterInVaildShape(iter)) { + MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString(); continue; } - if (!TbeCheckSupported(kernel_build_info_iter)) { - MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); - kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + if (!TbeCheckSupported(iter)) { + MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); continue; } - kernel_build_info_iter++; + new_kernel_info_list.emplace_back(*iter); } + (*kernel_info_list_) = new_kernel_info_list; } bool TbeKernelSelect::FilterInVaildShape( const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); - auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); + const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); - auto format = kernel_build_info_inputs_format.at(i); + const auto &format = kernel_build_info_inputs_format[i]; if (!IsShapeMatchFormat(shape, format)) { MS_LOG(INFO) << "The " << i << "th input check failed."; return false; } } - auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); + const auto &kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); - auto format = kernel_build_info_outputs_format.at(j); + const auto &format = kernel_build_info_outputs_format[j]; if (!IsShapeMatchFormat(shape, format)) { MS_LOG(INFO) << "The " << j << "th input check failed."; return false; @@ -309,7 +312,7 @@ bool TbeKernelSelect::TbeCheckSupported( if (!ret) { MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; } - ret = TbePythonFuncs::CheckSupported(kernel_json); + ret = AscendKernelBuildClient::Instance().CheckSupported(kernel_json.dump()); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); return ret; } @@ -339,12 +342,12 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind size_t io_info_num = ios_info.size(); for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { std::shared_ptr io_info_item = ios_info[io_info_index]; - auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); + const auto &kernel_build_info_dtype = io_info_item->dtypes()[kernel_build_info_index]; std::string kernel_build_info_format; if (!io_info_item->formats().empty()) { - kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); + kernel_build_info_format = io_info_item->formats()[kernel_build_info_index]; } - std::string io_param_type = io_info_item->param_type(); + const std::string &io_param_type = io_info_item->param_type(); std::vector reshape_type; StringToAxisVector(io_info_item->reshape_type(), &reshape_type); if (io_param_type == kParamTypeDynamic) { @@ -362,6 +365,7 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind } dynamic_input_index++; real_io_tensor_index += dynamic_input_size; + } else { if (ios_info.size() != 1) { MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; @@ -383,7 +387,6 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; } } - if (io_info_index != io_info_num) { MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num << "), this node may has optional input/output."; @@ -483,7 +486,7 @@ std::string TbeKernelSelect::OpSelectFormat() { if (!ret) { MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; } - res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); + res_json_str = AscendKernelBuildClient::Instance().SelectFormat(kernel_json.dump()); if (res_json_str.empty()) { MS_LOG(EXCEPTION) << "op select format error."; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc new file mode 100644 index 0000000000..0ce5e77880 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/parallel/ops_info/ops_utils.h" + +namespace mindspore { +namespace kernel { +using CheckSupportFun = bool (*)(const CNodePtr &cnode); + +constexpr char kAttrStrides[] = "strides"; + +static bool CheckStridedSlice(const CNodePtr &cnode) { + // check stride[-1] != 1 TODO + if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) { + auto strides = AnfAlgo::GetNodeAttr>(cnode, kAttrStrides); + if (!strides.empty() && strides[strides.size() - 1] == 1) { + return true; + } + } + // last tensor TODO + return true; +} + +bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + static std::map tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}}; + auto cnode_type = AnfAlgo::GetCNodeName(cnode); + auto find_iter = tbe_property_checker.find(cnode_type); + if (find_iter != tbe_property_checker.end()) { + return find_iter->second(cnode); + } + return true; +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h new file mode 100644 index 0000000000..83e6f35de8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_SELECT_TBE_PROPERTY_CHECKER_H +#define MINDSPORE_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_SELECT_TBE_PROPERTY_CHECKER_H +#include "mindspore/core/ir/anf.h" + +namespace mindspore { +namespace kernel { +class TbePropertyChecker { + public: + TbePropertyChecker() = default; + ~TbePropertyChecker() = default; + static bool CheckTbeProperties(const mindspore::CNodePtr &cnode); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_BACKEND_KERNEL_COMPILER_TBE_TBE_KERNEL_SELECT_TBE_PROPERTY_CHECKER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc deleted file mode 100644 index facb07991a..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" -#include "backend/kernel_compiler/tbe/tbe_utils.h" -#include "common/utils.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; -constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; -constexpr auto kOpSelectFormatFunc = "op_select_format"; -constexpr auto kCheckSupportedFunc = "check_supported"; -constexpr auto kTBEException = "TBEException"; - -PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; -PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; -PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr; -PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr; -bool TbePythonFuncs::Init() { - static bool initialized = false; - if (initialized) { - return true; - } - // Initialize cache - TbeUtils::LoadCache(); - - // tbe_process - PyObject *pTbeProcessModule = nullptr; - pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule); - if (pTbeProcessModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module."; - return false; - } - - pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc); - if (pCreateTbeParallelCompilerFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kCreateTbeParallelCompilerFunc << "]."; - return false; - } - - pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr); - if (pTbeCompiler_ == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler."; - return false; - } - - pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc); - if (pOpSelectFormatFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kOpSelectFormatFunc << "]."; - return false; - } - - pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc); - if (pCheckSupportedFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kCheckSupportedFunc << "]."; - return false; - } - initialized = true; - MS_LOG(INFO) << "TbePythonFuncs initialized Success."; - return true; -} - -std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - -std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { - PyObject *pArg = nullptr; - PyObject *pRet = nullptr; - std::string res_json_str; - - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return res_json_str; - } - - // assembly Args - pArg = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str())); - if (pArg == nullptr) { - MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; - return res_json_str; - } - - // call functions - if (pOpSelectFormatFunc_ == nullptr) { - MS_LOG(ERROR) << "function is nullptr."; - return res_json_str; - } - - pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg); - if (pRet == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc - << "], function args:" << PyObjectToStr(pArg); - } - - char *pstr = nullptr; - (void)PyArg_Parse(pRet, "s", &pstr); - res_json_str = pstr; - if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str - << " ,function args:" << PyObjectToStr(pArg); - } - return res_json_str; -} - -bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - bool ret = false; - - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return ret; - } - // assembly Args - pArg = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); - (void)PyTuple_SetItem(pArg, 0, arg1); - if (pArg == nullptr) { - MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; - return ret; - } - - // call functions - if (pCheckSupportedFunc_ == nullptr) { - MS_LOG(ERROR) << "function is nullptr."; - return ret; - } - - pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc - << "], function args: " << PyObjectToStr(pArg); - } - if (PyBool_Check(pRes)) { - ret = PyObject_IsTrue(pRes) != 0; - } else { - char *pstr = nullptr; - (void)PyArg_Parse(pRes, "s", &pstr); - std::string res_str = pstr; - if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str - << ", function args: " << PyObjectToStr(pArg); - } - } - - return ret; -} - -PyObject *TbePythonFuncs::TbeParallelCompiler() { - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return nullptr; - } - return pTbeCompiler_; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h deleted file mode 100644 index 4e1746475c..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_PYTHON_FUNCS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_PYTHON_FUNCS_H_ - -#include -#include -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace kernel { -class TbePythonFuncs { - public: - TbePythonFuncs() = default; - ~TbePythonFuncs() = default; - static std::string OpSelectFormat(const nlohmann::json &kernel_json); - static bool CheckSupported(const nlohmann::json &kernel_json); - static PyObject *TbeParallelCompiler(); - - private: - static bool Init(); - static std::string PyObjectToStr(_object *PyObj); - static PyObject *pCreateTbeParallelCompilerFunc_; - static PyObject *pTbeCompiler_; - static PyObject *pOpSelectFormatFunc_; - static PyObject *pCheckSupportedFunc_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_PYTHON_FUNCS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc index 76ef7b08d5..b68d30633d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc @@ -30,7 +30,7 @@ #include "backend/kernel_compiler/oplib/oplib.h" #include "utils/utils.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "runtime/device/kernel_info.h" #include "ir/dtype/type.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h index 39ddaaa73d..7d3f639b5e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_UTILS_H_ #include #include #include @@ -83,4 +83,4 @@ class KernelMeta { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_UTILS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 64d76ab358..c1bc8ec638 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -16,7 +16,6 @@ #include "backend/optimizer/ascend/ascend_backend_optimization.h" #include #include -#include #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/ascend/ir_fission/bn_split.h" #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" @@ -24,6 +23,7 @@ #include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" #include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" #include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" +#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" #include "backend/optimizer/pass/communication_op_fusion.h" @@ -59,6 +59,7 @@ #include "backend/optimizer/ascend/format_type/insert_trans_op.h" #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" +#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/optimize_dependence.h" #include "backend/optimizer/pass/erase_visit_attr.h" @@ -87,6 +88,7 @@ #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" #include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" @@ -97,7 +99,11 @@ #include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" -#include "utils/context/ms_context.h" +#include "backend/optimizer/ascend/format_type/remove_internal_output.h" +#include "backend/optimizer/ascend/ir_fission/concat_fission.h" +#include "backend/optimizer/ascend/ir_fission/pack_fission.h" +#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" +#include "utils/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" #include "debug/anf_ir_utils.h" @@ -105,18 +111,9 @@ namespace mindspore { namespace opt { namespace { -void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { +void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) { MS_EXCEPTION_IF_NULL(ir_fusion_pm); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -127,10 +124,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -140,6 +133,27 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} + +void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -147,13 +161,12 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace @@ -201,6 +214,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); optimizer->AddPassManager(data_layout_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); @@ -221,7 +235,9 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); optimizer->AddPassManager(mixed_precision_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); @@ -254,9 +270,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); - if (context_ptr->ir_fusion_flag()) { - AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); - } + AddAscendIRFusionRulesPass(ir_fusion_pm.get()); + AddAscendIRFusionPass(ir_fusion_pm.get()); if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { ir_fusion_pm->AddPass(std::make_shared()); @@ -331,8 +346,10 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto other_pm = std::make_shared("other_pm"); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); optimizer->AddPassManager(other_pm); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h index 8194ab467b..a25c8e0e5c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ #include #include "backend/session/kernel_graph.h" namespace mindspore { @@ -35,4 +35,4 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr #include "common/trans.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/optimizer/common/helper.h" #include "utils/utils.h" #include "runtime/device/kernel_info.h" @@ -26,7 +26,7 @@ #include "frontend/operator/ops.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_graph.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { @@ -51,11 +51,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = node; + AnfNodePtr input_node = nullptr; CNodePtr trans_data = nullptr; std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; - std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); + std::vector padding_axis; MS_EXCEPTION_IF_NULL(node); // if insert transdata for input we need to change the input if (is_insert_input) { @@ -66,12 +66,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); input_node = AnfAlgo::GetInputNode(cnode, insert_index); padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); + } else { + input_node = node; + padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); } + + auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0); bool need_padding = false; if (is_insert_input) { - need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size())); } else { - need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size())); } if (!need_padding) { // don't need padding insert transdata only @@ -80,8 +85,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt } else if (is_insert_input) { // if need padding & is input need insert a transdata // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); + auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0)); auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); trans_node = trans_data; @@ -89,8 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt // if need padding & is output need insert a transdata // node -> transdata[padding shape] -> reshape[ori_shape] trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - auto reshape_node = - CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); + auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); trans_node = reshape_node; } // refresh the transdata's format to ori format & dst format @@ -140,9 +143,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + auto kernel_graph = func_graph->cast(); + size_t out_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); if (output_format == kOpFormat_NC1KHKWHWC0) { MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " @@ -150,8 +154,12 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); + if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { + auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { + kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); + } + make_tuple_inputs.push_back(trans_op); } else { // No need insert trans op. make_tuple_inputs.push_back(tuple_getitem); @@ -162,15 +170,20 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } } // namespace void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type) { + const AnfNodePtr &trans_data, const std::vector &reshape_type, + const TypeId &type_id) { MS_EXCEPTION_IF_NULL(trans_data); auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); MS_EXCEPTION_IF_NULL(ori_build_info); auto builder = std::make_shared(ori_build_info); builder->SetInputsFormat({input_format}); - builder->SetInputReshapeType({reshape_type}); - builder->SetOutputReshapeType({reshape_type}); + builder->SetInputsReshapeType({reshape_type}); + builder->SetOutputsReshapeType({reshape_type}); builder->SetOutputsFormat({output_format}); + if (type_id != kTypeUnknown) { + builder->SetOutputsDeviceType({type_id}); + builder->SetInputsDeviceType({type_id}); + } AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); } @@ -178,15 +191,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const bool need_padding, const std::string &op_name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(input); - std::vector trans_inputs; - auto prim = std::make_shared(op_name); - trans_inputs.push_back(NewValueNode(prim)); - trans_inputs.push_back(input); - CNodePtr trans_node = func_graph->NewCNode(trans_inputs); + CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared(op_name)), input}); MS_EXCEPTION_IF_NULL(trans_node); - auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); if (need_padding) { // if need padding we should set the transdata node's shape to the padding shape + auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, trans_node.get()); @@ -202,6 +211,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, MS_EXCEPTION_IF_NULL(kernel_select); kernel_select->SelectKernel(trans_node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue>({}), trans_node); MS_EXCEPTION_IF_NULL(trans_node); trans_node->set_scope(input->scope()); return trans_node; @@ -213,11 +223,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr MS_EXCEPTION_IF_NULL(func_graph); std::string input_format = format; std::string output_format = format; - std::vector new_cast_inputs; - auto prim = std::make_shared(prim::kPrimCast->name()); - new_cast_inputs.push_back(NewValueNode(prim)); - new_cast_inputs.push_back(input); - CNodePtr cast = func_graph->NewCNode(new_cast_inputs); + CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared(prim::kPrimCast->name())), input}); MS_EXCEPTION_IF_NULL(cast); // set kernel build info kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; @@ -240,6 +246,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue>({}), cast); return cast; } @@ -249,9 +256,14 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP if (outputs_num == 0) { return node; } + auto kernel_graph = func_graph->cast(); // Single output if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { - return InsertTransOpForSingleOutput(func_graph, node, kernel_select); + auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) { + kernel_graph->ReplaceInternalOutput(node, new_node); + } + return new_node; } // Multiple output return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); @@ -263,7 +275,8 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + size_t in_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < in_num; ++input_index) { AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); MS_EXCEPTION_IF_NULL(input_node); new_inputs.push_back(input_node); @@ -284,8 +297,10 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + size_t in_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < in_num; ++input_index) { + auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index); + const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); TypeId origin_type(kTypeUnknown); auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); @@ -294,20 +309,19 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod // weight origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); if (origin_type == kTypeUnknown) { - origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); + origin_type = AnfAlgo::GetOutputDeviceDataType(prev_node.first, prev_node.second); } } else { // feature map - origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + origin_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); } const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); - const std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); - const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); + const std::vector origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); // In graph kernel, we check parameter, // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode(cur_input)) { new_inputs.push_back(cur_input); - } else if (origin_type != device_type) { + } else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) { auto cast = AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); MS_EXCEPTION_IF_NULL(cast); @@ -339,6 +353,7 @@ AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node MS_EXCEPTION_IF_NULL(new_node); new_node->set_abstract(node->abstract()); new_node->set_scope(node->scope()); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue>({}), new_node); return new_node; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index cb308a09a0..4d2833b999 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ #include #include @@ -86,7 +86,8 @@ class OpFinder { using OpFinderPtr = std::shared_ptr; void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); + const AnfNodePtr &trans_data, const std::vector &reshape_type = {}, + const TypeId &type_id = kTypeUnknown); CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, const bool need_padding, const std::string &op_name); @@ -106,4 +107,4 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index 22183c9050..dd0d61e1ac 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h index dfc45b4688..a3b2853755 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc index 59915d43d4..74ef83dcf4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h index abaf264d2e..9ca88959de 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class BnupdateEltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc index 1bfff1b50e..eecda879d9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h index 6bf74d5268..a70ff91ba5 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ #include #include @@ -44,4 +44,4 @@ class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc index 144ab4b53f..d9d6c37848 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h index 93aa324566..a1037621a9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ #include #include @@ -44,4 +44,4 @@ class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc index a2ebfbe79e..d32269fab1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc @@ -23,7 +23,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h index 224422530b..f0a93a0b0a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class ConvBnReduceFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_CONV_BNREDUCE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_CONV_BNREDUCE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc index 1a67e3c39b..d5626bc414 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h index 911cf744de..dbeae26170 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ #include #include @@ -44,4 +44,4 @@ class ConvDoubleInFusionPass : public FusionBasePass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc index 1eb26b12bc..f31d8f7580 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h index 6dddd600c2..202ddb0a7a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class ConvSingleInFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc index 285b8f6c07..6e48048b28 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc @@ -23,7 +23,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h index 6746dad984..43652f6057 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class DepthwiseConvEltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc index 1e24cce0e4..865ae57349 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h index ae63687631..ea126594eb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ #include #include @@ -43,4 +43,4 @@ class EltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc index 27a7a786d1..f66dbeda99 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc @@ -17,7 +17,7 @@ #include #include #include "debug/anf_ir_dump.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" #include "backend/session/anf_runtime_algorithm.h" diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h index dced2c2fa2..024ce416e3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ #include #include #include @@ -68,4 +68,4 @@ class FusionBasePass : public Pass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc index 7fcc6e45e0..f026bc8b01 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h index e0d08bb58d..9ca2a00395 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class MatmulEltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc index 58a219aec7..bf29ca1416 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h index 40a45360a1..34533684f2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class MultiOutputFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc index 95955818eb..741be4835d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc @@ -23,7 +23,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h index 4d56eee7b3..f238f425f5 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class ReduceEltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWSIE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_REDUCE_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc index f2117f9374..6ae541c8cb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc @@ -22,7 +22,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h index f3b97f8357..011f134996 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class SegmentEltwiseFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWSIE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_SEGMENT_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc index d93b47b66c..a6222a13ae 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc @@ -23,7 +23,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h index 371c206399..419e701274 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ #include #include @@ -45,4 +45,4 @@ class StridedReadConvStridedWriteFusionPass : public FusionBasePass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index 9685530705..353d3f080a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -28,7 +28,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" #include "runtime/device/kernel_info.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { @@ -120,8 +120,8 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector inputs_data_type; for (const auto &input : inputs_list) { auto real_input = AnfAlgo::VisitKernel(input, 0); - inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); - inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); + inputs_format.emplace_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); + inputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); } // outputs format and data type std::vector outputs_format; @@ -130,13 +130,13 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vectorname()) { auto tuple_getitem = output->cast(); MS_EXCEPTION_IF_NULL(tuple_getitem); - outputs_format.push_back(AnfAlgo::GetOutputFormat( + outputs_format.emplace_back(AnfAlgo::GetOutputFormat( tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( + outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType( tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); } else { - outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); + outputs_format.emplace_back(AnfAlgo::GetOutputFormat(output, 0)); + outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); } } builder.SetInputsFormat(inputs_format); @@ -229,7 +229,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, for (auto &buffer_fusion_info : *buffer_fusion_infos) { auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; + const auto &fusion_info = buffer_fusion_info.second; for (const auto &node : fusion_info.anf_nodes) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -237,10 +237,10 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == fusion_info.anf_nodes.end()) { - if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), - (*buffer_fusion_infos)[fusion_id].inputs_list.end(), - cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { - (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); + if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), + (*buffer_fusion_infos)[fusion_id].inputs_list.end(), + in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { + (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in); } } } @@ -277,7 +277,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, for (auto &buffer_fusion_info : *buffer_fusion_infos) { auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; + const auto &fusion_info = buffer_fusion_info.second; for (const auto &node : fusion_info.anf_nodes) { if (AnfAlgo::GetOutputTensorNum(node) == 1) { for (auto use_node : manager->node_users()[node]) { @@ -294,7 +294,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, std::back_inserter(tuple_getitem_nodes), [](const std::pair &use_node) { return use_node.first; }); std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); - for (auto getitem : tuple_getitem_nodes) { + for (auto &getitem : tuple_getitem_nodes) { MS_EXCEPTION_IF_NULL(getitem); auto getitem_ptr = getitem->cast(); auto input2 = getitem_ptr->input(2); @@ -304,7 +304,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); } prev_idx = output_idx + 1; - for (auto item_use_node : manager->node_users()[getitem]) { + for (auto &item_use_node : manager->node_users()[getitem]) { if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == fusion_info.anf_nodes.end()) { (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); @@ -365,31 +365,25 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph MS_EXCEPTION_IF_NULL(kernel_graph); bool change = false; std::unordered_map buffer_fusion_infos; - buffer_fusion_infos.clear(); GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); std::vector fusion_scope_infos; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - mindspore::kernel::FusionScopeInfo fusion_scope_info; - fusion_scope_info.scope_id = buffer_fusion_info.first; - fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; - fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; - fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; - fusion_scope_infos.push_back(fusion_scope_info); -#ifdef DEBUG - DumpFusionScopeInfo(fusion_scope_info); -#endif - } + std::transform( + buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos), + [](const std::pair &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo { + return mindspore::kernel::FusionScopeInfo(buffer_fusion_info.first, buffer_fusion_info.second.inputs_list, + buffer_fusion_info.second.anf_nodes, + buffer_fusion_info.second.outputs_list); + }); auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); - std::vector fusion_ids; + std::set fusion_ids; for (auto &buffer_fusion_info : buffer_fusion_infos) { MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); - fusion_ids.push_back(buffer_fusion_info.first); + fusion_ids.insert(buffer_fusion_info.first); } // Replace fusion op from return to head - std::sort(fusion_ids.begin(), fusion_ids.end()); for (auto &fusion_id : fusion_ids) { // Get kernel mod when supporting tbe if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { @@ -414,9 +408,10 @@ bool UbPatternFusion::ReplaceFusionOp(std::unordered_map types; std::vector> shapes; for (const auto &out_node : buffer_fusion_info.outputs_list) { - for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { - types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); - shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); + size_t out_num = AnfAlgo::GetOutputTensorNum(out_node); + for (size_t idx = 0; idx < out_num; ++idx) { + types.emplace_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(out_node, idx)); } } if (types.empty() || shapes.empty()) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h index 69eb0f43d4..b507a527b4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ #include #include #include @@ -47,4 +47,4 @@ class UbPatternFusion : public Pass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc new file mode 100644 index 0000000000..2ca18ec7e9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +void AddOutputs(const AnfNodePtr &node, int rank_size) { + MS_EXCEPTION_IF_NULL(node); + auto origin_abstract = node->abstract(); + MS_EXCEPTION_IF_NULL(origin_abstract); + auto tuple_abstract = origin_abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + auto &origin_abstracts = tuple_abstract->elements(); + AbstractBasePtrList abstract_list; + std::vector outputs_device_type; + std::vector outputs_device_format; + for (int i = 0; i < rank_size; ++i) { + for (size_t j = 0; j < origin_abstracts.size(); ++j) { + abstract_list.push_back(origin_abstracts[j]); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j)); + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j)); + } + } + // Update abstract + auto new_abstracts = std::make_shared(abstract_list); + node->set_abstract(new_abstracts); + // Update kernel build info + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + builder->SetOutputsDeviceType(outputs_device_type); + builder->SetOutputsFormat(outputs_device_format); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); +} +} // namespace + +AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::vector &new_tuple_getitems, + int rank_size) const { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector make_tuple_inputs; + size_t inputs_size = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < inputs_size; ++i) { + for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) { + std::vector concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + concat_inputs.push_back(new_tuple_getitems[idx]); + auto concat = func_graph->NewCNode(concat_inputs); + MS_EXCEPTION_IF_NULL(concat); + MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]); + concat->set_abstract(new_tuple_getitems[idx]->abstract()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); + std::vector dyn_input_size{rank_size}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); + kernel_select_->SelectKernel(concat); + make_tuple_inputs.push_back(concat); + } + } + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} + +const BaseRef ConcatOutputsForAllGather::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kAllGatherOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { + return nullptr; + } + auto fusion = AnfAlgo::GetNodeAttr(cnode, kAttrFusion); + if (fusion <= 0) { + return nullptr; + } + auto rank_size = AnfAlgo::GetNodeAttr(node, kAttrRankSize); + AddOutputs(node, rank_size); + std::vector new_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); + return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h new file mode 100644 index 0000000000..7b4d0d5427 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class ConcatOutputsForAllGather : public PatternProcessPass { + public: + explicit ConcatOutputsForAllGather(bool multigraph = true) + : PatternProcessPass("concat_outputs_for_all_gather", multigraph), + kernel_select_(std::make_shared()) {} + ~ConcatOutputsForAllGather() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::vector &new_tuple_getitems, int rank_size) const; + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc index a729cdd0f9..a1d957f72c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc @@ -56,12 +56,19 @@ const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, co return nullptr; } - // 3. next_node is not nop node and it has only one input which is memcpy's output + // 3. next_node is not nop node, not graph output and it has only one input which is memcpy's output for (auto &item : next_nodes) { auto next_node = item.first->cast(); if (opt::IsNopNode(next_node)) { return nullptr; } + + auto graph_outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + auto iter = std::find(graph_outputs.begin(), graph_outputs.end(), next_node); + if (iter != graph_outputs.end()) { + return nullptr; + } + if (next_node->inputs().size() != 2) { MS_LOG(DEBUG) << "next node has more than one input"; return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h index 365088b34a..eef34ff238 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class GetnextMemcpyElimination : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc new file mode 100644 index 0000000000..0f1946926d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" +#include +#include +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(cur_hccl); + MS_EXCEPTION_IF_NULL(graph); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(prev_node); + if (!AnfAlgo::IsCommunicationOp(prev_node)) { + return false; + } + auto prev_hccl_op = prev_node->cast(); + MS_EXCEPTION_IF_NULL(prev_hccl_op); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(prev_hccl_op); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + bool is_contain = false; + for (size_t i = 1; i < cur_hccl->size(); ++i) { + if (cur_hccl->input(i) == output) { + is_contain = true; + break; + } + } + if (!is_contain) { + return true; + } + } + } + return false; +} +} // namespace + +AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + std::vector memcpy_async_list; + std::vector new_inputs = {hccl_node->input(0)}; + for (size_t i = 1; i < hccl_node->size(); ++i) { + auto input = hccl_node->input(i); + MS_EXCEPTION_IF_NULL(input); + // when input is also a hccl op and just part outputs of it linking with cur_hccl_op + if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { + auto memcpy_async = CreateMemcpyAsyncOp(graph, input); + auto kernel_info = std::make_shared(); + memcpy_async->set_kernel_info(kernel_info); + MS_EXCEPTION_IF_NULL(kernel_select_); + kernel_select_->SelectKernel(memcpy_async->cast()); + new_inputs.push_back(memcpy_async); + memcpy_async_list.push_back(memcpy_async); + } else { + new_inputs.push_back(input); + } + } + + if (!memcpy_async_list.empty()) { + CNodePtr new_hccl_node = std::make_shared(*hccl_node); + new_hccl_node->set_inputs(new_inputs); + return new_hccl_node; + } + return nullptr; +} + +const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + if (!AnfAlgo::IsCommunicationOp(node)) { + return nullptr; + } + return InsertMemcpyAsync(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h new file mode 100644 index 0000000000..e1a29f5741 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertMemcpyAsyncForCascade : public PatternProcessPass { + public: + explicit InsertMemcpyAsyncForCascade(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_cascade", multigraph), + kernel_select_(std::make_shared()) {} + ~InsertMemcpyAsyncForCascade() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h index 6fefc32230..bd4c406310 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ #include "backend/optimizer/common/optimizer.h" @@ -32,4 +32,4 @@ class InsertMemcpyAsyncForGetNext : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index 2585006be6..b0bdfd30cd 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -32,12 +32,49 @@ const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe bool IsParameterOrValueNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - return kernel_with_index.first->isa() || kernel_with_index.first->isa(); + auto real_node = kernel_with_index.first; + MS_EXCEPTION_IF_NULL(real_node); + if (real_node->isa()) { + return true; + } + return real_node->isa(); +} + +void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node, + const std::vector &memcpy_async_list) { + MS_EXCEPTION_IF_NULL(control_depend); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); + make_tuple_inputs.emplace_back(hccl_node); + auto make_tuple = graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + control_depend->set_input(IntToSize(index), make_tuple); +} + +void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node, + const std::vector &memcpy_async_list) { + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(tuple_getitem); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + SetInput(output->cast(), node_index.second, graph, hccl_node, memcpy_async_list); + } + } } -void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { +void TransferControl(const CNodePtr &hccl_node, const std::vector &memcpy_async_list, + const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(hccl_node); - MS_EXCEPTION_IF_NULL(memcpy_async); MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -49,48 +86,49 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, // find hccl_node's output which is a control depend for (const auto &node_index : iter->second) { AnfNodePtr output = node_index.first; - int output_index = node_index.second; + MS_EXCEPTION_IF_NULL(output); if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - CNodePtr control_depend = output->cast(); - MS_EXCEPTION_IF_NULL(control_depend); - std::vector new_inputs; - for (size_t i = 0; i < control_depend->size(); ++i) { - if (i == IntToSize(output_index)) { - new_inputs.push_back(memcpy_async); - } else { - new_inputs.push_back(control_depend->input(i)); - } - } - control_depend->set_inputs(new_inputs); + SetInput(output->cast(), node_index.second, graph, hccl_node, memcpy_async_list); + } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) { + DealControlForGetitem(output->cast(), graph, hccl_node, memcpy_async_list); } } } } // namespace -bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { +bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, + const CNodePtr &cur_node) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(cur_node); // when input is a parameter or is a value node if (IsParameterOrValueNode(input)) { return true; } - // when input is a Ref or some special cnodes - if (kernel_query_->IsTbeRef(input) || - kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { - return true; - } + if (input->isa()) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(input); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // when input is used by others - if (iter->second.size() > 1) { - return true; + // when input is a Ref cnode + if (kernel_query_->IsTbeRef(input)) { + return true; + } + + // when input is some special cnodes + if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { + return true; + } + + // when input is used by others + auto iter = node_users.find(input); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + if (iter->second.size() > 1) { + return true; + } } return false; } @@ -98,21 +136,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(hccl_node); - bool has_insert_memcpy = false; - AnfNodePtr memcpy_async = nullptr; + std::vector memcpy_async_list; std::vector new_inputs = {hccl_node->input(0)}; for (size_t i = 1; i < hccl_node->size(); ++i) { auto input = hccl_node->input(i); - if (NeedInsertMemcpy(graph, input)) { - memcpy_async = CreateMemcpyAsyncOp(graph, input); - has_insert_memcpy = true; + if (NeedInsertMemcpy(graph, input, hccl_node)) { + auto memcpy_async = CreateMemcpyAsyncOp(graph, input); new_inputs.push_back(memcpy_async); + memcpy_async_list.push_back(memcpy_async); } else { new_inputs.push_back(input); } } - if (has_insert_memcpy) { + if (!memcpy_async_list.empty()) { CNodePtr new_hccl_node = std::make_shared(*hccl_node); new_hccl_node->set_inputs(new_inputs); auto manager = graph->manager(); @@ -122,9 +159,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co MS_LOG(DEBUG) << "end replace"; // transer hccl op's control to the memcpy_async - if (hccl_node->size() == 2) { - TransferControl(new_hccl_node, memcpy_async, graph); - } + TransferControl(new_hccl_node, memcpy_async_list, graph); } } @@ -133,11 +168,10 @@ const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_gr if (func_graph == nullptr || node == nullptr || !node->isa()) { return nullptr; } - auto cnode = node->cast(); if (!AnfAlgo::IsCommunicationOp(node)) { return nullptr; } - InsertMemcpyAsync(func_graph, cnode); + InsertMemcpyAsync(func_graph, node->cast()); return nullptr; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h index 7bd730a84d..e69866c0b9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -32,9 +32,9 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { private: void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; - bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; + bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const; KernelQueryPtr kernel_query_; }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h index 6aed678ff2..dadf1f1384 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/pass.h" @@ -32,4 +32,4 @@ class InsertPadForNMSWithMask : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc index f508bb2868..03c4618069 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc @@ -22,7 +22,7 @@ #include "utils/utils.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/common_utils.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h index 6bf1287ae7..4be8d3faf6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class ChangeAxisOfReduceKernel : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc index 7da0027310..24fda289ac 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc @@ -21,7 +21,7 @@ #include "utils/utils.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/common_utils.h" namespace mindspore { @@ -30,12 +30,13 @@ namespace { bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { MS_EXCEPTION_IF_NULL(node); // get prior node's device output format - string pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(node, input_index); + auto prev_node = AnfAlgo::GetPrevNodeOutput(node, input_index); + string pre_output_format = AnfAlgo::GetOutputFormat(prev_node.first, prev_node.second); string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); if (pre_output_format == selected_input_format) { return true; } - auto input_origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, input_index); + auto input_origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; // when input shape size is 1D, default format and NC1HWC0 are compatible @@ -87,7 +88,8 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt for (auto &t : todos) { CNodePtr cnode = t->cast(); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { + size_t in_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < in_num; ++i) { if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" << cnode->DebugString() << "]"; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h index bf956895de..ab016757a8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class CheckConsistency : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 4375a08031..88a8e7a9c0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP if (origin_format != cur_format && cur_shape.size() > 1) { auto kernel_select = std::make_shared(); final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(cur_format, origin_format, final_node); + RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); final_index = 0; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h index cb3b13dc49..da85844db8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ #include "ir/anf.h" #include "backend/optimizer/common/optimizer.h" @@ -33,4 +33,4 @@ class DealRefTransAndCast : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc index c3f7900645..11dcb14855 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -39,8 +39,10 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo MS_EXCEPTION_IF_NULL(cnode); std::vector make_tuple_inputs; AbstractBasePtrList abstract_list; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { + make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + auto kernel_graph = func_graph->cast(); + size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { AnfNodePtr replace_node = nullptr; const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); @@ -64,13 +66,16 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_scope(cnode->scope()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) { + kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0); + } } else { replace_node = getitem; } } else { replace_node = getitem; } - abstract_list.push_back(replace_node->abstract()); + abstract_list.emplace_back(replace_node->abstract()); make_tuple_inputs.push_back(replace_node); } AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); @@ -87,6 +92,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c return cnode; } MS_EXCEPTION_IF_NULL(cnode->Type()); + auto kernel_graph = func_graph->cast(); // Single output if (!cnode->Type()->isa()) { if (!need_insert_cast[0]) { @@ -109,6 +115,9 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_scope(cnode->scope()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) { + kernel_graph->ReplaceInternalOutput(cnode, replace_node); + } } return replace_node; } @@ -188,6 +197,10 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto new_node = InsertCastForInput(func_graph, cnode); + auto kernel_graph = func_graph->cast>(); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + kernel_graph->ReplaceInternalOutput(node, new_node); + } // process output return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h index 19c282aac9..c628ee7c08 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -33,4 +33,4 @@ class InsertCast : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index a22a1faa5f..9788db6773 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -22,7 +22,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" #include "backend/kernel_compiler/oplib/oplib.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { @@ -46,14 +46,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { return nullptr; } - AnfNodePtr front_node; + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + MS_LOG(DEBUG) << "process op: " << node->DebugString(); + AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); auto kernel_graph = func_graph->cast>(); if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { - front_node = kernel_graph->GetFrontNodeByInternalOutput(node); + kernel_graph->ReplaceInternalOutput(node, new_node); } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { @@ -61,12 +60,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An return new_node; } } - auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); - if (kernel_graph != nullptr && front_node != nullptr) { - auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); - kernel_graph->ReplaceInternalOutput(old_node, final_node); - } - return final_node; + return InsertTransOpForOutput(func_graph, new_node, kernel_select_); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h index 0b21375327..284988642f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ #include #include @@ -40,4 +40,4 @@ class InsertTransOp : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h index 82ff5f2b9a..67b8f823c9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ #include #include @@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h index d0e467b7a3..e0c4cd3cff 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H #include #include "backend/optimizer/common/optimizer.h" @@ -37,4 +37,4 @@ class MergeCastToOp : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h index f5608db05a..7a1473b3ab 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class ModifyOpAttrs : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc index 91b9326cc1..cff8ce5673 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -25,7 +25,7 @@ #include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "backend/kernel_compiler/common_utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/helper.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h index cc9333a013..e2b55ae75e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H #include #include #include @@ -44,4 +44,4 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc new file mode 100644 index 0000000000..530e180ed2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/remove_internal_output.h" +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool UsedForOutputOnly(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(node); + if (iter == node_users.end()) { + return false; + } + const auto &node_set = iter->second; + for (const auto &node_index : node_set) { + if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimMakeTuple)) { + return false; + } + } + return true; +} +} // namespace +const BaseRef RemoveInternalOutputTransOp::DefinePattern() const { + VarPtr X = std::make_shared(); + auto prim = std::make_shared(kTransDataOpName); + return VectorRef({prim, X}); +} + +const BaseRef RemoveInternalOutputCast::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::kPrimCast, X}); +} + +const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto kernel_graph = func_graph->cast(); + if (kernel_graph == nullptr) { + return nullptr; + } + if (!kernel_graph->IsInternalOutput(node, 0)) { + return nullptr; + } + if (!UsedForOutputOnly(func_graph, node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kTransOpInputNum); + auto input_node = cnode->input(1); + if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) { + kernel_graph->ReplaceInternalOutput(node, input_node); + } else { + auto tuple_getitem = input_node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + int idx = AnfAlgo::GetTupleGetItemOutIndex(tuple_getitem); + AnfNodePtr real_input_node = AnfAlgo::GetTupleGetItemRealInput(tuple_getitem); + kernel_graph->ReplaceInternalOutput(node, real_input_node, 0, idx); + } + return input_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h new file mode 100644 index 0000000000..0af2482ea2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveInternalOutput : public PatternProcessPass { + public: + explicit RemoveInternalOutput(const std::string &name, bool multigraph = true) + : PatternProcessPass(name, multigraph) {} + ~RemoveInternalOutput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class RemoveInternalOutputTransOp : public RemoveInternalOutput { + public: + explicit RemoveInternalOutputTransOp(bool multigraph = true) + : RemoveInternalOutput("remove_internal_output_trans_op", multigraph) {} + ~RemoveInternalOutputTransOp() override = default; + const BaseRef DefinePattern() const override; +}; + +class RemoveInternalOutputCast : public RemoveInternalOutput { + public: + explicit RemoveInternalOutputCast(bool multigraph = true) + : RemoveInternalOutput("remove_internal_output_cast", multigraph) {} + ~RemoveInternalOutputCast() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h index 135f11f52c..fb1f45ad61 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class RemoveNoUseReshapeOp : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc new file mode 100644 index 0000000000..92f6c5799b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +const BaseRef SplitUnsupportedTransData::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::KPrimTransData, X}); +} + +const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + auto ori_trans_data = node->cast(); + if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) { + return nullptr; + } + auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data); + MS_EXCEPTION_IF_NULL(kernel_info); + if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) { + MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1" + << ori_trans_data->DebugString(); + } + return SplitTransData(func_graph, ori_trans_data); +} +AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const { + auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node); + if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || + kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { + return trans_node; + } + auto builder_info_to_default = std::make_shared(kernel_info); + auto builder_info_to_special_foramt = std::make_shared(kernel_info); + builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); + builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); + std::vector next_trans_node_inputs = { + NewValueNode(std::make_shared(prim::KPrimTransData->name())), trans_node}; + auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); + next_trans_node->set_abstract(trans_node->abstract()); + AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get()); + AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); + return next_trans_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h new file mode 100644 index 0000000000..d4df2b57a8 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SplitUnsupportedTransData : public PatternProcessPass { + public: + explicit SplitUnsupportedTransData(bool multigraph = true) + : PatternProcessPass("split_unsupported_transdata", multigraph) {} + ~SplitUnsupportedTransData() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc index a3fd704bc5..860f6c397a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc @@ -27,7 +27,7 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ MS_EXCEPTION_IF_NULL(origin_addn_cnode); std::vector new_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; for (size_t i = begin_index; i < begin_index + offset; ++i) { - new_addn_inputs.push_back(origin_addn_cnode->input(i)); + new_addn_inputs.emplace_back(origin_addn_cnode->input(i)); } CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); MS_EXCEPTION_IF_NULL(new_addn); @@ -66,7 +66,7 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN cur_input_index += inputs_divisor_; } for (size_t i = cur_input_index; i <= origin_input_size; i++) { - base_addn_inputs.push_back(new_cnode->input(i)); + base_addn_inputs.emplace_back(new_cnode->input(i)); } CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); MS_EXCEPTION_IF_NULL(base_addn); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h index e04cdfdf7b..9df122f1d0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_ADDN_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_ADDN_FISSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -34,4 +34,4 @@ class AddnFission : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_ADDN_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc index f0edefd5f5..728ced95be 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc @@ -37,7 +37,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s } size_t output_num = 0; for (const auto &node_index : manager->node_users()[bn]) { - AnfNodePtr output = node_index.first; + const AnfNodePtr &output = node_index.first; MS_EXCEPTION_IF_NULL(output); if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { continue; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h index 23f0e56035..0bc29376d0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class BatchNormBertFission : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc index 97c67e4441..ffecb745b6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc @@ -32,7 +32,7 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { return false; } for (const auto &node_index : manager->node_users()[node]) { - AnfNodePtr output = node_index.first; + const AnfNodePtr &output = node_index.first; MS_EXCEPTION_IF_NULL(output); if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { continue; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h index 97100de284..8180bf2cd7 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -47,4 +47,4 @@ class BatchNormGradInferFission : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc index 97122386c6..260fc90f3d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc @@ -20,8 +20,8 @@ #include #include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" +#include "utils/ms_context.h" +#include "utils/ms_utils.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/kernel_info.h" #include "backend/session/anf_runtime_algorithm.h" @@ -33,7 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra std::vector *bn_update_grad_outputs) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); + const auto &bn_grad_inputs = bn_grad_node->inputs(); if (bn_grad_inputs.size() < kBNGradInputNum) { MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; } @@ -58,7 +58,8 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra std::vector *bn_reduce_grad_outputs) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); + MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs); + const auto &bn_grad_inputs = bn_grad_node->inputs(); if (bn_grad_inputs.size() < kBNGradInputNum) { MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h index e5378d8332..d217a82fe6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/helper.h" @@ -30,4 +30,4 @@ class BatchNormGradSplit : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc index 6c4e226120..45d629d4e9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc @@ -20,8 +20,8 @@ #include #include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" +#include "utils/ms_context.h" +#include "utils/ms_utils.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/kernel_info.h" #include "backend/session/anf_runtime_algorithm.h" diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h index 6fe78d4724..2e5b512dd1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/helper.h" @@ -30,4 +30,4 @@ class BnGradSplit : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc index 33670e5703..8ebbd9269a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc @@ -20,7 +20,7 @@ #include #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/kernel_info.h" #include "backend/session/anf_runtime_algorithm.h" diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h index 4340ba0af6..d14d1357a6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/helper.h" @@ -30,4 +30,4 @@ class BnSplit : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc new file mode 100644 index 0000000000..63fb40fa8b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/concat_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index, + size_t offset) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(origin_concat_cnode); + std::vector new_concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + new_concat_inputs.emplace_back(origin_concat_cnode->input(i)); + } + CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs); + MS_EXCEPTION_IF_NULL(new_concat); + new_concat->set_scope(origin_concat_cnode->scope()); + // Set attrs + AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); + AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(offset)), new_concat); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat); + // infer shape + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0); + auto axis = AnfAlgo::GetNodeAttr(origin_concat_cnode, kAttrAxis); + if (axis < 0) { + axis += input_shape.size(); + } + auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0); + if (axis < 0 || axis >= SizeToInt(output_shape.size()) || axis >= SizeToInt(input_shape.size())) { + MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"; + } + output_shape[axis] = input_shape[axis] * offset; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape}, + new_concat.get()); + return new_concat; +} +} // namespace + +const BaseRef ConcatFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimConcat, Xs}); +} + +const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // The real input begins with index 1. + size_t origin_input_size = cnode->inputs().size() - 1; + if (origin_input_size <= inputs_divisor_) { + return nullptr; + } + CNodePtr new_cnode = cnode; + while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); + std::vector base_concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + size_t cur_input_index = 1; + // Divide the inputs of concat by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { + base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_)); + cur_input_index += inputs_divisor_; + } + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_concat_inputs.emplace_back(new_cnode->input(i)); + } + CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); + MS_EXCEPTION_IF_NULL(base_concat); + base_concat->set_scope(new_cnode->scope()); + base_concat->set_abstract(new_cnode->abstract()); + // Set attrs + AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); + AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + std::vector dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat); + + new_cnode = base_concat; + origin_input_size = base_concat->inputs().size() - 1; + } + + return new_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.h new file mode 100644 index 0000000000..21f652b5af --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_CONCAT_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_CONCAT_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr size_t kConcatInputsDivisor = 63; +class ConcatFission : public PatternProcessPass { + public: + explicit ConcatFission(bool multigraph = true) + : PatternProcessPass("concat_fission", multigraph), inputs_divisor_(kConcatInputsDivisor) {} + ~ConcatFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + size_t inputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_CONCAT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc index e8a778b36f..563e44b501 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc @@ -31,9 +31,8 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; } - std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName))}; - inputs.push_back(lars_v2->input(1)); - inputs.push_back(lars_v2->input(2)); + std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName)), lars_v2->input(1), + lars_v2->input(2)}; auto square_sum_all = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(square_sum_all); square_sum_all->set_scope(lars_v2->scope()); @@ -56,13 +55,13 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, if (lars_v2->size() != kLarsV2InputNum) { MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; } - std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName))}; - inputs.push_back(lars_v2->input(1)); - inputs.push_back(lars_v2->input(2)); - inputs.push_back(square_sum_all_outputs[0]); - inputs.push_back(square_sum_all_outputs[1]); - inputs.push_back(lars_v2->input(3)); - inputs.push_back(lars_v2->input(4)); + std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName)), + lars_v2->input(1), + lars_v2->input(2), + square_sum_all_outputs[0], + square_sum_all_outputs[1], + lars_v2->input(3), + lars_v2->input(4)}; auto lars_v2_update = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(lars_v2_update); lars_v2_update->set_scope(lars_v2->scope()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h index 3a165f2b29..eca5e56897 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class LarsV2Fission : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc index 1d19def787..15e04b697b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc @@ -22,7 +22,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" #include "ir/primitive.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/utils.h" namespace mindspore { @@ -32,6 +32,7 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( std::vector *layer_norm_x_backprop_outputs) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(layer_norm_grad); + MS_EXCEPTION_IF_NULL(layer_norm_x_backprop_outputs); auto prim = std::make_shared(kLayerNormXBackpropOpName); std::vector layer_norm_x_backprop_inputs = {NewValueNode(prim)}; for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h index c1501b1593..a8f6c418c3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ #include #include @@ -39,4 +39,4 @@ class LayerNormGradSplit : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc new file mode 100644 index 0000000000..2037f61b4b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/pack_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index, + size_t offset) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(origin_pack_cnode); + std::vector new_pack_inputs{NewValueNode(std::make_shared(prim::kPrimPack->name()))}; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + new_pack_inputs.push_back(origin_pack_cnode->input(i)); + } + CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs); + MS_EXCEPTION_IF_NULL(new_pack); + new_pack->set_scope(origin_pack_cnode->scope()); + new_pack->set_abstract(origin_pack_cnode->abstract()); + AnfAlgo::CopyNodeAttr(kAttrAxis, origin_pack_cnode, new_pack); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_pack); + AnfAlgo::SetNodeAttr(kAttrNum, MakeValue(SizeToInt(offset)), new_pack); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack); + // infer shape + auto output_shape = AnfAlgo ::GetOutputInferShape(origin_pack_cnode, 0); + auto axis = AnfAlgo::GetNodeAttr(new_pack, kAttrAxis); + if (axis < 0) { + axis += output_shape.size(); + } + if (axis < 0) { + MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"; + } + std::vector new_shape; + for (size_t i = 0; i < output_shape.size() + 1; ++i) { + if (i < IntToSize(axis)) { + new_shape.push_back(output_shape[i]); + } else if (i == IntToSize(axis)) { + new_shape.push_back(offset); + } else { + new_shape.push_back(output_shape[SizeToInt(i) - 1]); + } + } + new_shape.erase(new_shape.begin() + axis + 1); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape}, + new_pack.get()); + return new_pack; +} +} // namespace + +const BaseRef PackFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimPack, Xs}); +} + +const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // The real input begins with index 1. + size_t origin_input_size = cnode->inputs().size() - 1; + if (origin_input_size <= inputs_divisor_) { + return nullptr; + } + std::vector base_concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + size_t cur_input_index = 1; + // Divide the inputs of pack by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { + base_concat_inputs.emplace_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_)); + cur_input_index += inputs_divisor_; + } + if (cur_input_index <= origin_input_size) { + base_concat_inputs.emplace_back( + CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1)); + } + + CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); + MS_EXCEPTION_IF_NULL(base_concat); + base_concat->set_scope(cnode->scope()); + base_concat->set_abstract(cnode->abstract()); + AnfAlgo::CopyNodeAttr(kAttrAxis, cnode, base_concat); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); + std::vector dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat); + + return base_concat; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.h new file mode 100644 index 0000000000..c11c98698d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_PACK_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_PACK_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr size_t kPackInputsDivisor = 63; +class PackFission : public PatternProcessPass { + public: + explicit PackFission(bool multigraph = true) + : PatternProcessPass("pack_fission", multigraph), inputs_divisor_(kPackInputsDivisor) {} + ~PackFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + size_t inputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_PACK_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc new file mode 100644 index 0000000000..f6a941cc69 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(old_node); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimReduceMin->name())), input}; + CNodePtr reduce_min = graph->NewCNode(inputs); + reduce_min->set_scope(old_node->scope()); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min); + return reduce_min; +} + +bool NeedOptimize(const TypeId &dtype, const std::vector &shape, const std::vector &axis) { + if (dtype != kNumberTypeFloat32) { + MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; + return false; + } + if (shape.size() == 0 || shape.size() == 1) { + MS_LOG(INFO) << "ReduceMin's input shape size is " << shape.size() << ", no need optimize!"; + return false; + } + if (axis.size() == 1) { + MS_LOG(INFO) << "ReduceMin axis size is 1, no need optimize!"; + return false; + } + int last_dim = SizeToInt(shape.size() - 1); + if (std::find(axis.begin(), axis.end(), -1) == axis.end() && + std::find(axis.begin(), axis.end(), last_dim) == axis.end()) { + MS_LOG(INFO) << "Attribute of axis does not contain the last axis, not match!"; + return false; + } + return true; +} + +std::vector CalFirstAxis(const std::vector &shape, const std::vector &axis) { + std::vector axis_fisrt; + int last_dim = SizeToInt(shape.size() - 1); + std::copy_if(axis.begin(), axis.end(), std::back_inserter(axis_fisrt), + [&last_dim](int v) { return v != -1 && v != last_dim; }); + + int dim_size = SizeToInt(shape.size()); + if (axis_fisrt.empty()) { + for (int i = 0; i < dim_size - 1; ++i) { + axis_fisrt.push_back(i); + } + } + + for (size_t i = 0; i < axis_fisrt.size(); ++i) { + if (axis_fisrt[i] < -dim_size || axis_fisrt[i] > dim_size - 1) { + MS_LOG(EXCEPTION) << "The axis of ReduceMin verify failed, quit optimizing"; + } + if (axis_fisrt[i] < 0) { + axis_fisrt[i] = dim_size + axis_fisrt[i]; + } + } + return axis_fisrt; +} + +std::vector GetInferShape(const std::vector &shape, const std::vector &axis_first, + bool keep_dims) { + std::vector shape_first; + for (size_t item = 0; item < shape.size(); ++item) { + if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { + if (keep_dims) { + // If keep_dims is true, current dimension set to 1 + shape_first.push_back(1); + } + } else { + // item is not in ConstValueAxis + shape_first.push_back(shape[item]); + } + } + return shape_first; +} +} // namespace + +const BaseRef ReduceMinFission::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::kPrimReduceMin, X}); +} + +const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (graph == nullptr || node == nullptr) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, 2); + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) { + MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!"; + return nullptr; + } + auto axis_value = prim->GetAttr(kAttrAxis); + MS_EXCEPTION_IF_NULL(axis_value); + if (!axis_value->isa()) { + return nullptr; + } + auto axis = AnfAlgo::GetNodeAttr>(cnode, kAttrAxis); + auto keep_dims = AnfAlgo::GetNodeAttr(cnode, kAttrKeepDims); + + if (!NeedOptimize(dtype, shape, axis)) { + MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); + return nullptr; + } + + // Create reduce_min1 + CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); + std::vector axis_first = CalFirstAxis(shape, axis); + std::vector shape_first = GetInferShape(shape, axis_first, keep_dims); + AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1); + + // Create reduce_min2 + CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); + reduce_min2->set_abstract(cnode->abstract()); + std::vector axis_last = {-1}; + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_last), reduce_min2); + return reduce_min2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.h new file mode 100644 index 0000000000..66976cb0b5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class ReduceMinFission : public PatternProcessPass { + public: + explicit ReduceMinFission(bool multigraph = true) : PatternProcessPass("reduce_min_fission", multigraph) {} + ~ReduceMinFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h index fb641c12d6..94d352bc1a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class SingleBatchNormFission : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc index 063f81a1ca..a37c2f38a9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc @@ -96,17 +96,16 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, const std::vector &size_splits_base, int split_dim, int num_split) { SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); - std::vector base_type_ids; - std::vector> base_output_shapes_base; auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + std::vector base_type_ids(num_split, type_id); + std::vector> base_output_shapes_base; if (split_dim < 0) { split_dim += output_shape.size(); } for (int i = 0; i < num_split; ++i) { output_shape[split_dim] = size_splits_base[i]; base_output_shapes_base.emplace_back(output_shape); - base_type_ids.emplace_back(type_id); } AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); } @@ -118,17 +117,14 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); - std::vector size_splits_new; - for (int i = 0; i < divisor; ++i) { - size_splits_new.emplace_back(small_split_size); - } + std::vector size_splits_new(divisor, small_split_size); // Create new output shape and new output type id for each new Splitv node which has full inputs. std::vector new_type_ids; std::vector> new_output_shapes; CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); // Create make_tuple input to create a make_tuple for replacing the old Split node. - std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; // Start to divide the outputs of Split. std::vector size_splits_base; const auto base_split_size = divisor * small_split_size; @@ -147,10 +143,7 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int auto last_node_num_split = num_split - cur_output_index; if (last_node_num_split > 1) { CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - std::vector size_splits_new_last; - for (int i = 0; i < last_node_num_split; ++i) { - size_splits_new_last.emplace_back(small_split_size); - } + std::vector size_splits_new_last(last_node_num_split, small_split_size); SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); // Create new output shape and new output type id for the last Splitv node std::vector last_new_type_ids; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h index 6428a21e73..ce12031a94 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -34,4 +34,4 @@ class SplitFission : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc index 6eeb7a61f7..8e18560260 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc @@ -24,7 +24,7 @@ #include "backend/session/kernel_graph.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h index e005a83a2f..cda8d50341 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TOPK_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TOPK_SPLIT_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -35,4 +35,4 @@ class TopKSplit : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TOPK_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h index bc681944c3..cde03f68f4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ #include #include #include @@ -42,4 +42,4 @@ class TransDataSplit : public Pass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h index 683a345cdb..2fe2d32fa0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ #include #include @@ -92,4 +92,4 @@ class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h index 2d599a8cc9..98bc63a6f1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ #include #include @@ -108,4 +108,4 @@ class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h index 6e5560bfb0..3eb2755a55 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ #include #include @@ -36,4 +36,4 @@ class AddInputToOutput : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h index 46872aa959..f6c69e1a83 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class BatchNorm2BNInfer : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h index 0676f8a040..79188c72ee 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -31,4 +31,4 @@ class BatchNormGrad2BNInferGrad : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc index 1d89bfd388..de6d4e6a51 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc @@ -21,7 +21,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "ir/primitive.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/utils.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h index 9282b75527..5584cd7fdc 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ #include #include @@ -48,4 +48,4 @@ class ClipByNormNoDivSquareSumFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h index 05bf713bdd..c78542bf54 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -37,4 +37,4 @@ class ClipByValueFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h index 932f0d2890..930198f36c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -38,4 +38,4 @@ class ConfusionMulGradFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h index e3a86e22c9..1954dd66b1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -40,4 +40,4 @@ class ConfusionSoftmaxGradRule : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc index 0fe042dc4e..4cf83df43c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc @@ -15,6 +15,7 @@ */ #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" #include +#include #include #include "backend/session/anf_runtime_algorithm.h" #include "ir/primitive.h" @@ -111,6 +112,13 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); + // Add attr mapping from original nodes to fusion nodes + auto original_names = + MakeValue>({relu->fullname_with_scope(), relu_grad->fullname_with_scope()}); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, original_names, relu_v2); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, original_names, relu_grad_v2); + AnfAlgo::SetNodeAttr(kAttrDatadumpIsMultiop, MakeValue(true), relu_v2); + AnfAlgo::SetNodeAttr(kAttrDatadumpIsMultiop, MakeValue(true), relu_grad_v2); auto manage = graph->manager(); MS_EXCEPTION_IF_NULL(manage); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h index 7506960ecb..7a831fcf3f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_DERELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_DERELU_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -30,4 +30,4 @@ class DereluFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_DERELU_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc index dbff0374f3..b4a2af2bd1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -44,7 +44,7 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; } for (const auto &node_index : manager->node_users()[bn]) { - AnfNodePtr output = node_index.first; + const AnfNodePtr &output = node_index.first; MS_EXCEPTION_IF_NULL(output); bn_outputs->push_back(output); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h index b3bbedc36e..04cbed35f2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ #include #include @@ -80,4 +80,4 @@ class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h index 45738c289c..b38f89c761 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ #include #include #include #include #include "ir/anf.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace opt { @@ -61,4 +61,4 @@ class InputToOutputRegistry { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h index d14ce6e3fe..36389dbb04 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ #include #include @@ -125,4 +125,4 @@ class LambNextMVRuleCond4 : public LambNextMVRule { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h index 23114c37ee..ca76d82678 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ #include #include @@ -107,4 +107,4 @@ class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h index 58f05c37ba..ff59e4005a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ #include #include @@ -65,4 +65,4 @@ class LambNextMVWithDecayV1Rule : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h index 67687cc037..e13cab3f7f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -51,4 +51,4 @@ class LambNextRightRule : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc index 8e38c3cc2e..550301dc11 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc @@ -21,7 +21,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "ir/primitive.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/utils.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h index 5ea01ccf65..db48a4a6b2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -52,4 +52,4 @@ class LambUpdateWithLRRuleFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h index c5396178a5..ece158974e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ #include #include @@ -46,4 +46,4 @@ class LambUpdateWithLrV2 : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h index 5bf1608143..a8414023cd 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -38,4 +38,4 @@ class LayerNormBetaGammaBackpropFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h index 8c762435a9..a2f9e8b0ff 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -31,4 +31,4 @@ class MatmulBiasaddFusion : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h index 8d36684a11..0625e00f72 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -31,4 +31,4 @@ class MomentumLossscaleFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h index 0ad13e10e6..aca5049249 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class MulAddFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h index 484cb75237..f954dbd14d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class MulAddNFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h index 0479fd3d63..605eb19d1c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ #include #include diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h index 122bdf55ca..9dec340112 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ #include #include @@ -37,4 +37,4 @@ class RefreshParameterFormat : public Pass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h index 848713201a..13aa28d185 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ #include #include @@ -35,4 +35,4 @@ class RemoveReshapePair : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h index a76538019e..4932cb485c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ #include #include @@ -43,4 +43,4 @@ class ReshapeTransposeFusion : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h index 1b884b2726..b7ae7dda1c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ #include #include @@ -59,4 +59,4 @@ class SoftmaxGradExtFusionV3 : public SoftmaxGradExtFusion { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h index 54189606ba..e378415f3c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class SquareSumFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h index 39b8fe4687..ac4ca604d6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ #include #include @@ -43,4 +43,4 @@ class TransposeReshapeFusion : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h index 852d5194ec..4b1ec803c2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ #include #include @@ -49,4 +49,4 @@ class TransposeTransDataFusion : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc index 887b9a76a1..236535ef3b 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -22,7 +22,8 @@ #include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" #include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" #include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" -#include "utils/context/ms_context.h" +#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h" +#include "utils/ms_context.h" #include "debug/anf_ir_dump.h" namespace mindspore { @@ -47,8 +48,9 @@ void BackendCommonOptimization(const std::shared_ptr &kern common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); optimizer->AddPassManager(common_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h index 4127fc05de..e673e101bb 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ #include #include "backend/session/kernel_graph.h" namespace mindspore { @@ -23,4 +23,4 @@ void BackendCommonOptimization(const std::shared_ptr &kern } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h index bdee5ee84a..b4cc2e505e 100644 --- a/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h +++ b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_FUSION_ID_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_FUSION_ID_ALLOCATOR_H_ #include #include "base/base.h" @@ -41,4 +41,4 @@ using FusionIdAllocatorPtr = std::shared_ptr; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_FUSION_ID_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 266130c6b1..5dd3f3b9dc 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -23,12 +23,12 @@ #include #include #include "utils/utils.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" #include "backend/session/anf_runtime_algorithm.h" #include "frontend/operator/ops.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "runtime/device/kernel_info.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { @@ -313,9 +313,9 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(outputs); for (size_t i = 0; i < output_num; i++) { - auto idx = NewValueNode(SizeToInt(i)); - MS_EXCEPTION_IF_NULL(idx); int temp = SizeToInt(i); + auto idx = NewValueNode(temp); + MS_EXCEPTION_IF_NULL(idx); auto imm = std::make_shared(temp); auto abstract_scalar = std::make_shared(imm); idx->set_abstract(abstract_scalar); @@ -781,5 +781,27 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set &suppor MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); return false; } + +ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { + types.push_back(kTypeUnknown); + } + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + return new_value_node; +} + } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index a267e65b53..16bdeb79fe 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ #include #include @@ -24,7 +24,7 @@ #include #include "ir/func_graph.h" #include "backend/session/kernel_graph.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/optimizer/common/pattern_engine.h" namespace mindspore { @@ -194,6 +194,9 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); // Check node's data type is in supported data type set bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); + +// Create a new value node of func graph,not kernel graph +ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/node_pass.h b/mindspore/ccsrc/backend/optimizer/common/node_pass.h index 780ae1a056..a790153f1b 100644 --- a/mindspore/ccsrc/backend/optimizer/common/node_pass.h +++ b/mindspore/ccsrc/backend/optimizer/common/node_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ #include #include @@ -33,4 +33,4 @@ class NodePass : public Pass { using NodePassPtr = std::shared_ptr; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/optimizer.h b/mindspore/ccsrc/backend/optimizer/common/optimizer.h index 0b03c9c0ee..9277d19d64 100644 --- a/mindspore/ccsrc/backend/optimizer/common/optimizer.h +++ b/mindspore/ccsrc/backend/optimizer/common/optimizer.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_OPTIMIZER_H_ #include #include @@ -26,8 +26,8 @@ #include "ir/primitive.h" #include "backend/optimizer/common/pass_manager.h" #include "backend/optimizer/common/pattern_engine.h" -#include "utils/graph_utils.h" -#include "common/utils.h" +#include "ir/graph_utils.h" +#include "utils/ms_utils.h" #include "backend/optimizer/common/helper.h" namespace mindspore { @@ -86,4 +86,4 @@ class GraphOptimizer { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pass.h b/mindspore/ccsrc/backend/optimizer/common/pass.h index 6e35fb1dc4..cb8ad63b46 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pass.h +++ b/mindspore/ccsrc/backend/optimizer/common/pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ #include #include @@ -38,4 +38,4 @@ using PassPtr = std::shared_ptr; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc index f9f41237e0..a0f03e85ad 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc @@ -25,7 +25,7 @@ #include "ir/func_graph.h" #include "ir/manager.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.h b/mindspore/ccsrc/backend/optimizer/common/pass_manager.h index 51db27d250..f758ee86ea 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pass_manager.h +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ #include #include @@ -58,4 +58,4 @@ using PassManagerPtr = std::shared_ptr; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h index 51fa8801b2..e461916c05 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h +++ b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_ #include #include @@ -36,7 +36,7 @@ #include "backend/optimizer/common/visit.h" #include "base/base.h" #include "utils/log_adapter.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" namespace mindspore { class CondVar; @@ -201,4 +201,4 @@ struct hash { } }; } // namespace std -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/visit.h b/mindspore/ccsrc/backend/optimizer/common/visit.h index 9799d3f9c1..53316f714b 100644 --- a/mindspore/ccsrc/backend/optimizer/common/visit.h +++ b/mindspore/ccsrc/backend/optimizer/common/visit.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_VISIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_VISIT_H_ #include #include @@ -27,7 +27,7 @@ #include #include "base/base.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" // namespace to support utils definition namespace mindspore { @@ -58,4 +58,4 @@ class DefaultVisitor : public Visitor { std::shared_ptr ExpandList(const std::vector &list); bool CheckIfNeedExpand(const std::vector &list); } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_VISIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc index 41e4abee27..d9194950b4 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -60,10 +60,11 @@ const BaseRef AdamFusion::DefinePattern() const { {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; + + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); + return next_param; } const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h index f87defc04c..1fa339c3f3 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADAM_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADAM_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -53,4 +53,4 @@ class AdamFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADAM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc index c95945c980..f0f4ac6f36 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -62,10 +62,11 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const { VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; + + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); + return next_param; } const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h index 53477ec898..015ce63206 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADAM_WEIGHT_DECAY_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADAM_WEIGHT_DECAY_FUSION_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -55,4 +55,4 @@ class AdamWeightDecayFusion : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADAM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc new file mode 100644 index 0000000000..575a01cc24 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_addn_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceAddNFusion::DefinePattern() const { + VectorRef addn = VectorRef({prim::kPrimAddN, A, B}); + return addn; +} + +const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + auto A = AnfAlgo::GetInputNode(utils::cast(node), 0); + auto B = AnfAlgo::GetInputNode(utils::cast(node), 1); + MS_EXCEPTION_IF_NULL(A); + MS_EXCEPTION_IF_NULL(B); + int num_input = AnfAlgo::GetNodeAttr(node, "n"); + + if (num_input == 2) { + auto prim = std::make_shared(prim::kPrimTensorAdd->name()); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), A, B}; + auto add_new = graph->NewCNode(inputs); + std::vector outputs_type; + std::vector> outputs_shape; + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(A, 0)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(A, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, add_new.get()); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(utils::cast(node), utils::cast(add_new)); + return add_new; + } else { + return nullptr; + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.h new file mode 100644 index 0000000000..d83da2b067 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceAddNFusion : public PatternProcessPass { + public: + explicit ReplaceAddNFusion(bool multigraph = true) : PatternProcessPass("replace_addn", multigraph) { + A = std::make_shared(); + B = std::make_shared(); + } + ~ReplaceAddNFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr A; + VarPtr B; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc new file mode 100644 index 0000000000..2d48e5b002 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_bn_cast_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceBNCastFusion::DefinePattern() const { + VectorRef in_cast = VectorRef({prim::kPrimCast, x_}); + VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_}); + VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_}); + return tupleget; +} + +const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + auto fbn2 = AnfAlgo::GetInputNode(utils::cast(node), 0); + auto x_after = AnfAlgo::GetInputNode(utils::cast(fbn2), 0); + auto x_before = AnfAlgo::GetInputNode(utils::cast(x_after), 0); + auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2), 1); + auto bias = AnfAlgo::GetInputNode(utils::cast(fbn2), 2); + auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2), 3); + auto var = AnfAlgo::GetInputNode(utils::cast(fbn2), 4); + + MS_EXCEPTION_IF_NULL(fbn2); + MS_EXCEPTION_IF_NULL(x_after); + MS_EXCEPTION_IF_NULL(x_before); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(bias); + MS_EXCEPTION_IF_NULL(mean); + MS_EXCEPTION_IF_NULL(var); + std::vector outputs_type; + std::vector> outputs_shape; + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + auto outlist = GetRealNodeUsedList(graph, fbn2); + for (size_t i = 0; i < outlist->size(); i++) { + auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(i).first), 1); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx == 0) { + auto cast = GetRealNodeUsedList(graph, outlist->at(i).first); + if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { + return nullptr; + } + manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(i).first)); + outputs_type.push_back(kNumberTypeFloat16); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get()); + } + } + + manager->Replace(utils::cast(x_after), utils::cast(x_before)); + outputs_type.clear(); + outputs_shape.clear(); + auto output_num = AnfAlgo::GetOutputTensorNum(fbn2); + for (size_t i = 0; i < output_num; i++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(fbn2, i)); + } + outputs_type[0] = kNumberTypeFloat16; + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get()); + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h new file mode 100644 index 0000000000..6b1e2ad7b1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_CAST_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceBNCastFusion : public PatternProcessPass { + public: + explicit ReplaceBNCastFusion(bool multigraph = true) : PatternProcessPass("replace_bn_cast", multigraph) { + x_ = std::make_shared(); + scale_ = std::make_shared(); + bias_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + y_ = std::make_shared(); + running_mean_ = std::make_shared(); + running_var_ = std::make_shared(); + save_mean_ = std::make_shared(); + save_var_ = std::make_shared(); + index_ = std::make_shared(); + } + ~ReplaceBNCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x_; + VarPtr scale_; + VarPtr bias_; + VarPtr mean_; + VarPtr var_; + VarPtr y_; + VarPtr running_mean_; + VarPtr running_var_; + VarPtr save_mean_; + VarPtr save_var_; + VarPtr index_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc new file mode 100644 index 0000000000..4e1be81ab7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { + VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_}); + VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_}); + VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); + return tupleget; +} + +const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + auto fbn2g = AnfAlgo::GetInputNode(utils::cast(node), 0); + + auto dy_after = AnfAlgo::GetInputNode(utils::cast(fbn2g), 0); + auto dy_before = AnfAlgo::GetInputNode(utils::cast(dy_after), 0); + auto x_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 1); + auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0); + // if x_type is fp32, the cast is necessary. + if (x_type == kNumberTypeFloat32) { + return nullptr; + } + auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2g), 2); + auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2g), 3); + auto var = AnfAlgo::GetInputNode(utils::cast(fbn2g), 4); + + MS_EXCEPTION_IF_NULL(fbn2g); + MS_EXCEPTION_IF_NULL(dy_after); + MS_EXCEPTION_IF_NULL(dy_before); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(x_); + MS_EXCEPTION_IF_NULL(mean); + MS_EXCEPTION_IF_NULL(var); + std::vector outputs_type; + std::vector> outputs_shape; + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + // 1. get all of the fusedbatchnormgrad nodes connected after dy_after. + auto fbn2g_all = GetRealNodeUsedList(graph, dy_after); + for (size_t i = 0; i < fbn2g_all->size(); i++) { + outputs_type.clear(); + outputs_shape.clear(); + auto kernel = utils::cast(fbn2g_all->at(i).first); + auto kernel_name = AnfAlgo::GetCNodeName(kernel); + // 2. deal all of the fusedbatchnormgrad, change the data type. + if (kernel_name == AnfAlgo::GetCNodeName(utils::cast(fbn2g))) { + auto output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t j = 0; j < output_num; j++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel, j)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(kernel, j)); + } + outputs_type[0] = kNumberTypeFloat16; + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, kernel.get()); + } + // 3. handle the output of fusedbatchnormgrad: tuplegetitem + auto outlist = GetRealNodeUsedList(graph, kernel); + for (size_t j = 0; j < outlist->size(); j++) { + outputs_type.clear(); + outputs_shape.clear(); + auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(j).first), 1); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx == 0) { + auto cast = GetRealNodeUsedList(graph, outlist->at(j).first); + if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { + continue; + } + manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(j).first)); + outputs_type.push_back(kNumberTypeFloat16); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get()); + } + } + } + manager->Replace(utils::cast(dy_after), utils::cast(dy_before)); + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h new file mode 100644 index 0000000000..b937aa25bf --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceBNGradCastFusion : public PatternProcessPass { + public: + explicit ReplaceBNGradCastFusion(bool multigraph = true) : PatternProcessPass("replace_bn_grad_cast", multigraph) { + dy_ = std::make_shared(); + x_ = std::make_shared(); + scale_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + dx_ = std::make_shared(); + bn_scale_ = std::make_shared(); + bn_bias_ = std::make_shared(); + index_ = std::make_shared(); + } + ~ReplaceBNGradCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr dy_; + VarPtr x_; + VarPtr scale_; + VarPtr mean_; + VarPtr var_; + VarPtr dx_; + VarPtr bn_scale_; + VarPtr bn_bias_; + VarPtr index_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc new file mode 100644 index 0000000000..864bb026af --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceMomentumCastFusion::DefinePattern() const { + VectorRef grad_cast = VectorRef({prim::kPrimCast, grad_}); + VectorRef momentum = VectorRef({prim::kPrimApplyMomentum, var_, acc_, lr_, grad_cast, mom_}); + return momentum; +} + +const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + auto grad_cast = AnfAlgo::GetInputNode(utils::cast(node), 3); + auto grad = AnfAlgo::GetInputNode(utils::cast(grad_cast), 0); + MS_EXCEPTION_IF_NULL(grad_cast); + MS_EXCEPTION_IF_NULL(grad); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(utils::cast(grad_cast), utils::cast(grad)); + std::vector outputs_type; + std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_num; i++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i)); + } + outputs_type[3] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); + + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get()); + + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h new file mode 100644 index 0000000000..f67033dcbe --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceMomentumCastFusion : public PatternProcessPass { + public: + explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) { + var_ = std::make_shared(); + acc_ = std::make_shared(); + lr_ = std::make_shared(); + grad_ = std::make_shared(); + mom_ = std::make_shared(); + } + ~ReplaceMomentumCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr var_; + VarPtr acc_; + VarPtr lr_; + VarPtr grad_; + VarPtr mom_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h index 4b928d6565..53dcca8eec 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_KERNEL_REFCOUNT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_KERNEL_REFCOUNT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ #include #include #include @@ -25,7 +25,8 @@ namespace mindspore { namespace memreuse { enum RefCountType { kDynamicRefCount, kStaticRefCount }; -enum NodeType { NORMAL, SPECIAL }; +enum NodeType { kCommonNode, kCommunicationNode }; +enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary }; static constexpr int kInitIndex = -1; class KernelRefCount { public: @@ -36,6 +37,7 @@ class KernelRefCount { size_t offset_; size_t size_; int index_; + KernelRefType type_; // remember to reset offset KernelRefCount() : stream_id_(0), @@ -44,6 +46,7 @@ class KernelRefCount { offset_(0), size_(0), index_(kInitIndex), + type_(kCommon), reftype_(kStaticRefCount) {} ~KernelRefCount() = default; void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); @@ -65,7 +68,7 @@ class KernelDef { KernelMap inputs_; KernelMap outputs_; KernelMap wk_space_; - NodeType dirty = NORMAL; + NodeType type_ = kCommonNode; KernelDef() = default; ~KernelDef() = default; void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } @@ -95,4 +98,4 @@ class KernelDef { using KernelDefPtr = std::shared_ptr; } // namespace memreuse } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_KERNEL_REFCOUNT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h index 1952415515..b299f4c4b1 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_COPY_MANAGER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_COPY_MANAGER_H_ #include #include +#include #include #include #include @@ -40,29 +41,58 @@ struct TensorInfo { struct KernelExecutionInfo { size_t topo_order_{0}; float execution_perform_{0.0}; - bool trigger_swap_{false}; - bool need_swap_{false}; - // output index to topo orders of node users + bool trigger_swap_out_{false}; + bool trigger_swap_in_{false}; + size_t swap_in_task_num_{0}; + // Key: output index, value: topo orders of node users std::map> node_users_map_; - // kernel output idx to host addr - std::map host_addrs_; + // Key: output index, value: pair (host addr, dirty or not) + std::map> host_addrs_; - KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} - explicit KernelExecutionInfo(size_t topo_order) - : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} - KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) + KernelExecutionInfo() {} + explicit KernelExecutionInfo(size_t topo_order) : KernelExecutionInfo(topo_order, 0.0, false, false, 0) {} + KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap_out, bool trigger_swap_in, + size_t swap_in_task_num) : topo_order_(topo_order), execution_perform_(execution_perform), - trigger_swap_(trigger_swap), - need_swap_(need_swap) {} + trigger_swap_out_(trigger_swap_out), + trigger_swap_in_(trigger_swap_in), + swap_in_task_num_(swap_in_task_num) {} }; -// trigger swap struct MemSwapInfo { SwapKind swap_kind_; - // kernel need to be swapped - AnfNodePtr kernel_{nullptr}; + // Topo order of kernel need be swapped + size_t topo_order_; size_t output_idx_{0}; + // Record the swapping out position of swapping in tensor + size_t swap_out_pos_; +}; + +struct SwapInfoComp { + bool operator()(const MemSwapInfo &a, const MemSwapInfo &b) { + int swap_kind_a = static_cast(a.swap_kind_); + int swap_kind_b = static_cast(b.swap_kind_); + if (swap_kind_a < swap_kind_b) { + return true; + } else if (swap_kind_a > swap_kind_b) { + return false; + } + + if (a.swap_out_pos_ < b.swap_out_pos_) { + return true; + } else if (a.swap_out_pos_ > b.swap_out_pos_) { + return false; + } + + if (a.topo_order_ < b.topo_order_) { + return true; + } else if (a.topo_order_ > b.topo_order_) { + return false; + } + + return a.output_idx_ < b.output_idx_; + } }; class MemCopyManager { @@ -75,7 +105,12 @@ class MemCopyManager { virtual void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} - virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} + virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr, bool profiling, + float *cost_time) {} + + virtual void AddMemSwapOutTaskMock(const DeviceAddressPtr &device_address) {} + + virtual void AddMemSwapInTaskMock(const DeviceAddressPtr &device_address) {} virtual bool SyncMemCopyStream(SwapKind swap_kind) { return true; } @@ -83,15 +118,22 @@ class MemCopyManager { virtual DeviceAddressPtr UpdateSwapInQueue() { return nullptr; } + virtual DeviceAddressPtr UpdateSwapOutQueueMock() { return nullptr; } + + virtual DeviceAddressPtr UpdateSwapInQueueMock() { return nullptr; } + virtual bool AllocHostPinnedMem(size_t size, void **addr) const { return true; } virtual void FreeHostPinnedMem(void *addr) const {} virtual void ClearSwapQueue() {} + + virtual void ClearSwapQueueMock() {} }; using MemCopyManagerPtr = std::shared_ptr; +using MemSwapInfoSet = std::set; } // namespace memswap } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc index 8f705be556..990a9d9c29 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc @@ -15,7 +15,7 @@ */ #include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" @@ -191,7 +191,8 @@ void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { MS_EXCEPTION_IF_NULL(device_addr); auto mem_block = FindMemBlock(device_addr); if (mem_block == nullptr) { - MS_LOG(WARNING) << "Can't find the mem_block of the device address[" << device_addr << "]."; + // May be destory the memory pool first, then destory the address, so this is normal case. + MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "]."; return; } CombineMemBuf(mem_block, device_addr); @@ -262,7 +263,7 @@ void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr devi } void DynamicMemPoolBestFit::ReleaseDeviceRes() { - MS_LOG(INFO) << "The dynamic memmory pool total size is " << total_mem_statistics_ << ", total used size is " + MS_LOG(INFO) << "The dynamic memory pool total size is " << total_mem_statistics_ << ", total used size is " << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << "."; for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { auto device_addr = (*iter)->device_addr(); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h index 07efa267aa..72765f4ecf 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_DYNAMIC_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_DYNAMIC_ALLOCATOR_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_DYNAMIC_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_DYNAMIC_ALLOCATOR_H_ #include #include @@ -139,4 +139,4 @@ class DynamicMemPoolBestFit { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_DYNAMIC_ALLOCATOR_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_DYNAMIC_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc index 263ceaec63..c45504e214 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -46,6 +46,8 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { if (iter == kernel_output_refs_.end()) { auto output_sizes = kernel_mod->GetOutputSizeList(); KernelRefCountPtrList kernel_refs; + bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel_cnode); + size_t output_index = 0; for (auto size : output_sizes) { total_dy_size_ += size; // do not MallocDynamicMem just record this @@ -54,9 +56,32 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); kernel_ref->stream_id_ = curr_stream_id; kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); + if (is_comm_op) { + kernel_ref->type_ = kCommReuse; + } else { + session::AnfWithOutIndex out_pair(kernel_cnode, output_index); + if (graph_->IsInRefOutputMap(out_pair)) { + kernel_ref->type_ = kRefNodeOutput; + auto origin_pair = graph_->GetRefCorrespondOutput(out_pair); + MS_EXCEPTION_IF_NULL(origin_pair.first); + MS_LOG(INFO) << "REF origin op is " << origin_pair.first->fullname_with_scope() << ", output index is " + << origin_pair.second << ", cur op is " << kernel_cnode->fullname_with_scope() + << ", out index is " << output_index; + if (origin_pair.first->isa()) { + auto cnode = origin_pair.first->cast(); + auto ref_ptr = GetRef(cnode, origin_pair.second); + if (ref_ptr != nullptr) { + ref_ptr->type_ = kRefNodeInput; + } + } + } else { + kernel_ref->type_ = kCommon; + } + } kernel_refs.push_back(kernel_ref); kernel_out_ref_num++; total_refs_list_.push_back(kernel_ref); + output_index++; } if (!kernel_refs.empty()) { kernel_output_refs_[key] = kernel_refs; @@ -155,9 +180,19 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_def_ptr); auto key = kernel.get(); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel); + size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_tensor_num; ++i) { auto ref_ptr = GetKernelInputRef(kernel, i); if (ref_ptr != nullptr) { + if (is_comm_op) { + if (input_tensor_num == 1) { + ref_ptr->type_ = kCommReuse; + } else { + ref_ptr->type_ = kCommNotReuse; + } + } + if (ref_ptr->reftype() == kStaticRefCount) { continue; } else if (ref_ptr->reftype() == kDynamicRefCount) { @@ -258,6 +293,11 @@ void MemReuseUtil::SetKernelDefMap() { auto key = kernel.get(); kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); + if (AnfAlgo::IsCommunicationOp(kernel)) { + kernel_def_ptr->type_ = kCommunicationNode; + } else { + kernel_def_ptr->type_ = kCommonNode; + } kernel_def_ptr_list_.push_back(kernel_def_ptr); kernel_map_[key] = kernel_def_ptr; } @@ -337,6 +377,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() { KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; kernel_ref->ref_count_ = kMaxRefCount; kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; + kernel_ref->type_ = kSummary; total_summary_size += kernel_ref->size_; MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; } else { @@ -344,12 +385,29 @@ void MemReuseUtil::SetSummaryNodesRefCount() { } } #ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); #endif MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; } +void MemReuseUtil::SetRefNodesInputRefCount() { + size_t total_size = 0; + for (auto iter : kernel_output_refs_) { + for (auto &ref_count : iter.second) { + MS_EXCEPTION_IF_NULL(ref_count); + if (ref_count->type_ == kRefNodeInput) { + ref_count->ref_count_ = kMaxRefCount; + total_size += ref_count->size_; + } + } + } + + MS_LOG(INFO) << "Special Tensor total size: RefNodeInput: " << total_size; +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); +#endif +} + void MemReuseUtil::SetGraphOutputRefCount() { auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { @@ -376,8 +434,7 @@ void MemReuseUtil::SetGraphOutputRefCount() { } } #ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); #endif } @@ -390,13 +447,14 @@ void MemReuseUtil::ResetDynamicUsedRefCount() { } } -void MemReuseUtil::SetAllInfo(KernelGraph *graph) { +void MemReuseUtil::SetAllInfo(const KernelGraph *graph) { if (!InitDynamicKernelRef(graph)) { MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; } SetKernelDefMap(); SetReuseRefCount(); SetSummaryNodesRefCount(); + SetRefNodesInputRefCount(); SetWorkSpaceList(); #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h index b286bcbc2c..ad884f44b4 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ #include #include #include @@ -52,7 +52,7 @@ class MemReuseUtil { MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; } - void SetAllInfo(KernelGraph *graph); + void SetAllInfo(const KernelGraph *graph); bool InitDynamicOutputKernelRef(); bool InitDynamicWorkspaceKernelRef(); bool InitDynamicKernelRef(const KernelGraph *graph); @@ -64,6 +64,7 @@ class MemReuseUtil { void SetKernelDefInputs(); void SetReuseRefCount(); void SetSummaryNodesRefCount(); + void SetRefNodesInputRefCount(); // Set the reference count of graph output specially. void SetGraphOutputRefCount(); // Reset the dynamic used reference count by ref_count_. @@ -83,6 +84,7 @@ class MemReuseUtil { void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; + bool is_all_nop_node() const { return is_all_nop_node_; } private: int util_index_; @@ -104,4 +106,4 @@ using MemReuseUtilPtr = std::shared_ptr; } // namespace memreuse } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc index d1a50a0dfe..e791d318fa 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc @@ -33,11 +33,11 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); // check info Correctness for (auto &tensor : tensor_ptr_list_) { - tensor->size_ = AlignMemorySize(tensor->size_); + tensor->size_ = AlignCommonMemorySize(tensor->size_); } // align wk size to 512 && refcount == 1 for (auto &wk : wk_tensor_list_) { - wk->size_ = AlignMemorySize(wk->size_); + wk->size_ = AlignCommonMemorySize(wk->size_); wk->ref_count_ = 1; } #ifdef ENABLE_D @@ -90,7 +90,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr auto curr_stream_id = kernel_curr->stream_id(); auto prev_stream_id = kernel_prev->stream_id(); if (curr_stream_id == prev_stream_id) { - mem_buf->type_ = IN_STREAM_REUSE; + mem_buf->type_ = kInStreamReuse; return true; } @@ -117,7 +117,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr } if (reuse_between_streams) { - mem_buf->type_ = BETWEEN_STREAMS_REUSE; + mem_buf->type_ = kBetweenStreamReuse; return true; } @@ -128,18 +128,33 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr auto kernel_curr_front = iter->second; auto depend_count = kernel_curr_front.count(kernel_prev); if (depend_count) { - mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; + mem_buf->type_ = kKernelDependenceReuse; return true; } return false; } -void BestFitMemReuse::AssignNodeOutputOffset() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { +void BestFitMemReuse::AssignCommonNodeOutputOffset() { + MS_EXCEPTION_IF_NULL(current_kernel_); + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[index]; MS_EXCEPTION_IF_NULL(tensor_desc); + if (tensor_desc->type_ == kRefNodeInput) { + total_refinput_size += tensor_desc->size_; + } else if (tensor_desc->type_ == kRefNodeOutput) { + total_refoutput_size += tensor_desc->size_; + // no need to alloc refnode output's memory + continue; + } else if (tensor_desc->type_ == kCommNotReuse) { + total_comm_not_reuse_size += tensor_desc->size_; + } else if (tensor_desc->type_ == kCommReuse) { + // get align size for communication op's single input + tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_); + total_comm_reuse_size += tensor_desc->size_; + } + auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); if (!reusable_membuf_map.empty()) { auto membuf_index = reusable_membuf_map.begin()->second; @@ -152,6 +167,94 @@ void BestFitMemReuse::AssignNodeOutputOffset() { MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; #endif } + // skip left align border for communication op single input to used + if (tensor_desc->type_ == kCommReuse) { + tensor_desc->offset_ += kDefaultMemAlignSize; + } + } +} + +void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { + size_t total_kernel_output_size = 0; + size_t output_num = 0; + // get all output size + MS_EXCEPTION_IF_NULL(current_kernel_); + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + if (tensor_desc->type_ == kCommReuse) { + total_comm_reuse_size += tensor_desc->size_; + total_comm_output_reuse_size += tensor_desc->size_; + total_kernel_output_size += tensor_desc->size_; + } else { + MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:" + << current_kernel_->scope_full_name() << " output index:" << tensor_idx + << " tensor_type:" << tensor_desc->type_; + continue; + } + } + total_kernel_output_size = AlignCommunicationMemorySize(total_kernel_output_size); + + // add left align border for the first output and right align border for the last output to alloc align border memory + size_t output_index = 0; + auto output_ref_indexes = current_kernel_->GetOutputRefIndexs(); + for (const auto &tensor_idx : output_ref_indexes) { + size_t index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + if (output_index == 0 || output_index == output_num - 1) { + tensor_desc->size_ += kDefaultMemAlignSize; + } + + if ((output_index == 0) && (output_ref_indexes.size() == 1)) { + // add right align border for single output + tensor_desc->size_ += kDefaultMemAlignSize; + } + + output_index++; + } + + auto reusable_membuf_map = GetReusableMembufMap(total_kernel_output_size); + if (!reusable_membuf_map.empty()) { + auto membuf_index = reusable_membuf_map.begin()->second; + output_index = 0; + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + ReuseExistMembuf(tensor_desc.get(), membuf_index + output_index, kDynamicMem); + // skip skip left align border for communication op's first output to used + if (output_index == 0) { + tensor_desc->offset_ += kDefaultMemAlignSize; + } + output_index++; + } + } else { + // no membuf can reuse, add new membuf after the membuf_ptr_list + output_index = 0; + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + AddNewMembufPtr(tensor_desc.get(), kDynamicMem); + // skip align size offset for first output to used + if (output_index == 0) { + tensor_desc->offset_ += kDefaultMemAlignSize; + } + output_index++; +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; +#endif + } + } +} + +void BestFitMemReuse::AssignNodeOutputOffset() { + if (current_kernel_->type_ == kCommunicationNode) { + AssignCommunicationNodeOutputOffset(); + } else { + AssignCommonNodeOutputOffset(); } } @@ -231,7 +334,7 @@ void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { } auto membuf_size = tensor_desc->size_; auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); - auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); + auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, kNew, current_kernel_); membuf_ptr_list_.push_back(membuf); tensor_desc->offset_ = membuf_offset; } @@ -253,7 +356,7 @@ void BestFitMemReuse::UpdateNodeInputAndMembuf() { } void BestFitMemReuse::ReleaseNodeUnusedOutput() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t tensor_index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[tensor_index]; MS_EXCEPTION_IF_NULL(tensor_desc); @@ -319,11 +422,17 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { } } -size_t BestFitMemReuse::AlignMemorySize(size_t size) const { +size_t BestFitMemReuse::AlignCommonMemorySize(size_t size) const { // memory size 512 align return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; } +size_t BestFitMemReuse::AlignCommunicationMemorySize(size_t size) const { + // memory size 512 align and add communication memory: left align border memory - data - right align border memory + return kDefaultMemAlignSize + (size + kDefaultMemAlignSize - 1) / kDefaultMemAlignSize * kDefaultMemAlignSize + + kDefaultMemAlignSize; +} + size_t BestFitMemReuse::GetAllocatedSize() { size_t AllocatedSize = kTotalSize; if (membuf_ptr_list_.empty()) { @@ -412,6 +521,10 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { ++op_num; #endif } + MS_LOG(INFO) << "Special Tensor total size: RefInput: " << total_refinput_size + << " RefOutput: " << total_refoutput_size << " CommReuse: " << total_comm_reuse_size + << " CommOutputReuse: " << total_comm_output_reuse_size + << " CommNotReuse: " << total_comm_not_reuse_size; #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().ExportMembufInfoIR(); MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h index ef1cfd3e11..fdfd6f39ac 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ #include #include #include @@ -40,11 +40,11 @@ static constexpr int kDynamicMem = -1; static constexpr int kWorkspaceMem = 1; static constexpr size_t kTotalSize = 0; enum Status { kUnused, kReused }; -enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; +enum MemType { kNew, kInStreamReuse, kBetweenStreamReuse, kKernelDependenceReuse }; class Membuf { public: Membuf() = default; - Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) + Membuf(Status status, size_t size, size_t offset, int index, MemType type, const KernelDefPtr &used_kernel) : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} ~Membuf() = default; // Memory block status flags @@ -53,7 +53,7 @@ class Membuf { size_t offset_{0}; // Store the tensor index stored in this memory block at a certain moment int index_{0}; - MEMTYPE type_{NEW}; + MemType type_{kNew}; KernelDefPtr used_kernel_; }; using MembufPtr = std::shared_ptr; @@ -74,6 +74,14 @@ class BestFitMemReuse { * Assign output tensor memory offset of current kernel */ void AssignNodeOutputOffset(); + /** + * Assign output tensor memory offset of common kernel + */ + void AssignCommonNodeOutputOffset(); + /** + * Assign output tensor memory offset of communication kernel + */ + void AssignCommunicationNodeOutputOffset(); /** * Update input tensor's status of current kernel, and the status of membuf used by current kernel */ @@ -110,8 +118,10 @@ class BestFitMemReuse { void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); // Merge unused membuf void ReleaseMembuf(size_t tensor_index, int flag); - // Memory address alignment 512 - size_t AlignMemorySize(size_t size) const; + // Memory address alignment for common memory + size_t AlignCommonMemorySize(size_t size) const; + // Memory address alignment for communication used memory + size_t AlignCommunicationMemorySize(size_t size) const; int GetRealIndex(size_t index, int flag = kDynamicMem) const; size_t GetTensorIndex(int index) const; size_t GetWorkspaceIndex(int index) const; @@ -153,7 +163,12 @@ class BestFitMemReuse { // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def std::map> kernel_front_map_; std::vector> stream_groups_; + size_t total_refinput_size{0}; + size_t total_refoutput_size{0}; + size_t total_comm_reuse_size{0}; + size_t total_comm_output_reuse_size{0}; + size_t total_comm_not_reuse_size{0}; }; } // namespace memreuse } // namespace mindspore -#endif // #define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#endif // #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc index b93bf42f9f..81dc3f8ba0 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc @@ -83,7 +83,7 @@ int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { return static_input_size; } -int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { +int64_t MemReuseChecker::CalculOriValue(const KernelGraph *graph) const { MS_EXCEPTION_IF_NULL(graph); int64_t static_value_size = 0; for (auto &value_node : graph->graph_value_nodes()) { @@ -101,7 +101,7 @@ int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { return static_value_size; } -int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { +int64_t MemReuseChecker::CalculOriStatic(const KernelGraph *graph) const { // cal static inputs auto static_input_size = CalculOriInput(graph); // do not calcul outpput size @@ -154,7 +154,7 @@ std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { } void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, - const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { + const KernelDefPtrMaps &kernel_def_ptr_list, const KernelGraph *graph) { total_ori_static_size_ = CalculOriStatic(graph); total_ori_input_size_ = CalculOriInput(graph); total_ori_value_size_ = CalculOriValue(graph); @@ -170,12 +170,14 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li ofs << "all_tensor_refs:\n"; ofs << "index:" << "\tsize:" - << "\trefcount:\n"; + << "\trefcount:" + << "\ttype:\n"; for (auto &ref : total_refs_list) { ofs << "%" << ref->index_ << "T" << "\t" << "#" << ref->size_ << "S" << "\t" << ref->ref_count_ << "C" + << "\t" << ref->type_ << "t" << "\n"; } ofs << "kernel_def exc_order:\n"; @@ -241,7 +243,7 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { auto scope_name = def->scope_full_name(); std::string split_name = GetSplitName(scope_name); - ofs << "$" << def_idx << "\t" << split_name << "\t"; + ofs << "$" << def_idx << "\t" << split_name << "\t" << static_cast(def->type_) << "\t"; ofs << "inputs["; for (auto &in : def->inputs_) { for (auto &in_ref : in.second) { diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h index 3c4a00a3ca..c5a5a128a1 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_ #include #include #include @@ -43,10 +43,10 @@ class MemReuseChecker { void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); bool CheckGraphOutputAssigned(const session::KernelGraph *graph); void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, - KernelGraph *graph); - int64_t CalculOriStatic(KernelGraph *graph) const; + const KernelGraph *graph); + int64_t CalculOriStatic(const KernelGraph *graph) const; int64_t CalculOriInput(const KernelGraph *graph) const; - int64_t CalculOriValue(KernelGraph *graph) const; + int64_t CalculOriValue(const KernelGraph *graph) const; int64_t CalculOriDy(const KernelGraph *graph) const; int64_t CalculOriWk(const KernelGraph *graph) const; std::string GetSplitName(const std::string &scope_name) const; @@ -94,4 +94,4 @@ class MemReuseChecker { }; } // namespace memreuse } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc index 41bf5460c3..1de2021490 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -22,22 +22,16 @@ namespace mindspore { namespace device { namespace memswap { -void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { +bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size) { MS_EXCEPTION_IF_NULL(kernel_graph); - graph_manager_ = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(graph_manager_); - auto &kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { - if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { - execution_order_.push_back(kernel); - } - } + execution_order_ = kernel_graph->execution_order(); + kernel_graph_ = kernel_graph; size_t kernel_index = 0; for (const auto &kernel : execution_order_) { - // parse topo order of kernel + // Parse topo order of kernel (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); - // parse tensor info + // Parse tensor info auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); @@ -48,7 +42,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { } } - // parse topo order of user kernel + // Parse topo order of user kernel SaveUserKernelTopoOrder(); sort(ordered_tensors_.begin(), ordered_tensors_.end(), @@ -61,18 +55,131 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { tensor_size_num_++; } } - tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; - tensor_size_threshold_idx_ = 0; - - distance_threshold_ = kernel_index / kDistanceInitFactor; + if (!InitSwapThreshold(0)) { + return false; + } mem_swap_initialized_ = true; MS_EXCEPTION_IF_NULL(mem_copy_manager_); mem_copy_manager_->Init(); + return true; +} + +bool MemSwapManager::InitSwapThreshold(size_t swap_mem_size) { + distance_threshold_ = execution_order_.size() / kDistanceInitFactor; + distance_decay_step_ = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; + if (distance_decay_step_ <= 1) { + distance_decay_step_ = 1; + } + tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; + tensor_size_threshold_idx_ = 0; + + size_t accumulation = 0; + while (accumulation < swap_mem_size) { + accumulation = 0; + for (const auto &tensor_info : ordered_tensors_) { + size_t tensor_size = tensor_info.tensor_size_; + if (tensor_size < tensor_size_threshold_) { + break; + } + if (!CheckDistanceBetweenKernels(tensor_info)) { + continue; + } + + accumulation += tensor_info.tensor_size_; + if (accumulation >= swap_mem_size) { + return true; + } + } + RetreatSwapThreshold(); + if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { + MS_LOG(ERROR) << "Init swap threshold info failed"; + return false; + } + } + return true; +} + +void MemSwapManager::RetreatSwapThreshold() { + if (distance_threshold_ >= kDistanceLowerBound) { + bool update_one_decay_step = (distance_threshold_ > distance_decay_step_) && + (distance_threshold_ - distance_decay_step_ >= kDistanceLowerBound); + if (update_one_decay_step) { + distance_threshold_ -= distance_decay_step_; + } else if (distance_threshold_ >= kDistanceLowerBound) { + size_t new_distance_decay_step = (distance_threshold_ - kDistanceLowerBound) / 4; + if (new_distance_decay_step < 1) { + new_distance_decay_step = 1; + } + distance_threshold_ -= new_distance_decay_step; + } + } + + while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { + ++tensor_size_threshold_idx_; + if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { + tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; + break; + } + } +} + +bool MemSwapManager::CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const { + const AnfNodePtr &kernel = tensor_info.kernel_; + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &node_users_map = kernel_exec_info.node_users_map_; + + auto iter = node_users_map.find(tensor_info.output_idx_); + if (iter == node_users_map.end()) { + return false; + } + + auto &node_users = iter->second; + if (node_users.front() - kernel_exec_info.topo_order_ > distance_threshold_) { + return true; + } + + for (size_t i = 1; i < node_users.size(); ++i) { + if (node_users[i] - node_users[i - 1] > distance_threshold_) { + return true; + } + } + return false; +} + +std::vector> MemSwapManager::CheckDistanceBetweenKernelsWithIdx( + const TensorInfo &tensor_info) const { + const AnfNodePtr &kernel = tensor_info.kernel_; + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &node_users_map = kernel_exec_info.node_users_map_; + std::vector> need_swap_topo_pair_list; + + auto iter = node_users_map.find(tensor_info.output_idx_); + if (iter == node_users_map.end()) { + return need_swap_topo_pair_list; + } + auto &node_users = iter->second; + if (node_users.front() - kernel_exec_info.topo_order_ > distance_threshold_) { + need_swap_topo_pair_list.emplace_back(kernel_exec_info.topo_order_, node_users.front()); + } + + for (size_t i = 1; i < node_users.size(); ++i) { + if (node_users[i] - node_users[i - 1] > distance_threshold_) { + need_swap_topo_pair_list.emplace_back(node_users[i - 1], node_users[i]); + } + } + return need_swap_topo_pair_list; } bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); - NodeUsersMap &user_map = graph_manager_->node_users(); + if (AnfAlgo::IsCommunicationOp(kernel)) { + return true; + } + + MS_EXCEPTION_IF_NULL(kernel_graph_); + const auto &graph_manager = kernel_graph_->manager(); + MS_EXCEPTION_IF_NULL(graph_manager); + NodeUsersMap &user_map = graph_manager->node_users(); auto iter = user_map.find(kernel); bool adjacent_with_communication_op = false; if (iter != user_map.end()) { @@ -81,11 +188,14 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { node_set.begin(), node_set.end(), [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); } - return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; + return adjacent_with_communication_op; } void MemSwapManager::SaveUserKernelTopoOrder() { - NodeUsersMap &user_map = graph_manager_->node_users(); + MS_EXCEPTION_IF_NULL(kernel_graph_); + const auto &graph_manager = kernel_graph_->manager(); + MS_EXCEPTION_IF_NULL(graph_manager); + NodeUsersMap &user_map = graph_manager->node_users(); for (const auto &kernel : execution_order_) { auto iter = user_map.find(kernel); if (iter == user_map.end()) { @@ -95,7 +205,7 @@ void MemSwapManager::SaveUserKernelTopoOrder() { auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); for (auto &node_pair : node_set) { auto user_kernel = node_pair.first; - if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { + if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) { continue; } @@ -120,60 +230,55 @@ void MemSwapManager::AddSwapInfo() { break; } - size_t output_idx = tensor.output_idx_; const AnfNodePtr &kernel = tensor.kernel_; if (IsCommunicationRelevantOp(kernel)) { continue; } - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - auto &node_users_map = kernel_exec_info.node_users_map_; - auto iter = node_users_map.find(output_idx); - if (iter == node_users_map.end()) { - continue; - } - auto &node_users = iter->second; - bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) || - (node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_); - if (!need_swap) { + auto need_swap_topo_pair_list = CheckDistanceBetweenKernelsWithIdx(tensor); + if (need_swap_topo_pair_list.empty()) { continue; } - AddKernelNeedSwap(kernel, true); HostAddress host_addr; host_addr.size = tensor_size; - auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast(&host_addr.addr)); - if (!ret) { - MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; - } - kernel_exec_info.host_addrs_[output_idx] = host_addr; - MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; - if (node_users.size() > 1) { - AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); - AddKernelTriggerSwap(execution_order_[node_users[0]], true); - } else { - AddKernelMemSwapInfo(kernel, mem_swap_out_info); - AddKernelTriggerSwap(kernel, true); - } + host_addr.addr = nullptr; - size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; - if (swap_in_order <= kernel_exec_info.topo_order_) { - MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + size_t output_idx = tensor.output_idx_; + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.host_addrs_[output_idx] = std::make_pair(host_addr, true); + + for (auto &swap_topo_pair : need_swap_topo_pair_list) { + size_t swap_out_order = swap_topo_pair.first; + MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel_exec_info.topo_order_, output_idx, + swap_out_order}; + AddKernelMemSwapInfo(execution_order_[swap_out_order], mem_swap_out_info); + + size_t swap_in_order = swap_topo_pair.second - 1; + MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel_exec_info.topo_order_, output_idx, + swap_out_order}; + if (swap_in_order <= swap_out_order) { + MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + AddKernelMemSwapInfo(execution_order_[swap_in_order], mem_swap_in_info); } - auto swap_in_kernel = execution_order_[swap_in_order]; - MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; - AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); - AddKernelTriggerSwap(swap_in_kernel, true); - - host_addrs_list_.push_back(host_addr); } } void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, - const HostAddress &host_address) const { + const HostAddress &host_address, bool mock, bool profiling, + float *cost_time) const { + if (!mock) { + if (swap_kind == SwapKind::kDeviceToHost) { + mem_copy_manager_->AddMemSwapOutTask(device_address, host_address); + } else if (swap_kind == SwapKind::kHostToDevice) { + mem_copy_manager_->AddMemSwapInTask(device_address, host_address, profiling, cost_time); + } + } + if (swap_kind == SwapKind::kDeviceToHost) { - mem_copy_manager_->AddMemSwapOutTask(device_address, host_address); + mem_copy_manager_->AddMemSwapOutTaskMock(device_address); } else if (swap_kind == SwapKind::kHostToDevice) { - mem_copy_manager_->AddMemSwapInTask(device_address, host_address); + mem_copy_manager_->AddMemSwapInTaskMock(device_address); } } @@ -181,34 +286,30 @@ bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) const { return mem_copy_manager_->SyncMemCopyStream(swap_kind); } -DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { +DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind, bool mock) const { + if (!mock) { + if (swap_kind == SwapKind::kDeviceToHost) { + return mem_copy_manager_->UpdateSwapOutQueue(); + } else { + return mem_copy_manager_->UpdateSwapInQueue(); + } + } + if (swap_kind == SwapKind::kDeviceToHost) { - return mem_copy_manager_->UpdateSwapOutQueue(); + return mem_copy_manager_->UpdateSwapOutQueueMock(); } else { - return mem_copy_manager_->UpdateSwapInQueue(); + return mem_copy_manager_->UpdateSwapInQueueMock(); } } -// retreat to find a workable swap scheme +// Retreat to find a workable swap scheme bool MemSwapManager::RetreatSwapInfo() { if (!trigger_swap_) { trigger_swap_ = true; } if (swap_info_already_set_) { ResetSwapInfo(); - if (distance_threshold_ >= kDistanceLowerBound) { - auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; - distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1); - } - - while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { - ++tensor_size_threshold_idx_; - if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { - tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; - break; - } - } - + RetreatSwapThreshold(); if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { MS_LOG(ERROR) << "Retreat swap info failed"; return false; @@ -220,6 +321,114 @@ bool MemSwapManager::RetreatSwapInfo() { return true; } +void MemSwapManager::AdjustSwapInPos(const AnfNodePtr &kernel, size_t index) { + if (kernel_first_move_cache_map_.find(kernel.get()) == kernel_first_move_cache_map_.end()) { + CacheCurSwapInfoSet(kernel); + } + + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + size_t kernel_pos = kernel_exec_info.topo_order_; + auto &mem_swap_info = mem_swap_info_cache_list_[index]; + + if (QueryFirstTimeMovePos(kernel, index)) { + best_and_cur_pos_cache_.first = BestSwapInPerformPos(kernel, mem_swap_info); + best_and_cur_pos_cache_.second = best_and_cur_pos_cache_.first; + size_t best_pos = best_and_cur_pos_cache_.first; + if (best_pos != kernel_pos) { + MoveSwapInfoPos(best_pos, kernel_pos, mem_swap_info); + } + AddFirstTimeMovePos(kernel, index, false); + return; + } + + auto &cur_pos = best_and_cur_pos_cache_.second; + if (cur_pos < kernel_pos) { + MoveSwapInfoPos(cur_pos + 1, cur_pos, mem_swap_info); + cur_pos++; + } +} + +void MemSwapManager::CacheCurSwapInfoSet(const AnfNodePtr &kernel) { + if (!kernel_first_move_cache_map_.empty()) { + kernel_first_move_cache_map_.clear(); + } + if (!mem_swap_info_cache_list_.empty()) { + mem_swap_info_cache_list_.clear(); + } + + auto mem_swap_info_set = QueryKernelMemSwapInfo(kernel); + size_t swap_in_task_cnt = 0; + for (auto &mem_swap_info : mem_swap_info_set) { + if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + (void)mem_swap_info_cache_list_.push_back(mem_swap_info); + kernel_first_move_cache_map_[kernel.get()].push_back(true); + swap_in_task_cnt++; + } + } + size_t swap_in_task_num = QueryKernelTriggerSwapInTaskNum(kernel); + if (swap_in_task_cnt != swap_in_task_num) { + MS_LOG(EXCEPTION) << "Swap_in_task_cnt :" << swap_in_task_cnt + << "must equal Swap_in_task_num: " << swap_in_task_num; + } +} + +void MemSwapManager::AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time) { + auto iter = kernel_first_move_cache_map_.find(kernel.get()); + if (iter == kernel_first_move_cache_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + auto &first_move_list = iter->second; + if (index >= first_move_list.size()) { + MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; + } + first_move_list[index] = first_time; +} + +bool MemSwapManager::QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const { + auto iter = kernel_first_move_cache_map_.find(kernel.get()); + if (iter == kernel_first_move_cache_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + const auto &first_move_list = iter->second; + if (index >= first_move_list.size()) { + MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; + } + return first_move_list[index]; +} + +size_t MemSwapManager::BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const { + auto need_swap_kernel = QueryKernelByTopoOrder(mem_swap_info.topo_order_); + const PerformPair &perform_pair = QueryKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_); + float swap_in_cost_time = perform_pair.second; + size_t swap_out_pos = mem_swap_info.swap_out_pos_; + auto &kernel_exec_info = SearchKernelExecutionInfo(trigger_kernel); + size_t trigger_kernel_pos = kernel_exec_info.topo_order_; + float kernel_execution_time = 0; + + size_t pos = trigger_kernel_pos; + for (; pos > swap_out_pos + 1; pos--) { + auto kernel = QueryKernelByTopoOrder(pos - 1); + if (QueryKernelTriggerSwapIn(kernel)) { + return pos; + } + kernel_execution_time += QueryKernelExecutionPerform(QueryKernelByTopoOrder(pos)); + if (kernel_execution_time >= swap_in_cost_time) { + return pos - 1; + } + } + return pos; +} + +void MemSwapManager::MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info) { + if (des_pos == src_pos) { + MS_LOG(EXCEPTION) << "destination pos can not equal source pos"; + } + auto des_kernel = QueryKernelByTopoOrder(des_pos); + auto src_kernel = QueryKernelByTopoOrder(src_pos); + AddKernelMemSwapInfo(des_kernel, mem_swap_info); + RemoveKernelMemSwapInfo(src_kernel, mem_swap_info); +} + KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); auto iter = kernel_execution_info_.find(kernel.get()); @@ -234,25 +443,53 @@ void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float p kernel_exec_info.execution_perform_ = perform; } -void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.trigger_swap_ = trigger_swap; -} - -void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.need_swap_ = need_swap; -} - void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, const std::pair &perform) { MS_EXCEPTION_IF_NULL(kernel); - kernel_swap_perform_[kernel.get()][output_idx] = perform; + auto iter = kernel_swap_perform_.find(kernel.get()); + if (iter == kernel_swap_perform_.end()) { + kernel_swap_perform_[kernel.get()][output_idx] = perform; + } } void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { MS_EXCEPTION_IF_NULL(kernel); - mem_swap_info_[kernel.get()].push_back(mem_swap_info); + (void)mem_swap_info_map_[kernel.get()].insert(mem_swap_info); + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { + kernel_exec_info.trigger_swap_out_ = true; + } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + kernel_exec_info.swap_in_task_num_++; + kernel_exec_info.trigger_swap_in_ = true; + } +} + +void MemSwapManager::RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { + MS_EXCEPTION_IF_NULL(kernel); + if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + auto map_iter = mem_swap_info_map_.find(kernel.get()); + if (map_iter == mem_swap_info_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + MemSwapInfoSet &mem_swap_info_set = map_iter->second; + + auto set_iter = mem_swap_info_set.find(mem_swap_info); + if (set_iter == mem_swap_info_set.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information in mem swap info set"; + } + mem_swap_info_set.erase(set_iter); + + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + if (kernel_exec_info.swap_in_task_num_ > 0) { + kernel_exec_info.swap_in_task_num_--; + } + if (kernel_exec_info.swap_in_task_num_ == 0) { + kernel_exec_info.trigger_swap_in_ = false; + } + if (mem_swap_info_set.empty()) { + (void)mem_swap_info_map_.erase(kernel.get()); + } + } } float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { @@ -262,12 +499,29 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.trigger_swap_; + return kernel_exec_info.trigger_swap_out_ || kernel_exec_info.trigger_swap_in_; +} + +bool MemSwapManager::QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.trigger_swap_in_; } -bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { +size_t MemSwapManager::QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const { const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.need_swap_; + return kernel_exec_info.swap_in_task_num_; +} + +const AnfNodePtr MemSwapManager::QueryKernelByTopoOrder(size_t index) const { + if (index >= execution_order_.size()) { + MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; + } + return execution_order_[index]; +} + +size_t MemSwapManager::QueryKernelTopoOrder(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.topo_order_; } const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { @@ -286,30 +540,68 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern return iter_output->second; } -const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { +const MemSwapInfoSet &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); - auto iter = mem_swap_info_.find(kernel.get()); - if (iter == mem_swap_info_.end()) { - MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + auto iter = mem_swap_info_map_.find(kernel.get()); + if (iter == mem_swap_info_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; } return iter->second; } -void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } - -bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { - auto iter = swap_in_blacklist_.find(device_ptr); - return iter != swap_in_blacklist_.end(); +void MemSwapManager::AssignHostMemory() { + for (auto &kernel_exec_info_pair : kernel_execution_info_) { + auto &kernel_exec_info = kernel_exec_info_pair.second; + auto &host_addrs_map = kernel_exec_info.host_addrs_; + for (auto &host_addr_pair : host_addrs_map) { + auto &host_addr = host_addr_pair.second.first; + auto ret = AllocHostPinnedMem(host_addr.size, reinterpret_cast(&host_addr.addr)); + if (!ret) { + MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << host_addr.size << "] failed."; + } + host_addrs_list_.push_back(host_addr); + } + } } -const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { +const HostAddress &MemSwapManager::QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const { auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); auto &host_addrs = kernel_exec_info.host_addrs_; auto iter = host_addrs.find(output_idx); if (iter == host_addrs.end()) { MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; } - return iter->second; + return (iter->second).first; +} + +void MemSwapManager::AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + (iter->second).second = dirty; +} + +bool MemSwapManager::QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return (iter->second).second; +} + +void MemSwapManager::ResetHostAddrIsDirty() { + for (auto &kernel_exec_info_pair : kernel_execution_info_) { + auto &kernel_exec_info = kernel_exec_info_pair.second; + auto &host_addrs = kernel_exec_info.host_addrs_; + for (auto &host_addr : host_addrs) { + host_addr.second.second = true; + } + } } bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { @@ -325,19 +617,69 @@ void MemSwapManager::ReleaseHostPinnedMem() { host_addrs_list_.clear(); } -void MemSwapManager::ClearSwapQueue() const { mem_copy_manager_->ClearSwapQueue(); } +void MemSwapManager::ClearSwapQueue(bool mock) const { + if (!mock) { + mem_copy_manager_->ClearSwapQueue(); + } else { + mem_copy_manager_->ClearSwapQueueMock(); + } +} void MemSwapManager::ResetSwapInfo() { - ClearSwapQueue(); + ClearSwapQueue(true); for (auto &kernel_exec_info_pair : kernel_execution_info_) { auto &kernel_exec_info = kernel_exec_info_pair.second; - kernel_exec_info.trigger_swap_ = false; - kernel_exec_info.need_swap_ = false; + kernel_exec_info.trigger_swap_out_ = false; + kernel_exec_info.trigger_swap_in_ = false; + kernel_exec_info.swap_in_task_num_ = 0; kernel_exec_info.host_addrs_.clear(); } - ReleaseHostPinnedMem(); - swap_in_blacklist_.clear(); - mem_swap_info_.clear(); + mem_swap_info_map_.clear(); +} + +void MemSwapManager::DumpSwapInfo() const { + for (auto &kernel : execution_order_) { + if (!QueryKernelTriggerSwap(kernel)) { + continue; + } + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + MS_LOG(WARNING) << "Trigger kernel topo order[" << kernel_exec_info.topo_order_ << "] , op name[" + << AnfAlgo::GetCNodeName(kernel) << "]"; + + const MemSwapInfoSet &mem_swap_info_set = QueryKernelMemSwapInfo(kernel); + for (auto &mem_swap_info : mem_swap_info_set) { + if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { + MS_LOG(WARNING) << " Swap Out Task: swapped kernel topo order[" << mem_swap_info.topo_order_ << "], op name[" + << AnfAlgo::GetCNodeName(QueryKernelByTopoOrder(mem_swap_info.topo_order_)) << "], output idx[" + << mem_swap_info.output_idx_ << "]"; + } else { + MS_LOG(WARNING) << " Swap In Task: swapped kernel topo order[" << mem_swap_info.topo_order_ << "], op name[" + << AnfAlgo::GetCNodeName(QueryKernelByTopoOrder(mem_swap_info.topo_order_)) << "], output idx[" + << mem_swap_info.output_idx_ << "]"; + } + } + } +} + +void MemSwapManager::DumpUserNodes() const { + for (auto &kernel : execution_order_) { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + const auto &node_users_map = kernel_exec_info.node_users_map_; + MS_LOG(WARNING) << "Kernel topo order[" << kernel_exec_info.topo_order_ << "], op name[" + << AnfAlgo::GetCNodeName(kernel) << "]"; + if (node_users_map.empty()) { + MS_LOG(WARNING) << " Kernel does not own any user node"; + } + + for (auto &item : node_users_map) { + size_t output_idx = item.first; + auto &node_users = item.second; + for (auto &order : node_users) { + MS_LOG(WARNING) << " Output index[" << output_idx << "] tensor is used by kernel[" + << AnfAlgo::GetCNodeName(QueryKernelByTopoOrder(order)) << "], topo order[" << order << "]"; + } + } + } } } // namespace memswap } // namespace device diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h index d8620c8516..fa2da8d721 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_SWAP_MANAGER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_SWAP_MANAGER_H_ #include #include @@ -32,7 +32,11 @@ namespace memswap { class MemSwapManager { public: explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) - : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { + : tensor_size_threshold_(0), + tensor_size_threshold_idx_(0), + tensor_size_num_(1), + distance_threshold_(1), + distance_decay_step_(1) { mem_copy_manager_ = mem_copy_manager; } @@ -42,24 +46,23 @@ class MemSwapManager { ~MemSwapManager() = default; - void Init(const mindspore::session::KernelGraph *kernel_graph); + bool Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size = 0); - void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, - const HostAddress &host_address) const; + void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, const HostAddress &host_address, + bool mock, bool profiling = false, float *cost_time = nullptr) const; bool SyncMemCopyStream(SwapKind swap_kind) const; - DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; + DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind, bool mock) const; - // retreat to find a workable swap scheme bool RetreatSwapInfo(); + void AdjustSwapInPos(const AnfNodePtr &kernel, size_t index); + bool trigger_swap() const { return trigger_swap_; } bool mem_swap_init() const { return mem_swap_initialized_; } - KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const; - void AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform); float QueryKernelExecutionPerform(const AnfNodePtr &kernel) const; @@ -70,53 +73,90 @@ class MemSwapManager { bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; - bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; + bool QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const; + + size_t QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const; + + const AnfNodePtr QueryKernelByTopoOrder(size_t index) const; + + size_t QueryKernelTopoOrder(const AnfNodePtr &kernel) const; - const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; + const MemSwapInfoSet &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; - void InsertSwapInBlackList(const void *device_ptr); + void AssignHostMemory(); - bool FindInSwapInBlackList(const void *device_ptr) const; + const HostAddress &QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const; - const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; + void AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty); + + bool QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const; + + void ResetHostAddrIsDirty(); bool AllocHostPinnedMem(size_t size, void **addr) const; void ReleaseHostPinnedMem(); - void ClearSwapQueue() const; + void ClearSwapQueue(bool mock) const; + + void DumpSwapInfo() const; + + void DumpUserNodes() const; private: + KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const; + void AddSwapInfo(); void ResetSwapInfo(); void SaveUserKernelTopoOrder(); - void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); + bool InitSwapThreshold(size_t swap_mem_size); + + void RetreatSwapThreshold(); + + void CacheCurSwapInfoSet(const AnfNodePtr &kernel); + + void AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time); + + bool QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const; - void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); + size_t BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const; + + void MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info); void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + void RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + + bool CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const; + + std::vector> CheckDistanceBetweenKernelsWithIdx(const TensorInfo &tensor_info) const; + bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; std::vector execution_order_; std::vector ordered_tensors_; std::unordered_map kernel_execution_info_; std::unordered_map> kernel_swap_perform_; - // trigger swap kernel key : MemSwapInfo of kernel need to be swapped - std::unordered_map> mem_swap_info_; + // Key: trigger swap kernel, value: MemSwapInfoSet of kernel need to be swapped + std::unordered_map mem_swap_info_map_; std::vector host_addrs_list_; - std::unordered_set swap_in_blacklist_; + + // Key: cache kernel address, value: lists of first time move pos or not + std::map> kernel_first_move_cache_map_; + std::vector mem_swap_info_cache_list_; + std::pair best_and_cur_pos_cache_; size_t tensor_size_threshold_; size_t tensor_size_threshold_idx_; size_t tensor_size_num_; size_t distance_threshold_; + size_t distance_decay_step_; MemCopyManagerPtr mem_copy_manager_{nullptr}; - FuncGraphManagerPtr graph_manager_{nullptr}; + const mindspore::session::KernelGraph *kernel_graph_{nullptr}; bool mem_swap_initialized_{false}; bool swap_info_already_set_{false}; bool trigger_swap_{false}; @@ -129,4 +169,4 @@ using MemSwapManagerPtr = std::shared_ptr; } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_SWAP_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc index 900dd0d563..1bdf464d26 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc @@ -20,7 +20,7 @@ #include #include "frontend/operator/ops.h" #include "utils/utils.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/log_adapter.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_graph.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h index 7e3fbdb472..c6e31d6ce1 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H_ #include #include "backend/session/kernel_graph.h" @@ -26,4 +26,4 @@ void AddAtomicClean(const std::shared_ptr &kernel_graph); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc index 133a7e764a..f2062b6f39 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc @@ -16,6 +16,8 @@ #include "backend/optimizer/pass/common_subexpression_elimination.h" #include #include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/flags.h" namespace mindspore { namespace opt { @@ -33,48 +35,60 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { } return false; } + +bool HasSideEffectAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) { + return false; + } + return AnfAlgo::GetNodeAttr(cnode, GRAPH_FLAG_SIDE_EFFECT); +} } // namespace -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); - bool replace = false; if (main->isa() && node->isa()) { auto main_value = GetValueNode(main); auto node_value = GetValueNode(node); if (main_value->isa() && node_value->isa()) { - replace = false; + return false; } else if (main_value->isa() && node_value->isa()) { - replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); + return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); } else { - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); } } else if (main->isa() && node->isa()) { + if (check_side_effect && HasSideEffectAttr(main)) { + return false; + } if (!CheckEqualKernelBuildInfo(main, node)) { - replace = false; - } else { - auto c_main = main->cast(); - MS_EXCEPTION_IF_NULL(c_main); - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - appsame = false; - break; - } - } - replace = appsame; + return false; + } + auto c_main = main->cast(); + MS_EXCEPTION_IF_NULL(c_main); + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + return false; } } + return true; } - return replace; + return false; } bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) { diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h index bac870e59f..aab8eee5a7 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ #include "backend/optimizer/common/pass.h" #include "frontend/optimizer/cse.h" @@ -36,4 +36,4 @@ class BackendCSE : public CSE { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index 3ba055880c..024f69e843 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -16,10 +16,11 @@ #include "backend/optimizer/pass/communication_op_fusion.h" #include +#include #include #include -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "frontend/operator/ops.h" #include "runtime/device/kernel_info.h" #include "backend/session/anf_runtime_algorithm.h" @@ -89,6 +90,13 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { } return group + op + std::to_string(fusion); } + +void CheckInputs(const std::vector &fusion_inputs) { + std::set inputs_set(fusion_inputs.begin(), fusion_inputs.end()); + if (inputs_set.size() < fusion_inputs.size()) { + MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input"; + } +} } // namespace bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, @@ -100,7 +108,10 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); - const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); + std::vector split_indices; + if (!parallel_context->enable_parallel_optimizer()) { + split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); + } size_t segments = 0; if (split_indices.size() != 0) { @@ -160,6 +171,7 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr MS_EXCEPTION_IF_NULL(cnode); fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); } + CheckInputs(fusion_inputs); AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); MS_EXCEPTION_IF_NULL(fused_node); auto kernel_info = std::make_shared(); @@ -169,9 +181,6 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr for (size_t idx = start_index; idx <= end_index; ++idx) { auto cnode = communication_op_info.communication_op_nodes[idx]; MS_EXCEPTION_IF_NULL(cnode); - AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); - AnfAlgo::CopyNodeAttr("op", cnode, fused_node); - AnfAlgo::CopyNodeAttr("group", cnode, fused_node); abstract_list.push_back(cnode->abstract()); } auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); @@ -179,6 +188,13 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr auto abstract_tuple = std::make_shared(abstract_list); MS_EXCEPTION_IF_NULL(abstract_tuple); fused_node->set_abstract(abstract_tuple); + auto final_node = communication_op_info.communication_op_nodes[end_index]; + AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); + AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); + AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node); + if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node)) { + AnfAlgo::CopyNodeAttr(kAttrRankSize, final_node, fused_node); + } return fused_node; } diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h index 0e7cf9762d..446b214c1f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ #include #include #include @@ -77,4 +77,4 @@ class ReduceScatterFusion : public CommunicationOpFusion { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index 814ad9567c..16f0b18711 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -26,6 +26,7 @@ namespace opt { ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimCast->name(), {1}); Register(prim::kPrimAvgPoolGrad->name(), {0}); + Register(prim::kPrimAvgPoolGradVm->name(), {0}); Register(prim::kPrimConv2DBackpropInput->name(), {2}); Register(prim::kPrimConv2DBackpropFilter->name(), {2}); Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); @@ -46,6 +47,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimCumSum->name(), {1}); Register(prim::kPrimCumProd->name(), {1}); Register(prim::kPrimReduceAll->name(), {1}); + Register(prim::kPrimReduceAny->name(), {1}); Register(prim::kPrimUnsortedSegmentMin->name(), {2}); Register(kSparseGatherV2, {2}); Register(kUnsortedSegmentProdOpName, {2}); diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h index bd6cac1322..5f6b659372 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace opt { @@ -75,4 +75,4 @@ struct ConstInputToAttrInfoReceiver { ::mindspore::opt::ConstInputToAttrInfoRegister(op_name) } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc index 51d399bbcd..7d4b54b23a 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc @@ -18,7 +18,7 @@ #include #include "backend/session/anf_runtime_algorithm.h" #include "ir/primitive.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/utils.h" #include "abstract/abstract_value.h" #include "backend/optimizer/common/helper.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h index 83b44d5f51..b2faa0437a 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -31,4 +31,4 @@ class ConstToAttrStridedSliceGradPass : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index f2e35351b4..2c24687c9e 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -23,7 +23,7 @@ #include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/helper.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "frontend/operator/ops.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h index e6def42fa1..529bfade57 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ #include #include #include @@ -37,4 +37,4 @@ class ConvertConstInputToAttr : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc index f204841f3c..c4e3f38bef 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc @@ -19,7 +19,7 @@ #include #include -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "backend/optimizer/common/helper.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_graph.h" @@ -29,28 +29,8 @@ namespace mindspore { namespace opt { namespace { -ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} -AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { +AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { MS_EXCEPTION_IF_NULL(input_node); auto value_node = input_node->cast(); MS_EXCEPTION_IF_NULL(value_node); @@ -60,6 +40,9 @@ AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePt if (value->isa()) { tensor_ptr = ScalarToTensor(value->cast()); } else if (value->isa()) { + if (!AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } tensor_ptr = CreateTupleTensor(value->cast()); } else { MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple"; @@ -93,7 +76,7 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt for (size_t i = 0; i < inputs.size() - 1; ++i) { auto input_node = inputs[i + 1]; if (IsValueNode(input_node) || IsValueNode(input_node)) { - auto tensor_input = CreateTensorInput(kernel_graph, input_node); + auto tensor_input = CreateTensorInput(cnode, kernel_graph, input_node); if (tensor_input == nullptr) { new_inputs.push_back(input_node); continue; @@ -142,6 +125,9 @@ const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &fun if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { return nullptr; } + if (!node->isa()) { + return nullptr; + } if (AnfAlgo::IsGraphKernel(node)) { return ProcessGraphKernelOp(node); } else { diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h index 072652497a..b2854ba300 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ #include #include "ir/anf.h" +#include "utils/convert_utils.h" #include "backend/optimizer/common/optimizer.h" namespace mindspore { @@ -32,4 +33,4 @@ class ConvertConstInputToTensorInput : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_scalar_to_tensor.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_scalar_to_tensor.cc new file mode 100644 index 0000000000..61825854bc --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_scalar_to_tensor.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h" + +#include +#include +#include + +#include "ir/graph_utils.h" +#include "utils/convert_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(input_node); + if (!input_node->isa()) { + return nullptr; + } + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + return nullptr; + } + tensor::TensorPtr tensor_ptr = ScalarToTensor(value->cast()); + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "Create tensor of" << input_node->DebugString() << "failed"; + return nullptr; + } + auto tensor_input = std::make_shared(tensor_ptr); + MS_EXCEPTION_IF_NULL(tensor_input); + tensor_input->set_abstract(tensor_ptr->ToAbstract()); + if (kernel_graph != nullptr) { + tensor_input = kernel_graph->NewValueNode(tensor_input); + kernel_graph->AddValueNodeToGraph(tensor_input); + } else { + tensor_input = MakeValueNode(tensor_input); + } + tensor_input->set_scope(input_node->scope()); + return tensor_input; +} +} // namespace + +const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return nullptr; + } + if (!node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + bool input_changed = false; + for (size_t i = 0; i < cnode->inputs().size(); ++i) { + auto new_input = CreateTensorInput(func_graph->cast(), cnode->inputs()[i]); + if (new_input != nullptr) { + cnode->set_input(i, new_input); + input_changed = true; + } + } + auto kernel_graph = func_graph->cast(); + if (kernel_graph == nullptr || !input_changed) { + return nullptr; + } + return kernel_graph->NewCNode(cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_scalar_to_tensor.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_scalar_to_tensor.h new file mode 100644 index 0000000000..520da08b45 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_scalar_to_tensor.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_ +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertConstScalarToTensor : public PatternProcessPass { + public: + explicit ConvertConstScalarToTensor(bool multigraph = true) + : PatternProcessPass("convert_const_scalar_to_tensor", multigraph) {} + ~ConvertConstScalarToTensor() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc index b96a7af8f3..ddb01bde93 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -27,86 +27,33 @@ namespace mindspore { namespace opt { namespace { -bool MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return false; - } - - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - TypeId infer_data_type; - if (AnfAlgo::GetOutputTensorNum(value_node) == 0) { - infer_data_type = kTypeUnknown; - } else { - infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0); - } - kernel_build_info_builder->SetOutputsDeviceType(std::vector{infer_data_type}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get()); - return true; -} - -void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node, - std::vector *plant_inputs, std::vector *dyn_input_sizes) { - MS_EXCEPTION_IF_NULL(plant_inputs); - MS_EXCEPTION_IF_NULL(dyn_input_sizes); - MS_EXCEPTION_IF_NULL(graph); - auto output_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes->push_back(output_size); - std::vector convert_inputs; - auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - if (input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node); - } else { - for (size_t index = 0; index < output_size; ++index) { - auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)}, - {AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get()); - convert_inputs.emplace_back(tuple_get_item); - } - } - (void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs)); -} - void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(graph); - auto &ori_args = cnode_ptr->inputs(); - if (ori_args.size() < 1) { - return; - } std::vector plant_inputs; std::vector dyn_input_sizes; - plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); - for (size_t i = 1; i < ori_args.size(); ++i) { - auto input_node = ori_args[i]; - if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { + plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) { + auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i); + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { auto input_size = AnfAlgo::GetOutputTensorNum(input_node); dyn_input_sizes.push_back(input_size); - auto cnode = input_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - for (size_t j = 1; j < inputs.size(); ++j) { - MS_EXCEPTION_IF_NULL(inputs[j]); - if (IsValueNode(inputs[j])) { - auto success = MakeValueNode(inputs[j]); + auto make_tuple = input_node->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) { + auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); + MS_EXCEPTION_IF_NULL(dyn_input_node); + if (IsValueNode(dyn_input_node)) { + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto success = kernel_graph->NewValueNode(dyn_input_node->cast()); if (!success) { - MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); + MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString(); } } - plant_inputs.push_back(inputs[j]); + plant_inputs.push_back(dyn_input_node); } - } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { - ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); } else { dyn_input_sizes.push_back(-1); plant_inputs.push_back(input_node); @@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu for (auto &t : todos) { ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast()); } - } else { - ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); } + ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); return node; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h index 63d2415dc5..34b566b0d8 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ #include #include @@ -38,4 +38,4 @@ class ConvertTupleInputToDynamicInput : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc index 34ba83ef17..8bdc234e81 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -17,6 +17,7 @@ #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/common/helper.h" @@ -25,29 +26,25 @@ namespace mindspore { namespace opt { namespace { -CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(cnode_ptr); +AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf, + std::unordered_map *transed_nodes) { + MS_EXCEPTION_IF_NULL(tuple_anf); MS_EXCEPTION_IF_NULL(graph); - std::vector convert_inputs = {cnode_ptr->input(0)}; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { - auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); - if (AnfAlgo::IsTupleOutput(input_node)) { - std::vector types; - std::vector> shapes; - std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; - for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { - make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); - types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); - shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); - } - auto make_tuple = graph->NewCNode(make_tuple_inputs_list); - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); - convert_inputs.emplace_back(make_tuple); - } else { - convert_inputs.push_back(input_node); - } + MS_EXCEPTION_IF_NULL(transed_nodes); + + if (!AnfAlgo::IsTupleOutput(tuple_anf)) { + return tuple_anf; + } + auto transed_node_it = transed_nodes->find(tuple_anf); + if (transed_node_it != transed_nodes->end()) { + return transed_node_it->second; } - return graph->NewCNode(convert_inputs); + auto kernel_graph = graph->cast(); + auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); + (*transed_nodes)[tuple_anf] = make_tuple; + // replace graph inputs if input is a parameter + kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); + return make_tuple; } } // namespace @@ -64,15 +61,24 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); + std::unordered_map transed_nodes; if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { return nullptr; } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { - return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); - })) { - return ConvertTupleInputToMakeTuple(func_graph, cnode); + bool cnode_input_changed = false; + for (size_t i = 0; i < cnode->inputs().size(); ++i) { + const auto &input = cnode->inputs()[i]; + if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && + !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { + cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes)); + cnode_input_changed = true; + } + } + auto kernel_graph = func_graph->cast(); + if (kernel_graph == nullptr || !cnode_input_changed) { + return nullptr; } - return nullptr; + return kernel_graph->NewCNode(cnode); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h index 2fb4715cff..5f5f853855 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ELIMINATE_REDUNDANT_OP_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ELIMINATE_REDUNDANT_OP_H_ #include #include @@ -46,4 +46,4 @@ class EliminateRedundantOp : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ELIMINATE_REDUNDANT_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h index 37b88a4e39..e5c47d0f1d 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h +++ b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ERASE_VISIT_ATTR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ERASE_VISIT_ATTR_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -32,4 +32,4 @@ class EraseVisitAttr : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ERASE_VISIT_ATTR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc index 32655f1ec2..59f7e0f401 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc @@ -26,7 +26,7 @@ #include "frontend/operator/ops.h" #include "utils/utils.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "backend/optimizer/common/helper.h" #include "backend/session/anf_runtime_algorithm.h" #include "vm/segment_runner.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h index 9b3916fe28..a756564b4a 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_ #include #include "backend/optimizer/common/optimizer.h" @@ -26,4 +26,4 @@ namespace opt { void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc index e04110d8a0..238d3573d0 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc @@ -27,7 +27,7 @@ #include "frontend/operator/ops.h" #include "utils/utils.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "backend/optimizer/common/helper.h" #include "backend/session/anf_runtime_algorithm.h" #include "vm/segment_runner.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h index e14661dfdf..b241122d00 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_ #include #include @@ -60,4 +60,4 @@ void ReplaceNewFuseCNode(const std::shared_ptr &kernel_gra void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select = false); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h index 9a25b924bd..ca29e53d2b 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h +++ b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_GETITEM_TUPLE_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_GETITEM_TUPLE_SPLIT_H_ #include "backend/optimizer/common/optimizer.h" @@ -29,4 +29,4 @@ class GetitemTuple : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_GETITEM_TUPLE_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc index 710e130a85..09f387d974 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -156,6 +156,5 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c } return replace_node; } - } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h index 8ddd4d662e..c1b2f53d71 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_DEPENDENCE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_DEPENDENCE_H_ #include "backend/optimizer/common/optimizer.h" @@ -31,4 +31,4 @@ class OptimizeDependence : public PatternProcessPass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_DEPENDENCE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc index cd34464cda..53faa131b1 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc @@ -71,7 +71,6 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { AbstractBasePtrList abstract_list; AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); - AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); abstract_list.push_back(cnode->abstract()); auto abstract_tuple = std::make_shared(abstract_list); diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h index 382b08304f..0ecd8f3ed2 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_REPLACE_NODE_BY_PROXY_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_REPLACE_NODE_BY_PROXY_H_ #include #include #include @@ -38,4 +38,4 @@ class ReplaceNodeByProxy : public Pass { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_REPLACE_NODE_BY_PROXY_H_ diff --git a/mindspore/ccsrc/backend/session/CMakeLists.txt b/mindspore/ccsrc/backend/session/CMakeLists.txt index b7b791ada9..b6d340a325 100644 --- a/mindspore/ccsrc/backend/session/CMakeLists.txt +++ b/mindspore/ccsrc/backend/session/CMakeLists.txt @@ -1,4 +1,5 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_build_client.cc" "kernel_graph.cc" "session_basic.cc" "session_factory.cc" diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 38c040e6b1..34471537db 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -27,7 +27,7 @@ #include "backend/optimizer/common/helper.h" #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel_build_info.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/trans.h" namespace mindspore { @@ -221,6 +221,11 @@ std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { } auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); MS_EXCEPTION_IF_NULL(func_graph); + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + std::string fg_name = "GraphKernel_"; + fg_name += GetValue(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + return fg_name; + } return func_graph->ToString(); } MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString(); @@ -405,7 +410,7 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod } auto node = cnode->input(input_idx + 1); MS_EXCEPTION_IF_NULL(node); - return VisitKernel(node, 0); + return VisitKernelWithReturnType(node, 0); } std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { @@ -707,6 +712,18 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz return addr; } +// get workspace device mutable addr of anf_node +DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetMutableWorkspaceAddr(index); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist"; + } + return addr; +} + // set infer shapes and types of anf node void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node) { @@ -728,7 +745,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector & for (size_t i = 0; i < types.size(); ++i) { std::vector shape_int; std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); - abstract_list.push_back(std::make_shared(TypeIdToType(types[i]), shape_int)); + abstract_list.emplace_back(std::make_shared(TypeIdToType(types[i]), shape_int)); } auto abstract_tuple = std::make_shared(abstract_list); node->set_abstract(abstract_tuple); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 4fa3150e36..c08819e2dc 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H -#define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H #include #include #include @@ -149,6 +149,8 @@ class AnfRuntimeAlgorithm { static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); // get workspace device addr of anf_node static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); + // get workspace device mutable addr of anf_node + static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index); // set infer shapes and types of anf node static void SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node); @@ -213,4 +215,4 @@ class AnfRuntimeAlgorithm { } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 274b355679..81516d9481 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -18,9 +18,12 @@ #include #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "utils/union_find_set.h" #include "runtime/device/ascend/ascend_label_assign.h" +#include "utils/ms_context.h" +#include "debug/anf_ir_dump.h" static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodeCallArg = 1; @@ -104,7 +107,7 @@ static void ReuseParameter(NotNull root_kg, static CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { for (size_t i = start; i < list.size() - 1; ++i) { - if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { + if (AnfAlgo::IsRealKernel(list[i])) { return list[i]; } } @@ -168,18 +171,43 @@ static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNullerase(exec_iter); } +void AscendControlParser::AttachChildGraphToReturnNode(NotNull graph, + const NotNull *> memo) { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + const std::vector> &child_graph_order = graph->child_graph_order(); + if (child_graph_order.empty()) { + return; + } + + std::vector depend_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; + for (auto &cg : child_graph_order) { + MS_EXCEPTION_IF_NULL(cg); + auto fg = cg->cast(); + MS_EXCEPTION_IF_NULL(fg); + depend_inputs.emplace_back(NewValueNode(fg)); + AttachChildGraphToReturnNode(NOT_NULL(cg), memo); + } + auto child_graphs = graph->NewCNode(depend_inputs); + InsertDependToGraph(graph, NOT_NULL(child_graphs)); +} + void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; std::vector> link_list; // Insert Assign ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo)); + memo.clear(); // Reuse Parameter ReuseParameter(kg, link_list); // replace call by label goto / label switch - memo.clear(); (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + memo.clear(); // assign label resource device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); + AttachChildGraphToReturnNode(kg, NOT_NULL(&memo)); } void AscendControlParser::EraseParameter(NotNull root_graph, @@ -248,10 +276,14 @@ void AscendControlParser::EraseParameter(NotNull root_graph, } MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); - - auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first; - parameter_count.AddReadCount(source, -1); + auto source = assign_node->input(kCNodeAssignSource); + MS_EXCEPTION_IF_NULL(source); + auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; parameter_count.AddWriteCount(para, -1); + parameter_count.AddReadCount(para, -1); + if (visit_source->isa()) { + parameter_count.AddReadCount(visit_source, read - 1); + } for (auto &node : all_nodes) { for (size_t i = 0; i < node->size(); ++i) { if (node->input(i) == para) { @@ -260,8 +292,6 @@ void AscendControlParser::EraseParameter(NotNull root_graph, } } } - parameter_count.AddReadCount(source, 1); - parameter_count.AddReadCount(para, -1); } root_graph->set_execution_order(exec_order); } @@ -318,6 +348,17 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { (void)RecurseGraph(root_graph, NOT_NULL(&memo)); EraseParameter(root_graph, memo); EraseLabel(root_graph); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (context_ptr->save_graphs_flag()) { + std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir"; + DumpIR(file_path, root_graph.get()); + } } std::vector>> AscendControlParser::ParseCallNode( @@ -654,6 +695,9 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull fr } for (size_t i = 0; i < from_outputs.size(); i++) { auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + if (assign_node == nullptr) { + continue; + } const auto &from_graph_exe_order = from_graph->execution_order(); std::vector real_exe_order(from_graph_exe_order.size()); size_t real_exe_order_size = 0; diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index ac24735139..3c370fa500 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H #include #include @@ -23,7 +23,7 @@ #include #include #include "backend/session/kernel_graph.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" #include "utils/contract.h" #include "utils/union_find_set.h" @@ -66,7 +66,8 @@ class AscendControlParser { static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); static std::vector>> ParseCallNode(NotNull call_node); static std::tuple> ParsePartial(NotNull node); - + static void AttachChildGraphToReturnNode(NotNull graph, + const NotNull *> memo); // root graph order static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, NotNull graph); @@ -89,4 +90,4 @@ class AscendControlParser::ReferenceCounter { } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc index d251eb2039..fd0bec6d5e 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include #include "backend/session/ascend_inference_session.h" #include "frontend/operator/ops.h" #include "ir/tensor.h" @@ -20,9 +22,8 @@ #include "ir/param_value.h" #include "runtime/device/kernel_runtime.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/trans.h" -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" #include "utils/config_manager.h" #include "utils/base_ref_extends.h" @@ -74,7 +75,7 @@ GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { if (AnfAlgo::IsParameterWeight(pk_node)) { const auto ¶m_value = pk_node->default_param(); MS_EXCEPTION_IF_NULL(param_value); - auto tensor = std::dynamic_pointer_cast(param_value->value()); + auto tensor = std::dynamic_pointer_cast(param_value); MS_EXCEPTION_IF_NULL(tensor); if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), LongToSize(tensor->data().nbytes()), tensor->data_type(), @@ -85,5 +86,131 @@ GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { } return graph_id; } + +bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector &inputs, + std::string *error_msg) const { + MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernel_graph_inputs = kernel_graph->inputs(); + size_t no_weight_input = 0; + vector paras; + // find parameters of graph inputs + for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { + if (!kernel_graph_inputs[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; + continue; + } + auto parameter = kernel_graph_inputs[i]->cast(); + if (!AnfAlgo::IsParameterWeight(parameter)) { + paras.push_back(parameter); + } + } + + // check inputs + for (size_t i = 0; i < paras.size(); ++i) { + // compare input number + if (paras.size() != inputs.size()) { + MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() + << "] but the graph input number is [" << paras.size() << "]"; + MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); + if (error_msg != nullptr) { + std::stringstream str_stream; + str_stream << "Input number is inconsistent. The given input number [" << inputs.size() + << "] but the graph input number is [" << paras.size() << "]\n"; + str_stream << "InputsInfo --" << InputsInfo(paras, inputs); + *error_msg = str_stream.str(); + } + return false; + } + auto input = inputs[no_weight_input++]; + if (!CompareInput(input, paras[i])) { + MS_LOG(ERROR) << "Please check the input information."; + MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); + if (error_msg != nullptr) { + std::stringstream str_stream; + str_stream << "Please check the input information.\n"; + str_stream << "InputsInfo --" << InputsInfo(paras, inputs); + *error_msg = str_stream.str(); + } + return false; + } + } + return true; +} + +bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const { + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(parameter); + // compare dims + auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); + + // compare shape + auto input_shape = input->shape(); + vector trans_input; + (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), + [](const int dim) { return static_cast(dim); }); + if (trans_input != parameter_shape) { + MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) + << ", but the parameter shape is " << PrintInputShape(parameter_shape) + << ". parameter : " << parameter->DebugString(); + return false; + } + + // compare data type + auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); + if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { + MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() + << ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) + << ". parameter : " << parameter->DebugString(); + return false; + } + return true; +} + +template +std::string AscendInferenceSession::PrintInputShape(std::vector shape) const { + string res = "["; + for (auto dim : shape) { + res += " " + std::to_string(dim); + } + return res + " ]"; +} + +std::string AscendInferenceSession::InputsInfo(const std::vector ¶s, + const std::vector &inputs) const { + const std::map dtype_name_map{ + {TypeId::kNumberTypeBegin, "Unknown"}, {TypeId::kNumberTypeBool, "Bool"}, + {TypeId::kNumberTypeFloat64, "Float64"}, {TypeId::kNumberTypeInt8, "Int8"}, + {TypeId::kNumberTypeUInt8, "Uint8"}, {TypeId::kNumberTypeInt16, "Int16"}, + {TypeId::kNumberTypeUInt16, "Uint16"}, {TypeId::kNumberTypeInt32, "Int32"}, + {TypeId::kNumberTypeUInt32, "Uint32"}, {TypeId::kNumberTypeInt64, "Int64"}, + {TypeId::kNumberTypeUInt64, "Uint64"}, {TypeId::kNumberTypeFloat16, "Float16"}, + {TypeId::kNumberTypeFloat32, "Float32"}, + }; + auto data_type_to_string = [&dtype_name_map](TypeId type_id) { + auto it = dtype_name_map.find(type_id); + if (it == dtype_name_map.end()) { + return std::string("Unknown"); + } + return it->second; + }; + + std::string graph = "graph inputs:{ "; + for (size_t i = 0; i < paras.size(); ++i) { + auto ¶ = paras[i]; + graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(para, 0).size()) + + ", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(para, 0)) + ", data type " + + data_type_to_string(AnfAlgo::GetSelectKernelBuildInfo(para)->GetOutputDeviceType(0)) + " }"; + } + + std::string actual = "given inputs:{ "; + for (size_t i = 0; i < inputs.size(); ++i) { + actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " + + PrintInputShape(inputs[i]->shape()) + ", data type " + data_type_to_string(inputs[i]->data_type()) + " }"; + } + return graph + " " + actual; +} + } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h index 5364ae8d4e..d092b3ccb3 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.h +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_INFERENCE_SESSION_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_INFERENCE_SESSION_H #include #include #include @@ -39,8 +39,14 @@ class AscendInferenceSession : public AscendSession { void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; GraphId CompileGraph(NotNull func_graph) override; + bool CheckModelInputs(uint32_t graph_id, const std::vector &inputs, + std::string *error_msg) const override; + bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const; + template + std::string PrintInputShape(std::vector shape) const; + std::string InputsInfo(const std::vector ¶s, const std::vector &inputs) const; }; MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 75bc4e2d05..59fc846759 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -34,16 +34,14 @@ #include "runtime/device/kernel_adjust.h" #include "runtime/device/ascend/ascend_stream_assign.h" #include "runtime/device/ascend/ascend_label_assign.h" -#include "predict/predict.h" #include "backend/session/anf_runtime_algorithm.h" #include "ir/scalar.h" #include "debug/anf_ir_dump.h" #include "debug/anf_ir_utils.h" #include "debug/draw.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/kernel_runtime_manager.h" -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" #include "utils/config_manager.h" #include "utils/base_ref_extends.h" #include "debug/tensor_load.h" @@ -82,24 +80,6 @@ void DumpGraphExeOrder(const std::vector &execution_order, const std:: i++; } buf << "================== execution order ==================\n"; - // std::cout << buf.str() << std::endl; -} - -void DumpGraphInputArgs(const VectorRef &args) { - MS_LOG(INFO) << "Args size[%lu]" << args.size(); - for (size_t i = 0; i < args.size(); i++) { - if (utils::isa(args[i])) { - auto anf = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(anf); - MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString(); - } else if (utils::isa(args[i])) { - auto value = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString(); - } else { - MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString(); - } - } } void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { @@ -109,52 +89,6 @@ void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool } } -std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { - MS_EXCEPTION_IF_NULL(graph); - std::vector graph_inputs = graph->inputs(); - auto valid_inputs = graph->valid_inputs(); - size_t real_args_size = 0; - std::vector real_args = {}; - for (size_t i = 0; i < args.size(); i++) { - if (utils::isa(args[i])) { - auto tmp_args = AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}); - for (auto &real_arg : tmp_args) { - auto anf_node = utils::cast(real_arg); - MS_EXCEPTION_IF_NULL(anf_node); - auto abstract = anf_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && - !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - real_args_size += tuple_abstract->size(); - continue; - } - real_args_size += 1; - real_args.push_back(real_arg); - } - } else { - real_args_size += 1; - real_args.push_back(args[i]); - } - } - if (graph_inputs.size() != valid_inputs.size()) { - MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size() - << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; - } - if (real_args_size != graph_inputs.size()) { - for (size_t j = 0; j < valid_inputs.size(); j++) { - if (valid_inputs[j]) { - MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); - } - } - MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() - << " not equal"; - } - return real_args; -} - std::vector GetCNodes(const std::vector &anf_nodes) { std::vector cnodes = {}; size_t i = 0; @@ -168,128 +102,6 @@ std::vector GetCNodes(const std::vector &anf_nodes) { return cnodes; } -static std::vector> GetChildList(const std::vector &cnodes, - const std::set &cut_prims) { - size_t after_cut_index = 0; - std::vector> ret; - for (size_t i = 0; i < cnodes.size(); ++i) { - bool is_cut_node = false; - for (auto &prim : cut_prims) { - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { - is_cut_node = true; - break; - } - } - if (is_cut_node) { - // is call and not switch call,cut to 3 lists - if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { - // if is not a call,cut to 2 lists - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); - after_cut_index = i; - } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); - ret.emplace_back(1, cnodes[i]); - after_cut_index = i + 1; - continue; - } - } - // get last child graph list - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); - continue; - } - } - return ret; -} - -static void BindCallArgsWithParameter(const std::vector ¶meters, const std::vector &args, - const KernelGraphPtr &graph, KernelGraphPtr child_graph, - const NotNull *> memo) { - MS_EXCEPTION_IF_NULL(child_graph); - MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); - if (args.empty()) { - return; - } - if (parameters.size() != args.size()) { - MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() - << " and args size:" << args.size() << " not equal!"; - } - child_graph->SetExecOrderByDefault(); - for (size_t i = 0; i < parameters.size(); i++) { - MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" - << args[i]->DebugString(); - if (args[i] == parameters[i]) { - MS_LOG(INFO) << "Parameter and arg are same."; - continue; - } - child_graph->SetRealInput(parameters[i], args[i]); - if (memo->find(child_graph) != memo->end() || !args[i]->isa()) { - MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); - child_graph->AddUnreuseArgs(args[i], graph); - } - } -} - -// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of -// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -static void UpdateRealInput(NotNull graph, bool split_flag, - const NotNull *> memo) { - MS_EXCEPTION_IF_NULL(memo.get()); - auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); - if (child_graphs.size() == 1) { - MS_EXCEPTION_IF_NULL(child_graphs[0]); - std::vector real_args = - std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); - std::vector child_inputs = child_graphs[0]->inputs(); - BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); - if (split_flag) { - call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); - } - } else if (child_graphs.size() == 2) { - auto get_partial_args = [&](size_t input_index) -> std::vector { - auto switch_node = call_node->input(1); - MS_EXCEPTION_IF_NULL(switch_node); - auto switch_cnode = switch_node->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - auto partial = switch_cnode->input(input_index); - MS_EXCEPTION_IF_NULL(partial); - if (IsValueNode(partial)) { - return {}; - } - auto partial_cnode = partial->cast(); - MS_EXCEPTION_IF_NULL(partial_cnode); - auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); - if (split_flag) { - partial_cnode->set_inputs( - std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); - } - return ret; - }; - BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); - BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); - } - } -} - -static void RecurseToUpdateCallRealInput(NotNull graph, - const NotNull *> memo) { - memo->insert(graph.get()); - MS_LOG(INFO) << "Start graph id:" << graph->graph_id(); - for (auto &child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) != memo->end()) { - MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() - << ",parent graph:" << graph->parent_graph()->graph_id(); - continue; - } - RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); - } - // this action should from bottom to top - graph->UpdateCallRealInput(); -} - void InsertMakeTupleForOutput(NotNull root_graph) { auto return_node = root_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); @@ -353,6 +165,10 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { RootGraphExecutorValidate(NOT_NULL(root_graph)); // adjust kernel AdjustKernel(root_graph); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + // Assign parameter keys. + AssignParamKey(root_graph); +#endif // assign stream AssignStream(NOT_NULL(root_graph)); // insert profiling point @@ -405,8 +221,6 @@ void AscendSession::BuildGraph(GraphId graph_id) { } // insert assigns to child graph InsertAllAssigns(); - // insert switch and active to child graph - MergeSwitchCompile(); SetFinalGraphSummaryFlag(graph); // OptChildGraphs auto graph_order = GetGraphOrder(final_graph_id_); @@ -419,7 +233,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { auto child_graph = GetGraph(graph_order[i]); CompileChildGraph(child_graph); } - GetSummaryNodes(graph.get()); + SetSummaryNodes(graph.get()); // merge child graph MergeGraphExecOrder(); } else { @@ -487,8 +301,6 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; DumpIR(file_path, child_graph); } - // convert kernel Graph to model - predictmodel::StepConvertGraph(child_graph); // optimize graph HardwareOptimize(child_graph); // assign static memory of parameters @@ -511,8 +323,10 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &kernel_gra } bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { - if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { - return true; - } - - return false; + return run_op_graphs_.find(graph_info) != run_op_graphs_.end(); } void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, @@ -598,14 +408,25 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; // malloc mem - RunOpMemoryAlloc(input_tensors, graph.get()); + RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get()); // load input data to device LoadInputData(graph, input_tensors); // run op RunOpExecTask(graph); // get output VectorRef outputs; - UpdateOutputs(graph, &outputs, input_tensors); + if (op_run_info.value != nullptr) { + std::vector pre_output_tensors; + TensorValueToTensor(op_run_info.value, &pre_output_tensors); + for (auto &pre_output : pre_output_tensors) { + tensor::TensorPtr tensor = std::make_shared(pre_output->data_type(), pre_output->shape()); + tensor->set_device_address(pre_output->device_address()); + tensor->set_dirty(false); + outputs.emplace_back(tensor); + } + } else { + UpdateOutputs(graph, &outputs, input_tensors); + } // trans output to tuple auto output_tensors = TransformBaseRefListToTuple(outputs); if (!utils::isa(output_tensors) || @@ -734,14 +555,15 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { MS_LOG(INFO) << "Finish!"; } -void AscendSession::RunOpMemoryAlloc(const std::vector &input_tensors, +void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value, + const std::vector &input_tensors, KernelGraph *kernel_graph) const { MS_LOG(INFO) << "Start memory alloc!"; MS_EXCEPTION_IF_NULL(kernel_graph); opt::RemoveNopNode(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); + runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph); MS_LOG(INFO) << "Finish!"; } @@ -834,55 +656,14 @@ void AscendSession::LoadTensor(const std::shared_ptr &kernel_graph) MS_LOG(INFO) << "Finish!"; } -GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { - MS_LOG(INFO) << "Start! Args size " << args.size(); - auto final_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(final_graph); - final_graph_id_ = final_graph->graph_id(); - MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; - // init private variables and bind them with final_graph_id - graph_execute_orders_[final_graph_id_] = std::vector(); - graph_order_types_[final_graph_id_] = std::vector(); - for (const auto ¶meter : args) { - MS_EXCEPTION_IF_NULL(parameter); - if (!parameter->isa()) { - MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!"; - } - AnfNodePtr parameter_backend = nullptr; - // if function return UINT_MAX,the parameter is not exist in child graph - auto parameter_belong_graph_id = GetGraphIdByNode(parameter); - if (parameter_belong_graph_id == kInvalidGraphId) { - parameter_backend = CreateNewParameterFromParameter(parameter, true, final_graph.get()); - final_graph->FrontBackendlMapAdd(parameter, parameter_backend); - MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph"; - } else { - // parametr is a parameter of child graph - auto graph = GetGraph(parameter_belong_graph_id); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph [" - << parameter_belong_graph_id << "]"; - parameter_backend = graph->GetBackendAnfByFrontAnf(parameter); - // add parameter in backend to final graph inputs - auto final_graph_inputs = final_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(final_graph_inputs); - final_graph_inputs->push_back(parameter_backend); - } - MS_EXCEPTION_IF_NULL(parameter_backend); - MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id " - << AnfAlgo::GetGraphId(parameter_backend.get()); - } - MS_LOG(INFO) << "End final_graph_id " << final_graph_id_; - return final_graph_id_; -} - -void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, +void AscendSession::RecurseSetSummaryNodes(KernelGraph *graph, std::map> *summary) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(summary); // if final graph have no child graph auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); if (graph_order_iter == graph_execute_orders_.end()) { - SessionBasic::GetSummaryNodes(graph); + SessionBasic::SetSummaryNodes(graph); auto summary_nodes = graph->summary_nodes(); summary->insert(summary_nodes.begin(), summary_nodes.end()); return; @@ -894,293 +675,25 @@ void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, if (child_graph == nullptr) { continue; } - SessionBasic::GetSummaryNodes(child_graph.get()); + SessionBasic::SetSummaryNodes(child_graph.get()); auto child_graph_summary = child_graph->summary_nodes(); summary->insert(child_graph_summary.begin(), child_graph_summary.end()); - RecurseGetSummaryNodes(child_graph.get(), summary); + RecurseSetSummaryNodes(child_graph.get(), summary); } graph->set_summary_nodes(*summary); } -void AscendSession::GetSummaryNodes(KernelGraph *graph) { +void AscendSession::SetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); auto summary_nodes = graph->summary_nodes(); std::map> summary; summary.insert(summary_nodes.begin(), summary_nodes.end()); - RecurseGetSummaryNodes(graph, &summary); + RecurseSetSummaryNodes(graph, &summary); graph->set_summary_nodes(summary); MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); } -AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { - auto fake_graph = GetGraph(fake_graph_id); - MS_EXCEPTION_IF_NULL(fake_graph); - auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { - auto parameter = fake_graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = fake_graph->NewParameter(parameter); - // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. - auto graph_inputs = fake_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->push_back(new_parameter); - return new_parameter; - }; - auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { - MS_EXCEPTION_IF_NULL(cnode); - auto abstract = cnode->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa()) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]"; - return create_parameter((*tuple_abstract)[output_idx]); - } - return create_parameter(cnode->abstract()); - }; - if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - auto make_tuple = output_item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - for (size_t i = 1; i < make_tuple->inputs().size(); i++) { - auto input = make_tuple->inputs()[i]; - make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); - } - return fake_graph->NewCNode(make_tuple_inputs); - } - return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); -} - -void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { - // get the backend anf node related to the output node of front - auto output_from_graph_id = GetGraphIdByNode(node); - auto output_from_graph = GetGraph(output_from_graph_id); - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id - << "] to final graph"; - MS_EXCEPTION_IF_NULL(output_from_graph); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - // if output is from final graph,it remarks no child graph exist - if (final_graph_id_ == output_from_graph_id) { - MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); - final_graph->set_output(ConstructOutput({node}, final_graph)); - final_graph->set_executable(false); - return; - } - final_graph->set_output(output_from_graph->output()); -} - -void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { - auto value_node = NewValueNode(value); - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - value_node->set_abstract(abstract::FromValue(value)); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); - final_graph->set_executable(false); - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; -} - -void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { - for (auto &output : vec_output) { - if (utils::isa(output)) { - auto output_anf_node = utils::cast(output); - SetFinalGraphOutput(output_anf_node); - } else if (utils::isa(output)) { - auto value = utils::cast(output); - SetFinalGraphOutput(value); - } else { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } - } -} - -void AscendSession::SetFinalGraphOutput(const BaseRef &output) { - if (utils::isa(output)) { - auto output_anf_node = utils::cast(output); - SetFinalGraphOutput(output_anf_node); - } else if (utils::isa(output)) { - auto value = utils::cast(output); - SetFinalGraphOutput(value); - } else if (utils::isa(output)) { - auto vec_output = utils::cast(output); - SetFinalGraphOutput(vec_output); - } else { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } -} - -void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { - MS_LOG(INFO) << "Start!"; - MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; - auto condition_graph = GetGraph(condition_graph_id); - MS_EXCEPTION_IF_NULL(condition_graph); - tensor::TensorPtr tensor = std::make_shared(kNumberTypeInt32, std::vector{1}); - int32_t *val = nullptr; - val = static_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - auto value_node = std::make_shared(tensor); - value_node->set_abstract(abstract::FromValue(tensor, false)); - auto counter_const = condition_graph->NewValueNode(value_node); - condition_graph->AddValueNodeToGraph(counter_const); - // create a new switch op - auto switch_primitive = std::make_shared("StreamSwitch"); - auto cond_output_it = condition_output_.find(condition_graph_id); - if (cond_output_it == condition_output_.end()) { - MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; - } - auto cond_output_kernel = - AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; - MS_EXCEPTION_IF_NULL(cond_output_kernel); - std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; - CNodePtr switch_node = condition_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(switch_node); - switch_node->set_abstract(std::make_shared()); - AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); - // set attr: cond_ RT_GREATER - AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue(static_cast(RT_GREATER)), switch_node); - // set attr:data_type - AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue(static_cast(RT_SWITCH_INT64)), switch_node); - // set attr:true branch graph id ,which is same to stream distinction label - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(true_graph_id), switch_node); - // append switch at the end of condition graph - auto return_node = condition_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - InsertControlDependToGraph(condition_graph_id, return_node->input(kReturnDataIndex), switch_node); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { - auto &graph_execute_order = GetGraphOrder(final_graph_id_); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id); - if (false_index == kInvalidIndex || false_index == 0) { - return; - } - for (int i = SizeToInt(false_index) - 1; i >= 0; i--) { - size_t graph_index = IntToSize(i); - if (graph_index >= graph_execute_order.size()) { - MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]"; - } - if (graph_order_type[graph_index] == COMMON_GRAPH) { - auto true_last_id = graph_execute_order[graph_index]; - MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id; - auto true_last = GetGraph(true_last_id); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - auto false_last = GetGraph(false_graph_id); - MS_EXCEPTION_IF_NULL(true_last); - MS_EXCEPTION_IF_NULL(false_last); - MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id; - // create fake output - auto fake_output_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(fake_output_graph); - graph_execute_order.push_back(fake_output_graph->graph_id()); - graph_order_type.push_back(COMMON_GRAPH); - fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); - final_graph->set_output(fake_output_graph->output()); - InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); - InsertMultipleAssignToGraph(false_graph_id, false_last->output(), final_graph->output()); - // insert stream active for loop sink - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1) { - // insert active in true graph, another active will be inserted in kernel adjust - InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); - } - break; - } - } -} - -void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, - const AnfNodePtr &output) { - if (switches_.find(cond_graph_id) != switches_.end()) { - MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; - return; - } - switches_[cond_graph_id] = std::pair(true_graph_id, false_graph_id); - condition_output_[cond_graph_id] = output; - MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; - // set the type of condition graph - auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - if (cond_graph_index >= graph_order_type.size()) { - MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size(); - } - graph_order_type[cond_graph_index] = CONDITION_GRAPH; - // update distinction label of false graph,update before merge to sure the distinction - if (false_graph_id != kInvalidGraphId) { - // false graph and condition in graph same stream - auto condition_graph = GetGraph(cond_graph_id); - MS_EXCEPTION_IF_NULL(condition_graph); - SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); - // if false graph is a condition graph and has been switch compiled before,it's false should be updated again - auto cond_it = switches_.find(false_graph_id); - while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { - cond_graph_id = cond_it->first; - false_graph_id = cond_it->second.second; - condition_graph = GetGraph(cond_graph_id); - if (condition_graph == nullptr) { - continue; - } - SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); - cond_it = switches_.find(false_graph_id); - } - } -} // namespace session - -void AscendSession::MergeSwitchCompile() { - auto graph_execute_order = GetGraphOrder(final_graph_id_); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - for (auto switch_compile : switches_) { - auto cond_graph_id = switch_compile.first; - auto true_graph_id = switch_compile.second.first; - auto false_graph_id = switch_compile.second.second; - MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; - auto condition_graph = GetGraph(cond_graph_id); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(condition_graph); - MS_EXCEPTION_IF_NULL(final_graph); - // insert switch to condition graph - InsertSwitchToGraph(cond_graph_id, true_graph_id); - auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); - auto prev_graph_id = kInvalidGraphId; - // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph - if (cond_graph_index == 0 && !final_graph->execution_order().empty()) { - prev_graph_id = final_graph_id_; - // set the distinction label of final graph - SetStreamDistinctionLabel(final_graph, final_graph_id_, true); - // if condition graph is not the first graph - } else if ((cond_graph_index - 1 < graph_execute_order.size()) && - (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) { - prev_graph_id = graph_execute_order[cond_graph_index - 1]; - } - // insert stream active to common graph - if (prev_graph_id != kInvalidGraphId) { - InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); - } - // if this is a 'if' condition - auto it = while_condition_graphs_.find(cond_graph_id); - if (it == while_condition_graphs_.end()) { - CopyOutputOfIf(false_graph_id); - } else { - // if it is a while,insert a stream active to true graph - GraphId from_graph = it->second; - InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); - } - } - MS_LOG(INFO) << "Finish!"; -} - void AscendSession::InsertAllAssigns() { std::vector> assigns; for (auto assign : assigns_) { @@ -1212,173 +725,6 @@ void AscendSession::InsertAllAssigns() { } } -// insert active to graph -void AscendSession::SetActive(GraphId from, GraphId to) { - if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { - MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from " - << while_condition_graphs_[to]; - return; - } - MS_LOG(INFO) << "From " << from << " to " << to; - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - std::vector graph_order_new; - std::vector graph_type_new; - for (size_t i = 0; i < graph_order.size(); i++) { - auto graph_id = graph_order[i]; - graph_order_new.push_back(graph_id); - graph_type_new.push_back(graph_type[i]); - if (from == graph_id) { - graph_order_new.push_back(kInvalidGraphId); - graph_type_new.push_back(BRANCH_END); - } - } - graph_order = graph_order_new; - graph_type = graph_type_new; - // set the graph type of condition graph - graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH; - // record the condition graph into while condition set - while_condition_graphs_[to] = from; -} - -void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(front_anf); - auto from_graph_id = GetGraphIdByNode(front_anf); - auto from_graph = GetGraph(from_graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - MS_EXCEPTION_IF_NULL(backend_parameter); - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); - MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" - << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) - << "]"; - // a node should not assign to itself - if (backend_arg.get() == backend_parameter.get()) { - return; - } - // if arg is the the parameter of child graph,it is parameter of final graph too - if (front_anf->isa()) { - MS_EXCEPTION_IF_NULL(backend_arg); - MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() - << "] will be replaced."; - to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); - return; - } - MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" - << backend_parameter->DebugString() << "of graph " << to_graph_id; - assigns_.emplace_back(std::tuple(front_anf, to_graph_id, input_idx)); -} - -void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, - size_t input_idx) { - MS_LOG(INFO) << "Start!"; - std::pair graph_input_pair(to_graph_id, input_idx); - initial_tenosrs_[graph_input_pair] = front_tensor; - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { - MS_LOG(INFO) << "To_graph_id " << to_graph_id; - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - for (size_t i = 0; i < graph_order.size(); i++) { - if (graph_order[i] == to_graph_id) { - return; - } - } - // if graph is not in graph order,add it to graph order - SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false); - graph_order.push_back(to_graph_id); - graph_type.push_back(COMMON_GRAPH); - for (size_t i = 0; i < graph_order.size(); i++) { - MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i]; - } -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto output_num = AnfAlgo::GetOutputTensorNum(node); - if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - return input_index + output_num; - } - auto valid_inputs = graph->valid_inputs(); - if (valid_inputs[input_index]) { - SetChildGraphParameter(node, graph->graph_id(), input_index); - } else { - MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); - } - return ++input_index; -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); - } - SetChildGraphParameter(value->cast(), graph->graph_id(), input_index); - return ++input_index; -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { - auto index = input_index; - for (auto &arg : vec_args) { - if (utils::isa(arg)) { - // arg is a anf node - auto node = utils::cast(arg); - index = SetChildGraphInput(graph, node, input_index); - } else if (utils::isa(arg)) { - // arg is a tensor - auto value = utils::cast(arg); - index = SetChildGraphInput(graph, value, input_index); - } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); - } - } - return index; -} - -void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { - MS_LOG(INFO) << "Set input of graph " << g; - auto to_graph = GetGraph(g); - MS_EXCEPTION_IF_NULL(to_graph); - DumpGraphInputArgs(args); - UpdateGraphOrder(g); - auto &graph_inputs = to_graph->inputs(); - auto real_args = GetRealArgs(to_graph, args); - size_t input_index = 0; - for (size_t i = 0; i < real_args.size(); i++) { - if (input_index >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size(); - } - auto &real_arg = real_args[i]; - if (utils::isa(real_arg)) { - // arg is a anf node - auto node = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, node, input_index); - } else if (utils::isa(real_arg)) { - // arg is a tensor - auto value = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, value, input_index); - } else if (utils::isa(real_arg)) { - // arg is a VectorRef - auto vec_args = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, vec_args, input_index); - } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); - } - } - MS_LOG(INFO) << "Finish!"; -} - GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { for (const auto &graph_item : graphs_) { auto graph = graph_item.second; @@ -1470,63 +816,10 @@ void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from MS_EXCEPTION_IF_NULL(assign_node); assign_node->set_abstract(to->abstract()); // append the assign at the end of from graph - InsertDependToGraph(graph_id, assign_node); -} - -void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { - std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); - std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); - MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to [" - << AnfAlgo::GetGraphId(to.get()) << "]"; - if (from_outputs.size() != to_outputs.size()) { - MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]"; - MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" - << to_outputs.size() << "]"; - } - for (size_t i = 0; i < from_outputs.size(); i++) { - InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]); - } -} - -void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { - MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; - auto from_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - std::vector inputs = {NewValueNode(std::make_shared("StreamActive"))}; - auto active_node = from_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(active_node); - active_node->set_abstract(std::make_shared()); - // set the active stream id into the attr of active node - std::vector active_index_value = {}; - active_index_value.push_back(actived_stream); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_value), active_node); - // append the active node at the end of from graph - auto return_node = from_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - InsertControlDependToGraph(graph_id, return_node->input(kReturnDataIndex), active_node); + AscendControlParser::InsertDependToGraph(NOT_NULL(graph), NOT_NULL(assign_node)); } -void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { - AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); -} - -void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, - const AnfNodePtr &second_node) { - AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), - NOT_NULL(second_node)); -} - -size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { - auto &graph_order = GetGraphOrder(final_graph); - for (size_t i = 0; i < graph_order.size(); i++) { - if (child_graph == graph_order[i]) { - return i; - } - } - return kInvalidIndex; -} - -std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) { +const std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) const { auto graph_order_iter = graph_execute_orders_.find(final_graph_id); if (graph_order_iter == graph_execute_orders_.end()) { MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph"; @@ -1534,8 +827,7 @@ std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) { return graph_order_iter->second; } -// get graph order type vector by graph id -std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) { +const std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) const { auto graph_type_iter = graph_order_types_.find(final_graph_id); if (graph_type_iter == graph_order_types_.end()) { MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_"; @@ -1567,85 +859,6 @@ void AscendSession::SyncInitialTenosrToDevice() { } } -static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector &list) { - // count the output of every anf node - std::set has_output_nodes; - for (auto &anf_node : list) { - MS_EXCEPTION_IF_NULL(anf_node); - for (auto &input : anf_node->inputs()) { - (void)has_output_nodes.insert(input); - } - } - - auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); - std::vector make_tuple_inputs = {make_tuple_primitve}; - int output_idx = 0; - MS_EXCEPTION_IF_NULL(new_kernel_graph); - for (auto &anf_node : list) { - if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { - new_kernel_graph->set_return(anf_node); - } - if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString(); - make_tuple_inputs.push_back(anf_node); - } - } - if (new_kernel_graph->get_return() == nullptr) { - new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); - } -} - -std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list) { - MS_EXCEPTION_IF_NULL(new_kernel_graph); - MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id(); - MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); - std::vector call_node_inputs; - std::vector new_graph_inputs; - // create new parameter from cnode - for (auto &anf_node : list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto input = cnode->inputs()[input_idx]; - MS_EXCEPTION_IF_NULL(input); - AnfNodePtr new_parameter = nullptr; - // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one - // parameter in graph inputs and one arg in call node - auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); - if (call_input_it != call_node_inputs.end()) { - cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); - continue; - } - // value node consider move to new graph - if (input->isa()) { - cnode->set_input(input_idx, input); - continue; - } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { - // if is cnode and not in current child graph - new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); - cnode->set_input(input_idx, new_parameter); - } else { - // if is a cnode and in current graph - continue; - } - new_graph_inputs.push_back(new_parameter); - call_node_inputs.push_back(input); - } - } - // set graph inputs of new graph - auto graph_inputs = new_kernel_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->clear(); - std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); - - MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); - ConstructSplitedGraphOutput(new_kernel_graph, list); - MS_LOG(INFO) << "End"; - return call_node_inputs; -} - void AscendSession::BackendOptimization(const std::vector &all_graphs) { MS_LOG(INFO) << "Start BackendCommonOptimization"; for (auto &graph : all_graphs) { @@ -1654,155 +867,28 @@ void AscendSession::BackendOptimization(const std::vector &all_g MS_LOG(INFO) << "End."; } -void AscendSession::SplitGraphs(NotNull root_graph) { - std::set memo; - // if output of graph is nullptr,no need insert maketuple at the end of graph - if (root_graph->output() == nullptr) { - return; - } - // if root graph output is a call node ,the root graph is condition graph of 'if' sentence - auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; - if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { - SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); - for (auto &child_graph : root_graph->child_graph_order()) { - RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); - } - } else { - RecurseSplitGraph(root_graph, NOT_NULL(&memo)); - } - memo.clear(); - // add maketuple to the end of the last child graph to suit old process - auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back(); - auto make_tuple = output_graph->NewCNode( - {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), output_graph->output()}); - output_graph->set_output(make_tuple); - // replace the real input if the real input is a call - RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); -} - -AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, - const std::vector &child_graph_list) { - // if child graph list only has a call ,then return the exist call - if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { - return child_graph_list[0]; - } - // create new child graph - auto child_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(child_graph); - // create new value node to bind child graph - auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); - std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), - graph_value_node}; - // set the graph id of all node of child graph - for (auto &child_graph_node : child_graph_list) { - AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); - } - auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); - std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); - auto new_call = graph->NewCNode(new_call_input); - AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call); - return new_call; -} - -void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims, - const NotNull *> memo) { - MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); - bool split_flag = false; - auto apply_list = GetCNodes(TopoSort(graph->get_return())); - // update the root graph child graph order - graph->UpdateChildGraphOrder(); - // get child list from current graph - std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); - if (child_graph_lists.size() > 1) { - std::list depend_input = {}; - for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { - auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]); - MS_EXCEPTION_IF_NULL(call_node); - // if call node is the last call of true graph,no need create child graph after that - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); - depend_input.push_front(call_node); - if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { - break; - } - } - depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); - auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); - auto new_return_primitive = - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); - graph->set_return(graph->NewCNode({new_return_primitive, depend})); - AnfNodePtr pre_call_node = nullptr; - AnfNodePtr cur_call_node = nullptr; - auto iter = depend_input.begin(); - for (++iter; iter != depend_input.end(); ++iter) { - pre_call_node = cur_call_node; - cur_call_node = *iter; - if (pre_call_node != nullptr && cur_call_node != nullptr) { - AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); - } - } - split_flag = true; - } - graph->UpdateChildGraphOrder(); - UpdateRealInput(graph, split_flag, memo); - MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; -} - -void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { - memo->insert(graph.get()); - SplitGraph(graph, {prim::kPrimCall}, memo); - for (auto &child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) == memo->end()) { - RecurseSplitGraph(NOT_NULL(child_graph), memo); - } - } -} - void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } void AscendSession::RootGraphExecutorValidate(NotNull graph) { AscendControlParser::ExecutorValidate(graph); } -void AscendSession::RecurseCompileGraph(NotNull graph, const NotNull *> memo) { - memo->insert(graph.get()); - CompileChildGraph(graph); - for (auto child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) != memo->end()) { - continue; - } - RecurseCompileGraph(NOT_NULL(child_graph), memo); - // copy ref map to final graph - auto child_ref_map = child_graph->GetRefMap(); - for (auto &item : child_ref_map) { - if (graph->IsInRefOutputMap(item.first)) { - MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; - } - graph->AddRefCorrespondPairs(item.first, item.second); - } - } -} - void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNull *> memo) { if (memo->find(graph.get()) != memo->end()) { return; } memo->insert(graph.get()); - graph->UpdateChildGraphOrder(); for (auto &child_graph : graph->child_graph_order()) { CreateMultiBranchOutput(NOT_NULL(child_graph), memo); } - std::map need_replace_list; auto node_list = GetCNodes(TopoSort(graph->get_return())); for (auto &node : node_list) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output - // auto multi_output_param = graph->NewParameter(); - auto origin_inputs = graph->inputs(); - auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get()); + auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); MS_EXCEPTION_IF_NULL(graph->MutableInputs()); - graph->MutableInputs()->operator=(origin_inputs); graph->AddChildGraphResult(output_param); std::vector depend_inputs = { @@ -1815,12 +901,16 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); for (auto &child_graph : child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); + // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert + // assign from condition to true graph + if (memo->find(child_graph) != memo->end()) { + continue; + } if (child_graph->get_output_null()) { continue; } - auto graph_output = child_graph->output(); - AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output), - NOT_NULL(output_param)); + AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, + NOT_NULL(child_graph->output()), NOT_NULL(output_param)); } } } @@ -1834,6 +924,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu } } } + memo->erase(graph.get()); } void AscendSession::IrFusionPass(const NotNull graph, NotNull *> memo) { @@ -1878,11 +969,11 @@ void AscendSession::SelectKernel(NotNull root_graph) { MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->execution_mode() == kGraphMode) { if (raise_precision_count > 0) { - MS_LOG(WARNING) << "There has " << raise_precision_count + MS_LOG(WARNING) << "There are " << raise_precision_count << " node/nodes used raise precision to selected the kernel!"; } if (reduce_precision_count > 0) { - MS_LOG(WARNING) << "There has " << raise_precision_count + MS_LOG(WARNING) << "There are " << reduce_precision_count << " node/nodes used reduce precision to selected the kernel!"; } } @@ -1942,8 +1033,6 @@ void AscendSession::HardwareOptimize(NotNull graph, memo->insert(graph.get()); MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id(); - // convert kernel Graph to model - predictmodel::StepConvertGraph(graph.get()); HardwareOptimize(graph.get()); for (auto &child_graph : graph->child_graph_order()) { diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 11cb1c92d2..a42377bbaa 100755 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H #include #include #include @@ -51,26 +51,16 @@ class AscendSession : public SessionBasic { py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors) override; - // set parameters of final graph - GraphId SetFinalGraphInput(const std::vector &args) override; - // set output of final graph - void SetFinalGraphOutput(const BaseRef &output) override; - // insert switch and set the relative active ops - void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; - // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter - void SetChildGraphInput(GraphId g, const VectorRef &args) override; // get graph id in child graphs by ME front anf node pointer GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; // get graph id of final graph GraphId GetFinalRunGraph() const override { return final_graph_id_; } - // insert active to graph - void SetActive(GraphId, GraphId) override; // compile child graph when session have multiple child graphs void CompileChildGraph(const KernelGraphPtr &child_graph); - void RecurseGetSummaryNodes(KernelGraph *graph, std::map> *summary); - void GetSummaryNodes(KernelGraph *graph); private: + void RecurseSetSummaryNodes(KernelGraph *graph, std::map> *summary); + void SetSummaryNodes(KernelGraph *graph) override; void InitRuntimeResource(); void SelectKernel(const KernelGraph &kernel_graph) const; void HardwareOptimize(const std::shared_ptr &kernel_graph) const; @@ -79,7 +69,8 @@ class AscendSession : public SessionBasic { void AssignStream(NotNull kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; - void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; + void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector &input_tensors, + KernelGraph *kernel_graph) const; void RunOpMemoryClear(const KernelGraph *kernel_graph) const; void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; void LoadTask(const std::shared_ptr &kernel_graph) const; @@ -91,63 +82,21 @@ class AscendSession : public SessionBasic { void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; void RunOpExecTask(const std::shared_ptr &kernel_graph) const; - size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); - size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); - size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); - - void SetFinalGraphOutput(const AnfNodePtr &node); - void SetFinalGraphOutput(const ValuePtr &value); - void SetFinalGraphOutput(const VectorRef &vec_output); - - void SplitGraph(NotNull graph, const std::set &cut_prims, - const NotNull *> memo); - // split graphs with recurse from root graph - void SplitGraphs(NotNull root_graph); - void BackendOptimization(const std::vector &all_graphs); - void LinkChildGraphs(NotNull graph); + static void BackendOptimization(const std::vector &all_graphs); + static void LinkChildGraphs(NotNull graph); void RootGraphExecutorValidate(NotNull graph); - std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list); - void RecurseCompileGraph(NotNull graph, const NotNull *> memo); - void RecurseSplitGraph(NotNull graph, const NotNull *> memo); - AnfNodePtr BindNewCallToNewGraph(NotNull graph, const std::vector &child_graph_list); - // merge execution order list of child graphs void MergeGraphExecOrder(); // insert assion op to sync data bettween different graphs void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); - // insert mutiple assigns to graph - void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); - // insert active op to graph - void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); - // get execute index of graph - size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); - // handle condition graph from vm - void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); - // insert depend to graph, used to attch control nodes to graph - void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); - // insert depend to graph, used to attch control nodes to graph - void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); - // set child graph parameter if front arg is a anf - void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); - // set child graph parameter if front arg is a tensor - void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); - // update the execution order of all child graphs - void UpdateGraphOrder(GraphId to_graph); - // handle switch when merge - void MergeSwitchCompile(); // get graph order vector by graph id - std::vector &GetGraphOrder(GraphId final_graph_id); + const std::vector &GetGraphOrder(GraphId final_graph_id) const; // get graph order type vector by graph id - std::vector &GetGraphOrderType(GraphId final_graph_id); - // copy output of if and else - void CopyOutputOfIf(GraphId false_graph_id); + const std::vector &GetGraphOrderType(GraphId final_graph_id) const; // check if graph cache exist bool GraphCacheExist(const GraphInfo &graph_info) const; // insert all assign to child graph void InsertAllAssigns(); - // create fake output of final graph - AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); // sync intial tensors' data to device void SyncInitialTenosrToDevice(); void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); @@ -161,16 +110,10 @@ class AscendSession : public SessionBasic { void AssignStaticMemory(const NotNull graph, NotNull *> memo) const; void UpdateRefOutputMap(const NotNull graph, NotNull *> memo) const; - // member variables // key is final_graph_id,value is child graph execute order of final graph std::unordered_map> graph_execute_orders_; // key is final_graph_id,value is the graph types of child graphs std::unordered_map> graph_order_types_; - // record condition graph of while - std::unordered_map while_condition_graphs_; - // record all conditions - std::unordered_map> switches_; - std::unordered_map condition_output_; // share parameters std::vector> assigns_; // initial tensors, these tensor will sync data to device before run graph @@ -181,4 +124,4 @@ class AscendSession : public SessionBasic { MS_REG_SESSION(kAscendDevice, AscendSession); } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index ca1c78d206..4ba62e53b7 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -16,18 +16,24 @@ #include "backend/session/cpu_session.h" #include +#include #include "ir/tensor.h" #include "ir/anf.h" #include "backend/kernel_compiler/kernel.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_runtime.h" -#include "predict/predict.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "runtime/device/cpu/kernel_select_cpu.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/replace_node_by_proxy.h" #ifdef ENABLE_DEBUGGER #include "debug/debugger/debugger.h" #endif +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "frontend/parallel/ps/util.h" +#endif namespace mindspore { namespace session { @@ -49,13 +55,29 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, return new_parameter; } +void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + std::string pass_name = "replace_node_by_proxy"; + pass_name.append(std::to_string(graph_sum_)); + pm->AddPass(std::make_shared(pass_name)); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { auto graph_id = graph_sum_; auto graph = ConstructKernelGraph(lst, outputs); MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Set kernel info"; SetKernelInfo(graph.get()); - predictmodel::StepConvertGraph(graph); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + AssignParamKey(graph); + if (parallel::ps::Util::IsRoleOfWorker()) { + Optimize(graph); + } +#endif MS_LOG(INFO) << "Build kernel"; BuildKernel(graph.get()); MS_LOG(INFO) << "Assign kernel address"; @@ -66,11 +88,13 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList void CPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { auto &kernel_graph = graphs_[graph_id]; MS_EXCEPTION_IF_NULL(kernel_graph); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + InitPSParamAndOptim(kernel_graph, inputs); +#endif MS_LOG(INFO) << "Bind input output address"; std::vector need_sync_outputs; runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs); MS_LOG(INFO) << "Run graph start"; - predictmodel::StepConvertWeight(inputs); auto execution_order = kernel_graph->execution_order(); Reorder(&execution_order); @@ -78,7 +102,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vectorset_execution_order(execution_order); NamedSummaryOutputs summary_outputs; if (enable_summary) { - GetSummaryNodes(kernel_graph.get()); + SetSummaryNodes(kernel_graph.get()); summary_outputs = kernel_graph->summary_nodes(); runtime_.IncreaseSummaryRefCount(summary_outputs); } @@ -119,6 +143,48 @@ void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { } } +namespace { +void KernelNotSupportException(const AnfNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + std::stringstream operator_info; + operator_info << "Operator[" << kernel_name << "] "; + auto kernel_info = dynamic_cast(kernel_node->kernel_info()); + if (kernel_info == nullptr) { + operator_info << "is not support."; + MS_LOG(EXCEPTION) << operator_info.str(); + } + auto kernel_build_Info = kernel_info->select_kernel_build_info(); + if (kernel_build_Info == nullptr) { + operator_info << "is not support."; + MS_LOG(EXCEPTION) << operator_info.str(); + } + size_t input_num = kernel_build_Info->GetInputNum(); + if (input_num > 0) { + operator_info << " input("; + for (size_t i = 0; i < input_num; ++i) { + operator_info << TypeIdLabel(kernel_build_Info->GetInputDeviceType(i)); + if (i != input_num - 1) { + operator_info << ","; + } + } + operator_info << ") "; + } + size_t output_num = kernel_build_Info->GetOutputNum(); + if (output_num > 0) { + operator_info << "output("; + for (size_t i = 0; i < output_num; ++i) { + operator_info << TypeIdLabel(kernel_build_Info->GetOutputDeviceType(i)); + if (i != kernel_build_Info->GetOutputNum() - 1) { + operator_info << ","; + } + } + operator_info << ") "; + } + operator_info << "is not support."; + MS_LOG(EXCEPTION) << operator_info.str(); +} +} // namespace + void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto &kernel_nodes = kernel_graph->execution_order(); @@ -129,7 +195,7 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { std::shared_ptr cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node); if (cpu_kernel == nullptr) { - MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support."; + KernelNotSupportException(kernel_node); } cpu_kernel->Init(kernel_node); AnfAlgo::SetKernelMod(cpu_kernel, kernel_node.get()); diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index b0dbd1cc2b..014b4168ab 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_CPU_SESSION_H -#define MINDSPORE_CCSRC_SESSION_CPU_SESSION_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_CPU_SESSION_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_CPU_SESSION_H #include #include #include @@ -37,6 +37,7 @@ class CPUSession : public SessionBasic { protected: ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; + void Optimize(const std::shared_ptr &kernel_graph); private: void SetKernelInfo(const KernelGraph *kernel_graph); @@ -46,4 +47,4 @@ class CPUSession : public SessionBasic { MS_REG_SESSION(kCPUDevice, CPUSession); } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_CPU_SESSION_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_CPU_SESSION_H diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 14e30c1a44..6e720babcf 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "debug/anf_ir_utils.h" #include "backend/session/gpu_session.h" #include "runtime/device/gpu/kernel_info_setter.h" #include "runtime/device/gpu/gpu_kernel_build.h" @@ -25,12 +26,16 @@ #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" #include "backend/optimizer/gpu/adam_fusion.h" +#include "backend/optimizer/gpu/replace_bn_cast_fusion.h" +#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" +#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" +#include "backend/optimizer/gpu/replace_addn_fusion.h" #include "runtime/device/kernel_runtime_manager.h" -#include "predict/predict.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/trans.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/base_ref_extends.h" +#include "debug/tensor_load.h" namespace mindspore { namespace session { @@ -59,6 +64,10 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto pm = std::make_shared(); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); @@ -90,12 +99,13 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { runtime_instance->AssignMemory(kernel_graph); } -void GPUSession::RunOpAllocateMemory(const std::vector &input_tensors, +void GPUSession::RunOpAllocateMemory(const ValuePtr &pre_output_value, + const std::vector &input_tensors, KernelGraph *kernel_graph) const { MS_EXCEPTION_IF_NULL(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); + runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph); } void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { @@ -153,7 +163,11 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, void GPUSession::Execute(const std::shared_ptr &kernel_graph) const { auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); +#ifdef ENABLE_DEBUGGER + if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) { +#else if (!runtime_instance->Run(kernel_graph.get())) { +#endif MS_LOG(EXCEPTION) << "GPU execute graph failed!"; } } @@ -163,16 +177,30 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList auto graph_id = graph_sum_; auto graph = ConstructKernelGraph(lst, outputs); MS_EXCEPTION_IF_NULL(graph); + // Prepare ms context info for dump .pb graph + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); // Optimize Optimize(graph); // Select kernel build info SelectKernel(graph); - // Convert kernel Graph to model - predictmodel::StepConvertGraph(graph); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + // Assign parameter keys. + AssignParamKey(graph); +#endif // Start gpu kernel runtime StartKernelRT(); + // Dump .pb graph before hardware optimization + if (save_graphs) { + DumpIRProto(graph, "before_hwopt_" + std::to_string(graph_id)); + } // HardwareOptimize HardwareOptimize(graph); + // Dump .pb graph after hardware optimization + if (save_graphs) { + DumpIRProto(graph, "after_hwopt_" + std::to_string(graph_id)); + } // Assign CUDA streams AssignStream(graph); // Hide NoOp from execution graph @@ -184,7 +212,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList Reorder(&execution_order); graph->set_execution_order(execution_order); // Get summary nodes. - GetSummaryNodes(graph.get()); + SetSummaryNodes(graph.get()); // Remove NoOp from execution graph opt::RemoveNopNode(graph.get()); // Set graph manager. @@ -202,16 +230,24 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { auto &kernel_graph = graphs_[graph_id]; +#ifdef ENABLE_DEBUGGER + PreIterationDbg(kernel_graph); +#endif // Load input data from user input LoadInputData(kernel_graph, inputs); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + // Initialize parameter server + InitPSParamAndOptim(kernel_graph, inputs); +#endif MS_EXCEPTION_IF_NULL(kernel_graph); - // Convert inputs to model - predictmodel::StepConvertWeight(inputs); { py::gil_scoped_release gil_release; // Run graph on GPU Execute(kernel_graph); } +#ifdef ENABLE_DEBUGGER + PostLoadTensor(kernel_graph); +#endif // Get result from GPU UpdateOutputs(kernel_graph, outputs, inputs); // Summary @@ -220,6 +256,9 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vectorenable_gpu_summary()) { Summary(kernel_graph.get()); } +#ifdef ENABLE_DEBUGGER + PostIterationDbg(kernel_graph); +#endif } void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, @@ -245,13 +284,27 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph MS_EXCEPTION_IF_NULL(kernel_graph); // Remove NoOp from execution graph opt::RemoveNopNode(kernel_graph.get()); - RunOpAllocateMemory(input_tensors, kernel_graph.get()); + RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get()); // Execute the computation LoadInputData(kernel_graph, input_tensors); - Execute(kernel_graph); + { + py::gil_scoped_release gil_release; + Execute(kernel_graph); + } // Fetch outputs VectorRef outputs; - UpdateOutputs(kernel_graph, &outputs, input_tensors); + if (op_run_info.value != nullptr) { + std::vector pre_output_tensors; + TensorValueToTensor(op_run_info.value, &pre_output_tensors); + for (auto &pre_output : pre_output_tensors) { + tensor::TensorPtr tensor = std::make_shared(pre_output->data_type(), pre_output->shape()); + tensor->set_device_address(pre_output->device_address()); + tensor->set_dirty(false); + outputs.emplace_back(tensor); + } + } else { + UpdateOutputs(kernel_graph, &outputs, input_tensors); + } // Trans output to tuple auto output_tensors = TransformBaseRefListToTuple(outputs); if (!utils::isa(output_tensors) || @@ -263,6 +316,70 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph RunOpClearMemory(kernel_graph.get()); return tuple_tensors; } + +#ifdef ENABLE_DEBUGGER +void GPUSession::Dump(const std::shared_ptr &kernel_graph) const { +#ifdef ENABLE_DUMP_E2E + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + (void)runtime_instance->DumpData(kernel_graph.get(), debugger_.get()); +#endif +} + +bool GPUSession::DumpDataEnabledIteration() const { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + return runtime_instance->DumpDataEnabledIteration(); +} + +void GPUSession::PreIterationDbg(const std::shared_ptr &kernel_graph) const { + if (debugger_) { + debugger_->PreExecute(kernel_graph); + } + PreLoadTensor(kernel_graph); +} + +void GPUSession::PostIterationDbg(const std::shared_ptr &kernel_graph) const { + bool dump_enabled = DumpDataEnabledIteration(); + // debug used for dump + if (debugger_ && dump_enabled) { + Dump(kernel_graph); + } + if (debugger_) { + debugger_->PostExecute(); + } +} + +void GPUSession::PreLoadTensor(const std::shared_ptr &kernel_graph) const { + bool dump_enabled = DumpDataEnabledIteration(); + if (!(debugger_ && (debugger_->debugger_enabled() || dump_enabled))) { + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + DebugServices *debug_services = debugger_->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + tensor_loader->EmptyTensor(); + uint32_t iter_num = tensor_loader->GetIterNum(); + tensor_loader->set_iter_num(++iter_num); +} + +void GPUSession::PostLoadTensor(const std::shared_ptr &kernel_graph) const { + bool dump_enabled = DumpDataEnabledIteration(); + if (!(debugger_ && (debugger_->debugger_enabled() || dump_enabled))) { + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + DebugServices *debug_services = debugger_->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + tensor_loader->EmptyPrevTensor(); +} +#endif + } // namespace gpu } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 7e07dfbcbd..70d904ef7a 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_GPU_SESSION_H -#define MINDSPORE_CCSRC_SESSION_GPU_SESSION_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H #include #include @@ -59,7 +59,8 @@ class GPUSession : public SessionBasic { void AllocateMemory(KernelGraph *kernel_graph) const; - void RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const; + void RunOpAllocateMemory(const ValuePtr &pre_output_value, const std::vector &input_tensors, + KernelGraph *kernel_graph) const; void RunOpClearMemory(KernelGraph *kernel_graph) const; @@ -67,10 +68,24 @@ class GPUSession : public SessionBasic { const std::vector &inputs_const) const override; void Execute(const std::shared_ptr &kernel_graph) const; + +#ifdef ENABLE_DEBUGGER + void Dump(const std::shared_ptr &kernel_graph) const; + + bool DumpDataEnabledIteration() const; + + void PreIterationDbg(const std::shared_ptr &kernel_graph) const; + + void PostIterationDbg(const std::shared_ptr &kernel_graph) const; + + void PreLoadTensor(const std::shared_ptr &kernel_graph) const; + + void PostLoadTensor(const std::shared_ptr &kernel_graph) const; +#endif }; using GPUSessionPtr = std::shared_ptr; MS_REG_SESSION(kGPUDevice, GPUSession); } // namespace gpu } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_GPU_SESSION_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H diff --git a/mindspore/ccsrc/backend/session/infer_session.cc b/mindspore/ccsrc/backend/session/infer_session.cc new file mode 100644 index 0000000000..b7829795b2 --- /dev/null +++ b/mindspore/ccsrc/backend/session/infer_session.cc @@ -0,0 +1,377 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/infer_session.h" +#include +#include +#include "include/inference.h" +#include "utils/load_onnx/anf_converter.h" +#include "backend/session/session_basic.h" +#include "backend/session/session_factory.h" +#include "base/base_ref_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/context/context_extends.h" + +#ifdef ENABLE_D +#include "utils/ms_context.h" +#endif + +using std::string; +using std::vector; + +namespace py = pybind11; +namespace mindspore::inference { + +std::shared_ptr InferSession::CreateSession(const std::string &device, uint32_t device_id) { + try { + auto session = std::make_shared(); + Status ret = session->InitEnv(device, device_id); + if (ret != SUCCESS) { + return nullptr; + } + return session; + } catch (std::bad_alloc &e) { + MS_LOG(ERROR) << "Inference CreatSession failed, failed to alloc memory"; + return nullptr; + } +} + +MSInferSession::MSInferSession() = default; +MSInferSession::~MSInferSession() = default; + +std::shared_ptr> MSInferSession::ReadFile(const std::string &file) { + if (file.empty()) { + MS_LOG(ERROR) << "file is nullptr"; + return nullptr; + } + std::string realPath = file; + std::ifstream ifs(realPath); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << realPath << " is not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << realPath << "open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + std::shared_ptr> buf(new (std::nothrow) std::vector(size)); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf->data(), size); + ifs.close(); + + return buf; +} + +Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { + auto graphBuf = ReadFile(file_name); + if (graphBuf == nullptr) { + MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); + return FAILED; + } + auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_); + if (graph == nullptr) { + MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + return FAILED; + } + Status ret = CompileGraph(graph, model_id); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str(); + return FAILED; + } + MS_LOG(INFO) << "Load model from file " << file_name << " success"; + +#ifdef ENABLE_D + // set d context + rtError_t rt_ret = rtCtxGetCurrent(&context_); + if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { + MS_LOG(ERROR) << "the ascend device context is null"; + return FAILED; + } +#endif + + return SUCCESS; +} + +Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; } + +Status ServingTensor2MSTensor(size_t index, const InferTensorBase &out_tensor, tensor::TensorPtr &ms_tensor) { + std::vector shape; + for (auto dim : out_tensor.shape()) { + shape.push_back(static_cast(dim)); + } + TypeId data_type; + const std::map type2id_map{ + {inference::kMSI_Unknown, TypeId::kNumberTypeBegin}, {inference::kMSI_Bool, TypeId::kNumberTypeBool}, + {inference::kMSI_Int8, TypeId::kNumberTypeInt8}, {inference::kMSI_Uint8, TypeId::kNumberTypeUInt8}, + {inference::kMSI_Int16, TypeId::kNumberTypeInt16}, {inference::kMSI_Uint16, TypeId::kNumberTypeUInt16}, + {inference::kMSI_Int32, TypeId::kNumberTypeInt32}, {inference::kMSI_Uint32, TypeId::kNumberTypeUInt32}, + {inference::kMSI_Int64, TypeId::kNumberTypeInt64}, {inference::kMSI_Uint64, TypeId::kNumberTypeUInt64}, + {inference::kMSI_Float16, TypeId::kNumberTypeFloat16}, {inference::kMSI_Float32, TypeId::kNumberTypeFloat32}, + {inference::kMSI_Float64, TypeId::kNumberTypeFloat64}, + }; + auto it = type2id_map.find(out_tensor.data_type()); + if (it == type2id_map.end()) { + MSI_LOG_WARNING << "undefined MSI data type " << out_tensor.data_type(); + return FAILED; + } else { + data_type = it->second; + } + + ms_tensor = std::make_shared(data_type, shape); + if (ms_tensor->Size() != out_tensor.data_size()) { + MSI_LOG_ERROR << "input " << std::to_string(index) + << " data size not match shape and dtype, calculated required size " << ms_tensor->Size() + << ", given " << out_tensor.data_size(); + return INFER_STATUS(INVALID_INPUTS) << "input " << std::to_string(index) + << " data size not match shape and dtype, calculated required size " + << ms_tensor->Size() << ", given " << out_tensor.data_size(); + } + memcpy_s(ms_tensor->data_c(), ms_tensor->Size(), out_tensor.data(), out_tensor.data_size()); + return SUCCESS; +} + +void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_tensor) { + vector shape; + for (auto dim : ms_tensor->shape()) { + shape.push_back(dim); + } + out_tensor.set_shape(shape); + + const std::map id2type_map{ + {TypeId::kNumberTypeBegin, inference::kMSI_Unknown}, {TypeId::kNumberTypeBool, inference::kMSI_Bool}, + {TypeId::kNumberTypeFloat64, inference::kMSI_Float64}, {TypeId::kNumberTypeInt8, inference::kMSI_Int8}, + {TypeId::kNumberTypeUInt8, inference::kMSI_Uint8}, {TypeId::kNumberTypeInt16, inference::kMSI_Int16}, + {TypeId::kNumberTypeUInt16, inference::kMSI_Uint16}, {TypeId::kNumberTypeInt32, inference::kMSI_Int32}, + {TypeId::kNumberTypeUInt32, inference::kMSI_Uint32}, {TypeId::kNumberTypeInt64, inference::kMSI_Int64}, + {TypeId::kNumberTypeUInt64, inference::kMSI_Uint64}, {TypeId::kNumberTypeFloat16, inference::kMSI_Float16}, + {TypeId::kNumberTypeFloat32, inference::kMSI_Float32}, + }; + auto it = id2type_map.find(ms_tensor->data_type()); + if (it == id2type_map.end()) { + MSI_LOG_WARNING << "undefined MS data type " << ms_tensor->data_type(); + out_tensor.set_data_type(inference::kMSI_Unknown); + } else { + out_tensor.set_data_type(it->second); + } + out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size()); +} + +Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) { +#ifdef ENABLE_D + if (context_ == nullptr) { + MS_LOG(ERROR) << "rtCtx is nullptr"; + return FAILED; + } + rtError_t rt_ret = rtCtxSetCurrent(context_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "set Ascend rtCtx failed"; + return FAILED; + } +#endif + + vector inputs; + for (size_t i = 0; i < request.size(); i++) { + if (request[i] == nullptr) { + MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i; + return FAILED; + } + tensor::TensorPtr input = nullptr; + auto ret = ServingTensor2MSTensor(i, *request[i], input); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Tensor convert failed"; + return ret; + } + inputs.push_back(input); + } + auto ret = CheckModelInputs(model_id, inputs); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed"; + return ret; + } + vector outputs = RunGraph(model_id, inputs); + if (outputs.empty()) { + MS_LOG(ERROR) << "Execute Model " << model_id << " Failed"; + return FAILED; + } + reply.clear(); + for (const auto &tensor : outputs) { + auto out_tensor = reply.add(); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed"; + return FAILED; + } + MSTensor2ServingTensor(tensor, *out_tensor); + } + return SUCCESS; +} + +Status MSInferSession::FinalizeEnv() { + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + if (!context::CloseTsd(ms_context)) { + MS_LOG(ERROR) << "Inference CloseTsd failed!"; + return FAILED; + } + return SUCCESS; +} + +std::shared_ptr MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) { + try { + auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed"; + return nullptr; + } +} + +void MSInferSession::RegAllOp() { + static std::mutex init_mutex; + static bool Initialized = false; + + std::lock_guard lock(init_mutex); + if (Initialized) { + return; + } + Initialized = true; + MsContext::GetInstance()->set_execution_mode(kGraphMode); + Py_Initialize(); + auto c_expression = PyImport_ImportModule("mindspore._c_expression"); + if (c_expression == nullptr) { + MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; + return; + } + PyObject *c_expression_dict = PyModule_GetDict(c_expression); + + PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); + if (op_info_loader_class == nullptr) { + MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; + return; + } + PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); + if (op_info_loader == nullptr) { + MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; + return; + } + PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); + if (op_info_loader_ins == nullptr) { + MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; + return; + } + auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); + if (all_ops_info_vector_addr_ul == nullptr) { + MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; + return; + } + auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); + auto all_ops_info = static_cast *>(all_ops_info_vector_addr); + for (auto op_info : *all_ops_info) { + kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); + } + all_ops_info->clear(); + delete all_ops_info; + Py_DECREF(op_info_loader); + Py_DECREF(op_info_loader_class); + Py_DECREF(c_expression_dict); + Py_DECREF(c_expression); + return; +} + +Status MSInferSession::CompileGraph(std::shared_ptr funcGraphPtr, uint32_t &model_id) { + MS_ASSERT(session_impl_ != nullptr); + try { + auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + py::gil_scoped_release gil_release; + model_id = graph_id; + return SUCCESS; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CompileGraph failed"; + return FAILED; + } +} + +std::vector MSInferSession::RunGraph(uint32_t graph_id, + const std::vector &inputs) { + try { + VectorRef outputs; + session_impl_->RunGraph(graph_id, inputs, &outputs); + + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference Rungraph failed"; + return std::vector(); + } +} + +string MSInferSession::AjustTargetName(const std::string &device) { + if (device == kAscendDevice) { + return std::string(kAscendDevice) + "Inference"; + } else { + MS_LOG(ERROR) << "Only support device Ascend right now"; + return ""; + } +} + +Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { + RegAllOp(); + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + ms_context->set_execution_mode(kGraphMode); + ms_context->set_device_id(device_id); + auto ajust_device = AjustTargetName(device); + if (ajust_device == "") { + return FAILED; + } + ms_context->set_device_target(device); + session_impl_ = session::SessionFactory::Get().Create(ajust_device); + if (session_impl_ == nullptr) { + MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; + return FAILED; + } + session_impl_->Init(device_id); + if (!context::OpenTsd(ms_context)) { + MS_LOG(ERROR) << "Session init OpenTsd failed!"; + return FAILED; + } + return SUCCESS; +} + +Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const { + MS_ASSERT(session_impl_ != nullptr); + std::string error_msg; + if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) { + return INFER_STATUS(INVALID_INPUTS) << error_msg; + } + return SUCCESS; +} + +} // namespace mindspore::inference diff --git a/mindspore/ccsrc/backend/session/infer_session.h b/mindspore/ccsrc/backend/session/infer_session.h new file mode 100644 index 0000000000..c58e16e382 --- /dev/null +++ b/mindspore/ccsrc/backend/session/infer_session.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H +#define MINDSPORE_CCSRC_SESSION_SESSION_H + +#include +#include +#include +#include +#include +#include + +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "include/inference.h" + +#ifdef ENABLE_D +#include "runtime/context.h" +#endif + +namespace mindspore { +namespace inference { +class MSInferSession : public InferSession { + public: + MSInferSession(); + ~MSInferSession(); + + Status InitEnv(const std::string &device_type, uint32_t device_id) override; + Status FinalizeEnv() override; + Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; + Status UnloadModel(uint32_t model_id) override; + Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override; + + private: + std::shared_ptr session_impl_ = nullptr; + std::vector graph_id_; + std::string device_type_; + int32_t device_id_; +#ifdef ENABLE_D + rtContext_t context_ = nullptr; +#endif + + std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device); + std::shared_ptr> ReadFile(const std::string &file); + static void RegAllOp(); + string AjustTargetName(const std::string &device); + Status CompileGraph(std::shared_ptr funcGraphPtr, uint32_t &model_id); + Status CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const; + std::vector RunGraph(uint32_t graph_id, const std::vector &inputs); +}; +} // namespace inference +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.cc b/mindspore/ccsrc/backend/session/kernel_build_client.cc new file mode 100644 index 0000000000..e12b6de9bf --- /dev/null +++ b/mindspore/ccsrc/backend/session/kernel_build_client.cc @@ -0,0 +1,177 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/kernel_build_client.h" + +#include +#include + +namespace mindspore { +namespace kernel { +void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { + std::string::size_type start = 0; + while ((start = (*dest).find(replace, start)) != std::string::npos) { + (*dest).replace(start, replace.size(), 1, new_char); + start++; // Replaced 1 charactor. + } +} + +int AscendKernelBuildClient::TbeStart(const std::string &json) { + // Start compiling.. + auto res = SendRequest(kTbeStart); + if (res != kAck) { + MS_LOG(ERROR) << "START failed, res: " << res; + return -1; + } + // Send the json data. + res = SendRequest(json); + if (res == kFailed) { + MS_LOG(ERROR) << "TBE/START responds failed, res: " << res; + return -1; + } + // Return task id. + return std::stoi(res); +} + +bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) { + // Start waiting.. + auto res = SendRequest(kTbeWait); + if (res != kAck) { + MS_LOG(ERROR) << "TBE/WAIT failed, res: " << res; + return false; + } + // Request task id. + *task_id = std::stoi(SendRequest(kContinue)); + // Requst task result. + *task_result = SendRequest(kContinue); + // Request prebuild result. + *pre_build_result = SendRequest(kContinue); + return true; +} + +void AscendKernelBuildClient::TbeReset() { + // Start compiling.. + auto res = SendRequest(kTbeReset); + if (res != kAck) { + MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res; + } +} + +bool AscendKernelBuildClient::AkgStart(int process_num, int wait_time) { + // Start compiling.. + auto res = SendRequest(kAkgStart); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/START failed, res: " << res; + return false; + } + std::string process_num_str = std::to_string(process_num); + res = SendRequest(process_num_str); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/START(process_num) responds failed, res: " << res; + return false; + } + std::string wait_time_str = std::to_string(wait_time); + res = SendRequest(wait_time_str); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/START(wait_time) responds failed, res: " << res; + return false; + } + return true; +} + +bool AscendKernelBuildClient::AkgSendData(const std::vector &jsons) { + auto res = SendRequest(kAkgData); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/DATA failed, res: " << res; + return false; + } + for (auto &json : jsons) { + res = SendRequest(json); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/DATA.. responds failed, res: " << res << ", when sending [" << json << "]"; + return false; + } + } + return true; +} + +// Fetch the result of AKG compiling. +bool AscendKernelBuildClient::AkgWait() { + auto res = SendRequest(kAkgWait); + if (res != kTrue) { + MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res; + return false; + } + return true; +} + +std::string AscendKernelBuildClient::SelectFormat(const std::string &json) { + // Start compiling.. + auto res = SendRequest(kFormat); + if (res != kAck) { + MS_LOG(ERROR) << "FORMAT failed, res: " << res; + return ""; + } + // Send the json data. + res = SendRequest(json); + if (res == kErr) { + MS_LOG(ERROR) << "FORMAT responds failed, res: " << res; + return ""; + } + return res; +} + +bool AscendKernelBuildClient::CheckSupported(const std::string &json) { + // Checking support.. + auto res = SendRequest(kSupport); + if (res != kAck) { + MS_LOG(ERROR) << "SUPPORT failed, res: " << res; + return false; + } + // Send the json data. + res = SendRequest(json); + if (res != kTrue) { + MS_LOG(INFO) << "SUPPORT responds failed, res: " << res; + return false; + } + return true; +} + +int GpuKernelBuildClient::AkgGetPid() { + auto res = SendRequest(kAkgPid); + if (res == kErr) { + MS_LOG(ERROR) << "AKG/PID failed, res: " << res; + return -1; + } + return std::stoi(res); +} + +bool GpuKernelBuildClient::AkgCompileSingle(const std::string json) { + auto res = SendRequest(kAkgCompileOp); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/COMPILE failed, res: " << res; + return false; + } + // Send single json data. + res = SendRequest(json); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/COMPILE responds failed, res: " << res; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.h b/mindspore/ccsrc/backend/session/kernel_build_client.h new file mode 100644 index 0000000000..f442c8846d --- /dev/null +++ b/mindspore/ccsrc/backend/session/kernel_build_client.h @@ -0,0 +1,260 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ +#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ + +#include +#include +#include +#include +#include + +#include "common/duplex_pipe.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace kernel { +void ReplaceStr(std::string *dest, const std::string &replace, char new_char); + +constexpr inline static int kBufferSize = 4096; +// The TAG as prefix of real command from remote. +constexpr inline static auto kTag = "[~]"; + +class KernelBuildClient { + public: + // Send Finish request to server + constexpr inline static auto kFinish = "FINISH"; + // Receive the response from server + constexpr inline static auto kAck = "ACK"; + constexpr inline static auto kErr = "ERR"; + constexpr inline static auto kTrue = "True"; + constexpr inline static auto kSuccess = "Success"; + + // Revert \n, \r, [space]. + constexpr inline static auto kLF = "[LF]"; + constexpr inline static auto kCR = "[CR]"; + constexpr inline static auto kSP = "[SP]"; + + constexpr inline static unsigned int kTimeOutSeconds = 350; + + virtual std::string GetEnv() = 0; + virtual std::string GetScript() = 0; + + void Open() { + if (!init_) { + // Exception's thrown if open failed + if (dp_->Open({GetEnv(), GetScript()}, true) != -1) { + dp_->SetTimeOutSeconds(kTimeOutSeconds); + dp_->SetTimeOutCallback([this]() { SendRequest(kFinish); }); + init_ = true; + } + } + } + void Close() { + if (init_) { + dp_->Close(); + init_ = false; + } + } + + // Send a request and fetch its response + std::string SendRequest(std::string data) { + Request(data); + return Response(); + } + void Request(std::string req) { + if (!init_) { + MS_LOG(EXCEPTION) << "Try to send request before Open()"; + } + MS_LOG(DEBUG) << "\t[" << req << "]"; + *dp_ << req; + } + std::string Response() { + if (!init_) { + MS_LOG(EXCEPTION) << "Try to get response before Open()"; + } + std::string res; + *dp_ >> res; + // Filter out the interference + auto start = res.find(kTag); + if (start == std::string::npos) { + MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res; + } + res = res.substr(start + std::strlen(kTag), res.size() - start); + // Revert the line feed and space + if (res != kSuccess && res != kAck && res != kErr && res != kTrue) { + ReplaceStr(&res, kLF, '\n'); + ReplaceStr(&res, kSP, ' '); + } + MS_LOG(DEBUG) << "\t[" << res << "]"; + return res; + } + + protected: + KernelBuildClient() : init_(false), dp_(std::make_shared()) {} + virtual ~KernelBuildClient() = default; + + private: + bool init_; + std::shared_ptr dp_; +}; + +static inline std::string GetScriptFilePath(const std::string cmd_env, const std::string &cmd_script) { + std::string cmd = cmd_env; + (void)cmd.append(1, ' ').append(cmd_script); + FILE *fpipe = popen(cmd.c_str(), "r"); + if (fpipe == nullptr) { + MS_LOG(EXCEPTION) << "popen failed, " << strerror(errno) << "(" << errno << ")"; + } + bool start = false; + std::string result; + char buf[kBufferSize]; + while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) { + if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) { + start = true; + } + // Filter with 'kTAG' and '\n' + if (start) { + auto size = std::strlen(buf); + bool line_end = buf[size - 1] == '\n'; + result.append(buf, line_end ? size - 1 : size); + if (line_end) { + break; + } + } + } + pclose(fpipe); + const std::string py_suffix = ".py"; + if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) { + MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}"; + } + result = result.substr(strlen(kTag)); + MS_LOG(DEBUG) << "result: " << result; + return result; +} + +class AscendKernelBuildClient : public KernelBuildClient { + public: + // Server configure + constexpr inline static auto kEnv = "python"; + constexpr inline static auto kGetPathScript = + "-c " + "\"" + "import pkgutil;" + "path = pkgutil" + ".get_loader(\\\"mindspore._extends.remote.kernel_build_server_ascend\\\")" // Server module name + ".get_filename();" + "print('[~]' + path)" + "\""; + + // Receive the response from server + constexpr inline static auto kFailed = "-1"; + + // Send building request to server + constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued + constexpr inline static auto kTbeStart = "TBE/START"; + constexpr inline static auto kTbeWait = "TBE/WAIT"; + constexpr inline static auto kTbeReset = "TBE/RESET"; + constexpr inline static auto kAkgStart = "AKG/START"; + constexpr inline static auto kAkgData = "AKG/DATA"; + constexpr inline static auto kAkgWait = "AKG/WAIT"; + + // Send server info. query to server + constexpr inline static auto kFormat = "FORMAT"; + constexpr inline static auto kSupport = "SUPPORT"; + + static AscendKernelBuildClient &Instance() { + static AscendKernelBuildClient instance; + return instance; + } + + std::string GetEnv() override { return kEnv; } + + std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); } + + // Before building. + std::string SelectFormat(const std::string &json); + bool CheckSupported(const std::string &json); + + // Run TBE building. + int TbeStart(const std::string &json); + bool TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result); + void TbeReset(); + + // Run AKG building. + bool AkgStart(int process_num, int wait_time); + bool AkgSendData(const std::vector &jsons); + bool AkgWait(); + bool AkgCompileSingle(const std::string json); + + AscendKernelBuildClient(const AscendKernelBuildClient &) = delete; + AscendKernelBuildClient &operator=(const AscendKernelBuildClient &) = delete; + + AscendKernelBuildClient(AscendKernelBuildClient &&) = delete; + AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete; + + private: + AscendKernelBuildClient() { Open(); } + ~AscendKernelBuildClient() override { Close(); } +}; + +class GpuKernelBuildClient : public KernelBuildClient { + public: + // Server configure + constexpr inline static auto kEnv = "python"; + constexpr inline static auto kGetPathScript = + "-c " + "\"" + "import pkgutil;" + "path = pkgutil" + ".get_loader(\\\"mindspore._extends.remote.kernel_build_server_gpu\\\")" // Server module name + ".get_filename();" + "print('[~]' + path)" + "\""; + + // Send building request to server + constexpr inline static auto kAkgPid = "AKG/PID"; + constexpr inline static auto kAkgCompileOp = "AKG/COMPILE"; // Compile a single op + + static GpuKernelBuildClient &Instance() { + static GpuKernelBuildClient instance; + return instance; + } + + std::string GetEnv() override { return kEnv; } + + std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); } + + // Fetch pid(pid_t) from remote. + int AkgGetPid(); + // Run AKG building. + bool AkgCompileSingle(const std::string json); + + GpuKernelBuildClient(const GpuKernelBuildClient &) = delete; + GpuKernelBuildClient &operator=(const GpuKernelBuildClient &) = delete; + + GpuKernelBuildClient(GpuKernelBuildClient &&) = delete; + GpuKernelBuildClient &operator=(GpuKernelBuildClient &&) = delete; + + private: + GpuKernelBuildClient() { Open(); } + ~GpuKernelBuildClient() override { Close(); } +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index df810fe6ef..6b49b4b878 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -79,31 +79,6 @@ std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { return real_inputs; } -AnfNodePtr MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return nullptr; - } - - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} - bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { if (left == right) { return true; @@ -120,7 +95,50 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { } return false; } + +void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector *device_formats, + std::vector *device_types) { + MS_EXCEPTION_IF_NULL(value_node); + MS_EXCEPTION_IF_NULL(device_formats); + MS_EXCEPTION_IF_NULL(device_types); + ValuePtr value = value_node->value(); + std::vector tensors; + TensorValueToTensor(value, &tensors); + if (!tensors.empty()) { + if (tensors.size() != AnfAlgo::GetOutputTensorNum(value_node)) { + MS_LOG(EXCEPTION) << "The size of tensors converted from value [" << tensors.size() + << "] is not equal to output size of value node [" << AnfAlgo::GetOutputTensorNum(value_node) + << "]"; + } + device_formats->clear(); + device_types->clear(); + for (const auto &tensor : tensors) { + MS_EXCEPTION_IF_NULL(tensor); + auto device_sync = tensor->device_address(); + if (device_sync != nullptr) { + auto device_address = std::dynamic_pointer_cast(device_sync); + MS_EXCEPTION_IF_NULL(device_address); + device_formats->emplace_back(device_address->format()); + device_types->emplace_back(device_address->type_id()); + continue; + } + device_formats->emplace_back(kOpFormat_DEFAULT); + device_types->emplace_back(kTypeUnknown); + } + } +} } // namespace +AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + this->SetKernelInfoForNode(new_value_node); + return new_value_node; +} + std::vector KernelGraph::outputs() const { auto graph_output = output(); if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { @@ -290,28 +308,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { MS_EXCEPTION_IF_NULL(cnode); cnode->set_abstract(std::make_shared()); CreateKernelInfoFromNewParameter(cnode); - - auto kernel_info = std::make_shared(); - std::vector feature_map_input_indexs; - // if the node only has the primitive(such as getNext) or the node's input has a feature map input - // then the node's output is a feature map output - for (size_t index = 1; index < inputs.size(); ++index) { - auto node = inputs[index]; - if (AnfAlgo::IsFeatureMapOutput(node)) { - feature_map_input_indexs.push_back(index); - } - } if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); } - if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { - kernel_info->SetFeatureMapFlag(true); - } - if (AnfAlgo::IsRealCNodeKernel(cnode)) { - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); - AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); - } - cnode->set_kernel_info(kernel_info); + SetKernelInfoForNode(cnode); AnfAlgo::SetGraphId(graph_id_, cnode.get()); return cnode; } @@ -351,6 +351,53 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { } } +void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = std::make_shared(); + node->set_kernel_info(kernel_info); + if (node->isa()) { + std::vector feature_map_input_indexs; + kernel_info->SetFeatureMapFlag(false); + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { + if (AnfAlgo::IsFeatureMapInput(node, index)) { + kernel_info->SetFeatureMapFlag(true); + feature_map_input_indexs.push_back(index); + } + } + if (AnfAlgo::GetInputTensorNum(node) == 0) { + kernel_info->SetFeatureMapFlag(true); + } + if (AnfAlgo::IsRealKernel(node)) { + // if the node only has the primitive(such as getNext) or the node's input has a feature map input + // then the node's output is a feature map output + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node); + AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node); + } + return; + } + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + std::vector types; + std::vector formats = {kOpFormat_DEFAULT}; + if (node->isa()) { + kernel_info->SetFeatureMapFlag(false); + types.emplace_back(kTypeUnknown); + auto value_node = node->cast(); + SyncDeviceInfoToValueNode(value_node, &formats, &types); + } + if (node->isa()) { + auto parameter = node->cast(); + MS_EXCEPTION_IF_NULL(parameter); + bool is_weight = AnfAlgo ::IsParameterWeight(parameter); + kernel_info->SetFeatureMapFlag(!is_weight); + types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); + } + // set parameter initaial device data type + kernel_build_info_builder->SetOutputsFormat(formats); + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); +} + CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); auto new_cnode = std::make_shared(*cnode); @@ -366,81 +413,31 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { } ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { - ParameterPtr new_parameter = add_parameter(); + auto abstract = parameter == nullptr ? std::make_shared() : parameter->abstract(); + auto new_parameter = NewParameter(abstract); MS_EXCEPTION_IF_NULL(new_parameter); - // create kernel_info form new parameter - auto kernel_info = std::make_shared(); - size_t output_tensor_num = 1; - // if use default parameter = nullptr,it remarks create a new parameter from no parameter - if (parameter == nullptr) { - new_parameter->set_abstract(std::make_shared()); - kernel_info->SetFeatureMapFlag(true); - } else { - // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter - new_parameter->set_abstract(parameter->abstract()); + + // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter + if (parameter != nullptr) { new_parameter->set_name(parameter->name()); if (AnfAlgo::IsParameterWeight(parameter)) { new_parameter->set_default_param(parameter->default_param()); - kernel_info->SetFeatureMapFlag(false); - } else { - kernel_info->SetFeatureMapFlag(true); } } - new_parameter->set_kernel_info(kernel_info); - // create kernel_build_info for new parameter - auto kernel_build_info_builder = std::make_shared(); - // create init data type, - std::vector init_data_type = {}; - - TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); - init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); - - // set the format of parameter to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); - // set parameter initaial device data type - kernel_build_info_builder->SetOutputsDeviceType(init_data_type); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get()); + // create kernel_info form new parameter + SetKernelInfoForNode(new_parameter); AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); return new_parameter; } -std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto node_value = value_node->value(); - auto output_size = AnfAlgo::GetOutputTensorNum(value_node); - std::vector convert_inputs; - if (!node_value->isa()) { - MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); - } - auto value_tuple = node_value->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - if (value_tuple->size() != output_size) { - MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() - << " is not mathced with the value node's output size" << output_size; - } - for (size_t index = 0; index < value_tuple->value().size(); ++index) { - auto new_value_node = std::make_shared(value_tuple->value()[index]); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, - {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get()); - AddValueNodeToGraph(new_value_node); - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - kernel_info->SetFeatureMapFlag(false); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - AddValueNodeToGraph(new_value_node); - convert_inputs.emplace_back(new_value_node); - } - if (!RemoveValueNodeFromGraph(value_node)) { - MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); - } - return convert_inputs; +ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) { + ParameterPtr new_parameter = add_parameter(); + new_parameter->set_abstract(abstract); + MS_EXCEPTION_IF_NULL(new_parameter); + // create kernel_info form new parameter + SetKernelInfoForNode(new_parameter); + AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); + return new_parameter; } ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { @@ -450,6 +447,110 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { return new_value_node; } +ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(abstract); + MS_EXCEPTION_IF_NULL(value); + ValueNodePtr new_value_node = std::make_shared(value); + new_value_node->set_abstract(abstract); + SetKernelInfoForNode(new_value_node); + AnfAlgo::SetGraphId(graph_id(), new_value_node.get()); + return new_value_node; +} + +AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(abstract); + MS_EXCEPTION_IF_NULL(value); + if (!abstract->isa()) { + auto new_value_node = NewValueNode(abstract, value); + AddValueNodeToGraph(new_value_node); + return new_value_node; + } + auto tuple_abstract = abstract->cast(); + auto value_tuple = value->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_EXCEPTION_IF_NULL(value_tuple); + if (tuple_abstract->size() != value_tuple->size()) { + MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size() + << " is not equal to value size:" << value_tuple->size(); + } + std::vector make_tuple_inputs = { + mindspore::NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; + for (size_t index = 0; index < tuple_abstract->size(); ++index) { + make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index])); + } + auto make_tuple = NewCNode(make_tuple_inputs); + make_tuple->set_abstract(tuple_abstract); + return make_tuple; +} + +AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) { + MS_EXCEPTION_IF_NULL(abstract); + if (!abstract->isa()) { + return NewParameter(abstract); + } + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + std::vector make_tuple_inputs = { + mindspore::NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; + for (size_t index = 0; index < tuple_abstract->size(); ++index) { + make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index])); + } + auto make_tuple = NewCNode(make_tuple_inputs); + make_tuple->set_abstract(tuple_abstract); + return make_tuple; +} + +AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) { + auto idx = mindspore::NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector types; + std::vector> shapes; + std::vector make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)}; + for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(node); ++tuple_out_index) { + make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index)); + types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index)); + } + auto make_tuple = NewCNode(make_tuple_inputs_list); + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); + return make_tuple; +} + +AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsTupleOutput(node)) { + return node; + } + if (node->isa()) { + return TransParameterTuple(node->abstract()); + } else if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value()); + if (RemoveValueNodeFromGraph(value_node)) { + MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); + } + return make_tuple; + } else if (node->isa()) { + return TransCNodeTuple(node->cast()); + } + MS_LOG(EXCEPTION) << "Unexpected node:" << node->DebugString(); +} + const std::vector &KernelGraph::inputs() const { MS_EXCEPTION_IF_NULL(inputs_); return *inputs_; @@ -747,6 +848,23 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { return false; } +void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) { + // update graph inputs + MS_EXCEPTION_IF_NULL(old_parameter); + MS_EXCEPTION_IF_NULL(new_parameter); + if (old_parameter == new_parameter) { + return; + } + for (size_t i = 0; i < inputs_->size(); i++) { + if ((*inputs_)[i] == old_parameter) { + MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString() + << ",new graph input:" << new_parameter->DebugString(); + (*inputs_)[i] = new_parameter; + break; + } + } +} + void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { MS_EXCEPTION_IF_NULL(inputs_); { @@ -770,15 +888,7 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNullset_input(i, new_anf_node); } } - // update graph inputs - for (size_t i = 0; i < inputs_->size(); i++) { - if ((*inputs_)[i] == old_anf_node.get()) { - MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() - << ",new graph input:" << new_anf_node->DebugString(); - (*inputs_)[i] = new_anf_node.get(); - break; - } - } + ReplaceGraphInput(old_anf_node, new_anf_node); } // update front to backend map FrontBackendlMapUpdate(old_anf_node, new_anf_node); @@ -787,27 +897,6 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull seed_nodes; UpdateNodeEdgeList(&seed_nodes); } - // update graph inputs in child graph - auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), - [&old_anf_node](const std::pair> &n) -> bool { - return n.first == old_anf_node.get(); - }); - if (it_real_inputs != real_inputs_.end()) { - // erase old parameter in map - auto old_args = it_real_inputs->second; - real_inputs_.erase(it_real_inputs); - // insert new parameter to map - auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), - [&new_anf_node](const std::pair> &n) -> bool { - return n.first == new_anf_node.get(); - }); - if (iter != real_inputs_.end()) { - MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; - iter->second = old_args; - } else { - real_inputs_.emplace_back(new_anf_node, old_args); - } - } } void KernelGraph::UpdateExecuteKernelStreamLabel() { @@ -842,56 +931,6 @@ std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi return result; } -void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { - MS_EXCEPTION_IF_NULL(parameter); - MS_EXCEPTION_IF_NULL(arg); - MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); - MS_EXCEPTION_IF_NULL(parameter); - MS_EXCEPTION_IF_NULL(arg); - auto iter = std::find_if( - real_inputs_.begin(), real_inputs_.end(), - [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); - if (iter != real_inputs_.end()) { - auto &args = iter->second; - args.push_back(arg); - } else { - real_inputs_.emplace_back(parameter, std::vector(1, arg)); - } -} - -void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { - unreuse_args_[arg] = from_graph; -} - -void KernelGraph::UpdateCallRealInput() { - MS_LOG(INFO) << "Update graph id: " << graph_id_; - std::vector>> real_inputs_map; - for (auto &it : real_inputs_) { - auto parameter = it.first; - MS_EXCEPTION_IF_NULL(parameter); - auto real_inputs = it.second; - std::vector new_real_inputs; - for (auto &real_input : real_inputs) { - // if real input is a call node ,find the child graph output act as the new real input - auto tmp_real_input = GetCallRealOutputs(real_input); - std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); - // replace the call in unreuse_args_ - auto unreuse_arg_it = unreuse_args_.find(real_input); - if (unreuse_arg_it != unreuse_args_.end()) { - auto old_graph = unreuse_arg_it->second; - for (auto new_real_input : new_real_inputs) { - // if call reference graph output is parameter, it will be allowed to reuse - if (!new_real_input->isa()) { - unreuse_args_[new_real_input] = old_graph; - } - } - } - } - real_inputs_map.emplace_back(parameter, new_real_inputs); - } - real_inputs_ = real_inputs_map; -} - void KernelGraph::PrintGraphExecuteOrder() const { MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; for (size_t i = 0; i < execution_order_.size(); i++) { @@ -922,17 +961,44 @@ void KernelGraph::PrintGraphExecuteOrder() const { } } -void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { +void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx, + bool unique_target) { if (front_node == nullptr || node == nullptr) { MS_LOG(INFO) << "Front node or node is nullptr"; return; } MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); front_to_internal_outputs_map_[front_node] = node; - internal_outputs_to_front_map_[node] = front_node; + if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { + output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast()); + } + internal_outputs_to_front_map_[node][output_idx] = std::pair(front_node, unique_target); +} + +void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor) { + if (node == nullptr) { + return; + } + internal_outputs_tensor_map_[node][output_idx] = tensor; +} + +tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, int output_idx) { + if (node == nullptr) { + return nullptr; + } + auto iter = internal_outputs_tensor_map_.find(node); + if (iter == internal_outputs_tensor_map_.end()) { + return nullptr; + } + auto idx_iter = iter->second.find(output_idx); + if (idx_iter == iter->second.end()) { + return nullptr; + } + return idx_iter->second; } -void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { +void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx, + int dst_output_idx) { if (new_node == nullptr || node == nullptr) { MS_LOG(INFO) << "New node or node is nullptr"; return; @@ -947,9 +1013,30 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr return; } MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); - internal_outputs_to_front_map_[new_node] = iter->second; - front_to_internal_outputs_map_[iter->second] = new_node; - internal_outputs_to_front_map_.erase(iter); + auto &front_nodes = iter->second; + // Move all front nodes to new node mapping + if (src_output_idx == -1) { + internal_outputs_to_front_map_[new_node] = front_nodes; + for (const auto &front_node_iter : front_nodes) { + front_to_internal_outputs_map_[front_node_iter.second.first] = new_node; + } + internal_outputs_to_front_map_.erase(iter); + return; + } + // Move specified front node to new node mapping + int index = SizeToInt(src_output_idx); + auto front_node_iter = front_nodes.find(index); + if (front_node_iter == front_nodes.end()) { + MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node"; + return; + } + auto front_node_pair = front_node_iter->second; + internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node_pair; + front_to_internal_outputs_map_[front_node_pair.first] = new_node; + front_nodes.erase(index); + if (front_nodes.empty()) { + internal_outputs_to_front_map_.erase(iter); + } } AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { @@ -960,36 +1047,32 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod return nullptr; } -bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { - if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { - return true; +bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const { + auto front_nodes_iter = internal_outputs_to_front_map_.find(node); + if (front_nodes_iter == internal_outputs_to_front_map_.end()) { + return false; } - return false; -} - -AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { - auto iter = internal_outputs_to_front_map_.find(node); - if (iter != internal_outputs_to_front_map_.end()) { - return iter->second; + if (output_idx == -1) { + return true; } - return nullptr; -} - -void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { - if (node == nullptr) { - return; + auto &front_nodes = front_nodes_iter->second; + if (front_nodes.find(output_idx) == front_nodes.end()) { + return false; } - (void)final_output_kernels_.insert(node); + return true; } -bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { - if (node == nullptr) { +bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const { + auto front_nodes_iter = internal_outputs_to_front_map_.find(node); + if (front_nodes_iter == internal_outputs_to_front_map_.end()) { return false; } - if (final_output_kernels_.find(node) != final_output_kernels_.end()) { - return true; + auto &front_nodes = front_nodes_iter->second; + auto idx_iter = front_nodes.find(output_idx); + if (idx_iter == front_nodes.end()) { + return false; } - return false; + return idx_iter->second.second; } void KernelGraph::UpdateChildGraphOrder() { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 48df351120..047c21ea20 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H -#define MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H #include #include @@ -27,7 +27,7 @@ #include #include "ir/func_graph.h" #include "ir/anf.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/contract.h" #include "runtime/device/kernel_info.h" @@ -49,13 +49,17 @@ class KernelGraph : public FuncGraph { const std::vector &inputs() const; std::vector *MutableInputs() const { return inputs_.get(); } + void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter); std::vector outputs() const; CNodePtr NewCNode(const std::vector &inputs) override; void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode); ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); + ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); + ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); - std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); + // trans tuple output to maketuple + no_tuple out + AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node); void set_execution_order(const std::vector &order) { execution_order_ = order; } const std::vector &execution_order() const { return execution_order_; } void SetExecOrderByDefault(); @@ -127,16 +131,8 @@ class KernelGraph : public FuncGraph { void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; - // get real inputs - const std::vector>> &real_inputs() const { return real_inputs_; } - void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); - // mark unreused args - void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph); - const std::map> &unreuse_args() const { return unreuse_args_; } // used to dump ir std::string ToString() const override; - // update the real input if the node is a call - void UpdateCallRealInput(); void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } CNodePtr get_start_label() { return start_label_; } @@ -147,13 +143,16 @@ class KernelGraph : public FuncGraph { void PrintGraphExecuteOrder() const; const std::map> &summary_nodes() const { return summary_nodes_; } void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } - void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); - void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); + void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx = 0, + bool unique_target = false); + void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1, + int dst_output_idx = -1); AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; - bool IsInternalOutput(const AnfNodePtr &node) const; - AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; - void AddFinalOutputKernel(const AnfNodePtr &node); - bool IsFinalOutputKernel(const AnfNodePtr &node) const; + bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const; + bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const; + void AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor); + tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, int output_idx); + uint32_t current_epoch() const { return current_epoch_; } void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } void UpdateChildGraphOrder(); @@ -166,6 +165,8 @@ class KernelGraph : public FuncGraph { private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); + void SetKernelInfoForNode(const AnfNodePtr &node) const; + AnfNodePtr MakeValueNode(const AnfNodePtr &node); void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, std::unordered_set *visited_nodes); // update node edge list @@ -177,6 +178,10 @@ class KernelGraph : public FuncGraph { bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes); void UpdateControlDependRelations(const std::vector &depends); + AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); + AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); + AnfNodePtr TransCNodeTuple(const CNodePtr &node); + AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); std::shared_ptr> inputs_; std::vector child_graph_result_; @@ -204,9 +209,6 @@ class KernelGraph : public FuncGraph { // valid inputs std::vector valid_inputs_; - // new members for control sink process - // all child grahs refers to partial node - std::map> node_to_child_graphs_; // child graph execute order in root graph std::vector> child_graph_order_; @@ -215,19 +217,16 @@ class KernelGraph : public FuncGraph { // parameter graph std::shared_ptr parent_graph_; - // record real parameters,inputs_ is the formal parameters - std::vector>> real_inputs_; - std::map> unreuse_args_; CNodePtr start_label_; CNodePtr end_goto_; bool null_output_; std::unordered_map front_to_internal_outputs_map_; - std::unordered_map internal_outputs_to_front_map_; - std::set final_output_kernels_; + std::unordered_map>> internal_outputs_to_front_map_; + std::unordered_map> internal_outputs_tensor_map_; uint32_t current_epoch_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H diff --git a/mindspore/ccsrc/backend/session/session.cc b/mindspore/ccsrc/backend/session/session.cc deleted file mode 100644 index 95484a1113..0000000000 --- a/mindspore/ccsrc/backend/session/session.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "include/inference.h" -#include "backend/session/session.h" -#include "utils/load_onnx/anf_converter.h" -#include "backend/session/session_basic.h" -#include "backend/session/session_factory.h" -#include "utils/base_ref_utils.h" -#include "backend/kernel_compiler/oplib/oplib.h" -#ifdef ENABLE_D -#include "utils/context/ms_context.h" -#include "backend/session/ascend_session.h" -#else -#include "backend/session/cpu_session.h" -#endif - -namespace py = pybind11; -namespace mindspore::inference { -std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { - try { - inference::Session::RegAllOp(); - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); - return anf_graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference LoadModel failed"; - return nullptr; - } -} - -void ExitInference() { - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return; - } - if (!ms_context->CloseTsd()) { - MS_LOG(ERROR) << "Inference CloseTsd failed!"; - return; - } -} - -std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { - try { - auto session = std::make_shared(); - auto ret = session->Init(device, device_id); - if (ret != 0) { - return nullptr; - } - return session; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CreatSession failed"; - return nullptr; - } -} - -void Session::RegAllOp() { - static std::mutex init_mutex; - static bool Initialized = false; - - std::lock_guard lock(init_mutex); - if (Initialized) { - return; - } - Initialized = true; - MsContext::GetInstance()->set_execution_mode(kGraphMode); - Py_Initialize(); - auto c_expression = PyImport_ImportModule("mindspore._c_expression"); - if (c_expression == nullptr) { - MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; - return; - } - PyObject *c_expression_dict = PyModule_GetDict(c_expression); - - PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); - if (op_info_loader_class == nullptr) { - MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; - return; - } - PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); - if (op_info_loader == nullptr) { - MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; - return; - } - PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); - if (op_info_loader_ins == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; - return; - } - auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); - if (all_ops_info_vector_addr_ul == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; - return; - } - auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); - auto all_ops_info = static_cast *>(all_ops_info_vector_addr); - for (auto op_info : *all_ops_info) { - kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); - } - all_ops_info->clear(); - delete all_ops_info; - Py_DECREF(op_info_loader); - Py_DECREF(op_info_loader_class); - Py_DECREF(c_expression_dict); - Py_DECREF(c_expression); - return; -} - -uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { - MS_ASSERT(session_impl_ != nullptr); - try { - auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); - py::gil_scoped_release gil_release; - return graph_id; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CompileGraph failed"; - return static_cast(-1); - } -} - -MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { - try { - std::vector inTensors; - inTensors.resize(inputs.size()); - bool has_error = false; - std::transform(inputs.begin(), inputs.end(), inTensors.begin(), - [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; - has_error = true; - return nullptr; - } - auto tensor = static_cast(tensor_ptr.get()); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; - has_error = true; - return nullptr; - } - return tensor->tensor(); - }); - if (has_error) { - MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; - std::vector> multiTensor; - return multiTensor; - } - VectorRef outputs; - session_impl_->RunGraph(graph_id, inTensors, &outputs); - - return TransformVectorRefToMultiTensor(outputs); - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference Rungraph failed"; - return MultiTensor(); - } -} -namespace { -string AjustTargetName(const std::string &device) { - if (device == kAscendDevice) { - return std::string(kAscendDevice) + "Inference"; - } else { - MS_LOG(ERROR) << "Only support device Ascend right now"; - return ""; - } -} -} // namespace -int Session::Init(const std::string &device, uint32_t device_id) { - RegAllOp(); - auto ms_context = MsContext::GetInstance(); - ms_context->set_execution_mode(kGraphMode); - ms_context->set_device_id(device_id); - auto ajust_device = AjustTargetName(device); - if (ajust_device == "") { - return -1; - } - ms_context->set_device_target(device); - session_impl_ = session::SessionFactory::Get().Create(ajust_device); - if (session_impl_ == nullptr) { - MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; - return -1; - } - session_impl_->Init(device_id); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return -1; - } - if (!ms_context->OpenTsd()) { - MS_LOG(ERROR) << "Session init OpenTsd failed!"; - return -1; - } - return 0; -} - -Session::Session() = default; -} // namespace mindspore::inference diff --git a/mindspore/ccsrc/backend/session/session.h b/mindspore/ccsrc/backend/session/session.h deleted file mode 100644 index 6ea9cfaa47..0000000000 --- a/mindspore/ccsrc/backend/session/session.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H -#define MINDSPORE_CCSRC_SESSION_SESSION_H - -#include -#include -#include -#include -#include -#include - -#include "backend/session/session_basic.h" -#include "ir/anf.h" -#include "include/inference.h" - -namespace mindspore { -namespace inference { -class Session : public MSSession { - public: - Session(); - - uint32_t CompileGraph(std::shared_ptr funcGraphPtr) override; - - MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) override; - - int Init(const std::string &device, uint32_t device_id); - - static void RegAllOp(); - - private: - std::shared_ptr session_impl_ = nullptr; - std::vector graph_id_; -}; -} // namespace inference -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 9755dfc7d0..bf60814564 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -24,26 +24,30 @@ #include "backend/kernel_compiler/common_utils.h" #include "frontend/operator/ops.h" #include "common/trans.h" -#include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "backend/optimizer/common/common_backend_optimization.h" #include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/helper.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "ir/dtype.h" #include "ir/anf.h" #include "ir/func_graph_cloner.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "frontend/parallel/ps/worker.h" +#include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/util.h" +#endif namespace mindspore { namespace session { -static std::shared_ptr> python_paras; +static std::shared_ptr> python_paras; void ClearPythonParasMap() { python_paras = nullptr; } namespace { const int kSummaryGetItem = 2; -ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) { +ValuePtr GetParamDefaultValue(const AnfNodePtr &node) { if (node == nullptr) { return nullptr; } @@ -54,9 +58,53 @@ ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) { return parameter->default_param(); } -BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, +tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, + const DeviceAddressPtr &address) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); + if (type_id == kTypeUnknown) { + type_id = AnfAlgo::GetOutputInferDataType(node, output_index); + } + tensor::TensorPtr tensor; + std::vector temp_shape; + if (graph->IsUniqueTargetInternalOutput(node, output_index)) { + temp_shape.emplace_back(1); + tensor = std::make_shared(type_id, temp_shape); + tensor->set_device_address(address); + tensor->set_dirty(false); + return tensor; + } + + tensor = graph->GetInternalOutputTensor(node, output_index); + if (tensor == nullptr) { + auto shape = AnfAlgo::GetOutputInferShape(node, output_index); + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor = std::make_shared(type_id, temp_shape); + bool is_internal_output = graph->IsInternalOutput(node, output_index); + if (is_internal_output) { + graph->AddInternalOutputTensor(node, output_index, tensor); + } + } + // if in paynative mode,data only copyed to host when user want to print data + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + MS_EXCEPTION_IF_NULL(address); + if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { + tensor->set_device_address(address); + tensor->set_dirty(false); + } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), + LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { + MS_LOG(INFO) << "Output sync device to host error!!!"; + tensor->set_dirty(false); + } + return tensor; +} + +BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, const std::vector &input_tensors) { MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; // if node is a value node, no need sync addr from device to host if (!AnfAlgo::OutputAddrExist(node, output_index)) { @@ -66,48 +114,22 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne return value_node->value(); } if (node->isa()) { - for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { + for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { if (input_idx >= input_tensors.size()) { MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); } - if (graph.inputs()[input_idx] == node) { + if (graph->inputs()[input_idx] == node) { return input_tensors[input_idx]; } } MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; } } - // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); - MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, output_index); - TypeId type_id = kNumberTypeFloat32; - type_id = AnfAlgo::GetOutputInferDataType(node, output_index); - std::vector temp_shape; - if (graph.IsInternalOutput(node)) { - temp_shape.emplace_back(1); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - tensor->set_device_address(address); - tensor->set_dirty(false); - return tensor; - } - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - // if in paynative mode,data only copyed to host when user want to print data - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { - tensor->set_device_address(address); - tensor->set_dirty(false); - } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), - LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { - MS_LOG(INFO) << "Output sync device to host error!!!"; - tensor->set_dirty(false); - } - return tensor; + return CreateOutputTensor(node, output_index, graph, address); } -BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, +BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph, const std::vector &input_tensors) { MS_EXCEPTION_IF_NULL(anf); MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; @@ -205,8 +227,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, auto param = graph->NewParameter(); MS_EXCEPTION_IF_NULL(param); if (tensor_mask == kParameterWeightTensorMask) { - auto param_value_new = std::make_shared(); - param->set_default_param(param_value_new); + param->set_default_param(input_tensor); } // set the kernel info of parameter auto kernel_build_info_builder = std::make_shared(); @@ -267,11 +288,27 @@ bool ExistSummaryNode(const KernelGraph *graph) { } return false; } + +bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &node_inputs = cnode->inputs(); + for (size_t i = 1; i < node_inputs.size(); ++i) { + if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) { + return false; + } + } + return true; +} } // namespace GraphId SessionBasic::graph_sum_ = 0; -KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { +KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { auto it = graphs_.find(graph_id); if (it == graphs_.end()) { MS_LOG(WARNING) << "Can't find graph " << graph_id; @@ -295,21 +332,24 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const MS_LOG(INFO) << "No corresponding internal output for output node"; return; } - auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); + size_t output_idx = 0; + if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { + output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast()); + } + auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx); auto ref_real_node = real_kernel.first; auto ref_real_node_index = real_kernel.second; - if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && - node_graph->IsFinalOutputKernel(ref_real_node)) { + if (ref_real_node->isa() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) { auto kernel_info = ref_real_node->kernel_info(); if (kernel_info == nullptr || !kernel_info->has_build_info()) { MS_LOG(INFO) << "No kernel info"; return; } - auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); - if (address == nullptr) { + if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) { MS_LOG(INFO) << "No kernel address"; return; } + auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); auto d_kernel_info = std::make_shared(); @@ -320,6 +360,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const builder.SetOutputsFormat({format}); d_kernel_info->set_select_kernel_build_info(builder.Build()); AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get()); } } @@ -329,8 +370,11 @@ std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr MS_EXCEPTION_IF_NULL(graph); std::vector parameters; std::vector pre_graph_out = {node}; + if (IgnoreCreateParameterForMakeTuple(node)) { + pre_graph_out.clear(); + } // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive - if (!AnfAlgo::IsRealKernel(node)) { + if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) { pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); } auto valid_inputs = graph->MutableValidInputs(); @@ -382,7 +426,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ParameterPtr new_parameter = nullptr; // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter if (python_paras == nullptr) { - python_paras = std::make_shared>(); + python_paras = std::make_shared>(); } auto iter = python_paras->find(param_value); if (iter != python_paras->end()) { @@ -406,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; auto parameters = CreateParameterFromTuple(anf, valid_input, graph); if (parameters.empty()) { - MS_LOG(EXCEPTION) << "No parameter exist!!"; + MS_LOG(INFO) << "Empty parameter from cnode"; + return nullptr; } if (parameters.size() == 1) { return parameters[0]; @@ -441,10 +486,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K } auto origin_inputs = cnode->inputs(); bool optimize_depend = false; + bool optimize_control_depend = false; if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && origin_inputs[kRealInputIndexInDepend]->isa()) { optimize_depend = true; } + if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) { + optimize_control_depend = true; + } // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { auto anf = origin_inputs[input_idx]; @@ -463,7 +512,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K cnode_inputs.emplace_back(new_value_node); } continue; - } else if (anf->isa() && AnfAlgo::GetOutputTensorNum(anf) == 1) { + } else if (anf->isa()) { auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); cnode_inputs.push_back(new_parameter); if (GetGraphIdByNode(anf) == kInvalidGraphId) { @@ -475,10 +524,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); continue; + } else if (optimize_control_depend) { + cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); } else { *from_other_graph = true; // the input node is a cnode from other graph auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); + if (parameter_from_cnode == nullptr) { + parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); + } cnode_inputs.push_back(parameter_from_cnode); (*other_graph_cnode)[anf] = parameter_from_cnode; } @@ -653,7 +707,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph auto param_value = GetParamDefaultValue(anf); ParameterPtr new_parameter = nullptr; if (python_paras == nullptr) { - python_paras = std::make_shared>(); + python_paras = std::make_shared>(); } auto iter = python_paras->find(param_value); if (iter != python_paras->end()) { @@ -804,6 +858,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector ¶ } } +namespace { +bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0); + if (ms_context->enable_pynative_infer()) { + return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; + } + if (tensor->is_dirty()) { + return true; + } + if (tensor->device_address() != device_address) { + (void)tensor->data_sync(); + return true; + } + return false; +} +} // namespace + // run graph steps void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const { @@ -813,7 +886,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap if (kernel_graph->input_ctrl_tensors()) { input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); } - auto input_nodes = kernel_graph->inputs(); + std::vector input_nodes; + for (const auto &input_node : kernel_graph->inputs()) { + auto params = AnfAlgo::GetAllOutput(input_node); + std::copy(params.begin(), params.end(), std::back_inserter(input_nodes)); + } if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() << ", input_ctrl_size:" << input_ctrl_size; @@ -825,32 +902,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap MS_EXCEPTION_IF_NULL(tensor); auto input_node = input_nodes[i]; MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - auto pk_node = input_node->cast(); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - bool need_sync = false; - if (ms_context->enable_pynative_infer()) { - if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { - need_sync = true; - } - } else { - if (tensor->is_dirty()) { - need_sync = true; - } else if (tensor->device_address() != device_address) { - (void)tensor->data_sync(); - need_sync = true; - } + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { + auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); + if (ms_context->execution_mode() == kPynativeMode || + AnfAlgo::IsParameterWeight(input_node->cast())) { + tensor->set_device_address(device_address); } - if (need_sync) { - if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { - tensor->set_device_address(device_address); - } - MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } } tensor->set_dirty(false); @@ -865,7 +927,7 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_grap for (auto &item : anf_outputs) { MS_EXCEPTION_IF_NULL(item); MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; - outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors)); + outputs->emplace_back(CreateTensorForOutput(item, kernel_graph, input_tensors)); } } @@ -876,7 +938,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } -void SessionBasic::GetSummaryNodes(KernelGraph *graph) { +void SessionBasic::SetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); if (!graph->summary_node_exist()) { @@ -916,7 +978,7 @@ void SessionBasic::Summary(KernelGraph *graph) { if (!exist_summary) { return; } - GetSummaryNodes(graph); + SetSummaryNodes(graph); auto summary_outputs = graph->summary_nodes(); std::map params_list; // fetch outputs apply kernel in session & run callback functions @@ -944,6 +1006,71 @@ void SessionBasic::Summary(KernelGraph *graph) { summary_callback_(0, params_list); } +namespace { +bool CNodePrimIsValueNode(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; + } + auto prim = cnode->input(kAnfPrimitiveIndex); + if (prim == nullptr || !prim->isa()) { + return false; + } + return true; +} + +void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, + const FuncGraphManagerPtr &front_func_graph_manager, + const std::shared_ptr &backend_graph) { + auto node_users = front_func_graph_manager->node_users(); + auto users = node_users[front_node]; + auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); + auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); + + auto front_real_kernel = front_real_kernel_pair.first; + std::string kernel_target = GetCNodeTarget(front_real_kernel); + bool internal_output = CNodePrimIsValueNode(front_real_kernel); + bool unique_target = true; + if (internal_output && opt::IsNopNode(front_real_kernel)) { + auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); + auto pre_node_target = GetCNodeTarget(pre_node_pair.first); + if (pre_node_target != kernel_target) { + unique_target = false; + } + } + if (internal_output) { + for (auto user : users) { + auto cnode = user.first->cast(); + if (cnode == nullptr) { + internal_output = false; + break; + } + auto prim = cnode->input(kAnfPrimitiveIndex); + if (prim == nullptr || !prim->isa()) { + internal_output = false; + break; + } + if (!AnfAlgo::IsRealKernel(user.first)) { + internal_output = false; + break; + } + if (kernel_target != GetCNodeTarget(user.first)) { + unique_target = false; + } + } + } + if (internal_output) { + MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << "To " + << backend_real_kernel_pair.first->DebugString(); + backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second, + unique_target); + } +} +} // namespace + CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector output_args; @@ -959,8 +1086,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: if (context_ptr->execution_mode() == kPynativeMode) { return backend_anf; } - auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); - auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); + MS_EXCEPTION_IF_NULL(out); auto out_func_graph = out->func_graph(); MS_EXCEPTION_IF_NULL(out_func_graph); @@ -968,20 +1094,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: if (out_func_graph_manager == nullptr) { return backend_anf; } - auto node_users = out_func_graph_manager->node_users(); - auto users = node_users[out]; - bool internal_output = true; - std::string kernel_target = GetCNodeTarget(front_real_kernel.first); - for (auto user : users) { - if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { - internal_output = false; - break; - } - } - if (internal_output) { - MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); - graph->AddInternalOutput(out, backend_real_kernel.first); - } + HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph); return backend_anf; } MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; @@ -1097,5 +1210,90 @@ KernelGraphPtr SessionBasic::NewKernelGraph() { graphs_[graph_sum_++] = graph; return graph; } + +AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector &node_list) { + MS_EXCEPTION_IF_NULL(push_node); + for (auto &node : node_list) { + if (node != nullptr && node->isa()) { + for (auto input : node->cast()->inputs()) { + if (push_node == AnfAlgo::VisitKernel(input, 0).first) { + if (AnfAlgo::GetCNodeName(node) != kPullOpName) { + MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid."; + } + return node; + } + } + } + } + return nullptr; +} + +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { + if (!parallel::ps::Util::IsRoleOfWorker()) { + MS_LOG(INFO) << "Not parameter server mode."; + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector node_list = TopoSort(kernel_graph->get_return()); + for (auto &node : node_list) { + if (node != nullptr && node->isa()) { + // Assign key for forward kernel EmbeddingLookup. + // The key will be assigned to embedding table ande Push kernel as well. + if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { + size_t embedding_table_idx = 0; + auto embedding_table = AnfAlgo::GetInputNode(node->cast(), embedding_table_idx); + size_t key = parallel::ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope()); + AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); + } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) { + auto pull_node = FindPullNode(node, node_list); + if (!pull_node) { + MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node."; + } + + // Second input of Pull node is the trainable parameter. + size_t parameter_index = 1; + auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast(), parameter_index); + size_t key = parallel::ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope()); + AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); + AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node); + + std::string optimizer_name = AnfAlgo::GetNodeAttr(node, kAttrOptimizerType); + parallel::ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name); + } + } + } +} + +void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, + const std::vector &inputs_const) { + if (!parallel::ps::Util::IsRoleOfWorker()) { + return; + } + std::vector inputs(inputs_const); + size_t input_ctrl_size = 1; + MS_EXCEPTION_IF_NULL(kernel_graph); + if (kernel_graph->input_ctrl_tensors()) { + input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); + } + auto input_nodes = kernel_graph->inputs(); + if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { + MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() + << ", input_ctrl_size:" << input_ctrl_size; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + MS_EXCEPTION_IF_NULL(tensor); + auto input_node = input_nodes[i]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + auto pk_node = input_node->cast(); + mindspore::parallel::ps::Worker::GetInstance().InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor); + } + } +} +#endif } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index c662e3978b..f20c27473e 100755 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H -#define MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H #include #include @@ -32,6 +32,7 @@ #include "utils/contract.h" #include "pipeline/pynative/pynative_execute.h" #include "runtime/device/kernel_info.h" +#include "utils/ms_context.h" #ifdef ENABLE_DEBUGGER #include "debug/debugger/debugger.h" #endif @@ -91,31 +92,30 @@ class SessionBasic { CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); - // set parameters of final graph - virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } - // set output of final graph - virtual void SetFinalGraphOutput(const BaseRef &) {} - // insert switch and set the relative active ops - virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} - // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter - virtual void SetChildGraphInput(GraphId, const VectorRef &) {} // get graph id in child graphs by ME front anf node pointer virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } - virtual void SetActive(GraphId, GraphId) {} - virtual void GetSummaryNodes(KernelGraph *graph); + void AssignParamKey(const KernelGraphPtr &kernel_graph); + void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector &inputs_const); + virtual bool CheckModelInputs(uint32_t graph_id, const std::vector &inputs, + std::string *error_msg) const { + return true; + } #ifdef ENABLE_DEBUGGER // set debugger void SetDebugger() { debugger_ = Debugger::GetInstance(); - debugger_->Init(device_id_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + debugger_->Init(device_id_, ms_context->device_target()); } #endif protected: + virtual void SetSummaryNodes(KernelGraph *graph); // Get graph by graph id ,if not exist return null ptr - KernelGraphPtr GetGraph(GraphId graph_id); + KernelGraphPtr GetGraph(GraphId graph_id) const; virtual void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, @@ -140,6 +140,7 @@ class SessionBasic { AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); + AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector &node_list); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; @@ -157,4 +158,4 @@ using SessionPtr = std::shared_ptr; using NamedSummaryOutputs = std::map>; } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/session_context.h b/mindspore/ccsrc/backend/session/session_context.h index 22cc0c813a..6acb3ae18c 100644 --- a/mindspore/ccsrc/backend/session/session_context.h +++ b/mindspore/ccsrc/backend/session/session_context.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H -#define MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_CONTEXT_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_CONTEXT_H #include #include #include @@ -24,7 +24,7 @@ #include "ir/tensor.h" #include "pipeline/jit/resource.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace session { const char kInputCtrlTensors[] = "input_ctrl_tensors"; @@ -47,4 +47,4 @@ class Context : public pipeline::ResourceBase { } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_CONTEXT_H diff --git a/mindspore/ccsrc/backend/session/session_factory.h b/mindspore/ccsrc/backend/session/session_factory.h index 054f03cf4b..6e4833fc2f 100644 --- a/mindspore/ccsrc/backend/session/session_factory.h +++ b/mindspore/ccsrc/backend/session/session_factory.h @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ -#define MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_FACTORY_H_ +#define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_FACTORY_H_ #include #include #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/session/session_basic.h" namespace mindspore { namespace session { @@ -53,4 +53,4 @@ class SessionRegistrar { } // namespace session } // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_FACTORY_H_ diff --git a/mindspore/ccsrc/common.h b/mindspore/ccsrc/common.h index 6b882a15d4..635010cea8 100644 --- a/mindspore/ccsrc/common.h +++ b/mindspore/ccsrc/common.h @@ -25,7 +25,7 @@ #include "abstract/dshape.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse_base.h" diff --git a/mindspore/ccsrc/common/CMakeLists.txt b/mindspore/ccsrc/common/CMakeLists.txt index 6673af4b70..cc393db54a 100644 --- a/mindspore/ccsrc/common/CMakeLists.txt +++ b/mindspore/ccsrc/common/CMakeLists.txt @@ -1,3 +1,16 @@ -file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +if (CMAKE_SYSTEM_NAME MATCHES "Windows") + file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "trans.cc" + "utils.cc" + "duplex_pipe_win.cc" + ) +else() + file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "trans.cc" + "utils.cc" + "duplex_pipe.cc" + ) +endif() + set_property(SOURCE ${_COMMON_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON) add_library(_mindspore_common_obj OBJECT ${_COMMON_ALL_SRC_FILES}) diff --git a/mindspore/ccsrc/common/duplex_pipe.cc b/mindspore/ccsrc/common/duplex_pipe.cc new file mode 100644 index 0000000000..b41be6e573 --- /dev/null +++ b/mindspore/ccsrc/common/duplex_pipe.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/duplex_pipe.h" + +#include +#include +#include +#include + +namespace mindspore { +int DuplexPipe::Open(std::initializer_list arg_list, bool append_fds) { + if (pipe(fd1_) == -1) { + DP_EXCEPTION << "pipe 1 failed, " << strerror(errno) << "(" << errno << ")"; + } + if (pipe(fd2_) == -1) { + close(fd1_[0]); + close(fd1_[1]); + DP_EXCEPTION << "pipe 2 failed, " << strerror(errno) << "(" << errno << ")"; + } + + pid_ = fork(); + if (pid_ < 0) { + close(fd1_[0]); + close(fd1_[1]); + close(fd2_[0]); + close(fd2_[1]); + DP_EXCEPTION << "fork failed, " << strerror(errno) << "(" << errno << ")"; + } else if (pid_ == 0) { // Remote process + DP_INFO << "Remote process, pid: " << getpid() << ", " << fd1_[0] << "/" << fd2_[1]; + remote_stdout_ = dup(STDOUT_FILENO); + remote_stdin_ = dup(STDIN_FILENO); + remote_stderr_ = dup(STDERR_FILENO); + close(fd1_[1]); + close(fd2_[0]); + if (!append_fds) { + dup2(fd1_[0], STDIN_FILENO); + dup2(fd2_[1], STDOUT_FILENO); + } + std::vector args; + std::transform(arg_list.begin(), arg_list.end(), std::back_inserter(args), + [](const std::string &arg) -> const char * { return arg.c_str(); }); + if (append_fds) { + std::string fd10 = std::to_string(fd1_[0]).c_str(); + args.emplace_back(fd10.c_str()); + std::string fd21 = std::to_string(fd2_[1]).c_str(); + args.emplace_back(fd21.c_str()); + } + args.emplace_back(nullptr); + if (execvp(args[0], const_cast(&args[0])) == -1) { + DP_EXCEPTION << "execute " << args[0] << " failed, " << strerror(errno) << "(" << errno << ")"; + } + } else { // Local process + DP_INFO << "Local process, id: " << getpid() << ", " << fd2_[0] << "/" << fd1_[1]; + local_stdout_ = dup(STDOUT_FILENO); + local_stdin_ = dup(STDIN_FILENO); + local_stderr_ = dup(STDERR_FILENO); + close(fd1_[0]); + close(fd2_[1]); + } + return 0; +} + +void DuplexPipe::Write(const std::string &buf, bool flush) { + // Write the string into pipe + if (write(fd1_[1], buf.data(), buf.size()) == -1) { + DP_ERROR << "write failed, error: " << strerror(errno) << "(" << errno << ")"; + return; + } + if (flush) { + // Flush into the pipe + if (write(fd1_[1], "\n", 1) == -1) { + DP_ERROR << "write failed, error: " << strerror(errno) << "(" << errno << ")"; + return; + } + } + DP_DEBUG << "<< [" << buf << "]"; +} + +std::string DuplexPipe::Read() { + // Read the string from pipe + std::string buf; + ssize_t size; + // MAYBE BLOCKED + // Read one line or multiple lines + while (SetTimeOut(), (size = read(fd2_[0], c_buf_, kBufferSize)) > 0) { // Till reading something + CancelTimeOut(); + DP_DEBUG << ">> [" << c_buf_ << "]"; + bool line_end = c_buf_[size - 1] == '\n'; + buf.append(c_buf_, line_end ? size - 1 : size); // Copy without the last '\n' + if (line_end) { + break; + } + } + return buf; +} + +void DuplexPipe::WriteWithStdout(const std::string &buf, bool flush) { + dup2(fd1_[1], STDOUT_FILENO); + // Write the string into pipe + std::cout << buf; + if (flush) { + // Flush into the pipe + std::cout << std::endl; + } + dup2(local_stdout_, STDOUT_FILENO); +} + +std::string DuplexPipe::ReadWithStdin() { + std::string buf; + dup2(fd2_[0], STDIN_FILENO); + // Maybe blocked + SetTimeOut(); + std::getline(std::cin, buf); // Not use 'std::cin >>' to include space + CancelTimeOut(); + dup2(local_stdin_, STDIN_FILENO); + return buf; +} + +DuplexPipe &DuplexPipe::operator<<(const std::string &buf) { + Write(buf); + return *this; +} + +DuplexPipe &DuplexPipe::operator>>(std::string &buf) { + buf = Read(); + return *this; +} + +void DuplexPipe::Close() { + close(fd1_[0]); + close(fd1_[1]); + close(fd2_[0]); + close(fd2_[1]); +} + +void DuplexPipe::Alarm::Set(std::shared_ptr dp, unsigned int interval_secs) { + dp_ = dp; + signal(SIGALRM, SigHandler); + alarm(interval_secs); +} + +void DuplexPipe::Alarm::Cancel() { + alarm(0); + dp_.reset(); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/common/duplex_pipe.h b/mindspore/ccsrc/common/duplex_pipe.h new file mode 100644 index 0000000000..a051e9b560 --- /dev/null +++ b/mindspore/ccsrc/common/duplex_pipe.h @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_COMMON_DUPLEX_PIPE_H_ +#define MINDSPORE_CCSRC_COMMON_DUPLEX_PIPE_H_ + +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" +#define DP_DEBUG MS_LOG(DEBUG) << "[DuplexPipe] " +#define DP_INFO MS_LOG(INFO) << "[DuplexPipe] " +#define DP_ERROR MS_LOG(ERROR) << "[DuplexPipe] " +#define DP_EXCEPTION MS_LOG(EXCEPTION) << "[DuplexPipe] " + +namespace mindspore { +// A tool to run a command as child process and build a duplex pipe between them. +// Similar to 'popen()', but use duplex not simplex pipe, more like 'socketpair'. +class DuplexPipe : public std::enable_shared_from_this { + public: + constexpr inline static int kBufferSize = 4096; + constexpr inline static unsigned int kTimeOutSeconds = 5; + + DuplexPipe() = default; + ~DuplexPipe() = default; + + // Create a subprocess and open a duplex pipe between local and remote + int Open(std::initializer_list arg_list, bool append_fds = false); + void Close(); + void SetTimeOutSeconds(unsigned int secs) { time_out_secs_ = secs; } + void SetTimeOutCallback(const std::function &cb) { + has_time_out_callback_ = true; + time_out_callback_ = cb; + } + + // Write the 'buf' to remote stdin + void Write(const std::string &buf, bool flush = true); + // Read from remote stdout/stderr into 'c_buf_' + std::string Read(); + + void WriteWithStdout(const std::string &buf, bool flush); + std::string ReadWithStdin(); + + DuplexPipe &operator<<(const std::string &buf); + DuplexPipe &operator>>(std::string &buf); + + private: + void SetTimeOut() { alarm_.Set(shared_from_this(), time_out_secs_); } + void CancelTimeOut() { alarm_.Cancel(); } + void TimeOut() { + if (has_time_out_callback_) { + time_out_callback_(); + } + Close(); + DP_EXCEPTION << "Time out when read from pipe"; + } + + // Subprocess id in parent process, + // otherwise zero in child process. + pid_t pid_; + + // Pipe: { Local:fd1_[1] --> Remote:fd1_[0] } + // Remote:fd1_[0] would be redirected by subprocess's stdin. + // Local:fd1_[1] would be used by 'Write()' as output. + int fd1_[2]; + + // Pipe: { Remote:fd2_[1] --> Local:fd2_[0] } + // Remote:fd2_[1] would be redirected by subprocess's stdout. + // Local:fd2_[0] would be used by 'Read()' as input. + int fd2_[2]; + + // // Used and returned by 'Read()'. + // std::string buf_; + char c_buf_[kBufferSize]; + + int local_stdin_; + int local_stdout_; + int local_stderr_; + int remote_stdin_; + int remote_stdout_; + int remote_stderr_; + + class Alarm { + public: + Alarm() = default; + ~Alarm() = default; + + void Set(std::shared_ptr dp, unsigned int interval_secs); + void Cancel(); + + private: + static void SigHandler(int sig) { + DP_INFO << "Signal: " << sig; + dp_->TimeOut(); + } + + inline static std::shared_ptr dp_; + }; + + unsigned int time_out_secs_ = kTimeOutSeconds; + bool has_time_out_callback_ = false; + std::function time_out_callback_; + Alarm alarm_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_COMMON_DUPLEX_PIPE_H_ diff --git a/mindspore/ccsrc/common/duplex_pipe_win.cc b/mindspore/ccsrc/common/duplex_pipe_win.cc new file mode 100644 index 0000000000..7face96522 --- /dev/null +++ b/mindspore/ccsrc/common/duplex_pipe_win.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/duplex_pipe.h" + +#include +#include + +namespace mindspore { +int DuplexPipe::Open(std::initializer_list arg_list, bool append_fds) { + DP_EXCEPTION << "Not support for Windows by now."; +} + +void DuplexPipe::Write(const std::string &buf, bool flush) { DP_EXCEPTION << "Not support for Windows by now."; } + +std::string DuplexPipe::Read() { DP_EXCEPTION << "Not support for Windows by now."; } + +void DuplexPipe::WriteWithStdout(const std::string &buf, bool flush) { + DP_EXCEPTION << "Not support for Windows by now."; +} + +std::string DuplexPipe::ReadWithStdin() { DP_EXCEPTION << "Not support for Windows by now."; } + +DuplexPipe &DuplexPipe::operator<<(const std::string &buf) { DP_EXCEPTION << "Not support for Windows by now."; } + +DuplexPipe &DuplexPipe::operator>>(std::string &buf) { DP_EXCEPTION << "Not support for Windows by now."; } + +void DuplexPipe::Close() { DP_EXCEPTION << "Not support for Windows by now."; } + +void DuplexPipe::Alarm::Set(std::shared_ptr dp, unsigned int interval_secs) { + DP_EXCEPTION << "Not support for Windows by now."; +} + +void DuplexPipe::Alarm::Cancel() { DP_EXCEPTION << "Not support for Windows by now."; } +} // namespace mindspore diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 1841826ca9..1b10a7d2f7 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -17,7 +17,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/kernel.h" #include "runtime/device/convert_tensor_utils.h" diff --git a/mindspore/ccsrc/common/utils.cc b/mindspore/ccsrc/common/utils.cc deleted file mode 100644 index 7109c121e5..0000000000 --- a/mindspore/ccsrc/common/utils.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/utils.h" -#include -#include -#include - -namespace mindspore { -namespace common { -const int CACHED_STR_NUM = 1 << 8; -const int CACHED_STR_MASK = CACHED_STR_NUM - 1; -std::vector STR_HOLDER(CACHED_STR_NUM); -const char *SafeCStr(const std::string &&str) { - static std::atomic index{0}; - uint32_t cur_index = index++; - cur_index = cur_index & CACHED_STR_MASK; - STR_HOLDER[cur_index] = str; - return STR_HOLDER[cur_index].c_str(); -} -} // namespace common -} // namespace mindspore diff --git a/mindspore/ccsrc/debug/CMakeLists.txt b/mindspore/ccsrc/debug/CMakeLists.txt index 37ffcceeaf..1f44c5e708 100644 --- a/mindspore/ccsrc/debug/CMakeLists.txt +++ b/mindspore/ccsrc/debug/CMakeLists.txt @@ -3,10 +3,6 @@ set(_DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/anf_ir_utils.cc" "${CMAKE_CURRENT_SOURCE_DIR}/draw.cc" "${CMAKE_CURRENT_SOURCE_DIR}/dump_proto.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/info.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/label.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/trace_info.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/trace_base.cc" "${CMAKE_CURRENT_SOURCE_DIR}/trace.cc" ) @@ -16,6 +12,7 @@ if (ENABLE_DEBUGGER) "${CMAKE_CURRENT_SOURCE_DIR}/debugger/grpc_client.cc" "${CMAKE_CURRENT_SOURCE_DIR}/debugger/proto_exporter.cc" "${CMAKE_CURRENT_SOURCE_DIR}/debug_services.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/common.cc" ) endif (ENABLE_DEBUGGER) @@ -23,9 +20,7 @@ if (ENABLE_D) list(APPEND _DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/common.cc" ) - if (ENABLE_DATA_DUMP) - list(APPEND _DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/data_dump_parser.cc") - endif(ENABLE_DATA_DUMP) + list(APPEND _DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/data_dump_parser.cc") endif() if (ENABLE_DUMP_E2E) diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index 42d372cefb..a1cc80f96b 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -25,8 +25,9 @@ #include "ir/primitive.h" #include "ir/func_graph.h" #include "runtime/device/kernel_info.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "backend/session/anf_runtime_algorithm.h" +#include "frontend/parallel/ops_info/operator_info.h" namespace mindspore { const std::string ToShortString(const TypeId &typeId) { @@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptroperator_info(); + auto operator_info = node->user_data(); if (operator_info == nullptr) { return; } diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 273a6f6458..261e811c91 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -23,11 +23,11 @@ #include #include -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/symbolic.h" #include "ir/meta_func_graph.h" #include "ir/param_value.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/resolve.h" #include "frontend/operator/composite/composite.h" @@ -36,8 +36,8 @@ #include "utils/ordered_set.h" #include "utils/utils.h" #include "debug/trace.h" -#include "debug/label.h" -#include "utils/context/ms_context.h" +#include "utils/label.h" +#include "utils/ms_context.h" #include "frontend/operator/ops.h" using mindspore::tensor::TensorPy; @@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap * │ └── MapPy * ├── Tail * ├── MakeTupleGradient + * ├── MakeListGradient * ├── GradOperation * └── TupleAdd */ @@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ // do nothing } else if (meta_func_graph->isa()) { // do nothing + } else if (meta_func_graph->isa()) { + // do nothing } else if (meta_func_graph->isa()) { // do nothing } else if (meta_func_graph->isa()) { @@ -1667,7 +1670,7 @@ class IrParser { // load parameter default value from serialized file py::object default_obj = LoadObject(lexer_.GetTokenText()); - auto param_value_new = py::cast(default_obj); + auto param_value_new = py::cast(default_obj); param->set_default_param(param_value_new); tok = lexer_.GetNextToken(); diff --git a/mindspore/ccsrc/debug/common.cc b/mindspore/ccsrc/debug/common.cc index 6caf7e2c39..1683a3f803 100644 --- a/mindspore/ccsrc/debug/common.cc +++ b/mindspore/ccsrc/debug/common.cc @@ -21,7 +21,7 @@ #include "utils/system/env.h" #include "utils/system/file_system.h" #include "utils/log_adapter.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { std::optional Common::GetRealPath(const std::string &input_path) { @@ -80,9 +80,9 @@ bool Common::CreateNotExistDirs(const std::string &path) { char tmp_char = temp_path[i]; temp_path[i] = '\0'; std::string path_handle(temp_path); - if (!fs->FileExist(temp_path)) { + if (!fs->FileExist(path_handle)) { MS_LOG(INFO) << "Dir " << path_handle << " does not exit, creating..."; - if (!fs->CreateDir(temp_path)) { + if (!fs->CreateDir(path_handle)) { MS_LOG(ERROR) << "Create " << path_handle << " dir error"; return false; } @@ -120,6 +120,10 @@ std::optional Common::GetConfigFile(const std::string &env) { MS_LOG(ERROR) << dump_config_file << " not exist."; return {}; } + auto suffix = dump_config_file.substr(dump_config_file.find_last_of('.') + 1); + if (suffix != "json") { + MS_LOG(EXCEPTION) << "[DataDump] dump config file suffix only support json! But got:." << suffix; + } return dump_config_file; } } // namespace mindspore diff --git a/mindspore/ccsrc/debug/common.h b/mindspore/ccsrc/debug/common.h index 8d4a6cb467..1dafbfd99e 100644 --- a/mindspore/ccsrc/debug/common.h +++ b/mindspore/ccsrc/debug/common.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEBUG_COMMON_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEBUG_COMMON_H_ +#ifndef MINDSPORE_CCSRC_DEBUG_COMMON_H_ +#define MINDSPORE_CCSRC_DEBUG_COMMON_H_ #include #include @@ -33,4 +33,4 @@ class Common { static bool CreateNotExistDirs(const std::string &path); }; } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEBUG_COMMON_H_ +#endif // MINDSPORE_CCSRC_DEBUG_COMMON_H_ diff --git a/mindspore/ccsrc/debug/data_dump_parser.cc b/mindspore/ccsrc/debug/data_dump_parser.cc index 259ec388d3..abe9938ed3 100644 --- a/mindspore/ccsrc/debug/data_dump_parser.cc +++ b/mindspore/ccsrc/debug/data_dump_parser.cc @@ -17,25 +17,31 @@ #include "debug/data_dump_parser.h" #include -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "debug/common.h" -constexpr auto kDataDumpConfigPtah = "DATA_DUMP_CONFIG_PATH"; -constexpr auto kEnableDataDump = "ENABLE_DATA_DUMP"; -constexpr auto kDataDumpPath = "DATA_DUMP_PATH"; +static constexpr auto kDataDumpConfigPtah = "DATA_DUMP_CONFIG_PATH"; +static constexpr auto kEnableDataDump = "ENABLE_DATA_DUMP"; +static constexpr auto kDataDumpPath = "DATA_DUMP_PATH"; +static constexpr auto kConfigDumpMode = "dump_mode"; +static constexpr auto kConfigOpDebugMode = "op_debug_mode"; +static constexpr auto kConfigNetName = "net_name"; +static constexpr auto kConfigIteration = "iteration"; +static constexpr auto kConfigKernels = "kernels"; + namespace mindspore { void DataDumpParser::ResetParam() { enable_ = false; net_name_.clear(); dump_mode_ = 0; dump_step_ = 0; - kernel_set_.clear(); + kernel_map_.clear(); } bool DataDumpParser::DumpEnabled() const { auto enable_dump = std::getenv(kEnableDataDump); - if (!enable_dump) { - MS_LOG(WARNING) << "[DataDump] enable dump is null. Please export ENABLE_DATA_DUMP"; + if (enable_dump == nullptr) { + MS_LOG(INFO) << "[DataDump] enable dump is null. Please export ENABLE_DATA_DUMP"; return false; } @@ -55,14 +61,25 @@ bool DataDumpParser::DumpEnabled() const { std::optional DataDumpParser::GetDumpPath() const { auto dump_path = std::getenv(kDataDumpPath); - if (!dump_path) { + if (dump_path == nullptr) { MS_LOG(ERROR) << "[DataDump] dump path is null. Please export DATA_DUMP_PATH"; return {}; } std::string dump_path_str(dump_path); + if (!std::all_of(dump_path_str.begin(), dump_path_str.end(), + [](char c) { return ::isalpha(c) || ::isdigit(c) || c == '-' || c == '_' || c == '/'; })) { + MS_LOG(EXCEPTION) << "[DataDump] dump path only support alphabets, digit or {'-', '_', '/'}, but got:" + << dump_path_str; + } return dump_path_str; } +std::string GetIfstreamString(const std::ifstream &ifstream) { + std::stringstream buffer; + buffer << ifstream.rdbuf(); + return buffer.str(); +} + void DataDumpParser::ParseDumpConfig() { std::lock_guard guard(lock_); MS_LOG(INFO) << "[DataDump] parse start"; @@ -84,7 +101,12 @@ void DataDumpParser::ParseDumpConfig() { } nlohmann::json j; - json_file >> j; + try { + json_file >> j; + } catch (nlohmann::json::parse_error &e) { + MS_LOG(ERROR) << "[DataDump] json contents:" << GetIfstreamString(json_file); + MS_LOG(EXCEPTION) << "[DataDump] parse json failed, error:" << e.what(); + } if (j.find("DumpSettings") == j.end()) { MS_LOG(EXCEPTION) << "[DataDump] DumpSettings is not exist."; } @@ -111,13 +133,16 @@ bool DataDumpParser::NeedDump(const std::string &op_full_name) const { if (dump_mode_ == 0) { return true; } - auto iter = kernel_set_.find(op_full_name); - return iter != kernel_set_.end(); + auto iter = kernel_map_.find(op_full_name); + return iter != kernel_map_.end(); } bool DataDumpParser::IsConfigExist(const nlohmann::json &dump_settings) const { - if (dump_settings.find("mode") == dump_settings.end() || dump_settings.find("net_name") == dump_settings.end() || - dump_settings.find("iteration") == dump_settings.end() || dump_settings.find("kernels") == dump_settings.end()) { + if (dump_settings.find(kConfigDumpMode) == dump_settings.end() || + dump_settings.find(kConfigNetName) == dump_settings.end() || + dump_settings.find(kConfigOpDebugMode) == dump_settings.end() || + dump_settings.find(kConfigIteration) == dump_settings.end() || + dump_settings.find(kConfigKernels) == dump_settings.end()) { MS_LOG(ERROR) << "[DataDump] DumpSettings keys are not exist."; return false; } @@ -125,28 +150,63 @@ bool DataDumpParser::IsConfigExist(const nlohmann::json &dump_settings) const { } bool DataDumpParser::ParseDumpSetting(const nlohmann::json &dump_settings) { - auto mode = dump_settings.at("mode"); - auto net_name = dump_settings.at("net_name"); - auto iteration = dump_settings.at("iteration"); - auto kernels = dump_settings.at("kernels"); - if (!(mode.is_number() && net_name.is_string() && iteration.is_number() && kernels.is_array())) { + auto mode = dump_settings.at(kConfigDumpMode); + auto op_debug_mode = dump_settings.at(kConfigOpDebugMode); + auto net_name = dump_settings.at(kConfigNetName); + auto iteration = dump_settings.at(kConfigIteration); + auto kernels = dump_settings.at(kConfigKernels); + if (!(mode.is_number_unsigned() && op_debug_mode.is_number_unsigned() && net_name.is_string() && + iteration.is_number_unsigned() && kernels.is_array())) { MS_LOG(ERROR) << "[DataDump] Element's type in Dump config json is invalid."; enable_ = false; return false; } + CheckDumpMode(mode); + CheckOpDebugMode(op_debug_mode); + enable_ = true; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); dump_mode_ = mode; + op_debug_mode_ = op_debug_mode; net_name_ = net_name; dump_step_ = iteration; for (const auto &kernel : kernels) { auto kernel_str = kernel.dump(); kernel_str.erase(std::remove(kernel_str.begin(), kernel_str.end(), '\"'), kernel_str.end()); MS_LOG(INFO) << "[DataDump] Need dump kernel:" << kernel_str; - kernel_set_.insert(kernel_str); + kernel_map_.insert({kernel_str, 0}); } return true; } + +void DataDumpParser::MatchKernel(const std::string &kernel_name) { + auto iter = kernel_map_.find(kernel_name); + if (iter == kernel_map_.end()) { + return; + } + iter->second = iter->second + 1; + MS_LOG(INFO) << "Match dump kernel:" << iter->first << " match times:" << iter->second; +} + +void DataDumpParser::PrintUnusedKernel() { + for (const auto &iter : kernel_map_) { + if (iter.second == 0) { + MS_LOG(WARNING) << "[DataDump] Unused Kernel in json:" << iter.first; + } + } +} + +void DataDumpParser::CheckDumpMode(uint32_t dump_mode) const { + if (dump_mode != 0 && dump_mode != 1) { + MS_LOG(EXCEPTION) << "[DataDump] dump_mode in config json should be 0 or 1"; + } +} + +void DataDumpParser::CheckOpDebugMode(uint32_t op_debug_mode) const { + if (op_debug_mode < 0 || op_debug_mode > 3) { + MS_LOG(EXCEPTION) << "[DataDump] op_debug_mode in config json file should be [0-3]"; + } +} } // namespace mindspore diff --git a/mindspore/ccsrc/debug/data_dump_parser.h b/mindspore/ccsrc/debug/data_dump_parser.h index 751c61dd1a..95a730fff9 100644 --- a/mindspore/ccsrc/debug/data_dump_parser.h +++ b/mindspore/ccsrc/debug/data_dump_parser.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ +#ifndef MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ +#define MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ #include -#include +#include #include #include #include "nlohmann/json.hpp" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { class DataDumpParser { @@ -38,8 +38,10 @@ class DataDumpParser { bool enable() const { return enable_; } const std::string &net_name() const { return net_name_; } uint32_t dump_mode() const { return dump_mode_; } + uint32_t op_debug_mode() const { return op_debug_mode_; } uint32_t dump_step() const { return dump_step_; } - const std::set &kernel_set() const { return kernel_set_; } + void MatchKernel(const std::string &kernel_name); + void PrintUnusedKernel(); private: DataDumpParser() = default; @@ -49,13 +51,16 @@ class DataDumpParser { void ResetParam(); bool IsConfigExist(const nlohmann::json &dump_settings) const; bool ParseDumpSetting(const nlohmann::json &dump_settings); + void CheckDumpMode(uint32_t dump_mode) const; + void CheckOpDebugMode(uint32_t op_debug_mode) const; std::mutex lock_; bool enable_{false}; std::string net_name_; + uint32_t op_debug_mode_{0}; uint32_t dump_mode_{0}; uint32_t dump_step_{0}; - std::set kernel_set_; + std::map kernel_map_; }; } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ +#endif // MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index cc6c5c53ad..1e99168c1e 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -171,6 +171,61 @@ void DebugServices::CheckWatchpoints(std::vector *name, std::vector } } +void DebugServices::CheckSingleWatchpoint(std::shared_ptr watchtensor, std::string *name, std::string *slot, + char **data_ptr, unsigned int *data_size, int *condition, + unsigned int *wacthpoint_id) { + std::lock_guard lg(lock_); + + std::string current_watchtensor_name; + current_watchtensor_name = watchtensor->GetName(); + mindspore::tensor::TensorPtr tensor_ptr = watchtensor->GetTensor(); + int tensor_data_type = tensor_ptr->data_type_c(); + watchpoint_t watchpoint_to_check; + + for (auto w_table_item : watchpoint_table) { + auto check_node_list = std::get<1>(w_table_item).check_node_list; + for (auto check_node : check_node_list) { + std::string w_name = std::get<0>(check_node); + bool w_type = std::get<1>(check_node); + // get current the full info including condition, id..., for current watchtensor + std::string current_node_name = current_watchtensor_name.substr(0, current_watchtensor_name.find_first_of(":")); + if ((w_type == true && (current_watchtensor_name.find(w_name) != string::npos || w_name == "*")) || + (w_type == false && current_node_name == w_name)) { + watchpoint_to_check = w_table_item.second; + // need to add support for float16 and float64, and other types when we support conditions beyond inf and nan + if (tensor_data_type != kNumberTypeFloat && tensor_data_type != kNumberTypeFloat32) { + return; + } + break; + } + } + } + + float *start_addr = reinterpret_cast(tensor_ptr->data_c()); + unsigned int num_elements = (tensor_ptr->data().nbytes()) / sizeof(float); + + for (unsigned int index = 0; index < num_elements; index++) { + float x = start_addr[index]; + if (((watchpoint_to_check.conditions.inf.enabled || watchpoint_to_check.conditions.neg_inf.enabled) && isinf(x)) || + (watchpoint_to_check.conditions.nan.enabled && isnan(x))) { + std::string name_no_slot = current_watchtensor_name.substr(0, current_watchtensor_name.find_first_of(":")); + *name = name_no_slot; + *slot = std::to_string(watchtensor->GetSlot()); + *data_ptr = reinterpret_cast(tensor_ptr->data_c()); + *data_size = tensor_ptr->data().nbytes(); + int condition_item = -1; + if (watchpoint_to_check.conditions.nan.enabled) { + condition_item = 0; + } else if (watchpoint_to_check.conditions.inf.enabled || watchpoint_to_check.conditions.neg_inf.enabled) { + condition_item = 1; + } + *condition = condition_item; + + *wacthpoint_id = watchpoint_to_check.id; + } + } +} + void DebugServices::ReadNodesTensors(std::vector name, std::vector *ret_name, std::vector *data_ptr, std::vector *data_size, std::vector *dtype, std::vector> *shape) { diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h index 41400af1d5..b664a9b9e9 100644 --- a/mindspore/ccsrc/debug/debug_services.h +++ b/mindspore/ccsrc/debug/debug_services.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_DEBUG_DEBUG_SERVICES_H_ #define MINDSPORE_CCSRC_DEBUG_DEBUG_SERVICES_H_ +#include #include #include #include @@ -77,6 +78,9 @@ class DebugServices { std::vector *data_size, std::vector *condition, std::vector *wacthpoint_id); + void CheckSingleWatchpoint(std::shared_ptr watchnode, std::string *name, std::string *slot, + char **data_ptr, unsigned int *data_size, int *condition, unsigned int *wacthpoint_id); + void ReadNodesTensors(std::vector name, std::vector *ret_name, std::vector *data_ptr, std::vector *data_size, std::vector *dtype, std::vector> *shape); diff --git a/mindspore/ccsrc/debug/debugger/debug_grpc.proto b/mindspore/ccsrc/debug/debugger/debug_grpc.proto index f742987a4e..27c93787b8 100644 --- a/mindspore/ccsrc/debug/debugger/debug_grpc.proto +++ b/mindspore/ccsrc/debug/debugger/debug_grpc.proto @@ -31,6 +31,10 @@ service EventListener { message Metadata { string device_name = 1; int32 cur_step = 2; + // define the backend is 'GPU' or "Ascend" + string backend = 3; + // the full name of current node + string cur_node = 4; } message EventReply { @@ -44,12 +48,22 @@ message EventReply { oneof cmd { bool exit = 2; - int32 run_cmd = 3; + RunCMD run_cmd = 3; SetCMD set_cmd = 4; ViewCMD view_cmd = 5; } } +message RunCMD { + // step level or node level. "step" or "node" + string run_level = 1; + oneof cmd { + int32 run_steps = 2; + // the next node full name + string node_name = 3; + } +} + message SetCMD { repeated WatchNode watch_nodes = 1; WatchCondition watch_condition = 2; diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index dd89e17e2d..77e75a5f19 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -21,6 +21,7 @@ #include "debug/debugger/debugger.h" #include "pipeline/jit/pipeline.h" #include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_runtime_manager.h" using debugger::EventReply; using debugger::GraphProto; @@ -41,17 +42,23 @@ Debugger::Debugger() : grpc_client_(nullptr), debug_services_(nullptr), device_id_(0), + device_target_(""), num_step_(0), debugger_enabled_(false), + run_level_(""), + node_name_(""), + cur_name_(""), is_dataset_graph_(false), partial_memory_(false) {} -void Debugger::Init(const uint32_t device_id) { +void Debugger::Init(const uint32_t device_id, const std::string device_target) { // access lock for public method std::lock_guard a_lock(access_lock_); // save device_id MS_LOG(INFO) << "Debugger got device_id: " << device_id; device_id_ = device_id; + MS_LOG(INFO) << "Debugger got device_target: " << device_target; + device_target_ = device_target; } void Debugger::EnableDebugger() { @@ -62,6 +69,14 @@ void Debugger::EnableDebugger() { grpc_client_ = nullptr; debug_services_ = nullptr; + // see if dump is enabled + bool dump_enabled = false; + if (device_target_ == kGPUDevice) { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + dump_enabled = runtime_instance->DumpDataEnabled(); + } + // get env variables to configure debugger const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER"); if (env_enable_str != nullptr) { @@ -70,7 +85,8 @@ void Debugger::EnableDebugger() { debugger_enabled_ = true; } } - if (!debugger_enabled_) { + + if (!debugger_enabled_ && !dump_enabled) { MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger."; return; } @@ -118,7 +134,10 @@ void Debugger::EnableDebugger() { } // initialize grpc client - grpc_client_ = std::make_unique(host, port); + if (debugger_enabled_) { + grpc_client_ = std::make_unique(host, port); + } + debug_services_ = std::make_unique(); } @@ -127,6 +146,7 @@ void Debugger::Reset() { std::lock_guard a_lock(access_lock_); // reset components device_id_ = 0; + device_target_ = ""; num_step_ = 0; debugger_enabled_ = false; is_dataset_graph_ = false; @@ -147,10 +167,46 @@ void Debugger::PostExecute() { // access lock for public method std::lock_guard a_lock(access_lock_); // analyze tensor data and send the watchpoints been hit + if (run_level_ == "node") { + MS_LOG(INFO) << "Debugger is in node level mode "; + return; + } if (debugger_enabled_ && !is_dataset_graph_) { - num_step_++; MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_; - SendWatchpointsAndSuspend(CheckWatchpoints()); + CommandLoop(); + } +} + +bool Debugger::ReadNodeDataRequired() { + if (debugger_enabled_ && !is_dataset_graph_) { + auto watchpoint_table = debug_services_->GetWatchpointTable(); + auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, watchpoint_table); + // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data + if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) { + return true; + } + } + return false; +} + +void Debugger::PostExecuteNode() { + // access lock for public method + std::lock_guard a_lock(access_lock_); + if (debugger_enabled_ && !is_dataset_graph_) { + auto watchpoint_table = debug_services_->GetWatchpointTable(); + auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, watchpoint_table); + // if kernel is watchpoint,and get hit. suspend. + if (is_watchpoint) { + auto hits = CheckSingleWatchpoint(cur_name_); + if (!hits.empty()) { + SendWatchpointsAndSuspend(hits); + } + } + // if kernel is not watchpoint and is next_to or continue_to node, suspend. + if (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) { + CommandLoop(); + } + return; } } @@ -215,6 +271,8 @@ void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) { Metadata metadata; metadata.set_device_name(device_name); metadata.set_cur_step(num_step_); + metadata.set_backend(device_target_); + metadata.set_cur_node(cur_name_); EventReply reply_metadata = grpc_client_->SendMetadata(metadata); if (reply_metadata.status() != reply_metadata.OK) { MS_LOG(ERROR) << "Error: SendMetadata failed"; @@ -232,8 +290,11 @@ void Debugger::CommandLoop() { // prepare metadata std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id()); Metadata metadata; + metadata.set_device_name(device_name); metadata.set_cur_step(num_step_); + metadata.set_backend(device_target_); + metadata.set_cur_node(cur_name_); // loop exit flag bool run = false; @@ -274,6 +335,16 @@ void Debugger::CommandLoop() { break; case DebuggerCommand::kRunCMD: MS_LOG(INFO) << "RunCMD"; + { + // print run cmd content + // get run_level and node_name + run_level_ = GetRunLevel(reply); + node_name_ = GetNodeName(reply); + + MS_LOG(INFO) << "run_level: " << run_level_; + MS_LOG(INFO) << "node_name_: " << node_name_; + } + // exit loop run = true; break; @@ -428,6 +499,35 @@ std::list Debugger::CheckWatchpoints() const { return hits; } +std::list Debugger::CheckSingleWatchpoint(std::string watchnode) const { + auto tensor_loader = debug_services_->tensor_loader(); + auto tensors = tensor_loader->GetNodeTensorMap(watchnode); + std::list hits; + for (std::vector>::iterator it = tensors.begin(); it != tensors.end(); ++it) { + auto cur_tensor = *it; + std::string name = ""; + std::string slot = ""; + char *data_ptr = nullptr; + unsigned int data_size = 0; + int condition = -1; + unsigned int watchpoint_id = -1; + WatchpointHit hit; + debug_services_->CheckSingleWatchpoint(cur_tensor, &name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); + if (name != "") { + hit.set_id(watchpoint_id); + // here TensorProto act as a tensor indicator, not sending tensor content + TensorProto *tensor_item = hit.mutable_tensor(); + tensor_item->set_node_name(name); + tensor_item->set_slot(slot); + tensor_item->set_finished(true); + WatchCondition *condition_item = hit.mutable_watch_condition(); + condition_item->set_condition(debugger::WatchCondition_Condition(condition)); + hits.push_back(hit); + } + } + return hits; +} + void Debugger::SendWatchpointsAndSuspend(const std::list &points) { // send info about watchpoint if (!points.empty()) { @@ -474,6 +574,24 @@ ProtoVector GetWatchnodes(const EventReply &reply) { return reply.set_cmd().watch_nodes(); } +std::string GetRunLevel(const EventReply &reply) { + if (!reply.has_run_cmd()) { + MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: " + ""; + return ""; + } + return reply.run_cmd().run_level(); +} + +std::string GetNodeName(const EventReply &reply) { + if (!reply.has_run_cmd()) { + MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: " + ""; + return ""; + } + return reply.run_cmd().node_name(); +} + WatchCondition GetWatchcondition(const EventReply &reply) { if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) { MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition()."; @@ -519,4 +637,20 @@ std::string GetTensorFullName(const TensorProto &tensor) { bool Debugger::partial_memory() { return partial_memory_; } +void Debugger::SetCurNode(std::string cur_name) { + // access lock for public method + std::lock_guard a_lock(access_lock_); + cur_name_ = cur_name; +} + +std::string Debugger::run_level() const { return run_level_; } + +void Debugger::SetStepNum(int32_t cur_num_step) { + // access lock for public method + std::lock_guard a_lock(access_lock_); + num_step_ = cur_num_step; +} + +int32_t Debugger::step_num() const { return num_step_; } + } // namespace mindspore diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h index 5a3965d7cc..ea035708ea 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.h +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -55,7 +55,7 @@ class Debugger : public std::enable_shared_from_this { // init // only save device_id - void Init(const uint32_t device_id); + void Init(const uint32_t device_id, const std::string device_target); // reset debugger void Reset(); @@ -69,6 +69,10 @@ class Debugger : public std::enable_shared_from_this { // don't need a graph_ptr because it is saved during pre_execute void PostExecute(); + bool ReadNodeDataRequired(); + + void PostExecuteNode(); + // suspend the execution after a debug_op void PostDebugOp(); @@ -78,6 +82,14 @@ class Debugger : public std::enable_shared_from_this { bool partial_memory(); + void SetCurNode(std::string cur_name); + + std::string run_level() const; + + void SetStepNum(int32_t cur_num_step); + + int32_t step_num() const; + private: // private constructor for singleton Debugger(); @@ -119,6 +131,7 @@ class Debugger : public std::enable_shared_from_this { // analyze tensors and check watchpoint conditions // return names of tensors and what condition they hit std::list CheckWatchpoints() const; + std::list CheckSingleWatchpoint(std::string watchnode) const; // send watchpoints that hit and enter command wait loop void SendWatchpointsAndSuspend(const std::list &points); @@ -128,8 +141,12 @@ class Debugger : public std::enable_shared_from_this { std::unique_ptr debug_services_; KernelGraphPtr graph_ptr_; uint32_t device_id_; + std::string device_target_; int32_t num_step_; bool debugger_enabled_; + std::string run_level_; + std::string node_name_; + std::string cur_name_; bool is_dataset_graph_; bool partial_memory_; std::mutex access_lock_; @@ -153,6 +170,8 @@ DebuggerCommand GetCommand(const EventReply &reply); // parse other data out of EventReply ProtoVector GetWatchnodes(const EventReply &reply); +std::string GetNodeName(const EventReply &reply); +std::string GetRunLevel(const EventReply &reply); WatchCondition GetWatchcondition(const EventReply &reply); int32_t GetWatchpointID(const EventReply &reply); bool GetWatchpointDelete(const EventReply &reply); diff --git a/mindspore/ccsrc/debug/debugger/proto_exporter.cc b/mindspore/ccsrc/debug/debugger/proto_exporter.cc index b4b4de9d99..35e7d906ed 100644 --- a/mindspore/ccsrc/debug/debugger/proto_exporter.cc +++ b/mindspore/ccsrc/debug/debugger/proto_exporter.cc @@ -24,7 +24,7 @@ #include "debug/debugger/debugger.h" #include "proto/debug_graph.pb.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/symbolic.h" namespace mindspore { diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index ff8132fb28..bf16fbf537 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -23,17 +23,14 @@ #include #include -#include "pybind11/pybind11.h" #include "ir/meta_func_graph.h" #include "ir/param_value.h" #include "ir/primitive.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/utils.h" #include "frontend/operator/composite/composite.h" #include "ir/tensor.h" -namespace py = pybind11; - namespace mindspore { // namespace to support debug utils @@ -321,8 +318,9 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { buffer_ << parameter->ToString(); auto param = parameter->cast(); if (param->has_default()) { - auto tensor = param->default_param()->value(); - if (tensor) { + auto tensor_v = param->default_param(); + if (tensor_v && tensor_v->isa()) { + auto tensor = tensor_v->cast(); auto &shape = tensor->shape(); std::ostringstream shape_str; std::copy(shape.begin(), shape.end(), std::ostream_iterator(shape_str, ",")); @@ -437,7 +435,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr) { return; } - auto distributed_operation_info = node->operator_info(); + auto distributed_operation_info = node->user_data(); if (distributed_operation_info != nullptr) { auto strategyPtr = distributed_operation_info->strategy(); if (strategyPtr != nullptr) { @@ -645,5 +643,13 @@ void ModelDigraph::Edge(AnfNodePtr start, AnfNodePtr end, int idx, int id_start) buffer_ << "[arrowhead=vee,"; buffer_ << "]" << std::endl; } + +struct DrawerRegister { + DrawerRegister() { + FuncGraph::set_drawer( + [](const std::string &filename, const FuncGraphPtr &func_graph) { Draw(filename, func_graph); }); + } + ~DrawerRegister() = default; +} drawer_regsiter; } // namespace draw } // namespace mindspore diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index 35cdfafe26..9594fa6b52 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -24,7 +24,7 @@ #include "debug/anf_ir_utils.h" #include "proto/anf_ir.pb.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/symbolic.h" namespace mindspore { @@ -120,10 +120,12 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); } } - } else if (type->isa()) { + } else if (type->isa()) { // Do Nothing } else if (type->isa()) { // Do Nothing + } else if (type->isa()) { + // Do Nothing } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_TUPLE); diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc index 9037a6d00b..147582a360 100644 --- a/mindspore/ccsrc/debug/e2e_dump.cc +++ b/mindspore/ccsrc/debug/e2e_dump.cc @@ -23,7 +23,7 @@ #include "utils/system/file_system.h" #include "utils/system/env.h" #include "utils/convert_utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "debug/common.h" using json = nlohmann::json; diff --git a/mindspore/ccsrc/debug/info.cc b/mindspore/ccsrc/debug/info.cc deleted file mode 100644 index f58522cf33..0000000000 --- a/mindspore/ccsrc/debug/info.cc +++ /dev/null @@ -1,222 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "debug/info.h" -#include -#include -#include -#include -#include "ir/anf.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { - std::string temp_line = line; - if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && - tip != kSourceLineTipDiscard) { - std::string start = temp_line.substr(0, IntToSize(col_begin)); - std::string trimmed = temp_line.substr(IntToSize(col_begin), IntToSize(col_end - col_begin)); - std::string end = temp_line.substr(IntToSize(col_end), IntToSize(SizeToInt(temp_line.length()) - col_end)); - std::stringstream oss; - std::stringstream tip_ss; - std::string start_spaces(start.length(), ' '); - if (tip == kSourceLineTipInLine) { - temp_line = start + "<" + trimmed + ">" + end; - } else if (tip == kSourceLineTipNextLine) { - tip_ss << start_spaces << "^"; - } - oss << temp_line << "\n" << tip_ss.str(); - return oss.str(); - } - return temp_line; -} -// Generate debug information for the location node . -// print the file name, line no and column no, and part of the content -std::string Location::ToString(SourceLineTip tip) { - std::stringstream debug_info_ss; - debug_info_ss << " In file " << file_name_ << "(" << line_ << ")" << std::endl; - if (line_ <= 0) { - return debug_info_ss.str(); - } - - char path[PATH_MAX + 1] = {0x00}; -#if defined(_WIN32) || defined(_WIN64) - if (file_name_.size() > PATH_MAX || _fullpath(path, file_name_.c_str(), PATH_MAX) == nullptr) { - return debug_info_ss.str(); - } -#else - if (file_name_.size() > PATH_MAX || realpath(file_name_.c_str(), path) == nullptr) { - return debug_info_ss.str(); - } -#endif - auto src_path = std::string(path); - std::ifstream file(src_path); - if (!file.is_open()) { - return debug_info_ss.str(); - } - - int line_num = 0; - std::string line; - (void)getline(file, line); - while (line_num != line_ - 1) { - (void)getline(file, line); - line_num++; - } - file.close(); - - debug_info_ss << HighLightLine(line, column_, column_end_, tip) << std::endl; - return debug_info_ss.str(); -} - -void TraceContext::ProcessAttributeFromContext() { - trace_info_ = nullptr; - location_ = nullptr; - func_name_ = ""; - // if there is trace context, get info from previous context - if (!TraceManager::trace_context_stack_.empty()) { - TraceContextPtr top = TraceManager::trace_context_stack_.top(); - trace_info_ = top->trace_info_; - location_ = top->location_; - func_name_ = top->func_name_; - } -} - -DebugInfo::DebugInfo() { - InitValueFromContext(); - unique_id_ = gen_unique_id(); - debug_id_ = -1; - name_ = ""; -} - -DebugInfo::DebugInfo(const std::string &name) { - InitValueFromContext(); - unique_id_ = gen_unique_id(); - debug_id_ = -1; - name_ = name; -} - -DebugInfo::DebugInfo(const LocationPtr &loc) { - InitValueFromContext(); - unique_id_ = gen_unique_id(); - debug_id_ = -1; - location_ = loc; -} - -int64_t DebugInfo::debug_id() { - // cppcheck-suppress variableScope - static int64_t cur_debug_id = 0; - if (debug_id_ == -1) { - debug_id_ = cur_debug_id; - cur_debug_id++; - } - return debug_id_; -} - -int64_t DebugInfo::unique_id_through_copy() const { - auto info = trace_info(); - if (info != nullptr) { - if (info->isa() && info->debug_info() != nullptr) { - return info->debug_info()->unique_id_through_copy(); - } - } - return unique_id(); -} - -std::string DebugInfo::debug_name() { - if (!name_.empty()) { - return name_; - } - std::string debug_name = std::to_string(debug_id()); - name_ = debug_name; - return debug_name; -} - -std::string NodeDebugInfo::debug_name() { - if (!name_.empty()) { - return name_; - } - std::string prefix = ""; - if (node_.lock() != nullptr) { - std::ostringstream oss; - oss << "[" << node_.lock()->type_name() << "]"; - prefix = oss.str(); - } - name_ = prefix + DebugInfo::debug_name(); - return name_; -} - -std::string GraphDebugInfo::debug_name() { - std::string prefix = ""; - return prefix + DebugInfo::debug_name(); -} - -LocationPtr GraphDebugInfo::location() { - // function may have decorator which is included in its location - if (deco_loc_ != nullptr) { - LocationPtr loc = std::make_shared(*DebugInfo::location()); - loc->set_line(loc->line() + (deco_loc_->line_end() - deco_loc_->line() + 1)); - return loc; - } - return DebugInfo::location(); -} -void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } - -TraceContextPtr TraceManager::CurrentContextInfo() { - if (!TraceManager::trace_context_stack_.empty()) { - return TraceManager::trace_context_stack_.top(); - } - return nullptr; -} - -void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { - TraceContextPtr context = std::make_shared(location); - context->set_func_name(func_name); - TraceManager::trace_context_stack_.push(context); -} - -void TraceManager::DebugTrace(const LocationPtr &location) { - TraceContextPtr context = std::make_shared(location); - TraceManager::trace_context_stack_.push(context); -} - -void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { - if (trace_info == nullptr) { - MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; - } - TraceContextPtr context = std::make_shared(trace_info); - if (trace_info->debug_info() == nullptr) { - MS_LOG(EXCEPTION) << "Trace debug info is null"; - } - TraceManager::trace_context_stack_.push(context); -} - -void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { - if (trace_info == nullptr) { - MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; - } - auto cloned_info = trace_info->clone(); - cloned_info->set_debug_info(debug_info); - if (cloned_info->debug_info() == nullptr) { - MS_LOG(EXCEPTION) << "Trace debug info is null with cloned trace"; - } - TraceContextPtr context = std::make_shared(cloned_info); - TraceManager::trace_context_stack_.push(context); -} - -void TraceManager::EndTrace() { TraceManager::trace_context_stack_.pop(); } - -std::stack TraceManager::trace_context_stack_; -} // namespace mindspore diff --git a/mindspore/ccsrc/debug/info.h b/mindspore/ccsrc/debug/info.h deleted file mode 100644 index 39475a4606..0000000000 --- a/mindspore/ccsrc/debug/info.h +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_INFO_H_ -#define MINDSPORE_CCSRC_IR_INFO_H_ - -#include -#include -#include -#include -#include -#include - -#include "base/base.h" -#include "debug/trace_info.h" - -namespace mindspore { -// namespace to support intermediate representation definition -enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 }; - -// Location class record the location in source code. -class Location { - public: - Location(const std::string &file_name, int line, int column, int line_end, int column_end) - : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} - Location(const Location &loc) - : file_name_(loc.file_name_), - line_(loc.line_), - column_(loc.column_), - line_end_(loc.line_end_), - column_end_(loc.column_end_) {} - std::string ToString(SourceLineTip tip = kSourceLineTipNextLine); - std::string file_name() { return file_name_; } - int line() const { return line_; } - void set_line(int line) { line_ = line; } - int line_end() const { return line_end_; } - void set_line_end(int line) { line_end_ = line; } - int column() const { return column_; } - void set_column(int column) { column_ = column; } - int column_end() const { return column_end_; } - void set_column_end(int column) { column_end_ = column; } - ~Location() = default; - - private: - std::string file_name_; - int line_; - int column_; - int line_end_; - int column_end_; -}; -class TraceContext; -using TraceContextPtr = std::shared_ptr; -class FuncGraph; -using FuncGraphPtr = std::shared_ptr; -using FuncGraphWeakPtr = std::weak_ptr; -class AnfNode; -using AnfNodeWeakPtr = std::weak_ptr; - -class TraceManager { - public: - TraceManager() = default; - ~TraceManager() = default; - static TraceContextPtr CurrentContextInfo(); - static void DebugTrace(const std::string &func_name, const LocationPtr &location); - static void DebugTrace(const LocationPtr &location); - static void DebugTrace(const TraceInfoPtr &trace_info); - // debug trace with a cloned trace info with debug_info - static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); - static void EndTrace(); - static std::stack trace_context_stack_; -}; - -class TraceGuard { - public: - explicit TraceGuard(const std::string func_name, const LocationPtr &location) { - TraceManager::DebugTrace(func_name, location); - } - explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } - ~TraceGuard() { TraceManager::EndTrace(); } -}; - -class TraceContext { - public: - LocationPtr location_; - TraceInfoPtr trace_info_; - std::string func_name_; - - protected: - void ProcessAttributeFromContext(); - - public: - ~TraceContext() = default; - explicit TraceContext(const LocationPtr &loc) { - ProcessAttributeFromContext(); - location_ = loc; - } - explicit TraceContext(const std::string &func_name) { - ProcessAttributeFromContext(); - func_name_ = func_name; - } - explicit TraceContext(const TraceInfoPtr &trace_info) { - ProcessAttributeFromContext(); - trace_info_ = trace_info; - } - void set_location(const LocationPtr &loc) { location_ = loc; } - LocationPtr location() { return location_; } - void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } - TraceInfoPtr trace_info() const { return trace_info_; } - void set_func_name(const std::string &func_name) { func_name_ = func_name; } - std::string func_name() { return func_name_; } -}; - -class DebugInfo : public Base { - public: - DebugInfo(); - - explicit DebugInfo(const std::string &name); - - explicit DebugInfo(const LocationPtr &loc); - - ~DebugInfo() override = default; - MS_DECLARE_PARENT(DebugInfo, Base); - int64_t debug_id(); - int64_t unique_id() const { return unique_id_; } - int64_t unique_id_through_copy() const; - std::string get_id() { return std::to_string(debug_id()); } - - void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } - TraceInfoPtr trace_info() const { return trace_info_; } - void set_location(const LocationPtr &loc) { location_ = loc; } - virtual LocationPtr location() { return location_; } - std::string name() { return name_; } - void set_name(const std::string &name) { name_ = name; } - virtual std::string debug_name(); - - virtual std::string get_python_func_belonged() { return ""; } - - protected: - template - std::shared_ptr shared_from_base() { - return std::static_pointer_cast(shared_from_this()); - } - - private: - void InitValueFromContext() { - if (TraceManager::CurrentContextInfo() != nullptr) { - auto context_info = TraceManager::CurrentContextInfo(); - trace_info_ = context_info->trace_info(); - location_ = context_info->location(); - } - } - static int64_t gen_unique_id() { - static int64_t cur_unique_id = 0; - return cur_unique_id++; - } - - protected: - int64_t unique_id_; - int64_t debug_id_; - TraceInfoPtr trace_info_; - LocationPtr location_; - std::string name_; -}; - -class NodeDebugInfo : public DebugInfo { - public: - NodeDebugInfo() { - if (TraceManager::CurrentContextInfo() != nullptr) { - auto context_info = TraceManager::CurrentContextInfo(); - py_func_belonged_ = context_info->func_name(); - } - } - explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { - if (TraceManager::CurrentContextInfo() != nullptr) { - auto context_info = TraceManager::CurrentContextInfo(); - py_func_belonged_ = context_info->func_name(); - } - } - ~NodeDebugInfo() override = default; - - std::string debug_name() override; - void set_node(const std::shared_ptr &node) { node_ = AnfNodeWeakPtr(node); } - std::shared_ptr get_node() const { return node_.lock(); } - void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } - std::string get_python_func_belonged() override { return py_func_belonged_; } - AnfNodeWeakPtr node_; - std::string py_func_belonged_; -}; -using NodeDebugInfoPtr = std::shared_ptr; - -class GraphDebugInfo : public DebugInfo { - public: - GraphDebugInfo() { - if (TraceManager::CurrentContextInfo() != nullptr) { - auto context_info = TraceManager::CurrentContextInfo(); - py_func_name_ = context_info->func_name(); - deco_loc_ = nullptr; - } - } - - explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { - if (TraceManager::CurrentContextInfo() != nullptr) { - auto context_info = TraceManager::CurrentContextInfo(); - py_func_name_ = context_info->func_name(); - deco_loc_ = nullptr; - } - } - ~GraphDebugInfo() override = default; - std::string debug_name() override; - LocationPtr location() override; - LocationPtr deco_location() { return deco_loc_; } - void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } - FuncGraphPtr get_graph() const { return func_graph_.lock(); } - void set_full_name(const std::string &name) { full_name_ = name; } - std::string get_full_name() { return full_name_; } - void set_deco_location(const LocationPtr &deco_list_loc); - std::string get_python_func_belonged() override { return py_func_name_; } - FuncGraphWeakPtr func_graph_; - LocationPtr deco_loc_; - std::string py_func_name_; - std::string full_name_; -}; - -using GraphDebugInfoPtr = std::shared_ptr; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_INFO_H_ diff --git a/mindspore/ccsrc/debug/label.cc b/mindspore/ccsrc/debug/label.cc deleted file mode 100644 index d8c4986482..0000000000 --- a/mindspore/ccsrc/debug/label.cc +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "debug/label.h" -#include -#include -#include - -#include "debug/info.h" -#include "ir/func_graph.h" - -namespace mindspore { -namespace label_manage { -static TraceLabelType trace_type = TraceLabelType::kShortSymbol; -TraceLabelType GetGlobalTraceLabelType() { return trace_type; } -void SetGlobalTraceLabelType(TraceLabelType label_type) { trace_type = label_type; } -struct NameWithTrace { - std::string name; - std::vector trace_labels; -}; -static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { - switch (trace_label) { - case TraceLabelType::kShortSymbol: - return trace_info->symbol(); - case TraceLabelType::kFullName: - return "_" + trace_info->full_name() + "_"; - default: - return ""; - } -} - -NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { - NameWithTrace trace_name; - // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node - auto temp_info = debug_info; - while (temp_info != nullptr) { - if (temp_info->trace_info() != nullptr) { - if (temp_info->trace_info()->isa() || temp_info->trace_info()->isa() || - temp_info->trace_info()->isa()) { - break; - } - trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); - temp_info = temp_info->trace_info()->debug_info(); - } else { - break; - } - } - if (!temp_info->name().empty()) { - trace_name.name = temp_info->name(); - } else { - trace_name.name = temp_info->debug_name(); - } - return trace_name; -} - -std::string CombineTraceTypes(const std::string &root_name, const std::vector &trace_labels) { - std::string tags = ""; - for (auto &itr : trace_labels) { - std::string symbol = itr; - tags = tags + symbol; - } - return tags + root_name; -} - -// get the label name of the node debug info -std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { - NameWithTrace trace_name = RootName(debug_info, trace_label); - return CombineTraceTypes(trace_name.name, trace_name.trace_labels); -} - -std::string CombineUniqueID(const DebugInfoPtr &debug_info) { - auto temp_info = debug_info; - std::string label = ""; - while (temp_info != nullptr) { - if (!temp_info->name().empty()) { - label = label + temp_info->name(); - } else { - // the symbol 'U' is for identification of number - label = label + "U" + std::to_string(temp_info->unique_id()); - } - - if (temp_info->trace_info() != nullptr) { - label = label + "_" + temp_info->trace_info()->full_name() + "_"; - temp_info = temp_info->trace_info()->debug_info(); - } else { - temp_info = nullptr; - } - } - return label; -} - -// get trace with unique id chain -std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); } - -std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { - if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { - return LabelStringUnique(debug_info); - } - return LabelString(debug_info, trace_label); -} -} // namespace label_manage -} // namespace mindspore diff --git a/mindspore/ccsrc/debug/tensor_load.h b/mindspore/ccsrc/debug/tensor_load.h index ae0e89aae2..8c4072ec49 100644 --- a/mindspore/ccsrc/debug/tensor_load.h +++ b/mindspore/ccsrc/debug/tensor_load.h @@ -24,6 +24,10 @@ #include #include #include "debug/tensor_data.h" +#include "ir/dtype.h" +#ifdef ENABLE_DUMP_E2E +#include "debug/e2e_dump.h" +#endif namespace mindspore { class TensorLoader { public: @@ -43,6 +47,9 @@ class TensorLoader { } tensor_list.push_back(tensor); tensor_list_map.insert({tensor->GetName(), tensor}); + auto node_name = tensor->GetName(); + node_name = node_name.substr(0, node_name.find_first_of(":")); + node_tensor_map.insert({node_name, tensor}); return true; } std::vector> GetTensor() { return tensor_list; } @@ -50,6 +57,17 @@ class TensorLoader { uint32_t GetIterNum() { return iter_num; } std::map> GetTensorMap() { return tensor_list_map; } + + std::vector> GetNodeTensorMap(std::string node_name) { + std::vector> tensors; + for (auto itr = node_tensor_map.begin(); itr != node_tensor_map.end(); itr++) { + if (itr->first == node_name) { + tensors.push_back(itr->second); + } + } + return tensors; + } + void SearchTensors(const std::vector &search_list, std::vector>> *result_list) { for (auto i : search_list) { @@ -66,17 +84,65 @@ class TensorLoader { void EmptyTensor() { std::lock_guard lg(lock_); prev_tensor_list_map.clear(); + node_tensor_map.clear(); tensor_list_map.swap(prev_tensor_list_map); tensor_list.clear(); } void EmptyPrevTensor() { prev_tensor_list_map.clear(); } + void EmptyCurrentTensor() { + tensor_list_map.clear(); + tensor_list.clear(); + } + void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; } +#ifdef ENABLE_DUMP_E2E + bool DumpTensorToFile(std::string tensor_name, bool trans_flag, const std::string &filepath, + const std::string &host_fmt, const std::vector &host_shape, TypeId host_type, + TypeId addr_type_id, std::string addr_format, size_t slot) const { + bool ret = false; + if (filepath.empty()) { + MS_LOG(ERROR) << "Dump file path is null!"; + return ret; + } + std::string shape = "shape"; + if (host_shape.size()) { + for (auto &value : host_shape) { + shape = shape + '_' + std::to_string(value); + } + } else { + shape = shape + "_0"; + } + std::string file_extension = ".bin"; + std::string path = ""; + if (trans_flag) { + path = filepath + '_' + shape + '_' + TypeIdLabel(host_type) + '_' + host_fmt + file_extension; + } else { + path = filepath + '_' + shape + '_' + TypeIdToType(addr_type_id)->ToString() + '_' + addr_format + file_extension; + } + + MS_LOG(INFO) << "Dump path is " << path; + + std::string tensor_loader_name = tensor_name + ":" + std::to_string(slot); + auto iter = tensor_list_map.find(tensor_loader_name); + if (iter != tensor_list_map.end()) { + std::shared_ptr node = iter->second; + mindspore::tensor::TensorPtr out_tensor = node->GetTensor(); + size_t host_size = out_tensor->data().nbytes(); + + ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); + } + + return ret; + } +#endif + private: std::vector> tensor_list; std::map> tensor_list_map; + std::multimap> node_tensor_map; std::map> prev_tensor_list_map; uint32_t iter_num; std::mutex lock_; diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index b8d3f0a7c7..605d09ce7f 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -28,11 +28,12 @@ #include #include "ir/meta_func_graph.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "frontend/operator/composite/composite.h" #include "ir/tensor.h" #include "debug/anf_ir_utils.h" #include "pipeline/jit/static_analysis/evaluator.h" +#include "utils/log_adapter.h" namespace mindspore { // namespace to support debug trace infomation @@ -495,5 +496,17 @@ void ClearTraceStack() { } cnode_debug_stack.clear(); } + +// Register trace provider to LogWriter. +struct TraceProviderRegister { + TraceProviderRegister() { + LogWriter::set_trace_provider([](std::ostringstream &oss) { + TraceGraphEval(); + GetEvalStackInfo(oss); + }); + } + ~TraceProviderRegister() = default; +} trace_provider_regsiter; + } // namespace trace } // namespace mindspore diff --git a/mindspore/ccsrc/debug/trace.h b/mindspore/ccsrc/debug/trace.h index 7cf45abe30..19add25876 100644 --- a/mindspore/ccsrc/debug/trace.h +++ b/mindspore/ccsrc/debug/trace.h @@ -23,8 +23,8 @@ #include #include -#include "debug/trace_base.h" -#include "debug/info.h" +#include "utils/trace_base.h" +#include "utils/info.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "pipeline/jit/static_analysis/static_analysis.h" diff --git a/mindspore/ccsrc/debug/trace_base.cc b/mindspore/ccsrc/debug/trace_base.cc deleted file mode 100644 index 6cd41d7f2d..0000000000 --- a/mindspore/ccsrc/debug/trace_base.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "debug/trace_base.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "utils/graph_utils.h" - -namespace mindspore { -// namespace to support debug trace infomation -namespace trace { -std::vector GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { - std::vector debug_with_loc_vec; - while (debug_info != nullptr) { - if (debug_info->location() != nullptr) { - debug_with_loc_vec.push_back(debug_info); - } - if (debug_info->trace_info() != nullptr) { - debug_info = debug_info->trace_info()->debug_info(); - } else { - break; - } - } - return debug_with_loc_vec; -} - -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { - auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); - if (debug_with_loc_vec.size() > 0) { - return debug_with_loc_vec[0]; - } else { - return info; - } -} - -std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { - if (info == nullptr) { - return ""; - } - auto src_info = GetSourceCodeDebugInfo(info); - if (src_info->location() != nullptr) { - return src_info->location()->ToString(tip); - } - return ""; -} - -// a trace info identifies a node transform, so we can trace the node transform through -// a link of trace info and debug info -std::string GetInfoWithAction(const std::vector &info_vec, SourceLineTip tip) { - if (info_vec.size() < 1) { - return ""; - } - if (info_vec.size() == 1) { - return info_vec[0]->location()->ToString(tip); - } - std::string traced_info = info_vec[0]->location()->ToString(tip); - for (size_t i = 1; i < info_vec.size(); i++) { - auto action_name = info_vec[i - 1]->trace_info()->GetActionBetweenNode(info_vec[i]); - if (action_name == "") { - break; - } - traced_info = traced_info + action_name + info_vec[i]->location()->ToString(tip); - } - return traced_info; -} - -std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { - if (info == nullptr) { - return ""; - } - auto info_vec = GetSourceCodeDebugInfoVec(info); - if (info_vec.size() == 0) { - return ""; - } else if (info_vec.size() == 1) { - return info_vec[0]->location()->ToString(tip); - } else if (info_vec.size() > 1) { - return GetInfoWithAction(info_vec, tip); - } - return ""; -} - -std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { - std::ostringstream oss; - if (info == nullptr) { - return ""; - } - - auto debug_info = GetTracedDebugInfo(info, tip); - if (tip == kSourceLineTipDiscard) { - std::replace(debug_info.begin(), debug_info.end(), '\r', '/'); - std::replace(debug_info.begin(), debug_info.end(), '\n', '/'); - } - oss << prefix << debug_info; - return oss.str(); -} -} // namespace trace -} // namespace mindspore diff --git a/mindspore/ccsrc/debug/trace_base.h b/mindspore/ccsrc/debug/trace_base.h deleted file mode 100644 index 774931cc09..0000000000 --- a/mindspore/ccsrc/debug/trace_base.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEBUG_TRACE_BASE_H_ -#define MINDSPORE_CCSRC_DEBUG_TRACE_BASE_H_ - -#include -#include -#include -#include -#include -#include - -#include "debug/info.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "utils/any.h" - -namespace mindspore { -namespace trace { -std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); -std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, - SourceLineTip tip = kSourceLineTipNextLine); -} // namespace trace -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEBUG_TRACE_BASE_H_ diff --git a/mindspore/ccsrc/debug/trace_info.cc b/mindspore/ccsrc/debug/trace_info.cc deleted file mode 100644 index 048bf2bdf0..0000000000 --- a/mindspore/ccsrc/debug/trace_info.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "debug/trace_info.h" -#include -#include -#include -#include "ir/anf.h" - -namespace mindspore { -std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { - if (info == nullptr) { - return ""; - } - std::string act_name = action_name(); - if (debug_info() == nullptr) { - MS_LOG(EXCEPTION) << "Traced debug info is null"; - } - if (debug_info() == info) { - return act_name; - } else if (debug_info()->trace_info() != nullptr) { - return act_name + debug_info()->trace_info()->GetActionBetweenNode(info); - } - return "Not in the traced info"; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h deleted file mode 100644 index 62908cb449..0000000000 --- a/mindspore/ccsrc/debug/trace_info.h +++ /dev/null @@ -1,417 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEBUG_TRACE_INFO_H_ -#define MINDSPORE_CCSRC_DEBUG_TRACE_INFO_H_ - -#include -#include -#include -#include -#include -#include - -#include "base/base.h" - -namespace mindspore { -class TraceInfo; -using TraceInfoPtr = std::shared_ptr; -class Location; -using LocationPtr = std::shared_ptr; -class DebugInfo; -using DebugInfoPtr = std::shared_ptr; - -// namespace to support intermediate representation definition -class TraceInfo : public Base { - public: - TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { - symbol_ = symbol; - full_name_ = full_name; - name_ = full_name_; - debug_info_ = info; - } - TraceInfo(const TraceInfo &info) - : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} - ~TraceInfo() override = default; - MS_DECLARE_PARENT(TraceInfo, Base); - virtual std::string name() { return name_; } - virtual std::string symbol() { return symbol_; } - virtual std::string full_name() { return full_name_; } - virtual TraceInfoPtr clone() { return shared_from_base(); } - virtual std::string action_name() { return ""; } - virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); - void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } - DebugInfoPtr debug_info() { return debug_info_; } - DebugInfoPtr DebugInfoHasLoc(); - std::vector> GetSourceCodeDebugInfo(); - - protected: - DebugInfoPtr debug_info_; - std::string symbol_; - std::string full_name_; - std::string name_; -}; - -class TracePhi : public TraceInfo { - public: - explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} - MS_DECLARE_PARENT(TracePhi, TraceInfo); - ~TracePhi() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceIfStmtTrueBranch : public TraceInfo { - public: - TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; - explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} - MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); - ~TraceIfStmtTrueBranch() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceIfStmtFalseBranch : public TraceInfo { - public: - TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; - explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} - MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); - ~TraceIfStmtFalseBranch() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceIfStmtAfterBranch : public TraceInfo { - public: - explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} - MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); - ~TraceIfStmtAfterBranch() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceIfExpTrueBranch : public TraceInfo { - public: - explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} - MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); - ~TraceIfExpTrueBranch() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceIfExpFalseBranch : public TraceInfo { - public: - explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} - MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); - ~TraceIfExpFalseBranch() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceCopy : public TraceInfo { - public: - TraceCopy() : TraceInfo(nullptr, "copy", "") {} - explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} - MS_DECLARE_PARENT(TraceCopy, TraceInfo); - ~TraceCopy() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceIterator : public TraceInfo { - public: - explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} - MS_DECLARE_PARENT(TraceIterator, TraceInfo); - ~TraceIterator() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceWhileHeader : public TraceInfo { - public: - explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} - MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); - ~TraceWhileHeader() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceWhileBody : public TraceInfo { - public: - explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} - MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); - ~TraceWhileBody() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceWhileAfter : public TraceInfo { - public: - explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} - MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); - ~TraceWhileAfter() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceForHeader : public TraceInfo { - public: - explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} - MS_DECLARE_PARENT(TraceForHeader, TraceInfo); - ~TraceForHeader() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceForBody : public TraceInfo { - public: - explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} - MS_DECLARE_PARENT(TraceForBody, TraceInfo); - ~TraceForBody() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceForAfter : public TraceInfo { - public: - explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} - MS_DECLARE_PARENT(TraceForAfter, TraceInfo); - ~TraceForAfter() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceLoopEnd : public TraceInfo { - public: - explicit TraceLoopEnd(const DebugInfoPtr &info) : TraceInfo(info, "loop_end", "↓↓") {} - MS_DECLARE_PARENT(TraceLoopEnd, TraceInfo); - ~TraceLoopEnd() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceEquiv : public TraceInfo { - public: - explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} - MS_DECLARE_PARENT(TraceEquiv, TraceInfo); - ~TraceEquiv() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceGradFpropApp : public TraceInfo { - public: - TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} - explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} - MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); - ~TraceGradFpropApp() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceGradBpropApp : public TraceInfo { - public: - TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} - explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} - MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); - ~TraceGradBpropApp() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceGradFprop : public TraceInfo { - public: - TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} - explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} - MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); - ~TraceGradFprop() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceGradBprop : public TraceInfo { - public: - TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} - explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} - MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); - ~TraceGradBprop() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceGradSens : public TraceInfo { - public: - TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} - explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} - MS_DECLARE_PARENT(TraceGradSens, TraceInfo); - ~TraceGradSens() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceSpecialize : public TraceInfo { - public: - explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } - MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); - std::string name() override { return full_name_ + counter_; } - std::string symbol() override { return counter_ + "_"; } - std::string full_name() override { return full_name_ + counter_ + "_"; } - ~TraceSpecialize() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } - std::string counter_; -}; - -class TraceGradOperation : public TraceInfo { - public: - explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} - MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); - ~TraceGradOperation() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceForceBool : public TraceInfo { - public: - explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} - MS_DECLARE_PARENT(TraceForceBool, TraceInfo); - ~TraceForceBool() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceForceWhileCond : public TraceInfo { - public: - explicit TraceForceWhileCond(const DebugInfoPtr &info) : TraceInfo(info, "force_while_cond", "") {} - MS_DECLARE_PARENT(TraceForceWhileCond, TraceInfo); - ~TraceForceWhileCond() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceExpandJ : public TraceInfo { - public: - explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} - MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); - ~TraceExpandJ() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceGenMetaFuncGraph : public TraceInfo { - public: - explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} - MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); - ~TraceGenMetaFuncGraph() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceEvaluatorGenGraph : public TraceInfo { - public: - explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} - MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); - ~TraceEvaluatorGenGraph() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceResolve : public TraceInfo { - public: - explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} - MS_DECLARE_PARENT(TraceResolve, TraceInfo); - ~TraceResolve() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceTransform : public TraceInfo { - public: - TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } - explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { - transform_name_ = transform_name; - } - - std::string full_name() override { return full_name_ + transform_name_; } - MS_DECLARE_PARENT(TraceTransform, TraceInfo); - std::string symbol() override { - if (transform_name_.empty()) { - return ""; - } - return transform_name_ + "_"; - } - - ~TraceTransform() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } - std::string transform_name_; -}; - -class TraceGenerateVarArg : public TraceInfo { - public: - explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} - MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); - ~TraceGenerateVarArg() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceGenerateKwArg : public TraceInfo { - public: - explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} - MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); - ~TraceGenerateKwArg() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceTrasformK : public TraceInfo { - public: - explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} - MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); - ~TraceTrasformK() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TracePartialTransform : public TraceInfo { - public: - explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} - MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); - ~TracePartialTransform() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; - -class TraceGetEnv : public TraceInfo { - public: - explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} - MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); - ~TraceGetEnv() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceDoSignature : public TraceInfo { - public: - explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} - MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); - ~TraceDoSignature() override = default; - TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } -}; - -class TraceCombileLikeGraphs : public TraceInfo { - public: - TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} - explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} - MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); - ~TraceCombileLikeGraphs() override = default; - TraceInfoPtr clone() override { - return std::make_shared(*shared_from_base()); - } -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEBUG_TRACE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.cc b/mindspore/ccsrc/frontend/operator/cc_implementations.cc index 3ec3455be7..2df68037e5 100644 --- a/mindspore/ccsrc/frontend/operator/cc_implementations.cc +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.cc @@ -23,7 +23,7 @@ #include "utils/misc.h" #include "utils/log_adapter.h" #include "utils/convert_utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { // namespace to support primitive operators definition diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.h b/mindspore/ccsrc/frontend/operator/cc_implementations.h index cef34da4f4..ffe75cb0c0 100644 --- a/mindspore/ccsrc/frontend/operator/cc_implementations.h +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_CC_IMPLEMENTATIONS_H_ -#define MINDSPORE_CCSRC_OPERATOR_CC_IMPLEMENTATIONS_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_CC_IMPLEMENTATIONS_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_CC_IMPLEMENTATIONS_H_ #include #include @@ -56,4 +56,4 @@ std::vector BroadcastShape_(std::vector s1, std::vector s2); } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_CC_IMPLEMENTATIONS_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_CC_IMPLEMENTATIONS_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 7d2573e50a..fbcb06629d 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -25,7 +25,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "abstract/dshape.h" #include "abstract/param_validator.h" #include "frontend/operator/cc_implementations.h" @@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg return fg; } +FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + int list_size = SizeToInt(args_spec_list.size()); + + std::ostringstream ss; + ss << "▶make_list_" << list_size; + FuncGraphPtr fg = std::make_shared(); + fg->debug_info()->set_name(ss.str()); + + std::vector params; + params.push_back(NewValueNode(prim::kPrimMakeList)); + for (int i = 0; i < list_size; ++i) { + params.push_back(fg->add_parameter()); + } + + // make fprob first result, maketuple's forward result. + AnfNodePtr out = fg->NewCNode(params); + + // make fprob second result, maketuple's backward function. + FuncGraphPtr b = std::make_shared(); + + ss.clear(); + ss << "◀make_list_" << list_size; + b->debug_info()->set_name(ss.str()); + AnfNodePtr dout = b->add_parameter(); + + std::vector grads; + grads.push_back(NewValueNode(prim::kPrimMakeTuple)); + grads.push_back(NewValueNode(newenv)); + for (int i = 0; i < list_size; ++i) { + grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)})); + } + + b->set_flag(FUNC_GRAPH_FLAG_CORE, true); + b->set_output(b->NewCNode(grads)); + + fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); + (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList)); + return fg; +} + GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { if (get_by_list) { @@ -803,6 +844,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li abstract::AbstractTuplePtr a_tuple = dyn_cast(abs_a); abstract::AbstractTuplePtr b_tuple = dyn_cast(abs_b); if (a_tuple == nullptr || b_tuple == nullptr) { + TypePtrList types; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), + [](const AbstractBasePtr &arg) -> TypePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->BuildType(); + }); + auto stub = GenerateStubFunc(types); + if (stub != nullptr) { + MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd " + << ", function: " << stub->ToString(); + return stub; + } MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " << args_spec_list[1]->ToString(); } diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index 3821192dba..21a4588958 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ #include #include @@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph { }; using MakeTupleGradientPtr = std::shared_ptr; +class MakeListGradient : public MetaFuncGraph { + public: + explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {} + ~MakeListGradient() override = default; + MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; } +}; +using MakeListGradientPtr = std::shared_ptr; + class GradOperation : public MetaFuncGraph { public: explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, @@ -189,4 +199,4 @@ using TupleGetItemTensorPtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.h b/mindspore/ccsrc/frontend/operator/composite/do_signature.h index 9139be806a..33ef78e27c 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.h +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ #include #include @@ -30,7 +30,7 @@ #include "utils/any.h" #include "ir/dtype.h" #include "ir/meta_func_graph.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { // namespace to support composite operators definition @@ -66,4 +66,4 @@ AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/list_append_operation.h b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.h index 1da3f9a009..9555010d01 100644 --- a/mindspore/ccsrc/frontend/operator/composite/list_append_operation.h +++ b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_LIST_APPEND_OPERATION_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_LIST_APPEND_OPERATION_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_LIST_APPEND_OPERATION_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_LIST_APPEND_OPERATION_H_ #include #include @@ -42,4 +42,4 @@ using ListAppendPtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_LIST_APPEND_OPERATION_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_LIST_APPEND_OPERATION_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/map.cc b/mindspore/ccsrc/frontend/operator/composite/map.cc index a5f674187b..f49c19aa9c 100644 --- a/mindspore/ccsrc/frontend/operator/composite/map.cc +++ b/mindspore/ccsrc/frontend/operator/composite/map.cc @@ -23,7 +23,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "abstract/dshape.h" #include "pybind_api/api_register.h" #include "debug/trace.h" diff --git a/mindspore/ccsrc/frontend/operator/composite/map.h b/mindspore/ccsrc/frontend/operator/composite/map.h index 428014f9c4..da8edcef43 100644 --- a/mindspore/ccsrc/frontend/operator/composite/map.h +++ b/mindspore/ccsrc/frontend/operator/composite/map.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ #include #include @@ -95,4 +95,4 @@ using MapPyPtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index ba0d3d9ebb..e9418bb5db 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -25,12 +25,12 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "abstract/dshape.h" #include "abstract/param_validator.h" #include "frontend/operator/cc_implementations.h" #include "frontend/optimizer/opt.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/symbolic.h" #include "pybind_api/api_register.h" #include "./common.h" @@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { return py::none(); } -FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); - if (!enable_sparse) { - return nullptr; - } - - std::vector parameters; - ParameterPtr undetermined_param = nullptr; - auto stub = std::make_shared(); - for (size_t i = 0; i < types.size(); ++i) { - auto param = stub->add_parameter(); - parameters.push_back(param); - if (types[i]->type_id() == kObjectTypeUndeterminedType) { - undetermined_param = param; - } - } - if (undetermined_param != nullptr) { - std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; - for (size_t i = 0; i < types.size(); ++i) { - if (types[i]->type_id() == kObjectTypeFunction) { - std::vector call_prim{parameters[i], undetermined_param}; - inputs.push_back(stub->NewCNode(call_prim)); - } else { - inputs.push_back(parameters[i]); - } - } - auto stub_output = stub->NewCNode(inputs); - stub->set_output(stub_output); - stub->set_stub(true); - return stub; - } - return nullptr; -} - FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { auto py_fn = SignMatch(types); std::ostringstream buffer; diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h index 2139a0e9d1..9bcfdb2ee2 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ #include #include @@ -62,4 +62,4 @@ using MultitypeFuncGraphPtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.h b/mindspore/ccsrc/frontend/operator/composite/unpack_call.h index 79c2600f36..5a5016f93b 100644 --- a/mindspore/ccsrc/frontend/operator/composite/unpack_call.h +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_UNPACK_CALL_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_UNPACK_CALL_H_ #include #include @@ -30,7 +30,7 @@ #include "utils/any.h" #include "ir/dtype.h" #include "ir/meta_func_graph.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { // namespace to support composite operators definition @@ -49,4 +49,4 @@ using UnpackCallPtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_UNPACK_CALL_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/zip_operation.h b/mindspore/ccsrc/frontend/operator/composite/zip_operation.h index 96697cb472..6c142e930f 100644 --- a/mindspore/ccsrc/frontend/operator/composite/zip_operation.h +++ b/mindspore/ccsrc/frontend/operator/composite/zip_operation.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ #include #include @@ -56,4 +56,4 @@ using ZipOperationPtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ diff --git a/mindspore/ccsrc/frontend/operator/ops.cc b/mindspore/ccsrc/frontend/operator/ops.cc deleted file mode 100755 index 5c7672ee3c..0000000000 --- a/mindspore/ccsrc/frontend/operator/ops.cc +++ /dev/null @@ -1,288 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/operator/ops.h" -#include -#include - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -// Arithmetic -const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); -const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); -const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); -const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); -const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); -const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); -const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); -const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); -const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); -const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); -const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); -const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); -const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); -const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); -const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); -const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); - -// Comparisons -const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); -const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); -const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); -const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); -const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); -const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); -const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); -const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); -const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); -const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); -const PrimitivePtr kPrimGreater = std::make_shared("Greater"); -const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); -const PrimitivePtr kPrimLess = std::make_shared("Less"); -const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); -const PrimitivePtr kPrimEqual = std::make_shared("Equal"); -const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); - -// Type introspection -const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); -const PrimitivePtr kPrimHasType = std::make_shared("hastype"); - -// Statements -const PrimitivePtr kPrimSwitch = std::make_shared("switch"); -const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); -const PrimitivePtr kPrimReturn = std::make_shared("return"); -const PrimitivePtr kPrimAssign = std::make_shared("Assign"); -const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); -const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); -const PrimitivePtr kPrimSelect = std::make_shared("Select"); -const PrimitivePtr kPrimCall = std::make_shared("call"); - -const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); -const PrimitivePtr kPrimDot = std::make_shared("dot"); -const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); -const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); -const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); -const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); - -const PrimitivePtr kPrimResolve = std::make_shared("resolve"); -const PrimitivePtr kPrimEmbed = std::make_shared("embed"); -const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); -const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); - -const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); -const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); -const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); - -// Structure -const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); -const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); -const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); -const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); -const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); -const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); -const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); -const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); -const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); -const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); -const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); -const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); -const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); -const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); -const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); -const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); -const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); -const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); -const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); -const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); -const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); -const PrimitivePtr kPrimListLen = std::make_shared("list_len"); -const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); -const PrimitivePtr kPrimListMap = std::make_shared("list_map"); -const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); -const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); - -const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); -const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); -const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); -const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); -const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); -const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); -const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); -const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); -const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); -const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); -const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); - -// Arrays -const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); -const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); -const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); -const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); -const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); -const PrimitivePtr kPrimShape = std::make_shared("Shape"); -const PrimitivePtr kPrimCast = std::make_shared("Cast"); -const PrimitivePtr kPrimConcat = std::make_shared("Concat"); -const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); -const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); -const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); -const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); -const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); -const PrimitivePtr kPrimSize = std::make_shared("Size"); -const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); -const PrimitivePtr kPrimPack = std::make_shared("Pack"); -const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); -const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); -const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); -const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); -const PrimitivePtr kPrimTile = std::make_shared("Tile"); -const PrimitivePtr kPrimAddN = std::make_shared("AddN"); -const PrimitivePtr KPrimTransData = std::make_shared("TransData"); -const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); -const PrimitivePtr kPrimPad = std::make_shared("Pad"); -const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); - -// Maths -const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); -const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); -const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); -const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); -const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); -const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); -const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); -const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); -const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); -const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); -const PrimitivePtr kPrimNeg = std::make_shared("Neg"); -const PrimitivePtr kPrimSub = std::make_shared("Sub"); -const PrimitivePtr kPrimMul = std::make_shared("Mul"); -const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); -const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); -const PrimitivePtr kPrimSquare = std::make_shared("Square"); -const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); -const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); -const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); -const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); -const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); -const PrimitivePtr kPrimPow = std::make_shared("Pow"); -const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); -const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); -const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); -const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); - -// NN -const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); -const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); -const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); -const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); -const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); -const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); -const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); -const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); -const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); -const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); -const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); -const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); -const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); -const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); -const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); -const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); -const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); -const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); -const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); -const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = - std::make_shared("DepthwiseConv2dNativeBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = - std::make_shared("DepthwiseConv2dNativeBackpropInput"); -const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); -const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared("SoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = - std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); -const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); -const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); -const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); -const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); -const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); -const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); -const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); -const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); -const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); -const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); -const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); -const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); -const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); -const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); -const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); -const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); -const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); -const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); - -// Other miscellaneous -const PrimitivePtr kPrimIdentity = std::make_shared("identity"); -const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -const PrimitivePtr kPrimJ = std::make_shared("J"); -const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); -const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); -const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); -const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); -const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); -const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); -const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); -const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); -const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); -const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); -const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); -const PrimitivePtr kPrimPrint = std::make_shared("Print"); - -const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); -const PrimitivePtr kPrimDepend = std::make_shared("Depend"); -const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); - -const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); -const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); -const PrimitivePtr kPrimIs_ = std::make_shared("is_"); -const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); -const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); -const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); -const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); -const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); -const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); - -// Comm ops -const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); -const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); -const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); -const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); - -// Debug ops -const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); -const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); -const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); -const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -const PrimitivePtr kPrimDebug = std::make_shared("Debug"); - -// IndexedSlices -const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); -const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); -const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); -const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); -const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 0dea045a6e..85b0ac474f 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -14,14 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_OPS_H_ -#define MINDSPORE_CCSRC_OPERATOR_OPS_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_OPS_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_OPS_H_ #include #include #include #include "ir/anf.h" #include "ir/primitive.h" +#include "base/core_ops.h" namespace mindspore { // namespace to support primitive operators @@ -31,290 +32,159 @@ ValuePtr GetPythonOps(const std::string &op_name, bool use_signature = false); // Arithmetic -extern const PrimitivePtr kPrimScalarAdd; -extern const PrimitivePtr kPrimScalarSub; -extern const PrimitivePtr kPrimScalarMul; -extern const PrimitivePtr kPrimScalarDiv; -extern const PrimitivePtr kPrimScalarFloordiv; -extern const PrimitivePtr kPrimScalarMod; -extern const PrimitivePtr kPrimScalarPow; -extern const PrimitivePtr kPrimScalarTrunc; -extern const PrimitivePtr kPrimScalarFloor; -extern const PrimitivePtr kPrimScalarUadd; -extern const PrimitivePtr kPrimScalarUsub; -extern const PrimitivePtr kPrimScalarExp; -extern const PrimitivePtr kPrimScalarLog; -extern const PrimitivePtr kPrimScalarSin; -extern const PrimitivePtr kPrimScalarCos; -extern const PrimitivePtr kPrimScalarTan; +inline const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); +inline const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); +inline const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); +inline const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); +inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); +inline const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); +inline const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); +inline const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); +inline const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); +inline const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); +inline const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); +inline const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); +inline const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); +inline const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); +inline const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); +inline const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); // Comparisons -extern const PrimitivePtr kPrimScalarEq; -extern const PrimitivePtr kPrimScalarLt; -extern const PrimitivePtr kPrimScalarGt; -extern const PrimitivePtr kPrimScalarNe; -extern const PrimitivePtr kPrimScalarLe; -extern const PrimitivePtr kPrimScalarGe; -extern const PrimitivePtr kPrimBoolNot; -extern const PrimitivePtr kPrimBoolAnd; -extern const PrimitivePtr kPrimBoolOr; -extern const PrimitivePtr kPrimBoolEq; -extern const PrimitivePtr kPrimGreater; -extern const PrimitivePtr kPrimGreaterEqual; -extern const PrimitivePtr kPrimLess; -extern const PrimitivePtr kPrimLessEqual; -extern const PrimitivePtr kPrimEqual; -extern const PrimitivePtr kPrimNotEqual; +inline const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); +inline const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); +inline const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); +inline const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); +inline const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); +inline const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); +inline const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); +inline const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); +inline const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); +inline const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); +inline const PrimitivePtr kPrimGreater = std::make_shared("Greater"); +inline const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); +inline const PrimitivePtr kPrimLess = std::make_shared("Less"); +inline const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); +inline const PrimitivePtr kPrimEqual = std::make_shared("Equal"); +inline const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); // Type introspection -extern const PrimitivePtr kPrimTypeOf; -extern const PrimitivePtr kPrimHasType; +inline const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); +inline const PrimitivePtr kPrimHasType = std::make_shared("hastype"); -// Statements -extern const PrimitivePtr kPrimSwitch; -extern const PrimitivePtr kPrimSwitchLayer; -extern const PrimitivePtr kPrimReturn; -extern const PrimitivePtr kPrimAssign; -extern const PrimitivePtr kPrimAssignAdd; -extern const PrimitivePtr kPrimAssignSub; -extern const PrimitivePtr kPrimSelect; -extern const PrimitivePtr kPrimCall; +inline const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); +inline const PrimitivePtr kPrimDot = std::make_shared("dot"); +inline const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); +inline const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); +inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); +inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); -extern const PrimitivePtr kPrimDistribute; -extern const PrimitivePtr kPrimDot; -extern const PrimitivePtr kPrimIm2Col; -extern const PrimitivePtr kPrimCol2Im; -extern const PrimitivePtr kPrimIm2ColV1; -extern const PrimitivePtr kPrimCol2ImV1; +inline const PrimitivePtr kPrimResolve = std::make_shared("resolve"); +inline const PrimitivePtr kPrimEmbed = std::make_shared("embed"); +inline const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); +inline const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); -extern const PrimitivePtr kPrimResolve; -extern const PrimitivePtr kPrimEmbed; -extern const PrimitivePtr kPrimRefToEmbed; -extern const PrimitivePtr kPrimCreateInstance; - -extern const PrimitivePtr kPrimLabelGoto; -extern const PrimitivePtr kPrimLabelSwitch; -extern const PrimitivePtr kPrimLabelSet; - -// Structure -extern const PrimitivePtr kPrimStringEqual; -extern const PrimitivePtr kPrimStringConcat; -extern const PrimitivePtr kPrimMakeTuple; -extern const PrimitivePtr kPrimMakeList; -extern const PrimitivePtr kPrimMakeDict; -extern const PrimitivePtr kPrimMakeKeywordArg; -extern const PrimitivePtr kPrimExtractKeywordArg; -extern const PrimitivePtr kPrimMakeSlice; -extern const PrimitivePtr kPrimMakeRecord; -extern const PrimitivePtr kPrimTupleGetItem; -extern const PrimitivePtr kPrimListGetItem; -extern const PrimitivePtr kPrimArrayGetItem; -extern const PrimitivePtr kPrimTupleSetItem; -extern const PrimitivePtr kPrimListSetItem; -extern const PrimitivePtr kPrimArraySetItem; -extern const PrimitivePtr kPrimDictGetItem; -extern const PrimitivePtr kPrimDictSetItem; -extern const PrimitivePtr kPrimListAppend; -extern const PrimitivePtr kPrimGetAttr; -extern const PrimitivePtr kPrimTupleLen; -extern const PrimitivePtr kPrimDictLen; -extern const PrimitivePtr kPrimListLen; -extern const PrimitivePtr kPrimArrayLen; -extern const PrimitivePtr kPrimListMap; -extern const PrimitivePtr kPrimListReduce; -extern const PrimitivePtr kPrimTupleReversed; -extern const PrimitivePtr kPrimTileShape; -extern const PrimitivePtr kPrimReducedShape; -extern const PrimitivePtr kPrimTupleDiv; -extern const PrimitivePtr kPrimTupleToArray; -extern const PrimitivePtr kPrimShapeMul; -extern const PrimitivePtr kPrimGenerateShapeIndex; -extern const PrimitivePtr kPrimGenerateInverseIndex; -extern const PrimitivePtr kPrimTupleEqual; -extern const PrimitivePtr kPrimListEqual; -extern const PrimitivePtr kPrimMakeRange; -extern const PrimitivePtr kPrimStopGradient; +inline const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); +inline const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); +inline const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); // Arrays -extern const PrimitivePtr kPrimScalarToArray; -extern const PrimitivePtr kPrimArrayToScalar; -extern const PrimitivePtr kPrimBroadcastShape; -extern const PrimitivePtr kPrimArrayMap; -extern const PrimitivePtr kPrimArrayReduce; -extern const PrimitivePtr kPrimShape; -extern const PrimitivePtr kPrimCast; -extern const PrimitivePtr kPrimConcat; -extern const PrimitivePtr kPrimSqueeze; -extern const PrimitivePtr kPrimTranspose; -extern const PrimitivePtr kPrimGatherV2; -extern const PrimitivePtr kPrimEmbeddingLookup; -extern const PrimitivePtr kPrimEmbeddingLookupCommGrad; -extern const PrimitivePtr kPrimSize; -extern const PrimitivePtr kPrimArgMax; -extern const PrimitivePtr kPrimPack; -extern const PrimitivePtr kPrimUnpack; -extern const PrimitivePtr kPrimUnsortedSegmentMin; -extern const PrimitivePtr kPrimUnsortedSegmentSum; -extern const PrimitivePtr kPrimConcatOffset; -extern const PrimitivePtr kPrimReshape; -extern const PrimitivePtr kPrimTile; -extern const PrimitivePtr kPrimAddN; -extern const PrimitivePtr KPrimTransData; -extern const PrimitivePtr kPrimNMSWithMask; -extern const PrimitivePtr kPrimPad; -extern const PrimitivePtr kPrimArgMaxWithValue; -extern const PrimitivePtr kPrimRealDiv; -extern const PrimitivePtr kPrimSqrt; -extern const PrimitivePtr kPrimReciprocal; -extern const PrimitivePtr kPrimExpandDims; - -// Maths -extern const PrimitivePtr kPrimTensorAdd; -extern const PrimitivePtr kPrimMatMul; -extern const PrimitivePtr kPrimBatchMatMul; -extern const PrimitivePtr kPrimMaximumGrad; -extern const PrimitivePtr kPrimMinimumGrad; -extern const PrimitivePtr kPrimReduceMean; -extern const PrimitivePtr kPrimReduceSum; -extern const PrimitivePtr kPrimReduceAll; -extern const PrimitivePtr kPrimReduceMax; -extern const PrimitivePtr kPrimReduceMin; -extern const PrimitivePtr kPrimNeg; -extern const PrimitivePtr kPrimSub; -extern const PrimitivePtr kPrimMul; -extern const PrimitivePtr kPrimRealDiv; -extern const PrimitivePtr kPrimMinimum; -extern const PrimitivePtr kPrimMaximum; -extern const PrimitivePtr kPrimSquare; -extern const PrimitivePtr kPrimSqrt; -extern const PrimitivePtr kPrimEqual; -extern const PrimitivePtr kPrimLess; -extern const PrimitivePtr kPrimLessEqual; -extern const PrimitivePtr kPrimCumSum; -extern const PrimitivePtr kPrimCumProd; -extern const PrimitivePtr kPrimSubscalar; -extern const PrimitivePtr kPrimInplaceAdd; -extern const PrimitivePtr kPrimInplaceSub; -extern const PrimitivePtr kPrimPow; +inline const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); +inline const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); +inline const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); +inline const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); +inline const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); +inline const PrimitivePtr kPrimCast = std::make_shared("Cast"); +inline const PrimitivePtr kPrimConcat = std::make_shared("Concat"); +inline const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); +inline const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); +inline const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); +inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); +inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); +inline const PrimitivePtr kPrimSize = std::make_shared("Size"); +inline const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); +inline const PrimitivePtr kPrimPack = std::make_shared("Pack"); +inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); +inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); +inline const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); +inline const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); +inline const PrimitivePtr kPrimTile = std::make_shared("Tile"); +inline const PrimitivePtr kPrimAddN = std::make_shared("AddN"); +inline const PrimitivePtr KPrimTransData = std::make_shared("TransData"); +inline const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); +inline const PrimitivePtr kPrimPad = std::make_shared("Pad"); +inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); +inline const PrimitivePtr kPrimUnique = std::make_shared("Unique"); +inline const PrimitivePtr kPrimUniqueGrad = std::make_shared("UniqueGrad"); // NN -extern const PrimitivePtr kPrimFlatten; -extern const PrimitivePtr kPrimSoftmax; -extern const PrimitivePtr kPrimLogSoftmax; -extern const PrimitivePtr kPrimLogSoftmaxGrad; -extern const PrimitivePtr kPrimApplyCenteredRMSProp; -extern const PrimitivePtr kPrimTanh; -extern const PrimitivePtr kPrimTanhGrad; -extern const PrimitivePtr kPrimPooling; -extern const PrimitivePtr kPrimPoolingGrad; -extern const PrimitivePtr kPrimFusedBatchNorm; -extern const PrimitivePtr kPrimBatchNorm; -extern const PrimitivePtr kPrimBatchNormGrad; -extern const PrimitivePtr kPrimConv2D; -extern const PrimitivePtr kPrimMaxPool; -extern const PrimitivePtr kPrimMaxPoolGrad; -extern const PrimitivePtr kPrimAvgPoolGrad; -extern const PrimitivePtr kPrimFusedBatchNormGrad; -extern const PrimitivePtr kPrimReluGrad; -extern const PrimitivePtr kPrimConv2DBackpropInput; -extern const PrimitivePtr kPrimConv2DBackpropFilter; -extern const PrimitivePtr kPrimDepthwiseConv2dNative; -extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter; -extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput; - -extern const PrimitivePtr kPrimBiasAddGrad; -extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits; -extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits; -extern const PrimitivePtr kPrimMomentum; -extern const PrimitivePtr kPrimApplyMomentum; -extern const PrimitivePtr kPrimLayerNorm; -extern const PrimitivePtr kPrimLayerNormGrad; -extern const PrimitivePtr kPrimLayerNormXBackprop; -extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; -extern const PrimitivePtr kPrimDropoutGenMask; -extern const PrimitivePtr kPrimDropoutDoMask; -extern const PrimitivePtr kPrimOneHot; -extern const PrimitivePtr kPrimGelu; -extern const PrimitivePtr kPrimGeluGrad; -extern const PrimitivePtr kPrimRelu; -extern const PrimitivePtr kPrimReluV2; -extern const PrimitivePtr kPrimActivation; -extern const PrimitivePtr kPrimZerosLike; -extern const PrimitivePtr kPrimFakeBprop; -extern const PrimitivePtr kPrimBpropCut; -extern const PrimitivePtr kPrimFakeQuantPerLayer; -extern const PrimitivePtr kPrimFakeQuantPerChannel; -extern const PrimitivePtr kPrimApplyRMSProp; - -// Other Miscellaneous -extern const PrimitivePtr kPrimIdentity; -extern const PrimitivePtr kPrimPartial; -extern const PrimitivePtr kPrimJ; -extern const PrimitivePtr kPrimEnvSetItem; -extern const PrimitivePtr kPrimEnvGetItem; -extern const PrimitivePtr kPrimEnvAdd; -extern const PrimitivePtr kPrimMakeRefKey; -extern const PrimitivePtr kPrimMakeRef; -extern const PrimitivePtr kPrimGetRefKey; -extern const PrimitivePtr kPrimGetRefValue; -extern const PrimitivePtr kPrimGetRefOrigin; -extern const PrimitivePtr kPrimInsertGradientOf; -extern const PrimitivePtr kPrimHookBackward; -extern const PrimitivePtr kPrimPrintShapeType; -extern const PrimitivePtr kPrimPrint; -extern const PrimitivePtr kPrimSameTypeShape; -extern const PrimitivePtr kPrimCheckBprop; -extern const PrimitivePtr kPrimDepend; -extern const PrimitivePtr kPrimStateSetItem; -extern const PrimitivePtr kPrimScalarSummary; -extern const PrimitivePtr kPrimImageSummary; -extern const PrimitivePtr kPrimTensorSummary; -extern const PrimitivePtr kPrimHistogramSummary; -extern const PrimitivePtr kPrimBroadcastGradientArgs; -extern const PrimitivePtr kPrimControlDepend; -extern const PrimitivePtr kPrimIs_; -extern const PrimitivePtr kPrimIsNot; -extern const PrimitivePtr kPrimInDict; -extern const PrimitivePtr kPrimNotInDict; -extern const PrimitivePtr kPrimMixedPrecisionCast; -extern const PrimitivePtr kPrimIsConsant; -extern const PrimitivePtr kPrimEquivFormat; -extern const PrimitivePtr kPrimDebug; +inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); +inline const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); +inline const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); +inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); +inline const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); +inline const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); +inline const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); +inline const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); +inline const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); +inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); +inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); +inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); +inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared("AvgPoolGradVm"); +inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); +inline const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); +inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); +inline const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); +inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); +inline const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); +inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); +inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); +inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); +inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = + std::make_shared("DepthwiseConv2dNativeBackpropFilter"); +inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = + std::make_shared("DepthwiseConv2dNativeBackpropInput"); +inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); +inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = + std::make_shared("SoftmaxCrossEntropyWithLogits"); +inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = + std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); +inline const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); +inline const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); +inline const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); +inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); +inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); +inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); +inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); +inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); +inline const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); +inline const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); +inline const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); +inline const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); +inline const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); +inline const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); +inline const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); +inline const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); +inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); +inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); +inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); // Comm ops -extern const PrimitivePtr kPrimAllReduce; -extern const PrimitivePtr kPrimMirror; -extern const PrimitivePtr kPrimVirtualDiv; -extern const PrimitivePtr kPrimVirtualDataset; - -// IndexedSlices -extern const PrimitivePtr kPrimMakeIndexedSlices; -extern const PrimitivePtr kPrimIndexedSlicesGetValues; -extern const PrimitivePtr kPrimIndexedSlicesGetIndices; -extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; -extern const PrimitivePtr kPrimIsIndexedSlices; - -// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll -const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; -// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it -// will be sunk(i.e. not unrolled) -const int MAX_FOR_LOOP_COUNT = 600; - -class DoSignaturePrimitive : public Primitive { - public: - explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) - : Primitive("S-Prim-" + name), function_(function) {} - - ~DoSignaturePrimitive() override = default; - - MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) - - const ValuePtr function() const { return function_; } - - private: - ValuePtr function_; -}; -using DoSignaturePrimitivePtr = std::shared_ptr; +inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); +inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); +inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); + +// RowTensor +inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared("MakeRowTensor"); +inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared("RowTensorGetValues"); +inline const PrimitivePtr kPrimRowTensorGetIndices = std::make_shared("RowTensorGetIndices"); +inline const PrimitivePtr kPrimRowTensorGetDenseShape = std::make_shared("RowTensorGetDenseShape"); + +// SparseTensor +inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared("MakeSparseTensor"); +inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared("SparseTensorGetValues"); +inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared("SparseTensorGetIndices"); +inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared("SparseTensorGetDenseShape"); class UnpackGraphPrimitive : public Primitive { public: @@ -333,4 +203,4 @@ using UnpackGraphPrimitivePtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_OPS_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_OPS_H_ diff --git a/mindspore/ccsrc/frontend/operator/prim_arrays.cc b/mindspore/ccsrc/frontend/operator/prim_arrays.cc index caaf1d1b2a..ea0725ae6e 100644 --- a/mindspore/ccsrc/frontend/operator/prim_arrays.cc +++ b/mindspore/ccsrc/frontend/operator/prim_arrays.cc @@ -15,7 +15,6 @@ */ #include "pipeline/jit/static_analysis/prim.h" -#include "frontend/operator/ops.h" #include "abstract/utils.h" #include "frontend/operator/cc_implementations.h" #include "abstract/param_validator.h" @@ -80,23 +79,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti return std::make_shared(elems); } -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, 0); - MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); - - AbstractBasePtrList values; - auto shp = arg->shape(); - for (int entry : shp->shape()) { - auto entry_v = MakeValue(entry); - values.push_back(std::make_shared(entry_v, entry_v->type())); - } - return std::make_shared(values); -} - AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor and a tuple. @@ -166,5 +148,47 @@ AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &pri ret->set_shape(std::make_shared(shape)); return ret; } + +AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // inputs: a 1-d Tensor + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + + auto shape = input->shape(); + if (shape->shape().size() != 1) { + MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; + } + std::vector ids_shape = {Shape::SHP_ANY}; + std::vector min_shape = {1}; + std::vector max_shape = shape->shape(); + auto ids = + std::make_shared(input->element(), std::make_shared(ids_shape, min_shape, max_shape)); + auto ids_idx = std::make_shared(std::make_shared(32), shape->shape()); + // outputs: ids, ids_idx + AbstractBasePtrList elements = {ids, ids_idx}; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // inputs: a 1-d Tensor + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr dout = CheckArg(op_name, args_spec_list, 0); + CheckArgsSize(op_name + " dout", dout->elements(), 2); + auto ids = CheckArg(op_name, dout->elements(), 0); + auto ids_idx = CheckArg(op_name, dout->elements(), 1); + if (ids->shape()->shape().size() != 1) { + MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1."; + } + if (ids_idx->shape()->shape().size() != 1) { + MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1."; + } + + // outputs: dx + return std::make_shared(ids->element(), ids_idx->shape()); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_maths.cc b/mindspore/ccsrc/frontend/operator/prim_maths.cc index e4543a3821..5d06fb8603 100644 --- a/mindspore/ccsrc/frontend/operator/prim_maths.cc +++ b/mindspore/ccsrc/frontend/operator/prim_maths.cc @@ -18,7 +18,7 @@ #include "frontend/operator/ops.h" #include "abstract/utils.h" #include "abstract/param_validator.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace abstract { diff --git a/mindspore/ccsrc/frontend/operator/prim_nn.cc b/mindspore/ccsrc/frontend/operator/prim_nn.cc index 96c86d815d..67c23307e5 100644 --- a/mindspore/ccsrc/frontend/operator/prim_nn.cc +++ b/mindspore/ccsrc/frontend/operator/prim_nn.cc @@ -402,31 +402,36 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti for (std::size_t i = 0; i < x_shape->size(); ++i) { auto value_track = x_shape_data[i]->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); - if (!value_track->isa()) { - MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << "."; + int64_t e_value = 0; + if (value_track->isa()) { + e_value = GetValue(value_track); + } else if (value_track->isa()) { + e_value = static_cast(GetValue(value_track)); + } else { + MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int64 or int32, but " + << value_track->ToString() << "."; } - int e_value = GetValue(value_track); if (e_value <= 0) { MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0"; } - if (std::numeric_limits::max() / count / e_value < 1) { + if (std::numeric_limits::max() / count / e_value < 1) { MS_LOG(EXCEPTION) << "integer multiply integer overflow"; } count = count * e_value; } // convert to bytes(8 bits) mask, using round up - int n128s = count / 128; + int64_t n128s = count / 128; if ((count % 128) != 0) { n128s++; } - int bytes_count = n128s * 16; - std::vector shape_y{bytes_count}; + int64_t bytes_count = n128s * 16; + std::vector shape_y{bytes_count}; primitive->set_attr("T", kInt32); return std::make_shared(std::make_shared(kAnyValue, kUInt8), - std::make_shared(std::vector{shape_y})); + std::make_shared(std::vector{shape_y})); } } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc index 530ad6a10c..4b2a5be482 100644 --- a/mindspore/ccsrc/frontend/operator/prim_others.cc +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -18,12 +18,12 @@ #include #include "ir/dtype.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/operator/ops.h" #include "abstract/param_validator.h" #include "pipeline/jit/static_analysis/prim.h" #include "abstract/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/symbolic.h" namespace mindspore { @@ -340,8 +340,8 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv return std::make_shared(kAnyValue, kBool); } -AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { +AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors and a tuple. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 3); @@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim auto values = CheckArg(op_name, args_spec_list, 1); auto dense_shape = CheckArg(op_name, args_spec_list, 2); + auto indices_dtype = indices->element()->BuildType(); + if (!indices_dtype->isa()) { + MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString(); + } + auto indices_shp = indices->shape()->shape(); + if (indices_shp.size() != 1) { + MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size() + << " dimension tensor"; + } + auto values_shp = values->shape()->shape(); + if (indices_shp[0] != values_shp[0]) { + MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values " + << values_shp[0] << ", but got " << indices_shp[0]; + } + + for (auto elem_type : dense_shape->ElementsType()) { + if (!elem_type->isa()) { + MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString(); + } + } auto dense_shape_value = dense_shape->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(dense_shape_value); auto shp = dense_shape_value->value(); @@ -358,53 +378,145 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim auto elem = GetValue(e); return elem; }); - auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); + if (dense_shape_vec.size() != values_shp.size()) { + MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values " + << values_shp.size() << ", but got " << dense_shape_value->size(); + } + for (size_t i = 0; i < dense_shape_vec.size(); i++) { + if (dense_shape_vec[i] < 0) { + MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got " + << dense_shape_vec[i]; + } + // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection + if (i != 0 && dense_shape_vec[i] != values_shp[i]) { + MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i + << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; + } + } + auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); ret->set_indices(indices); ret->set_values(values); ret->set_dense_shape(dense_shape); return ret; } -AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto row_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(row_tensor->values()); + return row_tensor->values(); +} + +AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto row_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(row_tensor->indices()); + return row_tensor->indices(); +} + +AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors and a tuple. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->values()); - return indexed_slices->values(); + auto row_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(row_tensor->dense_shape()); + return row_tensor->dense_shape(); } -AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { +AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto indices = CheckArg(op_name, args_spec_list, 0); + auto values = CheckArg(op_name, args_spec_list, 1); + auto dense_shape = CheckArg(op_name, args_spec_list, 2); + + auto indices_dtype = indices->element()->BuildType(); + if (!indices_dtype->isa()) { + MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString(); + } + auto indices_shp = indices->shape()->shape(); + if (indices_shp.size() != 2) { + MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size() + << " dimension tensor"; + } + auto values_shp = values->shape()->shape(); + if (values_shp.size() != 1) { + MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size() + << " dimension tensor"; + } + if (indices_shp[0] != values_shp[0]) { + MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values " + << values_shp[0] << ", but got " << indices_shp[0]; + } + + for (auto elem_type : dense_shape->ElementsType()) { + if (!elem_type->isa()) { + MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString(); + } + } + auto dense_shape_value = dense_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(dense_shape_value); + auto shp = dense_shape_value->value(); + std::vector dense_shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), + [](const ValuePtr &e) -> int { + auto elem = GetValue(e); + return elem; + }); + if (IntToSize(indices_shp[1]) != dense_shape_vec.size()) { + MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices " + << indices_shp[1] << ", but got " << dense_shape_vec.size(); + } + for (auto dense_shape_elem : dense_shape_vec) { + if (dense_shape_elem < 0) { + MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " + << dense_shape_value->ToString(); + } + } + auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); + ret->set_indices(indices); + ret->set_values(values); + ret->set_dense_shape(dense_shape); + return ret; +} + +AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors and a tuple. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->indices()); - return indexed_slices->indices(); + auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(sparse_tensor->values()); + return sparse_tensor->values(); } -AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { +AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors and a tuple. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); - return indexed_slices->dense_shape(); + auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(sparse_tensor->indices()); + return sparse_tensor->indices(); } -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { +AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 1); - bool ret = false; - if (args_spec_list[0]->isa()) { - ret = true; - } - MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); - return std::make_shared(ret); + auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape()); + return sparse_tensor->dense_shape(); } } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_statement.cc b/mindspore/ccsrc/frontend/operator/prim_statement.cc index bb421bdf8a..6a7f54007b 100644 --- a/mindspore/ccsrc/frontend/operator/prim_statement.cc +++ b/mindspore/ccsrc/frontend/operator/prim_statement.cc @@ -108,11 +108,6 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &p auto fb = args_spec_list[2]; MS_EXCEPTION_IF_NULL(cond); - auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG); - if (unroll_flag != nullptr && GetValue(unroll_flag) == 0) { - return tb->Join(fb); - } - ValuePtr v = cond->GetValueTrack(); MS_EXCEPTION_IF_NULL(v); // for tensor as condition, keeps both true and false branch. @@ -137,7 +132,17 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP // Inputs: index, branch const std::string op_name = primitive->name(); abstract::CheckArgsSize(op_name, args_spec_list, 2); - (void)CheckArg(op_name, args_spec_list, 0); + auto index = CheckArg(op_name, args_spec_list, 0); + auto &input_shape = index->shape()->shape(); + if (input_shape.size() != 0) { + MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size() + << " dimension tensor"; + } + auto dtype = index->element()->BuildType(); + if (dtype->type_id() != kInt32->type_id()) { + MS_EXCEPTION(ValueError) << op_name << " index must be a int32, but got " << dtype->ToString(); + } + AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); AbstractBasePtrList branches = branches_abs->elements(); const size_t maximum_layer_num = 1000; diff --git a/mindspore/ccsrc/frontend/operator/prim_structures.cc b/mindspore/ccsrc/frontend/operator/prim_structures.cc index b602b07a0c..cc53f9aa22 100644 --- a/mindspore/ccsrc/frontend/operator/prim_structures.cc +++ b/mindspore/ccsrc/frontend/operator/prim_structures.cc @@ -21,7 +21,7 @@ #include "abstract/param_validator.h" #include "frontend/operator/ops.h" #include "utils/convert_utils.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" using mindspore::tensor::TensorPy; @@ -150,11 +150,12 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr for (size_t index = 0; index < args_size; index++) { MS_EXCEPTION_IF_NULL(args_spec_list[index]); if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { - MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; + MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; } if (args_spec_list[index]->isa() && !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { - MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number."; + MS_EXCEPTION(TypeError) << "MakeSlice eval " << index + << " parameter is an AbstractScalar, but is not an int32 number."; } } // Slice: start, end, step diff --git a/mindspore/ccsrc/frontend/operator/prim_to_function.h b/mindspore/ccsrc/frontend/operator/prim_to_function.h index 285ab8d3ab..fe434d0129 100644 --- a/mindspore/ccsrc/frontend/operator/prim_to_function.h +++ b/mindspore/ccsrc/frontend/operator/prim_to_function.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPERATOR_PRIM_TO_FUNCTION_H_ -#define MINDSPORE_CCSRC_OPERATOR_PRIM_TO_FUNCTION_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_PRIM_TO_FUNCTION_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_PRIM_TO_FUNCTION_H_ #include #include @@ -60,4 +60,4 @@ class PrimToFunction { }; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPERATOR_PRIM_TO_FUNCTION_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_PRIM_TO_FUNCTION_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h index 37986e6810..5515ad459c 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_ADJOINT_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_ADJOINT_H_ #include #include @@ -54,4 +54,4 @@ using AdjointPtr = std::shared_ptr; } // namespace ad } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_ADJOINT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index b314b22f81..d4fe201710 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -22,7 +22,7 @@ #include "ir/anf.h" #include "ir/meta_func_graph.h" -#include "debug/info.h" +#include "utils/info.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" #include "pipeline/jit/resource.h" @@ -32,7 +32,7 @@ #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "utils/symbolic.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "./common.h" namespace mindspore { @@ -216,6 +216,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { TraceManager::DebugTrace(std::make_shared(cnode_morph->debug_info())); auto k_app = k_graph_->NewCNode(inputs); TraceManager::EndTrace(); + ReplaceEquivdout(k_app, cnode_morph->forward()); for (size_t i = 0; i < param_adjoints.size(); ++i) { param_adjoints[i]->RegisterKUser(k_app, i); } @@ -237,6 +238,37 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { return node_adjoint; } +void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) { + if (forward == nullptr) { + return; + } + auto &input = cnode->input(0); + if (!IsValueNode(input)) { + return; + } + auto fg = GetValueNode(input); + auto output = fg->output(); + if (!output->isa()) { + return; + } + auto cnode_output = output->cast(); + auto &cnode_input = cnode_output->input(1); + if (!cnode_input->isa()) { + return; + } + auto &input_fg = cnode_output->input(2); + if (!IsValueNode(input_fg)) { + return; + } + auto equivdout = cnode_input->cast(); + auto func_graph = GetValueNode(input_fg); + auto manager = Manage({fg, func_graph}, false); + MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; + auto value_node = NewValueNode(forward); + value_node->set_has_new_value(true); + manager->Replace(equivdout, value_node); +} + bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { // Do not care about non-CNode if (!node->isa()) { diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 9ee93334e8..0a25d3e396 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_ #include #include @@ -95,6 +95,7 @@ class DFunctor : public std::enable_shared_from_this { // Update k hole with adjoint_definition, only applied in recursive case. void UpdateAdjoint(const AdjointPtr &adjoint_definition); void CallDoutHoleOnTape(); + void ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward); std::unordered_map anfnode_to_adjoin_; // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. @@ -207,4 +208,4 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { } // namespace ad } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc index ef2d7d400a..b11d063db6 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -17,9 +17,9 @@ #include "frontend/optimizer/ad/grad.h" #include "frontend/optimizer/ad/dfunctor.h" #include "ir/func_graph_cloner.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/symbolic.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" namespace mindspore { namespace ad { diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.h b/mindspore/ccsrc/frontend/optimizer/ad/grad.h index ee9ab79ffb..8d3de5e201 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_ #include #include @@ -35,4 +35,4 @@ void CleanRes(); } // namespace ad } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 5ca2ca6c43..4380d16282 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -20,7 +20,7 @@ #include #include #include "ir/anf.h" -#include "ir/primitive_py.h" +#include "utils/primitive_py.h" #include "ir/meta_func_graph.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" @@ -32,8 +32,8 @@ #include "frontend/operator/composite/composite.h" #include "utils/symbolic.h" #include "utils/primitive_utils.h" -#include "utils/context/ms_context.h" -#include "debug/info.h" +#include "utils/ms_context.h" +#include "utils/info.h" #include "debug/trace.h" #include "./common.h" @@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { return meta; } + if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { + MetaFuncGraphPtr meta = std::make_shared("make_list_gradient"); + bprop_registry_meta_[prim::kPrimMakeList] = meta; + return meta; + } + MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } @@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R return fprop; } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { return nullptr; + } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { + return nullptr; } FuncGraphPtr bprop_fg = nullptr; @@ -264,7 +272,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re return IsPrimitiveCNode(user.first, prim); }); if (cnode == users.end()) { - MS_LOG(EXCEPTION) << "Fail to find cnode."; + MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString(); } auto inputs_num = cnode->first->cast()->inputs().size() - 1; diff --git a/mindspore/ccsrc/frontend/optimizer/anf_visitor.h b/mindspore/ccsrc/frontend/optimizer/anf_visitor.h new file mode 100644 index 0000000000..a1dd3aed04 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/anf_visitor.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_ANF_VISITOR_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_ANF_VISITOR_H_ + +#include +#include "ir/visitor.h" +#include "frontend/optimizer/optimizer_caller.h" + +namespace mindspore { +class AnfVisitor : public AnfIrVisitor, public OptimizerCaller {}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_ANF_VISITOR_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index e35760ceaf..edff3a8d79 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -34,7 +34,9 @@ using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractUndetermined; @@ -59,11 +61,33 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) { [](const AbstractAttribute &item) { return item.second; }); return std::make_shared(baselist); } + + return nullptr; +} + +static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { + if (t == nullptr) { + return nullptr; + } + if (t->isa()) { auto abs_list = dyn_cast(t); return std::make_shared(abs_list->elements()); } + if (t->isa()) { + auto abs_sparse = dyn_cast(t); + std::vector abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()}; + return std::make_shared(abstract_list); + } + + if (t->isa()) { + auto abs_row_tensor = dyn_cast(t); + std::vector abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(), + abs_row_tensor->dense_shape()}; + return std::make_shared(abstract_list); + } + return nullptr; } @@ -358,7 +382,71 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr new_node = EraseMakeKeywordArgNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { new_node = EraseExtractKeywordArg(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { + } + + if (new_node != nullptr) { + new_node->set_abstract(node->abstract()); + MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); + (void)manager->Replace(node, new_node); + changed = true; + } + } + + for (auto &node : manager->all_nodes()) { + auto ret = Reabs(node->abstract()); + if (ret) { + MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " + << ret->ToString(); + node->set_abstract(ret); + changed = true; + } + } + return changed; +} + +AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items; + (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); + return node->func_graph()->NewCNode(inputs); +} + +AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int &index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [spase_getattr, sparse] + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "Node's input number < 2."; + } + + AnfNodePtr sparse = inputs[1]; + MS_EXCEPTION_IF_NULL(sparse); + auto cons_node = NewValueNode(index); + AbstractBasePtr aptr = std::make_shared(std::make_shared(index)); + cons_node->set_abstract(aptr); + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node}); +} + +bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + bool changed = false; + + // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var + auto all_node = manager->all_nodes(); + for (auto &node : all_node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + AnfNodePtr new_node = nullptr; + if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { new_node = ConvertMakeListToMakeTuple(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { new_node = ConvertListGetItemToTupleGetItem(cnode); @@ -366,6 +454,18 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr new_node = ConvertListSetItemToTupleSetItem(cnode); } else if (IsValueNode(node)) { new_node = ConvertValueListNodeToValueTupleNode(node->cast()); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) || + IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) { + new_node = ConvertMakeSparseToMakeTuple(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) || + IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) { + new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0); + } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) || + IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) { + new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1); + } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) || + IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) { + new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2); } if (new_node != nullptr) { @@ -377,7 +477,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr } for (auto &node : manager->all_nodes()) { - auto ret = Reabs(node->abstract()); + auto ret = AdaptAbs(node->abstract()); if (ret) { MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " << ret->ToString(); diff --git a/mindspore/ccsrc/frontend/optimizer/clean.h b/mindspore/ccsrc/frontend/optimizer/clean.h index 54faabaa63..abd847f8c1 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.h +++ b/mindspore/ccsrc/frontend/optimizer/clean.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_ #include #include "ir/anf.h" @@ -32,6 +32,7 @@ namespace opt { // Remove the class type from graphs bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); +bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); // Remove most uses of tuples from the graph // tuples that are returned will be kept @@ -40,4 +41,4 @@ void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/control_depend.h b/mindspore/ccsrc/frontend/optimizer/control_depend.h index 076e2c0229..60a00e5b51 100644 --- a/mindspore/ccsrc/frontend/optimizer/control_depend.h +++ b/mindspore/ccsrc/frontend/optimizer/control_depend.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CONTROL_DEPEND_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CONTROL_DEPEND_H_ #include "ir/anf.h" @@ -25,4 +25,4 @@ namespace opt { void AddControlDepend(const FuncGraphPtr &graph); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CONTROL_DEPEND_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 4d968d6d74..c80b54097d 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) { if (node_abs == nullptr) { return kAnyValue; } + // Ignore the tracking_id and prim pointer hash; + if (node_abs->isa()) { + auto prim_abs = node_abs->cast(); + return prim_abs->prim(); + } return node_abs; } diff --git a/mindspore/ccsrc/frontend/optimizer/cse.h b/mindspore/ccsrc/frontend/optimizer/cse.h index 140f592715..55058f60e8 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.h +++ b/mindspore/ccsrc/frontend/optimizer/cse.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CSE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CSE_H_ #include #include @@ -58,4 +58,4 @@ BasePtr AbsOf(const AnfNodePtr &node); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CSE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc index c157777040..ef526559e5 100644 --- a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc @@ -19,7 +19,7 @@ #include #include #include "./common.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" namespace mindspore { /* namespace to support opt */ diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h index a79ef3ce6d..0a786e8ead 100644 --- a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H -#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H #include #include @@ -49,4 +49,4 @@ class GraphKernelReuse { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index efc3795a4c..d1d29fcbae 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -41,8 +41,10 @@ #include "frontend/optimizer/irpass/symbol_resolver.h" #include "frontend/optimizer/irpass/tile_eliminate.h" #include "frontend/optimizer/irpass/transpose_eliminate.h" +#include "frontend/optimizer/irpass/value_based_eliminate.h" #include "frontend/optimizer/opt.h" -#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" +#include "frontend/optimizer/irpass/row_tensor_eliminate.h" +#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" namespace mindspore { namespace opt { @@ -64,7 +66,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // ops eliminate item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", - {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); @@ -155,13 +157,24 @@ OptimizeIRPassLib::OptimizeIRPassLib() { mark_interface_fusion_ = MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); - // IndexedSlices Eliminate - indexed_slices_eliminate_ = MakeSubstitution( - std::make_shared(), "indexed_slices_eliminate", - {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); + // RowTensor Eliminate + row_tensor_eliminate_ = MakeSubstitution( + std::make_shared(), "row_tensor_eliminate", + {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape}); + + // SparseTensor Eliminate + sparse_tensor_eliminate_ = MakeSubstitution( + std::make_shared(), "sparse_tensor_eliminate", + {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); + + // Value_Based Eliminate + value_based_eliminate_ = MakeSubstitution(std::make_shared(), "value_based_eliminate", + {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); } ResolveIRPassLib::ResolveIRPassLib() { + resolver_resolve_attr_ = + MakeSubstitution(std::make_shared(), "resolver_resolve_attr", prim::kPrimGetAttr); resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 4af8c0789d..5a0f2ed5b7 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_ #include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/opt.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" namespace mindspore { namespace opt { @@ -105,8 +105,14 @@ class OptimizeIRPassLib { // Fusion SubstitutionPtr mark_interface_fusion_; - // IndexedSlices Eliminate - SubstitutionPtr indexed_slices_eliminate_; + // RowTensor Eliminate + SubstitutionPtr row_tensor_eliminate_; + + // SparseTensor Eliminate + SubstitutionPtr sparse_tensor_eliminate_; + + // Value_Based Eliminate + SubstitutionPtr value_based_eliminate_; }; // the collection of irpass for resolve action @@ -115,6 +121,7 @@ class ResolveIRPassLib { ResolveIRPassLib(); ~ResolveIRPassLib() = default; + SubstitutionPtr resolver_resolve_attr_; SubstitutionPtr resolver_resolve_; SubstitutionPtr resolver_getattr_; }; @@ -189,4 +196,4 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index 83f7fae582..46cc91443b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -14,542 +14,76 @@ * limitations under the License. */ -#include -#include -#include -#include - #include "frontend/optimizer/irpass/arithmetic_simplify.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "frontend/operator/ops.h" -#include "frontend/optimizer/irpass.h" -#include "frontend/optimizer/irpass/prim_eliminate.h" -#include "frontend/optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarMul)(node); - - if (is_zero_) { - return NewValueNode(zero_); - } - if (is_one_) { - return x_; - } - return nullptr; -} - -void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { - if (is_one_ || node->isa()) { - x_ = node; - return; - } - - AnfVisitor::Visit(node); - if (!is_one_) { - x_ = node; - } -} - -void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (*value == *zero_) { - is_zero_ = true; - } else if (*value == *one_) { - is_one_ = true; - } -} - -void MultiplyByZeroOrOne::Reset() { - x_ = nullptr; - is_one_ = false; - is_zero_ = false; -} - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > FLT_EPSILON) { - return false; - } - } - return true; - } else if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > DBL_EPSILON) { - return false; - } - } - return true; - } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (data2[i] != check_value_) { - return false; - } - } - return true; - } - // input Data Types is not supported - return false; -} - -bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { - return false; - } - return IsTensorConstant(value); -} - -void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { - if (!node->isa()) { +AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { return nullptr; } + PatternNode x, y, z, xs; + PConstant one_(node, false, 1); + PConstant one_scalar_(node, false, 1, true); + PConstant zero_(node, false, 0); + PConstant zero_scalar_(node, false, 0, true); + PConstant const_(node); + PConstant const_2(node); + PConstant any_const(node); - auto value = node->cast()->value(); + MATCH_REPLACE(node, x + zero_, x); // Add by zero + MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero + MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one - if (!value->isa()) { - return nullptr; - } - - tensor::TensorPtr tensor_ptr = dyn_cast(value); - return tensor_ptr->data_c(); -} + // Scalar Mul by zero + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); + // Prim Eliminate (identity) + MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); -// Make a new tensor (when possible) with the same shape as of `node` -// If x is nullptr then fill new tensor will "0" -// If x is a tensor with empty shape then fill new tensor with the single value of x -// If x is a tensor with same shape as `node` then return x as result -AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - if (x == nullptr) { - std::memset(data, 0, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - // x is not nullptr - if (x->isa()) { - if ((x->abstract() == nullptr) || !x->abstract()->isa()) { - return nullptr; - } - auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); - - if (x_shape != tensor_shape) { - return nullptr; + // ConstantDuplicateMul + auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { + auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node)); + auto mul_node = node->cast()->inputs()[0]; + if (new_mul_tensor == nullptr) { + auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph()); + return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph()); } - return x; - } + return NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph()); + }; + MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda); - if (!x->isa()) { - return nullptr; - } - auto x_value = x->cast()->value(); - if (!x_value->isa()) { + if (node->func_graph() == nullptr) { return nullptr; } - auto x_tensor_ptr = dyn_cast(x_value); + // OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y} + MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0), + PPrimitive(prim::kPrimMakeTuple, z, y)); - if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { - return nullptr; - } - char *source_data = reinterpret_cast(GetPointerToTensorData(x)); - if (x_tensor_ptr->DataSize() == 1) { - for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { - memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); - } - } else { - memcpy(data, source_data, mem_size); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_zero_) { - if (x_->func_graph() != node->func_graph()) { - return nullptr; - } - return NewTensorFilledWithData(node); - } - return nullptr; -} - -void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { - if (is_zero_) { - x_ = node; - return; - } - - if (IsParam(node)) { - x_ = node; - return; - } - - if (IsCNode(node)) { - CNodePtr cnode = node->cast(); - if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } - x_ = node; - return; - } - auto value = node->cast()->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); + // PowerOneEliminate + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x, + one_scalar_.CheckFunc(IsValueNode, node)); - if (is_one_) { - return NewTensorFilledWithData(node, x_); - } return nullptr; } -void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { - if (is_one_) { - x_ = node; - return; - } - - if (IsParam(node) || IsCNode(node)) { - x_ = node; - return; - } - - auto value = node->cast()->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByOne::Reset() { - x_ = nullptr; - is_one_ = false; -} - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void AddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && - ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void AddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimTensorAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void TensorAddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void TensorAddByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } -} - -void TensorAddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { - return nullptr; - } - - // {PrimMomentum, {...}, Y, Z, Xs} - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { - return nullptr; - } - auto y = inputs[2]; - auto z = inputs[3]; - - // {kPrimZerosLike, X} - if (inputs[1]->cast()->size() != 2) { - return nullptr; - } - - // {prim::kPrimMakeTuple, Z, Y} - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); -} - -// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -// Support function to multiply two constant tensors: partially support broadcasting shapes -template -void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, - void **out_data, int out_data_size) { - T *data_1 = reinterpret_cast(in_data_1); - T *data_2 = reinterpret_cast(in_data_2); - T *data_out = new T[out_data_size]; - - if (in_data_1_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[i]; - } - } - if (in_data_2_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; - } - } - *out_data = reinterpret_cast(data_out); - return; -} - -AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, - const AnfNodePtr &node_3) { - if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || - (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { - return nullptr; - } - - auto value_1 = GetValueNode(vnode_1); - auto value_2 = GetValueNode(vnode_2); - - if (!value_1->isa() || !value_2->isa()) { - return nullptr; - } - - auto tensor_ptr_1 = dyn_cast(value_1); - auto tensor_ptr_2 = dyn_cast(value_2); - - auto tensor_1_abstract = vnode_1->abstract()->cast(); - auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); - TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } - - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); - - int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); - - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; - } - - void *data_out; - - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), - &data_out, data_out_size); - } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - // Un-support data types - return nullptr; - } - } - } - - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - memcpy(data, data_out, mem_size); - - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimMul, Tensor1, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - - if (!IsCNode(c_p_node_)) { - return nullptr; - } - - auto tensor1 = vnode_; - auto mul = c_p_node_->cast(); - - Reset(); - // {prim::kPrimMul, Tensor2, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); - if (vnode_ == nullptr || c_p_node_ == nullptr) { +AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { return nullptr; } - auto tensor2 = vnode_; - auto c_p_node = c_p_node_; + PatternNode x, y; + PConstant zero_(node, false, 0); - auto PrimMul = GetValueNode(mul->input(0)); - auto fg = node->func_graph(); + // Multiply by zero + MATCH_REPLACE_IF(node, x * zero_, zero_.WithShapeAs(node), + !zero_.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); + auto zero_prim = PPrimitive(prim::kPrimZerosLike, y); + MATCH_REPLACE_IF(node, x * zero_prim, zero_.WithShapeAs(node), + !zero_prim.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); - auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); - if (new_mul_tensor == nullptr) { - auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); - return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); - } - return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); -} - -void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { - if (IsValueNode(node)) { - vnode_ = node; - } - - if (IsCNode(node) || IsParam(node)) { - c_p_node_ = node; - } -} - -void ConstantDuplicateMul::Reset() { - vnode_ = nullptr; - c_p_node_ = nullptr; -} - -AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[2])) { - return nullptr; - } - auto scalar = GetValueNode(inputs[2]); - if (scalar->isa() && GetValue(scalar) == 1.0) { - return inputs[1]; - } else if (scalar->isa() && GetValue(scalar) == 1) { - return inputs[1]; - } return nullptr; } @@ -654,27 +188,6 @@ void AdjustAllReduceMulAdd::Reset() { all_reduce_fg_ = nullptr; } -AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} - -AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h index 3088231396..699005f7bf 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h @@ -14,166 +14,22 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #include #include #include -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/prim_eliminate.h" -#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" +#include "ir/pattern_matcher.h" namespace mindspore { namespace opt { namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -class MultiplyByZeroOrOne : public AnfVisitor { - public: - MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} - ~MultiplyByZeroOrOne() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}, is_one_{false}; - ValuePtr zero_, one_; - AnfNodePtr x_{nullptr}; -}; - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -class CheckTensorConstant { - public: - explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} - ~CheckTensorConstant() = default; - - bool IsTensorConstant(const ValuePtr &value); - bool IsTensorScalarConstant(const ValuePtr &value); - - private: - int check_value_; -}; - -class TensorMultiplyBase : public AnfVisitor { - protected: - void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); - - // Make a new tensor (when possible) with the same shape as of `node` - // If x is nullptr then fill new tensor will "0" - // If x is a tensor with empty shape then fill new tensor with the single value of x - // If x is a tensor with same shape as `node` then return x as result - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); - - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -class TensorMultiplyByZero : public TensorMultiplyBase { - public: - TensorMultiplyByZero() : zero_(MakeValue(0)) {} - ~TensorMultiplyByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; -}; - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -class TensorMultiplyByOne : public TensorMultiplyBase { - public: - TensorMultiplyByOne() {} - ~TensorMultiplyByOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_one_{false}; -}; - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -class AddByZero : public AnfVisitor { - public: - AddByZero() : zero_(MakeValue(0)) {} - ~AddByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -class TensorAddByZero : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - AnfNodePtr x_{nullptr}; -}; - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -class OptUpdateZeroTensor : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - -// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -class ConstantDuplicateMul : public AnfVisitor { - public: - // Support function to multiply two constant tensors: partially support broadcasting shapes - template - void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size); - - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - AnfNodePtr vnode_; - AnfNodePtr c_p_node_; -}; - -class PowerOneEliminate : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - // grad = AllReduce(grad) / worker_number // grad = grad + weight * decy // -> @@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { class ArithmeticSimplify : public OptimizerCaller { public: - ArithmeticSimplify() - : multiply_by_zero_or_one_(std::make_shared()), - tensor_multiply_by_one_(std::make_shared()), - add_by_zero_(std::make_shared()), - tensor_add_by_zero_(std::make_shared()), - identity_(std::make_shared(prim::kPrimIdentity)), - opt_update_zero_tensor_(std::make_shared()), - constant_duplicate_mul_(std::make_shared()), - power_one_(std::make_shared()) { - eliminaters_.emplace_back(multiply_by_zero_or_one_); - eliminaters_.emplace_back(tensor_multiply_by_one_); - eliminaters_.emplace_back(add_by_zero_); - eliminaters_.emplace_back(tensor_add_by_zero_); - eliminaters_.emplace_back(identity_); - eliminaters_.emplace_back(opt_update_zero_tensor_); - eliminaters_.emplace_back(constant_duplicate_mul_); - eliminaters_.emplace_back(power_one_); - } - ~ArithmeticSimplify() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr multiply_by_zero_or_one_; - OptimizerCallerPtr tensor_multiply_by_one_; - OptimizerCallerPtr add_by_zero_; - OptimizerCallerPtr tensor_add_by_zero_; - OptimizerCallerPtr identity_; - OptimizerCallerPtr opt_update_zero_tensor_; - OptimizerCallerPtr constant_duplicate_mul_; - OptimizerCallerPtr power_one_; - - std::vector eliminaters_{}; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; }; // Arithmetic Simplifications should be done after step_parallel. @@ -242,18 +66,10 @@ class ArithmeticSimplify : public OptimizerCaller { // ArithmeticSimplify and deferred until step_parallel. class ArithmeticSimplify2 : public OptimizerCaller { public: - ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { - eliminaters_.emplace_back(tensor_multiply_by_zero_); - } - ~ArithmeticSimplify2() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr tensor_multiply_by_zero_; - std::vector eliminaters_{}; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; }; + } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index b3f3fe4733..72a6a4df9f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ #include #include #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" +#include "frontend/optimizer/optimizer_caller.h" #include "ir/pattern_matcher.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" @@ -51,8 +51,8 @@ class SwitchSimplify : public OptimizerCaller { } }; -// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => -// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} +// {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} => +// {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} class FloatTupleGetItemSwitch : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -98,19 +98,12 @@ class ConvertSwitchReplacement : public OptimizerCaller { return nullptr; } - auto cnode_ = node->cast(); - if (cnode_->size() < 1) { - return nullptr; - } - - auto node_ = cnode_->input(0); - PatternNode cond, true_br, false_br; - auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto g1_ = GetValueNode(true_br.GetNode(node_)); - auto g2_ = GetValueNode(false_br.GetNode(node_)); - auto x_ = cond.GetNode(node_); + auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node)); + auto g2_ = GetValueNode(false_br.GetNode(node)); + auto x_ = cond.GetNode(node); // for switch replace method, only graphs without graph inside can be replaced for (auto &item : g1_->value_nodes()) { @@ -133,7 +126,7 @@ class ConvertSwitchReplacement : public OptimizerCaller { auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; - auto fg = node_->func_graph(); + auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); @@ -142,8 +135,8 @@ class ConvertSwitchReplacement : public OptimizerCaller { }; MATCH_REPLACE_LAMBDA_IF( - node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, - true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); + node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node) && false_br.CheckFunc(IsValueNode, node)); return nullptr; } @@ -152,4 +145,4 @@ class ConvertSwitchReplacement : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ +#endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc index ddb84806e1..e556402c44 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc @@ -17,7 +17,7 @@ #include "frontend/optimizer/irpass/cast_eliminate.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "ir/func_graph.h" #include "pipeline/jit/parse/data_converter.h" diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h index d5222d4310..46c86ab716 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -78,4 +78,4 @@ class CastEliminater : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/convert.h b/mindspore/ccsrc/frontend/optimizer/irpass/convert.h index d887874203..f05fd8cd4c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/convert.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/convert.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_H_ #include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "ir/func_graph.h" #include "frontend/operator/ops.h" @@ -59,4 +59,4 @@ class PrintTupleWrapper : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ +#endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 14fd8743ff..1fee007a88 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #include #include @@ -25,8 +25,8 @@ #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -361,4 +361,4 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc index 44c1b62fa5..8950084437 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc @@ -24,7 +24,7 @@ #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h index f6992a87c6..d9040044d9 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ #include #include @@ -26,7 +26,7 @@ #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" @@ -51,4 +51,4 @@ class GradVarPrepare : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h index 82312d9e37..aefd6d73bc 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ #include #include @@ -23,8 +23,8 @@ #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" -#include "common/utils.h" +#include "frontend/optimizer/anf_visitor.h" +#include "utils/ms_utils.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/ad/grad.h" @@ -58,4 +58,4 @@ class ExpandJPrim : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h index 2f6404458f..ccfef5bcfb 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ #include #include @@ -24,7 +24,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "frontend/operator/ops.h" @@ -205,4 +205,4 @@ class IncorporateCallSwitch : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 828e205e4f..9c02df6b2f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #include #include @@ -25,8 +25,8 @@ #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -413,4 +413,4 @@ class IncorporateGetitemSet : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h deleted file mode 100644 index dfe345fe01..0000000000 --- a/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ - -#include -#include - -#include "frontend/optimizer/irpass.h" -#include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" -#include "frontend/operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} -// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} -// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} -class IndexedSlicesEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(1); - } - AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(2); - } - AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(3); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { - tuple_ = cnode; - is_match_ = true; - } - } - - void Reset() { - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - CNodePtr tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 8cafb268b4..0be228f44b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_ #include #include @@ -23,7 +23,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "frontend/operator/ops.h" @@ -201,4 +201,4 @@ class Inliner : public InlinerBase { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h index acd6844ee7..e794671f98 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #include #include #include -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -38,6 +38,7 @@ class GetitemEliminater : public AnfVisitor { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); if (is_match_) { return tuple_->input(id_); @@ -46,14 +47,18 @@ class GetitemEliminater : public AnfVisitor { } void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { tuple_ = cnode; } } void Visit(const ValueNodePtr &vnode) override { if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value()) + 1); + int idx = GetValue(vnode->value()); + if (idx < 0) { + idx = idx + tuple_->size() - 1; + } + id_ = IntToSize(idx + 1); if (tuple_->size() > id_) { is_match_ = true; } @@ -80,9 +85,12 @@ class GetitemConstEliminater : public AnfVisitor { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsVNode, IsVNode})(node); if (is_match_) { - return NewValueNode((*tuple_)[id_]); + auto out = NewValueNode((*tuple_)[id_]); + out->set_has_new_value(has_new_value_); + return out; } return nullptr; } @@ -90,6 +98,7 @@ class GetitemConstEliminater : public AnfVisitor { void Visit(const ValueNodePtr &vnode) override { if (IsValueNode(vnode)) { tuple_ = GetValueNode(vnode); + has_new_value_ = vnode->has_new_value(); } if (tuple_ != nullptr && IsValueNode(vnode)) { id_ = IntToSize(GetValue(vnode->value())); @@ -109,6 +118,7 @@ class GetitemConstEliminater : public AnfVisitor { bool is_match_{false}; size_t id_{0}; ValueTuplePtr tuple_{nullptr}; + bool has_new_value_{false}; }; // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) @@ -138,7 +148,7 @@ class SetitemEliminater : public AnfVisitor { } void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { auto &inputs = cnode->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_)); } @@ -234,6 +244,7 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsValueNode})(node); if (x_ == nullptr) { return nullptr; } @@ -298,4 +309,4 @@ class ItemTupleEliminater : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h index 8d3839bd9e..ffa383cc7a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H #include #include @@ -24,9 +24,9 @@ #include "backend/session/anf_runtime_algorithm.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "frontend/operator/composite/composite.h" namespace mindspore { @@ -83,4 +83,4 @@ class MarkInterfaceFusion : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h index a3cf6e2231..5e0c598501 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ #include #include @@ -23,7 +23,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -37,9 +37,11 @@ class MergeAddN : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { Reset(); - optimizer_ = optimizer; + mng_ = optimizer->resource()->manager(); is_outer_ = true; AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); + // do not hold this manager + mng_ = nullptr; if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } @@ -104,8 +106,7 @@ class MergeAddN : public AnfVisitor { } bool is_unique(const AnfNodePtr &node) { - auto mng = optimizer_->resource()->manager(); - auto &node_users = mng->node_users(); + auto &node_users = mng_->node_users(); if (node_users.find(node) == node_users.end()) { return false; } @@ -124,7 +125,7 @@ class MergeAddN : public AnfVisitor { } private: - OptimizerPtr optimizer_{nullptr}; + FuncGraphManagerPtr mng_{nullptr}; std::vector Xs_{}, Ys_{}, args_{}; bool is_inner_{false}, is_outer_{false}, is_match_{false}; }; @@ -317,4 +318,4 @@ class AddNEliminater : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h b/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h index 658a287234..079e28528d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ #include #include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -107,4 +107,4 @@ class MinMaximumGrad : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h b/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h index 999376e528..5b8a4030f1 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ #include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "pipeline/jit/parse/parse.h" @@ -57,4 +57,4 @@ class ReplaceOldParam : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h index 32fc5abc7d..dc63789a4b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ #include #include @@ -23,7 +23,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -76,4 +76,4 @@ class PartialEliminater : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h index d8c96825c9..8bce179287 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" namespace mindspore { namespace opt { @@ -46,4 +46,4 @@ class PrimEliminater : public AnfVisitor { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h index 78b7d3f4f1..5543783f5f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ #include #include @@ -23,7 +23,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "abstract/dshape.h" @@ -157,4 +157,4 @@ class ReduceOneEliminater : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h index 86eb4e761d..b7759daad4 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ #include @@ -42,8 +42,8 @@ class GetRefParamEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x; - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x); return nullptr; } }; @@ -91,4 +91,4 @@ class ReplaceRefkeyByParam : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h index 27d4bdad3d..0270224a26 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ #include #include "ir/func_graph.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -151,4 +151,4 @@ class ReshapeEliminater : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h new file mode 100644 index 0000000000..c237bec0ec --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/anf_visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}} +// {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}} +// {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}} +class RowTensorEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(1); + } + AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(2); + } + AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(3); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) { + tuple_ = cnode; + is_match_ = true; + } + } + + void Reset() { + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + CNodePtr tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h new file mode 100644 index 0000000000..07fb4e80b1 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}} +// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}} +// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}} +class SparseTensorEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(1); + } + AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(2); + } + AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(3); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) { + tuple_ = cnode; + is_match_ = true; + } + } + + void Reset() { + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + CNodePtr tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 01efa85e8d..09b54afb75 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -14,17 +14,17 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ #include #include #include #include -#include "ir/optimizer_caller.h" +#include "frontend/optimizer/optimizer_caller.h" #include "ir/pattern_matcher.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/prim_eliminate.h" @@ -207,4 +207,4 @@ class DependValueElim : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h index d8a15f6d83..6cb9312028 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ #include #include @@ -26,7 +26,7 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "ir/manager.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" @@ -302,4 +302,4 @@ class UnusedOutputEliminater : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h index de9e533550..e529c7ce04 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h @@ -14,22 +14,28 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ #include #include #include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/optimizer_caller.h" #include "frontend/optimizer/irpass.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" +#include "ir/pattern_matcher.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/parse_base.h" namespace mindspore { namespace opt { namespace irpass { + +const char PARSE_SUPER_NAME[] = "namespace"; + // {prim::kPrimResolve, Ns, Sym} class ResolverResolve : public AnfVisitor { public: @@ -90,7 +96,35 @@ class ResolverGetattr : public AnfVisitor { parse::NameSpacePtr ns_{nullptr}; parse::SymbolPtr sym_{nullptr}; }; + +// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} +class ResolveAttr : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + PatternNode ns_node, sym_node, attr_node; + auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr { + auto node_to_getattr = node->cast()->input(1); + std::string attr_as_string = GetValueNode(attr_node.GetNode(node))->value(); + + auto ns_ = GetValueNode(ns_node.GetNode(node)); + auto sym_ = GetValueNode(sym_node.GetNode(node)); + + if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) { + // deal with the case of getting attr from a class member + // and avoid the case of getting attr from self (the result of ParseSuper) + auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string); + return result; + } + return nullptr; + }; + MATCH_REPLACE_LAMBDA_IF( + node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node), + ResolveAttrLambda, attr_node.CheckFunc(IsValueNode, node)); + + return nullptr; + } +}; } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h index f561e04c10..decfd6a995 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ #include #include #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -74,4 +74,4 @@ class TileMultiplyByOne : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h index 70b8898462..ecf8b92743 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ #include #include #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "ir/visitor.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -76,4 +76,4 @@ class TransposeSameIOEliminater : public AnfVisitor { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc new file mode 100644 index 0000000000..38b59afe96 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/value_based_eliminate.h" + +namespace mindspore { +namespace opt { +namespace irpass { +#define UPPER_FLT_LIMIT (FLT_MAX / 2.0) +#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) + +bool IsCNodePositive(const AnfNodePtr &node) { + if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { + return IsCNodePositive(node->cast()->input(1)); + } + if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { + return true; + } + return false; +} + +// check if a value is bigger than UPPER_FLT_LIMIT +bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + auto value = value_node->value(); + if (value == nullptr) { + return false; + } + + auto scalar = value->cast(); + if (scalar != nullptr) { + if (scalar->isa()) { + return GetValue(scalar) > UPPER_FLT_LIMIT; + } + } + // Check for Tensor [] or Tensor [1] + auto tensor_ptr = value->cast(); + if (tensor_ptr == nullptr) { + return false; + } + if (tensor_ptr->DataSize() > 1) { + return false; + } + + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data = reinterpret_cast(tensor_ptr->data_c()); + return data[0] > UPPER_FLT_LIMIT; + } + + return false; +} + +// check if a value is smaller than LOWER_FLT_LIMIT +bool IsNodeScalarMinFLT(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + auto value = value_node->value(); + if (value == nullptr) { + return false; + } + + auto scalar = value->cast(); + if (scalar != nullptr) { + if (scalar->isa()) { + return GetValue(scalar) < LOWER_FLT_LIMIT; + } + } + // Check for Tensor [] or Tensor [1] + auto tensor_ptr = value->cast(); + if (tensor_ptr == nullptr) { + return false; + } + if (tensor_ptr->DataSize() > 1) { + return false; + } + + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data = reinterpret_cast(tensor_ptr->data_c()); + return data[0] < LOWER_FLT_LIMIT; + } + + return false; +} + +AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode x, y, z; + PConstant zero_(node, false, 0); + PConstant zero_scalar_(node, false, 0, true); + + // {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0 + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y, + IsCNodePositive(x.GetNode(node))); + + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, + IsCNodePositive(x.GetNode(node))); + + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); + + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); + + return nullptr; +} + +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h new file mode 100644 index 0000000000..3ae2d90a5c --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "frontend/optimizer/anf_visitor.h" +#include "ir/pattern_matcher.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0 +// {prim::kPrimMaximum, X, Y} -> X when Y is smaller than LOWER_FLT_LIMIT +// {prim::kPrimMinimum, X, Y} -> X when Y is greater than UPPER_FLT_LIMIT +class ValueBasedEliminate : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/opt.h b/mindspore/ccsrc/frontend/optimizer/opt.h index f440cc71dc..32f32803a5 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.h +++ b/mindspore/ccsrc/frontend/optimizer/opt.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ #include #include @@ -23,7 +23,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" -#include "ir/optimizer_caller.h" +#include "frontend/optimizer/optimizer_caller.h" #include "frontend/operator/ops.h" namespace mindspore { @@ -75,4 +75,4 @@ class SubstitutionList { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer.h b/mindspore/ccsrc/frontend/optimizer/optimizer.h index a1f11e74d0..bc60181587 100644 --- a/mindspore/ccsrc/frontend/optimizer/optimizer.h +++ b/mindspore/ccsrc/frontend/optimizer/optimizer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_ #include #include @@ -34,7 +34,7 @@ #include "frontend/optimizer/opt.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/action.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { @@ -239,4 +239,4 @@ class Optimizer : public std::enable_shared_from_this { }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer_caller.h b/mindspore/ccsrc/frontend/optimizer/optimizer_caller.h new file mode 100644 index 0000000000..ffb70b28a1 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/optimizer_caller.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_CALLER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_CALLER_H_ + +#include + +#include "ir/anf.h" +#include "ir/visitor.h" + +namespace mindspore { +namespace opt { +class Optimizer; +using OptimizerPtr = std::shared_ptr; +using OptimizerWeakPtr = std::weak_ptr; +using PredicateFuncType = mindspore::PredicateFuncType; +} // namespace opt + +class OptimizerCaller { + public: + virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } +}; +using OptimizerCallerPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_CALLER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.h b/mindspore/ccsrc/frontend/optimizer/pass_group.h index 08fa8018d6..22b17b81b1 100644 --- a/mindspore/ccsrc/frontend/optimizer/pass_group.h +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PASS_GROUP_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PASS_GROUP_H_ #include #include @@ -58,4 +58,4 @@ using PassGroupPtr = std::shared_ptr; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PASS_GROUP_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.cc b/mindspore/ccsrc/frontend/optimizer/pattern.cc new file mode 100644 index 0000000000..412c0bdb46 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pattern.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/pattern.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +int Pattern::g_id_ = 0; + +MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) { + if (!IsValueNode(node)) { + return nullptr; + } + MatchResultPtr res = std::make_shared(); + if (IsValueNode(node)) { + // iterate over all primitives + for (auto &iter : primitives_) { + if (IsPrimitive(node, iter) || iter->name() == "*") { + matched_prim_ = iter; + res->add_entry(shared_from_base(), node); + return res; + } + } + } + return nullptr; +} + +MatchResultPtr CallWith::match(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node)) { + return nullptr; + } + MatchResultPtr res = std::make_shared(); + // IsPrimitiveCNode + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Check Primitive ValueNode + if (prim_pattern_ != nullptr) { + // Passed in prim_pattern + auto prim_value_res = prim_pattern_->match(cnode->input(0)); + if (prim_value_res == nullptr) { + return nullptr; + } + res->merge(prim_value_res); + } else if (prim_ != nullptr) { + // Passed in primitive/primitive str + if (!IsPrimitive(cnode->input(0), prim_)) { + return nullptr; + } + } else { + MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern."; + } + // Check inputs + auto p_inputs_size = inputs_.size(); + auto node_inputs_size = cnode->size() - 1; + if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) { + return nullptr; + } + // If inputs is not specified, add node without looking into its inputs + if (p_inputs_size == 0) { + res->add_entry(shared_from_base(), cnode->input(0)); + return res; + } + bool failed = false; + for (std::size_t i = 0; i < node_inputs_size; i++) { + auto pattern = inputs_[i]; + auto input = cnode->input(i + 1); + auto input_match_result = pattern->match(input); + if (input_match_result == nullptr) { + failed = true; + break; + } + res->merge(input_match_result); + } + if (!failed) { + res->add_entry(shared_from_base(), cnode->input(0)); + return res; + } + return nullptr; +} + +MatchResultPtr IsIn::match(const AnfNodePtr &node) { + for (auto &iter : patterns_) { + auto res = iter->match(node); + if (res != nullptr) { + return res; + } + } + return nullptr; +} + +MatchResultPtr IsNot::match(const AnfNodePtr &node) { + for (auto &iter : patterns_) { + auto res = iter->match(node); + if (res != nullptr) { + return nullptr; + } + } + auto res = std::make_shared(); + res->add_entry(shared_from_base(), node); + return res; +} + +MatchResultPtr AnyPattern::match(const AnfNodePtr &node) { + MatchResultPtr res = std::make_shared(); + res->add_entry(shared_from_base(), node); + return res; +} + +AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) { + auto entry = match_result_.find(pattern); + if (entry == match_result_.end()) { + return nullptr; + } + return entry->second; +} + +void MatchResult::merge(const MatchResultPtr &other_result) { + auto other_result_map = other_result->_result(); + // add/update entries in other_result + for (auto &iter : other_result_map) { + match_result_[iter.first] = iter.second; + } +} + +REGISTER_PYBIND_DEFINE( + Pattern, ([](const py::module *m) { + (void)py::class_>(*m, "Pattern").def(py::init<>()); + (void)py::class_, Pattern>(*m, "IsIn_").def(py::init>()); + (void)py::class_, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr()) + .def(py::init, string, bool>()) + .def(py::init, string, bool>()); + (void)py::class_, Pattern>(*m, "CallWith_") + .def(py::init, bool>()) + .def(py::init, bool>()) + .def(py::init, bool>()); + (void)py::class_, Pattern>(*m, "IsNot_").def(py::init>()); + (void)py::class_, Pattern>(*m, "AnyPattern").def(py::init<>()); + (void)py::class_, Pattern>(*m, "NewTensor_") + .def(py::init()); + })); +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.h b/mindspore/ccsrc/frontend/optimizer/pattern.h new file mode 100644 index 0000000000..8d567e5ab2 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pattern.h @@ -0,0 +1,228 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "utils/primitive_py.h" +#include "utils/tensor_py.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +using std::string; +using std::vector; + +class MatchResult; +using MatchResultPtr = std::shared_ptr; +class Pattern; +using PatternPtr = std::shared_ptr; +class IsPrimTypeOf; +using IsPrimTypeOfPtr = std::shared_ptr; +class CallWith; +using CallWithPtr = std::shared_ptr; +class NewTensor; +using NewTensorPtr = std::shared_ptr; +struct PatternHasher; +struct PatternEqual; +using PatternNodeMap = std::unordered_map; + +class Pattern : public Base { + public: + Pattern() : unique_name_(std::to_string(g_id_++)) {} + virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; } + virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } + string unique_name() const { return unique_name_; } + vector inputs() { return inputs_; } + bool should_replace() { return should_replace_; } + virtual void reset() {} + + protected: + static int g_id_; + // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed + string unique_name_; + vector inputs_; + bool should_replace_ = true; +}; + +struct PatternEqual { + bool operator()(PatternPtr const &p1, PatternPtr const &p2) const { + MS_EXCEPTION_IF_NULL(p1); + MS_EXCEPTION_IF_NULL(p2); + return p1->unique_name() == p2->unique_name(); + } +}; + +struct PatternHasher { + std::size_t operator()(PatternPtr const &p) const { + MS_EXCEPTION_IF_NULL(p); + return std::hash()(p->unique_name()); + } +}; + +class IsPrimTypeOf : public Pattern { + public: + IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); } + IsPrimTypeOf(vector prims, string name, bool should_replace) + : primitives_(prims), name_(name), matched_prim_(nullptr) { + unique_name_ = std::to_string(g_id_++) + "_" + name; + should_replace_ = should_replace; + if (!should_replace) { + matched_prim_ = prims[0]; + } + } + IsPrimTypeOf(vector types, string name, bool should_replace) : types_(types), name_(name) { + unique_name_ = std::to_string(g_id_++) + "_" + name; + // Make primitives_ + for (auto &iter : types) { + primitives_.push_back(std::make_shared(iter, py::cast(nullptr))); + } + should_replace_ = should_replace; + if (!should_replace) { + matched_prim_ = primitives_[0]; + } + } + MS_DECLARE_PARENT(IsPrimTypeOf, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + PrimitivePyPtr matched_primitive() { return matched_prim_; } + void reset() override { + if (should_replace_) { + matched_prim_ = nullptr; + } + } + + private: + vector types_; + vector primitives_; + string name_; + PrimitivePyPtr matched_prim_; +}; + +class CallWith : public Pattern { + public: + CallWith() { unique_name_ = std::to_string(g_id_++); } + CallWith(PatternPtr prim_pattern, vector inputs, bool should_replace) { + // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting + prim_pattern_ = prim_pattern; + unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name(); + inputs_ = inputs; + should_replace_ = should_replace; + } + CallWith(PrimitivePyPtr prim, vector inputs, bool should_replace) { + prim_ = prim; + unique_name_ = std::to_string(g_id_++) + prim_->ToString(); + inputs_ = inputs; + should_replace_ = should_replace; + } + CallWith(string prim_str, vector inputs, bool should_replace) { + prim_ = std::make_shared(prim_str, py::cast(nullptr)); + unique_name_ = std::to_string(g_id_++) + prim_->ToString(); + inputs_ = inputs; + should_replace_ = should_replace; + } + MS_DECLARE_PARENT(CallWith, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + PrimitivePtr prim_value() { return prim_; } + PatternPtr prim_pattern() { return prim_pattern_; } + + private: + PatternPtr prim_pattern_ = nullptr; + PrimitivePtr prim_ = nullptr; + vector types_; + string name_; +}; + +class IsIn : public Pattern { + public: + IsIn() { unique_name_ = std::to_string(g_id_++); } + explicit IsIn(vector patterns) : patterns_(patterns) { + unique_name_ = std::to_string(g_id_++); + for (auto &iter : patterns) { + unique_name_ = unique_name_ + "_" + iter->unique_name(); + } + } + MS_DECLARE_PARENT(IsIn, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + + private: + vector patterns_; +}; + +class IsNot : public Pattern { + public: + IsNot() { unique_name_ = std::to_string(g_id_++); } + explicit IsNot(vector patterns) : patterns_(patterns) { + unique_name_ = std::to_string(g_id_++); + for (auto &iter : patterns) { + unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name(); + } + } + MS_DECLARE_PARENT(IsNot, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + + private: + vector patterns_; +}; + +class AnyPattern : public Pattern { + public: + AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; } + MS_DECLARE_PARENT(AnyPattern, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; +}; + +class NewTensor : public Pattern { + public: + NewTensor() { unique_name_ = std::to_string(g_id_++); } + explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; } + MS_DECLARE_PARENT(NewTensor, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override { + MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n"; + } + tensor::TensorPtr input_tensor() { return input_tensor_; } + + private: + tensor::TensorPtr input_tensor_; +}; + +class MatchResult { + public: + MatchResult() {} + void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } + PatternNodeMap _result() { return match_result_; } + AnfNodePtr get_node(const PatternPtr &pattern); + void merge(const MatchResultPtr &other_result); + void clear() { match_result_.clear(); } + void dump() { + MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n"; + for (auto &iter : match_result_) { + MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n"; + } + } + + private: + PatternNodeMap match_result_; +}; +} // namespace python_pass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index c1bf40fcbb..362427d227 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -22,6 +22,7 @@ #include "ir/func_graph.h" #include "ir/manager.h" +#include "utils/primitive_py.h" #include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/resource.h" @@ -29,6 +30,8 @@ namespace mindspore { namespace opt { namespace python_pass { namespace internal { +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res); + std::string GetNodeRepr(AnfNodePtr node) { if (node != nullptr) { if (node->isa()) { @@ -50,125 +53,104 @@ std::string GetNodeRepr(AnfNodePtr node) { return ""; } -void ResolveFuncGraph_(const FuncGraphPtr &fg) { - auto manager = Manage(fg, false); - parse::python_adapter::set_use_signature_in_resolve(false); - parse::ResolveAll(manager); - parse::python_adapter::set_use_signature_in_resolve(true); -} - -bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { +bool IsTraversable(const AnfNodePtr &node) { if (node == nullptr) { return false; } - MS_EXCEPTION_IF_NULL(pattern); - if (pattern->isa()) { - if (!node->isa()) { - return false; - } - if (GetNodeRepr(pattern) == GetNodeRepr(node)) { - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); - return true; - } - return false; - } else if (pattern->isa()) { - MS_LOG(DEBUG) << pattern->ToString() + "\n"; - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); + if (node->isa() || node->isa()) { return true; - } else if (pattern->isa()) { - // match every single sub ANode - if (!node->isa()) { - return false; - } - auto pattern_inputs = pattern->cast()->inputs(); - auto node_inputs = node->cast()->inputs(); - if (pattern_inputs.size() != node_inputs.size()) { - return false; - } - for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); - p_item++, node_item++) { - auto res = Match(*p_item, *node_item, equiv_ptr); - if (!res) { - return false; - } - } + } + if (IsValueNode(node) || IsValueNode(node)) { return true; } - MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; + return false; } -AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, - const NodeEquivPtr &equiv_ptr) { - if (cur_raw_dst_node_->isa()) { - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; - } else if (cur_raw_dst_node_->isa()) { - // check primitive ValueNode - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - return cur_raw_dst_node_; - } else if (cur_raw_dst_node_->isa()) { - std::vector new_inputs; - auto inputs = cur_raw_dst_node_->cast()->inputs(); - for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { - auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); - new_inputs.push_back(subed); - } - return func_graph->NewCNode(new_inputs); - } - MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); +AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) { + // Build up AnfNode from primitive + auto prim_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(prim_pattern); + PrimitivePyPtr prim = prim_pattern->matched_primitive(); + MS_EXCEPTION_IF_NULL(prim); + // Make value node out of primitives + return std::make_shared(prim); } -bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->isa() || node->isa()) { - return true; - } - if (IsValueNode(node) || IsValueNode(node)) { - return true; +AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) { + // Build a ValueNode from TensorPtr + auto new_tensor_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(new_tensor_pattern); + auto input_tensor = new_tensor_pattern->input_tensor(); + MS_EXCEPTION_IF_NULL(input_tensor); + return std::make_shared(input_tensor); +} + +AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) { + auto call_with_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(call_with_pattern); + auto prim = call_with_pattern->prim_value(); + if (prim != nullptr) { + return std::make_shared(prim); } - return false; + auto prim_pattern = call_with_pattern->prim_pattern(); + MS_EXCEPTION_IF_NULL(prim_pattern); + return ProcessSinglePattern(prim_pattern, res); } -} // namespace internal -void PythonPass::Build(const py::function &src, const py::function &dst) { - // 1. get FuncGraph from py::function - auto src_fg_ = parse::ParsePythonCode(src); - auto dst_fg_ = parse::ParsePythonCode(dst); - if (src_fg_ == nullptr || dst_fg_ == nullptr) { - MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) { + if (pattern->should_replace()) { + // Find replacement in the MatchResult + auto target_node = res->get_node(pattern); + if (target_node == nullptr) { + MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n"; + } + return target_node; } - // 2. Resolve - internal::ResolveFuncGraph_(src_fg_); - internal::ResolveFuncGraph_(dst_fg_); - // 3. from FuncGraphPtr to ValueNode - src_node_ = src_fg_->output(); - dst_node_ = dst_fg_->output(); + // Build up new node from pattern + if (pattern->isa()) { + return BuildPrimitive(pattern, res); + } else if (pattern->isa()) { + return BuildNewTensor(pattern, res); + } else if (pattern->isa()) { + return BuildPrimitiveValueNode(pattern, res); + } + return nullptr; } -PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, - bool multigraph) - : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { - Build(src, dst); +AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { + auto target_inputs = pattern->inputs(); + if (target_inputs.size() == 0) { + return ProcessSinglePattern(pattern, res); + } + // Build up the AnfNode in a recursive manner + std::vector new_inputs; + auto prim_value_node = ProcessSinglePattern(pattern, res); + MS_EXCEPTION_IF_NULL(prim_value_node); + new_inputs.push_back(prim_value_node); + for (auto &iter : target_inputs) { + if (iter == pattern) { + MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() + + "\n"; + } + new_inputs.push_back(BuildTarget(iter, func_graph, res)); + } + return func_graph->NewCNode(new_inputs); } +} // namespace internal AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - auto equiv_ptr = std::make_shared(); - bool is_a_match = internal::Match(src_node_, node, equiv_ptr); - if (is_a_match) { - auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); + MS_EXCEPTION_IF_NULL(src_pattern_); + MS_EXCEPTION_IF_NULL(dst_pattern_); + auto res = src_pattern_->match(node); + if (res != nullptr) { + res->dump(); + MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name(); + auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); + dst_pattern_->reset(); MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; return new_node; } + src_pattern_->reset(); return nullptr; } @@ -187,14 +169,12 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) { while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); - - // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { + // Check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) { continue; } node->seen_ = seen; - - // select nodes that this transform can be applied. + // Select nodes that this transform can be applied. AnfNodePtr new_node = Run(func_graph, node); bool change = (new_node != nullptr); if (new_node != nullptr && new_node != node) { @@ -205,17 +185,14 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) { if (run_only_once_) { return change; } - - // find success, and add them to todo list + // Find success, and add them to todo list if (IsValueNode(node)) { todo.push_back(GetValueNode(node)->output()); } - if (node->isa()) { auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } - auto &node_users = manager->node_users(); if (change && node_users.find(node) != node_users.end()) { for (auto &use : node_users[node]) { diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.h b/mindspore/ccsrc/frontend/optimizer/py_pass.h index b01bf7c942..022c16a686 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.h @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PASS_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PASS_H_ #include #include #include #include "ir/anf.h" +#include "frontend/optimizer/pattern.h" #include "pybind_api/api_register.h" #include "pybind_api/export_flags.h" @@ -33,17 +34,17 @@ using NodeEquivPtr = std::shared_ptr; class PythonPass { public: - explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst, - bool run_only_once = false, bool multigraph = true); + explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false, + bool multigraph = true) + : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {} ~PythonPass() = default; bool Run(const FuncGraphPtr &func_graph); std::string name() const { return name_; } AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); private: - void Build(const py::function &src, const py::function &dst); - AnfNodePtr src_node_ = nullptr; - AnfNodePtr dst_node_ = nullptr; + PatternPtr src_pattern_; + PatternPtr dst_pattern_; const std::string name_; bool run_only_once_; bool multigraph_ = true; @@ -53,4 +54,4 @@ using PythonPassPtr = std::shared_ptr; } // namespace python_pass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PASS_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc index 86d7067d1c..a269788dfe 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -49,7 +49,7 @@ PyPassManager::PyPassManager() { phase_to_group_[Phase::OPT] = std::make_shared(); } -void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, +void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, Phase phase, bool run_only_once, bool multigraph) { auto cur_pm = GetPassGroup(phase); MS_EXCEPTION_IF_NULL(cur_pm); diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h index 84868862a7..1bb619264e 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_ #include #include @@ -23,11 +23,12 @@ #include "ir/anf.h" #include "ir/func_graph.h" -#include "ir/primitive_py.h" -#include "utils/graph_utils.h" -#include "common/utils.h" +#include "utils/primitive_py.h" +#include "ir/graph_utils.h" +#include "utils/ms_utils.h" #include "pipeline/jit/parse/resolve.h" +#include "frontend/optimizer/pattern.h" #include "frontend/optimizer/py_pass.h" #include "frontend/optimizer/pass_group.h" @@ -51,7 +52,7 @@ class PyPassManager { // Access the only global instance static PyPassManagerPtr GetInstance(); virtual ~PyPassManager() = default; - void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); void Unregiste(const std::string &pass_name, Phase phase); PassGroupPtr GetPassGroup(Phase phase); @@ -63,4 +64,4 @@ class PyPassManager { } // namespace python_pass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/CMakeLists.txt b/mindspore/ccsrc/frontend/parallel/CMakeLists.txt index d2a099cf41..b66ca834b5 100644 --- a/mindspore/ccsrc/frontend/parallel/CMakeLists.txt +++ b/mindspore/ccsrc/frontend/parallel/CMakeLists.txt @@ -1,5 +1,12 @@ file(GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc" "ps/scheduler.cc" "ps/optimizer_info.cc" "ps/optimizer_info_builder.cc") + +if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) + list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/optimizer_info_builder.cc") + list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/optimizer_info.cc") + list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/scheduler.cc") + list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc") +endif() + if (ENABLE_DUMP_PROTO) list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") endif () diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc index 70ae5a7d20..927bb2e73b 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc @@ -50,7 +50,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { (void)cnode_set.emplace(cnode); } else { auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); @@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi return cnode_dist; } + auto operator_info = cnode->user_data(); MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) - << " operator_info: " << (cnode->operator_info() != nullptr); + << " operator_info: " << (operator_info != nullptr); - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); + if (IsParallelCareNode(cnode) && (operator_info != nullptr)) { + auto cost = operator_info->GetForwardMemoryCostFromCNode(); MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; if (allreduce_graph_.NodeInGraph(cnode)) { diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h index 7383c477a6..412673e7be 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ #include #include @@ -76,4 +76,4 @@ class AllreduceFusion { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h index a47039f070..89081ed189 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ #include #include @@ -82,4 +82,4 @@ class AllreduceGraph { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc index 1c478887df..05b40e3bb9 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { } auto para_ptr = node_ptr->cast(); MS_EXCEPTION_IF_NULL(para_ptr); - auto layout_ptr = para_ptr->tensor_layout(); + auto layout_ptr = para_ptr->user_data(); if (layout_ptr == nullptr) { MS_LOG(ERROR) << "layout_ptr is nullptr!"; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h index 6538381f27..6741461b24 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ #include #include @@ -63,4 +63,4 @@ class AllreduceNode { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h index 2612e71984..a569645142 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ #include "frontend/optimizer/optimizer.h" @@ -29,4 +29,4 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h index cc4508681b..a60cbc0428 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ #include #include @@ -79,6 +79,8 @@ class StrategyWithCost { public: StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} + StrategyWithCost(StrategyPtr strategy, CostPtrList c_list) + : strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {} StrategyWithCost(const StrategyWithCost &swc) = delete; StrategyWithCost(StrategyWithCost &&swc) @@ -99,6 +101,7 @@ enum DecisionType { EDGE_ELIMINATION, MERGE_ELIMINATION, CONTRACT_ELIMINATION, + SOURCE_ELIMINATION, TRIANGLE_ELIMINATION, STAR_ELIMINATION, FINAL_TYPE, @@ -199,6 +202,38 @@ struct ContractEliminationDecision : public Decision { MS_DECLARE_PARENT(ContractEliminationDecision, Decision); }; +/* 'SourceEliminationDecision' is for the source Elimination in DP algorithm: + * 1 1,5 + * / \ // \\ + * / \ // \\ + * / \ // \\ + * / \ // \\ + * 2 <- 5 -> 3 ==> 2 3 + * \ / \ / + * \ / \ / + * \ / \ / + * 4 4 + * + * In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and + * no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into + * '1' new edges are generated to replace the old ones incident to '1' and '5'. + * + */ +struct SourceEliminationDecision : public Decision { + SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c) + : op1_strategy_(std::move(op1_stra)), + op1_cost_(std::move(op1_c)), + op2_strategy_(std::move(op2_stra)), + op2_cost_(std::move(op2_c)) { + type_ = DecisionType::SOURCE_ELIMINATION; + } + StrategyPtr op1_strategy_; + CostPtr op1_cost_; + StrategyPtr op2_strategy_; + CostPtr op2_cost_; + MS_DECLARE_PARENT(SourceEliminationDecision, Decision); +}; + /* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: * * u @@ -296,6 +331,7 @@ using OpEliminationDecisionPtr = std::shared_ptr; using EdgeEliminationDecisionPtr = std::shared_ptr; using MergeEliminationDecisionPtr = std::shared_ptr; using ContractEliminationDecisionPtr = std::shared_ptr; +using SourceEliminationDecisionPtr = std::shared_ptr; using TriangleEliminationDecisionPtr = std::shared_ptr; using StarEliminationDecisionPtr = std::shared_ptr; using FinalDecisionPtr = std::shared_ptr; @@ -308,4 +344,4 @@ void RefineForPracticalCost(const CostPtr &, bool is_redistribution); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc index 9408596111..b49ca05f3e 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -42,66 +42,76 @@ Status GetStrategy(const CostGraphPtr &graph) { auto elimi = std::make_shared(n_edge, l_edge, node, r_edge); eliminations.emplace_back(std::move(elimi)); } - auto edges = graph->CheckEdgeElimination(); - if ((!flag) && (!edges.empty())) { - // Applying the Edge Elimination - flag = true; - auto n_edge = graph->EliminationEdges(edges); - auto elimi = std::make_shared(n_edge, edges); - eliminations.emplace_back(std::move(elimi)); + if (!flag) { + auto edges = graph->CheckEdgeElimination(); + if (!edges.empty()) { + // Applying the Edge Elimination + flag = true; + auto n_edge = graph->EliminationEdges(edges); + auto elimi = std::make_shared(n_edge, edges); + eliminations.emplace_back(std::move(elimi)); + } } - auto merge_node = graph->CheckMergeElimination(); - if ((!flag) && (merge_node != nullptr)) { - // Applying the Merge Elimination - flag = true; - auto succ_edge = merge_node->GetAliveSuccEdges()[0]; - auto target_node = graph->EliminationMerge(merge_node); - auto elimi = std::make_shared(merge_node, succ_edge, target_node); - eliminations.emplace_back(std::move(elimi)); + if (!flag) { + auto merge_node = graph->CheckMergeElimination(); + if (merge_node != nullptr) { + // Applying the Merge Elimination + flag = true; + auto succ_edge = merge_node->GetAliveSuccEdges()[0]; + auto target_node = graph->EliminationMerge(merge_node); + auto elimi = std::make_shared(merge_node, succ_edge, target_node); + eliminations.emplace_back(std::move(elimi)); + } } - auto contracted_node = graph->CheckContractElimination(); - if ((!flag) && (contracted_node != nullptr)) { - // Applying the Contract Elimination - flag = true; - auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; - auto target_node = graph->EliminationContract(contracted_node); - auto elimi = std::make_shared(target_node, prev_edge, contracted_node); - eliminations.emplace_back(std::move(elimi)); + if (!flag) { + auto contracted_node = graph->CheckContractElimination(); + if ((contracted_node != nullptr)) { + // Applying the Contract Elimination + flag = true; + auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; + auto target_node = graph->EliminationContract(contracted_node); + auto elimi = std::make_shared(target_node, prev_edge, contracted_node); + eliminations.emplace_back(std::move(elimi)); + } } - auto triangle_pair = graph->CheckTriangleElimination(); - if ((!flag) && (triangle_pair.first != nullptr)) { - // Applying the Triangle Elimination - flag = true; - auto eliminated_node = triangle_pair.first; - auto l_r_edge = triangle_pair.second; + if (!flag) { + auto triangle_pair = graph->CheckTriangleElimination(); + if (triangle_pair.first != nullptr) { + // Applying the Triangle Elimination + flag = true; + auto eliminated_node = triangle_pair.first; + auto l_r_edge = triangle_pair.second; - auto left_node = l_r_edge->prev_operator(); - auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; - auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(left_edge); - if (left_edge->next_operator() != left_node) { - auto tmp = left_edge; - left_edge = right_edge; - right_edge = tmp; + auto left_node = l_r_edge->prev_operator(); + auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; + auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(left_edge); + if (left_edge->next_operator() != left_node) { + auto tmp = left_edge; + left_edge = right_edge; + right_edge = tmp; + } + auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); + auto right_node = l_r_edge->next_operator(); + auto elimi = + std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); + eliminations.emplace_back(std::move(elimi)); } - auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); - auto right_node = l_r_edge->next_operator(); - auto elimi = - std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); - eliminations.emplace_back(std::move(elimi)); } - auto star_center = graph->CheckStarElimination(); - if ((!flag) && (star_center != nullptr)) { - // Applying the Star Elimination - flag = true; - auto succ_edges = graph->EliminationStar(star_center); - std::vector succ_nodes; - for (size_t i = 0; i < succ_edges.size(); ++i) { - MS_EXCEPTION_IF_NULL(succ_edges[i]); - succ_nodes.push_back(succ_edges[i]->next_operator()); + if (!flag) { + auto star_center = graph->CheckStarElimination(); + if (star_center != nullptr) { + // Applying the Star Elimination + flag = true; + auto succ_edges = graph->EliminationStar(star_center); + std::vector succ_nodes; + for (size_t i = 0; i < succ_edges.size(); ++i) { + MS_EXCEPTION_IF_NULL(succ_edges[i]); + succ_nodes.push_back(succ_edges[i]->next_operator()); + } + auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); + eliminations.emplace_back(std::move(elimi)); } - auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); - eliminations.emplace_back(std::move(elimi)); } } diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h index 812f375f0b..ec131e519f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ #include #include @@ -42,7 +42,7 @@ namespace parallel { // the operators' strategies can be all determined. struct Elimination : public Base { - enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; + enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR }; Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} EdgePtr new_edge_; @@ -100,6 +100,26 @@ struct ContractElimination : public Elimination { MS_DECLARE_PARENT(ContractElimination, Elimination); }; +// Source Elimination +struct SourceElimination : public Elimination { + SourceElimination(OperatorInfoPtr p_source, std::vector p_succ_edges, std::vector p_new_succ_edges, + OperatorInfoPtr s_source, std::vector s_succ_edges, std::vector s_new_succ_edges) + : Elimination(nullptr, Elimination::EliminationType::SOURCE), + primary_source_(std::move(p_source)), + primary_succ_edges_(std::move(p_succ_edges)), + primary_new_succ_edges_(std::move(p_new_succ_edges)), + secondary_source_(std::move(s_source)), + secondary_succ_edges_(std::move(s_succ_edges)), + secondary_new_succ_edges_(std::move(s_new_succ_edges)) {} + OperatorInfoPtr primary_source_; + std::vector primary_succ_edges_; + std::vector primary_new_succ_edges_; + OperatorInfoPtr secondary_source_; + std::vector secondary_succ_edges_; + std::vector secondary_new_succ_edges_; + MS_DECLARE_PARENT(SourceElimination, Elimination); +}; + // Triangle Elimination struct TriangleElimination : public Elimination { TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, @@ -138,6 +158,7 @@ using OpEliminationPtr = std::shared_ptr; using EdgeEliminationPtr = std::shared_ptr; using MergeEliminationPtr = std::shared_ptr; using ContractEliminationPtr = std::shared_ptr; +using SourceEliminationPtr = std::shared_ptr; using TriangleEliminationPtr = std::shared_ptr; using StarEliminationPtr = std::shared_ptr; @@ -149,4 +170,4 @@ Status RecoverStrategy(std::vector eliminations); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc index e3f1de7207..59be491852 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -320,5 +320,17 @@ Status Edge::CalculateMemoryCostForInference() { } return SUCCESS; } + +void Edge::SetCostMapAndInputOutput(std::map &cost_map) { + cost_map_ = cost_map; + pre_op_output_.clear(); + next_op_input_.clear(); + + for (auto &key_value : cost_map_) { + auto &key_pair = key_value.first; + pre_op_output_.emplace_back(std::pair>(key_pair.first, {})); + next_op_input_.emplace_back(std::pair>(key_pair.second, {})); + } +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h index 3fffd1b86d..9a09021380 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -22,7 +22,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/auto_parallel/costmodel.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/tensor_layout/tensor_info.h" @@ -80,6 +80,8 @@ class Edge { std::string edge_name() const { return edge_name_; } // Init cost_map_: for each output layout and input layout, calculate the cost Status InitEdgeCost(); + std::map GetCostMap() { return cost_map_; } + void SetCostMapAndInputOutput(std::map &); // For two operators u--->v, given the output tensor layout of u, // and the input tensor layout of v, return the redistribution cost, // and the op_list to carry out the redistribution. diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 1c1fc3a700..5313062e9c 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -794,6 +794,191 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const { return nullptr; } +std::pair CostGraph::CheckSourceElimination() const { + size_t source_count = 0; + std::vector op_vector(2, nullptr); + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0; + if (bool_test) { + op_vector[source_count++] = op; + if (source_count == 2) { + return std::make_pair(op_vector[0], op_vector[1]); + } + } + } + return std::make_pair(nullptr, nullptr); +} + +void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist, + StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist, + CostPtrList *op1_new_clist) { + for (auto &op1_cost : op1_old_clist) { + for (auto &op2_cost : op2_old_clist) { + double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_; + double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_; + double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_; + double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_; + double communication_without_para = + op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; + auto decision = std::make_shared(op1_old_stra, op1_cost, op2_old_stra, op2_cost); + auto new_cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + MS_EXCEPTION_IF_NULL(op1_new_clist); + op1_new_clist->emplace_back(std::move(new_cost)); + } + } +} + +std::pair>, std::vector>> CostGraph::EliminationSources( + OperatorInfoPtr op1, OperatorInfoPtr op2) { + MS_EXCEPTION_IF_NULL(op1); + MS_EXCEPTION_IF_NULL(op2); + MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name(); + + auto op1_old_succ_edges = op1->GetAliveSuccEdges(); + std::vector>>> op1_edges_reorganised_cost( + op1_old_succ_edges.size()); + std::vector> op1_new_edges_cost(op1_old_succ_edges.size()); + std::vector> op1_new_succ_edges(op1_old_succ_edges.size()); + + auto op2_old_succ_edges = op2->GetAliveSuccEdges(); + std::vector>>> op2_edges_reorganised_cost( + op2_old_succ_edges.size()); + std::vector> op2_new_edges_cost(op2_old_succ_edges.size()); + std::vector> op2_new_succ_edges(op2_old_succ_edges.size()); + + // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost' + for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { + const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap(); + std::map>> from_tocost; + for (const auto &key_value : op1_cost_map) { + const auto &from_to_strategies = key_value.first; + const auto &costlist = key_value.second; + from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist)); + } + op1_edges_reorganised_cost[i] = from_tocost; + } + + for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { + const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap(); + std::map>> from_tocost; + for (const auto &key_value : op2_cost_map) { + const auto &from_to_strategies = key_value.first; + const auto &costlist = key_value.second; + from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist)); + } + op2_edges_reorganised_cost[i] = from_tocost; + } + + // Merge op2 into op1 + const auto &op1_old_stra_cost = op1->GetStrategyCost(); + const auto &op2_old_stra_cost = op2->GetStrategyCost(); + std::vector> op1_new_stra_cost; + + for (auto &op1_stra_cost : op1_old_stra_cost) { + auto op1_old_stra = op1_stra_cost->strategy_ptr; + auto op1_old_costlist = op1_stra_cost->cost_list; + + for (auto &op2_stra_cost : op2_old_stra_cost) { + auto op2_stra = op2_stra_cost->strategy_ptr; + auto op2_costlist = op2_stra_cost->cost_list; + + StrategyPtr op1_new_stra = std::make_shared(*op1_old_stra); + op1_new_stra->CoverStrategy(op2_stra); + CostPtrList op1_new_costlist; + // Calculate new cost for 'op1_new_costlist' + CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist); + std::shared_ptr swc = std::make_shared(op1_new_stra, op1_new_costlist); + op1_new_stra_cost.emplace_back(swc); + + // Set cost for new successive edges of op1 and op2 + for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { + auto &from_tocost = op1_edges_reorganised_cost[i]; + auto &to_cost = from_tocost[op1_old_stra]; + auto &new_cost_map = op1_new_edges_cost[i]; + for (auto &stra_costlit : to_cost) { + auto &to_strategy = stra_costlit.first; + auto &edge_costlist = stra_costlit.second; + CostPtrKey new_key = {op1_new_stra, to_strategy}; + new_cost_map[new_key] = edge_costlist; + } + } + for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { + auto &from_tocost = op2_edges_reorganised_cost[i]; + auto &to_cost = from_tocost[op2_stra]; + auto &new_cost_map = op2_new_edges_cost[i]; + for (auto &stra_costlist : to_cost) { + auto &to_strategy = stra_costlist.first; + auto &edge_costlist = stra_costlist.second; + CostPtrKey new_key = {op1_new_stra, to_strategy}; + new_cost_map[new_key] = edge_costlist; + } + } + } + } + op1->SetStrategyCost(op1_new_stra_cost); + op2->SetNotAlive(); + + // Update the edges incident to op1, and edges incident to op2 + for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { + auto &new_cost_map = op1_new_edges_cost[i]; + auto &ith_edge = op1_old_succ_edges[i]; + + std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); + std::shared_ptr new_edge; + if (ith_edge->is_combined()) { + std::vector output_indexs, input_indexs; + output_indexs = ith_edge->prev_op_output_indexs(); + input_indexs = ith_edge->next_op_input_indexs(); + new_edge = + std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); + } else { + size_t output_index, input_index; + output_index = ith_edge->prev_op_output_index(); + input_index = ith_edge->next_op_input_index(); + new_edge = + std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); + } + new_edge->SetCostMapAndInputOutput(new_cost_map); + // replace the old successive edges with the new ones. + op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); + ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); + op1_new_succ_edges[i] = new_edge; + } + for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { + auto &new_cost_map = op2_new_edges_cost[i]; + auto &ith_edge = op2_old_succ_edges[i]; + const auto &destination = ith_edge->next_operator(); + + std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); + std::shared_ptr new_edge; + if (ith_edge->is_combined()) { + std::vector output_indexs, input_indexs; + output_indexs = ith_edge->prev_op_output_indexs(); + input_indexs = ith_edge->next_op_input_indexs(); + new_edge = std::make_shared(new_edge_name, op1, destination, output_indexs, input_indexs, true); + } else { + size_t output_index, input_index; + output_index = ith_edge->prev_op_output_index(); + input_index = ith_edge->next_op_input_index(); + new_edge = std::make_shared(new_edge_name, op1, destination, output_index, input_index, false); + } + new_edge->SetCostMapAndInputOutput(new_cost_map); + // replace the old successive edges with the new ones. + destination->ReplacePreEdge(op2, new_edge); + op1->AddSuccEdge(new_edge); + op2_new_succ_edges[i] = new_edge; + } + MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; + return {op1_new_succ_edges, op2_new_succ_edges}; +} + // Check the graph whether a TriangleElimination can be performed std::pair> CostGraph::CheckTriangleElimination() const { for (auto &op : ops_) { @@ -1395,7 +1580,7 @@ Status CostGraph::InitSelectedStrategy() { if (stra.empty()) { MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; } - std::vector stra_inputs = {stra}; + Strategys stra_inputs = {stra}; StrategyPtr reshape_stra = std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); reshape_info->set_strategy(reshape_stra); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 87f13e3383..116d188c0e 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ #include #include @@ -23,7 +23,7 @@ #include #include #include "mindspore/ccsrc/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/auto_parallel/edge_costmodel.h" #include "frontend/parallel/costmodel_context.h" #include "frontend/parallel/ops_info/operator_info.h" @@ -31,24 +31,6 @@ namespace mindspore { namespace parallel { -#define OPERATOR_TO_OPERATOR_CONNECTOR "-" -#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) -#define DEFAULT_COST_MODEL_ALPHA 1.0 -#define DEFAULT_COST_MODEL_BETA 400.0 -#define DEFAULT_COST_MODEL_GAMMA 0.001 -#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true -#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 -#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 -#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 -#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false -#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 -#define DEFAULT_FULLY_USE_DEVICES true -#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false -#define DEFAULT_IS_MULTI_SUBGRAPHS false -#define DEFAULT_RUN_PHASE 0 -#define TRAINING_PHASE 0 -#define INFERENCE_PHASE 1 - class CostGraph; using CostGraphPtr = std::shared_ptr; extern CostGraphPtr entire_costgraph; @@ -73,7 +55,7 @@ class CostGraph { CostGraph() { dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA; + costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; } ~CostGraph() = default; void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } @@ -179,6 +161,14 @@ class CostGraph { void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, const StrategyPtr &, const CostPtrList &, std::vector, CostPtrList &, CostPtrList &, CostPtrList *); + // Return . we merge 'op2' into 'op1' + std::pair CheckSourceElimination() const; + void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &, + CostPtrList *); + // We merge 'op2' into op1. The returned value are ''. 'Edges1' are newly updated edges for 'op1', + // 'Edges2' are newly updated edges for 'op2'. + std::pair>, std::vector>> EliminationSources( + OperatorInfoPtr op1, OperatorInfoPtr op2); // Calculate memory cost for training phase or inference phase. Status CalculateMemoryCost(); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then @@ -235,4 +225,4 @@ class CostGraph { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc index aaf3fdff3c..63524ec3fe 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -198,8 +198,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector &inputs // this operator uses double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { - TensorInfo input0_info = inputs[0]; - Shape input0_slice_shape = input0_info.slice_shape(); + TensorInfo input0 = inputs[0]; + Shape input0_slice_shape = input0.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); } @@ -240,12 +240,16 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, +double SoftmaxCost::GetForwardComputationCost(const std::vector &, const std::vector &outputs, int32_t) const { - // In the forward phase, the computation cost = slice(A) - TensorInfo input0 = inputs[0]; - Shape input0_slice_shape = input0.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); + if (outputs.empty() || outputs_type_lengths_.empty()) { + MS_LOG(EXCEPTION) << "The outputs or outputs_type_length is empty"; + } + + // use output for Tile operator + TensorInfo output_info = outputs[0]; + Shape output_slice_shape = output_info.slice_shape(); + return ListProduct(output_slice_shape) * static_cast(outputs_type_lengths_[0]); } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index dda597bd1f..8cd1370fb5 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -170,6 +170,8 @@ class ActivationCost : public OperatorCost { using ActivationCostPtr = std::shared_ptr; using TransposeCost = ActivationCost; using TransposeCostPtr = std::shared_ptr; +using StridedSliceCost = ActivationCost; +using StridedSliceCostPtr = std::shared_ptr; class SoftmaxCost : public OperatorCost { public: @@ -195,6 +197,8 @@ class SoftmaxCost : public OperatorCost { int32_t) const override; }; using SoftmaxCostPtr = std::shared_ptr; +using TileCost = SoftmaxCost; +using TileCostPtr = std::shared_ptr; class TmpIdentityCost : public OperatorCost { public: diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 68b776155a..e51bd579f3 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -45,10 +45,9 @@ void GenerateStrategy(const std::shared_ptr &graph, const std::vector> PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies; +Strategys PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + Strategys strategies; auto attrs = ops[iter_ops]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); @@ -105,41 +104,40 @@ std::vector> PrepareMatMul(const std::shared_ptr &gr } for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - std::vector s; + Dimensions s; if (transpose_a && (iter_op_inputs == 0)) { s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); } else if (transpose_b && (iter_op_inputs == 1)) { s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); } else { s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); } strategies.push_back(s); } return strategies; } -std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { - std::vector> strategies; +Strategys PrepareBiasAdd(const std::shared_ptr &s) { + Strategys strategies; strategies.push_back(*s); - std::vector s_biasadd; + Dimensions s_biasadd; s_biasadd.push_back(s->at(1)); strategies.push_back(s_biasadd); return strategies; } -std::vector> PrepareOneHot(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); +Strategys PrepareOneHot(const std::shared_ptr &graph, const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); int32_t axis = -1; auto iter = ops[iter_ops]->attrs().find(AXIS); @@ -158,15 +156,14 @@ std::vector> PrepareOneHot(const std::shared_ptr &gr graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; } - std::vector s_empty = {}; + Dimensions s_empty = {}; strategies.push_back(s_empty); strategies.push_back(s_empty); return strategies; } -std::vector> PrepareGatherV2(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector> strategies; +Strategys PrepareGatherV2(const std::vector> &ops, const size_t iter_ops, Dimensions s) { + Strategys strategies; auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); if (axis_input < 0) { @@ -179,23 +176,104 @@ std::vector> PrepareGatherV2(const std::vectorname().find("Info"); - auto name = ops[iter_ops]->name().substr(0, pos); - if (name == "GatherV2") { - return strategies; + return strategies; +} + +Strategys PrepareGatherV2P(const std::vector> &ops, const size_t iter_ops, Dimensions s) { + Strategys strategies; + + auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape(); + Dimensions index(output_shape.size() - 1, 0); + for (size_t i = 0; i < index.size(); i++) { + index[i] = i; + } + std::sort(index.begin(), index.end(), + [&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); }); + std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; }); + index.insert(index.begin(), 0); + + Dimensions strategie(output_shape.size(), 1); + size_t num_device = g_device_manager->DeviceNum(); + size_t cut = 1; + for (size_t i = 0; i < index.size(); i++) { + while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) { + output_shape[index[i]] /= 2; + cut *= 2; + strategie[index[i]] *= 2; + } + if (cut == num_device) { + break; + } } - std::vector s_indices; - for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { - s_indices.push_back(1); + auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); + if (axis_input < 0) { + axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); + } + int32_t axis = axis_input; + if (axis >= SizeToInt(s.size())) { + MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; + } + if (axis == 0) { + s.clear(); + s.push_back(1); + for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) { + s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]); + } + strategies.push_back(s); + s.clear(); + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s.push_back(strategie[i]); + } + strategies.push_back(s); + } else if (axis == 1) { + s.clear(); + s.push_back(strategie[0]); + s.push_back(1); + strategies.push_back(s); + s.clear(); + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]); + } + strategies.push_back(s); + } else { + MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1."; } - strategies.push_back(s_indices); return strategies; } -std::vector> PrepareL2Normalize(const std::vector> &ops, - const size_t iter_ops, std::vector s) { +Dimensions PrepareGatherV2POutputStrategy(const std::vector> &ops, + const size_t incoming_op_index) { + auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape(); + Dimensions index(output_shape.size() - 1, 0); + for (size_t i = 0; i < index.size(); i++) { + index[i] = i; + } + std::sort(index.begin(), index.end(), + [&output_shape](const int &a, const int &b) { return (output_shape[a + 1] > output_shape[b + 1]); }); + std::transform(std::begin(index), std::end(index), std::begin(index), [](int x) { return x + 1; }); + index.insert(index.begin(), 0); + + Dimensions strategie(output_shape.size(), 1); + size_t num_device = g_device_manager->DeviceNum(); + size_t cut = 1; + for (size_t i = 0; i < index.size(); i++) { + while (output_shape[index[i]] % 2 == 0 && output_shape[index[i]] > 0 && cut < num_device) { + output_shape[index[i]] /= 2; + cut *= 2; + strategie[index[i]] *= 2; + } + if (cut == num_device) { + break; + } + } + + return strategie; +} + +Strategys PrepareL2Normalize(const std::vector> &ops, const size_t iter_ops, + Dimensions s) { int32_t axis = 0; auto iter = ops[iter_ops]->attrs().find(AXIS); if (iter != ops[iter_ops]->attrs().end()) { @@ -215,14 +293,14 @@ std::vector> PrepareL2Normalize(const std::vector> strategies; + Strategys strategies; strategies.push_back(s); return strategies; } -std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { +Strategys MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_graph, + const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -231,31 +309,31 @@ std::vector> MakeRecSearchStrategy(const std::shared_ptrstrategy(); - std::vector> strategies; + Strategys strategies; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; } size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - std::vector s; + Dimensions s; if (output_size == 4) { s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); } else if (output_size == 2) { s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); } else if (output_size == 1) { s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); } else if (output_size == 0) { s = {}; } else { @@ -266,9 +344,9 @@ std::vector> MakeRecSearchStrategy(const std::shared_ptr> MakeDataParallelStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { +Strategys MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_graph, + const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -277,7 +355,7 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr } StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - std::vector> strategies; + Strategys strategies; size_t max_device_num = g_device_manager->DeviceNum(); size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { @@ -285,7 +363,7 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; } - std::vector s; + Dimensions s; size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); for (size_t dim = 0; dim < input_size; dim++) { if (input_size == 1 || input_size == 2 || input_size == 4) { @@ -318,9 +396,8 @@ std::vector> MakeDataParallelStrategy(const std::shared_ptr return strategies; } -std::vector> PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { +Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { if (ops.empty()) { MS_LOG(EXCEPTION) << "Failure: Operators is empty."; } @@ -348,7 +425,7 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::shared_ptr> &index_list) { for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { - std::vector> strategies; + Strategys strategies; size_t iter_graph = index_list->at(iter_ops); if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); @@ -375,10 +452,10 @@ size_t FindIndexOfOperatorIncoming(const std::vector> & return incoming_op_index; } -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_ops, const size_t iter_graph) { - std::vector s; +Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_ops, const size_t iter_graph) { + Dimensions s; for (auto input : ops[iter_ops]->inputs_tensor_info()) { auto input_stra_dim = input.shape().size(); if (input_stra_dim == 0) { @@ -402,13 +479,23 @@ std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t incoming_op_index) { - std::vector s; - if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || - ops[incoming_op_index]->type() == TRANSPOSE) { +Dimensions PrepareIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t incoming_op_index) { + Dimensions s; + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) { return s; } + if (ops[incoming_op_index]->type() == GATHERV2) { + auto pos = ops[incoming_op_index]->name().find("Info"); + auto name = ops[incoming_op_index]->name().substr(0, pos); + if (name == "GatherV2") { + return s; + } else if (name == "GatherV2P") { + return PrepareGatherV2POutputStrategy(ops, incoming_op_index); + } else { + MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl; + } + } auto strategy = ops[incoming_op_index]->selected_strategy(); if (strategy->GetInputNumber() == 0) { return s; @@ -426,8 +513,8 @@ std::vector PrepareIncomingOperatorInputStrategy(const std::vector GetAxisList(const std::vector> &ops, const int iter_ops) { - std::vector axis_list; +Dimensions GetAxisList(const std::vector> &ops, const int iter_ops) { + Dimensions axis_list; auto axis_param = ops[iter_ops]->attrs().find(AXIS)->second; std::vector elements; if (axis_param->isa()) { @@ -448,10 +535,10 @@ std::vector GetAxisList(const std::vector return axis_list; } -std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - std::vector s_Squeeze; - std::vector stra_dim_list; +Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, + const size_t incoming_op_index, Dimensions s) { + Dimensions s_Squeeze; + Dimensions stra_dim_list; for (size_t i = 0; i < s.size(); i++) { stra_dim_list.push_back(i); } @@ -488,8 +575,8 @@ bool GetKeepDims(const std::vector> &ops, const si return keepdims; } -std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { - std::vector dim_list; +Dimensions GetDimList(const std::vector> &ops, const size_t iter_ops) { + Dimensions dim_list; bool keep_dims = GetKeepDims(ops, iter_ops); if (keep_dims != false) { return dim_list; @@ -499,10 +586,13 @@ std::vector GetDimList(const std::vector> if (input_value.back()->isa()) { auto attr_axis = GetValue>(input_value.back()); if (attr_axis.empty()) { - MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; - } - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + for (size_t i = 0; i < input_dim; i++) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } } } else if (input_value.back()->isa()) { int axis = GetValue(input_value.back()); @@ -513,10 +603,10 @@ std::vector GetDimList(const std::vector> return dim_list; } -std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - std::vector s_Reduce; - std::vector axis_list; +Dimensions ModifyStrategyIfReduceIncoming(const std::vector> &ops, + const size_t incoming_op_index, Dimensions s) { + Dimensions s_Reduce; + Dimensions axis_list; for (size_t i = 0; i < s.size(); i++) { axis_list.push_back(i); } @@ -536,8 +626,8 @@ std::vector ModifyStrategyIfReduceIncoming(const std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { - std::vector dim_list; +Dimensions GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { + Dimensions dim_list; auto iter = ops[iter_ops]->attrs().find(AXIS); if (iter == ops[iter_ops]->attrs().end()) { MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; @@ -564,15 +654,15 @@ std::vector GetDimListFromAttrs(const std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { +Dimensions ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, Dimensions s) { bool keepdims = GetKeepDims(ops, incoming_op_index); if (keepdims) { return s; } - std::vector s_Arg; - std::vector axis_list; + Dimensions s_Arg; + Dimensions axis_list; for (size_t i = 0; i < s.size(); i++) { axis_list.push_back(i); } @@ -592,9 +682,9 @@ std::vector ModifyStrategyIfArgIncoming(const std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t incoming_op_index) { - std::vector s; +Dimensions CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t incoming_op_index) { + Dimensions s; s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); if (s.size() != 0) { if (ops[incoming_op_index]->type() == SQUEEZE) { @@ -611,11 +701,9 @@ std::vector CopyIncomingOperatorInputStrategy(const std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, - std::vector basic_stra) { - std::vector s_empty = {}; - std::vector> stra; +Strategys GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, + Dimensions basic_stra) { + Strategys stra; MS_EXCEPTION_IF_NULL(ops[iter_ops]); if (basic_stra.size() == 0) { @@ -626,26 +714,124 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect return stra; } - auto s_ptr = std::make_shared>(basic_stra); + auto s_ptr = std::make_shared(basic_stra); if (ops[iter_ops]->type() == BIAS_ADD) { return PrepareBiasAdd(s_ptr); } if (ops[iter_ops]->type() == GATHERV2) { - return PrepareGatherV2(ops, iter_ops, basic_stra); + auto pos = ops[iter_ops]->name().find("Info"); + auto name = ops[iter_ops]->name().substr(0, pos); + if (name == "GatherV2") { + return PrepareGatherV2(ops, iter_ops, basic_stra); + } else if (name == "GatherV2P") { + return PrepareGatherV2P(ops, iter_ops, basic_stra); + } else { + MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl; + } } if (ops[iter_ops]->type() == L2_NORMALIZE) { return PrepareL2Normalize(ops, iter_ops, basic_stra); } + if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL || + ops[iter_ops]->type() == DIV) { + return CheckBroadcast(ops, iter_ops, basic_stra); + } + + return CheckDivisible(ops, iter_ops, basic_stra); +} + +// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc. +Strategys CheckBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s) { + Strategys stra; + + size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size(); + + // Do Broadcasting in the second tensor. + if (second_tensor_dim < first_tensor_dim) { + bool braoadcast_first_tensor = false; + // Push back the first tensor's strategy. + stra.push_back(s); + // Push back the second tensor's strategy after applying broadcast. + stra.push_back(ApplyBroadcast(ops, iter_ops, s, second_tensor_dim, first_tensor_dim, braoadcast_first_tensor)); + } else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor. + bool braoadcast_first_tensor = true; + // Push back the first tensor's strategy after applying broadcast. + stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, braoadcast_first_tensor)); + // Push back the second tensor's strategy. + stra.push_back(s); + } else { // Broadcasting can be ignored or No broadcasting needs to be applied. + stra = CheckDivisible(ops, iter_ops, s); + } + return stra; +} + +Dimensions ApplyBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s, + size_t target_tensor_dim, size_t refer_tensor_dim, bool braoadcast_first_tensor) { + Dimensions s_empty = {}; + Dimensions s_broadcast; + int target_tensor_index = 0; + int refer_tensor_index = 0; + + // Indexing target and refer tensor. + if (braoadcast_first_tensor) { + target_tensor_index = 0; + refer_tensor_index = 1; + } else { + target_tensor_index = 1; + refer_tensor_index = 0; + } + + // When target tensor with an empty dim. + if (target_tensor_dim == 0) { + return s_empty; + } else if (target_tensor_dim == 1) { // When target tensor with a single dim. + bool broadcast_dim_found = false; + for (size_t iter = 0; iter < refer_tensor_dim; iter++) { + // Find and copy that dim's strategy from the refer tensor. + if ((ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] == + ops[iter_ops]->inputs_tensor_info()[target_tensor_index].shape()[0]) && + (ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] > 1) && + (refer_tensor_dim == s.size())) { + s_broadcast.push_back(s.at(iter)); + broadcast_dim_found = true; + break; + } + } + // Cannot decide which dim it is, push back one. + if (broadcast_dim_found == false) { + s_broadcast.push_back(1); + } + } else { + // Cannot decide which dim needs to do broadcast, push back one(s). + for (size_t iter = 0; iter < target_tensor_dim; iter++) { + s_broadcast.push_back(1); + } + } + + return s_broadcast; +} + +// Check whether the operator can be divided by the current strategy. +Strategys CheckDivisible(const std::vector> &ops, const size_t iter_ops, + Dimensions basic_stra) { + Dimensions s_empty = {}; + Strategys stra; + + // For all the input tensors. for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + // If input tensor is empty, return strategy as void. if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { stra.push_back(s_empty); continue; } - std::vector tmp_stra = basic_stra; + Dimensions tmp_stra = basic_stra; bool modified = false; + + // Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead. for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { tmp_stra[j] = 1; @@ -658,6 +844,7 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect stra.push_back(basic_stra); } } + return stra; } @@ -673,8 +860,8 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &gra for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { size_t iter_ops = no_stra_op_list->at(iter_list - 1); - std::vector> stra; - std::vector s; + Strategys stra; + Dimensions s; size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); if (incoming_op_index != SIZE_MAX) { auto iter_graph = index_list->at(incoming_op_index); @@ -701,9 +888,9 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &gra } } -std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector s_Squeeze; +Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, + Dimensions s) { + Dimensions s_Squeeze; auto axis_list = GetAxisList(ops, iter_ops); size_t s_index = 0; size_t axis_list_index = 0; @@ -728,10 +915,10 @@ std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, - const std::vector> &input_tensor_names, - const size_t iter_ops) { - std::vector s; +Dimensions CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops) { + Dimensions s; if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || @@ -775,8 +962,8 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vectorsize(); iter_list > 0; iter_list--) { auto iter_ops = no_stra_op_list->at(iter_list - 1); - std::vector> stra; - std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); + Strategys stra; + Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); @@ -815,8 +1002,8 @@ void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) { auto iter_ops = no_stra_op_list->at(iter_list); - std::vector> stra; - std::vector s; + Strategys stra; + Dimensions s; size_t max_dim_num = 0; for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 9acd05e0a9..2263deb588 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -31,61 +31,63 @@ void GenerateStrategy(const std::shared_ptr &graph, const std::vector>> &eli_list, const std::vector> &input_tensor_names, const std::shared_ptr> &index_list); -std::vector> PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareBiasAdd(const std::shared_ptr> &s); -std::vector> PrepareOneHot(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareGatherV2(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector> PrepareL2Normalize(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); +Strategys PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +Strategys PrepareBiasAdd(const std::shared_ptr &s); +Strategys PrepareOneHot(const std::shared_ptr &graph, const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +Strategys PrepareGatherV2(const std::vector> &ops, const size_t iter_ops, Dimensions s); +Strategys PrepareGatherV2P(const std::vector> &ops, const size_t iter_ops, Dimensions s); +Dimensions PrepareGatherV2POutputStrategy(const std::vector> &ops, + const size_t incoming_op_index); +Strategys PrepareL2Normalize(const std::vector> &ops, const size_t iter_ops, + Dimensions s); +Strategys MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_graph, + const size_t iter_ops); +Strategys CheckBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s); +Dimensions ApplyBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s, + size_t target_tensor_dim, size_t refer_tensor_dim, bool braoadcast_first_tensor); +Strategys CheckDivisible(const std::vector> &ops, const size_t iter_ops, Dimensions s); +Strategys MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_graph, + const size_t iter_ops); +Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::shared_ptr> &index_list); size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, const size_t iter_ops); -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_ops, const size_t iter_graph); -std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t incoming_op_index); -std::vector GetAxisList(const std::vector> &ops, const int iter_ops); -std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); +Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_ops, const size_t iter_graph); +Dimensions PrepareIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t incoming_op_index); +Dimensions GetAxisList(const std::vector> &ops, const int iter_ops); +Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, + const size_t incoming_op_index, Dimensions s); bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); -std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); -std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); -std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t incoming_op_index); -std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, - std::vector basic_stra); +Dimensions GetDimList(const std::vector> &ops, const size_t iter_ops); +Dimensions ModifyStrategyIfReduceIncoming(const std::vector> &ops, + const size_t incoming_op_index, Dimensions s); +Dimensions GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); +Dimensions ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, Dimensions s); +Dimensions CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t incoming_op_index); +Strategys GenerateStrategiesFromStrategy(const std::vector> &ops, const size_t iter_ops, + Dimensions basic_stra); void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, const std::vector> &ops, const std::vector> &input_tensor_names, const std::shared_ptr> &index_list, const std::shared_ptr> &no_stra_op_list); -std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, - const std::vector> &input_tensor_names, - const size_t iter_ops); +Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, const size_t iter_ops, + Dimensions s); +Dimensions CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops); void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, const std::vector> &input_tensor_names, const std::shared_ptr> &no_stra_op_list); diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 7164660be0..24c6823a1a 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -24,12 +24,12 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/device_manager.h" namespace mindspore { namespace parallel { -static std::map> param_shapes; +static std::map param_shapes; std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL}; @@ -44,7 +44,10 @@ std::shared_ptr ParallelContext::GetInstance() { return inst_context_; } -ParallelContext::ParallelContext() { Reset(); } +ParallelContext::ParallelContext() { + communication_backend_ = HCCL_BACKEND; + Reset(); +} void ParallelContext::Reset() { mirror_mean_ = false; @@ -53,7 +56,6 @@ void ParallelContext::Reset() { loss_repeated_mean_ = true; device_num_ = 1; global_rank_ = 0; - communication_backend_ = HCCL_BACKEND; device_num_is_set_ = false; global_rank_is_set_ = false; parallel_mode_ = STAND_ALONE; @@ -63,6 +65,8 @@ void ParallelContext::Reset() { strategy_ckpt_load_file_ = ""; strategy_ckpt_save_file_ = ""; enable_parallel_optimizer_ = false; + all_reduce_fusion_split_indices_.clear(); + all_reduce_fusion_split_sizes_.clear(); } void ParallelContext::set_device_num(int32_t device_num) { @@ -169,7 +173,7 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); return; } - std::vector shape = iter->second; + Shape shape = iter->second; std::shared_ptr base_shape = std::make_shared(shape); ptr->set_shape(base_shape); MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; @@ -185,7 +189,10 @@ void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, cons return; } - std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); + std::vector shape_int = dyn_cast(ptr->GetShapeTrack())->shape(); + Shape shape; + (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape), + [](const int &value) { return static_cast(value); }); auto ret = param_shapes.try_emplace(param_node->name(), shape); if (!ret.second) { MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 1bb40d5c29..3436372641 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ -#define MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ #include #include @@ -28,7 +28,7 @@ #include "utils/convert_utils.h" #include "ir/anf.h" #include "ir/func_graph.h" -#include "debug/info.h" +#include "utils/info.h" #include "abstract/abstract_value.h" namespace mindspore { @@ -139,4 +139,4 @@ void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, cons } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc index 67d087eabd..536895c8de 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -19,7 +19,7 @@ #include #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" -#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "utils/ms_context.h" namespace mindspore { namespace parallel { @@ -41,7 +41,7 @@ CostModelContext::CostModelContext() { void CostModelContext::ResetCostModel() { device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA; + costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; @@ -66,6 +66,12 @@ void CostModelContext::ResetAlgoParameters() { elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; } +void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { + if (device_target == kGPUDevice) { + costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; + } +} + void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } @@ -128,5 +134,14 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { } void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } + +struct CostRegister { + CostRegister() { + MsContext::device_seter([](const std::string &device_target) { + CostModelContext::GetInstance()->set_costmodel_context_for_device(device_target); + }); + } + ~CostRegister() = default; +} cost_regsiter; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h index bddab683ff..b1668e13ef 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.h +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.h @@ -14,17 +14,36 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_COSTMODEL_CONTEXT_H_ -#define MINDSPORE_CCSRC_PARALLEL_COSTMODEL_CONTEXT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_COSTMODEL_CONTEXT_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_COSTMODEL_CONTEXT_H_ #include #include #include #include "utils/log_adapter.h" +#include "utils/ms_context.h" namespace mindspore { namespace parallel { +#define OPERATOR_TO_OPERATOR_CONNECTOR "-" +#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) +#define DEFAULT_COST_MODEL_ALPHA 1.0 +#define DEFAULT_COST_MODEL_BETA_ASCEND 400.0 // for 'device_target = Ascend' +#define DEFAULT_COST_MODEL_BETA_GPU 50.0 // for 'device_target = GPU' +#define DEFAULT_COST_MODEL_GAMMA 0.001 +#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true +#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 +#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 +#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 +#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false +#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 +#define DEFAULT_FULLY_USE_DEVICES true +#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false +#define DEFAULT_IS_MULTI_SUBGRAPHS false +#define DEFAULT_RUN_PHASE 0 +#define TRAINING_PHASE 0 +#define INFERENCE_PHASE 1 class CostModelContext { public: ~CostModelContext() = default; @@ -35,6 +54,7 @@ class CostModelContext { static std::shared_ptr GetInstance(); + void set_costmodel_context_for_device(const std::string &); // DEVICE_MEMORY_CAPACITY void set_device_memory_capacity(double); double device_memory_capacity() const { return device_memory_capacity_; } @@ -178,4 +198,4 @@ class CostModelContext { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_COSTMODEL_CONTEXT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_COSTMODEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device.h b/mindspore/ccsrc/frontend/parallel/device.h index c9633623d2..c69438b23e 100644 --- a/mindspore/ccsrc/frontend/parallel/device.h +++ b/mindspore/ccsrc/frontend/parallel/device.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_H_ #include #include @@ -42,4 +42,4 @@ class Device { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.cc b/mindspore/ccsrc/frontend/parallel/device_manager.cc index d3657afdb8..3a6f6878c6 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/device_manager.cc @@ -345,9 +345,6 @@ std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { // name. Otherwise, let the pointer g point to that group. Group DeviceManager::CreateGroup(const std::string &group_name, const std::vector &devices) { - if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { - MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; - } Group g; (void)gm_.CreateGroup(group_name, devices, &g); return g; diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 654acd9dff..3023f4a355 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_ #include #include @@ -25,7 +25,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/device.h" #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/group_manager.h" @@ -127,4 +127,4 @@ class DeviceManager { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.cc b/mindspore/ccsrc/frontend/parallel/device_matrix.cc index 9cc85d9701..e54f6d84ee 100644 --- a/mindspore/ccsrc/frontend/parallel/device_matrix.cc +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.cc @@ -159,7 +159,7 @@ std::string ShapeToString(const Shape &shape) { return str + "]"; } -std::string ListToString(const std::vector &list) { +std::string ListToString(const RankList &list) { std::string str = "["; for (auto &element : list) { str += std::to_string(element) + ", "; diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.h b/mindspore/ccsrc/frontend/parallel/device_matrix.h index f1e7acec39..bf3bea1f23 100644 --- a/mindspore/ccsrc/frontend/parallel/device_matrix.h +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_ #include #include @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { using RankList = std::vector; -using Shape = std::vector; +using Shape = std::vector; class DeviceMatrix { public: @@ -48,8 +48,8 @@ class DeviceMatrix { }; std::string ShapeToString(const Shape &shape); -std::string ListToString(const std::vector &list); +std::string ListToString(const RankList &list); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_ diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 3ba40fade9..88f5b67355 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ -#define MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ #include #include @@ -133,7 +133,10 @@ REGISTER(SigmoidCrossEntropyWithLogitsInfo); REGISTER(SquareInfo); REGISTER(GatherV2PInfo); REGISTER(EmbeddingLookupInfo); +REGISTER(TileInfo); +REGISTER(StridedSliceInfo); +REGISTER(DropoutInfo); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h index b3ef54a22e..b6e417fab8 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ #include #include @@ -66,4 +66,4 @@ class GenerateGraph { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index 21298697f4..fba348229c 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "ir/func_graph.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/graph_util/graph_info.h" @@ -37,14 +37,21 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { for (auto para : graph_params) { std::string name = std::static_pointer_cast(para)->name(); - std::shared_ptr tensor_layout = std::static_pointer_cast(para)->tensor_layout(); + auto tensor_layout = para->user_data(); if (tensor_layout == nullptr) { MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; } else { auto device_arrangement = tensor_layout->device_arrangement().array(); auto tensor_map = tensor_layout->tensor_map().array(); auto slice_shape = tensor_layout->slice_shape().array(); - std::vector> layout = {device_arrangement, tensor_map, slice_shape}; + int32_t _field_size = tensor_layout->get_field_size(); + Shape field_size; + if (_field_size != 0) { + field_size.push_back(_field_size); + } else { + field_size = {0}; + } + std::vector layout = {device_arrangement, tensor_map, slice_shape, field_size}; dict[py::str(name)] = layout; MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); } @@ -63,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { if (node->isa()) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - auto distributed_operation_info = cnode->operator_info(); + auto distributed_operation_info = cnode->user_data(); if (distributed_operation_info != nullptr) { auto strategyPtr = distributed_operation_info->strategy(); if (strategyPtr != nullptr) { diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h index e34d628b2b..083e27243a 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ #include "pybind11/stl.h" #include "pybind11/pybind11.h" @@ -30,4 +30,4 @@ py::dict GetAllreduceFusion(const FuncGraphPtr &graph); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc index 45a88c3a23..0121c70d40 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc @@ -19,8 +19,8 @@ #include "debug/anf_ir_utils.h" #include "debug/draw.h" #include "ir/func_graph.h" -#include "utils/context/ms_context.h" -#include "utils/graph_utils.h" +#include "utils/ms_context.h" +#include "ir/graph_utils.h" namespace mindspore { namespace parallel { diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h index de800f0981..a8837fe265 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GRAPH_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GRAPH_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_INFO_H_ #include #include @@ -29,4 +29,4 @@ void DumpGraph(const FuncGraphPtr &root, const std::string &name); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GRAPH_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index e50df2818b..5d9c999b8b 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -38,7 +38,12 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { if (!para_ptr->has_default()) { return false; } - return para_ptr->default_param()->requires_grad(); + auto obj = py::cast(para_ptr->default_param()); + auto param_value = py::cast(obj.attr("_value")); + if (param_value == nullptr) { + return false; + } + return param_value->requires_grad(); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h index 6037c466cd..ea97747e1d 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ #include #include "base/base.h" @@ -28,4 +28,4 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc index 8929af7b0b..98fca25b3d 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -70,11 +70,9 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto mindspore::parallel::Group *const group) { // it is simple to use size to determine whether it is a world group uint32_t world_size = 0; - if (world_group_ != NCCL_WORLD_GROUP) { - (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); - } + (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); - if ((world_group_ == NCCL_WORLD_GROUP) || (devices.size() == world_size)) { + if (devices.size() == world_size) { auto it = groups_.find(world_group_); if (it == groups_.end()) { (void)group->Init(world_group_, devices); diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.h b/mindspore/ccsrc/frontend/parallel/group_manager.h index b9cf9663b0..5d4eaef815 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.h +++ b/mindspore/ccsrc/frontend/parallel/group_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ -#define MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GROUP_MANAGER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GROUP_MANAGER_H_ #include #include @@ -72,4 +72,4 @@ class GroupManager { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GROUP_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/node_check.h b/mindspore/ccsrc/frontend/parallel/node_check.h index 8b628f31b1..ac8f79ee12 100644 --- a/mindspore/ccsrc/frontend/parallel/node_check.h +++ b/mindspore/ccsrc/frontend/parallel/node_check.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_NODE_CHECK_H_ -#define MINDSPORE_CCSRC_PARALLEL_NODE_CHECK_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_NODE_CHECK_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_NODE_CHECK_H_ #include "ir/primitive.h" @@ -25,4 +25,4 @@ bool IsInBlackList(const PrimitivePtr &prim); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_NODE_CHECK_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_NODE_CHECK_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc index 35cac1480c..3d37e98c51 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "ir/value.h" #include "frontend/parallel/auto_parallel/costmodel.h" @@ -54,6 +56,29 @@ Status Activation::CheckStrategy(const StrategyPtr &strategy) { return SUCCESS; } +Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + // dropout don't support repeated calculation + CheckGlobalDeviceManager(); + auto input_strategy = strategy->GetInputDim().at(0); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies()); + if (IntToSize(product_p) != dev_num) { + MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; + return FAILED; + } + + return SUCCESS; +} + Status ActivationInfo::GetAttrs() { if (attrs_.size() < ACTIVATION_ATTR_SIZE) { MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; @@ -120,6 +145,27 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return SUCCESS; } +Status DropoutInfo::GenerateStrategies(int32_t stage_id) { + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + Status Softmax::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { @@ -130,7 +176,7 @@ Status Softmax::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); Dimensions input_strategy = stra.at(0); for (auto &element : axis_) { @@ -181,7 +227,7 @@ Status Softmax::GetAttrs() { MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; return FAILED; } - MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_); + MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ListToString(axis_); } else { MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int."; return FAILED; @@ -258,7 +304,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { } Status ActivationBase::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); dev_matrix_shape_ = input_strategy; @@ -296,11 +342,11 @@ Status ActivationBase::InferForwardCommunication() { } Status ActivationBase::InferTensorMap() { - std::vector tensor_map_index; + Shape tensor_map_index; size_t size = inputs_shape_.at(0).size(); // such as 4: tensor_map_index [3,2,1,0] for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - i - 1)); + tensor_map_index.push_back((int64_t)(size - i - 1)); } inputs_tensor_map_.push_back(tensor_map_index); @@ -334,6 +380,32 @@ Status ActivationBase::InferTensorInfo() { return SUCCESS; } +Status DropoutInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + // dropout has two outputs + Strategys outputs_strategy = {inputs_strategy.at(0), inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + TensorLayout input_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + return FAILED; + } + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + inputs_tensor_info_.push_back(input_tensor_info); + // the two outputs of dropout all have the same tensor_info as input + outputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(input_tensor_info); + + return SUCCESS; +} + Status ActivationBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; @@ -425,7 +497,7 @@ Status ExpandDimsInfo::InferTensorMap() { // for example: if the dimension of input is 3, and the axis is 2, // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] - std::vector input_tensor_map, output_tensor_map; + Shape input_tensor_map, output_tensor_map; size_t size = inputs_shape_[0].size(); for (size_t i = 0; i < size; ++i) { input_tensor_map.push_back(SizeToInt(size - i - 1)); @@ -607,7 +679,7 @@ Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { Status SqueezeInfo::InferTensorMap() { // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] - std::vector input_tensor_map, output_tensor_map; + Shape input_tensor_map, output_tensor_map; if (inputs_shape_.empty()) { MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h index a74707efbe..67ddf466ea 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ #include #include @@ -219,6 +219,20 @@ class SigmoidInfo : public ActivationOther { : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SigmoidInfo() override = default; }; + +class DropoutInfo : public ActivationOther { + public: + DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~DropoutInfo() override = default; + Status GenerateStrategies(int32_t stage_id) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override { return SUCCESS; } + Status InferTensorInfo() override; +}; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc index 1dd9c899ca..3517bea32f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc @@ -54,9 +54,9 @@ Shapes ArithmeticBase::InferExpendShape() { return input_shapes; } -std::vector ExpendStrategy(const StrategyPtr &strategy) { - std::vector expend_strategy; - std::vector stra = strategy->GetInputDim(); +Strategys ExpendStrategy(const StrategyPtr &strategy) { + Strategys expend_strategy; + Strategys stra = strategy->GetInputDim(); Dimensions sub_a_strategy = stra.at(0); Dimensions sub_b_strategy = stra.at(1); size_t input_a_size = sub_a_strategy.size(); @@ -83,7 +83,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } Shapes input_shapes = InferExpendShape(); - std::vector expend_strategy = ExpendStrategy(strategy); + Strategys expend_strategy = ExpendStrategy(strategy); Dimensions sub_a_strategy = expend_strategy.at(0); Dimensions sub_b_strategy = expend_strategy.at(1); Shape input_a_shape = input_shapes.at(0); @@ -103,7 +103,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { } Status ArithmeticBase::InferDevMatrixShape() { - std::vector expend_strategy = ExpendStrategy(strategy_); + Strategys expend_strategy = ExpendStrategy(strategy_); Dimensions sub_a_strategy = expend_strategy.at(0); Dimensions sub_b_strategy = expend_strategy.at(1); Shape dev_shape; @@ -123,7 +123,7 @@ TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shap TensorMap tensor_map_index; for (size_t i = 0; i < strategy.size(); ++i) { if (strategy[i] == dev_matrix_shape[i]) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i)); + tensor_map_index.push_back((int64_t)(LAST_INDEX(strategy.size()) - i)); } else { tensor_map_index.push_back(-1); } @@ -159,15 +159,15 @@ void ArithmeticBase::ReComputeBatchSplitFlagList() { } Status ArithmeticBase::InferTensorMap() { - std::vector tensor_map_index; - std::vector expend_strategy = ExpendStrategy(strategy_); + Shape tensor_map_index; + Strategys expend_strategy = ExpendStrategy(strategy_); Dimensions sub_a_expend_strategy = expend_strategy.at(0); Dimensions sub_b_expend_strategy = expend_strategy.at(1); Strategys stra = strategy_->GetInputDim(); Dimensions sub_a_strategy = stra.at(0); Dimensions sub_b_strategy = stra.at(1); for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i)); + tensor_map_index.push_back((int64_t)(LAST_INDEX(sub_a_expend_strategy.size()) - i)); } Shape dev_shape; @@ -261,7 +261,7 @@ Status ArithmeticBase::InferTensorInfo() { // infer slice shape Shapes inputs_slice_shape, outputs_slice_shape; - std::vector expend_strategy = ExpendStrategy(strategy_); + Strategys expend_strategy = ExpendStrategy(strategy_); Dimensions sub_a_expend_strategy = expend_strategy.at(0); Dimensions sub_b_expend_strategy = expend_strategy.at(1); Strategys inputs_strategy = strategy_->GetInputDim(); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h index 1d347e4ec1..c3f8226e92 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ #include #include @@ -132,4 +132,4 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc index 64aceb90f6..5f727ab55c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc @@ -43,13 +43,13 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { dev_num_ = dev_num; size_t strategy_size = strategy->GetInputNumber(); - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); for (size_t i = 0; i < strategy_size; ++i) { Shape sub_strategy = stra.at(i); size_t strategy_len = sub_strategy.size(); bool flag = false; for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); + int64_t strategy_value = sub_strategy.at(j); if (strategy_value > 1) { if (flag || strategy_value != dev_num_) { if (is_auto_parallel_) { @@ -95,7 +95,7 @@ Status BatchParallelInfo::InferTensorMap() { return FAILED; } for (size_t i = 0; i < inputs_shape_.size(); i++) { - std::vector tensor_map_index; + Shape tensor_map_index; for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { tensor_map_index.push_back(0); @@ -106,7 +106,7 @@ Status BatchParallelInfo::InferTensorMap() { inputs_tensor_map_.push_back(tensor_map_index); } for (size_t i = 0; i < outputs_shape_.size(); i++) { - std::vector tensor_map_index; + Shape tensor_map_index; for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { if (i == 0 && j == 0) { tensor_map_index.push_back(0); @@ -123,7 +123,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() { Strategys outputs_strategy; for (size_t i = 0; i < outputs_shape_.size(); ++i) { - std::vector strategy; + Dimensions strategy; for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { if (i == 0 && j == 0) { strategy.push_back(dev_num_); @@ -201,7 +201,7 @@ Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { is_auto_parallel_ = true; size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); StrategyPtr sp; - std::vector strategy; + Strategys strategy; for (size_t i = 0; i < inputs_shape_.size(); i++) { Shape temp(inputs_shape_[i].size(), 1); if (split_flag_list_[i]) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h index 0ba30c385a..3d47b53b54 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ #include #include @@ -69,4 +69,4 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc index e8b3afba16..25d5e72112 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc @@ -36,11 +36,11 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { } return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); Dimensions sub_a_strategy = stra.at(0); Dimensions sub_b_strategy = stra.at(1); - int32_t channel_a_strategy = sub_a_strategy.at(1); - int32_t channel_b_strategy = sub_b_strategy.at(0); + int64_t channel_a_strategy = sub_a_strategy.at(1); + int64_t channel_b_strategy = sub_b_strategy.at(0); if (channel_a_strategy != channel_b_strategy) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -53,7 +53,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { } Status BiasAddInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions sub_a_strategy = stra.at(0); dev_matrix_shape_ = sub_a_strategy; return SUCCESS; @@ -67,13 +67,13 @@ void BiasAddInfo::ReComputeBatchSplitFlagList() { Status BiasAddInfo::InferTensorMap() { TensorMap sub_a_tensor_map; TensorMap sub_b_tensor_map; - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions sub_a_strategy = stra.at(0); size_t sub_a_strategy_size = sub_a_strategy.size(); for (size_t i = 0; i < sub_a_strategy_size; ++i) { - sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - i)); + sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(sub_a_strategy_size) - i)); } - sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - 1)); + sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(sub_a_strategy_size) - 1)); inputs_tensor_map_.push_back(sub_a_tensor_map); inputs_tensor_map_.push_back(sub_b_tensor_map); @@ -213,7 +213,7 @@ Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; for (auto &sp : sp_vector) { - std::vector tmp_strategy; + Strategys tmp_strategy; Dimensions input0_strategy = sp->GetInputDim()[0]; tmp_strategy.push_back(input0_strategy); // input0 diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h index 3ede65a3ba..62045e0dcc 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ #include @@ -56,4 +56,4 @@ class BiasAddInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h index 2829889846..e2cbf3f6c7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ #include #include @@ -62,4 +62,4 @@ class MinimumInfo : public ArithmeticBase { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc index 3b411ccb0e..c17389ae56 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc @@ -38,7 +38,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); if (stra.size() != 1) { MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; return FAILED; @@ -68,7 +68,7 @@ Status DropoutDoMaskInfo::InferDevMatrixShape() { return FAILED; } - std::vector strategy = strategy_->GetInputDim(); + Strategys strategy = strategy_->GetInputDim(); if (strategy.empty()) { MS_LOG(ERROR) << name_ << ": The strategy is empty"; return FAILED; @@ -84,7 +84,7 @@ Status DropoutDoMaskInfo::InferTensorMap() { return FAILED; } - std::vector tensor_map_index; + Shape tensor_map_index; size_t size = inputs_shape_[0].size(); // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0] for (size_t i = 0; i < size; ++i) { @@ -169,13 +169,13 @@ Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -std::shared_ptr>> DropoutDoMaskInfo::GenerateBatchStrategies() { +std::shared_ptr DropoutDoMaskInfo::GenerateBatchStrategies() { CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions strategy(inputs_shape_[0].size() - 1, 1); (void)strategy.insert(strategy.begin(), SizeToInt(dev_num)); - std::vector strategy_v = {strategy}; - return std::make_shared>>(strategy_v); + Strategys strategy_v = {strategy}; + return std::make_shared(strategy_v); } Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { @@ -259,8 +259,10 @@ void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { if (manager == nullptr) { MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; } - - ValuePtr new_shape = MakeValue(input_slice_shape); + std::vector input_slice_shape_int; + (void)std::transform(input_slice_shape.begin(), input_slice_shape.end(), std::back_inserter(input_slice_shape_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr new_shape = MakeValue(input_slice_shape_int); AnfNodePtr val = NewValueNode(new_shape); (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); } @@ -306,8 +308,10 @@ std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); return replace_ops; } - - ValuePtr new_shape = MakeValue(input_slice_shape); + std::vector input_slice_shape_int; + (void)std::transform(input_slice_shape.begin(), input_slice_shape.end(), std::back_inserter(input_slice_shape_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr new_shape = MakeValue(input_slice_shape_int); Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); OperatorAttrs attrs = {attr_0, attr_1}; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h index ea7d590071..53f8d3e52f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ #include #include @@ -40,7 +40,7 @@ class DropoutDoMaskInfo : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; - std::shared_ptr>> GenerateBatchStrategies() override; + std::shared_ptr GenerateBatchStrategies() override; std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); protected: @@ -57,4 +57,4 @@ using DropoutDoMaskInfoPtr = std::shared_ptr; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h index e25da9e743..77d798d20f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ #include #include @@ -66,4 +66,4 @@ class LogicalNotInfo : public ActivationOther { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc index 4e6e947f68..f9683f3d08 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc @@ -109,7 +109,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { } Status GatherV2Info::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); dev_matrix_shape_ = stra.at(0); return SUCCESS; } @@ -129,8 +129,8 @@ Status GatherV2Info::InferTensorMap() { << outputs_shape_.size(); return FAILED; } - std::vector tensor_map_in; - std::vector tensor_map_out; + Shape tensor_map_in; + Shape tensor_map_out; size_t size = inputs_shape_.at(0).size(); // such as 4: tensor_map_index [3,2,1,0] for (size_t i = 0; i < size; ++i) { @@ -149,7 +149,7 @@ Status GatherV2Info::InferTensorMap() { return FAILED; } - std::vector tensor_map_in_index; + Shape tensor_map_in_index; if (index_size_ >= 1) { tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); } @@ -323,7 +323,7 @@ Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SUCCESS; } -std::shared_ptr>> GatherV2Info::GenerateBatchStrategies() { +std::shared_ptr GatherV2Info::GenerateBatchStrategies() { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " << inputs_shape_.size(); @@ -343,8 +343,8 @@ std::shared_ptr>> GatherV2Info::GenerateBatchSt for (size_t i = 1; i < inputs_shape_[0].size(); i++) { strategy.push_back(1); } - std::vector strategy_v = {strategy}; - return std::make_shared>>(strategy_v); + Strategys strategy_v = {strategy}; + return std::make_shared(strategy_v); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h index b3dc0fab87..d7ceeda2e1 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ #include #include @@ -50,7 +50,7 @@ class GatherV2Info : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - std::shared_ptr>> GenerateBatchStrategies() override; + std::shared_ptr GenerateBatchStrategies() override; protected: Status CheckStrategy(const StrategyPtr &strategy) override; @@ -70,4 +70,4 @@ class GatherV2Info : public OperatorInfo { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index eb3c9900f8..b7e57dd1aa 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -73,8 +73,8 @@ Status GatherV2PInfo::GetAttrs() { MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; return FAILED; } - param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); - index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); + param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); + index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); } else { MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; return FAILED; @@ -93,14 +93,14 @@ Status GatherV2PInfo::GetAttrs() { Status GatherV2PInfo::CheckManualSplit() { auto param_shape = inputs_shape_.at(0); - int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, - [](int32_t s, int32_t shape) { return s + shape; }); + int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, + [](int64_t s, int64_t shape) { return s + shape; }); if (split_shape_sum < param_shape.at(0)) { MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; return FAILED; } - if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { + if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) { MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; return FAILED; } @@ -269,8 +269,8 @@ Status GatherV2PInfo::InferTensorMap() { size_t param_size = inputs_shape_.at(0).size(); size_t index_size = inputs_shape_.at(1).size(); size_t total_size = param_size + index_size; - std::vector tensor_map_index; - std::vector tensor_map_params; + Shape tensor_map_index; + Shape tensor_map_params; auto param_strategy = strategy_->GetInputDim().at(0); if (param_strategy.at(IntToSize(axis_)) != 1) { tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); @@ -288,7 +288,7 @@ Status GatherV2PInfo::InferTensorMap() { } // infer output tensor map - std::vector tensor_map_out; + Shape tensor_map_out; if (param_strategy.at(IntToSize(axis_)) == 1) { // param_strategy(axis) == 1 for (size_t i = 0; i < param_size; ++i) { @@ -427,8 +427,8 @@ Status GatherV2PInfo::InferGroup() { return SUCCESS; } -std::vector GetRankFromGroup(const Group &group) { - std::vector rank_list; +RankList GetRankFromGroup(const Group &group) { + RankList rank_list; auto device_list = group.GetDevicesList(); for (auto &device : device_list) { rank_list.insert(rank_list.end(), device.rank() % 8); @@ -455,6 +455,9 @@ Status GatherV2PInfo::InferForwardCommunication() { MS_LOG(ERROR) << name_ << ": Infer Group failed."; return FAILED; } + if (group_.name().empty()) { + return SUCCESS; + } attr_group = std::make_pair(GROUP, MakeValue(group_.name())); Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); OperatorAttrs attrs = {attr_op, attr_group}; @@ -472,7 +475,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { MS_LOG(ERROR) << "GenerateGraph Init failed"; return FAILED; } - if (manual_split_) { + if (manual_split_ && target_ != CPU) { if (InferOffset() != SUCCESS) { MS_LOG(ERROR) << name_ << ": Infer Bias failed."; return FAILED; @@ -519,7 +522,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { } ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { - if (manual_split_) { + if (manual_split_ && target_ != CPU) { if (ComputeReplaceGraph(cnode) != SUCCESS) { MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; return nullptr; @@ -540,13 +543,25 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { } Status GatherV2PInfo::ComputeReplaceOp() { - if (InferBias() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer offset failed."; - return FAILED; + int64_t bias = 0; + if (manual_split_) { + if (InferOffset() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + bias = index_offset_; + } else { + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + bias = bias_; } + OperatorName op_name = EMBEDDING_LOOKUP; OperatorAttrs attrs; - Attr param_offset = std::make_pair("offset", MakeValue(bias_)); + int32_t bias_int = static_cast(bias); + Attr param_offset = std::make_pair("offset", MakeValue(bias_int)); OperatorParams params = {std::make_pair(param_offset, 3)}; OperatorArgs args = std::make_pair(attrs, params); Operator op = std::make_pair(op_name, args); @@ -620,7 +635,7 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { +std::shared_ptr GatherV2PInfo::GenerateBatchStrategies() { CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions param_strategy(inputs_shape_[0].size(), 1); @@ -629,8 +644,8 @@ std::shared_ptr>> GatherV2PInfo::GenerateBatchS for (size_t i = 1; i < inputs_shape_[1].size(); i++) { index_strategy.push_back(1); } - std::vector strategy_v = {param_strategy, index_strategy}; - return std::make_shared>>(strategy_v); + Strategys strategy_v = {param_strategy, index_strategy}; + return std::make_shared(strategy_v); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index eb26c616d0..42c8fee9e4 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ #include #include @@ -45,7 +45,9 @@ class GatherV2PInfo : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - std::shared_ptr>> GenerateBatchStrategies() override; + std::shared_ptr GenerateBatchStrategies() override; + const std::vector ¶m_split_shapes() const { return param_split_shapes_; } + const std::vector &index_offsets() const { return index_offsets_; } protected: Status CheckStrategy(const StrategyPtr &strategy) override; @@ -56,7 +58,6 @@ class GatherV2PInfo : public OperatorInfo { Status InferTensorMap() override; Status GetAttrs() override; - private: Status ComputeReplaceGraph(const CNodePtr &cnode); Status CheckManualSplit(); Status ComputeReplaceOp(); @@ -67,14 +68,14 @@ class GatherV2PInfo : public OperatorInfo { int32_t axis_; std::string target_ = DEVICE; std::string replace_op_name_ = GATHERV2; - int32_t bias_; - int32_t index_offset_; - int32_t slice_size_; + int64_t bias_; + int64_t index_offset_; + int64_t slice_size_; Shape out_dev_matrix_shape_; Group group_; bool manual_split_ = false; - std::vector param_split_shapes_; - std::vector index_offsets_; + std::vector param_split_shapes_; + std::vector index_offsets_; }; class SparseGatherV2Info : public GatherV2PInfo { @@ -97,4 +98,4 @@ class EmbeddingLookupInfo : public GatherV2PInfo { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc index 3606732156..b9599ae95e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -118,7 +118,7 @@ Status GetNextInfo::Init(const StrategyPtr &strategy) { } Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { - std::vector stras = strategy->GetInputDim(); + Strategys stras = strategy->GetInputDim(); for (Dimensions stra : stras) { if (stra.size() != 0) { if (is_auto_parallel_) { @@ -215,7 +215,15 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { out_shapes[i][0] = out_shapes[i][0] / dev_num_; } } - ValuePtr new_shapes = MakeValue(out_shapes); + std::vector> out_shapes_int; + (void)std::transform(out_shapes.begin(), out_shapes.end(), std::back_inserter(out_shapes_int), + [](const std::vector &shape) { + std::vector shape_int; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), + [](const int64_t &v) { return static_cast(v); }); + return shape_int; + }); + ValuePtr new_shapes = MakeValue(out_shapes_int); Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); Attr attr_shapes = std::make_pair(SHAPES, new_shapes); Attr attr_num = std::make_pair(GETNEXT_NUM, attrs_[GETNEXT_NUM]); @@ -254,7 +262,7 @@ Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status GetNextInfo::GenerateStrategies(int32_t stage_id) { is_auto_parallel_ = true; - std::vector stra; + Strategys stra; StrategyPtr sp = std::make_shared(stage_id, stra); if (SetCostUnderStrategy(sp) == SUCCESS) { MS_LOG(INFO) << name_ << " : Successfully generated strategy."; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h index 36e7a0fcb3..bf30529aaf 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ #include #include @@ -66,4 +66,4 @@ class GetNextInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc index 126fdcf84e..2a513b2d2e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc @@ -37,7 +37,7 @@ Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); Dimensions input_strategy = stra.at(0); int32_t axis_index = axis_; if (axis_ < 0) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h index c74dde4b4b..1e3442e9e3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ #include #include @@ -47,4 +47,4 @@ class L2NormalizeInfo : public Activation { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc index 62d7c6d61e..f0b62370e9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc @@ -49,7 +49,7 @@ Status LayerNormInfo::GetAttrs() { Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { MS_EXCEPTION_IF_NULL(strategy); - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); if (stra.size() != LAYER_NORM_INPUT_SIZE) { MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); return FAILED; @@ -104,7 +104,7 @@ Status LayerNormInfo::InferDevMatrixShape() { MS_LOG(ERROR) << name_ << ": The strategy is null"; return FAILED; } - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); if (stra.empty()) { MS_LOG(ERROR) << name_ << ": The strategy is empty"; return FAILED; @@ -228,7 +228,7 @@ Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector tmp_strategy; + Strategys tmp_strategy; Dimensions input_strategy = sp->GetInputDim()[0]; Dimensions gamma_strategy = input_strategy; (void)gamma_strategy.erase(gamma_strategy.begin(), diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h index 9ee11bb215..908c3f587b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ #include #include @@ -73,4 +73,4 @@ class LayerNormInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc index 889f204fb0..0ef7fa8e4f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc @@ -38,7 +38,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); Dimensions input_strategy = stra.at(0); Dimensions label_strategy = stra.at(1); if (input_strategy != label_strategy) { @@ -52,8 +52,8 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle axis_index = static_cast(input_dim) + axis_; } - int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); - int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); + int64_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); + int64_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); // Dimension corresponding to axis is un-splittable if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) { if (is_auto_parallel_) { @@ -82,21 +82,21 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() { } Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); dev_matrix_shape_ = input_strategy; return SUCCESS; } Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() { - std::vector tensor_map_index; + Shape tensor_map_index; size_t size = inputs_shape_[0].size(); // such as 4: tensor_map_index [3,2,1,0] for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - i - 1)); + tensor_map_index.push_back((int64_t)(size - i - 1)); } - std::vector first_output_tensor_map = {tensor_map_index[0]}; + Shape first_output_tensor_map = {tensor_map_index[0]}; inputs_tensor_map_.push_back(tensor_map_index); // input inputs_tensor_map_.push_back(tensor_map_index); // label outputs_tensor_map_.push_back(first_output_tensor_map); // output-0 diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h index 7e5478bedf..99d5c23e21 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LOSS_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LOSS_INFO_H_ #include #include @@ -64,4 +64,4 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LOSS_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index 60a3d60b39..f6a0c10383 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -105,6 +105,17 @@ Status MatMulBase::GetAttrs() { } } + auto field_size_iter = attrs_.find(FIELD_SIZE); + if (field_size_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(field_size_iter->second); + if (field_size_iter->second->isa()) { + field_size_ = field_size_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of field_size is not int."; + return FAILED; + } + } + // infer inputs dimension size if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; @@ -147,7 +158,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); Dimensions mat_a_strategy = stra.at(0); Dimensions mat_b_strategy = stra.at(1); @@ -196,7 +207,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) { } Status MatMulBase::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions mat_a_strategy = stra.at(0); Dimensions mat_b_strategy = stra.at(1); @@ -268,10 +279,10 @@ Status MatMulBase::InferTensorMap() { size = dev_matrix_shape_.size() - 1; } - std::vector tensor_map_index; + Shape tensor_map_index; // such as 5: tensor_map_index [4,3,2,1,0] for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); + tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i)); } // infer output tensor map: [4,3,2,0], delete the second-from-end element @@ -298,7 +309,7 @@ Status MatMulBase::InferTensorMap() { mat_b_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_b_dimension_)); if (transpose_b_) { // swap the last two elements - int32_t last_value = mat_b_tensor_map.back(); + int64_t last_value = mat_b_tensor_map.back(); mat_b_tensor_map.pop_back(); (void)mat_b_tensor_map.insert( mat_b_tensor_map.begin() + static_cast(LAST_INDEX(mat_b_tensor_map.size())), last_value); @@ -346,6 +357,10 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts return FAILED; } + if (field_size_ != 0) { + mat_b_layout.set_field_size(field_size_); + } + inputs_layout->push_back(mat_a_layout); inputs_layout->push_back(mat_b_layout); outputs_layout->push_back(output_layout); @@ -421,7 +436,7 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { return FAILED; } CheckGlobalDeviceManager(); - std::vector dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); size_t dev_num = dev_list.size(); Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1]; if (transpose_a_) { @@ -488,13 +503,14 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { - int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); + int64_t product = + std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); if (!FULLY_USE_DEVICES) { - if (IntToSize(product) > dev_num) { + if (LongToSize(product) > dev_num) { return FAILED; } } else { - if (IntToSize(product) != dev_num) { + if (LongToSize(product) != dev_num) { return FAILED; } } @@ -535,7 +551,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; } } - std::vector stras; + Strategys stras; stras.push_back(input0_partitions); stras.push_back(input1_partitions); (*sp) = std::make_shared(stage_id, stras); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h index d4e144c2b6..1ff3276474 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_INFO_H_ #include #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "ir/value.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h" #include "frontend/parallel/ops_info/operator_info.h" @@ -62,6 +62,7 @@ class MatMulBase : public OperatorInfo { bool transpose_a_ = false; bool transpose_b_ = false; bool forward_reduce_scatter_ = false; + int32_t field_size_ = 0; size_t mat_a_dimension_ = 0; size_t mat_b_dimension_ = 0; }; @@ -93,4 +94,4 @@ class BatchMatMulInfo : public MatMul { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index 15acb085f5..1042a8ebf7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -77,7 +77,7 @@ Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { } Status OneHotInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); // Now input only support 1-D tensor, so the output is a 2-D tensor @@ -96,16 +96,16 @@ Status OneHotInfo::InferDevMatrixShape() { } Status OneHotInfo::InferTensorMap() { - std::vector input_tensor_map_index, output_tensor_map_index; + Shape input_tensor_map_index, output_tensor_map_index; size_t size = outputs_shape_[0].size(); // such as 2: tensor_map_index [1,0] if (axis_ == 0) { for (size_t i = 0; i < size; ++i) { - output_tensor_map_index.push_back((int32_t)(i)); + output_tensor_map_index.push_back((int64_t)(i)); } } else { for (size_t i = 0; i < size; ++i) { - output_tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); + output_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i)); } } outputs_tensor_map_.push_back(output_tensor_map_index); @@ -299,13 +299,13 @@ Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SUCCESS; } -std::shared_ptr>> OneHotInfo::GenerateBatchStrategies() { +std::shared_ptr OneHotInfo::GenerateBatchStrategies() { CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions strategy = {SizeToInt(dev_num), 1}; Dimensions empty_strategy; - std::vector strategy_v = {strategy, empty_strategy, empty_strategy}; - return std::make_shared>>(strategy_v); + Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; + return std::make_shared(strategy_v); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h index dfd7e6cbaf..362c5a57a3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ #include #include @@ -41,7 +41,7 @@ class OneHotInfo : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - std::shared_ptr>> GenerateBatchStrategies() override; + std::shared_ptr GenerateBatchStrategies() override; protected: Status CheckStrategy(const StrategyPtr &strategy) override; @@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 3dd47b1de6..420e7f9f96 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -30,7 +30,7 @@ #include "frontend/parallel/auto_parallel/edge_costmodel.h" #include "frontend/parallel/auto_parallel/graph_costmodel.h" #include "frontend/parallel/context.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/log_adapter.h" namespace mindspore { @@ -52,7 +52,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); for (size_t i = 0; i < strategy_size; ++i) { Shape sub_strategy = stra.at(i); Shape sub_input_shape = inputs_shape.at(i); @@ -70,7 +70,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap } for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); + int64_t strategy_value = sub_strategy.at(j); if (strategy_value < MIN_SLICE_NUM) { if (is_auto_parallel) { MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; @@ -89,7 +89,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap return FAILED; } - int32_t shape_value = sub_input_shape.at(j); + int64_t shape_value = sub_input_shape.at(j); if ((shape_value % strategy_value) != 0) { if (is_auto_parallel) { MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; @@ -138,9 +138,9 @@ void OperatorInfo::SetDeviceListByStrategy() { } Status OperatorInfo::InferRepeatedCalcInfo() { - int32_t g_dev_list_size = SizeToInt(global_device_list_.size()); - int32_t dev_matrix_size = - std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); + int64_t g_dev_list_size = SizeToLong(global_device_list_.size()); + int64_t dev_matrix_size = + std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); if (dev_matrix_size == 0) { MS_LOG(ERROR) << name_ << ": The dev matrix size is 0"; return FAILED; @@ -149,7 +149,7 @@ Status OperatorInfo::InferRepeatedCalcInfo() { if (g_dev_list_size == dev_matrix_size) { repeated_calc_num_ = 1; } else if (g_dev_list_size % dev_matrix_size == 0) { - repeated_calc_num_ = g_dev_list_size / dev_matrix_size; + repeated_calc_num_ = ((int32_t)(g_dev_list_size / dev_matrix_size)); } else { MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " << dev_matrix_size; @@ -326,7 +326,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { Shape slice_shape; - if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { + if (std::any_of(strategy.begin(), strategy.end(), [](int64_t value) { return value <= 0; })) { MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; return slice_shape; } @@ -430,7 +430,8 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat return FAILED; } - used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); + used_devices_ = + ((int32_t)(std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()))); // must be after InferDevMatrixShape if (InferRepeatedCalcInfo() != SUCCESS) { @@ -646,8 +647,8 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, con succ_edges_ = new_succ_edges; } -std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes &shapes, const std::vector &split_flag_list) { +std::shared_ptr GenerateBatchStrategiesBySplitFlag(const Shapes &shapes, + const std::vector &split_flag_list) { if (shapes.size() != split_flag_list.size()) { MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " << shapes.size(); @@ -655,21 +656,21 @@ std::shared_ptr>> GenerateBatchStrategiesBySpli } CheckGlobalDeviceManager(); int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - std::vector> strategy_v; + Strategys strategy_v; for (size_t i = 0; i != shapes.size(); i++) { if (shapes[i].empty()) { MS_LOG(INFO) << "Elements of shapes is empty."; - std::vector empty_element; + Dimensions empty_element; strategy_v.push_back(empty_element); } else { - std::vector element(shapes[i].size(), 1); + Dimensions element(shapes[i].size(), 1); if (split_flag_list[i]) { element[0] = dev_num; } strategy_v.push_back(element); } } - return std::make_shared>>(strategy_v); + return std::make_shared(strategy_v); } void OperatorInfo::ReComputeBatchSplitFlagList() { @@ -692,26 +693,26 @@ Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &input MS_LOG(ERROR) << "The strategy is null."; return FAILED; } - int32_t product = 1; + int64_t product = 1; for (auto &input_partition : inputs_partitions) { - product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); + product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); } if (!FULLY_USE_DEVICES) { - if (IntToSize(product) > dev_num) { + if (LongToSize(product) > dev_num) { return FAILED; } } else { - if ((product != 1) && (IntToSize(product) != dev_num)) { + if ((product != 1) && (LongToSize(product) != dev_num)) { return FAILED; } } - std::vector stras(inputs_partitions); + Strategys stras(inputs_partitions); (*sp) = std::make_shared(stage_id, stras); return SUCCESS; } -std::shared_ptr>> OperatorInfo::GenerateBatchStrategies() { +std::shared_ptr OperatorInfo::GenerateBatchStrategies() { ComputeBatchSplitFlagList(); return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); } @@ -793,7 +794,7 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs // second, get the correct strategy for input0 for (auto &sp : *sp_vector) { - std::vector tmp_strategy; + Strategys tmp_strategy; Dimensions input0_strategy = sp->GetInputDim()[0]; size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); @@ -842,7 +843,7 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &input // second, get the correct strategy for input1 for (auto &sp : *sp_vector) { - std::vector tmp_strategy; + Strategys tmp_strategy; tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 Dimensions input1_strategy = sp->GetInputDim()[1]; @@ -1175,7 +1176,7 @@ int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const S // The number of repetitions is equal to the number of all devices divided by the number of devices use for // tensor map. - int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); + int64_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { @@ -1194,7 +1195,7 @@ int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const S } } - return device_num; + return (int32_t)device_num; } Status OperatorInfo::InferAsLossDivisor() { @@ -1330,5 +1331,9 @@ void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { PrintStrategy(s_strategy); } } + +void OperatorInfo::SetStrategyCost(const std::vector> &stra_cost) { + strategy_cost_ = stra_cost; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 8641c47491..7801482c56 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ #include #include @@ -25,7 +25,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "base/base.h" #include "frontend/parallel/auto_parallel/costmodel.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h" @@ -43,11 +43,10 @@ using ForwardOp = OperatorVector; using MirrorOps = std::vector; using Ops = std::vector; using VirtualDivOp = OperatorVector; -using TensorMaps = std::vector>; +using TensorMaps = std::vector; using TensorLayouts = std::vector; using different_type = std::vector::difference_type; using PrimitiveAttrs = std::unordered_map; -using Strategys = std::vector; using ReplaceGraphPtr = std::shared_ptr>, AnfNodePtr>>; class Edge; @@ -88,7 +87,7 @@ class OperatorInfo { void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; - virtual std::shared_ptr>> GenerateBatchStrategies(); + virtual std::shared_ptr GenerateBatchStrategies(); virtual void ReComputeBatchSplitFlagList(); void ComputeBatchSplitFlagList(); @@ -97,6 +96,7 @@ class OperatorInfo { // is checked Status SetCostUnderStrategyBase(const StrategyPtr &strategy); std::vector> GetStrategyCost() { return strategy_cost_; } + void SetStrategyCost(const std::vector> &); // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory // at the end of forward phase. @@ -163,6 +163,9 @@ class OperatorInfo { const std::string &type() const { return type_; } const std::unordered_map &attrs() const { return attrs_; } + // Key for user data. + constexpr static char key[] = "OpInfo"; + protected: // needed by rec_parser std::string type_; @@ -267,8 +270,8 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); -std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes &shapes, const std::vector &split_flag_list); +std::shared_ptr GenerateBatchStrategiesBySplitFlag(const Shapes &shapes, + const std::vector &split_flag_list); void PrintStrategy(const StrategyPtr &strategy); // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) @@ -286,4 +289,4 @@ Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index bc732ed234..de9481c4d5 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ #include "frontend/parallel/ops_info/activation_info.h" #include "frontend/parallel/ops_info/arithmetic_info.h" @@ -37,5 +37,7 @@ #include "frontend/parallel/ops_info/transpose_info.h" #include "frontend/parallel/ops_info/virtual_dataset_info.h" #include "frontend/parallel/ops_info/gather_v2_p_info.h" +#include "frontend/parallel/ops_info/tile_info.h" +#include "frontend/parallel/ops_info/strided_slice_info.h" -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 79dfb56693..4b75c2ca95 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_ namespace mindspore { namespace parallel { @@ -29,6 +29,11 @@ constexpr int32_t NO_SPLIT_STRATEGY = 1; constexpr int32_t SPLIT_FLAG = 1; constexpr int32_t NO_SPLIT_FLAG = 0; constexpr size_t MATMUL_ATTRS_SIZE = 2; +constexpr size_t STRIDED_SLICE_ATTRS_SIZE = 5; +constexpr size_t STRIDED_SLICE_INPUTS_SIZE = 4; +constexpr size_t STRIDED_SLICE_BEGIN_INDEX = 1; +constexpr size_t STRIDED_SLICE_END_INDEX = 2; +constexpr size_t STRIDED_SLICE_STRIDES_INDEX = 3; constexpr size_t MATMUL_INPUTS_SIZE = 2; constexpr size_t MATMUL_OUTPUTS_SIZE = 1; constexpr size_t ACTIVATION_ATTR_SIZE = 1; @@ -100,6 +105,7 @@ constexpr char CONCAT_DIM[] = "concat_dim"; constexpr char FORWARD[] = "forward"; constexpr char BACKWARD[] = "backward"; constexpr char REDISTRIBUTION[] = "redistribution"; +constexpr char SKIP_REDISTRIBUTION[] = "skip_redistribution"; constexpr char REPLACE[] = "replace"; constexpr char CONNSYMBOL[] = "/"; constexpr char INSTANCE_NAME[] = "instance_name"; @@ -131,6 +137,7 @@ constexpr char FORWARD_OP[] = "forward_op"; constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; constexpr char DARA_PARALLEL[] = "data_parallel"; constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; +constexpr char FIELD_SIZE[] = "field_size"; constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; constexpr char DEVICE[] = "Device"; @@ -180,6 +187,7 @@ constexpr char RELU[] = "ReLU"; constexpr char ONEHOT[] = "OneHot"; constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask"; constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask"; +constexpr char TILE[] = "Tile"; constexpr char REDUCE_MAX[] = "ReduceMax"; constexpr char REDUCE_MIN[] = "ReduceMin"; constexpr char REDUCE_SUM[] = "ReduceSum"; @@ -230,6 +238,8 @@ constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; constexpr char ADD[] = "Add"; +constexpr char DROPOUT[] = "Dropout"; +constexpr char KStridedSlice[] = "StridedSlice"; // Parallel don't care constexpr char TUPLE_GETITEM[] = "tuple_getitem"; @@ -293,4 +303,4 @@ constexpr size_t THIRD_FROM_END(size_t s) { return s - 3; } } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPS_UTILS_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc index 57b35b69f7..90513b712f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc @@ -43,7 +43,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { } return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy size."; @@ -67,7 +67,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { * device matrix is same with the strategy matrix */ Status PReLUInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); input_strategy_ = input_strategy; dev_matrix_shape_ = input_strategy; @@ -103,7 +103,7 @@ Status PReLUInfo::InferTensorMap() { TensorMap input_tensor_map; // such as 4: input_tensor_map [3,2,1,0] for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { - input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1)); + input_tensor_map.push_back((int64_t)(inputs_shape_[0].size() - i - 1)); } TensorMap param_tensor_map; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h index e6e5e23bac..49d298e73b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PRELU_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PRELU_INFO_H_ #include #include @@ -60,4 +60,4 @@ class PReLUInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PRELU_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index 0488dceeca..d685be5c0d 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -43,7 +43,7 @@ Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { } Status ReduceMethod::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); dev_matrix_shape_ = input_strategy; @@ -119,11 +119,12 @@ Status ReduceMethod::GetAttrs() { } Status ReduceMethod::InferTensorMap() { - std::vector tensor_map_index, dim_list, output_tensor_map; + Shape tensor_map_index, output_tensor_map; + std::vector dim_list; size_t size = inputs_shape_.at(0).size(); // such as 4: tensor_map_index [3,2,1,0] for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - 1 - i)); + tensor_map_index.push_back((int64_t)(size - 1 - i)); } dim_list = reduce_dim(); for (size_t i = 0; i < size; ++i) { @@ -462,7 +463,7 @@ Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) { std::vector dim_list = reduce_dim(); MS_ASSERT(dim_list.size() == 1); - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); MS_ASSERT(stra.size() == 1); Shape input_strategy = stra.at(0); MS_ASSERT(dim_list.at(0) < input_strategy.size()); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h index ed9ab0721d..56a83d61d5 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ #include #include @@ -138,4 +138,4 @@ class ReduceMinInfo : public ReduceMethod { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index fb62c1d02c..8b1537d421 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -18,6 +18,7 @@ #include #include +#include #include "frontend/parallel/device_manager.h" #include "frontend/parallel/device_matrix.h" @@ -56,7 +57,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) */ Status ReshapeInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); input_strategy_ = stra.at(0); dev_matrix_shape_.push_back(input_strategy_[0]); return SUCCESS; @@ -145,17 +146,25 @@ Status ReshapeInfo::ComputeReplaceOp() { MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); - RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); - if (redistribution_oplist_ptr == nullptr) { - if (is_generating_costs_) { - MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; - } else { - MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; + if (is_skip_) { + ConstructOperator constructor; + replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array()); + replace_op_info_.clear(); + MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is " + << ShapeToString(output_layout_.slice_shape().array()); + } else { + RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); + if (redistribution_oplist_ptr == nullptr) { + if (is_generating_costs_) { + MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; + } else { + MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; + } + return FAILED; } - return FAILED; + replace_op_ = redistribution_oplist_ptr->first; + replace_op_info_ = redistribution_oplist_ptr->second; } - replace_op_ = redistribution_oplist_ptr->first; - replace_op_info_ = redistribution_oplist_ptr->second; MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); return SUCCESS; } @@ -172,7 +181,7 @@ Status ReshapeInfo::InferTensorMap() { return FAILED; } - std::vector tensor_map_index_input; + Shape tensor_map_index_input; tensor_map_index_input.push_back(0); for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { @@ -180,7 +189,7 @@ Status ReshapeInfo::InferTensorMap() { } inputs_tensor_map_.push_back(tensor_map_index_input); - std::vector tensor_map_index_output; + Shape tensor_map_index_output; tensor_map_index_output.push_back(0); for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { @@ -196,7 +205,7 @@ Status ReshapeInfo::InferTensorMap() { */ Strategys ReshapeInfo::GetOutputsStrategy() { Strategys outputs_strategy; - std::vector strategy; + Dimensions strategy; strategy.push_back(input_strategy_[0]); for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { strategy.push_back(1); @@ -255,6 +264,19 @@ Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayout } Status ReshapeInfo::InferTensorInfo() { + // skip reshape infer if skip_redistribution is true + if (is_skip_) { + TensorLayout layout; + Shape shape; + Shape slice_shape; + layout.set_skip_redistribution(true); + TensorInfo tensor_info_in(layout, shape, slice_shape); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_in); + MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo"; + return SUCCESS; + } + Shapes inputs_slice_shape, outputs_slice_shape; Strategys inputs_strategy = strategy_->GetInputDim(); Strategys outputs_strategy = GetOutputsStrategy(); @@ -303,7 +325,7 @@ void ReshapeInfo::device_number(const StrategyPtr &strategy) { } Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { - std::vector tensor_map_index; + Shape tensor_map_index; for (size_t i = 0; i < shape.size(); i++) { tensor_map_index.push_back(MAP_NONE); } @@ -316,6 +338,16 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l } Status ReshapeInfo::Init(const StrategyPtr &strategy) { + auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION); + if (reshape_skip_redis_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second); + if (!reshape_skip_redis_iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool."; + return FAILED; + } + is_skip_ = reshape_skip_redis_iter->second->cast()->value(); + } + ResetQueueMember(); device_number(strategy); if (strategy) { @@ -472,7 +504,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector stra_inputs = {stra}; + Strategys stra_inputs = {stra}; StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); if (next_stra_costs.empty()) { if (Init(nullptr) == FAILED) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h index 2463b440f8..91e4382726 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ #include @@ -97,11 +97,12 @@ class ReshapeInfo : public OperatorInfo { TensorLayout output_layout_; bool input_layout_set_flag_; bool output_layout_set_flag_; - bool is_generating_costs_; + bool is_generating_costs_ = false; + bool is_skip_ = false; std::string pre_operator_name_; std::string next_operator_name_; }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc new file mode 100644 index 0000000000..cda37cdf20 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/strided_slice_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +Status StridedSliceInfo::GetMask(const std::string &mask_name, int32_t *mask_value) { + if (mask_value == nullptr) { + return FAILED; + } + auto mask_iter = attrs_.find(mask_name); + if (mask_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(mask_iter->second); + if (mask_iter->second->isa()) { + *mask_value = mask_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << ": The value of " << mask_name << " is not int"; + return FAILED; + } + } + return SUCCESS; +} + +Status GetInput(const ValuePtr &input_value, std::vector *input) { + MS_EXCEPTION_IF_NULL(input_value); + ValueTuplePtr value_tuple = input_value->cast(); + if (value_tuple == nullptr) { + MS_LOG(ERROR) << "Input value must be ValueTuplePtr."; + return FAILED; + } + + for (auto &element : value_tuple->value()) { + MS_EXCEPTION_IF_NULL(element); + if (element->isa()) { + int32_t value = element->cast()->value(); + input->push_back(value); + } else { + MS_LOG(ERROR) << "The value must be int32"; + return FAILED; + } + } + + return SUCCESS; +} + +Status StridedSliceInfo::GetAttrs() { + if (attrs_.size() < STRIDED_SLICE_ATTRS_SIZE) { + MS_LOG(ERROR) << name_ << ": The size of attrs small than " << STRIDED_SLICE_ATTRS_SIZE; + return FAILED; + } + + if ((GetMask(BEGIN_MASK, &begin_mask_) != SUCCESS) || (GetMask(END_MASK, &end_mask_) != SUCCESS) || + (GetMask(ELLIPSIS_MASK, &ellipsis_mask_) != SUCCESS) || (GetMask(NEW_AXIS_MASK, &new_axis_mask_) != SUCCESS) || + (GetMask(SHRINK_AXIS_MASK, &shrink_axis_mask_) != SUCCESS)) { + return FAILED; + } + has_mask_ = ((begin_mask_ != 0) || (end_mask_ != 0) || (ellipsis_mask_ != 0) || (new_axis_mask_ != 0) || + (shrink_axis_mask_ != 0)); + + if (input_value_.size() != STRIDED_SLICE_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": The size of input value must be " << STRIDED_SLICE_INPUTS_SIZE << ", but got " + << input_value_.size(); + return FAILED; + } + + if ((GetInput(input_value_[STRIDED_SLICE_BEGIN_INDEX], &begin_) != SUCCESS) || + (GetInput(input_value_[STRIDED_SLICE_END_INDEX], &end_) != SUCCESS) || + (GetInput(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS)) { + return FAILED; + } + + return SUCCESS; +} + +Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + Dimensions strategy_value = stra[0]; + bool has_split = std::any_of(strategy_value.begin(), strategy_value.end(), [](int32_t v) { return v > 1; }); + + if (has_split && has_mask_) { + MS_LOG(ERROR) << name_ << ": When there is a mask, the input is not supported to be split"; + return FAILED; + } + + if (strategy_value.size() < strides_.size()) { + MS_LOG(ERROR) << name_ << ": The size of strategy must be larger or equal to the size of strides"; + return FAILED; + } + for (size_t i = 0; i < strides_.size(); ++i) { + if ((strides_[i] != 1) && (strategy_value[i] > 1)) { + MS_LOG(ERROR) << name_ << ": When a certain dimension is split, now does not support that the stride is not 1"; + return FAILED; + } + } + + if ((begin_.size() != end_.size()) || (begin_.size() != strides_.size())) { + MS_LOG(ERROR) << name_ << ": The size of begin " << begin_.size() << ", end " << end_.size() << " and strides " + << strides_.size() << " must be equal"; + return FAILED; + } + + for (size_t i = 0; i < begin_.size(); ++i) { + bool no_fully_fetch = ((begin_[i] != 0) || (end_[i] < inputs_shape_[0][i])); + if (no_fully_fetch && (strategy_value[i] != 1)) { + MS_LOG(ERROR) << name_ << "When a dimension is not fully fetched, the dimension can not be split now"; + return FAILED; + } + } + + return SUCCESS; +} + +Status StridedSliceInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << "The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status StridedSliceInfo::InferTensorMap() { + TensorMap tensor_map; + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << "The inputs shape is empty"; + return FAILED; + } + + // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. + int32_t size = SizeToInt(inputs_shape_[0].size()); + for (int i = 0; i < size; ++i) { + tensor_map.push_back(size - i - 1); + } + + inputs_tensor_map_.push_back(tensor_map); + outputs_tensor_map_.push_back(tensor_map); + return SUCCESS; +} + +Status StridedSliceInfo::InferMirrorOps() { + mirror_ops_.clear(); + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; + return FAILED; + } + Shape input_tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group for input failed."; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror group is empty."; + return SUCCESS; + } + + OperatorVector input_op, begin_op, end_op, strides_op; + input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(input_op); + mirror_ops_.push_back(begin_op); + mirror_ops_.push_back(end_op); + mirror_ops_.push_back(strides_op); + return SUCCESS; +} + +Status StridedSliceInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + // infer tensor layout + TensorLayout input_layout, output_layout; + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + + TensorInfo input_tensor_info(input_layout); + TensorInfo output_tensor_info(output_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +// Note: if the batch dimension is not fully fetched, the batch strategy may not work. +std::shared_ptr StridedSliceInfo::GenerateBatchStrategies() { + split_flag_list_ = {true}; + return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); +} + +Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + return FAILED; + } + + return SUCCESS; +} + +Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) { + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer attrs failed"; + return FAILED; + } + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + Shape input_split(inputs_shape_[0].size(), 1); + if (has_mask_) { + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + input_split[i] = 0; + } + } else { + for (size_t i = 0; i < begin_.size(); ++i) { + bool no_fully_fetch = ((begin_[i] != 0) || (end_[i] < inputs_shape_[0][i])); + if (no_fully_fetch || (strides_[i] != 1)) { + input_split[i] = 0; + } + } + } + Shapes splittable_inputs = {input_split}; + + std::vector sp_vector; + is_auto_parallel_ = true; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status StridedSliceInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status StridedSliceInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h new file mode 100644 index 0000000000..3225308bf2 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_ + +#include + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class StridedSliceInfo : public OperatorInfo { + public: + StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~StridedSliceInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + std::shared_ptr GenerateBatchStrategies() override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetMask(const std::string &mask_name, int32_t *mask_value); + + private: + std::vector begin_; + std::vector end_; + std::vector strides_; + int32_t begin_mask_ = 0; + int32_t end_mask_ = 0; + int32_t ellipsis_mask_ = 0; + int32_t new_axis_mask_ = 0; + int32_t shrink_axis_mask_ = 0; + bool has_mask_ = false; +}; + +using StridedSliceInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc new file mode 100644 index 0000000000..6949c7e0c1 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc @@ -0,0 +1,255 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/tile_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +// get the multiples +Status TileInfo::GetAttrs() { + if (input_value_.size() < 2) { + MS_LOG(ERROR) << name_ << ": The size of input value is smaller than 2."; + return FAILED; + } + if (input_value_[1] == nullptr) { + MS_LOG(ERROR) << name_ << ": The multiples is null."; + return FAILED; + } + + std::vector elements; + ValueTuplePtr multiples = input_value_[1]->cast(); + if (multiples == nullptr) { + MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr."; + return FAILED; + } + elements = multiples->value(); + if (elements.size() != outputs_shape_[0].size()) { + MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size."; + return FAILED; + } + + for (auto &element : elements) { + MS_EXCEPTION_IF_NULL(element); + if (element->isa()) { + int64_t axis = static_cast(element->cast()->value()); + full_multiples_.push_back(axis); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; + return FAILED; + } + } + + return SUCCESS; +} + +Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { + Shapes multiples = {full_multiples_}; + if (CheckStrategyValue(strategy, multiples, is_auto_parallel_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + return FAILED; + } + + return SUCCESS; +} + +Status TileInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << "The strategy is empty"; + return FAILED; + } + if (full_multiples_.size() != stra[0].size()) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + + slice_multiples_ = full_multiples_; + for (size_t i = 0; i < full_multiples_.size(); ++i) { + slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i]; + } + return SUCCESS; +} + +Status TileInfo::InferTensorMap() { + TensorMap input_tensor_map; + TensorMap output_tensor_map; + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty"; + return FAILED; + } + + // the input tensor cannot be split + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + input_tensor_map.push_back(MAP_NONE); + } + + // cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices. + int32_t size = SizeToInt(outputs_shape_[0].size()); + for (int i = 0; i < size; ++i) { + output_tensor_map.push_back(size - i - 1); + } + + inputs_tensor_map_.push_back(input_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + return SUCCESS; +} + +Status TileInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group for input failed."; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror group is empty."; + return SUCCESS; + } + + OperatorVector input_op, multiples_op; + input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(input_op); + mirror_ops_.push_back(multiples_op); + return SUCCESS; +} + +Status TileInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + // infer tensor layout + TensorLayout input_layout, output_layout; + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + + TensorInfo input_tensor_info(input_layout); + TensorInfo output_tensor_info(output_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +void TileInfo::UpdateMultiples(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != 3) { + MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3"; + } + + if (!IsValueNode(cnode->input(2))) { + MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple."; + } + + auto func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + std::vector slice_multiples_int; + (void)std::transform(slice_multiples_.begin(), slice_multiples_.end(), std::back_inserter(slice_multiples_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr new_multiples = MakeValue(slice_multiples_int); + AnfNodePtr val = NewValueNode(new_multiples); + (void)manager->Replace(cnode->input(2), val); +} + +std::shared_ptr TileInfo::GenerateBatchStrategies() { + if (InferAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; + } + Shapes multiples_shape = {full_multiples_}; + split_flag_list_ = {true}; + return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_); +} + +Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + return FAILED; + } + + return SUCCESS; +} + +Status TileInfo::GenerateStrategies(int32_t stage_id) { + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer attrs failed"; + return FAILED; + } + Shape multiples_split(full_multiples_.size(), 1); + Shapes splittable_inputs = {multiples_split}; + + std::vector sp_vector; + is_auto_parallel_ = true; + Shapes tmp_inputs_shape = {full_multiples_}; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status TileInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status TileInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h new file mode 100644 index 0000000000..5ff576e308 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TILE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TILE_INFO_H_ + +#include + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class TileInfo : public OperatorInfo { + public: + TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TileInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + std::shared_ptr GenerateBatchStrategies() override; + void UpdateMultiples(const CNodePtr &cnode); + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + + private: + std::vector full_multiples_; + std::vector slice_multiples_; +}; + +using TileInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TILE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc index ed6eaa89f1..90b0c7e94b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc @@ -37,18 +37,18 @@ Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &st } Status TmpIdentityInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); dev_matrix_shape_ = input_strategy; return SUCCESS; } Status TmpIdentityInfo::InferTensorMap() { - std::vector tensor_map_index; + Shape tensor_map_index; size_t size = inputs_shape_[0].size(); // such as 4: tensor_map_index [3,2,1,0] for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - 1 - i)); + tensor_map_index.push_back((int64_t)(size - 1 - i)); } inputs_tensor_map_.push_back(tensor_map_index); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h index 7f73f81180..474e18db20 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ #include #include @@ -55,4 +55,4 @@ class TmpIdentityInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc index b6bb875abc..ebcebbc66c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc @@ -41,7 +41,7 @@ Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { } Status TransposeInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); input_strategy_ = stra.at(0); for (auto &iter : input_strategy_) { dev_matrix_shape_.push_back(iter); @@ -105,13 +105,13 @@ Status TransposeInfo::InferTensorMap() { return FAILED; } - std::vector tensor_map_index_input; + Shape tensor_map_index_input; for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1)); } inputs_tensor_map_.push_back(tensor_map_index_input); - std::vector tensor_map_index_output = tensor_map_index_input; + Shape tensor_map_index_output = tensor_map_index_input; for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) { tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])]; } @@ -122,7 +122,7 @@ Status TransposeInfo::InferTensorMap() { // the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v Strategys TransposeInfo::GetOutputsStrategy() { Strategys outputs_strategy; - std::vector strategy = input_strategy_; + Dimensions strategy = input_strategy_; for (uint32_t i = 0; i < strategy.size(); i++) { strategy[i] = input_strategy_[IntToUint(axis_v_[i])]; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h index d3b62dc234..0d017c1821 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ #include #include @@ -61,4 +61,4 @@ class TransposeInfo : public OperatorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc index 3b89d7c84c..5e3676fe05 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -38,7 +38,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - std::vector stra = strategy->GetInputDim(); + Strategys stra = strategy->GetInputDim(); if (stra.size() < 1) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1."; @@ -80,12 +80,12 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { } Status VirtualDatasetInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); + Strategys stra = strategy_->GetInputDim(); Dimensions strategy_first = stra.at(0); int32_t stage = strategy_->GetInputStage(); CheckGlobalDeviceManager(); int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - int32_t batch_split_num = strategy_first.at(0); + int32_t batch_split_num = ((int32_t)(strategy_first.at(0))); dev_matrix_shape_.push_back(batch_split_num); if (dev_num > batch_split_num) { dev_matrix_shape_.push_back(dev_num / batch_split_num); @@ -103,11 +103,11 @@ Status VirtualDatasetInfo::InferTensorMap() { bool full_batch = ParallelContext::GetInstance()->full_batch(); for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - std::vector tensor_map_index; + Shape tensor_map_index; if (full_batch) { tensor_map_index.push_back(MAP_NONE); } else { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); + tensor_map_index.push_back((int64_t)(LAST_INDEX(dev_matrix_shape_.size()))); } for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { tensor_map_index.push_back(MAP_NONE); @@ -193,7 +193,7 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); } StrategyPtr sp; - std::vector strategy; + Strategys strategy; for (auto &shape : inputs_shape_) { Shape temp; temp.emplace_back(SizeToInt(total_dev_num)); diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h index 5e136c816f..b0d557dc1f 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/common.h +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_COMMON_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_COMMON_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_COMMON_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_COMMON_H_ #include #include @@ -57,14 +57,22 @@ constexpr char kMomentum[] = "momentum"; constexpr char kApplyMomentum[] = "ApplyMomentum"; constexpr char kSparseAdam[] = "Adam"; constexpr char kSparseFtrl[] = "Ftrl"; +constexpr char kApplyMomentumOp[] = "Momentum"; +constexpr char kSparseAdamOp[] = "Adam"; +constexpr char kSparseFtrlOp[] = "FTRL"; constexpr int kInitWeightsCmd = 10; constexpr int kInitWeightToOptimIdCmd = 11; constexpr int kInitOptimInputsShapeCmd = 12; +constexpr int kInitKeyToPushNodeIdCmd = 13; constexpr int kInitEmbeddingsCmd = 20; +constexpr int kCheckReadyForPushCmd = 25; +constexpr int kCheckReadyForPullCmd = 26; constexpr int kEmbeddingLookupCmd = 30; +constexpr int kFinalizeCmd = 40; constexpr size_t kInvalidKey = UINT64_MAX; +constexpr int kInvalidID = -1; using Key = ::ps::Key; using Keys = ::ps::SArray; @@ -72,16 +80,13 @@ using Values = ::ps::SArray; using ValuesPtr = std::shared_ptr; using Weight = ::ps::SArray; using Grad = ::ps::SArray; -using LookupIds = ::ps::SArray; +using LookupIds = ::ps::SArray; using Lengths = ::ps::SArray; using WeightPtr = std::shared_ptr; using GradPtr = std::shared_ptr; -// using EmbeddingTable = std::unordered_map; -// using EmbeddingTable = ::ps::SArray; -// using EmbeddingTablePtr = std::shared_ptr; using InputsShape = std::vector>>; using InputsShapePtr = std::shared_ptr>>>; } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_COMMON_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_COMMON_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index e16c713e3c..5f25b79c23 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -57,6 +57,18 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { } } +void DenseOptimInfo::ComputeMean(size_t n) { + if (n > 1) { + float *accum_grad_data = reinterpret_cast(gradient()->addr); + size_t size = gradient()->size / sizeof(float); + for (size_t i = 0; i < size; i++) { + accum_grad_data[i] /= n; + } + } +} + +void DenseOptimInfo::Reset() { memset_s(gradient()->addr, gradient()->size, 0x00, gradient()->size); } + void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { // Append grad data to the end float *accum_grad_data = reinterpret_cast(gradient()->addr); @@ -118,11 +130,13 @@ const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } +size_t MomentumOptimInfo::grad_index() { return 1; } + SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, - const AddressPtr &indices, size_t grads_offset, size_t indices_offset) { + const AddressPtr &indices) { inputs_.push_back(weight); inputs_.push_back(m); inputs_.push_back(v); @@ -134,8 +148,8 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address inputs_.push_back(epsilon); inputs_.push_back(grad); inputs_.push_back(indices); - grads_offset_ = grads_offset; - indices_offset_ = indices_offset; + grads_offset_ = 0; + indices_offset_ = 0; } void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { @@ -159,15 +173,14 @@ size_t SparseAdamOptimInfo::grad_index() { return 6; } size_t SparseAdamOptimInfo::indices_index() { return 7; } SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, - const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, - size_t indices_offset) { + const AddressPtr &grad, const AddressPtr &indices) { inputs_.push_back(weight); inputs_.push_back(accum); inputs_.push_back(linear); inputs_.push_back(grad); inputs_.push_back(indices); - grads_offset_ = grads_offset; - indices_offset_ = indices_offset; + grads_offset_ = 0; + indices_offset_ = 0; } const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index bb9a64acdb..dc567e023c 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_H_ #include #include "backend/kernel_compiler/kernel.h" @@ -33,6 +33,7 @@ class OptimizerInfo { virtual void Update(const Values &values, const Lengths &lengths) {} virtual void UpdateWeight(const WeightPtr &weight); virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; + virtual void ComputeMean(size_t n) {} virtual void Reset() {} void AddWorkspace(const AddressPtr &workspace); @@ -58,6 +59,8 @@ class DenseOptimInfo : public OptimizerInfo { ~DenseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; + void ComputeMean(size_t n) override; + void Reset() override; }; class SparseOptimInfo : public OptimizerInfo { @@ -81,6 +84,7 @@ class MomentumOptimInfo : public DenseOptimInfo { const AddressPtr &gradient(); const AddressPtr &indices(); + size_t grad_index() override; }; class SparseAdamOptimInfo : public SparseOptimInfo { @@ -88,7 +92,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo { SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, - const AddressPtr &indices, size_t grads_offset, size_t indices_offset); + const AddressPtr &indices); ~SparseAdamOptimInfo() override = default; void Update(const Values &values, const Lengths &lens) override; @@ -102,7 +106,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo { class SparseFtrlOptimInfo : public SparseOptimInfo { public: SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, - const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, size_t indices_offset); + const AddressPtr &grad, const AddressPtr &indices); ~SparseFtrlOptimInfo() override = default; const AddressPtr &gradient(); @@ -114,4 +118,4 @@ class SparseFtrlOptimInfo : public SparseOptimInfo { } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index 159a50793e..e1d5ffb32a 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -48,20 +48,26 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co size_t worker_num) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); + weight_addr->size = weight->size() * sizeof(float); void *data_ptr = values.data(); + void *copy_data_ptr = new float[values.size()]; + auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } AddressPtr accumulate = std::make_shared(); accumulate->addr = new float[weight->size()]; - accumulate->size = weight->size(); + accumulate->size = weight->size() * sizeof(float); + memset_s(accumulate->addr, accumulate->size, 0x00, accumulate->size); AddressPtr learning_rate = std::make_shared(); - learning_rate->addr = data_ptr; - learning_rate->size = lens[0]; + learning_rate->addr = copy_data_ptr; + learning_rate->size = lens[0] * sizeof(float); AddressPtr gradient = std::make_shared(); gradient->addr = reinterpret_cast(learning_rate->addr) + lens[0]; - gradient->size = lens[1]; + gradient->size = lens[1] * sizeof(float); AddressPtr momentum = std::make_shared(); momentum->addr = reinterpret_cast(gradient->addr) + lens[1]; - momentum->size = lens[2]; + momentum->size = lens[2] * sizeof(float); return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum); } @@ -71,17 +77,25 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, size_t worker_num) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); + weight_addr->size = weight->size() * sizeof(float); AddressPtr m = std::make_shared(); m->addr = new float[weight->size()]; m->size = weight->size() * sizeof(float); + int ret = memset_s(m->addr, m->size, 0x00, m->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } AddressPtr v = std::make_shared(); v->addr = new float[weight->size()]; v->size = weight->size() * sizeof(float); + ret = memset_s(v->addr, v->size, 0x00, v->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } void *data_ptr = values.data(); void *copy_data_ptr = new float[values.size()]; - auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); + ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } @@ -114,10 +128,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies()); AddressPtr grad = std::make_shared(); grad->addr = new float[total_grad_size * worker_num]; - auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], - lens[6] * sizeof(float)); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + ret = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], + lens[6] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } grad->size = lens[6] * sizeof(float); @@ -126,15 +140,15 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies()); AddressPtr indices = std::make_shared(); indices->addr = new float[total_indice_size * worker_num]; - auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float), - reinterpret_cast(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float)); - if (ret3 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")"; + ret = memcpy_s(indices->addr, lens[7] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5] + lens[6], + lens[7] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } - indices->size = lens[7] * sizeof(float); + indices->size = lens[7] * sizeof(int); return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, - grad, indices, total_grad_size, total_indice_size); + grad, indices); } OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, @@ -142,7 +156,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, size_t worker_num) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); + weight_addr->size = weight->size() * sizeof(float); AddressPtr accum = std::make_shared(); accum->addr = new float[weight->size()]; accum->size = weight->size() * sizeof(float); @@ -152,14 +166,17 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, } AddressPtr linear = std::make_shared(); linear->addr = new float[weight->size()]; - memcpy_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")"; + } linear->size = weight->size() * sizeof(float); const std::shared_ptr> &grad_shape = (*inputs_shape)[3]; size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies()); AddressPtr grad = std::make_shared(); grad->addr = new float[total_grad_size * worker_num]; - auto ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); + ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } @@ -170,14 +187,14 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); AddressPtr indices = std::make_shared(); indices->addr = new float[total_indice_size * worker_num]; - auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], - lens[1] * sizeof(float)); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + ret = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], + lens[1] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } - indices->size = lens[1] * sizeof(float); + indices->size = lens[1] * sizeof(int); - return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, total_grad_size, total_indice_size); + return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices); } } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h index c5aae32921..03b842ed33 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ #include #include #include "backend/kernel_compiler/kernel.h" -#include "backend/kernel_compiler/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" #include "frontend/parallel/ps/optimizer_info.h" namespace mindspore { @@ -63,4 +63,4 @@ class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h old mode 100755 new mode 100644 index 1afb4c9fa6..6eedbd76b3 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_PARAMETER_SERVER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_PARAMETER_SERVER_H_ #include #include @@ -28,9 +28,9 @@ #include #include #include +#include #include "ir/func_graph.h" #include "backend/session/session_basic.h" -#include "backend/session/kernel_graph.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/session_factory.h" #include "frontend/parallel/ps/common.h" @@ -38,14 +38,14 @@ #include "frontend/parallel/ps/optimizer_info_builder.h" #include "frontend/parallel/ps/util.h" #include "runtime/device/cpu/kernel_select_cpu.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/kernel_compiler/kernel.h" -#include "backend/kernel_compiler/ps/pserver_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" -#include "backend/kernel_compiler/ps/sparse_apply_adam_ps_kernel.h" -#include "backend/kernel_compiler/ps/sparse_apply_ftrl_ps_kernel.h" -#include "backend/kernel_compiler/ps/apply_momentum_ps_kernel.h" -#include "backend/kernel_compiler/ps/embedding_look_up_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" namespace mindspore { namespace parallel { @@ -70,24 +70,39 @@ class ParameterServer { ps_(new ::ps::KVServer(0)), handler_(nullptr), func_graph_(nullptr), - kernel_graph_(nullptr), sess_(nullptr), + running_(true), thread_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; ParameterServer &operator=(const ParameterServer &) = delete; - struct ServerHandler { + class ServerHandler { + public: explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + void Init(); void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); - void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); + + private: + void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleInitWeights(const ::ps::KVPairs &req_data); - void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); - void HandleInitInputsShape(const ::ps::KVPairs &req_data); - void HandleInitEmbeddings(const ::ps::KVPairs &req_data); + void HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res); + void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + ParameterServer *ps_; + typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res); + std::unordered_map handlers_; + std::unordered_map init_weights_; + std::unordered_map init_weight_to_optim_; + std::unordered_map init_optim_info_; }; bool Init(const FuncGraphPtr &func_graph); @@ -98,15 +113,18 @@ class ParameterServer { void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, const std::shared_ptr>>> &shapes); + void Finalize(); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); int SumOfShapes(const std::vector &shapes) const; - size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); bool ReadyForUpdateWeights(); - bool ReadyForAccumGrads(); + bool ReadyForPush(const Key &key); + bool ReadyForPull(const Key &key); void ResetGradAccumCount(); + const CNodePtr GetCNode(const std::string &name) const; + std::mutex &mutex(); size_t pserver_num_; size_t worker_num_; @@ -115,31 +133,28 @@ class ParameterServer { std::unique_ptr<::ps::KVServer> ps_; std::unique_ptr handler_; FuncGraphPtr func_graph_; - std::shared_ptr kernel_graph_; std::shared_ptr sess_; + bool running_; - std::unordered_map> optimizers_; + std::unordered_map> optimizers_; std::unordered_map optim_inputs_shape_; std::unordered_map> optim_infos_; std::unordered_map> optim_info_builders_; std::unordered_map weight_key_to_optims_; + std::unordered_map weight_key_to_optim_op_; std::unordered_map weights_; + std::unordered_map is_embedding_; std::unordered_map grads_; std::unordered_map grads_accum_counter_; - // std::unordered_map embeddings_; std::unordered_map> embedding_lookup_ops_; - std::unordered_map embedding_row_lens_; - - T learning_rate_; - T momentum_; + std::unordered_map tokens_; std::mutex mutex_; std::condition_variable apply_grads_cv_; - std::condition_variable accum_grads_cv_; std::unique_ptr thread_; - friend struct ServerHandler; + friend class ServerHandler; }; class FuncGraph; @@ -147,33 +162,32 @@ template void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server) { ::ps::KVPairs res; - if (req_meta.cmd == kInitWeightsCmd) { - MS_LOG(ERROR) << "handle init weights cmd" << std::endl; - HandleInitWeights(req_data); - } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { - MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; - HandleInitWeightToOptimId(req_data); - } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { - MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; - HandleInitInputsShape(req_data); - } else if (req_meta.cmd == kInitEmbeddingsCmd) { - MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; - HandleInitEmbeddings(req_data); - } else if (req_meta.cmd == kEmbeddingLookupCmd) { - MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; - HandleEmbeddingLookup(req_meta, req_data, &res); + if (handlers_.count(req_meta.cmd) > 0) { + auto &handler_ptr = handlers_[req_meta.cmd]; + (this->*handler_ptr)(req_meta, req_data, &res); } else if (req_meta.push) { - MS_LOG(ERROR) << "handle push req cmd" << std::endl; - HandlePushReq(req_meta, req_data); + HandlePushReq(req_meta, req_data, &res); } else { - MS_LOG(ERROR) << "handle pull req cmd" << std::endl; HandlePullReq(req_meta, req_data, &res); } server->Response(req_meta, res); } template -void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::Init() { + handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights; + handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId; + handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape; + handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings; + handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; + handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; + handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; + handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; +} + +template +void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); } @@ -186,7 +200,9 @@ void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me } template -void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); size_t key_num = req_data.keys.size(); T *data_ptr = req_data.vals.data(); size_t pos = 0; @@ -205,22 +221,41 @@ void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs } template -void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); size_t key_num = req_data.keys.size(); for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; T val = req_data.vals[i]; + if (init_weight_to_optim_[key]) { + continue; + } else { + init_weight_to_optim_[key] = true; + } ps_->InitWeightKeyToOptims(key, val); } } template -void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); + const Key &key = req_data.keys[0]; + if (init_optim_info_[key]) { + return; + } else { + init_optim_info_[key] = true; + } ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); } template -void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { +void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); + const Key &key = req_data.keys[0]; std::shared_ptr>>> shapes = std::make_shared>>>(); std::shared_ptr> input_shape = std::make_shared>(); @@ -230,7 +265,6 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs shapes->push_back(indices_shape); shapes->push_back(output_shape); - const Key &key = req_data.keys[0]; const Lengths &lens = req_data.lens; size_t index = 0; for (int i = 0; i < lens[0]; i++) { @@ -245,32 +279,52 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs ps_->InitEmbeddingTable(key, shapes); } +template +void ParameterServer::ServerHandler::HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + bool ready = ps_->ReadyForPush(key); + res->keys.push_back(key); + res->vals.push_back(ready); +} + +template +void ParameterServer::ServerHandler::HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + bool ready = ps_->ReadyForPull(key); + res->keys.push_back(key); + res->vals.push_back(ready); +} + template void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { const Key &key = req_data.keys[0]; - ps_->DoEmbeddingLookup(key, req_data.vals, res); - for (size_t i = 0; i < req_data.vals.size(); i++) { - res->keys->push_back(req_data.vals[i]); + for (size_t i = 0; i < req_data.keys.size(); i++) { + res->keys.push_back(req_data.keys[i]); } + ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); +} + +template +void ParameterServer::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + ps_->Finalize(); } template bool ParameterServer::Init(const FuncGraphPtr &func_graph) { - const char *server_num = getenv(kEnvPServerNum); - const char *worker_num = getenv(kEnvWorkerNum); - if (server_num != nullptr) { - pserver_num_ = *server_num - '0'; - } - if (worker_num != nullptr) { - worker_num_ = *worker_num - '0'; - } + pserver_num_ = ::ps::NumServers(); + worker_num_ = ::ps::NumWorkers(); func_graph_ = func_graph; rank_id_ = ::ps::MyRank(); handler_.reset(new ServerHandler(this)); + handler_->Init(); InitOptimInfoBuilders(); - ps_->set_request_handle(*handler_); thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); return true; @@ -288,10 +342,11 @@ void ParameterServer::InitOptimInfoBuilders() { template void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_id) { - if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { + if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") { return; } weight_key_to_optims_[key] = Util::optimizer_name(optim_id); + weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id); } template @@ -314,31 +369,49 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &va } if (weight_key_to_optims_.count(key) > 0) { const std::string &optim_name = weight_key_to_optims_[key]; - if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { + const std::string &optim_op_name = weight_key_to_optim_op_[key]; + if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) { + const CNodePtr cnode = GetCNode(optim_op_name); + MS_EXCEPTION_IF_NULL(cnode); if (optim_name == kSparseAdam) { std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); - optimizers_[optim_name] = optimizer; + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; } else if (optim_name == kApplyMomentum) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); - optimizers_[optim_name] = optimizer; + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; } else if (optim_name == kSparseFtrl) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); - optimizers_[optim_name] = optimizer; + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; } } } } +template +const CNodePtr ParameterServer::GetCNode(const std::string &name) const { + std::list cnodes = func_graph_->GetOrderedCnodes(); + for (CNodePtr cnode : cnodes) { + std::string fullname = cnode->fullname_with_scope(); + if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) { + return cnode; + } + } + return nullptr; +} + template void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { - if (weights_.count(key) == 0) { + MS_LOG(INFO) << "Initializing weight for key " << key; + if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { weights_[key] = weight; + tokens_[key] = 0; + is_embedding_[key] = false; } } @@ -353,7 +426,7 @@ void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { template void ParameterServer::InitEmbeddingTable( const Key &key, const std::shared_ptr>>> &shapes) { - // Init embedding lookup kernel + MS_LOG(INFO) << "Initializing embedding table for key " << key; std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); lookup->InitKernel(shapes); embedding_lookup_ops_[key] = lookup; @@ -364,17 +437,35 @@ void ParameterServer::InitEmbeddingTable( for (auto shape : input_shapes) { total_dims *= shape; } - WeightPtr embedding = std::make_shared(total_dims, 0.01); + + WeightPtr embedding = std::make_shared(total_dims, 0); + T *embedding_data = embedding->data(); + std::default_random_engine engine; + std::normal_distribution random(0, 0.01); + for (size_t i = 0; i < total_dims; i++) { + embedding_data[i] = random(engine); + } weights_[key] = embedding; + tokens_[key] = 0; + is_embedding_[key] = true; grads_accum_counter_[key] = 0; } +template +void ParameterServer::Finalize() { + running_ = false; + apply_grads_cv_.notify_one(); +} + template void ParameterServer::UpdateWeights() { while (true) { std::unique_lock lock(mutex_); - apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; }); + if (!running_) { + break; + } for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { Key key = iter->first; @@ -382,8 +473,7 @@ void ParameterServer::UpdateWeights() { std::shared_ptr optimizer = nullptr; if (weight_key_to_optims_.count(key) > 0) { - const std::string &optim_name = weight_key_to_optims_[key]; - optimizer = optimizers_[optim_name]; + optimizer = optimizers_[key]; } MS_EXCEPTION_IF_NULL(optimizer); @@ -391,32 +481,31 @@ void ParameterServer::UpdateWeights() { if (optim_info == nullptr) { continue; } - const WeightPtr &weight = weights_[key]; - optim_info->UpdateWeight(weight); const std::vector &inputs = optim_info->inputs(); const std::vector &workspaces = optim_info->workspaces(); const std::vector &outputs = optim_info->outputs(); + optim_info->ComputeMean(worker_num_); optimizer->Execute(inputs, workspaces, outputs); optim_info->Reset(); + if (!is_embedding_[key]) { + tokens_[key] = worker_num_; + } } ResetGradAccumCount(); - accum_grads_cv_.notify_all(); } } template void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { std::unique_lock lock(mutex_); - accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); - const Key &key = keys[0]; std::shared_ptr optim_info = optim_infos_[key]; // Create or update the optimizer info if (optim_info == nullptr) { const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; - std::shared_ptr pserver_kernel = optimizers_[weight_key_to_optims_[key]]; + std::shared_ptr pserver_kernel = optimizers_[key]; if (pserver_kernel == nullptr) { MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; } @@ -427,10 +516,8 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const optim_infos_[key] = optim_info; } else { optim_info->Update(values, lengths); + optim_info->Accumulate(values, lengths); } - MS_EXCEPTION_IF_NULL(optim_info); - - optim_info->Accumulate(values, lengths); grads_accum_counter_[key] += 1; if (grads_accum_counter_[key] == worker_num_) { @@ -444,14 +531,13 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const template WeightPtr ParameterServer::weight(const Key &key) { std::unique_lock lock(mutex_); - if (weights_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid weight key " << key; - return nullptr; + MS_LOG(EXCEPTION) << "Invalid weight key " << key; } WeightPtr weight_ptr = weights_[key]; WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); + tokens_[key] -= 1; return copy_weight_ptr; } @@ -485,8 +571,13 @@ void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, inputs.push_back(indices); embedding_table->addr = table_ptr->data(); embedding_table->size = table_ptr->size() * sizeof(T); - indices->addr = lookup_ids.data(); - indices->size = lookup_ids.size() * sizeof(T); + + std::unique_ptr tmp_ids(new int[lookup_ids.size()]); + for (size_t i = 0; i < lookup_ids.size(); i++) { + tmp_ids[i] = static_cast(lookup_ids[i]); + } + indices->addr = tmp_ids.get(); + indices->size = lookup_ids.size() * sizeof(int); std::vector workspaces; std::vector outputs; @@ -499,7 +590,7 @@ void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, table_lookup_op->Execute(inputs, workspaces, outputs); res->vals = *addr; - res->lens.push_back(res.vals.size()); + res->lens.push_back(res->vals.size()); } template @@ -512,27 +603,27 @@ int ParameterServer::SumOfShapes(const std::vector &shapes) const { } template -size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { - size_t capacity = 0; - for (size_t i = 0; i < keys.size(); i++) { - Key key = keys[i]; - if (embedding_row_lens_.count(key) > 0) { - capacity += embedding_row_lens_[key] * lens[i]; - } else { - MS_LOG(ERROR) << "Invalid embedding lookup id " << key; - } - } - return capacity; +inline bool ParameterServer::ReadyForUpdateWeights() { + return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); } template -inline bool ParameterServer::ReadyForUpdateWeights() { - return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); +inline bool ParameterServer::ReadyForPush(const Key &key) { + std::unique_lock lock(mutex_); + if (weights_.empty()) { + MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " + "kInitWeightsCmd command. 2.The Server failed to initialize weights."; + } + return grad_accum_count_ < weights_.size() && tokens_[key] <= 0; } template -inline bool ParameterServer::ReadyForAccumGrads() { - return grad_accum_count_ < weights_.size(); +inline bool ParameterServer::ReadyForPull(const Key &key) { + std::unique_lock lock(mutex_); + if (tokens_.count(key) == 0 || weights_[key] == 0) { + MS_LOG(EXCEPTION) << "Invalid weight key " << key; + } + return tokens_[key] > 0; } template @@ -543,6 +634,11 @@ inline void ParameterServer::ResetGradAccumCount() { } } +template +inline std::mutex &ParameterServer::mutex() { + return mutex_; +} + template void ParameterServer::Run(const FuncGraphPtr &func_graph) { ::ps::Start(0); @@ -552,8 +648,10 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { } Init(func_graph); thread_->join(); + ::ps::Finalize(0, true); + exit(1); } } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_PARAMETER_SERVER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc index 274b7259b0..04c259487f 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc @@ -23,9 +23,8 @@ namespace parallel { namespace ps { void Scheduler::Run() { ::ps::Start(0); - while (true) { - sleep(1); - } + ::ps::Finalize(0, true); + exit(1); } } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.h b/mindspore/ccsrc/frontend/parallel/ps/scheduler.h index e656bcfd22..4ec0b137cf 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/scheduler.h +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_SCHEDULER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_SCHEDULER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_SCHEDULER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_SCHEDULER_H_ namespace mindspore { namespace parallel { namespace ps { @@ -37,4 +37,4 @@ class Scheduler { } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_SCHEDULER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_SCHEDULER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc index fc63e88901..1bda9c1323 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -17,7 +17,7 @@ #include "frontend/parallel/ps/util.h" #include #include "frontend/parallel/ps/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace parallel { @@ -33,6 +33,13 @@ std::unordered_map Util::id_to_optimizers{ {1, kSparseAdam}, {2, kSparseFtrl}, }; + +std::unordered_map Util::id_to_optimizer_nodes{ + {0, kApplyMomentumOp}, + {1, kSparseAdamOp}, + {2, kSparseFtrlOp}, +}; + bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } bool Util::IsRoleOfWorker() { @@ -112,6 +119,13 @@ std::string Util::optimizer_name(int id) { return ""; } +std::string Util::optimizer_node_name(int id) { + if (id_to_optimizer_nodes.count(id) > 0) { + return id_to_optimizer_nodes[id]; + } + return ""; +} + bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } int Util::LocalShard(int first_dim, int rank_id, int server_num) { diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h index 8947ad36de..c56ab8264e 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.h +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_UTIL_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_UTIL_H_ #include #include @@ -34,14 +34,16 @@ class Util { static void SetInternalEnvVar(); static int optimizer_id(std::string name); static std::string optimizer_name(int id); + static std::string optimizer_node_name(int id); static bool is_optimizer(std::string name); static int LocalShard(int first_dim, int rank_id, int server_num); private: static std::unordered_map optimizer_to_ids; static std::unordered_map id_to_optimizers; + static std::unordered_map id_to_optimizer_nodes; }; } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_UTIL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index 9ecbc28fc5..4908534cb7 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_H_ #include #include @@ -24,6 +24,7 @@ #include #include "ps/ps.h" #include "utils/log_adapter.h" +#include "ir/tensor.h" #include "frontend/parallel/ps/util.h" #include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/worker_proxy.h" @@ -43,17 +44,20 @@ class Worker { void Push(const std::vector &keys, std::vector addrs, const std::vector &sizes); void Pull(const size_t key, void *dev_addr, const size_t size); size_t SetParamKey(const std::string ¶m_name); + void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); + bool GetParamInitInServer(const std::string ¶m_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name); void SetOptimInputShapes(size_t key, const std::vector &shape); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); - void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); - void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + void InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor); + void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); + void Finalize(); private: Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} - ~Worker() { ::ps::Finalize(0, true); } + ~Worker() = default; Worker(const Worker &) = delete; Worker &operator=(const Worker &) = delete; @@ -72,6 +76,7 @@ class Worker { std::map init_keys_; std::map key_to_optimId_; std::map>> key_to_optim_shapes_; + std::map param_to_init_in_server_; }; template @@ -80,7 +85,6 @@ void Worker::Run() { MS_LOG(INFO) << "'Worker is already running."; return; } - ::ps::Start(0); if (!::ps::IsWorker()) { MS_LOG(EXCEPTION) << "The role is not worker."; @@ -98,23 +102,45 @@ void Worker::Push(const std::vector &keys, std::vector add ::ps::SArray total_buffer(total_size, 0); size_t offset = 0; for (size_t i = 0; i < sizes.size(); i++) { - memcpy(total_buffer.data() + offset / sizeof(T), addrs[i], sizes[i] * sizeof(T)); + auto ret = memcpy_s(total_buffer.data() + offset / sizeof(T), sizes[i] * sizeof(T), + reinterpret_cast(addrs[i]), sizes[i] * sizeof(T)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } offset += sizes[i] * sizeof(T); } + while (!kv_worker_->IsReadyForPush(keys[0])) { + continue; + } kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); } template void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { ::ps::SArray variables(size / sizeof(T), 0); + while (!kv_worker_->IsReadyForPull(key)) { + continue; + } kv_worker_->Wait(kv_worker_->ZPull({key}, &variables)); - memcpy(dev_addr, variables.data(), size); + auto ret = memcpy_s(dev_addr, size, variables.data(), size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } } template -void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, +void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd) { - kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, &lookup_result, cmd); + kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); +} + +template +void Worker::Finalize() { + if (running_) { + kv_worker_->Finalize(); + kv_worker_.reset(); + running_ = false; + } } template @@ -154,9 +180,9 @@ void Worker::InitPSOptimInputShapes(const size_t key) { } } } - MS_LOG(ERROR) << "keys:" << keys; - MS_LOG(ERROR) << "shape_len:" << shape_len; - MS_LOG(ERROR) << "all_shape:" << all_shape; + MS_LOG(INFO) << "keys:" << keys; + MS_LOG(INFO) << "shape_len:" << shape_len; + MS_LOG(INFO) << "all_shape:" << all_shape; if (!init_keys_[key]) { init_keys_[key] = true; } @@ -185,12 +211,26 @@ size_t Worker::SetParamKey(const std::string ¶m_name) { return key; } +template +void Worker::SetParamInitInServer(const std::string ¶m_name, bool init_in_server) { + MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server; + param_to_init_in_server_[param_name] = init_in_server; +} + +template +bool Worker::GetParamInitInServer(const std::string ¶m_name) { + if (param_to_init_in_server_.count(param_name) == 0) { + return false; + } + return param_to_init_in_server_[param_name]; +} + template size_t Worker::GetParamKey(const std::string ¶m_name) { size_t key = kInvalidKey; if (param_to_key_.find(param_name) != param_to_key_.end()) { key = param_to_key_[param_name]; - MS_LOG(ERROR) << "Get key of parameter " << param_name << " key is " << key; + MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key; } return key; } @@ -230,17 +270,27 @@ void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vecto template // Initialize parameters and optimizer kernels of Parameter Server. -void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) { +void Worker::InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor) { + void *param_data = tensor->data_c(); + size_t param_size = LongToSize(tensor->data().nbytes()); + std::vector param_shape = tensor->shape_c(); + size_t param_key = GetParamKey(param_name); if (param_key == kInvalidKey) { MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; return; } + bool init_in_server = false; + std::vector shape_init_in_server = {1}; + if (param_shape == shape_init_in_server) { + init_in_server = true; + } + SetParamInitInServer(param_name, init_in_server); bool init = IsKeyInit(param_key); if (!init) { - MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; - // No need to push embedding table data to Parameter Server. - if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { + MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name + << ", whether init in server: " << init_in_server; + if (!init_in_server) { InitPSParamData({param_key}, param_data, param_size); } InitPSOptimId(param_key); @@ -250,10 +300,14 @@ void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_d template void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + bool has_init = IsKeyInit(key); + if (has_init) { + return; + } kv_worker_->AddEmbeddingTable(key, row_count); } } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index a0f58d39a4..7f73081ab7 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -14,14 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ #include #include #include #include #include +#include #include "ps/ps.h" #include "frontend/parallel/ps/util.h" @@ -34,42 +35,40 @@ class WorkerProxy : public ::ps::KVWorker { using Worker = ::ps::KVWorker; using Callback = std::function; using SlicedKVs = std::vector>>; - using Slicer = - std::function &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; + using Slicer = std::function &send, const std::vector<::ps::Range> &ranges, + SlicedKVs *sliced)>; using ::ps::SimpleApp::obj_; explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { - using _1 = std::placeholders::_1; - using _2 = std::placeholders::_2; - using _3 = std::placeholders::_3; + using std::placeholders::_1; + using std::placeholders::_2; + using std::placeholders::_3; + using std::placeholders::_4; lookup_customer_ = std::unique_ptr<::ps::Customer>( new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); - lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3); - init_embedding_slicer_ = std::bind(&WorkerProxy::EmbeddingTableInitSlicer, this, _1, _2, _3); - push_slicer_ = std::bind(&WorkerProxy::PushSlicer, this, _1, _2, _3); - broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3); + lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3, _4); + broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3, _4); } ~WorkerProxy() override = default; void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, int priority = 0); int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); + bool IsReadyForPush(const Key &key); + bool IsReadyForPull(const Key &key); void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, int cmd = 0, int priority = 0); + void Finalize(); private: template - int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, + int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, const Callback &cb); - void LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + void LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced); - void EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); - void PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); - void BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + void BroadcastSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced); void ProcessLookupResult(const ::ps::Message &msg); void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, @@ -80,10 +79,9 @@ class WorkerProxy : public ::ps::KVWorker { std::unordered_map>> lookup_results_; std::mutex mutex_; Slicer lookup_slicer_; - Slicer init_embedding_slicer_; - Slicer push_slicer_; Slicer broadcast_slicer_; std::unordered_map lookup_callbacks_; + std::unordered_map expected_result_count_; }; template @@ -108,17 +106,21 @@ void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_c } template -void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, +void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, int priority) { int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); ::ps::KVPairs kvs; kvs.keys = keys; - kvs.vals = lookup_ids; - kvs.lens = lens; + kvs.lens = lookup_ids; kvs.priority = priority; - Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); + expected_result_count_[ts] = 0; + Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_); + int server_num = ::ps::NumServers(); + int expect_rt_count = expected_result_count_[ts]; + lookup_customer_->AddResponse(ts, server_num - expect_rt_count); lookup_customer_->WaitRequest(ts); + expected_result_count_.erase(ts); } template @@ -130,10 +132,32 @@ int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons kvs.vals = vals; kvs.lens = lens; kvs.priority = priority; - Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); + Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_); return ts; } +template +bool WorkerProxy::IsReadyForPush(const Key &key) { + ::ps::SArray result(1, 0); + this->Wait(this->ZPull({key}, &result, nullptr, kCheckReadyForPushCmd)); + if (result[0] > 0) { + return true; + } else { + return false; + } +} + +template +bool WorkerProxy::IsReadyForPull(const Key &key) { + ::ps::SArray result(1, 0); + this->Wait(this->ZPull({key}, &result, nullptr, kCheckReadyForPullCmd)); + if (result[0] > 0) { + return true; + } else { + return false; + } +} + template void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, int cmd, int priority) { @@ -143,13 +167,24 @@ void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::S kvs.vals = vals; kvs.lens = lens; kvs.priority = priority; - Send(obj_, ts, true, false, cmd, kvs, push_slicer_); + Send(obj_, ts, true, false, cmd, kvs, broadcast_slicer_); obj_->WaitRequest(ts); } +template +void WorkerProxy::Finalize() { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys.push_back(0); + kvs.vals.push_back(0.0f); + Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); + obj_->WaitRequest(ts); + ::ps::Finalize(0, true); +} + template template -int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, +int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *lookup_result, int cmd, const Callback &cb) { int ts = lookup_customer_->NewRequest(::ps::kServerGroup); const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { @@ -157,20 +192,8 @@ int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps: auto &kvs = lookup_results_[ts]; mutex_.unlock(); - size_t total_len = 0; - const auto &s = kvs[0]; - for (size_t i = 0; i < s.lens.size(); i++) { - total_len += s.lens[i]; - } - lookup_result->resize(total_len, 0); - T *result_addr = lookup_result->data(); - - for (const auto &s : kvs) { - size_t offset = 0; - for (size_t i = 0; i < s.vals.size(); i++) { - result_addr[offset++] += s.vals[i]; - } - } + auto &s = kvs[0]; + *lookup_result = s.vals; mutex_.lock(); lookup_results_.erase(ts); @@ -182,68 +205,36 @@ int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps: } template -void WorkerProxy::LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, +void WorkerProxy::LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced) { - int *data = send.lens.data(); - size_t size = send.lens.size(); - std::vector lookup_ids(data, data + size); - std::sort(lookup_ids.begin(), lookup_ids.end()); + int *lookup_ids = send.lens.data(); + size_t id_size = send.lens.size(); const Key &key = send.keys[0]; const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); sliced->resize(ranges.size()); - size_t index = 0; for (size_t i = 0; i < ranges.size(); i++) { - const ::ps::Range &range = ranges[i]; - const auto &begin = range.begin(); - const auto &end = range.end(); auto &kvs = sliced->at(i).second; - auto lookup_id = static_cast(lookup_ids[index]); - while (lookup_id >= begin && lookup_id <= end) { - kvs.vals.push_back(lookup_id); - if (++index >= lookup_ids.size()) { - break; - } - lookup_id = static_cast(lookup_ids[index]); - } kvs.keys.push_back(key); - kvs.lens.push_back(kvs.vals.size()); + kvs.vals.push_back(0.0f); + for (size_t j = 0; j < id_size; j++) { + kvs.keys.push_back(lookup_ids[j]); + kvs.vals.push_back(0.0f); + } - if (kvs.vals.size() == 0) { + if (kvs.keys.size() <= 1) { sliced->at(i).first = false; } else { sliced->at(i).first = true; + expected_result_count_[timestamp] += 1; } } } template -void WorkerProxy::EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { - const Key &key = send.keys[0]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); - sliced->resize(ranges.size()); - for (size_t i = 0; i < ranges.size(); i++) { - sliced->at(i).first = true; - sliced->at(i).second = send; - } -} - -template -void WorkerProxy::PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { - auto server_num = ::ps::Postoffice::Get()->num_servers(); - sliced->resize(server_num); - for (int i = 0; i < server_num; i++) { - sliced->at(i).first = true; - sliced->at(i).second = send; - } -} - -template -void WorkerProxy::BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, +void WorkerProxy::BroadcastSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced) { auto server_num = ::ps::Postoffice::Get()->num_servers(); sliced->resize(server_num); @@ -268,7 +259,7 @@ void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { lookup_results_[ts].push_back(kvs); mutex_.unlock(); } - if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { + if (lookup_customer_->NumResponse(ts) == expected_result_count_[ts] - 1) { const auto &cb = lookup_callbacks_[ts]; cb(); lookup_callbacks_.erase(ts); @@ -279,7 +270,7 @@ template void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, const Slicer &slicer) { SlicedKVs sliced; - slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); + slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); for (size_t i = 0; i < sliced.size(); i++) { const auto &s = sliced[i]; @@ -308,4 +299,4 @@ void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bo } // namespace ps } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/status.h b/mindspore/ccsrc/frontend/parallel/status.h index 6bfe9f0e72..e066b79635 100644 --- a/mindspore/ccsrc/frontend/parallel/status.h +++ b/mindspore/ccsrc/frontend/parallel/status.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_STATUS_H_ -#define MINDSPORE_CCSRC_PARALLEL_STATUS_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STATUS_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STATUS_H_ #include @@ -29,4 +29,4 @@ enum Status { } } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_STATUS_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STATUS_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 8d54eb454a..33d2bbc609 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -41,6 +41,7 @@ #include "frontend/parallel/context.h" #include "frontend/parallel/ops_info/tmp_identity_info.h" #include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "pipeline/jit/parse/python_adapter.h" @@ -122,12 +123,7 @@ std::vector ExtractInputParameterByNode(const CNodePtr &node) { if (input->isa()) { auto input_parameter = input->cast(); - if (input_parameter->has_default()) { - bool requires_grad = input_parameter->default_param()->requires_grad(); - is_parameter.push_back(requires_grad); - } else { - is_parameter.push_back(false); - } + is_parameter.push_back(ParameterRequireGrad(input_parameter)); } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { is_parameter.push_back(false); } @@ -260,7 +256,7 @@ bool IsSplittableOperator(const std::string &op_name) { REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, - STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, + STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; // clang-format on @@ -338,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_cnode(cnode); // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); + std::string strategy_key_name = ""; + auto param_names = NodeParameterName(cnode); + if (!param_names.empty()) { + strategy_key_name = param_names[0].first; + } bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); // If no strategy has been configured for this operator, then candidate strategies are generated for @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); + cnode->set_user_data(operator_info); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); + cnode->set_user_data(operator_info); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() << " does not match the Prim: " << prim->name(); } - (void)cnode->set_operator_info(current_op_ptr); + cnode->set_user_data(current_op_ptr); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); @@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { PrimitivePtr prim = GetValueNode(prim_anf_node); size_t edge_count = 0; + auto node_op_info = cnode->user_data(); + for (size_t i = 1; i < inputs.size(); ++i) { auto prev_cnode = inputs[i]->cast(); bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); @@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); while (bool_result) { if (IsAutoParallelCareNode(prev_cnode)) { - std::string edge_name = - prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); + auto prev_op_info = prev_cnode->user_data(); + std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); // If the edge between these two operators already has been added, then the edge will not be added again. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { break; @@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { if (follow_strategy) { // Redistribution in not allowed on the edge. // Elementwise operators have the same strategy as their previous operators. - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false, true); + edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true); } else { - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false); + edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, i - 1, false); } // Init costs for this edge if (edge_ptr->InitEdgeCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Edge cost initialization failed"; } - cnode->operator_info()->AddPrevEdge(edge_ptr); - prev_cnode->operator_info()->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " - << cnode->operator_info()->name(); + node_op_info->AddPrevEdge(edge_ptr); + prev_op_info->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " + << node_op_info->name(); edge_count++; break; @@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); } } - MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); + MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); } MS_LOG(INFO) << "Constructing edges for cost graph ends."; @@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector &all_nodes) { for (auto &target : target_set) { auto target_cnode = target.first->cast(); auto input_index = target.second; - (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); + (void)target_without_duplicate.insert(std::to_string(input_index) + + target_cnode->user_data()->name()); } if (target_without_duplicate.size() <= 1) { continue; @@ -782,7 +783,7 @@ void AugmentCostGraph(const std::vector &all_nodes) { std::vector shape_int = input_shape->shape(); Shape shape; (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape), - [](int sub_shape) { return static_cast(sub_shape); }); + [](int sub_shape) { return static_cast(sub_shape); }); Shapes inputs_shape = {shape}; Shapes outputs_shape = {shape}; // 2) init the attr @@ -797,12 +798,7 @@ void AugmentCostGraph(const std::vector &all_nodes) { std::vector is_parameter; auto casted_target_parameter = target_parameter->cast(); MS_EXCEPTION_IF_NULL(casted_target_parameter); - if (casted_target_parameter->has_default()) { - bool requires_grad = casted_target_parameter->default_param()->requires_grad(); - is_parameter.push_back(requires_grad); - } else { - is_parameter.push_back(false); - } + is_parameter.push_back(ParameterRequireGrad(casted_target_parameter)); if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) { MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed"; } @@ -830,24 +826,24 @@ void AugmentCostGraph(const std::vector &all_nodes) { auto target_cnode = target.first->cast(); auto prim = GetValueNode(target_cnode->input(0)); auto input_index = target.second; + auto target_op_info = target_cnode->user_data(); - std::string edge_name = - std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); + std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); // If the edge between these two operators already has been added, then the edge will not be added again. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { continue; } - std::shared_ptr edge_ptr = std::make_shared( - edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); + std::shared_ptr edge_ptr = + std::make_shared(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true); if (edge_ptr->InitEdgeCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Edge cost initialization failed"; } - target_cnode->operator_info()->AddPrevEdge(edge_ptr); + target_op_info->AddPrevEdge(edge_ptr); tmp_identity_ptr->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); + entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr); MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " - << target_cnode->operator_info()->name(); + << target_op_info->name(); add_identity_edge = true; } if (new_identity && add_identity_edge) { @@ -861,20 +857,13 @@ bool FindReshape(const CNodePtr &cnode) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { return false; } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { return false; } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; - } - if (prim->name() != RESHAPE) { - return false; - } - return true; + return (prim->name() == RESHAPE); } // find previous node, then obtain its strategy_cost_ vector to get its layout vector. @@ -890,8 +879,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ if (!IsValueNode(cnode->input(0))) { return false; } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - *pre_operator_info = cnode->operator_info(); + auto node_op_info = cnode->user_data(); + if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { + *pre_operator_info = node_op_info; *out_index = 0; return true; } @@ -905,8 +895,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; } CNodePtr pre_cnode = pre_node->cast(); - if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { - *pre_operator_info = pre_cnode->operator_info(); + auto pre_op_info = pre_cnode->user_data(); + if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { + *pre_operator_info = pre_op_info; return true; } return false; @@ -945,14 +936,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + auto op_info = use_apply->user_data(); + if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); - *next_operator_info = use_apply->operator_info(); + *next_operator_info = op_info; *in_index = node_pair.second - 1; return true; } MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); + << " " << (op_info != nullptr); if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { return true; @@ -973,8 +965,8 @@ void ReshapeCostCompute(const std::vector &all_nodes) { int32_t out_index = 0; OperatorInfoPtr pre_operator_info; std::vector> pre_stra_costs; + auto operator_info = cnode->user_data(); if (pre_node->isa()) { - OperatorInfoPtr operator_info = cnode->operator_info(); auto reshape_info = std::dynamic_pointer_cast(operator_info); reshape_info->SetCostForReshapeWithParameter(); pre_operator_info = reshape_info; @@ -995,7 +987,6 @@ void ReshapeCostCompute(const std::vector &all_nodes) { } // set input_layout and output_layout for reshape. // init reshape and set cost for each input_layout and output_layout. - OperatorInfoPtr operator_info = cnode->operator_info(); auto reshape_info = std::dynamic_pointer_cast(operator_info); reshape_info->set_pre_operator_name(pre_operator_info->name()); reshape_info->set_pre_operator_index(out_index); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 6b9cfd9d37..f8ec0d3288 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { if (!IsParallelCareNode(node)) { return nullptr; } - OperatorInfoPtr distribute_operator = node->operator_info(); + OperatorInfoPtr distribute_operator = node->user_data(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; } @@ -302,16 +302,20 @@ void Redistribution(const std::pair &node_pair, const OperatorI MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); // extract tensor layout in and out if (distribute_operator->outputs_tensor_info().empty()) { - MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; + MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name(); + return; } if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " - << next_distribute_operator->inputs_tensor_info().size(); + MS_LOG(WARNING) << "The index is out of range, the index is " << index - 1 << ", the vector size is " + << next_distribute_operator->inputs_tensor_info().size() << "next operator name is " + << next_distribute_operator->name(); + return; } TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); + if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " @@ -405,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { if (prim->name() == GET_NEXT) { return true; } - if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { + if ((prim->name() == CAST) && !cnode->has_user_data()) { return false; } @@ -442,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { + if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data()) { Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, pre_node); } else { @@ -455,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(next_node); - OperatorInfoPtr op_info = next_node->operator_info(); + OperatorInfoPtr op_info = next_node->user_data(); MS_EXCEPTION_IF_NULL(op_info); // If the shape of tensor is [] or [1], no need to split it. @@ -580,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { // step1:get graph manager distribute_operator - OperatorInfoPtr distribute_operator = node->operator_info(); + OperatorInfoPtr distribute_operator = node->user_data(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; } @@ -618,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { (void)prim->SetAttrs(attrs); } if (index == replace_op.size() - 1) { - (void)replace_node->set_operator_info(node->operator_info()); + replace_node->set_user_data(node->user_data()); } replace_node->set_in_forward_flag(true); replace_input[0]->set_scope(scope); @@ -698,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { auto pre_cnode = pre_node->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + if (pre_prim->name() == CAST && !pre_cnode->has_user_data()) { pre_node = pre_cnode->input(1); } @@ -1015,14 +1019,16 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { } if (var->size() > 0) { std::vector elements = var->value(); - std::vector strategy; + Strategys strategy; for (uint32_t index = 0; index < elements.size(); ++index) { Dimensions dim; if (elements[index]->isa()) { ValueTuplePtr value_tuple = elements[index]->cast(); std::vector value_vector = value_tuple->value(); - (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), - [](const ValuePtr &value) { return static_cast(GetValue(value)); }); + (void)std::transform( + value_vector.begin(), value_vector.end(), std::back_inserter(dim), [](const ValuePtr &value) { + return value->isa() ? GetValue(value) : static_cast(GetValue(value)); + }); strategy.push_back(dim); } else { MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; @@ -1071,12 +1077,20 @@ Shapes GetNodeShape(const AnfNodePtr &node) { for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); - shapes.push_back(each_shape->shape()); + std::vector shape_int = each_shape->shape(); + Shape new_shape; + (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(new_shape), + [](const int &value) { return static_cast(value); }); + shapes.push_back(new_shape); } } else { auto shape_ptr = dyn_cast(base_shape_ptr); MS_EXCEPTION_IF_NULL(shape_ptr); - shapes.push_back(shape_ptr->shape()); + std::vector shape_int = shape_ptr->shape(); + Shape new_shape; + (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(new_shape), + [](const int &value) { return static_cast(value); }); + shapes.push_back(new_shape); } return shapes; } @@ -1194,7 +1208,7 @@ std::pair FindParallelCareNode(const AnfNodePtr &node) { if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { return node_pair; } else if (FindParallelCareNode(node_pair.first).first != nullptr) { return FindParallelCareNode(node_pair.first); @@ -1244,7 +1258,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pairToString() << " shape " << parameter->Shape()->ToString(); CNodePtr cnode = res.first->cast(); MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = cnode->operator_info(); + OperatorInfoPtr distribute_operator = cnode->user_data(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; } @@ -1267,7 +1281,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::paircast(); MS_EXCEPTION_IF_NULL(parameter_ptr); - parameter_ptr->set_tensor_layout(std::make_shared(tensor_layout)); + parameter_ptr->set_user_data(std::make_shared(tensor_layout)); } void CoverSliceShape(const FuncGraphPtr &root) { @@ -1291,11 +1305,8 @@ void CoverSliceShape(const FuncGraphPtr &root) { g_RefMap.clear(); } -bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { - MS_EXCEPTION_IF_NULL(root); +bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { MS_EXCEPTION_IF_NULL(parameter_node); - FuncGraphManagerPtr manager = root->manager(); - MS_EXCEPTION_IF_NULL(manager); auto cloned_parameter = parameter_node->cast(); MS_EXCEPTION_IF_NULL(cloned_parameter); @@ -1303,8 +1314,12 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_nod if (!cloned_parameter->has_default()) { return false; } - - bool cloned = cloned_parameter->default_param()->cloned(); + auto obj = py::cast(cloned_parameter->default_param()); + auto param_value = py::cast(obj.attr("_value")); + if (param_value == nullptr) { + return false; + } + bool cloned = param_value->cloned(); if (!cloned) { return false; } @@ -1320,12 +1335,16 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { auto cloned_parameter = cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(cloned_parameter); - if (!ParameterIsCloned(root, cloned_parameter_node)) { + if (!ParameterIsCloned(cloned_parameter_node)) { + continue; + } + auto obj = py::cast(cloned_parameter->default_param()); + auto param_value = py::cast(obj.attr("_value")); + if (param_value == nullptr) { continue; } - // get the cloned index - int32_t cloned_index = cloned_parameter->default_param()->cloned_index(); + int32_t cloned_index = param_value->cloned_index(); // find the be cloned parameter bool found_be_cloned_parameter = false; @@ -1340,12 +1359,18 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { } const auto ¶m_value_cloned = be_cloned_parameter->default_param(); - if (!param_value_cloned->be_cloned()) { + + auto obj_in = py::cast(param_value_cloned); + auto param_value_in = py::cast(obj_in.attr("_value")); + if (param_value_in == nullptr) { + continue; + } + if (!param_value_in->be_cloned()) { continue; } // get the be cloned index - auto &be_cloned_index = param_value_cloned->be_cloned_index(); + auto &be_cloned_index = param_value_in->be_cloned_index(); if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) { found_be_cloned_parameter = true; cloned_from_parameter = be_cloned_parameter; @@ -1355,7 +1380,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { if (found_be_cloned_parameter) { // set the shape and tensor layout for cloned parameter - cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); + cloned_parameter->set_user_data(cloned_from_parameter->user_data()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); @@ -1370,10 +1395,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { << cloned_index << ", but not found the be cloned parameter"; } } - std::string env = common::GetEnv("SLICE_ENV"); - if (!env.empty()) { - MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; - } } void SetVirtualDatasetStrategy(const CNodePtr &node) { @@ -1401,7 +1422,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { if (shape_list[0][i].empty()) { MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; } - std::vector input_strategy = {dev_num}; + Dimensions input_strategy = {dev_num}; for (size_t j = 1; j < shape_list[0][i].size(); j++) { input_strategy.push_back(1); } @@ -1454,18 +1475,22 @@ void ExtractInformation(const std::vector &all_nodes) { (*operator_).set_outputs_dtype(cnode->Type()); (*operator_).set_cnode(cnode); if (prim->name() == RESHAPE) { - (void)cnode->set_operator_info(operator_); + cnode->set_user_data(operator_); continue; } // load strategy checkpoint // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); + std::string strategy_key_name = ""; + auto param_names = NodeParameterName(cnode); + if (!param_names.empty()) { + strategy_key_name = param_names[0].first; + } bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() << " is empty, using batch parallel"; - std::shared_ptr> strategy_v_ptr = operator_->GenerateBatchStrategies(); + std::shared_ptr strategy_v_ptr = operator_->GenerateBatchStrategies(); if (strategy_v_ptr == nullptr) { MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; } @@ -1489,7 +1514,7 @@ void ExtractInformation(const std::vector &all_nodes) { if (operator_->Init(strategyPtr) == FAILED) { MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; } - (void)cnode->set_operator_info(operator_); + cnode->set_user_data(operator_); } else { MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; } @@ -1532,13 +1557,13 @@ std::shared_ptr FindNextLayout(const CNodePtr &cnode) { if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + if (IsParallelCareNode(use_apply) && use_apply->has_user_data()) { MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); auto layout = GetInputLayoutFromCNode(node_pair); return std::make_shared(layout); } MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); + << " " << use_apply->has_user_data(); auto layout_ptr = FindNextLayout(use_apply); if (layout_ptr) { @@ -1570,7 +1595,7 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &n if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -1614,7 +1639,7 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -1654,12 +1679,12 @@ void ReshapeInit(const std::vector &all_nodes) { continue; } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { continue; } PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); + OperatorInfoPtr operator_info = cnode->user_data(); if (operator_info == nullptr) { MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; } @@ -1704,7 +1729,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { auto current_prim = GetValueNode(pre_cnode->input(0)); // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + if (current_prim->name() == CAST && !pre_cnode->has_user_data()) { pre_cnode = pre_cnode->input(1)->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); current_prim = GetValueNode(pre_cnode->input(0)); @@ -1761,7 +1786,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { return ret; } - OperatorInfoPtr operator_info = loss_cnode->operator_info(); + OperatorInfoPtr operator_info = loss_cnode->user_data(); MS_EXCEPTION_IF_NULL(operator_info); TensorInfo loss_grad_tensor_info; size_t op_output_size = operator_info->outputs_tensor_info().size(); @@ -1799,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay if (sens_tensor_node->isa()) { auto sens_tensor_param = sens_tensor_node->cast(); MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + sens_tensor_param->set_user_data(std::make_shared(loss_grad_layout)); } MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; return; @@ -1824,7 +1849,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay cloned_abstract->set_shape(parallel_shape); sens_tensor_node->set_abstract(cloned_abstract); auto sens_tensor_param = sens_tensor_node->cast(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + sens_tensor_param->set_user_data(std::make_shared(loss_grad_layout)); return; } MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; @@ -1890,8 +1915,24 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); } +void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() < 3 || !IsValueNode(cnode->input(0))) { + return; + } + auto prim = GetValueNode(cnode->input(0)); + if (prim->name() != TILE) { + return; + } + + TileInfoPtr tile = std::dynamic_pointer_cast(distribute_operator); + MS_EXCEPTION_IF_NULL(tile); + tile->UpdateMultiples(cnode); +} + void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { HandleDropoutNode(distribute_operator, cnode); + HandleTileNode(distribute_operator, cnode); } std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { @@ -2081,26 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector> NodeParameterName(const CNodePtr &node) { std::vector node_inputs{node->inputs()}; - for (auto input : node_inputs) { + std::vector> param_names; + for (int i = 0; i < UintToInt(node_inputs.size()); ++i) { + auto input = node_inputs[i]; if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default()) { - const auto ¶m_value = input_parameter->default_param(); - if (param_value->requires_grad()) { - return param_value->name(); + if (ParameterRequireGrad(input_parameter)) { + param_names.push_back({input_parameter->name(), i}); } } } } - return ""; + return param_names; } void CheckpointStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; StrategyMap stra_map; + TensorInfoMap tensor_info_map; + ManualShapeMap manual_shape_map; auto ret = func_graph->get_return(); auto all_nodes = DeepScopedGraphSearch(ret); for (auto &node : all_nodes) { @@ -2109,23 +2153,45 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; } - std::string param_name = NodeParameterName(cnode); - if (param_name.empty()) { + auto param_names = NodeParameterName(cnode); + if (param_names.empty()) { continue; } + string param_name = param_names[0].first; PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); + OperatorInfoPtr operator_info = cnode->user_data(); if (operator_info) { if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { continue; } + std::vector input_tensor_info = operator_info->inputs_tensor_info(); StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); stra_map[param_name] = strategyPtr; + for (auto param_name_pair : param_names) { + if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { + continue; + } + tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1]; + } + if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos || + operator_info->name().find(GATHERV2) != std::string::npos) { + auto gatherv2_info = std::dynamic_pointer_cast(operator_info); + auto param_split_shapes = gatherv2_info->param_split_shapes(); + auto index_offsets = gatherv2_info->index_offsets(); + if (param_split_shapes.size() != index_offsets.size()) { + MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; + } + std::vector> manual_shape; + for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) { + manual_shape.push_back({param_split_shapes[i], index_offsets[i]}); + } + manual_shape_map[param_name] = manual_shape; + } } } - if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { + if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; } } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index f9fe67ea6b..4e142e4aee 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_H_ #include @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector &all_nodes); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -std::string NodeParameterName(const CNodePtr &node); +std::vector> NodeParameterName(const CNodePtr &node); void CheckpointStrategy(const FuncGraphPtr &func_graph); @@ -152,4 +152,4 @@ std::set ForwardGraph(const FuncGraphPtr &root); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/strategy.h b/mindspore/ccsrc/frontend/parallel/strategy.h index ca01164a6a..95b09c6cb0 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy.h +++ b/mindspore/ccsrc/frontend/parallel/strategy.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ -#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_ #include #include @@ -24,30 +24,46 @@ #include #include "frontend/parallel/status.h" +#include "frontend/parallel/device_matrix.h" namespace mindspore { namespace parallel { #define MIN_SLICE_NUM 1 -using Dimensions = std::vector; - +using Dimensions = Shape; +using Strategys = std::vector; class Strategy; using StrategyPtr = std::shared_ptr; class Strategy { public: - Strategy(int32_t stage, std::vector inputs) : stage_(stage), inputs_(std::move(inputs)) {} + Strategy(int32_t stage, Strategys inputs) + : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {} + + Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) { + inputs_ = another_stra.GetInputDim(); + internal_size_ = another_stra.GetInternalSize(); + if (internal_size_ != 0) { + internal_stragies_ = another_stra.GetInternalStrategies(); + } else { + internal_stragies_ = {}; + } + } + ~Strategy() = default; size_t GetInputNumber() const { return inputs_.size(); } - std::vector GetInputDim() const { return inputs_; } + Strategys GetInputDim() const { return inputs_; } int32_t GetInputStage() const { return stage_; } void ExpandInputDimFromOneToTwo() { if (inputs_.size() == 1) { inputs_.push_back(inputs_[0]); } } - void ResetInputs(const std::vector &input) { inputs_ = input; } + void ResetInputs(const Strategys &input) { inputs_ = input; } + std::vector GetInternalStrategies() const { return internal_stragies_; } + size_t GetInternalSize() const { return internal_size_; } + // TODO(Xiaoda): need fix for adapting 'CoverStrategy' bool IsEqual(const StrategyPtr &another_stra) { if (another_stra == nullptr) { return false; @@ -58,17 +74,25 @@ class Strategy { return true; } + // Include 'another_stra' into this strategy + void CoverStrategy(const StrategyPtr &another_stra) { + internal_stragies_.push_back(another_stra); + internal_size_++; + } + private: const int32_t stage_; // The size of Dimensions must equal to inputs_ tensor dimension. - std::vector inputs_; + Strategys inputs_; + size_t internal_size_ = 0; + std::vector internal_stragies_; }; -inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { +inline StrategyPtr NewStrategy(const int32_t stage, const Strategys &inputs) { return std::make_shared(stage, inputs); } } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index bf7c4e29ab..9ab55e6915 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -20,7 +20,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" #include "proto/node_strategy.pb.h" @@ -66,10 +66,10 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); auto stage = (int32_t)parallel_strategys.stage(); size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size()); - std::vector> strategy_inputs; + Strategys strategy_inputs; for (size_t j = 0; j < strategys_num; j++) { straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); - std::vector dimension; + Dimensions dimension; size_t dim_num = IntToSize(parallel_strategy.dim_size()); for (size_t k = 0; k < dim_num; k++) { dimension.push_back(parallel_strategy.dim(SizeToInt(k))); @@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, + ManualShapeMap *manual_shape_map) { straspb::ParallelStrategyMap parallel_strategy_map; parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); for (auto &node_stra : strategy_map) { @@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { } } } + for (auto &node_tensor_info : tensor_info_map) { + TensorInfo tensor_info = node_tensor_info.second; + TensorLayout tensor_layout = tensor_info.tensor_layout(); + straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); + MS_EXCEPTION_IF_NULL(parallel_layout_item); + parallel_layout_item->set_param_name(node_tensor_info.first); + straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); + straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); + MS_EXCEPTION_IF_NULL(dev_matrix); + for (auto dim : tensor_layout.device_arrangement().array()) { + dev_matrix->add_dim(IntToUint(dim)); + } + straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); + MS_EXCEPTION_IF_NULL(tensor_map); + for (auto dim : tensor_layout.tensor_map().array()) { + tensor_map->add_dim(dim); + } + straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); + straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset(); + MS_EXCEPTION_IF_NULL(manual_shape_map); + auto manual_shape = (*manual_shape_map)[node_tensor_info.first]; + for (auto dim_pair : manual_shape) { + param_split_shape->add_dim(dim_pair.first); + indices_offset->add_dim(dim_pair.second); + } + } + std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); if (!parallel_strategy_map.SerializeToOstream(&output)) { MS_LOG(ERROR) << "Save strategy file failed"; diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index 67cbb92ee2..1b1a4e17c0 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -14,18 +14,24 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ -#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ #include #include +#include +#include #include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/context.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" namespace mindspore { namespace parallel { using StrategyMap = std::unordered_map; +using TensorInfoMap = std::unordered_map; +using ManualShapeMap = std::unordered_map>>; class StrategyCheckpoint { public: StrategyCheckpoint() { @@ -38,7 +44,7 @@ class StrategyCheckpoint { ~StrategyCheckpoint() = default; Status Load(StrategyMap *strategy_map); - Status Save(const StrategyMap &strategy_map); + Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); static StrategyCheckpoint &GetInstance(); bool LoadCheckPointOn() const { return load_checkpoint_on_; } @@ -55,4 +61,4 @@ class StrategyCheckpoint { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc index cff3d53a88..565750b944 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc @@ -18,7 +18,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/status.h" #include "frontend/parallel/tensor_layout/shape_util.h" #include "utils/convert_utils.h" @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Arrangement::Init(const std::vector &array) { +Status Arrangement::Init(const Shape &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -40,7 +40,7 @@ Status Arrangement::Init(const std::vector &array) { } bool Arrangement::IsValidArrangement() { - return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; }); + return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0; }); } void Arrangement::ComputeSize() { @@ -57,14 +57,14 @@ void Arrangement::ComputeSize() { * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1], * if value > size_, return [] */ -std::vector Arrangement::GetFrontElementByValue(int32_t value) const { - std::vector out; +Shape Arrangement::GetFrontElementByValue(int64_t value) const { + Shape out; if (GetDimSize() == 0) { return out; } if (value <= size_) { - int32_t size = 1; - uint32_t shape_list_idx = 0; + int64_t size = 1; + size_t shape_list_idx = 0; while (size < value) { size *= array_[shape_list_idx]; if (size <= value) { @@ -88,9 +88,9 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft if (expand_list.size() != GetDimSize()) { return nullptr; } - std::vector new_shape; - for (uint32_t i = 0; i < expand_list.size(); i++) { - std::vector expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); + Shape new_shape; + for (size_t i = 0; i < expand_list.size(); i++) { + Shape expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); if (expand_shape.empty()) { new_shape.push_back(GetDimByIdx(i)); } else { @@ -109,11 +109,11 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft * arrangement_list = [[4, 2], [2, 2]] */ std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { - int32_t size = 1; - uint32_t ind = 0; + int64_t size = 1; + size_t ind = 0; std::vector arrangement_list; - std::vector shape; - for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) { + Shape shape; + for (size_t i = 0; i < expand_shape.GetDimSize(); i++) { size *= expand_shape.GetDimByIdx(i); if (size > GetDimByIdx(ind)) { MS_LOG(ERROR) << "invalid expand_shape"; @@ -145,7 +145,7 @@ std::shared_ptr, Arrangement>> Arrangement::G if (expand_shape_list_ptr == nullptr) { return nullptr; } - std::vector expand_num_list_shape; + Shape expand_num_list_shape; (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), std::back_inserter(expand_num_list_shape), [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); @@ -158,9 +158,9 @@ std::shared_ptr, Arrangement>> Arrangement::G return std::make_shared, Arrangement>>(out_value); } -std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { - std::vector shape_accum; - int32_t size = 0; +Shape Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { + Shape shape_accum; + int64_t size = 0; for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) { shape_accum.push_back(size); size += *iter; @@ -173,11 +173,11 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLef if (expand_list.size() != GetDimSize()) { return nullptr; } - std::vector new_shape; - for (uint32_t i = 0; i < expand_list.size(); i++) { + Shape new_shape; + for (size_t i = 0; i < expand_list.size(); i++) { if (expand_list[i].GetDimSize() >= 1) { - int32_t size = 1; - for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { + int64_t size = 1; + for (size_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { new_shape.push_back(expand_list[i].GetDimByIdx(k)); size *= expand_list[i].GetDimByIdx(k); } @@ -207,7 +207,7 @@ std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2 if (status != Status::SUCCESS) { return nullptr; } - std::vector out_shape; + Shape out_shape; status = AccumulateProductToShape(out_accum, &out_shape); if (status != Status::SUCCESS) { return nullptr; @@ -231,8 +231,8 @@ std::vector Arrangement::GetSqueezeIdx() const { } Arrangement Arrangement::GetSqueezeArrangement() const { - std::vector out_shape(array_.size()); - auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; }); + Shape out_shape(array_.size()); + auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int64_t value) { return value != 1; }); out_shape.resize(LongToSize(std::distance(out_shape.begin(), it))); // if all elements are 1, out_shape = {1} diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h index ab807fb20a..184739e41d 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ #include #include @@ -32,11 +32,11 @@ class Arrangement : public Array { public: Arrangement() : size_(1) {} ~Arrangement() override = default; - Status Init(const std::vector &array) override; - int32_t size() const { return size_; } - std::vector GetFrontElementByValue(int32_t value) const; + Status Init(const Shape &array) override; + int64_t size() const { return size_; } + Shape GetFrontElementByValue(int64_t value) const; std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; - std::vector ComputeReverseAccumulateSumInReverseOrder() const; + Shape ComputeReverseAccumulateSumInReverseOrder() const; std::shared_ptr GetExpandedShapeByExpandListReserveLeft( const std::vector &expand_list) const; std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( @@ -50,9 +50,9 @@ class Arrangement : public Array { private: bool IsValidArrangement(); void ComputeSize(); - int32_t size_; + int64_t size_; }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc index 4e1f467793..71c58f1ddb 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc @@ -31,14 +31,14 @@ std::string Array::ToString() const { return buffer.str(); } -Status Array::Init(const std::vector &array) { +Status Array::Init(const Shape &array) { array_ = array; return IsvalidArray() ? Status::SUCCESS : Status::FAILED; } bool Array::IsvalidArray() const { return true; } -int32_t Array::GetDimByIdx(uint32_t idx) const { +int64_t Array::GetDimByIdx(size_t idx) const { size_t mod_idx = idx; if (idx >= GetDimSize()) { MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize(); @@ -46,7 +46,7 @@ int32_t Array::GetDimByIdx(uint32_t idx) const { return array_[mod_idx]; } -int32_t Array::GetDimByReverseIdx(uint32_t idx) const { +int64_t Array::GetDimByReverseIdx(size_t idx) const { size_t mod_idx = idx; if (idx >= GetDimSize()) { MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize(); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h index 13b3982a18..c939d607a7 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ #include #include @@ -23,6 +23,7 @@ #include #include #include "frontend/parallel/status.h" +#include "frontend/parallel/device_matrix.h" namespace mindspore { namespace parallel { @@ -31,18 +32,18 @@ class Array { Array() = default; virtual ~Array() = default; std::string ToString() const; - virtual Status Init(const std::vector &array); + virtual Status Init(const Shape &array); bool IsvalidArray() const; - std::vector array() const { return array_; } + Shape array() const { return array_; } size_t GetDimSize() const { return array_.size(); } - int32_t GetDimByIdx(uint32_t idx) const; - int32_t GetDimByReverseIdx(uint32_t idx) const; + int64_t GetDimByIdx(size_t idx) const; + int64_t GetDimByReverseIdx(size_t idx) const; bool operator==(const Array &a1) const; protected: - std::vector array_; + Shape array_; }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc index 9395d3df89..3080c8279d 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc @@ -18,6 +18,7 @@ #include #include +#include namespace mindspore { namespace parallel { @@ -28,9 +29,25 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix return Status::SUCCESS; } +// skip redistribution for reshape operator +OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) { + OperatorAttrs attrs; + std::vector shape_int; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr param_value = MakeValue(shape_int); + Attr param = std::make_pair(SHAPE, param_value); + OperatorParams params = {std::make_pair(param, 2)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(RESHAPE, args); + OperatorVector opvector; + opvector.push_back(op); + return opvector; +} + Status ConstructOperator::ReshapeOP(Shape shape) { - int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); + int64_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int64_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); if (prod != prod_expect) { ValuePtr ptr = MakeValue(shape); MS_EXCEPTION_IF_NULL(ptr); @@ -38,7 +55,10 @@ Status ConstructOperator::ReshapeOP(Shape shape) { return Status::INVALID_ARGUMENT; } OperatorAttrs attrs; - ValuePtr param_value = MakeValue(shape); + std::vector shape_int; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr param_value = MakeValue(shape_int); Attr param = std::make_pair(SHAPE, param_value); OperatorParams params = {std::make_pair(param, 2)}; OperatorArgs args = std::make_pair(attrs, params); @@ -55,12 +75,21 @@ Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &en Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value); OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask}; - ValuePtr param_begin_value = MakeValue(begin); + std::vector begin_int; + (void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr param_begin_value = MakeValue(begin_int); Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2); - ValuePtr param_end_value = MakeValue(end); + std::vector end_int; + (void)std::transform(end.begin(), end.end(), std::back_inserter(end_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr param_end_value = MakeValue(end_int); Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3); - ValuePtr param_strides_value = MakeValue(strides); + std::vector strides_int; + (void)std::transform(strides.begin(), strides.end(), std::back_inserter(strides_int), + [](const int64_t &value) { return static_cast(value); }); + ValuePtr param_strides_value = MakeValue(strides_int); Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4); OperatorParams params = {param_begin, param_end, param_strides}; OperatorArgs op_args = std::make_pair(attrs, params); @@ -73,16 +102,16 @@ Status ConstructOperator::StridedSliceOP(Args args) { MS_LOG(ERROR) << "args size should not be less than 3!"; return Status::FAILED; } - int32_t split_count = args[0]; + int64_t split_count = args[0]; if (split_count <= 0) { MS_LOG(ERROR) << "split_count should not be less than 0!"; return Status::FAILED; } - int32_t split_dim = args[1]; - int32_t dev_dim = args[2]; + int64_t split_dim = args[1]; + int64_t dev_dim = args[2]; std::vector group_list; - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) { MS_LOG(ERROR) << "stride slice op: create group failed"; return FAILED; } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice @@ -101,7 +130,7 @@ Status ConstructOperator::StridedSliceOP(Args args) { Shape strides(size, 1); size_t index = 0; for (auto num : tensor_shape_) { - if (index != IntToSize(split_dim)) { + if (index != LongToSize(split_dim)) { begin[index] = 0; end[index] = num; } else { @@ -110,9 +139,9 @@ Status ConstructOperator::StridedSliceOP(Args args) { << "! when construct StridedSlice operator"; return Status::INVALID_ARGUMENT; } - int32_t count = num / split_count; - begin[index] = SizeToInt(rank) * count; - end[index] = (SizeToInt(rank) + 1) * count; + int64_t count = num / split_count; + begin[index] = SizeToLong(rank) * count; + end[index] = (SizeToLong(rank) + 1) * count; } index++; } @@ -122,7 +151,7 @@ Status ConstructOperator::StridedSliceOP(Args args) { return Status::SUCCESS; } -Status ConstructOperator::AllGatherOP(int32_t dev_dim) { +Status ConstructOperator::AllGatherOP(int64_t dev_dim) { if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!"; return Status::INVALID_ARGUMENT; @@ -147,7 +176,7 @@ Status ConstructOperator::AllGatherOP(int32_t dev_dim) { return Status::SUCCESS; } -Status ConstructOperator::ConcatOP(int32_t concat_dim) { +Status ConstructOperator::ConcatOP(int64_t concat_dim) { if (IntToSize(concat_dim) >= tensor_shape_.size()) { MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; return Status::INVALID_ARGUMENT; @@ -161,7 +190,7 @@ Status ConstructOperator::ConcatOP(int32_t concat_dim) { return Status::SUCCESS; } -Status ConstructOperator::SplitOP(int32_t split_count) { +Status ConstructOperator::SplitOP(int64_t split_count) { if (split_count <= 0) { MS_LOG(ERROR) << "Invalid split count when construct Split operator!"; return Status::FAILED; @@ -183,30 +212,30 @@ Status ConstructOperator::AlltoAllOP(Args args) { MS_LOG(ERROR) << "args size should not be less than 4!"; return Status::FAILED; } - int32_t split_count = args[0]; - int32_t split_dim = args[1]; - int32_t concat_dim = args[2]; - int32_t dev_dim = args[3]; + int64_t split_count = args[0]; + int64_t split_dim = args[1]; + int64_t concat_dim = args[2]; + int64_t dev_dim = args[3]; if (split_count <= 0) { MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!"; return Status::FAILED; } - if (tensor_shape_[IntToSize(split_dim)] % split_count != 0) { + if (tensor_shape_[LongToSize(split_dim)] % split_count != 0) { MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim << "when construct AlltoAll operator!"; return Status::INVALID_ARGUMENT; } - if (IntToSize(concat_dim) >= tensor_shape_.size()) { + if (LongToSize(concat_dim) >= tensor_shape_.size()) { MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!"; return Status::INVALID_ARGUMENT; } - if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { + if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!"; return Status::INVALID_ARGUMENT; } std::vector group_list; - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) { MS_LOG(ERROR) << "AlltoAll op: create group failed"; return FAILED; } else if (group_list.empty()) { // this group only has one device, don't need do alltoall diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h index b06d70af36..8a3c780e6d 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ #include #include @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -using Args = std::vector; +using Args = std::vector; class ConstructOperator { public: @@ -35,11 +35,12 @@ class ConstructOperator { ConstructOperator() : dev_size_(0) {} ~ConstructOperator() = default; Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); + OperatorVector SkipRedisReshapeOP(Shape shape); Status ReshapeOP(Shape shape); Status StridedSliceOP(Args args); - Status AllGatherOP(int32_t dev_dim); - Status SplitOP(int32_t split_count); - Status ConcatOP(int32_t concat_dim); + Status AllGatherOP(int64_t dev_dim); + Status SplitOP(int64_t split_count); + Status ConcatOP(int64_t concat_dim); Status AlltoAllOP(Args args); Operator GetOperator() const { return op_; } void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } @@ -55,4 +56,4 @@ class ConstructOperator { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc index d5d34a484f..a53235e206 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc @@ -15,7 +15,7 @@ */ #include "frontend/parallel/tensor_layout/layout_transfer.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/status.h" namespace mindspore { diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h index 01c56fc7cf..6c3dcab4db 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ #include #include "frontend/parallel/status.h" @@ -45,4 +45,4 @@ class LayoutTransfer { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc index 184f0c7530..b6c6904d4e 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc @@ -18,7 +18,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/status.h" #include "frontend/parallel/tensor_layout/shape_util.h" #include "utils/convert_utils.h" @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Map::Init(const std::vector &array) { +Status Map::Init(const Shape &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -39,11 +39,11 @@ Status Map::Init(const std::vector &array) { } bool Map::IsValidMap() { - if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { + if (std::any_of(array_.begin(), array_.end(), [](int64_t value) { return ((value < 0) && (value != MAP_NONE)); })) { return false; } // check that all none -1 value in array_ is different - std::vector sorted_array = array_; + Shape sorted_array = array_; std::sort(sorted_array.begin(), sorted_array.end()); int32_t value = MAP_NONE; for (auto &element : sorted_array) { @@ -58,7 +58,7 @@ bool Map::IsValidMap() { return true; } -int32_t Map::GetMaxItem() const { +int64_t Map::GetMaxItem() const { if (!array_.empty()) { return *std::max_element(array_.begin(), array_.end()); } else { @@ -66,7 +66,7 @@ int32_t Map::GetMaxItem() const { } } -int32_t Map::GetIndexByValue(int32_t value) const { +int32_t Map::GetIndexByValue(int64_t value) const { auto iter = find(array_.begin(), array_.end(), value); if (iter != array_.end()) { return static_cast(std::distance(array_.begin(), iter)); @@ -82,15 +82,15 @@ std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) co if (expand_num_list.GetDimSize() != GetDimSize()) { return nullptr; } - std::vector new_shape; - for (uint32_t i = 0; i != GetDimSize(); i++) { + Shape new_shape; + for (size_t i = 0; i != GetDimSize(); i++) { if (GetDimByIdx(i) == MAP_NONE) { - for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { + for (int64_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { new_shape.push_back(MAP_NONE); } } else { new_shape.push_back(GetDimByIdx(i)); - int32_t j = 1; + int64_t j = 1; while (j < expand_num_list.GetDimByIdx(i)) { new_shape.push_back(MAP_NONE); j++; @@ -106,17 +106,17 @@ std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) co * expand.size() should be equal to array_.size() */ std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { - if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { + if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { return nullptr; } - std::vector new_shape; - for (uint32_t i = 0; i < GetDimSize(); i++) { + Shape new_shape; + for (size_t i = 0; i < GetDimSize(); i++) { if (GetDimByIdx(i) == MAP_NONE) { new_shape.push_back(MAP_NONE); } else { - int32_t start_map = - expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; - for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { + int64_t start_map = + expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; + for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { new_shape.push_back(k + start_map); } } @@ -127,16 +127,16 @@ std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_nu } std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { - if (GetMaxItem() >= static_cast(input_vector.size())) { + if (GetMaxItem() >= static_cast(input_vector.size())) { return nullptr; } std::vector out; Arrangement empty_arrangement; - for (uint32_t i = 0; i < GetDimSize(); i++) { + for (size_t i = 0; i < GetDimSize(); i++) { if (GetDimByIdx(i) == MAP_NONE) { out.push_back(empty_arrangement); } else { - out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); + out.push_back(input_vector[input_vector.size() - 1 - LongToSize(GetDimByIdx(i))]); } } return std::make_shared>(out); @@ -144,7 +144,7 @@ std::shared_ptr> Map::ReMapVector(const std::vector idx_list) const { for (auto &value : idx_list) { - if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { + if (GetDimByIdx(value) != MAP_NONE) { return false; } } @@ -152,11 +152,11 @@ bool Map::CheckNoneByIdxList(std::vector idx_list) const { } Map Map::SqueezeMapByIdxList(std::vector idx_list) const { - std::vector out_shape; + Shape out_shape; for (size_t i = 0; i < GetDimSize(); i++) { auto it = std::find(idx_list.begin(), idx_list.end(), i); if (it == idx_list.end()) { - out_shape.push_back(GetDimByIdx(SizeToUint(i))); + out_shape.push_back(GetDimByIdx(i)); } } if (out_shape.empty()) { diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h index 3d299d4b90..3a18bab028 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_MAP_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_MAP_H_ #include #include @@ -34,9 +34,9 @@ class Map : public Array { public: Map() = default; ~Map() override = default; - Status Init(const std::vector &array) override; - int32_t GetMaxItem() const; - int32_t GetIndexByValue(int32_t value) const; + Status Init(const Shape &array) override; + int64_t GetMaxItem() const; + int32_t GetIndexByValue(int64_t value) const; std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; std::shared_ptr> ReMapVector(const std::vector &input_vector) const; @@ -49,4 +49,4 @@ class Map : public Array { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_MAP_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h index 0347b6423a..589eea2383 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ #include #include "frontend/parallel/status.h" @@ -37,4 +37,4 @@ class RedistributionLayoutTransfer : public LayoutTransfer { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc index 6ac24418b7..56912f1c7c 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc @@ -47,8 +47,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, cons constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); size_t key = 0; - std::vector map = in_tensor_map_.array(); - for (int32_t item : map) { + Shape map = in_tensor_map_.array(); + for (int64_t item : map) { map_[key++] = item; } @@ -83,9 +83,9 @@ Status RedistributionOperatorInfer::InferRedistributionOperator() { // break loop structure with concat_by_axis if (len_global == operator_list_.size() && !map_.empty()) { size_t index = map_.begin()->first; - int32_t in_dim = map_[index]; + int64_t in_dim = map_[index]; map_[index] = NONE; - Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; + Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))}; if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { return Status::FAILED; } @@ -97,8 +97,8 @@ Status RedistributionOperatorInfer::InferRedistributionOperator() { Status RedistributionOperatorInfer::InferSplitByAxis() { for (auto iter = map_.begin(); iter != map_.end();) { uint32_t index = iter->first; - int32_t in_dim = iter->second; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + int64_t in_dim = iter->second; + int64_t out_dim = out_tensor_map_.GetDimByIdx(index); if (in_dim == out_dim) { (void)map_.erase(iter++); continue; @@ -122,8 +122,8 @@ Status RedistributionOperatorInfer::InferSplitByAxis() { Status RedistributionOperatorInfer::InferPermuteByAxis() { for (auto iter = map_.begin(); iter != map_.end();) { uint32_t index = iter->first; - int32_t in_dim = map_[index]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + int64_t in_dim = map_[index]; + int64_t out_dim = out_tensor_map_.GetDimByIdx(index); if (in_dim == out_dim) { (void)map_.erase(iter++); continue; @@ -132,9 +132,9 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { std::any_of(map_.begin(), map_.end(), [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); - int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); + int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim)); if (is_cost_model_) { - int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); + int64_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, dev_num}; if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { @@ -165,10 +165,10 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { Status RedistributionOperatorInfer::InferConcatByAxis() { for (auto iter = map_.begin(); iter != map_.end();) { uint32_t index = iter->first; - int32_t in_dim = map_[index]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + int64_t in_dim = map_[index]; + int64_t out_dim = out_tensor_map_.GetDimByIdx(index); if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) { - Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; + Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))}; if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; return Status::FAILED; @@ -215,7 +215,7 @@ Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) { MS_LOG(ERROR) << "args size should not be less than 3!"; return Status::FAILED; } - uint32_t index = IntToUint(args[1]); + size_t index = LongToSize(args[1]); if (constructor_.StridedSliceOP(args) != Status::SUCCESS) { return Status::FAILED; } else { @@ -239,11 +239,11 @@ Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) { operator_vector_.push_back(constructor_.GetOperator()); output_info_vector_.push_back(std::make_pair(false, 0)); } - uint32_t index = IntToUint(args[1]); - int32_t val = args[2]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + size_t index = LongToSize(args[1]); + int64_t val = args[2]; + int64_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (cur_tensor_layout_.UpdateTensorMap(IntToUint(val), NONE) == Status::FAILED) { + if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) { return Status::FAILED; } if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) { @@ -257,9 +257,9 @@ Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { MS_LOG(ERROR) << "args size should not be less than 3!"; return Status::FAILED; } - int32_t tensor_dim = args[0]; - int32_t dev_dim = args[1]; - int32_t split_count = args[2]; + int64_t tensor_dim = args[0]; + int64_t dev_dim = args[1]; + int64_t split_count = args[2]; if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) { return Status::FAILED; } else { @@ -280,7 +280,7 @@ Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { output_info_vector_.push_back(std::make_pair(false, 0)); } } - if (cur_tensor_layout_.UpdateTensorMap(IntToUint(tensor_dim), NONE) == Status::FAILED) { + if (cur_tensor_layout_.UpdateTensorMap(LongToSize(tensor_dim), NONE) == Status::FAILED) { return Status::FAILED; } return Status::SUCCESS; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h index 66cdb3f925..56e98a24b6 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ #include #include @@ -28,10 +28,10 @@ #include "utils/convert_utils.h" namespace mindspore { namespace parallel { -using DeviceArrangement = std::vector; -using TensorMap = std::vector; -using TensorShape = std::vector; -using RedistributionOperatorMap = std::unordered_map; +using DeviceArrangement = Shape; +using TensorMap = Shape; +using TensorShape = Shape; +using RedistributionOperatorMap = std::unordered_map; using OperatorR = std::pair; using OperatorC = std::pair; using OperatorList = std::vector; @@ -74,4 +74,4 @@ class RedistributionOperatorInfer { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h index f9ebe9e32b..fc9583f38a 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ #include #include "frontend/parallel/status.h" @@ -45,4 +45,4 @@ class ReshapeLayoutTransfer : public LayoutTransfer { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc index 83282d16b3..453ad8066f 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc @@ -26,7 +26,7 @@ namespace parallel { * shape = [2, 8, 32] * shape_accum = [2, 2 * 8, 2 * 8 * 32] */ -Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { +Status ShapeToAccumulateProduct(const Shape &shape, Shape *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -47,7 +47,7 @@ Status ShapeToAccumulateProduct(const std::vector &shape, std::vector &shape, std::vector *shape_accum) { +Status ShapeToAccumulateProductReverse(const Shape &shape, Shape *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector &shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { +Status AccumulateProductToShape(const Shape &shape_accum, Shape *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -81,7 +81,7 @@ Status AccumulateProductToShape(const std::vector &shape_accum, std::ve MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; return Status::FAILED; } - shape->push_back(static_cast((*iter) / value)); + shape->push_back(static_cast((*iter) / value)); value = (*iter); } return Status::SUCCESS; @@ -92,7 +92,7 @@ Status AccumulateProductToShape(const std::vector &shape_accum, std::ve * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] * shape = [2, 8, 32] */ -Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { +Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -105,7 +105,7 @@ Status AccumulateProductReverseToShape(const std::vector &shape_accum_r MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; return Status::FAILED; } - (void)shape->insert(shape->begin(), static_cast((*iter) / value)); + (void)shape->insert(shape->begin(), static_cast((*iter) / value)); value = *iter; } return Status::SUCCESS; @@ -122,8 +122,7 @@ Status AccumulateProductReverseToShape(const std::vector &shape_accum_r * in2 = [8, 16] * *out = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, - std::vector *out_accum) { +Status UnifyAccumulateProduct(const Shape &in1_accum, const Shape &in2_accum, Shape *out_accum) { MS_EXCEPTION_IF_NULL(out_accum); out_accum->clear(); auto in1_iter = in1_accum.begin(); @@ -159,19 +158,19 @@ Status UnifyAccumulateProduct(const std::vector &in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { +Status UnifyShape(const Shape &in1, const Shape &in2, Shape *out) { MS_EXCEPTION_IF_NULL(out); - std::vector in1_accum; + Shape in1_accum; Status status = ShapeToAccumulateProduct(in1, &in1_accum); if (status != Status::SUCCESS) { return status; } - std::vector in2_accum; + Shape in2_accum; status = ShapeToAccumulateProduct(in2, &in2_accum); if (status != Status::SUCCESS) { return status; } - std::vector out_accum; + Shape out_accum; status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); if (status != Status::SUCCESS) { return status; @@ -194,9 +193,8 @@ Status UnifyShape(const std::vector &in1, const std::vector &i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, - const std::vector &expand_accum_reverse, - std::vector *out_accum_reverse) { +Status ExpandAccumulateProduct(const Shape &in_accum_reverse, const Shape &expand_accum_reverse, + Shape *out_accum_reverse) { MS_EXCEPTION_IF_NULL(out_accum_reverse); out_accum_reverse->clear(); auto in_riter = in_accum_reverse.rbegin(); @@ -236,19 +234,19 @@ Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { +Status ExpandShape(const Shape &in, const Shape &expand, Shape *out) { MS_EXCEPTION_IF_NULL(out); - std::vector in_accum_reverse; + Shape in_accum_reverse; Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); if (status != Status::SUCCESS) { return status; } - std::vector expand_accum_reverse; + Shape expand_accum_reverse; status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse); if (status != Status::SUCCESS) { return status; } - std::vector out_accum_reverse; + Shape out_accum_reverse; status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse); if (status != Status::SUCCESS) { return status; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h index 49dd39ffd6..b54603f83e 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ #include #include @@ -24,6 +24,7 @@ #include #include "frontend/parallel/status.h" +#include "frontend/parallel/device_matrix.h" namespace mindspore { namespace parallel { @@ -39,7 +40,7 @@ namespace parallel { * shape_accum = [2, 2 * 8, 2 * 8 * 32] * */ -Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); +Status ShapeToAccumulateProduct(const Shape &shape, Shape *shape_accum); /* * compute the accumulating product of all the values in shape from right to left, @@ -53,7 +54,7 @@ Status ShapeToAccumulateProduct(const std::vector &shape, std::vector &shape, std::vector *shape_accum); +Status ShapeToAccumulateProductReverse(const Shape &shape, Shape *shape_accum); /* * compute the original shape from the accumulating product shape_accum, @@ -68,7 +69,7 @@ Status ShapeToAccumulateProductReverse(const std::vector &shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); +Status AccumulateProductToShape(const Shape &shape_accum, Shape *shape); /* * compute the original shape from the accumulating product shape_accum, @@ -83,7 +84,7 @@ Status AccumulateProductToShape(const std::vector &shape_accum, std::ve * shape = [2, 8, 32] * */ -Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); +Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *shape); /* * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, @@ -101,8 +102,7 @@ Status AccumulateProductReverseToShape(const std::vector &shape_accum_r * in2_accum = [8, 16] * out_accum = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, - std::vector *out_accum); +Status UnifyAccumulateProduct(const Shape &in1_accum, const Shape &in2_accum, Shape *out_accum); /* * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] @@ -117,7 +117,7 @@ Status UnifyAccumulateProduct(const std::vector &in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); +Status UnifyShape(const Shape &in1, const Shape &in2, Shape *out); /* * given two accumulate product in reverse order of in and expand, @@ -141,9 +141,8 @@ Status UnifyShape(const std::vector &in1, const std::vector &i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, - const std::vector &expand_accum_reverse, - std::vector *out_accum_reverse); +Status ExpandAccumulateProduct(const Shape &in_accum_reverse, const Shape &expand_accum_reverse, + Shape *out_accum_reverse); /* * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], @@ -165,8 +164,8 @@ Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); +Status ExpandShape(const Shape &in, const Shape &expand, Shape *out); } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h index fc78b1f59c..042851b0df 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ #include #include @@ -68,4 +68,4 @@ class TensorInfo { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc index b9c6cc78de..203b4f9958 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -17,7 +17,7 @@ #include "frontend/parallel/tensor_layout/tensor_layout.h" #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "ir/value.h" #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/status.h" @@ -64,8 +64,8 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens } } -Status TensorLayout::InitFromVector(const std::vector &device_arrangement, - const std::vector &tensor_map, const std::vector &tensor_shape) { +Status TensorLayout::InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, + const Shape &tensor_shape) { if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { return FAILED; } @@ -82,7 +82,7 @@ Status TensorLayout::InitFromVector(const std::vector &device_arrangeme } bool TensorLayout::IsValidTensorLayout() const { - if (tensor_map_origin_.GetMaxItem() >= static_cast(device_arrangement_origin_.GetDimSize())) { + if (tensor_map_origin_.GetMaxItem() >= static_cast(device_arrangement_origin_.GetDimSize())) { MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!"; return false; } @@ -114,18 +114,18 @@ bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const { } void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { - std::vector device_arrangement_shape; - std::vector tensor_map_shape = tensor_map_origin_.array(); - uint32_t dev_num = SizeToUint(device_arrangement_origin_.GetDimSize()); - int32_t dev_num_left = SizeToInt(device_arrangement_origin_.GetDimSize()); - for (uint32_t i = 0; i < dev_num; i++) { + Shape device_arrangement_shape; + Shape tensor_map_shape = tensor_map_origin_.array(); + size_t dev_num = device_arrangement_origin_.GetDimSize(); + size_t dev_num_left = device_arrangement_origin_.GetDimSize(); + for (size_t i = 0; i < dev_num; i++) { if (device_arrangement_origin_.GetDimByIdx(i) == 1) { - int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast(dev_num - 1 - i)); + int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast(dev_num - 1 - i)); if (idx != -1) { tensor_map_shape[static_cast(idx)] = -1; } for (auto &value : tensor_map_shape) { - if (value >= dev_num_left - 1 - static_cast(i)) { + if (value >= SizeToLong(dev_num_left) - 1 - static_cast(i)) { value--; } } @@ -139,7 +139,7 @@ void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { } // if idx is not in tensor_map, return -1 -int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const { +int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const { return tensor_map_.GetIndexByValue(idx); } @@ -288,7 +288,7 @@ std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrang } bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { - std::vector in_expand_shape_shape; + Shape in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { return false; @@ -297,7 +297,7 @@ bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) con } std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { - std::vector in_expand_shape_shape; + Shape in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { return nullptr; @@ -311,14 +311,14 @@ std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arra } Arrangement TensorLayout::slice_shape() const { - std::vector shape; - for (uint32_t index = 0; index < tensor_map_.GetDimSize(); index++) { - int32_t dim = tensor_map_.GetDimByIdx(index); - int32_t num = tensor_shape_.GetDimByIdx(index); + Shape shape; + for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) { + int64_t dim = tensor_map_.GetDimByIdx(index); + int64_t num = tensor_shape_.GetDimByIdx(index); if (dim == -1) { shape.push_back(num); } else { - int32_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); + int64_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); shape.push_back(num / divisor); } } @@ -331,7 +331,7 @@ Arrangement TensorLayout::slice_shape() const { } } -Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { +Status TensorLayout::UpdateTensorMap(size_t index, int64_t value) { if (index >= tensor_map_.GetDimSize()) { MS_LOG(ERROR) << "Index is out of the size of the tensor map!"; return Status::FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index a9fdc9610c..9832964005 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ #include #include @@ -38,8 +38,15 @@ class TensorLayout { std::string StandardToString() const; std::string OriginToString() const; Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); - Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, - const std::vector &tensor_shape); + Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape); + + bool skip_redistribution() const { return skip_redistribution_; } + + void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } + + int32_t get_field_size() const { return field_size_; } + + void set_field_size(int32_t field_size) { field_size_ = field_size; } Arrangement device_arrangement() const { return device_arrangement_; } @@ -71,10 +78,13 @@ class TensorLayout { Arrangement slice_shape() const; - Status UpdateTensorMap(uint32_t index, int32_t value); + Status UpdateTensorMap(size_t index, int64_t value); TensorLayout SqueezeShape() const; + // Key for user data. + constexpr static char key[] = "TLayout"; + private: std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( const Arrangement &expanded_shape) const; @@ -84,7 +94,7 @@ class TensorLayout { int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const; bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; - int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const; + int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; Arrangement device_arrangement_origin_; Map tensor_map_origin_; @@ -92,8 +102,10 @@ class TensorLayout { Arrangement device_arrangement_; Map tensor_map_; Arrangement tensor_shape_; + bool skip_redistribution_ = false; + int32_t field_size_ = 0; }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc index 43bb330787..87d385c81b 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc @@ -18,7 +18,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "frontend/parallel/status.h" #include "frontend/parallel/tensor_layout/shape_util.h" diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h index df4bd1570f..196827d18a 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ #include #include @@ -87,4 +87,4 @@ class TensorRedistribution { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index df9729c4ee..81032180b6 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -47,7 +47,9 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") -ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR}) +include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") +set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") +ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) ################## Include sub-modules ############################### add_subdirectory(util) @@ -64,6 +66,7 @@ add_dependencies(kernels core) add_dependencies(engine-datasetops-source core) add_dependencies(engine-datasetops-source-sampler core) add_dependencies(engine-datasetops core) +add_dependencies(engine-datasetops-mapop core) add_dependencies(engine-opt core) add_dependencies(engine-perf core) add_dependencies(engine-gnn core) @@ -87,6 +90,7 @@ set(submodules $ $ $ + $ $ $ $ @@ -112,23 +116,33 @@ endif () add_dependencies(_c_dataengine generated_engine_files) +if (ENABLE_PYTHON) set_target_properties(_c_dataengine PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}" SUFFIX "${PYTHON_MODULE_EXTENSION}" ) +endif() ###################################################################### ################# Link with external libraries ######################## target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) + if (ENABLE_PYTHON) + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) + else() + target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY}) + endif() else() set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) + if (ENABLE_PYTHON) + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) + else() + target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY}) + endif() endif() target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs - mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB}) + mindspore::opencv_imgproc mindspore::tinyxml2 mindspore::sentencepiece mindspore::sentencepiece_train ${ICU_LIB}) if (ENABLE_GPUQUE) target_link_libraries(_c_dataengine PRIVATE gpu_queue ${CUDNN_PATH}/lib64/libcudnn.so @@ -146,6 +160,12 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) else() target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) + if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) + target_link_libraries(_c_dataengine PRIVATE mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) + if (${ENABLE_IBVERBS} STREQUAL "ON") + target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) + endif() + endif() endif() if (USE_GLOG) diff --git a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt index ae0b9cc28e..1f1c3c9bdd 100644 --- a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt @@ -1,16 +1,29 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) if (ENABLE_PYTHON) - add_library(APItoPython OBJECT - de_pipeline.cc - python_bindings.cc - ) - target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS}) -endif() + add_library(APItoPython OBJECT + python/de_pipeline.cc + python/pybind_register.cc + python/bindings.cc + python/bindings/dataset/engine/cache/bindings.cc + python/bindings/dataset/core/bindings.cc + python/bindings/dataset/kernels/data/bindings.cc + python/bindings/dataset/kernels/bindings.cc + python/bindings/dataset/engine/datasetops/bindings.cc + python/bindings/dataset/engine/datasetops/source/bindings.cc + python/bindings/dataset/engine/gnn/bindings.cc + python/bindings/dataset/kernels/image/bindings.cc + python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc + python/bindings/dataset/text/bindings.cc + python/bindings/dataset/text/kernels/bindings.cc + python/bindings/mindrecord/include/bindings.cc + ) + target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS}) +endif () add_library(cpp-API OBJECT - datasets.cc - iterator.cc - transforms.cc - samplers.cc - ) + datasets.cc + iterator.cc + transforms.cc + samplers.cc + ) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 3072a62dc9..0e6090f128 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -17,33 +17,47 @@ #include #include "minddata/dataset/include/datasets.h" -#include "minddata/dataset/include/transforms.h" #include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/include/transforms.h" #include "minddata/dataset/engine/dataset_iterator.h" +// Source dataset headers (in alphabetical order) +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +// Dataset operator headers (in alphabetical order) #include "minddata/dataset/engine/datasetops/batch_op.h" -#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/concat_op.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/shuffle_op.h" -#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" + +// Sampler headers (in alphabetical order) #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/path.h" namespace mindspore { namespace dataset { namespace api { -#define RETURN_NULL_IF_ERROR(_s) \ - do { \ - Status __rc = (_s); \ - if (__rc.IsError()) { \ - return nullptr; \ - } \ +#define RETURN_EMPTY_IF_ERROR(_s) \ + do { \ + Status __rc = (_s); \ + if (__rc.IsError()) { \ + MS_LOG(ERROR) << __rc; \ + return {}; \ + } \ } while (false) // Function to create the iterator, which will build and launch the execution tree. @@ -53,7 +67,7 @@ std::shared_ptr Dataset::CreateIterator() { iter = std::make_shared(); Status rc = iter->BuildAndLaunchTree(shared_from_this()); if (rc.IsError()) { - MS_LOG(ERROR) << "CreateIterator failed."; + MS_LOG(ERROR) << "CreateIterator failed." << rc; return nullptr; } @@ -75,6 +89,45 @@ Dataset::Dataset() { connector_que_size_ = cfg->op_connector_size(); } +// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS +// (In alphabetical order) + +// Function to create a CelebADataset. +std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &dataset_type, + const std::shared_ptr &sampler, const bool &decode, + const std::set &extensions) { + auto ds = std::make_shared(dataset_dir, dataset_type, sampler, decode, extensions); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Cifar10Dataset. +std::shared_ptr Cifar10(const std::string &dataset_dir, std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Cifar100Dataset. +std::shared_ptr Cifar100(const std::string &dataset_dir, std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a CocoDataset. +std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, + const std::string &task, const bool &decode, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, annotation_file, task, decode, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + // Function to create a ImageFolderDataset. std::shared_ptr ImageFolder(std::string dataset_dir, bool decode, std::shared_ptr sampler, std::set extensions, @@ -97,15 +150,35 @@ std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptrValidateParams() ? ds : nullptr; } -// Function to create a Cifar10Dataset. -std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, - std::shared_ptr sampler) { - auto ds = std::make_shared(dataset_dir, num_samples, sampler); +// Function to overload "+" operator to concat two datasets +std::shared_ptr operator+(const std::shared_ptr &datasets1, + const std::shared_ptr &datasets2) { + std::shared_ptr ds = std::make_shared(std::vector({datasets1, datasets2})); + + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a VOCDataset. +std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode, + const std::map &class_index, bool decode, + std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, task, mode, class_index, decode, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a ZipDataset. +std::shared_ptr Zip(const std::vector> &datasets) { + auto ds = std::make_shared(datasets); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } +// FUNCTIONS TO CREATE DATASETS FOR DATASET OPS +// (In alphabetical order) + // Function to create a Batch dataset std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { // Default values @@ -123,6 +196,57 @@ std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remai return ds; } +// Function to create a Concat dataset +std::shared_ptr Dataset::Concat(const std::vector> &datasets) { + auto ds = std::make_shared(datasets); + ds->children.push_back(shared_from_this()); + + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Map dataset. +std::shared_ptr Dataset::Map(std::vector> operations, + std::vector input_columns, + std::vector output_columns, + const std::vector &project_columns) { + auto ds = std::make_shared(operations, input_columns, output_columns, project_columns); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a ProjectDataset. +std::shared_ptr Dataset::Project(const std::vector &columns) { + auto ds = std::make_shared(columns); + // Call derived class validation method. + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a RenameDataset. +std::shared_ptr Dataset::Rename(const std::vector &input_columns, + const std::vector &output_columns) { + auto ds = std::make_shared(input_columns, output_columns); + // Call derived class validation method. + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + // Function to create Repeat dataset. std::shared_ptr Dataset::Repeat(int32_t count) { // Workaround for repeat == 1, do not inject repeat. @@ -141,12 +265,10 @@ std::shared_ptr Dataset::Repeat(int32_t count) { return ds; } -// Function to create a Map dataset. -std::shared_ptr Dataset::Map(std::vector> operations, - std::vector input_columns, - std::vector output_columns, - const std::vector &project_columns) { - auto ds = std::make_shared(operations, input_columns, output_columns, project_columns); +// Function to create a ShuffleOp +std::shared_ptr Dataset::Shuffle(int32_t shuffle_size) { + // Pass in reshuffle_each_epoch with true + auto ds = std::make_shared(shuffle_size, true); if (!ds->ValidateParams()) { return nullptr; @@ -157,11 +279,11 @@ std::shared_ptr Dataset::Map(std::vector Dataset::Shuffle(int32_t shuffle_size) { - // Pass in reshuffle_each_epoch with true - auto ds = std::make_shared(shuffle_size, true); +// Function to create a SkipDataset. +std::shared_ptr Dataset::Skip(int32_t count) { + auto ds = std::make_shared(count); + // Call derived class validation method. if (!ds->ValidateParams()) { return nullptr; } @@ -171,9 +293,16 @@ std::shared_ptr Dataset::Shuffle(int32_t shuffle_size) { return ds; } -// Function to create a ProjectDataset. -std::shared_ptr Dataset::Project(const std::vector &columns) { - auto ds = std::make_shared(columns); +// Function to create a TakeDataset. +std::shared_ptr Dataset::Take(int32_t count) { + // If count is greater than the number of element in dataset or equal to -1, + // all the element in dataset will be taken + if (count == -1) { + return shared_from_this(); + } + + auto ds = std::make_shared(count); + // Call derived class validation method. if (!ds->ValidateParams()) { return nullptr; @@ -184,15 +313,237 @@ std::shared_ptr Dataset::Project(const std::vector return ds; } +// Function to create a Zip dataset +std::shared_ptr Dataset::Zip(const std::vector> &datasets) { + // Default values + auto ds = std::make_shared(datasets); + ds->children.push_back(shared_from_this()); + + return ds->ValidateParams() ? ds : nullptr; +} + +// OTHER FUNCTIONS +// (In alphabetical order) + // Helper function to create default RandomSampler. std::shared_ptr CreateDefaultSampler() { - int32_t num_samples = 0; // 0 means to sample all ids. + const int32_t num_samples = 0; // 0 means to sample all ids. bool replacement = false; return std::make_shared(replacement, num_samples); } +// Helper function to validate dataset params +bool ValidateCommonDatasetParams(std::string dataset_dir) { + if (dataset_dir.empty()) { + MS_LOG(ERROR) << "No dataset path is specified"; + return false; + } + return true; +} + /* ####################################### Derived Dataset classes ################################# */ +// DERIVED DATASET CLASSES LEAF-NODE DATASETS +// (In alphabetical order) + +// Constructor for CelebADataset +CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type, + const std::shared_ptr &sampler, const bool &decode, + const std::set &extensions) + : dataset_dir_(dataset_dir), + dataset_type_(dataset_type), + sampler_(sampler), + decode_(decode), + extensions_(extensions) {} + +bool CelebADataset::ValidateParams() { + Path dir(dataset_dir_); + if (!dir.IsDirectory()) { + MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; + return false; + } + std::set dataset_type_list = {"all", "train", "valid", "test"}; + auto iter = dataset_type_list.find(dataset_type_); + if (iter == dataset_type_list.end()) { + MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'."; + return false; + } + return true; +} + +// Function to build CelebADataset +std::vector> CelebADataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + std::unique_ptr schema = std::make_unique(); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + // label is like this:0 1 0 0 1...... + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + decode_, dataset_type_, extensions_, std::move(schema), + std::move(sampler_->Build()))); + return node_ops; +} + +// Constructor for Cifar10Dataset +Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), sampler_(sampler) {} + +bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } + +// Function to build CifarOp for Cifar10 +std::vector> Cifar10Dataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, + dataset_dir_, connector_que_size_, std::move(schema), + std::move(sampler_->Build()))); + return node_ops; +} + +// Constructor for Cifar100Dataset +Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), sampler_(sampler) {} + +bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } + +// Function to build CifarOp for Cifar100 +std::vector> Cifar100Dataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_, + dataset_dir_, connector_que_size_, std::move(schema), + std::move(sampler_->Build()))); + return node_ops; +} + +// Constructor for CocoDataset +CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, + const bool &decode, const std::shared_ptr &sampler) + : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} + +bool CocoDataset::ValidateParams() { + Path dir(dataset_dir_); + if (!dir.IsDirectory()) { + MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; + return false; + } + Path annotation_file(annotation_file_); + if (!annotation_file.Exists()) { + MS_LOG(ERROR) << "annotation_file is invalid or not exist"; + return false; + } + std::set task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"}; + auto task_iter = task_list.find(task_); + if (task_iter == task_list.end()) { + MS_LOG(ERROR) << "Invalid task type"; + return false; + } + return true; +} + +// Function to build CocoDataset +std::vector> CocoDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + CocoOp::TaskType task_type; + if (task_ == "Detection") { + task_type = CocoOp::TaskType::Detection; + } else if (task_ == "Stuff") { + task_type = CocoOp::TaskType::Stuff; + } else if (task_ == "Keypoint") { + task_type = CocoOp::TaskType::Keypoint; + } else if (task_ == "Panoptic") { + task_type = CocoOp::TaskType::Panoptic; + } + + std::unique_ptr schema = std::make_unique(); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + switch (task_type) { + case CocoOp::TaskType::Detection: + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case CocoOp::TaskType::Stuff: + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("segmentation"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case CocoOp::TaskType::Keypoint: + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("keypoints"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("num_keypoints"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case CocoOp::TaskType::Panoptic: + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn( + ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_EMPTY_IF_ERROR( + schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + default: + MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type"; + return {}; + } + std::shared_ptr op = + std::make_shared(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, + connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); + node_ops.push_back(op); + return node_ops; +} + ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, std::set extensions, std::map class_indexing) @@ -203,16 +554,9 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std class_indexing_(class_indexing), exts_(extensions) {} -bool ImageFolderDataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - - return true; -} +bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } -std::shared_ptr>> ImageFolderDataset::Build() { +std::vector> ImageFolderDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -225,29 +569,22 @@ std::shared_ptr>> ImageFolderDataset::Bui // This arg is exist in ImageFolderOp, but not externalized (in Python API). std::unique_ptr schema = std::make_unique(); TensorShape scalar = TensorShape::CreateScalar(); - RETURN_NULL_IF_ERROR( + RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_NULL_IF_ERROR( + RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_, class_indexing_, std::move(schema), std::move(sampler_->Build()))); - return std::make_shared>>(node_ops); + return node_ops; } MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr sampler) : dataset_dir_(dataset_dir), sampler_(sampler) {} -bool MnistDataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } +bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } - return true; -} - -std::shared_ptr>> MnistDataset::Build() { +std::vector> MnistDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -258,16 +595,84 @@ std::shared_ptr>> MnistDataset::Build() { // Do internal Schema generation. auto schema = std::make_unique(); - RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); TensorShape scalar = TensorShape::CreateScalar(); - RETURN_NULL_IF_ERROR( + RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), std::move(sampler_->Build()))); - return std::make_shared>>(node_ops); + return node_ops; +} + +// Constructor for VOCDataset +VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, + const std::map &class_index, bool decode, + std::shared_ptr sampler) + : dataset_dir_(dataset_dir), + task_(task), + mode_(mode), + class_index_(class_index), + decode_(decode), + sampler_(sampler) {} + +bool VOCDataset::ValidateParams() { + Path dir(dataset_dir_); + if (!dir.IsDirectory()) { + MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; + return false; + } + if (task_ == "Segmentation") { + if (!class_index_.empty()) { + MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task."; + return false; + } + Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt"; + if (!imagesets_file.Exists()) { + MS_LOG(ERROR) << "[Segmentation] imagesets_file is invalid or not exist"; + return false; + } + } else if (task_ == "Detection") { + Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt"; + if (!imagesets_file.Exists()) { + MS_LOG(ERROR) << "[Detection] imagesets_file is invalid or not exist."; + return false; + } + } else { + MS_LOG(ERROR) << "Invalid task: " << task_; + return false; + } + return true; +} + +// Function to build VOCDataset +std::vector> VOCDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(dataset_dir_); + (void)builder->SetTask(task_); + (void)builder->SetMode(mode_); + (void)builder->SetNumWorkers(num_workers_); + (void)builder->SetSampler(std::move(sampler_->Build())); + (void)builder->SetDecode(decode_); + (void)builder->SetClassIndex(class_index_); + + std::shared_ptr op; + RETURN_EMPTY_IF_ERROR(builder->Build(&op)); + node_ops.push_back(op); + return node_ops; } +// DERIVED DATASET CLASSES LEAF-NODE DATASETS +// (In alphabetical order) + BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, std::map>> pad_map) : batch_size_(batch_size), @@ -276,7 +681,7 @@ BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, st cols_to_map_(cols_to_map), pad_map_(pad_map) {} -std::shared_ptr>> BatchDataset::Build() { +std::vector> BatchDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -288,34 +693,39 @@ std::shared_ptr>> BatchDataset::Build() { node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, cols_to_map_, pad_map_)); #endif - return std::make_shared>>(node_ops); + return node_ops; } bool BatchDataset::ValidateParams() { if (batch_size_ <= 0) { + MS_LOG(ERROR) << "Batch: Batch size cannot be negative"; return false; } return true; } -RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} - -std::shared_ptr>> RepeatDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - node_ops.push_back(std::make_shared(repeat_count_)); - return std::make_shared>>(node_ops); +// Function to build ConcatOp +ConcatDataset::ConcatDataset(const std::vector> &datasets) : datasets_(datasets) { + this->children = datasets_; } -bool RepeatDataset::ValidateParams() { - if (repeat_count_ <= 0) { +bool ConcatDataset::ValidateParams() { + if (datasets_.empty()) { + MS_LOG(ERROR) << "Concat: concatenated datasets are not specified."; return false; } - return true; } + +std::vector> ConcatDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(connector_que_size_)); + return node_ops; +} + MapDataset::MapDataset(std::vector> operations, std::vector input_columns, std::vector output_columns, const std::vector &project_columns) : operations_(operations), @@ -323,13 +733,10 @@ MapDataset::MapDataset(std::vector> operations, output_columns_(output_columns), project_columns_(project_columns) {} -std::shared_ptr>> MapDataset::Build() { +std::vector> MapDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; - // Currently default is true, and this is not exposed to user. - bool perf_mode = true; - std::vector> tensor_ops; // Build tensorOp from tensorOperation vector @@ -340,19 +747,82 @@ std::shared_ptr>> MapDataset::Build() { // This parameter will be removed with next rebase std::vector col_orders; - auto map_op = - std::make_shared(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_, perf_mode); + auto map_op = std::make_shared(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); if (!project_columns_.empty()) { auto project_op = std::make_shared(project_columns_); node_ops.push_back(project_op); } node_ops.push_back(map_op); - return std::make_shared>>(node_ops); + return node_ops; } bool MapDataset::ValidateParams() { if (operations_.empty()) { + MS_LOG(ERROR) << "Map: No operation is specified."; + return false; + } + + return true; +} + +// Function to build ProjectOp +ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} + +bool ProjectDataset::ValidateParams() { + if (columns_.empty()) { + MS_LOG(ERROR) << "No columns are specified."; + return false; + } + return true; +} + +std::vector> ProjectDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(columns_)); + return node_ops; +} + +// Function to build RenameOp +RenameDataset::RenameDataset(const std::vector &input_columns, + const std::vector &output_columns) + : input_columns_(input_columns), output_columns_(output_columns) {} + +bool RenameDataset::ValidateParams() { + if (input_columns_.empty() || output_columns_.empty()) { + MS_LOG(ERROR) << "input and output columns must be specified"; + return false; + } + if (input_columns_.size() != output_columns_.size()) { + MS_LOG(ERROR) << "input and output columns must be the same size"; + return false; + } + return true; +} + +std::vector> RenameDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(input_columns_, output_columns_, connector_que_size_)); + return node_ops; +} + +RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} + +std::vector> RepeatDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(repeat_count_)); + return node_ops; +} + +bool RepeatDataset::ValidateParams() { + if (repeat_count_ <= 0) { + MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative"; return false; } @@ -364,13 +834,13 @@ ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} // Function to build the ShuffleOp -std::shared_ptr>> ShuffleDataset::Build() { +std::vector> ShuffleDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; node_ops.push_back(std::make_shared(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, rows_per_buffer_)); - return std::make_shared>>(node_ops); + return node_ops; } // Function to validate the parameters for ShuffleDataset @@ -383,62 +853,71 @@ bool ShuffleDataset::ValidateParams() { return true; } -// Constructor for Cifar10Dataset -Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} +// Constructor for SkipDataset +SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {} -bool Cifar10Dataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - if (num_samples_ < 0) { - MS_LOG(ERROR) << "Number of samples cannot be negative"; +// Function to build the SkipOp +std::vector> SkipDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(skip_count_, connector_que_size_)); + return node_ops; +} + +// Function to validate the parameters for SkipDataset +bool SkipDataset::ValidateParams() { + if (skip_count_ <= -1) { + MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_; return false; } + return true; } -// Function to build CifarOp -std::shared_ptr>> Cifar10Dataset::Build() { +// Constructor for TakeDataset +TakeDataset::TakeDataset(int32_t count) : take_count_(count) {} + +// Function to build the TakeOp +std::vector> TakeDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; - // If user does not specify Sampler, create a default sampler based on the shuffle variable. - if (sampler_ == nullptr) { - sampler_ = CreateDefaultSampler(); - } + node_ops.push_back(std::make_shared(take_count_, connector_que_size_)); + return node_ops; +} - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_NULL_IF_ERROR( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); +// Function to validate the parameters for TakeDataset +bool TakeDataset::ValidateParams() { + if (take_count_ < -1) { + MS_LOG(ERROR) << "Take: Invalid input, take_count: " << take_count_; + return false; + } - node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, - dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_->Build()))); - return std::make_shared>>(node_ops); + return true; } -// Function to build ProjectOp -ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} +// Function to build ZipOp +ZipDataset::ZipDataset(const std::vector> &datasets) : datasets_(datasets) { + for (auto dataset : datasets_) { + this->children.push_back(dataset); + } +} -bool ProjectDataset::ValidateParams() { - if (columns_.empty()) { - MS_LOG(ERROR) << "No columns are specified."; +bool ZipDataset::ValidateParams() { + if (datasets_.empty()) { + MS_LOG(ERROR) << "Zip: dataset to zip are not specified."; return false; } return true; } -std::shared_ptr>> ProjectDataset::Build() { +std::vector> ZipDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; - node_ops.push_back(std::make_shared(columns_)); - return std::make_shared>>(node_ops); + node_ops.push_back(std::make_shared(rows_per_buffer_, connector_que_size_)); + return node_ops; } } // namespace api diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc deleted file mode 100644 index 2a6166f868..0000000000 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc +++ /dev/null @@ -1,1605 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/api/de_pipeline.h" - -#include -#include -#include - -#include "common/utils.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/cache/cache_client.h" -#include "minddata/dataset/engine/dataset_iterator.h" -#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" -#include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/datasetops/filter_op.h" -#include "minddata/dataset/engine/datasetops/source/celeba_op.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" -#include "minddata/dataset/engine/datasetops/source/clue_op.h" -#include "minddata/dataset/engine/datasetops/source/coco_op.h" -#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" -#include "minddata/dataset/engine/datasetops/source/manifest_op.h" -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/random_data_op.h" -#include "minddata/dataset/engine/datasetops/source/text_file_op.h" -#include "minddata/dataset/engine/datasetops/source/voc_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/kernels/py_func_op.h" -#include "minddata/dataset/util/random.h" -#include "minddata/dataset/util/status.h" -#include "minddata/mindrecord/include/shard_category.h" -#include "minddata/mindrecord/include/shard_distributed_sample.h" -#include "minddata/mindrecord/include/shard_sample.h" -#include "minddata/mindrecord/include/shard_shuffle.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); - -static std::unordered_map g_parse_op_func_ = { - {kShuffle, &DEPipeline::ParseShuffleOp}, - {kMindrecord, &DEPipeline::ParseMindRecordOp}, - {kMap, &DEPipeline::ParseMapOp}, - {kFilter, &DEPipeline::ParseFilterOp}, - {kBatch, &DEPipeline::ParseBatchOp}, - {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, - {kBarrier, &DEPipeline::ParseBarrierOp}, - {kRepeat, &DEPipeline::ParseRepeatOp}, - {kSkip, &DEPipeline::ParseSkipOp}, - {kZip, &DEPipeline::ParseZipOp}, - {kConcat, &DEPipeline::ParseConcatOp}, - {kRename, &DEPipeline::ParseRenameOp}, - {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, - {kGenerator, &DEPipeline::ParseGeneratorOp}, - {kTfReader, &DEPipeline::ParseTFReaderOp}, - {kProject, &DEPipeline::ParseProjectOp}, - {kTake, &DEPipeline::ParseTakeOp}, - {kImageFolder, &DEPipeline::ParseImageFolderOp}, - {kMnist, &DEPipeline::ParseMnistOp}, - {kManifest, &DEPipeline::ParseManifestOp}, - {kVoc, &DEPipeline::ParseVOCOp}, - {kCoco, &DEPipeline::ParseCocoOp}, - {kCifar10, &DEPipeline::ParseCifar10Op}, - {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}, - {kRandomData, &DEPipeline::ParseRandomDataOp}, - {kTextFile, &DEPipeline::ParseTextFileOp}, - {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, - {kClue, &DEPipeline::ParseClueOp}}; - -DEPipeline::DEPipeline() : iterator_(nullptr) { - try { - // One time init - (void)GlobalInit(); - - // Instantiate the execution tree - tree_ = std::make_shared(); - repeat_num_ = 1; - batch_size_ = 1; - num_rows_ = 0; - num_classes_ = 0; - temp_batch_size_ = 1; - temp_drop_remainder_ = false; - } catch (const std::exception &err) { - MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << "."; - return; - } -} - -DEPipeline::~DEPipeline() { - { - // Release GIL before joining all threads - py::gil_scoped_release gil_release; - // Release tree - tree_.reset(); - } -} - -// Function to add a Node to the Execution Tree. -Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { - // For each operator, Parse through the list of arguments, then call the respective builder/constructor. - // Note that each call to the parse function may result in building more than one dataset operator. - // For example, one call to ParseNNNOp may result in multiple internal C nodes: - // nodeA - // | - // nodeB - // | - // nodeC - // However, the python side dataset is more abstract, and it does not know about the potential subtree that - // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the - // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child - // to this subtee. - // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse - // function. - DsOpPtr top = nullptr; - DsOpPtr bottom = nullptr; - auto iter = g_parse_op_func_.find(op_name); - if (iter != g_parse_op_func_.end()) { - pFunction func = iter->second; - RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); - - if (top == nullptr) { - RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); - } - - // It is not required that the parse function always produces the bottom pointer. If it's still null, - // then set top and bottom to be the same operator - if (bottom == nullptr) bottom = top; - - // Pack these pointers into a py dict so that we can return both back to python. - (*output)["top"] = top; - (*output)["bottom"] = bottom; - } else { - RETURN_STATUS_UNEXPECTED("No such Op"); - } - // Associate current dataset op node with the tree. - RETURN_IF_NOT_OK(tree_->AssociateNode(top)); - return Status::OK(); -} -// Function to add a child and parent relationship. -Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) { - // Link this relationship. - // Note parent node takes ownership of the child - return (parent_op->AddChild(child_op)); -} - -// Function to assign the node as root. -Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } - -// Function to launch the tree execution. -Status DEPipeline::LaunchTreeExec() { - RETURN_IF_NOT_OK(tree_->Prepare()); - RETURN_IF_NOT_OK(tree_->Launch()); - iterator_ = std::make_unique(tree_); - if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); - return Status::OK(); -} - -void DEPipeline::PrintTree() { - for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { - std::stringstream ss; - ss << *itr; - MS_LOG(DEBUG) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << "."; - } -} - -Status DEPipeline::GetNextAsMap(py::dict *output) { - TensorMap row; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetNextAsMap(&row); - } - RETURN_IF_NOT_OK(s); - // Generate Python dict as return - for (auto el : row) { - (*output)[common::SafeCStr(el.first)] = el.second; - } - return Status::OK(); -} - -Status DEPipeline::GetNextAsList(py::list *output) { - TensorRow row; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->FetchNextTensorRow(&row); - } - RETURN_IF_NOT_OK(s); - // Generate Python list as return - for (auto el : row) { - output->append(el); - } - return Status::OK(); -} - -Status DEPipeline::GetOutputShapes(py::list *output) { - std::vector shapes; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetOutputShapes(&shapes); - } - RETURN_IF_NOT_OK(s); - for (auto el : shapes) { - py::list shape; - for (auto dim : el.AsVector()) { - shape.append(dim); - } - output->append(shape); - } - return Status::OK(); -} - -Status DEPipeline::GetOutputTypes(py::list *output) { - std::vector types; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetOutputTypes(&types); - } - RETURN_IF_NOT_OK(s); - for (auto el : types) { - output->append(el.AsNumpyType()); - } - return Status::OK(); -} - -int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; } - -int DEPipeline::GetBatchSize() const { return batch_size_; } - -int DEPipeline::GetRepeatCount() const { return repeat_num_; } - -float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -std::string ToString(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -std::vector ToStringVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (!l.is_none()) - vector.push_back(py::str(l)); - else - vector.emplace_back(""); - } - return vector; -} - -std::set ToStringSet(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::set set; - for (auto l : list) { - if (!l.is_none()) { - (void)set.insert(py::str(l)); - } - } - return set; -} - -std::map ToStringMap(const py::handle handle) { - py::dict dict = py::reinterpret_borrow(handle); - std::map map; - for (auto p : dict) { - (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second))); - } - return map; -} - -std::vector ToIntVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (!l.is_none()) { - vector.push_back(ToInt(l)); - } - } - return vector; -} - -std::vector ToTypeVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (l.is_none()) { - vector.emplace_back(DataType()); - } else { - vector.push_back(l.cast()); - } - } - return vector; -} - -Status DEPipeline::SetBatchParameters(const py::dict &args) { - if (args["batch_size"].is_none()) { - std::string err_msg = "Error: batchSize is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - temp_batch_size_ = ToInt(args["batch_size"]); - CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid."); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "drop_remainder") { - temp_drop_remainder_ = ToBool(value); - } - } - } - - return Status::OK(); -} - -Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - if (!args["buffer_size"].is_none()) { - (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); - } else { - std::string err_msg = "Error: Shuffle buffer size is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "reshuffle_each_epoch") { - (void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, - std::vector> *operators, - int num_padded) { - auto sampler = py::reinterpret_borrow(handle); - auto create = sampler.attr("create_for_minddataset"); - auto op = create().cast>(); - std::stack> stack_ops; - while (op != nullptr) { - auto sampler_op = std::dynamic_pointer_cast(op); - if (sampler_op && num_padded > 0) { - sampler_op->SetNumPaddedSamples(num_padded); - stack_ops.push(sampler_op); - } else { - stack_ops.push(op); - } - op = op->GetChildOp(); - } - while (!stack_ops.empty()) { - operators->push_back(stack_ops.top()); - stack_ops.pop(); - } - return Status::OK(); -} - -Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_file"].is_none()) { - std::string err_msg = "Error: at least one of dataset_files is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - bool load_dataset = ToBool(args["load_dataset"]); - if (load_dataset == true) { - (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); - } else { - (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); - } - (void)builder->SetLoadDataset(load_dataset); - std::vector in_col_names; - if (!args["columns_list"].is_none()) { - in_col_names = ToStringVector(args["columns_list"]); - if (in_col_names.empty() || in_col_names[0].empty()) { - std::string err_msg = "Error: columns_list is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetColumnsToLoad(in_col_names); - } - - if (!args["padded_sample"].is_none()) { - (void)builder->SetPaddedSample(args["padded_sample"]); - (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); - } - std::vector> operators; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumMindRecordWorkers(ToInt(value)); - } else if (key == "block_reader" && ToBool(value) == true) { - (void)builder->SetBlockReader(); - } else if (key == "sampler") { - int num_padded = 0; - if (!args["num_padded"].is_none()) { - num_padded = ToInt(args["num_padded"]); - } - RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); - } - } - } - - if (!operators.empty()) { - (void)builder->SetOperators(operators); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - num_rows_ = op->num_rows(); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - MapOp::Builder map_builder; - std::vector> tensor_op_list; - std::vector project_columns; - std::shared_ptr cache_client = nullptr; - int num_workers = 0; - - if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "input_columns") { - std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)map_builder.SetInColNames(in_col_names); - } else if (key == "output_columns") { - (void)map_builder.SetOutColNames(ToStringVector(value)); - } else if (key == "columns_order") { - project_columns = ToStringVector(value); - } else if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)map_builder.SetNumWorkers(num_workers); - } else if (key == "prefetch_size") { - (void)map_builder.SetOpConnectorSize(ToInt(value)); - } else if (key == "operations") { - py::handle tensor_ops = args["operations"]; - // operation can be a list of TensorOps or a single TensorOp. - if (py::isinstance(tensor_ops)) { - for (auto op : tensor_ops) { - std::shared_ptr tensor_op; - if (py::isinstance(op)) { - tensor_op = op.cast>(); - } else if (py::isinstance(op)) { - tensor_op = std::make_shared(op.cast()); - } else { - RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); - } - tensor_op_list.push_back(tensor_op); - } - } - if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); - (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); - } else if (key == "cache") { - cache_client = value.cast>(); - } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); - } - } - } - - std::shared_ptr map_op; - RETURN_IF_NOT_OK(map_builder.Build(&map_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); - *top = map_op; - - // Add a project op over top of the map if the user wanted to reposition the columns - if (!project_columns.empty()) { - ProjectOp::Builder proj_builder(project_columns); - std::shared_ptr proj_op; - RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); - RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); - *top = proj_op; - *bottom = map_op; - } - - // Additionally, add a cache if required. This will go over top of the project op if one - // was created, otherwise it goes over top of the map op - if (cache_client) { - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); - *top = cache_op; - *bottom = map_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - - if (args["predicate"].is_none()) { - RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); - } - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "predicate") { - py::handle op = args["predicate"]; - if (!py::isinstance(op)) { - RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); - } - py::function predicate_func = op.cast(); - (void)builder->SetPredicateFunc(std::move(predicate_func)); - } else if (key == "input_columns") { - std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)builder->SetInColNames(in_col_names); - } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - repeat_num_ = ToInt(args["count"]); - std::shared_ptr op; - RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "source") { - py::object obj = py::cast(&value); - if (!py::isinstance(obj)) { - std::string err_msg = "Error: generator is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetGeneratorFunction(obj.cast()); - } else if (key == "column_names") { - (void)builder->SetColumnNames(ToStringVector(value)); - } else if (key == "column_types") { - (void)builder->SetColumnTypes(ToTypeVector(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder; - if (py::isinstance(args["batch_size"])) { - batch_size_ = ToInt(args["batch_size"]); - CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); - builder = std::make_shared(ToInt(args["batch_size"])); - } else if (py::isinstance(args["batch_size"])) { - builder = std::make_shared(1); - (void)builder->SetBatchSizeFunc(args["batch_size"].cast()); - } else { - std::string err_msg = "Error: batch_size is neither an Integer nor a python function"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "drop_remainder") { - (void)builder->SetDrop(ToBool(value)); - } - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } - if (key == "per_batch_map") { - (void)builder->SetBatchMapFunc(value.cast()); - } - if (key == "input_columns") { - (void)builder->SetColumnsToMap(ToStringVector(value)); - } - if (key == "pad_info") { - PadInfo pad_info; - RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); - (void)builder->SetPaddingMap(pad_info, true); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", - "bucket_batch_sizes"}; - for (auto name : mandatory_arguments) { - if (args[name.c_str()].is_none()) { - std::string err_msg = "Error: " + name + " is not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - - std::shared_ptr builder = std::make_shared( - ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), - ToIntVector(args[mandatory_arguments[2].c_str()])); - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "length_dependent_columns") { - (void)builder->SetLengthDependentColumns(ToStringVector(value)); - } - if (key == "bucket_boundaries") { - (void)builder->SetBucketBoundaries(ToIntVector(value)); - } - if (key == "bucket_batch_sizes") { - (void)builder->SetBucketBatchSizes(ToIntVector(value)); - } - if (key == "element_length_function") { - (void)builder->SetElementLengthFunction(value.cast()); - } - if (key == "pad_info") { - PadInfo pad_info; - RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); - (void)builder->SetPadInfo(pad_info); - } - if (key == "pad_to_bucket_boundary") { - (void)builder->SetPadToBucketBoundary(ToBool(value)); - } - if (key == "drop_remainder") { - (void)builder->SetDropRemainder(ToBool(value)); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - // Right now barrier should only take num_rows_per_buffer = 1 - // The reason for this is because having it otherwise can lead to blocking issues - // See barrier_op.h for more details - (void)builder->SetRowsPerBuffer(1); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "condition_name") { - (void)builder->SetConditionName(ToString(value)); - } else if (key == "condition_func") { - (void)builder->SetConditionFunc(value.cast()); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - int32_t prefetch_size = 0; - if (args.contains("prefetch_size")) { - if (args["prefetch_size"].is_none()) { - prefetch_size = 16; - } else { - prefetch_size = ToInt(args["prefetch_size"]); - } - } - std::shared_ptr builder = std::make_shared(prefetch_size); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "queue_name") { - (void)builder->SetChannelName(ToString(value)); - } else if (key == "device_type") { - (void)builder->SetDeviceType(ToString(value)); - } else if (key == "device_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "num_batch") { - (void)builder->SetNumBatch(ToInt(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector in_col_names; - std::vector out_col_names; - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "input_columns") { - in_col_names = ToStringVector(value); - } else if (key == "output_columns") { - out_col_names = ToStringVector(value); - } - } - } - if (in_col_names.empty() || in_col_names[0].empty()) { - std::string err_msg = "Error: input_column_names is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (out_col_names.empty() || out_col_names[0].empty()) { - std::string err_msg = "Error: output_column_names is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetInColNames(in_col_names); - (void)builder->SetOutColNames(out_col_names); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - std::vector files_list; - std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; - int num_workers = 0; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetDatasetFilesList(files_list); - } else { - std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_load; - bool schema_exists = false; - bool shuffle_required = false; - int64_t num_devices = 0; - int64_t total_rows = 0; - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)builder->SetNumWorkers(num_workers); - } else if (key == "columns_list") { - columns_to_load = ToStringVector(value); - (void)builder->SetColumnsToLoad(columns_to_load); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "schema_file_path" || key == "schema_json_string") { - schema_exists = true; - } else if (key == "num_samples") { - total_rows = ToInt(value); - (void)builder->setTotalRows(total_rows); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "shard_equal_rows") { - (void)builder->SetShardEqualRows(ToBool(value)); - } else if (key == "cache") { - cache_client = value.cast>(); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); - } - } - } - if (schema_exists) { - std::unique_ptr schema = std::make_unique(); - if (args.contains("schema_file_path")) { - RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); - } else { - RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); - } - (void)builder->SetDataSchema(std::move(schema)); - } - - // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed - // because TFReaderOp is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - if (sampler) { - (void)builder->SetSampler(std::move(sampler)); - } else if (cache_client) { - int64_t num_samples = 0; - int64_t start_index = 0; - sampler = std::make_shared(num_samples, start_index); - (void)builder->SetSampler(std::move(sampler)); - } - - std::shared_ptr tf_op; - RETURN_IF_NOT_OK(builder->Build(&tf_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); - *top = tf_op; - - if (!cache_client && shuffle_required) { - const boolean estimate = true; - const int64_t workers = 8; - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset via estimate and then compute the shuffle size - RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); - *top = shuffle_op; - *bottom = tf_op; - } - - // Add a cache op over this op if required and update the output subtree (top/bottom) - if (cache_client) { - // Note, it is not allowed to have both shuffle and cache - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); - *top = cache_op; - *bottom = tf_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["columns"].is_none()) { - std::string err_msg = "Error: columns is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_project = ToStringVector(args["columns"]); - std::shared_ptr builder = std::make_shared(columns_to_project); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - int num_workers = 0; - std::shared_ptr cache_client = nullptr; - std::shared_ptr builder = std::make_shared(); - (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)builder->SetNumWorkers(num_workers); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "extensions") { - (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "cache") { - cache_client = value.cast>(); - } - } - } - std::shared_ptr if_op; - RETURN_IF_NOT_OK(builder->Build(&if_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); - *top = if_op; - - // Additionally, add a cache if required. - // Note that this cache op is only acting as a place holder for the caching position - // within the tree. Later, a pre-pass will execute a tree transform to set up the actual - // caching logic in the tree. - if (cache_client) { - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); - *top = cache_op; - *bottom = if_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_file"].is_none()) { - std::string err_msg = "Error: No dataset files specified for manifest"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr builder = std::make_shared(); - (void)builder->SetManifestFile(ToString(args["dataset_file"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "usage") { - (void)builder->SetUsage(ToString(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["task"].is_none()) { - std::string err_msg = "Error: No task specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["mode"].is_none()) { - std::string err_msg = "Error: No mode specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - (void)builder->SetTask(ToString(args["task"])); - (void)builder->SetMode(ToString(args["mode"])); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - - return Status::OK(); -} - -Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["annotation_file"].is_none()) { - std::string err_msg = "Error: No annotation_file specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["task"].is_none()) { - std::string err_msg = "Error: No task specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - (void)builder->SetFile(ToString(args["annotation_file"])); - (void)builder->SetTask(ToString(args["task"])); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetCifarDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - - (void)builder->SetCifarType(true); - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetCifarDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - - (void)builder->SetCifarType(false); - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - RandomDataOp::Builder builder; - std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; - int num_workers = 0; - - if (args["total_rows"].is_none()) { - std::string err_msg = "Error: total_rows is a required argument"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_load; - bool schema_exists = false; - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)builder.SetNumWorkers(num_workers); - } else if (key == "schema_file_path" || key == "schema_json_string") { - schema_exists = true; - } else if (key == "columns_list") { - columns_to_load = ToStringVector(value); - } else if (key == "total_rows") { - // This is not sampling here. The random data op needs to know how much data to generate. - (void)builder.SetTotalRows(ToInt(value)); - } else if (key == "cache") { - cache_client = value.cast>(); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); - } - } - } - if (schema_exists) { - std::unique_ptr schema = std::make_unique(); - if (args.contains("schema_file_path")) { - RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); - } else { - RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); - } - (void)builder.SetDataSchema(std::move(schema)); - } - - // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed - // because RandomDataOp is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - if (sampler) { - (void)builder.SetSampler(std::move(sampler)); - } else if (cache_client) { - int64_t num_samples = 0; - int64_t start_index = 0; - sampler = std::make_shared(num_samples, start_index); - (void)builder.SetSampler(std::move(sampler)); - } - - std::shared_ptr random_op = nullptr; - RETURN_IF_NOT_OK(builder.Build(&random_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); - *top = random_op; - - // Add a cache op over this op if required and update the output subtree (top/bottom) - if (cache_client) { - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); - *top = cache_op; - *bottom = random_op; - } - - return Status::OK(); -} - -int32_t DEPipeline::GetNumClasses() const { return num_classes_; } - -Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - - std::shared_ptr builder = std::make_shared(); - if (builder == nullptr) { - std::string err_msg = "Create celebaop builder failed"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - (void)builder->SetCelebADir(ToString(args["dataset_dir"])); - for (const auto &arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "extensions") { - (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "dataset_type") { - (void)builder->SetDatasetType(ToString(value)); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetTextFilesList(files_list); - } else { - RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); - } - // Optional arguments - bool shuffle_required = false; - int64_t num_devices = 0; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "num_samples") { - (void)builder->SetTotalRows(ToInt(value)); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } - } - } - - std::shared_ptr txt_op; - RETURN_IF_NOT_OK(builder->Build(&txt_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); - *top = txt_op; - - if (shuffle_required) { - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset and then compute the shuffle size - RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); - *top = shuffle_op; - *bottom = txt_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { - for (auto p : py::reinterpret_borrow(value)) { - if (!p.second.is_none()) { - auto tp = py::reinterpret_borrow(p.second); - CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); - TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); - std::shared_ptr pad_val = nullptr; - if (py::isinstance(tp[1])) { - std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED( - Tensor::CreateTensor(&pad_val, std::vector{pad_val_string}, TensorShape::CreateScalar()), - "Cannot create pad_value Tensor"); - } else { - float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(), - DataType(DataType::DE_FLOAT32)), - "Cannot create pad_value Tensor"); - pad_val->SetItemAt({}, pad_val_float); - } - (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); - } else { // tuple is None - (void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}}); - } - } - return Status::OK(); -} - -Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "freq_range") { - py::tuple tp = py::reinterpret_borrow(value); - if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); - if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); - } else if (key == "top_k") { - builder->SetTopK(py::reinterpret_borrow(value)); - } else if (key == "columns") { - (void)builder->SetColumnNames(ToStringVector(value)); - } else if (key == "vocab") { - (void)builder->SetVocab(value.cast>()); - } else if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "special_first") { - (void)builder->SetSpecialFirst(ToBool(value)); - } else if (key == "special_tokens") { - (void)builder->SetSpecialTokens(ToStringVector(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetClueFilesList(files_list); - } else { - RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); - } - // Optional arguments - bool shuffle_required = false; - int64_t num_devices = 0; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "cols_to_keyword") { - std::map map_dict; - for (auto p : py::reinterpret_borrow(value)) { - if (!p.second.is_none()) { - map_dict.insert({ToString(p.first), ToString(p.second)}); - } else { - map_dict.insert({ToString(p.first), ToString(p.first)}); - } - } - (void)builder->SetColsKeyMap(map_dict); - } - } - } - - std::shared_ptr clue_op; - RETURN_IF_NOT_OK(builder->Build(&clue_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); - *top = clue_op; - - if (shuffle_required) { - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset and then compute the shuffle size - RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); - *top = shuffle_op; - *bottom = clue_op; - } - - return Status::OK(); -} - -// Helper function to inject the cache operator over top of the current operation being built. -Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num_workers, - std::shared_ptr input_op, std::shared_ptr *cache_op) { - std::shared_ptr new_cache_op = nullptr; - CacheOp::Builder cache_builder; - // use the same number of workers as the leaf. We need some optimization here, the user does not - // give the cache op number of workers directly. - if (num_workers != 0) { - (void)cache_builder.SetNumWorkers(num_workers); - } - (void)cache_builder.SetClient(cache_client); - RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); - RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); - // We have now created: - // - // CacheOp - // | - // input_op - // - *cache_op = new_cache_op; - - return Status::OK(); -} - -// Helper function to inject a shuffle operator over top of the current operation being built. -Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, - std::shared_ptr *shuffle_op) { - std::shared_ptr new_shuffle_op = nullptr; - ShuffleOp::Builder shuffle_builder; - - (void)shuffle_builder.SetShuffleSize(shuffle_size); - RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); - RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); - // We have now created: - // - // ShuffleOp - // | - // input_op - // - *shuffle_op = new_shuffle_op; - - return Status::OK(); -} - -// Common code for computing a default shuffle size -Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size) { - const int64_t average_files_multiplier = 4; - const int64_t shuffle_max = 10000; - int64_t avg_rows_per_file = 0; - - // Adjust the num rows per shard if sharding was given - if (num_devices > 0) { - if (num_rows % num_devices == 0) { - num_rows = num_rows / num_devices; - } else { - num_rows = (num_rows / num_devices) + 1; - } - } - - // Cap based on total rows directive. Some ops do not have this and give value of 0. - if (total_rows > 0) { - num_rows = std::min(num_rows, total_rows); - } - - // get the average per file - avg_rows_per_file = num_rows / num_files; - - *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h deleted file mode 100644 index 755e827ef2..0000000000 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h +++ /dev/null @@ -1,225 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_API_DE_PIPELINE_H_ -#define DATASET_API_DE_PIPELINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "minddata/dataset/core/client.h" // DE client -#include "minddata/dataset/engine/dataset_iterator.h" -#include "minddata/dataset/util/status.h" -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -using DsOpPtr = std::shared_ptr; - -class CacheClient; - -// enum for the dataset operator names -enum OpName { - kShuffle, - kMindrecord, - kBatch, - kBucketBatch, - kBarrier, - kCache, - kRepeat, - kSkip, - kTake, - kZip, - kConcat, - kMap, - kFilter, - kDeviceQueue, - kGenerator, - kRename, - kTfReader, - kProject, - kImageFolder, - kMnist, - kManifest, - kVoc, - kCoco, - kCifar10, - kCifar100, - kCelebA, - kRandomData, - kTextFile, - kBuildVocab, - kClue -}; - -// The C++ binder class that we expose to the python script. -class DEPipeline { - public: - DEPipeline(); - - ~DEPipeline(); - - // Function to add a Node to the Execution Tree. - Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); - - // Function to add a child and parent relationship. - static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); - - // Function to assign the node as root. - Status AssignRootNode(const DsOpPtr &dataset_op); - - // Function to launch the tree execution. - Status LaunchTreeExec(); - - // Get a row of data as dictionary of column name to the value. - Status GetNextAsMap(py::dict *output); - - // Get a row of data as list. - Status GetNextAsList(py::list *output); - - Status GetOutputShapes(py::list *output); - - Status GetOutputTypes(py::list *output); - - int GetDatasetSize() const; - - int GetBatchSize() const; - - int GetRepeatCount() const; - - Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status BuildMindrecordSamplerChain(const py::handle &handle, - std::vector> *operators, - int num_padded); - - Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom); - - Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - void PrintTree(); - - int32_t GetNumClasses() const; - - Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status SetBatchParameters(const py::dict &args); - - Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - private: - // Execution tree that links the dataset operators. - std::shared_ptr tree_; - - std::unique_ptr iterator_; - - static Status ParsePadInfo(py::handle value, PadInfo *pad_info); - - /// \brief Helper function to inject a cache operator over top of the current operation being built. - /// \param[in] cache_client The client to use for caching - /// \param[in] num_workers The number of workers to use in the cache op - /// \param[in] input_op The operator to build the cache on top of - /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be - /// the cache operator - /// \return Status return code - Status AddCacheOp(std::shared_ptr cache_client, int num_workers, std::shared_ptr input_op, - std::shared_ptr *cache_op); - - /// \brief Helper function to inject a shuffle operator over top of the current operation being built. - /// \param[in] shuffle_size The size to use in the shuffle buffer - /// \param[in] input_op The operator to build shuffle on top of - /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be - /// the shuffle operator - /// \return Status return code - Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, - std::shared_ptr *shuffle_op); - - /// \brief Helper function to compute the shuffle size - /// \param[in] num_files The number of files in the dataset - /// \param[in] num_devices The number of devices in the dataset - /// \param[in] num_rows The number of rows in the dataset - /// \param[in] total_rows An upper bound on the total rows in the dataset - /// \param[out] shuffle_size The resultant computed shuffle size - /// \return Status return code - Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size); - - int batch_size_; - int repeat_num_; - int num_rows_; - int num_classes_; - - int temp_batch_size_; - bool temp_drop_remainder_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_API_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc b/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc new file mode 100644 index 0000000000..efb2dbdf97 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/de_tensor.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/de_tensor.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "mindspore/core/ir/dtype/type_id.h" +#include "utils/hashing.h" +#include "mindspore/lite/src/ir/tensor.h" + +namespace mindspore { +namespace tensor { +dataset::DataType MSTypeToDEType(TypeId data_type) { + switch (data_type) { + case kNumberTypeBool: + return dataset::DataType(dataset::DataType::DE_BOOL); + case kNumberTypeInt8: + return dataset::DataType(dataset::DataType::DE_INT8); + case kNumberTypeUInt8: + return dataset::DataType(dataset::DataType::DE_UINT8); + case kNumberTypeInt16: + return dataset::DataType(dataset::DataType::DE_INT16); + case kNumberTypeUInt16: + return dataset::DataType(dataset::DataType::DE_UINT16); + case kNumberTypeInt32: + return dataset::DataType(dataset::DataType::DE_INT32); + case kNumberTypeUInt32: + return dataset::DataType(dataset::DataType::DE_UINT32); + case kNumberTypeInt64: + return dataset::DataType(dataset::DataType::DE_INT64); + case kNumberTypeUInt64: + return dataset::DataType(dataset::DataType::DE_UINT64); + case kNumberTypeFloat16: + return dataset::DataType(dataset::DataType::DE_FLOAT16); + case kNumberTypeFloat32: + return dataset::DataType(dataset::DataType::DE_FLOAT32); + case kNumberTypeFloat64: + return dataset::DataType(dataset::DataType::DE_FLOAT64); + default: + return dataset::DataType(dataset::DataType::DE_UNKNOWN); + } +} + +TypeId DETypeToMSType(dataset::DataType data_type) { + switch (data_type.value()) { + case dataset::DataType::DE_BOOL: + return mindspore::TypeId::kNumberTypeBool; + case dataset::DataType::DE_INT8: + return mindspore::TypeId::kNumberTypeInt8; + case dataset::DataType::DE_UINT8: + return mindspore::TypeId::kNumberTypeUInt8; + case dataset::DataType::DE_INT16: + return mindspore::TypeId::kNumberTypeInt16; + case dataset::DataType::DE_UINT16: + return mindspore::TypeId::kNumberTypeUInt16; + case dataset::DataType::DE_INT32: + return mindspore::TypeId::kNumberTypeInt32; + case dataset::DataType::DE_UINT32: + return mindspore::TypeId::kNumberTypeUInt32; + case dataset::DataType::DE_INT64: + return mindspore::TypeId::kNumberTypeInt64; + case dataset::DataType::DE_UINT64: + return mindspore::TypeId::kNumberTypeUInt64; + case dataset::DataType::DE_FLOAT16: + return mindspore::TypeId::kNumberTypeFloat16; + case dataset::DataType::DE_FLOAT32: + return mindspore::TypeId::kNumberTypeFloat32; + case dataset::DataType::DE_FLOAT64: + return mindspore::TypeId::kNumberTypeFloat64; + default: + return kTypeUnknown; + } +} + +MSTensor *DETensor::CreateTensor(TypeId data_type, const std::vector &shape) { + return new DETensor(data_type, shape); +} + +MSTensor *DETensor::CreateTensor(const std::string &path) { + std::shared_ptr t; + (void)dataset::Tensor::CreateFromFile(path, &t); + return new DETensor(std::move(t)); +} + +DETensor::DETensor(TypeId data_type, const std::vector &shape) { + std::vector t_shape; + t_shape.reserve(shape.size()); + std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape), + [](int s) -> dataset::dsize_t { return static_cast(s); }); + dataset::Tensor::CreateEmpty(dataset::TensorShape(t_shape), MSTypeToDEType(data_type), &this->tensor_impl_); +} + +DETensor::DETensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } + +MSTensor *DETensor::ConvertToLiteTensor() { + // static MSTensor::CreateTensor is only for the LiteTensor + MSTensor *tensor = MSTensor::CreateTensor(this->data_type(), this->shape()); + MS_ASSERT(tensor->Size() == this->Size()); + memcpy_s(tensor->MutableData(), tensor->Size(), this->MutableData(), this->Size()); + return tensor; +} + +std::shared_ptr DETensor::tensor() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_; +} + +TypeId DETensor::data_type() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return DETypeToMSType(this->tensor_impl_->type()); +} + +TypeId DETensor::set_data_type(TypeId data_type) { + MS_ASSERT(this->tensor_impl_ != nullptr); + if (data_type != this->data_type()) { + std::shared_ptr temp; + dataset::Tensor::CreateFromMemory(this->tensor_impl_->shape(), MSTypeToDEType(data_type), + this->tensor_impl_->GetBuffer(), &temp); + this->tensor_impl_ = temp; + } + return data_type; +} + +std::vector DETensor::shape() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + std::vector t_shape = this->tensor_impl_->shape().AsVector(); + std::vector shape; + shape.reserve(t_shape.size()); + std::transform(t_shape.begin(), t_shape.end(), std::back_inserter(shape), + [](dataset::dsize_t s) -> int { return static_cast(s); }); + return shape; +} + +size_t DETensor::set_shape(const std::vector &shape) { + MS_ASSERT(this->tensor_impl_ != nullptr); + std::vector t_shape; + t_shape.reserve(shape.size()); + std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape), + [](int s) -> dataset::dsize_t { return static_cast(s); }); + dataset::Status rc = this->tensor_impl_->Reshape(dataset::TensorShape(t_shape)); + return shape.size(); +} + +int DETensor::DimensionSize(size_t index) const { + MS_ASSERT(this->tensor_impl_ != nullptr); + int dim_size = -1; + auto shape = this->shape(); + if (index < shape.size()) { + dim_size = shape[index]; + } else { + MS_LOG(ERROR) << "Dimension index is wrong: " << index; + } + return dim_size; +} + +int DETensor::ElementsNum() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->Size(); +} + +std::size_t DETensor::hash() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + auto shape = this->shape(); + std::size_t hash_value = std::hash{}(SizeToInt(this->data_type())); + hash_value = hash_combine(hash_value, std::hash{}(shape.size())); + // hash all elements may costly, so only take at most 4 elements into account based on + // some experiments. + for (size_t i = 0; (i < shape.size()) && (i < 4); ++i) { + hash_value = hash_combine(hash_value, (std::hash{}(shape[i]))); + } + return hash_value; +} + +size_t DETensor::Size() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->SizeInBytes(); +} + +void *DETensor::MutableData() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->GetMutableBuffer(); +} + +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/execute.cc b/mindspore/ccsrc/minddata/dataset/api/execute.cc new file mode 100644 index 0000000000..548bb8866b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/execute.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/execute.h" +#include "minddata/dataset/include/de_tensor.h" +#include "minddata/dataset/include/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +namespace api { + +Execute::Execute(std::shared_ptr op) : op_(std::move(op)) {} + +std::shared_ptr Execute::operator()(std::shared_ptr input) { + // Build the op + if (op_ == nullptr) { + MS_LOG(ERROR) << "Input TensorOperation is not valid"; + return nullptr; + } + + std::shared_ptr de_input = std::dynamic_pointer_cast(input)->tensor(); + if (de_input == nullptr) { + MS_LOG(ERROR) << "Input Tensor is not valid"; + return nullptr; + } + std::shared_ptr transform = op_->Build(); + std::shared_ptr de_output; + Status rc = transform->Compute(de_input, &de_output); + + if (rc.IsError()) { + // execution failed + MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString(); + return nullptr; + } + return std::make_shared(std::move(de_output)); +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index 068bcfaa04..e60ddc7643 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -52,9 +52,17 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { // Iterative BFS converting Dataset tree into runtime Execution tree. std::queue, std::shared_ptr>> q; - if (ds != nullptr) { + if (ds == nullptr) { + RETURN_STATUS_UNEXPECTED("Input is null pointer"); + } else { // Convert the current root node. - auto root_op = ds->Build()->front(); + auto root_ops = ds->Build(); + if (root_ops.empty()) { + RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); + } + + auto root_op = root_ops.front(); + RETURN_UNEXPECTED_IF_NULL(root_op); RETURN_IF_NOT_OK(tree_->AssociateNode(root_op)); @@ -68,20 +76,22 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { // Iterate through all the direct children of the first element in our BFS queue for (auto child : node_pair.first->children) { auto child_ops = child->Build(); - RETURN_UNEXPECTED_IF_NULL(child_ops); + if (child_ops.empty()) { + RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); + } auto node_op = node_pair.second; // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them // with the execution tree and add the child and parent relationship between the nodes // Note that some Dataset objects might return more than one DatasetOps // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset - for (auto child_op : *child_ops) { + for (auto child_op : child_ops) { RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); RETURN_IF_NOT_OK(node_op->AddChild(child_op)); node_op = child_op; } // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current // execution tree) to the BFS queue - q.push(std::make_pair(child, child_ops->back())); + q.push(std::make_pair(child, child_ops.back())); } } RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc new file mode 100644 index 0000000000..c578c295e5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/api/python/de_pipeline.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER( + DEPipeline, 0, ([](const py::module *m) { + (void)py::class_(*m, "DEPipeline") + .def(py::init<>()) + .def( + "AddNodeToTree", + [](DEPipeline &de, const OpName &op_name, const py::dict &args) { + py::dict out; + THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); + return out; + }, + py::return_value_policy::reference) + .def_static("AddChildToParentNode", + [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { + THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); + }) + .def("AssignRootNode", + [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) + .def("SetBatchParameters", + [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) + .def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); }) + .def("GetNextAsMap", + [](DEPipeline &de) { + py::dict out; + THROW_IF_ERROR(de.GetNextAsMap(&out)); + return out; + }) + .def("GetNextAsList", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetNextAsList(&out)); + return out; + }) + .def("GetOutputShapes", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetOutputShapes(&out)); + return out; + }) + .def("GetOutputTypes", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetOutputTypes(&out)); + return out; + }) + .def("GetDatasetSize", &DEPipeline::GetDatasetSize) + .def("GetBatchSize", &DEPipeline::GetBatchSize) + .def("GetNumClasses", &DEPipeline::GetNumClasses) + .def("GetRepeatCount", &DEPipeline::GetRepeatCount) + .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) + .def("SaveDataset", [](DEPipeline &de, const std::vector &file_names, const std::string &file_type) { + THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); + return true; + }); + })); + +PYBIND_REGISTER(OpName, 0, ([](const py::module *m) { + (void)py::enum_(*m, "OpName", py::arithmetic()) + .value("SHUFFLE", OpName::kShuffle) + .value("BATCH", OpName::kBatch) + .value("BUCKETBATCH", OpName::kBucketBatch) + .value("BARRIER", OpName::kBarrier) + .value("MINDRECORD", OpName::kMindrecord) + .value("CACHE", OpName::kCache) + .value("REPEAT", OpName::kRepeat) + .value("SKIP", OpName::kSkip) + .value("TAKE", OpName::kTake) + .value("ZIP", OpName::kZip) + .value("CONCAT", OpName::kConcat) + .value("MAP", OpName::kMap) + .value("FILTER", OpName::kFilter) + .value("DEVICEQUEUE", OpName::kDeviceQueue) + .value("GENERATOR", OpName::kGenerator) + .export_values() + .value("RENAME", OpName::kRename) + .value("TFREADER", OpName::kTfReader) + .value("PROJECT", OpName::kProject) + .value("IMAGEFOLDER", OpName::kImageFolder) + .value("MNIST", OpName::kMnist) + .value("MANIFEST", OpName::kManifest) + .value("VOC", OpName::kVoc) + .value("COCO", OpName::kCoco) + .value("CIFAR10", OpName::kCifar10) + .value("CIFAR100", OpName::kCifar100) + .value("RANDOMDATA", OpName::kRandomData) + .value("BUILDVOCAB", OpName::kBuildVocab) + .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) + .value("CELEBA", OpName::kCelebA) + .value("TEXTFILE", OpName::kTextFile) + .value("EPOCHCTRL", OpName::kEpochCtrl) + .value("CSV", OpName::kCsv) + .value("CLUE", OpName::kClue); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc new file mode 100644 index 0000000000..affc3600df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/api/python/de_pipeline.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(GlobalContext, 0, ([](const py::module *m) { + (void)py::class_(*m, "GlobalContext") + .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference); + })); + +PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) { + (void)py::class_>(*m, "ConfigManager") + .def("__str__", &ConfigManager::ToString) + .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) + .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) + .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) + .def("set_op_connector_size", &ConfigManager::set_op_connector_size) + .def("set_seed", &ConfigManager::set_seed) + .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) + .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) + .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) + .def("get_worker_connector_size", &ConfigManager::worker_connector_size) + .def("get_op_connector_size", &ConfigManager::op_connector_size) + .def("get_seed", &ConfigManager::seed) + .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) + .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); + })); + +PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) { + (void)py::class_>(*m, "Tensor", py::buffer_protocol()) + .def(py::init([](py::array arr) { + std::shared_ptr out; + THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out)); + return out; + })) + .def_buffer([](Tensor &tensor) { + py::buffer_info info; + THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); + return info; + }) + .def("__str__", &Tensor::ToString) + .def("shape", &Tensor::shape) + .def("type", &Tensor::type) + .def("as_array", [](py::object &t) { + auto &tensor = py::cast(t); + if (tensor.type() == DataType::DE_STRING) { + py::array res; + tensor.GetDataAsNumpyStrings(&res); + return res; + } + py::buffer_info info; + THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); + return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t); + }); + })); + +PYBIND_REGISTER(TensorShape, 0, ([](const py::module *m) { + (void)py::class_(*m, "TensorShape") + .def(py::init()) + .def("__str__", &TensorShape::ToString) + .def("as_list", &TensorShape::AsPyList) + .def("is_known", &TensorShape::known); + })); + +PYBIND_REGISTER(DataType, 0, ([](const py::module *m) { + (void)py::class_(*m, "DataType") + .def(py::init()) + .def(py::self == py::self) + .def("__str__", &DataType::ToString) + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); + })); + +PYBIND_REGISTER(BorderType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "BorderType", py::arithmetic()) + .value("DE_BORDER_CONSTANT", BorderType::kConstant) + .value("DE_BORDER_EDGE", BorderType::kEdge) + .value("DE_BORDER_REFLECT", BorderType::kReflect) + .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric) + .export_values(); + })); + +PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) { + (void)py::enum_(*m, "InterpolationMode", py::arithmetic()) + .value("DE_INTER_LINEAR", InterpolationMode::kLinear) + .value("DE_INTER_CUBIC", InterpolationMode::kCubic) + .value("DE_INTER_AREA", InterpolationMode::kArea) + .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) + .export_values(); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc new file mode 100644 index 0000000000..aa5ba9e561 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/engine/cache/cache_client.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { + (void)py::class_>(*m, "CacheClient") + .def(py::init()); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc new file mode 100644 index 0000000000..54844b8a2a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(CBatchInfo, 0, ([](const py::module *m) { + (void)py::class_(*m, "CBatchInfo") + .def(py::init()) + .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) + .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); + })); + +PYBIND_REGISTER(DatasetOp, 0, ([](const py::module *m) { + (void)py::class_>(*m, "DatasetOp"); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc new file mode 100644 index 0000000000..a44e538e37 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/api/python/pybind_register.h" + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "CifarOp") + .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { + int64_t count = 0; + THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); + return count; + }); + })); + +PYBIND_REGISTER(ClueOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "ClueOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); + } + THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); + return count; + }); + })); + +PYBIND_REGISTER(CsvOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "CsvOp") + .def_static("get_num_rows", [](const py::list &files, bool csv_header) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); + } + THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count)); + return count; + }); + })); +PYBIND_REGISTER(CocoOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "CocoOp") + .def_static("get_class_indexing", + [](const std::string &dir, const std::string &file, const std::string &task) { + std::vector>> output_class_indexing; + THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); + return output_class_indexing; + }) + .def_static("get_num_rows", + [](const std::string &dir, const std::string &file, const std::string &task) { + int64_t count = 0; + THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); + return count; + }); + })); + +PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "ImageFolderOp") + .def_static("get_num_rows_and_classes", [](const std::string &path) { + int64_t count = 0, num_classes = 0; + THROW_IF_ERROR( + ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); + return py::make_tuple(count, num_classes); + }); + })); + +PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "ManifestOp") + .def_static("get_num_rows_and_classes", + [](const std::string &file, const py::dict &dict, const std::string &usage) { + int64_t count = 0, num_classes = 0; + THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); + return py::make_tuple(count, num_classes); + }) + .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, + const std::string &usage) { + std::map output_class_indexing; + THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); + return output_class_indexing; + }); + })); +PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "MindRecordOp") + .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, + const py::object &sampler, const int64_t num_padded) { + int64_t count = 0; + std::shared_ptr op; + if (py::hasattr(sampler, "create_for_minddataset")) { + auto create = sampler.attr("create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); + return count; + }); + })); + +PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "MnistOp") + .def_static("get_num_rows", [](const std::string &dir) { + int64_t count = 0; + THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); + return count; + }); + })); + +PYBIND_REGISTER(TextFileOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "TextFileOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); + return count; + }); + })); + +PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "TFReaderOp") + .def_static( + "get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { + int64_t count = 0; + std::vector filenames; + for (auto l : files) { + !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); + return count; + }); + })); + +PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "VOCOp") + .def_static("get_num_rows", + [](const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t numSamples) { + int64_t count = 0; + THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); + return count; + }) + .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, + const std::string &task_mode, const py::dict &dict) { + std::map output_class_indexing; + THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); + return output_class_indexing; + }); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc new file mode 100644 index 0000000000..29aa0f12d6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) { + (void)py::class_>(*m, "Sampler") + .def("set_num_rows", + [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) + .def("set_num_samples", + [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) + .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) + .def("get_indices", + [](Sampler &self) { + py::array ret; + THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); + return ret; + }) + .def("add_child", [](std::shared_ptr self, std::shared_ptr child) { + THROW_IF_ERROR(self->AddChild(child)); + }); + })); + +PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DistributedSampler") + .def(py::init()); + })); + +PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) { + (void)py::class_>(*m, "PKSampler") + .def(py::init()); + })); + +PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) { + (void)py::class_>(*m, "PythonSampler") + .def(py::init()); + })); + +PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) { + (void)py::class_>(*m, "RandomSampler") + .def(py::init()); + })); + +PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) { + (void)py::class_>(*m, + "SequentialSampler") + .def(py::init()); + })); + +PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "SubsetRandomSampler") + .def(py::init>()); + })); + +PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "WeightedRandomSampler") + .def(py::init, bool>()); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc new file mode 100644 index 0000000000..18dcfb470a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" + +#include "minddata/dataset/engine/gnn/graph.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER( + Graph, 0, ([](const py::module *m) { + (void)py::class_>(*m, "Graph") + .def(py::init([](std::string dataset_file, int32_t num_workers) { + std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); + THROW_IF_ERROR(g_out->Init()); + return g_out; + })) + .def("get_all_nodes", + [](gnn::Graph &g, gnn::NodeType node_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); + return out; + }) + .def("get_all_edges", + [](gnn::Graph &g, gnn::EdgeType edge_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); + return out; + }) + .def("get_nodes_from_edges", + [](gnn::Graph &g, std::vector edge_list) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); + return out; + }) + .def("get_all_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); + return out; + }) + .def("get_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, + std::vector neighbor_types) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); + return out; + }) + .def("get_neg_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, + gnn::NodeType neg_neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); + return out; + }) + .def("get_node_feature", + [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); + return out.getRow(); + }) + .def("get_edge_feature", + [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); + return out.getRow(); + }) + .def("graph_info", + [](gnn::Graph &g) { + py::dict out; + THROW_IF_ERROR(g.GraphInfo(&out)); + return out; + }) + .def("random_walk", + [](gnn::Graph &g, std::vector node_list, std::vector meta_path, + float step_home_param, float step_away_param, gnn::NodeIdType default_node) { + std::shared_ptr out; + THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); + return out; + }); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/bindings.cc new file mode 100644 index 0000000000..32e52aed42 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/bindings.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/api/python/de_pipeline.h" + +#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h" +#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h" +#include "minddata/dataset/kernels/py_func_op.h" +#include "mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.h" +#include "mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.h" + +namespace mindspore { +namespace dataset { + +Status PyListToTensorOps(const py::list &py_ops, std::vector> *ops) { + RETURN_UNEXPECTED_IF_NULL(ops); + for (auto op : py_ops) { + if (py::isinstance(op)) { + ops->emplace_back(op.cast>()); + } else if (py::isinstance(op)) { + ops->emplace_back(std::make_shared(op.cast())); + } else { + RETURN_STATUS_UNEXPECTED("element is neither a TensorOp nor a pyfunc."); + } + } + CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty."); + for (auto const &op : *ops) { + RETURN_UNEXPECTED_IF_NULL(op); + } + return Status::OK(); +} + +PYBIND_REGISTER(TensorOp, 0, ([](const py::module *m) { + (void)py::class_>(*m, "TensorOp") + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); + })); + +PYBIND_REGISTER(ComposeOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "ComposeOp") + .def(py::init([](const py::list &ops) { + std::vector> t_ops; + THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops)); + return std::make_shared(t_ops); + })); + })); + +PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "NoOp", "TensorOp that does nothing, for testing purposes only.") + .def(py::init<>()); + })); + +PYBIND_REGISTER(RandomChoiceOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "RandomChoiceOp") + .def(py::init([](const py::list &ops) { + std::vector> t_ops; + THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops)); + return std::make_shared(t_ops); + })); + })); + +PYBIND_REGISTER(RandomApplyOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "RandomApplyOp") + .def(py::init([](double prob, const py::list &ops) { + std::vector> t_ops; + THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops)); + if (prob < 0 || prob > 1) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1].")); + } + return std::make_shared(prob, t_ops); + })); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc new file mode 100644 index 0000000000..3500885dac --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/data/bindings.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" +#include "minddata/dataset/kernels/data/duplicate_op.h" +#include "minddata/dataset/kernels/data/fill_op.h" +#include "minddata/dataset/kernels/data/mask_op.h" +#include "minddata/dataset/kernels/data/one_hot_op.h" +#include "minddata/dataset/kernels/data/pad_end_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" +#include "minddata/dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(ConcatenateOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ConcatenateOp", "Tensor operation concatenate tensors.") + .def(py::init, std::shared_ptr>(), py::arg("axis"), + py::arg("prepend").none(true), py::arg("append").none(true)); + })); + +PYBIND_REGISTER(DuplicateOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "DuplicateOp", + "Duplicate tensor.") + .def(py::init<>()); + })); + +PYBIND_REGISTER(FillOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") + .def(py::init>()); + })); + +PYBIND_REGISTER(MaskOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MaskOp", "Tensor mask operation using relational comparator") + .def(py::init, DataType>()); + })); + +PYBIND_REGISTER(OneHotOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") + .def(py::init()); + })); + +PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.") + .def(py::init>()); + })); + +PYBIND_REGISTER(SliceOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "SliceOp", + "Tensor slice operation.") + .def(py::init()) + .def(py::init([](const py::list &py_list) { + std::vector c_list; + for (auto l : py_list) { + if (!l.is_none()) { + c_list.push_back(py::reinterpret_borrow(l)); + } + } + return std::make_shared(c_list); + })) + .def(py::init([](const py::tuple &py_slice) { + if (py_slice.size() != 3) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + Slice c_slice; + if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), + py::reinterpret_borrow(py_slice[1]), + py::reinterpret_borrow(py_slice[2])); + } else if (py_slice[0].is_none() && py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[1])); + } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), + py::reinterpret_borrow(py_slice[1])); + } + + if (!c_slice.valid()) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + return std::make_shared(c_slice); + })); + })); + +PYBIND_REGISTER(ToFloat16Op, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ToFloat16Op", py::dynamic_attr(), + "Tensor operator to type cast float32 data to a float16 type.") + .def(py::init<>()); + })); + +PYBIND_REGISTER(TypeCastOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); + })); + +PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) { + (void)py::enum_(*m, "RelationalOp", py::arithmetic()) + .value("EQ", RelationalOp::kEqual) + .value("NE", RelationalOp::kNotEqual) + .value("LT", RelationalOp::kLess) + .value("LE", RelationalOp::kLessEqual) + .value("GT", RelationalOp::kGreater) + .value("GE", RelationalOp::kGreaterEqual) + .export_values(); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc new file mode 100644 index 0000000000..b2873cd5a4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc @@ -0,0 +1,346 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/kernels/py_func_op.h" +#include "minddata/dataset/kernels/image/auto_contrast_op.h" +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/equalize_op.h" +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/invert_op.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_resize_op.h" +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/rescale_op.h" +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/uniform_aug_op.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(AutoContrastOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "AutoContrastOp", "Tensor operation to apply autocontrast on an image.") + .def(py::init>(), py::arg("cutoff") = AutoContrastOp::kCutOff, + py::arg("ignore") = AutoContrastOp::kIgnore); + })); + +PYBIND_REGISTER(NormalizeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") + .def(py::init(), py::arg("meanR"), py::arg("meanG"), + py::arg("meanB"), py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); + })); + +PYBIND_REGISTER(EqualizeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.") + .def(py::init<>()); + })); + +PYBIND_REGISTER(InvertOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "InvertOp", "Tensor operation to apply invert on RGB images.") + .def(py::init<>()); + })); + +PYBIND_REGISTER(RescaleOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") + .def(py::init(), py::arg("rescale"), py::arg("shift")); + })); + +PYBIND_REGISTER(CenterCropOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "CenterCropOp", + "Tensor operation to crop and image in the middle. Takes height and width (optional)") + .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); + })); + +PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeOp::kDefWidth, + py::arg("interpolation") = ResizeOp::kDefInterpolation); + })); + +PYBIND_REGISTER(ResizeWithBBoxOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, + py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); + })); + +PYBIND_REGISTER( + RandomResizeWithBBoxOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomResizeWithBBoxOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); + })); +PYBIND_REGISTER(UniformAugOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") + .def(py::init>, int32_t>(), py::arg("transforms"), + py::arg("NumOps") = UniformAugOp::kDefNumOps); + })); +PYBIND_REGISTER(BoundingBoxAugmentOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BoundingBoxAugmentOp", + "Tensor operation to apply a transformation on a random choice of bounding boxes.") + .def(py::init, float>(), py::arg("transform"), + py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); + })); +PYBIND_REGISTER(ResizeBilinearOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ResizeBilinearOp", + "Tensor operation to resize an image using " + "Bilinear mode. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeBilinearOp::kDefWidth); + })); + +PYBIND_REGISTER(DecodeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DecodeOp", "Tensor operation to decode a jpg image") + .def(py::init<>()) + .def(py::init(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat); + })); + +PYBIND_REGISTER(RandomHorizontalFlipOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.") + .def(py::init(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability); + })); + +PYBIND_REGISTER( + RandomHorizontalFlipWithBBoxOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomHorizontalFlipWithBBoxOp", + "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.") + .def(py::init(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability); + })); +PYBIND_REGISTER(RandomVerticalFlipOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.") + .def(py::init(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability); + })); +PYBIND_REGISTER(RandomVerticalFlipWithBBoxOp, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "RandomVerticalFlipWithBBoxOp", + "Tensor operation to randomly flip an image vertically" + " and adjust bounding boxes.") + .def(py::init(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability); + })); +PYBIND_REGISTER( + RandomCropOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "RandomCropOp", + "Gives random crop of specified size " + "Takes crop size") + .def( + py::init(), + py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop, + py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft, + py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType, + py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR, + py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB); + })); +PYBIND_REGISTER( + HwcToChwOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "ChannelSwapOp").def(py::init<>()); + })); +PYBIND_REGISTER( + RandomCropWithBBoxOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomCropWithBBoxOp", + "Gives random crop of given " + "size + adjusts bboxes " + "Takes crop size") + .def( + py::init(), + py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop, + py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom, + py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft, + py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight, + py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType, + py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded, + py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG, + py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB); + })); +PYBIND_REGISTER(CutOutOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "CutOutOp", + "Tensor operation to randomly erase a portion of the image. Takes height and width.") + .def(py::init(), py::arg("boxHeight"), + py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor, + py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG, + py::arg("fillB") = CutOutOp::kDefFillB); + })); +PYBIND_REGISTER(PadOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "PadOp", + "Pads image with specified color, default black, " + "Takes amount to pad for top, bottom, left, right of image, boarder type and color") + .def(py::init(), + py::arg("padTop"), py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), + py::arg("borderTypes") = PadOp::kDefBorderType, py::arg("fillR") = PadOp::kDefFillR, + py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); + })); + +PYBIND_REGISTER(RandomCropDecodeResizeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding") + .def(py::init(), + py::arg("targetHeight"), py::arg("targetWidth"), + py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb, + py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation, + py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter); + })); + +PYBIND_REGISTER( + RandomResizeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomResizeOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); + })); + +PYBIND_REGISTER(RandomColorAdjustOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomColorAdjustOp", + "Tensor operation to adjust an image's color randomly." + "Takes range for brightness, contrast, saturation, hue and") + .def(py::init(), + py::arg("bright_factor_start"), py::arg("bright_factor_end"), py::arg("contrast_factor_start"), + py::arg("contrast_factor_end"), py::arg("saturation_factor_start"), + py::arg("saturation_factor_end"), py::arg("hue_factor_start"), py::arg("hue_factor_end")); + })); + +PYBIND_REGISTER(RandomCropAndResizeWithBBoxOp, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "RandomCropAndResizeWithBBoxOp", + "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size." + "Takes output height and width and" + "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," + "interpolation mode, and max attempts to crop") + .def(py::init(), + py::arg("targetHeight"), py::arg("targetWidth"), + py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb, + py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation, + py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter); + })); + +PYBIND_REGISTER(RandomCropAndResizeOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomCropAndResizeOp", + "Tensor operation to randomly crop an image and resize to a given size." + "Takes output height and width and" + "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," + "interpolation mode, and max attempts to crop") + .def(py::init(), + py::arg("targetHeight"), py::arg("targetWidth"), + py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb, + py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation, + py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter); + })); + +PYBIND_REGISTER(RandomRotationOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomRotationOp", + "Tensor operation to apply RandomRotation." + "Takes a range for degrees and " + "optional parameters for rotation center and image expand") + .def( + py::init(), + py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX, + py::arg("centerY") = RandomRotationOp::kDefCenterY, + py::arg("interpolation") = RandomRotationOp::kDefInterpolation, + py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR, + py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); + })); + +PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RandomSelectSubpolicyOp") + .def(py::init([](const py::list &py_policy) { + std::vector cpp_policy; + for (auto &py_sub : py_policy) { + cpp_policy.push_back({}); + for (auto handle : py_sub.cast()) { + py::tuple tp = handle.cast(); + if (tp.is_none() || tp.size() != 2) { + THROW_IF_ERROR( + Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob).")); + } + std::shared_ptr t_op; + if (py::isinstance(tp[0])) { + t_op = (tp[0]).cast>(); + } else if (py::isinstance(tp[0])) { + t_op = std::make_shared((tp[0]).cast()); + } else { + THROW_IF_ERROR( + Status(StatusCode::kUnexpectedError, "op is neither a tensorOp nor a pyfunc.")); + } + double prob = (tp[1]).cast(); + if (prob < 0 || prob > 1) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1].")); + } + cpp_policy.back().emplace_back(std::make_pair(t_op, prob)); + } + } + return std::make_shared(cpp_policy); + })); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/bindings.cc new file mode 100644 index 0000000000..244b3402ab --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/bindings.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/text/sentence_piece_vocab.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(Vocab, 0, ([](const py::module *m) { + (void)py::class_>(*m, "Vocab") + .def(py::init<>()) + .def_static("from_list", + [](const py::list &words, const py::list &special_tokens, bool special_first) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); + return v; + }) + .def_static( + "from_file", + [](const std::string &path, const std::string &dlm, int32_t vocab_size, + const py::list &special_tokens, bool special_first) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); + return v; + }) + .def_static("from_dict", [](const py::dict &words) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); + return v; + }); + })); + +PYBIND_REGISTER(SentencePieceVocab, 0, ([](const py::module *m) { + (void)py::class_>(*m, "SentencePieceVocab") + .def(py::init<>()) + .def_static("from_file", + [](const py::list &paths, const int vocab_size, const float character_coverage, + const SentencePieceModel model_type, const py::dict ¶ms) { + std::shared_ptr v; + std::vector path_list; + for (auto path : paths) { + path_list.emplace_back(py::str(path)); + } + std::unordered_map param_map; + for (auto param : params) { + std::string key = py::reinterpret_borrow(param.first); + if (key == "input" || key == "vocab_size" || key == "model_prefix" || + key == "character_coverage" || key == "model_type") { + continue; + } + param_map[key] = py::reinterpret_borrow(param.second); + } + THROW_IF_ERROR(SentencePieceVocab::BuildFromFile( + path_list, vocab_size, character_coverage, model_type, param_map, &v)); + return v; + }) + .def_static("save_model", [](const std::shared_ptr *vocab, std::string path, + std::string filename) { + THROW_IF_ERROR(SentencePieceVocab::SaveModel(vocab, path, filename)); + }); + })); + +PYBIND_REGISTER(SentencePieceModel, 0, ([](const py::module *m) { + (void)py::enum_(*m, "SentencePieceModel", py::arithmetic()) + .value("DE_SENTENCE_PIECE_UNIGRAM", SentencePieceModel::kUnigram) + .value("DE_SENTENCE_PIECE_BPE", SentencePieceModel::kBpe) + .value("DE_SENTENCE_PIECE_CHAR", SentencePieceModel::kChar) + .value("DE_SENTENCE_PIECE_WORD", SentencePieceModel::kWord) + .export_values(); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc new file mode 100644 index 0000000000..5949b0f358 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc @@ -0,0 +1,244 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" +#include "minddata/dataset/api/python/pybind_register.h" + +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" +#include "minddata/dataset/text/kernels/lookup_op.h" +#include "minddata/dataset/text/kernels/ngram_op.h" +#include "minddata/dataset/text/kernels/sliding_window_op.h" +#include "minddata/dataset/text/kernels/to_number_op.h" +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h" +#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h" + +#ifdef ENABLE_ICU4C +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/bert_tokenizer_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" +#endif + +namespace mindspore { +namespace dataset { + +#ifdef ENABLE_ICU4C + +PYBIND_REGISTER(BasicTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.") + .def(py::init(), + py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, + py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, + py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, + py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, + py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets); + })); + +PYBIND_REGISTER(WhitespaceTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.") + .def(py::init(), py::arg(" with_offsets ") = WhitespaceTokenizerOp::kDefWithOffsets); + })); + +PYBIND_REGISTER(UnicodeScriptTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "UnicodeScriptTokenizerOp", + "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.") + .def(py::init<>()) + .def(py::init(), + py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace, + py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets); + })); + +PYBIND_REGISTER(CaseFoldOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor") + .def(py::init<>()); + })); + +PYBIND_REGISTER(NormalizeUTF8Op, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.") + .def(py::init<>()) + .def(py::init(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm); + })); + +PYBIND_REGISTER(RegexReplaceOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RegexReplaceOp", + "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.") + .def(py::init(), py::arg("pattern"), + py::arg("replace"), py::arg("replace_all")); + })); + +PYBIND_REGISTER(RegexTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.") + .def(py::init(), py::arg("delim_pattern"), + py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets); + })); +PYBIND_REGISTER(BertTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BertTokenizerOp", "Tokenizer used for Bert text process.") + .def(py::init &, const std::string &, const int &, const std::string &, + const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(), + py::arg("vocab"), + py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), + py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, + py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), + py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, + py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, + py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, + py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, + py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); + })); + +PYBIND_REGISTER(NormalizeForm, 0, ([](const py::module *m) { + (void)py::enum_(*m, "NormalizeForm", py::arithmetic()) + .value("DE_NORMALIZE_NONE", NormalizeForm::kNone) + .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc) + .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc) + .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd) + .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd) + .export_values(); + })); + +#endif + +PYBIND_REGISTER(JiebaTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "JiebaTokenizerOp", "") + .def(py::init(), + py::arg("hmm_path"), py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix, + py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets) + .def("add_word", [](JiebaTokenizerOp &self, const std::string word, int freq) { + THROW_IF_ERROR(self.AddWord(word, freq)); + }); + })); + +PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") + .def(py::init(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets); + })); + +PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "LookupOp", "Tensor operation to LookUp each word.") + .def(py::init([](std::shared_ptr vocab, const py::object &py_word) { + if (vocab == nullptr) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); + } + if (py_word.is_none()) { + return std::make_shared(vocab, Vocab::kNoTokenExists); + } + std::string word = py::reinterpret_borrow(py_word); + WordIdType default_id = vocab->Lookup(word); + if (default_id == Vocab::kNoTokenExists) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, + "default unknown token: " + word + " doesn't exist in vocab.")); + } + return std::make_shared(vocab, default_id); + })); + })); + +PYBIND_REGISTER(NgramOp, 1, ([](const py::module *m) { + (void)py::class_>(*m, "NgramOp", + "TensorOp performs ngram mapping.") + .def(py::init &, int32_t, int32_t, const std::string &, + const std::string &, const std::string &>(), + py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), + py::arg("r_pad_token"), py::arg("separator")); + })); + +PYBIND_REGISTER( + WordpieceTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.") + .def( + py::init &, const std::string &, const int &, const std::string &, const bool &>(), + py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), + py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, + py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), + py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); + })); + +PYBIND_REGISTER(SlidingWindowOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.") + .def(py::init(), py::arg("width"), py::arg("axis")); + })); + +PYBIND_REGISTER( + SentencePieceTokenizerOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "SentencePieceTokenizerOp", "Tokenize scalar token or 1-D tokens to tokens by sentence piece.") + .def( + py::init &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(), + py::arg("vocab"), py::arg("load_type") = SPieceTokenizerLoadType::kModel, + py::arg("out_type") = SPieceTokenizerOutType::kString) + .def(py::init(), + py::arg("model_path"), py::arg("model_filename"), py::arg("load_type") = SPieceTokenizerLoadType::kFile, + py::arg("out_type") = SPieceTokenizerOutType::kString); + })); + +PYBIND_REGISTER(ToNumberOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ToNumberOp", "TensorOp to convert strings to numbers.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); + })); + +PYBIND_REGISTER(TruncateSequencePairOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") + .def(py::init()); + })); + +PYBIND_REGISTER(JiebaMode, 0, ([](const py::module *m) { + (void)py::enum_(*m, "JiebaMode", py::arithmetic()) + .value("DE_JIEBA_MIX", JiebaMode::kMix) + .value("DE_JIEBA_MP", JiebaMode::kMp) + .value("DE_JIEBA_HMM", JiebaMode::kHmm) + .export_values(); + })); + +PYBIND_REGISTER(SPieceTokenizerOutType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "SPieceTokenizerOutType", py::arithmetic()) + .value("DE_SPIECE_TOKENIZER_OUTTYPE_KString", SPieceTokenizerOutType::kString) + .value("DE_SPIECE_TOKENIZER_OUTTYPE_KINT", SPieceTokenizerOutType::kInt) + .export_values(); + })); + +PYBIND_REGISTER(SPieceTokenizerLoadType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "SPieceTokenizerLoadType", py::arithmetic()) + .value("DE_SPIECE_TOKENIZER_LOAD_KFILE", SPieceTokenizerLoadType::kFile) + .value("DE_SPIECE_TOKENIZER_LOAD_KMODEL", SPieceTokenizerLoadType::kModel) + .export_values(); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc new file mode 100644 index 0000000000..5c785fcd37 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" + +#include "minddata/dataset/util/random.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_pk_sample.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_sequential_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) { + (void)py::class_>( + *m, "ShardOperator") + .def("add_child", + [](std::shared_ptr self, + std::shared_ptr child) { self->SetChildOp(child); }); + })); + +PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) { + (void)py::class_>(*m, + "MindrecordDistributedSampler") + .def(py::init()); + })); + +PYBIND_REGISTER( + ShardPkSample, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MindrecordPkSampler") + .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { + if (shuffle == true) { + return std::make_shared(kColumn, kVal, std::numeric_limits::max(), + GetSeed()); + } else { + return std::make_shared(kColumn, kVal); + } + })); + })); + +PYBIND_REGISTER( + ShardSample, 0, ([](const py::module *m) { + (void)py::class_>( + *m, "MindrecordSubsetRandomSampler") + .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + })); + +PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) { + (void)py::class_>(*m, + "MindrecordSequentialSampler") + .def(py::init([](int num_samples, int start_index) { + return std::make_shared(num_samples, start_index); + })); + })); + +PYBIND_REGISTER( + ShardShuffle, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MindrecordRandomSampler") + .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { + return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); + })); + })); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc new file mode 100644 index 0000000000..408ed3b270 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -0,0 +1,1985 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/api/python/de_pipeline.h" + +#include +#include +#include + +#include "utils/ms_utils.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/kernels/py_func_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using json = nlohmann::json; +using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); + +static std::unordered_map g_parse_op_func_ = { + {kShuffle, &DEPipeline::ParseShuffleOp}, + {kMindrecord, &DEPipeline::ParseMindRecordOp}, + {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, + {kBatch, &DEPipeline::ParseBatchOp}, + {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, + {kRepeat, &DEPipeline::ParseRepeatOp}, + {kSkip, &DEPipeline::ParseSkipOp}, + {kZip, &DEPipeline::ParseZipOp}, + {kConcat, &DEPipeline::ParseConcatOp}, + {kRename, &DEPipeline::ParseRenameOp}, + {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, + {kGenerator, &DEPipeline::ParseGeneratorOp}, + {kTfReader, &DEPipeline::ParseTFReaderOp}, + {kProject, &DEPipeline::ParseProjectOp}, + {kTake, &DEPipeline::ParseTakeOp}, + {kImageFolder, &DEPipeline::ParseImageFolderOp}, + {kMnist, &DEPipeline::ParseMnistOp}, + {kManifest, &DEPipeline::ParseManifestOp}, + {kVoc, &DEPipeline::ParseVOCOp}, + {kCoco, &DEPipeline::ParseCocoOp}, + {kCifar10, &DEPipeline::ParseCifar10Op}, + {kCifar100, &DEPipeline::ParseCifar100Op}, + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kRandomData, &DEPipeline::ParseRandomDataOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}, + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, + {kClue, &DEPipeline::ParseClueOp}, + {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, + {kCsv, &DEPipeline::ParseCsvOp}, + {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; + +DEPipeline::DEPipeline() : iterator_(nullptr) { + try { + // One time init + (void)GlobalInit(); + + // Instantiate the execution tree + tree_ = std::make_shared(); + repeat_num_ = 1; + batch_size_ = 1; + num_rows_ = 0; + num_classes_ = 0; + temp_batch_size_ = 1; + temp_drop_remainder_ = false; + } catch (const std::exception &err) { + MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << "."; + return; + } +} + +DEPipeline::~DEPipeline() { + { + // Release GIL before joining all threads + py::gil_scoped_release gil_release; + // Release tree + tree_.reset(); + } +} + +// Function to add a Node to the Execution Tree. +Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { + // For each operator, Parse through the list of arguments, then call the respective builder/constructor. + // Note that each call to the parse function may result in building more than one dataset operator. + // For example, one call to ParseNNNOp may result in multiple internal C nodes: + // nodeA + // | + // nodeB + // | + // nodeC + // However, the python side dataset is more abstract, and it does not know about the potential subtree that + // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the + // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child + // to this subtee. + // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse + // function. + DsOpPtr top = nullptr; + DsOpPtr bottom = nullptr; + auto iter = g_parse_op_func_.find(op_name); + if (iter != g_parse_op_func_.end()) { + pFunction func = iter->second; + RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); + + if (top == nullptr) { + RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); + } + + // It is not required that the parse function always produces the bottom pointer. If it's still null, + // then set top and bottom to be the same operator + if (bottom == nullptr) bottom = top; + + // Pack these pointers into a py dict so that we can return both back to python. + (*output)["top"] = top; + (*output)["bottom"] = bottom; + } else { + RETURN_STATUS_UNEXPECTED("No such Op"); + } + // Associate current dataset op node with the tree. + RETURN_IF_NOT_OK(tree_->AssociateNode(top)); + return Status::OK(); +} +// Function to add a child and parent relationship. +Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) { + // Link this relationship. + // Note parent node takes ownership of the child + return (parent_op->AddChild(child_op)); +} + +// Function to assign the node as root. +Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } + +// Function to launch the tree execution. +Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) { + RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); + RETURN_IF_NOT_OK(tree_->Launch()); + iterator_ = std::make_unique(tree_); + if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); + return Status::OK(); +} + +void DEPipeline::PrintTree() { + for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { + std::stringstream ss; + ss << *itr; + MS_LOG(DEBUG) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << "."; + } +} + +Status DEPipeline::GetNextAsMap(py::dict *output) { + TensorMap row; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetNextAsMap(&row); + } + RETURN_IF_NOT_OK(s); + // Generate Python dict as return + for (auto el : row) { + (*output)[common::SafeCStr(el.first)] = el.second; + } + return Status::OK(); +} + +Status DEPipeline::GetNextAsList(py::list *output) { + TensorRow row; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->FetchNextTensorRow(&row); + } + RETURN_IF_NOT_OK(s); + // Generate Python list as return + for (auto el : row) { + output->append(el); + } + return Status::OK(); +} + +Status DEPipeline::GetOutputShapes(py::list *output) { + std::vector shapes; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetOutputShapes(&shapes); + } + RETURN_IF_NOT_OK(s); + for (auto el : shapes) { + py::list shape; + for (auto dim : el.AsVector()) { + shape.append(dim); + } + output->append(shape); + } + return Status::OK(); +} + +Status DEPipeline::GetOutputTypes(py::list *output) { + std::vector types; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetOutputTypes(&types); + } + RETURN_IF_NOT_OK(s); + for (auto el : types) { + output->append(el.AsNumpyType()); + } + return Status::OK(); +} + +int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; } + +int DEPipeline::GetBatchSize() const { return batch_size_; } + +int DEPipeline::GetRepeatCount() const { return repeat_num_; } + +float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +Status DEPipeline::StopSend() { + // tree_.root() must be DeviceQueueOp + DeviceQueueOp *op = dynamic_cast(tree_->root().get()); + if (op == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "StopSend only supported by DeviceQueueOp"); + } + op->StopSend(); + return Status::OK(); +} + +int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +std::string ToString(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +std::vector ToStringVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (!l.is_none()) + vector.push_back(py::str(l)); + else + vector.emplace_back(""); + } + return vector; +} + +std::set ToStringSet(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::set set; + for (auto l : list) { + if (!l.is_none()) { + (void)set.insert(py::str(l)); + } + } + return set; +} + +std::map ToStringMap(const py::handle handle) { + py::dict dict = py::reinterpret_borrow(handle); + std::map map; + for (auto p : dict) { + (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second))); + } + return map; +} + +std::vector ToIntVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (!l.is_none()) { + vector.push_back(ToInt(l)); + } + } + return vector; +} + +std::vector ToTypeVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (l.is_none()) { + vector.emplace_back(DataType()); + } else { + vector.push_back(l.cast()); + } + } + return vector; +} + +Status DEPipeline::SetBatchParameters(const py::dict &args) { + if (args["batch_size"].is_none()) { + std::string err_msg = "Error: batchSize is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + temp_batch_size_ = ToInt(args["batch_size"]); + CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid."); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "drop_remainder") { + temp_drop_remainder_ = ToBool(value); + } + } + } + + return Status::OK(); +} + +Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + if (!args["buffer_size"].is_none()) { + (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); + } else { + std::string err_msg = "Error: Shuffle buffer size is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "reshuffle_each_epoch") { + (void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::SaveDataset(const std::vector &file_names, const std::string &file_type) { + Status s; + auto mr_header = std::make_shared(); + auto mr_writer = std::make_unique(); + std::vector blob_fields; + uint64_t mr_schema_id = 0; + if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter."); + } + + TensorRow row; + std::unordered_map column_name_id_map; + for (auto el : iterator_->GetColumnNameMap()) { + std::string column_name = el.first; + std::transform(column_name.begin(), column_name.end(), column_name.begin(), + [](unsigned char c) { return ispunct(c) ? '_' : c; }); + column_name_id_map[column_name] = el.second; + } + bool first_loop = true; // build schema in first loop + do { + json row_raw_data; + std::map>> row_bin_data; + { + py::gil_scoped_release gil_release; + s = iterator_->FetchNextTensorRow(&row); + } + RETURN_IF_NOT_OK(s); + if (row.empty()) break; + if (first_loop) { + json mr_json; + std::vector index_fields; + s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); + RETURN_IF_NOT_OK(s); + MS_LOG(DEBUG) << "Schema of saved mindrecord: " << mr_json.dump(); + if (mindrecord::SUCCESS != + mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); + } + mr_writer->SetShardHeader(mr_header); + first_loop = false; + } + // construct data + if (!row.empty()) { // write data + s = FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data); + RETURN_IF_NOT_OK(s); + std::shared_ptr> output_bin_data; + mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data); + std::map> raw_data; + raw_data.insert(std::pair>(mr_schema_id, std::vector{row_raw_data})); + std::vector> bin_data; + if (nullptr != output_bin_data) { + bin_data.emplace_back(*output_bin_data); + } + mr_writer->WriteRawData(raw_data, bin_data); + } + } while (!row.empty()); + mr_writer->Commit(); + if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); + } + return Status::OK(); +} + +Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row, + const std::unordered_map &column_name_id_map, + json *row_raw_data, + std::map>> *row_bin_data) { + if (row_raw_data == nullptr) { + RETURN_STATUS_UNEXPECTED("error: row raw data is NULL."); + } + if (row_bin_data == nullptr) { + RETURN_STATUS_UNEXPECTED("error: row bin data is NULL."); + } + if (column_name_id_map.empty()) { + RETURN_STATUS_UNEXPECTED("Error: column not found"); + } + Status s; + for (auto &col : column_name_id_map) { + auto idx = col.second; + auto column_name = col.first; + auto &tensor = row[idx]; + auto column_type = tensor->type(); + + std::unique_ptr> data_ptr; + if (column_type == DataType::DE_INT8) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT16) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT16) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT8) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT32) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT32) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT64) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_FLOAT32) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_FLOAT64) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_STRING) { + std::string_view sv; + RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {0})); // assume scalar string tensor + std::string ss(sv); + (*row_raw_data)[column_name] = std::move(ss); + continue; + } else { + RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data."); + } + RETURN_IF_NOT_OK(s); + if (data_ptr != nullptr) { + (*row_bin_data)[column_name] = std::move(data_ptr); + } + } + return Status::OK(); +} + +template +Status DEPipeline::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, + std::unique_ptr *data, std::unique_ptr> *data_ptr, + std::unique_ptr *s, bool need_convert) { + if (nullptr == src) { + RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL."); + } + *data_ptr = std::make_unique>(num_of_elements * sizeof(T)); + if (need_convert) { + auto tmp_ptr = std::make_unique>(num_of_elements * sizeof(S)); + std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin()); + auto s_ptr = reinterpret_cast(&(*(tmp_ptr->begin()))); + auto el = std::make_unique(); + for (uint32_t i = 0; i < num_of_elements; ++i) { + *el = *(s_ptr + i); + auto t_ptr = reinterpret_cast(el.get()); + for (uint32_t j = 0; j < sizeof(T); ++j) { + *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j); + } + } + } else { + std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin()); + } + if (shape.empty()) { + *data = std::make_unique(); + auto t_ptr = reinterpret_cast((*data).get()); + for (uint32_t i = 0; i < sizeof(T); ++i) { + *(t_ptr + i) = *((*data_ptr)->begin() + i); + } + } + return Status::OK(); +} + +Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map &column_name_id_map, + const TensorRow &row, json *schema, std::vector *index_fields) { + if (schema == nullptr) { + RETURN_STATUS_UNEXPECTED("error: schema is NULL."); + } + if (index_fields == nullptr) { + RETURN_STATUS_UNEXPECTED("error: index fields is NULL."); + } + if (column_name_id_map.empty()) { + RETURN_STATUS_UNEXPECTED("Error: column not found."); + } + json dataset_schema; + for (auto &col : column_name_id_map) { + auto idx = col.second; + auto column_name = col.first; + auto &tensor = row[idx]; + auto column_type = tensor->type(); + auto column_shape = tensor->shape(); + + std::string mr_type; + auto shapes = column_shape.AsVector(); + std::vector mr_shape(shapes.begin(), shapes.end()); + std::string el = column_type.ToString(); + dataset_schema[column_name] = el; + if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) { + std::string err_msg("Error: can not support data type: " + el); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + mr_type = mindrecord::kTypesMap.at(el); + } + if (mr_shape.empty()) { + if (mr_type == "bytes") { // map to int32 when bytes without shape. + mr_type == "int32"; + } + (*schema)[column_name] = {{"type", mr_type}}; + } else { + if (mr_type == "string") { // mindrecord can not support string with shape. + std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (mr_type == "bytes") { // ignore shape of bytes in minrecord + (*schema)[column_name] = {{"type", mr_type}}; + } else { + (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}}; + } + } + if (mr_type == "bytes" || !mr_shape.empty()) continue; + index_fields->emplace_back(column_name); // candidate of index fields + } + MS_LOG(DEBUG) << "Schema of dataset: " << dataset_schema.dump(); + return Status::OK(); +} +Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded) { + auto sampler = py::reinterpret_borrow(handle); + auto create = sampler.attr("create_for_minddataset"); + auto op = create().cast>(); + std::stack> stack_ops; + while (op != nullptr) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (sampler_op && num_padded > 0) { + sampler_op->SetNumPaddedSamples(num_padded); + stack_ops.push(sampler_op); + } else { + stack_ops.push(op); + } + op = op->GetChildOp(); + } + while (!stack_ops.empty()) { + operators->push_back(stack_ops.top()); + stack_ops.pop(); + } + return Status::OK(); +} + +Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_file"].is_none()) { + std::string err_msg = "Error: at least one of dataset_files is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + bool load_dataset = ToBool(args["load_dataset"]); + if (load_dataset == true) { + (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); + } else { + (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); + } + (void)builder->SetLoadDataset(load_dataset); + std::vector in_col_names; + if (!args["columns_list"].is_none()) { + in_col_names = ToStringVector(args["columns_list"]); + if (in_col_names.empty() || in_col_names[0].empty()) { + std::string err_msg = "Error: columns_list is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetColumnsToLoad(in_col_names); + } + + if (!args["padded_sample"].is_none()) { + (void)builder->SetPaddedSample(args["padded_sample"]); + (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); + } + std::vector> operators; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumMindRecordWorkers(ToInt(value)); + } else if (key == "sampler") { + int num_padded = 0; + if (!args["num_padded"].is_none()) { + num_padded = ToInt(args["num_padded"]); + } + RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); + } + } + } + + if (!operators.empty()) { + (void)builder->SetOperators(operators); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + num_rows_ = op->num_rows(); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + MapOp::Builder map_builder; + std::vector> tensor_op_list; + std::vector project_columns; + std::shared_ptr cache_client = nullptr; + int num_workers = 0; + + if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)map_builder.SetInColNames(in_col_names); + } else if (key == "output_columns") { + (void)map_builder.SetOutColNames(ToStringVector(value)); + } else if (key == "columns_order") { + project_columns = ToStringVector(value); + } else if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)map_builder.SetNumWorkers(num_workers); + } else if (key == "prefetch_size") { + (void)map_builder.SetOpConnectorSize(ToInt(value)); + } else if (key == "operations") { + py::handle tensor_ops = args["operations"]; + // operation can be a list of TensorOps or a single TensorOp. + if (py::isinstance(tensor_ops)) { + for (auto op : tensor_ops) { + std::shared_ptr tensor_op; + if (py::isinstance(op)) { + tensor_op = op.cast>(); + } else if (py::isinstance(op)) { + tensor_op = std::make_shared(op.cast()); + } else { + RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); + } + tensor_op_list.push_back(tensor_op); + } + } + if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); + (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr map_op; + RETURN_IF_NOT_OK(map_builder.Build(&map_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); + *top = map_op; + + // Add a project op over top of the map if the user wanted to reposition the columns + if (!project_columns.empty()) { + ProjectOp::Builder proj_builder(project_columns); + std::shared_ptr proj_op; + RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); + RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); + *top = proj_op; + *bottom = map_op; + } + + // Additionally, add a cache if required. This will go over top of the project op if one + // was created, otherwise it goes over top of the map op + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); + *top = cache_op; + *bottom = map_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + + if (args["predicate"].is_none()) { + RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "predicate") { + py::handle op = args["predicate"]; + if (!py::isinstance(op)) { + RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); + } + py::function predicate_func = op.cast(); + (void)builder->SetPredicateFunc(std::move(predicate_func)); + } else if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)builder->SetInColNames(in_col_names); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + repeat_num_ = ToInt(args["count"]); + std::shared_ptr op; + RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseEpochCtrlOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(EpochCtrlOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "source") { + py::object obj = py::cast(&value); + if (!py::isinstance(obj)) { + std::string err_msg = "Error: generator is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetGeneratorFunction(obj.cast()); + } else if (key == "column_names") { + (void)builder->SetColumnNames(ToStringVector(value)); + } else if (key == "column_types") { + (void)builder->SetColumnTypes(ToTypeVector(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder; + if (py::isinstance(args["batch_size"])) { + batch_size_ = ToInt(args["batch_size"]); + CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); + builder = std::make_shared(ToInt(args["batch_size"])); + } else if (py::isinstance(args["batch_size"])) { + builder = std::make_shared(1); + (void)builder->SetBatchSizeFunc(args["batch_size"].cast()); + } else { + std::string err_msg = "Error: batch_size is neither an Integer nor a python function"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "drop_remainder") { + (void)builder->SetDrop(ToBool(value)); + } + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } + if (key == "per_batch_map") { + (void)builder->SetBatchMapFunc(value.cast()); + } + if (key == "input_columns") { + (void)builder->SetColumnsToMap(ToStringVector(value)); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPaddingMap(pad_info, true); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", + "bucket_batch_sizes"}; + for (auto name : mandatory_arguments) { + if (args[name.c_str()].is_none()) { + std::string err_msg = "Error: " + name + " is not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + std::shared_ptr builder = std::make_shared( + ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), + ToIntVector(args[mandatory_arguments[2].c_str()])); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "length_dependent_columns") { + (void)builder->SetLengthDependentColumns(ToStringVector(value)); + } + if (key == "bucket_boundaries") { + (void)builder->SetBucketBoundaries(ToIntVector(value)); + } + if (key == "bucket_batch_sizes") { + (void)builder->SetBucketBatchSizes(ToIntVector(value)); + } + if (key == "element_length_function") { + (void)builder->SetElementLengthFunction(value.cast()); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPadInfo(pad_info); + } + if (key == "pad_to_bucket_boundary") { + (void)builder->SetPadToBucketBoundary(ToBool(value)); + } + if (key == "drop_remainder") { + (void)builder->SetDropRemainder(ToBool(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + int32_t prefetch_size = 0; + if (args.contains("prefetch_size")) { + if (args["prefetch_size"].is_none()) { + prefetch_size = 16; + } else { + prefetch_size = ToInt(args["prefetch_size"]); + } + } + std::shared_ptr builder = std::make_shared(prefetch_size); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "queue_name") { + (void)builder->SetChannelName(ToString(value)); + } else if (key == "device_type") { + (void)builder->SetDeviceType(ToString(value)); + } else if (key == "device_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "send_epoch_end") { + (void)builder->SetSendEpochEnd(ToBool(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector in_col_names; + std::vector out_col_names; + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "input_columns") { + in_col_names = ToStringVector(value); + } else if (key == "output_columns") { + out_col_names = ToStringVector(value); + } + } + } + if (in_col_names.empty() || in_col_names[0].empty()) { + std::string err_msg = "Error: input_column_names is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (out_col_names.empty() || out_col_names[0].empty()) { + std::string err_msg = "Error: output_column_names is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetInColNames(in_col_names); + (void)builder->SetOutColNames(out_col_names); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetDatasetFilesList(files_list); + } else { + std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_load; + bool schema_exists = false; + bool shuffle_required = false; + int64_t num_devices = 0; + int64_t total_rows = 0; + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); + } else if (key == "columns_list") { + columns_to_load = ToStringVector(value); + (void)builder->SetColumnsToLoad(columns_to_load); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "schema_file_path" || key == "schema_json_string") { + schema_exists = true; + } else if (key == "num_samples") { + total_rows = ToInt(value); + (void)builder->setTotalRows(total_rows); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "shard_equal_rows") { + (void)builder->SetShardEqualRows(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } + } + } + if (schema_exists) { + std::unique_ptr schema = std::make_unique(); + if (args.contains("schema_file_path")) { + RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); + } else { + RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); + } + (void)builder->SetDataSchema(std::move(schema)); + } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because TFReaderOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + + std::shared_ptr tf_op; + RETURN_IF_NOT_OK(builder->Build(&tf_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); + *top = tf_op; + + if (!cache_client && shuffle_required) { + const boolean estimate = true; + const int64_t workers = 8; + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset via estimate and then compute the shuffle size + RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); + *top = shuffle_op; + *bottom = tf_op; + } + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + // Note, it is not allowed to have both shuffle and cache + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); + *top = cache_op; + *bottom = tf_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["columns"].is_none()) { + std::string err_msg = "Error: columns is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_project = ToStringVector(args["columns"]); + std::shared_ptr builder = std::make_shared(columns_to_project); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + int num_workers = 0; + std::shared_ptr cache_client = nullptr; + std::shared_ptr builder = std::make_shared(); + (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "extensions") { + (void)builder->SetExtensions(ToStringSet(value)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } + } + } + std::shared_ptr if_op; + RETURN_IF_NOT_OK(builder->Build(&if_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); + *top = if_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); + *top = cache_op; + *bottom = if_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_file"].is_none()) { + std::string err_msg = "Error: No dataset files specified for manifest"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr builder = std::make_shared(); + (void)builder->SetManifestFile(ToString(args["dataset_file"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "usage") { + (void)builder->SetUsage(ToString(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["task"].is_none()) { + std::string err_msg = "Error: No task specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["mode"].is_none()) { + std::string err_msg = "Error: No mode specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + (void)builder->SetTask(ToString(args["task"])); + (void)builder->SetMode(ToString(args["mode"])); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + + return Status::OK(); +} + +Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["annotation_file"].is_none()) { + std::string err_msg = "Error: No annotation_file specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["task"].is_none()) { + std::string err_msg = "Error: No task specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + (void)builder->SetFile(ToString(args["annotation_file"])); + (void)builder->SetTask(ToString(args["task"])); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetCifarDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + + (void)builder->SetCifarType(true); + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetCifarDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + + (void)builder->SetCifarType(false); + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + RandomDataOp::Builder builder; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + + if (args["total_rows"].is_none()) { + std::string err_msg = "Error: total_rows is a required argument"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_load; + bool schema_exists = false; + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder.SetNumWorkers(num_workers); + } else if (key == "schema_file_path" || key == "schema_json_string") { + schema_exists = true; + } else if (key == "columns_list") { + columns_to_load = ToStringVector(value); + } else if (key == "total_rows") { + // This is not sampling here. The random data op needs to know how much data to generate. + (void)builder.SetTotalRows(ToInt(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } + } + } + if (schema_exists) { + std::unique_ptr schema = std::make_unique(); + if (args.contains("schema_file_path")) { + RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); + } else { + RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); + } + (void)builder.SetDataSchema(std::move(schema)); + } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because RandomDataOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder.SetSampler(std::move(sampler)); + } else if (cache_client) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder.SetSampler(std::move(sampler)); + } + + std::shared_ptr random_op = nullptr; + RETURN_IF_NOT_OK(builder.Build(&random_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); + *top = random_op; + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); + *top = cache_op; + *bottom = random_op; + } + + return Status::OK(); +} + +int32_t DEPipeline::GetNumClasses() const { return num_classes_; } + +Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + + std::shared_ptr builder = std::make_shared(); + if (builder == nullptr) { + std::string err_msg = "Create celebaop builder failed"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + (void)builder->SetCelebADir(ToString(args["dataset_dir"])); + for (const auto &arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "extensions") { + (void)builder->SetExtensions(ToStringSet(value)); + } else if (key == "dataset_type") { + (void)builder->SetDatasetType(ToString(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetTextFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetTotalRows(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } + } + } + + std::shared_ptr txt_op; + RETURN_IF_NOT_OK(builder->Build(&txt_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); + *top = txt_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); + *top = shuffle_op; + *bottom = txt_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + auto tp = py::reinterpret_borrow(p.second); + CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); + TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); + std::shared_ptr pad_val = nullptr; + if (py::isinstance(tp[1])) { + std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); + CHECK_FAIL_RETURN_UNEXPECTED( + Tensor::CreateFromVector(std::vector{pad_val_string}, TensorShape::CreateScalar(), &pad_val), + "Cannot create pad_value Tensor"); + } else { + float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); + CHECK_FAIL_RETURN_UNEXPECTED( + Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val), + "Cannot create pad_value Tensor"); + pad_val->SetItemAt({}, pad_val_float); + } + (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); + } else { // tuple is None + (void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}}); + } + } + return Status::OK(); +} + +Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "freq_range") { + py::tuple tp = py::reinterpret_borrow(value); + if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); + if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); + } else if (key == "top_k") { + builder->SetTopK(py::reinterpret_borrow(value)); + } else if (key == "columns") { + (void)builder->SetColumnNames(ToStringVector(value)); + } else if (key == "vocab") { + (void)builder->SetVocab(value.cast>()); + } else if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "special_first") { + (void)builder->SetSpecialFirst(ToBool(value)); + } else if (key == "special_tokens") { + (void)builder->SetSpecialTokens(ToStringVector(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "vocab_size") { + builder->SetVocabSize(ToInt(value)); + } else if (key == "character_coverage") { + (void)builder->SetCharacterCoverage(ToFloat(value)); + } else if (key == "params") { + std::unordered_map params; + for (auto param : py::reinterpret_borrow(value)) { + std::string param_key = py::reinterpret_borrow(param.first); + if (param_key == "input" || param_key == "vocab_size" || param_key == "model_prefix" || + param_key == "character_coverage" || param_key == "model_type") { + continue; + } + params[param_key] = py::reinterpret_borrow(param.second); + } + (void)builder->SetParams(params); + } else if (key == "vocab") { + (void)builder->SetVocab(value.cast>()); + } else if (key == "model_type") { + (void)builder->SetModelType(value.cast()); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetClueFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "cols_to_keyword") { + std::map map_dict; + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + map_dict.insert({ToString(p.first), ToString(p.second)}); + } else { + map_dict.insert({ToString(p.first), ToString(p.first)}); + } + } + (void)builder->SetColsKeyMap(map_dict); + } + } + } + + std::shared_ptr clue_op; + RETURN_IF_NOT_OK(builder->Build(&clue_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); + *top = clue_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); + *top = shuffle_op; + *bottom = clue_op; + } + + return Status::OK(); +} + +// Helper function to inject the cache operator over top of the current operation being built. +Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num_workers, + std::shared_ptr input_op, std::shared_ptr *cache_op) { + std::shared_ptr new_cache_op = nullptr; + CacheOp::Builder cache_builder; + // use the same number of workers as the leaf. We need some optimization here, the user does not + // give the cache op number of workers directly. + if (num_workers != 0) { + (void)cache_builder.SetNumWorkers(num_workers); + } + (void)cache_builder.SetClient(cache_client); + RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); + RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); + // We have now created: + // + // CacheOp + // | + // input_op + // + *cache_op = new_cache_op; + + return Status::OK(); +} + +Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetCsvFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + std::vector col_names; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "field_delim") { + (void)builder->SetFieldDelim(ToString(value)[0]); + } else if (key == "column_defaults") { + py::list py_object_list = py::reinterpret_borrow(value); + std::vector> column_default_list; + for (auto l : py_object_list) { + std::string type_s = (std::string)py::str(l.get_type().attr("__name__")); + if (type_s == "int") { + column_default_list.push_back(std::make_shared>(CsvOp::INT, ToInt(l))); + } else if (type_s == "float") { + column_default_list.push_back(std::make_shared>(CsvOp::FLOAT, ToFloat(l))); + } else if (type_s == "str") { + column_default_list.push_back(std::make_shared>(CsvOp::STRING, ToString(l))); + } else { + RETURN_STATUS_UNEXPECTED("Record type is not allowed"); + } + } + (void)builder->SetColumDefault(column_default_list); + } else if (key == "column_names") { + col_names = ToStringVector(value); + (void)builder->SetColumName(col_names); + } + } + } + + std::shared_ptr csv_op; + RETURN_IF_NOT_OK(builder->Build(&csv_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); + *top = csv_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op)); + *top = shuffle_op; + *bottom = csv_op; + } + + return Status::OK(); +} + +// Helper function to inject a shuffle operator over top of the current operation being built. +Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op) { + std::shared_ptr new_shuffle_op = nullptr; + ShuffleOp::Builder shuffle_builder; + + (void)shuffle_builder.SetShuffleSize(shuffle_size); + RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); + RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); + // We have now created: + // + // ShuffleOp + // | + // input_op + // + *shuffle_op = new_shuffle_op; + + return Status::OK(); +} + +// Common code for computing a default shuffle size +Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size) { + const int64_t average_files_multiplier = 4; + const int64_t shuffle_max = 10000; + int64_t avg_rows_per_file = 0; + + // Adjust the num rows per shard if sharding was given + if (num_devices > 0) { + if (num_rows % num_devices == 0) { + num_rows = num_rows / num_devices; + } else { + num_rows = (num_rows / num_devices) + 1; + } + } + + // Cap based on total rows directive. Some ops do not have this and give value of 0. + if (total_rows > 0) { + num_rows = std::min(num_rows, total_rows); + } + + // get the average per file + avg_rows_per_file = num_rows / num_files; + + *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h new file mode 100644 index 0000000000..80d524982a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h @@ -0,0 +1,254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/client.h" // DE client +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/util/status.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +using json = nlohmann::json; +using DsOpPtr = std::shared_ptr; + +class CacheClient; + +// enum for the dataset operator names +enum OpName { + kShuffle, + kMindrecord, + kBatch, + kBucketBatch, + kBarrier, + kCache, + kRepeat, + kSkip, + kTake, + kZip, + kConcat, + kMap, + kFilter, + kDeviceQueue, + kGenerator, + kRename, + kTfReader, + kProject, + kImageFolder, + kMnist, + kManifest, + kVoc, + kCoco, + kCifar10, + kCifar100, + kCelebA, + kRandomData, + kTextFile, + kBuildVocab, + kClue, + kEpochCtrl, + kSentencePieceVocab, + kCsv +}; + +// The C++ binder class that we expose to the python script. +class DEPipeline { + public: + DEPipeline(); + + ~DEPipeline(); + + // Function to add a Node to the Execution Tree. + Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); + + // Function to add a child and parent relationship. + static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); + + // Function to assign the node as root. + Status AssignRootNode(const DsOpPtr &dataset_op); + + // Function to launch the tree execution. + Status LaunchTreeExec(int32_t num_epochs); + + // Get a row of data as dictionary of column name to the value. + Status GetNextAsMap(py::dict *output); + + // Get a row of data as list. + Status GetNextAsList(py::list *output); + + Status GetOutputShapes(py::list *output); + + Status GetOutputTypes(py::list *output); + + Status SaveDataset(const std::vector &file_names, const std::string &file_type); + + int GetDatasetSize() const; + + int GetBatchSize() const; + + int GetRepeatCount() const; + + Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + template + Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, + std::unique_ptr *data, std::unique_ptr> *data_ptr, + std::unique_ptr *s, bool need_convert = false); + + Status FetchMetaFromTensorRow(const std::unordered_map &column_name_id_map, + const TensorRow &row, json *schema, std::vector *index_fields); + + Status FetchDataFromTensorRow(const TensorRow &row, + const std::unordered_map &column_name_id_map, json *row_raw_data, + std::map>> *row_bin_data); + + Status BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded); + + Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom); + + Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + void PrintTree(); + + int32_t GetNumClasses() const; + + Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status SetBatchParameters(const py::dict &args); + + Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status StopSend(); + Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom); + + Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCsvOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + private: + // Execution tree that links the dataset operators. + std::shared_ptr tree_; + + std::unique_ptr iterator_; + + static Status ParsePadInfo(py::handle value, PadInfo *pad_info); + + /// \brief Helper function to inject a cache operator over top of the current operation being built. + /// \param[in] cache_client The client to use for caching + /// \param[in] num_workers The number of workers to use in the cache op + /// \param[in] input_op The operator to build the cache on top of + /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the cache operator + /// \return Status return code + Status AddCacheOp(std::shared_ptr cache_client, int num_workers, std::shared_ptr input_op, + std::shared_ptr *cache_op); + + /// \brief Helper function to inject a shuffle operator over top of the current operation being built. + /// \param[in] shuffle_size The size to use in the shuffle buffer + /// \param[in] input_op The operator to build shuffle on top of + /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the shuffle operator + /// \return Status return code + Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op); + + /// \brief Helper function to compute the shuffle size + /// \param[in] num_files The number of files in the dataset + /// \param[in] num_devices The number of devices in the dataset + /// \param[in] num_rows The number of rows in the dataset + /// \param[in] total_rows An upper bound on the total rows in the dataset + /// \param[out] shuffle_size The resultant computed shuffle size + /// \return Status return code + Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size); + + int batch_size_; + int repeat_num_; + int num_rows_; + int num_classes_; + + int temp_batch_size_; + bool temp_drop_remainder_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.cc b/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.cc new file mode 100644 index 0000000000..11a520be5e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/api/python/pybind_register.h" + +namespace mindspore { +namespace dataset { + +PybindDefinedFunctionRegister &PybindDefinedFunctionRegister::GetSingleton() { + static PybindDefinedFunctionRegister instance; + return instance; +} + +// This is where we externalize the C logic as python modules +PYBIND11_MODULE(_c_dataengine, m) { + m.doc() = "pybind11 for _c_dataengine"; + + auto all_fns = mindspore::dataset::PybindDefinedFunctionRegister::AllFunctions(); + + for (auto &item : all_fns) { + for (auto &func : item.second) { + func.second(&m); + } + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h b/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h new file mode 100644 index 0000000000..8717a2844c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h @@ -0,0 +1,81 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef API_PYBIND_API_H_ +#define API_PYBIND_API_H_ + +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; +namespace mindspore { + +namespace dataset { +#define THROW_IF_ERROR(s) \ + do { \ + Status rc = std::move(s); \ + if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ + } while (false) + +using PybindDefineFunc = std::function; + +class PybindDefinedFunctionRegister { + public: + static void Register(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) { + return GetSingleton().RegisterFn(name, priority, fn); + } + + PybindDefinedFunctionRegister(const PybindDefinedFunctionRegister &) = delete; + + PybindDefinedFunctionRegister &operator=(const PybindDefinedFunctionRegister &) = delete; + + static std::map> &AllFunctions() { + return GetSingleton().module_fns_; + } + std::map> module_fns_; + + protected: + PybindDefinedFunctionRegister() = default; + + virtual ~PybindDefinedFunctionRegister() = default; + + static PybindDefinedFunctionRegister &GetSingleton(); + + void RegisterFn(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) { + module_fns_[priority][name] = fn; + } +}; + +class PybindDefineRegisterer { + public: + PybindDefineRegisterer(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) { + PybindDefinedFunctionRegister::Register(name, priority, fn); + } + ~PybindDefineRegisterer() = default; +}; + +#ifdef ENABLE_PYTHON +#define PYBIND_REGISTER(name, priority, define) PybindDefineRegisterer g_pybind_define_f_##name(#name, priority, define) +#endif +} // namespace dataset +} // namespace mindspore +#endif // API_PYBIND_API_H_ diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc deleted file mode 100644 index 145291ec3b..0000000000 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ /dev/null @@ -1,954 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "minddata/dataset/api/de_pipeline.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" -#include "minddata/dataset/engine/datasetops/source/clue_op.h" -#include "minddata/dataset/engine/datasetops/source/coco_op.h" -#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" -#include "minddata/dataset/engine/datasetops/source/manifest_op.h" -#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/random_data_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "minddata/dataset/engine/datasetops/source/text_file_op.h" -#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" -#include "minddata/dataset/engine/datasetops/source/voc_op.h" -#include "minddata/dataset/engine/cache/cache_client.h" -#include "minddata/dataset/engine/gnn/graph.h" -#include "minddata/dataset/engine/jagged_connector.h" -#include "minddata/dataset/kernels/data/concatenate_op.h" -#include "minddata/dataset/kernels/data/duplicate_op.h" -#include "minddata/dataset/kernels/data/fill_op.h" -#include "minddata/dataset/kernels/data/mask_op.h" -#include "minddata/dataset/kernels/data/one_hot_op.h" -#include "minddata/dataset/kernels/data/pad_end_op.h" -#include "minddata/dataset/kernels/data/slice_op.h" -#include "minddata/dataset/kernels/data/to_float16_op.h" -#include "minddata/dataset/kernels/data/type_cast_op.h" -#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" -#include "minddata/dataset/kernels/image/center_crop_op.h" -#include "minddata/dataset/kernels/image/cut_out_op.h" -#include "minddata/dataset/kernels/image/decode_op.h" -#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" -#include "minddata/dataset/kernels/image/image_utils.h" -#include "minddata/dataset/kernels/image/normalize_op.h" -#include "minddata/dataset/kernels/image/pad_op.h" -#include "minddata/dataset/kernels/image/random_color_adjust_op.h" -#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" -#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" -#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" -#include "minddata/dataset/kernels/image/random_crop_op.h" -#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" -#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" -#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" -#include "minddata/dataset/kernels/image/random_resize_op.h" -#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" -#include "minddata/dataset/kernels/image/random_rotation_op.h" -#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" -#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" -#include "minddata/dataset/kernels/image/rescale_op.h" -#include "minddata/dataset/kernels/image/resize_bilinear_op.h" -#include "minddata/dataset/kernels/image/resize_op.h" -#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" -#include "minddata/dataset/kernels/image/uniform_aug_op.h" -#include "minddata/dataset/kernels/no_op.h" -#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" -#include "minddata/dataset/text/kernels/lookup_op.h" -#include "minddata/dataset/text/kernels/ngram_op.h" -#include "minddata/dataset/text/kernels/to_number_op.h" -#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" -#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" -#include "minddata/dataset/text/vocab.h" -#include "minddata/dataset/util/random.h" -#include "minddata/mindrecord/include/shard_distributed_sample.h" -#include "minddata/mindrecord/include/shard_operator.h" -#include "minddata/mindrecord/include/shard_pk_sample.h" -#include "minddata/mindrecord/include/shard_sample.h" -#include "minddata/mindrecord/include/shard_sequential_sample.h" -#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "pybind11/stl_bind.h" - -#ifdef ENABLE_ICU4C -#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" -#include "minddata/dataset/text/kernels/bert_tokenizer_op.h" -#include "minddata/dataset/text/kernels/case_fold_op.h" -#include "minddata/dataset/text/kernels/normalize_utf8_op.h" -#include "minddata/dataset/text/kernels/regex_replace_op.h" -#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" -#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" -#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" -#endif - -namespace py = pybind11; - -namespace mindspore { -namespace dataset { -#define THROW_IF_ERROR(s) \ - do { \ - Status rc = std::move(s); \ - if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ - } while (false) - -void bindDEPipeline(py::module *m) { - (void)py::class_(*m, "DEPipeline") - .def(py::init<>()) - .def( - "AddNodeToTree", - [](DEPipeline &de, const OpName &op_name, const py::dict &args) { - py::dict out; - THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); - return out; - }, - py::return_value_policy::reference) - .def_static("AddChildToParentNode", - [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { - THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); - }) - .def("AssignRootNode", - [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) - .def("SetBatchParameters", - [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) - .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) - .def("GetNextAsMap", - [](DEPipeline &de) { - py::dict out; - THROW_IF_ERROR(de.GetNextAsMap(&out)); - return out; - }) - .def("GetNextAsList", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetNextAsList(&out)); - return out; - }) - .def("GetOutputShapes", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetOutputShapes(&out)); - return out; - }) - .def("GetOutputTypes", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetOutputTypes(&out)); - return out; - }) - .def("GetDatasetSize", &DEPipeline::GetDatasetSize) - .def("GetBatchSize", &DEPipeline::GetBatchSize) - .def("GetNumClasses", &DEPipeline::GetNumClasses) - .def("GetRepeatCount", &DEPipeline::GetRepeatCount); -} -void bindDatasetOps(py::module *m) { - (void)py::class_>(*m, "TFReaderOp") - .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { - int64_t count = 0; - std::vector filenames; - for (auto l : files) { - !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); - } - THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); - return count; - }); - - (void)py::class_>(*m, "CifarOp") - .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { - int64_t count = 0; - THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); - return count; - }); - - (void)py::class_>(*m, "ImageFolderOp") - .def_static("get_num_rows_and_classes", [](const std::string &path) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }); - - (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, - const int64_t num_padded) { - int64_t count = 0; - std::shared_ptr op; - if (py::hasattr(sampler, "create_for_minddataset")) { - auto create = sampler.attr("create_for_minddataset"); - op = create().cast>(); - } - THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); - return count; - }); - - (void)py::class_>(*m, "ManifestOp") - .def_static("get_num_rows_and_classes", - [](const std::string &file, const py::dict &dict, const std::string &usage) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }) - .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { - std::map output_class_indexing; - THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); - return output_class_indexing; - }); - - (void)py::class_>(*m, "MnistOp") - .def_static("get_num_rows", [](const std::string &dir) { - int64_t count = 0; - THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); - return count; - }); - - (void)py::class_>(*m, "TextFileOp") - .def_static("get_num_rows", [](const py::list &files) { - int64_t count = 0; - std::vector filenames; - for (auto file : files) { - !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); - } - THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); - return count; - }); - - (void)py::class_>(*m, "ClueOp") - .def_static("get_num_rows", [](const py::list &files) { - int64_t count = 0; - std::vector filenames; - for (auto file : files) { - file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); - } - THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); - return count; - }); - - (void)py::class_>(*m, "VOCOp") - .def_static("get_num_rows", - [](const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples) { - int64_t count = 0; - THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); - return count; - }) - .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, - const std::string &task_mode, const py::dict &dict) { - std::map output_class_indexing; - THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); - return output_class_indexing; - }); - (void)py::class_>(*m, "CocoOp") - .def_static("get_class_indexing", - [](const std::string &dir, const std::string &file, const std::string &task) { - std::vector>> output_class_indexing; - THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); - return output_class_indexing; - }) - .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) { - int64_t count = 0; - THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); - return count; - }); -} -void bindTensor(py::module *m) { - (void)py::class_(*m, "GlobalContext") - .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference); - - (void)py::class_>(*m, "ConfigManager") - .def("__str__", &ConfigManager::ToString) - .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) - .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) - .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) - .def("set_op_connector_size", &ConfigManager::set_op_connector_size) - .def("set_seed", &ConfigManager::set_seed) - .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) - .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) - .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) - .def("get_worker_connector_size", &ConfigManager::worker_connector_size) - .def("get_op_connector_size", &ConfigManager::op_connector_size) - .def("get_seed", &ConfigManager::seed) - .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) - .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); - - (void)py::class_>(*m, "Tensor", py::buffer_protocol()) - .def(py::init([](py::array arr) { - std::shared_ptr out; - THROW_IF_ERROR(Tensor::CreateTensor(&out, arr)); - return out; - })) - .def_buffer([](Tensor &tensor) { - py::buffer_info info; - THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); - return info; - }) - .def("__str__", &Tensor::ToString) - .def("shape", &Tensor::shape) - .def("type", &Tensor::type) - .def("as_array", [](py::object &t) { - auto &tensor = py::cast(t); - if (tensor.type() == DataType::DE_STRING) { - py::array res; - tensor.GetDataAsNumpyStrings(&res); - return res; - } - py::buffer_info info; - THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); - return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t); - }); - - (void)py::class_(*m, "TensorShape") - .def(py::init()) - .def("__str__", &TensorShape::ToString) - .def("as_list", &TensorShape::AsPyList) - .def("is_known", &TensorShape::known); - - (void)py::class_(*m, "DataType") - .def(py::init()) - .def(py::self == py::self) - .def("__str__", &DataType::ToString) - .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); -} - -void bindTensorOps1(py::module *m) { - (void)py::class_>(*m, "TensorOp") - .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); - - (void)py::class_>( - *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") - .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), - py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); - - (void)py::class_>( - *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") - .def(py::init(), py::arg("rescale"), py::arg("shift")); - - (void)py::class_>( - *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)") - .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); - - (void)py::class_>( - *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); - - (void)py::class_>( - *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, - py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); - - (void)py::class_>( - *m, "RandomResizeWithBBoxOp", - "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); - - (void)py::class_>( - *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") - .def(py::init>, int32_t>(), py::arg("operations"), - py::arg("NumOps") = UniformAugOp::kDefNumOps); - - (void)py::class_>( - *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") - .def(py::init, float>(), py::arg("transform"), - py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); - - (void)py::class_>( - *m, "ResizeBilinearOp", - "Tensor operation to resize an image using " - "Bilinear mode. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth); - - (void)py::class_>(*m, "DecodeOp", - "Tensor operation to decode a jpg image") - .def(py::init<>()) - .def(py::init(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat); - - (void)py::class_>( - *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.") - .def(py::init(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability); - - (void)py::class_>( - *m, "RandomHorizontalFlipWithBBoxOp", - "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.") - .def(py::init(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability); -} - -void bindTensorOps2(py::module *m) { - (void)py::class_>( - *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.") - .def(py::init(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability); - - (void)py::class_>( - *m, "RandomVerticalFlipWithBBoxOp", - "Tensor operation to randomly flip an image vertically" - " and adjust bounding boxes.") - .def(py::init(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability); - - (void)py::class_>(*m, "RandomCropOp", - "Gives random crop of specified size " - "Takes crop size") - .def(py::init(), - py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop, - py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft, - py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType, - py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR, - py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB); - (void)py::class_>(*m, "ChannelSwapOp").def(py::init<>()); - - (void)py::class_>(*m, "RandomCropWithBBoxOp", - "Gives random crop of given " - "size + adjusts bboxes " - "Takes crop size") - .def(py::init(), - py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop, - py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom, - py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft, - py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight, - py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType, - py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded, - py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG, - py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB); - - (void)py::class_>( - *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") - .def(py::init()); - - (void)py::class_>( - *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") - .def(py::init>()); - - (void)py::class_>(*m, "SliceOp", "Tensor slice operation.") - .def(py::init()) - .def(py::init([](const py::list &py_list) { - std::vector c_list; - for (auto l : py_list) { - if (!l.is_none()) { - c_list.push_back(py::reinterpret_borrow(l)); - } - } - return std::make_shared(c_list); - })) - .def(py::init([](const py::tuple &py_slice) { - if (py_slice.size() != 3) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); - } - Slice c_slice; - if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1]), - py::reinterpret_borrow(py_slice[2])); - } else if (py_slice[0].is_none() && py_slice[2].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[1])); - } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1])); - } - - if (!c_slice.valid()) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); - } - return std::make_shared(c_slice); - })); - - (void)py::enum_(*m, "RelationalOp", py::arithmetic()) - .value("EQ", RelationalOp::kEqual) - .value("NE", RelationalOp::kNotEqual) - .value("LT", RelationalOp::kLess) - .value("LE", RelationalOp::kLessEqual) - .value("GT", RelationalOp::kGreater) - .value("GE", RelationalOp::kGreaterEqual) - .export_values(); - - (void)py::class_>(*m, "MaskOp", - "Tensor mask operation using relational comparator") - .def(py::init, DataType>()); - - (void)py::class_>(*m, "DuplicateOp", "Duplicate tensor.") - .def(py::init<>()); - - (void)py::class_>( - *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") - .def(py::init()); - - (void)py::class_>(*m, "ConcatenateOp", - "Tensor operation concatenate tensors.") - .def(py::init, std::shared_ptr>(), py::arg("axis"), - py::arg("prepend").none(true), py::arg("append").none(true)); - - (void)py::class_>( - *m, "RandomRotationOp", - "Tensor operation to apply RandomRotation." - "Takes a range for degrees and " - "optional parameters for rotation center and image expand") - .def(py::init(), - py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX, - py::arg("centerY") = RandomRotationOp::kDefCenterY, - py::arg("interpolation") = RandomRotationOp::kDefInterpolation, - py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR, - py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); - - (void)py::class_>( - *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.") - .def(py::init>()); -} - -void bindTensorOps3(py::module *m) { - (void)py::class_>( - *m, "RandomCropAndResizeOp", - "Tensor operation to randomly crop an image and resize to a given size." - "Takes output height and width and" - "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," - "interpolation mode, and max attempts to crop") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb, - py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation, - py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter); - - (void)py::class_>( - *m, "RandomCropAndResizeWithBBoxOp", - "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size." - "Takes output height and width and" - "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," - "interpolation mode, and max attempts to crop") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb, - py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation, - py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter); - - (void)py::class_>( - *m, "RandomColorAdjustOp", - "Tensor operation to adjust an image's color randomly." - "Takes range for brightness, contrast, saturation, hue and") - .def(py::init(), py::arg("bright_factor_start"), - py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"), - py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"), - py::arg("hue_factor_end")); - - (void)py::class_>( - *m, "RandomResizeOp", - "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); - - (void)py::class_>( - *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.") - .def(py::init(), py::arg("boxHeight"), - py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor, - py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG, - py::arg("fillB") = CutOutOp::kDefFillB); -} - -void bindTensorOps4(py::module *m) { - (void)py::class_>( - *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") - .def(py::init(), py::arg("data_type")) - .def(py::init(), py::arg("data_type")); - - (void)py::class_>(*m, "NoOp", - "TensorOp that does nothing, for testing purposes only.") - .def(py::init<>()); - - (void)py::class_>( - *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.") - .def(py::init<>()); - - (void)py::class_>( - *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb, - py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation, - py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter); - - (void)py::class_>( - *m, "PadOp", - "Pads image with specified color, default black, " - "Takes amount to pad for top, bottom, left, right of image, boarder type and color") - .def(py::init(), py::arg("padTop"), - py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType, - py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); - (void)py::class_>(*m, "ToNumberOp", - "TensorOp to convert strings to numbers.") - .def(py::init(), py::arg("data_type")) - .def(py::init(), py::arg("data_type")); -} - -void bindTokenizerOps(py::module *m) { - (void)py::class_>(*m, "JiebaTokenizerOp", "") - .def(py::init(), py::arg("hmm_path"), - py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix, - py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets) - .def("add_word", - [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); - (void)py::class_>( - *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") - .def(py::init(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets); - (void)py::class_>(*m, "LookupOp", - "Tensor operation to LookUp each word.") - .def(py::init([](std::shared_ptr vocab, const py::object &py_word) { - if (vocab == nullptr) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); - } - if (py_word.is_none()) { - return std::make_shared(vocab, Vocab::kNoTokenExists); - } - std::string word = py::reinterpret_borrow(py_word); - WordIdType default_id = vocab->Lookup(word); - if (default_id == Vocab::kNoTokenExists) { - THROW_IF_ERROR( - Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab.")); - } - return std::make_shared(vocab, default_id); - })); - (void)py::class_>(*m, "NgramOp", "TensorOp performs ngram mapping.") - .def(py::init &, int32_t, int32_t, const std::string &, const std::string &, - const std::string &>(), - py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"), - py::arg("separator")); - (void)py::class_>( - *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.") - .def( - py::init &, const std::string &, const int &, const std::string &, const bool &>(), - py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), - py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, - py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), - py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); -} - -void bindDependIcuTokenizerOps(py::module *m) { -#ifdef ENABLE_ICU4C - (void)py::class_>( - *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.") - .def(py::init(), py::arg("with_offsets") = WhitespaceTokenizerOp::kDefWithOffsets); - (void)py::class_>( - *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.") - .def(py::init<>()) - .def(py::init(), - py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace, - py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets); - (void)py::class_>( - *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor") - .def(py::init<>()); - (void)py::class_>( - *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.") - .def(py::init<>()) - .def(py::init(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm); - (void)py::class_>( - *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.") - .def(py::init(), py::arg("pattern"), py::arg("replace"), - py::arg("replace_all")); - (void)py::class_>( - *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.") - .def(py::init(), py::arg("delim_pattern"), - py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets); - (void)py::class_>( - *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.") - .def(py::init(), - py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, - py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, - py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, - py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, - py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets); - (void)py::class_>(*m, "BertTokenizerOp", - "Tokenizer used for Bert text process.") - .def(py::init &, const std::string &, const int &, const std::string &, const bool &, - const bool &, const NormalizeForm &, const bool &, const bool &>(), - py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), - py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, - py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), - py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, - py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, - py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, - py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, - py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); -#endif -} - -void bindSamplerOps(py::module *m) { - (void)py::class_>(*m, "Sampler") - .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) - .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) - .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) - .def("get_indices", - [](Sampler &self) { - py::array ret; - THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); - return ret; - }) - .def("add_child", - [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); - - (void)py::class_>(*m, "ShardOperator") - .def("add_child", [](std::shared_ptr self, - std::shared_ptr child) { self->SetChildOp(child); }); - - (void)py::class_>(*m, "DistributedSampler") - .def(py::init()); - - (void)py::class_>(*m, "PKSampler") - .def(py::init()); - - (void)py::class_>(*m, "RandomSampler") - .def(py::init()); - - (void)py::class_>(*m, "SequentialSampler") - .def(py::init()); - - (void)py::class_>(*m, "SubsetRandomSampler") - .def(py::init>()); - - (void)py::class_>( - *m, "MindrecordSubsetRandomSampler") - .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); - - (void)py::class_>( - *m, "MindrecordPkSampler") - .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { - if (shuffle == true) { - return std::make_shared(kColumn, kVal, std::numeric_limits::max(), - GetSeed()); - } else { - return std::make_shared(kColumn, kVal); - } - })); - - (void)py::class_>(*m, "MindrecordDistributedSampler") - .def(py::init()); - - (void)py::class_>( - *m, "MindrecordRandomSampler") - .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { - return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); - })); - - (void)py::class_>(*m, "MindrecordSequentialSampler") - .def(py::init([](int num_samples, int start_index) { - return std::make_shared(num_samples, start_index); - })); - - (void)py::class_>(*m, "WeightedRandomSampler") - .def(py::init, bool>()); - - (void)py::class_>(*m, "PythonSampler") - .def(py::init()); -} - -void bindInfoObjects(py::module *m) { - (void)py::class_(*m, "CBatchInfo") - .def(py::init()) - .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) - .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); -} - -void bindCacheClient(py::module *m) { - (void)py::class_>(*m, "CacheClient") - .def(py::init()); -} - -void bindVocabObjects(py::module *m) { - (void)py::class_>(*m, "Vocab") - .def(py::init<>()) - .def_static("from_list", - [](const py::list &words, const py::list &special_tokens, bool special_first) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); - return v; - }) - .def_static("from_file", - [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens, - bool special_first) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); - return v; - }) - .def_static("from_dict", [](const py::dict &words) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); - return v; - }); -} - -void bindGraphData(py::module *m) { - (void)py::class_>(*m, "Graph") - .def(py::init([](std::string dataset_file, int32_t num_workers) { - std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); - THROW_IF_ERROR(g_out->Init()); - return g_out; - })) - .def("get_all_nodes", - [](gnn::Graph &g, gnn::NodeType node_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); - return out; - }) - .def("get_all_edges", - [](gnn::Graph &g, gnn::EdgeType edge_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); - return out; - }) - .def("get_nodes_from_edges", - [](gnn::Graph &g, std::vector edge_list) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); - return out; - }) - .def("get_all_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); - return out; - }) - .def("get_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, - std::vector neighbor_types) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); - return out; - }) - .def("get_neg_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, - gnn::NodeType neg_neighbor_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); - return out; - }) - .def("get_node_feature", - [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { - TensorRow out; - THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); - return out.getRow(); - }) - .def("get_edge_feature", - [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { - TensorRow out; - THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); - return out.getRow(); - }) - .def("graph_info", - [](gnn::Graph &g) { - py::dict out; - THROW_IF_ERROR(g.GraphInfo(&out)); - return out; - }) - .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, - float step_home_param, float step_away_param, gnn::NodeIdType default_node) { - std::shared_ptr out; - THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); - return out; - }); -} - -// This is where we externalize the C logic as python modules -PYBIND11_MODULE(_c_dataengine, m) { - m.doc() = "pybind11 for _c_dataengine"; - (void)py::class_>(m, "DatasetOp"); - - (void)py::enum_(m, "OpName", py::arithmetic()) - .value("SHUFFLE", OpName::kShuffle) - .value("BATCH", OpName::kBatch) - .value("BUCKETBATCH", OpName::kBucketBatch) - .value("BARRIER", OpName::kBarrier) - .value("MINDRECORD", OpName::kMindrecord) - .value("CACHE", OpName::kCache) - .value("REPEAT", OpName::kRepeat) - .value("SKIP", OpName::kSkip) - .value("TAKE", OpName::kTake) - .value("ZIP", OpName::kZip) - .value("CONCAT", OpName::kConcat) - .value("MAP", OpName::kMap) - .value("FILTER", OpName::kFilter) - .value("DEVICEQUEUE", OpName::kDeviceQueue) - .value("GENERATOR", OpName::kGenerator) - .export_values() - .value("RENAME", OpName::kRename) - .value("TFREADER", OpName::kTfReader) - .value("PROJECT", OpName::kProject) - .value("IMAGEFOLDER", OpName::kImageFolder) - .value("MNIST", OpName::kMnist) - .value("MANIFEST", OpName::kManifest) - .value("VOC", OpName::kVoc) - .value("COCO", OpName::kCoco) - .value("CIFAR10", OpName::kCifar10) - .value("CIFAR100", OpName::kCifar100) - .value("RANDOMDATA", OpName::kRandomData) - .value("BUILDVOCAB", OpName::kBuildVocab) - .value("CELEBA", OpName::kCelebA) - .value("TEXTFILE", OpName::kTextFile) - .value("CLUE", OpName::kClue); - - (void)py::enum_(m, "JiebaMode", py::arithmetic()) - .value("DE_JIEBA_MIX", JiebaMode::kMix) - .value("DE_JIEBA_MP", JiebaMode::kMp) - .value("DE_JIEBA_HMM", JiebaMode::kHmm) - .export_values(); - -#ifdef ENABLE_ICU4C - (void)py::enum_(m, "NormalizeForm", py::arithmetic()) - .value("DE_NORMALIZE_NONE", NormalizeForm::kNone) - .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc) - .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc) - .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd) - .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd) - .export_values(); -#endif - - (void)py::enum_(m, "InterpolationMode", py::arithmetic()) - .value("DE_INTER_LINEAR", InterpolationMode::kLinear) - .value("DE_INTER_CUBIC", InterpolationMode::kCubic) - .value("DE_INTER_AREA", InterpolationMode::kArea) - .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) - .export_values(); - - (void)py::enum_(m, "BorderType", py::arithmetic()) - .value("DE_BORDER_CONSTANT", BorderType::kConstant) - .value("DE_BORDER_EDGE", BorderType::kEdge) - .value("DE_BORDER_REFLECT", BorderType::kReflect) - .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric) - .export_values(); - bindDEPipeline(&m); - bindTensor(&m); - bindTensorOps1(&m); - bindTensorOps2(&m); - bindTensorOps3(&m); - bindTensorOps4(&m); - bindTokenizerOps(&m); - bindSamplerOps(&m); - bindDatasetOps(&m); - bindInfoObjects(&m); - bindCacheClient(&m); - bindVocabObjects(&m); - bindGraphData(&m); - bindDependIcuTokenizerOps(&m); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index 91421f0ff8..75c2c6bcc1 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {} /// Function to create a Distributed Sampler. std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, - int64_t num_samples, uint32_t seed) { - auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed); + int64_t num_samples, uint32_t seed, bool even_dist) { + auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed, even_dist); // Input validation if (!sampler->ValidateParams()) { return nullptr; @@ -71,8 +71,8 @@ std::shared_ptr SequentialSampler(int64_t start_index, int } /// Function to create a Subset Random Sampler. -std::shared_ptr SubsetRandomSampler(const std::vector &indices, int64_t num_samples) { - auto sampler = std::make_shared(indices, num_samples); +std::shared_ptr SubsetRandomSampler(std::vector indices, int64_t num_samples) { + auto sampler = std::make_shared(std::move(indices), num_samples); // Input validation if (!sampler->ValidateParams()) { return nullptr; @@ -81,9 +81,9 @@ std::shared_ptr SubsetRandomSampler(const std::vector WeightedRandomSampler(const std::vector &weights, int64_t num_samples, +std::shared_ptr WeightedRandomSampler(std::vector weights, int64_t num_samples, bool replacement) { - auto sampler = std::make_shared(weights, num_samples, replacement); + auto sampler = std::make_shared(std::move(weights), num_samples, replacement); // Input validation if (!sampler->ValidateParams()) { return nullptr; @@ -95,8 +95,13 @@ std::shared_ptr WeightedRandomSampler(const std::vecto // DistributedSampler DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, - uint32_t seed) - : num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {} + uint32_t seed, bool even_dist) + : num_shards_(num_shards), + shard_id_(shard_id), + shuffle_(shuffle), + num_samples_(num_samples), + seed_(seed), + even_dist_(even_dist) {} bool DistributedSamplerObj::ValidateParams() { if (num_shards_ <= 0) { @@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() { } std::shared_ptr DistributedSamplerObj::Build() { - return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_); + return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, + even_dist_); } // PKSampler @@ -184,8 +190,8 @@ std::shared_ptr SequentialSamplerObj::Build() { } // SubsetRandomSampler -SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples) - : indices_(indices), num_samples_(num_samples) {} +SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector indices, int64_t num_samples) + : indices_(std::move(indices)), num_samples_(num_samples) {} bool SubsetRandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { @@ -202,9 +208,8 @@ std::shared_ptr SubsetRandomSamplerObj::Build() { } // WeightedRandomSampler -WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples, - bool replacement) - : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} +WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector weights, int64_t num_samples, bool replacement) + : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} bool WeightedRandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index 59a25ef9f5..a68fc7747f 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -16,18 +16,21 @@ #include "minddata/dataset/include/transforms.h" #include "minddata/dataset/kernels/image/image_utils.h" -#include "minddata/dataset/kernels/image/normalize_op.h" + +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/crop_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" -#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_crop_op.h" -#include "minddata/dataset/kernels/image/center_crop_op.h" -#include "minddata/dataset/kernels/image/uniform_aug_op.h" #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" -#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" #include "minddata/dataset/kernels/image/random_rotation_op.h" -#include "minddata/dataset/kernels/image/cut_out_op.h" -#include "minddata/dataset/kernels/image/random_color_adjust_op.h" -#include "minddata/dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/swap_red_blue_op.h" +#include "minddata/dataset/kernels/image/uniform_aug_op.h" namespace mindspore { namespace dataset { @@ -38,9 +41,9 @@ TensorOperation::TensorOperation() {} // Transform operations for computer vision. namespace vision { -// Function to create NormalizeOperation. -std::shared_ptr Normalize(std::vector mean, std::vector std) { - auto op = std::make_shared(mean, std); +// Function to create CenterCropOperation. +std::shared_ptr CenterCrop(std::vector size) { + auto op = std::make_shared(size); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -48,9 +51,9 @@ std::shared_ptr Normalize(std::vector mean, std::vect return op; } -// Function to create DecodeOperation. -std::shared_ptr Decode(bool rgb) { - auto op = std::make_shared(rgb); +// Function to create CropOperation. +std::shared_ptr Crop(std::vector coordinates, std::vector size) { + auto op = std::make_shared(coordinates, size); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -58,9 +61,9 @@ std::shared_ptr Decode(bool rgb) { return op; } -// Function to create ResizeOperation. -std::shared_ptr Resize(std::vector size, InterpolationMode interpolation) { - auto op = std::make_shared(size, interpolation); +// Function to create CutOutOp. +std::shared_ptr CutOut(int32_t length, int32_t num_patches) { + auto op = std::make_shared(length, num_patches); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -68,10 +71,9 @@ std::shared_ptr Resize(std::vector size, Interpolation return op; } -// Function to create RandomCropOperation. -std::shared_ptr RandomCrop(std::vector size, std::vector padding, - bool pad_if_needed, std::vector fill_value) { - auto op = std::make_shared(size, padding, pad_if_needed, fill_value); +// Function to create DecodeOperation. +std::shared_ptr Decode(bool rgb) { + auto op = std::make_shared(rgb); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -79,9 +81,9 @@ std::shared_ptr RandomCrop(std::vector size, std:: return op; } -// Function to create CenterCropOperation. -std::shared_ptr CenterCrop(std::vector size) { - auto op = std::make_shared(size); +// Function to create NormalizeOperation. +std::shared_ptr Normalize(std::vector mean, std::vector std) { + auto op = std::make_shared(mean, std); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -89,10 +91,10 @@ std::shared_ptr CenterCrop(std::vector size) { return op; } -// Function to create UniformAugOperation. -std::shared_ptr UniformAugment(std::vector> operations, - int32_t num_ops) { - auto op = std::make_shared(operations, num_ops); +// Function to create PadOperation. +std::shared_ptr Pad(std::vector padding, std::vector fill_value, + BorderType padding_mode) { + auto op = std::make_shared(padding, fill_value, padding_mode); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -100,9 +102,11 @@ std::shared_ptr UniformAugment(std::vector RandomHorizontalFlip(float prob) { - auto op = std::make_shared(prob); +// Function to create RandomColorAdjustOperation. +std::shared_ptr RandomColorAdjust(std::vector brightness, + std::vector contrast, + std::vector saturation, std::vector hue) { + auto op = std::make_shared(brightness, contrast, saturation, hue); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -110,9 +114,20 @@ std::shared_ptr RandomHorizontalFlip(float prob) return op; } -// Function to create RandomVerticalFlipOperation. -std::shared_ptr RandomVerticalFlip(float prob) { - auto op = std::make_shared(prob); +// Function to create RandomCropOperation. +std::shared_ptr RandomCrop(std::vector size, std::vector padding, + bool pad_if_needed, std::vector fill_value) { + auto op = std::make_shared(size, padding, pad_if_needed, fill_value); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomHorizontalFlipOperation. +std::shared_ptr RandomHorizontalFlip(float prob) { + auto op = std::make_shared(prob); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -132,10 +147,9 @@ std::shared_ptr RandomRotation(std::vector degre return op; } -// Function to create PadOperation. -std::shared_ptr Pad(std::vector padding, std::vector fill_value, - BorderType padding_mode) { - auto op = std::make_shared(padding, fill_value, padding_mode); +// Function to create RandomVerticalFlipOperation. +std::shared_ptr RandomVerticalFlip(float prob) { + auto op = std::make_shared(prob); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -143,9 +157,9 @@ std::shared_ptr Pad(std::vector padding, std::vector CutOut(int32_t length, int32_t num_patches) { - auto op = std::make_shared(length, num_patches); +// Function to create ResizeOperation. +std::shared_ptr Resize(std::vector size, InterpolationMode interpolation) { + auto op = std::make_shared(size, interpolation); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -153,11 +167,9 @@ std::shared_ptr CutOut(int32_t length, int32_t num_patches) { return op; } -// Function to create RandomColorAdjustOperation. -std::shared_ptr RandomColorAdjust(std::vector brightness, - std::vector contrast, - std::vector saturation, std::vector hue) { - auto op = std::make_shared(brightness, contrast, saturation, hue); +// Function to create SwapRedBlueOperation. +std::shared_ptr SwapRedBlue() { + auto op = std::make_shared(); // Input validation if (!op->ValidateParams()) { return nullptr; @@ -165,195 +177,119 @@ std::shared_ptr RandomColorAdjust(std::vector return op; } -/* ####################################### Derived TensorOperation classes ################################# */ - -// NormalizeOperation -NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} - -bool NormalizeOperation::ValidateParams() { - if (mean_.size() != 3) { - MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); - return false; - } - - if (std_.size() != 3) { - MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); - return false; +// Function to create UniformAugOperation. +std::shared_ptr UniformAugment(std::vector> transforms, + int32_t num_ops) { + auto op = std::make_shared(transforms, num_ops); + // Input validation + if (!op->ValidateParams()) { + return nullptr; } - - return true; -} - -std::shared_ptr NormalizeOperation::Build() { - return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); + return op; } -// DecodeOperation -DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} - -bool DecodeOperation::ValidateParams() { return true; } - -std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } +/* ####################################### Derived TensorOperation classes ################################# */ -// ResizeOperation -ResizeOperation::ResizeOperation(std::vector size, InterpolationMode interpolation) - : size_(size), interpolation_(interpolation) {} +// CenterCropOperation +CenterCropOperation::CenterCropOperation(std::vector size) : size_(size) {} -bool ResizeOperation::ValidateParams() { +bool CenterCropOperation::ValidateParams() { if (size_.empty() || size_.size() > 2) { - MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); + MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; return false; } return true; } -std::shared_ptr ResizeOperation::Build() { - int32_t height = size_[0]; - int32_t width = 0; +std::shared_ptr CenterCropOperation::Build() { + int32_t crop_height = size_[0]; + int32_t crop_width = 0; - // User specified the width value. + // User has specified crop_width. if (size_.size() == 2) { - width = size_[1]; + crop_width = size_[1]; } - return std::make_shared(height, width, interpolation_); + std::shared_ptr tensor_op = std::make_shared(crop_height, crop_width); + return tensor_op; } -// RandomCropOperation -RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, - std::vector fill_value) - : size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {} +// CropOperation. +CropOperation::CropOperation(std::vector coordinates, std::vector size) + : coordinates_(coordinates), size_(size) {} -bool RandomCropOperation::ValidateParams() { - if (size_.empty() || size_.size() > 2) { - MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size(); - return false; - } - - if (padding_.empty() || padding_.size() != 4) { - MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()"; +bool CropOperation::ValidateParams() { + // Do some input validation. + if (coordinates_.empty() || coordinates_.size() > 2) { + MS_LOG(ERROR) << "Crop: coordinates must be a vector of one or two values"; return false; } - - if (fill_value_.empty() || fill_value_.size() != 3) { - MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()"; + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "Crop: size must be a vector of one or two values"; return false; } return true; } -std::shared_ptr RandomCropOperation::Build() { - int32_t crop_height = size_[0]; - int32_t crop_width = 0; +std::shared_ptr CropOperation::Build() { + int32_t x, y, height, width; - int32_t pad_top = padding_[0]; - int32_t pad_bottom = padding_[1]; - int32_t pad_left = padding_[2]; - int32_t pad_right = padding_[3]; + x = coordinates_[0]; + y = coordinates_[1]; - uint8_t fill_r = fill_value_[0]; - uint8_t fill_g = fill_value_[1]; - uint8_t fill_b = fill_value_[2]; + height = size_[0]; + width = size_[1]; - // User has specified the crop_width value. - if (size_.size() == 2) { - crop_width = size_[1]; - } - - auto tensor_op = std::make_shared(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, - BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b); + std::shared_ptr tensor_op = std::make_shared(x, y, height, width); return tensor_op; } -// CenterCropOperation -CenterCropOperation::CenterCropOperation(std::vector size) : size_(size) {} +// CutOutOperation +CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} -bool CenterCropOperation::ValidateParams() { - if (size_.empty() || size_.size() > 2) { - MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; +bool CutOutOperation::ValidateParams() { + if (length_ < 0) { + MS_LOG(ERROR) << "CutOut: length cannot be negative"; return false; } - return true; -} - -std::shared_ptr CenterCropOperation::Build() { - int32_t crop_height = size_[0]; - int32_t crop_width = 0; - - // User has specified crop_width. - if (size_.size() == 2) { - crop_width = size_[1]; + if (num_patches_ < 0) { + MS_LOG(ERROR) << "CutOut: number of patches cannot be negative"; + return false; } - - std::shared_ptr tensor_op = std::make_shared(crop_height, crop_width); - return tensor_op; -} - -// UniformAugOperation -UniformAugOperation::UniformAugOperation(std::vector> operations, int32_t num_ops) - : operations_(operations), num_ops_(num_ops) {} - -bool UniformAugOperation::ValidateParams() { return true; } - -std::shared_ptr UniformAugOperation::Build() { - std::vector> tensor_ops; - (void)std::transform(operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), - [](std::shared_ptr op) -> std::shared_ptr { return op->Build(); }); - std::shared_ptr tensor_op = std::make_shared(tensor_ops, num_ops_); - return tensor_op; + return true; } -// RandomHorizontalFlipOperation -RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} - -bool RandomHorizontalFlipOperation::ValidateParams() { return true; } - -std::shared_ptr RandomHorizontalFlipOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(probability_); +std::shared_ptr CutOutOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(length_, length_, num_patches_, false, 0, 0, 0); return tensor_op; } -// RandomVerticalFlipOperation -RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} +// DecodeOperation +DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} -bool RandomVerticalFlipOperation::ValidateParams() { return true; } +bool DecodeOperation::ValidateParams() { return true; } -std::shared_ptr RandomVerticalFlipOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(probability_); - return tensor_op; -} +std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } -// Function to create RandomRotationOperation. -RandomRotationOperation::RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, - bool expand, std::vector center, - std::vector fill_value) - : degrees_(degrees), - interpolation_mode_(interpolation_mode), - expand_(expand), - center_(center), - fill_value_(fill_value) {} +// NormalizeOperation +NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} -bool RandomRotationOperation::ValidateParams() { - if (degrees_.empty() || degrees_.size() != 2) { - MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()"; - return false; - } - if (center_.empty() || center_.size() != 2) { - MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()"; +bool NormalizeOperation::ValidateParams() { + if (mean_.size() != 3) { + MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); return false; } - if (fill_value_.empty() || fill_value_.size() != 3) { - MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()"; + + if (std_.size() != 3) { + MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); return false; } + return true; } -std::shared_ptr RandomRotationOperation::Build() { - std::shared_ptr tensor_op = - std::make_shared(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_, - fill_value_[0], fill_value_[1], fill_value_[2]); - return tensor_op; +std::shared_ptr NormalizeOperation::Build() { + return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); } // PadOperation @@ -411,26 +347,6 @@ std::shared_ptr PadOperation::Build() { return tensor_op; } -// CutOutOperation -CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} - -bool CutOutOperation::ValidateParams() { - if (length_ < 0) { - MS_LOG(ERROR) << "CutOut: length cannot be negative"; - return false; - } - if (num_patches_ < 0) { - MS_LOG(ERROR) << "CutOut: number of patches cannot be negative"; - return false; - } - return true; -} - -std::shared_ptr CutOutOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(length_, length_, num_patches_, false, 0, 0, 0); - return tensor_op; -} - // RandomColorAdjustOperation. RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector brightness, std::vector contrast, std::vector saturation, std::vector hue) @@ -485,6 +401,153 @@ std::shared_ptr RandomColorAdjustOperation::Build() { return tensor_op; } +// RandomCropOperation +RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, + std::vector fill_value) + : size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {} + +bool RandomCropOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size(); + return false; + } + + if (padding_.empty() || padding_.size() != 4) { + MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()"; + return false; + } + + if (fill_value_.empty() || fill_value_.size() != 3) { + MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr RandomCropOperation::Build() { + int32_t crop_height = size_[0]; + int32_t crop_width = 0; + + int32_t pad_top = padding_[0]; + int32_t pad_bottom = padding_[1]; + int32_t pad_left = padding_[2]; + int32_t pad_right = padding_[3]; + + uint8_t fill_r = fill_value_[0]; + uint8_t fill_g = fill_value_[1]; + uint8_t fill_b = fill_value_[2]; + + // User has specified the crop_width value. + if (size_.size() == 2) { + crop_width = size_[1]; + } + + auto tensor_op = std::make_shared(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, + BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b); + return tensor_op; +} + +// RandomHorizontalFlipOperation +RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} + +bool RandomHorizontalFlipOperation::ValidateParams() { return true; } + +std::shared_ptr RandomHorizontalFlipOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +// Function to create RandomRotationOperation. +RandomRotationOperation::RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, + bool expand, std::vector center, + std::vector fill_value) + : degrees_(degrees), + interpolation_mode_(interpolation_mode), + expand_(expand), + center_(center), + fill_value_(fill_value) {} + +bool RandomRotationOperation::ValidateParams() { + if (degrees_.empty() || degrees_.size() != 2) { + MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()"; + return false; + } + if (center_.empty() || center_.size() != 2) { + MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()"; + return false; + } + if (fill_value_.empty() || fill_value_.size() != 3) { + MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr RandomRotationOperation::Build() { + std::shared_ptr tensor_op = + std::make_shared(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_, + fill_value_[0], fill_value_[1], fill_value_[2]); + return tensor_op; +} + +// RandomVerticalFlipOperation +RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} + +bool RandomVerticalFlipOperation::ValidateParams() { return true; } + +std::shared_ptr RandomVerticalFlipOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +// ResizeOperation +ResizeOperation::ResizeOperation(std::vector size, InterpolationMode interpolation) + : size_(size), interpolation_(interpolation) {} + +bool ResizeOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); + return false; + } + return true; +} + +std::shared_ptr ResizeOperation::Build() { + int32_t height = size_[0]; + int32_t width = 0; + + // User specified the width value. + if (size_.size() == 2) { + width = size_[1]; + } + + return std::make_shared(height, width, interpolation_); +} + +// SwapRedBlueOperation. +SwapRedBlueOperation::SwapRedBlueOperation() {} + +bool SwapRedBlueOperation::ValidateParams() { return true; } + +std::shared_ptr SwapRedBlueOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(); + return tensor_op; +} + +// UniformAugOperation +UniformAugOperation::UniformAugOperation(std::vector> transforms, int32_t num_ops) + : transforms_(transforms), num_ops_(num_ops) {} + +bool UniformAugOperation::ValidateParams() { return true; } + +std::shared_ptr UniformAugOperation::Build() { + std::vector> tensor_ops; + (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops), + [](std::shared_ptr op) -> std::shared_ptr { return op->Build(); }); + std::shared_ptr tensor_op = std::make_shared(tensor_ops, num_ops_); + return tensor_op; +} + } // namespace vision } // namespace api } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/core/client.h b/mindspore/ccsrc/minddata/dataset/core/client.h index 78b298e616..3de90cfeb2 100644 --- a/mindspore/ccsrc/minddata/dataset/core/client.h +++ b/mindspore/ccsrc/minddata/dataset/core/client.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_CLIENT_H_ -#define DATASET_CORE_CLIENT_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CLIENT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CLIENT_H_ // client.h // Include file for DE client functions @@ -25,20 +25,24 @@ #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/dataset_iterator.h" + +#ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#endif #ifdef ENABLE_PYTHON #include "minddata/dataset/engine/datasetops/barrier_op.h" #include "minddata/dataset/engine/datasetops/filter_op.h" #include "minddata/dataset/engine/datasetops/source/generator_op.h" #include "minddata/dataset/engine/datasetops/build_vocab_op.h" +#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" #endif #include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" -#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" #include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/rename_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" @@ -58,4 +62,4 @@ extern Status GlobalInit(); } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_CLIENT_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h index a8e1907c41..4d25c472e0 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.h +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_CONFIG_MANAGER_H_ -#define DATASET_CORE_CONFIG_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONFIG_MANAGER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONFIG_MANAGER_H_ #include #include @@ -134,4 +134,4 @@ class ConfigManager { } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_CONFIG_MANAGER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONFIG_MANAGER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index c85ef52bf5..8c8c0044a6 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_CONSTANTS_H_ -#define DATASET_CORE_CONSTANTS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_ #include #include @@ -26,6 +26,9 @@ namespace dataset { using uchar = unsigned char; using dsize_t = int64_t; +// Target devices to perform map operation +enum class MapTargetDevice { kCpu, kGpu, kDvpp }; + // Possible dataset types for holding the data and client type enum class DatasetType { kUnknown, kArrow, kTf }; @@ -63,4 +66,4 @@ using row_id_type = int64_t; } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_CONSTANTS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc index 5af748b5de..79c27d45cb 100644 --- a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc @@ -23,16 +23,33 @@ namespace mindspore { namespace dataset { -CVTensor::CVTensor(const TensorShape &shape, const DataType &type) : Tensor(shape, type) { + +CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); } -CVTensor::CVTensor(const TensorShape &shape, const DataType &type, const uchar *data) : Tensor(shape, type, data) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +Status CVTensor::CreateEmpty(const TensorShape &shape, DataType type, CVTensorPtr *out) { + const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); + *out = std::allocate_shared(*alloc, shape, type); + int64_t byte_size = (*out)->SizeInBytes(); + // Don't allocate if we have a tensor with no elements. + if (byte_size != 0) { + RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size)); + } + + return (*out)->MatInit((*out)->GetMutableBuffer(), (*out)->shape_, (*out)->type_, &(*out)->mat_); } -CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +Status CVTensor::CreateFromMat(const cv::Mat &mat, CVTensorPtr *out) { + TensorPtr out_tensor; + cv::Mat mat_local = mat; + // if the input Mat's memory is not continuous, copy it to one block of memory + if (!mat.isContinuous()) mat_local = mat.clone(); + TensorShape shape(mat.size, mat_local.type()); + DataType type = DataType::FromCVType(mat_local.type()); + RETURN_IF_NOT_OK(CreateFromMemory(shape, type, mat_local.data, &out_tensor)); + *out = AsCVTensor(out_tensor); + return Status::OK(); } std::pair, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) { @@ -57,7 +74,8 @@ std::shared_ptr CVTensor::AsCVTensor(std::shared_ptr t) { if (cv_t != nullptr) { return cv_t; } else { - return std::make_shared(t); + const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); + return std::allocate_shared(*alloc, t); } } @@ -97,5 +115,13 @@ void CVTensor::Squeeze() { Tensor::Squeeze(); (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); } + +Status CVTensor::MatAtIndex(const std::vector &index, cv::Mat *mat) { + uchar *start = nullptr; + TensorShape remaining({-1}); + RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); + RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h index a614418be6..f32d422672 100644 --- a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_CV_TENSOR_H_ -#define DATASET_CORE_CV_TENSOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CV_TENSOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CV_TENSOR_H_ #include #include @@ -30,56 +30,60 @@ namespace mindspore { namespace dataset { +using CVTensorPtr = std::shared_ptr; class CVTensor : public Tensor { public: - // Create an empty CVTensor of shape `shape` and type `type`. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - CVTensor(const TensorShape &shape, const DataType &type); - - // Create a CVTensor from a given buffer, shape and type. - // @note This constructor allocates a new space in the memory and copies the buffer into it. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - CVTensor(const TensorShape &shape, const DataType &type, const uchar *data); - - // Create a CVTensor from a given CV::Mat. - // @note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. - // @param mat CV::Mat - explicit CVTensor(const cv::Mat &mat) - : CVTensor(TensorShape(mat.size, mat.type()), DataType::FromCVType(mat.type()), mat.data) {} - - ~CVTensor() = default; - - // Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, - // this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data - // provided. The Passed Tensor will be invalidated. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor - // @return CVTensor - static std::shared_ptr AsCVTensor(std::shared_ptr tensor); - - // Create a CVTensor from a given tensor. The input tensor will be invalidated (i.e., the shape and type will be - // set to unknown and the data buffer will point to null. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor + // Inherit Tensor's constructors + using Tensor::Tensor; + + /// Create a CVTensor from a given tensor. This constructor should not be used directly, use Create* instead. + /// The input tensor will be invalidated (i.e., the shape and type will be + /// set to unknown and the data buffer will point to null. + /// \note there is no memory copying here, the buffer will be assigned to the constructed tensor. + /// \param tensor explicit CVTensor(std::shared_ptr tensor); - // Getter function for the CV::Mat - // @return + /// Create CV tensor with type and shape. Items of the tensor would be uninitialized. + /// \param shape [in] shape of the output tensor + /// \param type [in] type of the output tensor + /// \param out [out] Generated tensor + /// \return Status code + static Status CreateEmpty(const TensorShape &shape, DataType type, CVTensorPtr *out); + + /// Create CV tensor from cv::Mat + /// \note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. + /// \param mat [in] cv::Mat to be copied into the new tensor. + /// \param out [out] Generated tensor + /// \return Status code + static Status CreateFromMat(const cv::Mat &mat, CVTensorPtr *out); + + ~CVTensor() override = default; + + /// Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, + /// this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data + /// provided. The Passed Tensor will be invalidated. + /// \note the input tensor will be invalidated. + /// \note there is no memory copying here, the buffer will be assigned to the constructed tensor. + /// \param tensor [in] + /// \return CVTensor + static std::shared_ptr AsCVTensor(std::shared_ptr tensor); + + /// Get a reference to the CV::Mat + /// \return a reference to the internal CV::Mat cv::Mat mat() const { return mat_; } - // Static function to check if the passed information (shape and type) can be treated as a valid description - // of an image in OpenCV. Moreover, it returns OpenCV shape and type - // For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. - // In case of invalid shape or type, the function will return pair - // @param shape TensorShape - // @param type DataType - // @return std::pair of OpenCV shape and type - std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); + /// Get a copy of the CV::Mat + /// \return a copy of internal CV::Mat + cv::Mat matCopy() const { return mat_.clone(); } + + /// Static function to check if the passed information (shape and type) can be treated as a valid description + /// of an image in OpenCV. Moreover, it returns OpenCV shape and type + /// For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. + /// In case of invalid shape or type, the function will return pair + /// \param shape [in] TensorShape + /// \param type [in] DataType + /// \return std::pair of OpenCV shape and type + static std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); Status Reshape(const TensorShape &shape) override; @@ -87,20 +91,21 @@ class CVTensor : public Tensor { void Squeeze() override; - Status Mat(const std::vector &index, cv::Mat *mat) { - uchar *start = nullptr; - TensorShape remaining({-1}); - RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); - RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); - return Status::OK(); - } + Status MatAtIndex(const std::vector &index, cv::Mat *mat); private: + /// Opencv Mat object wrapping the raw data of the tensor. + /// Modifying the content of the matrix, modifies the tensor. cv::Mat mat_; - // Initialize CV::Mat with the data_, shape_ and type_ + /// Create cv::Mat from data, TensorShape and DataType + /// \param data [in] Pointer to the data in memory. + /// \param shape [in] Shape of the tensor. + /// \param type [in] Type of the tensor. + /// \param mat [out] cv::Mat initialized with the provided data. + /// \return Status code Status MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat); }; } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_CV_TENSOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CV_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/core/data_type.h index db4834cae2..ab48c3fc78 100644 --- a/mindspore/ccsrc/minddata/dataset/core/data_type.h +++ b/mindspore/ccsrc/minddata/dataset/core/data_type.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_DATA_TYPE_H_ -#define DATASET_CORE_DATA_TYPE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ #include @@ -284,6 +284,11 @@ inline DataType DataType::FromCType() { return DataType(DataType::DE_STRING); } +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_STRING); +} + template <> inline bool DataType::IsLooselyCompatible() const { return type_ == DataType::DE_BOOL; @@ -347,4 +352,4 @@ inline bool DataType::IsLooselyCompatible() const { } } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_DATA_TYPE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/global_context.h b/mindspore/ccsrc/minddata/dataset/core/global_context.h index fe0847f639..031c591ed8 100644 --- a/mindspore/ccsrc/minddata/dataset/core/global_context.h +++ b/mindspore/ccsrc/minddata/dataset/core/global_context.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_GLOBAL_CONTEXT_H_ -#define DATASET_CORE_GLOBAL_CONTEXT_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_GLOBAL_CONTEXT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_GLOBAL_CONTEXT_H_ #include #include @@ -105,4 +105,4 @@ class GlobalContext { } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_GLOBAL_CONTEXT_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_GLOBAL_CONTEXT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.cc b/mindspore/ccsrc/minddata/dataset/core/tensor.cc index 842615f9e1..2c7bbb5b51 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.cc +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.cc @@ -24,7 +24,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/core/global_context.h" @@ -59,49 +59,11 @@ Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), data_allocator_ = std::make_unique>(global_pool); } -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data) : Tensor(shape, type) { - if (type.IsNumeric()) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Given the shape/type of this tensor, compute the data size and copy in the input bytes. - int64_t byte_size = this->SizeInBytes(); - Status s = this->AllocateBuffer(byte_size); // Allocates data_ inside itself - if (s.IsOk() && data_ != nullptr) { - int ret_code = memcpy_s(data_, byte_size, data, byte_size); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } else { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - } - } else { - MS_LOG(ERROR) << "Type should be numeric to use this constructor."; - } -} - -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length) - : Tensor(shape, type) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Allocates data_ inside itself - Status s = AllocateBuffer(length); - if (s.IsError()) { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - if (data_ != nullptr) { - int ret_code = memcpy_s(data_, length, data, length); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } - } -} - Tensor::Tensor(Tensor &&other) noexcept : shape_(other.shape()), type_(other.type()), data_(other.GetMutableBuffer()), + data_end_(other.data_end_), data_allocator_(std::move(other.data_allocator_)) { other.Invalidate(); } @@ -117,118 +79,61 @@ Tensor &Tensor::operator=(Tensor &&other) noexcept { } return *this; } +Status Tensor::CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(shape.known(), "Invalid shape."); + CHECK_FAIL_RETURN_UNEXPECTED(type != DataType::DE_UNKNOWN, "Invalid data type."); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, shape, type); + // if it's a string tensor and it has no elements, Just initialize the shape and type. + if (!type.IsNumeric() && shape.NumOfElements() == 0) { + return Status::OK(); + } -Tensor::Tensor(const std::vector &strings, const TensorShape &shape) - : Tensor(TensorShape({static_cast(strings.size())}), DataType(DataType::DE_STRING)) { - auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; - dsize_t total_length = std::accumulate(strings.begin(), strings.end(), 0, length_sum); - - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize + 1) * shape_.NumOfElements() + kOffsetSize + total_length; - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); + CHECK_FAIL_RETURN_UNEXPECTED(type.IsNumeric(), "Number of elements is not 0. The type should be numeric."); - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (const auto &str : strings) { - // insert the start index of the string. - offset_arr[i++] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; + int64_t byte_size = (*out)->SizeInBytes(); + // Don't allocate if we have a tensor with no elements. + if (byte_size != 0) { + RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size)); } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - this->data_end_ = data_ + offset_arr[i]; - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); + return Status::OK(); } - -Tensor::Tensor(const dataengine::BytesList &bytes_list, const TensorShape &shape) - : Tensor(TensorShape({static_cast(bytes_list.value_size())}), DataType(DataType::DE_STRING)) { - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize)*shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (; i < bytes_list.value_size(); i++) { - const std::string &str = bytes_list.value(i); - // insert the start index of the string. - offset_arr[i] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) { - MS_LOG(ERROR) << "Cannot copy string into Tensor"; - } - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; +Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out) { + RETURN_IF_NOT_OK(CreateEmpty(shape, type, out)); + if (src != nullptr) { + // Given the shape/type of this tensor, compute the data size and copy in the input bytes. + int64_t byte_size = (*out)->SizeInBytes(); + int ret_code = memcpy_s((*out)->data_, byte_size, src, byte_size); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor."); } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - data_end_ = data_ + offset_arr[i]; - - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); + return Status::OK(); } -Status Tensor::CreateTensor(std::shared_ptr *ptr, TensorImpl tensor_impl, const TensorShape &shape, - DataType type, const unsigned char *data) { - if (!shape.known()) { - RETURN_STATUS_UNEXPECTED("Invalid shape."); - } - if (type == DataType::DE_UNKNOWN) { - RETURN_STATUS_UNEXPECTED("Invalid data type."); +Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type, const unsigned char *src, + const dsize_t &length, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr, "Pointer to source data is null."); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, shape, type); + if (type.IsNumeric()) { + dsize_t calculated_length = (*out)->SizeInBytes(); + CHECK_FAIL_RETURN_UNEXPECTED(calculated_length == length, "Length of source data does not match the shape."); + } else { + // min_length is the length of a tensor with empty strings + // min_length = the number of bytes needed to store the offsets + 1 byte for each element + dsize_t min_length = (shape.NumOfElements() + 1) * kOffsetSize + shape.NumOfElements(); + CHECK_FAIL_RETURN_UNEXPECTED(min_length <= length, "Length of source data does not match the shape."); } - switch (tensor_impl) { - case TensorImpl::kFlexible: { - // The flex tensor is really just the base class tensor implementation - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - case TensorImpl::kCv: { - const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - default: { - std::string err_msg("Invalid tensor implementation type."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); // returns base-class shared_ptr + RETURN_IF_NOT_OK((*out)->AllocateBuffer(length)); + int ret_code = memcpy_s((*out)->data_, length, src, length); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor."); + + return Status::OK(); } #ifdef ENABLE_PYTHON -Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr) { +Status Tensor::CreateFromNpString(py::array arr, std::shared_ptr *out) { std::vector shape; for (dsize_t i = 0; i < arr.ndim(); i++) { shape.push_back(static_cast(arr.shape()[i])); @@ -244,34 +149,38 @@ Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::arr arr.resize(shape); // resize arr back to the original shape - return CreateTensor(ptr, strings, TensorShape{shape}); + return CreateFromVector(strings, TensorShape{shape}, out); } -Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { +Status Tensor::CreateFromNpArray(const py::array &arr, std::shared_ptr *out) { if (DataType::FromNpArray(arr) == DataType::DE_STRING) { - return CreateTensorFromNumpyString(ptr, arr); + return CreateFromNpString(arr, out); } const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, TensorShape({}), DataType(DataType::DE_UNKNOWN)); + *out = std::allocate_shared(*alloc, TensorShape::CreateScalar(), DataType(DataType::DE_UNKNOWN)); std::vector shape; for (dsize_t i = 0; i < arr.ndim(); i++) { shape.push_back(static_cast(arr.shape()[i])); } - (*ptr)->shape_ = TensorShape(shape); - (*ptr)->type_ = DataType::FromNpArray(arr); - if (!(*ptr)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); + (*out)->shape_ = TensorShape(shape); + (*out)->type_ = DataType::FromNpArray(arr); + if (!(*out)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); - if ((*ptr)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); + if ((*out)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); - (*ptr)->data_allocator_ = std::make_unique>(global_pool); - int64_t byte_size = (*ptr)->SizeInBytes(); - RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size)); + (*out)->data_allocator_ = std::make_unique>(global_pool); + int64_t byte_size = (*out)->SizeInBytes(); + if (byte_size == 0) { + return Status::OK(); + } + + RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size)); unsigned char *data = static_cast(arr.request().ptr); - if ((*ptr)->data_ == nullptr) { + if ((*out)->data_ == nullptr) { RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); } @@ -282,61 +191,92 @@ Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { // check if strides are contiguous bool is_strided = false; - dsize_t count = (*ptr)->shape_.NumOfElements(); + dsize_t count = (*out)->shape_.NumOfElements(); for (size_t i = 0; i < shape.size(); i++) { count /= shape[i]; - if (strides[i] != (*ptr)->type_.SizeInBytes() * count) { + if (strides[i] != (*out)->type_.SizeInBytes() * count) { is_strided = true; break; } } if (is_strided) { - RETURN_IF_NOT_OK(CopyStridedArray((*ptr)->data_, data, shape, strides, (*ptr)->type_.SizeInBytes())); + RETURN_IF_NOT_OK(CopyStridedArray((*out)->data_, data, shape, strides, (*out)->type_.SizeInBytes())); } else { - int ret_code = memcpy_s((*ptr)->data_, byte_size, data, byte_size); + int ret_code = memcpy_s((*out)->data_, byte_size, data, byte_size); if (ret_code != 0) { RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); } } - return Status::OK(); // returns base-class shared_ptr + return Status::OK(); } #endif -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape) { +#ifndef ENABLE_ANDROID +Status Tensor::CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out) { const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, strings, shape); - return Status::OK(); -} + *out = std::allocate_shared(*alloc, TensorShape({static_cast(bytes_list.value_size())}), + DataType(DataType::DE_STRING)); + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize) * (*out)->shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, bytes_list, shape); + (*out)->data_ = (*out)->data_allocator_->allocate(num_bytes); + + auto offset_arr = reinterpret_cast((*out)->data_); + uchar *buf = (*out)->GetStringsBuffer(); + + offset_t offset = buf - (*out)->data_; // the first string will start here + uint32_t i = 0; + for (; i < bytes_list.value_size(); i++) { + const std::string &str = bytes_list.value(i); + // insert the start index of the string. + offset_arr[i] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s((*out)->data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) { + MS_LOG(ERROR) << "Cannot copy string into Tensor"; + } + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + (*out)->data_end_ = (*out)->data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + (*out)->Reshape(shape); return Status::OK(); } +#endif -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::string &file_path) { +Status Tensor::CreateFromFile(const std::string &path, std::shared_ptr *out) { std::ifstream fs; - fs.open(file_path, std::ios::binary | std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path); + fs.open(path, std::ios::binary | std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + path); int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8))); - int64_t written_bytes = fs.read(reinterpret_cast((*ptr)->GetMutableBuffer()), num_bytes).gcount(); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape{num_bytes}, DataType(DataType::DE_UINT8), out)); + int64_t written_bytes = fs.read(reinterpret_cast((*out)->GetMutableBuffer()), num_bytes).gcount(); CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor"); fs.close(); return Status::OK(); } -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type)); +#ifndef ENABLE_ANDROID +Status Tensor::CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, + const DataType &type, dsize_t pad_size, TensorPtr *out) { + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, type, out)); - unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer(); + unsigned char *current_tensor_addr = (*out)->GetMutableBuffer(); int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size; for (int i = 0; i < bytes_list.value_size(); i++) { @@ -361,6 +301,7 @@ Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::Byte return Status::OK(); } +#endif // Memcpy the given strided array's used part to consecutive memory // Consider a 3-d array @@ -368,7 +309,7 @@ Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::Byte // Here we convert array C to array A, by memcpy index by index (Note that not all elements in C is copied) Status Tensor::CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, std::vector strides, uint8_t type_size) { - dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); for (dsize_t i = 0; i < size; ++i) { dsize_t offset = 0; dsize_t count = i; @@ -429,29 +370,29 @@ void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) c MS_ASSERT(data_); switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, bool); + CASE_PRINT_HEX(DataType::DE_BOOL, bool) - CASE_PRINT_HEX(DataType::DE_INT8, int8_t); + CASE_PRINT_HEX(DataType::DE_INT8, int8_t) - CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t); + CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t) - CASE_PRINT(DataType::DE_INT16, int16_t); + CASE_PRINT(DataType::DE_INT16, int16_t) - CASE_PRINT(DataType::DE_UINT16, uint16_t); + CASE_PRINT(DataType::DE_UINT16, uint16_t) - CASE_PRINT(DataType::DE_INT32, int32_t); + CASE_PRINT(DataType::DE_INT32, int32_t) - CASE_PRINT(DataType::DE_UINT32, uint32_t); + CASE_PRINT(DataType::DE_UINT32, uint32_t) - CASE_PRINT(DataType::DE_INT64, int64_t); + CASE_PRINT(DataType::DE_INT64, int64_t) - CASE_PRINT(DataType::DE_UINT64, uint64_t); + CASE_PRINT(DataType::DE_UINT64, uint64_t) - CASE_PRINT(DataType::DE_FLOAT16, float16); + CASE_PRINT(DataType::DE_FLOAT16, float16) - CASE_PRINT(DataType::DE_FLOAT32, float); + CASE_PRINT(DataType::DE_FLOAT32, float) - CASE_PRINT(DataType::DE_FLOAT64, double); + CASE_PRINT(DataType::DE_FLOAT64, double) case DataType::DE_STRING: { std::string_view o{""}; @@ -501,50 +442,14 @@ void Tensor::Print(std::ostream &out) const { } } Status Tensor::AllocateBuffer(const dsize_t &length) { + RETURN_UNEXPECTED_IF_NULL(data_allocator_); if (data_ == nullptr) { - if (data_allocator_ != nullptr) { - data_ = data_allocator_->allocate(length); - RETURN_UNEXPECTED_IF_NULL(data_); - data_end_ = data_ + length; - } else { - data_ = static_cast(malloc(length)); - data_end_ = data_ + length; - RETURN_UNEXPECTED_IF_NULL(data_); - } + data_ = data_allocator_->allocate(length); + CHECK_FAIL_RETURN_UNEXPECTED(data_ != nullptr, "Failed to allocate memory for tensor."); + data_end_ = data_ + length; } return Status::OK(); } -const unsigned char *Tensor::GetBuffer() const { - // This version cannot modify anything. data_ could possibly be null. - return data_; -} - -// check for empty -bool Tensor::HasData() const { - if (data_ == nullptr) { - return true; - } else { - return false; - } -} - -unsigned char *Tensor::GetMutableBuffer() { - if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { - return nullptr; - } - // If the data area is already created, return the pointer to it - if (data_ != nullptr) { - return data_; - } else { - // If the data area is not created, then identify the memory size based - // on the shape and type and allocate it. - if (this->AllocateBuffer(this->SizeInBytes()).IsOk()) { - return data_; - } else { - return nullptr; - } - } -} Status Tensor::Reshape(const TensorShape &shape) { if (shape.NumOfElements() == shape_.NumOfElements()) { @@ -621,16 +526,34 @@ Status Tensor::StartAddrOfIndex(std::vector ind, uchar **start_addr_of_ return Status::OK(); } -Status Tensor::InsertTensor(const std::vector &ind, const std::shared_ptr &tensor) { +Status Tensor::InsertTensor(const std::vector &ind, const std::shared_ptr &tensor, + const bool partial_insert) { std::string err_msg; - err_msg += (this->type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; - err_msg += (!this->shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; - err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : ""; - err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; + if (partial_insert) { + err_msg += (ind.size() != 1) + ? "[Tensor] only supports 1D insertion of elements not along the full length of the axis\n" + : ""; + err_msg += + (ind.at(0) + tensor->shape().NumOfElements() > shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; + } else { + err_msg += (ind.size() + tensor->Rank() != Rank()) ? "[Tensor] incorrect index\n" : ""; + } + err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot insert into a tensor of type string\n" : ""; + err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; + + err_msg += tensor->type().SizeInBytes() != type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; uchar *start_addr_of_ind = nullptr; - TensorShape remaining_shape({-1}); - err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; - err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; + if (partial_insert) { + TensorShape remaining_shape = tensor->shape(); + err_msg += + (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; + } else { + TensorShape remaining_shape = TensorShape::CreateUnknownRankShape(); + err_msg += + (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; + err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; + } + if (!err_msg.empty()) { MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; RETURN_STATUS_UNEXPECTED(err_msg); @@ -651,39 +574,6 @@ Status Tensor::InsertTensor(const std::vector &ind, const std::shared_p } } -Status Tensor::Concatenate(const std::vector &index, const std::shared_ptr &tensor) { - std::string err_msg; - err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; - err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; - err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; - - err_msg += - (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; - err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; - uchar *start_addr_of_ind = nullptr; - - TensorShape remaining_shape = tensor->shape(); - StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); - err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; - - if (!err_msg.empty()) { - MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; - - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - int ret_code = - memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); - - if (ret_code == 0) { - return Status::OK(); - } else { - err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; - MS_LOG(DEBUG) << "Tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } -} - Status Tensor::ExpandDim(const dsize_t &axis) { if (axis > Rank()) { std::string err = "Axis is out of bound"; @@ -697,7 +587,7 @@ Status Tensor::ExpandDim(const dsize_t &axis) { return Status::OK(); } -std::vector Tensor::Strides() { +std::vector Tensor::Strides() const { std::vector strides = shape_.Strides(); uint8_t size = type_.SizeInBytes(); std::transform(strides.begin(), strides.end(), strides.begin(), [&size](const auto &c) { return c * size; }); @@ -765,7 +655,6 @@ Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) #ifdef ENABLE_PYTHON // return data as numpy, should return status Status Tensor::GetDataAsNumpy(py::array *data) { - RETURN_UNEXPECTED_IF_NULL(data_); RETURN_UNEXPECTED_IF_NULL(data); if (type_ == DataType::DE_BOOL) { *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); @@ -974,7 +863,9 @@ Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vect } Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); - CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); + if (indices.empty()) { + return CreateEmpty(TensorShape({0}), type_, out); + } if (type_.IsNumeric()) { return SliceNumeric(out, indices); } else { @@ -982,8 +873,7 @@ Status Tensor::Slice(std::shared_ptr *out, const std::vector &i } } Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { - RETURN_IF_NOT_OK( - CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); + RETURN_IF_NOT_OK(CreateEmpty(TensorShape({static_cast(indices.size())}), type_, out)); (*out)->GetMutableBuffer(); dsize_t out_index = 0; dsize_t dim_length = shape_[0]; @@ -1027,7 +917,7 @@ Status Tensor::SliceString(std::shared_ptr *out, const std::vector(strings.size())}), out); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h index b0b173e9c3..b2fe352c1d 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_TENSOR_H_ -#define DATASET_CORE_TENSOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ #include #include @@ -33,16 +33,26 @@ #include "pybind11/stl.h" #endif +#include "utils/ms_utils.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/util/status.h" +#ifndef ENABLE_ANDROID #include "proto/example.pb.h" +#else +#include "minddata/dataset/include/de_tensor.h" +#endif #ifdef ENABLE_PYTHON namespace py = pybind11; #endif namespace mindspore { +#ifdef ENABLE_ANDROID +namespace tensor { +class DETensor; +} // namespace tensor +#endif namespace dataset { class Tensor; template @@ -50,170 +60,157 @@ class Allocator; using CharAllocPtr = std::unique_ptr>; using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors +using offset_t = uint32_t; // type of offset values to store strings locations +using TensorPtr = std::shared_ptr; class Tensor { public: Tensor() = delete; - - // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - Tensor(const TensorShape &shape, const DataType &type); - - // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); - - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); - Tensor(const Tensor &other) = delete; - Tensor &operator=(const Tensor &other) = delete; + /// Create a tensor using shape and type. This constructor should not be used directly, use CreateFromTensor instead + /// \note The shape and type information should be known and valid + /// \note The constructor does not allocate data + /// \param shape TensorShape + /// \param type DataType + Tensor(const TensorShape &shape, const DataType &type); + + /// Move constructor + /// \param other Tensor to be moved Tensor(Tensor &&other) noexcept; + /// Move assigment operator + /// \param other Tensor to be moved Tensor &operator=(Tensor &&other) noexcept; - Status AllocateBuffer(const dsize_t &length); - - // type of offest values to store strings information - using offset_t = uint32_t; - // const of the size of the offset variable - static constexpr uint8_t kOffsetSize = sizeof(offset_t); - // Tensor base class which holds the data in an unsigned char* buffer. - - // Construct a scalar string Tensor - explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} - - // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is - // the size of the vector `strings`. - // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. - // Thr offset array will store one extra value to find the length of the last string. - // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn - // The value of each offset is the start index of the corresponding string - // Offsets is of type offest_t - // strings will ne null-terminated - // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) - // |----------------------------------------------------------------| - // | OFFSET ARRAY | STRINGS | - // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | - // | 11 | 15 | 18 | abc\0 | de\0 | - // |----------------------------------------------------------------| - explicit Tensor(const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // Same as Tensor(vector) but the input is protobuf bytelist - explicit Tensor(const dataengine::BytesList &bytes_list, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // A static factory method to create the given flavour of derived Tensor - // Returns the base class reference for the Tensor. - // @param ptr output argument to hold the created Tensor of given tensor_impl - // @param tensor_impl - which implementation of Tensor - // @param shape - shape of the tensor - // @param type - datatype of the tensor - // @param data - data to be copied to Tensor new allocation - // @return Status Code - static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, - const unsigned char *data = nullptr); - - // Create a copy of the input tensor - // @param out [out] output tensor to be generated - // @param in [in] orginal tensor to be copied - // @return Status - static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); - return Status::OK(); + /// Create a numeric tensor with type and shape. Items of the tensor would be uninitialized. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out); + + /// Create a numeric tensor from a pointer in memory. Length of the source data is determined from the shape and type. + /// Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out); + + /// Create a tensor from a pointer in memory and length. Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[in] length length of the src data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, + const dsize_t &length, TensorPtr *out); + + /// Create a copy of the input tensor + /// \param[in] in original tensor to be copied + /// \param[out] out output tensor to be generated + /// \return Status + static Status CreateFromTensor(const TensorPtr &in, TensorPtr *out) { + return CreateFromMemory(in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes(), out); } #ifdef ENABLE_PYTHON - // A static factory method to create a Tensor from a given py::array. - // @param ptr output argument to hold the created Tensor - // @param arr py::array - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, py::array arr); - - // Helper function to create a tensor from Numpy of strings - static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); + /// Create a Tensor from a given py::array + /// \param[in] arr py::array + /// \param[out] out Created tensor + /// \return Status Code + static Status CreateFromNpArray(const py::array &arr, TensorPtr *out); +#endif + +#ifndef ENABLE_ANDROID + /// Create a tensor of type DE_STRING from a BytesList. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the outout tensor + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out); + + /// Create a tensor of type UINT8 or INT8 from a BytesList. + /// The tensor will be padded with ' ' to reach the required pad_size. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the output tensor + /// \param[in] type type of created tensor. Should be DE_UINT8 or INT8 + /// \param[in] pad_size The size of the tensor after padding + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, + const DataType &type, dsize_t pad_size, TensorPtr *out); #endif - // A static factory method to create a Tensor from a given list of strings. - // @param ptr output argument to hold the created Tensor - // @param strings elements of the tensor - // @param shape shape of the tensor - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // create tensor from protobuf bytelist with strings - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape); - - // A static factory method to create a Tensor from a given list of numbers. - // @param ptr output argument to hold the created Tensor - // @param items elements of the tensor - // @param shape shape of the tensor - // @return Status Code + /// Create a Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[in] shape shape of the output tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code template - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, - const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { + static Status CreateFromVector(const std::vector &items, const TensorShape &shape, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); DataType type = DataType::FromCType(); + // if items is empty, items_ptr would be nullptr. CreateFromMemory will handle this case. auto items_ptr = reinterpret_cast(&items[0]); - TensorShape shape = shape_req; - if (!shape.known()) { - shape = TensorShape({static_cast(items.size())}); - } - return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); + return CreateFromMemory(shape, type, items_ptr, out); } - // A static factory method to create a Tensor from a given number. - // @param ptr output argument to hold the created Tensor - // @param item value - // @return Status Code + /// Create a 1D Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code template - static Status CreateTensor(std::shared_ptr *ptr, const T &item) { - return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); + static Status CreateFromVector(const std::vector &items, TensorPtr *out) { + return CreateFromVector(items, TensorShape({static_cast(items.size())}), out); } - // Create tensor from protobuf bytelist with uint8 or int8 types - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size); - - static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); + /// Create a numeric scalar Tensor from the given value. + /// \tparam T type of value + /// \param[in] item value + /// \param[out] out Created tensor + /// \return Status code + template + static Status CreateScalar(const T &item, TensorPtr *out) { + DataType type = DataType::FromCType(); + auto item_ptr = reinterpret_cast(&item); + return CreateFromMemory(TensorShape::CreateScalar(), type, item_ptr, out); + } - // Copy raw data of a array based on shape and strides to the destination pointer - // @param dst Pointer to the destination array where the content is to be copied - // @param src Pointer to the source of strided array to be copied - // @param shape - shape of the source array - // @param strides - strides of the source array - // @param type_size - number of bytes needed to store one array element's type - // @return Status Code - static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size); + /// Create a tensor from a binary file on disk. + /// \param[in] path file to be read + /// \param[out] out Created Tensor + /// \return Status code + static Status CreateFromFile(const std::string &path, TensorPtr *out); - // Release the memory using the allocator + /// Destruct the tensor and release the memory using the allocator virtual ~Tensor(); - // compare the tensor shape and data + /// Equality operator. compares tensor shape, type and data + /// \param[in] rhs Tensor to be compared with + /// \return bool bool operator==(const Tensor &rhs) const; bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } - // Get item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return the item specified at index + /// Get item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return the item specified at index template Status GetItemAt(T *o, const std::vector &index) const; - // Get string located at `index`. - // @param index vector - // @return return std::string_view specified at index + /// Get string located at `index`. + /// \param[in] index vector + /// \return return std::string_view specified at index Status GetItemAt(std::string_view *o, const std::vector &index) const; template @@ -225,22 +222,21 @@ class Tensor { template Status GetFloatAt(T *o, const std::vector &index) const; - // set item at location specified by index - // @tparam `T` - // @param index - // @param value of type `T` + /// set item at location specified by index + /// \tparam `T` + /// \param[in] index + /// \param[in] value of type `T` template Status SetItemAt(const std::vector &index, const T &value) { - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); T *ptr = nullptr; RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); *ptr = value; return Status::OK(); } - // set string item at location specified by index - // @param index - // @param value of type std::string + /// set string item at location specified by index + /// \param[in] index + /// \param[in] value of type std::string Status SetItemAt(const std::vector &index, const std::string &value) { RETURN_UNEXPECTED_IF_NULL(data_); uchar *ptr = nullptr; @@ -253,7 +249,8 @@ class Tensor { return Status::OK(); } - // fill tensor with Zeros. Does not support strings. + + /// fill tensor with Zeros. Does not support strings. Status Zero() { CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); dsize_t size = SizeInBytes(); @@ -262,13 +259,12 @@ class Tensor { return Status::OK(); } - // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. - // @tparam T - // @param value + /// Fill all elements in the Tensor with the given value of type `T`. Does not support strings. + /// \tparam T + /// \param value[in] template Status Fill(const T &value) { CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); int64_t cellSize = type_.SizeInBytes(); if ((data_ != nullptr) && type_.IsCompatible()) { for (dsize_t i = 0; i < Size(); i++) { @@ -283,86 +279,88 @@ class Tensor { } } - // Getter function for shape - // @return + /// Getter function for shape + /// \return const TensorShape &shape() const { return shape_; } /// Check if tensor has data /// \return bool - true if tensor is empty - bool HasData() const; + bool HasData() const { return data_ != nullptr; } - // Reshape the tensor. The given shape should have the same number of elements in the Tensor - // @param shape + /// Reshape the tensor. The given shape should have the same number of elements in the Tensor + /// \param shape virtual Status Reshape(const TensorShape &shape); - // @return number of elements in this tensor + /// \return number of elements in this tensor dsize_t Size() const { return shape().NumOfElements(); } - // @return the number of bytes this tensor is needs + /// \return the number of bytes this tensor is needs dsize_t SizeInBytes() const { if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); return data_end_ - data_; } - // @return the rank of the tensor + /// \return the rank of the tensor dsize_t Rank() const { return shape().Rank(); } - // Get the starting memory address as a constant for the data of the tensor. This potentially - // drives an allocation if the data area. - // @return const unsigned char* - const unsigned char *GetBuffer() const; + /// Get the starting memory address as a constant for the data of the tensor. This potentially + /// drives an allocation if the data area. + /// \return const unsigned char* + const unsigned char *GetBuffer() const { return data_; } - // Getter of the type - // @return + /// Getter of the type + /// \return DataType type() const { return type_; } - // Provide stream operator for displaying it - // @param output stream - // @param so the Tensor object to be printed - // @return output stream + /// Provide stream operator for displaying it + /// \param output stream + /// \param so the Tensor object to be printed + /// \return output stream friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { so.Print(out); return out; } - // Invalidate this Tensor by setting the type and shape to unknown and MData to null. - // Calling this method will make the Tensor and its data inaccessible, use it with caution. + /// Invalidate this Tensor by setting the type and shape to unknown and MData to null. + /// Calling this method will make the Tensor and its data inaccessible, use it with caution. void Invalidate(); - // Copy input tensor into self at the location index. - // Index is a vector of axises which can be incomplete: - // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. - // @param index - // @param input - // @return Status code - Status InsertTensor(const std::vector &index, const std::shared_ptr &input); - - // Find the address of the given index. Used in InsertTensor. - // Example: - // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 - // @param index incomplete index - // @param output: startAddrofIndex - // @param output: remaining - // @return Status code + /// Copy input tensor into self at the location index. + /// Index is a vector of axises which can be incomplete: + /// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. + /// \param index + /// \param input + /// \param partial_insert: boolean to determine if insertion along the full axis is enforced + /// \return Status code + Status InsertTensor(const std::vector &index, const std::shared_ptr &input, + const bool partial_insert = false); + + /// Find the address of the given index. Used in InsertTensor. + /// Example: + /// Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 + /// \param index incomplete index + /// \param output: startAddrofIndex + /// \param output: remaining + /// \return Status code Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); - // Expand the shape of the Tensor with one extra dimension. - // For example, if the shape is <512,512,3>: - // *- ExpandDim(0) gives: <1,512,512,3> - // *- ExpandDim(1) gives: <512,1,512,3> - // *- ExpandDim(3) gives: <512,512,3,1> - // @param axis location of the dim + /// Expand the shape of the Tensor with one extra dimension. + /// For example, if the shape is <512,512,3>: + /// *- ExpandDim(0) gives: <1,512,512,3> + /// *- ExpandDim(1) gives: <512,1,512,3> + /// *- ExpandDim(3) gives: <512,512,3,1> + /// \param axis location of the dim virtual Status ExpandDim(const dsize_t &axis); virtual void Squeeze(); - // Calculates the strides of the Tensor - // Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) - // The strides will be {6,2,1}. - // Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) - // The strides will be {24,8,4}. - // @return vector of integers - std::vector Strides(); + /// Calculates the strides of the Tensor + /// Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) + /// The strides will be {6,2,1}. + /// Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) + /// The strides will be {24,8,4}. + /// \return vector of integers + std::vector Strides() const; std::string ToString() { std::stringstream ss; @@ -370,26 +368,26 @@ class Tensor { return ss.str(); } - // Handle negative indices. + /// Handle negative indices. static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } - // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. - // Based on the type of tensor, SliceNumeric or SliceString will be called - // @param out Tensor - // @param indices vector of indices - // @return Status error code - Status Slice(std::shared_ptr *out, const std::vector &indices); + /// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + /// Based on the type of tensor, SliceNumeric or SliceString will be called + /// \param[out] out Tensor + /// \param[in] indices vector of indices + /// \return Status error code + Status Slice(TensorPtr *out, const std::vector &indices); - // Slice numeric tensors. - Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); + /// Slice numeric tensors. + Status SliceNumeric(TensorPtr *out, const std::vector &indices); - // Slice string tensors - Status SliceString(std::shared_ptr *out, const std::vector &indices); + /// Slice string tensors + Status SliceString(TensorPtr *out, const std::vector &indices); #ifdef ENABLE_PYTHON - // Constructs numpy array from input tensor - // @param data this data is the location of python data - // @return Status code + /// Constructs numpy array from input tensor + /// \param[in] data this data is the location of python data + /// \return Status code Status GetDataAsNumpy(py::array *data); Status GetDataAsNumpyStrings(py::array *data); @@ -397,12 +395,9 @@ class Tensor { static Status GetBufferInfo(Tensor *t, py::buffer_info *out); #endif - // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor - Status Concatenate(const std::vector &index, const std::shared_ptr &input); - - // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor - // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 - // @tparam T type of values in the Tensor Iterator + /// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor + /// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 + /// \tparam T type of values in the Tensor Iterator template class TensorIterator { public: @@ -493,7 +488,7 @@ class Tensor { }; // Specialization of TensorIterator for strings. It returns std::string_view for every item. - // @tparam DUMMY, used to mbe able to specialize the inner class + // \tparam DUMMY, used to mbe able to specialize the inner class template class TensorIterator { public: @@ -580,89 +575,197 @@ class Tensor { const char *data_; }; - // Return a TensorIterator that points to the start of the Tensor. - // It's the user responsibility to use the correct type that matches the Tensor type - // @param T The type of values in the Tensor - // @return TensorIterator + /// Return a TensorIterator that points to the start of the Tensor. + /// It's the user responsibility to use the correct type that matches the Tensor type + /// \tparam T The type of values in the Tensor + /// \return TensorIterator template TensorIterator begin() { - AllocateBuffer(SizeInBytes()); return TensorIterator(data_); } - // Return a linear iterator that points to the place after the last element of the Tensor. - // @tparam T The type of values in the Tensor - // @return TensorIterator + /// Return a linear iterator that points to the place after the last element of the Tensor. + /// \tparam T The type of values in the Tensor + /// \return TensorIterator template TensorIterator end() { return TensorIterator(data_end_); } - // Copies the last dimension at `index` from Tensor `src` to this Tensor. - // @param src Tensor - // @param index vector to the start of the dimension. The last dim should be 0 - // @return Status + /// Copies the last dimension at `index` from Tensor `src` to this Tensor. + /// \param[in] src Tensor + /// \param[in] index vector to the start of the dimension. The last dim should be 0 + /// \return Status Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); protected: - // Get the starting memory address for the data of the tensor. This potentially - // drives an allocation if the data is null. - // @return unsigned char* - unsigned char *GetMutableBuffer(); - - // A function that prints Tensor recursively, first called by print - // @param out - // @param cur_dim - // @param cur_index + /// Allocate memory for the tensor using the data_allocator + /// \param[in] length number of bytes to be allocated + /// \return Error Status + Status AllocateBuffer(const dsize_t &length); + + /// Get the starting memory address for the data of the tensor. This potentially + /// drives an allocation if the data is null. + /// \return unsigned char* + unsigned char *GetMutableBuffer() { return data_; } + + /// A function that prints Tensor recursively, first called by print + /// \param[in] out + /// \param[in] cur_dim + /// \param[in] cur_index void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; - // A function that prints info about the tensor - // @param out output stream + /// A function that prints info about the tensor + /// \param[out] out output stream void Print(std::ostream &out) const; - // A function that print the value as specified by its index - // @param index vector representing the index - // @param out + /// A function that print the value as specified by its index + /// \param[in] index vector representing the index + /// \param[out] out void PrintItemAt(const std::vector &index, std::ostream &out) const; - // Get pointer to item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return a pointer to the item specified at index of type `T` + /// Get pointer to item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return a pointer to the item specified at index of type `T` template Status GetItemPtr(T **, const std::vector &index) const; - // Get pointer to string located at `index` and the length of string - // @param index vector - // @return return a pointer to the string specified at index and the length of the string + /// Get pointer to string located at `index` and the length of string + /// \param[in] index vector + /// \return return a pointer to the string specified at index and the length of the string Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; - // Given a flat index of an item string, return the start and length of the item - // @param index flat index of the item - // @return start address of the ths string - // @return length of the string + /// Given a flat index of an item string, return the start and length of the item + /// \param[in] index flat index of the item + /// \param[out] start address of the ths string + /// \param[out] length of the string Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; - // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the - // tensor's type is a string, otherwise undefined address would be returned. - // @return address of the first string of the tensor. + /// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if + /// the tensor's type is a string, otherwise undefined address would be returned. \return address of the first string + /// of the tensor. uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } - // all access to shape_ should be via shape + /// all access to shape_ should be via shape TensorShape shape_; - // data type of tensor + /// data type of tensor DataType type_; - // pointer to the start of the physical data + /// pointer to the start of the physical data unsigned char *data_; - // An allocator for data_ + /// An allocator for data_ CharAllocPtr data_allocator_; - // pointer to the end of the physical data + /// pointer to the end of the physical data unsigned char *data_end_ = nullptr; + + private: +#ifdef ENABLE_ANDROID + friend class tensor::DETensor; +#endif + /// Copy raw data of a array based on shape and strides to the destination pointer + /// \param dst [out] Pointer to the destination array where the content is to be copied + /// \param[in] src Pointer to the source of strided array to be copied + /// \param[in] shape shape of the source array + /// \param[in] strides strides of the source array + /// \param[in] type_size number of bytes needed to store one array element's type + /// \return Status Code + static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size); + + /// const of the size of the offset variable + static constexpr uint8_t kOffsetSize = sizeof(offset_t); + +#ifdef ENABLE_PYTHON + /// Helper function to create a tensor from Numpy array of strings + /// \param[in] arr Numpy array + /// \param[out] out Created Tensor + /// \return Status + static Status CreateFromNpString(py::array arr, TensorPtr *out); +#endif }; template <> inline Tensor::TensorIterator Tensor::end() { return TensorIterator(data_, shape_.NumOfElements()); } + +/// Create a Tensor from a given list of strings. +/// @note: The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. +/// The offset array will store one extra value to find the length of the last string. +/// OFFSET_1, OFFSET_2, ..., OFFSET_n+1, STRING_1, STRING_2, ..., STRING_n +/// The value of each offset is the start index of the corresponding string +/// Offsets is of type offset_t +/// strings will ne null-terminated +/// example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) +/// |----------------------------------------------------------------| +/// | OFFSET ARRAY | STRINGS | +/// | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | +/// | 11 | 15 | 18 | abc\0 | de\0 | +/// |----------------------------------------------------------------| +/// \param[in] items elements of the tensor +/// \param[in] shape shape of the output tensor +/// \param[out] out output argument to hold the created Tensor +/// \return Status Code +template <> +inline Status Tensor::CreateFromVector(const std::vector &items, const TensorShape &shape, + TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, TensorShape({static_cast(items.size())}), + DataType(DataType::DE_STRING)); + if (items.size() == 0) { + if (shape.known()) { + return (*out)->Reshape(shape); + } + } + auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; + dsize_t total_length = std::accumulate(items.begin(), items.end(), 0, length_sum); + + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize + 1) * (*out)->shape_.NumOfElements() + kOffsetSize + total_length; + + (*out)->AllocateBuffer(num_bytes); + auto offset_arr = reinterpret_cast((*out)->data_); + uchar *buf = (*out)->GetStringsBuffer(); + + offset_t offset = buf - (*out)->data_; // the first string will start here + uint32_t i = 0; + for (const auto &str : items) { + // insert the start index of the string. + offset_arr[i++] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s((*out)->data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + (*out)->data_end_ = (*out)->data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) { + RETURN_IF_NOT_OK((*out)->Reshape(shape)); + } + return Status::OK(); +} +/// Create a string scalar Tensor from the given value. +/// \param[in] item value +/// \param[out] out Created tensor +/// \return Status code +template <> +inline Status Tensor::CreateScalar(const std::string &item, TensorPtr *out) { + return CreateFromVector({item}, TensorShape::CreateScalar(), out); +} } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_TENSOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_row.h b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h index e8f066c87b..613c256017 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor_row.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_CORE_TENSOR_ROW_H_ -#define DATASET_CORE_TENSOR_ROW_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ #include #include @@ -128,4 +128,4 @@ class TensorRow { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_TENSOR_ROW_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc index ff40062d37..19c3a6b457 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc @@ -19,7 +19,7 @@ #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/log_adapter.h" #include "minddata/dataset/core/constants.h" diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h index 4944f9e32c..179c922c6c 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CORE_TENSOR_SHAPE_H_ -#define DATASET_CORE_TENSOR_SHAPE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ #include #include @@ -193,4 +193,4 @@ class TensorShape { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_TENSOR_SHAPE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt index e3ead16d05..342eac8fb4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt @@ -15,12 +15,15 @@ add_library(engine OBJECT data_schema.cc dataset_iterator.cc ) -target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) + +if (ENABLE_PYTHON) + target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) +endif() if (ENABLE_TDTQUE) add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf - engine-cache-client engine-cache-server) + engine-cache-client engine-cache-server engine-datasetops-mapop) else () add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf - engine-cache-client engine-cache-server) + engine-cache-client engine-cache-server engine-datasetops-mapop) endif () diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index f25db87578..963d2e7e89 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_CACHE_CLIENT_H_ -#define DATASET_ENGINE_CACHE_CLIENT_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ #include #include @@ -23,9 +23,9 @@ #include #include -#include "./de_tensor_generated.h" #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/engine/cache/de_tensor_generated.h" #include "minddata/dataset/util/lock.h" namespace mindspore { @@ -138,4 +138,4 @@ class CacheClient { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_CACHE_CLIENT_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index 3b7fc057a2..a460e43aea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -141,8 +141,9 @@ Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const Re #undef CASE DataType type(dest); - std::shared_ptr ts = - std::make_shared(shape, type, static_cast(data.GetPointer()), data.GetSize()); + std::shared_ptr ts; + RETURN_IF_NOT_OK( + Tensor::CreateFromMemory(shape, type, static_cast(data.GetPointer()), data.GetSize(), &ts)); // Next we restore the real data which can be embedded or stored separately. if (ts->SizeInBytes() != data.GetSize()) { MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 3d0edc6dd8..6851cebe0c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_CACHE_REQ_H_ -#define DATASET_ENGINE_CACHE_REQ_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_REQ_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_REQ_H_ #include #include @@ -23,8 +23,8 @@ #include #include -#include "./de_tensor_generated.h" #include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/engine/cache/de_tensor_generated.h" #include "minddata/dataset/util/slice.h" #include "minddata/dataset/util/wait_post.h" @@ -111,6 +111,7 @@ class BatchFetchRequest : public BaseRequest { friend class CacheService; BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id) : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} + ~BatchFetchRequest() = default; Status RestoreRows(TensorTable *out); private: @@ -130,6 +131,8 @@ class CreationCacheRequest : public BaseRequest { CreateCacheFlag flag = CreateCacheFlag::kNone) : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} + ~CreationCacheRequest() = default; + std::string cookie() const { return cookie_; } private: @@ -142,6 +145,8 @@ class PurgeCacheRequest : public BaseRequest { public: friend class CacheServer; explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} + + ~PurgeCacheRequest() = default; }; /// \brief Request to destroy a cache class DestroyCacheRequest : public BaseRequest { @@ -149,6 +154,9 @@ class DestroyCacheRequest : public BaseRequest { friend class CacheServer; explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kDestroyCache) {} + + /// \brief Destructor + ~DestroyCacheRequest() = default; }; /// \brief Obtain the statistics of the current connection class GetStatRequest : public BaseRequest { @@ -156,6 +164,9 @@ class GetStatRequest : public BaseRequest { friend class CacheServer; friend class CacheService; explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} + + ~GetStatRequest() = default; + row_id_type GetMinRowId() const { auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); return msg->min_row_id(); @@ -217,9 +228,11 @@ class BuildPhaseDoneRequest : public BaseRequest { BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} + ~BuildPhaseDoneRequest() = default; + private: std::string cookie_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_CACHE_SERVICE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 13b68c4389..c0dc8c467b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_CACHE_SERVER_H_ -#define DATASET_ENGINE_CACHE_SERVER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ #include #include @@ -95,4 +95,4 @@ class CacheServer : public Service { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_CORE_CACHE_TENSOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CACHE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h index bf324e82e3..f4bd13e6ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_CACHE_SERVICE_H_ -#define DATASET_ENGINE_CACHE_SERVICE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ #include #include @@ -25,10 +25,10 @@ #include #include -#include "./de_tensor_generated.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/engine/cache/de_tensor_generated.h" #include "minddata/dataset/util/arena.h" #include "minddata/dataset/util/btree.h" #include "minddata/dataset/util/cache_pool.h" @@ -84,6 +84,7 @@ class CacheService : public Service { public: using state_type = std::underlying_type::type; ServiceStat() : min_(0), max_(0), state_(0) {} + ~ServiceStat() = default; CachePool::CacheStat stat_{}; row_id_type min_; row_id_type max_; @@ -140,4 +141,4 @@ class CacheService : public Service { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_CACHE_SERVICE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/connector.h b/mindspore/ccsrc/minddata/dataset/engine/connector.h index a91d8e68e9..0366609b3f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/connector.h +++ b/mindspore/ccsrc/minddata/dataset/engine/connector.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_CONNECTOR_H_ -#define DATASET_ENGINE_CONNECTOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONNECTOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONNECTOR_H_ #include #include @@ -208,4 +208,4 @@ class Connector { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_CONNECTOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h index 5fcb4c21a5..01f2b3a881 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATA_BUFFER_H_ -#define DATASET_ENGINE_DATA_BUFFER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_ #include #include @@ -105,4 +105,4 @@ class DataBuffer { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATA_BUFFER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc b/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc index 50d910251d..6db8ec5614 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc @@ -23,7 +23,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/status.h" #include "minddata/dataset/core/tensor_shape.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_schema.h b/mindspore/ccsrc/minddata/dataset/engine/data_schema.h index 96f6f2b118..a53c37f8c1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/data_schema.h +++ b/mindspore/ccsrc/minddata/dataset/engine/data_schema.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATA_SCHEMA_H_ -#define DATASET_ENGINE_DATA_SCHEMA_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_ #include #include @@ -205,4 +205,4 @@ class DataSchema { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATA_SCHEMA_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_SCHEMA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc index f75ca5d097..99c5c96b40 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc @@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) { out_map->clear(); TensorRow curr_row; + MS_LOG(INFO) << "get next as map start."; RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); + MS_LOG(INFO) << "fetchNextTensor success."; // Return empty map if there's no data if (curr_row.empty()) { @@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { // Once eof is handled, always return empty row. Class must be destroyed and recreated if you // want to iterate again. if (eof_handled_) { - return Status::OK(); + std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; + RETURN_STATUS_UNEXPECTED(err); } // Check if we need to get a new DataBuffer to iterate. @@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually // handle eoe and eof messages here. // - // An eoe buffer means we have iterated fully to the end of the tree. - // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of - // all operators. + // An eoe buffer means we have iterated an epoch. + // The next buffer in the pipeline might be an EOF or a databuffer for next epoch if (curr_buffer_->eoe()) { - MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; - - // Before returning the last empty vector, fetch the eof buffer which should be the last - // buffer, and then free it. - RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); - - if (!curr_buffer_->eof()) { - RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!"); - } - eof_handled_ = true; - curr_buffer_.reset(); // explicitly free the eof buffer - // Set tree to Finished state - root_->Tree()->SetFinished(); - + MS_LOG(INFO) << "End of data iteration."; + curr_buffer_.reset(); // explicitly free the eoe buffer return Status::OK(); } + // An eof buffer means it is the end of execution and all operators are shutting down. + // Because there is no more data to return to the caller, this will change `eof_handled_` state and + // returns status unexpected error. if (curr_buffer_->eof()) { - // An eof by itself, without being preceded by an eoe, is possible if a repeat operator - // exists below us in the stack. Repeat operator eats eoe's but eventually allows the - // flow of an eof up the pipeline by itself. eof_handled_ = true; curr_buffer_.reset(); // explicitly free the eof buffer - // Set tree to Finished state - root_->Tree()->SetFinished(); - return Status::OK(); + std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; + RETURN_STATUS_UNEXPECTED(err); } } @@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) { // Once eof is handled, always return empty row. Class must be destroyed and recreated if you // want to iterate again. if (eof_handled_) { - return Status::OK(); + std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; + RETURN_STATUS_UNEXPECTED(err); } // Check if we need to get a new DataBuffer to iterate. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { + // GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and + // this child iterator might not see EOE buffer. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); - // Unlike the DatasetIterator, this child iterator does not quit after eoe. - // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the + // If an eoe is picked up here, we simply return an empty vector and it's up to the // caller to decide what it wants to do next. if (curr_buffer_->eoe()) { MS_LOG(DEBUG) << "Child iterator picked up EOE."; end_epoch_ = true; return Status::OK(); + } else { + end_epoch_ = false; } if (curr_buffer_->eof()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h index 253d1604e2..a309fdda72 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h +++ b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASET_ITERATOR_H_ -#define DATASET_ENGINE_DATASET_ITERATOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASET_ITERATOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASET_ITERATOR_H_ #include #include @@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase { // @return The string to column id mapping. std::unordered_map GetColumnNameMap() const override; + // Return T/F if end of epoch + bool end_of_epoch() { return end_epoch_; } + private: DatasetOp *current_op_; // The parent operator. We consume from it's children. int32_t child_idx_; // The specific child this iterator will fetch from. @@ -153,4 +156,4 @@ class ChildIterator : public IteratorBase { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASET_ITERATOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASET_ITERATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt index a2cd6dc07a..ef97f3322e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(source) +add_subdirectory(map_op) file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) @@ -9,7 +10,6 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES pipeline_op.cc batch_op.cc device_queue_op.cc - map_op.cc project_op.cc rename_op.cc repeat_op.cc @@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES shuffle_op.cc zip_op.cc concat_op.cc + epoch_ctrl_op.cc cache_base_op.cc cache_lookup_op.cc cache_op.cc @@ -31,8 +32,8 @@ if (ENABLE_PYTHON) barrier_op.cc filter_op.cc build_vocab_op.cc + build_sentence_piece_vocab_op.cc ) endif() add_library(engine-datasetops OBJECT ${DATASET_ENGINE_DATASETOPS_SRC_FILES}) - diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc index 51ea232e68..2090661981 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc @@ -212,8 +212,6 @@ Status BarrierOp::getNextTensorRow(TensorRow *new_row) { // A function that prints info about the Operator void BarrierOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h index a3ac843272..cdbae0941e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ #include #include @@ -121,6 +121,10 @@ class BarrierOp : public PipelineOp { // @param show_all - if it should print everything void Print(std::ostream &out, bool show_all) const override; + // Op name getter + // @return Name of the current Op + std::string Name() const override { return kBarrierOp; } + // Provide stream operator for displaying it friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { bo.Print(out, false); @@ -166,4 +170,4 @@ class BarrierOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 844d054307..6a681c1660 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -18,7 +18,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #ifdef ENABLE_PYTHON #include "minddata/dataset/core/pybind_support.h" #endif @@ -135,8 +135,6 @@ Status BatchOp::operator()() { } void BatchOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -176,12 +174,15 @@ Status BatchOp::BatchRows(const std::unique_ptr *src, const std::u std::shared_ptr new_tensor; if (first_type.IsNumeric()) { // numeric tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, first_type, &new_tensor)); dsize_t j = 0; for (auto row : **src) { std::shared_ptr old_tensor = row.at(i); // row j, column i if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first - RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); + if (new_shape.NumOfElements() != 0) { + RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); + } + // Don't do anything if the tensor has no data } else { RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i)); } @@ -194,7 +195,7 @@ Status BatchOp::BatchRows(const std::unique_ptr *src, const std::u strings.emplace_back(*itr); } } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, new_shape, &new_tensor)); } batched_row.emplace_back(new_tensor); } @@ -352,7 +353,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou py::list output_list = py::cast(ret_tuple[i]); for (size_t j = 0; j < output_list.size(); j++) { std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast(output_list[j]))); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast(output_list[j]), &out)); output_batch.push_back(std::move(out)); } output->push_back(std::move(output_batch)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h index 0c042433f7..503415704f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ #include #include @@ -200,7 +200,7 @@ class BatchOp : public ParallelOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "BatchOp"; } + std::string Name() const override { return kBatchOp; } // batch the rows in src table then put it to dest table // @param const std::unique_ptr *src - table that has the rows for batching @@ -284,4 +284,4 @@ class BatchOp : public ParallelOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc index 138bb7980b..971f14c669 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -107,8 +107,6 @@ Status BucketBatchByLengthOp::EoeReceived(int32_t) { return Status::OK(); } -void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } - Status BucketBatchByLengthOp::operator()() { TaskManager::FindMe()->Post(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h index 332ff4bb22..e14a5ff760 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ #include #include @@ -109,10 +109,7 @@ class BucketBatchByLengthOp : public PipelineOp { // @return Status - The error code returned Status EoeReceived(int32_t) override; - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; + std::string Name() const override { return kBucketBatchByLengthOp; } // << Stream output operator overload // @notes This allows you to write the debug print info using stream operators @@ -152,4 +149,4 @@ class BucketBatchByLengthOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.cc new file mode 100644 index 0000000000..379adc1fc7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" + +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr vocab, + std::vector col_names, uint32_t vocab_size, + float character_coverage, SentencePieceModel model_type, + const std::unordered_map ¶ms, + int32_t op_conn_size) + : PipelineOp(op_conn_size), + vocab_size_(vocab_size), + vocab_(vocab), + col_names_(col_names), + character_coverage_(character_coverage), + model_type_(model_type), + params_(params), + col_id_(0) { + sentence_queue_ = std::make_unique>(op_conn_size); + read_done_ = false; + ret_status_ = Status::OK(); +} + +Status BuildSentencePieceVocabOp::operator()() { + RETURN_UNEXPECTED_IF_NULL(tree_); + RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK( + tree_->AllTasks()->CreateAsyncTask("sentenceTask", std::bind(&BuildSentencePieceVocabOp::SentenceThread, this))); + TaskManager::FindMe()->Post(); + child_iterator_ = std::make_unique(this, 0, 0); + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + + bool eoe_warning = false; // give out warning if receive more than 1 eoe + while (child_iterator_->eof_handled() == false) { + while (new_row.empty() == false) { + RETURN_IF_NOT_OK(sentence_queue_->EmplaceBack(new_row)); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); + eoe_warning = true; + } + // add empty tensorRow for quit + TensorRow empty_row = {}; + sentence_queue_->EmplaceBack(empty_row); + return Status::OK(); +} + +Status BuildSentencePieceVocabOp::SentenceThread() { + TaskManager::FindMe()->Post(); + if (col_names_.empty() == true) { + auto itr = column_name_id_map_.find("text"); + CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), + "'text' column doesn't exist when column name is empty"); + col_id_ = itr->second; + } else { + auto itr = column_name_id_map_.find(col_names_[0]); + CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col_names_[0] + "column doesn't exist"); + col_id_ = itr->second; + } + std::unique_ptr sentence_iter = std::make_unique(this); + std::string model_proto; + sentencepiece::util::Status s_status = + sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto); + if (!s_status.ok()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message()); + } else { + if (vocab_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "sentencepiece vocab ptr must not be nullptr"); + } + vocab_->set_model_proto(model_proto); + } + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + return Status::OK(); +} + +std::unordered_map BuildSentencePieceVocabOp::BuildParams() { + std::unordered_map ret_params; + ret_params["vocab_size"] = std::to_string(vocab_size_); + ret_params["character_coverage"] = std::to_string(character_coverage_); + if (model_type_ == SentencePieceModel::kBpe) { + ret_params["model_type"] = "BPE"; + } else if (model_type_ == SentencePieceModel::kChar) { + ret_params["model_type"] = "CHAR"; + } else if (model_type_ == SentencePieceModel::kWord) { + ret_params["model_type"] = "WORD"; + } else { + ret_params["model_type"] = "UNIGRAM"; + } + // filter some params that set by function param + // filter model_prefix that must be empty + for (auto param : params_) { + std::string key = param.first; + if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" || + key == "model_type") { + continue; + } + ret_params[key] = param.second; + } + + ret_params["model_prefix"] = ""; + ret_params["minloglevel"] = "1"; + return ret_params; +} + +bool BuildSentencePieceVocabOp::Done() { return read_done_; } + +void BuildSentencePieceVocabOp::Next(std::string *sentence) { + TensorRow new_row; + Status s = sentence_queue_->PopFront(&new_row); + + if (s.IsError()) { + read_done_ = true; + ret_status_ = s; + return; + } + if (new_row.empty() == true) { + read_done_ = true; + ret_status_ = Status::OK(); + return; + } + + if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) { + ret_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "for dataset only words on string columns or must bu scalar"); + read_done_ = true; + return; + } + + std::string_view sentence_v; + new_row[col_id_]->GetItemAt(&sentence_v, {}); + + std::string st{sentence_v}; + *sentence = st; + ret_status_ = Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status BuildSentencePieceVocabOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +Status BuildSentencePieceVocabOp::Builder::Build(std::shared_ptr *op) { + (*op) = std::make_shared(builder_vocab_, builder_col_names_, builder_vocab_size_, + builder_character_coverage_, builder_model_type_, builder_params_, + builder_connector_size_); + return Status::OK(); +} + +BuildSentencePieceVocabOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_connector_size_ = cfg->op_connector_size(); +} + +BuildSentencePieceVocabOp::DatasetSentenceIterator::DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr) + : s_p_vocab_ptr_(s_p_vocab_ptr) {} + +bool BuildSentencePieceVocabOp::DatasetSentenceIterator::done() const { + if (s_p_vocab_ptr_ == nullptr) { + return true; + } + return s_p_vocab_ptr_->Done(); +} + +void BuildSentencePieceVocabOp::DatasetSentenceIterator::Next() { + if (s_p_vocab_ptr_ == nullptr) { + return; + } + s_p_vocab_ptr_->Next(&value_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h new file mode 100644 index 0000000000..c868585344 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h @@ -0,0 +1,189 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/text/sentence_piece_vocab.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace dataset { +namespace py = pybind11; + +class BuildSentencePieceVocabOp : public PipelineOp { + public: + class Builder { + public: + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param uint32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(uint32_t size) { + builder_connector_size_ = size; + return *this; + } + + // Setter method + // @param uint32_t size + // @return Builder & reference to builder class object + Builder &SetVocabSize(uint32_t size) { + builder_vocab_size_ = size; + return *this; + } + + // Setter method + // @param float charactor corverage - to determine the minimum symbols + // @return Builder & reference to builder class object + Builder &SetCharacterCoverage(float character_coverage) { + builder_character_coverage_ = character_coverage; + return *this; + } + + // Setter method + // @param SentencePieceModel model_type - model algorithm + // @return Builder & reference to builder class object + Builder &SetModelType(SentencePieceModel model_type) { + builder_model_type_ = model_type; + return *this; + } + + // Setter method + // @param std::unordered_map params + // @return Builder & reference to builder class object + Builder &SetParams(std::unordered_map params) { + builder_params_ = params; + return *this; + } + + // Setter method + // @param std::shared_ptr vocab + // @return Builder & reference to builder class object + Builder &SetVocab(std::shared_ptr vocab) { + builder_vocab_ = vocab; + return *this; + } + + // set columns names + // @param const std::vector & col_names - name of columns to get words + // @return Builder & reference to builder class object + Builder &SetColumnNames(const std::vector &col_names) { + builder_col_names_ = col_names; + return *this; + } + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + uint32_t builder_connector_size_; + uint32_t builder_vocab_size_; + float builder_character_coverage_; + SentencePieceModel builder_model_type_; + std::unordered_map builder_params_; + std::vector builder_col_names_; + std::shared_ptr builder_vocab_; + }; + + public: + class DatasetSentenceIterator : public sentencepiece::SentenceIterator { + public: + explicit DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr); + ~DatasetSentenceIterator() {} + + bool done() const override; + void Next() override; + const std::string &value() const override { return value_; } + sentencepiece::util::Status status() const override { return sentencepiece::util::Status(); } + + private: + std::string value_; + BuildSentencePieceVocabOp *s_p_vocab_ptr_; + }; + + BuildSentencePieceVocabOp(std::shared_ptr vocab, std::vector col_names, + uint32_t vocab_size, float character_coverage, SentencePieceModel model_type, + const std::unordered_map ¶ms, int32_t op_conn_size); + + ~BuildSentencePieceVocabOp() = default; + + // the thread for sentence train + Status SentenceThread(); + + Status EofReceived(int32_t) override { return Status::OK(); } + + Status EoeReceived(int32_t) override { return Status::OK(); } + + Status operator()() override; + + // Getter + // @return the number of workers + int32_t num_producers() const override { return 1; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return 1; } + + Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); } + + std::string Name() const override { return kBuildSentencePieceVocabOp; } + + // build the input params for sentence api + std::unordered_map BuildParams(); + + bool Done(); + void Next(std::string *sentence); + + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + private: + bool read_done_; + Status ret_status_; + uint32_t vocab_size_; + float character_coverage_; + SentencePieceModel model_type_; + std::unordered_map params_; + std::shared_ptr vocab_; + std::vector col_names_; + uint32_t col_id_; + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + std::unique_ptr> sentence_queue_; // master thread assigns each worker TensorRow via this +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc index 8ed51ebbb6..5ab4f7251b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc @@ -17,11 +17,13 @@ #include "minddata/dataset/engine/datasetops/build_vocab_op.h" #include +#include #include #include #include #include #include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -202,5 +204,27 @@ BuildVocabOp::Builder::Builder() builder_num_workers_ = cfg->num_parallel_workers(); builder_connector_size_ = cfg->op_connector_size(); } + +// A print method typically used for debugging +void BuildVocabOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCode is needed here to show more info about the op." + << "\n\n"; + } +} + +// Pre-Visitor accept method for NodePass +Status BuildVocabOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h index 42ea0deb5c..07650381f8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ #include #include @@ -131,6 +131,22 @@ class BuildVocabOp : public ParallelOp { ~BuildVocabOp() = default; + /// \brief A print method typically used for debugging + /// \param[out] out The output stream to write output to + /// \param[in] show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + std::string Name() const override { return kBuildVocabOp; } + + /// \briefStream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param[out] out Reference to the output stream being overloaded + /// \param[in] vop - reference to the BuildVocabOp to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BuildVocabOp &vop) { + vop.Print(out, false); + return out; + } + Status WorkerEntry(int32_t worker_id) override; // collect the work product from each worker @@ -152,6 +168,12 @@ class BuildVocabOp : public ParallelOp { Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + private: const int32_t interval_; bool special_first_; @@ -171,4 +193,4 @@ class BuildVocabOp : public ParallelOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index 1b0890686f..8e3a291d72 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -22,8 +22,6 @@ namespace mindspore { namespace dataset { // A print method typically used for debugging void CacheBase::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -91,13 +89,14 @@ Status CacheBase::FetchSamplesToWorkers() { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); // If repeat but the not last repeat, wait for reset. - if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (!IsLastIteration()) { MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; RETURN_IF_NOT_OK(epoch_sync_.Wait()); } else { // We can break out from the loop. break; } + UpdateRepeatAndEpochCounter(); } while (true); // Flow the eof before exit RETURN_IF_NOT_OK( diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h index fb3e999b76..40f3426394 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ #include #include @@ -59,6 +59,9 @@ class CacheBase : public ParallelOp { /// \param show_all A bool to control if you want to show all info or just a summary void Print(std::ostream &out, bool show_all) const override; + /// \brief Gives a name to the class, typically used for debugging + std::string Name() const override { return kCacheBase; } + /// \brief << Stream output operator overload /// \notes This allows you to write the debug print info using stream operators /// \param out reference to the output stream being overloaded @@ -105,4 +108,4 @@ class CacheBase : public ParallelOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h index 46a58c5d02..adec3d8283 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ #include #include @@ -100,7 +100,7 @@ class CacheLookupOp : public CacheBase, public Sampler { Status GetNextSample(std::unique_ptr *out_buffer) override; void Print(std::ostream &out, bool show_all) const override; bool AllowCacheMiss() override { return true; } - std::string Name() const override { return "CacheLookupOp"; } + std::string Name() const override { return kCacheLookupOp; } /// \brief Base-class override for NodePass visitor acceptor /// \param[in] p The node to visit @@ -119,4 +119,4 @@ class CacheLookupOp : public CacheBase, public Sampler { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 75579dc3a6..b9be973d9c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -28,9 +28,7 @@ namespace mindspore { namespace dataset { CacheMergeOp::~CacheMergeOp() = default; -void CacheMergeOp::Print(std::ostream &out, bool show_all) - const { // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; +void CacheMergeOp::Print(std::ostream &out, bool show_all) const { if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -96,7 +94,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); } } - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + RETURN_IF_NOT_OK(EofReceived(worker_id)); return Status::OK(); } Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { @@ -226,7 +224,8 @@ void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { if (GetState() == State::kEmpty) { // We will do a deep copy for (auto &ts : row) { - auto out_ts = std::make_shared(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); + std::shared_ptr out_ts; + Tensor::CreateFromTensor(ts, &out_ts); cleaner_copy_.push_back(out_ts); } cleaner_copy_.setId(row.getId()); @@ -293,10 +292,24 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { Status CacheMergeOp::EoeReceived(int32_t worker_id) { // If we are in a repeat path, send the eoe up. // Otherwise ignore it. - if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { + if (op_total_repeats_ > 1) { return DatasetOp::EoeReceived(worker_id); } return Status::OK(); } + +// Base-class override for handling cases when an eof is received. +Status CacheMergeOp::EofReceived(int32_t worker_id) { + // If we are not in a repeated path, then the merge op gets a eof by itself, without first + // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. + // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class + // provides that for us. + if (op_total_repeats_ == 1) { + MS_LOG(DEBUG) << "Cache merge sending eoe"; + RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); + } + MS_LOG(DEBUG) << "Cache merge sending eof"; + return DatasetOp::EofReceived(worker_id); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h index df37465fc4..a4d92d1221 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ #include #include @@ -140,6 +140,8 @@ class CacheMergeOp : public ParallelOp { std::shared_ptr cache_client, const std::shared_ptr &sampler); ~CacheMergeOp(); void Print(std::ostream &out, bool show_all) const override; + std::string Name() const override { return kCacheMergeOp; } + friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { mo.Print(out, false); return out; @@ -176,6 +178,11 @@ class CacheMergeOp : public ParallelOp { /// \return Status object Status EoeReceived(int32_t worker_id) override; + /// \brief Base-class override for handling cases when an eof is received. + /// \param worker_id - The worker id + /// \return Status - The error code return + Status EofReceived(int32_t worker_id) override; + protected: Status ComputeColMap() override; @@ -193,4 +200,4 @@ class CacheMergeOp : public ParallelOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index 143c45b2dc..c742d82522 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -85,6 +85,10 @@ Status CacheOp::operator()() { TaskManager::FindMe()->Post(); // Wait for the workers to finish caching the rows. RETURN_IF_NOT_OK(WaitForCachingAllRows()); + // Current repeats and current epochs may have increased when caching all rows with DatasetOp::GetNextInput. + // But they shouldn't be increased because now cache op is starting to act as a leaf and its epoch hasn't started. + op_current_repeats_ = 0; + op_current_epochs_ = 0; RETURN_IF_NOT_OK(FetchSamplesToWorkers()); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h index dd34d54973..f6af02fdba 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ #include #include @@ -140,7 +140,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { /// \brief Base-class override for handling cases if we allow cache miss bool AllowCacheMiss() override { return false; } /// \brief Base-class override for the name of this operator - std::string Name() const override { return "CacheOp"; } + std::string Name() const override { return kCacheOp; } /// \brief A public wrapper for creating the cache through the client /// \param[in] cache_crc The crc that identifies the cache /// \see cache_pass.cc @@ -165,4 +165,4 @@ class CacheOp : public CacheBase, public RandomAccessOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index 7acb68350b..d2dc8f535c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -16,7 +16,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/datasetops/concat_op.h" @@ -42,8 +42,6 @@ ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), c // A function that prints info about the Operator void ConcatOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); @@ -87,6 +85,7 @@ Status ConcatOp::operator()() { auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); } + UpdateRepeatAndEpochCounter(); } CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, "Something went wrong, eof count does not match the number of children."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h index 3d3d9df71c..58653b5a01 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ #include #include @@ -77,7 +77,7 @@ class ConcatOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "ConcatOp"; } + std::string Name() const override { return kConcatOp; } // Private function for computing the assignment of the column name map. // @return - Status @@ -94,4 +94,4 @@ class ConcatOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 9254141308..f866885f7f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -26,6 +26,7 @@ #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/opt/pass.h" @@ -41,7 +42,10 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler operator_id_(kInvalidOperatorId), tree_(nullptr), state_(OpState::kDeOpIdle), - op_ctrl_flags_(kDeOpNone), + op_total_repeats_(kInfiniteRepeat), + op_num_repeats_per_epoch_(kInfiniteRepeat), + op_current_repeats_(0), + op_current_epochs_(0), out_connector_(nullptr) { // The operator starts out with an invalid operator id. The only way to // get it out of invalid state is to assign the operator to an execution tree. @@ -102,6 +106,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr to_add) { } return Status::OK(); } +// Removes child operator in this operator. +Status DatasetOp::RemoveChildren() { + for (const auto &child : child_) { + child->RemoveParent(this); + } + child_.clear(); + + return Status::OK(); +} // Adds a parent operator to this operator void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } @@ -185,6 +198,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { } } +// Getter function to get all of our children. +std::vector> DatasetOp::children() const { return child_; } + +// Getter function to get all of our parents. +std::vector DatasetOp::parents() const { return parent_; } + // Creates the connector within this operator void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers @@ -206,7 +225,10 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { // When show_all is true, we display more detailed output for the op. // Derived printers should show their own header info, then call base class printer, followed by // derived-specific items. - // For now, the base class doesn't have any summary info to show so it's a no-op in that case. + + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; + if (show_all) { // The detailed display will show common base class info of the op. Allow the derived class to print // it's own id and name though as the first line. @@ -218,8 +240,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { for (size_t i = 0; i < parent_.size(); i++) { out << "\n Parent[" << i << "] id: " << parent_[i]->id(); } - out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex - << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); + out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ + << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; if (sampler_) { sampler_->Print(out, show_all); } @@ -228,15 +250,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { // Gets the next buffer from the given child Status DatasetOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { -#if defined(_WIN32) || defined(_WIN64) - RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), p_buffer, retry_if_eoe)); -#else - std::unique_ptr next_buff; // pop is a blocked call and will throw an interruption if the whole group shuts down. - RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), &next_buff, retry_if_eoe)); - - *p_buffer = std::move(next_buff); -#endif + RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), p_buffer, retry_if_eoe)); return Status::OK(); } @@ -253,6 +268,7 @@ Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t wo RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); // Loop until non EOE is received while (buf->eoe()) { + UpdateRepeatAndEpochCounter(); RETURN_IF_NOT_OK(EoeReceived(worker_id)); if (state_ == OpState::kDeOpIdle) { *p_buffer = std::move(buf); @@ -372,6 +388,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { op->tree_->Print(ss, op); std::string ss_str = ss.str(); + // Filter out the Num workers field when generating the check sum + ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), ""); + + // Filter out Number of rows when generating the check sum + ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), ""); + // Filter out the Operator control flags field when generating the check sum ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); @@ -384,8 +407,15 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); + MS_LOG(DEBUG) << "Printing the tree for generating crc:\n" << ss_str; + uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); return cache_crc; } + +void DatasetOp::UpdateRepeatAndEpochCounter() { + op_current_repeats_++; + if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index b4630c1652..3c83582c9f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ -#define DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ #include #include @@ -28,6 +28,31 @@ namespace mindspore { namespace dataset { +constexpr char kBarrierOp[] = "BarrierOp"; +constexpr char kBatchOp[] = "BatchOp"; +constexpr char kBucketBatchByLengthOp[] = "BucketBatchByLengthOp"; +constexpr char kBuildSentencePieceVocabOp[] = "BuildSentencePieceVocabOp"; +constexpr char kBuildVocabOp[] = "BuildVocabOp"; +constexpr char kCacheBase[] = "CacheBase"; +constexpr char kCacheLookupOp[] = "CacheLookupOp"; +constexpr char kCacheMergeOp[] = "CacheMergeOp"; +constexpr char kCacheOp[] = "CacheOp"; +constexpr char kConcatOp[] = "ConcatOp"; +constexpr char kDatasetOp[] = "DatasetOp"; +constexpr char kDeviceQueueOp[] = "DeviceQueueOp"; +constexpr char kEpochCtrlOp[] = "EpochCtrlOp"; +constexpr char kFilterOp[] = "FilterOp"; +constexpr char kMapOp[] = "MapOp"; +constexpr char kParallelOp[] = "ParallelOp"; +constexpr char kPipelineOp[] = "PipelineOp"; +constexpr char kProjectOp[] = "ProjectOp"; +constexpr char kRenameOp[] = "RenameOp"; +constexpr char kRepeatOp[] = "RepeatOp"; +constexpr char kShuffleOp[] = "ShuffleOp"; +constexpr char kSkipOp[] = "SkipOp"; +constexpr char kTakeOp[] = "TakeOp"; +constexpr char kZipOp[] = "ZipOp"; + // Forward declare class ExecutionTree; @@ -45,13 +70,7 @@ class DatasetOp : public std::enable_shared_from_this { public: static constexpr int32_t kInvalidOperatorId = -1; - - // Operator control flags - enum OpControlFlags { - kDeOpNone = 0, - kDeOpRepeated = 1, // Operator is a node in a repeat path - kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop - }; + static constexpr int32_t kInfiniteRepeat = -1; // Flags that control operator runtime behaviours enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; @@ -76,6 +95,9 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Status eerror code returned Status Remove(); + // Removes child operator in this operator. + Status RemoveChildren(); + /// \brief Getter function to get a shared pointer to our child /// \param[in] child_index An operator can have n children. Indicates which child to return. /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index @@ -86,6 +108,12 @@ class DatasetOp : public std::enable_shared_from_this { /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. void Parent(DatasetOp **parent, int32_t parent_index) const; + // Getter function to get all of our children. + std::vector> children() const; + + // Getter function to get all of our parents. + std::vector parents() const; + // Inserts a operator as the parent current op. // Inserted op will become the sole parent of the current op. // The existing parent of the current op will be transferred to the inserted op. @@ -204,13 +232,23 @@ class DatasetOp : public std::enable_shared_from_this { /// \return T/F if this is an inlined operator bool inlined() const { return (oc_queue_size_ == 0); } - /// \brief Setter function - /// \return Sets the control flags - void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } + /// \brief Setter function, set the number of total repeats for the operator + void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } + + /// \brief Setter function, set the number of repeats per epoch for the operator + void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; } - /// \brief Setter function - /// \return Sets the control flags - void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } + /// \brief Getter function + /// \return The number of required repeats for the operator + int32_t op_total_repeats() { return op_total_repeats_; } + + /// \brief Getter function + /// \return The number of required epochs for the operator + int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; } + + /// \brief Getter function + /// \return The number of repeats per epoch for the operator + int32_t op_num_repeats_per_epoch() { return op_num_repeats_per_epoch_; } /// \brief Register the internal worker connectors. No op unless it is a parallel op /// \return Status @@ -283,7 +321,7 @@ class DatasetOp : public std::enable_shared_from_this { /// Op name getter /// \return Name of the current Op - virtual std::string Name() const { return "DatasetOp"; } + virtual std::string Name() const = 0; /// Execution Tree getter /// \return Pointer to the ExecutionTree the current op belongs to, no ownership @@ -316,6 +354,10 @@ class DatasetOp : public std::enable_shared_from_this { /// \return boolean returns true if it's a leaf bool IsLeaf() { return (child_.empty()); } + /// Checks if an operator has reached its last iteration + /// \return boolean returns true if it's last iteration + bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } + protected: /// \brief Removes a parent operator from this operator /// \notes External callers do not have access to this function @@ -334,6 +376,10 @@ class DatasetOp : public std::enable_shared_from_this { /// \return - Status virtual Status ComputeColMap(); + /// Increase op_current_repeats_ by 1 when one repeat finished. + /// If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1. + void UpdateRepeatAndEpochCounter(); + std::vector> child_; // Child nodes std::vector parent_; // Parent nodes. No ownership std::shared_ptr sampler_; // Some leaf ops might have a sampler @@ -341,7 +387,10 @@ class DatasetOp : public std::enable_shared_from_this { int32_t operator_id_; // Generated id for the node ExecutionTree *tree_; // Back pointer to our tree. OpState state_; // The state of the operator, Running, Idle, Terminated - uint32_t op_ctrl_flags_; // Flags for the operator + int32_t op_total_repeats_; // Required number of repeats for the operator + int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator + int32_t op_current_repeats_; // Current number of repeats the operator has handled + int32_t op_current_epochs_; // Current number of epochs the operator has handled std::unique_ptr out_connector_; // Output Connector std::unordered_map column_name_id_map_; // Mapping between col index and col name std::mutex column_name_map_mutex_; // For protecting shared access to the column map @@ -360,4 +409,4 @@ class DatasetOp : public std::enable_shared_from_this { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 4fe779246b..72f47c0fbf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -25,19 +25,21 @@ #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" #include "minddata/dataset/util/status.h" #include "minddata/dataset/util/task_manager.h" namespace mindspore { namespace dataset { DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - int32_t op_connector_size, int64_t num_batch) + int32_t op_connector_size, bool send_epoch_end) : PipelineOp(op_connector_size), channel_name_(channel_name), device_type_(device_type), device_id_(device_id), prefetch_size_(prefetch_size), - num_batch_(num_batch) {} + send_epoch_end_(send_epoch_end), + stop_send_(false) {} DeviceQueueOp::~DeviceQueueOp() {} @@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size) : builder_prefetch_size_(prefetch_size), builder_device_id_(0), builder_device_type_(DeviceType::CPU), - builder_channel_name_(""), - builder_num_batch_(0) { + builder_channel_name_("") { std::shared_ptr cfg = GlobalContext::config_manager(); builder_op_connector_size_ = cfg->op_connector_size(); } @@ -64,6 +65,19 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) { return Status::OK(); } +Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) const { + // this method checks if the buffer meets the conditions to be sent to TDT + if (buffer->NumRows() != 0) { + TensorRow row; + buffer->GetRow(0, &row); + for (const auto &item : row) { + CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); + CHECK_FAIL_RETURN_UNEXPECTED(item->HasData(), "Cannot send tensor with no data."); + } + } + return Status::OK(); +} + Status DeviceQueueOp::operator()() { TaskManager::FindMe()->Post(); @@ -82,23 +96,10 @@ Status DeviceQueueOp::operator()() { return Status::OK(); } -Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) const { - // this method checks if the buffer meets the conditions to be sent to TDT - if (buffer->NumRows() != 0) { - TensorRow row; - buffer->GetRow(0, &row); - for (const auto &item : row) { - CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); - } - } - return Status::OK(); -} - #ifdef ENABLE_TDTQUE Status DeviceQueueOp::SendDataToAscend() { MS_LOG(INFO) << "Device queue, sending data to Ascend."; int64_t total_batch = 0; - bool is_break_loop = false; double batch_start_time, end_time; int32_t batch_cost, tdt_cost; int32_t connector_size = 0; @@ -115,15 +116,20 @@ Status DeviceQueueOp::SendDataToAscend() { std::unique_ptr current_buffer; RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - while (!current_buffer->eof() && !is_break_loop) { - while (!current_buffer->eoe() && !is_break_loop) { + while (!current_buffer->eof()) { + while (!current_buffer->eoe()) { RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); TensorRow currRow; - for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { + for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) { RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); if (status == TdtStatus::FAILED) { - return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); + if (stop_send_) { + MS_LOG(INFO) << "stop_send received"; + return Status::OK(); + } else { + return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); + } } if (isProfilingEnable) { @@ -140,9 +146,6 @@ Status DeviceQueueOp::SendDataToAscend() { profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); } total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - is_break_loop = true; - } } if (isProfilingEnable) { connector_size = ChildOpConnectorSize(); @@ -150,6 +153,19 @@ Status DeviceQueueOp::SendDataToAscend() { } RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); } + if (current_buffer->eoe() && send_epoch_end_) { + TensorRow currRow; + auto status = + tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); + if (status == TdtStatus::FAILED) { + if (stop_send_) { + MS_LOG(INFO) << "stop_send received"; + return Status::OK(); + } else { + return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); + } + } + } if (isProfilingEnable) { connector_size = ChildOpConnectorSize(); connector_capacity = ChildOpConnectorCapacity(); @@ -158,7 +174,7 @@ Status DeviceQueueOp::SendDataToAscend() { } tree_->SetFinished(); - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + MS_LOG(INFO) << "Device queue total batch is " << total_batch; return Status::OK(); } @@ -196,27 +212,22 @@ Status DeviceQueueOp::SendDataToGPU() { } RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - is_break_loop = true; - } } - if (!TaskManager::FindMe()->Interrupted()) + if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed()) RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); else is_break_loop = true; } - if (!TaskManager::FindMe()->Interrupted()) + if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed()) RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); else is_break_loop = true; } - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + MS_LOG(INFO) << "Device queue total batch is " << total_batch << "."; GpuBufferMgr::GetInstance().Close(handle); - GpuBufferMgr::GetInstance().CloseConfirm(); - return Status::OK(); } @@ -240,8 +251,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, con if (ret == BlockQueueStatus_T::ERROR_INPUT) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); } else { - MS_LOG(WARNING) << "Retry pushing data..."; - continue; + if (!stop_send_) { + MS_LOG(WARNING) << "Retry pushing data..."; + continue; + } + break; } } else { break; @@ -283,20 +297,16 @@ Status DeviceQueueOp::SendDataToCPU() { MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - break; - } + if (stop_send_) break; } } - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + MS_LOG(INFO) << "Device queue total batch is " << total_batch << "."; return Status::OK(); } void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h index 0fb4fb093d..224d36b85f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ #include #include #include #include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/util/status.h" #ifdef ENABLE_TDTQUE @@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp { return *this; } - Builder &SetNumBatch(int64_t num_batch) { - builder_num_batch_ = num_batch; + Builder &SetSendEpochEnd(bool send_epoch_end) { + builder_send_epoch_end_ = send_epoch_end; return *this; } @@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp { // to call this Build() method. It will instantiate the DeviceQueueOp // and return it to caller as a shared pointer. Status Build(std::shared_ptr *ptr) { - *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, - builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); + *ptr = + std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, + builder_prefetch_size_, builder_op_connector_size_, builder_send_epoch_end_); return Status::OK(); } @@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp { int32_t builder_device_id_; DeviceType builder_device_type_; std::string builder_channel_name_; - int64_t builder_num_batch_; int32_t builder_op_connector_size_; + bool builder_send_epoch_end_; }; // Name: constructor // Description DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - int32_t op_connector_size, int64_t num_batch); + int32_t op_connector_size, bool send_epoch_end); // Name: destructor // Description @@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp { const int32_t get_prefetch_size() { return prefetch_size_; } + void StopSend() { stop_send_ = true; } + // Name: Print() // Description: A function that prints info about the node void Print(std::ostream &out, // In: The output stream to print to @@ -142,13 +146,14 @@ class DeviceQueueOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "DeviceQueueOp"; } + std::string Name() const override { return kDeviceQueueOp; } private: // Name: checkExceptions(DataBuffer); // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp Status CheckExceptions(const std::unique_ptr &buffer) const; + private: #ifdef ENABLE_TDTQUE Status SendDataToAscend(); #endif @@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp { DeviceType device_type_; const int32_t device_id_; const int32_t prefetch_size_; - const int64_t num_batch_; + const bool send_epoch_end_; + bool stop_send_; #ifdef ENABLE_TDTQUE std::shared_ptr tdtInstancePtr; @@ -172,4 +178,4 @@ class DeviceQueueOp : public PipelineOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc new file mode 100644 index 0000000000..1343fd4608 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { + +// The builder "build" method creates the final object. +Status EpochCtrlOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_repeats_); + return Status::OK(); +} + +// Constructor +EpochCtrlOp::EpochCtrlOp(int32_t num_epoch) : RepeatOp(num_epoch) { MS_LOG(INFO) << "Welcome to Epoch Ctrl Op."; } + +// Destructor +EpochCtrlOp::~EpochCtrlOp() {} + +// A print method typically used for debugging +void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [epochs: " << num_repeats_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ + << "\nLeaf Nodes in execution path:"; + if (!eoe_ops_.empty()) { + for (size_t i = 0; i < eoe_ops_.size(); i++) { + out << "\n Operator: " << eoe_ops_[i]->id(); + } + } else { + out << " None."; + } + out << "\n\n"; + } +} + +Status EpochCtrlOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { + if (child_.empty()) { + RETURN_STATUS_UNEXPECTED("EpochCtrlOp can't be the leaf node."); + } + + std::unique_ptr buf; + + // `retry_if_eoe` is false because EpochCtrlOp does not eat EOE. + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, false)); + + // Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op. + // Other databuffers containing data or EOF will simply be forwarded. + // EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up. + if (buf->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + } + + *p_buffer = std::move(buf); + return Status::OK(); +} + +Status EpochCtrlOp::EoeReceived(int32_t worker_id) { + UpdateRepeatAndEpochCounter(); + repeat_count_++; + MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_ + << ". Max epochs: " << num_repeats_; + + // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. + state_ = OpState::kDeOpIdle; + + if (repeat_count_ != num_repeats_) { + for (auto &eoe_op : eoe_ops_) { + MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + } + + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status EpochCtrlOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status EpochCtrlOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h new file mode 100644 index 0000000000..c494208116 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_ +#define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_ + +#include +#include +#include +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class EpochCtrlOp : public RepeatOp { + public: + class Builder : public RepeatOp::Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of repeats to do + // @return This is a constructor. + explicit Builder(int32_t count) : RepeatOp::Builder(count) {} + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new EpochCtrlOp object + Status Build(std::shared_ptr *); + }; + + // Contructor + explicit EpochCtrlOp(int32_t num_epoch); + + // Destructor + ~EpochCtrlOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + std::string Name() const override { return kEpochCtrlOp; } + + // This function returns the buffer that is at the top of our output connector. The caller is + // typically our parent node, when the parent is asking us to provide the next buffer of data. + // Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us + // will simply bounce you to get a buffer from our child. + // Epoch Control Op does not eat the EOE, it will pass the EOE to the next op. + Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; + + // Base-class override for handling cases when an eoe is received. + // @param worker_id - The worker id + Status EoeReceived(int32_t worker_id) override; + + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc index f32648a3df..10dd4f71d7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -90,8 +90,6 @@ Status FilterOp::ValidateInColumns(const std::vector *input_columns // A print method typically used for debugging. void FilterOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -119,6 +117,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); if (in_buffer->eoe()) { filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + UpdateRepeatAndEpochCounter(); continue; } else if (in_buffer->eof()) { filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h index fcc6e577df..8cc0cd55ff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ #include #include @@ -129,7 +129,7 @@ class FilterOp : public ParallelOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "FilterOp"; } + std::string Name() const override { return kFilterOp; } private: // predicate_func python callable which returns a boolean value. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc deleted file mode 100644 index e5e70dbbdf..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc +++ /dev/null @@ -1,373 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/datasetops/map_op.h" -#include -#include -#include -#include -#include -#include "minddata/dataset/core/config_manager.h" - -#include "minddata/dataset/core/constants.h" -#include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/data_buffer.h" -#include "minddata/dataset/engine/db_connector.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "minddata/dataset/engine/opt/pass.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "utils/log_adapter.h" -#include "minddata/dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -MapOp::Builder::Builder() : build_perf_mode_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status MapOp::Builder::sanityCheck() const { - if (build_tensor_funcs_.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Building a MapOp that has not provided any function/operation to apply"); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status MapOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(sanityCheck()); - *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), - std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, - build_perf_mode_); - return Status::OK(); -} - -// Constructor of MapOp -MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode) - : ParallelOp(num_workers, op_connector_size), - tfuncs_(std::move(tensor_funcs)), - in_columns_(in_col_names), - out_columns_(out_col_names), - perf_mode_(perf_mode) { - // If caller didn't specify the out_col_names, assume they are same as the in_columns. - if (out_columns_.empty() || out_columns_[0].empty()) { - out_columns_ = in_columns_; - } - MS_LOG(DEBUG) << "Performance Mode in map operator is " << perf_mode_ << "."; -} - -// The number of threads consuming data from previous op's output Connector. -int32_t MapOp::num_consumers() const { - // When Performance Mode is on, there is only one thread consuming from the previous Connector. - return perf_mode_ == true ? 1 : num_workers_; -} - -// A print method typically used for debugging -void MapOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nInput column names:"; - for (size_t i = 0; i < in_columns_.size(); i++) { - out << " " << in_columns_[i]; - } - out << "\n TensorOps:"; - for (size_t i = 0; i < tfuncs_.size(); i++) { - out << " " << *(tfuncs_[i].get()); - } - out << "\n\n"; - } -} - -// This class functor will provide the master loop that drives the logic for performing the work -Status MapOp::operator()() { - if (perf_mode_) { - // Create and register the local queues. - local_queues_.Init(num_workers_, oc_queue_size_); - Status rc = local_queues_.Register(tree_->AllTasks()); - if (rc.IsError()) { - TaskManager::FindMe()->Post(); - return rc; - } - } - - // The operator class just starts off threads by calling the tree_ function - Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); - // Synchronize with TaskManager - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - - if (perf_mode_) { - int64_t que_id = 0; - std::unique_ptr buff; - bool is_eof = false; - // Draining output connector of the previous op and distribute it to local queues. - // Stop when all worker threads are finished (received EOF). - while (!is_eof) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); - is_eof = buff->eof(); - RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); - que_id = (que_id + 1) % num_workers_; - } - } - - return Status::OK(); -} - -// Private function for worker/thread to loop continuously. It comprises the main -// logic of MapOp: getting the data from previous Op, validating user specified column names, -// applying a list of TensorOps to each of the data, process the results and then -// pushing them back to MapOp's output Connector to be fetched by the next Op. -Status MapOp::WorkerEntry(int32_t worker_id) { - // Handshake with TaskManager that thread creation is successful. - TaskManager::FindMe()->Post(); - std::unique_ptr in_buffer; - - // Getting a databuffer to work on. - // Perform the first fetch here outside of the loop. This allows us to execute one-time only - // initializations that happen after the first fetch. - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - - // Sanity check the databuffer. - // Special case: if there's more threads than buffers, some threads simply get the final control - // messages (eoe/eof), and so they will not perform the check. - if (!in_buffer->eoe() && !in_buffer->eof()) { - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - if (num_rows == 0 || num_cols == 0) { - RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); - } - } - - // Now that init work is done, drop into the main fetching loop. - // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself - // rather than use the base-class defaults. - while (true) { - // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work - // with Performance Mode design. - if (in_buffer->eoe()) { - // Calling base class EoeReceived to forward eoe buffer. - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - continue; - } else if (in_buffer->eof()) { - // Calling base class EofReceived to forward eof buffer. - RETURN_IF_NOT_OK(EofReceived(worker_id)); - break; - } - - std::unique_ptr new_tensor_table(std::make_unique()); - // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. - RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get())); - - // Replace the TensorTable in DataBuffer with the new one. - in_buffer->set_tensor_table(std::move(new_tensor_table)); - - // Push the buffer onto the connector for next operator to consume. - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(worker_id), std::move(in_buffer))); - - // Fetch the next buffer and loop back to the top. - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - } - - return Status::OK(); -} - -Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table) { - // Getting number of rows and cols in this buffer. - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - - for (int32_t r = 0; r < num_rows; r++) { - // to_process : A vector of Tensors only holding cols in input_columns. - // result_row; : A vector of Tensors to hold the result after Compute(). - // cur_row : A vector of Tensors holding all the columns from DataBuffer. - TensorRow to_process, result_row, cur_row; - RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); - - // Populate the Tensor from the current row to be processed by TensorOp - for (const auto &idx : to_process_indices_) { - to_process.push_back(std::move(cur_row[idx])); - } - - // Looping over multiple TensorOps supplied in to MapOp. - // The assumption is that the result of one TensorOp matches the required input to the next TensorOp. - for (size_t i = 0; i < tfuncs_.size(); i++) { - // TensorOp can operate on single col or multiple cols. MapOp always call compute for multiple cols. - // TensorOp base class will call the single column Compute() depending on the ops. - // Note: The columns of the result_row is not preallocated, the compute function of each tensor op are - // required to resize/push back the result_row - RETURN_IF_NOT_OK(tfuncs_[i]->Compute(to_process, &result_row)); - - // Assign result_row to to_process for the next TensorOp processing, except for the last TensorOp in the list. - if (i + 1 < tfuncs_.size()) { - to_process = std::move(result_row); - } - } - - if (out_columns_.size() != result_row.size()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Result of a tensorOp doesn't match output column names"); - } - - if (in_columns_.size() == out_columns_.size()) { - for (size_t i = 0; i < result_row.size(); i++) { - cur_row[to_process_indices_[i]] = std::move(result_row[i]); - } - new_tensor_table->push_back(std::move(cur_row)); - } else { - // Add the columns we did not touch to the result_row. - for (int32_t i = 0; i < num_cols; i++) { - if (keep_input_columns_[i]) { - result_row.push_back(std::move(cur_row[i])); - } - } - - // Add this final result_row to our new TensorTable. - new_tensor_table->push_back(std::move(result_row)); - } - } - - return Status::OK(); -} - -Status MapOp::ComputeColMap() { - // If the map has not been set up yet in the base class, then set it up - if (column_name_id_map_.empty()) { - std::unordered_map current_name_id_map = child_[0]->column_name_id_map(); - // Initialize private variables - RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); - // Create the final column name to index mapping in the base class field - CreateFinalColMap(¤t_name_id_map); - MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// Validating if each of the input_columns exists in the DataBuffer. -Status MapOp::ValidateInColumns(const std::unordered_map &col_name_id_map) { - for (const auto &inCol : in_columns_) { - bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; - if (!found) { - std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); -} - -Status MapOp::InitPrivateVariable(std::unordered_map *col_name_id_map) { - // If input_columns is empty(), The col at index-0 will be picked. - if (in_columns_.empty()) { - for (const auto &pair : *col_name_id_map) { - if (pair.second == 0) { - MS_LOG(INFO) << "Input columns empty for map op, will apply to the first column in the current table."; - in_columns_.push_back(pair.first); - break; - } - } - - // If caller didn't specify the out_col_names, assume they are same as the input_columns. - // This was done in the constructor, but if input columns was empty to start we have to redo it here. - if (out_columns_.empty() || out_columns_[0].empty()) { - out_columns_ = in_columns_; - } - } - - // Before we continue, issue a sanity check to make sure the input columns from user and the incoming - // columns from child are correct - RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); - - // initialize keep_input_columns, true means to keep the column. - keep_input_columns_.resize(col_name_id_map->size(), true); - for (const auto &col_name : in_columns_) { - int32_t missed = (*col_name_id_map)[col_name]; - keep_input_columns_[missed] = false; - } - - // initialize to_process_indices. - for (const auto &col_name : in_columns_) { - to_process_indices_.push_back((*col_name_id_map)[col_name]); - } - return Status::OK(); -} - -// Create the final column name to index mapping and get indices of the columns this mapop does not use. -void MapOp::CreateFinalColMap(std::unordered_map *col_name_id_map) { - std::unordered_map final_col_name_id_map; - size_t num_cols = col_name_id_map->size(); - std::vector new_ids(num_cols); - if (in_columns_.size() == out_columns_.size()) { - for (size_t i = 0; i < in_columns_.size(); i++) { - int32_t loc = (*col_name_id_map)[in_columns_[i]]; - (void)col_name_id_map->erase(in_columns_[i]); - (*col_name_id_map)[out_columns_[i]] = loc; - } - - // Set the base class final column id map result - column_name_id_map_ = *col_name_id_map; - } else { - int32_t fill_idx = 0; - // First columns of the tables are occupied by the output columns from tensorOp. - for (const auto &col_name : out_columns_) { - final_col_name_id_map[col_name] = fill_idx++; - } - - // Creating new_ids mapping for the columns we keep. - for (size_t i = 0; i < num_cols; i++) { - if (keep_input_columns_[i]) { - new_ids[i] = fill_idx++; - } - } - - // Iterating through the old mapping to update the final mapping for the columns we kept. - std::string name; - for (const auto &pair : *col_name_id_map) { - name = pair.first; - int32_t old_id = pair.second; - if (keep_input_columns_[old_id]) { - final_col_name_id_map[name] = new_ids[old_id]; - } - } - - // Set the base class final column id map result - column_name_id_map_ = final_col_name_id_map; - } -} - -// Visitor accept method for NodePass -Status MapOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h deleted file mode 100644 index b1cd58010f..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_MAP_OP_H_ - -#include -#include -#include -#include -#include -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/queue.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class DataBuffer; -class ExecutionTree; - -// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. -// The column order behavior after MapOp is as follows. -// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp -// is the same as the original column order where the Remainder Columns stay in the same position, -// and the Output Columns are placed the same position of the Input Columns. -// For example, initially if the dataset has column order |A, B, C, D, E|, -// and we apply MapOp() with Input Columns {B, C} and Output Columns {X, Y}. -// The column order after applying MapOp will be |A, X, Y, D, E|. -// Note that in this case, |X, Y| is the Output Columns and |A, D, E| which is the Remainder Columns stay in -// their original position, and column B is replaced by column X and column C is replace by column Y. -// [Case 2] If the number of Input Columns != the number of Output Column, column ordering after MapOp -// is Output Columns followed by Remainder Columns. -// For example, initially if the dataset has column order |A, B, C, D, E|, -// and we apply MapOp() with Input Columns {B, C, A} and Output Columns {X, Y}. -// The column order after applying MapOp will be |X, Y, D, E|. -// Note that in this case, |X, Y| is the Output Columns and |D, E| is the Remainder Columns, -// and the Input Columns are gone and replaced by the Output Columns. - -// Keywords: -// Input Columns : a vector of column names (string) passed to MapOp specifying the column names from which -// Tensors are taken and passed to the TensorOp Compute(). -// Output Columns : a vector of column names (string) passed to MapOp specifying what are the column names -// for the Tensors produced by TensorOp Compute(). -// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns. -// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns. -class MapOp : public ParallelOp { - public: - // The nested builder class inside of the MapOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - build_in_col_names_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOutColNames(const std::vector &out_col_names) { - build_out_col_names_ = out_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetTensorFuncs(std::vector> funcs) { - build_tensor_funcs_ = std::move(funcs); - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPerformanceMode(bool perf_mode) { - build_perf_mode_ = perf_mode; - return *this; - } - - // The builder "build" method creates the final object. - // @param ptr The shared_ptr to the new MapOp object - // @return Status - Status Build(std::shared_ptr *ptr); - - private: - std::vector build_in_col_names_; - std::vector build_out_col_names_; - std::vector> build_tensor_funcs_; - int32_t build_num_workers_; - int32_t build_op_connector_size_; - bool build_perf_mode_; // Default true. - - // Check if the required parameters are set by the builder. - // @return Status The error code return - Status sanityCheck() const; - }; - - // Constructor of MapOp - // @note The builder class should be used to call it. - // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). - // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). - // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. - // @param num_workers The number of worker threads. - // @param op_connector_size The size of each queue in the connector. - MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode); - - // Destructor - ~MapOp() = default; - - // A print method typically used for debugging - // @param out The output stream to write output to - // @param show_all A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out reference to the output stream being overloaded - // @param mo reference to the MapOp to display - // @return the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const MapOp &mo) { - mo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status The error code return - Status operator()() override; - - // Getter - // @return the number of threads consuming data from previous op's output Connector. - int32_t num_consumers() const override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MapOp"; } - - // List of tensor ops getter/setter - // @Return the vector of tensor ops by non-const reference - - auto &TFuncs() { return tfuncs_; } - - const auto &TFuncs() const { return tfuncs_; } - - private: - // Local queues where worker threads can pop from. - // Popping directly from the Connector can block if the previous designated threads haven't pop. - // Setting the size of these queues to 0 is essentially the same as pulling directly from Connector. - QueueList> local_queues_; - - // Static variables to be ready by worker threads, no modification and readonly - std::vector> tfuncs_; - - // Variable to store the column name that the tensorOps are consuming - std::vector in_columns_; - - // Variable to store the column name that the tensorOps are producing - std::vector out_columns_; - - // Boolean mapping, true means to keep the column. - std::vector keep_input_columns_; - - // Indices of the columns to process. - std::vector to_process_indices_; - - // Performance mode is when the main thread creates local queues, pulls databuffers from the previous - // op's Connector and distributes them to the local queues. Workers pull from the local queues. - // If this flag is false, each worker pulls directly from the Connector. This use less resources - // (thread and memory), but when the computation cost is heavy (e.g. DecodeOp) and fluctuating, it can - // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. - bool perf_mode_; - - // Private function for worker/thread to loop continuously. It comprises the main - // logic of MapOp: getting the data from previous Op, validating user specified column names, - // applying a list of TensorOps to each of the data, process the results and then - // pushing them back to MapOp's output Connector to be fetched by the next Op. - // @param worker_id The id assigned to this thread/worker upon creation. - // @return Status The error code return - Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ - - // Private helper function for getting the next buffer - // When PerformanceMode is enabled, workers pop from the local queue. - // Otherwise, workers pop from the first child output Connector. - // @param p_buffer - the buffer to return - // @return Status return code - Status FetchNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { - if (perf_mode_) { - RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(p_buffer)); - } else { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id)); - } - return Status::OK(); - } - - // Private function for worker thread to perform TensorOp's compute function and get the result. - // @param in_buffer A raw pointer to the DataBuffer. A raw pointer is fine because this function doesn't manage memory - // and is not shared with other threads. - // @param[out] new_tensor_table A new Tensor Table to be populated in this function. - Status WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table); - - // Private function that create the final column name to index mapping and - // get indices of the columns this mapop does not use. - // @param col_name_id_map The column name to index mapping obtained from child operator - void CreateFinalColMap(std::unordered_map *col_name_id_map); - - // Validating if each of the input_columns exists in the DataBuffer. - // @param - the column map to check - // @return - status return code - Status ValidateInColumns(const std::unordered_map &col_name_id_map); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - // Private function for initializing private variables such as in_columns_, out_columns_. - // @return - Status - Status InitPrivateVariable(std::unordered_map *col_name_id_map); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/CMakeLists.txt new file mode 100644 index 0000000000..a877016c9c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/CMakeLists.txt @@ -0,0 +1,10 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +set(DATASET_ENGINE_DATASETOPS_MAPOP_SRC_FILES + map_op.cc + cpu_map_job.cc + gpu_map_job.cc + ) + +add_library(engine-datasetops-mapop OBJECT ${DATASET_ENGINE_DATASETOPS_MAPOP_SRC_FILES}) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/cpu_map_job.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/cpu_map_job.cc new file mode 100644 index 0000000000..8b6b753aae --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/cpu_map_job.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CpuMapJob::CpuMapJob() = default; + +// Constructor +CpuMapJob::CpuMapJob(std::vector> operations) : MapJob(operations) {} + +// Destructor +CpuMapJob::~CpuMapJob() = default; + +// A function to execute a cpu map job +Status CpuMapJob::Run(std::vector in, std::vector *out) { + int32_t num_rows = in.size(); + for (int32_t row = 0; row < num_rows; row++) { + TensorRow input_row = in[row]; + TensorRow result_row; + for (size_t i = 0; i < ops_.size(); i++) { + // Call compute function for cpu + RETURN_IF_NOT_OK(ops_[i]->Compute(input_row, &result_row)); + + // Assign result_row to to_process for the next TensorOp processing, except for the last TensorOp in the list. + if (i + 1 < ops_.size()) { + input_row = std::move(result_row); + } + } + out->push_back(std::move(result_row)); + } + + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/cpu_map_job.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/cpu_map_job.h new file mode 100644 index 0000000000..330b676865 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/cpu_map_job.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_CPU_MAP_JOB_H_ +#define DATASET_ENGINE_DATASETOPS_MAP_OP_CPU_MAP_JOB_H_ + +#include +#include +#include "minddata/dataset/engine/datasetops/map_op/map_job.h" + +namespace mindspore { +namespace dataset { +class CpuMapJob : public MapJob { + public: + // Constructor + CpuMapJob(); + + // Constructor + explicit CpuMapJob(std::vector> operations); + + // Destructor + ~CpuMapJob(); + + // A pure virtual run function to execute a cpu map job + Status Run(std::vector in, std::vector *out) override; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_CPU_MAP_JOB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/gpu_map_job.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/gpu_map_job.cc new file mode 100644 index 0000000000..64502d6da2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/gpu_map_job.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" + +namespace mindspore { +namespace dataset { + +// Constructor +GpuMapJob::GpuMapJob(std::vector> operations) : MapJob(operations) {} + +// Destructor +GpuMapJob::~GpuMapJob() = default; + +// A function to execute a cpu map job +Status GpuMapJob::Run(std::vector in, std::vector *out) { + // Do nothing for now + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/gpu_map_job.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/gpu_map_job.h new file mode 100644 index 0000000000..743c5104c9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/gpu_map_job.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_GPU_MAP_JOB_H_ +#define DATASET_ENGINE_DATASETOPS_MAP_OP_GPU_MAP_JOB_H_ + +#include +#include +#include "minddata/dataset/engine/datasetops/map_op/map_job.h" + +namespace mindspore { +namespace dataset { +class GpuMapJob : public MapJob { + public: + // Constructor + explicit GpuMapJob(std::vector> operations); + + // Destructor + ~GpuMapJob(); + + // A pure virtual run function to execute a cpu map job + Status Run(std::vector in, std::vector *out) override; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_GPU_MAP_JOB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_job.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_job.h new file mode 100644 index 0000000000..fd05dfd53f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_job.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_MAP_JOB_H_ +#define DATASET_ENGINE_DATASETOPS_MAP_OP_MAP_JOB_H_ + +#include +#include + +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class MapJob { + public: + // Constructor + explicit MapJob(std::vector> operations) : ops_(operations) {} + + // Constructor + MapJob() = default; + + // Destructor + ~MapJob() = default; + + Status AddOperation(std::shared_ptr operation) { + ops_.push_back(operation); + return Status::OK(); + } + + // A pure virtual run function to execute a particular map job + virtual Status Run(std::vector in, std::vector *out) = 0; + + protected: + std::vector> ops_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_MAP_JOB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc new file mode 100644 index 0000000000..dff595e11e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -0,0 +1,432 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" +#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" +#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +MapOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status MapOp::Builder::sanityCheck() const { + if (build_tensor_funcs_.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Building a MapOp that has not provided any function/operation to apply"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status MapOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(sanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), + std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_); + return Status::OK(); +} + +// Constructor of MapOp +MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size) + : ParallelOp(num_workers, op_connector_size), + tfuncs_(std::move(tensor_funcs)), + in_columns_(in_col_names), + out_columns_(out_col_names) { + // If caller didn't specify the out_col_names, assume they are same as the in_columns. + if (out_columns_.empty() || out_columns_[0].empty()) { + out_columns_ = in_columns_; + } +} + +// The number of threads consuming data from previous op's output Connector. +int32_t MapOp::num_consumers() const { + // When Performance Mode is on, there is only one thread consuming from the previous Connector. + return 1; +} + +// A print method typically used for debugging +void MapOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nInput column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } + out << "\n TensorOps:"; + for (size_t i = 0; i < tfuncs_.size(); i++) { + out << " " << *(tfuncs_[i].get()); + } + out << "\n\n"; + } +} + +// A helper function that fetch worker map job from local queues and extract the data and map job list +Status MapOp::FetchNextWork(uint32_t worker_id, std::unique_ptr *db, + std::vector> *job_list) { + std::unique_ptr worker_job; + // Fetch the next worker job and data buffer + RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(&worker_job)); + // Extract the databuffer and job list from the map worker job. + *db = std::move(worker_job->databuffer); + *job_list = std::move(worker_job->jobs); + + return Status::OK(); +} + +Status MapOp::GenerateWorkerJob(const std::unique_ptr *worker_job) { + std::shared_ptr map_job = nullptr; + MapTargetDevice prev_target; + for (size_t i = 0; i < tfuncs_.size(); i++) { + // Currently we only have CPU as the device target + // In the future, we will have heuristic or control from user to select target device + // MapTargetDevice target_device; + // RETURN_IF_NOT_OK(SelectTarget(tfuncs_[i], &target_device)); + MapTargetDevice target_device = MapTargetDevice::kCpu; + + switch (target_device) { + case MapTargetDevice::kCpu: + // If there is no existing map_job, we will create one. + // map_job could be nullptr when we are at the first tensor op or when the target device of the prev op + // is different with that of the current op. + if (map_job == nullptr) { + map_job = std::make_shared(); + } + map_job->AddOperation(tfuncs_[i]); + break; + + case MapTargetDevice::kGpu: + break; + + case MapTargetDevice::kDvpp: + break; + + default: + break; + } + + // Push map_job into worker_job if one of the two conditions is true: + // 1) It is the last tensor operation in tfuncs_ + // 2) The the target device of the current tensor operation is different with previous one + if ((i + 1 == tfuncs_.size()) || ((i != 0) && (prev_target != target_device))) { + (*worker_job)->jobs.push_back(std::move(map_job)); + } + + prev_target = target_device; + } + + return Status::OK(); +} + +// This class functor will provide the master loop that drives the logic for performing the work +Status MapOp::operator()() { + // Create and register the local queues. + local_queues_.Init(num_workers_, oc_queue_size_); + Status rc = local_queues_.Register(tree_->AllTasks()); + if (rc.IsError()) { + TaskManager::FindMe()->Post(); + return rc; + } + + // The operator class just starts off threads by calling the tree_ function + rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); + // Synchronize with TaskManager + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + + int64_t que_id = 0; + std::unique_ptr buff; + bool is_eof = false; + // Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues + // Stop when all worker threads are finished (received EOF) + while (!is_eof) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); + is_eof = buff->eof(); + + // Create an empty map worker job to be populated by a databuffer and map jobs + std::unique_ptr worker_job = std::make_unique(); + worker_job->databuffer = std::move(buff); + + // Populate map worker job for a worker to execute + RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job)); + + // Push map worker job to the corresponding worker's queue + RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job))); + que_id = (que_id + 1) % num_workers_; + } + + return Status::OK(); +} + +// Private function for worker/thread to loop continuously. It comprises the main +// logic of MapOp: getting the data from previous Op, validating user specified column names, +// applying a list of TensorOps to each of the data, process the results and then +// pushing them back to MapOp's output Connector to be fetched by the next Op. +Status MapOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + + std::unique_ptr in_buffer; + std::vector> job_list; + // Fetch next data buffer and map job list + RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); + + // Sanity check the databuffer. + // Special case: if there's more threads than buffers, some threads simply get the final control + // messages (eoe/eof), and so they will not perform the check. + if (!in_buffer->eoe() && !in_buffer->eof()) { + int32_t num_rows = in_buffer->NumRows(); + int32_t num_cols = in_buffer->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); + } + } + + // Now that init work is done, drop into the main fetching loop. + // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself + // rather than use the base-class defaults. + while (true) { + // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work + // with Performance Mode design. + if (in_buffer->eoe()) { + UpdateRepeatAndEpochCounter(); + // Calling base class EoeReceived to forward eoe buffer. + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + // Fetch next data buffer and map job list + RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); + continue; + } else if (in_buffer->eof()) { + // Calling base class EofReceived to forward eof buffer. + RETURN_IF_NOT_OK(EofReceived(worker_id)); + break; + } + + std::unique_ptr new_tensor_table(std::make_unique()); + // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list)); + // Replace the TensorTable in DataBuffer with the new one. + in_buffer->set_tensor_table(std::move(new_tensor_table)); + // Push the buffer onto the connector for next operator to consume. + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(worker_id), std::move(in_buffer))); + // Fetch next data buffer and map job list + RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); + } + return Status::OK(); +} + +Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table, + const std::vector> &job_list) { + int32_t num_rows = in_buffer->NumRows(); + int32_t num_cols = in_buffer->NumCols(); + + std::vector job_input_table; + std::vector original_table; + + // Prepare the data that we need from in_buffer + for (int32_t r = 0; r < num_rows; r++) { + // to_process : A vector of Tensors only holding cols in input_columns. + // cur_row : A vector of Tensors holding all the cols from DataBuffer. + TensorRow to_process, cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + // From the current row, select the Tensor that need to be passed to TensorOp + (void)std::transform(to_process_indices_.begin(), to_process_indices_.end(), std::back_inserter(to_process), + [&cur_row](const auto &it) { return std::move(cur_row[it]); }); + job_input_table.push_back(std::move(to_process)); + original_table.push_back(std::move(cur_row)); + } + + // Variable to keep the result after executing the job. + std::vector result_table; + // Executing the list of jobs + for (size_t i = 0; i < job_list.size(); i++) { + // Executre MapJob. + RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table)); + // Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list. + if (i + 1 < job_list.size()) { + job_input_table = std::move(result_table); + } + } + + // Sanity check a row in result_table + if (!result_table.empty() && out_columns_.size() != result_table[0].size()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Result of a tensorOp doesn't match output column names"); + } + + // Merging the data processed by job (result_table) with the data that are not used. + for (int32_t r = 0; r < num_rows; r++) { + TensorRow out_row; + if (in_columns_.size() == out_columns_.size()) { + // Place the processed tensor back into the original index of the input tensor + for (size_t i = 0; i < result_table[r].size(); i++) { + original_table[r][to_process_indices_[i]] = std::move(result_table[r][i]); + } + out_row = std::move(original_table[r]); + } else { + // Append the data in the original table that we did not use to the end of each row in result_table. + for (int32_t i = 0; i < num_cols; i++) { + if (keep_input_columns_[i]) { + result_table[r].push_back(std::move(original_table[r][i])); + } + } + out_row = std::move(result_table[r]); + } + // Add this final out_row to our new TensorTable. + new_tensor_table->push_back(std::move(out_row)); + } + return Status::OK(); +} + +Status MapOp::ComputeColMap() { + // If the map has not been set up yet in the base class, then set it up + if (column_name_id_map_.empty()) { + std::unordered_map current_name_id_map = child_[0]->column_name_id_map(); + // Initialize private variables + RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); + // Create the final column name to index mapping in the base class field + CreateFinalColMap(¤t_name_id_map); + MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// Validating if each of the input_columns exists in the DataBuffer. +Status MapOp::ValidateInColumns(const std::unordered_map &col_name_id_map) { + for (const auto &inCol : in_columns_) { + bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +Status MapOp::InitPrivateVariable(std::unordered_map *col_name_id_map) { + // If input_columns is empty(), The col at index-0 will be picked. + if (in_columns_.empty()) { + auto itr = + std::find_if(col_name_id_map->begin(), col_name_id_map->end(), [](const auto &it) { return it.second == 0; }); + CHECK_FAIL_RETURN_UNEXPECTED(itr != col_name_id_map->end(), "Column name id map doesn't have id 0"); + MS_LOG(INFO) << "Input columns empty for map op, will apply to the first column in the current table."; + in_columns_.push_back(itr->first); + + // If caller didn't specify the out_col_names, assume they are same as the input_columns. + // This was done in the constructor, but if input columns was empty to start we have to redo it here. + if (out_columns_.empty() || out_columns_[0].empty()) { + out_columns_ = in_columns_; + } + } + + // Before we continue, issue a sanity check to make sure the input columns from user and the incoming + // columns from child are correct + RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); + + // initialize keep_input_columns, true means to keep the column. + keep_input_columns_.resize(col_name_id_map->size(), true); + for (const auto &col_name : in_columns_) { + int32_t missed = (*col_name_id_map)[col_name]; + keep_input_columns_[missed] = false; + } + + // initialize to_process_indices. + for (const auto &col_name : in_columns_) { + to_process_indices_.push_back((*col_name_id_map)[col_name]); + } + return Status::OK(); +} + +// Create the final column name to index mapping and get indices of the columns this mapop does not use. +void MapOp::CreateFinalColMap(std::unordered_map *col_name_id_map) { + std::unordered_map final_col_name_id_map; + size_t num_cols = col_name_id_map->size(); + std::vector new_ids(num_cols); + if (in_columns_.size() == out_columns_.size()) { + for (size_t i = 0; i < in_columns_.size(); i++) { + int32_t loc = (*col_name_id_map)[in_columns_[i]]; + (void)col_name_id_map->erase(in_columns_[i]); + (*col_name_id_map)[out_columns_[i]] = loc; + } + + // Set the base class final column id map result + column_name_id_map_ = *col_name_id_map; + } else { + int32_t fill_idx = 0; + // First columns of the tables are occupied by the output columns from tensorOp. + for (const auto &col_name : out_columns_) { + final_col_name_id_map[col_name] = fill_idx++; + } + + // Creating new_ids mapping for the columns we keep. + for (size_t i = 0; i < num_cols; i++) { + if (keep_input_columns_[i]) { + new_ids[i] = fill_idx++; + } + } + + // Iterating through the old mapping to update the final mapping for the columns we kept. + std::string name; + for (const auto &pair : *col_name_id_map) { + name = pair.first; + int32_t old_id = pair.second; + if (keep_input_columns_[old_id]) { + final_col_name_id_map[name] = new_ids[old_id]; + } + } + + // Set the base class final column id map result + column_name_id_map_ = final_col_name_id_map; + } +} + +// Visitor accept method for NodePass +Status MapOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h new file mode 100644 index 0000000000..77ee94d86d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h @@ -0,0 +1,254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/engine/datasetops/map_op/map_job.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. +// The column order behavior after MapOp is as follows. +// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp +// is the same as the original column order where the Remainder Columns stay in the same position, +// and the Output Columns are placed the same position of the Input Columns. +// For example, initially if the dataset has column order |A, B, C, D, E|, +// and we apply MapOp() with Input Columns {B, C} and Output Columns {X, Y}. +// The column order after applying MapOp will be |A, X, Y, D, E|. +// Note that in this case, |X, Y| is the Output Columns and |A, D, E| which is the Remainder Columns stay in +// their original position, and column B is replaced by column X and column C is replace by column Y. +// [Case 2] If the number of Input Columns != the number of Output Column, column ordering after MapOp +// is Output Columns followed by Remainder Columns. +// For example, initially if the dataset has column order |A, B, C, D, E|, +// and we apply MapOp() with Input Columns {B, C, A} and Output Columns {X, Y}. +// The column order after applying MapOp will be |X, Y, D, E|. +// Note that in this case, |X, Y| is the Output Columns and |D, E| is the Remainder Columns, +// and the Input Columns are gone and replaced by the Output Columns. + +// Keywords: +// Input Columns : a vector of column names (string) passed to MapOp specifying the column names from which +// Tensors are taken and passed to the TensorOp Compute(). +// Output Columns : a vector of column names (string) passed to MapOp specifying what are the column names +// for the Tensors produced by TensorOp Compute(). +// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns. +// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns. +class MapOp : public ParallelOp { + public: + // The nested builder class inside of the MapOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOutColNames(const std::vector &out_col_names) { + build_out_col_names_ = out_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetTensorFuncs(std::vector> funcs) { + build_tensor_funcs_ = std::move(funcs); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new MapOp object + // @return Status + Status Build(std::shared_ptr *ptr); + + private: + std::vector build_in_col_names_; + std::vector build_out_col_names_; + std::vector> build_tensor_funcs_; + int32_t build_num_workers_; + int32_t build_op_connector_size_; + + // Check if the required parameters are set by the builder. + // @return Status The error code return + Status sanityCheck() const; + }; + + // Constructor of MapOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). + // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). + // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + MapOp(const std::vector &in_col_names, const std::vector &out_col_names, + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size); + + // Destructor + ~MapOp() = default; + + // A print method typically used for debugging + // @param out The output stream to write output to + // @param show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out reference to the output stream being overloaded + // @param mo reference to the MapOp to display + // @return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const MapOp &mo) { + mo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // This main thread creates local queues, pulls databuffers from the previous + // op's Connector and distributes them to the local queues. Workers pull from the local queues. + // @return Status The error code return + Status operator()() override; + + // Getter + // @return the number of threads consuming data from previous op's output Connector. + int32_t num_consumers() const override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return kMapOp; } + + // List of tensor ops getter/setter + // @Return the vector of tensor ops by non-const reference + + auto &TFuncs() { return tfuncs_; } + + const auto &TFuncs() const { return tfuncs_; } + + private: + // A unit of job for map worker thread. + // MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob. + struct MapWorkerJob { + std::vector> jobs; + std::unique_ptr databuffer; + }; + + // A helper function to create jobs for workers. + Status GenerateWorkerJob(const std::unique_ptr *worker_job); + + // A helper function that fetch worker map job from local queues and extract the data and map job list + Status FetchNextWork(uint32_t worker_id, std::unique_ptr *db, + std::vector> *job_list); + + // Local queues where worker threads get a job from + QueueList> local_queues_; + + // Tensorops to be read and applied by worker threads + std::vector> tfuncs_; + + // Variable to store the column name that the tensorOps are consuming + std::vector in_columns_; + + // Variable to store the column name that the tensorOps are producing + std::vector out_columns_; + + // Boolean mapping, true means to keep the column. + std::vector keep_input_columns_; + + // Indices of the columns to process. + std::vector to_process_indices_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of MapOp: getting the data from previous Op, validating user specified column names, + // applying a list of TensorOps to each of the data, process the results and then + // pushing them back to MapOp's output Connector to be fetched by the next Op. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Private function for worker thread to perform TensorOp's compute function and get the result. + // @param in_buffer A raw pointer to the DataBuffer. A raw pointer is fine because this function doesn't manage memory + // and is not shared with other threads. + // @param[out] new_tensor_table A new Tensor Table to be populated in this function. + Status WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table, + const std::vector> &job_list); + + // Private function that create the final column name to index mapping and + // get indices of the columns this mapop does not use. + // @param col_name_id_map The column name to index mapping obtained from child operator + void CreateFinalColMap(std::unordered_map *col_name_id_map); + + // Validating if each of the input_columns exists in the DataBuffer. + // @param - the column map to check + // @return - status return code + Status ValidateInColumns(const std::unordered_map &col_name_id_map); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + // Private function for initializing private variables such as in_columns_, out_columns_. + // @return - Status + Status InitPrivateVariable(std::unordered_map *col_name_id_map); +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h index da54ce1331..8d7ba6302a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ #include +#include #include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" @@ -54,6 +55,7 @@ class ParallelOp : public DatasetOp { // @param out - The output stream to write output to // @param show_all - A bool to control if you want to show all info or just a summary void Print(std::ostream &out, bool show_all) const override; + std::string Name() const override { return kParallelOp; } // << Stream output operator overload // @notes This allows you to write the debug print info using stream operators @@ -123,4 +125,4 @@ class ParallelOp : public DatasetOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h index 0538349f48..88faad8265 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ #include +#include #include #include "minddata/dataset/engine/datasetops/dataset_op.h" @@ -42,6 +43,7 @@ class PipelineOp : public DatasetOp { // @param out - The output stream to write output to // @param show_all - A bool to control if you want to show all info or just a summary void Print(std::ostream &out, bool show_all) const override; + std::string Name() const override { return kPipelineOp; } // << Stream output operator overload // @notes This allows you to write the debug print info using stream operators @@ -95,4 +97,4 @@ class PipelineOp : public DatasetOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc index e232a64164..9265a59028 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc @@ -51,8 +51,6 @@ ProjectOp::ProjectOp(const std::vector &columns_to_project) : PipelineOp(0), columns_to_project_(columns_to_project) {} void ProjectOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); @@ -76,6 +74,9 @@ Status ProjectOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t w if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { RETURN_IF_NOT_OK(Project(p_buffer)); } + if ((*p_buffer)->eoe()) { + UpdateRepeatAndEpochCounter(); + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h index c2f14d34b7..864baab0fa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ #include #include @@ -109,7 +109,7 @@ class ProjectOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "ProjectOp"; } + std::string Name() const override { return kProjectOp; } private: std::vector columns_to_project_; @@ -124,4 +124,4 @@ class ProjectOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc index d12660e6f9..0eeccea50a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc @@ -141,8 +141,6 @@ Status RenameOp::ComputeColMap() { // prints rename void RenameOp::Print(std::ostream &out, // In: The output stream to print to bool show_all) const { // In: T/F if it should print everything - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h index d846bb1b40..25c9e46896 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ -#define DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ #include #include @@ -118,7 +118,7 @@ class RenameOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "RenameOp"; } + std::string Name() const override { return kRenameOp; } protected: // Rename core functionality @@ -135,4 +135,4 @@ class RenameOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 6d3dc91ed3..123cb4451c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -28,10 +28,10 @@ namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. -RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} +RepeatOp::Builder::Builder(int32_t count) : build_num_repeats_(count) {} Status RepeatOp::Builder::SanityCheck() const { - if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { + if (build_num_repeats_ < kInfiniteRepeat || build_num_repeats_ == 0) { std::string err_msg("Repeat count must be > 0 or -1."); RETURN_STATUS_UNEXPECTED(err_msg); } @@ -41,30 +41,28 @@ Status RepeatOp::Builder::SanityCheck() const { // The builder "build" method creates the final object. Status RepeatOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_repeats_); + *ptr = std::make_shared(build_num_repeats_); return Status::OK(); } // Constructor of the RepeatOp. -RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} +RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), num_repeats_(count), repeat_count_(0) {} // Destructor RepeatOp::~RepeatOp() {} // A print method typically used for debugging void RepeatOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); // Then show any custom derived-internal 1-liner info for this op - out << " [repeats: " << max_repeats_ << "]\n"; + out << " [repeats: " << num_repeats_ << "]\n"; } else { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ + out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ << "\nLeaf Nodes in execution path:"; if (!eoe_ops_.empty()) { for (size_t i = 0; i < eoe_ops_.size(); i++) { @@ -109,22 +107,13 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t wo // Base-class override for handling cases when an eoe is received. Status RepeatOp::EoeReceived(int32_t worker_id) { + UpdateRepeatAndEpochCounter(); + repeat_count_++; MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; - bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); - bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); - // If we've reached the requested repeat count, then flag the eoe nodes - // to tell them they've got one more epoch to perform. When they reach the end - // of the last epoch, they quit rather than loop again. This happens in two cases: - // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, - // 2- We are not repeated - if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { - for (auto &eoe_op : eoe_ops_) { - eoe_op->set_control_flag(kDeOpLastRepeat); - } - } - if (repeat_count_ == max_repeats_) { + + if (repeat_count_ == num_repeats_) { repeat_count_ = 0; state_ = OpState::kDeOpIdle; return Status::OK(); @@ -132,6 +121,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { // Invoke a reset against the eoe nodes only. for (auto &eoe_op : eoe_ops_) { + MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); RETURN_IF_NOT_OK(eoe_op->Reset()); } @@ -167,8 +157,9 @@ int32_t RepeatOp::num_consumers() const { Status RepeatOp::Reset() { // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. // In that case, we now have to bounce the reset down to our own eoe ops. - MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; + MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; for (auto &eoe_op : eoe_ops_) { + MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); RETURN_IF_NOT_OK(eoe_op->Reset()); } state_ = OpState::kDeOpRunning; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index f5259de30e..bdd4953541 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ #include #include @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { class RepeatOp : public PipelineOp { public: - static constexpr int32_t kInfiniteRepeat = -1; - // The nested builder class inside of the RepeatOp is used to help manage all of the arguments // for constructing it. This repeat op is very simple though, so this builder is really just // provided for a consistent look and feel for creators of Dataset operators overall. @@ -46,8 +44,8 @@ class RepeatOp : public PipelineOp { // @return shared_ptr to the new RepeatOp object Status Build(std::shared_ptr *); - private: - int32_t build_max_repeats_; + protected: + int32_t build_num_repeats_; Status SanityCheck() const; }; @@ -129,18 +127,29 @@ class RepeatOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "RepeatOp"; } + std::string Name() const override { return kRepeatOp; } + + /// \brief Getter function + /// \return The number of repeats that the user requested + int32_t num_repeats() { return num_repeats_; } - /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes - /// \param[in] eoe_op The input leaf/eoe operator to add to the list + // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes + // \param[in] eoe_op The input leaf/eoe operator to add to the list void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } - private: - int32_t max_repeats_; // The number of repeats that the user requested - int32_t repeat_count_; // A counter for the current number of executed repeats + protected: + // The number of repeats that the user requested. + // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. + // For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4), + // num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6. + int32_t num_repeats_; + // A counter for the current number of executed repeats. + // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class + // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. + int32_t repeat_count_; std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc index 0eb5f29eaf..2b4e64cfad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc @@ -100,8 +100,6 @@ Status ShuffleOp::SelfReset() { // A print method typically used for debugging void ShuffleOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h index 86bea7cc77..37ae9230c7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ #include #include @@ -163,7 +163,7 @@ class ShuffleOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "ShuffleOp"; } + std::string Name() const override { return kShuffleOp; } private: // Private function to add a new row to the shuffle buffer. @@ -201,4 +201,4 @@ class ShuffleOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc index 2fe8cbeaa6..d25e66ee7b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -58,8 +58,6 @@ SkipOp::~SkipOp() {} // A print method typically used for debugging void SkipOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h index a717d0efa4..657da1fe84 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ #include #include @@ -82,7 +82,7 @@ class SkipOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "SkipOp"; } + std::string Name() const override { return kSkipOp; } private: int32_t max_skips_; // The number of skips that the user requested @@ -91,4 +91,4 @@ class SkipOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt index 389e3f5af6..868c6fdb89 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt @@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES celeba_op.cc text_file_op.cc clue_op.cc + csv_op.cc ) set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES @@ -29,4 +30,4 @@ if (ENABLE_PYTHON) ) endif() -add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) \ No newline at end of file +add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 9d7d5622a6..0cb4ec5559 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -293,7 +293,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -310,6 +310,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); } + UpdateRepeatAndEpochCounter(); } } @@ -359,7 +360,7 @@ Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::paircolumn(1).tensorImpl(), - TensorShape({1, (uint32_t)image_label.second.size()}), - data_schema_->column(1).type())); + RETURN_IF_NOT_OK( + Tensor::CreateEmpty(TensorShape({1, (uint32_t)image_label.second.size()}), data_schema_->column(1).type(), &label)); RETURN_IF_NOT_OK(label->Zero()); for (uint32_t index = 0; index < image_label.second.size(); index++) { if (image_label.second[index] == 1) { @@ -387,8 +387,6 @@ Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::pair:"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h index ef183f8e65..a3b1212c74 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H -#define DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H #include #include @@ -237,4 +237,4 @@ class CelebAOp : public ParallelOp, RandomAccessOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index 06be682bfd..fceec890b2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -20,7 +20,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" @@ -120,7 +120,7 @@ Status CifarOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -137,6 +137,7 @@ Status CifarOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } @@ -190,15 +191,12 @@ Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) { std::shared_ptr label; std::shared_ptr fine_label; std::shared_ptr ori_image = cifar_image_label_pairs_[index].first; - std::shared_ptr copy_image = - std::make_shared(ori_image->shape(), ori_image->type(), ori_image->GetBuffer()); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&cifar_image_label_pairs_[index].second[0]))); + std::shared_ptr copy_image; + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(ori_image, ©_image)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cifar_image_label_pairs_[index].second[0], &label)); + if (cifar_image_label_pairs_[index].second.size() > 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &fine_label, data_schema_->column(2).tensorImpl(), data_schema_->column(2).shape(), - data_schema_->column(2).type(), reinterpret_cast(&cifar_image_label_pairs_[index].second[1]))); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cifar_image_label_pairs_[index].second[1], &fine_label)); (*trow) = TensorRow(index, {copy_image, std::move(label), std::move(fine_label)}); } else { (*trow) = TensorRow(index, {copy_image, std::move(label)}); @@ -220,8 +218,6 @@ Status CifarOp::LoadBuffer(const std::vector &keys, std::unique_ptr:"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -359,9 +355,8 @@ Status CifarOp::ParseCifarData() { } std::shared_ptr image_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image_tensor, data_schema_->column(0).tensorImpl(), - TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), - data_schema_->column(0).type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), + data_schema_->column(0).type(), &image_tensor)); auto itr = image_tensor->begin(); uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; for (int pix = 0; pix < total_pix; ++pix) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h index 60169f32bf..f6703f0544 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ #include #include @@ -233,4 +233,4 @@ class CifarOp : public ParallelOp, public RandomAccessOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 958514583a..7b0650e962 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -127,7 +127,7 @@ Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr (*tensor_table)->push_back(std::move(tRow)); std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); (**tensor_table)[row][0] = std::move(tensor); return Status::OK(); } @@ -144,26 +144,19 @@ Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_c std::string final_str = key_chain.back(); switch (cursor.type()) { case nlohmann::detail::value_t::string: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; - case nlohmann::detail::value_t::number_integer: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; case nlohmann::detail::value_t::number_unsigned: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; case nlohmann::detail::value_t::number_float: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); - (*t)->SetItemAt({0}, cursor.get()); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; case nlohmann::detail::value_t::array: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(cursor.get>(), t)); break; default: break; @@ -278,13 +271,14 @@ Status ClueOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); @@ -318,8 +312,6 @@ Status ClueOp::WorkerEntry(int32_t worker_id) { // A print method typically used for debugging void ClueOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index ab429561ec..d4873ec697 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ #include #include @@ -274,4 +274,4 @@ class ClueOp : public ParallelOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index daef2f284b..db98416d6b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -18,7 +18,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" @@ -167,7 +167,7 @@ Status CocoOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); @@ -184,12 +184,11 @@ Status CocoOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } void CocoOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -215,7 +214,7 @@ Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Te auto itr = coordinate_map_.find(image_id); if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - std::string kImageFile = image_folder_path_ + image_id; + std::string kImageFile = image_folder_path_ + std::string("/") + image_id; RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); auto bboxRow = itr->second; @@ -239,9 +238,8 @@ Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Te } std::vector bbox_dim = {bbox_row_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&coordinate, data_schema_->column(1).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(1).type(), - reinterpret_cast(&bbox_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(bbox_row, TensorShape(bbox_dim), &coordinate)); + if (task_type_ == TaskType::Detection) { RETURN_IF_NOT_OK(LoadDetectionTensorRow(row_id, image_id, image, coordinate, trow)); } else if (task_type_ == TaskType::Stuff || task_type_ == TaskType::Keypoint) { @@ -278,13 +276,12 @@ Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &ima iscrowd_row.push_back(annotation[i]); } } - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector( + category_id_row, TensorShape({static_cast(category_id_row.size()), 1}), &category_id)); + + RETURN_IF_NOT_OK( + Tensor::CreateFromVector(iscrowd_row, TensorShape({static_cast(iscrowd_row.size()), 1}), &iscrowd)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)}); return Status::OK(); } @@ -302,9 +299,8 @@ Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_ item_queue = itr_item->second; std::vector bbox_dim = {static_cast(item_queue.size()), 1}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&item, data_schema_->column(2).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(2).type(), - reinterpret_cast(&item_queue[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(item_queue, TensorShape(bbox_dim), &item)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)}); return Status::OK(); } @@ -334,18 +330,14 @@ Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id, area_row.push_back(annotation[i]); } } + RETURN_IF_NOT_OK(Tensor::CreateFromVector( + category_id_row, TensorShape({static_cast(category_id_row.size()), 1}), &category_id)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + RETURN_IF_NOT_OK( + Tensor::CreateFromVector(iscrowd_row, TensorShape({static_cast(iscrowd_row.size()), 1}), &iscrowd)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(area_row, TensorShape({static_cast(area_row.size()), 1}), &area)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &area, data_schema_->column(4).tensorImpl(), TensorShape({static_cast(area_row.size()), 1}), - data_schema_->column(4).type(), reinterpret_cast(&area_row[0]))); (*trow) = TensorRow( row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)}); return Status::OK(); @@ -596,7 +588,7 @@ Status CocoOp::LaunchThreadsAndInitOp() { } Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor)); if (decode_ == true) { Status rc = Decode(*tensor, tensor); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index 31070c26f5..209a7726ca 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_COC0_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COC0_OP_H_ #include #include @@ -337,4 +337,4 @@ class CocoOp : public ParallelOp, public RandomAccessOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_Coco_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_Coco_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc new file mode 100644 index 0000000000..8fbcc7205e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -0,0 +1,780 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/csv_op.h" + +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +CsvOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(-1), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status CsvOp::Builder::ValidateInputs() const { + std::string err; + err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); +} + +Status CsvOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_csv_files_list_.size()) { + builder_num_workers_ = builder_csv_files_list_.size(); + MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + std::shared_ptr csv_op = std::make_shared( + builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, + builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); + RETURN_IF_NOT_OK(csv_op->Init()); + *op = std::move(csv_op); + + return Status::OK(); +} + +CsvOp::CsvOp(const std::vector &csv_files_list, char field_delim, + const std::vector> &column_default, + const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, + int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, + int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + csv_files_list_(std::move(csv_files_list)), + field_delim_(field_delim), + column_default_list_(column_default), + column_name_list_(column_name), + rows_per_buffer_(rows_per_buffer), + num_rows_per_shard_(0), + all_num_rows_(0), + num_samples_(num_samples), + filename_index_(std::make_unique()), + load_jagged_connector_(true), + shuffle_files_(shuffle_files), + finished_reading_dataset_(false), + num_devices_(num_device), + device_id_(device_id), + load_io_block_queue_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status CsvOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(csv_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + jagged_buffer_connector_ = std::make_shared(num_workers_, 1, worker_connector_size_); + + return Status::OK(); +} + +int CsvOp::CsvParser::put_record(char c) { + std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_); + std::shared_ptr t; + if (cur_col_ >= column_default_.size()) { + err_message_ = "Number of file columns does not match the default records"; + return -1; + } + switch (column_default_[cur_col_]->type) { + case CsvOp::INT: + Tensor::CreateScalar(std::stoi(s), &t); + break; + case CsvOp::FLOAT: + Tensor::CreateScalar(std::stof(s), &t); + break; + default: + Tensor::CreateScalar(s, &t); + break; + } + if (cur_col_ >= (*tensor_table_)[cur_row_].size()) { + err_message_ = "Number of file columns does not match the tensor table"; + return -1; + } + (*tensor_table_)[cur_row_][cur_col_] = std::move(t); + pos_ = 0; + cur_col_++; + return 0; +} + +int CsvOp::CsvParser::put_row(char c) { + if (total_rows_ < start_offset_) { + total_rows_++; + cur_col_ = 0; + return 0; + } + + if (total_rows_ >= end_offset_) { + cur_col_ = 0; + return 0; + } + + int ret = put_record(c); + if (ret < 0) { + return ret; + } + + total_rows_++; + cur_row_++; + cur_col_ = 0; + + if (cur_row_ == csv_rows_per_buffer_) { + cur_buffer_->set_tensor_table(std::move(tensor_table_)); + buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); + + cur_buffer_ = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table_ = std::make_unique(); + cur_row_ = 0; + } + return 0; +} + +int CsvOp::CsvParser::end_file(char c) { + if (cur_col_ > 0) { + put_row(c); + } + if (cur_row_ > 0) { + cur_buffer_->set_tensor_table(std::move(tensor_table_)); + buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); + } + return 0; +} + +int CsvOp::CsvParser::countRows(int c) { + Message m; + if (c == '"') { + m = Message::MS_QUOTE; + } else if (c == '\r' || c == '\n' || c == std::char_traits::eof()) { + m = Message::MS_END_OF_LINE; + } else { + m = Message::MS_NORMAL; + } + StateDiagram::iterator it = sdl.find({cur_state_, m}); + if (it == sd.end()) { + return -1; + } + cur_state_ = it->second.first; + return it->second.second(*this, c); +} + +Status CsvOp::CsvParser::initCsvParser() { + str_buf_.resize(CSV_BUFFER_SIZE); + + // State diagram for counting rows + sdl = {// START_OF_FILE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ null_func │ + // └───────────┴───────────┴─────────────┘ + {{State::START_OF_FILE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::START_OF_FILE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, + + // UNQUOTE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ add_row │ + // └───────────┴───────────┴─────────────┘ + {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::UNQUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, + + // QUOTE + // ┌───────────┬──────────────┬───────────┐ + // │ abc │ " │ \n │ + // ├───────────┼──────────────┼───────────┤ + // │ QUOTE │ SECOND_QUOTE │ QUOTE │ + // ├───────────┼──────────────┼───────────┤ + // | null_func │ null_func │ null_func │ + // └───────────┴──────────────┴───────────┘ + {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::null_func}}, + {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, + {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::null_func}}, + + // SECOND_QUOTE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ add_row │ + // └───────────┴───────────┴─────────────┘ + {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, + + // END_OF_LINE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ null_func │ + // └───────────┴───────────┴─────────────┘ + {{State::END_OF_LINE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::END_OF_LINE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}}; + + // State diagram for CSV parser + sd = {// START_OF_FILE + // ┌───────────┬──────────┬──────────┬────────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ + // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ + // | lambda │ lambda │ lambda │ null_func │ null_func │ + // └───────────┴──────────┴──────────┴────────────────┴────────────────┘ + {{State::START_OF_FILE, Message::MS_NORMAL}, + {State::UNQUOTE, + [this](CsvParser &, char c) -> int { + this->tensor_table_ = std::make_unique(); + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->str_buf_[0] = c; + this->pos_ = 1; + return 0; + }}}, + {{State::START_OF_FILE, Message::MS_DELIM}, + {State::DELIM, + [this](CsvParser &, char c) -> int { + this->tensor_table_ = std::make_unique(); + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + return this->put_record(c); + }}}, + {{State::START_OF_FILE, Message::MS_QUOTE}, + {State::QUOTE, + [this](CsvParser &, char c) -> int { + this->tensor_table_ = std::make_unique(); + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->pos_ = 0; + return 0; + }}}, + {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, + {{State::START_OF_FILE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::null_func}}, + + // UNQUOTE + // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // │ UNQUOTE │ DELIM │ EXCEPTION │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // | put_char │ put_record │ exception │ put_row │ end_file │ + // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ + {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, + {{State::UNQUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, + {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, + {{State::UNQUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, + // UNQUOTE-Exception + {{State::UNQUOTE, Message::MS_QUOTE}, {State::EXCEPTION, &CsvParser::catch_exception}}, + + // DELIM + // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // | put_char │ put_record │ lambda │ put_row │ end_file │ + // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ + {{State::DELIM, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, + {{State::DELIM, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, + {{State::DELIM, Message::MS_QUOTE}, + {State::QUOTE, + [this](CsvParser &, char c) -> int { + this->pos_ = 0; + return 0; + }}}, + {{State::DELIM, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, + {{State::DELIM, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, + + // QUOTE + // ┌───────────┬──────────┬──────────────┬──────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ + // │ QUOTE │ QUOTE │ SECOND_QUOTE │ QUOTE │ EXCEPTION │ + // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ + // | put_char │ put_char │ null_func │ put_char │ exception │ + // └───────────┴──────────┴──────────────┴──────────┴────────────────┘ + {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::put_char}}, + {{State::QUOTE, Message::MS_DELIM}, {State::QUOTE, &CsvParser::put_char}}, + {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, + {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::put_char}}, + // QUOTE-Exception + {{State::QUOTE, Message::MS_END_OF_FILE}, {State::EXCEPTION, &CsvParser::catch_exception}}, + + // SECOND_QUOTE + // ┌───────────┬────────────┬──────────┬─────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ + // │ EXCEPTION │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ + // | exception │ put_record │ put_char │ put_row │ end_file │ + // └───────────┴────────────┴──────────┴─────────────┴────────────────┘ + {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::put_char}}, + {{State::SECOND_QUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, + {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, + {{State::SECOND_QUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, + // SECOND_QUOTE-Exception + {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::EXCEPTION, &CsvParser::catch_exception}}, + + // END_OF_LINE + // ┌─────────┬────────┬────────┬─────────────┬─────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├─────────┼────────┼────────┼─────────────┼─────────────┤ + // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├─────────┼────────┼────────┼─────────────┼─────────────┤ + // | lambda │ lambda │ lambda │ null_func │ end_file │ + // └─────────┴────────┴────────┴─────────────┴─────────────┘ + {{State::END_OF_LINE, Message::MS_NORMAL}, + {State::UNQUOTE, + [this](CsvParser &, char c) -> int { + if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) { + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + } + this->str_buf_[0] = c; + this->pos_ = 1; + return 0; + }}}, + {{State::END_OF_LINE, Message::MS_DELIM}, + {State::DELIM, + [this](CsvParser &, char c) -> int { + if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) { + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + } + return this->put_record(c); + }}}, + {{State::END_OF_LINE, Message::MS_QUOTE}, + {State::QUOTE, + [this](CsvParser &, char c) -> int { + if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) { + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + } + return 0; + }}}, + {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, + {{State::END_OF_LINE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}}; + return Status::OK(); +} + +Status CsvOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); + csv_parser.setStartOffset(start_offset); + csv_parser.setEndOffset(end_offset); + std::ifstream ifs; + ifs.open(file, std::ifstream::in); + if (column_name_list_.empty()) { + std::string tmp; + getline(ifs, tmp); + } + csv_parser.Reset(); + try { + while (ifs.good()) { + // when ifstream reachs the end of file, the function get() return std::char_traits::eof() + // which is a 32-bit -1, it's not equal to the 8-bit -1 on Euler OS. So instead of char, we use + // int to receive its return value. + int chr = ifs.get(); + if (csv_parser.processMessage(chr) != 0) { + RETURN_STATUS_UNEXPECTED("Failed to parse file " + file + ":" + std::to_string(csv_parser.total_rows_ + 1) + + ". error message: " + csv_parser.err_message_); + } + } + } catch (std::invalid_argument &ia) { + std::string err_row = std::to_string(csv_parser.total_rows_ + 1); + RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", type does not match"); + } catch (std::out_of_range &oor) { + std::string err_row = std::to_string(csv_parser.total_rows_ + 1); + RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of range"); + } + return Status::OK(); +} + +Status CsvOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this))); + + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == -1 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (IsLastIteration()) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + UpdateRepeatAndEpochCounter(); + } + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + return Status::OK(); +} + +Status CsvOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// A print method typically used for debugging +void CsvOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ + << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; + for (int i = 0; i < csv_files_list_.size(); ++i) { + out << " " << csv_files_list_[i]; + } + out << "\n\n"; + } +} + +// Pops an element from a queue in io_block_queues +Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +Status CsvOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +Status CsvOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status CsvOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +Status CsvOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CsvDataset. Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +int64_t CsvOp::CountTotalRows(const std::string &file) { + CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); + std::ifstream ifs; + ifs.open(file, std::ifstream::in); + if (column_name_list_.empty()) { + std::string tmp; + getline(ifs, tmp); + } + csv_parser.Reset(); + while (ifs.good()) { + int chr = ifs.get(); + if (csv_parser.countRows(chr) != 0) { + break; + } + } + + return csv_parser.total_rows_; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status CsvOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +Status CsvOp::CountAllFileRows(const std::vector &files, bool csv_header, int64_t *count) { + std::shared_ptr op; + *count = 0; + if (csv_header) { + RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op)); + } else { + RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op)); + } + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +std::vector CsvOp::split(const std::string &s, char delim) { + std::vector res; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + res.push_back(item); + } + return res; +} + +Status CsvOp::ComputeColMap() { + // Set the column name mapping (base class field) + if (column_name_id_map_.empty()) { + if (column_name_list_.empty()) { + std::string line; + std::ifstream handle(csv_files_list_[0]); + getline(handle, line); + std::vector col_names = split(line, field_delim_); + for (int32_t i = 0; i < col_names.size(); i++) { + column_name_id_map_[col_names[i]] = i; + } + } else { + for (int32_t i = 0; i < column_name_list_.size(); i++) { + column_name_id_map_[column_name_list_[i]] = i; + } + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + if (column_default_list_.size() < column_name_id_map_.size()) { + for (int32_t i = column_default_list_.size(); i < column_name_id_map_.size(); i++) { + column_default_list_.push_back(std::make_shared>(CsvOp::STRING, "")); + } + } + if (column_default_list_.size() != column_name_id_map_.size()) { + RETURN_STATUS_UNEXPECTED("The number of column names does not match the column defaults"); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h new file mode 100644 index 0000000000..1921b61bdc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -0,0 +1,453 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +namespace mindspore { +namespace dataset { + +const size_t CSV_BUFFER_SIZE = 4096; +using StringIndex = AutoIndexObj; +class JaggedConnector; + +class CsvOp : public ParallelOp { + public: + enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; + + struct BaseRecord { + public: + BaseRecord() = default; + explicit BaseRecord(RecordType t) : type(t) {} + virtual ~BaseRecord() {} + RecordType type; + }; + + template + class Record : public BaseRecord { + public: + Record() = default; + Record(RecordType t, T v) : BaseRecord(t), value(v) {} + ~Record() {} + T value; + }; + + // CsvParser is a class that parsing CSV file. + // We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'. + // The 'sd' is used for parsing CSV syntactic, it's complete and complicate. + // The 'sdl' is used for counting the record rows, it's concise and it runs fast. + struct CsvParser { + public: + CsvParser() = delete; + + CsvParser(int32_t worker_id, std::shared_ptr connector, int64_t rows_per_buffer, char field_delim, + std::vector> column_default) + : worker_id_(worker_id), + buffer_connector_(connector), + csv_rows_per_buffer_(rows_per_buffer), + csv_field_delim_(field_delim), + column_default_(column_default), + cur_state_(START_OF_FILE), + pos_(0), + cur_row_(0), + cur_col_(0), + total_rows_(0), + start_offset_(0), + end_offset_(std::numeric_limits::max()), + err_message_("unkonw") { + cur_buffer_ = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + initCsvParser(); + } + + ~CsvParser() = default; + + void Reset() { + cur_state_ = START_OF_FILE; + pos_ = 0; + cur_row_ = 0; + cur_col_ = 0; + } + + void setStartOffset(int64_t start_offset) { start_offset_ = start_offset; } + + void setEndOffset(int64_t end_offset) { end_offset_ = end_offset; } + + int processMessage(int c) { + Message m = getMessage(c); + StateDiagram::iterator it = sd.find({cur_state_, m}); + if (it == sd.end()) { + return -1; + } + cur_state_ = it->second.first; + return it->second.second(*this, c); + } + + int countRows(int c); + + Status initCsvParser(); + + enum State : uint8_t { + START_OF_FILE = 0, + UNQUOTE, + DELIM, + QUOTE, + SECOND_QUOTE, + END_OF_LINE, + END_OF_FILE, + EXCEPTION + }; + + enum Message : uint8_t { + MS_NORMAL = 0, + MS_DELIM, + MS_QUOTE, + MS_END_OF_LINE, + MS_END_OF_FILE, + }; + + typedef std::pair StateMessagePair; + typedef std::pair> StateActionPair; + typedef std::map StateDiagram; + + Message getMessage(int c) { + if (c == csv_field_delim_) { + return Message::MS_DELIM; + } else if (c == '"') { + return Message::MS_QUOTE; + } else if (c == '\r' || c == '\n') { + return Message::MS_END_OF_LINE; + } else if (c == std::char_traits::eof()) { + return Message::MS_END_OF_FILE; + } else { + return Message::MS_NORMAL; + } + } + + int null_func(char c) { return 0; } + + int put_char(char c) { + if (pos_ >= str_buf_.size()) { + str_buf_.resize(str_buf_.size() * 2); + } + str_buf_[pos_] = c; + pos_++; + return 0; + } + + int put_record(char c); + + int put_row(char c); + + int end_file(char c); + + int add_row(char c) { + total_rows_++; + return 0; + } + + int catch_exception(char c) { + MS_LOG(ERROR) << "Invalid syntax!"; + return -1; + } + + int32_t worker_id_; + std::shared_ptr buffer_connector_; + int64_t csv_rows_per_buffer_; + const char csv_field_delim_; + std::vector> column_default_; + State cur_state_; + size_t pos_; + int cur_row_; + int cur_col_; + int64_t total_rows_; + int64_t start_offset_; + int64_t end_offset_; + StateDiagram sd; + StateDiagram sdl; + std::vector str_buf_; + std::unique_ptr tensor_table_; + std::unique_ptr cur_buffer_; + std::string err_message_; + }; + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetCsvFilesList(const std::vector &files_list) { + builder_csv_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetFieldDelim(char field_delim) { + builder_field_delim_ = field_delim; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColumDefault(std::vector> record_list) { + builder_column_default_list_ = record_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColumName(std::vector col_name_list) { + builder_column_name_list_ = col_name_list; + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_csv_files_list_; + bool builder_shuffle_files_; + char builder_field_delim_; + std::vector> builder_column_default_list_; + std::vector builder_column_name_list_; + }; + + // Constructor of CsvOp + CsvOp() = delete; + + CsvOp(const std::vector &csv_files_list, char field_delim, + const std::vector> &column_default, const std::vector &column_name, + int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~CsvOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all csv files. + // @param csv_header - a bool that indicates csv file include header line + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, bool csv_header, int64_t *count); + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return csv_files_list_; } + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a csv file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - csv file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + // Split string based on a character delimiter + // @return - the a string vector + std::vector split(const std::string &s, char delim); + + int32_t device_id_; + bool shuffle_files_; + bool finished_reading_dataset_; + int32_t num_devices_; + int64_t rows_per_buffer_; + bool load_io_block_queue_; + int64_t num_rows_per_shard_; + int64_t all_num_rows_; + int64_t num_samples_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + std::vector csv_files_list_; + WaitPost io_block_queue_wait_post_; + std::shared_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + bool load_jagged_connector_; + char field_delim_; + std::vector> column_default_list_; + std::vector column_name_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 773dfc78b6..a70f54bdee 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -61,8 +61,6 @@ GeneratorOp::GeneratorOp(py::function generator_function, std::vectorDealloc(); } void GeneratorOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); @@ -129,7 +127,7 @@ Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) "Generator should return a tuple of numpy arrays."); } std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, ret_py_ele.cast())); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast(), &tensor)); if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) && (column_types_[i] != tensor->type())) { return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator type check failed."); @@ -218,7 +216,7 @@ Status GeneratorOp::operator()() { MS_LOG(DEBUG) << "Generator operator sends out EOE."; std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { // If last repeat or not repeated, push out EOF and exit master loop MS_LOG(DEBUG) << "Generator operator sends out EOF."; std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); @@ -233,6 +231,7 @@ Status GeneratorOp::operator()() { // Clear the status of the wait post wp_.Clear(); } + UpdateRepeatAndEpochCounter(); } } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index d09bfc3d71..1d7f2b97f3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ #include #include @@ -27,6 +27,9 @@ #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/pipeline_op.h" #include "minddata/dataset/util/wait_post.h" +#include "pybind11/pybind11.h" + +namespace py = pybind11; namespace mindspore { namespace dataset { @@ -160,4 +163,4 @@ class GeneratorOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 85839303db..78dfae5dbe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -16,7 +16,7 @@ #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" @@ -151,7 +151,7 @@ Status ImageFolderOp::operator()() { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); @@ -168,6 +168,7 @@ Status ImageFolderOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } @@ -201,10 +202,8 @@ Status ImageFolderOp::WorkerEntry(int32_t worker_id) { // Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer Status ImageFolderOp::LoadTensorRow(row_id_type row_id, ImageLabelPair pairPtr, TensorRow *trow) { std::shared_ptr image, label; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&pairPtr->second))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, folder_path_ + (pairPtr->first))); + RETURN_IF_NOT_OK(Tensor::CreateScalar(pairPtr->second, &label)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pairPtr->first), &image)); if (decode_ == true) { Status rc = Decode(image, &image); @@ -230,8 +229,6 @@ Status ImageFolderOp::LoadBuffer(const std::vector &keys, std::unique_p } void ImageFolderOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h index 153751d3c5..219a4e53f9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ #include #include @@ -271,4 +271,4 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h index df26aa1fc1..86ae3c9c56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ #include #include @@ -122,4 +122,4 @@ class FilenameBlock : public IOBlock { }; // class TFBlock } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 0476baf56f..dfb131a43d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -20,7 +20,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" @@ -112,7 +112,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -129,6 +129,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } @@ -185,17 +186,14 @@ Status ManifestOp::LoadTensorRow(row_id_type row_id, const std::pair label_index(data.second.size()); (void)std::transform(data.second.begin(), data.second.end(), label_index.begin(), [this](const std::string &label_name) { return label_index_[label_name]; }); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(label_index, &label)); if (label_index.size() == 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), TensorShape({}), - data_schema_->column(1).type(), - reinterpret_cast(&label_index[0]))); + label->Reshape(TensorShape({})); } else { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &label, data_schema_->column(1).tensorImpl(), TensorShape(std::vector(1, label_index.size())), - data_schema_->column(1).type(), reinterpret_cast(&label_index[0]))); + label->Reshape(TensorShape(std::vector(1, label_index.size()))); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data.first)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.first, &image)); if (decode_ == true) { Status rc = Decode(image, &image); if (rc.IsError()) { @@ -220,8 +218,6 @@ Status ManifestOp::LoadBuffer(const std::vector &keys, std::unique_ptr< } void ManifestOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h index bac8f04c94..5a1a1d726b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ #include #include @@ -247,4 +247,4 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index cf1493eb78..8bc314b0af 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/global_context.h" @@ -51,7 +51,6 @@ MindRecordOp::Builder::Builder() : build_dataset_file_({}) { build_num_mind_record_workers_ = kDefaultMindRecordWorkers; build_rows_per_buffer_ = cfg->rows_per_buffer(); build_op_connector_queue_size_ = cfg->op_connector_size(); - build_block_reader_ = false; builder_num_workers_ = 0; build_num_padded_ = 0; build_sample_ = nullptr; @@ -69,10 +68,10 @@ Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { if (build_num_padded_ > 0) { sample_json = ToJson(build_sample_); } - new_mind_record_op = std::make_shared( - build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, - build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_, build_num_padded_, - sample_json, build_sample_bytes_); + new_mind_record_op = + std::make_shared(build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, + build_load_dataset_, build_op_connector_queue_size_, build_columns_to_load_, + build_operators_, build_num_padded_, sample_json, build_sample_bytes_); RETURN_IF_NOT_OK(new_mind_record_op->Init()); *ptr = std::move(new_mind_record_op); @@ -113,9 +112,8 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader, - int64_t num_padded, const mindrecord::json &sample_json, - const std::map &sample_bytes) + const std::vector> &operators, int64_t num_padded, + const mindrecord::json &sample_json, const std::map &sample_bytes) : ParallelOp(num_mind_record_workers, op_connector_queue_size), rows_per_buffer_(rows_per_buffer), dataset_file_(dataset_file), @@ -123,27 +121,21 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf columns_to_load_(columns_to_load), operators_(operators), num_mind_record_workers_(num_mind_record_workers), - block_reader_(block_reader), num_rows_(0), buffers_needed_(0), buf_cnt_(0), ended_worker_(0), - buffer_water_mark_(0), num_padded_(num_padded), sample_json_(sample_json), sample_bytes_(sample_bytes) { io_blk_queues_.Init(num_workers_, op_connector_queue_size); - if (!block_reader_) return; - for (int32_t i = 0; i < num_workers_; ++i) { - block_buffer_.emplace_back(std::make_unique>(std::vector{})); - } } // Private helper method to encapsulate some common construction/reset tasks Status MindRecordOp::Init() { shard_reader_ = std::make_unique(); auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, - block_reader_, num_padded_); + num_padded_); CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); @@ -204,8 +196,6 @@ MindRecordOp::~MindRecordOp() {} // A print method typically used for debugging void MindRecordOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -215,7 +205,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info ParallelOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\n Dataset file : "; + out << "\nDataset file : "; for (auto &file : dataset_file_) { out << file << " "; } @@ -264,23 +254,6 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) { } RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); - if (!block_reader_) { - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - - // update block-reader buffer - block_buffer_[buffer_id % num_workers_]->clear(); - { - std::unique_lock lck(mtx_block_reader_); - if (buffer_id == buffer_water_mark_) { - buffer_water_mark_++; - while (block_set_.count(buffer_water_mark_) > 0) (void)block_set_.erase(buffer_water_mark_++); - } else { - (void)block_set_.insert(buffer_id); - } - } - cv_reader_.notify_one(); RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); } RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); @@ -291,23 +264,16 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_bu *fetched_buffer = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); std::unique_ptr tensor_table = std::make_unique(); for (int32_t i = 0; i < rows_per_buffer_; ++i) { - ShardTuple tupled_buffer; - mindrecord::TaskType task_type = mindrecord::TaskType::kCommonTask; - if (block_reader_) { - if (i >= block_buffer_[buffer_id % num_workers_]->size()) break; - tupled_buffer = block_buffer_[buffer_id % num_workers_]->at(i); - } else { - int32_t row_id = buffer_id * rows_per_buffer_ + i; - auto rc = shard_reader_->GetNextById(row_id, worker_id); - task_type = rc.first; - tupled_buffer = rc.second; - if (task_type == mindrecord::TaskType::kPaddedTask) { - TensorRow tensor_row; - RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); - tensor_table->push_back(std::move(tensor_row)); - } - if (tupled_buffer.empty()) break; + int32_t row_id = buffer_id * rows_per_buffer_ + i; + auto rc = shard_reader_->GetNextById(row_id, worker_id); + auto task_type = rc.first; + auto tupled_buffer = rc.second; + if (task_type == mindrecord::TaskType::kPaddedTask) { + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); + tensor_table->push_back(std::move(tensor_row)); } + if (tupled_buffer.empty()) break; if (task_type == mindrecord::TaskType::kCommonTask) { for (const auto &tupled_row : tupled_buffer) { std::vector columns_blob = std::get<0>(tupled_row); @@ -381,36 +347,21 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector(num_elements), &new_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(new_shape, type, data, &tensor)); } else { std::vector shapeDetails = {static_cast(num_elements)}; auto new_shape = TensorShape(shapeDetails); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(new_shape, type, data, &tensor)); } tensor_row->push_back(std::move(tensor)); } return Status::OK(); } -Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { - { - std::unique_lock lck(mtx_block_reader_); - cv_reader_.wait(lck, [buffer_id, this] { return buffer_id < buffer_water_mark_ + num_workers_; }); - } - for (int32_t i = 0; i < rows_per_buffer_; i++) { - // Block reader does NOT care about argument - auto rc = shard_reader_->GetNextById(i, i); - ShardTuple tuple_buffer = rc.second; - if (tuple_buffer.empty()) break; - block_buffer_[buffer_id % num_workers_]->push_back(std::move(tuple_buffer)); - } - return Status::OK(); -} - // Class functor operator () override. // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // provide the master loop that drives the logic for performing the work @@ -423,12 +374,11 @@ Status MindRecordOp::operator()() { while (true) { // each iterator is 1 epoch for (int32_t i = 0; i < buffers_needed_; ++i) { - if (block_reader_) RETURN_IF_NOT_OK(FetchBlockBuffer(i)); std::vector keys(1, i); RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -446,6 +396,7 @@ Status MindRecordOp::operator()() { RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); shard_reader_wait_post_.Clear(); } + UpdateRepeatAndEpochCounter(); } } @@ -455,12 +406,7 @@ Status MindRecordOp::operator()() { Status MindRecordOp::Reset() { RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. - if (block_reader_) { - shard_reader_->Reset(); - buffer_water_mark_ = 0; - } else { - shard_reader_->ShuffleTask(); - } + shard_reader_->ShuffleTask(); shard_reader_wait_post_.Set(); return Status::OK(); @@ -473,7 +419,7 @@ Status MindRecordOp::LaunchThreadAndInitOp() { RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); - if (shard_reader_->Launch(!block_reader_) == MSRStatus::FAILED) { + if (shard_reader_->Launch(true) == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); } // Launch main workers that load DataBuffers by reading all images diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h index 367505b172..939e48b616 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ #pragma once #include @@ -94,11 +94,6 @@ class MindRecordOp : public ParallelOp { return *this; } - Builder &SetBlockReader() { - build_block_reader_ = true; - return *this; - } - Builder &SetLoadDataset(bool load_dataset) { build_load_dataset_ = load_dataset; return *this; @@ -132,7 +127,6 @@ class MindRecordOp : public ParallelOp { bool build_load_dataset_; std::vector build_columns_to_load_; std::vector> build_operators_; - bool build_block_reader_; int64_t build_num_padded_; py::handle build_sample_; std::map build_sample_bytes_; @@ -148,9 +142,8 @@ class MindRecordOp : public ParallelOp { // @param operators - ShardOperators for Shuffle, Category, Sample MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader, - int64_t num_padded_, const mindrecord::json &sample_json, - const std::map &sample_bytes_); + const std::vector> &operators, int64_t num_padded_, + const mindrecord::json &sample_json, const std::map &sample_bytes_); // Destructor ~MindRecordOp() override; @@ -206,8 +199,6 @@ class MindRecordOp : public ParallelOp { // Getter method std::vector columns_to_load() const { return columns_to_load_; } - bool block_reader() const { return block_reader_; } - bool load_dataset() const { return load_dataset_; } Status Init(); @@ -232,8 +223,6 @@ class MindRecordOp : public ParallelOp { Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, const mindrecord::json &columns_json, const mindrecord::TaskType task_type); - Status FetchBlockBuffer(const int32_t &buffer_id); - // Private function for computing the assignment of the column name map. // @return - Status Status ComputeColMap() override; @@ -244,12 +233,10 @@ class MindRecordOp : public ParallelOp { std::vector columns_to_load_; // Columns to load from dataset std::vector> operators_; // ShardOperators to use int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader - bool block_reader_; // block reader switch int32_t buffers_needed_; // Counter for the buffers that were fetched int64_t buf_cnt_; // Buffer counter int32_t num_rows_; // One more than the last row id in the range for this cache std::atomic ended_worker_; - std::atomic buffer_water_mark_; int64_t num_padded_; mindrecord::json sample_json_; @@ -263,14 +250,8 @@ class MindRecordOp : public ParallelOp { WaitPost shard_reader_wait_post_; QueueList> io_blk_queues_; - // For block reader - std::mutex mtx_block_reader_; - std::condition_variable cv_reader_; - std::vector>> block_buffer_; - std::unordered_set block_set_; - std::mutex ended_worker_mutex_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index 11ad18865e..731309b36a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -17,7 +17,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" @@ -111,7 +111,7 @@ Status MnistOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( @@ -128,6 +128,7 @@ Status MnistOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } @@ -160,12 +161,10 @@ Status MnistOp::WorkerEntry(int32_t worker_id) { // Load 1 TensorRow (image,label) using 1 MnistLabelPair. Status MnistOp::LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *trow) { std::shared_ptr image, label; - int32_t l = mnist_pair.second; // make a copy of cached tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), mnist_pair.first->shape(), - mnist_pair.first->type(), mnist_pair.first->GetBuffer())); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), reinterpret_cast(&l))); + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(mnist_pair.first, &image)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(mnist_pair.second, &label)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); return Status::OK(); } @@ -183,8 +182,6 @@ Status MnistOp::LoadBuffer(const std::vector &keys, std::unique_ptr:"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -325,8 +322,8 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la pixels[m] = (pixels[m] == 0) ? 0 : 255; } std::shared_ptr image; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), img_tensor_shape, - data_schema_->column(0).type(), reinterpret_cast(pixels))); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(img_tensor_shape, data_schema_->column(0).type(), + reinterpret_cast(pixels), &image)); image_label_pairs_.emplace_back(std::make_pair(image, labels_buf[j])); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index 039f6b112f..1f2e3dd730 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ #include #include @@ -40,7 +40,7 @@ namespace dataset { template class Queue; -using MnistLabelPair = std::pair, int32_t>; +using MnistLabelPair = std::pair, uint32_t>; class MnistOp : public ParallelOp, public RandomAccessOp { public: @@ -249,4 +249,4 @@ class MnistOp : public ParallelOp, public RandomAccessOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 46f3adfa62..292fb4cdf0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -91,8 +91,6 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64 // A print method typically used for debugging void RandomDataOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -221,7 +219,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { all_out_.Wait(); // If we are not in a repeat loop, or that was the last repeat already, then setup our exit // condition from the master loop. - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { *quitting = true; } @@ -231,6 +229,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { if (last_guy_in) { MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " << eoe_worker_id_; + UpdateRepeatAndEpochCounter(); // Prepare for sync all_out_.Clear(); // Always flow eoe at the end @@ -361,8 +360,7 @@ Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor."); } - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get())); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(*new_shape, current_col.type(), buf.get(), &new_tensor)); // Add this tensor to the tensor row for output (*new_row).push_back(std::move(new_tensor)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h index c77695439d..4644c29f0c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ #include #include @@ -288,4 +288,4 @@ class RandomDataOp : public ParallelOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 2b5e7c67c8..2299f802bb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -24,13 +24,14 @@ namespace mindspore { namespace dataset { DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed) + uint32_t seed, bool even_dist) : Sampler(num_samples, std::numeric_limits::max()), cnt_(0), seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), device_id_(dev_id), num_devices_(num_dev), - shuffle_(shuffle) {} + shuffle_(shuffle), + even_dist_(even_dist) {} Status DistributedSampler::InitSampler() { // Special value of 0 for num_samples means that the user wants to sample the entire set of data. @@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, "fail to init DistributedSampler"); rnd_.seed(seed_++); - samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) + if (even_dist_) { + samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) + } else { + int64_t mod = num_rows_ % num_devices_; + samples_per_buffer_ = num_rows_ / num_devices_; + if (mod > device_id_) { + samples_per_buffer_++; + } + } samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; if (shuffle_ == true) { shuffle_vec_.reserve(num_rows_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h index 76bcf052f9..215611cfbe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ #include #include @@ -27,26 +27,32 @@ namespace mindspore { namespace dataset { class DistributedSampler : public Sampler { public: - // @param num_samples - // @param int64_t num_dev - // @param int64_t dev_id - // @param bool shuffle + /// \brief Constructor + /// \param[in] num_samples The total number of rows in the dataset + /// \param[in] num_dev Total number of shards for the distributed sampler + /// \param[in] dev_id Device id of the shard + /// \param[in] shuffle Option to shuffle + /// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will + /// result in different samples being picked + /// \param even_dist The option to indicate whether or not each shard returns the same number of rows. + /// This option is not exposed in the python API. Current behavior is that the remainder will always + /// be handled by the first n shards, n being the corresponding device id. DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed = std::numeric_limits::max()); + uint32_t seed = std::numeric_limits::max(), bool even_dist = true); - // default destructor + /// \brief default destructor ~DistributedSampler() = default; - // @param std::unique_ptr * pBuffer - // @param int32_t workerId - // @return - The error code return + /// \param std::unique_ptr * pBuffer + /// \param int32_t workerId + /// \return Status code Status GetNextSample(std::unique_ptr *out_buffer) override; - // Init sampler, called by base class or python + /// Init sampler, called by base class or python Status InitSampler() override; - // for next epoch of sampleIds - // @return - The error code return + /// \brief for next epoch of sampleIds + /// \return Status code Status ResetSampler() override; void Print(std::ostream &out, bool show_all) const override; @@ -59,8 +65,9 @@ class DistributedSampler : public Sampler { bool shuffle_; std::mt19937 rnd_; std::vector shuffle_vec_; + bool even_dist_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h index aed61fa273..e51c419cd4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ #include #include @@ -73,4 +73,4 @@ class PKSampler : public Sampler { // NOT YET FINISHED } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc index 50c67bca6c..a501a2dcb0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -41,7 +41,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { try { py::object py_ret = py_sampler_instance.attr("_get_indices")(); py::array np_sample_ids = py_ret.cast(); - Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + Tensor::CreateFromNpArray(np_sample_ids, &sample_ids); // copy numpy to tensor if (HasChildSampler()) { for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h index 61716feb94..0700edee27 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ #include #include @@ -63,4 +63,4 @@ class PythonSampler : public Sampler { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h index 6e21b088b9..fe5330a42f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ #include #include @@ -63,4 +63,4 @@ class RandomSampler : public Sampler { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index 60d75d2eec..f13e8122c8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -73,9 +73,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t col_desc_ = std::make_unique("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); } TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type())); - RETURN_IF_NOT_OK( - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets! + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc_->type(), sample_ids)); return Status::OK(); } @@ -103,6 +101,13 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { // check this buffer is not a ctrl buffer CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); + + // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch + RETURN_IF_NOT_OK(GetNextSample(&db)); + CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); + // Reset Sampler since this is the end of the epoch + RETURN_IF_NOT_OK(ResetSampler()); + { py::gil_scoped_acquire gil_acquire; if (Py_IsInitialized() == 0) { @@ -114,11 +119,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { return Status(StatusCode::kPyFuncException, e.what()); } } - // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch - RETURN_IF_NOT_OK(GetNextSample(&db)); - CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); - // Reset Sampler since this is the end of the epoch - RETURN_IF_NOT_OK(ResetSampler()); return Status::OK(); } #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 4cae935a42..268eb1256d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ #include #include @@ -158,4 +158,4 @@ class Sampler { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h index c6ccd0d1eb..2a313347f1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ #include #include @@ -62,4 +62,4 @@ class SequentialSampler : public Sampler { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index fccc15e57b..0a1feef0a9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ #include #include @@ -72,4 +72,4 @@ class SubsetRandomSampler : public Sampler { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index b1a531abe9..9bcb2bac22 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ #include #include diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index c1f5b13a94..f069139859 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/engine/datasetops/source/text_file_op.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/util/task_manager.h" @@ -58,7 +58,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { // Throttle the number of workers if we have more workers than files! if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { builder_num_workers_ = builder_text_files_list_.size(); - MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + MS_LOG(DEBUG) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; } builder_schema_ = std::make_unique(); @@ -98,8 +98,6 @@ TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t tot // A print method typically used for debugging void TextFileOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -146,7 +144,7 @@ Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptrpush_back(std::move(tRow)); std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); (**tensor_table)[row][0] = std::move(tensor); return Status::OK(); } @@ -421,13 +419,14 @@ Status TextFileOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index 68c226ab80..9dfb4ac2ae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ #include #include @@ -286,4 +286,4 @@ class TextFileOp : public ParallelOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index ae7907b5ce..98c3622415 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -27,7 +27,7 @@ #include "proto/example.pb.h" #include "./securec.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/connector.h" @@ -158,8 +158,6 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 // A print method typically used for debugging void TFReaderOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -310,13 +308,14 @@ Status TFReaderOp::operator()() { std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { finished_reading_dataset_ = true; NotifyToFillIOBlockQueue(); } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; } + UpdateRepeatAndEpochCounter(); } std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); @@ -677,8 +676,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table // into the tensor TensorShape current_shape = TensorShape::CreateUnknownRankShape(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape)); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&ts, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(current_shape, current_col.type(), data_ptr, &ts)); break; } case dataengine::Feature::KindCase::kInt64List: { @@ -735,7 +733,7 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng if (current_col.type() == DataType::DE_STRING) { TensorShape shape = TensorShape::CreateScalar(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape)); + RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, shape, tensor)); return Status::OK(); } @@ -763,7 +761,7 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng // know how many elements there are and the total bytes, create tensor here: TensorShape current_shape = TensorShape::CreateScalar(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, current_shape, current_col.type(), pad_size)); + RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, current_shape, current_col.type(), pad_size, tensor)); return Status::OK(); } @@ -836,10 +834,7 @@ Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengin // know how many elements there are, create tensor here: TensorShape current_shape = TensorShape::CreateUnknownRankShape(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type())); - - // Tensors are lazily allocated, this eagerly allocates memory for the tensor. - RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(current_shape, current_col.type(), tensor)); int64_t i = 0; auto it = (*tensor)->begin(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h index c03f3957e9..8a295291b4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ #include #include @@ -417,4 +417,4 @@ class TFReaderOp : public ParallelOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index e90d423ef4..49f25aa695 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -19,7 +19,7 @@ #include #include #include "./tinyxml2.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" @@ -34,7 +34,10 @@ namespace mindspore { namespace dataset { const char kColumnImage[] = "image"; const char kColumnTarget[] = "target"; -const char kColumnAnnotation[] = "annotation"; +const char kColumnBbox[] = "bbox"; +const char kColumnLabel[] = "label"; +const char kColumnDifficult[] = "difficult"; +const char kColumnTruncate[] = "truncate"; const char kJPEGImagesFolder[] = "/JPEGImages/"; const char kSegmentationClassFolder[] = "/SegmentationClass/"; const char kAnnotationsFolder[] = "/Annotations/"; @@ -70,7 +73,13 @@ Status VOCOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(builder_schema_->AddColumn( ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); } *ptr = std::make_shared(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, @@ -136,7 +145,7 @@ Status VOCOp::operator()() { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + if (IsLastIteration()) { std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); @@ -153,12 +162,11 @@ Status VOCOp::operator()() { wp_.Clear(); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + UpdateRepeatAndEpochCounter(); } } void VOCOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info ParallelOp::Print(out, show_all); @@ -190,14 +198,16 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target)); (*trow) = TensorRow(row_id, {std::move(image), std::move(target)}); } else if (task_type_ == TaskType::Detection) { - std::shared_ptr image, annotation; + std::shared_ptr image; + TensorRow annotation; const std::string kImageFile = folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); const std::string kAnnotationFile = folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation)); - (*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)}); + RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); + trow->push_back(std::move(image)); + trow->insert(trow->end(), annotation.begin(), annotation.end()); } return Status::OK(); } @@ -271,7 +281,7 @@ Status VOCOp::ParseAnnotationIds() { const std::string kAnnotationName = folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); - if (label_map_.find(kAnnotationName) != label_map_.end()) { + if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) { new_image_ids.push_back(id); } } @@ -293,7 +303,7 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) { if (!Path(path).Exists()) { RETURN_STATUS_UNEXPECTED("File is not found : " + path); } - Bbox bbox; + Annotation annotation; XMLDocument doc; XMLError e = doc.LoadFile(common::SafeCStr(path)); if (e != XMLError::XML_SUCCESS) { @@ -332,13 +342,13 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) { } if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && ymin > 0 && xmax > xmin && ymax > ymin) { - std::vector bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; - bbox.emplace_back(std::make_pair(label_name, bbox_list)); + std::vector bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, difficult, truncated}; + annotation.emplace_back(std::make_pair(label_name, bbox_list)); label_index_[label_name] = 0; } object = object->NextSiblingElement("object"); } - if (bbox.size() > 0) label_map_[path] = bbox; + if (annotation.size() > 0) annotation_map_[path] = annotation; return Status::OK(); } @@ -364,7 +374,7 @@ Status VOCOp::LaunchThreadsAndInitOp() { } Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor)); if (decode_ == true) { Status rc = Decode(*tensor, tensor); if (rc.IsError()) { @@ -374,31 +384,38 @@ Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &co return Status::OK(); } -Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, - std::shared_ptr *tensor) { - Bbox bbox_info = label_map_[path]; - std::vector bbox_row; - dsize_t bbox_column_num = 0, bbox_num = 0; - for (auto box : bbox_info) { - if (label_index_.find(box.first) != label_index_.end()) { - std::vector bbox; - bbox.insert(bbox.end(), box.second.begin(), box.second.end()); - if (class_index_.find(box.first) != class_index_.end()) { - bbox.push_back(static_cast(class_index_[box.first])); +// When task is Detection, user can get bbox data with four columns: +// column ["bbox"] with datatype=float32 +// column ["label"] with datatype=uint32 +// column ["difficult"] with datatype=uint32 +// column ["truncate"] with datatype=uint32 +Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) { + Annotation annotation = annotation_map_[path]; + std::shared_ptr bbox, label, difficult, truncate; + std::vector bbox_data; + std::vector label_data, difficult_data, truncate_data; + dsize_t bbox_num = 0; + for (auto item : annotation) { + if (label_index_.find(item.first) != label_index_.end()) { + if (class_index_.find(item.first) != class_index_.end()) { + label_data.push_back(static_cast(class_index_[item.first])); } else { - bbox.push_back(static_cast(label_index_[box.first])); - } - bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); - if (bbox_column_num == 0) { - bbox_column_num = static_cast(bbox.size()); + label_data.push_back(static_cast(label_index_[item.first])); } + CHECK_FAIL_RETURN_UNEXPECTED(item.second.size() == 6, "annotation only support 6 parameters."); + + std::vector tmp_bbox = {(item.second)[0], (item.second)[1], (item.second)[2], (item.second)[3]}; + bbox_data.insert(bbox_data.end(), tmp_bbox.begin(), tmp_bbox.end()); + difficult_data.push_back(static_cast((item.second)[4])); + truncate_data.push_back(static_cast((item.second)[5])); bbox_num++; } } - - std::vector bbox_dim = {bbox_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(), - reinterpret_cast(&bbox_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(bbox_data, TensorShape({bbox_num, 4}), &bbox)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(label_data, TensorShape({bbox_num, 1}), &label)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(difficult_data, TensorShape({bbox_num, 1}), &difficult)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(truncate_data, TensorShape({bbox_num, 1}), &truncate)); + (*row) = TensorRow({std::move(bbox), std::move(label), std::move(difficult), std::move(truncate)}); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h index e0c46c7a94..ef5578f467 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ #include #include @@ -40,7 +40,7 @@ namespace dataset { template class Queue; -using Bbox = std::vector>>; +using Annotation = std::vector>>; class VOCOp : public ParallelOp, public RandomAccessOp { public: @@ -234,10 +234,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp { Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return + // @param TensorRow *row - return // @return Status - The error code return - Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + Status ReadAnnotationToTensor(const std::string &path, TensorRow *row); // @param const std::vector &keys - keys in ioblock // @param std::unique_ptr db @@ -287,8 +286,8 @@ class VOCOp : public ParallelOp, public RandomAccessOp { QueueList> io_block_queues_; std::map class_index_; std::map label_index_; - std::map label_map_; + std::map annotation_map_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc index d1f07983f7..f754b4898a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc @@ -16,7 +16,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/datasetops/take_op.h" @@ -53,8 +53,6 @@ TakeOp::TakeOp(int32_t count, int32_t op_connector_size) // A print method typically used for debugging void TakeOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); @@ -84,6 +82,7 @@ Status TakeOp::operator()() { // Loop until non EOE is received if (buf->eoe()) { + UpdateRepeatAndEpochCounter(); take_count_ = 0; RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h index 7f3f821bd8..d055207520 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ #include #include @@ -86,7 +86,7 @@ class TakeOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "TakeOp"; } + std::string Name() const override { return kTakeOp; } private: int32_t max_takes_; // The number of takes that the user requested @@ -97,4 +97,4 @@ class TakeOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc index 88019c30fc..1b6a0ecb79 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc @@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) { if (eof_) { return Status::OK(); } + // One of our child iterators encounter EOE. Returns and proceed with draining phase. if (new_row.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); + return Status::OK(); } // Pack this first row into our tensor table @@ -207,8 +208,6 @@ Status ZipOp::drainPipeline() { // A function that prints info about the Operator void ZipOp::Print(std::ostream &out, // In: The output stream to print to bool show_all) const { // In: T/F if it should print everything - // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; if (!show_all) { // Call the super class for displaying any common 1-liner info PipelineOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h index c9466e26e2..2995b49c23 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ #include #include @@ -112,7 +112,7 @@ class ZipOp : public PipelineOp { // Op name getter // @return Name of the current Op - std::string Name() const override { return "ZipOp"; } + std::string Name() const override { return kZipOp; } private: // Handles preprocessing of the main loop, used when starting new epoch @@ -155,4 +155,4 @@ class ZipOp : public PipelineOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/db_connector.h b/mindspore/ccsrc/minddata/dataset/engine/db_connector.h index 4a5c20bc12..2d2cf6d226 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/db_connector.h +++ b/mindspore/ccsrc/minddata/dataset/engine/db_connector.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_DB_CONNECTOR_H_ -#define DATASET_ENGINE_DB_CONNECTOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DB_CONNECTOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DB_CONNECTOR_H_ #include #include @@ -95,4 +95,4 @@ class DbConnector : public Connector> { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_DB_CONNECTOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DB_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 55dec24e79..79c954595e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -23,6 +23,7 @@ #include "minddata/dataset/engine/opt/pre/removal_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/opt/post/repeat_pass.h" +#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/monitor.h" @@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr &op) { if (op->tree_ == this) { return Status::OK(); } - if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { + if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding && tree_state_ != kDeTStatePrepare) { std::string err_msg = "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast(tree_state_)) + " Expected states: " + std::to_string(static_cast(kDeTStateInit)) + " or " + - std::to_string(static_cast(kDeTStateBuilding)); + std::to_string(static_cast(kDeTStateBuilding)) + " or " + std::to_string(static_cast(kDeTStatePrepare)); RETURN_STATUS_UNEXPECTED(err_msg); } @@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::functionPrepareTreePreAction()); @@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() { std::vector> pre_actions; // Construct pre actions MS_LOG(INFO) << "Running pre pass loops."; + pre_actions.push_back(std::make_unique()); pre_actions.push_back(std::make_unique()); pre_actions.push_back(std::make_unique()); // Apply pre action passes @@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() { " Expected state: " + std::to_string(static_cast(kDeTStatePrepare)); RETURN_STATUS_UNEXPECTED(err_msg); } + + if (root_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree."); + } + // Start the recursive prepare RETURN_IF_NOT_OK(this->PrepareNode(root_)); tree_state_ = kDeTStateReady; diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h index b62bf8e85d..41d51733c1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_EXECUTION_TREE_H_ -#define DATASET_ENGINE_EXECUTION_TREE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_ #include #include @@ -176,7 +176,7 @@ class ExecutionTree { // For example, repeatOp inlining // // @return Status - The error code return - Status Prepare(); + Status Prepare(int num_epochs = -1); // Compulsory transformation/action pre optimization. // @return Status - The error code return @@ -193,6 +193,7 @@ class ExecutionTree { // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively // walk the tree to perform modifications to the tree or specific nodes within the tree to get // it ready for execution. + // @param Total number of epochs that will be run on this tree // @return Status - The error code return Status PrepareDeprecated(); @@ -231,6 +232,10 @@ class ExecutionTree { // Optional optimizations status bool OptimizationEnabled() const { return optimize_; } + // Getter function to get the total number of epochs to be run on this tree. + // @return total number of epochs + int32_t num_epochs() { return num_epochs_; } + private: // A helper functions for doing the recursive printing // @param dataset_op - The dataset op to print @@ -245,6 +250,7 @@ class ExecutionTree { int32_t id_count_; // Counter for generating operator id's uint32_t prepare_flags_; // Flags used during tree prepare TreeState tree_state_; // Tracking the current tree state + int32_t num_epochs_; // Total number of epochs to run for this tree std::unique_ptr perf_monitor_; // Performance Monitor std::unique_ptr profiling_manager_; // Profiling manager bool optimize_; // Flag to enable optional optimizations @@ -254,4 +260,4 @@ inline bool operator==(const ExecutionTree::Iterator &lhs, const ExecutionTree:: } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_EXECUTION_TREE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h index c62c088bab..e7f4eef793 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_EDGE_H_ -#define DATASET_ENGINE_GNN_EDGE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_EDGE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_EDGE_H_ #include #include @@ -83,4 +83,4 @@ class Edge { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_EDGE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_EDGE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h index 0d7eba1009..0151ada706 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_FEATURE_H_ -#define DATASET_ENGINE_GNN_FEATURE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_FEATURE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_FEATURE_H_ #include @@ -49,4 +49,4 @@ class Feature { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_FEATURE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_FEATURE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc index 9083eb4c4b..7cbfedcf46 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc @@ -57,8 +57,7 @@ Status Graph::CreateTensorByVector(const std::vector> &data, Data std::shared_ptr tensor; size_t m = data.size(); size_t n = data[0].size(); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast(m), static_cast(n)}), type, &tensor)); auto ptr = tensor->begin(); for (const auto &id_m : data) { CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); @@ -310,8 +309,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); shape = shape.PrependDim(size); std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->Value()->type(), &fea_tensor)); dsize_t index = 0; for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { @@ -358,8 +356,7 @@ Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::ve dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); shape = shape.PrependDim(size); std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->Value()->type(), &fea_tensor)); dsize_t index = 0; for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h index 76930d91f2..cb755b0bed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_GRAPH_H_ -#define DATASET_ENGINE_GNN_GRAPH_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ #include #include @@ -264,4 +264,4 @@ class Graph { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_GRAPH_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 9d2c6211f4..2339b02de2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -125,7 +125,7 @@ Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrec (*feature_map)[node_type].insert(ind); if ((*default_feature)[ind] == nullptr) { std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); RETURN_IF_NOT_OK(zero_tensor->Zero()); (*default_feature)[ind] = std::make_shared(ind, zero_tensor); } @@ -151,7 +151,7 @@ Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrec (*feature_map)[edge_type].insert(ind); if ((*default_feature)[ind] == nullptr) { std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); RETURN_IF_NOT_OK(zero_tensor->Zero()); (*default_feature)[ind] = std::make_shared(ind, zero_tensor); } @@ -170,9 +170,9 @@ Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector< key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, - std::move(TensorShape({static_cast(n_bytes / col_type_size)})), - std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast(n_bytes / col_type_size)})), + std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), + data, tensor)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h index f7f9245b8a..e59b13837c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_ -#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_ #include #include @@ -126,4 +126,4 @@ class GraphLoader { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h index d112972f8f..e9c7ba7f0e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_ -#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_EDGE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_EDGE_H_ #include #include @@ -57,4 +57,4 @@ class LocalEdge : public Edge { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_EDGE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h index 9c122931e7..350797ac75 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_ -#define DATASET_ENGINE_GNN_LOCAL_NODE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_NODE_H_ #include #include @@ -79,4 +79,4 @@ class LocalNode : public Node { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_LOCAL_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h index a7c803fee2..c89bb0e905 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_GNN_NODE_H_ -#define DATASET_ENGINE_GNN_NODE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_NODE_H_ #include #include @@ -84,4 +84,4 @@ class Node { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_GNN_NODE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h b/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h index cee0b7abf3..bfd543da56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h +++ b/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_JAGGED_CONNECTOR_H_ -#define DATASET_ENGINE_JAGGED_CONNECTOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_JAGGED_CONNECTOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_JAGGED_CONNECTOR_H_ #include #include @@ -85,4 +85,4 @@ class JaggedConnector : public Connector> { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_JAGGED_CONNECTOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_JAGGED_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 0ab1fb7925..50346ffad8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -3,9 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE add_library(engine-opt OBJECT pass.cc post/repeat_pass.cc - pre/cache_pass.cc pre/cache_transform_pass.cc - pre/removal_nodes.cc + pre/epoch_injection_pass.cc pre/removal_pass.cc optional/tensor_op_fusion_pass.cc util/printer_pass.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc index d8ce2dd863..fc0eb027b6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc @@ -17,7 +17,7 @@ #include #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/kernels/image/decode_op.h" -#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" #include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" namespace mindspore { diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h index a109af396c..614c267c05 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TENSOR_OP_FUSION_PASS_H_ -#define DATASET_TENSOR_OP_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TENSOR_OP_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TENSOR_OP_FUSION_PASS_H_ #include #include "minddata/dataset/engine/opt/pass.h" @@ -35,4 +35,4 @@ class TensorOpFusionPass : public NodePass { } // namespace dataset } // namespace mindspore -#endif // DATASET_TENSOR_OP_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TENSOR_OP_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index 4a8bbaf38f..4a2041e63d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -16,12 +16,15 @@ #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/build_vocab_op.h" +#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" -#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" #include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/rename_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" @@ -31,10 +34,14 @@ #include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#endif #include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#endif #include "minddata/dataset/engine/datasetops/source/voc_op.h" #ifdef ENABLE_PYTHON #include "minddata/dataset/engine/datasetops/filter_op.h" @@ -133,6 +140,7 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { return RunOnNode(std::static_pointer_cast(node), modified); } +#ifndef ENABLE_ANDROID Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); @@ -142,6 +150,7 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); } +#endif #ifdef ENABLE_PYTHON Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { @@ -153,6 +162,16 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); } + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} #endif Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { @@ -190,21 +209,11 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { return RunOnNode(std::static_pointer_cast(node), modified); } -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); } -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); @@ -230,6 +239,11 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) return RunOnNode(std::static_pointer_cast(node), modified); } +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return PreRunOnNode(std::static_pointer_cast(node), modified); @@ -244,5 +258,20 @@ Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified // Fallback to base class visitor by default return PreRunOnNode(std::static_pointer_cast(node), modified); } + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index 845ab34d66..f154b6c205 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_OPT_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_ #include #include @@ -37,9 +37,11 @@ class SkipOp; class ShuffleOp; +#ifndef ENABLE_ANDROID class MindRecordOp; class TFReaderOp; +#endif #ifdef ENABLE_PYTHON class FilterOp; @@ -77,6 +79,12 @@ class CacheMergeOp; class CacheLookupOp; +class EpochCtrlOp; + +class BuildVocabOp; + +class BuildSentencePieceVocabOp; + // The base class Pass is the basic unit of tree transformation. // The actual implementation of the passes will be derived from here. class Pass : public std::enable_shared_from_this { @@ -85,6 +93,8 @@ class Pass : public std::enable_shared_from_this { // @param tree - Pointer to the execution tree to be transformed. // @param modified - Pointer to the modified flag, virtual Status Run(ExecutionTree *tree, bool *modified) = 0; + + virtual ~Pass() = default; }; // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. @@ -150,14 +160,20 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); +#ifndef ENABLE_ANDROID virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); +#endif #ifdef ENABLE_PYTHON virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); #endif virtual Status RunOnNode(std::shared_ptr node, bool *modified); @@ -174,12 +190,8 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); @@ -190,12 +202,20 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + private: // Helper function to perform DFS visit Status DFSNodeVisit(std::shared_ptr node, bool *modified); @@ -210,4 +230,4 @@ class NodePass : public Pass { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_OPT_PASS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index 59a3f71c53..aac0eaa2e9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -20,19 +20,68 @@ #include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" namespace mindspore { namespace dataset { -RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} +RepeatPass::RepeatPass() + : is_repeated_(false), + nested_repeats_(0), + num_repeats_(1), + num_epochs_(1), + is_merge_(false), + is_cached_(false), + cache_lookup_(nullptr) {} // Identifies the subtree below this node as being in a repeated path of the tree. Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Create a new stack for eoe operators and push onto our stack of stacks. + std::unique_ptr new_stack = std::make_unique(); + eoe_op_stacks_.push(std::move(new_stack)); // If we are already repeated, then this is a nested repeat. if (is_repeated_) { nested_repeats_++; } is_repeated_ = true; + + // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. + // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. + if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { + num_repeats_ = -num_repeats_; + } + // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times. + // + // Consider this example: + // tfreader --> map --> repeat(2) --> epoch ctrl(3) + // num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3), + // meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op. + // + // Another example: + // tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4) + // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4), + // meaning repeat2 and map op should be set to read 8 times (2*4). + // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times. + num_repeats_ *= node->num_repeats(); + return Status::OK(); +} + +// Identifies the subtree below this node as being in a repeated path of the tree. +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // EpochCtrl is derived from RepeatOp. Generally it should do the identical setup + // that RepeatOp does. However, epoch control is actually simpler because it can + // only exist as the root node so it doesn't need all the nested code. + // Create a new stack for eoe operators and push onto our stack of stacks. + std::unique_ptr new_stack = std::make_unique(); + eoe_op_stacks_.push(std::move(new_stack)); + is_repeated_ = true; + // Get the total number of epochs from the EpochCtrlOp parameter + num_epochs_ = node->num_repeats(); + // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. + // For example: tfreader --> epoch ctrl(3) + // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3), + // meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op. + num_repeats_ *= num_epochs_; return Status::OK(); } @@ -43,42 +92,86 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modifi return Status::OK(); } +// Identifies the subtree below this node as being cached +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that we're under a merge op + is_cached_ = true; + return Status::OK(); +} + // Hooks up any identified eoe nodes under this repeat. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { node->AddToEoeList(leaf_op); leaf_op = PopFromEOEOpStack(); } + // At this point, we are done with the save area stack. It's a unique pointer to an empty stack + // at this time, so we can pop it to get rid of it. + op_stack *current_stack = eoe_op_stacks_.top().get(); + if (!current_stack->empty()) { + RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); + } + eoe_op_stacks_.pop(); + // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up - // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. + // and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed + // from the save area, because the merge op above us may also take action on it later for a different + // case when there is no repeat in the merge leg. if (is_merge_ && cache_lookup_) { - cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); + cache_lookup_->set_total_repeats(num_repeats_); + cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); node->AddToEoeList(std::move(cache_lookup_)); } // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. if (nested_repeats_ > 0) { - node->set_control_flag(DatasetOp::kDeOpRepeated); AddToEOEOpStack(node); nested_repeats_--; - } - - // If we are not nested, or we were the top-most repeat, now we clear the flag - if (nested_repeats_ == 0) { + } else { + // If we are not nested, or we were the top-most repeat, now we clear the flag + if (nested_repeats_ != 0) { + RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); + } is_repeated_ = false; } + if (is_cached_) { + AddToCachedOpStack(node); + } + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + // We finish the walk of this RepeatOp's descendent nodes. + // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. + // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, + // so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. + num_repeats_ /= node->num_repeats(); + return Status::OK(); +} +// Hooks up any identified eoe nodes under this repeat. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Pop the leaf ops from the save-area stack and add them to the eoe node tracking + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + node->AddToEoeList(leaf_op); + leaf_op = PopFromEOEOpStack(); + } + is_repeated_ = false; + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + // We finish the walk of this EpochCtrl's descendent nodes. + num_repeats_ /= node->num_repeats(); return Status::OK(); } // CacheOp removes previous leaf ops and replaces them with itself Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + is_cached_ = false; if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); // if we are a cache within a repeat path of the tree, then there will be // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the // repeat or epoch ctrl operators can work with them for repeat activity during runtime. @@ -90,13 +183,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // the repeating behaviours shall be invoked against the cache op. std::shared_ptr leaf_op = PopFromEOEOpStack(); while (leaf_op != nullptr) { - leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); - leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); leaf_op = PopFromEOEOpStack(); } AddToEOEOpStack(std::static_pointer_cast(node)); + + // adjust the total epochs and total repeats for ops under this cache op + std::shared_ptr cached_op = PopFromCachedOpStack(); + while (cached_op != nullptr) { + int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; + cached_op->set_total_repeats(cached_op_total_repeats); + // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 + cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); + cached_op = PopFromCachedOpStack(); + } } + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); return Status::OK(); } @@ -105,22 +208,36 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // If we are in a repeat path, then set our repeated flag if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - // if we are a leaf node then save ourself in a stack for the repeat operator above us if (node->IsLeaf()) { AddToEOEOpStack(node); } } + if (is_cached_) { + AddToCachedOpStack(node); + } + // Set total repeats and total epochs for the node + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); return Status::OK(); } // Turns off the tracking for operations under merge op Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Setting the flag is needed since we didn't call the base class DatasetOp version - if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); + if (is_repeated_) { + // If there was not any repeat in the merge cache miss leg, then the cache_lookup + // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack + if (cache_lookup_) { + cache_lookup_->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + AddToEOEOpStack(std::move(cache_lookup_)); + } + } + node->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used is_merge_ = false; - cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed return Status::OK(); } @@ -131,29 +248,42 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); } - // If we are in a repeat path already, then there must be a repeat above the merge op - // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. - if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - AddToEOEOpStack(node); - } else { - // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we - // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself - // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. - cache_lookup_ = std::static_pointer_cast(node); - } + // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we + // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself + // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. + // Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will + // add the lookup to the eoe stack + cache_lookup_ = std::static_pointer_cast(node); + return Status::OK(); } // Adds an operator to the eoe operator stack save area -void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } +void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { + op_stack *current_stack = eoe_op_stacks_.top().get(); + current_stack->push(dataset_op); +} // Pops an operator from the eoe operator stack save area std::shared_ptr RepeatPass::PopFromEOEOpStack() { std::shared_ptr top_op = nullptr; - if (!eoe_stack_.empty()) { - top_op = eoe_stack_.top(); - eoe_stack_.pop(); + op_stack *current_stack = eoe_op_stacks_.top().get(); + if (current_stack != nullptr && !current_stack->empty()) { + top_op = current_stack->top(); + current_stack->pop(); + } + return top_op; +} + +// Adds an operator to the cached operator stack save area +void RepeatPass::AddToCachedOpStack(std::shared_ptr dataset_op) { cached_op_stacks_.push(dataset_op); } + +// Pops an operator from the cached operator stack save area +std::shared_ptr RepeatPass::PopFromCachedOpStack() { + std::shared_ptr top_op = nullptr; + if (!cached_op_stacks_.empty()) { + top_op = cached_op_stacks_.top(); + cached_op_stacks_.pop(); } return top_op; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h index 9b733e2329..1e865eadac 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ -#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ #include #include @@ -30,27 +30,50 @@ namespace dataset { /// to the eoe-producing (typically leaf) nodes underneath it. class RepeatPass : public NodePass { public: + using op_stack = std::stack>; + /// \brief Constructor RepeatPass(); + /// \brief Destructor + ~RepeatPass() = default; + /// \brief Identifies the subtree below this node as being in a repeated path of the tree. /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all /// \return Status The error code return Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Identifies the subtree below this node as being in a repeated path of the tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Identifies the subtree below this node as being in a cache merge path /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all /// \return Status The error code return Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Identifies the subtree below this node as being cached + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Hooks up any identified eoe nodes under this repeat. /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all /// \return Status The error code return Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Hooks up any identified eoe nodes under this repeat. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief CacheOp removes previous leaf ops and replaces them with itself /// \param[in] node The node being visited /// \param[inout] modified Indicator if the node was changed at all @@ -86,13 +109,26 @@ class RepeatPass : public NodePass { /// \return shared_ptr to the popped operator std::shared_ptr PopFromEOEOpStack(); - bool is_repeated_; // T/F if we are processing under a repeat - bool is_merge_; // T/F if we are processing under a cache merge op - int32_t nested_repeats_; // A counter for nested repeats - std::stack> eoe_stack_; // A save area for leaf/eoe ops - std::shared_ptr cache_lookup_; // A save area for a cache lookup op + /// \brief Adds an operator to the cached operator stack save area + /// \param op - The dataset op to work add to cached stack + /// \return Status - The error code return + void AddToCachedOpStack(std::shared_ptr dataset_op); + + /// \brief Pops an operator from the cached operator stack save area + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromCachedOpStack(); + + bool is_repeated_; // T/F if we are processing under a repeat + bool is_merge_; // T/F if we are processing under a cache merge op + bool is_cached_; // T/F is we are processing under a cache op + int32_t nested_repeats_; // A counter for nested repeats + int32_t num_repeats_; // A multiplier to the total number of repeats + int32_t num_epochs_; // To save the total number of epochs + std::stack> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) + op_stack cached_op_stacks_; // A save area for ops under a cache op + std::shared_ptr cache_lookup_; // A save area for a cache lookup op }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc deleted file mode 100644 index 09b5f14a17..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "minddata/dataset/engine/opt/pre/cache_pass.h" -#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" -#include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/datasetops/source/celeba_op.h" -#include "minddata/dataset/engine/datasetops/source/generator_op.h" -#include "minddata/dataset/engine/datasetops/source/manifest_op.h" -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/voc_op.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" -#include "minddata/dataset/engine/datasetops/source/coco_op.h" -#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" -#include "minddata/dataset/engine/datasetops/source/random_data_op.h" -#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" -#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" - -namespace mindspore { -namespace dataset { - -// Constructor -CachePass::CachePass(CacheTransformPass *transform_pass) - : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); - } - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache -// transformation -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - is_caching_ = false; // We a no longer in a cache subtree. clear the flag. - if (leaf_op_) { - MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; - // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, - // using base class pointers. - transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); - } else { - // If there was no leaf_op set, then this is a non-mappable scenario. - - if (sampler_) { - // Grab the sampler that was saved from the leaf and plug it into the cache op - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; - } else { - // We're a cache op but no sampler was saved from leaf, so create a default sampler - int64_t num_samples = 0; - int64_t start_index = 0; - sampler_ = std::make_shared(num_samples, start_index); - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; - } - - // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache - uint32_t cache_crc = DatasetOp::GenerateCRC(node); - RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); - } - - return Status::OK(); -} - -// Common code for mappable leaf setup. -Status CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); - } - - // If we are a leaf in the caching path, then save this leaf. - if (is_caching_) { - MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; - leaf_op_ = std::move(leaf_op); - } - return Status::OK(); -} - -// Common code for non mappable leaf setup. -Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); - } - - // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf - // as save it for use by cache op in ascendant tree. - if (is_caching_) { - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); - MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; - } else { - // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can - // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) - std::shared_ptr sampler_from_leaf; - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); - } - return Status::OK(); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic - // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); - } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h deleted file mode 100644 index cbc805cd3e..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ - -#include -#include -#include -#include "minddata/dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class CacheTransformPass; - -/// \class CachePass cache_pass.h -/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache -/// transformation. It works in conjunction with the CacheTransformPass -class CachePass : public NodePass { - public: - /// \brief Constructor - /// \param[in] transform_pass Raw pointer back to controlling tree pass - explicit CachePass(CacheTransformPass *transform_pass); - - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache - /// transformation - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - /// \brief Common code for mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The error code return - Status MappableCacheLeafSetup(std::shared_ptr leaf_op); - - /// \brief Common code for non-mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The error code return - Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); - - bool is_caching_; - std::shared_ptr leaf_op_; - std::shared_ptr sampler_; - CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 033150e8f4..8a463aecfa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -15,17 +15,193 @@ */ #include -#include "minddata/dataset/engine/opt/pre/cache_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" + +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#endif + +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" + +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#endif + +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#endif namespace mindspore { namespace dataset { +// Constructor +CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); + } + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache +// transformation +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + is_caching_ = false; // We a no longer in a cache subtree. clear the flag. + if (leaf_op_) { + MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; + // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, + // using base class pointers. + AddMappableCacheOperators(std::move(leaf_op_), node); + } else { + // If there was no leaf_op set, then this is a non-mappable scenario. + + if (sampler_) { + // Grab the sampler that was saved from the leaf and plug it into the cache op + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; + } else { + // We're a cache op but no sampler was saved from leaf, so create a default sampler + int64_t num_samples = 0; + int64_t start_index = 0; + sampler_ = std::make_shared(num_samples, start_index); + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; + } + + // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache + uint32_t cache_crc = DatasetOp::GenerateCRC(node); + RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); + } + + return Status::OK(); +} + +// Common code for mappable leaf setup. +Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // If we are a leaf in the caching path, then save this leaf. + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; + leaf_op_ = std::move(leaf_op); + } + return Status::OK(); +} + +// Common code for non mappable leaf setup. +Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf + // as save it for use by cache op in ascendant tree. + if (is_caching_) { + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + } else { + // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can + // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) + std::shared_ptr sampler_from_leaf; + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); + } + return Status::OK(); +} + +#ifndef ENABLE_ANDROID +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic + // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} +#endif + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +#ifndef ENABLE_ANDROID +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} +#endif + +#ifdef ENABLE_PYTHON +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} +#endif + +// Assigns the leaf and cache operators that are involved in a cache transformation +void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr leaf_op, + std::shared_ptr cache_op) { + cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); +} + // constructor CacheTransformPass::CacheTransformPass() {} @@ -34,11 +210,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { MS_LOG(INFO) << "Pre pass: Cache transform pass started."; // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will // use to execute a transform. - std::unique_ptr cache_pass = std::make_unique(this); - RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); + CachePass cache_pass = CachePass(); + RETURN_IF_NOT_OK(cache_pass.Run(tree, modified)); // Then, execute the transform for each pair - for (auto cache_pair : cache_pairs_) { + for (auto cache_pair : cache_pass.cache_pairs()) { MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); } @@ -98,11 +274,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share return Status::OK(); } - -// Assigns the leaf and cache operators that are involved in a cache transformation -void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr leaf_op, - std::shared_ptr cache_op) { - cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h index 02c22c4472..970461d48f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ #include #include @@ -33,21 +33,143 @@ class CacheClient; /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching /// operations class CacheTransformPass : public TreePass { + /// \class CachePass + /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache + /// transformation. It works in conjunction with the CacheTransformPass + class CachePass : public NodePass { + public: + /// \brief Constructor + /// \param[in] transform_pass Raw pointer back to controlling tree pass + CachePass(); + + /// \brief Destructor + ~CachePass() = default; + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree and assigns the operators that + /// will be involved in a cache transformation + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + +#ifndef ENABLE_ANDROID + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + +#ifdef ENABLE_PYTHON + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + +#ifndef ENABLE_ANDROID + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + /// \brief Getter + std::vector, std::shared_ptr>> cache_pairs() { return cache_pairs_; } + + private: + /// \brief Common code for mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status MappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Common code for non-mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Assigns the leaf and cache operators that are involved in a cache transformation + /// \param[in] leaf_op The leaf operator involved in the cache transform + /// \param[in] cache_op The cache operator involved in the cache transform + void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); + + bool is_caching_; + std::shared_ptr leaf_op_; + std::shared_ptr sampler_; + // The two operators that work together to establish the cache transform + std::vector, std::shared_ptr>> cache_pairs_; + }; + public: /// \brief Constructor CacheTransformPass(); + /// \brief Destructor + ~CacheTransformPass() = default; + /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations /// \param[inout] tree The tree to operate on. /// \param[inout] Indicate of the tree was modified. /// \return Status The error code return Status RunOnTree(ExecutionTree *tree, bool *modified) override; - /// \brief Assigns the leaf and cache operators that are involved in a cache transformation - /// \param[in] leaf_op The leaf operator involved in the cache transform - /// \param[in] cache_op The cache operator involved in the cache transform - void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); - private: /// \brief Helper function to execute the cache transformation. /// @@ -69,11 +191,8 @@ class CacheTransformPass : public TreePass { /// \return Status The error code return Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, std::shared_ptr cache_op, std::shared_ptr cache_client); - - // The two operators that work together to establish the cache transform - std::vector, std::shared_ptr>> cache_pairs_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc new file mode 100644 index 0000000000..2cd1f74089 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" + +namespace mindspore { +namespace dataset { + +// constructor +EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr node) : injection_point_(node) {} + +// Performs finder work for BuildVocabOp that has special rules about epoch control injection +Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { + injection_point_ = nullptr; + return Status::OK(); +} + +// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection +Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, + bool *modified) { + injection_point_ = nullptr; + return Status::OK(); +} + +// Temporary code to prevent the injection of epoch control when cache op is present +// Remove this code in cache op phase 2 +Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { + injection_point_ = nullptr; + return Status::OK(); +} + +Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr node, bool *modified) { + // Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here. + injection_point_ = node->child(0); + return Status::OK(); +} + +// constructor +EpochInjectionPass::EpochInjectionPass() {} + +// Runs an injection pass to inject in operators needed at the pre pass stage +Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: Injection pass started."; + + // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. + // The finder can make updates to the EpochInjectionPass object. + EpochInjectionPass::InjectionFinder finder(tree->root()); + RETURN_IF_NOT_OK(finder.Run(tree, modified)); + + // The first injection logic is to check if we should inject the epoch control op as the root node. + // Do not inject the op if the number of epochs is 1. + int32_t num_epochs = tree->num_epochs(); + std::shared_ptr epoch_inject_node = finder.injection_point(); + if (num_epochs != 1 && epoch_inject_node != nullptr) { + std::shared_ptr epoch_ctrl_op; + RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); + epoch_inject_node->InsertAsParent(epoch_ctrl_op); + } + + MS_LOG(INFO) << "Pre pass: Injection pass complete."; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h new file mode 100644 index 0000000000..292f411aff --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class EpochInjectionPass epoch_injection_pass.h +/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api +/// parsing. +class EpochInjectionPass : public TreePass { + /// \class InjectionFinder + /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for + /// operators that need to be injected. It is run first by the main injection pass to find out what operators + /// it may need to inject. + class InjectionFinder : public NodePass { + public: + /// \brief Constructor + explicit InjectionFinder(std::shared_ptr node); + + /// \brief Destructor + ~InjectionFinder() = default; + + /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Temporary code to prevent the injection of epoch control when cache op is present. + /// Remove this code in cache op phase 2 + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Register the DeviceQueueOp for further action. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + std::shared_ptr injection_point() { return injection_point_; } + + private: + std::shared_ptr injection_point_; + }; + + public: + /// \brief Constructor + EpochInjectionPass(); + + /// \brief Runs an injection pass to inject in operators needed at the pre pass stage + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc deleted file mode 100644 index f04d7bc07d..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "minddata/dataset/engine/opt/pre/removal_nodes.h" -#include "minddata/dataset/engine/opt/pre/removal_pass.h" -#include "minddata/dataset/engine/datasetops/shuffle_op.h" - -namespace mindspore { -namespace dataset { - -RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; - is_caching_ = false; - return Status::OK(); -} - -// Perform ShuffleOp removal check. -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - // If we are in a cache descendant tree, then this shuffle op needs to be removed - if (is_caching_) { - MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; - if (removal_pass_) { - removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); - } - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h deleted file mode 100644 index 32025cd597..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ - -#include -#include "minddata/dataset/engine/opt/pass.h" -#include "minddata/dataset/engine/opt/pre/removal_pass.h" - -namespace mindspore { -namespace dataset { -/// \class RemovalNodes removal_nodes.h -/// \brief This is a NodePass who's job is to identify which nodes should be removed. -/// It works in conjunction with the removal_pass. -class RemovalNodes : public NodePass { - public: - /// \brief Constructor - /// \param[in] removal_pass Raw pointer back to controlling tree pass - explicit RemovalNodes(RemovalPass *removal_pass); - - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Resets the tracking of the cache within the tree - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Destructor - ~RemovalNodes() = default; - - /// \brief Perform ShuffleOp removal check - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - bool is_caching_; - RemovalPass *removal_pass_; // Back pointer to the owning removal pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc index 0db422a7c2..14c00b4a63 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc @@ -16,32 +16,58 @@ #include #include -#include "minddata/dataset/engine/opt/pre/removal_nodes.h" #include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/execution_tree.h" namespace mindspore { namespace dataset { +RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree +Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; + is_caching_ = false; + return Status::OK(); +} + +// Perform ShuffleOp removal check. +Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + // If we are in a cache descendant tree, then this shuffle op needs to be removed + if (is_caching_) { + MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; + nodes_to_remove_.push_back(std::static_pointer_cast(node)); + } + return Status::OK(); +} + // constructor RemovalPass::RemovalPass() {} -// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. +// Walk the tree to collect the nodes to remove, then removes them. Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { MS_LOG(INFO) << "Pre pass: removal pass started."; // Create the removal node pass which can identify which nodes need to be removed. - std::unique_ptr removal_nodes = std::make_unique(this); + std::unique_ptr removal_nodes = std::make_unique(); RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); // Then, execute the removal of any nodes that were set up for removal - for (auto node : removal_nodes_) { + for (auto node : removal_nodes->nodes_to_remove()) { node->Remove(); } MS_LOG(INFO) << "Pre pass: removal pass complete."; return Status::OK(); } - -// Adds an operator to the list of operators to be removed -void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h index bcab7cf08c..f1e8b79495 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ #include #include @@ -30,6 +30,45 @@ class DatasetOp; /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which /// nodes should be removed, and then removes them. class RemovalPass : public TreePass { + /// \class RemovalNodes + /// \brief This is a NodePass who's job is to identify which nodes should be removed. + /// It works in conjunction with the removal_pass. + class RemovalNodes : public NodePass { + public: + /// \brief Constructor + /// \param[in] removal_pass Raw pointer back to controlling tree pass + RemovalNodes(); + + /// \brief Destructor + ~RemovalNodes() = default; + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform ShuffleOp removal check + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + /// \return All the nodes to be removed + std::vector> nodes_to_remove() { return nodes_to_remove_; } + + private: + bool is_caching_; + std::vector> nodes_to_remove_; + }; + public: /// \brief Constructor RemovalPass(); @@ -42,15 +81,8 @@ class RemovalPass : public TreePass { /// \param[inout] Indicate of the tree was modified. /// \return Status The error code return Status RunOnTree(ExecutionTree *tree, bool *modified) override; - - /// \brief Adds an operator to the list of operators to be removed - /// \param[in] dataset_op The operator to add to the removal list - void AddToRemovalList(std::shared_ptr dataset_op); - - private: - std::vector> removal_nodes_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc index eb74d8fcc3..02f7bf8dfa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc @@ -60,7 +60,7 @@ Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { std::cout << "Visiting ShuffleOp" << '\n'; return Status::OK(); } - +#ifndef ENABLE_ANDROID Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { *modified = false; std::cout << "Visiting MindRecordOp" << '\n'; @@ -72,6 +72,7 @@ Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) std::cout << "Visiting TFReaderOp" << '\n'; return Status::OK(); } +#endif #ifdef ENABLE_PYTHON Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h index 527df3ccc9..d469554a93 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H -#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H #include #include "minddata/dataset/engine/opt/pass.h" @@ -39,9 +39,11 @@ class PrinterPass : public NodePass { Status RunOnNode(std::shared_ptr node, bool *modified) override; +#ifndef ENABLE_ANDROID Status RunOnNode(std::shared_ptr node, bool *modified) override; Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif #ifdef ENABLE_PYTHON Status RunOnNode(std::shared_ptr node, bool *modified) override; @@ -61,4 +63,4 @@ class PrinterPass : public NodePass { } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h index 61ba06a76f..efafd2860c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_CONNECTOR_SIZE_H -#define DATASET_CONNECTOR_SIZE_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_SIZE_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_SIZE_H #include #include @@ -69,4 +69,4 @@ class ConnectorSize : public Sampling { } // namespace dataset } // namespace mindspore -#endif // DATASET_CONNECTOR_SIZE_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_SIZE_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc index b5e2efaf73..693fb2d65d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc @@ -43,14 +43,17 @@ Status ConnectorThroughput::Sample() { out_buffer_count_row[col] = cur_out_buffer_count; auto sz = timestamps_.size(); cur_time = std::chrono::steady_clock::now(); - auto _dt = std::chrono::duration_cast(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]); - auto dt = std::chrono::duration(_dt).count(); + double dt = 0; + if (sz > 1) { + auto _dt = std::chrono::duration_cast(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]); + dt = std::chrono::duration(_dt).count(); + } auto prev_out_buffer_count = out_buffer_count_table_[col][out_buffer_count_table_.size() - 1]; if (dt != 0) { auto thr = (cur_out_buffer_count - prev_out_buffer_count) / (1000 * dt); throughput_row[col] = thr; } else { - throughput_row[col] = -1; + throughput_row[col] = 0; } col++; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h index 9cf387230a..9871aa425b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_CONNECTOR_THROUGHPUT_H -#define DATASET_CONNECTOR_THROUGHPUT_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_THROUGHPUT_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_THROUGHPUT_H #include #include @@ -100,4 +100,4 @@ class ConnectorThroughput : public Sampling { } // namespace dataset } // namespace mindspore -#endif // DATASET_CONNECTOR_THROUGHPUT_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_THROUGHPUT_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h b/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h index 2dfc3fd99d..d0a30fd2b7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_CYCLIC_ARRAY_H -#define DATASET_CYCLIC_ARRAY_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CYCLIC_ARRAY_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CYCLIC_ARRAY_H #include #include @@ -194,4 +194,4 @@ class CyclicArray { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_CYCLIC_ARRAY_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CYCLIC_ARRAY_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h b/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h index 8f215fd8df..b67ab5f4b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_PERF_DATA_H -#define DATASET_PERF_DATA_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PERF_DATA_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_PERF_DATA_H #include #include "minddata/dataset/core/constants.h" @@ -85,4 +85,4 @@ class PerfData { } // namespace dataset } // namespace mindspore -#endif // DATASET_PERF_DATA_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PERF_DATA_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc index f5c018c03b..4fdc6174a3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc @@ -17,7 +17,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/engine/perf/monitor.h" #include "minddata/dataset/engine/perf/device_queue_tracing.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h index 24f7f2efe8..b6f0ad2ab7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_PROFILE_H_ -#define DATASET_UTIL_PROFILE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_PROFILE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_PROFILE_H_ #include #include diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index 126291179a..dd57fa7ea4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "minddata/dataset/engine/tdt/tdt_plugin.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/log_adapter.h" #include "minddata/dataset/engine/perf/profiling.h" @@ -29,20 +29,27 @@ std::shared_ptr TdtPlugin::GetInstance() { return instance_ptr_; } -TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { +TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time, + tdt::TdtDataType tdt_type) { MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; std::vector items; double start_time; - auto ret = translate(ts_row, items); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "TDT converting tensor failed!"; - return FAILED; + if (tdt_type == tdt::TDT_TENSOR) { + auto ret = translate(ts_row, items); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "TDT converting tensor failed!"; + return FAILED; + } + } else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) { + DataItem data_item; + data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE; + items.emplace_back(data_item); + MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE"; } if (profiling) { start_time = ProfilingTime::GetCurMilliSecond(); } if (tdt::TdtHostPushData(channel_name, items) != 0) { - MS_LOG(ERROR) << "TDT pushing data failed!"; return FAILED; } if (profiling) { @@ -122,8 +129,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector &i data_item.dataPtr_ = std::shared_ptr(reinterpret_cast(&(*ts->begin())), [](const void *elem) {}); items.emplace_back(data_item); - MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " - << ts->Size() << "."; + MS_LOG(INFO) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes + << ", data length is " << ts->Size() << "."; } return SUCCESS; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h index a7db08b7f5..1275918c9f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_TDT_TDT_PLUGIN_H_ -#define DATASET_ENGINE_TDT_TDT_PLUGIN_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_PLUGIN_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_PLUGIN_H_ #include #include @@ -38,7 +38,8 @@ class TdtPlugin { public: static std::shared_ptr GetInstance(); - TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); + TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time, + tdt::TdtDataType tdt_type = tdt::TDT_TENSOR); private: TdtPlugin() {} @@ -51,4 +52,4 @@ class TdtPlugin { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_TDT_TDT_PLUGIN_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_PLUGIN_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h deleted file mode 120000 index 22fe6d07e1..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h +++ /dev/null @@ -1 +0,0 @@ -../../../core/constants.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h deleted file mode 120000 index 37a0e1b686..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h +++ /dev/null @@ -1 +0,0 @@ -../../../core/data_type.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h deleted file mode 120000 index 1fb7a24d91..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h +++ /dev/null @@ -1 +0,0 @@ -../../../core/tensor_shape.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h deleted file mode 120000 index b06279c05b..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h +++ /dev/null @@ -1 +0,0 @@ -../../../util/status.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 6f38f5ea16..eefcb024f6 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_INCLUDE_DATASETS_H_ -#define DATASET_INCLUDE_DATASETS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ #include #include @@ -40,14 +40,76 @@ namespace api { class TensorOperation; class SamplerObj; +// Datasets classes (in alphabetical order) +class CelebADataset; +class Cifar10Dataset; +class Cifar100Dataset; +class CocoDataset; class ImageFolderDataset; class MnistDataset; +class VOCDataset; +// Dataset Op classes (in alphabetical order) class BatchDataset; -class RepeatDataset; +class ConcatDataset; class MapDataset; -class ShuffleDataset; -class Cifar10Dataset; class ProjectDataset; +class RenameDataset; +class RepeatDataset; +class ShuffleDataset; +class SkipDataset; +class TakeDataset; +class ZipDataset; + +/// \brief Function to create a CelebADataset +/// \notes The generated dataset has two columns ['image', 'attr']. +// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. +/// \param[in] dataset_dir Path to the root directory that contains the dataset. +/// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'. +/// \param[in] decode Decode the images after reading (default=False). +/// \param[in] extensions List of file extensions to be included in the dataset (default=None). +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all", + const std::shared_ptr &sampler = nullptr, const bool &decode = false, + const std::set &extensions = {}); + +/// \brief Function to create a Cifar10 Dataset +/// \notes The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr Cifar10(const std::string &dataset_dir, std::shared_ptr sampler = nullptr); + +/// \brief Function to create a Cifar100 Dataset +/// \notes The generated dataset has two columns ['image', 'coarse_label', 'fine_label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr Cifar100(const std::string &dataset_dir, + std::shared_ptr sampler = nullptr); + +/// \brief Function to create a CocoDataset +/// \notes The generated dataset has multi-columns : +/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32], +/// ['iscrowd', dtype=uint32]]. +/// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]]. +/// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32], +/// ['num_keypoints', dtype=uint32]]. +/// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32], +/// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]]. +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] annotation_file Path to the annotation json +/// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint' +/// \param[in] decode Decode the images after reading +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, + const std::string &task = "Detection", const bool &decode = false, + const std::shared_ptr &sampler = nullptr); /// \brief Function to create an ImageFolderDataset /// \notes A source dataset that reads images from a tree of directories @@ -73,15 +135,37 @@ std::shared_ptr ImageFolder(std::string dataset_dir, bool de /// \return Shared pointer to the current MnistDataset std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler = nullptr); -/// \brief Function to create a Cifar10 Dataset -/// \notes The generated dataset has two columns ['image', 'label'] +/// \brief Function to create a ConcatDataset +/// \notes Reload "+" operator to concat two datasets +/// \param[in] datasets1 Shared pointer to the first dataset to be concatenated +/// \param[in] datasets2 Shared pointer to the second dataset to be concatenated +/// \return Shared pointer to the current ConcatDataset +std::shared_ptr operator+(const std::shared_ptr &datasets1, + const std::shared_ptr &datasets2); + +/// \brief Function to create a VOCDataset +/// \notes The generated dataset has multi-columns : +/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], +/// ['difficult', dtype=uint32], ['truncate', dtype=uint32]]. +/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. /// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] num_samples The number of images to be included in the dataset +/// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection" +/// \param[in] mode Set the data list txt file to be readed +/// \param[in] class_indexing A str-to-int mapping from label name to index +/// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, - std::shared_ptr sampler); +std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", + const std::string &mode = "train", + const std::map &class_index = {}, bool decode = false, + std::shared_ptr sampler = nullptr); + +/// \brief Function to create a ZipDataset +/// \notes Applies zip to the dataset +/// \param[in] datasets List of shared pointers to the datasets that we want to zip +/// \return Shared pointer to the current Dataset +std::shared_ptr Zip(const std::vector> &datasets); /// \class Dataset datasets.h /// \brief A base class to represent a dataset in the data pipeline. @@ -96,8 +180,8 @@ class Dataset : public std::enable_shared_from_this { ~Dataset() = default; /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object - /// \return shared pointer to the list of newly created DatasetOps - virtual std::shared_ptr>> Build() = 0; + /// \return The list of shared pointers to the newly created DatasetOps + virtual std::vector> Build() = 0; /// \brief Pure virtual function for derived class to implement parameters validation /// \return bool True if all the params are valid @@ -125,13 +209,11 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the current BatchDataset std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); - /// \brief Function to create a RepeatDataset - /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 - /// \param[in] count Number of times the dataset should be repeated - /// \return Shared pointer to the current Dataset - /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` - /// due to a limitation in the current implementation - std::shared_ptr Repeat(int32_t count = -1); + /// \brief Function to create a ConcatDataset + /// \notes Concat the datasets in the input + /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together + /// \return Shared pointer to the current ConcatDataset + std::shared_ptr Concat(const std::vector> &datasets); /// \brief Function to create a MapDataset /// \notes Applies each operation in operations to this dataset @@ -153,17 +235,51 @@ class Dataset : public std::enable_shared_from_this { std::vector output_columns = {}, const std::vector &project_columns = {}); + /// \brief Function to create a Project Dataset + /// \notes Applies project to the dataset + /// \param[in] columns The name of columns to project + /// \return Shared pointer to the current Dataset + std::shared_ptr Project(const std::vector &columns); + + /// \brief Function to create a Rename Dataset + /// \notes Renames the columns in the input dataset + /// \param[in] input_columns List of the input columns to rename + /// \param[in] output_columns List of the output columns + /// \return Shared pointer to the current Dataset + std::shared_ptr Rename(const std::vector &input_columns, + const std::vector &output_columns); + + /// \brief Function to create a RepeatDataset + /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 + /// \param[in] count Number of times the dataset should be repeated + /// \return Shared pointer to the current Dataset + /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` + /// due to a limitation in the current implementation + std::shared_ptr Repeat(int32_t count = -1); + /// \brief Function to create a Shuffle Dataset /// \notes Randomly shuffles the rows of this dataset /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling /// \return Shared pointer to the current ShuffleDataset std::shared_ptr Shuffle(int32_t shuffle_size); - /// \brief Function to create a Project Dataset - /// \notes Applies project to the dataset - /// \param[in] columns The name of columns to project + /// \brief Function to create a SkipDataset + /// \notes Skips count elements in this dataset. + /// \param[in] count Number of elements the dataset to be skipped. + /// \return Shared pointer to the current SkipDataset + std::shared_ptr Skip(int32_t count); + + /// \brief Function to create a TakeDataset + /// \notes Takes count elements in this dataset. + /// \param[in] count Number of elements the dataset to be taken. /// \return Shared pointer to the current Dataset - std::shared_ptr Project(const std::vector &columns); + std::shared_ptr Take(int32_t count = -1); + + /// \brief Function to create a Zip Dataset + /// \notes Applies zip to the dataset + /// \param[in] datasets A list of shared pointers to the datasets that we want to zip + /// \return Shared pointer to the current Dataset + std::shared_ptr Zip(const std::vector> &datasets); protected: std::vector> children; @@ -176,6 +292,99 @@ class Dataset : public std::enable_shared_from_this { /* ####################################### Derived Dataset classes ################################# */ +class CelebADataset : public Dataset { + public: + /// \brief Constructor + CelebADataset(const std::string &dataset_dir, const std::string &dataset_type, + const std::shared_ptr &sampler, const bool &decode, + const std::set &extensions); + + /// \brief Destructor + ~CelebADataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::string dataset_type_; + bool decode_; + std::set extensions_; + std::shared_ptr sampler_; +}; + +class Cifar10Dataset : public Dataset { + public: + /// \brief Constructor + Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr sampler); + + /// \brief Destructor + ~Cifar10Dataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::shared_ptr sampler_; +}; + +class Cifar100Dataset : public Dataset { + public: + /// \brief Constructor + Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr sampler); + + /// \brief Destructor + ~Cifar100Dataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::shared_ptr sampler_; +}; + +class CocoDataset : public Dataset { + public: + /// \brief Constructor + CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, + const bool &decode, const std::shared_ptr &sampler); + + /// \brief Destructor + ~CocoDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::string annotation_file_; + std::string task_; + bool decode_; + std::shared_ptr sampler_; +}; + /// \class ImageFolderDataset /// \brief A Dataset derived class to represent ImageFolder dataset class ImageFolderDataset : public Dataset { @@ -188,8 +397,8 @@ class ImageFolderDataset : public Dataset { ~ImageFolderDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid @@ -212,9 +421,31 @@ class MnistDataset : public Dataset { /// \brief Destructor ~MnistDataset() = default; + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::shared_ptr sampler_; +}; + +class VOCDataset : public Dataset { + public: + /// \brief Constructor + VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, + const std::map &class_index, bool decode, std::shared_ptr sampler); + + /// \brief Destructor + ~VOCDataset() = default; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid @@ -222,6 +453,10 @@ class MnistDataset : public Dataset { private: std::string dataset_dir_; + std::string task_; + std::string mode_; + std::map class_index_; + bool decode_; std::shared_ptr sampler_; }; @@ -235,8 +470,8 @@ class BatchDataset : public Dataset { ~BatchDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid @@ -250,6 +485,91 @@ class BatchDataset : public Dataset { std::map>> pad_map_; }; +class ConcatDataset : public Dataset { + public: + /// \brief Constructor + explicit ConcatDataset(const std::vector> &datasets); + + /// \brief Destructor + ~ConcatDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector> datasets_; +}; + +class MapDataset : public Dataset { + public: + /// \brief Constructor + MapDataset(std::vector> operations, std::vector input_columns = {}, + std::vector output_columns = {}, const std::vector &columns = {}); + + /// \brief Destructor + ~MapDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector> operations_; + std::vector input_columns_; + std::vector output_columns_; + std::vector project_columns_; +}; + +class ProjectDataset : public Dataset { + public: + /// \brief Constructor + explicit ProjectDataset(const std::vector &columns); + + /// \brief Destructor + ~ProjectDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector columns_; +}; + +class RenameDataset : public Dataset { + public: + /// \brief Constructor + explicit RenameDataset(const std::vector &input_columns, const std::vector &output_columns); + + /// \brief Destructor + ~RenameDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector input_columns_; + std::vector output_columns_; +}; + class RepeatDataset : public Dataset { public: /// \brief Constructor @@ -259,8 +579,8 @@ class RepeatDataset : public Dataset { ~RepeatDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid @@ -276,7 +596,7 @@ class ShuffleDataset : public Dataset { ~ShuffleDataset() = default; - std::shared_ptr>> Build() override; + std::vector> Build() override; bool ValidateParams() override; @@ -286,72 +606,67 @@ class ShuffleDataset : public Dataset { bool reset_every_epoch_; }; -class MapDataset : public Dataset { +class SkipDataset : public Dataset { public: /// \brief Constructor - MapDataset(std::vector> operations, std::vector input_columns = {}, - std::vector output_columns = {}, const std::vector &columns = {}); + explicit SkipDataset(int32_t count); /// \brief Destructor - ~MapDataset() = default; + ~SkipDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: - std::vector> operations_; - std::vector input_columns_; - std::vector output_columns_; - std::vector project_columns_; + int32_t skip_count_; }; -class Cifar10Dataset : public Dataset { +class TakeDataset : public Dataset { public: /// \brief Constructor - Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler); + explicit TakeDataset(int32_t count); /// \brief Destructor - ~Cifar10Dataset() = default; + ~TakeDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: - std::string dataset_dir_; - int32_t num_samples_; - std::shared_ptr sampler_; + int32_t take_count_; }; -class ProjectDataset : public Dataset { +class ZipDataset : public Dataset { public: /// \brief Constructor - explicit ProjectDataset(const std::vector &columns); + explicit ZipDataset(const std::vector> &datasets); /// \brief Destructor - ~ProjectDataset() = default; + ~ZipDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: - std::vector columns_; + std::vector> datasets_; }; + } // namespace api } // namespace dataset } // namespace mindspore -#endif // DATASET_INCLUDE_DATASETS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/de_tensor.h b/mindspore/ccsrc/minddata/dataset/include/de_tensor.h new file mode 100644 index 0000000000..749b9d35d9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/de_tensor.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_DETENSOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_DETENSOR_H_ +#include +#include +#include +#include "include/ms_tensor.h" +#include "minddata/dataset/include/tensor.h" +#include "minddata/dataset/util/status.h" +namespace mindspore { +namespace tensor { +class DETensor : public MSTensor { + public: + /// \brief Create a MSTensor pointer. + /// \param[data_type] DataTypeId of tensor to be created. + /// \param[shape] Shape of tensor to be created. + /// \return - MSTensor pointer. + static MSTensor *CreateTensor(TypeId data_type, const std::vector &shape); + + /// \brief Create a MSTensor pointer. + /// \param[path] Path file to be read. + /// \return - MSTensor pointer. + static MSTensor *CreateTensor(const std::string &path); + + DETensor(TypeId data_type, const std::vector &shape); + + explicit DETensor(std::shared_ptr tensor_ptr); + + ~DETensor() = default; + + /// \brief Create a duplicate instance, convert the DETensor to the LiteTensor. + /// \return - MSTensor pointer. + MSTensor *ConvertToLiteTensor(); + + std::shared_ptr tensor() const; + + TypeId data_type() const override; + + TypeId set_data_type(const TypeId data_type) override; + + std::vector shape() const override; + + size_t set_shape(const std::vector &shape) override; + + int DimensionSize(size_t index) const override; + + int ElementsNum() const override; + + std::size_t hash() const override; + + size_t Size() const override; + + void *MutableData() const override; + + protected: + std::shared_ptr tensor_impl_; +}; +} // namespace tensor +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_DETENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/execute.h b/mindspore/ccsrc/minddata/dataset/include/execute.h new file mode 100644 index 0000000000..53d6ee5572 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/execute.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_API_EXECUTE_H_ +#define DATASET_API_EXECUTE_H_ + +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/include/de_tensor.h" +#include "minddata/dataset/include/transforms.h" + +namespace mindspore { +namespace dataset { + +class TensorOp; + +namespace api { + +// class to run tensor operations in eager mode +class Execute { + public: + /// \brief Constructor + explicit Execute(std::shared_ptr op); + + /// \brief callable function to execute the TensorOperation in eager mode + /// \param[inout] input - the tensor to be transformed + /// \return - the output tensor, nullptr if Compute fails + std::shared_ptr operator()(std::shared_ptr input); + + private: + std::shared_ptr op_; +}; + +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_API_EXECUTE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h index c3784821a6..241a8c65b1 100644 --- a/mindspore/ccsrc/minddata/dataset/include/iterator.h +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_INCLUDE_ITERATOR_H_ -#define DATASET_INCLUDE_ITERATOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ #include #include @@ -112,4 +112,4 @@ class Iterator { } // namespace api } // namespace dataset } // namespace mindspore -#endif // DATASET_INCLUDE_ITERATOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 3d57e67059..9d423c78fa 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_API_SAMPLERS_H_ -#define DATASET_API_SAMPLERS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_ #include #include @@ -52,9 +52,12 @@ class WeightedRandomSamplerObj; /// \param[in] shuffle - If true, the indices are shuffled. /// \param[in] num_samples - The number of samples to draw (default to all elements). /// \param[in] seed - The seed in use when shuffle is true. +/// \param[in] even_dist - If true, each shard would return the same number of rows (default to true). +/// If false the total rows returned by all the shards would not have overlap. /// \return Shared pointer to the current Sampler. std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, - int64_t num_samples = 0, uint32_t seed = 1); + int64_t num_samples = 0, uint32_t seed = 1, + bool even_dist = true); /// Function to create a PK Sampler. /// \notes Samples K elements for each P class in the dataset. @@ -84,8 +87,7 @@ std::shared_ptr SequentialSampler(int64_t start_index = 0, /// \param[in] indices - A vector sequence of indices. /// \param[in] num_samples - The number of samples to draw (default to all elements). /// \return Shared pointer to the current Sampler. -std::shared_ptr SubsetRandomSampler(const std::vector &indices, - int64_t num_samples = 0); +std::shared_ptr SubsetRandomSampler(std::vector indices, int64_t num_samples = 0); /// Function to create a Weighted Random Sampler. /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given @@ -94,13 +96,14 @@ std::shared_ptr SubsetRandomSampler(const std::vector WeightedRandomSampler(const std::vector &weights, - int64_t num_samples = 0, bool replacement = true); +std::shared_ptr WeightedRandomSampler(std::vector weights, int64_t num_samples = 0, + bool replacement = true); /* ####################################### Derived Sampler classes ################################# */ class DistributedSamplerObj : public SamplerObj { public: - DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed); + DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed, + bool even_dist); ~DistributedSamplerObj() = default; @@ -114,6 +117,7 @@ class DistributedSamplerObj : public SamplerObj { bool shuffle_; int64_t num_samples_; uint32_t seed_; + bool even_dist_; }; class PKSamplerObj : public SamplerObj { @@ -164,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj { class SubsetRandomSamplerObj : public SamplerObj { public: - SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples); + SubsetRandomSamplerObj(std::vector indices, int64_t num_samples); ~SubsetRandomSamplerObj() = default; @@ -173,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj { bool ValidateParams() override; private: - const std::vector &indices_; + const std::vector indices_; int64_t num_samples_; }; class WeightedRandomSamplerObj : public SamplerObj { public: - explicit WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples = 0, - bool replacement = true); + explicit WeightedRandomSamplerObj(std::vector weights, int64_t num_samples = 0, bool replacement = true); ~WeightedRandomSamplerObj() = default; @@ -189,11 +192,11 @@ class WeightedRandomSamplerObj : public SamplerObj { bool ValidateParams() override; private: - const std::vector &weights_; + const std::vector weights_; int64_t num_samples_; bool replacement_; }; } // namespace api } // namespace dataset } // namespace mindspore -#endif // DATASET_API_SAMPLERS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/status.h b/mindspore/ccsrc/minddata/dataset/include/status.h deleted file mode 120000 index bba92b63ad..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/status.h +++ /dev/null @@ -1 +0,0 @@ -../util/status.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/status.h b/mindspore/ccsrc/minddata/dataset/include/status.h new file mode 100644 index 0000000000..b919b4dc4e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/status.h @@ -0,0 +1,137 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STATUS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STATUS_H_ + +#if defined(__GNUC__) || defined(__clang__) +#define DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define DEPRECATED __declspec(deprecated) +#else +#pragma message("WARNING: You need to implement DEPRECATED for this compiler") +#define DEPRECATED +#endif + +#include +#include +#include + +namespace mindspore { +namespace dataset { +#define RETURN_IF_NOT_OK(_s) \ + do { \ + Status __rc = (_s); \ + if (__rc.IsError()) { \ + return __rc; \ + } \ + } while (false) + +#define RETURN_STATUS_UNEXPECTED(_e) \ + do { \ + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, _e); \ + } while (false) + +#define CHECK_FAIL_RETURN_UNEXPECTED(_condition, _e) \ + do { \ + if (!(_condition)) { \ + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, _e); \ + } \ + } while (false) + +#define RETURN_UNEXPECTED_IF_NULL(_ptr) \ + do { \ + if ((_ptr) == nullptr) { \ + std::string err_msg = "The pointer[" + std::string(#_ptr) + "] is null."; \ + RETURN_STATUS_UNEXPECTED(err_msg); \ + } \ + } while (false) + +enum class StatusCode : char { + kOK = 0, + kOutOfMemory = 1, + kShapeMisMatch = 2, + kInterrupted = 3, + kNoSpace = 4, + kPyFuncException = 5, + kDuplicateKey = 6, + kPythonInterpreterFailure = 7, + kTDTPushFailure = 8, + kFileNotExist = 9, + kProfilingError = 10, + kBoundingBoxOutOfBounds = 11, + kBoundingBoxInvalidShape = 12, + // Make this error code the last one. Add new error code above it. + kUnexpectedError = 127 +}; + +std::string CodeAsString(const StatusCode c); + +class Status { + public: + Status() noexcept; + + explicit Status(StatusCode c) noexcept; + + ~Status() noexcept; + + // Copy constructor + Status(const Status &s); + + Status &operator=(const Status &s); + + // Move constructor + Status(Status &&) noexcept; + + Status &operator=(Status &&) noexcept; + + Status(const StatusCode code, const std::string &msg); + + Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); + + // Return a success status + static Status OK() { return Status(StatusCode::kOK); } + + std::string ToString() const; + + StatusCode get_code() const; + + friend std::ostream &operator<<(std::ostream &os, const Status &s); + + explicit operator bool() const { return (get_code() == StatusCode::kOK); } + + bool operator==(const Status &other) const { return (this->get_code() == other.get_code()); } + + bool operator!=(const Status &other) const { return !(*this == other); } + + bool IsOk() const { return (get_code() == StatusCode::kOK); } + + bool IsError() const { return !IsOk(); } + + bool IsOutofMemory() const { return (get_code() == StatusCode::kOutOfMemory); } + + bool IsInterrupted() const { return (get_code() == StatusCode::kInterrupted); } + + bool IsShapeIncorrect() const { return (get_code() == StatusCode::kShapeMisMatch); } + + bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); } + + private: + StatusCode code_; + std::string err_msg_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STATUS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/tensor.h b/mindspore/ccsrc/minddata/dataset/include/tensor.h deleted file mode 120000 index 34b5e020a9..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/tensor.h +++ /dev/null @@ -1 +0,0 @@ -../core/tensor.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/tensor.h b/mindspore/ccsrc/minddata/dataset/include/tensor.h new file mode 100644 index 0000000000..b2fe352c1d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/tensor.h @@ -0,0 +1,771 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ + +#include +#include +#include +#include +#include "./securec.h" +#include "utils/log_adapter.h" +#if defined(_WIN32) || defined(_WIN64) +#undef HAVE_STDDEF_H +#undef HAVE_STDLIB_H +#endif + +#ifdef ENABLE_PYTHON +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#endif + +#include "utils/ms_utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/status.h" +#ifndef ENABLE_ANDROID +#include "proto/example.pb.h" +#else +#include "minddata/dataset/include/de_tensor.h" +#endif + +#ifdef ENABLE_PYTHON +namespace py = pybind11; +#endif +namespace mindspore { +#ifdef ENABLE_ANDROID +namespace tensor { +class DETensor; +} // namespace tensor +#endif +namespace dataset { +class Tensor; +template +class Allocator; + +using CharAllocPtr = std::unique_ptr>; +using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors +using offset_t = uint32_t; // type of offset values to store strings locations +using TensorPtr = std::shared_ptr; + +class Tensor { + public: + Tensor() = delete; + Tensor(const Tensor &other) = delete; + Tensor &operator=(const Tensor &other) = delete; + + /// Create a tensor using shape and type. This constructor should not be used directly, use CreateFromTensor instead + /// \note The shape and type information should be known and valid + /// \note The constructor does not allocate data + /// \param shape TensorShape + /// \param type DataType + Tensor(const TensorShape &shape, const DataType &type); + + /// Move constructor + /// \param other Tensor to be moved + Tensor(Tensor &&other) noexcept; + + /// Move assigment operator + /// \param other Tensor to be moved + Tensor &operator=(Tensor &&other) noexcept; + + /// Create a numeric tensor with type and shape. Items of the tensor would be uninitialized. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out); + + /// Create a numeric tensor from a pointer in memory. Length of the source data is determined from the shape and type. + /// Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out); + + /// Create a tensor from a pointer in memory and length. Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[in] length length of the src data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, + const dsize_t &length, TensorPtr *out); + + /// Create a copy of the input tensor + /// \param[in] in original tensor to be copied + /// \param[out] out output tensor to be generated + /// \return Status + static Status CreateFromTensor(const TensorPtr &in, TensorPtr *out) { + return CreateFromMemory(in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes(), out); + } + +#ifdef ENABLE_PYTHON + /// Create a Tensor from a given py::array + /// \param[in] arr py::array + /// \param[out] out Created tensor + /// \return Status Code + static Status CreateFromNpArray(const py::array &arr, TensorPtr *out); +#endif + +#ifndef ENABLE_ANDROID + /// Create a tensor of type DE_STRING from a BytesList. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the outout tensor + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out); + + /// Create a tensor of type UINT8 or INT8 from a BytesList. + /// The tensor will be padded with ' ' to reach the required pad_size. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the output tensor + /// \param[in] type type of created tensor. Should be DE_UINT8 or INT8 + /// \param[in] pad_size The size of the tensor after padding + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, + const DataType &type, dsize_t pad_size, TensorPtr *out); +#endif + + /// Create a Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[in] shape shape of the output tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code + template + static Status CreateFromVector(const std::vector &items, const TensorShape &shape, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); + DataType type = DataType::FromCType(); + // if items is empty, items_ptr would be nullptr. CreateFromMemory will handle this case. + auto items_ptr = reinterpret_cast(&items[0]); + return CreateFromMemory(shape, type, items_ptr, out); + } + + /// Create a 1D Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code + template + static Status CreateFromVector(const std::vector &items, TensorPtr *out) { + return CreateFromVector(items, TensorShape({static_cast(items.size())}), out); + } + + /// Create a numeric scalar Tensor from the given value. + /// \tparam T type of value + /// \param[in] item value + /// \param[out] out Created tensor + /// \return Status code + template + static Status CreateScalar(const T &item, TensorPtr *out) { + DataType type = DataType::FromCType(); + auto item_ptr = reinterpret_cast(&item); + return CreateFromMemory(TensorShape::CreateScalar(), type, item_ptr, out); + } + + /// Create a tensor from a binary file on disk. + /// \param[in] path file to be read + /// \param[out] out Created Tensor + /// \return Status code + static Status CreateFromFile(const std::string &path, TensorPtr *out); + + /// Destruct the tensor and release the memory using the allocator + virtual ~Tensor(); + + /// Equality operator. compares tensor shape, type and data + /// \param[in] rhs Tensor to be compared with + /// \return bool + bool operator==(const Tensor &rhs) const; + + bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } + + /// Get item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return the item specified at index + template + Status GetItemAt(T *o, const std::vector &index) const; + + /// Get string located at `index`. + /// \param[in] index vector + /// \return return std::string_view specified at index + Status GetItemAt(std::string_view *o, const std::vector &index) const; + + template + Status GetUnsignedIntAt(T *o, const std::vector &index) const; + + template + Status GetSignedIntAt(T *o, const std::vector &index) const; + + template + Status GetFloatAt(T *o, const std::vector &index) const; + + /// set item at location specified by index + /// \tparam `T` + /// \param[in] index + /// \param[in] value of type `T` + template + Status SetItemAt(const std::vector &index, const T &value) { + T *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *ptr = value; + return Status::OK(); + } + + /// set string item at location specified by index + /// \param[in] index + /// \param[in] value of type std::string + Status SetItemAt(const std::vector &index, const std::string &value) { + RETURN_UNEXPECTED_IF_NULL(data_); + uchar *ptr = nullptr; + offset_t length = 0; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index, &length)); + if (value.length() != length) { + RETURN_STATUS_UNEXPECTED("Length of the new string does not match the item."); + } + memcpy_s(reinterpret_cast(ptr), length, value.c_str(), length); + + return Status::OK(); + } + + /// fill tensor with Zeros. Does not support strings. + Status Zero() { + CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); + dsize_t size = SizeInBytes(); + CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(GetMutableBuffer(), size, 0, size) == 0, + "Failed to fill tensor with zeroes."); + return Status::OK(); + } + + /// Fill all elements in the Tensor with the given value of type `T`. Does not support strings. + /// \tparam T + /// \param value[in] + template + Status Fill(const T &value) { + CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); + int64_t cellSize = type_.SizeInBytes(); + if ((data_ != nullptr) && type_.IsCompatible()) { + for (dsize_t i = 0; i < Size(); i++) { + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err"); + } + return Status::OK(); + } else { + std::string err; + err += (data_ == nullptr) ? "data_ is nullptr \t" : ""; + err += type_.IsCompatible() ? "data type not compatible\t" : ""; + return Status(StatusCode::kUnexpectedError, err); + } + } + + /// Getter function for shape + /// \return + const TensorShape &shape() const { return shape_; } + + /// Check if tensor has data + /// \return bool - true if tensor is empty + bool HasData() const { return data_ != nullptr; } + + /// Reshape the tensor. The given shape should have the same number of elements in the Tensor + /// \param shape + virtual Status Reshape(const TensorShape &shape); + + /// \return number of elements in this tensor + dsize_t Size() const { return shape().NumOfElements(); } + + /// \return the number of bytes this tensor is needs + dsize_t SizeInBytes() const { + if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); + return data_end_ - data_; + } + + /// \return the rank of the tensor + dsize_t Rank() const { return shape().Rank(); } + + /// Get the starting memory address as a constant for the data of the tensor. This potentially + /// drives an allocation if the data area. + /// \return const unsigned char* + const unsigned char *GetBuffer() const { return data_; } + + /// Getter of the type + /// \return + DataType type() const { return type_; } + + /// Provide stream operator for displaying it + /// \param output stream + /// \param so the Tensor object to be printed + /// \return output stream + friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { + so.Print(out); + return out; + } + + /// Invalidate this Tensor by setting the type and shape to unknown and MData to null. + /// Calling this method will make the Tensor and its data inaccessible, use it with caution. + void Invalidate(); + + /// Copy input tensor into self at the location index. + /// Index is a vector of axises which can be incomplete: + /// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. + /// \param index + /// \param input + /// \param partial_insert: boolean to determine if insertion along the full axis is enforced + /// \return Status code + Status InsertTensor(const std::vector &index, const std::shared_ptr &input, + const bool partial_insert = false); + + /// Find the address of the given index. Used in InsertTensor. + /// Example: + /// Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 + /// \param index incomplete index + /// \param output: startAddrofIndex + /// \param output: remaining + /// \return Status code + Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); + + /// Expand the shape of the Tensor with one extra dimension. + /// For example, if the shape is <512,512,3>: + /// *- ExpandDim(0) gives: <1,512,512,3> + /// *- ExpandDim(1) gives: <512,1,512,3> + /// *- ExpandDim(3) gives: <512,512,3,1> + /// \param axis location of the dim + virtual Status ExpandDim(const dsize_t &axis); + + virtual void Squeeze(); + + /// Calculates the strides of the Tensor + /// Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) + /// The strides will be {6,2,1}. + /// Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) + /// The strides will be {24,8,4}. + /// \return vector of integers + std::vector Strides() const; + + std::string ToString() { + std::stringstream ss; + this->Print(ss); + return ss.str(); + } + + /// Handle negative indices. + static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } + + /// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + /// Based on the type of tensor, SliceNumeric or SliceString will be called + /// \param[out] out Tensor + /// \param[in] indices vector of indices + /// \return Status error code + Status Slice(TensorPtr *out, const std::vector &indices); + + /// Slice numeric tensors. + Status SliceNumeric(TensorPtr *out, const std::vector &indices); + + /// Slice string tensors + Status SliceString(TensorPtr *out, const std::vector &indices); + +#ifdef ENABLE_PYTHON + /// Constructs numpy array from input tensor + /// \param[in] data this data is the location of python data + /// \return Status code + Status GetDataAsNumpy(py::array *data); + + Status GetDataAsNumpyStrings(py::array *data); + + static Status GetBufferInfo(Tensor *t, py::buffer_info *out); +#endif + + /// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor + /// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 + /// \tparam T type of values in the Tensor Iterator + template + class TensorIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T *; + using reference = T &; + + explicit TensorIterator(uchar *ptr = nullptr) { ptr_ = reinterpret_cast(ptr); } + + TensorIterator(const TensorIterator &raw_iterator) { ptr_ = raw_iterator.ptr_; } + + ~TensorIterator() = default; + + TensorIterator &operator=(const TensorIterator &rhs) { + ptr_ = rhs.ptr_; + return *this; + } + + TensorIterator &operator=(T *rhs) { + ptr_ = rhs; + return *this; + } + + bool operator==(const TensorIterator &rhs) { return ptr_ == rhs.ptr_; } + + bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } + + operator bool() const { return ptr_ != nullptr; } + + T &operator*() { return *ptr_; } + + const T &operator*() const { return *ptr_; } + + T *operator->() { return ptr_; } + + TensorIterator &operator+=(const ptrdiff_t &inc) { + ptr_ += inc; + return *this; + } + + TensorIterator &operator-=(const ptrdiff_t &inc) { + ptr_ -= inc; + return *this; + } + + TensorIterator &operator++() { + ++ptr_; + return *this; + } + + TensorIterator &operator--() { + --ptr_; + return *this; + } + + TensorIterator operator++(int) { + auto temp(*this); + ++ptr_; + return temp; + } + + TensorIterator operator--(int) { + auto temp(*this); + --ptr_; + return temp; + } + + TensorIterator operator+(const ptrdiff_t &inc) { + auto oldPtr = ptr_; + ptr_ += inc; + auto temp(*this); + ptr_ = oldPtr; + return temp; + } + + TensorIterator operator-(const ptrdiff_t &inc) { + auto oldPtr = ptr_; + ptr_ -= inc; + auto temp(*this); + ptr_ = oldPtr; + return temp; + } + + protected: + T *ptr_; + }; + + // Specialization of TensorIterator for strings. It returns std::string_view for every item. + // \tparam DUMMY, used to mbe able to specialize the inner class + template + class TensorIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = std::string_view; + using difference_type = ptrdiff_t; + using pointer = std::string_view *; + using reference = std::string_view &; + + explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) { + data_ = reinterpret_cast(data); + index_ = index; + } + + TensorIterator(const TensorIterator &raw_iterator) { + data_ = raw_iterator.data_; + index_ = raw_iterator.index_; + } + + ~TensorIterator() = default; + + bool operator==(const TensorIterator &rhs) { return data_ == rhs.data_ && index_ == rhs.index_; } + + bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } + + operator bool() const { return data_ != nullptr; } + + std::string_view operator*() const { + auto offset_ = reinterpret_cast(data_); + offset_t start = offset_[index_]; + return std::string_view{data_ + start}; + } + + TensorIterator &operator+=(const dsize_t &inc) { + index_ += inc; + return *this; + } + + TensorIterator &operator-=(const dsize_t &inc) { + index_ -= inc; + return *this; + } + + TensorIterator &operator++() { + ++index_; + return *this; + } + + TensorIterator &operator--() { + --index_; + return *this; + } + + TensorIterator operator++(int) { + auto temp(*this); + ++index_; + return temp; + } + + TensorIterator operator--(int) { + auto temp(*this); + --index_; + return temp; + } + + TensorIterator operator+(const dsize_t &inc) { + auto oldPtr = index_; + index_ += inc; + auto temp(*this); + index_ = oldPtr; + return temp; + } + + TensorIterator operator-(const dsize_t &inc) { + auto oldPtr = index_; + index_ -= inc; + auto temp(*this); + index_ = oldPtr; + return temp; + } + + protected: + dsize_t index_; + const char *data_; + }; + + /// Return a TensorIterator that points to the start of the Tensor. + /// It's the user responsibility to use the correct type that matches the Tensor type + /// \tparam T The type of values in the Tensor + /// \return TensorIterator + template + TensorIterator begin() { + return TensorIterator(data_); + } + + /// Return a linear iterator that points to the place after the last element of the Tensor. + /// \tparam T The type of values in the Tensor + /// \return TensorIterator + template + TensorIterator end() { + return TensorIterator(data_end_); + } + + /// Copies the last dimension at `index` from Tensor `src` to this Tensor. + /// \param[in] src Tensor + /// \param[in] index vector to the start of the dimension. The last dim should be 0 + /// \return Status + Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); + + protected: + /// Allocate memory for the tensor using the data_allocator + /// \param[in] length number of bytes to be allocated + /// \return Error Status + Status AllocateBuffer(const dsize_t &length); + + /// Get the starting memory address for the data of the tensor. This potentially + /// drives an allocation if the data is null. + /// \return unsigned char* + unsigned char *GetMutableBuffer() { return data_; } + + /// A function that prints Tensor recursively, first called by print + /// \param[in] out + /// \param[in] cur_dim + /// \param[in] cur_index + void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; + + /// A function that prints info about the tensor + /// \param[out] out output stream + void Print(std::ostream &out) const; + + /// A function that print the value as specified by its index + /// \param[in] index vector representing the index + /// \param[out] out + void PrintItemAt(const std::vector &index, std::ostream &out) const; + + /// Get pointer to item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return a pointer to the item specified at index of type `T` + template + Status GetItemPtr(T **, const std::vector &index) const; + + /// Get pointer to string located at `index` and the length of string + /// \param[in] index vector + /// \return return a pointer to the string specified at index and the length of the string + Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; + + /// Given a flat index of an item string, return the start and length of the item + /// \param[in] index flat index of the item + /// \param[out] start address of the ths string + /// \param[out] length of the string + Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; + + /// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if + /// the tensor's type is a string, otherwise undefined address would be returned. \return address of the first string + /// of the tensor. + uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } + + /// all access to shape_ should be via shape + TensorShape shape_; + /// data type of tensor + DataType type_; + /// pointer to the start of the physical data + unsigned char *data_; + /// An allocator for data_ + CharAllocPtr data_allocator_; + /// pointer to the end of the physical data + unsigned char *data_end_ = nullptr; + + private: +#ifdef ENABLE_ANDROID + friend class tensor::DETensor; +#endif + /// Copy raw data of a array based on shape and strides to the destination pointer + /// \param dst [out] Pointer to the destination array where the content is to be copied + /// \param[in] src Pointer to the source of strided array to be copied + /// \param[in] shape shape of the source array + /// \param[in] strides strides of the source array + /// \param[in] type_size number of bytes needed to store one array element's type + /// \return Status Code + static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size); + + /// const of the size of the offset variable + static constexpr uint8_t kOffsetSize = sizeof(offset_t); + +#ifdef ENABLE_PYTHON + /// Helper function to create a tensor from Numpy array of strings + /// \param[in] arr Numpy array + /// \param[out] out Created Tensor + /// \return Status + static Status CreateFromNpString(py::array arr, TensorPtr *out); +#endif +}; +template <> +inline Tensor::TensorIterator Tensor::end() { + return TensorIterator(data_, shape_.NumOfElements()); +} + +/// Create a Tensor from a given list of strings. +/// @note: The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. +/// The offset array will store one extra value to find the length of the last string. +/// OFFSET_1, OFFSET_2, ..., OFFSET_n+1, STRING_1, STRING_2, ..., STRING_n +/// The value of each offset is the start index of the corresponding string +/// Offsets is of type offset_t +/// strings will ne null-terminated +/// example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) +/// |----------------------------------------------------------------| +/// | OFFSET ARRAY | STRINGS | +/// | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | +/// | 11 | 15 | 18 | abc\0 | de\0 | +/// |----------------------------------------------------------------| +/// \param[in] items elements of the tensor +/// \param[in] shape shape of the output tensor +/// \param[out] out output argument to hold the created Tensor +/// \return Status Code +template <> +inline Status Tensor::CreateFromVector(const std::vector &items, const TensorShape &shape, + TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, TensorShape({static_cast(items.size())}), + DataType(DataType::DE_STRING)); + if (items.size() == 0) { + if (shape.known()) { + return (*out)->Reshape(shape); + } + } + auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; + dsize_t total_length = std::accumulate(items.begin(), items.end(), 0, length_sum); + + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize + 1) * (*out)->shape_.NumOfElements() + kOffsetSize + total_length; + + (*out)->AllocateBuffer(num_bytes); + auto offset_arr = reinterpret_cast((*out)->data_); + uchar *buf = (*out)->GetStringsBuffer(); + + offset_t offset = buf - (*out)->data_; // the first string will start here + uint32_t i = 0; + for (const auto &str : items) { + // insert the start index of the string. + offset_arr[i++] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s((*out)->data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + (*out)->data_end_ = (*out)->data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) { + RETURN_IF_NOT_OK((*out)->Reshape(shape)); + } + return Status::OK(); +} +/// Create a string scalar Tensor from the given value. +/// \param[in] item value +/// \param[out] out Created tensor +/// \return Status code +template <> +inline Status Tensor::CreateScalar(const std::string &item, TensorPtr *out) { + return CreateFromVector({item}, TensorShape::CreateScalar(), out); +} +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 31531a20af..1788f9ce51 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_API_TRANSFORMS_H_ -#define DATASET_API_TRANSFORMS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_ #include #include @@ -46,57 +46,21 @@ class TensorOperation : public std::enable_shared_from_this { // Transform operations for performing computer vision. namespace vision { -class NormalizeOperation; +// Transform Op classes (in alphabetical order) +class CenterCropOperation; +class CropOperation; +class CutOutOperation; class DecodeOperation; -class ResizeOperation; +class NormalizeOperation; +class PadOperation; +class RandomColorAdjustOperation; class RandomCropOperation; -class CenterCropOperation; -class UniformAugOperation; class RandomHorizontalFlipOperation; -class RandomVerticalFlipOperation; class RandomRotationOperation; -class PadOperation; -class CutOutOperation; -class RandomColorAdjustOperation; - -/// \brief Function to create a Normalize TensorOperation. -/// \notes Normalize the input image with respect to mean and standard deviation. -/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. -/// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr Normalize(std::vector mean, std::vector std); - -/// \brief Function to create a Decode TensorOperation. -/// \notes Decode the input image in RGB mode. -/// \param[in] rgb - a boolean of whether to decode in RGB mode or not. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr Decode(bool rgb = true); - -/// \brief Function to create a Resize TensorOperation. -/// \notes Resize the input image to the given size.. -/// \param[in] size - a vector representing the output size of the resized image. -/// If size is a single value, the image will be resized to this value with -/// the same image aspect ratio. If size has 2 values, it should be (height, width). -/// \param[in] interpolation An enum for the mode of interpolation -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr Resize(std::vector size, - InterpolationMode interpolation = InterpolationMode::kLinear); - -/// \brief Function to create a RandomCrop TensorOperation. -/// \notes Crop the input image at a random location. -/// \param[in] size - a vector representing the output size of the cropped image. -/// If size is a single value, a square crop of size (size, size) is returned. -/// If size has 2 values, it should be (height, width). -/// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided, -/// it pads the left, top, right and bottom respectively. -/// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than -/// the given output size. -/// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to -/// fill R, G, B channels respectively. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomCrop(std::vector size, std::vector padding = {0, 0, 0, 0}, - bool pad_if_needed = false, - std::vector fill_value = {0, 0, 0}); +class RandomVerticalFlipOperation; +class ResizeOperation; +class SwapRedBlueOperation; +class UniformAugOperation; /// \brief Function to create a CenterCrop TensorOperation. /// \notes Crops the input image at the center to the given size. @@ -106,37 +70,32 @@ std::shared_ptr RandomCrop(std::vector size, std:: /// \return Shared pointer to the current TensorOperation. std::shared_ptr CenterCrop(std::vector size); -/// \brief Function to create a UniformAugment TensorOperation. -/// \notes Tensor operation to perform randomly selected augmentation. -/// \param[in] operations - a vector of TensorOperation operations. -/// \param[in] num_ops - integer representing the number of OPs to be selected and applied. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr UniformAugment(std::vector> operations, - int32_t num_ops = 2); +/// \brief Function to create a Crop TensorOp +/// \notes Crop an image based on location and crop size +/// \param[in] coordinates Starting location of crop. Must be a vector of two values, in the form of {x_coor, y_coor} +/// \param[in] size Size of the cropped area. Must be a vector of two values, in the form of {height, width} +/// \return Shared pointer to the current TensorOp +std::shared_ptr Crop(std::vector coordinates, std::vector size); -/// \brief Function to create a RandomHorizontalFlip TensorOperation. -/// \notes Tensor operation to perform random horizontal flip. -/// \param[in] prob - float representing the probability of flip. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomHorizontalFlip(float prob = 0.5); +/// \brief Function to create a CutOut TensorOp +/// \notes Randomly cut (mask) out a given number of square patches from the input image +/// \param[in] length Integer representing the side length of each square patch +/// \param[in] num_patches Integer representing the number of patches to be cut out of an image +/// \return Shared pointer to the current TensorOp +std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1); -/// \brief Function to create a RandomVerticalFlip TensorOperation. -/// \notes Tensor operation to perform random vertical flip. -/// \param[in] prob - float representing the probability of flip. +/// \brief Function to create a Decode TensorOperation. +/// \notes Decode the input image in RGB mode. +/// \param[in] rgb - a boolean of whether to decode in RGB mode or not. /// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomVerticalFlip(float prob = 0.5); +std::shared_ptr Decode(bool rgb = true); -/// \brief Function to create a RandomRotation TensorOp -/// \notes Rotates the image according to parameters -/// \param[in] degrees A float vector size 2, representing the starting and ending degree -/// \param[in] resample An enum for the mode of interpolation -/// \param[in] expand A boolean representing whether the image is expanded after rotation -/// \param[in] center A float vector size 2, representing the x and y center of rotation. -/// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color -/// \return Shared pointer to the current TensorOp -std::shared_ptr RandomRotation( - std::vector degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, - std::vector center = {-1, -1}, std::vector fill_value = {0, 0, 0}); +/// \brief Function to create a Normalize TensorOperation. +/// \notes Normalize the input image with respect to mean and standard deviation. +/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. +/// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Normalize(std::vector mean, std::vector std); /// \brief Function to create a Pad TensorOp /// \notes Pads the image according to padding parameters @@ -160,13 +119,6 @@ std::shared_ptr RandomRotation( std::shared_ptr Pad(std::vector padding, std::vector fill_value = {0}, BorderType padding_mode = BorderType::kConstant); -/// \brief Function to create a CutOut TensorOp -/// \notes Randomly cut (mask) out a given number of square patches from the input image -/// \param[in] length Integer representing the side length of each square patch -/// \param[in] num_patches Integer representing the number of patches to be cut out of an image -/// \return Shared pointer to the current TensorOp -std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1); - /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} @@ -183,119 +135,202 @@ std::shared_ptr RandomColorAdjust(std::vector std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); +/// \brief Function to create a RandomCrop TensorOperation. +/// \notes Crop the input image at a random location. +/// \param[in] size - a vector representing the output size of the cropped image. +/// If size is a single value, a square crop of size (size, size) is returned. +/// If size has 2 values, it should be (height, width). +/// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided, +/// it pads the left, top, right and bottom respectively. +/// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than +/// the given output size. +/// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to +/// fill R, G, B channels respectively. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomCrop(std::vector size, std::vector padding = {0, 0, 0, 0}, + bool pad_if_needed = false, + std::vector fill_value = {0, 0, 0}); + +/// \brief Function to create a RandomHorizontalFlip TensorOperation. +/// \notes Tensor operation to perform random horizontal flip. +/// \param[in] prob - float representing the probability of flip. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomHorizontalFlip(float prob = 0.5); + +/// \brief Function to create a RandomRotation TensorOp +/// \notes Rotates the image according to parameters +/// \param[in] degrees A float vector size 2, representing the starting and ending degree +/// \param[in] resample An enum for the mode of interpolation +/// \param[in] expand A boolean representing whether the image is expanded after rotation +/// \param[in] center A float vector size 2, representing the x and y center of rotation. +/// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color +/// \return Shared pointer to the current TensorOp +std::shared_ptr RandomRotation( + std::vector degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, + std::vector center = {-1, -1}, std::vector fill_value = {0, 0, 0}); + +/// \brief Function to create a RandomVerticalFlip TensorOperation. +/// \notes Tensor operation to perform random vertical flip. +/// \param[in] prob - float representing the probability of flip. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomVerticalFlip(float prob = 0.5); + +/// \brief Function to create a Resize TensorOperation. +/// \notes Resize the input image to the given size.. +/// \param[in] size - a vector representing the output size of the resized image. +/// If size is a single value, the image will be resized to this value with +/// the same image aspect ratio. If size has 2 values, it should be (height, width). +/// \param[in] interpolation An enum for the mode of interpolation +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Resize(std::vector size, + InterpolationMode interpolation = InterpolationMode::kLinear); + +/// \brief Function to create a SwapRedBlue TensorOp +/// \notes Swaps the red and blue channels in image +/// \return Shared pointer to the current TensorOp +std::shared_ptr SwapRedBlue(); + +/// \brief Function to create a UniformAugment TensorOperation. +/// \notes Tensor operation to perform randomly selected augmentation. +/// \param[in] transforms - a vector of TensorOperation transforms. +/// \param[in] num_ops - integer representing the number of OPs to be selected and applied. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr UniformAugment(std::vector> transforms, + int32_t num_ops = 2); + /* ####################################### Derived TensorOperation classes ################################# */ -class NormalizeOperation : public TensorOperation { +class CenterCropOperation : public TensorOperation { public: - NormalizeOperation(std::vector mean, std::vector std); + explicit CenterCropOperation(std::vector size); - ~NormalizeOperation() = default; + ~CenterCropOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector mean_; - std::vector std_; + std::vector size_; }; -class DecodeOperation : public TensorOperation { +class CropOperation : public TensorOperation { public: - explicit DecodeOperation(bool rgb = true); + CropOperation(std::vector coordinates, std::vector size); - ~DecodeOperation() = default; + ~CropOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - bool rgb_; + std::vector coordinates_; + std::vector size_; }; -class ResizeOperation : public TensorOperation { +class CutOutOperation : public TensorOperation { public: - explicit ResizeOperation(std::vector size, - InterpolationMode interpolation_mode = InterpolationMode::kLinear); + explicit CutOutOperation(int32_t length, int32_t num_patches = 1); - ~ResizeOperation() = default; + ~CutOutOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector size_; - InterpolationMode interpolation_; + int32_t length_; + int32_t num_patches_; }; -class RandomCropOperation : public TensorOperation { +class DecodeOperation : public TensorOperation { public: - RandomCropOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, - bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}); + explicit DecodeOperation(bool rgb = true); - ~RandomCropOperation() = default; + ~DecodeOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector size_; - std::vector padding_; - bool pad_if_needed_; - std::vector fill_value_; + bool rgb_; }; -class CenterCropOperation : public TensorOperation { +class NormalizeOperation : public TensorOperation { public: - explicit CenterCropOperation(std::vector size); + NormalizeOperation(std::vector mean, std::vector std); - ~CenterCropOperation() = default; + ~NormalizeOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector size_; + std::vector mean_; + std::vector std_; }; -class UniformAugOperation : public TensorOperation { +class PadOperation : public TensorOperation { public: - explicit UniformAugOperation(std::vector> operations, int32_t num_ops = 2); + PadOperation(std::vector padding, std::vector fill_value = {0}, + BorderType padding_mode = BorderType::kConstant); - ~UniformAugOperation() = default; + ~PadOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector> operations_; - int32_t num_ops_; + std::vector padding_; + std::vector fill_value_; + BorderType padding_mode_; }; -class RandomHorizontalFlipOperation : public TensorOperation { +class RandomColorAdjustOperation : public TensorOperation { public: - explicit RandomHorizontalFlipOperation(float probability = 0.5); + RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, + std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); - ~RandomHorizontalFlipOperation() = default; + ~RandomColorAdjustOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - float probability_; + std::vector brightness_; + std::vector contrast_; + std::vector saturation_; + std::vector hue_; }; -class RandomVerticalFlipOperation : public TensorOperation { +class RandomCropOperation : public TensorOperation { public: - explicit RandomVerticalFlipOperation(float probability = 0.5); + RandomCropOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, + bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}); - ~RandomVerticalFlipOperation() = default; + ~RandomCropOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; + std::vector padding_; + bool pad_if_needed_; + std::vector fill_value_; +}; + +class RandomHorizontalFlipOperation : public TensorOperation { + public: + explicit RandomHorizontalFlipOperation(float probability = 0.5); + + ~RandomHorizontalFlipOperation() = default; std::shared_ptr Build() override; @@ -324,57 +359,63 @@ class RandomRotationOperation : public TensorOperation { std::vector fill_value_; }; -class PadOperation : public TensorOperation { +class RandomVerticalFlipOperation : public TensorOperation { public: - PadOperation(std::vector padding, std::vector fill_value = {0}, - BorderType padding_mode = BorderType::kConstant); + explicit RandomVerticalFlipOperation(float probability = 0.5); - ~PadOperation() = default; + ~RandomVerticalFlipOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector padding_; - std::vector fill_value_; - BorderType padding_mode_; + float probability_; }; -class CutOutOperation : public TensorOperation { +class ResizeOperation : public TensorOperation { public: - explicit CutOutOperation(int32_t length, int32_t num_patches = 1); + explicit ResizeOperation(std::vector size, + InterpolationMode interpolation_mode = InterpolationMode::kLinear); - ~CutOutOperation() = default; + ~ResizeOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - int32_t length_; - int32_t num_patches_; + std::vector size_; + InterpolationMode interpolation_; }; -class RandomColorAdjustOperation : public TensorOperation { +class UniformAugOperation : public TensorOperation { public: - RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, - std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); + explicit UniformAugOperation(std::vector> transforms, int32_t num_ops = 2); - ~RandomColorAdjustOperation() = default; + ~UniformAugOperation() = default; std::shared_ptr Build() override; bool ValidateParams() override; private: - std::vector brightness_; - std::vector contrast_; - std::vector saturation_; - std::vector hue_; + std::vector> transforms_; + int32_t num_ops_; +}; + +class SwapRedBlueOperation : public TensorOperation { + public: + SwapRedBlueOperation(); + + ~SwapRedBlueOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; }; } // namespace vision } // namespace api } // namespace dataset } // namespace mindspore -#endif // DATASET_API_TRANSFORMS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h b/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h deleted file mode 120000 index f2c939bc0b..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h +++ /dev/null @@ -1 +0,0 @@ -../../../../utils/log_adapter.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/utils/overload.h b/mindspore/ccsrc/minddata/dataset/include/utils/overload.h deleted file mode 120000 index 7dc313d512..0000000000 --- a/mindspore/ccsrc/minddata/dataset/include/utils/overload.h +++ /dev/null @@ -1 +0,0 @@ -../../../../utils/overload.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt index 8a9096ff23..3273822db5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt @@ -4,11 +4,16 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) if (ENABLE_PYTHON) add_library(kernels OBJECT + data/compose_op.cc + data/random_apply_op.cc + data/random_choice_op.cc py_func_op.cc tensor_op.cc) target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS}) else() add_library(kernels OBJECT + data/compose_op.cc + data/random_apply_op.cc + data/random_choice_op.cc tensor_op.cc) endif() - diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.cc new file mode 100644 index 0000000000..b50e0ae24b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/compose_op.h" + +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status ComposeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + std::vector in_shapes = inputs; + for (auto &op : ops_) { + RETURN_IF_NOT_OK(op->OutputShape(in_shapes, outputs)); + in_shapes = std::move(outputs); // outputs become empty after move + } + outputs = std::move(in_shapes); + return Status::OK(); +} +Status ComposeOp::OutputType(const std::vector &inputs, std::vector &outputs) { + std::vector in_types = inputs; + for (auto &op : ops_) { + RETURN_IF_NOT_OK(op->OutputType(in_types, outputs)); + in_types = std::move(outputs); // outputs become empty after move + } + outputs = std::move(in_types); + return Status::OK(); +} +Status ComposeOp::Compute(const TensorRow &inputs, TensorRow *outputs) { + IO_CHECK_VECTOR(inputs, outputs); + TensorRow in_rows = inputs; + for (auto &op : ops_) { + RETURN_IF_NOT_OK(op->Compute(in_rows, outputs)); + in_rows = std::move(*outputs); // after move, *outputs become empty + } + (*outputs) = std::move(in_rows); + return Status::OK(); +} + +ComposeOp::ComposeOp(const std::vector> &ops) : ops_(ops) { + if (ops_.empty()) { + MS_LOG(ERROR) << "op_list is empty this might lead to Segmentation Fault."; + } else if (ops_.size() == 1) { + MS_LOG(WARNING) << "op_list has only 1 op. Compose is probably not needed."; + } +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h new file mode 100644 index 0000000000..d639936378 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_COMPOSE_OP_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_COMPOSE_OP_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class ComposeOp : public TensorOp { + public: + /// constructor + /// \param[in] ops list of TensorOps to compose into 1 TensorOp + explicit ComposeOp(const std::vector> &ops); + + /// default destructor + ~ComposeOp() override = default; + + /// return the number of inputs the first tensorOp in compose takes + /// \return number of input tensors + uint32_t NumInput() override { return ops_.front()->NumInput(); } + + /// return the number of outputs the last tensorOp in compose produces + /// \return number of output tensors + uint32_t NumOutput() override { return ops_.back()->NumOutput(); } + + /// \param[in] inputs + /// \param[out] outputs + /// \return Status code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + /// \param[in] inputs + /// \param[out] outputs + /// \return Status code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + /// \param[in] input + /// \param[out] output + /// \return Status code + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kComposeOp; } + + private: + std::vector> ops_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_COMPOSE_OP_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h index 46cc613049..cf2afb30d2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ -#define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_CONCATENATE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_CONCATENATE_OP_H_ #include #include @@ -40,7 +40,6 @@ class ConcatenateOp : public TensorOp { /// Print method to see which tensor Op this is. /// @param std::ostream &out - output stream object. - void Print(std::ostream &out) const override { out << "ConcatenateOp"; } /// Compute method allowing multiple tensors as inputs /// @param TensorRow &input - input tensor rows diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index b1d51a6c08..5632dddeec 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -97,7 +97,7 @@ Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *ou if (input->Rank() == 1) num_elements = input->shape()[0]; TensorShape out_shape({num_elements, num_classes}); std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(out_shape, input->type(), &out)); RETURN_IF_NOT_OK(out->Zero()); for (dsize_t i = 0; i < num_elements; ++i) { if (input->type().IsUnsignedInt()) { @@ -133,7 +133,9 @@ Status Fill(const std::shared_ptr input, std::shared_ptr *output fill_output = fill_value; } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); + if (input_type.IsNumeric()) { + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input_shape, input_type, &out)); + } switch (input_type.value()) { case DataType::DE_BOOL: { @@ -216,7 +218,7 @@ Status Fill(const std::shared_ptr input, std::shared_ptr *output for (int i = 0; i < input_shape.NumOfElements(); i++) { strings.emplace_back(fill_string); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, input_shape, &out)); break; } case DataType::DE_UNKNOWN: { @@ -285,9 +287,8 @@ void CastFrom(const std::shared_ptr &input, std::shared_ptr *out // Type cast operator Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output)); - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); switch (input->type().value()) { case DataType::DE_BOOL: CastFrom(input, output); @@ -335,8 +336,7 @@ Status TypeCast(const std::shared_ptr &input, std::shared_ptr *o Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { // initiate new tensor for type cast DataType new_type = DataType("float16"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), new_type, output)); auto in_itr = input->begin(); auto out_itr = (*output)->begin(); @@ -387,7 +387,7 @@ Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr (*dst) = src; // if no padding, copy the pointer } else { CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(pad_shape), src->type(), dst)); auto tensor_type = src->type().value(); if (pad_val == 0) { // if pad with zero, don't care what type it is RETURN_IF_NOT_OK((*dst)->Zero()); @@ -447,7 +447,7 @@ Status PadEndString(const std::shared_ptr &src, std::shared_ptr std::vector cur_ind(src->Rank(), 0); std::vector strings; RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, TensorShape(pad_shape), dst)); } return Status::OK(); } @@ -521,7 +521,7 @@ Status Mask(const std::shared_ptr &input, std::shared_ptr *outpu "Cannot convert constant value to the type of the input tensor."); CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_BOOL), output)); std::unique_ptr value_cast_op(new TypeCastOp(input->type())); std::shared_ptr casted_value; @@ -580,77 +580,73 @@ Status Mask(const std::shared_ptr &input, std::shared_ptr *outpu Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, std::shared_ptr append) { - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); - CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); - axis = Tensor::HandleNeg(axis, input[0]->shape().Rank()); CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); - std::shared_ptr out; + TensorShape t = TensorShape::CreateScalar(); + + DataType first_dtype = input[0]->type(); + + TensorRow tensor_list; + if (prepend != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == prepend->type(), "Tensor types do not match"); CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); - } else { - out = input[0]; + tensor_list.emplace_back(prepend); } - for (dsize_t i = 1; i < input.size(); i++) { - std::shared_ptr out_t; + + for (dsize_t i = 0; i < input.size(); i++) { + CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == input[i]->type(), "Tensor types do not match"); CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); - out = out_t; + tensor_list.emplace_back(input[i]); } - std::shared_ptr out_t; + if (append != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == append->type(), "Tensor types do not match"); CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); - } else { - out_t = out; + tensor_list.emplace_back(append); } - output->push_back(out_t); - - return Status::OK(); -} - -Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, - std::shared_ptr append) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); - TensorShape t({}); - - for (dsize_t i = 0; i < input->shape().Rank(); i++) { + // create final shape + for (dsize_t i = 0; i < tensor_list[0]->shape().Rank(); i++) { if (i != axis) { - t = t.AppendDim(input->shape()[i]); + t = t.AppendDim(tensor_list[0]->shape()[i]); } else { - dsize_t new_shape = input->shape()[i] + append->shape()[i]; - + dsize_t new_shape = 0; + for (dsize_t j = 0; j < tensor_list.size(); j++) { + new_shape = tensor_list[j]->shape()[i] + new_shape; + } t = t.AppendDim(new_shape); } } + std::shared_ptr out; - if (input->type().IsNumeric()) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); + if (input[0]->type().IsNumeric()) { + RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, tensor_list[0]->type(), &out)); + std::vector index(axis + 1, 0); - RETURN_IF_NOT_OK(out->Concatenate({0}, input)); - RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); - *output = out; + int n = index.size() - 1; + for (dsize_t i = 0; i < tensor_list.size(); i++) { + RETURN_IF_NOT_OK(out->InsertTensor({index}, tensor_list[i], true)); + index[n] = index[n] + tensor_list[i]->shape()[axis]; + } } else { std::vector strings; - auto itr = input->begin(); - for (; itr != input->end(); itr++) { - strings.emplace_back(*itr); - } - itr = append->begin(); - for (; itr != append->end(); itr++) { - strings.emplace_back(*itr); + for (dsize_t i = 0; i < tensor_list.size(); i++) { + auto itr = tensor_list[i]->begin(); + for (; itr != tensor_list[i]->end(); itr++) { + strings.emplace_back(*itr); + } } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); - - *output = out; + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, &out)); } + output->push_back(out); + return Status::OK(); } + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h index 141545a583..5e82b41024 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_DATA_UTILS_H_ -#define DATASET_KERNELS_DATA_DATA_UTILS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_DATA_UTILS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_DATA_UTILS_H_ #include #include @@ -152,12 +152,7 @@ Status Mask(const std::shared_ptr &input, std::shared_ptr *outpu Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, std::shared_ptr append); - -// helper for concat, always append to the input, and pass that to the output -Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, - std::shared_ptr append); - } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_DATA_UTILS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc index 57a424704f..c7fc6c1d7e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc @@ -26,7 +26,7 @@ Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { IO_CHECK_VECTOR(input, output); CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, input[0])); + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input[0], &out)); output->push_back(input[0]); output->push_back(out); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h index 60b2d8c33b..bf4aa4bda8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_DUPLICATE_OP_H_ -#define DATASET_KERNELS_DATA_DUPLICATE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_DUPLICATE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_DUPLICATE_OP_H_ #include #include @@ -32,8 +32,6 @@ class DuplicateOp : public TensorOp { ~DuplicateOp() override = default; - void Print(std::ostream &out) const override { out << "DuplicateOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; uint32_t NumOutput() override { return 2; } @@ -42,4 +40,4 @@ class DuplicateOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DUPLICATE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DUPLICATE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h index af0d9e7941..e2761142df 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_FILL_OP_H_ -#define DATASET_KERNELS_DATA_FILL_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_FILL_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_FILL_OP_H_ #include #include @@ -31,7 +31,6 @@ class FillOp : public TensorOp { explicit FillOp(std::shared_ptr value) : fill_value_(value) {} ~FillOp() override = default; - void Print(std::ostream &out) const override { out << "FillOp"; } Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h index e6ac8c3964..762cf1de40 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h @@ -37,8 +37,6 @@ class MaskOp : public TensorOp { ~MaskOp() override = default; - void Print(std::ostream &out) const override { out << "MaskOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; Status OutputType(const std::vector &inputs, std::vector &outputs) override; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h new file mode 100644 index 0000000000..18761cde25 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_NO_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_NO_OP_H_ + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class NoOp : public TensorOp { + public: + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { + *output = input; + return Status::OK(); + } + + std::string Name() const override { return kNoOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_NO_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h index 06a4823573..629a7e3082 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_ONE_HOT_OP_H_ -#define DATASET_KERNELS_DATA_ONE_HOT_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_ONE_HOT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_ONE_HOT_OP_H_ #include #include @@ -31,8 +31,6 @@ class OneHotOp : public TensorOp { ~OneHotOp() override = default; - void Print(std::ostream &out) const override { out << "OneHotOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; Status OutputShape(const std::vector &inputs, std::vector &outputs) override; @@ -44,4 +42,4 @@ class OneHotOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_ONE_HOT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h index c28f7250e0..019124382a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_ -#define DATASET_KERNELS_DATA_PAD_END_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_PAD_END_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_PAD_END_OP_H_ #include #include @@ -32,8 +32,6 @@ class PadEndOp : public TensorOp { ~PadEndOp() override = default; - void Print(std::ostream &out) const override { out << "PadEndOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; Status OutputShape(const std::vector &inputs, std::vector &outputs) override; @@ -46,4 +44,4 @@ class PadEndOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_PAD_END_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc new file mode 100644 index 0000000000..9fe1d875e2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/random_apply_op.h" + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +uint32_t RandomApplyOp::NumOutput() { + if (compose_->NumOutput() != NumInput()) { + MS_LOG(WARNING) << "NumOutput!=NumInput (randomApply would randomly affect number of outputs)."; + return 0; + } + return compose_->NumOutput(); +} + +Status RandomApplyOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(compose_->OutputShape(inputs, outputs)); + // randomApply either runs all ops or do nothing. If the two methods don't give the same result. return unknown shape. + if (inputs != outputs) { // when RandomApply is not applied, input should be the same as output + outputs.clear(); + outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape()); + } + return Status::OK(); +} +Status RandomApplyOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(compose_->OutputType(inputs, outputs)); + if (inputs != outputs) { // when RandomApply is not applied, input should be the same as output + outputs.clear(); + outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN)); + } + return Status::OK(); +} +Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) { + if (rand_double_(gen_) <= prob_) { + RETURN_IF_NOT_OK(compose_->Compute(input, output)); + } else { + IO_CHECK_VECTOR(input, output); + *output = input; // copy over the tensors + } + return Status::OK(); +} +RandomApplyOp::RandomApplyOp(double prob, const std::vector> &ops) + : prob_(prob), gen_(GetSeed()), rand_double_(0, 1) { + compose_ = std::make_unique(ops); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.h new file mode 100644 index 0000000000..8edc3e1d70 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_APPLY_OP_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_APPLY_OP_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/compose_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class RandomApplyOp : public TensorOp { + public: + /// constructor + /// \param[in] prob probability whether the list of TensorOps will be applied + /// \param[in] ops the list of TensorOps to apply with prob likelihood + explicit RandomApplyOp(double prob, const std::vector> &ops); + + /// default destructor + ~RandomApplyOp() = default; + + /// return the number of inputs the first tensorOp in compose takes + /// \return number of input tensors + uint32_t NumInput() override { return compose_->NumInput(); } + + /// return the number of outputs + /// \return number of output tensors + uint32_t NumOutput() override; + + /// return output shape if randomApply won't affect the output shape, otherwise return unknown shape + /// \param[in] inputs + /// \param[out] outputs + /// \return Status code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + /// return output type if randomApply won't affect the output type, otherwise return unknown type + /// \param[in] inputs + /// \param[out] outputs + /// \return Status code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + /// \param[in] input + /// \param[out] output + /// \return Status code + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomApplyOp; } + + private: + double prob_; + std::shared_ptr compose_; + std::mt19937 gen_; // mersenne_twister_engine + std::uniform_real_distribution rand_double_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_APPLY_OP_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc new file mode 100644 index 0000000000..ee505b0dc2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/random_choice_op.h" + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +uint32_t RandomChoiceOp::NumInput() { + uint32_t num_input = ops_.front()->NumInput(); + for (auto &op : ops_) { + uint32_t cur_num = op->NumInput(); + if (num_input != cur_num && cur_num > 0) { + MS_LOG(WARNING) << "Unable to determine NumInput, ops in RandomChoice don't take the same number of input."; + return 0; + } + } + return num_input; +} + +uint32_t RandomChoiceOp::NumOutput() { + uint32_t num_output = ops_.front()->NumOutput(); + for (auto &op : ops_) { + uint32_t cur_num = op->NumOutput(); + if (num_output != cur_num) { + MS_LOG(WARNING) << "Unable to determine NumInput, ops in RandomChoice don't have the same number of input."; + return 0; + } + } + return num_output; +} + +Status RandomChoiceOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs)); + for (auto &op : ops_) { + std::vector out_shapes; + RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes)); + if (outputs != out_shapes) { + MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape."; + outputs.clear(); + outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape()); + return Status::OK(); + } + } + return Status::OK(); +} + +Status RandomChoiceOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs)); + for (auto &op : ops_) { + std::vector out_types; + RETURN_IF_NOT_OK(op->OutputType(inputs, out_types)); + if (outputs != out_types) { + MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType."; + outputs.clear(); + outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN)); + return Status::OK(); + } + } + return Status::OK(); +} + +Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) { + size_t rand_num = rand_int_(gen_); + CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num)); + RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output)); + return Status::OK(); +} +RandomChoiceOp::RandomChoiceOp(const std::vector> &ops) + : ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) { + if (ops_.empty()) { + MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty."; + } else if (ops_.size() == 1) { + MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.h new file mode 100644 index 0000000000..7244ca8931 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_CHOICE_OP_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_CHOICE_OP_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/compose_op.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class RandomChoiceOp : public TensorOp { + public: + /// constructor + /// \param[in] ops list of TensorOps to randomly choose 1 from + explicit RandomChoiceOp(const std::vector> &ops); + + /// default destructor + ~RandomChoiceOp() = default; + + /// return the number of inputs. All op in ops_ should have the same number of inputs + /// \return number of input tensors + uint32_t NumInput() override; + + /// return the number of outputs. All op in ops_ should have the same number of outputs + /// \return number of input tensors + uint32_t NumOutput() override; + + /// return output shape if all ops in ops_ return the same shape, otherwise return unknown shape + /// \param[in] inputs + /// \param[out] outputs + /// \return Status code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + /// return output type if all ops in ops_ return the same type, otherwise return unknown type + /// \param[in] inputs + /// \param[out] outputs + /// \return Status code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + /// \param[in] input + /// \param[out] output + /// \return Status code + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomChoiceOp; } + + private: + std::vector> ops_; + std::mt19937 gen_; // mersenne_twister_engine + std::uniform_int_distribution rand_int_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RANDOM_CHOICE_OP_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h index 1cf99830c9..39042cf6d4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ -#define DATASET_KERNELS_DATA_SLICE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_SLICE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_SLICE_OP_H_ #include #include @@ -67,8 +67,6 @@ class SliceOp : public TensorOp { ~SliceOp() override = default; - void Print(std::ostream &out) const override { out << "SliceOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; std::string Name() const override { return kSliceOp; } @@ -84,4 +82,4 @@ class SliceOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_SLICE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_SLICE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h index 91f660ca9c..6d0b2f6f30 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h @@ -39,8 +39,6 @@ class ToFloat16Op : public TensorOp { // @return Status - The error code return Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - void Print(std::ostream &out) const override { out << "ToFloat16Op"; } - Status OutputType(const std::vector &inputs, std::vector &outputs) override; std::string Name() const override { return kToFloat16Op; } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h index b82bc32342..64a84713da 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ -#define DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ #include #include @@ -39,7 +39,6 @@ class TypeCastOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - void Print(std::ostream &out) const override { out << "TypeCastOp"; } Status OutputType(const std::vector &inputs, std::vector &outputs) override; std::string Name() const override { return kTypeCastOp; } @@ -50,4 +49,4 @@ class TypeCastOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index c0c575de9a..fc4a6790be 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -1,11 +1,15 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(kernels-image OBJECT + auto_contrast_op.cc center_crop_op.cc + crop_op.cc cut_out_op.cc decode_op.cc + equalize_op.cc hwc_to_chw_op.cc image_utils.cc + invert_op.cc normalize_op.cc pad_op.cc random_color_adjust_op.cc @@ -19,11 +23,13 @@ add_library(kernels-image OBJECT bounding_box_augment_op.cc random_resize_op.cc random_rotation_op.cc + random_select_subpolicy_op.cc random_vertical_flip_op.cc random_vertical_flip_with_bbox_op.cc rescale_op.cc resize_bilinear_op.cc resize_op.cc + swap_red_blue_op.cc uniform_aug_op.cc resize_with_bbox_op.cc random_resize_with_bbox_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc new file mode 100644 index 0000000000..417d16783c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/kernels/image/auto_contrast_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { + +const float AutoContrastOp::kCutOff = 0.0; +const std::vector AutoContrastOp::kIgnore = {}; + +Status AutoContrastOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return AutoContrast(input, output, cutoff_, ignore_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h new file mode 100644 index 0000000000..6d8c847b67 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class AutoContrastOp : public TensorOp { + public: + /// Default cutoff to be used + static const float kCutOff; + /// Default ignore to be used + static const std::vector kIgnore; + + AutoContrastOp(const float &cutoff, const std::vector &ignore) : cutoff_(cutoff), ignore_(ignore) {} + + ~AutoContrastOp() override = default; + + /// Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const AutoContrastOp &so) { + so.Print(out); + return out; + } + + std::string Name() const override { return kAutoContrastOp; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + private: + float cutoff_; + std::vector ignore_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h index 8e30c5738d..c992c9196e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ -#define DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ #include #include @@ -47,8 +47,6 @@ class BoundingBoxAugmentOp : public TensorOp { return out; } - void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kBoundingBoxAugmentOp; } @@ -62,4 +60,4 @@ class BoundingBoxAugmentOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc index 35079b05cd..a27b2cb000 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc @@ -15,7 +15,7 @@ */ #include "minddata/dataset/kernels/image/center_crop_op.h" #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/util/status.h" @@ -36,6 +36,10 @@ Status CenterCropOp::Compute(const std::shared_ptr &input, std::shared_p int32_t top = crop_het_ - input->shape()[0]; // number of pixels to pad (top and bottom) int32_t left = crop_wid_ - input->shape()[1]; std::shared_ptr pad_image; + + CHECK_FAIL_RETURN_UNEXPECTED((top < input->shape()[0] * 3 && left < input->shape()[1] * 3), + "CenterCropOp padding size is too big, it's more than 3 times the original size."); + if (top > 0 && left > 0) { // padding only return Pad(input, output, top / 2 + top % 2, top / 2, left / 2 + left % 2, left / 2, BorderType::kConstant); } else if (top > 0) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h index 1f8cbcf230..532aaf70d5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ #include #include @@ -49,4 +49,4 @@ class CenterCropOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/crop_op.cc new file mode 100644 index 0000000000..5d1853b78d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/crop_op.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/crop_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status CropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape size " + std::to_string(input->shape().Size()) + + " of input tensor is invalid"); + int32_t input_h = static_cast(input->shape()[0]); + int32_t input_w = static_cast(input->shape()[1]); + CHECK_FAIL_RETURN_UNEXPECTED(y_ + height_ <= input_h, "Crop height dimensions exceed image dimensions"); + CHECK_FAIL_RETURN_UNEXPECTED(x_ + width_ <= input_w, "Crop width dimensions exceed image dimensions"); + return Crop(input, output, x_, y_, height_, width_); +} + +Status CropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{height_, width_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/crop_op.h new file mode 100644 index 0000000000..35d3ebeda5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/crop_op.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CROP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CROP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CropOp : public TensorOp { + public: + /// \brief Constructor to Crop Op + /// \param[in] x - the horizontal starting coordinate + /// \param[in] y - the vertical starting coordinate + /// \param[in] height - the height of the crop box + /// \param[in] width - the width of the crop box + explicit CropOp(int32_t x, int32_t y, int32_t height, int32_t width) : x_(x), y_(y), height_(height), width_(width) {} + + CropOp(const CropOp &rhs) = default; + + CropOp(CropOp &&rhs) = default; + + ~CropOp() override = default; + + void Print(std::ostream &out) const override { + out << "CropOp x: " << x_ << " y: " << y_ << " w: " << width_ << " h: " << height_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kCropOp; } + + protected: + int32_t x_; + int32_t y_; + int32_t height_; + int32_t width_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h index 263cbdb27c..3893c84fb3 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ -#define DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ #include #include @@ -76,4 +76,4 @@ class CutOutOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h index 29bf1d0146..b5bfe5a014 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_DECODE_OP_H_ -#define DATASET_KERNELS_IMAGE_DECODE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DECODE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DECODE_OP_H_ #include #include @@ -37,7 +37,6 @@ class DecodeOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - void Print(std::ostream &out) const override { out << "DecodeOp"; } Status OutputShape(const std::vector &inputs, std::vector &outputs) override; Status OutputType(const std::vector &inputs, std::vector &outputs) override; @@ -49,4 +48,4 @@ class DecodeOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_DECODE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DECODE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc new file mode 100644 index 0000000000..e5bf0fd628 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/equalize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { + +// only supports RGB images + +Status EqualizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Equalize(input, output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h new file mode 100644 index 0000000000..d7bc46b480 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class EqualizeOp : public TensorOp { + public: + EqualizeOp() {} + ~EqualizeOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kEqualizeOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h index 0d5f70f895..c2eb2e8759 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ -#define DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ #include #include @@ -28,8 +28,6 @@ namespace mindspore { namespace dataset { class HwcToChwOp : public TensorOp { public: - void Print(std::ostream &out) const override { out << "HwcToChw"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; Status OutputShape(const std::vector &inputs, std::vector &outputs) override; @@ -38,4 +36,4 @@ class HwcToChwOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index ddbce3e23a..86de12597b 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -20,7 +20,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/core/tensor.h" @@ -63,9 +63,8 @@ int GetCVBorderType(BorderType type) { Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code) { std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes())); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); if (input_cv->mat().data) { try { @@ -110,8 +109,9 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out TensorShape shape{output_height, output_width}; int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &output_cv)); + auto cv_mode = GetCVInterpolationMode(mode); cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); *output = std::static_pointer_cast(output_cv); @@ -147,8 +147,8 @@ Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *o RETURN_STATUS_UNEXPECTED(err); } cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); - std::shared_ptr output_cv = std::make_shared(img_mat); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(img_mat, &output_cv)); *output = std::static_pointer_cast(output_cv); return Status::OK(); } catch (const cv::Exception &e) { @@ -309,7 +309,8 @@ Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr(ts, DataType(DataType::DE_UINT8)); + std::shared_ptr output_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(ts, DataType(DataType::DE_UINT8), &output_tensor)); const int buffer_size = output_tensor->SizeInBytes(); JSAMPLE *buffer = reinterpret_cast(&(*output_tensor->begin())); const int max_scanlines_to_read = skipped_scanlines + crop_h; @@ -331,8 +332,8 @@ Status Rescale(const std::shared_ptr &input, std::shared_ptr *ou RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); } cv::Mat input_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv)); try { input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); *output = std::static_pointer_cast(output_cv); @@ -354,8 +355,8 @@ Status Crop(const std::shared_ptr &input, std::shared_ptr *outpu TensorShape shape{h, w}; int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &output_cv)); cv::Rect roi(x, y, w, h); (input_cv->mat())(roi).copyTo(output_cv->mat()); *output = std::static_pointer_cast(output_cv); @@ -386,10 +387,11 @@ Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) int height = input_cv->shape()[0]; int width = input_cv->shape()[1]; - auto output_cv = std::make_unique(TensorShape{num_channels, height, width}, input_cv->type()); + std::shared_ptr output_cv; + CVTensor::CreateEmpty(TensorShape{num_channels, height, width}, input_cv->type(), &output_cv); for (int i = 0; i < num_channels; ++i) { cv::Mat mat; - RETURN_IF_NOT_OK(output_cv->Mat({i}, &mat)); + RETURN_IF_NOT_OK(output_cv->MatAtIndex({i}, &mat)); cv::extractChannel(input_cv->mat(), mat, i); } *output = std::move(output_cv); @@ -406,8 +408,9 @@ Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *ou if (input_cv->shape().Size() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); + cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); *output = std::static_pointer_cast(output_cv); return Status::OK(); @@ -440,8 +443,8 @@ Status CropAndResize(const std::shared_ptr &input, std::shared_ptrshape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr cvt_out = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(cvt_out); + std::shared_ptr cvt_out; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &cvt_out)); cv::resize(cv_in(roi), cvt_out->mat(), cv::Size(target_width, target_height), 0, 0, cv_mode); *output = std::static_pointer_cast(cvt_out); return Status::OK(); @@ -475,8 +478,7 @@ Status Rotate(const std::shared_ptr &input, std::shared_ptr *out if (!expand) { // this case means that the shape doesn't change, size stays the same // We may not need this memcpy if it is in place. - output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); // using inter_nearest to comply with python default cv::warpAffine(input_img, output_cv->mat(), rot, input_img.size(), GetCVInterpolationMode(interpolation), cv::BORDER_CONSTANT, fill_color); @@ -489,7 +491,7 @@ Status Rotate(const std::shared_ptr &input, std::shared_ptr *out // use memcpy and don't compute the new shape since openCV has a rounding problem cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation), cv::BORDER_CONSTANT, fill_color); - output_cv = std::make_shared(output_img); + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &output_cv)); RETURN_UNEXPECTED_IF_NULL(output_cv); } *output = std::static_pointer_cast(output_cv); @@ -506,8 +508,8 @@ Status Normalize(const std::shared_ptr &input, std::shared_ptr * RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); } cv::Mat in_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv)); mean->Squeeze(); if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { std::string err_msg = "Mean tensor should be of size 3 and type float."; @@ -548,8 +550,8 @@ Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptrRank() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); output_cv->mat() = input_img * alpha; *output = std::static_pointer_cast(output_cv); } catch (const cv::Exception &e) { @@ -572,8 +574,8 @@ Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr(cv::mean(gray).val[0] + 0.5); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); output_img = cv::Mat::zeros(input_img.rows, input_img.cols, CV_8UC1); output_img = output_img + mean_img; cv::cvtColor(output_img, output_img, CV_GRAY2RGB); @@ -585,6 +587,112 @@ Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &cutoff, + const std::vector &ignore) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // Reshape to extend dimension if rank is 2 for algorithm to work. then reshape output to be of rank 2 like input + if (input_cv->Rank() == 2) { + RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); + } + // Get number of channels and image matrix + std::size_t num_of_channels = input_cv->shape()[2]; + if (num_of_channels != 1 && num_of_channels != 3) { + RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); + } + cv::Mat image = input_cv->mat(); + // Separate the image to channels + std::vector planes(num_of_channels); + cv::split(image, planes); + cv::Mat b_hist, g_hist, r_hist; + // Establish the number of bins and set variables for histogram + int32_t hist_size = 256; + int32_t channels = 0; + float range[] = {0, 256}; + const float *hist_range[] = {range}; + bool uniform = true, accumulate = false; + // Set up lookup table for LUT(Look up table algorithm) + std::vector table; + std::vector image_result; + for (std::size_t layer = 0; layer < planes.size(); layer++) { + // Reset lookup table + table = std::vector{}; + // Calculate Histogram for channel + cv::Mat hist; + cv::calcHist(&planes[layer], 1, &channels, cv::Mat(), hist, 1, &hist_size, hist_range, uniform, accumulate); + hist.convertTo(hist, CV_32SC1); + std::vector hist_vec; + hist.col(0).copyTo(hist_vec); + // Ignore values in ignore + for (const auto &item : ignore) hist_vec[item] = 0; + int32_t n = std::accumulate(hist_vec.begin(), hist_vec.end(), 0); + // Find pixel values that are in the low cutoff and high cutoff. + int32_t cut = static_cast((cutoff / 100.0) * n); + if (cut != 0) { + for (int32_t lo = 0; lo < 256 && cut > 0; lo++) { + if (cut > hist_vec[lo]) { + cut -= hist_vec[lo]; + hist_vec[lo] = 0; + } else { + hist_vec[lo] -= cut; + cut = 0; + } + } + cut = static_cast((cutoff / 100.0) * n); + for (int32_t hi = 255; hi >= 0 && cut > 0; hi--) { + if (cut > hist_vec[hi]) { + cut -= hist_vec[hi]; + hist_vec[hi] = 0; + } else { + hist_vec[hi] -= cut; + cut = 0; + } + } + } + int32_t lo = 0; + int32_t hi = 255; + for (; lo < 256 && !hist_vec[lo]; lo++) { + } + for (; hi >= 0 && !hist_vec[hi]; hi--) { + } + if (hi <= lo) { + for (int32_t i = 0; i < 256; i++) { + table.push_back(i); + } + } else { + const float scale = 255.0 / (hi - lo); + const float offset = -1 * lo * scale; + for (int32_t i = 0; i < 256; i++) { + int32_t ix = static_cast(i * scale + offset); + ix = std::max(ix, 0); + ix = std::min(ix, 255); + table.push_back(ix); + } + } + cv::Mat result_layer; + cv::LUT(planes[layer], table, result_layer); + image_result.push_back(result_layer); + } + cv::Mat result; + cv::merge(image_result, result); + result.convertTo(result, input_cv->mat().type()); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv)); + (*output) = std::static_pointer_cast(output_cv); + (*output) = std::static_pointer_cast(output_cv); + (*output)->Reshape(input->shape()); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in auto contrast"); + } + return Status::OK(); +} + Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { try { std::shared_ptr input_cv = CVTensor::AsCVTensor(input); @@ -596,8 +704,8 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptrRank() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); cv::Mat output_img = output_cv->mat(); cv::Mat gray; cv::cvtColor(input_img, gray, CV_RGB2GRAY); @@ -625,8 +733,8 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * if (input_cv->Rank() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); cv::Mat output_img; cv::cvtColor(input_img, output_img, CV_RGB2HSV_FULL); for (int y = 0; y < output_img.cols; y++) { @@ -646,6 +754,47 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * return Status::OK(); } +Status Equalize(const std::shared_ptr &input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2. + if (input_cv->Rank() == 2) { + RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); + } + // Get number of channels and image matrix + std::size_t num_of_channels = input_cv->shape()[2]; + if (num_of_channels != 1 && num_of_channels != 3) { + RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); + } + cv::Mat image = input_cv->mat(); + // Separate the image to channels + std::vector planes(num_of_channels); + cv::split(image, planes); + // Equalize each channel separately + std::vector image_result; + for (std::size_t layer = 0; layer < planes.size(); layer++) { + cv::Mat channel_result; + cv::equalizeHist(planes[layer], channel_result); + image_result.push_back(channel_result); + } + cv::Mat result; + cv::merge(image_result, result); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv)); + (*output) = std::static_pointer_cast(output_cv); + (*output)->Reshape(input->shape()); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in equalize."); + } + return Status::OK(); +} + Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { @@ -723,9 +872,11 @@ Status Pad(const std::shared_ptr &input, std::shared_ptr *output } else { cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type); } - std::shared_ptr output_cv = std::make_shared(out_image); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_image, &output_cv)); // pad the dimension if shape information is only 2 dimensional, this is grayscale + CHECK_FAIL_RETURN_UNEXPECTED(input_cv->Rank() == 3, + "Pad error: invalid image shape, only support 3 channels image."); int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2); *output = std::static_pointer_cast(output_cv); @@ -788,7 +939,7 @@ Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, } } std::shared_ptr retV; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast(*bboxCount), bboxDim}))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(copyVals, TensorShape({static_cast(*bboxCount), bboxDim}), &retV)); (*bboxList) = retV; // reset pointer return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index f489c7367b..4e1b689bec 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ -#define DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ #include @@ -175,6 +175,14 @@ Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &alpha); +// Returns image with contrast maximized. +// @param input: Tensor of shape // in RGB/Grayscale and any OpenCv compatible type, see CVTensor. +// @param cutoff: Cutoff percentage of how many pixels are to be removed (high pixels change to 255 and low change to 0) +// from the high and low ends of the histogram. +// @param ignore: Pixel values to be ignored in the algorithm. +Status AutoContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &cutoff, + const std::vector &ignore); + // Returns image with adjusted saturation. // @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. // @param alpha: Alpha value to adjust saturation by. Should be a positive number. @@ -192,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &hue); +/// \brief Returns image with equalized histogram. +/// \param[in] input: Tensor of shape // in RGB/Grayscale and +/// any OpenCv compatible type, see CVTensor. +/// \param[out] output: Equalized image of same shape and type. +Status Equalize(const std::shared_ptr &input, std::shared_ptr *output); + // Masks out a random section from the image with set dimension // @param input: input Tensor // @param output: cutOut Tensor @@ -256,4 +270,4 @@ Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc new file mode 100644 index 0000000000..ed46194baa --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/invert_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +// only supports RGB images + +Status InvertOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + + if (input_cv->Rank() != 3) { + RETURN_STATUS_UNEXPECTED("Shape not "); + } + int num_channels = input_cv->shape()[2]; + if (num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: num of channels != 3"); + } + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); + RETURN_UNEXPECTED_IF_NULL(output_cv); + + output_cv->mat() = cv::Scalar::all(255) - input_img; + *output = std::static_pointer_cast(output_cv); + } + + catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in invert"); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.h new file mode 100644 index 0000000000..3a9ee464c9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_INVERT_OP_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_INVERT_OP_H + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class InvertOp : public TensorOp { + public: + InvertOp() {} + ~InvertOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kInvertOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_INVERT_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc index de5deb31ef..56593e33ca 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc @@ -24,20 +24,14 @@ namespace mindspore { namespace dataset { NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { - int size[] = {3}; - cv::Mat mean_cv(1, size, CV_32F); - mean_cv.at(0) = mean_r; - mean_cv.at(1) = mean_g; - mean_cv.at(2) = mean_b; - mean_ = std::make_shared(mean_cv); - mean_->Squeeze(); - - cv::Mat std_cv(1, size, CV_32F); - std_cv.at(0) = std_r; - std_cv.at(1) = std_g; - std_cv.at(2) = std_b; - std_ = std::make_shared(std_cv); - std_->Squeeze(); + Status s = Tensor::CreateFromVector({mean_r, mean_g, mean_b}, &mean_); + if (s.IsError()) { + MS_LOG(ERROR) << "Could not create mean tensor."; + } + s = Tensor::CreateFromVector({std_r, std_g, std_b}, &std_); + if (s.IsError()) { + MS_LOG(ERROR) << "Could not create std tensor."; + } } Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { @@ -47,9 +41,7 @@ Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_pt } void NormalizeOp::Print(std::ostream &out) const { - out << "NormalizeOp, mean: " << mean_->mat().at(0) << ", " << mean_->mat().at(1) << ", " - << mean_->mat().at(2) << "std: " << std_->mat().at(0) << ", " << std_->mat().at(1) << ", " - << std_->mat().at(2) << std::endl; + out << "NormalizeOp, mean: " << mean_ << std::endl << "std: " << std_ << std::endl; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h index 7821869c8f..4e4b760abd 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ #include #include @@ -39,10 +39,10 @@ class NormalizeOp : public TensorOp { std::string Name() const override { return kNormalizeOp; } private: - std::shared_ptr mean_; - std::shared_ptr std_; + std::shared_ptr mean_; + std::shared_ptr std_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h index 9437058406..7f2e313c38 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_PAD_OP_H_ -#define DATASET_KERNELS_IMAGE_PAD_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_PAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_PAD_OP_H_ #include #include @@ -49,8 +49,6 @@ class PadOp : public TensorOp { ~PadOp() override = default; - void Print(std::ostream &out) const override { out << "PadOp: "; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; Status OutputShape(const std::vector &inputs, std::vector &outputs) override; @@ -69,4 +67,4 @@ class PadOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_PAD_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_PAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h index fb29b57062..5555963ac5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ #include #include @@ -47,10 +47,6 @@ class RandomColorAdjustOp : public TensorOp { ~RandomColorAdjustOp() override = default; - // Print function for RandomJitter. - // @param out output stream to print to. - void Print(std::ostream &out) const override { out << "RandomColorAdjustOp: "; } - // Overrides the base class compute function. // Calls multiple transform functions in ImageUtils, this function takes an input tensor. // and transforms its data using openCV, the output memory is manipulated to contain the result. @@ -77,4 +73,4 @@ class RandomColorAdjustOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h index 41d775fdf7..ed92c0cfdd 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ #include #include @@ -75,4 +75,4 @@ class RandomCropAndResizeOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h index ddaac10fac..43255b5e18 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" #include @@ -46,4 +46,4 @@ class RandomCropAndResizeWithBBoxOp : public RandomCropAndResizeOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h index 863fd48c14..161bdaf42f 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ #include #include @@ -40,8 +40,7 @@ class RandomCropDecodeResizeOp : public RandomCropAndResizeOp { ~RandomCropDecodeResizeOp() override = default; void Print(std::ostream &out) const override { - out << "RandomCropDecodeResize: " << RandomCropAndResizeOp::target_height_ << " " - << RandomCropAndResizeOp::target_width_; + out << Name() << ": " << RandomCropAndResizeOp::target_height_ << " " << RandomCropAndResizeOp::target_width_; } Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; @@ -51,4 +50,4 @@ class RandomCropDecodeResizeOp : public RandomCropAndResizeOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc index 51772e9ec3..c53e0c06d2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc @@ -56,6 +56,10 @@ Status RandomCropOp::ImagePadding(const std::shared_ptr &input, std::sha *t_pad_left = pad_left_; *t_pad_right = pad_right_; + CHECK_FAIL_RETURN_UNEXPECTED(pad_top_ < input->shape()[0] * 3 && pad_bottom_ < input->shape()[0] * 3 && + pad_left_ < input->shape()[1] * 3 && pad_right_ < input->shape()[1] * 3, + "RandomCropBBoxOp padding size is too big, it's more than 3 times the original size."); + RETURN_IF_NOT_OK( Pad(input, pad_image, pad_top_, pad_bottom_, pad_left_, pad_right_, border_type_, fill_r_, fill_g_, fill_b_)); CHECK_FAIL_RETURN_UNEXPECTED((*pad_image)->shape().Size() >= 2, "Abnormal shape"); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h index 44f1789f9d..3dfb3f713d 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ #include #include @@ -52,7 +52,7 @@ class RandomCropOp : public TensorOp { ~RandomCropOp() override = default; - void Print(std::ostream &out) const override { out << "RandomCropOp: " << crop_height_ << " " << crop_width_; } + void Print(std::ostream &out) const override { out << Name() << ": " << crop_height_ << " " << crop_width_; } Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; @@ -98,4 +98,4 @@ class RandomCropOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h index bfcd1610d3..479f087954 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ #include #include @@ -38,7 +38,7 @@ class RandomCropWithBBoxOp : public RandomCropOp { ~RandomCropWithBBoxOp() override = default; void Print(std::ostream &out) const override { - out << "RandomCropWithBBoxOp: " << RandomCropOp::crop_height_ << " " << RandomCropOp::crop_width_; + out << Name() << ": " << RandomCropOp::crop_height_ << " " << RandomCropOp::crop_width_; } Status Compute(const TensorRow &input, TensorRow *output) override; @@ -48,4 +48,4 @@ class RandomCropWithBBoxOp : public RandomCropOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h index 9e08929180..53c11df1a6 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ #include #include @@ -44,8 +44,6 @@ class RandomHorizontalFlipOp : public TensorOp { return out; } - void Print(std::ostream &out) const override { out << "RandomHorizontalFlipOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; std::string Name() const override { return kRandomHorizontalFlipOp; } @@ -57,4 +55,4 @@ class RandomHorizontalFlipOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h index d98669ea13..f8e1e847f6 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ #include #include @@ -45,8 +45,6 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp { return out; } - void Print(std::ostream &out) const override { out << "RandomHorizontalFlipWithBBoxOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kRandomHorizontalFlipWithBBoxOp; } @@ -58,4 +56,4 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h index 8b2b067751..77dee5b4d9 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ #include #include @@ -40,9 +40,7 @@ class RandomResizeOp : public ResizeOp { ~RandomResizeOp() = default; // Description: A function that prints info about the node - void Print(std::ostream &out) const override { - out << "RandomResizeOp: " << ResizeOp::size1_ << " " << ResizeOp::size2_; - } + void Print(std::ostream &out) const override { out << Name() << ": " << ResizeOp::size1_ << " " << ResizeOp::size2_; } Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; @@ -55,4 +53,4 @@ class RandomResizeOp : public ResizeOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h index 6bad0d30fa..dbca032520 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H -#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H #include #include @@ -42,7 +42,7 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { // Description: A function that prints info about the node void Print(std::ostream &out) const override { - out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; + out << Name() << ": " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; } Status Compute(const TensorRow &input, TensorRow *output) override; @@ -56,4 +56,4 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h index ea679ccb56..bdd5cda97a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ #include #include @@ -58,10 +58,6 @@ class RandomRotationOp : public TensorOp { ~RandomRotationOp() override = default; - // Print function for RandomRotation - // @param out output stream to print to - void Print(std::ostream &out) const override { out << "RandomRotationOp: "; } - // Overrides the base class compute function // Calls the rotate function in ImageUtils, this function takes an input tensor // and transforms its data using openCV, the output memory is manipulated to contain the result @@ -87,4 +83,4 @@ class RandomRotationOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc new file mode 100644 index 0000000000..d01231f1f8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h" + +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) { + TensorRow in_row = input; + size_t rand_num = rand_int_(gen_); + CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(), "invalid rand_num:" + std::to_string(rand_num)); + for (auto &sub : policy_[rand_num]) { + if (rand_double_(gen_) <= sub.second) { + RETURN_IF_NOT_OK(sub.first->Compute(in_row, output)); + in_row = std::move(*output); + } + } + *output = std::move(in_row); + return Status::OK(); +} + +uint32_t RandomSelectSubpolicyOp::NumInput() { + uint32_t num_in = policy_.front().front().first->NumInput(); + for (auto &sub : policy_) { + for (auto p : sub) { + if (num_in != p.first->NumInput()) { + MS_LOG(WARNING) << "Unable to determine numInput."; + return 0; + } + } + } + return num_in; +} + +uint32_t RandomSelectSubpolicyOp::NumOutput() { + uint32_t num_out = policy_.front().front().first->NumOutput(); + for (auto &sub : policy_) { + for (auto p : sub) { + if (num_out != p.first->NumOutput()) { + MS_LOG(WARNING) << "Unable to determine numInput."; + return 0; + } + } + } + return num_out; +} + +Status RandomSelectSubpolicyOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + outputs.clear(); + outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape()); + return Status::OK(); +} + +Status RandomSelectSubpolicyOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs)); + for (auto &sub : policy_) { + for (auto p : sub) { + std::vector tmp_types; + RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types)); + if (outputs != tmp_types) { + outputs.clear(); + outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN)); + return Status::OK(); + } + } + } + return Status::OK(); +} +RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector &policy) + : gen_(GetSeed()), policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) { + if (policy_.empty()) { + MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty."; + } +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.h new file mode 100644 index 0000000000..aea76c2ace --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { + +using Subpolicy = std::vector, double>>; + +class RandomSelectSubpolicyOp : public TensorOp { + public: + /// constructor + /// \param[in] policy policy to choose subpolicy from + explicit RandomSelectSubpolicyOp(const std::vector &policy); + + /// destructor + ~RandomSelectSubpolicyOp() override = default; + + /// return number of input tensors + /// \return number of inputs if all ops in policy have the same NumInput, otherwise return 0 + uint32_t NumInput() override; + + /// return number of output tensors + /// \return number of outputs if all ops in policy have the same NumOutput, otherwise return 0 + uint32_t NumOutput() override; + + /// return unknown shapes + /// \param[in] inputs + /// \param[out] outputs + /// \return Status Code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + /// return output type if all ops in policy return the same type, otherwise return unknown type + /// \param[in] inputs + /// \param[out] outputs + /// \return Status Code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + /// \param[in] input + /// \param[out] output + /// \return Status code + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomSelectSubpolicyOp; } + + private: + std::vector policy_; + std::mt19937 gen_; // mersenne_twister_engine + std::uniform_int_distribution rand_int_; + std::uniform_real_distribution rand_double_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h index cee5869c71..1724d7a57d 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ #include #include @@ -38,8 +38,6 @@ class RandomVerticalFlipOp : public TensorOp { ~RandomVerticalFlipOp() override = default; - void Print(std::ostream &out) const override { out << "RandomVerticalFlipOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; std::string Name() const override { return kRandomVerticalFlipOp; } @@ -51,4 +49,4 @@ class RandomVerticalFlipOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h index c9f19f5217..f46101cc48 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ #include #include @@ -29,7 +29,6 @@ namespace mindspore { namespace dataset { class RandomVerticalFlipWithBBoxOp : public TensorOp { public: - // Default values, also used by python_bindings.cc static const float kDefProbability; // Constructor for RandomVerticalFlipWithBBoxOp // @param probability: Probablity of Image flipping, 0.5 by default @@ -39,8 +38,6 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp { ~RandomVerticalFlipWithBBoxOp() override = default; - void Print(std::ostream &out) const override { out << "RandomVerticalFlipWithBBoxOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kRandomVerticalFlipWithBBoxOp; } @@ -52,4 +49,4 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h index c70b7bf6cf..1a6f597ad5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RESCALE_OP_H_ -#define DATASET_KERNELS_IMAGE_RESCALE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESCALE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESCALE_OP_H_ #include #include @@ -33,7 +33,7 @@ class RescaleOp : public TensorOp { ~RescaleOp() override = default; void Print(std::ostream &out) const override { - out << "RescaleOp: shift: " << shift_ << ", Rescale: " << rescale_ << std::endl; + out << Name() << ": shift: " << shift_ << ", Rescale: " << rescale_ << std::endl; } Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; @@ -47,4 +47,4 @@ class RescaleOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RESCALE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESCALE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc index 48a8fbbc53..db10f5f10c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc @@ -22,6 +22,5 @@ namespace mindspore { namespace dataset { const int32_t ResizeBilinearOp::kDefWidth = 0; -void ResizeBilinearOp::Print(std::ostream &out) const { out << "ResizeBilinearOp: "; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h index fd8f940946..ab5ecd292a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ -#define DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ #include #include @@ -48,13 +48,9 @@ class ResizeBilinearOp : public ResizeOp { // Description: Destructor ~ResizeBilinearOp() = default; - // Name: Print() - // Description: A function that prints info about the node - void Print(std::ostream &out) const override; - std::string Name() const override { return kResizeBilinearOp; } }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h index 3f847243ff..149cab6e85 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RESIZE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_OP_H_ #include #include @@ -50,7 +50,7 @@ class ResizeOp : public TensorOp { ~ResizeOp() override = default; - void Print(std::ostream &out) const override { out << "ResizeOp: " << size1_ << " " << size2_; } + void Print(std::ostream &out) const override { out << Name() << ": " << size1_ << " " << size2_; } Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; Status OutputShape(const std::vector &inputs, std::vector &outputs) override; @@ -65,4 +65,4 @@ class ResizeOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RESIZE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc index 9df2d8a25e..8d40514f1b 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc @@ -20,7 +20,6 @@ #include "minddata/dataset/kernels/image/resize_op.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/core/cv_tensor.h" -#include "minddata/dataset/core/pybind_support.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/util/status.h" diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h index d2b5c96bf3..18781da1a5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H -#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H #include #include "minddata/dataset/core/tensor.h" @@ -34,7 +34,7 @@ class ResizeWithBBoxOp : public ResizeOp { ~ResizeWithBBoxOp() override = default; - void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } + void Print(std::ostream &out) const override { out << Name() << ": " << size1_ << " " << size2_; } Status Compute(const TensorRow &input, TensorRow *output) override; @@ -43,4 +43,4 @@ class ResizeWithBBoxOp : public ResizeOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/swap_red_blue_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/swap_red_blue_op.cc new file mode 100644 index 0000000000..cee93a323a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/swap_red_blue_op.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/swap_red_blue_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status SwapRedBlueOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return SwapRedAndBlue(input, output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/swap_red_blue_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/swap_red_blue_op.h new file mode 100644 index 0000000000..c42bbb80c7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/swap_red_blue_op.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_SWAP_RED_BLUE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_SWAP_RED_BLUE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class SwapRedBlueOp : public TensorOp { + public: + // SwapRedBlues the image to the output specified size. If only one value is provided, + // the it will crop the smaller size and maintains the aspect ratio. + // @param size1: the first size of output. If only this parameter is provided + // the smaller dimension will be cropd to this and then the other dimension changes + // such that the aspect ratio is maintained. + // @param size2: the second size of output. If this is also provided, the output size + // will be (size1, size2) + // @param InterpolationMode: the interpolation mode being used. + SwapRedBlueOp() {} + + SwapRedBlueOp(const SwapRedBlueOp &rhs) = default; + + SwapRedBlueOp(SwapRedBlueOp &&rhs) = default; + + ~SwapRedBlueOp() override = default; + + void Print(std::ostream &out) const override { out << "SwapRedBlueOp x"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kSwapRedBlueOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_SWAP_RED_BLUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h index 0ae0fda92b..ec7080d1de 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ -#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ #include #include @@ -40,7 +40,7 @@ class UniformAugOp : public TensorOp { // Destructor ~UniformAugOp() override = default; - void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } + void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; } // Overrides the base class compute function // @return Status - The error code return @@ -56,4 +56,4 @@ class UniformAugOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/no_op.h b/mindspore/ccsrc/minddata/dataset/kernels/no_op.h deleted file mode 100644 index f5a6a58f2b..0000000000 --- a/mindspore/ccsrc/minddata/dataset/kernels/no_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_NO_OP_H_ -#define DATASET_KERNELS_NO_OP_H_ - -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class NoOp : public TensorOp { - public: - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { - *output = input; - return Status::OK(); - } - - void Print(std::ostream &out) const override { out << "NoOp"; }; - - std::string Name() const override { return kNoOp; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_NO_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc index f501dd4b4f..dbf2dfe73e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc @@ -49,7 +49,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { if (py::isinstance(ret_py_obj)) { // In case of a n-1 mapping, the return value will be a numpy array std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast())); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_obj.cast(), &out)); output->push_back(out); } else if (py::isinstance(ret_py_obj)) { // In case of a n-m mapping, the return value will be a tuple of numpy arrays @@ -61,7 +61,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { goto ShapeMisMatch; } std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast())); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast(), &out)); output->push_back(out); } } else { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h index 75d222b433..d11437b490 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_KERNELS_PY_FUNC_OP_H_ -#define DATASET_KERNELS_PY_FUNC_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_PY_FUNC_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_PY_FUNC_OP_H_ #include #include @@ -47,4 +47,4 @@ class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_PY_FUNC_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_PY_FUNC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc index b625e3b532..e394284679 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc @@ -48,8 +48,6 @@ Status TensorOp::Compute(const TensorRow &input, TensorRow *output) { "Is this TensorOp oneToOne? If no, please implement this Compute() in the derived class."); } -void TensorOp::Print(std::ostream &out) const { out << "TensorOp" << std::endl; } - Status TensorOp::OutputShape(const std::vector &inputs, std::vector &outputs) { if (inputs.size() != NumInput()) return Status(StatusCode::kUnexpectedError, diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 3bcba4b463..b6fad31330 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_TENSOR_OP_H_ -#define DATASET_KERNELS_TENSOR_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_ #include #include @@ -86,12 +86,19 @@ namespace mindspore { namespace dataset { +// base class +constexpr char kTensorOp[] = "TensorOp"; + // image +constexpr char kAutoContrastOp[] = "AutoContrastOp"; constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kDecodeOp[] = "DecodeOp"; constexpr char kCenterCropOp[] = "CenterCropOp"; constexpr char kCutOutOp[] = "CutOutOp"; +constexpr char kCropOp[] = "CropOp"; +constexpr char kEqualizeOp[] = "EqualizeOp"; constexpr char kHwcToChwOp[] = "HwcToChwOp"; +constexpr char kInvertOp[] = "InvertOp"; constexpr char kNormalizeOp[] = "NormalizeOp"; constexpr char kPadOp[] = "PadOp"; constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; @@ -111,6 +118,7 @@ constexpr char kRescaleOp[] = "RescaleOp"; constexpr char kResizeBilinearOp[] = "ResizeBilinearOp"; constexpr char kResizeOp[] = "ResizeOp"; constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; +constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; constexpr char kUniformAugOp[] = "UniformAugOp"; // text @@ -120,6 +128,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp"; constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; constexpr char kLookupOp[] = "LookupOp"; constexpr char kNgramOp[] = "NgramOp"; +constexpr char kSlidingWindowOp[] = "SlidingWindowOp"; constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; @@ -129,9 +138,14 @@ constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp"; constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp"; constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp"; constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp"; +constexpr char kRandomChoiceOp[] = "RandomChoiceOp"; +constexpr char kRandomApplyOp[] = "RandomApplyOp"; +constexpr char kComposeOp[] = "ComposeOp"; +constexpr char kRandomSelectSubpolicyOp[] = "RandomSelectSubpolicyOp"; +constexpr char kSentencepieceTokenizerOp[] = "SentencepieceTokenizerOp"; // data -constexpr char kConcatenateOp[] = "kConcatenateOp"; +constexpr char kConcatenateOp[] = "ConcatenateOp"; constexpr char kDuplicateOp[] = "DuplicateOp"; constexpr char kFillOp[] = "FillOp"; constexpr char kMaskOp[] = "MaskOp"; @@ -154,7 +168,7 @@ class TensorOp { // A function that prints info about the tensor operation // @param out - virtual void Print(std::ostream &out) const; + virtual void Print(std::ostream &out) const { out << Name() << std::endl; } // Provide stream operator for displaying it // @param output stream @@ -209,4 +223,4 @@ class TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_TENSOR_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt index 605b2644b7..04c65fec2d 100644 --- a/mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt @@ -4,6 +4,7 @@ file(GLOB _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(text OBJECT vocab.cc + sentence_piece_vocab.cc ) add_dependencies(text text-kernels) \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt index 449bb93d8b..a9903ec0a0 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt @@ -12,12 +12,15 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows")) whitespace_tokenizer_op.cc) endif() add_library(text-kernels OBJECT + data_utils.cc lookup_op.cc jieba_tokenizer_op.cc unicode_char_tokenizer_op.cc ngram_op.cc + sliding_window_op.cc wordpiece_tokenizer_op.cc truncate_sequence_pair_op.cc to_number_op.cc + sentence_piece_tokenizer_op.cc ${ICU_DEPEND_FILES} ) diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc index 6195572944..f530edf779 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc @@ -136,8 +136,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptrbegin(); iter != input->end(); iter++) { RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } Status BasicTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h index cbc21273c2..39b4ec521d 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ #include #include #include @@ -45,8 +45,6 @@ class BasicTokenizerOp : public TensorOp { ~BasicTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "BasicTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; protected: @@ -74,4 +72,4 @@ class BasicTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h index b281903349..5af84da5d9 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ #include #include @@ -42,8 +42,6 @@ class BertTokenizerOp : public TensorOp { ~BertTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "BertTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kBertTokenizerOp; } @@ -54,4 +52,4 @@ class BertTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc index 0ea5cadedb..b38df2f0f6 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc @@ -39,8 +39,7 @@ Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr nfkc_case_fold->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h index f7a2105269..9b2f4bef1d 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ -#define DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ #include #include @@ -31,12 +31,10 @@ class CaseFoldOp : public TensorOp { ~CaseFoldOp() override = default; - void Print(std::ostream &out) const override { out << "CaseFoldOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; std::string Name() const override { return kCaseFoldOp; } }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc new file mode 100644 index 0000000000..17b4c613ba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/data_utils.h" + +#include +#include +#include +#include + +#include "minddata/dataset/core/pybind_support.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" + +namespace mindspore { +namespace dataset { +Status SlidingWindowHelper(const std::shared_ptr &input, std::shared_ptr *output, TensorShape out_shape, + uint32_t width, int32_t axis) { + // if the data row has fewer items than width, the corresponding result row will be empty + if (out_shape.Size() == 0) { + MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty."; + return Tensor::CreateEmpty(TensorShape({0}), input->type(), output); + } + + axis = Tensor::HandleNeg(axis, input->shape().Size()); + int32_t axis_end = input->shape()[axis]; + std::shared_ptr tmp; + auto concatenate_op = std::make_unique(axis, nullptr, nullptr); + + // Slice on specified axis and concatenate on new axis + for (int32_t i = 0; i + width <= axis_end; i++) { + auto slice_op = std::make_unique(Slice(i, i + width, 1)); + slice_op->Compute(input, &tmp); + if (i == 0) { + *output = tmp; + } else { + TensorRow in({*output, tmp}); + TensorRow out_row; + concatenate_op->Compute(in, &out_row); + *output = out_row[0]; + } + } + (*output)->Reshape(out_shape); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h new file mode 100644 index 0000000000..6c0e1f8dce --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_ + +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +/// \brief Helper method that perform sliding window on input tensor. +/// \param[in] input - Input tensor. +/// \param[in] out_shape - Output shape of output tensor. +/// \param[in] width - The axis along which sliding window is computed. +/// \param[in] axis - The width of the window. +/// \param[out] output - Output tensor +/// \return Status return code +Status SlidingWindowHelper(const std::shared_ptr &input, std::shared_ptr *output, TensorShape out_shape, + uint32_t width, int32_t axis); +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc index 0a1ae92d14..abcf72c9da 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc @@ -68,15 +68,12 @@ Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { offsets_limit.push_back(static_cast(item.offset + item.word.length())); } } - token_tensor = std::make_shared(words, TensorShape({(dsize_t)words.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(words, &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h index 4e49891c00..a319ccd015 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ -#define DATASET_ENGINE_TEXT_JIEBA_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TEXT_JIEBA_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TEXT_JIEBA_OP_H_ #include #include @@ -46,8 +46,7 @@ class JiebaTokenizerOp : public TensorOp { ~JiebaTokenizerOp() override = default; void Print(std::ostream &out) const override { - out << "JiebaTokenizerOp: " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" - << mp_dict_path_; + out << Name() << ": " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" << mp_dict_path_; } Status Compute(const TensorRow &input, TensorRow *output) override; @@ -68,4 +67,4 @@ class JiebaTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TEXT_JIEBA_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc index 02b75bc4f9..0317804416 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc @@ -26,7 +26,7 @@ LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); RETURN_UNEXPECTED_IF_NULL(vocab_); - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor."); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None string tensor received."); std::vector word_ids; word_ids.reserve(input->Size()); for (auto itr = input->begin(); itr != input->end(); itr++) { @@ -34,16 +34,14 @@ Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptrshape(), type_, - reinterpret_cast(word_ids.data()))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output)); return Status::OK(); } Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { - CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); - CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match."); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type."); outputs[0] = type_; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h index 4efc64321b..bd1bf67cd3 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_LOOKUP_OP_H_ -#define DATASET_TEXT_KERNELS_LOOKUP_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_ #include #include @@ -64,4 +64,4 @@ class LookupOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_LOOKUP_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc index 36781b9b4d..27b8cb6065 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc @@ -67,7 +67,7 @@ Status NgramOp::Compute(const std::shared_ptr &input, std::shared_ptr(res.size())}))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(res, TensorShape({static_cast(res.size())}), output)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h index 6ce3881638..32032ee7a4 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_NGRAM_OP_H_ -#define DATASET_TEXT_KERNELS_NGRAM_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_NGRAM_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_NGRAM_OP_H_ #include #include @@ -72,4 +72,4 @@ class NgramOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_NGRAM_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_NGRAM_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc index 0c0aa5fa2d..b669ca9a8a 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc @@ -68,8 +68,7 @@ Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::share normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h index f914be1c58..66b630adb1 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ -#define DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ #include #include @@ -39,8 +39,6 @@ class NormalizeUTF8Op : public TensorOp { ~NormalizeUTF8Op() override = default; - void Print(std::ostream &out) const override { out << "NormalizeUTF8Op"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; std::string Name() const override { return kNormalizeUTF8Op; } @@ -50,4 +48,4 @@ class NormalizeUTF8Op : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc index c370393e76..b36afba8fc 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc @@ -50,8 +50,7 @@ Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared for (auto iter = input->begin(); iter != input->end(); iter++) { RETURN_IF_NOT_OK(RegexReplace(&matcher, *iter, &strs[i])); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h index ac3d3f7ff0..ae9723da09 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ -#define DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ #include #include @@ -38,8 +38,6 @@ class RegexReplaceOp : public TensorOp { ~RegexReplaceOp() override = default; - void Print(std::ostream &out) const override { out << "RegexReplaceOp"; } - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; std::string Name() const override { return kRegexReplaceOp; } @@ -54,4 +52,4 @@ class RegexReplaceOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc index 7ff1d994be..95cb455276 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc @@ -120,15 +120,11 @@ Status RegexTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; RETURN_IF_NOT_OK(input[0]->GetItemAt(&text, {})); RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens, &offsets_start, &offsets_limit)); - token_tensor = std::make_shared(std::move(tokens), TensorShape({(dsize_t)tokens.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(std::move(tokens), &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h index 56271f9551..eabed89480 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_REGEX_TOKENIZER_OP_H_ -#define DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_REGEX_TOKENIZER_OP_H_ #include #include #include @@ -43,8 +43,6 @@ class RegexTokenizerOp : public TensorOp { ~RegexTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "RegexTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; protected: @@ -63,4 +61,4 @@ class RegexTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_REGEX_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc new file mode 100644 index 0000000000..919f108237 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h" + +#include +#include + +#include "utils/ms_utils.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr vocab, + const SPieceTokenizerLoadType load_type, + const SPieceTokenizerOutType out_type) + : vocab_(vocab), load_type_(load_type), out_type_(out_type) { + auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); + if (!status.ok()) { + model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed."); + } else { + model_status_ = Status::OK(); + } +} + +SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename, + const SPieceTokenizerLoadType load_type, + const SPieceTokenizerOutType out_type) + : load_type_(load_type), out_type_(out_type) { + (void)GetModelRealPath(model_path, model_filename); + auto status = processor_.Load(file_path_); + if (!status.ok()) { + model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed."); + } else { + model_status_ = Status::OK(); + } +} + +Status SentencePieceTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!model_status_.IsOk()) { + return model_status_; + } + + if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); + } + + std::string_view sentence_v; + RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {})); + std::string sentence{sentence_v}; + + if (out_type_ == SPieceTokenizerOutType::kString) { + std::vector pieces; + auto status = processor_.Encode(sentence, &pieces); + if (!status.ok()) { + RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error"); + } + RETURN_IF_NOT_OK(Tensor::CreateFromVector(pieces, output)); + } else { + std::vector ids; + auto status = processor_.Encode(sentence, &ids); + if (!status.ok()) { + RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error"); + } + RETURN_IF_NOT_OK(Tensor::CreateFromVector(ids, output)); + } + return Status::OK(); +} + +Status SentencePieceTokenizerOp::GetModelRealPath(const std::string &model_path, const std::string &filename) { + char real_path[PATH_MAX] = {0}; + if (file_path_.size() >= PATH_MAX) { + RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid."); + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, common::SafeCStr(model_path), PATH_MAX) == nullptr) { + RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid."); + } +#else + if (realpath(common::SafeCStr(model_path), real_path) == nullptr) { + RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid."); + } +#endif + std::string abs_path = real_path; + file_path_ = (Path(abs_path) / Path(filename)).toString(); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h new file mode 100644 index 0000000000..c7baca00b5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_SENTENCE_PIECE_TOKENIZER_OP_H +#define DATASET_SENTENCE_PIECE_TOKENIZER_OP_H + +#include + +#include +#include +#include + +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/text/sentence_piece_vocab.h" + +namespace mindspore { +namespace dataset { +enum class SPieceTokenizerOutType { kString = 0, kInt = 1 }; +enum class SPieceTokenizerLoadType { kFile = 0, kModel = 1 }; + +class SentencePieceTokenizerOp : public TensorOp { + public: + SentencePieceTokenizerOp(const std::shared_ptr vocab, SPieceTokenizerLoadType load_type, + const SPieceTokenizerOutType out_type); + + SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename, + const SPieceTokenizerLoadType load_type, const SPieceTokenizerOutType out_type); + + ~SentencePieceTokenizerOp() override = default; + + Status GetModelRealPath(const std::string &model_path, const std::string &filename); + + void Print(std::ostream &out) const override { + out << Name() << " out_type = " << out_type_ << " load_type = " << load_type_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kSentencepieceTokenizerOp; } + + protected: + SPieceTokenizerOutType out_type_; + std::shared_ptr vocab_; + std::string file_path_; + SPieceTokenizerLoadType load_type_; + sentencepiece::SentencePieceProcessor processor_; + Status model_status_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_SENTENCE_PIECE_TOKENIZER_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc new file mode 100644 index 0000000000..f857f1ab96 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/sliding_window_op.h" + +namespace mindspore { +namespace dataset { +Status SlidingWindowOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now."); + CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now."); + + std::vector input_shape = {input->shape()}; + std::vector output_shape = {TensorShape({})}; + RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape)); + + RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_)); + return Status::OK(); +} + +Status SlidingWindowOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); + int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size()); + TensorShape input_shape = inputs[0]; + std::vector output_shape_initializer; + + // if a data row has fewer items than width, the corresponding result row will be empty. + if (input_shape[axis] >= width_) { + for (int32_t idx = 0; idx < input_shape.Size(); ++idx) { + if (idx != axis) { + output_shape_initializer.push_back(input_shape[idx]); + } else { + output_shape_initializer.push_back(input_shape[idx] - (width_ - 1)); + output_shape_initializer.push_back(width_); + } + } + } + + outputs.pop_back(); + outputs.emplace_back(TensorShape(output_shape_initializer)); + CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h new file mode 100644 index 0000000000..5725e94a7b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sliding_window_op.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/data_utils.h" + +namespace mindspore { +namespace dataset { + +class SlidingWindowOp : public TensorOp { + public: + /// \brief Constructor of SlidingWindowOp. + /// \param[in] width - The axis along which sliding window is computed. + /// \param[in] axis - The width of the window. + /// \return Status return code + explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {} + + /// \brief Destructor of SlidingWindowOp. + ~SlidingWindowOp() override = default; + + /// \brief Perform sliding window to tensor. + /// \param[in] input - Input tensor of Op. + /// \param[out] output - output tensor of Op. + /// \return Status return code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + /// \brief Calculate tensor shape for output tensor. + /// \param[in] inputs - Input tensor shapes. + /// \param[out] outputs - Output tensor shapes. + /// \return Status return code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + /// \brief Print name of op. + std::string Name() const override { return kSlidingWindowOp; } + + private: + uint32_t width_; // The width of the window. Must be an integer and greater than zero. + int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now. +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc index a6685a2d64..3fda769ea2 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc @@ -114,7 +114,7 @@ Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::s casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } @@ -157,7 +157,7 @@ Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std: casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } @@ -165,7 +165,7 @@ Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_p // special case, float16 does not exist in c++, no native support for // casting, so cast to float first then use this method, which use Eigen. std::shared_ptr temp; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType("float32"), &temp)); RETURN_IF_NOT_OK(ToFloat(input, &temp)); RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); return Status::OK(); @@ -200,7 +200,7 @@ Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } @@ -233,7 +233,7 @@ Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_pt casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h index 8582fcf073..d13c5e9236 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ -#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ #include #include @@ -78,4 +78,4 @@ class ToNumberOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h index ce82735645..dbfcd6b61a 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ -#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ #include #include @@ -36,8 +36,6 @@ class TruncateSequencePairOp : public TensorOp { ~TruncateSequencePairOp() override = default; - void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kTruncateSequencePairOp; } @@ -47,4 +45,4 @@ class TruncateSequencePairOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc index e08f61100b..c8b33d0ce4 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc @@ -55,15 +55,13 @@ Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output offsets_start.push_back(0); offsets_limit.push_back(0); } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor)); + output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h index 415d99b451..8a0bc01391 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ #include #include @@ -33,8 +33,6 @@ class UnicodeCharTokenizerOp : public TensorOp { ~UnicodeCharTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kUnicodeCharTokenizerOp; } @@ -45,4 +43,4 @@ class UnicodeCharTokenizerOp : public TensorOp { } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc index 60fe8dd0e4..43ebbda42f 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc @@ -96,15 +96,12 @@ Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *outp offsets_start.push_back(0); offsets_limit.push_back(0); } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h index fc3b9e620a..2fe5629682 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ #include #include @@ -36,8 +36,6 @@ class UnicodeScriptTokenizerOp : public TensorOp { ~UnicodeScriptTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "UnicodeScriptTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kUnicodeScriptTokenizerOp; } @@ -48,4 +46,4 @@ class UnicodeScriptTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc index d3bb32081e..c872777813 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc @@ -79,15 +79,12 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) offsets_start.push_back(0); offsets_limit.push_back(0); } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h index 7cc37fd705..adbc6f6244 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ #include #include @@ -33,8 +33,6 @@ class WhitespaceTokenizerOp : public TensorOp { ~WhitespaceTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "WhitespaceTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; std::string Name() const override { return kWhitespaceTokenizerOp; } @@ -44,4 +42,4 @@ class WhitespaceTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc index f0bd448e39..04a1274b03 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc @@ -1,157 +1,154 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" -#include -#include - -namespace mindspore { -namespace dataset { - -const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; -const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; -const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; -const bool WordpieceTokenizerOp::kDefWithOffsets = false; - -WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, - const int &max_bytes_per_token, const std::string &unknown_token, - const bool &with_offsets) - : vocab_(vocab), - suffix_indicator_(suffix_indicator), - max_bytes_per_token_(max_bytes_per_token), - unknown_token_(unknown_token), - with_offsets_(with_offsets) {} - -Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, - bool *out_found, int *out_end) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); - *out_found = false; - for (int i = runes.size() - 1; i >= 0; i--) { - *out_end = runes[i].offset + runes[i].len; - int len = *out_end - start; - std::string word = input_token.substr(start, len); - if (start > 0) { - word = suffix_indicator_ + word; - } - if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { - *out_found = true; - break; - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, - std::vector *out_tokens, std::vector *offsets_start, - std::vector *offsets_limit) const { - out_tokens->clear(); - offsets_start->push_back(basic_start); - if (unknown_token_.empty()) { - out_tokens->emplace_back(input_token); - offsets_limit->push_back(basic_start + input_token.length()); - } else { - out_tokens->emplace_back(unknown_token_); - offsets_limit->push_back(basic_start + input_token.length()); - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, - std::vector *out_tokens) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); - std::string subword = input_token.substr(start, end - start); - if (start > 0) { - subword = suffix_indicator_ + subword; - } - out_tokens->emplace_back(subword); - return Status::OK(); -} - -Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, - std::vector *out_tokens, std::vector *offsets_start, - std::vector *offsets_limit) const { - if (input_token.size() > max_bytes_per_token_) { - offsets_start->push_back(basic_start); - if (!unknown_token_.empty()) { - offsets_limit->push_back(basic_start + unknown_token_.size()); - out_tokens->emplace_back(unknown_token_); - } else { - out_tokens->emplace_back(input_token); - offsets_limit->push_back(basic_start + input_token.size()); - } - return Status::OK(); - } - RuneStrArray runes; - if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - int end = 0; - for (int start = 0; start < input_token.size();) { - bool found = false; - RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); - if (found) { - RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); - offsets_start->push_back(static_cast(basic_start + start)); - offsets_limit->push_back(static_cast(basic_start + end)); - start = end; - } else { - return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); - } - dsize_t count = 0; - std::vector out_tokens; - std::vector offsets_start, offsets_limit; - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { - uint32_t basic_start = 0; - std::vector temp_tokens; - if (with_offsets_ && input.size() == 3) { - RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); - } - RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); - out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); - count++; - } - if (out_tokens.empty()) { - out_tokens.emplace_back(""); - offsets_start.push_back(0); - offsets_limit.push_back(0); - } - token_tensor = std::make_shared(out_tokens, TensorShape({(dsize_t)out_tokens.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include +#include + +namespace mindspore { +namespace dataset { + +const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; +const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; +const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; +const bool WordpieceTokenizerOp::kDefWithOffsets = false; + +WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, + const int &max_bytes_per_token, const std::string &unknown_token, + const bool &with_offsets) + : vocab_(vocab), + suffix_indicator_(suffix_indicator), + max_bytes_per_token_(max_bytes_per_token), + unknown_token_(unknown_token), + with_offsets_(with_offsets) {} + +Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, + bool *out_found, int *out_end) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); + *out_found = false; + for (int i = runes.size() - 1; i >= 0; i--) { + *out_end = runes[i].offset + runes[i].len; + int len = *out_end - start; + std::string word = input_token.substr(start, len); + if (start > 0) { + word = suffix_indicator_ + word; + } + if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { + *out_found = true; + break; + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + out_tokens->clear(); + offsets_start->push_back(basic_start); + if (unknown_token_.empty()) { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.length()); + } else { + out_tokens->emplace_back(unknown_token_); + offsets_limit->push_back(basic_start + input_token.length()); + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_tokens) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); + std::string subword = input_token.substr(start, end - start); + if (start > 0) { + subword = suffix_indicator_ + subword; + } + out_tokens->emplace_back(subword); + return Status::OK(); +} + +Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + if (input_token.size() > max_bytes_per_token_) { + offsets_start->push_back(basic_start); + if (!unknown_token_.empty()) { + offsets_limit->push_back(basic_start + unknown_token_.size()); + out_tokens->emplace_back(unknown_token_); + } else { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.size()); + } + return Status::OK(); + } + RuneStrArray runes; + if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + int end = 0; + for (int start = 0; start < input_token.size();) { + bool found = false; + RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); + if (found) { + RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); + offsets_start->push_back(static_cast(basic_start + start)); + offsets_limit->push_back(static_cast(basic_start + end)); + start = end; + } else { + return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); + } + dsize_t count = 0; + std::vector out_tokens; + std::vector offsets_start, offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { + uint32_t basic_start = 0; + std::vector temp_tokens; + if (with_offsets_ && input.size() == 3) { + RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); + } + RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); + out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); + count++; + } + if (out_tokens.empty()) { + out_tokens.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + Tensor::CreateFromVector(out_tokens, &token_tensor); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h index 4f9c76f57e..d636ab8e0f 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ #include #include #include @@ -44,8 +44,6 @@ class WordpieceTokenizerOp : public TensorOp { ~WordpieceTokenizerOp() override = default; - void Print(std::ostream &out) const override { out << "WordpieceTokenizerOp"; } - Status Compute(const TensorRow &input, TensorRow *output) override; protected: @@ -69,4 +67,4 @@ class WordpieceTokenizerOp : public TensorOp { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/sentence_piece_vocab.cc b/mindspore/ccsrc/minddata/dataset/text/sentence_piece_vocab.cc new file mode 100644 index 0000000000..d9935112f4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/sentence_piece_vocab.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/sentence_piece_vocab.h" + +#include +#include +#include + +#include "utils/ms_utils.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +SentencePieceVocab::SentencePieceVocab() : model_proto_("") {} + +Status SentencePieceVocab::BuildFromFile(const std::vector &path_list, const int vocab_size, + const float character_coverage, const SentencePieceModel model_type, + const std::unordered_map ¶ms, + std::shared_ptr *vocab) { + std::unordered_map unorder_map; + + // the input of sentence is comma separated string + std::string input_str = ""; + for (auto path : path_list) { + input_str += path; + input_str += ","; + } + input_str.pop_back(); + unorder_map["input"] = input_str; + unorder_map["vocab_size"] = std::to_string(vocab_size); + unorder_map["model_prefix"] = ""; + unorder_map["minloglevel"] = "1"; + unorder_map["character_coverage"] = std::to_string(character_coverage); + if (model_type == SentencePieceModel::kWord) { + unorder_map["model_type"] = "WORD"; + } else if (model_type == SentencePieceModel::kBpe) { + unorder_map["model_type"] = "BPE"; + } else if (model_type == SentencePieceModel::kChar) { + unorder_map["model_type"] = "CHAR"; + } else { + unorder_map["model_type"] = "UNIGRAM"; + } + + // filter some params that set by function param + // filter model_prefix that must be empty + for (auto param : params) { + std::string key = param.first; + if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" || + key == "model_type") { + continue; + } + unorder_map[key] = param.second; + } + + // set sentence lib's log + unorder_map["minloglevel"] = "1"; + *vocab = std::make_shared(); + std::string model_proto; + sentencepiece::util::Status s_status = sentencepiece::SentencePieceTrainer::Train(unorder_map, nullptr, &model_proto); + if (!s_status.ok()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message()); + } + vocab->get()->set_model_proto(model_proto); + + return Status::OK(); +} + +Status SentencePieceVocab::SaveModel(const std::shared_ptr *vocab, std::string path, + std::string filename) { + char real_path[PATH_MAX] = {0}; + + if (path.size() >= PATH_MAX) { + RETURN_STATUS_UNEXPECTED("sentence model path is invalid."); + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + RETURN_STATUS_UNEXPECTED("sentence model path is invalid."); + } +#else + if (realpath(common::SafeCStr(path), real_path) == nullptr) { + RETURN_STATUS_UNEXPECTED("sentence model path is invalid."); + } +#endif + + std::string abs_real_path = (Path(real_path) / Path(filename)).toString(); + std::ofstream os_file(abs_real_path, std::ios::out); + (void)os_file.write(vocab->get()->model_proto().data(), vocab->get()->model_proto().size()); + os_file.close(); + return Status::OK(); +} + +const std::string &SentencePieceVocab::model_proto() { return model_proto_; } + +void SentencePieceVocab::set_model_proto(const std::string model_proto) { model_proto_ = model_proto; } + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/sentence_piece_vocab.h b/mindspore/ccsrc/minddata/dataset/text/sentence_piece_vocab.h new file mode 100644 index 0000000000..efe4981af7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/sentence_piece_vocab.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_SENTENCE_PIECE_VOCAB_H_ +#define DATASET_TEXT_SENTENCE_PIECE_VOCAB_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +enum class SentencePieceModel { kUnigram = 0, kBpe = 1, kChar = 2, kWord = 3 }; +class SentencePieceVocab { + public: + static Status BuildFromFile(const std::vector &path_list, const int vocab_size, + const float character_coverage, const SentencePieceModel model_type, + const std::unordered_map ¶ms, + std::shared_ptr *vocab); + static Status SaveModel(const std::shared_ptr *vocab, std::string path, std::string filename); + SentencePieceVocab(); + + ~SentencePieceVocab() = default; + + const std::string &model_proto(); + + void set_model_proto(const std::string model_proto); + + private: + std::string model_proto_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_SENTENCE_PIECE_VOCAB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/vocab.h b/mindspore/ccsrc/minddata/dataset/text/vocab.h index 6bf6c488c5..06da5f8f33 100644 --- a/mindspore/ccsrc/minddata/dataset/text/vocab.h +++ b/mindspore/ccsrc/minddata/dataset/text/vocab.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_TEXT_VOCAB_H_ -#define DATASET_TEXT_VOCAB_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VOCAB_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VOCAB_H_ #include #include @@ -85,4 +85,4 @@ class Vocab { } // namespace dataset } // namespace mindspore -#endif // DATASET_TEXT_VOCAB_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VOCAB_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h index b5eaed97a6..8c64c2940e 100644 --- a/mindspore/ccsrc/minddata/dataset/util/allocator.h +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_ALLOCATOR_H_ -#define DATASET_UTIL_ALLOCATOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_ #include #include @@ -175,4 +175,4 @@ class MemGuard { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_ALLOCATOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.h b/mindspore/ccsrc/minddata/dataset/util/arena.h index 8887757af1..2525bde142 100644 --- a/mindspore/ccsrc/minddata/dataset/util/arena.h +++ b/mindspore/ccsrc/minddata/dataset/util/arena.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_ARENA_H_ -#define DATASET_UTIL_ARENA_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ARENA_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ARENA_H_ #include #include @@ -102,4 +102,4 @@ class Arena : public MemoryPool { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_ARENA_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ARENA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/auto_index.h b/mindspore/ccsrc/minddata/dataset/util/auto_index.h index 0fe55159e6..a1c3a613e2 100644 --- a/mindspore/ccsrc/minddata/dataset/util/auto_index.h +++ b/mindspore/ccsrc/minddata/dataset/util/auto_index.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_AUTO_INDEX_H_ -#define DATASET_UTIL_AUTO_INDEX_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_AUTO_INDEX_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_AUTO_INDEX_H_ #include #include @@ -96,4 +96,4 @@ class AutoIndexObj : public BPlusTree { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_AUTO_INDEX_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_AUTO_INDEX_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/bit.h b/mindspore/ccsrc/minddata/dataset/util/bit.h index f02f39585c..e4872a3662 100644 --- a/mindspore/ccsrc/minddata/dataset/util/bit.h +++ b/mindspore/ccsrc/minddata/dataset/util/bit.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_BIT_H_ -#define DATASET_UTIL_BIT_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_BIT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_BIT_H_ namespace mindspore { namespace dataset { @@ -72,4 +72,4 @@ Enum operator~(Enum v) { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_BIT_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_BIT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/btree.h b/mindspore/ccsrc/minddata/dataset/util/btree.h index 828976a0a1..06c84f7a66 100644 --- a/mindspore/ccsrc/minddata/dataset/util/btree.h +++ b/mindspore/ccsrc/minddata/dataset/util/btree.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_INDEX_H_ -#define DATASET_UTIL_INDEX_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INDEX_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INDEX_H_ #include #include @@ -453,7 +453,7 @@ class BPlusTree { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_INDEX_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INDEX_H_ #include "btree_impl.tpp" #include "btree_iterator.tpp" diff --git a/mindspore/ccsrc/minddata/dataset/util/buddy.h b/mindspore/ccsrc/minddata/dataset/util/buddy.h index b1bcd3ce41..d95c2b0528 100644 --- a/mindspore/ccsrc/minddata/dataset/util/buddy.h +++ b/mindspore/ccsrc/minddata/dataset/util/buddy.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_BUDDY_H_ -#define DATASET_UTIL_BUDDY_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_BUDDY_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_BUDDY_H_ #include #include @@ -130,4 +130,4 @@ class BuddySpace { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_BUDDY_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc index 22fb72eb8a..37c6107fb0 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/cache_pool.h" #include "minddata/dataset/util/services.h" diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h index cdb6da16b6..9bed5a2ef3 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_CACHE_POOL_H_ -#define DATASET_UTIL_CACHE_POOL_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_CACHE_POOL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_CACHE_POOL_H_ #include #include diff --git a/mindspore/ccsrc/minddata/dataset/util/circular_pool.h b/mindspore/ccsrc/minddata/dataset/util/circular_pool.h index a63afbd691..56e28aa6a2 100644 --- a/mindspore/ccsrc/minddata/dataset/util/circular_pool.h +++ b/mindspore/ccsrc/minddata/dataset/util/circular_pool.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_CIRCULAR_POOL_H_ -#define DATASET_UTIL_CIRCULAR_POOL_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_CIRCULAR_POOL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_CIRCULAR_POOL_H_ #include #include @@ -105,4 +105,4 @@ class CircularPool : public MemoryPool { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_CIRCULAR_POOL_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_CIRCULAR_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cond_var.h b/mindspore/ccsrc/minddata/dataset/util/cond_var.h index 88fcad24a2..24fd2bc2bc 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cond_var.h +++ b/mindspore/ccsrc/minddata/dataset/util/cond_var.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_COND_VAR_H_ -#define DATASET_UTIL_COND_VAR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_COND_VAR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_COND_VAR_H_ #include #include @@ -56,4 +56,4 @@ class CondVar : public IntrpResource { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_COND_VAR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_COND_VAR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h b/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h index 9d78e2cd32..00ba0d84bb 100644 --- a/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_INTRP_RESOURCE_H_ -#define DATASET_UTIL_INTRP_RESOURCE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INTRP_RESOURCE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INTRP_RESOURCE_H_ #include #include "minddata/dataset/util/status.h" @@ -49,4 +49,4 @@ class IntrpResource { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_INTRP_RESOURCE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INTRP_RESOURCE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc b/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc index a82c82cdc9..80417ac2a0 100644 --- a/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc @@ -15,7 +15,7 @@ */ #include "minddata/dataset/util/intrp_service.h" #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/services.h" #include "minddata/dataset/util/task_manager.h" diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_service.h b/mindspore/ccsrc/minddata/dataset/util/intrp_service.h index cb6bf30c73..2aa6987e6b 100644 --- a/mindspore/ccsrc/minddata/dataset/util/intrp_service.h +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_service.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_INTRP_SERVICE_H_ -#define DATASET_UTIL_INTRP_SERVICE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INTRP_SERVICE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INTRP_SERVICE_H_ #include #include @@ -60,4 +60,4 @@ class IntrpService : public Service { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_INTRP_SERVICE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_INTRP_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/list.h b/mindspore/ccsrc/minddata/dataset/util/list.h index 06f26ab57c..b7eb8ebd01 100644 --- a/mindspore/ccsrc/minddata/dataset/util/list.h +++ b/mindspore/ccsrc/minddata/dataset/util/list.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_LIST_H_ -#define DATASET_UTIL_LIST_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LIST_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LIST_H_ #include #include @@ -213,4 +213,4 @@ struct List { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_LIST_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LIST_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/lock.h b/mindspore/ccsrc/minddata/dataset/util/lock.h index 9492d34bdf..35f15ef934 100644 --- a/mindspore/ccsrc/minddata/dataset/util/lock.h +++ b/mindspore/ccsrc/minddata/dataset/util/lock.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_LOCK_H_ -#define DATASET_UTIL_LOCK_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOCK_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOCK_H_ #include #include @@ -170,4 +170,4 @@ class LockGuard { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_LOCK_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOCK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/memory_pool.h b/mindspore/ccsrc/minddata/dataset/util/memory_pool.h index c7cc473109..33e6012626 100644 --- a/mindspore/ccsrc/minddata/dataset/util/memory_pool.h +++ b/mindspore/ccsrc/minddata/dataset/util/memory_pool.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_MEMORY_POOL_H_ -#define DATASET_UTIL_MEMORY_POOL_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_MEMORY_POOL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_MEMORY_POOL_H_ #include #include @@ -56,4 +56,4 @@ void operator delete(void *, std::shared_ptr); void operator delete[](void *, std::shared_ptr); -#endif // DATASET_UTIL_MEMORY_POOL_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/path.cc b/mindspore/ccsrc/minddata/dataset/util/path.cc index 8740ecb8e0..97416be6a8 100644 --- a/mindspore/ccsrc/minddata/dataset/util/path.cc +++ b/mindspore/ccsrc/minddata/dataset/util/path.cc @@ -22,7 +22,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/mindspore/ccsrc/minddata/dataset/util/path.h b/mindspore/ccsrc/minddata/dataset/util/path.h index 8bc07ca8f3..17dc015c70 100644 --- a/mindspore/ccsrc/minddata/dataset/util/path.h +++ b/mindspore/ccsrc/minddata/dataset/util/path.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_PATH_H_ -#define DATASET_UTIL_PATH_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_PATH_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_PATH_H_ #include #include @@ -111,4 +111,4 @@ class Path { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_PATH_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_PATH_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/queue.h b/mindspore/ccsrc/minddata/dataset/util/queue.h index 7a0a987499..021ee5ab10 100644 --- a/mindspore/ccsrc/minddata/dataset/util/queue.h +++ b/mindspore/ccsrc/minddata/dataset/util/queue.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_QUEUE_H_ -#define DATASET_UTIL_QUEUE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_ #include #include @@ -24,7 +24,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/log_adapter.h" #include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/services.h" @@ -253,4 +253,4 @@ class QueueList { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_QUEUE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/random.h b/mindspore/ccsrc/minddata/dataset/util/random.h index d2658f67ec..7e8dd48e88 100644 --- a/mindspore/ccsrc/minddata/dataset/util/random.h +++ b/mindspore/ccsrc/minddata/dataset/util/random.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_RANDOM_H_ -#define DATASET_UTIL_RANDOM_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_RANDOM_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_RANDOM_H_ #if defined(_WIN32) || defined(_WIN64) #include @@ -71,4 +71,4 @@ inline uint32_t GetSeed() { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_RANDOM_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_RANDOM_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.h b/mindspore/ccsrc/minddata/dataset/util/semaphore.h index d07398acb1..88935dc6f7 100644 --- a/mindspore/ccsrc/minddata/dataset/util/semaphore.h +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_SEMAPHORE_H_ -#define DATASET_UTIL_SEMAPHORE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SEMAPHORE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SEMAPHORE_H_ #include "minddata/dataset/util/cond_var.h" @@ -51,4 +51,4 @@ class Semaphore { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_SEMAPHORE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/service.cc b/mindspore/ccsrc/minddata/dataset/util/service.cc index 19d60ab47a..d7e1e02e6a 100644 --- a/mindspore/ccsrc/minddata/dataset/util/service.cc +++ b/mindspore/ccsrc/minddata/dataset/util/service.cc @@ -34,7 +34,18 @@ Status Service::ServiceStart() { state_ = STATE::kStartInProg; // At this point, we will let go of the lock. This allow others to proceed. lck.Unlock(); - RETURN_IF_NOT_OK(DoServiceStart()); + // Call the real implementation from the derived class. + Status rc = DoServiceStart(); + // If we hit any error, change the state back into the initial state. + // It is possible that the user may want to drive a clean up by calling + // ServiceStop but if it will end up in a loop because of the state is still + // kStartInProg. + if (rc.IsError()) { + lck.Lock(); + state_ = STATE::kStopped; + lck.Unlock(); + return rc; + } // Lock again to change state. lck.Lock(); state_ = STATE::kRunning; diff --git a/mindspore/ccsrc/minddata/dataset/util/service.h b/mindspore/ccsrc/minddata/dataset/util/service.h index 2b9c7197fe..325f5f059a 100644 --- a/mindspore/ccsrc/minddata/dataset/util/service.h +++ b/mindspore/ccsrc/minddata/dataset/util/service.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_SERVICE_H_ -#define DATASET_UTIL_SERVICE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICE_H_ #include #include "minddata/dataset/util/lock.h" @@ -50,4 +50,4 @@ class Service { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_SERVICE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/services.cc b/mindspore/ccsrc/minddata/dataset/util/services.cc index 547773e0f1..44eba24ca6 100644 --- a/mindspore/ccsrc/minddata/dataset/util/services.cc +++ b/mindspore/ccsrc/minddata/dataset/util/services.cc @@ -77,9 +77,13 @@ Status Services::CreateAllInstances() { rc = sa_[kSlotTaskMgr_]->ServiceStart(); RETURN_IF_NOT_OK(rc); // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers +#if !defined(_WIN32) && !defined(_WIN64) sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); RETURN_IF_NOT_OK(rc); rc = sa_[kSlotCacheMgr_]->ServiceStart(); +#else + sa_[kSlotCacheMgr_] = nullptr; +#endif return rc; } diff --git a/mindspore/ccsrc/minddata/dataset/util/services.h b/mindspore/ccsrc/minddata/dataset/util/services.h index c7adea0b6e..9d4dca9765 100644 --- a/mindspore/ccsrc/minddata/dataset/util/services.h +++ b/mindspore/ccsrc/minddata/dataset/util/services.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_SERVICES_H_ -#define DATASET_UTIL_SERVICES_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ #include #include @@ -101,4 +101,4 @@ class Services { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_SERVICES_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/sig_handler.h b/mindspore/ccsrc/minddata/dataset/util/sig_handler.h index af40738feb..03f644a45c 100644 --- a/mindspore/ccsrc/minddata/dataset/util/sig_handler.h +++ b/mindspore/ccsrc/minddata/dataset/util/sig_handler.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_SIG_HANDLER_H_ -#define DATASET_UTIL_SIG_HANDLER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SIG_HANDLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SIG_HANDLER_H_ #include #include @@ -33,4 +33,4 @@ extern void IntHandler(int sig_num, // The signal that was raised } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_SIG_HANDLER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SIG_HANDLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.h b/mindspore/ccsrc/minddata/dataset/util/slice.h index 1caee0f816..304a7e8698 100644 --- a/mindspore/ccsrc/minddata/dataset/util/slice.h +++ b/mindspore/ccsrc/minddata/dataset/util/slice.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_SLICE_H_ -#define DATASET_UTIL_SLICE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SLICE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SLICE_H_ #include #include @@ -125,4 +125,4 @@ class WritableSlice : public ReadableSlice { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_SLICE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/status.cc b/mindspore/ccsrc/minddata/dataset/util/status.cc index 3fc498b701..9d60bfe6a6 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.cc +++ b/mindspore/ccsrc/minddata/dataset/util/status.cc @@ -15,7 +15,7 @@ */ #include "minddata/dataset/util/status.h" #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/task_manager.h" namespace mindspore { diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index 7a480f4239..b919b4dc4e 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_STATUS_H_ -#define DATASET_UTIL_STATUS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STATUS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STATUS_H_ #if defined(__GNUC__) || defined(__clang__) #define DEPRECATED __attribute__((deprecated)) @@ -134,4 +134,4 @@ class Status { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_STATUS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STATUS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc index 506495227d..b64926e596 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc @@ -19,7 +19,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/status.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.h b/mindspore/ccsrc/minddata/dataset/util/storage_container.h index a304012b60..1eec09f942 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_container.h +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ -#define DATASET_UTIL_STORAGE_CONTAINER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STORAGE_CONTAINER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STORAGE_CONTAINER_H_ #include #include @@ -76,4 +76,4 @@ class StorageContainer { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc index 2f85d00a45..82f39fd7ae 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc @@ -19,7 +19,7 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/services.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.h b/mindspore/ccsrc/minddata/dataset/util/storage_manager.h index e79e7c6e63..764ac83575 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_manager.h +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ -#define DATASET_UTIL_STORAGE_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STORAGE_MANAGER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STORAGE_MANAGER_H_ #include #include @@ -73,4 +73,4 @@ class StorageManager : public Service { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_STORAGE_MANAGER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/system_pool.h b/mindspore/ccsrc/minddata/dataset/util/system_pool.h index 3a7e61d16b..789252dc8c 100644 --- a/mindspore/ccsrc/minddata/dataset/util/system_pool.h +++ b/mindspore/ccsrc/minddata/dataset/util/system_pool.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_SYSTEM_POOL_H_ -#define DATASET_UTIL_SYSTEM_POOL_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SYSTEM_POOL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SYSTEM_POOL_H_ #include #include @@ -72,4 +72,4 @@ class SystemPool : public MemoryPool { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_SYSTEM_POOL_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SYSTEM_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/task.cc b/mindspore/ccsrc/minddata/dataset/util/task.cc index 39d754e806..fb71e93379 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task.cc +++ b/mindspore/ccsrc/minddata/dataset/util/task.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "minddata/dataset/util/task.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/util/task_manager.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/minddata/dataset/util/task.h b/mindspore/ccsrc/minddata/dataset/util/task.h index 9309a3de7b..951beb9a4c 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task.h +++ b/mindspore/ccsrc/minddata/dataset/util/task.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_TASK_H_ -#define DATASET_UTIL_TASK_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ #include #include @@ -122,4 +122,4 @@ extern thread_local Task *gMyTask; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_TASK_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.cc b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc index fefea0b97c..e72fed5d07 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc @@ -296,7 +296,13 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio return Status::OK(); } -void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } +void TaskGroup::interrupt_all() noexcept { + // There is a racing condition if we don't stop the interrupt service at this point. New resource + // may come in and not being picked up after we call InterruptAll(). So stop new comers and then + // interrupt any existing resources. + (void)intrp_svc_->ServiceStop(); + intrp_svc_->InterruptAll(); +} Status TaskGroup::join_all(Task::WaitFlag wf) { Status rc; @@ -312,7 +318,6 @@ Status TaskGroup::join_all(Task::WaitFlag wf) { } Status TaskGroup::DoServiceStop() { - intrp_svc_->ServiceStop(); interrupt_all(); return (join_all(Task::WaitFlag::kNonBlocking)); } diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.h b/mindspore/ccsrc/minddata/dataset/util/task_manager.h index 3030390bab..7b81bc8c71 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task_manager.h +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_TASK_MANAGER_H_ -#define DATASET_UTIL_TASK_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_ #if !defined(_WIN32) && !defined(_WIN64) #include @@ -178,4 +178,4 @@ inline Status GetInterruptStatus() { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_TASK_MANAGER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/treap.h b/mindspore/ccsrc/minddata/dataset/util/treap.h index 777b7c5701..50c59b4f76 100644 --- a/mindspore/ccsrc/minddata/dataset/util/treap.h +++ b/mindspore/ccsrc/minddata/dataset/util/treap.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_TREAP_H_ -#define DATASET_UTIL_TREAP_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_ #include #include @@ -404,4 +404,4 @@ class Treap { }; } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_TREAP_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TREAP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/wait_post.h b/mindspore/ccsrc/minddata/dataset/util/wait_post.h index afd3bea38b..7af4151a57 100644 --- a/mindspore/ccsrc/minddata/dataset/util/wait_post.h +++ b/mindspore/ccsrc/minddata/dataset/util/wait_post.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef DATASET_UTIL_WAIT_POST_H_ -#define DATASET_UTIL_WAIT_POST_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_WAIT_POST_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_WAIT_POST_H_ #include #include "minddata/dataset/util/cond_var.h" @@ -50,4 +50,4 @@ class WaitPost { } // namespace dataset } // namespace mindspore -#endif // DATASET_UTIL_WAIT_POST_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_WAIT_POST_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc index d9e51efc4e..7c7f79ccfb 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -16,7 +16,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/mindrecord/include/common/shard_utils.h" #include "minddata/mindrecord/include/shard_error.h" #include "minddata/mindrecord/include/shard_index_generator.h" @@ -133,6 +133,7 @@ void BindGlobalParams(py::module *m) { (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; + (*m).attr("MAX_FILE_COUNT") = kMaxFileCount; (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); } diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc index b5021802a0..185f6cbb60 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc @@ -15,7 +15,7 @@ */ #include "minddata/mindrecord/include/common/shard_utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "./securec.h" using mindspore::LogStream; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h index 3b3698ca68..dfc41aa4a0 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ -#define MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ #include #include @@ -37,4 +37,4 @@ py::object FromJsonImpl(const json &j); json ToJsonImpl(const py::handle &obj); } // namespace detail } // namespace nlohmann -#endif // MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h index bd1cda8a99..5ee6f70a82 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ -#define MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ #include #include @@ -104,7 +104,8 @@ const uint64_t kInt64Len = 8; const uint64_t kMinFileSize = kInt64Len; const int kMinShardCount = 1; -const int kMaxShardCount = 1000; +const int kMaxShardCount = 1000; // write +const int kMaxFileCount = 4096; // read const int kMinConsumerCount = 1; const int kMaxConsumerCount = 128; @@ -137,6 +138,10 @@ const std::set kScalarFieldTypeSet = {"string", "int32", "int64", " // number field list const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; +const std::unordered_map kTypesMap = { + {"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"}, + {"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"}, + {"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}}; /// \brief split a string using a character /// \param[in] field target string /// \param[in] separator a character for spliting @@ -179,4 +184,4 @@ uint32_t GetMaxThreadNum(); } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h index ed1e748afe..b4a0bcae6d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ -#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ #include #include @@ -60,4 +60,4 @@ class ShardCategory : public ShardOperator { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h index f6353ed3ce..9510eeed1c 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ -#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_COLUMN_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_COLUMN_H_ #include #include @@ -164,4 +164,4 @@ class ShardColumn { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_COLUMN_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h index f166ec1e6c..6cd332c028 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ #include #include @@ -29,9 +29,10 @@ namespace mindspore { namespace mindrecord { class ShardDistributedSample : public ShardSample { public: - ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); + ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed, + int no_of_samples = 0); - ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); + ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0); void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } @@ -50,4 +51,4 @@ class ShardDistributedSample : public ShardSample { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h index 8488ca70ce..dca284ea1c 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_ERROR_H_ -#define MINDRECORD_INCLUDE_SHARD_ERROR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_ERROR_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_ERROR_H_ #include #include @@ -81,4 +81,4 @@ std::string ErrnoToMessage(MSRStatus status); } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_ERROR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_ERROR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h index 67169e8696..51928d7874 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_HEADER_H_ -#define MINDRECORD_INCLUDE_SHARD_HEADER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ #include #include @@ -124,6 +124,10 @@ class ShardHeader { MSRStatus FileToPages(const std::string dump_file_name); + static MSRStatus initialize(const std::shared_ptr *header_ptr, const json &schema, + const std::vector &index_fields, std::vector &blob_fields, + uint64_t &schema_id); + private: MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); @@ -148,7 +152,7 @@ class ShardHeader { MSRStatus CheckIndexField(const std::string &field, const json &schema); - void ParsePage(const json &page, int shard_index, bool load_dataset); + MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset); MSRStatus ParseStatistics(const json &statistics); @@ -183,4 +187,4 @@ class ShardHeader { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_HEADER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h index 79b10893fb..4e38c54fd2 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INDEX_H -#define MINDRECORD_INDEX_H +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INDEX_H +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INDEX_H #pragma once #include @@ -62,4 +62,4 @@ class Index { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INDEX_H +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INDEX_H diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h index fb85d9adbc..7814840aef 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ -#define MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ #include #include @@ -57,6 +57,8 @@ class ShardIndexGenerator { /// \brief create databases for indexes MSRStatus WriteToDatabase(); + static MSRStatus finalize(const std::vector file_names); + private: static int Callback(void *not_used, int argc, char **argv, char **az_col_name); @@ -117,4 +119,4 @@ class ShardIndexGenerator { }; } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h index b5ea53b759..2ba4498063 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ -#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #include #include "minddata/mindrecord/include/shard_task.h" @@ -60,4 +60,4 @@ class ShardOperator { }; } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h index 01c70acf29..2e9cc7dd1a 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_PAGE_H_ -#define MINDRECORD_INCLUDE_SHARD_PAGE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_PAGE_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_PAGE_H_ #include #include @@ -103,4 +103,4 @@ class Page { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_PAGE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_PAGE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h index 2d420b563d..04f47db358 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ #include #include @@ -46,4 +46,4 @@ class ShardPkSample : public ShardCategory { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index b1b0c1397a..e08375b10f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_READER_H_ -#define MINDRECORD_INCLUDE_SHARD_READER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ #include #include @@ -63,7 +63,6 @@ using ROW_GROUP_BRIEF = using TASK_RETURN_CONTENT = std::pair, json>>>>; const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode -const int kNumPageInBuffer = 16; // page buffer size in block-reader mode class ShardReader { public: @@ -77,12 +76,10 @@ class ShardReader { /// \param[in] n_consumer number of threads when reading /// \param[in] selected_columns column list to be populated /// \param[in] operators operators applied to data, operator type is shuffle, sample or category - /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode /// \return MSRStatus the status of MSRStatus MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, const std::vector &selected_columns = {}, - const std::vector> &operators = {}, const bool &block_reader = false, - const int num_padded = 0); + const std::vector> &operators = {}, const int num_padded = 0); /// \brief open files and initialize reader, python API /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list @@ -189,10 +186,6 @@ class ShardReader { std::pair, json>>> GetNextById(const int64_t &task_id, const int32_t &consumer_id); - /// \brief return a batch in block-reader mode, given that one is ready - /// \return a batch of images and image data - std::vector, json>> GetBlockNext(); - /// \brief return a batch, given that one is ready, python API /// \return a batch of images and image data std::vector>, pybind11::object>> GetNextPy(); @@ -242,9 +235,6 @@ class ShardReader { /// \brief populate one row by task list in row-reader mode MSRStatus ConsumerByRow(int consumer_id); - /// \brief populate one row by task list in block-reader mode - MSRStatus ConsumerByBlock(int consumer_id); - /// \brief get offset address of images within page std::vector> GetImageOffset(int group_id, int shard_id, const std::pair &criteria = {"", ""}); @@ -262,10 +252,6 @@ class ShardReader { const std::pair &criteria = {"", ""}); - /// \brief create task list in block-reader mode - MSRStatus CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators); - /// \brief create category-applied task list MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, const std::shared_ptr &op); @@ -290,15 +276,10 @@ class ShardReader { /// \brief read one row by one task TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); - /// \brief get one row from buffer in block-reader mode - std::shared_ptr, json>>> GetRowFromBuffer(int bufId, int rowId); - /// \brief get labels from binary file std::pair> GetLabelsFromBinaryFile( int shard_id, const std::vector &columns, const std::vector> &label_offsets); - MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); - /// \brief get classes in one shard void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); @@ -349,18 +330,8 @@ class ShardReader { // map of delivery std::unordered_map, json>>>> delivery_map_; // Delivery/Iterator mode end - - // Block reader mode begin - bool block_reader_; // block-reader mode - int row_id_; // row id in one page - int num_blocks_; // number of pages - // raw data page - std::vector>, std::vector>>> delivery_block_; - std::unordered_set delivery_block_set_; // set of delivered pages - std::vector> buf_; // page buffer - // Block reader mode end }; } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_READER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h index ce813bc4bf..6e5d85372c 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ #include #include @@ -32,7 +32,7 @@ class ShardSample : public ShardOperator { ShardSample(int num, int den); - ShardSample(int num, int den, int par); + ShardSample(int num, int den, int par, int no_of_samples = 0); ShardSample(const std::vector &indices, uint32_t seed); @@ -58,4 +58,4 @@ class ShardSample : public ShardOperator { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h index 56eae85e5a..6efcca4dc1 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ -#define MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ #include #include @@ -87,4 +87,4 @@ class Schema { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h index 45d9bda338..fd7f7dee21 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ -#define MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ #include #include @@ -99,4 +99,4 @@ class ShardSegment : public ShardReader { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h index 724be9acaf..4205c405b9 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ #include #include @@ -45,4 +45,4 @@ class ShardSequentialSample : public ShardSample { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h index d7f736b55b..4fca36a5e4 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ #include #include "minddata/mindrecord/include/shard_operator.h" @@ -45,4 +45,4 @@ class ShardShuffle : public ShardOperator { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h index f100bb9833..cefe7a11ae 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h @@ -15,8 +15,8 @@ */ #pragma once -#ifndef MINDRECORD_STATISTICS_H -#define MINDRECORD_STATISTICS_H +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_STATISTICS_H +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_STATISTICS_H #include #include @@ -88,4 +88,4 @@ class Statistics { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_STATISTICS_H +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_STATISTICS_H diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h index f07da656f2..6074a036da 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_ -#define MINDRECORD_INCLUDE_SHARD_TASK_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ #include #include @@ -64,4 +64,4 @@ class ShardTask { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_TASK_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h index 833928773e..ddb7e7cb8f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDRECORD_INCLUDE_SHARD_WRITER_H_ -#define MINDRECORD_INCLUDE_SHARD_WRITER_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include #include @@ -108,6 +108,13 @@ class ShardWriter { std::map> &blob_data, bool sign = true, bool parallel_writer = false); + MSRStatus MergeBlobData(const std::vector &blob_fields, + const std::map>> &row_bin_data, + std::shared_ptr> *output); + + static MSRStatus initialize(const std::unique_ptr *writer_ptr, + const std::vector &file_names); + private: /// \brief write shard header data to disk MSRStatus WriteShardHeader(); @@ -254,4 +261,4 @@ class ShardWriter { } // namespace mindrecord } // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_WRITER_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc index f9b18a3bf0..e2d6247735 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc @@ -16,7 +16,7 @@ #include #include "minddata/mindrecord/include/shard_index_generator.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using mindspore::LogStream; using mindspore::ExceptionType::NoExceptionType; @@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() { shard_no = task_++; } } +MSRStatus ShardIndexGenerator::finalize(const std::vector file_names) { + if (file_names.empty()) { + MS_LOG(ERROR) << "Mindrecord files is empty."; + return FAILED; + } + ShardIndexGenerator sg{file_names[0]}; + if (SUCCESS != sg.Build()) { + MS_LOG(ERROR) << "Failed to build index generator."; + return FAILED; + } + if (SUCCESS != sg.WriteToDatabase()) { + MS_LOG(ERROR) << "Failed to write to database."; + return FAILED; + } + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 84d7fddb6f..c42b732463 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -16,7 +16,7 @@ #include "minddata/mindrecord/include/shard_distributed_sample.h" #include "minddata/mindrecord/include/shard_reader.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using mindspore::LogStream; using mindspore::ExceptionType::NoExceptionType; @@ -43,9 +43,6 @@ ShardReader::ShardReader() { page_size_ = 0; header_size_ = 0; num_rows_ = 0; - row_id_ = 0; - num_blocks_ = 0; - block_reader_ = false; num_padded_ = 0; } @@ -252,7 +249,7 @@ std::vector> ShardReader::ReadRowGroupSummar if (shard_count <= 0) { return row_group_summary; } - if (shard_count <= kMaxShardCount) { + if (shard_count <= kMaxFileCount) { for (int shard_id = 0; shard_id < shard_count; ++shard_id) { // return -1 when page's size equals to 0. auto last_page_id = shard_header_->GetLastPageId(shard_id); @@ -855,8 +852,7 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, const std::vector &selected_columns, - const std::vector> &operators, const bool &block_reader, - int num_padded) { + const std::vector> &operators, int num_padded) { // Open file and set header by ShardReader auto ret = Init(file_paths, load_dataset); if (SUCCESS != ret) { @@ -890,19 +886,8 @@ MSRStatus ShardReader::Open(const std::vector &file_paths, bool loa operators_ = operators; - if (block_reader) { - block_reader_ = true; - if (Open() == FAILED) { - return FAILED; - } - delivery_block_ = std::vector>, std::vector>>>( - kNumPageInBuffer, std::shared_ptr>, std::vector>>{}); - buf_ = std::vector>(kNumPageInBuffer, std::vector(page_size_)); - } else { - block_reader_ = false; - if (Open(n_consumer) == FAILED) { - return FAILED; - } + if (Open(n_consumer) == FAILED) { + return FAILED; } return SUCCESS; } @@ -960,29 +945,13 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { } for (int x = 0; x < n_consumer_; ++x) { - if (block_reader_) { - thread_set_[x] = std::thread(&ShardReader::ConsumerByBlock, this, x); - } else { - thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); - } + thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); } MS_LOG(INFO) << "Launch read thread successfully."; return SUCCESS; } -MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators) { - CheckIfColumnInIndex(selected_columns_); - for (const auto &rg : row_group_summary) { - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto n_Rows = std::get<3>(rg); - tasks_.InsertTask(TaskType::kCommonTask, shard_id, group_id, std::vector{n_Rows}, json{}); - } - return SUCCESS; -} - MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, const std::shared_ptr &op) { CheckIfColumnInIndex(selected_columns_); @@ -1054,7 +1023,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector(ret); auto local_columns = std::get<2>(ret); - if (shard_count_ <= kMaxShardCount) { + if (shard_count_ <= kMaxFileCount) { for (int shard_id = 0; shard_id < shard_count_; shard_id++) { for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], @@ -1070,47 +1039,39 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, const std::vector> &operators) { - if (block_reader_) { - if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { + int category_operator = -1; + for (uint32_t i = 0; i < operators.size(); ++i) { + const auto &op = operators[i]; + if (std::dynamic_pointer_cast(op)) { + category_operator = static_cast(i); + break; + } + } + if (-1 == category_operator) { + if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { return FAILED; } - } else { - int category_operator = -1; - for (uint32_t i = 0; i < operators.size(); ++i) { - const auto &op = operators[i]; - if (std::dynamic_pointer_cast(op)) { - category_operator = static_cast(i); - break; + if (num_padded_ > 0) { + for (int i = 0; i < num_padded_; ++i) { + tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); } } - if (-1 == category_operator) { - if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { - return FAILED; - } - if (num_padded_ > 0) { - for (int i = 0; i < num_padded_; ++i) { - tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); - } - } - } else { - if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { - return FAILED; - } + } else { + if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { + return FAILED; } } for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { const auto &op = operators[operator_no]; if (std::dynamic_pointer_cast(op)) continue; - if (block_reader_ && std::dynamic_pointer_cast(op)) continue; if (SUCCESS != (*op)(tasks_)) { return FAILED; } } if (tasks_.permutation_.empty()) tasks_.MakePerm(); - num_rows_ = block_reader_ ? tasks_.SizeOfRows() : tasks_.Size(); - num_blocks_ = block_reader_ ? tasks_.Size() : 0; + num_rows_ = tasks_.Size(); MS_LOG(INFO) << "Total rows is " << num_rows_; return SUCCESS; } @@ -1207,140 +1168,10 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { } } -MSRStatus ShardReader::ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, - const int &buf_id) { - auto &io_seekg = file_streams_[shard_id]->seekg(page_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf_[buf_id][0]), page_length); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { - // Set thread name -#if !defined(_WIN32) && !defined(_WIN64) - auto thread_id = kThreadName + std::to_string(consumer_id); - prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); -#endif - - // Loop forever - for (;;) { - int task_id = 0; - - // Get next task ID - task_id = task_id_++; - - // All tasks are done, either quit or repeat again - if (task_id >= num_blocks_) { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [this] { return interrupt_ || task_id_ < num_blocks_; }); - if (interrupt_) { - return SUCCESS; - } - continue; - } - - // Pick up task from task list - auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - - auto shard_id = std::get<0>(std::get<1>(task)); - auto group_id = std::get<1>(std::get<1>(task)); - auto row_group_brief = ReadRowGroupBrief(group_id, shard_id, selected_columns_); - if (SUCCESS != std::get<0>(row_group_brief)) { - return FAILED; - } - auto page_length = std::get<2>(row_group_brief); - auto page_offset = std::get<3>(row_group_brief); - - MS_LOG(DEBUG) << "Block task " << task_id << tasks_.permutation_[task_id] << ", shard " << shard_id << ", group " - << group_id << ", page length " << page_length << ", page offset " << page_offset; - - // Deliver block data to output map - auto offset_and_labels = std::make_pair(std::get<4>(row_group_brief), std::get<5>(row_group_brief)); - - int deliver_id = deliver_id_; - // Hanging if maximum map size exceeded otherwise, set batch data in buffer - { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id < deliver_id_ + kNumPageInBuffer; }); - if (interrupt_) { - return SUCCESS; - } - } - - auto buf_id = task_id % kNumPageInBuffer; - delivery_block_[buf_id] = - std::make_shared>, std::vector>>(offset_and_labels); - - // Read blob - if (ReadBlob(shard_id, page_offset, page_length, buf_id) != SUCCESS) { - return FAILED; - } - - { - std::unique_lock lck(mtx_delivery_); - delivery_block_set_.insert(task_id); - } - cv_iterator_.notify_one(); - } -} - -std::shared_ptr, json>>> ShardReader::GetRowFromBuffer(int buf_id, - int rowId) { - auto &blob_page = buf_[buf_id]; - auto &offsets = (*delivery_block_[buf_id]).first; - auto &labels = (*delivery_block_[buf_id]).second; - auto &addr_start = offsets[rowId][0]; - auto &addr_end = offsets[rowId][1]; - std::vector images(blob_page.begin() + addr_start, blob_page.begin() + addr_end); - std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(labels[rowId])); - return std::make_shared, json>>>(std::move(batch)); -} - -std::vector, json>> ShardReader::GetBlockNext() { - if (deliver_id_ >= num_blocks_) { - return std::vector, json>>(); - } - - if (row_id_ == 0) { - std::unique_lock lck(mtx_delivery_); - cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_block_set_.count(deliver_id_) > 0); }); - - if (interrupt_) { - return std::vector, json>>(); - } - } - auto buf_id = deliver_id_ % kNumPageInBuffer; - auto res = GetRowFromBuffer(buf_id, row_id_); - - row_id_++; - if (row_id_ == (*delivery_block_[buf_id]).first.size()) { - row_id_ = 0; - { - std::unique_lock lck(mtx_delivery_); - delivery_block_set_.erase(deliver_id_++); - } - cv_delivery_.notify_all(); - } - - return *res; -} - std::vector, json>> ShardReader::GetNext() { if (interrupt_) { return std::vector, json>>(); } - if (block_reader_) return GetBlockNext(); if (deliver_id_ >= static_cast(tasks_.Size())) { return std::vector, json>>(); } @@ -1366,9 +1197,6 @@ std::pair, json>>> ShardRe if (interrupt_) { return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); } - if (block_reader_) { - return std::make_pair(TaskType::kCommonTask, GetBlockNext()); - } const auto &ret = ConsumerOneTask(task_id, consumer_id); if (SUCCESS != ret.first) { return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); @@ -1423,7 +1251,6 @@ void ShardReader::Reset() { } void ShardReader::ShuffleTask() { - if (block_reader_) return; // exist shuffle and distributed sampler in ops, skip shuffle bool has_sharding = false; for (const auto &op : operators_) { diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc index eda8924e13..a9a4a79cdf 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc @@ -15,7 +15,7 @@ */ #include "minddata/mindrecord/include/shard_segment.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "./securec.h" #include "minddata/mindrecord/include/common/shard_utils.h" diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc index e85229cc34..bf702180ab 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc @@ -15,7 +15,7 @@ */ #include "minddata/mindrecord/include/shard_writer.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/mindrecord/include/common/shard_utils.h" #include "./securec.h" @@ -83,7 +83,7 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) { // if not append and mindrecord file exist, return FAILED fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); if (fs->good()) { - MS_LOG(ERROR) << "MindRecord file already existed."; + MS_LOG(ERROR) << "MindRecord file already existed, please delete file: " << common::SafeCStr(file); fs->close(); return FAILED; } @@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map *row_count = std::get<2>(v); return SUCCESS; } +MSRStatus ShardWriter::MergeBlobData(const std::vector &blob_fields, + const std::map>> &row_bin_data, + std::shared_ptr> *output) { + if (blob_fields.empty()) { + return SUCCESS; + } + if (blob_fields.size() == 1) { + auto &blob = row_bin_data.at(blob_fields[0]); + auto blob_size = blob->size(); + *output = std::make_shared>(blob_size); + std::copy(blob->begin(), blob->end(), (*output)->begin()); + } else { + size_t output_size = 0; + for (auto &field : blob_fields) { + output_size += row_bin_data.at(field)->size(); + } + output_size += blob_fields.size() * sizeof(uint64_t); + *output = std::make_shared>(output_size); + std::vector buf(sizeof(uint64_t), 0); + size_t idx = 0; + for (auto &field : blob_fields) { + auto &blob = row_bin_data.at(field); + uint64_t blob_size = blob->size(); + // big edian + for (size_t i = 0; i < buf.size(); ++i) { + buf[buf.size() - 1 - i] = std::numeric_limits::max() & blob_size; + blob_size >>= 8u; + } + std::copy(buf.begin(), buf.end(), (*output)->begin() + idx); + idx += buf.size(); + std::copy(blob->begin(), blob->end(), (*output)->begin() + idx); + idx += blob->size(); + } + } + return SUCCESS; +} MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, std::vector> &blob_data, bool sign, bool parallel_writer) { @@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &la last_blob_page = page.first; } } + +MSRStatus ShardWriter::initialize(const std::unique_ptr *writer_ptr, + const std::vector &file_names) { + if (nullptr == writer_ptr) { + MS_LOG(ERROR) << "ShardWriter pointer is NULL."; + return FAILED; + } + auto res = (*writer_ptr)->Open(file_names, false); + if (SUCCESS != res) { + MS_LOG(ERROR) << "Failed to open mindrecord files to writer."; + return FAILED; + } + (*writer_ptr)->SetHeaderSize(1 << 24); + (*writer_ptr)->SetPageSize(1 << 25); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc index 4cc5e9f413..47e001e8f8 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc @@ -16,7 +16,7 @@ #include "minddata/mindrecord/include/shard_column.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/mindrecord/include/common/shard_utils.h" #include "minddata/mindrecord/include/shard_error.h" diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc index 4c7abbb4b4..6bc1c1408d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, - uint32_t seed) - : ShardSample(1, num_shards, shard_id), + uint32_t seed, int no_of_samples) + : ShardSample(1, num_shards, shard_id, no_of_samples), shuffle_(shuffle), no_of_padded_samples_(no_of_padded_samples), first_epoch_(true) { shuffle_op_ = std::make_shared(seed, kShuffleSample); } -ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) - : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, + int no_of_samples) + : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {} int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (no_of_padded_samples_ <= 0) { diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc index 500037399b..9f75d84e7a 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -22,7 +22,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/mindrecord/include/shard_error.h" #include "minddata/mindrecord/include/shard_page.h" @@ -55,7 +55,9 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool l header_size_ = header["header_size"].get(); page_size_ = header["page_size"].get(); } - ParsePage(header["page"], shard_index, load_dataset); + if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { + return FAILED; + } shard_index++; } return SUCCESS; @@ -248,11 +250,16 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { return SUCCESS; } -void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { +MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { // set shard_index when load_dataset is false - if (pages_.empty() && shard_count_ <= kMaxShardCount) { + if (shard_count_ > kMaxFileCount) { + MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount; + return FAILED; + } + if (pages_.empty() && shard_count_ <= kMaxFileCount) { pages_.resize(shard_count_); } + for (auto &page : pages) { int page_id = page["page_id"]; int shard_id = page["shard_id"]; @@ -275,6 +282,7 @@ void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_datase pages_[shard_index].push_back(std::move(parsed_page)); } } + return SUCCESS; } MSRStatus ShardHeader::ParseStatistics(const json &statistics) { @@ -715,11 +723,43 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { std::string line; while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line), -1, true); + if (SUCCESS != ParsePage(json::parse(line), -1, true)) { + return FAILED; + } } page_in_handle.close(); return SUCCESS; } + +MSRStatus ShardHeader::initialize(const std::shared_ptr *header_ptr, const json &schema, + const std::vector &index_fields, std::vector &blob_fields, + uint64_t &schema_id) { + if (nullptr == header_ptr) { + MS_LOG(ERROR) << "ShardHeader pointer is NULL."; + return FAILED; + } + auto schema_ptr = Schema::Build("mindrecord", schema); + if (nullptr == schema_ptr) { + MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema."; + return FAILED; + } + schema_id = (*header_ptr)->AddSchema(schema_ptr); + // create index + std::vector> id_index_fields; + if (!index_fields.empty()) { + for (auto &el : index_fields) { + id_index_fields.emplace_back(schema_id, el); + } + if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) { + MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index."; + return FAILED; + } + } + + auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; + blob_fields = build_schema_ptr->GetBlobFields(); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc index 808ab55bfb..b8be83735b 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc @@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den) indices_({}), sampler_type_(kCustomTopPercentSampler) {} -ShardSample::ShardSample(int num, int den, int par) +ShardSample::ShardSample(int num, int den, int par, int no_of_samples) : numerator_(num), denominator_(den), partition_id_(par), - no_of_samples_(0), + no_of_samples_(no_of_samples), indices_({}), sampler_type_(kCustomTopPercentSampler) {} @@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python } } else { + int count = 0; for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + if (no_of_samples_ != 0 && count == no_of_samples_) break; new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start + count++; } } std::swap(tasks, new_tasks); @@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { return FAILED; } total_no = static_cast(tasks.permutation_.size()); + int count = 0; for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + if (no_of_samples_ != 0 && count == no_of_samples_) break; new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + count++; } std::swap(tasks, new_tasks); } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc index 093be9792f..b9b26e33d1 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc @@ -15,7 +15,7 @@ */ #include "minddata/mindrecord/include/shard_schema.h" -#include "common/utils.h" +#include "utils/ms_utils.h" using mindspore::LogStream; using mindspore::ExceptionType::NoExceptionType; diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc index 6f8e440f91..972e3b2d14 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc @@ -15,7 +15,7 @@ */ #include "minddata/mindrecord/include/shard_task.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/mindrecord/include/common/shard_utils.h" using mindspore::LogStream; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 74eb9f3f9b..c5b38fe829 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -34,12 +34,17 @@ #include "pipeline/jit/static_analysis/static_analysis.h" #include "pipeline/jit/static_analysis/program_specialize.h" #include "pipeline/jit/resource.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "pipeline/jit/remove_value_node_dup.h" #include "frontend/optimizer/optimizer.h" #include "vm/transform.h" #include "parse/python_adapter.h" #include "frontend/optimizer/py_pass_manager.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "frontend/parallel/ps/parameter_server.h" +#include "frontend/parallel/ps/scheduler.h" +#include "frontend/parallel/ps/worker.h" +#endif namespace mindspore { namespace pipeline { @@ -228,8 +233,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { for (const auto ¶m : func_graph->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { - const auto ¶m_value = param_node->default_param(); - ValuePtr value = param_value->value(); + ValuePtr value = param_node->default_param(); constexpr bool broaden = true; AbstractBasePtr ptr = abstract::FromValue(value, broaden); @@ -347,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) { } auto graph_id = res->results()[kOutput].cast(); std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); - std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr); + compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast(bc_ptr).get(); MS_EXCEPTION_IF_NULL(msbc_ptr); compile::VmEvalFuncPtr run = std::make_shared([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { @@ -374,6 +378,25 @@ bool ExecuteAction(const ResourcePtr &res) { return true; } +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +bool StartPSWorkerAction(const ResourcePtr &res) { + parallel::ps::Worker::GetInstance().Run(); + return true; +} + +bool StartPSServerAction(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + auto &ps = parallel::ps::ParameterServer::GetInstance(); + ps.Run(func_graph); + return true; +} + +bool StartPSSchedulerAction(const ResourcePtr &res) { + parallel::ps::Scheduler::GetInstance().Run(); + return true; +} +#endif + // The parallel primitive related valuenode might be partitioned so that its value changes by device, // that will result in a syncronization error due to different executing order. // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, @@ -481,7 +504,11 @@ std::vector VmPipeline() { actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); actions.emplace_back(std::make_pair("validate", ValidateAction)); - +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (parallel::ps::Util::IsRoleOfWorker()) { + actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); + } +#endif // compile the ANF graph actions.emplace_back(std::make_pair("task_emit", TaskEmitAction)); @@ -490,5 +517,21 @@ std::vector VmPipeline() { return actions; } + +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +std::vector PServerPipeline() { + auto actions = CommonPipeline(); + actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); + actions.emplace_back(std::make_pair("validate", ValidateAction)); + actions.emplace_back(std::make_pair("pserver", StartPSServerAction)); + return actions; +} + +std::vector PSchedulerPipeline() { + std::vector actions; + actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction)); + return actions; +} +#endif } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/action.h b/mindspore/ccsrc/pipeline/jit/action.h index 0a1feab1c9..59a750f5ad 100644 --- a/mindspore/ccsrc/pipeline/jit/action.h +++ b/mindspore/ccsrc/pipeline/jit/action.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_ACTION_H_ -#define MINDSPORE_CCSRC_PIPELINE_ACTION_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ACTION_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_ACTION_H_ #include #include @@ -38,9 +38,14 @@ bool VmOptimizeAction(const ResourcePtr &res); bool PynativeOptimizeAction(const ResourcePtr &res); bool TaskEmitAction(const ResourcePtr &res); bool ExecuteAction(const ResourcePtr &res); +bool StartPSWorkerAction(const ResourcePtr &res); +bool StartPSServerAction(const ResourcePtr &res); +bool StartPSSchedulerAction(const ResourcePtr &res); std::vector GePipeline(); std::vector VmPipeline(); +std::vector PServerPipeline(); +std::vector PSchedulerPipeline(); abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, const abstract::AbstractBasePtrList &args_spec, bool clear = false); FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, @@ -50,4 +55,4 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_ACTION_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_ACTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/base.h b/mindspore/ccsrc/pipeline/jit/base.h index 0a8a2b75f3..595c2decd4 100644 --- a/mindspore/ccsrc/pipeline/jit/base.h +++ b/mindspore/ccsrc/pipeline/jit/base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_BASE_H_ -#define MINDSPORE_CCSRC_PIPELINE_BASE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_BASE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_BASE_H_ #include #include @@ -24,7 +24,7 @@ #include "ir/anf.h" #include "pipeline/jit/resource.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace pipeline { @@ -59,4 +59,4 @@ inline std::string GetFilePathName(const std::string &file_name) { } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_BASE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 65adebb6e2..8dc9d4a63a 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -68,6 +68,8 @@ PYBIND11_MODULE(_c_expression, m) { py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") + .def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"), + py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.") .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), "Get Parameter Tensor Layout Dictionary.") .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), @@ -81,9 +83,7 @@ PYBIND11_MODULE(_c_expression, m) { .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); - (void)py::class_>(m, "EnvInstance_") - .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) - .def(py::init()); + (void)py::class_>(m, "EnvInstance_").def(py::init()); (void)m.def("generate_key", &mindspore::pipeline::GenerateKey, "Generate the function graph key."); (void)m.def("real_run_op", &mindspore::pynative::RunOp, "Run op pynatively."); @@ -111,8 +111,6 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") - .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") - .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, @@ -125,10 +123,6 @@ PYBIND11_MODULE(_c_expression, m) { "Set whether to enable reduce precision.") .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") - .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.") - .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.") - .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.") - .def("set_save_ms_model_path", &mindspore::MsContext::set_save_ms_model_path, "Set path to save ms model") .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index baef64481b..13dfaeb6c0 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -30,7 +30,7 @@ #include "frontend/operator/composite/composite.h" #include "ir/func_graph_cloner.h" #include "utils/symbolic.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "debug/trace.h" #include "frontend/optimizer/ad/grad.h" @@ -73,7 +73,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) { namespace { bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; - py::tuple tuple = obj.cast(); + auto tuple = obj.cast(); std::vector value_list; for (size_t it = 0; it < tuple.size(); ++it) { ValuePtr out = nullptr; @@ -91,7 +91,7 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python list"; - py::list list = obj.cast(); + auto list = obj.cast(); std::vector value_list; for (size_t it = 0; it < list.size(); ++it) { ValuePtr out = nullptr; @@ -124,11 +124,12 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { MS_LOG(DEBUG) << "Converting python dict"; - py::dict dict_values = obj.cast(); + auto dict_values = obj.cast(); std::vector> key_values; for (auto item : dict_values) { if (!py::isinstance(item.first)) { - MS_LOG(EXCEPTION) << "The key of dict is only support str."; + MS_LOG(ERROR) << "The key of dict is only support str."; + return false; } std::string key = py::str(item.first); ValuePtr out = nullptr; @@ -158,7 +159,7 @@ void ConvertDataClass(py::object obj, ValuePtr *const data) { } bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { - MS_LOG(DEBUG) << "Converting primitive object"; + MS_LOG(DEBUG) << "Converting primitive object" << use_signature; // need check the primitive is class type or instance auto obj_type = data_converter::GetObjType(obj); @@ -184,6 +185,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = } else { *data = primitive; } + MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString(); } return true; } @@ -203,45 +205,10 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_ return true; } -bool ConvertDataType(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting type object"; - auto typeptr = obj.cast(); - if (typeptr == nullptr) { - MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null"; - return false; - } - *data = typeptr; - return true; -} - -bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting MetaTensor object."; - - auto m_tensor = obj.cast(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null."; - return false; - } - *data = m_tensor; - return true; -} - -bool ConvertTensor(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting tensor object"; - - auto m_tensor = obj.cast(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null"; - return false; - } - *data = m_tensor; - return true; -} - bool ConvertSlice(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting slice object"; - py::slice slice_obj = obj.cast(); + auto slice_obj = obj.cast(); auto convert_func = [obj](std::string attr) -> ValuePtr { auto py_attr = py::getattr(obj, attr.c_str()); if (py::isinstance(py_attr)) { @@ -359,17 +326,19 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature ConvertDataClass(obj, &converted); } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { ret = ConvertPrimitive(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) { + } else if (py::isinstance(obj)) { ret = ConvertMetaFuncGraph(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) { - ret = ConvertDataType(obj, &converted); - } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { - ret = ConvertTensor(obj, &converted); - } else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) { - ret = ConvertMetaTensor(obj, &converted); - } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { - std::shared_ptr env = obj.cast>(); + } else if (py::isinstance(obj)) { + converted = obj.cast(); + } else if (py::isinstance(obj)) { + converted = obj.cast(); + } else if (py::isinstance(obj)) { + converted = obj.cast(); + } else if (py::isinstance(obj)) { + auto env = obj.cast>(); converted = env; + } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { + converted = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); } else if (py::hasattr(obj, "__parameter__")) { auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); ret = ConvertData(to_convert, &converted); @@ -387,12 +356,12 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python std::string obj_id = results[0] + python_mod_get_parse_method; std::string obj_key = results[1]; FuncGraphPtr func_graph = nullptr; - Any value = Any(); + ValuePtr value = nullptr; bool is_cache = data_converter::GetObjectValue(obj_id, &value); if (is_cache) { - if (value.is()) { + if (value && value->isa()) { MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; - func_graph = value.cast(); + func_graph = value->cast(); return func_graph; } } @@ -405,7 +374,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); data_converter::CacheObjectValue(obj_id, func_graph); - if (obj_key != "") { + if (!obj_key.empty()) { MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); data_converter::SetObjGraphValue(obj_key, func_graph); } @@ -413,10 +382,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python return func_graph; } namespace data_converter { -static std::unordered_map object_map_ = std::unordered_map(); +static std::unordered_map object_map_; -static std::unordered_map> object_graphs_map_ = - std::unordered_map>(); +static std::unordered_map> object_graphs_map_; void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { object_graphs_map_[obj_key].push_back(data); @@ -428,8 +396,8 @@ const std::unordered_map> &GetObjGraphs() return object_graphs_map_; } -void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } -bool GetObjectValue(const std::string &obj_key, Any *const data) { +void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; } +bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) { if (object_map_.count(obj_key)) { *data = object_map_[obj_key]; return true; @@ -472,7 +440,7 @@ bool IsCellInstance(const py::object &obj) { py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object obj; - if (params.size() == 0) { + if (params.empty()) { obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); } else { obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); @@ -531,7 +499,7 @@ ClassPtr ParseDataClass(const py::object &cls_obj) { ClassAttrVector attributes; py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); for (auto &item : names) { - TypePtr type_value = item.second.cast(); + auto type_value = item.second.cast(); MS_EXCEPTION_IF_NULL(type_value); MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; attributes.push_back(std::make_pair(py::cast(item.first), type_value)); @@ -540,8 +508,8 @@ ClassPtr ParseDataClass(const py::object &cls_obj) { std::unordered_map methods_map; py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); for (auto &item : methods) { - std::string fun_name = item.first.cast(); - py::object obj = py::cast(item.second); + auto fun_name = item.first.cast(); + auto obj = py::cast(item.second); std::shared_ptr method_obj = std::make_shared(obj, fun_name); methods_map[fun_name] = method_obj; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h index 6632d4801e..e279069d73 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_PARSE_DATA_CONVERTER_H_ -#define PIPELINE_PARSE_DATA_CONVERTER_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_DATA_CONVERTER_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_DATA_CONVERTER_H_ #include #include @@ -32,8 +32,8 @@ namespace mindspore { namespace parse { // data convert for parse namespace data_converter { -void CacheObjectValue(const std::string &obj_key, const Any &data); -bool GetObjectValue(const std::string &obj_key, Any *const data); +void CacheObjectValue(const std::string &obj_key, const ValuePtr &data); +bool GetObjectValue(const std::string &obj_key, ValuePtr *const data); void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); @@ -58,4 +58,4 @@ void CleanDataClassToClassMap(); } // namespace parse } // namespace mindspore -#endif // PIPELINE_PARSE_DATA_CONVERTER_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_DATA_CONVERTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index b52dddda66..e424c9c2f7 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -23,7 +23,7 @@ #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/parse.h" #include "frontend/operator/ops.h" -#include "debug/info.h" +#include "utils/info.h" #include "debug/trace.h" #include "pybind11/pybind11.h" @@ -99,7 +99,7 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { } // Resolve class member, two possible: method, member variable -AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { +AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) { py::object namespace_var = parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj()); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); @@ -109,7 +109,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { // Make a resolve node for symbol string AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { - if (value.compare(0, strlen("self."), "self.") == 0) { + if (value.compare(0, strlen("self"), "self") == 0) { auto start = value.find_first_of('.') + 1; if (start >= value.size()) { MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; @@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; + auto removable = CollectRemovablePhi(phi); + // If the phi node is not necessary, not need to add to jumps_ of the prev blocks. + if (removable) { + MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var; + return; + } for (auto &pred : prev_blocks_) { MS_EXCEPTION_IF_NULL(pred); MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); @@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { CNodePtr jump = pred->jumps_[this]; jump->add_input(arg_node); } - // If the phi node in the body part of a for/while loop is being removed, - // then the closure convert phase will generate a cycle in graph if the - // loop is kept after specialization. This should be investigate further. - // Just now user has to set a flag on a function to indicate the for loop - // will definitely can be unroll as the sequence in for statement is fixed - // size in compile time. - if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || - parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - CollectRemovablePhi(phi); - } } AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { @@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame // 2. it's costly to iterate the graph to replace the phi for each phi. // Args : // phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { +bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { MS_EXCEPTION_IF_NULL(phi); std::string var = phi_nodes_[phi]; - MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); + MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var; if (prev_blocks_.size() == 0) { - MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); - return; + MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var; + return false; } AnfNodePtr arg_node = SearchReplaceNode(var, phi); if (arg_node != nullptr) { @@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { const auto ¶m = phi_iter.second->cast(); if (param == phi) { MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() - << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); + << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString() + << " in graph " << arg_node->func_graph()->ToString(); prev->removable_phis_[phi_iter.first] = arg_node; } } } } + return true; } + return false; } // A block should be marked matured if its predecessor blocks have been processed @@ -299,13 +298,8 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); } - // Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch' - auto prim_switch = std::make_shared(prim::kPrimSwitch->name()); - if (!unroll_loop) { - prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0)); - } CNodePtr switch_app = - func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()), + func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())}); CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); func_graph()->set_output(switch_app_new); diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index cbf75a3dd8..e598790cd4 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_PARSE_FUNCTION_BLOCK_H_ -#define PIPELINE_PARSE_FUNCTION_BLOCK_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_ #include #include @@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this { AnfNodePtr ReadVariable(const std::string &var_name); void AddPrevBlock(const FunctionBlockPtr &block); void SetPhiArgument(const ParameterPtr &phi); - void CollectRemovablePhi(const ParameterPtr &phi); + bool CollectRemovablePhi(const ParameterPtr &phi); // A block is matured if all its predecessors is generated void Mature(); CNodePtr ForceToBoolNode(const AnfNodePtr &cond); @@ -68,7 +68,7 @@ class FunctionBlock : public std::enable_shared_from_this { void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } AnfNodePtr MakeResolveAstOp(const py::object &op); - AnfNodePtr MakeResolveClassMember(std::string attr); + AnfNodePtr MakeResolveClassMember(const std::string &attr); AnfNodePtr MakeResolveSymbol(const std::string &value); AnfNodePtr MakeResolveOperation(const std::string &value); AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); @@ -115,4 +115,4 @@ class FunctionBlock : public std::enable_shared_from_this { } // namespace parse } // namespace mindspore -#endif // PIPELINE_PARSE_FUNCTION_BLOCK_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index edc9a66594..cadb0f6199 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -17,15 +17,18 @@ */ #include "pipeline/jit/parse/parse.h" + +#include #include #include #include #include #include +#include "pipeline/jit/parse/resolve.h" #include "frontend/operator/ops.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/composite/composite.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "debug/trace.h" namespace mindspore { @@ -504,14 +507,45 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); return block->func_graph()->NewCNode(make_tuple_nodes); } + +AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) { + py::object father_class; + if (args.empty()) { + father_class = py::none(); + } else if (args.size() == 2) { + father_class = args[0]; + auto arg_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1]))); + if (arg_type != AST_SUB_TYPE_NAME || py::cast(python_adapter::GetPyObjAttr(args[1], "id")) != "self") { + MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'."; + } + } else { + MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << "."; + } + py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj()); + py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance); + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + SymbolPtr symbol = std::make_shared("namespace"); + return block->MakeResolve(name_space, symbol); +} + // process function call, eg : f1(x, y) ... AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Call"; // process function call py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); + py::list args = python_adapter::GetPyObjAttr(node, "args"); + + auto arg_type = + AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node))); + if (arg_type == AST_SUB_TYPE_NAME) { + auto name_id = py::cast(python_adapter::GetPyObjAttr(function_ast_node, "id")); + if (name_id == "super") { + return ParseSuper(block, args); + } + } + AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); // function call arguments should be passed in as groups and unpacked later using unpack call - py::list args = python_adapter::GetPyObjAttr(node, "args"); std::vector packed_arguments; std::vector group_arguments; @@ -1027,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For, create an if else statement"; MS_EXCEPTION_IF_NULL(block); - // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' + // create statement 'len(xs) < MAX_FOR_LOOP_COUNT' AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); - CNodePtr bool_node = block->func_graph()->NewCNode( - {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); + CNodePtr bool_node = + block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)}); // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); @@ -1157,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); MS_EXCEPTION_IF_NULL(iter_node); - CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); + // Generate node for loop count and convert it to tensor, to make the loop not unroll + CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node}); + auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); + auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)}); + + CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len}); FunctionBlockPtr header_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); @@ -1165,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o // create loop variable 'i' ParameterPtr loop_var = header_block->func_graph()->add_parameter(); // create loop condition 'i < len(xs)' - CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); + auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); + auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)}); + CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter}); // generate the body of the for statement FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); @@ -1436,35 +1477,59 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje return block; } +AnfNodePtr FindPhis(const std::unordered_map &removable_phis, const AnfNodePtr &node) { + const auto &inp = node->cast(); + const auto &iter = removable_phis.find(inp); + if (iter == removable_phis.end()) { + return node; + } + return FindPhis(removable_phis, iter->second); +} + void Parser::RemoveUnnecessaryPhis() { // merge all removable phis to one map; std::unordered_map removable_phis; + std::vector phis; for (FunctionBlockPtr &block : func_block_list_) { MS_EXCEPTION_IF_NULL(block); removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); + std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis), + [](std::pair pair) { return pair.first; }); } - if (removable_phis.size() == 0) { return; } - for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { - if (node->isa()) { - const auto &cnode = node->cast(); - auto &inputs = cnode->inputs(); - for (std::size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->isa()) { - const auto &inp = inputs[i]->cast(); - const auto &iter = removable_phis.find(inp); - if (iter == removable_phis.end()) { - continue; - } - auto &argNode = iter->second; - MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " - << cnode->DebugString() << " with " << argNode->DebugString(); - cnode->set_input(i, argNode); - } - } + auto fg_name = func_graph_->ToString(); + auto mng = Manage(func_graph_, false); + // replace the nodes + // remove from inside to outside + for (int idx = SizeToInt(phis.size() - 1); idx >= 0; idx--) { + auto phi = phis[IntToSize(idx)]; + auto new_node = FindPhis(removable_phis, phi); + MS_LOG(DEBUG) << "phi " << phi->DebugString() << " to " << new_node->DebugString(); + mng->Replace(phi, new_node); + } + // remove the parameter + for (FunctionBlockPtr &block : func_block_list_) { + MS_EXCEPTION_IF_NULL(block); + auto &local_removable_phis = block->removable_phis(); + if (local_removable_phis.size() == 0) { + continue; } + auto func_graph = block->func_graph(); + auto ¶meters = func_graph->parameters(); + std::vector new_parameters(parameters.size()); + auto it = std::copy_if( + parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) { + return local_removable_phis.find(param->cast()) == local_removable_phis.end(); + }); + + // shrink container to new size + new_parameters.resize(std::distance(new_parameters.begin(), it)); + func_graph->set_parameters(new_parameters); + } + for (auto fg : mng->func_graphs()) { + fg->ClearAllManagerInfo(); } } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 90e965389f..afb72ba5c9 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_PARSE_PARSE_H_ -#define PIPELINE_PARSE_PARSE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ #include #include @@ -48,6 +48,10 @@ enum ParseStatusCode : int { PARSE_FAILURE = 0xFF }; +// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it +// will be sunk(i.e. not unrolled) +const int MAX_FOR_LOOP_COUNT = 600; + class AstNodeType; class ParseAst; @@ -138,6 +142,8 @@ class Parser { AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); // process a function call AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); + // process function 'super' + AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args); // process the if expression AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); // process class type define @@ -357,4 +363,4 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo } // namespace parse } // namespace mindspore -#endif // PIPELINE_PARSE_PARSE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index bdd79d00bd..d3c6851eda 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef PIPELINE_PARSE_PARSE_BASE_H_ -#define PIPELINE_PARSE_PARSE_BASE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_BASE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_BASE_H_ #include #include #include "pybind11/pybind11.h" @@ -81,10 +81,10 @@ const char PYTHON_PARSE_GET_LOCATION[] = "get_location"; const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; +const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; -const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input"; // define the common name const char NAMED_PRIMITIVE_LEN[] = "len"; @@ -149,4 +149,4 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, } // namespace parse } // namespace mindspore -#endif // PIPELINE_PARSE_PARSE_BASE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h index 0f49539bc8..d2cabdb2d8 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h +++ b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef PIPELINE_PARSE_PYTHON_ADAPTER_H_ -#define PIPELINE_PARSE_PYTHON_ADAPTER_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PYTHON_ADAPTER_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PYTHON_ADAPTER_H_ #include #include #include @@ -75,4 +75,4 @@ py::object CallPyFn(const std::string &module, const std::string &name, T... arg } // namespace parse } // namespace mindspore -#endif // PIPELINE_PARSE_PYTHON_ADAPTER_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PYTHON_ADAPTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 8d4c402639..48f3a24652 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -103,10 +103,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object } if (para_node == nullptr) { auto node = top_graph->AddWeightParameter(param_name); - auto param_value = py::cast(python_adapter::GetPyObjAttr(obj, "_value")); - node->set_default_param(param_value); + auto value = py::cast(obj); + node->set_default_param(value); // set_abstract for parameter - ValuePtr value = param_value->value(); constexpr bool broaden = true; node->set_abstract(abstract::FromValue(value, broaden)); para_node = node; @@ -228,19 +227,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func return true; } -} // namespace - -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, - const AnfNodePtr &node) { - if (node->func_graph() == nullptr || manager == nullptr) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; - } - SymbolResolver symbol_resolver(name_space, symbol, node); - if (!symbol_resolver.Resolve()) { - MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - py::object obj = symbol_resolver.result(); +// resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager +AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj, + const AnfNodePtr &node) { ScopeGuard scope_guard(node->scope()); AnfNodePtr resolved_node = nullptr; TraceManager::DebugTrace(std::make_shared(node->debug_info())); @@ -262,10 +252,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr TraceManager::EndTrace(); return resolved_node; } +} // namespace + +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { + if (node->func_graph() == nullptr || manager == nullptr) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; + } + SymbolResolver symbol_resolver(name_space, symbol, node); + if (!symbol_resolver.Resolve()) { + MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + + py::object obj = symbol_resolver.result(); + + AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node); + + return resolved_node; +} + +AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, + const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) { + if (node->func_graph() == nullptr || manager == nullptr) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; + } + SymbolResolver symbol_resolver(name_space, symbol, node); + if (!symbol_resolver.Resolve()) { + MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + + py::object obj = symbol_resolver.result(); + if (!data_converter::IsCellInstance(obj)) { + return nullptr; + } + py::object obj_attr = obj.attr(attr.c_str()); + + AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); + + return resolved_node; +} namespace { opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ + {"resolve_attr", + { + // for resolve primitive; + irpass.resolver_resolve_attr_, + }}, {"resolve", { // for resolve and getattr primitive; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index d924f1ef44..db937daebf 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef PIPELINE_PARSE_RESOLVE_H_ -#define PIPELINE_PARSE_RESOLVE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ #include #include @@ -80,7 +80,7 @@ using SymbolPtr = std::shared_ptr; // PyObjectWrapper class wrappers resolved python object for further processing. class PyObjectWrapper : public Named { public: - explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + explicit PyObjectWrapper(const py::object &obj, const std::string &name = "Python object") : Named(name), obj_(obj) {} ~PyObjectWrapper() override = default; MS_DECLARE_PARENT(PyObjectWrapper, Named); py::object obj() { return obj_; } @@ -93,7 +93,7 @@ class PyObjectWrapper : public Named { // ClassObject class wrappers dataclass class ClassObject : public PyObjectWrapper { public: - explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") + explicit ClassObject(const py::object &obj, const std::string &name = "Python dataclass") : PyObjectWrapper(obj, name) {} ~ClassObject() override = default; MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); @@ -103,7 +103,7 @@ class ClassObject : public PyObjectWrapper { // ClassType class wrappers class name in python class ClassType : public PyObjectWrapper { public: - explicit ClassType(const py::object &obj, const std::string name = "Python class type") + explicit ClassType(const py::object &obj, const std::string &name = "Python class type") : PyObjectWrapper(obj, name) {} ~ClassType() override = default; MS_DECLARE_PARENT(ClassType, PyObjectWrapper); @@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr; AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node); +// Resolve Cell with attr name. +AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, + const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr); + // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); @@ -155,4 +159,4 @@ bool ResolveAll(const FuncGraphManagerPtr &manager); } // namespace parse } // namespace mindspore -#endif // PIPELINE_PARSE_RESOLVE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index bb9a517556..0c27ba7c48 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { return true; } +bool CleanAfterOptAPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + + FuncGraphPtr func_graph = res->func_graph(); + bool changed = opt::CleanAfterOptA(func_graph, res->manager()); + + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + if (changed) { + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + } + res->set_args_spec(args_spec); + return true; +} + namespace { OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ @@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Safe inlining irpass.inline_, + irpass.sparse_tensor_eliminate_, }); opt::OptPassConfig a_2 = opt::OptPassConfig({ irpass.merge_addn_, @@ -143,20 +162,15 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { } OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = opt::OptPassConfig({ - irpass.zero_like_fill_zero_, - irpass.item_tuple_eliminate_, - irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, - irpass.inline_, - irpass.special_op_eliminate_, - irpass.get_make_ref_eliminate_, - }); + opt::OptPassConfig b_1 = + opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, + irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, irpass.get_ref_param_eliminate_, - irpass.indexed_slices_eliminate_, + irpass.row_tensor_eliminate_, }); OptPassGroupMap map({ {"b_1", b_1}, @@ -321,19 +335,23 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { return true; } -std::vector kVmPasses = {{"opt_a", OptPassAGroup}, - {"simplify_data_structures", SimplifyDataStructuresPass}, +std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_a", OptPassAGroup}, + {"clean_after_opta", CleanAfterOptAPass}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_control_depend", AddControlDependPass}}; -std::vector kGePasses = { - {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, - {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, - {"cconv", CconvPass}}; +std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_a", OptPassAGroup}, + {"clean_after_opta", CleanAfterOptAPass}, + {"opt_b", OptPassBGroup}, + {"add_control_depend", AddControlDependPass}, + {"opt_control", ControlGroup}, + {"opt_prepare", PrepareGroup}, + {"cconv", CconvPass}}; std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; } // namespace pipeline diff --git a/mindspore/ccsrc/pipeline/jit/pass.h b/mindspore/ccsrc/pipeline/jit/pass.h index 0233b6cf26..6176113a15 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.h +++ b/mindspore/ccsrc/pipeline/jit/pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_PASS_H_ -#define MINDSPORE_CCSRC_PIPELINE_PASS_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PASS_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PASS_H_ #include #include @@ -40,4 +40,4 @@ void ReclaimOptimizer(); } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_PASS_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PASS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 05699793ff..d0993f0fcc 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -32,6 +32,7 @@ #include "debug/anf_ir_utils.h" #include "utils/config_manager.h" #include "utils/convert_utils.h" +#include "utils/context/context_extends.h" #include "utils/utils.h" #include "vm/segment_runner.h" #include "frontend/parallel/context.h" @@ -40,6 +41,13 @@ #include "debug/trace.h" #include "pipeline/pynative/pynative_execute.h" #include "frontend/optimizer/py_pass_manager.h" +#include "pybind_api/pybind_patch.h" + +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/util.h" +#include "frontend/parallel/ps/worker.h" +#endif #if (ENABLE_GE || ENABLE_D) #include "pipeline/jit/pipeline_ge.h" @@ -119,7 +127,7 @@ py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple i size_t count = 0; for (auto arg_obj : inputs) { - if (py::hasattr(arg_obj, PYTHON_TENSOR_FLAG)) { + if (py::isinstance(arg_obj)) { MS_LOG(DEBUG) << "Verify Tensor"; std::shared_ptr m_tensor = arg_obj.cast>(); if (m_tensor == nullptr) { @@ -255,6 +263,7 @@ void ExecutorPy::DelNetRes(const std::string &id) { for (auto &item : tmp_info) { if (item.first.find(id) != string::npos) { MS_LOG(DEBUG) << "Delete network res:" << item.first; + item.second = nullptr; (void)info_.erase(item.first); flag = true; } @@ -378,16 +387,6 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { MS_LOG(INFO) << "End save compiled func graph!"; } -bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { - std::string phase_prefix = GetPhasePrefix(phase_s); - - if (use_vm && phase_prefix == "export") { - MS_LOG(INFO) << "Use ge backend to export geir"; - use_vm = false; - } - return use_vm; -} - void ExecutorPy::GetGeBackendPolicy() const { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -397,6 +396,40 @@ void ExecutorPy::GetGeBackendPolicy() const { } } +bool IsPhaseExportGeir(const std::string &phase_s) { + auto phase_to_export = "export.geir"; + return phase_s.rfind(phase_to_export) != std::string::npos; +} + +std::vector GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) { + bool is_geir = IsPhaseExportGeir(phase_s); + + std::string backend = MsContext::GetInstance()->backend_policy(); + +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (mindspore::parallel::ps::Util::IsParamServerMode()) { + mindspore::parallel::ps::Util::SetInternalEnvVar(); + } + if (parallel::ps::Util::IsRoleOfPServer()) { + resource->results()[kBackend] = compile::CreateBackend(); + return PServerPipeline(); + } + if (parallel::ps::Util::IsRoleOfScheduler()) { + return PSchedulerPipeline(); + } +#endif + + if (use_vm && backend != "ge" && !is_geir) { + // Create backend and session + auto backend_ptr = compile::CreateBackend(); + // Connect session to debugger + backend_ptr->SetDebugger(); + resource->results()[kBackend] = backend_ptr; + return VmPipeline(); + } + return GePipeline(); +} + bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { MS_LOG(DEBUG) << "Start ExecutorPy compile!"; if ((!py::isinstance(phase))) { @@ -415,22 +448,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons std::string phase_s = py::cast(phase); MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!"; ResourcePtr resource = std::make_shared(obj); - std::vector p_actions; - - use_vm = ChangeExportGeirUseVmFlag(use_vm, phase_s); - - std::string backend = MsContext::GetInstance()->backend_policy(); - if (use_vm && backend != "ge") { - // Create backend and session - auto backend_ptr = compile::CreateBackend(); - // Connect session to debugger - backend_ptr->SetDebugger(); - resource->results()[kBackend] = backend_ptr; - p_actions = VmPipeline(); - } else { - p_actions = GePipeline(); - } + auto p_actions = GetPipline(resource, phase_s, use_vm); std::shared_ptr pip = std::make_shared(resource, FilterActions(p_actions, phase_s)); // get the parameters items and add the value to args_spec @@ -464,8 +483,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons } std::vector ExecutorPy::FilterActions(const std::vector &actions, const std::string &phase) { - // phase does not contain 'export_onnx' - if (GetPhasePrefix(phase).find("export_onnx") == std::string::npos) { + // filter action after validate when 'export'. + if (GetPhasePrefix(phase).rfind("export", 0) == std::string::npos) { return actions; } MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'"; @@ -521,6 +540,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: } catch (const py::index_error &ex) { ReleaseResource(phase); throw py::index_error(ex); + } catch (const py::attribute_error &ex) { + ReleaseResource(phase); + throw py::attribute_error(ex); } catch (const std::exception &ex) { ReleaseResource(phase); // re-throw this exception to Python interpreter to handle it @@ -698,7 +720,11 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef if (!param_ptr->has_default()) { MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; } - arg_list->push_back(param_ptr->default_param()->value()); + if (!param_ptr->default_param()->isa()) { + MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString() + << "] is not initialized, need to call `.init_data()`"; + } + arg_list->push_back(param_ptr->default_param()); } } } @@ -761,6 +787,24 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::stri #endif } +void ExecutorPy::UpdataParamNodeDefaultInput(const std::string &phase, + const std::unordered_map ¶ms_value) { + FuncGraphPtr func_graph = info_[phase]->resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase + << ")!"; + auto ¶ms = func_graph->parameters(); + for (const auto ¶m : params) { + MS_EXCEPTION_IF_NULL(param); + auto param_cast = param->cast(); + MS_EXCEPTION_IF_NULL(param_cast); + auto iter = params_value.find(param_cast->name()); + if (iter != params_value.end()) { + param_cast->set_default_param(iter->second); + } + } +} + void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { #if ENABLE_GE RunGEInitGraph(init_params, phase); @@ -774,10 +818,13 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba #ifndef NO_DLIB auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->IsTsdOpened() || !ms_context->IsGeInited()) { + if (!context::IsTsdOpened(ms_context) || !context::IsGeInited(ms_context)) { (void)InitBackend(); } #endif + if (iter_num == -1) { + iter_num = INT32_MAX; + } if (name == kMsConvert || name == kMsVm) { return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); } @@ -865,7 +912,7 @@ void InitHccl() { mindspore::parse::python_adapter::set_python_env_flag(true); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - (void)ms_context->OpenTsd(); + (void)context::OpenTsd(ms_context); uint32_t device_id = ms_context->device_id(); std::string device_name = ms_context->device_target(); ms_context->set_enable_hccl(true); @@ -898,8 +945,8 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s void ReleaseGeTsd() { auto context_ptr = MsContext::GetInstance(); if (context_ptr != nullptr) { - (void)context_ptr->FinalizeGe(true); - (void)context_ptr->CloseTsd(true); + (void)context::FinalizeGe(context_ptr, true); + (void)context::CloseTsd(context_ptr, true); } } @@ -909,17 +956,17 @@ void InitBackend() { // open tsd before ge initialize auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->OpenTsd()) { + if (!context::OpenTsd(ms_context)) { MS_LOG(EXCEPTION) << "Open tsd failed"; } - (void)ms_context->InitGe(); + (void)context::InitGe(ms_context); } void FinalizeBackend() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - (void)context_ptr->FinalizeGe(); - (void)context_ptr->CloseTsd(); + (void)context::FinalizeGe(context_ptr); + (void)context::CloseTsd(context_ptr); } void ClearResAtexit() { @@ -927,12 +974,19 @@ void ClearResAtexit() { pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (mindspore::parallel::ps::Util::IsParamServerMode()) { + if (parallel::ps::Util::IsRoleOfWorker()) { + parallel::ps::Worker::GetInstance().Finalize(); + } + } +#endif ad::g_k_prims.clear(); abstract::ClearPrimEvaluatorMap(); compile::ClearConvertCache(); pipeline::GetMethodMap().clear(); + pipeline::GetAttrMap().clear(); pipeline::ExecutorPy::ClearRes(); pipeline::ReclaimOptimizer(); pynative::PynativeExecutor::GetInstance()->ClearRes(); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 705853d086..0dac6e94a9 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ -#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_H_ #include #include @@ -88,6 +88,8 @@ class ExecutorPy : public std::enable_shared_from_this { FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, const py::object &broadcast_params = {}); + void UpdataParamNodeDefaultInput(const std::string &phase, + const std::unordered_map ¶ms); void RunInitGraph(const py::dict &init_params, const std::string &phase); py::dict GetParameterLayout(const std::string &phase); py::dict GetCNodeStrategy(const std::string &phase); @@ -101,7 +103,6 @@ class ExecutorPy : public std::enable_shared_from_this { private: ExecutorPy(); void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); - bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; void GetGeBackendPolicy() const; // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after // 'validate' stage @@ -145,4 +146,4 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc index e08af4f2dc..dade1ada38 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc @@ -161,7 +161,7 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) // convert int to tensor with shape([1]) tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); - } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { + } else if (py::isinstance(item.second.attr("default_input"))) { // cast tensor tensor = py::cast>(item.second.attr("default_input")); } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.h b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h index f834125231..7054d2ecf4 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_ge.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ -#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_GE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_GE_H_ #include #include @@ -52,4 +52,4 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase); } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h index b36544bdba..fd52924d58 100644 --- a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ -#define MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_REMOVE_VALUE_NODE_DUP_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_REMOVE_VALUE_NODE_DUP_H_ #include #include @@ -31,4 +31,4 @@ void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_REMOVE_VALUE_NODE_DUP_H_ diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index ece128b77b..0125108734 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -17,172 +17,187 @@ */ #include "pipeline/jit/resource.h" -#include "pipeline/jit/pipeline.h" #include "pipeline/jit/static_analysis/static_analysis.h" -#include "debug/draw.h" #include "debug/trace.h" #include "ir/dtype.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/ops.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "frontend/optimizer/ad/dfunctor.h" -#include "vm/segment_runner.h" namespace mindspore { // namespace to support opmap definition namespace pipeline { -MethodMap &GetMethodMap() { - static MethodMap method_map = { - {kObjectTypeString, - { - {"__bool__", std::string("str_bool")} // C.str_bool - }}, - {kMetaTypeNone, - { - {"__bool__", std::string("none_bool")} // C.none_bool - }}, - {kNumberTypeBool, - { - {"__and__", prim::kPrimBoolAnd}, // P.bool_and - {"__or__", prim::kPrimBoolOr}, // P.bool_or - {"__eq__", prim::kPrimBoolEq}, // P.bool_eq - {"__ne__", std::string("bool_ne")}, // C.bool_ne - {"__bool__", prim::kPrimIdentity} // P.identity - }}, - {kNumberTypeInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul - {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow - {"__floor__", prim::kPrimIdentity}, // P.identity - {"__trunc__", prim::kPrimIdentity}, // P.identity - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt - {"__le__", prim::kPrimScalarLe}, // P.scalar_le - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array - }}, - {kNumberTypeUInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimIdentity}, // P.identity, - {"__trunc__", prim::kPrimIdentity}, // P.identity, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kNumberTypeFloat, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv - {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, - {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("float_bool")}, // C.float_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kObjectTypeTuple, - { - {"__len__", prim::kPrimTupleLen}, // P.tuple_len, - {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, - {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity, - {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, - {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext - {"__bool__", std::string("tuple_bool")} // C.tuple_bool - }}, - {kObjectTypeList, - { - {"__len__", prim::kPrimListLen}, // P.list_len, - {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, - {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity - {"__ms_next__", std::string("list_next")}, // C.list_next - {"append", std::string("list_append")}, // C.list_next - {"__bool__", std::string("list_bool")}, // C.list_bool - {"__ms_hasnext__", std::string("list_hasnext")}, - }}, - {kObjectTypeDictionary, +BuiltInTypeMap &GetMethodMap() { + static BuiltInTypeMap method_map = {{kObjectTypeString, + { + {"__bool__", std::string("str_bool")} // C.str_bool + }}, + {kMetaTypeNone, + { + {"__bool__", std::string("none_bool")} // C.none_bool + }}, + {kNumberTypeBool, + { + {"__and__", prim::kPrimBoolAnd}, // P.bool_and + {"__or__", prim::kPrimBoolOr}, // P.bool_or + {"__eq__", prim::kPrimBoolEq}, // P.bool_eq + {"__ne__", std::string("bool_ne")}, // C.bool_ne + {"__bool__", prim::kPrimIdentity} // P.identity + }}, + {kNumberTypeInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul + {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow + {"__floor__", prim::kPrimIdentity}, // P.identity + {"__trunc__", prim::kPrimIdentity}, // P.identity + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt + {"__le__", prim::kPrimScalarLe}, // P.scalar_le + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array + }}, + {kNumberTypeUInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimIdentity}, // P.identity, + {"__trunc__", prim::kPrimIdentity}, // P.identity, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kNumberTypeFloat, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv + {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, + {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("float_bool")}, // C.float_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kObjectTypeTuple, + { + {"__len__", prim::kPrimTupleLen}, // P.tuple_len, + {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, + {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity, + {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, + {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext + {"__bool__", std::string("tuple_bool")} // C.tuple_bool + }}, + {kObjectTypeList, + { + {"__len__", prim::kPrimListLen}, // P.list_len, + {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, + {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity + {"__ms_next__", std::string("list_next")}, // C.list_next + {"append", std::string("list_append")}, // C.list_next + {"__bool__", std::string("list_bool")}, // C.list_bool + {"__ms_hasnext__", std::string("list_hasnext")}, + }}, + {kObjectTypeDictionary, + { + {"__len__", prim::kPrimDictLen}, // P.dict_len + {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem + {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, + {"__bool__", std::string("dict_bool")} // C.dict_bool + }}, + {kObjectTypeTensorType, + { + {"all", std::string("all_")}, // C.reduce_all + {"any", std::string("any_")}, // C.reduce_any + {"__add__", std::string("add")}, // C.add + {"__sub__", std::string("sub")}, // C.sub + {"__mul__", std::string("mul")}, // C.mul + {"__truediv__", std::string("truediv")}, // C.truediv + {"__floordiv__", std::string("floordiv")}, // C.floordiv + {"__mod__", std::string("mod")}, // C.mod + {"__pow__", std::string("pow_")}, // C.pow + {"__floor__", std::string("array_floor")}, // C.array_floor + {"__trunc__", std::string("array_trunc")}, // C.array_trunc + {"__pos__", std::string("array_uadd")}, // C.array_uadd + {"__neg__", std::string("array_usub")}, // C.array_usub + {"__eq__", std::string("eq")}, // C.eq + {"__ne__", std::string("ne")}, // C.ne + {"__lt__", std::string("lt")}, // C.lt + {"__gt__", std::string("gt")}, // C.gt + {"__le__", std::string("le")}, // C.le + {"__ge__", std::string("ge")}, // C.ge + {"__matmul__", prim::kPrimDot}, // P.dot, + {"__len__", prim::kPrimArrayLen}, // P.array_len, + {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, + {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, + {"__ms_iter__", std::string("array_iter")}, // C.array_iter + {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, + {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"transpose", std::string("transpose")}, // P.transpose + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + }}, + {kObjectTypeJTagged, {}}, + {kObjectTypeSymbolicKeyType, {}}, + {kObjectTypeEnvType, {}}}; + return method_map; +} + +BuiltInTypeMap &GetAttrMap() { + static BuiltInTypeMap attr_map = { + {kObjectTypeTensorType, { - {"__len__", prim::kPrimDictLen}, // P.dict_len - {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem - {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, - {"__bool__", std::string("dict_bool")} // C.dict_bool + {"shape", std::string("shape_")}, // C.shape_ + {"dtype", std::string("dtype_")}, // C.dtype_ }}, - {kObjectTypeTensorType, + {kObjectTypeRowTensorType, { - {"__add__", std::string("add")}, // C.add - {"__sub__", std::string("sub")}, // C.sub - {"__mul__", std::string("mul")}, // C.mul - {"__truediv__", std::string("truediv")}, // C.truediv - {"__floordiv__", std::string("floordiv")}, // C.floordiv - {"__mod__", std::string("mod")}, // C.mod - {"__pow__", std::string("pow_")}, // C.pow - {"__floor__", std::string("array_floor")}, // C.array_floor - {"__trunc__", std::string("array_trunc")}, // C.array_trunc - {"__pos__", std::string("array_uadd")}, // C.array_uadd - {"__neg__", std::string("array_usub")}, // C.array_usub - {"__eq__", std::string("eq")}, // C.eq - {"__ne__", std::string("ne")}, // C.ne - {"__lt__", std::string("lt")}, // C.lt - {"__gt__", std::string("gt")}, // C.gt - {"__le__", std::string("le")}, // C.le - {"__ge__", std::string("ge")}, // C.ge - {"__matmul__", prim::kPrimDot}, // P.dot, - {"__len__", prim::kPrimArrayLen}, // P.array_len, - {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, - {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, - {"__ms_iter__", std::string("array_iter")}, // C.array_iter - {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, - {"transpose", std::string("transpose")}, // P.transpose - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + {"values", prim::kPrimRowTensorGetValues}, // F.row_tensor_get_values + {"indices", prim::kPrimRowTensorGetIndices}, // F.row_tensor_get_indices + {"dense_shape", prim::kPrimRowTensorGetDenseShape}, // F.row_tensor_get_dense_shape }}, - {kObjectTypeIndexedSlicesType, + {kObjectTypeSparseTensorType, { - {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values - {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices - {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape + {"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values + {"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices + {"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape }}, - {kObjectTypeJTagged, {}}, - {kObjectTypeSymbolicKeyType, {}}, - {kObjectTypeEnvType, {}}}; - return method_map; + }; + return attr_map; } Resource::Resource(const py::object &obj) @@ -193,6 +208,7 @@ Resource::Resource(const py::object &obj) Resource::~Resource() { MS_LOG(DEBUG) << "Resource clear"; + std::unordered_map().swap(results_); // If exit normally, these global variables will be cleaned // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION, // these global variables may not being cleaned, it may @@ -212,31 +228,42 @@ Resource::~Resource() { } } -bool Resource::IsTypeInMethodMap(const TypeId &type) { - TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); - auto iter = method_map.find(static_cast(type_id)); - if (iter != method_map.end()) { - return true; +Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) { + auto type_method_map = method_map.find(static_cast(type_id)); + if (type_method_map == method_map.end()) { + return Any(); + } + auto method = type_method_map->second.find(name); + if (method == type_method_map->second.end()) { + return Any(); } - return false; + return method->second; } -Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { +bool Resource::IsTypeInBuiltInMap(const TypeId &type) { TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); + const BuiltInTypeMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter == method_map.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; - return Any(); + const BuiltInTypeMap &attr_map = GetAttrMap(); + iter = attr_map.find(static_cast(type_id)); + if (iter == attr_map.end()) { + return false; + } } + return true; +} - auto iter_map = iter->second.find(name); - if (iter_map == iter->second.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; - return Any(); - } - return iter_map->second; +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { + TypeId type_id = NormalizeTypeId(type); + const BuiltInTypeMap &method_map = GetMethodMap(); + return GetMethodOrAttr(name, type_id, method_map); +} + +Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) { + TypeId type_id = NormalizeTypeId(type); + const BuiltInTypeMap &attr_map = GetAttrMap(); + return GetMethodOrAttr(name, type_id, attr_map); } void Resource::Clean() { diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h index 819fdd3d20..243e424d03 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.h +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ -#define MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_RESOURCE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_RESOURCE_H_ #include #include @@ -44,9 +44,11 @@ const char kOutput[] = "output"; class InferenceResource; -using MethodMap = std::unordered_map>; +using BuiltInTypeMap = std::unordered_map>; -MethodMap &GetMethodMap(); +BuiltInTypeMap &GetMethodMap(); + +BuiltInTypeMap &GetAttrMap(); class ResourceBase { public: @@ -87,10 +89,12 @@ class Resource : public ResourceBase { abstract::AnalysisEnginePtr engine() { return engine_; } - static bool IsTypeInMethodMap(const TypeId &type); + static bool IsTypeInBuiltInMap(const TypeId &type); static Any GetMethodPtr(const TypeId &type, const std::string &name); + static Any GetAttrPtr(const TypeId &type, const std::string &name); + const py::object &input() const { return input_; } FuncGraphPtr func_graph() const { return func_graph_; } @@ -117,4 +121,4 @@ using ResourcePtr = std::shared_ptr; } // namespace pipeline } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_RESOURCE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc deleted file mode 100644 index 8bdb2a0c6c..0000000000 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc +++ /dev/null @@ -1,361 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/jit/static_analysis/abstract_function.h" - -#include - -#include "pipeline/jit/static_analysis/static_analysis.h" - -namespace mindspore { -namespace abstract { -class Evaluator; -class AnalysisEngine; - -AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { - if (func_list.size() == 1) { - return func_list[0]; - } - return std::make_shared(func_list); -} - -AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { - auto this_func = shared_from_base(); - if (other->isa()) { - if (*this_func == *other) { - return this_func; - } - return std::make_shared(this_func, other); - } - auto other_union = dyn_cast(other); - if (other_union->IsSuperSet(this_func)) { - return other; - } - return std::make_shared(this_func, other); -} - -void AbstractFuncAtom::Visit(std::function visit_func) const { - visit_func(const_cast(this)->shared_from_base()); -} - -bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; } - -AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; } - -AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) { - AbstractFuncAtomPtrList new_func_list; - auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); }; - - first->Visit(build_func_list); - second->Visit(build_func_list); - func_list_ = new_func_list; -} - -std::string AbstractFuncUnion::ToString() const { - std::ostringstream buffer; - buffer << "AbstractFuncUnion({"; - int i = 0; - for (const auto &func : func_list_) { - MS_EXCEPTION_IF_NULL(func); - buffer << "[" << i << "]: " << func->ToString() << ", "; - i++; - } - buffer << "})"; - return buffer.str(); -} - -bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) { - MS_EXCEPTION_IF_NULL(other); - std::vector is_in_list; - auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) { - auto iter = find(func_list_.begin(), func_list_.end(), func); - if (iter == func_list_.end()) { - is_in_list.push_back(false); - } - return true; - }; - other->Visit(build_in_list); - return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; }); -} - -AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { - auto this_func = shared_from_base(); - if (other->isa()) { - if (IsSuperSet(other)) { - return this_func; - } - return std::make_shared(this_func, other); - } - auto other_union = dyn_cast(other); - if (other_union->IsSuperSet(this_func)) { - return other; - } - return std::make_shared(this_func, other); -} - -void AbstractFuncUnion::Visit(std::function visit_func) const { - for (AbstractFuncAtomPtr poss : func_list_) { - visit_func(poss); - } -} - -bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_union = static_cast(&other); - if (func_list_.size() != other_union->func_list_.size()) { - return false; - } - if (func_list_ == other_union->func_list_) { - return true; - } - return false; -} - -std::size_t AbstractFuncUnion::hash() const { - std::size_t hash_sum = 0; - for (auto f : func_list_) { - hash_sum = hash_combine(hash_sum, f->hash()); - } - return hash_sum; -} - -EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_prim = static_cast(&other); - if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { - return true; - } - return false; -} - -std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } - -EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_fg = static_cast(&other); - if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { - return true; - } - return false; -} - -std::size_t FuncGraphAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), func_graph_->hash()); - hash_value = hash_combine(hash_value, context_->hash()); - return hash_value; -} - -std::string FuncGraphAbstractClosure::ToString() const { - std::stringstream ss; - ss << "FuncGraphAbstractClosure: " - << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); - return ss.str(); -} - -EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_meta_fg = static_cast(&other); - if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { - return true; - } - return false; -} - -std::size_t MetaFuncGraphAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); - return hash_value; -} - -std::string MetaFuncGraphAbstractClosure::ToString() const { - return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name(); -} - -bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_partial = static_cast(&other); - if (fn_ != other_partial->fn_) { - return false; - } - if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_partial->args_spec_list_) { - return true; - } - return false; -} - -std::size_t PartialAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), fn_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -std::string PartialAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; - for (auto arg : args_spec_list_) { - buffer << arg->ToString() << ", "; - } - buffer << "))"; - return buffer.str(); -} - -EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_transformed = static_cast(&other); - if (fn_ == other_transformed->fn_) { - return true; - } - return false; -} - -std::size_t JTransformedAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), fn_->hash()); - return hash_value; -} - -EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_virtual = static_cast(&other); - if (output_ != other_virtual->output_) { - return false; - } - if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_virtual->args_spec_list_) { - return true; - } - return false; -} - -std::size_t VirtualAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), output_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -std::string VirtualAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "VirtualAbstractClosure(args: {"; - int i = 0; - for (const auto &arg : args_spec_list_) { - MS_EXCEPTION_IF_NULL(arg); - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - buffer << "}, output: " << output_->ToString() << ")"; - return buffer.str(); -} - -EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_typed = static_cast(&other); - if (output_ != other_typed->output_) { - return false; - } - if (prim_ != other_typed->prim_) { - return false; - } - if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_typed->args_spec_list_) { - return true; - } - return false; -} - -std::size_t TypedPrimitiveAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), prim_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -std::string TypedPrimitiveAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {"; - int i = 0; - for (const auto &arg : args_spec_list_) { - MS_EXCEPTION_IF_NULL(arg); - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - buffer << "}, output: " << output_->ToString() << ")"; - return buffer.str(); -} - -bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - return true; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h deleted file mode 100644 index 0823b21cd7..0000000000 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h +++ /dev/null @@ -1,303 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ -#define PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ - -#include -#include - -#include "abstract/abstract_value.h" -#include "abstract/analysis_context.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -namespace abstract { -class AbstractFuncAtom : public AbstractFunction { - public: - AbstractFuncAtom() = default; - ~AbstractFuncAtom() override = default; - MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) - - AbstractFunctionPtr GetUnique() override { return shared_from_base(); } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { - MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; - } - - AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; - void Visit(std::function) const final; - bool operator==(const AbstractFunction &other) const override; - - std::size_t hash() const override { return tid(); } -}; - -class AbstractFuncUnion : public AbstractFunction { - public: - explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list); - AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second); - ~AbstractFuncUnion() override = default; - MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction) - - std::string ToString() const override; - - AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { - MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; - } - bool IsSuperSet(const AbstractFunctionPtr &other); - AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; - void Visit(std::function) const final; - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } - - private: - AbstractFuncAtomPtrList func_list_; -}; - -class PrimitiveAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Primitive. - // prim: The primitive - // tracking_id: Identifies different uses of the same primitive. - explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr) - : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {} - ~PrimitiveAbstractClosure() override = default; - MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - PrimitivePtr prim() { return prim_; } - - AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } - - void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } - - AbstractFunctionPtr Copy() const override { return std::make_shared(prim_, tracking_id()); } - - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override { return "Prim: " + prim_->name(); } - - private: - PrimitivePtr prim_; - // store it as weak_ptr to break reference cycle. - // one reference cycle example is Graph::set_output() input0 local variable. - AnfNodeWeakPtr tracking_id_; -}; - -class FuncGraphAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Graph in a certain Context. - // context: The context, or Context.empty() - FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : func_graph_(func_graph), context_(context) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(context); - } - ~FuncGraphAbstractClosure() override = default; - MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - FuncGraphPtr func_graph() { return func_graph_; } - - AnalysisContextPtr context() const override { return context_; } - - AbstractFunctionPtr Copy() const override { - return std::make_shared(func_graph_, context_); - } - - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - FuncGraphPtr func_graph_; - AnalysisContextPtr context_; -}; -using FuncGraphAbstractClosurePtr = std::shared_ptr; - -class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { - public: - explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) - : meta_func_graph_(meta_func_graph), scope_(scope) {} - ~MetaFuncGraphAbstractClosure() override = default; - MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) - - MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; } - - AnalysisContextPtr context() const override { return kDummyAnalysisContext; } - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - ScopePtr GetScope() { return scope_; } - - AbstractFunctionPtr Copy() const override { return std::make_shared(meta_func_graph_); } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - MetaFuncGraphPtr meta_func_graph_; - ScopePtr scope_; -}; -using MetaFuncGraphAbstractClosurePtr = std::shared_ptr; - -class PartialAbstractClosure : public AbstractFuncAtom { - public: - // Represents a partial application. - // args_spec_list: The first few arguments of that function - PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, - const AnfNodePtr &node = nullptr) - : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} - ~PartialAbstractClosure() override = default; - MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - AbstractFunctionPtr fn() { return fn_; } - AbstractBasePtrList args() { return args_spec_list_; } - AnfNodePtr node() { return node_.lock(); } - void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } - AbstractFunctionPtr Copy() const override { - return std::make_shared(fn_, args_spec_list_, node_.lock()); - } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - AbstractFuncAtomPtr fn_; - AbstractBasePtrList args_spec_list_; - // The CNode which this PartialAbstractClosure evaluated from. - AnfNodeWeakPtr node_; -}; - -class JTransformedAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Function transformed through the application of J. - explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} - ~JTransformedAbstractClosure() override = default; - MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - AbstractFuncAtomPtr fn() { return fn_; } - AbstractFunctionPtr Copy() const override { return std::make_shared(fn_); } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override { return "J(" + fn_->ToString() + ")"; } - - private: - AbstractFuncAtomPtr fn_; -}; - -class VirtualAbstractClosure : public AbstractFuncAtom { - public: - // Represents some function with an explicitly fixed type signature. - // args_spec_list: The arguments as abstract value given to the function - // output: The output which is abstract value. - VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec) - : args_spec_list_(args_spec_list), output_(output_spec) {} - VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec) - : args_spec_list_({args_spec}), output_(output_spec) {} - ~VirtualAbstractClosure() override = default; - MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - AbstractBasePtrList args_spec_list() { return args_spec_list_; } - - AbstractBasePtr output() { return output_; } - AbstractFunctionPtr Copy() const override { - return std::make_shared(args_spec_list_, output_); - } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - AbstractBasePtrList args_spec_list_; - AbstractBasePtr output_; -}; -using VirtualAbstractClosurePtr = std::shared_ptr; - -class TypedPrimitiveAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Primitive with an explicitly fixed type signature. - // args_spec_list: The arguments as abstract value given to the Primitive - // output: The output which is abstract value. - TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list, - const AbstractBasePtr &output_spec) - : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {} - ~TypedPrimitiveAbstractClosure() override = default; - MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - PrimitivePtr prim() { return prim_; } - AbstractBasePtrList args_spec_list() { return args_spec_list_; } - AbstractBasePtr output() { return output_; } - AbstractFunctionPtr Copy() const override { - return std::make_shared(prim_, args_spec_list_, output_); - } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - PrimitivePtr prim_; - AbstractBasePtrList args_spec_list_; - AbstractBasePtr output_; -}; - -// Represents a function that can't be called. -class DummyAbstractClosure : public AbstractFuncAtom { - public: - DummyAbstractClosure() = default; - ~DummyAbstractClosure() override = default; - MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; } - - AbstractFunctionPtr Copy() const override { return std::make_shared(); } - bool operator==(const AbstractFunction &other) const override; - - std::string ToString() const override { return "DummyAbstractClosure()"; } -}; - -struct AbstractFunctionHasher { - std::size_t operator()(const AbstractFunctionPtr &t) const { - std::size_t hash = t->hash(); - return hash; - } -}; - -struct AbstractFunctionEqual { - bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; } -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 3e820eed3a..424a057bc3 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -22,6 +22,7 @@ #include "ir/func_graph_cloner.h" #include "abstract/utils.h" #include "debug/trace.h" +#include "utils/ms_context.h" namespace mindspore { namespace abstract { @@ -143,6 +144,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList MS_EXCEPTION_IF_NULL(arg); return arg->Broaden(); }); + if (func_graph_->joined_shapes_.size() != broaded_list.size()) { + MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size() + << " does not equal to number of original buffer arguments " + << func_graph_->joined_shapes_.size(); + } + for (size_t i = 0; i < broaded_list.size(); ++i) { + broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); + } MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) << ", broaded: " << mindspore::ToString(broaded_list); return broaded_list; @@ -170,6 +179,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. if (!(joined_args_spec_list == args_spec_list)) { func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + func_graph_->joined_shapes_.clear(); + std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), + std::back_inserter(func_graph_->joined_shapes_), + [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } return joined_args_spec_list; @@ -184,6 +197,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa if (!(joined_args_spec_list == args_spec_list)) { trace_.push_back(joined_args_spec_list); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + func_graph_->joined_shapes_.clear(); + std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), + std::back_inserter(func_graph_->joined_shapes_), + [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); @@ -373,9 +390,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) AbstractBasePtrList bparams; bparams.push_back(SensitivityTransform(orig_func_)); - (void)std::transform( - args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), - [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), + [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { + if (enable_sparse && arg_spec->isa()) { + return std::make_shared(); + } + return SensitivityTransform(arg_spec); + }); AbstractBasePtr bparams_final = std::make_shared(bparams); AbstractFunctionPtr bprop = std::make_shared(SensitivityTransform(result->abstract()), bparams_final); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h index 461574257d..af597d1d33 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ -#define PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ #include #include @@ -25,7 +25,7 @@ #include #include "pipeline/jit/static_analysis/static_analysis.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace abstract { @@ -327,4 +327,4 @@ class JEvaluator : public Evaluator { }; } // namespace abstract } // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 99e613395c..2dd0ba6b49 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -21,9 +21,9 @@ #include #include #include -#include #include #include +#include #include "frontend/operator/cc_implementations.h" #include "frontend/operator/ops.h" @@ -31,15 +31,13 @@ #include "frontend/operator/prim_to_function.h" #include "abstract/utils.h" #include "utils/symbolic.h" -#include "./common.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/resolve.h" -#include "ir/tensor.h" #include "utils/convert_utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "pipeline/jit/parse/data_converter.h" #include "abstract/param_validator.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace abstract { @@ -64,8 +62,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimShape, {InferImplShape, true}}, {prim::kPrimPack, {InferImplPack, true}}, + {prim::kPrimUnique, {InferImplUnique, true}}, + {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, @@ -133,20 +132,27 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimControlDepend, {InferImplControlDepend, true}}, // Debug {prim::kPrimDebug, {InferImplDebug, true}}, - // IndexedSlices - {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, - {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, - {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, - {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, - {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, + // RowTensor + {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, + {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, + {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, + {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, + // SparseTensor + {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, + {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, + {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, + {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, }; return prim_eval_implement_map; } using mindspore::parse::PyObjectWrapper; +std::unordered_set prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", + "env_getitem"}; + EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { - if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { + if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; @@ -166,17 +172,23 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - auto ret_abstract = AbstractEval(args_spec_list); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; - return ret_abstract; + auto do_signature = prim_->cast(); + auto &func = do_signature->function(); + if (func->isa()) { + auto sig_prim = func->cast(); + if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) { + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined"; + return ret_abstract; + } + } } if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } - auto do_signature = dyn_cast(prim_); auto out_node = dyn_cast(out_conf->node()); const auto &out_node_inputs = out_node->inputs(); if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { @@ -380,8 +392,26 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { if (abs_base->isa()) { auto arg_tensor = dyn_cast(abs_base); dic["shape"] = arg_tensor->shape()->shape(); + if (MsContext::GetInstance()->execution_mode() == kGraphMode) { + const auto &min_shape = arg_tensor->shape()->min_shape(); + const auto &max_shape = arg_tensor->shape()->max_shape(); + if (!min_shape.empty() && !max_shape.empty()) { + dic["min_shape"] = min_shape; + dic["max_shape"] = max_shape; + } + } dic["dtype"] = arg_tensor->BuildType(); dic["value"] = BuildValue(arg_tensor->BuildValue()); + } else if (abs_base->isa()) { + auto arg = dyn_cast(abs_base); + dic["shape"] = arg->shape()->shape(); + dic["dtype"] = arg->BuildType(); + dic["value"] = BuildValue(arg->BuildValue()); + } else if (abs_base->isa()) { + auto arg = dyn_cast(abs_base); + dic["shape"] = arg->shape()->shape(); + dic["dtype"] = arg->BuildType(); + dic["value"] = BuildValue(arg->BuildValue()); } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { std::vector shape; dic["shape"] = shape; @@ -436,6 +466,11 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic["shape"] = py::none(); dic["dtype"] = abs_base->BuildType(); dic["value"] = py::none(); + } else if (abs_base->isa()) { + auto arg = dyn_cast(abs_base); + dic["shape"] = py::none(); + dic["dtype"] = arg->BuildType(); + dic["value"] = py::none(); } else { auto value = abs_base->BuildValue(); if ((*value == *kAnyValue)) { @@ -479,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic if (output["value"].is_none()) { auto out_shape = output["shape"]; auto out_dtype = output["dtype"]; - return PyListDtype2AbstractTensor(out_shape, out_dtype); + py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); + py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); + + return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; @@ -523,14 +561,8 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs return iter->second; } auto py_args = PreparePyInputs(prim_py_, args); - - auto pyobj = prim_py_->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; - } - auto infer_fuc = pyobj.attr("__infer__"); prim_py_->BeginRecordAddAttr(); - py::dict output = infer_fuc(*py_args); + py::dict output = prim_py_->RunInfer(py_args); prim_py_->EndRecordAddAttr(); auto added_attrs = prim_py_->evaluate_added_attrs(); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); @@ -625,7 +657,7 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm } const int kResolveCaseUserDefineClass = 1; -const int kResolveCaseBuildinTypeMethod = 2; +const int kResolveCaseBuiltInType = 2; const int kResolveCaseFunction = 3; int GetResolveCase(const TypePtr &data_type) { MS_EXCEPTION_IF_NULL(data_type); @@ -634,8 +666,8 @@ int GetResolveCase(const TypePtr &data_type) { } // try method map, if not in method map, the data_type should be External type. - if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { - return kResolveCaseBuildinTypeMethod; + if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) { + return kResolveCaseBuiltInType; } return kResolveCaseFunction; @@ -665,8 +697,10 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun manager->AddFuncGraph(func_graph); } -EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &old_conf) { +enum REQUIRE_TYPE { ATTR, METHOD }; + +EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf, + REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) { MS_EXCEPTION_IF_NULL(old_conf); AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); @@ -692,6 +726,9 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_ MS_EXCEPTION_IF_NULL(old_conf); FuncGraphPtr func_graph = old_conf->node()->func_graph(); CNodePtr new_cnode = func_graph->NewCNode(input); + if (require_type == REQUIRE_TYPE::ATTR) { + new_cnode = func_graph->NewCNode({new_cnode}); + } AnalysisEnginePtr eng = old_conf->engine(); AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); return eng->ForwardConfig(old_conf, fn_conf); @@ -763,8 +800,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng ValuePtr method = cls->GetMethod(item_name); if (method->isa()) { - MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() - << ", item value: " << item_v->ToString(); + MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() + << ", item value: " << item_v->ToString(); } // Infer class method @@ -772,9 +809,9 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng return StaticGetterInferred(converted_v, data_conf, out_conf); } -EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, - const TypePtr &data_type, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &out_conf) { +EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, + const TypePtr &data_type, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(item_v); MS_EXCEPTION_IF_NULL(data_type); // The method maybe a Primitive or Composite @@ -783,22 +820,29 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &eng } std::string item_name = item_v->cast()->value(); - Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); - if (method.empty()) { - MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; + REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD; + Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); + if (require.empty()) { + require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name); + if (require.empty()) { + MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name; + } + require_type = REQUIRE_TYPE::ATTR; } ValuePtr converted_v = nullptr; - if (method.is()) { + if (require.is()) { // composite registered in standard_method_map go to this branch - converted_v = prim::GetPythonOps(method.cast()); - AddToManager(engine, converted_v->cast()); - } else if (method.is()) { - converted_v = method.cast(); + converted_v = prim::GetPythonOps(require.cast()); + if (!converted_v->isa()) { + AddToManager(engine, converted_v->cast()); + } + } else if (require.is()) { + converted_v = require.cast(); } else { - MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); + MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString(); } - return StaticGetterInferred(converted_v, data_conf, out_conf); + return StaticGetterInferred(converted_v, data_conf, out_conf, require_type); } EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, @@ -822,8 +866,8 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt int case_v = GetResolveCase(data_type); if (case_v == kResolveCaseUserDefineClass) { return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); - } else if (case_v == kResolveCaseBuildinTypeMethod) { - return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); + } else if (case_v == kResolveCaseBuiltInType) { + return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf); } else { return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 692fbe66e8..48bb0e990c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_STATIC_ANALYSIS_PRIM_H_ -#define PIPELINE_STATIC_ANALYSIS_PRIM_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_PRIM_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_PRIM_H_ #include #include @@ -218,10 +218,6 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -246,10 +242,12 @@ AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const Primitiv const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); @@ -350,17 +348,23 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_PRIM_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_PRIM_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index ad39190dc3..25b34d3681 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -23,8 +23,8 @@ #include "./common.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/do_signature.h" -#include "pipeline/jit/static_analysis/abstract_function.h" -#include "utils/graph_utils.h" +#include "abstract/abstract_function.h" +#include "ir/graph_utils.h" #include "utils/log_adapter.h" #include "utils/profile.h" #include "debug/trace.h" @@ -82,6 +82,9 @@ std::shared_ptr ProgramSpecializer::GetFuncGraphSpecialize if (iter != specializations_.end()) { return iter->second; } + if (context->func_graph()) { + MS_LOG(EXCEPTION) << "Specialize inner error"; + } return nullptr; } @@ -115,6 +118,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod std::shared_ptr specializer = shared_from_this(); while (fg != nullptr && fg != specializer->func_graph_) { specializer = specializer->parent_; + MS_EXCEPTION_IF_NULL(specializer); } // If had replicated, just return that. auto iter = specializer->repl_node_->find(node); @@ -539,8 +543,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early if (status == kSpecializeFindUniqueArgvalPoly || - (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || - func->abstract()->isa()))) { + (func->isa() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) { auto wrapped_node = BuildSpecializedParameterNode(new_node); new_inputs[0] = wrapped_node; } @@ -654,17 +657,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c } } if (!is_attr_same) { - if (prim->isa()) { - PrimitivePyPtr prim_py = prim->cast(); - auto clone_fn = prim_py->GetPyObj().attr("_clone"); - py::object new_obj = clone_fn(); - auto cloned_prim = new_obj.cast(); - for (auto &item : *attrs) { - cloned_prim->AddAttr(item.first, item.second); - } - return cloned_prim; - } - auto cloned_prim = std::make_shared(*prim); + auto cloned_prim = prim->Clone(); for (auto &item : *attrs) { cloned_prim->AddAttr(item.first, item.second); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h index d7f95be4ca..2c08ea00ef 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ -#define PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ #include #include @@ -133,4 +133,4 @@ class FuncGraphSpecializer : public std::enable_shared_from_thisvalue(), conf->context(), conf); + auto out = ToAbstract(value_node->value(), conf->context(), conf); + if (value_node->has_new_value()) { + out = out->Broaden(); + } + return out; } EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { @@ -434,8 +438,30 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptrGetEvaluator(shared_from_this()); - return evaluator; + if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; + } else if (func->isa()) { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; + } else if (func->isa()) { + MS_LOG(EXCEPTION) << "A dummy function cannot eval"; + } else { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction"; + } + return nullptr; } EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { @@ -444,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); } MS_EXCEPTION_IF_NULL(func); - if (func->tracking_id() == nullptr) { + if (func->tracking_id() == nullptr || func->isa() || + func->isa()) { EvaluatorPtr evaluator = _GetEvaluatorFor(func); return evaluator; } @@ -613,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { } abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context) { + const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) { AnalysisContextPtr temp_context = context; if (temp_context == nullptr) { temp_context = abstract::AnalysisContext::DummyContext(); } - return std::make_shared(func_graph, temp_context); + return std::make_shared(func_graph, temp_context, anf_node); } abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { @@ -626,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_ if (anf_node == nullptr) { meta_func_graph_fn = std::make_shared(meta_func_graph); } else { - meta_func_graph_fn = std::make_shared(meta_func_graph, anf_node->scope()); + meta_func_graph_fn = + std::make_shared(meta_func_graph, anf_node, anf_node->scope()); } return meta_func_graph_fn; } @@ -637,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con } AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { - if (value->isa()) { - auto func_graph = value->cast(); - return MakeAbstractClosure(func_graph, context); - } AnfNodePtr anf_node = nullptr; if (conf != nullptr) { anf_node = conf->node(); } + if (value->isa()) { + auto func_graph = value->cast(); + return MakeAbstractClosure(func_graph, context, anf_node); + } if (value->isa()) { auto meta_func_graph = value->cast(); return MakeAbstractClosure(meta_func_graph, anf_node); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 181696f756..57e78dcec8 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ -#define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ #include #include @@ -33,9 +33,9 @@ #include "utils/log_adapter.h" #include "ir/anf.h" -#include "ir/primitive_py.h" +#include "utils/primitive_py.h" #include "abstract/analysis_context.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "pipeline/jit/parse/parse.h" namespace mindspore { @@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this { const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; - std::unordered_map constructors_; + std::unordered_map constructors_; AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list> eval_trace_; @@ -277,4 +277,4 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_ } // namespace abstract } // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 04aa6efd05..9655f7a659 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -32,10 +32,11 @@ using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractError; using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractIndexedSlices; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractType; @@ -94,8 +95,8 @@ void ValidateAbstract(const AnfNodePtr &node) { } if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa()) { + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa() || ptrBase->isa()) { return; } diff --git a/mindspore/ccsrc/pipeline/jit/validator.h b/mindspore/ccsrc/pipeline/jit/validator.h index 041448aed9..819cf89f18 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.h +++ b/mindspore/ccsrc/pipeline/jit/validator.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ -#define MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_VALIDATOR_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_VALIDATOR_H_ #include #include @@ -35,4 +35,4 @@ void ValidateOperation(const AnfNodePtr &node); } // namespace validator } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H__ +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_VALIDATOR_H__ diff --git a/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt b/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt index c15928ee76..661b62ca61 100644 --- a/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt +++ b/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc") +file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute.cc") if (ENABLE_GE) file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc") diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index afb6d0982b..2fa238f2b8 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PYNATIVE_BASE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_BASE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_ #include #include @@ -26,7 +26,8 @@ #include #include "pybind11/pybind11.h" -#include "ir/primitive_py.h" +#include "ir/anf.h" +#include "utils/primitive_py.h" #include "abstract/abstract_value.h" namespace mindspore { @@ -48,19 +49,21 @@ enum PynativeStatusCode { enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; struct OpExecInfo { - PrimitivePyPtr py_primitive; std::string op_name; + std::string prim_id; + PrimitivePyPtr py_primitive; AbstractBasePtr abstract; + ValuePtr value = nullptr; - py::tuple op_inputs; - py::tuple inputs_mask; + py::list op_inputs; py::dict op_attrs; + std::vector inputs_mask; }; using OpExecInfoPtr = std::shared_ptr; -OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); +OpExecInfoPtr GenerateOpExecInfo(const py::args &args); -const std::set ignore_infer_prim = {"make_ref"}; +const std::set ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; } // namespace pynative } // namespace mindspore -#endif // MINDSPORE_CCSRC_PYNATIVE_BASE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 5e3add1b5f..703f3dff7e 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -23,11 +23,12 @@ #include #include "debug/trace.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" #include "ir/param_value.h" #include "utils/any.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" +#include "utils/context/context_extends.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/do_signature.h" @@ -57,7 +58,7 @@ using mindspore::tensor::TensorPy; const char SINGLE_OP_GRAPH[] = "single_op_graph"; // primitive unable to infer value for constant input in PyNative mode -const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; +const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"}; namespace mindspore { namespace pynative { @@ -111,7 +112,7 @@ inline ValuePtr PyAttrValue(const py::object &obj) { return converted_ret; } -std::string GetId(const py::object &obj) { +static std::string GetId(const py::object &obj) { py::object to_process = obj; std::string prefix = ""; if (py::isinstance(to_process)) { @@ -141,10 +142,10 @@ std::string GetId(const py::object &obj) { return py::cast(ret); } -py::object GetTupleObj(const py::object &obj) { - py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); - py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); - return obj_tuple; +static std::string GetOpId(const OpExecInfoPtr &op_exec_info) { + auto id = GetId(op_exec_info->py_primitive->GetPyObj()); + op_exec_info->prim_id = id; + return id; } std::map> GetTypeIndex(const std::vector &dtypes) { @@ -180,8 +181,10 @@ std::map GetDstType(const py::tuple &py_args, if (!has_int && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { has_int = true; } - if (py::isinstance(py_args[index])) { - auto arg = py::cast(py_args[index]); + + auto obj = py_args[index]; + if (py::isinstance(obj)) { + auto arg = py::cast(obj); TypeId arg_type_id = arg->data_type(); auto type_priority = prim::type_map.find(arg_type_id); if (type_priority == prim::type_map.end()) { @@ -231,24 +234,19 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { return RunOp(args)[0]; } -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, - py::list *const out_args_list) { - auto &py_args = *out_args; - py::tuple input_mask(args.size()); - for (size_t i = 0; i < args.size(); ++i) { - input_mask[i] = py::hasattr(args[i], "__parameter__"); - py_args[i] = GetTupleObj(args[i]); - } + +void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) { + auto &out_args = op_exec_info->op_inputs; auto signature = prim->signatures(); std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { - return input_mask; + return; } auto type_indexes = GetTypeIndex(dtypes); - auto dst_type = GetDstType(py_args, type_indexes); + auto dst_type = GetDstType(out_args, type_indexes); for (size_t i = 0; i < dtypes.size(); ++i) { if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { @@ -258,8 +256,10 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu if (it == dst_type.end() || it->second == kTypeUnknown) { continue; } - if (py::isinstance(py_args[i])) { - auto arg = py::cast(py_args[i]); + + auto obj = out_args[i]; + if (py::isinstance(obj)) { + auto arg = py::cast(obj); if (arg->data_type() == it->second) { continue; } @@ -268,26 +268,29 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu TypeIdToMsTypeStr(it->second)); } } - py::object cast_output = DoAutoCast(py_args[i], it->second); - (*out_args)[i] = cast_output; - (*out_args_list)[i] = cast_output; + + if (!py::isinstance(obj) && !py::isinstance(obj) && !py::isinstance(obj)) { + MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: " + << py::cast(obj.attr("__class__").attr("__name__")) << ", and the value is " + << py::cast(obj) << "."; + } + py::object cast_output = DoAutoCast(out_args[i], it->second); + out_args[i] = cast_output; + ValuePtr input_value = PyAttrValue(cast_output); } - return input_mask; } -void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { - size_t size = py_args.size(); - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < size; i++) { - ValuePtr input_value = PyAttrValue(py_args[i]); - args_spec_list.emplace_back(abstract::FromValueInside( - input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa())); - } +void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info, + const abstract::AbstractBasePtrList &args_spec_list) { + MS_LOG(DEBUG) << "prim " << prim->name() << "input infer" << mindspore::ToString(args_spec_list); + prim->BeginRecordAddAttr(); AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); + prim->EndRecordAddAttr(); op_exec_info->abstract = infer_res; + MS_LOG(DEBUG) << "prim " << prim->name() << "infer result " << op_exec_info->abstract->ToString(); } -OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) { +OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { MS_LOG(ERROR) << "Three args are needed by RunOp"; return nullptr; @@ -296,26 +299,19 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) MS_EXCEPTION_IF_NULL(op_exec_info); op_exec_info->op_name = py::cast(args[PY_NAME]); auto prim = py::cast(args[PY_PRIM]); - auto pyobj = prim->GetPyObj(); - if (pyobj == nullptr) { + if (!prim->HasPyObj()) { MS_LOG(EXCEPTION) << "pyobj is empty"; } - - py::list a = args[PY_INPUTS]; - size_t input_num = a.size(); - op_exec_info->op_inputs = py::tuple(input_num); - - op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args); - // use python infer method - if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); - } op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); - if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { - MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; - return nullptr; + auto inst = PynativeExecutor::GetInstance(); + if (inst->grad_flag()) { + op_exec_info->value = inst->GetForwardValue(op_exec_info); + } else { + (void)GetOpId(op_exec_info); } + op_exec_info->op_inputs = args[PY_INPUTS]; + ConvertInputs(prim, args[PY_INPUTS], op_exec_info); return op_exec_info; } @@ -324,18 +320,21 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, MS_EXCEPTION_IF_NULL(op_exec_info); std::string graph_info; // get input tensor info - size_t input_num = op_exec_info->op_inputs.size(); - for (size_t index = 0; index < input_num; ++index) { - auto input = op_exec_info->op_inputs[index]; - if (py::isinstance(input)) { - auto tensor_ptr = py::cast(input); - (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); - } + for (const auto &tensor : input_tensors) { + MS_EXCEPTION_IF_NULL(tensor); + auto tensor_shape = tensor->shape(); + (void)std::for_each(tensor_shape.begin(), tensor_shape.end(), + [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); + (void)graph_info.append(std::to_string(tensor->data_type()) + "_"); } // get prim and abstract info - MS_EXCEPTION_IF_NULL(op_exec_info->abstract); - (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + - op_exec_info->abstract->ToString()); + (void)graph_info.append(op_exec_info->prim_id + "_"); + // get attr info + const auto &op_prim = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(op_prim); + const auto &attr_map = op_prim->evaluate_added_attrs(); + (void)std::for_each(attr_map.begin(), attr_map.end(), + [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); return graph_info; } @@ -345,14 +344,12 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat MS_EXCEPTION_IF_NULL(status); MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); + + auto &op_inputs = op_exec_info->op_inputs; if (op_exec_info->op_name == "HookBackward") { - auto op_inputs = op_exec_info->op_inputs; py::tuple result(op_inputs.size()); for (size_t i = 0; i < op_inputs.size(); i++) { py::object input = op_inputs[i]; - if (py::hasattr(input, "__parameter__")) { - input = py::getattr(input, "data"); - } auto tensor = py::cast(input); auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); new_tensor->set_device_address(tensor->device_address()); @@ -363,19 +360,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat MS_LOG(INFO) << "RunOpInVM end"; return std::move(result); } - auto func = op_exec_info->py_primitive->GetComputeFunction(); - if (py::isinstance(func)) { - MS_LOG(ERROR) << "VM failed to get func"; + auto primitive = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(primitive); + auto result = primitive->RunPyComputeFunction(op_inputs); + if (py::isinstance(result)) { + MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func"; *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; py::tuple err_ret(0); return std::move(err_ret); } // execute op - py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); + py::tuple tuple_result = py::make_tuple(result); *status = PYNATIVE_SUCCESS; MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); + return std::move(tuple_result); } bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, @@ -394,7 +393,9 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i ValuePtr value = parse::data_converter::PyDataToValue(input_object); MS_EXCEPTION_IF_NULL(value); auto input_name = input_names_vec[input_index]; - op_prim->set_attr(input_name, value); + op_prim->BeginRecordAddAttr(); + op_prim->AddAttr(input_name, value); + op_prim->EndRecordAddAttr(); return true; } return false; @@ -442,8 +443,9 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, const Primitiv if (tuple_inputs.size() == 0) { MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; } - if (py::isinstance(tuple_inputs[0])) { - PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); + auto inputs = py::cast(input_object); + if (py::isinstance(inputs[0])) { + PlantTensorTupleToVector(inputs, op_prim, input_tensors); } else { ConvertValueTupleToTensor(input_object, input_tensors); *tensor_mask = kValueNodeTensorMask; @@ -495,10 +497,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *te PrimitivePtr op_prim = op_run_info->py_primitive; MS_EXCEPTION_IF_NULL(op_prim); - if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) { - MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size " - << op_run_info->inputs_mask.size(); - } opt::ConstInputToAttrInfoRegister reg; bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); size_t input_num = op_run_info->op_inputs.size(); @@ -509,7 +507,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *te continue; } // convert const and tuple input to tensor - int tensor_mask = py::cast(op_run_info->inputs_mask[index]); + int tensor_mask = static_cast(op_run_info->inputs_mask[index]); ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); // mark tensors, data : 0, weight : 1, valuenode: 2 std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); @@ -536,7 +534,6 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); ms_context->set_enable_pynative_infer(true); std::string device_target = ms_context->device_target(); if (device_target != kAscendDevice && device_target != kGPUDevice) { @@ -545,9 +542,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat if (session == nullptr) { session = session::SessionFactory::Get().Create(device_target); + MS_EXCEPTION_IF_NULL(session); + session->Init(ms_context->device_id()); } - MS_EXCEPTION_IF_NULL(session); - session->Init(ms_context->device_id()); std::vector input_tensors; std::vector tensors_mask; @@ -559,6 +556,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); ms_context->set_enable_pynative_infer(false); *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; return result; } @@ -599,29 +597,78 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn return result; } -AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { - if (!grad_flag_ || graph_info_map_.empty()) { - return nullptr; - } +ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { + auto id = GetOpId(op_exec_info); + auto op = id; + op.append(std::to_string(op_id_map_[id])); + auto iter = op_forward_map_.find(op); + if (iter != op_forward_map_.end()) { + ++op_id_map_[id]; + MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; + return iter->second; + } + return nullptr; +} + +AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + abstract::AbstractBasePtrList *args_spec_list) { + CNodePtr cnode = nullptr; std::vector inputs; auto prim = op_exec_info->py_primitive; inputs.push_back(NewValueNode(prim)); - py::tuple op_masks = op_exec_info->inputs_mask; - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < args.size(); i++) { - auto node = GetInput(args[i], op_masks[i]); - args_spec_list.push_back(node->abstract()); + + size_t size = op_exec_info->op_inputs.size(); + for (size_t i = 0; i < size; i++) { + auto obj = op_exec_info->op_inputs[i]; + bool op_mask = py::hasattr(obj, "__parameter__"); + (*op_masks).push_back(op_mask); + MS_LOG(DEBUG) << "gen args i " << i << op_exec_info->op_name << " op mask" << op_mask << "grad_flag_" << grad_flag_; + + AnfNodePtr node = nullptr; + abstract::AbstractBasePtr abs = nullptr; + auto id = GetId(obj); + if (node_abs_map_.find(id) != node_abs_map_.end()) { + abs = node_abs_map_[id]; + } + if (!graph_info_map_.empty()) { + node = GetInput(obj, op_mask); + } + if (node != nullptr && node->abstract() != nullptr) { + abs = node->abstract(); + } + if (abs == nullptr || prim->is_const_value()) { + MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; + ValuePtr input_value = PyAttrValue(obj); + bool broaden = !prim->is_const_value() && input_value->isa(); + abs = abstract::FromValueInside(input_value, broaden); + node_abs_map_[id] = abs; + } + (*args_spec_list).push_back(abs); inputs.push_back(node); } - auto cnode = curr_g_->NewCNode(inputs); - MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); - py::object out_real = out; - if (out.size() == 1) { - MS_LOG(DEBUG) << "MakeCnode out size is one."; - out_real = out[0]; + MS_LOG(DEBUG) << "MakeCnode args end"; + if (grad_flag_) { + if (curr_g_ != nullptr) { + cnode = curr_g_->NewCNode(inputs); + MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); + } + } + + return cnode; +} + +void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out_real, + const AnfNodePtr &cnode) { + if (!grad_flag_ || graph_info_map_.empty()) { + MS_LOG(DEBUG) << "no graph cnode"; + return; } + std::string obj_id = GetId(out_real); + MS_EXCEPTION_IF_NULL(cnode); + MS_LOG(DEBUG) << "MakeCnode set obj node id " << cnode->DebugString(4) << "id " << obj_id; + if (py::isinstance(out_real)) { auto value = py::cast(out_real); if (value.size() > 1) { @@ -632,28 +679,65 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const } } } - MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id; set_obj_node_map(curr_g_, obj_id, cnode); set_pyobj(curr_g_, obj_id); - return cnode; +} + +void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) { + auto id = GetOpId(op_exec_info); + auto op = id; + op.append(std::to_string(op_id_map_[id])); + auto iter = op_forward_map_.find(op); + if (iter != op_forward_map_.end()) { + return; + } + op_forward_map_[op] = value; + ++op_id_map_[id]; + MS_LOG(DEBUG) << "Save: " << op_exec_info->op_name << "(" << op << "), " << value; +} + +void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { + if (!grad_flag_ || op_exec_info->value != nullptr) { + return; + } + py::object out_real = out; + if (out.size() == 1) { + out_real = out[0]; + } + auto value = PyAttrValue(out_real); + if (cnode != nullptr) { + cnode->set_forward(value); + } + SaveOpForwardValue(op_exec_info, value); } AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { - auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)]; + auto id = GetId(obj); + auto &out = graph_info_map_[curr_g_].obj_node_map[id]; if (out.second.size() == 1 && out.second[0] == -1) { return out.first; } auto node = out.first; MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); + auto abs = node->abstract(); for (auto &idx : out.second) { std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)}; node = curr_g_->NewCNode(tuple_get_item_inputs); + if (abs != nullptr && abs->isa()) { + auto prim_abs = dyn_cast(abs)->elements()[idx]; + MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString(); + node->set_abstract(prim_abs); + } + } + if (node->abstract() != nullptr) { + node_abs_map_[id] = node->abstract(); } MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); + node->cast()->set_forward(PyAttrValue(obj)); return node; } -py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { +py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; mindspore::parse::python_adapter::set_python_env_flag(true); MsBackendPolicy backend_policy; @@ -668,7 +752,7 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { #else auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->PynativeInitGe(); + context::PynativeInitGe(ms_context); backend_policy = kMsBackendGeOnly; #endif if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { @@ -683,41 +767,89 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { return err_ret; } - auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); - if (node != nullptr) { - node->set_abstract(op_exec_info->abstract); - MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); - } MS_LOG(DEBUG) << "RunOp end"; return result; } -py::tuple RunOpInner(const py::args &args) { +py::tuple PynativeExecutor::RunOpInner(const py::args &args) { MS_LOG(DEBUG) << "RunOp start" << args.size(); - py::list args_input = args[PY_INPUTS]; + OpExecInfoPtr op_exec_info = nullptr; + auto prim = py::cast(args[PY_PRIM]); + auto name = py::cast(args[PY_NAME]); + abstract::AbstractBasePtrList args_spec_list; + std::vector op_masks; + op_exec_info = GenerateOpExecInfo(args); + if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { + return RunOpInner(op_exec_info); + } + auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); + bool is_find = false; + if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) { + auto abs_list = prim_abs_list[prim->id()]; + MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); + if (abs_list.find(args_spec_list) != abs_list.end()) { + MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name; + op_exec_info->abstract = abs_list[args_spec_list].abs; + prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); + is_find = true; + } + } - OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input); - MS_EXCEPTION_IF_NULL(op_exec_info); + if (op_exec_info->abstract == nullptr) { + // use python infer method + if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { + PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); + } + } + if (cnode != nullptr) { + cnode->set_abstract(op_exec_info->abstract); + MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString(); + } + + op_exec_info->inputs_mask = op_masks; + MS_EXCEPTION_IF_NULL(op_exec_info); if (op_exec_info->abstract != nullptr) { + MS_LOG(DEBUG) << "run op infer" << name << op_exec_info->abstract->ToString(); py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); if (!output["value"].is_none()) { py::tuple value_ret(1); value_ret[0] = output["value"]; return value_ret; } - if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { + if (op_exec_info->py_primitive->is_const_value()) { py::tuple value_ret(1); value_ret[0] = ""; return value_ret; } } - return RunOpInner(op_exec_info, args_input); + + if (!is_find) { + // const_value need infer every step + auto &out = prim_abs_list[prim->id()]; + out[args_spec_list].abs = op_exec_info->abstract; + out[args_spec_list].attrs = prim->evaluate_added_attrs(); + MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); + } + + auto result = RunOpInner(op_exec_info); + py::object out_real = result; + if (result.size() == 1) { + MS_LOG(DEBUG) << "MakeCnode out size is one."; + out_real = result[0]; + } + std::string obj_id = GetId(out_real); + node_abs_map_[obj_id] = op_exec_info->abstract; + PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, out_real, cnode); + if (cnode != nullptr) { + PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast(), result); + } + return result; } py::tuple RunOp(const py::args &args) { try { - return RunOpInner(args); + return PynativeExecutor::GetInstance()->RunOpInner(args); } catch (const py::error_already_set &ex) { // print function call stack info before release std::ostringstream oss; @@ -759,6 +891,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetId(cell); if (cell_graph_map_.count(cell_id) != 0) { + if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { + resource_ = cell_resource_map_[cell_id]; + } MS_LOG(DEBUG) << "Newgraph already compiled"; return; } @@ -767,6 +902,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg if (top_g_ == nullptr) { top_g_ = curr_g_ = g; + resource_ = std::make_shared(); + cell_resource_map_[cell_id] = resource_; df_builder_ = std::make_shared(); MS_LOG(DEBUG) << "First new graph" << top_g_.get(); Pushp(); @@ -792,11 +929,11 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str return node; } -AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) { +AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { AnfNodePtr node = nullptr; std::string obj_id = GetId(obj); - if (op_mask != nullptr && py::cast(op_mask)) { + if (op_mask) { MS_LOG(DEBUG) << "Topgraph free parameter"; // get the parameter name from parameter object auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); @@ -807,8 +944,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { auto free_param = df_builder_->add_parameter(); free_param->set_name(param_name); - auto free_param_new = py::cast(obj.attr("_value")); - free_param->set_default_param(free_param_new); + free_param->set_default_param(py::cast(obj)); free_param->debug_info()->set_name(param_name); MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; graph_info_map_[df_builder_].param_map[obj_id] = free_param; @@ -841,8 +977,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o auto tuple_size = static_cast(tuple.size()); for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], py::object())); + args.push_back(GetInput(tuple[i], false)); } + auto cnode = curr_g_->NewCNode(args); set_obj_node_map(curr_g_, GetId(obj), cnode); node = cnode; @@ -896,15 +1033,15 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o auto tuple_size = static_cast(tuple.size()); auto cnode = curr_g_->NewCNode(args); for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], py::object())); + args.push_back(GetInput(tuple[i], false)); set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); SetTupleOutput(tuple[i], cnode, std::vector{i}); } cnode->set_inputs(args); set_obj_node_map(curr_g_, out_id, cnode); } else { - MS_LOG(ERROR) << "Graph has no this out: " << out_id; - return; + MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id; + MakeValueNode(out, out_id); } } EndGraphByOutId(out_id, cell, out, args); @@ -936,7 +1073,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje if (curr_g_ != top_g_) { Popp(); for (size_t i = 0; i < args.size(); i++) { - auto input = GetInput(args[i], py::object()); + auto input = GetInput(args[i], false); inputs.push_back(input); } auto out_cnode = curr_g_->NewCNode(inputs); @@ -968,15 +1105,24 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh AnfNodePtr para_node = nullptr; if (graph_info_map_[df_builder_].param_map.count(param_id)) { para_node = graph_info_map_[df_builder_].param_map[param_id]; - - AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - auto refkey = std::make_shared(para_node->cast()->name()); - AnfNodePtr ref_key_node = NewValueNode(refkey); - AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); - - w_args.push_back(ref_node); + } else { + auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; + } + auto param_name = py::cast(name_attr); + auto free_param = df_builder_->add_parameter(); + free_param->set_name(param_name); + free_param->set_default_param(py::cast(param)); + free_param->debug_info()->set_name(param_name); + para_node = free_param; } + AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); + AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); + auto refkey = std::make_shared(para_node->cast()->name()); + AnfNodePtr ref_key_node = NewValueNode(refkey); + AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); + w_args.push_back(ref_node); } } else { MS_LOG(DEBUG) << "training not paramter_tuple"; @@ -1003,8 +1149,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args for (const auto ¶m : df_builder_->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { - const auto ¶m_value = param_node->default_param(); - ValuePtr value = param_value->value(); + ValuePtr value = param_node->default_param(); AbstractBasePtr ptr = abstract::FromValue(value, true); if (ptr == nullptr) { MS_LOG(EXCEPTION) << "Args convert error"; @@ -1065,9 +1210,10 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje void PynativeExecutor::Clear(const std::string &flag) { if (!flag.empty()) { - MS_LOG(INFO) << "Clear res"; + MS_LOG(DEBUG) << "Clear res"; (void)graph_map_.erase(flag); (void)cell_graph_map_.erase(flag); + (void)cell_resource_map_.erase(flag); Clean(); // Maybe exit in the pynative runing op, so need reset pynative flag. auto ms_context = MsContext::GetInstance(); @@ -1077,18 +1223,22 @@ void PynativeExecutor::Clear(const std::string &flag) { return; } - MS_LOG(INFO) << "Clear"; + MS_LOG(DEBUG) << "Clear"; + grad_flag_ = false; top_g_ = nullptr; + df_builder_ = nullptr; curr_g_ = nullptr; graph_info_map_.clear(); + op_id_map_.clear(); + // node_abs_map_.clear(); std::stack().swap(graph_p_); } void PynativeExecutor::Clean() { - MS_LOG(INFO) << "Clean all res"; + MS_LOG(DEBUG) << "Clean all res"; Clear(); grad_flag_ = false; - df_builder_ = nullptr; + op_forward_map_.clear(); ad::CleanRes(); pipeline::ReclaimOptimizer(); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 152d58aca4..246ceada15 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ #include #include @@ -29,7 +29,7 @@ #include "pybind11/numpy.h" #include "pipeline/pynative/base.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "ir/anf.h" #include "pipeline/jit/resource.h" #include "frontend/operator/composite/composite.h" @@ -41,12 +41,20 @@ namespace py = pybind11; using ResourcePtr = std::shared_ptr; using GradOperationPtr = std::shared_ptr; +struct PrimAbsInfo { + abstract::AbstractBasePtr abs; + std::unordered_map attrs; +}; + +using AbstractListMap = std::unordered_map; + py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::tuple RunOp(const py::args &args); -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, - py::list *const out_args_list); +void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, + py::list *const out_args_list); void ClearPyNativeSession(); @@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void ClearRes(); bool grad_flag() { return grad_flag_; } void set_grad_flag(bool flag) { grad_flag_ = flag; } - AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask); + AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetObjNode(const py::object &obj); FuncGraphPtr curr_g() { return curr_g_; } void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } @@ -95,7 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this { void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); } - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + abstract::AbstractBasePtrList *args_spec_list); + void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); + ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); + void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value); + void SaveForwardResult(const CNodePtr &cnode, const py::object &out); + void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); + py::object Run(const py::tuple &args, const py::object &phase); void Pushp(); @@ -104,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this { size_t arg_size); void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); + py::tuple RunOpInner(const py::args &args); + py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); ~PynativeExecutor(); @@ -115,11 +132,16 @@ class PynativeExecutor : public std::enable_shared_from_this { bool grad_flag_; std::unordered_map graph_map_; std::unordered_map cell_graph_map_; + std::unordered_map cell_resource_map_; std::unordered_map graph_info_map_; + std::unordered_map op_forward_map_; + std::unordered_map op_id_map_; + std::unordered_map node_abs_map_; std::stack graph_p_; FuncGraphPtr top_g_; FuncGraphPtr df_builder_; FuncGraphPtr curr_g_; + std::unordered_map prim_abs_list; }; using PynativeExecutorPtr = std::shared_ptr; @@ -127,4 +149,4 @@ using PynativeExecutorPtr = std::shared_ptr; } // namespace pynative } // namespace mindspore -#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc index 897c21fc90..994306ec2d 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc @@ -23,12 +23,12 @@ #include "utils/any.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "frontend/operator/ops.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/static_analysis/prim.h" #include "backend/session/session_factory.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" const char SINGLE_OP_GRAPH[] = "single_op_graph"; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h index 2978278489..b8459db687 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ +#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ #include #include @@ -27,7 +27,7 @@ #include "transform/graph_ir/convert.h" #include "transform/graph_ir/graph_runner.h" #include "transform/graph_ir/types.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" using GeTensor = ge::Tensor; using GeTensorPtr = std::shared_ptr; @@ -43,4 +43,4 @@ py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat } // namespace pynative } // namespace mindspore -#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ +#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ diff --git a/mindspore/ccsrc/predict/CMakeLists.txt b/mindspore/ccsrc/predict/CMakeLists.txt deleted file mode 100644 index a8cca431e7..0000000000 --- a/mindspore/ccsrc/predict/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -file(GLOB_RECURSE _PREDICT_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "predict.cc" - "generator/utils/ir_model_util.cc" - "converter/*.cc" - "converter/attr_utils/*.cc" - "converter/lite_model/*.cc" - "converter/lite_model/operations/*.cc" -) - -if (ENABLE_D) - file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "generator/ir/*.cc") - list(APPEND _PREDICT_SRC_LIST ${_D_SRC_LIST}) -endif () -add_library(_mindspore_predict_obj OBJECT ${_PREDICT_SRC_LIST}) \ No newline at end of file diff --git a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.cc b/mindspore/ccsrc/predict/converter/attr_utils/convert_util.cc deleted file mode 100644 index ff2e7bab0e..0000000000 --- a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.cc +++ /dev/null @@ -1,229 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/attr_utils/convert_util.h" - -namespace mindspore { -namespace predict { -namespace utils { -TypePtr GetTypePtr(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - TypePtr type_ptr = anf_node->Type(); - MS_EXCEPTION_IF_NULL(type_ptr); - if (type_ptr->isa()) { - auto tensor_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - return elem; - } else if (type_ptr->isa()) { - auto tuple_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tuple_ptr); - auto tuple_i = (*tuple_ptr)[0]; - MS_EXCEPTION_IF_NULL(tuple_i); - if (tuple_i->isa()) { - auto tensor_ptr = tuple_i->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem; - } else if (tuple_i->isa()) { - return type_ptr; - } else { - MS_LOG(EXCEPTION) << "unsupported type: " << type_ptr->ToString(); - } - } else if (type_ptr->isa()) { - return type_ptr; - } - std::string type_name = type_ptr->ToString(); - MS_LOG(EXCEPTION) - << "The output type of node should be a tensor type a number or a tuple of tensor type, but this is: " - << type_name; -} - -MsDataType GetMSDataType(TypeId ori_data_type) { - MsDataType dst_data_type; - switch (ori_data_type) { - case kNumberTypeFloat16: - dst_data_type = mindspore::predict::DataType_DT_FLOAT16; - return dst_data_type; - case kNumberTypeFloat32: - dst_data_type = mindspore::predict::DataType_DT_FLOAT; - return dst_data_type; - case kNumberTypeInt8: - dst_data_type = mindspore::predict::DataType_DT_INT8; - return dst_data_type; - case kNumberTypeInt32: - dst_data_type = mindspore::predict::DataType_DT_INT32; - return dst_data_type; - case kNumberTypeUInt8: - dst_data_type = mindspore::predict::DataType_DT_UINT8; - return dst_data_type; - case kNumberTypeUInt32: - dst_data_type = mindspore::predict::DataType_DT_UINT32; - return dst_data_type; - case kTypeUnknown: - dst_data_type = mindspore::predict::DataType_DT_UNDEFINED; - return dst_data_type; - default: - MS_LOG(EXCEPTION) << "Ms don't support this DataType"; - } -} - -MsFormat GetMsFormat(const std::string &format_str) { - if (format_str == kOpFormat_DEFAULT) { - MsFormat ms_format = predict::Format_NCHW; - return ms_format; - } else { - // all middle format default to NCHW - return predict::Format_NCHW; - } -} - -TensorPtr GetParaAscendTensor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!anf_node->isa()) { - return nullptr; - } - auto device_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - // device type_ptr - auto device_type_ptr = GetTypePtr(anf_node); - // device shape - auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, 0); - std::vector tensor_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(tensor_shape), SizeToInt); - // device format - auto format = AnfAlgo::GetOutputFormat(anf_node, 0); - // device tensor - TensorPtr device_tensor = std::make_shared(device_type_id, tensor_shape); - // device info - device_tensor->SetDeviceInfo(format, device_type_ptr); - return device_tensor; -} - -TensorPtr GetParaCpuTensor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!(anf_node->isa())) { - return nullptr; - } else { - auto ori_type_id = AnfAlgo::GetOutputInferDataType(anf_node, 0); - auto ori_type_ptr = GetTypePtr(anf_node); - auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); - std::vector tensor_shape; - (void)std::transform(ori_shape.begin(), ori_shape.end(), std::back_inserter(tensor_shape), SizeToInt); - auto ori_format = AnfAlgo::GetOutputFormat(anf_node, 0); - TensorPtr cpu_tensor = std::make_shared(ori_type_id, tensor_shape); - cpu_tensor->SetDeviceInfo(ori_format, ori_type_ptr); - return cpu_tensor; - } -} - -TensorPtr GetValueTensor(const ValueNodePtr &const_node) { - MS_EXCEPTION_IF_NULL(const_node); - auto value_ptr = const_node->value(); - MS_EXCEPTION_IF_NULL(value_ptr); - if (!value_ptr->isa()) { - return nullptr; - } - TensorPtr tensor = value_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor); - auto data_type = tensor->Dtype(); - MS_EXCEPTION_IF_NULL(data_type); - auto type_id = data_type->type_id(); - auto shape = tensor->shape(); - TensorPtr tensor_constant = std::make_shared(type_id, shape); - tensor_constant->SetDeviceInfo(tensor->device_info().format_, tensor->device_info().data_type_); - return tensor_constant; -} - -TensorPtr GetKernelCpuTensor(const CNodePtr &c_node_ptr, size_t inx) { - if (c_node_ptr == nullptr || inx >= AnfAlgo::GetOutputTensorNum(c_node_ptr)) { - MS_LOG(ERROR) << "GetKernelCpuTensor failed"; - return nullptr; - } - auto ori_shape = AnfAlgo::GetOutputInferShape(c_node_ptr, inx); - auto ori_type_id = AnfAlgo::GetOutputInferDataType(c_node_ptr, inx); - std::vector tensor_shape; - (void)std::transform(ori_shape.begin(), ori_shape.end(), std::back_inserter(tensor_shape), SizeToInt); - auto ori_output_type = GetTypePtr(c_node_ptr); - TensorPtr device_tensor = std::make_shared(ori_type_id, tensor_shape); - auto format = AnfAlgo::GetOutputFormat(c_node_ptr, inx); - device_tensor->SetDeviceInfo(format, ori_output_type); - return device_tensor; -} - -TensorPtr GetKernelAscendTensor(const CNodePtr &c_node_ptr, size_t inx) { - if (c_node_ptr == nullptr || inx >= AnfAlgo::GetOutputTensorNum(c_node_ptr)) { - MS_LOG(ERROR) << "GetKernelAscendTensor failed"; - return nullptr; - } - auto shape = AnfAlgo::GetOutputDeviceShape(c_node_ptr, inx); - std::vector tensor_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(tensor_shape), SizeToInt); - auto format = AnfAlgo::GetOutputFormat(c_node_ptr, inx); - auto type_id = AnfAlgo::GetOutputDeviceDataType(c_node_ptr, inx); - auto output_type_ptr = GetTypePtr(c_node_ptr); - TensorPtr device_tensor = std::make_shared(type_id, tensor_shape); - device_tensor->SetDeviceInfo(format, output_type_ptr); - return device_tensor; -} - -TensorPtr GetOutputTensor(const AnfNodePtr &out_node, size_t inx) { - MS_EXCEPTION_IF_NULL(out_node); - auto shape = AnfAlgo::GetOutputInferShape(out_node, inx); - std::vector tensor_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(tensor_shape), SizeToInt); - auto type_id = AnfAlgo::GetOutputInferDataType(out_node, inx); - auto output_type_ptr = GetTypePtr(out_node); - auto format = AnfAlgo::GetOutputFormat(out_node, inx); - TensorPtr output_tensor = std::make_shared(type_id, tensor_shape); - output_tensor->SetDeviceInfo(format, output_type_ptr); - return output_tensor; -} - -bool FindNodeInMap(const std::unordered_map &node_map, const AnfNodePtr &node) { - return std::any_of(node_map.begin(), node_map.end(), - [node](const std::pair &kernel_key) { return kernel_key.first == node.get(); }); -} - -bool SaveDeviceModelUtil(const std::shared_ptr &new_ms_graph_ptr, const std::string &save_path_name, - SubGraphDefT *sub_graph) { - MS_EXCEPTION_IF_NULL(new_ms_graph_ptr); - MS_EXCEPTION_IF_NULL(sub_graph); - // save mindspore schema to file - new_ms_graph_ptr->name = "default_graph"; - std::unique_ptr sub_graph_ptr(sub_graph); - new_ms_graph_ptr->subgraphs.emplace_back(std::move(sub_graph_ptr)); - // get flatbuffer builder - flatbuffers::FlatBufferBuilder builder(1024); - auto offset = mindspore::predict::GraphDef::Pack(builder, new_ms_graph_ptr.get()); - builder.Finish(offset); - auto size = builder.GetSize(); - if (size == 0) { - MS_LOG(ERROR) << "builder has no size"; - return false; - } - auto content = builder.GetBufferPointer(); - std::ofstream output(save_path_name); - if (!output.is_open()) { - MS_LOG(EXCEPTION) << "mindspore.mindspoire output failed"; - } - (void)output.write((const char *)content, size); - output.close(); - return true; -} -} // namespace utils -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h b/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h deleted file mode 100644 index 612ccde1a5..0000000000 --- a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_ATTR_UTILS_CONVERT_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_ATTR_UTILS_CONVERT_UTIL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "ir/tensor.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "predict/schema/inner/ms_generated.h" - -using TensorPtr = mindspore::tensor::TensorPtr; -using TensorPtrList = std::vector; -using AllOutputTensors = std::unordered_map; -using OpDefT = mindspore::predict::OpDefT; -using GraphDefT = mindspore::predict::GraphDefT; -using TensorDefT = mindspore::predict::TensorDefT; -using SubGraphDefT = mindspore::predict::SubGraphDefT; -using SubGraphPtr = std::unique_ptr; -using MsDataType = mindspore::predict::DataType; -using MsFormat = mindspore::predict::Format; -using MsKernelKey = void *; -namespace mindspore { -namespace predict { -namespace utils { -TypePtr GetTypePtr(const AnfNodePtr &anf_node); -MsDataType GetMSDataType(TypeId ori_data_type); -MsFormat GetMsFormat(const std::string &format_str); -TensorPtr GetParaAscendTensor(const AnfNodePtr &anf_node); -TensorPtr GetParaCpuTensor(const AnfNodePtr &anf_node); -TensorPtr GetValueTensor(const ValueNodePtr &const_node); -TensorPtr GetKernelCpuTensor(const CNodePtr &c_node_ptr, size_t inx); -TensorPtr GetKernelAscendTensor(const CNodePtr &c_node_ptr, size_t inx); -TensorPtr GetOutputTensor(const AnfNodePtr &out_node, size_t inx); -bool FindNodeInMap(const std::unordered_map &Nodemap, const AnfNodePtr &node); -bool SaveDeviceModelUtil(const std::shared_ptr &new_ms_graph_ptr, const std::string &save_path_name, - SubGraphDefT *sub_graph_def_t); -} // namespace utils -} // namespace predict -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_ATTR_UTILS_CONVERT_UTIL_H_ diff --git a/mindspore/ccsrc/predict/converter/attr_utils/op_attr_type.h b/mindspore/ccsrc/predict/converter/attr_utils/op_attr_type.h deleted file mode 100644 index 49504cd3c8..0000000000 --- a/mindspore/ccsrc/predict/converter/attr_utils/op_attr_type.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_CPU_ATTR_UTILS_OP_ATTR_TYPE_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_CPU_ATTR_UTILS_OP_ATTR_TYPE_H_ -namespace mindspore { -namespace predict { -namespace convert { -typedef enum CpuOpType { - CPU_OP_PAD = 0, - CPU_OP_MAXIMUM, - CPU_OP_CONCAT, - CPU_OP_SOFTMAX, - CPU_OP_ACTIVATION, - CPU_OP_CONV2D, - CPU_OP_FUSEDBATCHNORM, - CPU_OP_CAFFEBATCHNORM, - CPU_OP_SQUEEZE, - CPU_OP_BIASADD, - CPU_OP_POOLING, - CPU_OP_DEPTHWISECONV2D, - CPU_OP_DEDEPTHWISECONV2D, - CPU_OP_RESIZE, - CPU_OP_DETECTIONPOSTPROCESS, - CPU_OP_FULLCONNECTION, - CPU_OP_MEAN, - CPU_OP_DECONV2D, - CPU_OP_SCALE, - CPU_OP_ELTWISE, - CPU_OP_ADD, - CPU_OP_SLICE, - CPU_OP_MUL, - CPU_OP_EXP, - CPU_OP_RESHAPE, - CPU_OP_POWER, - CPU_OP_ARGMAX, - CPU_OP_ARGMAX_NETOUTPUT, - CPU_OP_MATMUL, - CPU_OP_CAFFEPRELU, - CPU_OP_STRIDEDSLICE, - CPU_OP_STACK, - CPU_OP_RANGE, - CPU_OP_EXPANDDIMS, - CPU_OP_TILE, - CPU_OP_CAST, - CPU_OP_CAFFECROP, - CPU_OP_PRESERVEED = 37 -} CpuOpType_t; -} // namespace convert -} // namespace predict -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_CPU_ATTR_UTILS_OP_ATTR_TYPE_H_ diff --git a/mindspore/ccsrc/predict/converter/executor_tensor.cc b/mindspore/ccsrc/predict/converter/executor_tensor.cc deleted file mode 100644 index b51496a9b4..0000000000 --- a/mindspore/ccsrc/predict/converter/executor_tensor.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/executor_tensor.h" - -namespace mindspore { -namespace executor { -int TensorCache::addExTensor(int tensor_key, const TensorPtr &tensor, int refCount, const std::vector &host_shape, - ExTensorType stable, bool inc) { - MS_EXCEPTION_IF_NULL(tensor); - TensorPtr tmp_tensor = tensor; - ExTensorPtr ex_tensor_ptr = - std::make_shared(tensor_key, tmp_tensor, refCount, nodeIndex, host_shape, stable); - int pre_index = ex_tensor_ptr->index_; - if (inc) { - nodeIndex++; - } - // no need to judge,just add to map directly - tensors[tensor_key].push_back(ex_tensor_ptr); - return pre_index; -} - -std::vector TensorCache::findTensor(int key) { - std::vector ex_tensors; - auto iter = tensors.find(key); - if (iter != tensors.end()) { - return iter->second; - } else { - MS_LOG(INFO) << "can not find any tensorlist"; - return ex_tensors; - } -} - -void TensorCache::deleteTensor(int key) { (void)tensors.erase(key); } -} // namespace executor -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/executor_tensor.h b/mindspore/ccsrc/predict/converter/executor_tensor.h deleted file mode 100644 index 7b95c409f8..0000000000 --- a/mindspore/ccsrc/predict/converter/executor_tensor.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_EXECUTOR_TENSOR_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_EXECUTOR_TENSOR_H_ - -#include -#include -#include -#include -#include "ir/tensor.h" - -namespace mindspore { -namespace executor { -using TensorPtr = tensor::TensorPtr; -static constexpr int MS_MAX_REFCOUNT = 999; -enum ExTensorType { INPUTDATA, WEIGHTS, CONSTANT, KERNEL, OUTPUT }; -class ExTensor { - public: - int key_; - TensorPtr device_tensor_ptr_; - int ref_count_; - int index_; - std::vector host_shape_; - ExTensorType stable_; - ExTensor(int key, TensorPtr tensor_ptr, int ref_count, int index, std::vector host_shape, - ExTensorType ex_tensor_type) - : key_(key), - device_tensor_ptr_(std::move(tensor_ptr)), - ref_count_(ref_count), - index_(index), - host_shape_(std::move(host_shape)), - stable_(ex_tensor_type) {} - ~ExTensor() { host_shape_.clear(); } -}; -using ExTensorPtr = std::shared_ptr; -class TensorCache { - public: - TensorCache() = default; - - ~TensorCache() { tensors.clear(); } - - int addExTensor(int tensor_key, const TensorPtr &tensor, int refCount, const std::vector &host_shape, - ExTensorType stable, bool inc = true); - // just adjust for dynamic tensor - std::vector findTensor(int key); - void deleteTensor(int key); - const std::unordered_map> &GetCachedTensor() const { return tensors; } - - private: - std::unordered_map> tensors; - int nodeIndex = 0; -}; -using TensorCachePtr = std::shared_ptr; -} // namespace executor -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_EXECUTOR_TENSOR_H_ diff --git a/mindspore/ccsrc/predict/converter/kernel2ms.cc b/mindspore/ccsrc/predict/converter/kernel2ms.cc deleted file mode 100644 index 04aceb62eb..0000000000 --- a/mindspore/ccsrc/predict/converter/kernel2ms.cc +++ /dev/null @@ -1,561 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/kernel2ms.h" -#include -#include "ir/anf.h" -#include "predict/converter/lite_model/op_attr_packer.h" -#include "mindspore/ccsrc/frontend/operator/ops.h" - -namespace mindspore { -namespace executor { -Kernel2Ms &Kernel2Ms::GetInstance() { - static Kernel2Ms instance; - return instance; -} - -bool Kernel2Ms::SetMemResue() const { - MS_LOG(INFO) << "MemResue start"; - return true; -} - -bool Kernel2Ms::SetAllTensors(const TensorCachePtr &tensor_cache, SubGraphDefT *ms_graph) { - if (tensor_cache == nullptr || ms_graph == nullptr) { - return false; - } - const std::unordered_map> &cachedTensors = tensor_cache->GetCachedTensor(); - size_t total_size = 0; - if (cachedTensors.empty()) { - return false; - } - for (auto &iter : cachedTensors) { - auto ex_tensors = iter.second; - total_size += ex_tensors.size(); - } - ms_graph->allTensors.resize(total_size); - for (auto &iter : cachedTensors) { - for (auto &ex_tensor : iter.second) { - std::unique_ptr ms_tensor(new TensorDefT()); - auto device_tensor_tmp = ex_tensor->device_tensor_ptr_; - auto device_d_type = device_tensor_tmp->data_type(); - ms_tensor->dataType = predict::utils::GetMSDataType(device_d_type); - auto device_shape = device_tensor_tmp->shape(); - ms_tensor->dims.clear(); - if (device_shape.empty()) { - ms_tensor->dims.push_back(1); - } else { - ms_tensor->dims.assign(device_shape.begin(), device_shape.end()); - } - std::string format_str = device_tensor_tmp->device_info().format_; - ms_tensor->format = predict::utils::GetMsFormat(format_str); - ms_tensor->offset = 0; - auto stable = ex_tensor->stable_; - if (stable == INPUTDATA || stable == CONSTANT || stable == WEIGHTS) { - ms_tensor->refCount = MS_MAX_REFCOUNT; - } else { - ms_tensor->refCount = 0; - } - ms_graph->allTensors[IntToSize(ex_tensor->index_)] = std::move(ms_tensor); - } - } - return true; -} - -bool Kernel2Ms::SetGraphOutputIdx(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache, - SubGraphDefT *ms_graph, AllOutputTensors *all_output_tensors) { - MS_EXCEPTION_IF_NULL(tensor_cache); - MS_EXCEPTION_IF_NULL(ms_graph); - MS_EXCEPTION_IF_NULL(all_output_tensors); - auto out_nodes = kernel_graph_ptr->outputs(); - if (out_nodes.empty()) { - return false; - } - // maybe need to judge out_nodes is real && output must be CNode - for (size_t i = 0; i < out_nodes.size(); ++i) { - std::vector real_inputs_link; - std::vector real_output_idx_link; - GetRealInpoutsPtr(out_nodes[i], &real_inputs_link, &real_output_idx_link); - if (real_inputs_link.empty()) { - MS_LOG(INFO) << "this graph output node is vitural node, has no real input"; - continue; - } - for (size_t k = 0; k < real_inputs_link.size(); ++k) { - int key = node_indexs_[out_nodes[i].get()]; - auto ex_tensor_list = tensor_cache->findTensor(key); - if (ex_tensor_list.empty()) { - MS_LOG(INFO) << "SetGraphOutputIdx do not add Extensor "; - continue; - } - auto ex_tensor = ex_tensor_list[real_output_idx_link[k]]; - ex_tensor_list.clear(); - ms_graph->outputIndex.push_back(ex_tensor->index_); - } - } - return true; -} - -bool Kernel2Ms::SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &output_tensor, - const TensorCachePtr &tensor_cache, int ref_count, size_t order_index, OpDefT *ms_node) { - MS_EXCEPTION_IF_NULL(c_node_ptr); - MS_EXCEPTION_IF_NULL(output_tensor); - MS_EXCEPTION_IF_NULL(ms_node); - MS_EXCEPTION_IF_NULL(tensor_cache); - if (!predict::utils::FindNodeInMap(node_indexs_, c_node_ptr)) { - MS_LOG(ERROR) << "can not find any pk_key in inited node_indexs map"; - return false; - } - int tensor_key = node_indexs_[c_node_ptr.get()]; - auto host_shape = AnfAlgo::GetOutputInferShape(c_node_ptr, order_index); - std::vector tensor_shape; - (void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(tensor_shape), SizeToInt); - int outputIndex = tensor_cache->addExTensor(tensor_key, output_tensor, ref_count, tensor_shape, KERNEL); - ms_node->outputIndex.push_back(outputIndex); - return true; -} - -void Kernel2Ms::GetRealInpoutsPtr(const AnfNodePtr &node, std::vector *real_inputs, - std::vector *real_output_idx) { - MS_EXCEPTION_IF_NULL(real_inputs); - MS_EXCEPTION_IF_NULL(real_output_idx); - size_t default_idx = 0; - if (node->isa()) { - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - std::string c_node_name = GetCNodeFuncName(c_node); - if (c_node_name == prim::kPrimTupleGetItem->name()) { - auto v_node = c_node->inputs()[kTupleGetItemIndex]->cast(); - MS_EXCEPTION_IF_NULL(v_node); - default_idx = IntToSize(GetValue(v_node->value())); - real_inputs->push_back(c_node->inputs()[1]); - real_output_idx->push_back(default_idx); - return; - } else if (c_node_name == prim::kPrimDepend->name()) { - GetRealInpoutsPtr(c_node->inputs()[1], real_inputs, real_output_idx); - return; - } else if (c_node_name == prim::kPrimMakeTuple->name()) { - for (auto &in : c_node->inputs()) { - GetRealInpoutsPtr(in, real_inputs, real_output_idx); - } - return; - } else { - real_inputs->push_back(node); - real_output_idx->push_back(default_idx); - } - } else if (node->isa()) { - real_inputs->push_back(node); - real_output_idx->push_back(default_idx); - } else if (node->isa()) { - real_inputs->push_back(node); - real_output_idx->push_back(default_idx); - } -} - -bool Kernel2Ms::SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, OpDefT *ms_node) { - MS_EXCEPTION_IF_NULL(c_node_ptr); - MS_EXCEPTION_IF_NULL(tensor_cache); - MS_EXCEPTION_IF_NULL(ms_node); - for (size_t i = 1; i < c_node_ptr->inputs().size(); ++i) { - std::vector real_inputs; - std::vector real_output_idx; - GetRealInpoutsPtr(c_node_ptr->inputs()[i], &real_inputs, &real_output_idx); - if (real_inputs.empty()) { - MS_LOG(INFO) << "kernel has no inputs: " << c_node_ptr.get() << " input size[%lu]" << c_node_ptr->inputs().size(); - continue; - } - for (size_t j = 0; j < real_inputs.size(); ++j) { - int key = node_indexs_[real_inputs[j].get()]; - std::vector ex_tensor_list = tensor_cache->findTensor(key); - if (ex_tensor_list.empty()) { - continue; - } - ExTensorPtr ex_tensor_ptr = ex_tensor_list[real_output_idx[j]]; - ex_tensor_list.clear(); - ms_node->inputIndex.push_back(ex_tensor_ptr->index_); - } - } - return true; -} - -void Kernel2Ms::TransformGraphIndx() { - // transform index && anfnodeptr - if (node_indexs_.empty()) { - MS_LOG(EXCEPTION) << "node_indexs_ not ininted"; - } - for (auto &item : node_indexs_) { - index_nodes_[item.second] = item.first; - } -} - -bool Kernel2Ms::InitGraphInputsIndx(const KernelGraphPtr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - auto input_nodes = kernel_graph_ptr->inputs(); - if (input_nodes.empty()) { - return false; - } - for (const auto &input_node : input_nodes) { - if (input_node->isa()) { - if (!predict::utils::FindNodeInMap(node_indexs_, input_node)) { - // init every parameter node - node_indexs_[input_node.get()] = graph_index_; - graph_index_++; - } - } else { - MS_LOG(INFO) << "This node is anfnode, no need to handle, continue. node info: " << input_node->ToString(); - continue; - } - } - MS_LOG(DEBUG) << "inputs GraphIndex: " << graph_index_; - return true; -} - -bool Kernel2Ms::InitGraphValueNodesIndx(const KernelGraphPtr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - if (kernel_graph_ptr->value_nodes().empty()) { - return false; - } - for (auto &item : kernel_graph_ptr->value_nodes()) { - if (item.first->isa()) { - auto value_node = item.first->cast(); - MS_EXCEPTION_IF_NULL(value_node); - if (value_node == nullptr) { - MS_LOG(WARNING) << "value_node is nullptr"; - return false; - } - if (value_node->value() == nullptr) { - MS_LOG(ERROR) << "Constant value is null."; - return false; - } - if (!value_node->value()->isa()) { - continue; - } - if (!predict::utils::FindNodeInMap(node_indexs_, item.first)) { - // init node - auto node_ptr = item.first; - node_indexs_[node_ptr.get()] = graph_index_; - graph_index_++; - } - } - } - return true; -} - -bool Kernel2Ms::InitGraphOpsIndx(const KernelGraphPtr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - auto kernels = kernel_graph_ptr->execution_order(); - if (kernels.empty()) { - MS_LOG(WARNING) << "this graph has no kernel"; - return false; - } - for (size_t i = 0; i < kernels.size(); ++i) { - // for each kernel's inputs foreach real_input - if (kernels[i]->isa()) { - if (!predict::utils::FindNodeInMap(node_indexs_, kernels[i])) { - // init node - node_indexs_[kernels[i].get()] = graph_index_; - graph_index_++; - } - } - } - return true; -} - -bool Kernel2Ms::InitGraphOutputsIndx(const KernelGraphPtr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - // graph output && their inputs should link together - auto out_nodes = kernel_graph_ptr->outputs(); - if (out_nodes.empty()) { - MS_LOG(ERROR) << "this graph has no outputs"; - return false; - } - for (auto &item : out_nodes) { - if (!predict::utils::FindNodeInMap(node_indexs_, item)) { - node_indexs_[item.get()] = graph_index_; - graph_index_++; - } - } - return true; -} - -bool Kernel2Ms::InitGraphIndx(const KernelGraphPtr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - // only parameter - if (!InitGraphInputsIndx(kernel_graph_ptr)) { - return false; - } - // init value node - if (!InitGraphValueNodesIndx(kernel_graph_ptr)) { - return false; - } - // init op - if (!InitGraphOpsIndx(kernel_graph_ptr)) { - return false; - } - // init Graphoutput attention: out_put nodes have inputs - return InitGraphOutputsIndx(kernel_graph_ptr); -} - -bool Kernel2Ms::SetGraphInputTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache, - SubGraphDefT *ms_graph) { - MS_EXCEPTION_IF_NULL(tensor_cache); - MS_EXCEPTION_IF_NULL(ms_graph); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - if (convert_mode_ == kConvertUnused) { - return false; - } - if (kernel_graph_ptr->inputs().empty()) { - return false; - } - for (const auto &input_node : kernel_graph_ptr->inputs()) { - if (input_node->isa()) { - ParameterPtr pk_node = std::dynamic_pointer_cast(input_node); - TensorPtr device_tensor; - if (convert_mode_ == kConvertCpuMode) { - device_tensor = predict::utils::GetParaCpuTensor(input_node); - } else { - device_tensor = predict::utils::GetParaAscendTensor(input_node); - } - if (device_tensor == nullptr) { - return false; - } - ExTensorType node_type; - if (AnfAlgo::IsParameterWeight(pk_node)) { - node_type = WEIGHTS; - } else { - node_type = INPUTDATA; - } - if (!predict::utils::FindNodeInMap(node_indexs_, input_node)) { - MS_LOG(WARNING) << "can not find any pk_key in inited node_indexs map"; - return false; - } - auto pk_key = node_indexs_[input_node.get()]; - all_output_tensors_[pk_key].push_back(device_tensor); - int nodeRefCount = SizeToInt(AnfAlgo::GetOutputTensorNum(input_node)); - int nodeInputIdx = - tensor_cache->addExTensor(pk_key, device_tensor, nodeRefCount, device_tensor->shape(), node_type); - if (!AnfAlgo::IsParameterWeight(pk_node)) { - ms_graph->inputIndex.push_back(nodeInputIdx); - all_input_idxs_.push_back(nodeInputIdx); - } else { - input_weight_idxs_.push_back(nodeInputIdx); - all_input_idxs_.push_back(nodeInputIdx); - } - } - } - return true; -} - -bool Kernel2Ms::SetGraphValueTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(tensor_cache); - for (auto &item : kernel_graph_ptr->value_nodes()) { - if (item.first->isa()) { - auto const_node = item.first->cast(); - auto tensor_constant = predict::utils::GetValueTensor(const_node); - if (tensor_constant == nullptr) { - continue; - } - if (!predict::utils::FindNodeInMap(node_indexs_, item.first)) { - MS_LOG(WARNING) << "can not find any pk_key in inited node_indexs map"; - return false; - } - int constant_key = node_indexs_[(item.first).get()]; - all_output_tensors_[constant_key].push_back(tensor_constant); - auto shape = tensor_constant->shape(); - (void)tensor_cache->addExTensor(constant_key, tensor_constant, 0, shape, CONSTANT); - } - } - return true; -} - -bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache, - SubGraphDefT *ms_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(tensor_cache); - MS_EXCEPTION_IF_NULL(ms_graph); - auto kernels = kernel_graph_ptr->execution_order(); - if (kernels.empty()) { - MS_LOG(ERROR) << "this graph has no kernels"; - return false; - } - for (auto &kernel : kernels) { - if (!predict::utils::FindNodeInMap(node_indexs_, kernel)) { - MS_LOG(ERROR) << "can not find any pk_key in inited node_indexs map"; - return false; - } - auto kernel_key = node_indexs_[kernel.get()]; - std::unique_ptr ms_node(new OpDefT); - ms_node->name = kernel->fullname_with_scope(); - ms_node->fmkType = mindspore::predict::FmkType_CAFFE; - auto c_name = AnfAlgo::GetCNodeName(kernel); - auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name); - if (fun == nullptr) { - MS_LOG(WARNING) << "get node [" << kernel->fullname_with_scope() << "] attr failed."; - } else if (!fun(kernel, ms_node.get())) { - MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed."; - return false; - } - auto output_size = AnfAlgo::GetOutputTensorNum(kernel); - int nodeRefCount = SizeToInt(output_size); - for (size_t j = 0; j < output_size; ++j) { - TensorPtr device_tensor; - if (convert_mode_ == kConvertCpuMode) { - device_tensor = predict::utils::GetKernelCpuTensor(kernel, j); - } else if (convert_mode_ == kConvertAscendMode) { - device_tensor = predict::utils::GetKernelAscendTensor(kernel, j); - } - if (device_tensor == nullptr) { - return false; - } - all_output_tensors_[kernel_key].push_back(device_tensor); - if (!SetOpOutputIdx(kernel, device_tensor, tensor_cache, nodeRefCount, j, ms_node.get())) { - return false; - } - } - tmp_op_nodes_.emplace_back(ms_node.release()); - } - return true; -} - -bool Kernel2Ms::KernelGraph2MsGraph(const KernelGraphPtr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - graph_index_ = 0; - all_output_tensors_.clear(); - node_indexs_.clear(); - index_nodes_.clear(); - std::unique_ptr sub_ms_graph(new SubGraphDefT()); - if (!InitGraphIndx(kernel_graph_ptr)) { - return false; - } - TransformGraphIndx(); - tensor_cache_ptr_ = std::make_shared(); - // foreach node to init it's real output tensor - if (!SetGraphInputTensors(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get())) { - return false; - } - // Get KernelGraph value node - if (!SetGraphValueTensors(kernel_graph_ptr, tensor_cache_ptr_)) { - return false; - } - // Get KernelGraph apply_kernel && add opNode - if (!SetGraphOpTensors(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get())) { - return false; - } - // Get KernelGraph outputs - if (!SetGraphOutputIdx(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get(), &all_output_tensors_)) { - return false; - } - auto kernels = kernel_graph_ptr->execution_order(); - for (size_t i = 0; i < kernels.size(); ++i) { - auto ms_node = tmp_op_nodes_[i]; - if (!SetOpInputIdx(kernels[i], tensor_cache_ptr_, ms_node)) { - return false; - } - std::unique_ptr ms_node_tmp(ms_node); - sub_ms_graph->nodes.emplace_back(std::move(ms_node_tmp)); - } - if (!SetAllTensors(tensor_cache_ptr_, sub_ms_graph.get())) { - return false; - } - if (!SetMemResue()) { - return false; - } - sub_ms_graph_ = std::move(sub_ms_graph); - sub_ms_graph_->name = "default_sub_graph"; - return true; -} - -bool Kernel2Ms::CheckInputSizes(const std::vector &input_tensors, - const std::vector &all_input_idxs) { - if (input_tensors.size() != all_input_idxs.size()) { - MS_LOG(EXCEPTION) << "real input tensors size:" << input_tensors.size() - << "not equal converted tesnors size:" << all_input_idxs.size() << "the graph has changed"; - } - for (auto in : all_input_idxs) { - if (in < sub_ms_graph_->allTensors.size()) { - auto real_tensor = input_tensors[in]; - auto convert_dims = sub_ms_graph_->allTensors[in]->dims; - auto real_dims = real_tensor->shape(); - if (real_dims.size() != convert_dims.size()) { - return false; - } else { - for (size_t i = 0; i < convert_dims.size(); ++i) { - if (convert_dims[i] != real_dims[i]) { - return false; - } - } - } - } else { - MS_LOG(EXCEPTION) << "index: " << in << "in all_input_idxs is valid"; - } - } - return true; -} - -void Kernel2Ms::ReleaseContextRes() { - tmp_op_nodes_.clear(); - node_indexs_.clear(); - index_nodes_.clear(); - tensor_cache_ptr_ = nullptr; - all_output_tensors_.clear(); -} - -bool Kernel2Ms::KernelInput2MS(const std::vector &input_tensors) { - const std::unordered_map> &cache_tensors = tensor_cache_ptr_->GetCachedTensor(); - if (cache_tensors.empty()) { - return false; - } - auto all_weights_idxs = GetAllInputWeightIdxs(); - auto all_input_idxs = GetAllInputIdxs(); - auto real_input_size = input_tensors.size(); - // check tensor size - bool ret = CheckInputSizes(input_tensors, all_input_idxs); - std::vector match_to_rel_idxs; - // indx order not matched,macth to it - if (!ret) { - for (auto idx : all_weights_idxs) { - auto macth_idx = real_input_size - idx; - match_to_rel_idxs.push_back(macth_idx); - } - } else { - match_to_rel_idxs = all_weights_idxs; - } - if (match_to_rel_idxs.size() == all_weights_idxs.size()) { - for (size_t j = 0; j < all_weights_idxs.size(); ++j) { - auto cache_idx = all_weights_idxs[j]; - auto match_idx = match_to_rel_idxs[j]; - auto real_tensor = input_tensors[match_idx]; - auto real_size = LongToSize(real_tensor->data().nbytes()); - auto real_data = real_tensor->data_c(); - MS_EXCEPTION_IF_NULL(real_data); - if (sub_ms_graph_->allTensors[cache_idx] != nullptr) { - sub_ms_graph_->allTensors[cache_idx]->data.resize(real_size); - } - if (memcpy_s(sub_ms_graph_->allTensors[cache_idx]->data.data(), real_size, real_data, real_size) != 0) { - MS_LOG(ERROR) << "KernelInput2MS memcpy_s failed"; - return false; - } - } - } - ReleaseContextRes(); - return true; -} - -bool Kernel2Ms::SaveDeviceModel(const std::shared_ptr &new_ms_graph_ptr, const std::string &save_path_name) { - MS_EXCEPTION_IF_NULL(new_ms_graph_ptr); - return predict::utils::SaveDeviceModelUtil(new_ms_graph_ptr, save_path_name, sub_ms_graph_.release()); -} -} // namespace executor -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/kernel2ms.h b/mindspore/ccsrc/predict/converter/kernel2ms.h deleted file mode 100644 index 8cbc89ed6a..0000000000 --- a/mindspore/ccsrc/predict/converter/kernel2ms.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_KERNEL_TO_MS_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_KERNEL_TO_MS_H_ - -#include -#include -#include -#include -#include -#include "backend/session/kernel_graph.h" -#include "predict/converter/executor_tensor.h" -#include "predict/schema/inner/ms_generated.h" -#include "predict/converter/attr_utils/convert_util.h" - -static constexpr size_t kTupleGetItemIndex = 2; -namespace mindspore { -namespace executor { -using KernelGraphPtr = std::shared_ptr; -enum ConvertMode { kConvertCpuMode, kConvertAscendMode, kConvertUnused }; -enum TargetMode { kCPUTarget, kGPUTarget, kUnknowTarget }; -class Kernel2Ms { - public: - static Kernel2Ms &GetInstance(); - - Kernel2Ms(const Kernel2Ms &) = delete; - - Kernel2Ms &operator=(const Kernel2Ms &) = delete; - - bool KernelGraph2MsGraph(const KernelGraphPtr &kernel_graph_ptr); - - bool KernelInput2MS(const std::vector &input_tensors); - - ConvertMode convert_mode() const { return convert_mode_; } - - void set_convert_mode(ConvertMode convert_mode) { convert_mode_ = convert_mode; } - - TargetMode device_target() const { return device_target_; } - - void set_device_target(TargetMode device_target) { device_target_ = device_target; } - - bool SaveDeviceModel(const std::shared_ptr &new_ms_graph_ptr, const std::string &save_path_name); - - private: - Kernel2Ms() : graph_index_(0) {} - - void ReleaseContextRes(); - - ~Kernel2Ms() = default; - - bool SetAllTensors(const TensorCachePtr &tensor_cache, SubGraphDefT *sub_graph_def_t); - - bool SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, OpDefT *ms_node); - - bool SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &output_tensor, const TensorCachePtr &tensor_cache, - int ref_count, size_t order_index, OpDefT *ms_node); - - bool SetGraphOutputIdx(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache, - SubGraphDefT *sub_graph_def_t, AllOutputTensors *all_output_tensors); - - void TransformGraphIndx(); - - void GetRealInpoutsPtr(const AnfNodePtr &node, std::vector *real_inputs, - std::vector *real_output_idx); - - bool InitGraphIndx(const KernelGraphPtr &kernel_graph_ptr); - - bool InitGraphInputsIndx(const KernelGraphPtr &kernel_graph_ptr); - - bool InitGraphValueNodesIndx(const KernelGraphPtr &kernel_graph_ptr); - - bool InitGraphOpsIndx(const KernelGraphPtr &kernel_graph_ptr); - - bool InitGraphOutputsIndx(const KernelGraphPtr &kernel_graph_ptr); - - bool SetGraphInputTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache, - SubGraphDefT *sub_graph_def_t); - - bool SetGraphValueTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache); - - bool SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache, - SubGraphDefT *sub_graph_def_t); - std::vector GetAllInputWeightIdxs() const { return input_weight_idxs_; } - std::vector GetAllInputIdxs() const { return all_input_idxs_; } - - bool CheckInputSizes(const std::vector &input_tensors, const std::vector &all_input_idxs); - - bool SetMemResue() const; - SubGraphPtr sub_ms_graph_; - AllOutputTensors all_output_tensors_; - std::vector tmp_op_nodes_; - std::unordered_map node_indexs_; - std::unordered_map index_nodes_; - int graph_index_ = 0; - TensorCachePtr tensor_cache_ptr_ = nullptr; - ConvertMode convert_mode_ = kConvertCpuMode; - TargetMode device_target_ = kCPUTarget; - std::vector input_weight_idxs_; - std::vector all_input_idxs_; -}; -using Kernel2MsPtr = std::shared_ptr; -} // namespace executor -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_KERNEL_TO_MS_H_ diff --git a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc deleted file mode 100644 index 52648812be..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" -#include "./securec.h" - -namespace mindspore { -namespace predict { -namespace convert { -// forward declare -bool Conv2dPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool MatMulPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool BiasAddPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool ReshapePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool ActivationPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool PoolingPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool FusedBatchNormPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool AddPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool CastPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool MeanPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool SoftmaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool ScalePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool AddFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool ArgMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool BatchNormFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool FakeQuantWithMinMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool FakeQuantWithMinMaxPerChannelPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool MulPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool MulFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); -bool SqueezePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op); - -OpAttrFactory::OpAttrFactory() { - pack_funs_ = {{"Conv2D", Conv2dPacker}, - {"MatMul", MatMulPacker}, - {"BiasAdd", BiasAddPacker}, - {"Reshape", ReshapePacker}, - {"Activation", ActivationPacker}, - {"ReLU", ActivationPacker}, - {"ReLU6", ActivationPacker}, - {"EReLU", ActivationPacker}, - {"LeakyReLU", ActivationPacker}, - {"Sigmoid", ActivationPacker}, - {"Softsign", ActivationPacker}, - {"Softplus", ActivationPacker}, - {"Tanh", ActivationPacker}, - {"HSwish", ActivationPacker}, - {"HSigmoid", ActivationPacker}, - {"MaxPool", PoolingPacker}, - {"MaxPool2D", PoolingPacker}, - {"MeanPool", PoolingPacker}, - {"GlobalPool", PoolingPacker}, - {"FusedBatchNorm", FusedBatchNormPacker}, - {"FusedBatchNormGrad", FusedBatchNormPacker}, - {"Cast", CastPacker}, - {"TensorAdd", AddPacker}, - {"SoftMax", SoftmaxPacker}, - {"SimpleMean", MeanPacker}, - {"ReduceMean", MeanPacker}, - {"AddFold", AddFoldPacker}, - {"ArgMax", ArgMaxPacker}, - {"BatchNorm", BatchNormFoldPacker}, - {"FakeQuantPerLayer", FakeQuantWithMinMaxPacker}, - {"FakeQuantPerChannel", FakeQuantWithMinMaxPerChannelPacker}, - {"Mul", MulPacker}, - {"MulFold", MulFoldPacker}, - {"Squeeze", SqueezePacker}}; -} -OpAttrPackFun OpAttrFactory::GetPackFun(const std::string &opType) { - if (pack_funs_.find(opType) == pack_funs_.end()) { - MS_LOG(WARNING) << "Op Attr pack fun [" << opType << "] not found."; - return nullptr; - } - return pack_funs_[opType]; -} - -mindspore::predict::Format GetAttrFormat(const std::string &format) { - if (format == kOpFormat_NCHW) { - return predict::Format::Format_NCHW; - } else if (format == kOpFormat_NHWC) { - return predict::Format::Format_NHWC; - } else { - return predict::Format::Format_NUM_OF_FORMAT; - } -} - -mindspore::predict::PadMode GetAttrPadMode(const std::string &pad_mode) { - if (pad_mode == "same") { - return mindspore::predict::PadMode::PadMode_SAME; - } else if (pad_mode == "valid") { - return mindspore::predict::PadMode::PadMode_VALID; - } else { - return mindspore::predict::PadMode::PadMode_NOTSET; - } -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h deleted file mode 100644 index 31f14ef73a..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_OP_ATTR_PACKER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_OP_ATTR_PACKER_H_ - -#include -#include -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "predict/schema/inner/ms_generated.h" - -static constexpr size_t kNIndex = 0; -static constexpr size_t kCIndex = 1; -static constexpr size_t kHIndex = 2; -static constexpr size_t kWIndex = 3; -static constexpr size_t kNCHWSize = 4; -namespace mindspore { -namespace predict { -namespace convert { -using OpAttrPackFun = bool (*)(const CNodePtr &c_node_ptr, OpDefT *ms_op); -class OpAttrFactory { - public: - static OpAttrFactory *GetInstance() { - static OpAttrFactory instance; - return &instance; - } - OpAttrFactory(const OpAttrFactory &) = delete; - OpAttrFactory &operator=(const OpAttrFactory &) = delete; - OpAttrPackFun GetPackFun(const std::string &op_type); - ~OpAttrFactory() { pack_funs_.clear(); } - OpAttrFactory(); - - private: - std::unordered_map pack_funs_; -}; - -mindspore::predict::Format GetAttrFormat(const std::string &format); - -mindspore::predict::PadMode GetAttrPadMode(const std::string &pad_mode); -} // namespace convert -} // namespace predict -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_CONVERTER_CPU_OP_INFO_OP_ATTR_FACTORY_H_ diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/activation_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/activation_packer.cc deleted file mode 100644 index 3dc09f70b4..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/activation_packer.cc +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool ActivationPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new ActivationT()); - MS_EXCEPTION_IF_NULL(attr); - if (AnfAlgo::GetCNodeName(c_node_ptr) == "ReLU") { - attr->type = predict::ActivationType::ActivationType_RELU; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "Sigmoid") { - attr->type = predict::ActivationType::ActivationType_SIGMOID; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "ReLU6") { - attr->type = predict::ActivationType::ActivationType_RELU6; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "ELU") { - attr->type = predict::ActivationType::ActivationType_ELU; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "Leaky_ReLU") { - attr->type = predict::ActivationType::ActivationType_LEAKY_RELU; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "ABS") { - attr->type = predict::ActivationType::ActivationType_ABS; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "ReLU1") { - attr->type = predict::ActivationType::ActivationType_RELU1; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "Softsign") { - attr->type = predict::ActivationType::ActivationType_SOFTSIGN; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "Softplus") { - attr->type = predict::ActivationType::ActivationType_SOFTPLUS; - } else if (AnfAlgo::GetCNodeName(c_node_ptr) == "Tanh") { - attr->type = predict::ActivationType::ActivationType_TANH; - } else { - attr->type = predict::ActivationType::ActivationType_UNKNOW; - MS_LOG(WARNING) << "unknow Activation"; - } - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Activation; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/add_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/add_packer.cc deleted file mode 100644 index 02a9bda65e..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/add_packer.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool AddPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new AddT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Add; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/addfold_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/addfold_packer.cc deleted file mode 100644 index b6affd5001..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/addfold_packer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool AddFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new AddFoldT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->attr.type = OpT_AddFold; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/argmax_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/argmax_packer.cc deleted file mode 100644 index 4df643704c..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/argmax_packer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool ArgMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new ArgMaxT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->attr.type = OpT_ArgMax; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/batchnormfold_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/batchnormfold_packer.cc deleted file mode 100644 index f05f3894be..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/batchnormfold_packer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool BatchNormFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new BatchNormFoldT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->attr.type = OpT_BatchNormFold; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/biasadd_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/biasadd_packer.cc deleted file mode 100644 index 6fe32c1f6b..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/biasadd_packer.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool BiasAddPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new BiasAddT()); - MS_EXCEPTION_IF_NULL(attr); - attr->axis = {1}; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_BiasAdd; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/cast_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/cast_packer.cc deleted file mode 100644 index d0f3f80f6c..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/cast_packer.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool CastPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new CastT()); - MS_EXCEPTION_IF_NULL(attr); - attr->srcT = 0; - attr->dstT = 0; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Cast; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/conv2d_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/conv2d_packer.cc deleted file mode 100644 index 176b235f5f..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/conv2d_packer.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool Conv2dPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - int kernel_group_value = AnfAlgo::GetNodeAttr(c_node_ptr, "group"); - int kernel_channel_value = AnfAlgo::GetNodeAttr(c_node_ptr, "out_channel"); - std::vector kernel_size_value = AnfAlgo::GetNodeAttr>(c_node_ptr, "kernel_size"); - std::string kernel_pad_mode_value = AnfAlgo::GetNodeAttr(c_node_ptr, "pad_mode"); - int kernel_pad_value = AnfAlgo::GetNodeAttr(c_node_ptr, "pad"); - auto kernel_stride_value = AnfAlgo::GetNodeAttr>(c_node_ptr, "stride"); - auto kernel_dilation_value = AnfAlgo::GetNodeAttr>(c_node_ptr, "dilation"); - std::string kernel_data_format_value = AnfAlgo::GetNodeAttr(c_node_ptr, "data_format"); - std::unique_ptr attr(new Conv2DT()); - MS_EXCEPTION_IF_NULL(attr); - attr->format = GetAttrFormat(kernel_data_format_value); - attr->group = kernel_group_value; - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(c_node_ptr, 1); - if (in_shape.size() != kNCHWSize) { - return false; - } - attr->channelIn = SizeToInt(in_shape[1]); - attr->channelOut = kernel_channel_value; - attr->kernelW = kernel_size_value[0]; - attr->kernelH = kernel_size_value[1]; - attr->strideW = kernel_stride_value[0]; - attr->strideH = kernel_stride_value[1]; - attr->padMode = GetAttrPadMode(kernel_pad_mode_value); - attr->padUp = kernel_pad_value; - attr->padDown = kernel_pad_value; - attr->padLeft = kernel_pad_value; - attr->padRight = kernel_pad_value; - attr->dilateW = kernel_dilation_value[0]; - attr->dilateH = kernel_dilation_value[1]; - attr->hasBias = false; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Conv2D; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/fakequantwithminmax_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/fakequantwithminmax_packer.cc deleted file mode 100644 index 195a4fde9f..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/fakequantwithminmax_packer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool FakeQuantWithMinMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new FakeQuantWithMinMaxT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->attr.type = OpT_FakeQuantWithMinMax; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/fakequantwithminmaxperchannel_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/fakequantwithminmaxperchannel_packer.cc deleted file mode 100644 index 0074c87646..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/fakequantwithminmaxperchannel_packer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool FakeQuantWithMinMaxPerChannelPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new FakeQuantWithMinMaxPerChannelT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->attr.type = OpT_FakeQuantWithMinMaxPerChannel; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/fusedbatchnorm_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/fusedbatchnorm_packer.cc deleted file mode 100644 index e0092820c2..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/fusedbatchnorm_packer.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool FusedBatchNormPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new FusedBatchNormT()); - MS_EXCEPTION_IF_NULL(attr); - auto kernel_epsilon = AnfAlgo::GetNodeAttr(c_node_ptr, "epsilon"); - attr->epsilon = kernel_epsilon; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_FusedBatchNorm; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/matmul_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/matmul_packer.cc deleted file mode 100644 index a0f82810a7..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/matmul_packer.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool MatMulPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - bool kernel_transpore_a = AnfAlgo::GetNodeAttr(c_node_ptr, "transpose_a"); - bool kernel_transpore_b = AnfAlgo::GetNodeAttr(c_node_ptr, "transpose_b"); - std::unique_ptr attr(new MatMulT()); - MS_EXCEPTION_IF_NULL(attr); - attr->transposeA = kernel_transpore_a; - attr->transposeB = kernel_transpore_b; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_MatMul; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/mean_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/mean_packer.cc deleted file mode 100644 index eac3fa88f1..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/mean_packer.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool MeanPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new MeanT()); - MS_EXCEPTION_IF_NULL(attr); - attr->axis = {1}; - attr->keepDims = false; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Mean; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/mul_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/mul_packer.cc deleted file mode 100644 index 6c430e79e7..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/mul_packer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool MulPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new MulT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->attr.type = OpT_Mul; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/mulflod_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/mulflod_packer.cc deleted file mode 100644 index 1df7204875..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/mulflod_packer.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool MulFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new MulFoldT()); - MS_EXCEPTION_IF_NULL(attr); - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_MulFold; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/pooling_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/pooling_packer.cc deleted file mode 100644 index edfdcda040..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/pooling_packer.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool PoolingPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new PoolingT()); - MS_EXCEPTION_IF_NULL(attr); - std::string kernel_format_value = AnfAlgo::GetNodeAttr(c_node_ptr, "data_format"); - attr->format = GetAttrFormat(kernel_format_value); - auto c_name = AnfAlgo::GetCNodeName(c_node_ptr); - if (c_name == "MaxPool") { - ms_op->name = c_node_ptr->fullname_with_scope(); - attr->poolingMode = mindspore::predict::PoolMode::PoolMode_MAX_POOLING; - } else if (c_name == "MeanPool") { - ms_op->name = c_node_ptr->fullname_with_scope(); - attr->poolingMode = mindspore::predict::PoolMode::PoolMode_MEAN_POOLING; - } else if (c_name == "GlobalPool") { - ms_op->name = c_node_ptr->fullname_with_scope(); - } else { - MS_LOG(ERROR) << "unknowed pooling type."; - return false; - } - std::vector kernel_ksize = AnfAlgo::GetNodeAttr>(c_node_ptr, "ksize"); - attr->windowW = kernel_ksize[kHIndex]; - attr->windowH = kernel_ksize[kWIndex]; - std::vector kernel_strides = AnfAlgo::GetNodeAttr>(c_node_ptr, "strides"); - attr->strideW = kernel_strides[kHIndex]; - attr->strideH = kernel_strides[kWIndex]; - std::string kernel_pad_mode_value = AnfAlgo::GetNodeAttr(c_node_ptr, "padding"); - attr->padMode = GetAttrPadMode(kernel_pad_mode_value); - attr->padUp = 0; - attr->padDown = 0; - attr->padLeft = 0; - attr->padRight = 0; - ms_op->attr.type = OpT_Pooling; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/reshape_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/reshape_packer.cc deleted file mode 100644 index a0a263631d..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/reshape_packer.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool ReshapePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new ReshapeT()); - MS_EXCEPTION_IF_NULL(attr); - attr->format = predict::Format::Format_NCHW; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Reshape; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/scale_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/scale_packer.cc deleted file mode 100644 index 356775247d..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/scale_packer.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool ScalePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new ScaleT()); - MS_EXCEPTION_IF_NULL(attr); - attr->format = predict::Format::Format_NCHW; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_Scale; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/softmax_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/softmax_packer.cc deleted file mode 100644 index fe96bae451..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/softmax_packer.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool SoftmaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new SoftMaxT()); - MS_EXCEPTION_IF_NULL(attr); - attr->axis = {1}; - ms_op->name = c_node_ptr->fullname_with_scope(); - ms_op->attr.type = OpT_SoftMax; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/converter/lite_model/operations/squeeze_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/operations/squeeze_packer.cc deleted file mode 100644 index 7e836fe021..0000000000 --- a/mindspore/ccsrc/predict/converter/lite_model/operations/squeeze_packer.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/converter/lite_model/op_attr_packer.h" - -namespace mindspore { -namespace predict { -namespace convert { -bool SqueezePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) { - if (c_node_ptr == nullptr || ms_op == nullptr) { - return false; - } - std::unique_ptr attr(new SqueezeT()); - MS_EXCEPTION_IF_NULL(attr); - - std::vector kernel_axis_value = AnfAlgo::GetNodeAttr>(c_node_ptr, "axis"); - attr->axis = kernel_axis_value; - - ms_op->attr.type = OpT_Squeeze; - ms_op->attr.value = attr.release(); - return true; -} -} // namespace convert -} // namespace predict -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/generator/ir/ir_model.cc b/mindspore/ccsrc/predict/generator/ir/ir_model.cc deleted file mode 100644 index ff46524577..0000000000 --- a/mindspore/ccsrc/predict/generator/ir/ir_model.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/generator/ir/ir_model.h" - -#include -#include - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace generator { -IRModel::~IRModel() { ir_tasks_.clear(); } -void IRModel::SetIrTaskInfos(const std::vector &ir_tasks) { - (void)std::copy(ir_tasks.begin(), ir_tasks.end(), std::back_inserter(ir_tasks_)); -} -} // namespace generator -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/generator/ir/ir_model.h b/mindspore/ccsrc/predict/generator/ir/ir_model.h deleted file mode 100644 index 82bd2aad3f..0000000000 --- a/mindspore/ccsrc/predict/generator/ir/ir_model.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_EXECUTOR_GENERATOR_IR_IR_MODEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_EXECUTOR_GENERATOR_IR_IR_MODEL_H_ -#include -#include -#include -#include "predict/generator/ir/ir_task_info.h" -namespace mindspore { -namespace generator { -class IRModel { - public: - void SetIrTaskInfos(const std::vector &ir_tasks); - IRModel() = default; - ~IRModel(); - - private: - std::vector ir_tasks_; -}; -using IrModelPtr = std::shared_ptr; -} // namespace generator -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_EXECUTOR_GENERATOR_IR_IR_MODEL_H_ diff --git a/mindspore/ccsrc/predict/generator/ir/ir_task_info.cc b/mindspore/ccsrc/predict/generator/ir/ir_task_info.cc deleted file mode 100644 index 1c275ea8ed..0000000000 --- a/mindspore/ccsrc/predict/generator/ir/ir_task_info.cc +++ /dev/null @@ -1,244 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/generator/ir/ir_task_info.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace generator { -bool CceIRTaskInfo::SerializeIRToProto() { - auto cce_task_def_ptr = std::unique_ptr(); - auto kernel_context_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(cce_task_def_ptr); - MS_EXCEPTION_IF_NULL(kernel_context_ptr); - kernel_context_ptr->set_kernel_type(k_ctx_.kernel_type); - kernel_context_ptr->set_op_id(k_ctx_.op_id); - kernel_context_ptr->set_kernel_func_id(k_ctx_.kernel_func_id); - kernel_context_ptr->set_op_index(k_ctx_.op_index); - kernel_context_ptr->set_is_flowtable(k_ctx_.is_flowtable); - kernel_context_ptr->set_args_count(k_ctx_.args_count); - for (unsigned int i : k_ctx_.origin_op_index) { - kernel_context_ptr->add_origin_op_index(i); - } - void *tmp_args_offset = static_cast((k_ctx_.args_offset).data()); - if (tmp_args_offset == nullptr) { - MS_LOG(WARNING) << "tmp_args_offset have no data"; - return false; - } - kernel_context_ptr->set_args_offset(tmp_args_offset, k_ctx_.args_offset.size()); - cce_task_def_ptr->set_allocated_kernel_context(std::move(kernel_context_ptr).get()); - cce_task_def_ptr->set_stub_func(stub_func_); - cce_task_def_ptr->set_block_dim(block_dim_); - cce_task_def_ptr->set_args_size(args_size_); - void *tmp_sm_desc = static_cast(sm_desc_.data()); - if (tmp_sm_desc == nullptr) { - MS_LOG(WARNING) << "tmp_sm_desc have no data"; - return false; - } - cce_task_def_ptr->set_sm_desc(tmp_sm_desc, sm_desc_.size()); - - void *tmp_flow_table = static_cast(flow_table_.data()); - if (tmp_flow_table == nullptr) { - MS_LOG(WARNING) << "tmp_flow_table have no data"; - return false; - } - cce_task_def_ptr->set_flow_table(tmp_flow_table, flow_table_.size()); - return true; -} - -CceIRTaskInfo::~CceIRTaskInfo() { - args_.clear(); - sm_desc_.clear(); - flow_table_.clear(); -} - -bool TbeIRTaskInfo::SerializeIRToProto() { - auto tbe_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(tbe_task_def_ptr); - tbe_task_def_ptr->set_stub_func(stub_func_); - tbe_task_def_ptr->set_block_dim(block_dim_); - tbe_task_def_ptr->set_args_size(args_size_); - void *tmp_args = static_cast(args_.data()); - if (tmp_args == nullptr) { - MS_LOG(WARNING) << "tmp_args have no data"; - return false; - } - tbe_task_def_ptr->set_args(tmp_args, args_.size()); - void *tmp_sm_desc = static_cast(sm_desc_.data()); - if (tmp_sm_desc == nullptr) { - MS_LOG(WARNING) << "tmp_sm_desc have no data"; - return false; - } - tbe_task_def_ptr->set_sm_desc(tmp_sm_desc, sm_desc_.size()); - void *tmp_meta_data = static_cast(meta_data_.data()); - if (tmp_meta_data == nullptr) { - MS_LOG(WARNING) << "tmp_meta_data have no data"; - return false; - } - tbe_task_def_ptr->set_meta_data(tmp_meta_data, meta_data_.size()); - for (auto &in : input_data_addrs_) { - tbe_task_def_ptr->add_input_addrs(in); - } - for (auto &ou : output_data_addrs_) { - tbe_task_def_ptr->add_output_addrs(ou); - } - for (auto &wk : workspace_addrs_) { - tbe_task_def_ptr->add_workspace_addrs(wk); - } - return true; -} - -TbeIRTaskInfo::~TbeIRTaskInfo() { - args_.clear(); - sm_desc_.clear(); - meta_data_.clear(); - input_data_addrs_.clear(); - output_data_addrs_.clear(); - workspace_addrs_.clear(); -} - -bool AicpuIRTaskInfo::SerializeIRToProto() { - auto aicpu_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(aicpu_task_def_ptr); - aicpu_task_def_ptr->set_op_type(op_type_); - aicpu_task_def_ptr->set_flag(flag_); - for (auto &shape : input_data_shapes_) { - auto in_shape_ptr = aicpu_task_def_ptr->add_input_shapes(); - for (auto &in_sh : shape) { - in_shape_ptr->add_shape(static_cast(in_sh)); - } - } - for (auto &shape : output_data_shapes_) { - auto ou_shape_ptr = aicpu_task_def_ptr->add_output_shapes(); - for (auto &ou_sh : shape) { - ou_shape_ptr->add_shape(static_cast(ou_sh)); - } - } - for (auto &in_type : input_data_types_) { - aicpu_task_def_ptr->add_input_types(in_type); - } - for (auto &ou_type : output_data_types_) { - aicpu_task_def_ptr->add_output_types(ou_type); - } - for (auto &in_addr : input_data_addrs_) { - aicpu_task_def_ptr->add_input_addrs(in_addr); - } - for (auto &ou_addr : output_data_addrs_) { - aicpu_task_def_ptr->add_output_addrs(ou_addr); - } - void *tmp_node_def = static_cast(node_def_.data()); - if (tmp_node_def == nullptr) { - MS_LOG(WARNING) << "tmp_node_def have no data"; - return false; - } - aicpu_task_def_ptr->set_node_def(tmp_node_def, node_def_.size()); - void *tmp_func_def = static_cast(func_def_.data()); - if (tmp_func_def == nullptr) { - MS_LOG(WARNING) << "tmp_func_def have no data"; - return false; - } - aicpu_task_def_ptr->set_func_def(tmp_func_def, func_def_.size()); - return true; -} - -AicpuIRTaskInfo::~AicpuIRTaskInfo() { - input_data_types_.clear(); - input_data_shapes_.clear(); - input_data_addrs_.clear(); - output_data_types_.clear(); - output_data_shapes_.clear(); - output_data_addrs_.clear(); - node_def_.clear(); - func_def_.clear(); -} - -bool LabelIRTaskInfo::SerializeIRToProto() { - auto label_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(label_task_def_ptr); - label_task_def_ptr->set_label_id(label_id_); - return true; -} - -bool EventIRTaskInfo::SerializeIRToProto() { - auto event_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(event_task_def_ptr); - event_task_def_ptr->set_event_id(event_id_); - return true; -} - -bool HcclIRTaskInfo::SerializeIRToProto() { - auto hccl_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(hccl_task_def_ptr); - hccl_task_def_ptr->set_hccl_type(hccl_type_); - hccl_task_def_ptr->set_input_addr(input_data_addr_); - hccl_task_def_ptr->set_output_addr(output_data_addr_); - auto tmp_wk = static_cast(workspace_.data()); - hccl_task_def_ptr->set_workspace(tmp_wk, workspace_.size()); - hccl_task_def_ptr->set_workspace_num(workspace_num_); - auto tmp_pri_def = static_cast(private_def_.data()); - hccl_task_def_ptr->set_private_def(tmp_pri_def, private_def_.size()); - hccl_task_def_ptr->set_ops_kernel_store(ops_kernel_store_); - hccl_task_def_ptr->set_count(count_); - hccl_task_def_ptr->set_root_id(root_id_); - hccl_task_def_ptr->set_op_type(op_type_); - hccl_task_def_ptr->set_data_type(data_type_); - return true; -} - -HcclIRTaskInfo::~HcclIRTaskInfo() { - workspace_.clear(); - private_def_.clear(); -} - -bool ProfilerIRTaskInfo::SerializeIRToProto() { - auto profiler_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(profiler_task_def_ptr); - profiler_task_def_ptr->set_log_id(log_id_); - profiler_task_def_ptr->set_flat(flat_); - profiler_task_def_ptr->set_notify(notify_); - return true; -} - -bool MemcpyAsyncIRTaskInfo::SerializeIRToProto() { - auto mem_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(mem_task_def_ptr); - mem_task_def_ptr->set_dst(dst_); - mem_task_def_ptr->set_dst_max(dst_max_); - mem_task_def_ptr->set_src(src_); - mem_task_def_ptr->set_count(count_); - mem_task_def_ptr->set_kind(kind_); - return true; -} - -bool StreamSwitchIRTaskInfo::SerializeIRToProto() { - auto stream_switch_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(stream_switch_task_def_ptr); - stream_switch_task_def_ptr->set_true_stream_id(true_stream_id_); - stream_switch_task_def_ptr->set_input_addr(input_addr_); - stream_switch_task_def_ptr->set_value_addr(value_addr_); - stream_switch_task_def_ptr->set_cond(cond_); - stream_switch_task_def_ptr->set_data_type(data_type_); - return true; -} - -bool StreamActiveIRTaskInfo::SerializeIRToProto() { - auto stream_active_task_def_ptr = std::unique_ptr(); - MS_EXCEPTION_IF_NULL(stream_active_task_def_ptr); - stream_active_task_def_ptr->set_active_stream_id(active_stream_id_); - return true; -} -} // namespace generator -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/generator/ir/ir_task_info.h b/mindspore/ccsrc/predict/generator/ir/ir_task_info.h deleted file mode 100644 index 4b3ac85ea6..0000000000 --- a/mindspore/ccsrc/predict/generator/ir/ir_task_info.h +++ /dev/null @@ -1,295 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_EXECUTOR_GENERATOR_IR_IR_TASK_H_ -#define MINDSPORE_MINDSPORE_CCSRC_EXECUTOR_GENERATOR_IR_IR_TASK_H_ -#include -#include -#include -#include -#include -#include "proto/ge_runtime_taskinfo.pb.h" - -namespace mindspore { -namespace generator { -using TaskType = ::ge::model_runner::TaskDef_TaskType; -enum TaskTmpType { - CCE_TMP_DEF = 0, - TBE_TMP_DEF = 1, - AICPU_TMP_DEF = 2, - LABEL_TMP_DEF = 3, - EVENT_TMP_DEF = 4, - HCCL_TMP_DEF = 5, - PROFILER_TRACE_TMP_DEF = 6, - MEMCPY_ASYNC_TMP_DEF = 7, - STREAM_SWITCH_TMP_DEF = 8, - STREAM_ACTIVE_TMP_DEF = 9 -}; - -struct KernelContext { - uint32_t kernel_type = 0; - uint32_t op_id = 0; - uint32_t kernel_func_id = 0; - uint32_t op_index = 0; - bool is_flowtable = false; - std::vector args_offset; - uint32_t args_count = 0; - std::vector origin_op_index; -}; - -class IRtaskInfo { - public: - virtual ~IRtaskInfo() = default; - virtual bool SerializeIRToProto() = 0; - - protected: - IRtaskInfo(TaskType task_type, TaskTmpType task_tmp_type, uint64_t stream_id) - : task_type_(task_type), task_tmp_type_(task_tmp_type), stream_id_(stream_id) {} - - public: - uint64_t GetStreamId() const { return stream_id_; } - TaskType GetTaskType() const { return task_type_; } - TaskTmpType GetTaskTmpType() const { return task_tmp_type_; } - - private: - TaskType task_type_; - TaskTmpType task_tmp_type_; - uint64_t stream_id_ = 0; -}; - -using IRtaskInfoPtr = std::shared_ptr; - -class CceIRTaskInfo : public IRtaskInfo { - public: - CceIRTaskInfo(TaskType task_type, uint64_t stream_id, KernelContext k_ctx, std::string stub_func, uint32_t block_dim, - std::vector args, uint32_t args_size, std::vector sm_desc, - std::vector flow_table) - : IRtaskInfo(task_type, CCE_TMP_DEF, stream_id), - k_ctx_(std::move(k_ctx)), - stub_func_(std::move(stub_func)), - block_dim_(block_dim), - args_(std::move(args)), - args_size_(args_size), - sm_desc_(std::move(sm_desc)), - flow_table_(std::move(flow_table)) {} - ~CceIRTaskInfo() override; - bool SerializeIRToProto() override; - - private: - KernelContext k_ctx_; - std::string stub_func_; - uint32_t block_dim_ = 0; - std::vector args_; - // uintptr_t args_addr_; - uint32_t args_size_ = 0; - std::vector sm_desc_; - std::vector flow_table_; -}; - -class TbeIRTaskInfo : public IRtaskInfo { - public: - TbeIRTaskInfo(TaskType task_type, uint64_t stream_id, std::string stub_func, uint32_t block_dim, - std::vector args, uint32_t args_size, std::vector sm_desc, - std::vector meta_data, std::vector input_data_addrs, - std::vector output_data_addrs, std::vector workspace_addrs) - : IRtaskInfo(task_type, TBE_TMP_DEF, stream_id), - stub_func_(std::move(stub_func)), - block_dim_(block_dim), - args_(std::move(args)), - args_size_(args_size), - sm_desc_(std::move(sm_desc)), - meta_data_(std::move(meta_data)), - input_data_addrs_(std::move(input_data_addrs)), - output_data_addrs_(std::move(output_data_addrs)), - workspace_addrs_(std::move(workspace_addrs)) {} - ~TbeIRTaskInfo() override; - bool SerializeIRToProto() override; - - private: - std::string stub_func_; - uint32_t block_dim_ = 0; - std::vector args_; - uint32_t args_size_ = 0; - std::vector sm_desc_; - // uintptr_t binary_; - // uint32_t binary_size_; - std::vector meta_data_; - std::vector input_data_addrs_; - std::vector output_data_addrs_; - std::vector workspace_addrs_; - // std::vector flow_table_; -}; - -class AicpuIRTaskInfo : public IRtaskInfo { - public: - AicpuIRTaskInfo(TaskType task_type, uint64_t stream_id, std::string op_type, uint32_t flag, - std::vector input_data_types, std::vector> input_data_shapes, - std::vector input_data_addrs, std::vector output_data_types, - std::vector> output_data_shapes, std::vector output_data_addrs, - std::vector node_def, std::vector func_def) - : IRtaskInfo(task_type, AICPU_TMP_DEF, stream_id), - op_type_(std::move(op_type)), - flag_(flag), - input_data_types_(std::move(input_data_types)), - input_data_shapes_(std::move(input_data_shapes)), - input_data_addrs_(std::move(input_data_addrs)), - output_data_types_(std::move(output_data_types)), - output_data_shapes_(std::move(output_data_shapes)), - output_data_addrs_(std::move(output_data_addrs)), - node_def_(std::move(node_def)), - func_def_(std::move(func_def)) {} - ~AicpuIRTaskInfo() override; - bool SerializeIRToProto() override; - - private: - std::string op_type_; - uint32_t flag_ = 0; - std::vector input_data_types_; - std::vector> input_data_shapes_; - std::vector input_data_addrs_; - std::vector output_data_types_; - std::vector> output_data_shapes_; - std::vector output_data_addrs_; - std::vector node_def_; - std::vector func_def_; -}; - -class LabelIRTaskInfo : public IRtaskInfo { - public: - LabelIRTaskInfo(TaskType task_type, uint64_t stream_id, uint32_t label_id) - : IRtaskInfo(task_type, LABEL_TMP_DEF, stream_id), label_id_(label_id) {} - ~LabelIRTaskInfo() override {} - bool SerializeIRToProto() override; - - private: - uint32_t label_id_ = 0; -}; - -class EventIRTaskInfo : public IRtaskInfo { - public: - EventIRTaskInfo(TaskType task_type, uint64_t stream_id, uint32_t event_id) - : IRtaskInfo(task_type, EVENT_TMP_DEF, stream_id), event_id_(event_id) {} - ~EventIRTaskInfo() override {} - bool SerializeIRToProto() override; - - private: - uint32_t event_id_ = 0; -}; - -class HcclIRTaskInfo : public IRtaskInfo { - public: - HcclIRTaskInfo(TaskType task_type, uint64_t stream_id, std::string hccl_type, uintptr_t input_data_addr, - uintptr_t output_data_addr, std::vector workspace, int64_t workspace_num, - std::vector private_def, uintptr_t ops_kernel_store, int32_t count, int64_t root_id, - int64_t op_type, int64_t data_type) - : IRtaskInfo(task_type, HCCL_TMP_DEF, stream_id), - hccl_type_(std::move(hccl_type)), - input_data_addr_(input_data_addr), - output_data_addr_(output_data_addr), - workspace_(std::move(workspace)), - workspace_num_(workspace_num), - private_def_(std::move(private_def)), - ops_kernel_store_(ops_kernel_store), - count_(count), - root_id_(root_id), - op_type_(op_type), - data_type_(data_type) {} - ~HcclIRTaskInfo() override; - bool SerializeIRToProto() override; - - private: - std::string hccl_type_; - uintptr_t input_data_addr_ = 0; - uintptr_t output_data_addr_ = 0; - std::vector workspace_; - int64_t workspace_num_ = 0; - std::vector private_def_; - uintptr_t ops_kernel_store_ = 0; - int32_t count_ = 0; - int64_t root_id_ = 0; - int64_t op_type_ = 0; - int64_t data_type_ = 0; -}; - -class ProfilerIRTaskInfo : public IRtaskInfo { - public: - ProfilerIRTaskInfo(TaskType task_type, uint64_t stream_id, uint64_t log_id, bool notify, uint32_t flat) - : IRtaskInfo(task_type, PROFILER_TRACE_TMP_DEF, stream_id), log_id_(log_id), notify_(notify), flat_(flat) {} - ~ProfilerIRTaskInfo() override {} - bool SerializeIRToProto() override; - - private: - uint64_t log_id_ = 0; - bool notify_ = false; - uint32_t flat_ = 0; -}; - -class MemcpyAsyncIRTaskInfo : public IRtaskInfo { - public: - MemcpyAsyncIRTaskInfo(TaskType task_type, uint32_t stream_id, uint64_t dst, uint64_t dst_max, uint64_t src, - uint64_t count, int64_t kind) - : IRtaskInfo(task_type, MEMCPY_ASYNC_TMP_DEF, stream_id), - dst_(dst), - dst_max_(dst_max), - src_(src), - count_(count), - kind_(kind) {} - ~MemcpyAsyncIRTaskInfo() override {} - bool SerializeIRToProto() override; - - private: - uint64_t dst_ = 0; - uint64_t dst_max_ = 0; - uint64_t src_ = 0; - uint64_t count_ = 0; - uint32_t kind_ = 0; -}; - -class StreamSwitchIRTaskInfo : public IRtaskInfo { - public: - StreamSwitchIRTaskInfo(TaskType task_type, uint64_t stream_id, uint32_t true_stream_id, uintptr_t input_addr, - uintptr_t value_addr, uint32_t cond, int64_t data_type) - : IRtaskInfo(task_type, STREAM_SWITCH_TMP_DEF, stream_id), - true_stream_id_(true_stream_id), - input_addr_(input_addr), - value_addr_(value_addr), - cond_(cond), - data_type_(data_type) {} - ~StreamSwitchIRTaskInfo() override {} - bool SerializeIRToProto() override; - - private: - uint32_t true_stream_id_ = 0; - uintptr_t input_addr_ = 0; - uintptr_t value_addr_ = 0; - uint32_t cond_ = 0; - int64_t data_type_ = 0; -}; - -class StreamActiveIRTaskInfo : public IRtaskInfo { - public: - StreamActiveIRTaskInfo(TaskType task_type, uint64_t stream_id, uint32_t active_stream_id) - : IRtaskInfo(task_type, STREAM_ACTIVE_TMP_DEF, stream_id), active_stream_id_(active_stream_id) {} - ~StreamActiveIRTaskInfo() override {} - bool SerializeIRToProto() override; - - private: - uint32_t active_stream_id_ = 0; -}; -}; // namespace generator -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_EXECUTOR_GENERATOR_IR_IR_TASK_H_ diff --git a/mindspore/ccsrc/predict/generator/utils/ir_model_util.cc b/mindspore/ccsrc/predict/generator/utils/ir_model_util.cc deleted file mode 100644 index 8128009472..0000000000 --- a/mindspore/ccsrc/predict/generator/utils/ir_model_util.cc +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/generator/utils/ir_model_util.h" -namespace mindspore { -namespace generator { -IRModelUtil &IRModelUtil::GetInstance() { - static IRModelUtil instance; - return instance; -} - -void IRModelUtil::Init() { - MS_LOG(INFO) << "IRModel init success"; - version_ = "defaultVersion"; - stream_num_ = 0; - event_num_ = 0; - batch_num_ = 0; - memory_size_ = 0; - weight_size_ = 0; - var_size_ = 0; - logic_mem_base_ = 0; - logic_var_base_ = 0; - logic_var_base_ = 0; - priority_ = 0; - is_enable_save_model_ = false; - min_static_offset_ = 0; - max_dynamic_offset_ = 0; -} -} // namespace generator -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/generator/utils/ir_model_util.h b/mindspore/ccsrc/predict/generator/utils/ir_model_util.h deleted file mode 100644 index a654cb980f..0000000000 --- a/mindspore/ccsrc/predict/generator/utils/ir_model_util.h +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_GENERATOR_IR_IR_MODEL_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_GENERATOR_IR_IR_MODEL_UTIL_H_ -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" - -namespace mindspore { -namespace generator { -class IRModelUtil { - public: - static IRModelUtil &GetInstance(); - IRModelUtil(const IRModelUtil &) = delete; - IRModelUtil &operator=(const IRModelUtil &) = delete; - void Init(); - - void set_version(const std::string &version) { version_ = version; } - void set_stream_num(uint32_t stream_num) { stream_num_ = stream_num; } - void set_event_num(uint32_t event_num) { event_num_ = event_num; } - void set_batch_num(uint32_t batch_num) { batch_num_ = batch_num; } - void set_memory_size(uint32_t memory_size) { memory_size_ = memory_size; } - void set_weight_size(uint32_t weight_size) { weight_size_ = weight_size; } - void set_var_size(uint32_t var_size) { var_size_ = var_size; } - void set_logic_mem_base(uint32_t logic_mem_base) { logic_mem_base_ = logic_mem_base; } - void set_logic_weight_base(uint32_t logic_weight_base) { logic_weight_base_ = logic_weight_base; } - void set_logic_var_base(uint32_t logic_var_base) { logic_var_base_ = logic_var_base; } - void set_priority(uint32_t priority) { priority_ = priority; } - void set_is_enable_save_model(bool is_enable_save_model) { is_enable_save_model_ = is_enable_save_model; } - void set_min_static_offset(uint64_t min_static_offset) { min_static_offset_ = min_static_offset; } - void set_max_dynamic_offset(uint64_t max_dynamic_offset) { max_dynamic_offset_ = max_dynamic_offset; } - void set_max_mem_size(uint64_t max_mem_size) { max_mem_size_ = max_mem_size; } - void set_irmodel_mem_base(uint8_t irmodel_mem_base) { irmodel_mem_base_ = irmodel_mem_base; } - - std::string version() const { return version_; } - uint32_t stream_num() const { return stream_num_; } - uint32_t event_num() const { return event_num_; } - uint32_t batch_num() const { return batch_num_; } - uint64_t memory_size() const { return memory_size_; } - uint64_t weight_size() const { return weight_size_; } - uint64_t var_size() const { return var_size_; } - uint64_t logic_mem_base() const { return logic_mem_base_; } - uint64_t logic_weight_base() const { return logic_weight_base_; } - uint64_t logic_var_base() const { return logic_var_base_; } - uint32_t priority() const { return priority_; } - bool is_enable_save_model() const { return is_enable_save_model_; } - uint64_t min_static_offset() const { return min_static_offset_; } - uint64_t max_dynamic_offset() const { return max_dynamic_offset_; } - uint64_t max_mem_size() const { return max_mem_size_; } - uint8_t irmodel_mem_base() const { return irmodel_mem_base_; } - - private: - IRModelUtil() = default; - ~IRModelUtil() = default; - std::string version_; - uint32_t stream_num_ = 0; - uint32_t event_num_ = 0; - uint32_t batch_num_ = 0; - uint64_t memory_size_ = 0; - uint64_t weight_size_ = 0; - uint64_t var_size_ = 0; - uint64_t logic_mem_base_ = 0; - uint64_t logic_weight_base_ = 0; - uint64_t logic_var_base_ = 0; - uint32_t priority_ = 0; - bool is_enable_save_model_ = false; - uint64_t min_static_offset_ = 0; - uint64_t max_dynamic_offset_ = 0; - uint64_t max_mem_size_ = 0; - uint8_t irmodel_mem_base_ = 0; -}; -} // namespace generator -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_GENERATOR_IR_IR_MODEL_UTIL_H_ diff --git a/mindspore/ccsrc/predict/predict.cc b/mindspore/ccsrc/predict/predict.cc deleted file mode 100644 index bbb12c3787..0000000000 --- a/mindspore/ccsrc/predict/predict.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "predict/predict.h" - -#include -#include -#include - -namespace mindspore { -namespace predictmodel { -void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr) { - MS_LOG(INFO) << "start convert_graph step"; - // get kernel_graph. this graph can be origin or device, depends on which steps to persistence - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag(); - if (save_ms_model) { - if (kernel_graph_ptr->inputs().empty()) { - return; - } - // set convert_mode: convert cpu info or convert Davnici - executor::Kernel2Ms::GetInstance().set_convert_mode(executor::kConvertCpuMode); - // convert kernel_graph to sub_ms_graph - bool ret = executor::Kernel2Ms::GetInstance().KernelGraph2MsGraph(kernel_graph_ptr); - if (!ret) { - MS_LOG(WARNING) << "convert to mindsporeGraph failed"; - } else { - MS_LOG(INFO) << "convert to Graph success"; - } - } -} - -void StepConvertWeight(const std::vector &inputs) { - MS_LOG(INFO) << "start convert_input step"; - // get all inputs tensor - bool save_ms_model = MsContext::GetInstance()->save_ms_model_flag(); - std::string save_path = MsContext::GetInstance()->save_ms_model_path(); - if (save_ms_model) { - if (inputs.empty()) { - return; - } - MS_LOG(INFO) << "save ms model is true to path " << save_path; - if (!executor::Kernel2Ms::GetInstance().KernelInput2MS(inputs)) { - MS_LOG(WARNING) << "convert mindspore kernel input failed"; - } - auto new_ms_graph_ptr = std::make_shared(); - bool ret = executor::Kernel2Ms::GetInstance().SaveDeviceModel(new_ms_graph_ptr, save_path); - if (!ret) { - MS_LOG(WARNING) << "convert to mindsporeGraph failed"; - } else { - MS_LOG(INFO) << "save ms model success"; - } - } -} -} // namespace predictmodel -} // namespace mindspore diff --git a/mindspore/ccsrc/predict/predict.h b/mindspore/ccsrc/predict/predict.h deleted file mode 100644 index 9125451492..0000000000 --- a/mindspore/ccsrc/predict/predict.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PREDICT_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PREDICT_H_ - -#include -#include -#include "backend/session/session_basic.h" -#include "predict/converter/kernel2ms.h" - -namespace mindspore { -namespace predictmodel { -using KernelGraphPtr = std::shared_ptr; -void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr); -void StepConvertWeight(const std::vector &inputs); -} // namespace predictmodel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_H_ diff --git a/mindspore/ccsrc/predict/proto/DModel_ir.proto b/mindspore/ccsrc/predict/proto/DModel_ir.proto deleted file mode 100644 index 02bfa94df3..0000000000 --- a/mindspore/ccsrc/predict/proto/DModel_ir.proto +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -import public "Graph_ir.proto"; -import public "ge_runtime_taskinfo.proto"; -package ge.model_runner; -option cc_enable_arenas = true; - -message ModelTaskDef { - - string version = 1; - - repeated TaskDef task = 10; - - uint32 stream_num = 11; - uint32 event_num = 12; - uint32 batch_num_ = 13; - - uint64 memory_size = 14; - uint64 weight_size = 15; - uint64 var_size_ = 16; - - uint64 logic_mem_base_ = 17; - uint64 logic_weight_base_ = 18; - uint64 logic_var_base_ = 19; - - uint32 priority_ = 20; -} diff --git a/mindspore/ccsrc/predict/proto/Graph_ir.proto b/mindspore/ccsrc/predict/proto/Graph_ir.proto deleted file mode 100644 index af91ec0917..0000000000 --- a/mindspore/ccsrc/predict/proto/Graph_ir.proto +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package mindspore; - -// Data type definition -enum DataType { - DT_UNDEFINED = 0; - // Basic types. - DT_BOOL = 1; // bool - - DT_INT8 = 2; // int8_t - DT_INT16 = 3; // int16_t - DT_INT32 = 4; // int32_t - DT_INT64 = 5; // int64_t - - DT_UINT8 = 6; // uint8_t - DT_UINT16 = 7; // uint16_t - DT_UINT32 = 8; // uint32_t - DT_UINT64 = 9; // uint64_t - - DT_FLOAT16 = 10; // float 16 - DT_FLOAT32 = 11; // float 32 - DT_FLOAT64 = 12; // float 64 - - DT_STRING = 13; // string - DT_TENSOR = 14; // tensor - DT_GRAPH = 15; // graph - - // list type - DT_BOOLS = 16; // list of bool - - DT_INTS8 = 17; // list of int8_t - DT_INTS16 = 18; // list of int16_t - DT_INTS32 = 19; // list of int32_t - DT_INTS64 = 20; // list of int64_t - - DT_UINTS8 = 21; // list of uint8_t - DT_UINTS16 = 22; // list of uint16_t - DT_UINTS32 = 23; // list of uint32_t - DT_UINTS64 = 24; // list of uint64_t - - DT_FLOATS16 = 25; // list of float16 - DT_FLOATS32 = 26; // list of float32 - DT_FLOATS64 = 27; // list of float64 - - DT_STRINGS = 28; // list of string - DT_TENSORS = 29; // list of tensor - DT_GRAPHS = 30; // list of graph - - DT_TUPLE = 31; // tuple - DT_LIST = 32; // list - DT_DICT = 33; // dictionary - - // other types - DT_NONE = 34; // None - DT_SYM_INST = 35; // Symbolic Key Instance - - // type related type - DT_BASE_INT = 36; // type generic int - DT_BASE_UINT = 37; // type generate unsigned int - DT_BASE_FLOAT = 38; // type generate float - DT_TYPE = 39; // type type - DT_ANYTHING = 40; // type anything -}; - -enum MSConst { - DEFAULT_REFCOUNT = 0; - WEIGHT_REFCOUNT = 999; -}; - -message TensorDef { - DataType data_type = 1; - - repeated int64 dims = 2; - - string format = 3; - string layout = 4; - uint32 refCount = 5; - uint64 offset = 6; - uint64 size = 7; - uint64 weight_size = 8; - bytes data = 9; -} - -message OpDef { - string name = 1; - string type = 2; - - string fwk_type = 3; - string opAttr = 4; - repeated int64 input_index = 5; - repeated int64 output_index = 6; -} - -message GraphDef { - string name = 1; - - repeated int64 input_index = 2; - - repeated int64 output_index = 3; - uint64 mempool_size = 4; - - repeated OpDef opdefs = 5; - - repeated TensorDef alltensors = 6; -} - - - diff --git a/mindspore/ccsrc/predict/proto/ge_runtime_taskinfo.proto b/mindspore/ccsrc/predict/proto/ge_runtime_taskinfo.proto deleted file mode 100644 index 3429d06544..0000000000 --- a/mindspore/ccsrc/predict/proto/ge_runtime_taskinfo.proto +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package ge.model_runner; -option cc_enable_arenas = true; - -message TaskDef { - enum TaskType { - CCE = 0; - TBE = 1; - AICPU = 2; - LABEL_SET = 3; - LABEL_SWITCH = 4; - LABEL_GOTO = 5; - EVENT_RECORD = 6; - EVENT_WAIT = 7; - FUSION_START = 8; - FUSION_END = 9; - HCCL = 10; - PROFILER_TRACE = 11; - MEMCPY_ASYNC = 12; - STREAM_SWITCH = 13; - STREAM_ACTIVE = 14; - // insert new task type here - REVSERVED = 23; - }; - - TaskType task_type = 1; - uint64 stream_id = 2; - oneof subclass { - CceTaskDef cce_task_def = 3; - TbeTaskDef tbe_task_def = 4; - AicpuTaskDef aicpu_task_def = 5; - LabelTaskDef label_task_def = 6; - EventTaskDef event_task_def = 7; - HcclTaskDef hccl_task_def = 8; - ProfilerTaskDef profiler_task_def = 9; - MemcpyAsyncTaskDef memcpy_async_task_def = 10; - StreamSwitchTaskDef stream_switch_task_def = 11; - StreamActiveTaskDef stream_active_task_def = 12; - } -} - -message CceTaskDef { - KernelContext kernel_context = 1; - string stub_func = 2; - uint32 block_dim = 3; - bytes args = 4; - uint32 args_size = 5; - bytes sm_desc = 6; - bytes flow_table = 7; -} - -message TbeTaskDef { - string stub_func = 1; - uint32 block_dim = 2; - bytes args = 3; - uint32 args_size = 4; - bytes sm_desc = 5; - bytes meta_data = 8; - repeated uint64 input_addrs = 9; - repeated uint64 output_addrs = 10; - repeated uint64 workspace_addrs = 11; -} - -message AicpuTaskDef { - string op_type = 1; - uint32 flag = 2; - repeated uint32 input_types = 3; - repeated Shape input_shapes = 4; - repeated uint64 input_addrs = 5; - repeated uint32 output_types = 6; - repeated Shape output_shapes = 7; - repeated uint64 output_addrs = 8; - bytes node_def = 9; - bytes func_def = 10; -} - -message Shape { - repeated uint32 shape = 1; -} - -message LabelTaskDef { - uint32 label_id = 1; -} - -message EventTaskDef { - uint32 event_id = 1; -} - -message HcclTaskDef { - string hccl_type = 1; - uint64 input_addr = 2; - uint64 output_addr = 3; - bytes workspace = 4; - int64 workspace_num = 5; - bytes private_def = 6; - uint64 ops_kernel_store = 7; - int32 count = 8; - int64 root_id = 9; - int64 op_type = 10; - int64 data_type = 11; -} - -message ProfilerTaskDef { - uint64 log_id = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncTaskDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; -} - -message StreamSwitchTaskDef { - uint32 true_stream_id = 1; - uint64 input_addr = 2; - uint64 value_addr = 3; - int64 cond = 4; - int64 data_type = 5; -} - -message StreamActiveTaskDef { - uint32 active_stream_id = 1; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; - uint32 kernel_func_id = 3; - uint32 op_index = 4; - bool is_flowtable = 5; - bytes args_offset = 6; - uint32 args_count = 7; - repeated uint32 origin_op_index = 8; -} diff --git a/mindspore/ccsrc/predict/readme.txt b/mindspore/ccsrc/predict/readme.txt deleted file mode 100644 index d75abf257b..0000000000 --- a/mindspore/ccsrc/predict/readme.txt +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -this is a dictory for predict including saving model &&& saving taskinfos. diff --git a/mindspore/ccsrc/predict/schema/inner/readme.txt b/mindspore/ccsrc/predict/schema/inner/readme.txt deleted file mode 100644 index 774f71f602..0000000000 --- a/mindspore/ccsrc/predict/schema/inner/readme.txt +++ /dev/null @@ -1 +0,0 @@ -this is a dictory for predict to gen fbs headers \ No newline at end of file diff --git a/mindspore/ccsrc/predict/schema/ms.fbs b/mindspore/ccsrc/predict/schema/ms.fbs deleted file mode 100644 index 7c3dcfb498..0000000000 --- a/mindspore/ccsrc/predict/schema/ms.fbs +++ /dev/null @@ -1,212 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -include "op.fbs"; - -namespace mindspore.predict; - -enum MSCONST: int { - WEIGHT_REFCOUNT = 999 -} - -table QuantParam { - scale: double; - zeroPoint: int; - min: double = 0; - max: double = 0; - narrowRange: bool = true; - numBits: int = 8; -} - -table QuantParamArray { - param: [QuantParam]; //pre-channel -} - -table TensorDef { - // data type - dataType: DataType; - // shape - dims: [int]; - format: Format; - refCount: int; - offset: int; - data: [ubyte]; -} - -union OpT { - Concat, - SoftMax, - Activation, - Conv2D, - FusedBatchNorm, - CaffeBatchNorm, - BiasAdd, - Pooling, - DepthwiseConv2D, - DeDepthwiseConv2D, - Resize, - DetectionPostProcess, - FullConnection, - Mean, - DeConv2D, - Scale, - Reshape, - Eltwise, - NetOutput, - Add, - Sub, - MatMul, - StridedSlice, - Power, - Slice, - Stack, - Mul, - RealDiv, - Pad, - Maximum, - Minimum, - CaffePReLU, - LeakyReLU, - ArgMax, - ArgMin, - Exp, - CaffeCrop, - Range, - Rsqrt, - ExpandDims, - Tile, - Cast, - Shape, - Nchw2Nhwc, - Nhwc2Nchw, - QuantDTypeCast, - Split, - Permute, - FakeQuantWithMinMaxVars, - Equal, - Less, - Greater, - Min, - Floor, - Abs, - Neg, - Cos, - Sin, - Sqrt, - Square, - Constant, - Log, - Tan, - Atan, - Asin, - Clip, - Transpose, - Squeeze, - Unsqueeze, - Upsample, - Dropout, - Broadcast, - Lrn, - Prelu, - ZerosLike, - TopK, - SpaceToDepth, - SpaceToBatch, - SparseToDense, - ReverseSequence, - Rank, - Gather, - GatherNd, - Fill, - Elu, - DepthToSpace, - BatchToSpace, - AddN, - Ceil, - EmbeddingLookup, - EmbeddingLookupSparse, - FloorDiv, - FloorMod, - L2Norm, - LocalResponseNormalization, - MatrixDiag, - Reduce, - Reverse, - Round, - Select, - Scatter, - Unique, - Unstack, - LogicalAnd, - LogicalOr, - LogicalXor, - LogicalNot, - OnnxInt8Quantize, - OnnxInt8Dequantize, - FakeQuantWithMinMax, - FakeQuantWithMinMaxPerChannel, - BatchNormFold, - MulFold, - AddFold, - SquaredDifference -} - -enum QuantType: int { - QUANT_NONE, - AwareTrainning, - WeightQuant, - PostTraining -} - -enum FmkType: int { - TF, - CAFFE, - ONNX, - MS, - TFLITE -} - -table OpDef { - name: string; - fmkType: FmkType; - attr: OpT; - inputIndex: [uint]; - outputIndex: [uint]; - quantType: QuantType = QUANT_NONE; - quantParam: [QuantParamArray]; -} - -table SubGraphDef { - name: string; - inputIndex: [uint]; - outputIndex: [uint]; - mempoolSize: uint; - nodes: [OpDef]; - allTensors: [TensorDef]; // weight + input + output -} - -table MempoolCfg { - size: uint; - shiftFactor: uint; -} - -table GraphDef { - name: string; - mempoolCfg: MempoolCfg; - subgraphs: [SubGraphDef]; -} - -root_type GraphDef; diff --git a/mindspore/ccsrc/predict/schema/op.fbs b/mindspore/ccsrc/predict/schema/op.fbs deleted file mode 100644 index 9286c2b2d3..0000000000 --- a/mindspore/ccsrc/predict/schema/op.fbs +++ /dev/null @@ -1,699 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -namespace mindspore.predict; - -enum ResizeMethod: byte { - UNKNOW = -1, - BILINEAR = 0, - NEAREST_NEIGHBOR = 1 -} - -enum DataType : int { - DT_FLOAT = 0, - DT_FLOAT16 = 1, - DT_INT8 = 2, - DT_INT32 = 3, - DT_UINT8 = 4, - DT_INT16 = 5, - DT_UINT32 = 8, - DT_INT64 = 9, - DT_UINT16 = 10, - DT_UNDEFINED = 16 -} - -enum Format : int { - NCHW = 0, - NHWC, - HWKC, - HWCK, - KCHW, - CKHW, - KHWC, - CHWK, - NC4HW4 = 100, - NUM_OF_FORMAT -} - -enum ActivationType : byte { - NO_ACTIVATION = 0, - RELU = 1, - SIGMOID = 2, - RELU6 = 3, - ELU = 4, - LEAKY_RELU = 5, - ABS = 6, - RELU1 = 7, - SOFTSIGN = 8, - SOFTPLUS = 9, - TANH = 10, - SELU = 11, - HSWISH = 12, - HSIGMOID = 13, - THRESHOLDRELU = 14, - LINEAR = 15, - UNKNOW = 16 -} - -enum ReduceType : byte { - REDUCE_MAX = 0, - REDUCE_MEAN = 1, - REDUCE_ALL = 2, - REDUCE_ANY = 3, - REDUCE_LOG_SUM_EXP = 4, - REDUCE_PROD = 5, - REDUCE_SUM = 6, - UNKNOW = 7 -} - -enum PoolMode : byte { - MAX_POOLING = 0, - MEAN_POOLING = 1, -} - -enum EltwiseMode : byte { - PROD = 0, - SUM = 1, - MAXIMUM = 2, - UNKNOW = 3 -} - -enum PadMode : byte { - NOTSET = 0, - SAME = 1, - VALID = 2, - CAFFE = 4 -} - -enum RoundMode : byte { - FLOOR = 0, - CEIL = 1 -} - -enum PaddingMode : byte { - CONSTANT = 0, - REFLECT = 1, - SYMMETRIC = 2, - MODE_RESERVED = 3 -} - -table Pad { - paddingmode: PaddingMode; - paddings: [int]; -} - -table Maximum { -} - -table Minimum { -} - -table Concat { - axis: int; - n: int; -} - -table SoftMax { - axis: [int]; -} - -table Activation { - type: ActivationType = 0; -} - -table Conv2D { - format: Format = 0; - group: int; - channelIn: int; - channelOut: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - -table FusedBatchNorm { - epsilon: float = 0.00001; // eg. epsilon=0.001 - momentum: float = 0.9; - spatial: int = 1; -} - -table CaffeBatchNorm { - epsilon: float; // eg. epsilon=0.001 -} - -table Shape { -} - -table Nchw2Nhwc { - -} - -table Nhwc2Nchw { - -} - -table FakeQuantWithMinMaxVars { - narrowRange: bool; - numBits: int; -} - -table BiasAdd { - axis: [int]; -} - -table Pooling { - format: Format = 0; - poolingMode: PoolMode; - global: bool = false; - windowW: int; - windowH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - roundMode: RoundMode; -} - -table DepthwiseConv2D { - format: Format = 0; - channelIn: int; - channelMultiplier: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - -table DeDepthwiseConv2D { - format: Format = 0; - channelIn: int; - channelMultiplier: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - - -table Resize { - format: Format = 0; - method: ResizeMethod; - newHeight: long; - newWidth: long; - alignCorners: bool = false; - preserveAspectRatio: bool = false; -} - -table DetectionPostProcess { - format: Format = 0; - inputSize: int; - hScale: float; - wScale: float; - xScale: float; - yScale: float; - NmsIouThreshold: float; - NmsScoreThreshold: float; - MaxDetections: long; - DetectionsPreClass: long; - MaxClassesPreDetection: long; - NumClasses: long; - UseRegularNms: bool; -} - -table FullConnection { - hasBias: bool; - axis: int; -} - -// Mean(input_tensor, axis, keep_dims) -table Mean { - axis: [int]; - keepDims: bool = false; -} - -table DeConv2D { - format: Format = 0; - group: int; - channelIn: int; - channelOut: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - -table Scale { - format: Format = 0; -} - -table Eltwise { - mode: EltwiseMode; -} - -table Add { -} - -table Sub { -} - -table Mul { -} - -table RealDiv { -} - -table Rsqrt { -} - -table Equal { -} - -table Less { -} - -table Greater { -} - -table Min { -} - -table Slice { - format: Format = 0; - begin: [int]; - size: [int]; -} - -table Floor { -} - -table Abs { -} - -table Neg { -} - -table Exp { -} - -table Cos { -} - -table Sin { -} - -table Sqrt { -} - -table Square { -} - -table Ceil { -} - -table Log { -} - -table Tan { -} - -table Atan { -} - -table Asin { -} - -table Reshape { - format: Format = 0; - shape: [long]; -} - -table Power { - power: float; - scale: float; - shift: float; -} - -table ArgMax { - axis: int; - outMaxValue: bool; - topK: int = 1; - keepDims: bool; - axisType: int; -} - -table ArgMin { - axis: int; - outMaxValue: bool; - topK: int = 1; - keepDims: bool; - axisType: int; -} - -table NetOutput { -} - -table MatMul { - transposeA : bool = false; - transposeB : bool = false; -} - -table CaffePReLU { - channelShared : bool = false; -} - -table LeakyReLU { - negativeSlope: float; -} - -table StridedSlice { - beginMask: int; - endMask: int; - ellipsisMask: int; - newAxisMask: int; - shrinkAxisMask: int; - begin: [int]; - end: [int]; - stride: [int]; - isScale: [int]; -} - -table Stack { - axis: int; - n: int; - isScale: [int]; -} - -table Range { - dType: DataType; - start: int; - limit: int; - delta: int; -} - -table ExpandDims { - dim: int; -} - -table Tile { - multiples: [int]; -} - -table Cast { - srcT: int; - dstT: int; -} - -table QuantDTypeCast { - srcT: DataType; - dstT: DataType; -} - -table Split { - numberSplit: int; - sizeSplits: [int]; - splitDim: int; -} - -table CaffeCrop { - axis : long; - offsets : [long]; -} - -table Permute { - order: [long]; -} - -table Clip { - max: float; - min: float; -} - -table Constant { -} - - -table Elu { - alpha: float = 1.0; -} - -table Broadcast { -} - -table Lrn { - alpha: float = 0.0001; - beta: float = 0.75; - bias: float = 1.0; - size: int; -} - -enum ReduceMode : byte { - ReduceMean = 0, - ReduceMax = 1, - ReduceMin = 2, - ReduceProd = 3, - ReduceSum = 4, - ReduceSumSquare = 5 -} - -table Reduce { - axes: [int]; - keepDims: int; - mode: ReduceMode; -} - -table Prelu { - slope: [float]; -} - -table Transpose { - perm: [int]; - conjugate: bool = false; -} - -table Squeeze { - axis: [int]; -} - -table Unsqueeze { - axis: [int]; -} - -table Upsample { - mode: string; - scales: [float]; -} - -table Dropout { - ratio : float = 0.5; -} - -table LocalResponseNormalization { - depth_radius: int; - bias: float; - alpha: float; - beta: float; -} - -table ZerosLike { -} - -table TopK { - k : int; - sorted : bool = true; -} - -table SpaceToDepth { - blockSize : int; - format: Format = 0; -} - -table SpaceToBatch { - blockShape : [int]; - paddings : [int]; -} - -table SparseToDense { - validateIndices: bool; -} - -table ReverseSequence { - seqAxis: int; - batchAxis: int; -} - -table Rank { -} - - -table Gather { - axis: int; - batchDims: int; -} - -table GatherNd { - batchDims: int; -} - -table Fill { - dims: [int]; -} - -table DepthToSpace { - blockSize: int; - format: Format = 0; -} - - -table BatchToSpace { - blockShape: [int]; - crops: [int]; -} - -table AddN { - N: int; -} - - -table EmbeddingLookup { - ids: [int]; - maxNorm: float; -} - -table EmbeddingLookupSparse { - spIds: [int]; - spWeights: [float]; - //combiner: Combiner=0; - maxNortm: float; -} - -table FloorDiv { -} - -table FloorMod { -} - -table L2Norm { - axis: [int]; - epsilon: float; -} - -table LogicalAnd { -} - -table LogicalOr { -} - -table LogicalXor { -} - -table LogicalNot { -} - -table MatrixDiag { - k: int; - numRows: int; - numCols: int; - paddingValue: float; -} - -table Select { -} - -table TfReduce { - type: ReduceType = 7; -} - -table Reverse { - axis: [int]; -} - -table Round { -} - -table Scatter { -} - -table Unique { -} - -table Unstack { - num: int; - axis: int; -} - -table OnnxInt8Quantize { -} - -table OnnxInt8Dequantize { -} - -table FakeQuantWithMinMax { -} - -table FakeQuantWithMinMaxPerChannel { -} - -table BatchNormFold { -} - -table MulFold { -} - -table AddFold { -} - -table SquaredDifference { -} diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 253e271e52..c8f988d4ef 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -16,23 +16,8 @@ #include "pybind_api/export_flags.h" namespace mindspore { - const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__"; -const char PYTHON_METAFUNCGRAPH_FLAG[] = "__metafuncgraph_flag__"; -const char PYTHON_TENSOR_FLAG[] = "__tensor_flag__"; -const char PYTHON_META_TENSOR_FLAG[] = "__meta_tensor_flag__"; -const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__"; -const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__"; const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; - -// flag names -const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; -const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; -const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; -const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; -const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; -const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; -const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; - +const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 6ea584e66d..56e0a87ead 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -17,24 +17,12 @@ #ifndef PYBIND_API_EXPORT_FLAGS_H_ #define PYBIND_API_EXPORT_FLAGS_H_ +#include "utils/flags.h" namespace mindspore { - extern const char PYTHON_PRIMITIVE_FLAG[]; -extern const char PYTHON_METAFUNCGRAPH_FLAG[]; -extern const char PYTHON_TENSOR_FLAG[]; -extern const char PYTHON_META_TENSOR_FLAG[]; -extern const char PYTHON_ENVINSTANCE_FLAG[]; -extern const char PYTHON_DTYPE_FLAG[]; extern const char PYTHON_CELL_AS_LIST[]; extern const char PYTHON_DATACLASS_FIELDS[]; - -extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; -extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; -extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; -extern const char GRAPH_FLAG_HAS_EFFECT[]; -extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; -extern const char GRAPH_FLAG_RANDOM_EFFECT[]; -extern const char GRAPH_FLAG_SIDE_EFFECT[]; +extern const char PYTHON_CLASS_MEMBER_NAMESPACE[]; } // namespace mindspore #endif // PYBIND_API_EXPORT_FLAGS_H_ diff --git a/mindspore/ccsrc/pybind_api/pybind_patch.h b/mindspore/ccsrc/pybind_api/pybind_patch.h new file mode 100644 index 0000000000..a71774b26a --- /dev/null +++ b/mindspore/ccsrc/pybind_api/pybind_patch.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PYBIND_API_PYBIND_PATCH_H_ +#define PYBIND_API_PYBIND_PATCH_H_ + +namespace pybind11 { +PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) +} + +#endif // PYBIND_API_PYBIND_PATCH_H_ diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt index 9c95aee0dc..fac738ca69 100644 --- a/mindspore/ccsrc/runtime/device/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -18,17 +18,20 @@ if (ENABLE_CPU) endif () if (ENABLE_MPI) - # _ms_mpi - file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") - set_property(SOURCE ${MPI_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) - target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) + if (ENABLE_CPU) + file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") + set_property(SOURCE ${MPI_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) + target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) + endif () - set_property(SOURCE "gpu/mpi/mpi_initializer.cc" - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") - target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) + if (ENABLE_GPU) + set_property(SOURCE "gpu/mpi/mpi_initializer.cc" + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") + target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) + endif () endif () # gpu @@ -55,6 +58,7 @@ if (ENABLE_GPU) PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) + target_link_libraries(_ms_mpi PRIVATE gpu_collective) endif () # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 1a87f3e6af..bfd8de81b0 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -16,16 +16,21 @@ #include "runtime/device/ascend/ascend_device_address.h" #include #include +#include +#include #include #include #include "runtime/mem.h" #include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/kernel_runtime.h" #include "runtime/device/convert_tensor_utils.h" #include "ir/dtype/type.h" #include "ir/tensor.h" #include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "utils/utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/trans.h" #ifdef ENABLE_DUMP_E2E #include "debug/e2e_dump.h" @@ -34,6 +39,58 @@ #include "debug/tensor_load.h" #endif +namespace { +const std::unordered_map type_id_name_map = { + {mindspore::kNumberTypeBool, "bool"}, {mindspore::kNumberTypeInt8, "int8"}, + {mindspore::kNumberTypeInt16, "int16"}, {mindspore::kNumberTypeInt32, "int32"}, + {mindspore::kNumberTypeInt64, "int64"}, {mindspore::kNumberTypeFloat16, "float16"}, + {mindspore::kNumberTypeFloat32, "float32"}, {mindspore::kNumberTypeUInt8, "uint8"}, + {mindspore::kNumberTypeUInt16, "uint16"}, {mindspore::kNumberTypeUInt32, "uint32"}, + {mindspore::kNumberTypeUInt64, "uint64"}}; +const std::set> use_trans_data = { + std::make_pair("float16", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_NC1HWC0), + std::make_pair("bool", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_FRAC_Z), + std::make_pair("float16", mindspore::kOpFormat_FRAC_Z), std::make_pair("float16", mindspore::kOpFormat_FRAC_NZ), + std::make_pair("float32", mindspore::kOpFormat_FRAC_NZ), std::make_pair("int32", mindspore::kOpFormat_FRAC_NZ), + std::make_pair("float16", mindspore::kOpFormat_NHWC), std::make_pair("float32", mindspore::kOpFormat_NHWC), + std::make_pair("int8", mindspore::kOpFormat_NHWC), std::make_pair("int16", mindspore::kOpFormat_NHWC), + std::make_pair("int32", mindspore::kOpFormat_NHWC), std::make_pair("int64", mindspore::kOpFormat_NHWC), + std::make_pair("uint8", mindspore::kOpFormat_NHWC), std::make_pair("uint16", mindspore::kOpFormat_NHWC), + std::make_pair("uint32", mindspore::kOpFormat_NHWC), std::make_pair("uint64", mindspore::kOpFormat_NHWC), + std::make_pair("float16", mindspore::kOpFormat_HWCN), std::make_pair("float32", mindspore::kOpFormat_HWCN), + std::make_pair("int8", mindspore::kOpFormat_HWCN), std::make_pair("int16", mindspore::kOpFormat_HWCN), + std::make_pair("int32", mindspore::kOpFormat_HWCN), std::make_pair("int64", mindspore::kOpFormat_HWCN), + std::make_pair("uint8", mindspore::kOpFormat_HWCN), std::make_pair("uint16", mindspore::kOpFormat_HWCN), + std::make_pair("uint32", mindspore::kOpFormat_HWCN), std::make_pair("uint64", mindspore::kOpFormat_HWCN)}; +constexpr auto src_format = "src_format"; +constexpr auto dst_format = "dst_format"; +constexpr auto src = "src_0"; +constexpr auto dst = "dst"; +constexpr auto param_type_required = "required"; +constexpr auto gen_model_single = "single"; +constexpr auto trans_data = "trans_data"; +constexpr auto platform_tbe = "TBE"; +constexpr auto name = "name"; +constexpr auto valid = "valid"; +constexpr auto value = "value"; +constexpr auto dtype = "dtype"; +constexpr auto format_str = "format"; +constexpr auto ori_format = "ori_format"; +constexpr auto ori_shape = "ori_shape"; +constexpr auto param_type = "param_type"; +constexpr auto shape_str = "shape"; +constexpr auto process_aicore = "aicore"; +constexpr auto gen_model_str = "gen_model"; +constexpr auto impl_path_str = "impl_path"; +constexpr auto attrs_str = "attrs"; +constexpr auto inputs_str = "inputs"; +constexpr auto outputs_str = "outputs"; +constexpr auto kernel_name_str = "kernel_name"; +constexpr auto op_info_str = "op_info"; +constexpr auto platform_str = "platform"; +constexpr auto fractal_z = "FRACTAL_Z"; +} // namespace + namespace mindspore { namespace device { namespace ascend { @@ -96,11 +153,117 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s return true; } +DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->device_id(); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type); + return address_ptr; +} + +size_t GetCommonAlignSize(size_t input_size) { + return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; +} + +nlohmann::json ConstructAttrs(const std::string &format) { + nlohmann::json real_attr; + nlohmann::json src_attr; + nlohmann::json des_attr; + src_attr[name] = src_format; + src_attr[valid] = true; + if (format == kOpFormat_FRAC_Z) { + src_attr[value] = fractal_z; + } else { + src_attr[value] = format; + } + des_attr[name] = dst_format; + des_attr[valid] = true; + des_attr[value] = kOpFormat_NCHW; + real_attr.push_back(src_attr); + real_attr.push_back(des_attr); + return real_attr; +} + +nlohmann::json ConstructInputs(const std::vector &input_shape, const std::vector &output_shape, + const std::string &format, mindspore::TypeId type) { + nlohmann::json input; + nlohmann::json input_json; + nlohmann::json real_input; + real_input[dtype] = type_id_name_map.at(type); + if (format == kOpFormat_FRAC_Z) { + real_input[format_str] = fractal_z; + } else { + real_input[format_str] = format; + } + real_input[name] = src; + real_input[ori_format] = kOpFormat_NCHW; + for (auto shape : output_shape) { + real_input[ori_shape].push_back(shape); + } + real_input[param_type] = param_type_required; + // obtain inputs shape + for (auto shape : input_shape) { + real_input[shape_str].push_back(shape); + } + real_input[valid] = true; + input_json.push_back(real_input); + input.push_back(input_json); + return input; +} + +nlohmann::json ConstructOutputs(const std::vector &output_shape, mindspore::TypeId type) { + nlohmann::json output; + nlohmann::json output_json; + nlohmann::json real_output; + real_output[dtype] = type_id_name_map.at(type); + real_output[format_str] = kOpFormat_NCHW; + real_output[name] = dst; + real_output[ori_format] = kOpFormat_NCHW; + for (auto shape : output_shape) { + real_output[ori_shape].push_back(shape); + } + real_output[param_type] = param_type_required; + // obtain outputs shape + for (auto shape : output_shape) { + real_output[shape_str].push_back(shape); + } + real_output[valid] = true; + output_json.push_back(real_output); + output.push_back(output_json); + return output; +} + +nlohmann::json ConstructTransDataKernelJson(const std::vector &host_shape, + const std::vector &device_shape, const std::string &format, + mindspore::TypeId type) { + // generate kernel json + nlohmann::json kernel_json; + kernel_json[gen_model_str] = gen_model_single; + kernel_json[impl_path_str] = ""; + // construct op_info + nlohmann::json op_info; + op_info[attrs_str] = ConstructAttrs(format); + op_info[inputs_str] = ConstructInputs(device_shape, host_shape, format, type); + op_info[kernel_name_str] = ""; + op_info[name] = trans_data; + op_info[outputs_str] = ConstructOutputs(host_shape, type); + kernel_json[op_info_str] = op_info; + kernel_json[platform_str] = platform_tbe; + std::string json_str = kernel_json[op_info_str].dump(); + size_t hash_id = std::hash()(json_str); + const std::string op_name = op_info[name]; + const std::string json_name = op_name + "_" + std::to_string(hash_id); + kernel_json[op_info_str][kernel_name_str] = json_name; + return kernel_json; +} + void AscendDeviceAddress::SyncStream() const { MS_LOG(INFO) << "Start!"; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() != kPynativeMode) { + if (ms_context->execution_mode() != kPynativeMode && !ms_context->enable_pynative_infer()) { MS_LOG(INFO) << "Finish!"; return; } @@ -158,31 +321,171 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t return sync_ok; } +void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, + size_t output_size, const std::vector &workspace_size_list) const { + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + auto input_address = std::make_shared(); + MS_EXCEPTION_IF_NULL(input_address); + input_address->addr = ptr_; + input_address->size = size_; + auto output_address = std::make_shared(); + MS_EXCEPTION_IF_NULL(output_address); + output_address->addr = output_address_ptr; + output_address->size = output_size; + AddressPtrList kernel_inputs = {input_address}; + AddressPtrList kernel_outputs = {output_address}; + AddressPtrList kernel_workspaces; + std::vector workspace_address_ptr(workspace_size_list.size()); + if (!workspace_size_list.empty()) { + for (size_t i = 0; i < workspace_size_list.size(); ++i) { + auto workspace_size = GetCommonAlignSize(workspace_size_list[i]); + workspace_address_ptr[i] = AssignLaunchMemory(workspace_size, "", kTypeUnknown); + MS_EXCEPTION_IF_NULL(workspace_address_ptr[i]); + auto workspace_address = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace_address); + workspace_address->addr = workspace_address_ptr[i]->GetMutablePtr(); + workspace_address->size = workspace_address_ptr[i]->GetSize(); + kernel_workspaces.push_back(workspace_address); + } + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->device_id(); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto ret = + runtime_instance->LaunchTaskBasedOnSingleKernel(kernel_mod_ptr, kernel_inputs, kernel_outputs, kernel_workspaces); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + } + SyncStream(); +} + +kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const { + static std::set constructed_kernel; + auto build_manager = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manager); + std::string processor = process_aicore; + // get size + std::vector input_size_list; + std::vector output_size_list; + (void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); + std::string json_name = kernel_json[op_info_str][kernel_name_str]; + // op build + if (constructed_kernel.find(json_name) == constructed_kernel.end()) { + auto task_id = build_manager->StartCompileOp(kernel_json); + build_manager->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list); + } + while (!build_manager->IsAllTaskFinish()) { + int task_id = -1; + std::string task_result; + std::string pre_build_result; + auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + if (task_result != "Success") { + MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; + } + (void)build_manager->TaskFinishProcess(task_id, false); + } + constructed_kernel.insert(json_name); + // search cache + auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); + MS_EXCEPTION_IF_NULL(cached_kernel_pack); + auto kernel_mod_ptr = + build_manager->GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); + return kernel_mod_ptr; +} + +bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector &host_shape, + const std::vector &device_shape, + size_t size, mindspore::TypeId type, + void *host_ptr) const { + bool sync_ok = true; + // construct trans data kernel json + nlohmann::json kernel_json = ConstructTransDataKernelJson(host_shape, device_shape, format_, type_id_); + MS_LOG(INFO) << "Construct trans_data kernel json: " << kernel_json.dump(); + auto kernel_mod_ptr = CompileTransDataAndObtainKernelMod(kernel_json); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + auto host_size = size; + if (type_id_ != type) { + auto device_dtype_size = trans::TypeIdSize(type_id_); + if (device_dtype_size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + } + auto shape_size = trans::ShapeSize(host_shape); + size = device_dtype_size * shape_size; + } + size = GetCommonAlignSize(size); + auto output_address = AssignLaunchMemory(size, kOpFormat_NCHW, type_id_); + MS_EXCEPTION_IF_NULL(output_address); + auto workspace_size_list = GetWorkspaceSizeList(kernel_json); + // launch + LaunchTransData(kernel_mod_ptr, output_address->GetMutablePtr(), output_address->GetSize(), workspace_size_list); + if (type_id_ == type) { + SyncMemory(host_ptr, output_address->GetPtr(), host_size, RT_MEMCPY_DEVICE_TO_HOST); + } else { + auto host = std::vector(size); + SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST); + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; + sync_ok = trans::TransDataType(type_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + } + return sync_ok; +} + +std::vector AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::json &kernel_json) const { + std::string json_name = kernel_json[op_info_str][kernel_name_str]; + std::string processor = process_aicore; + auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); + MS_EXCEPTION_IF_NULL(cached_kernel_pack); + auto kernel_json_info = cached_kernel_pack->kernel_json_info(); + return kernel_json_info.workspaces; +} + +std::vector AscendDeviceAddress::GetDeviceShape(std::vector *host_shape) const { + std::vector device_shape; + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + device_shape = trans::TransShapeToDevice(*host_shape, format_); + } else { + if (host_shape_.empty()) { + *host_shape = trans::PaddingShapeTo4d(*host_shape); + } else { + host_shape->clear(); + (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), IntToSize); + } + device_shape = trans::TransShapeToDevice(*host_shape, format_); + } + return device_shape; +} + bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, mindspore::TypeId type, void *host_ptr) const { MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; bool sync_ok = false; - auto host_tmp = std::vector(size_); - SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); std::vector host_shape; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - std::vector device_shape; if (host_shape.empty()) { host_shape.emplace_back(1); } - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { - device_shape = trans::TransShapeToDevice(host_shape, format_); - } else { - if (host_shape_.empty()) { - host_shape = trans::PaddingShapeTo4d(host_shape); - } else { - host_shape.clear(); - (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); + std::vector device_shape = GetDeviceShape(&host_shape); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode && type_id_name_map.find(type_id_) != type_id_name_map.end()) { + std::pair type_format = std::make_pair(type_id_name_map.at(type_id_), format_); + if (use_trans_data.find(type_format) != use_trans_data.end()) { + sync_ok = SyncDeviceToHostAndConvertFormatBasedOnTransData(host_shape, device_shape, size, type, host_ptr); + return sync_ok; } - - device_shape = trans::TransShapeToDevice(host_shape, format_); } + auto host_tmp = std::vector(size_); + SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); if (type_id_ != type) { const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; @@ -303,11 +606,6 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(ptr_) - kMemAlignSize; -} - AscendDeviceAddress::~AscendDeviceAddress() { if (ptr_ == nullptr) { return; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h index 78d7006b56..944d4bce7c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -14,15 +14,17 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ #include #include #include +#include #include "runtime/device/device_address.h" #include "runtime/device/ascend/ascend_memory_pool.h" #include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" namespace mindspore { #ifdef ENABLE_DEBUGGER @@ -39,7 +41,6 @@ class AscendDeviceAddress : public DeviceAddress { bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } - void UpdateCommunicationAddress() override; #ifdef ENABLE_DUMP_E2E bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, const std::vector &host_shape, TypeId host_type) const; @@ -54,11 +55,19 @@ class AscendDeviceAddress : public DeviceAddress { bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const; + bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector &host_shape, + const std::vector &device_shape, size_t size, + mindspore::TypeId type, void *host_ptr) const; void SyncStream() const; - uint8_t *communication_ptr_{nullptr}; + + void LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, size_t output_size, + const std::vector &workspace_size_list) const; + std::vector GetDeviceShape(std::vector *host_shape) const; + std::vector GetWorkspaceSizeList(const nlohmann::json &kernel_json) const; + kernel::KernelModPtr CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const; }; using AscendDeviceAddressPtr = std::shared_ptr; } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 3ab3a52d42..0408dac280 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#define PATH_MAX 0x3ffff +#define PATH_MAX 4096 #include "runtime/device/ascend/ascend_kernel_runtime.h" #include #include @@ -23,7 +23,8 @@ #include #include "runtime/device/ascend/ascend_device_address.h" #include "runtime/device/cpu/mpi/mpi_adapter.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" +#include "utils/context/context_extends.h" #include "utils/mpi/mpi_config.h" #include "runtime/device/ascend/profiling/profiling_manager.h" #include "hccl/hcom.h" @@ -37,7 +38,6 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/ascend/profiling/profiling_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" -#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" #include "runtime/device/ascend/ascend_memory_manager.h" #include "debug/tensor_load.h" @@ -49,6 +49,10 @@ using mindspore::device::ascend::tasksink::TaskGenerator; using mindspore::kernel::tbe::TbeUtils; using std::vector; +constexpr uint32_t kTupleTaskId = 0; +constexpr uint32_t kTupleStreamId = 1; +constexpr uint32_t kTupleArgs = 2; + namespace mindspore { namespace device { namespace ascend { @@ -91,13 +95,17 @@ std::string GetRankId() { AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } void AscendKernelRuntime::ClearGraphModelMap() { -#ifdef ENABLE_DATA_DUMP for (auto &iter : graph_data_dumper_) { MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; - iter.second->UnloadDumpInfo(); + auto &data_dumper = iter.second; + MS_EXCEPTION_IF_NULL(data_dumper); + data_dumper->UnloadDumpInfo(); + data_dumper->OpDebugUnregister(); } graph_data_dumper_.clear(); -#endif + // tell users which dump kernel name not used + DataDumpParser::GetInstance().PrintUnusedKernel(); + for (auto &iter : graph_model_map_) { MS_LOG(INFO) << "Ge UnloadModel " << iter.first; auto ret = ModelRunner::Instance().UnloadModel(iter.first); @@ -108,18 +116,29 @@ void AscendKernelRuntime::ClearGraphModelMap() { } void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { - MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; - auto iter = graph_model_map_.find(graph_id); - if (iter == graph_model_map_.end()) { + MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper"; + if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) { + MS_LOG(DEBUG) << "Unload dump info " << graph_id; + auto &data_dumper = dumper_iter->second; + MS_EXCEPTION_IF_NULL(data_dumper); + data_dumper->UnloadDumpInfo(); + data_dumper->OpDebugUnregister(); + graph_data_dumper_.erase(dumper_iter); + } else { MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; - return; } - MS_LOG(DEBUG) << "Ge UnloadModel " << iter->first; - auto ret = ModelRunner::Instance().UnloadModel(iter->first); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; + + MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; + if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { + MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; + auto ret = ModelRunner::Instance().UnloadModel(graph_id); + if (!ret) { + MS_LOG(ERROR) << "UnloadModel failed"; + } + graph_model_map_.erase(model_iter); + } else { + MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; } - graph_model_map_.erase(iter); } bool AscendKernelRuntime::NeedDestroyHccl() { @@ -167,9 +186,7 @@ bool AscendKernelRuntime::Init() { } #endif -#ifdef ENABLE_DATA_DUMP DataDumpParser::GetInstance().ParseDumpConfig(); -#endif // Start up profiling before rtSetDevice ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); @@ -272,7 +289,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p } // namespace #endif -bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { +bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { MS_EXCEPTION_IF_NULL(graph); #ifdef ENABLE_DUMP_E2E MS_LOG(INFO) << "Start dump step"; @@ -502,17 +519,26 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { bool status = ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); if (!status) { - MS_LOG(EXCEPTION) << "Load Task Failed"; + MS_LOG(EXCEPTION) << "Load Model Failed"; + } + + std::function model_handle = + std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first); + DistributeDebugTask(NOT_NULL(graph), NOT_NULL(model_handle)); + + status = ModelRunner::Instance().DistributeTask(model_iter->first); + if (!status) { + MS_LOG(EXCEPTION) << "Distribute Task Failed"; } + if (ProfilingManager::GetInstance().IsProfiling()) { auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first); auto stream_ids = ModelRunner::Instance().GetStreamIdList(model_iter->first); ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph)); } -#ifdef ENABLE_DATA_DUMP - LaunchDataDump(NOT_NULL(graph)); -#endif + LaunchDataDump(graph->graph_id()); + if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) { MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed"; return false; @@ -520,35 +546,40 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { return true; } -#ifdef ENABLE_DATA_DUMP -void AscendKernelRuntime::LaunchDataDump(NotNull graph) { +void AscendKernelRuntime::DistributeDebugTask(NotNull graph, + NotNull> model_handle) { if (!DataDumpParser::GetInstance().DumpEnabled()) { return; } - auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph->graph_id()); - auto data_dumper = std::make_shared(graph.get(), runtime_info_map); + auto data_dumper = std::make_shared(graph.get(), model_handle); MS_EXCEPTION_IF_NULL(data_dumper); - data_dumper->LoadDumpInfo(); auto ret = graph_data_dumper_.try_emplace(graph->graph_id(), data_dumper); + data_dumper->OpDebugRegister(); if (!ret.second) { MS_LOG(WARNING) << "[DataDump] Insert graphId:" << graph->graph_id() << " data dumper failed"; } } -#endif + +void AscendKernelRuntime::LaunchDataDump(GraphId graph_id) { + if (!DataDumpParser::GetInstance().DumpEnabled()) { + return; + } + auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph_id); + if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) { + auto &data_dumper = dumper_iter->second; + MS_EXCEPTION_IF_NULL(data_dumper); + data_dumper->set_runtime_info(runtime_info_map); + data_dumper->LoadDumpInfo(); + } else { + MS_LOG(EXCEPTION) << "GraphId:" << graph_id << " not found"; + } +} void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { - auto task_ids = ModelRunner::Instance().GetTaskIdList(graph_id); - auto graph_task_names = ProfilingUtils::graph_kernel_name(); - auto iter = graph_task_names.find(graph_id); - if (iter != graph_task_names.end()) { - const auto &task_names = iter->second; - if (task_ids.size() != task_names.size()) { - MS_LOG(WARNING) << "Task_ids and task_names size not match"; - return; - } - for (size_t i = 0; i < task_ids.size(); ++i) { - MS_LOG(INFO) << "Task_id:" << task_ids[i] << " task_name:" << task_names[i]; - } + auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph_id); + for (auto iter : runtime_info_map) { + MS_LOG(WARNING) << "Task name:" << iter.first << " task_id:" << std::get(*iter.second) + << " stream_id:" << std::get(*iter.second); } } @@ -658,7 +689,7 @@ bool AscendKernelRuntime::ResetDevice() { bool AscendKernelRuntime::HcclInit() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->IsTsdOpened()) { + if (!context::IsTsdOpened(context_ptr)) { MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; } MS_LOG(INFO) << "Do hcom init"; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 4f1663d4d5..995ee96c75 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ #include #include #include @@ -24,10 +24,8 @@ #include "framework/ge_runtime/davinci_model.h" #include "runtime/device/kernel_runtime_manager.h" #include "backend/session/session_basic.h" -#ifdef ENABLE_DATA_DUMP #include "debug/data_dump_parser.h" #include "runtime/device/ascend/dump/data_dumper.h" -#endif using ge::model_runner::TaskInfo; using std::unordered_map; @@ -40,7 +38,7 @@ class AscendKernelRuntime : public KernelRuntime { AscendKernelRuntime() = default; ~AscendKernelRuntime() override; bool Init() override; - bool DumpData(session::KernelGraph *graph) override; + bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; bool GenTask(const session::KernelGraph *graph) override; bool RunTask(const session::KernelGraph *graph) override; @@ -65,19 +63,18 @@ class AscendKernelRuntime : public KernelRuntime { bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; bool CheckGraphIdValid(GraphId graph_id) const; static void DebugTaskIdName(GraphId graph_id); + void DistributeDebugTask(NotNull graph, NotNull> model_handle); + void LaunchDataDump(GraphId graph_id); rtContext_t rt_context_{nullptr}; bool initialized_{false}; unordered_map>> task_map_; unordered_map> graph_model_map_; -#ifdef ENABLE_DATA_DUMP - void LaunchDataDump(NotNull graph); unordered_map> graph_data_dumper_; -#endif }; MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc index 035f4dd8e3..b15df0d60b 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc @@ -71,8 +71,7 @@ static void AssignLabelForLabelSet(NotNull memo->insert(graph.get()); MS_LOG(INFO) << "Assign label for " << graph->ToString(); - graph->SetExecOrderByDefault(); - auto nodes = graph->execution_order(); + const auto &nodes = graph->execution_order(); for (auto &node : nodes) { if (!node->isa()) { @@ -103,11 +102,7 @@ static void AssignLabelForGotoSwitch(NotNullToString(); - auto nodes = graph->execution_order(); - auto end_goto = graph->get_end_goto(); - if (end_goto != nullptr) { - nodes.push_back(end_goto); - } + const auto &nodes = graph->execution_order(); for (auto &node : nodes) { if (!node->isa()) { continue; @@ -115,20 +110,18 @@ static void AssignLabelForGotoSwitch(NotNullcast(); MS_EXCEPTION_IF_NULL(cnode); - std::string node_name = AnfAlgo::GetCNodeName(node); - if (node_name == kLabelGotoOpName) { + if (IsPrimitiveCNode(cnode, prim::kPrimLabelGoto)) { UpdateLabelGoto(NOT_NULL(cnode)); cnode->set_abstract(nullptr); } - if (node_name == kLabelSwitchOpName) { + if (IsPrimitiveCNode(cnode, prim::kPrimLabelSwitch)) { UpdateLabelSwitch(NOT_NULL(cnode)); } } for (auto &cg : graph->child_graph_order()) { AssignLabelForGotoSwitch(NOT_NULL(cg), memo); } - graph->SetExecOrderByDefault(); } void AscendLabelAssign::AssignLabel(NotNull> graph) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h index 6b09f2940e..43cadb8639 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ #include #include @@ -50,4 +50,4 @@ class AscendLabelAssign { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index f9da0850c6..0970fe01ed 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -16,7 +16,7 @@ #include #include "runtime/device/ascend/ascend_memory_manager.h" #include "runtime/device/ascend/ascend_memory_pool.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "runtime/mem.h" namespace mindspore { namespace device { @@ -24,17 +24,18 @@ namespace ascend { constexpr uint64_t kAscendDeviceMemGB = 30; constexpr uint64_t kMemSizeGB = 30; constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); +constexpr uint64_t kReservedMemorySize = 10 * 1024 * 1024; void AscendMemoryManager::MallocDeviceMemory() { auto context_mem = GetDeviceMemSizeFromContext(); device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; - dynamic_mem_offset_ = device_mem_size_; - auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), dynamic_mem_offset_, RT_MEMORY_HBM); + auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << dynamic_mem_offset_ << "] fail, ret[" << ret << "]"; + MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; } + dynamic_mem_offset_ = device_mem_size_ - kReservedMemorySize; AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_); AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); } @@ -79,7 +80,7 @@ void AscendMemoryManager::FreeDeviceMemory() { void AscendMemoryManager::ResetDynamicMemory() { total_dynamic_size_ = 0; - dynamic_mem_offset_ = device_mem_size_; + dynamic_mem_offset_ = device_mem_size_ - kReservedMemorySize; AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); } @@ -95,6 +96,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me } else { align_size = GetCommonAlignSize(size); } + + auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); + MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "] memory pool[" << device_mem_pool_offset << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + if (communication_mem) { // create protect area [kMemAlignSize -- data -- kMemAlignSize] uint8_t *alloc_address = reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); @@ -111,12 +118,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m } else { align_size = GetCommonAlignSize(size); } + + auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); + MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "] memory pool[" << device_mem_pool_offset << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + if (dynamic_mem_offset_ < align_size) { MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ << "]) malloc [" << align_size << "] failed!"; } auto new_offset = dynamic_mem_offset_ - align_size; - auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); if (new_offset <= device_mem_pool_offset) { MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ << "] memory pool[" << device_mem_pool_offset << "])" @@ -127,9 +139,9 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); if (communication_mem) { // create protect area [kMemAlignSize -- data -- kMemAlignSize] - return device_mem_base_ + new_offset + kMemAlignSize; + return device_mem_base_ + dynamic_mem_offset_ + kMemAlignSize; } else { - return device_mem_base_ + new_offset; + return device_mem_base_ + dynamic_mem_offset_; } } } // namespace ascend diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h index 720f15be00..fc684f3fd8 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ #include "runtime/device/memory_manager.h" namespace mindspore { namespace device { @@ -43,4 +43,4 @@ class AscendMemoryManager : public MemoryManager { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc index fe71ba43fc..4715a72a20 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc @@ -23,7 +23,7 @@ namespace device { namespace ascend { size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (size == 0) { - MS_LOG(EXCEPTION) << "Can not alloc memory size(0) in memory pool !"; + MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; } if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) { MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" @@ -33,7 +33,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { *addr = device_mem_pool_base_ + device_mem_pool_offset_; device_mem_pool_offset_ += size; if (*addr == nullptr) { - MS_LOG(EXCEPTION) << "Alloc device address is nullptr, failed to alloc memory pool memory!"; + MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!"; } return size; } @@ -50,6 +50,8 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { return size; } +size_t AscendMemoryPool::mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE / 2; } + void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { MS_EXCEPTION_IF_NULL(device_mem_pool_base); device_mem_pool_base_ = device_mem_pool_base; @@ -62,9 +64,9 @@ void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_o uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; } size_t AscendMemoryPool::free_mem_size() { - if (graph_dynamic_mem_offset_ < device_mem_pool_offset_) { + if (graph_dynamic_mem_offset_ <= device_mem_pool_offset_) { MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_ - << "] less than device mem pool offset [" << device_mem_pool_offset_ << "]!"; + << "] less than or equal to device mem pool offset [" << device_mem_pool_offset_ << "]!"; } return graph_dynamic_mem_offset_ - device_mem_pool_offset_; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h index 7a75198ab4..439468fd6f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ #include #include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" @@ -46,6 +46,8 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { protected: // The real size by memory alloc aligned. size_t AlignMemorySize(size_t size) const override; + // Get the minimum memory unit size using for dynamic extend. + size_t mem_alloc_unit_size() const override; private: AscendMemoryPool() = default; @@ -57,4 +59,4 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 7cf5b94d45..57fe3090af 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -20,11 +20,10 @@ #include #include "ir/manager.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" +#include "utils/ms_context.h" +#include "utils/ms_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_adjust.h" -#include "predict/generator/utils/ir_model_util.h" #include "backend/optimizer/common/helper.h" #include "utils/utils.h" @@ -53,13 +52,6 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) GetStreamRelations(); PrintStreamGroups(); FindEventRelations(graph_ptr); - - // Get info for D Model - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); - generator::IRModelUtil::GetInstance().set_stream_num(resource_manager.get_cur_stream_num()); - // Init to 1,temporarily - generator::IRModelUtil::GetInstance().set_batch_num(1); } } @@ -672,10 +664,8 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull &graph_ptr) { CNodePtr cur_cnode_ptr = nullptr; auto cnode_ptr_list = graph_ptr->execution_order(); - // 1)first stream 0 should be actived first; - need_first_active_streams_.emplace_back(0); - // 2)stream witch kStreamNeedActivedFirst attr should be actived; + // 1)stream witch kStreamNeedActivedFirst attr should be actived; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); @@ -691,19 +681,25 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull &gra } } - // 3)independent stream:if has not been activate, push to need active vector + // 2)independent stream:if has not been activate, push to need active vector if (!independent_stream_activated_) { for (auto &item : independent_stream_map_) { need_first_active_streams_.emplace_back(item.first); } } - // 4)hcom stream:if has not been activate, push to need active vector + // 3)hcom stream:if has not been activate, push to need active vector if (!hcom_stream_activated_) { for (auto &item : hcom_stream_map_) { need_first_active_streams_.emplace_back(item.first); } } + + // 4)first stream 0 should be actived first; + auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), 0); + if (it == need_first_active_streams_.end()) { + need_first_active_streams_.emplace_back(0); + } } // section8 @@ -958,7 +954,7 @@ void AscendStreamAssign::DFS(uint32_t start, std::vector *group) { if (!IsVecExist(group)) { stream_groups_.emplace_back(*group); } else { - MS_LOG(WARNING) << "DFS should not print this log"; + MS_LOG(WARNING) << "DFS find same stream group, Not expected"; } return; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index 00fca60e8d..1ee3cd6104 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ #include #include @@ -182,4 +182,4 @@ class AscendStreamAssign { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc index ab2c6b2748..0a6ac52268 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef ENABLE_DATA_DUMP #include "runtime/device/ascend/dump/data_dumper.h" #include @@ -23,36 +22,53 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/mem.h" #include "runtime/kernel.h" +#include "runtime/rt_model.h" #include "runtime/device/ascend/dump/ge_dump.h" #include "proto/op_mapping_info.pb.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "debug/data_dump_parser.h" -constexpr uint32_t kAicpuLoadFlag = 1; -constexpr uint32_t kAicpuUnloadFlag = 0; -constexpr uint32_t kTupleTaskId = 0; -constexpr uint32_t kTupleStreamId = 1; -constexpr uint32_t kTupleArgs = 2; -constexpr uint32_t kCurrentStepTensorIndex = 0; -constexpr uint32_t kCurrentEpochTensorIndex = 1; -constexpr uint32_t kStepsPerEpochTensorIndex = 2; +static constexpr uint32_t kAicpuLoadFlag = 1; +static constexpr uint32_t kAicpuUnloadFlag = 0; +static constexpr uint32_t kTupleTaskId = 0; +static constexpr uint32_t kTupleStreamId = 1; +static constexpr uint32_t kTupleArgs = 2; +static constexpr uint32_t kCurrentStepTensorIndex = 0; +static constexpr uint32_t kCurrentEpochTensorIndex = 1; +static constexpr uint32_t kStepsPerEpochTensorIndex = 2; +static constexpr uint64_t kOpDebugShape = 2048; +static constexpr uint64_t kOpDebugHostMemSize = 2048; +static constexpr uint64_t kOpDebugDevMemSize = sizeof(void *); +static constexpr uint8_t kNoOverflow = 0; +static constexpr uint8_t kAiCoreOverflow = (0x1 << 0); +static constexpr uint8_t kAtomicOverflow = (0x1 << 1); +static constexpr uint8_t kAllOverflow = (kAiCoreOverflow | kAtomicOverflow); +static const std::map kOverflowModeStr = {{kNoOverflow, "NoOverflow"}, + {kAiCoreOverflow, "AiCoreOverflow"}, + {kAtomicOverflow, "AtomicOverflow"}, + {kAllOverflow, "AllOverflow"}}; +constexpr const char *kNodeNameOpDebug = "Node_OpDebug"; +constexpr const char *kOpTypeOpDebug = "Opdebug"; namespace mindspore { namespace device { namespace ascend { -void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task); -void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task); -void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr); +static void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task); +static void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task); +static void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr); DataDumper::~DataDumper() { ReleaseDevMem(&dev_load_mem_); ReleaseDevMem(&dev_unload_mem_); + ReleaseDevMem(&op_debug_buffer_addr_); + ReleaseDevMem(&op_debug_dump_args_); } void DataDumper::LoadDumpInfo() { MS_LOG(INFO) << "[DataDump] LoadDumpInfo start"; MS_EXCEPTION_IF_NULL(kernel_graph_); aicpu::dump::OpMappingInfo dump_info; + SetOpDebugMappingInfo(NOT_NULL(&dump_info)); SetOpMappingInfo(NOT_NULL(&dump_info)); auto kernels = kernel_graph_->execution_order(); @@ -63,6 +79,7 @@ void DataDumper::LoadDumpInfo() { } MS_LOG(INFO) << "[DataDump] LoadDumpInfo kernel:" << kernel->fullname_with_scope(); dump_kernel_names_.emplace_back(kernel->fullname_with_scope()); + DataDumpParser::GetInstance().MatchKernel(kernel->fullname_with_scope()); aicpu::dump::Task task; ConstructDumpTask(NOT_NULL(kernel), NOT_NULL(&task)); @@ -71,6 +88,8 @@ void DataDumper::LoadDumpInfo() { } RtLoadDumpData(dump_info, &dev_load_mem_); load_flag_ = true; + // graph id may changed in Unload + graph_id_ = kernel_graph_->graph_id(); MS_LOG(INFO) << "[DataDump] LoadDumpInfo end"; } @@ -83,7 +102,7 @@ void DataDumper::SetOpMappingInfo(NotNull dump_inf MS_LOG(EXCEPTION) << "Dump path invalid"; } auto device_id = context_ptr->device_id(); - dump_info->set_dump_path(dump_path.value() + "_" + std::to_string(device_id) + "/"); + dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/"); MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); dump_info->set_model_name(DataDumpParser::GetInstance().net_name() + "_" + std::to_string(kernel_graph_->graph_id())); @@ -107,9 +126,9 @@ void DataDumper::SetOpMappingInfo(NotNull dump_inf MS_EXCEPTION_IF_NULL(currnet_epoch_tensor->device_address()); MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor->device_address()); - void *current_step = current_step_tensor->device_address()->ptr_; - void *current_epoch = currnet_epoch_tensor->device_address()->ptr_; - void *steps_per_epoch = steps_per_epoch_tensor->device_address()->ptr_; + void *current_step = current_step_tensor->device_address()->GetMutablePtr(); + void *current_epoch = currnet_epoch_tensor->device_address()->GetMutablePtr(); + void *steps_per_epoch = steps_per_epoch_tensor->device_address()->GetMutablePtr(); if (current_epoch != nullptr && current_step != nullptr && steps_per_epoch != nullptr) { dump_info->set_step_id_addr(reinterpret_cast(current_epoch)); @@ -132,14 +151,13 @@ bool DataDumper::KernelNeedDump(const CNodePtr &kernel) const { void DataDumper::UnloadDumpInfo() { if (!load_flag_) { - MS_LOG(WARNING) << "Load not success, no need to unload"; + MS_LOG(WARNING) << "[DataDump] Load not success, no need to unload"; return; } - MS_EXCEPTION_IF_NULL(kernel_graph_); - MS_LOG(INFO) << "[DataDump] UnloadDumpInfo start. graphId:" << kernel_graph_->graph_id(); + MS_LOG(INFO) << "[DataDump] UnloadDumpInfo start. graphId:" << graph_id_; aicpu::dump::OpMappingInfo op_mapping_info; - op_mapping_info.set_model_id(kernel_graph_->graph_id()); + op_mapping_info.set_model_id(graph_id_); op_mapping_info.set_flag(kAicpuUnloadFlag); for (const auto &kernel_name : dump_kernel_names_) { @@ -193,6 +211,84 @@ void DataDumper::ConstructDumpTask(NotNull kernel, NotNull dump_info) const { + MS_LOG(INFO) << "[DataDump] Add op debug info to OpMappingInfo, task id = " << debug_task_id_ + << ", stream id = " << debug_stream_id_; + aicpu::dump::Task task; + task.set_end_graph(false); + task.set_task_id(debug_task_id_); + task.set_stream_id(debug_stream_id_); + task.mutable_op()->set_op_name(kNodeNameOpDebug); + task.mutable_op()->set_op_type(kOpTypeOpDebug); + + aicpu::dump::Output output; + output.set_data_type(ge::proto::DataType::DT_UINT8); + output.set_format(GeFormat::kFormat_ND); + + output.mutable_shape()->add_dim(kOpDebugShape); + + output.set_original_name(kNodeNameOpDebug); + output.set_original_output_index(0); + output.set_original_output_format(GeFormat::kFormat_ND); + output.set_original_output_data_type(ge::proto::DataType::DT_UINT8); + // due to lhisi virtual addr bug, cannot use args now + output.set_address(static_cast(reinterpret_cast(op_debug_dump_args_))); + output.set_size(kOpDebugHostMemSize); + + task.mutable_output()->Add(std::move(output)); + dump_info->mutable_task()->Add(std::move(task)); +} + +void DataDumper::OpDebugRegister() { + uint32_t op_debug_mode = DataDumpParser::GetInstance().op_debug_mode(); + auto iter = kOverflowModeStr.find(op_debug_mode); + if (iter == kOverflowModeStr.end()) { + MS_LOG(EXCEPTION) << "Invalid op debug mode " << op_debug_mode; + } + MS_LOG(INFO) << "[DataDump] Op debug mode is " << iter->second; + if (op_debug_mode == kNoOverflow) { + return; + } + + rtError_t rt_ret = rtMalloc(&op_debug_buffer_addr_, kOpDebugHostMemSize, RT_MEMORY_DDR); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMalloc failed, ret = " << rt_ret; + } + + rt_ret = rtMalloc(&op_debug_dump_args_, kOpDebugDevMemSize, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMalloc failed, ret = " << rt_ret; + } + + rt_ret = + rtMemcpy(op_debug_dump_args_, sizeof(void *), &op_debug_buffer_addr_, sizeof(void *), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMemcpy failed, ret = " << rt_ret; + } + + rt_ret = rtDebugRegister(model_handle_(), op_debug_mode, op_debug_buffer_addr_, &debug_stream_id_, &debug_task_id_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtDebugRegister failed, ret = " << rt_ret; + } + + MS_LOG(INFO) << "[DataDump] Distribute op debug task, task id = " << debug_task_id_ + << ", stream id = " << debug_stream_id_; +} + +void DataDumper::OpDebugUnregister() { + uint32_t op_debug_mode = DataDumpParser::GetInstance().op_debug_mode(); + if (op_debug_mode == kNoOverflow) { + MS_LOG(INFO) << "[DataDump] Op debug mode is no overflow, no need to unregister."; + return; + } + + MS_LOG(INFO) << "[DataDump] Start."; + rtError_t rt_ret = rtDebugUnRegister(model_handle_()); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtDebugUnRegister failed, ret = " << rt_ret; + } +} + void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr) { std::string proto_str; size_t proto_size = dump_info.ByteSizeLong(); @@ -241,6 +337,11 @@ void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull(reinterpret_cast(args)) + offset); + // device address data size + auto address = AnfAlgo::GetOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + output.set_size(address->GetSize()); + MS_LOG(INFO) << "[DataDump] output " << i << " address size:" << output.size(); MS_EXCEPTION_IF_NULL(task->mutable_output()); task->mutable_output()->Add(std::move(output)); offset += sizeof(void *); @@ -271,6 +372,11 @@ void DumpKernelInput(const CNodePtr &kernel, void *args, NotNulladd_dim(dim); } input.set_address(static_cast(reinterpret_cast(args)) + offset); + // device address data size + auto address = AnfAlgo::GetPrevNodeOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + input.set_size(address->GetSize()); + MS_LOG(INFO) << "[DataDump] input " << i << " address size:" << input.size(); MS_EXCEPTION_IF_NULL(task->mutable_input()); task->mutable_input()->Add(std::move(input)); offset += sizeof(void *); @@ -279,4 +385,3 @@ void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull #include #include #include #include +#include #include "backend/session/kernel_graph.h" namespace aicpu { @@ -37,27 +37,42 @@ namespace ascend { using RuntimeInfo = std::tuple; class DataDumper { public: - DataDumper(const session::KernelGraph *kernel_graph, - const std::map> &runtime_info_map) - : load_flag_(false), + DataDumper(const session::KernelGraph *kernel_graph, NotNull> model_handle) + : model_handle_(model_handle), + debug_task_id_(-1), + debug_stream_id_(-1), + op_debug_buffer_addr_(nullptr), + op_debug_dump_args_(nullptr), + load_flag_(false), dev_load_mem_(nullptr), dev_unload_mem_(nullptr), - kernel_graph_(kernel_graph), - runtime_info_map_(runtime_info_map) {} + graph_id_(UINT32_MAX), + kernel_graph_(kernel_graph) {} ~DataDumper(); + void set_runtime_info(const std::map> &runtime_info) { + runtime_info_map_ = runtime_info; + } void LoadDumpInfo(); - void UnloadDumpInfo(); + void OpDebugRegister(); + void OpDebugUnregister(); private: void ReleaseDevMem(void **ptr) const; bool KernelNeedDump(const CNodePtr &kernel) const; void SetOpMappingInfo(NotNull dump_info) const; + void SetOpDebugMappingInfo(NotNull dump_info) const; void ConstructDumpTask(NotNull kernel, NotNull dump_task) const; + std::function model_handle_; + uint32_t debug_task_id_; + uint32_t debug_stream_id_; + void *op_debug_buffer_addr_; + void *op_debug_dump_args_; bool load_flag_; void *dev_load_mem_; void *dev_unload_mem_; + uint32_t graph_id_; std::vector dump_kernel_names_; const session::KernelGraph *kernel_graph_; std::map> runtime_info_map_; @@ -65,5 +80,4 @@ class DataDumper { } // namespace ascend } // namespace device } // namespace mindspore -#endif -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DUMP_DATADUMP_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h b/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h index eae70c4b0b..60af609a48 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h +++ b/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_GE_DUMP_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_GE_DUMP_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DUMP_GE_DUMP_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DUMP_GE_DUMP_H_ #include #include @@ -117,4 +117,4 @@ static GeFormat GetGeFormat(const std::string &format, size_t shape_size) { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_GE_DUMP_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DUMP_GE_DUMP_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h index 0d2870eb0a..b478f59c14 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ #include "backend/session/kernel_graph.h" @@ -39,4 +39,4 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph); } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index e8fc6c7a98..f4a135ce70 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -24,11 +24,11 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "debug/anf_ir_dump.h" #include "frontend/operator/ops.h" #include "ir/func_graph.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" #include "backend/kernel_compiler/common_utils.h" @@ -492,10 +492,14 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { continue; } + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown && + AnfAlgo::OutputAddrExist(real_input_node, 0)) { + continue; + } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } @@ -546,6 +550,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); // If aicore not find valid kernel info reloading aicpu kernel info list to find it + if (select_status == kNoMatched) { MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h index 8a93b77cec..82bf5e4f75 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ #include "ir/anf.h" #include "backend/kernel_compiler/kernel_build_info.h" namespace mindspore { @@ -35,4 +35,4 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h index bf4977bf9a..0150c5b41b 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PLUGIN_IMPL_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PLUGIN_IMPL_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PLUGIN_IMPL_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PLUGIN_IMPL_H_ #include #include "./prof_engine.h" @@ -42,4 +42,4 @@ class PluginImpl : public PluginIntf { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PLUGIN_IMPL_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PLUGIN_IMPL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h index c7cbc4b7dd..cdb175fe5c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_ENGINE_IMPL_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_ENGINE_IMPL_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_ENGINE_IMPL_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_ENGINE_IMPL_H_ #include "./prof_engine.h" @@ -36,4 +36,4 @@ class ProfilingEngineImpl : public EngineIntf { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_ENGINE_IMPL_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_ENGINE_IMPL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc index 6117fe5ecf..18d67d5253 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc @@ -22,8 +22,8 @@ #include "runtime/device/ascend/profiling/plugin_impl.h" #include "runtime/device/ascend/profiling/profiling_engine_impl.h" #include "utils/log_adapter.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" +#include "utils/ms_context.h" +#include "utils/ms_utils.h" #include "utils/convert_utils.h" #include "runtime/base.h" @@ -144,17 +144,17 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { nlohmann::json startCfg; startCfg["startCfg"] = devices; - if (!ProfStartUp(NOT_NULL(&startCfg))) { + if (!ProfStartUp(startCfg)) { MS_LOG(ERROR) << "ProfMgrStartUp failed."; return false; } return true; } -bool ProfilingManager::ProfStartUp(NotNull startCfg) { +bool ProfilingManager::ProfStartUp(const nlohmann::json &startCfg) { // convert json to string std::stringstream ss; - ss << *startCfg; + ss << startCfg; std::string cfg = ss.str(); MS_LOG(INFO) << "profiling config " << cfg; auto ret = rtProfilerStart(); @@ -181,7 +181,8 @@ bool ProfilingManager::StopProfiling() { } Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); if (reporter != nullptr) { - MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); + auto ret = reporter->Flush(); + MS_LOG(INFO) << "report data end, ret = " << ret; } auto rt_ret = rtProfilerStop(); diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h index 05b5248996..26118bb27c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_MANAGER_H_ #include #include @@ -22,7 +22,7 @@ #include #include #include "utils/contract.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" using std::map; using std::string; @@ -49,7 +49,7 @@ class ProfilingManager { ~ProfilingManager() { prof_handle_ = nullptr; } private: - bool ProfStartUp(NotNull json); + bool ProfStartUp(const nlohmann::json &json); std::shared_ptr engine_0_; uint32_t device_id_; void *prof_handle_; @@ -58,4 +58,4 @@ class ProfilingManager { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc index 5b1db6a404..ed96f6ed2a 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc @@ -19,10 +19,10 @@ #include "backend/kernel_compiler/kernel.h" #include "runtime/device/ascend/profiling/profiling_manager.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/utils.h" #include "runtime/device/ascend/profiling/reporter/task_desc_reporter.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "runtime/device/ascend/profiling/reporter/point_reporter.h" namespace mindspore { @@ -167,7 +167,7 @@ std::string ProfilingUtils::GetGraphLastTbeKernelName(const std::vectorfullname_with_scope(); break; } @@ -319,7 +319,7 @@ void ProfilingUtils::SetGraphProfilingCNode(uint32_t graph_id, const std::vector bool ProfilingUtils::ValidComputeGraph(NotNull graph_ptr) { for (const auto &node : graph_ptr->execution_order()) { - if (AnfAlgo::GetKernelType(node) == TBE_KERNEL) { + if (AnfAlgo::GetKernelType(node) == TBE_KERNEL || AnfAlgo::GetKernelType(node) == AKG_KERNEL) { return true; } } diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h index de8ff2ac39..468e8f8394 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ #include #include @@ -139,4 +139,4 @@ class ProfilingUtils { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h index f25c64ce05..6b6ee9e992 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ #include #include @@ -47,4 +47,4 @@ class DescReporter { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h index 531f122cde..2ae9ff45ff 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ #include #include @@ -38,4 +38,4 @@ class GraphDescReporter : public DescReporter { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h index c24535f4ec..a5c806020d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ #include #include @@ -34,4 +34,4 @@ class PointReporter : public DescReporter { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h index 6d0ed45bef..9222d9e096 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_PROFILING_DESC_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_PROFILING_DESC_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_PROFILING_DESC_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_PROFILING_DESC_H_ #include #include @@ -85,4 +85,4 @@ class PointDesc : public ProfDesc { } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_PROFILING_DESC_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_PROFILING_DESC_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h index 51526735a9..a92aa1cb30 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ #include #include @@ -43,4 +43,4 @@ class TaskDescReporter : public DescReporter { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h index 874de54e8a..b353be681f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_ #include #include "runtime/rt.h" @@ -36,4 +36,4 @@ class RuntimeUtils { } // namespace ascend } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc index 5aeb932105..2b419367f9 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -18,8 +18,8 @@ #include #include "backend/kernel_compiler/task_stream.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" +#include "utils/ms_context.h" +#include "utils/ms_utils.h" #include "runtime/device/ascend/profiling/profiling_utils.h" #include "runtime/device/ascend/profiling/profiling_manager.h" diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h index 134dec48b6..e4206ac5a7 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASK_TASK_BUILD_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASK_TASK_BUILD_H_ #include #include @@ -58,4 +58,4 @@ class TaskGenerator { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASK_TASK_BUILD_H_ diff --git a/mindspore/ccsrc/runtime/device/convert_tensor_utils.h b/mindspore/ccsrc/runtime/device/convert_tensor_utils.h index ba1f221ac4..9636c9f8a4 100644 --- a/mindspore/ccsrc/runtime/device/convert_tensor_utils.h +++ b/mindspore/ccsrc/runtime/device/convert_tensor_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CONVERT_TENSOR_UTILS_H_ -#define MINDSPORE_CCSRC_DEVICE_CONVERT_TENSOR_UTILS_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CONVERT_TENSOR_UTILS_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CONVERT_TENSOR_UTILS_H_ #include #include @@ -30,4 +30,4 @@ void FloatToDouble(void *dst, const void *src, size_t elem_num); } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CONVERT_TENSOR_UTILS_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CONVERT_TENSOR_UTILS_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc index 92269233bd..c2131a541e 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc @@ -52,6 +52,11 @@ bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size bool CPUDeviceAddress::SyncHostToDevice(const std::vector & /*shape*/, size_t size, TypeId type, const void *host_ptr) const { + if (host_ptr == ptr_) { + MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; + return true; + } + if (type == kNumberTypeFloat16) { HalfToFloat(ptr_, host_ptr, size / 2); } else if (type == kNumberTypeFloat64) { diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h index 63cf171fa2..c06e2915e0 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ #include #include @@ -40,4 +40,4 @@ class CPUDeviceAddress : public DeviceAddress { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index d2e41a1fbd..1b44394ce0 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -20,14 +20,14 @@ #include #include #include -#include +#include #include #include "backend/kernel_compiler/kernel.h" #include "runtime/device/cpu/cpu_device_address.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/config_manager.h" #include "utils/profile.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/session_basic.h" #include "frontend/operator/ops.h" @@ -36,29 +36,11 @@ namespace mindspore { namespace device { namespace cpu { const size_t INIT_NODE_REF = 1; -namespace { -TypeId GetCPUSupportOutputTypeId(const TypeId type_id) { - TypeId support_type_id = type_id; - if (type_id == kNumberTypeUInt32) { - support_type_id = kNumberTypeInt32; - } - if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 || - type_id == kNumberTypeFloat64) { - support_type_id = kNumberTypeFloat32; - } - if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) { - MS_LOG(EXCEPTION) << "Check output type failed."; - } - return support_type_id; -} -} // namespace - void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { AssignValueNodeAddress(kernel_graph); AssignInputNodeAddress(kernel_graph); AssignKernelOutputAddress(kernel_graph); - resource_manager_.MemPlan(kernel_graph); - resource_manager_.MemMalloc(kernel_graph); + resource_manager_.AssignMemory(kernel_graph); } void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) { @@ -142,11 +124,10 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t return std::make_shared(device_ptr, device_size, format, type_id); } -tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, +tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, + size_t index, std::vector *need_sync_outputs) { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(bound_addresses); MS_EXCEPTION_IF_NULL(need_sync_outputs); size_t output_size = AnfAlgo::GetOutputTensorNum(node); if (index >= output_size) { @@ -154,28 +135,44 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s } auto address = AnfAlgo::GetMutableOutputAddr(node, index); MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, index); - std::vector temp_shape; - (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); - TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); - type_id = GetCPUSupportOutputTypeId(type_id); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - MS_EXCEPTION_IF_NULL(tensor); - if (bound_addresses->find(address) != bound_addresses->end()) { + + TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index); + TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index); + tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index); + if (tensor == nullptr) { + auto shape = AnfAlgo::GetOutputInferShape(node, index); + std::vector temp_shape; + (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + tensor = std::make_shared(infer_type_id, temp_shape); + bool is_internal_output = kernel_graph->IsInternalOutput(node, index); + if (is_internal_output) { + kernel_graph->AddInternalOutputTensor(node, index, tensor); + } + } + if (bound_addresses_.find(address) != bound_addresses_.end()) { tensor->set_device_address(address); need_sync_outputs->emplace_back(tensor); } else { - address->ptr_ = tensor->data_c(); + if (infer_type_id != device_type_id) { + size_t type_size = GetTypeByte(TypeIdToType(device_type_id)); + std::vector data_shape = tensor->shape(); + size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); + address->ptr_ = resource_manager_.MemMalloc(tensor_size); + need_sync_outputs->emplace_back(tensor); + tensor->set_device_address(address); + need_sync_outputs->emplace_back(tensor); + } else { + address->ptr_ = tensor->data_c(); + } address->ref_count_ = INIT_NODE_REF; - (void)bound_addresses->insert(address); + (void)bound_addresses_.insert(address); } tensor->set_dirty(false); return tensor; } -BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, +BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, + const session::KernelWithIndex &kernel_with_index, std::vector *need_sync_outputs) { auto &input_node = kernel_with_index.first; auto index = kernel_with_index.second; @@ -187,24 +184,26 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k VectorRef ret; for (size_t i = 1; i < node->inputs().size(); i++) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); - auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); + auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); ret.push_back(out); } return ret; } - return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); - } else if (input_node->isa() || input_node->isa()) { - auto iter = input_map.find(input_node.get()); - if (iter != input_map.end()) { + return CreatTensorForOutput(kernel_graph, node, index, need_sync_outputs); + } else if (input_node->isa()) { + auto iter = input_param_tensor_map_.find(input_node); + if (iter != input_param_tensor_map_.end()) { return iter->second; } + } else if (input_node->isa()) { + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->value(); } return BaseRef(); } - -void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, - const std::vector &inputs, VectorRef *outputs, - std::vector *need_sync_outputs) { +void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector &inputs, + VectorRef *outputs, std::vector *need_sync_outputs) { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); // bind input ptr @@ -212,11 +211,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, if (input_nodes.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; } - std::unordered_map input_map; + input_param_tensor_map_.clear(); size_t input_idx = 0; for (auto &item : input_nodes) { MS_EXCEPTION_IF_NULL(item); - input_map[item.get()] = inputs[input_idx]; + input_param_tensor_map_[item] = inputs[input_idx]; if (item->isa()) { auto address = AnfAlgo::GetMutableOutputAddr(item, 0); auto tensor = inputs[input_idx]; @@ -226,12 +225,13 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, if (tensor_address != nullptr && tensor_address != address) { (void)tensor->data_sync(); } - std::vector data_shape = tensor->shape(); - size_t tensor_size = - std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); - if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { + if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 || + tensor->data_type() == kNumberTypeInt32) { address->ptr_ = tensor->data_c(); } else { + std::vector data_shape = tensor->shape(); + size_t tensor_size = + std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); address->ptr_ = resource_manager_.MemMalloc(tensor_size); if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { @@ -245,11 +245,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, input_idx++; } // new output and bind ptr - std::set bound_addresses; + bound_addresses_.clear(); auto output_nodes = kernel_graph->outputs(); for (const auto &item : output_nodes) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); - auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); + auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); outputs->push_back(std::move(out)); } } @@ -276,7 +276,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput resource_manager_.DecreaseSummaryRefCount(summary_outputs); } -bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { +bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) { MS_EXCEPTION_IF_NULL(kernel_graph); resource_manager_.IncreaseAddressRefCount(kernel_graph); diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h index a29f840bfd..e391332f85 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ #include #include #include -#include +#include #include #include "runtime/device/kernel_runtime.h" #include "backend/session/kernel_graph.h" @@ -36,9 +36,9 @@ class CPUKernelRuntime : public KernelRuntime { ~CPUKernelRuntime() override = default; bool Init() override { return true; } - bool Run(session::KernelGraph *graph) override; + bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; void AssignKernelAddress(session::KernelGraph *kernel_graph); - void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, + void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector &inputs, VectorRef *outputs, std::vector *need_sync_outputs); void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); @@ -49,22 +49,21 @@ class CPUKernelRuntime : public KernelRuntime { TypeId type_id) override; private: - tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, + tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, std::vector *need_sync_outputs); - BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, + BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index, std::vector *need_sync_outputs); void AssignValueNodeAddress(session::KernelGraph *kernel_graph); void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); void AddRuntimeAddress(DeviceAddress *address, std::vector *input_list); CPUResourceManager resource_manager_; + std::set bound_addresses_; + std::map input_param_tensor_map_; }; } // namespace cpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc index c607260ab3..f8917893f8 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc @@ -34,11 +34,13 @@ void CPUResourceManager::MemFree() { dynamic_mem_.clear(); } -void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { - mem_plan_.MemPlan(graph); - size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph); +void CPUResourceManager::AssignMemory(const session::KernelGraph *graph) { + size_t graph_mem_size = mem_plan_.MemPlan(graph); if (graph_mem_size > mem_size_) { - MemFree(); + if (mem_size_ > 0) { + dynamic_mem_[mem_ptr_] = mem_size_; + mem_size_ = 0; + } mem_ptr_ = reinterpret_cast(malloc(graph_mem_size)); if (mem_ptr_ != nullptr) { mem_size_ = graph_mem_size; @@ -48,9 +50,6 @@ void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { dynamic_malloc_ = true; } } -} - -void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) { if (dynamic_malloc_) { return; } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h index d251760dd2..5e476cac69 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ #include -#include +#include #include "backend/session/kernel_graph.h" #include "backend/session/session_basic.h" #include "runtime/device/device_address.h" @@ -30,8 +30,7 @@ class CPUResourceManager { CPUResourceManager() = default; ~CPUResourceManager(); - void MemPlan(const session::KernelGraph *graph); - void MemMalloc(const session::KernelGraph *graph); + void AssignMemory(const session::KernelGraph *graph); void IncreaseAddressRefCount(const session::KernelGraph *graph); void DecreaseAddressRefCount(const AnfNodePtr &kernel); void *MemMalloc(size_t mem_size); @@ -46,10 +45,10 @@ class CPUResourceManager { size_t mem_size_{0}; uint8_t *mem_ptr_{nullptr}; bool dynamic_malloc_{false}; - std::unordered_map dynamic_mem_; + std::map dynamic_mem_; }; } // namespace cpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc index 7838e66984..78c63bac81 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc @@ -19,9 +19,9 @@ namespace mindspore { namespace device { namespace cpu { -void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { +size_t CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); - size_t total_mem_size = 0; + size_t total_mem_size = 32; auto kernels = graph->execution_order(); for (const auto &kernel : kernels) { MS_EXCEPTION_IF_NULL(kernel); @@ -58,15 +58,8 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { } } } - graph_mem_size_[graph] = total_mem_size; -} -size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const { - auto iter = graph_mem_size_.find(graph); - if (iter != graph_mem_size_.end()) { - return iter->second; - } - return 0; + return total_mem_size; } void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h index 123e29fbe5..8fce841807 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h @@ -13,11 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ #include -#include #include "backend/session/kernel_graph.h" #include "runtime/device/device_address.h" @@ -29,15 +28,11 @@ class CPUSimpleMemPlan { CPUSimpleMemPlan() = default; ~CPUSimpleMemPlan() = default; - void MemPlan(const session::KernelGraph *graph); + size_t MemPlan(const session::KernelGraph *graph); void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); - size_t GetGraphMemSize(const session::KernelGraph *graph) const; - - private: - std::unordered_map graph_mem_size_; }; } // namespace cpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc index 9528e61ee9..b9496318dc 100644 --- a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -141,7 +141,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) { if (kernel_attr.GetAllSame()) { ExpandKernelAttr(kernel_node, &kernel_attr); } - if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { + bool ignore_check = false; + if (index == kernel_attrs.size() - 1 && input_types.size() == input_not_cnode_indexes.size()) { + ignore_check = true; + } + if (ignore_check || IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (kernel_attr.GetOutputSize() != output_num) { MS_LOG(DEBUG) << "Output num is not equal!"; diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h index b707c55e2c..9fd5c55b7d 100644 --- a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_KERNEL_SELECT_CPU_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_KERNEL_SELECT_CPU_H_ #include #include @@ -67,4 +67,4 @@ class KernelAttr { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_KERNEL_SELECT_CPU_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h index 839aa495f7..f0a7daf225 100644 --- a/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h +++ b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_ #ifdef ENABLE_MPI #include #include @@ -71,4 +71,4 @@ class MPIAdapter { } // namespace cpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_ diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 32f5fcced9..4775319f51 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -61,10 +61,10 @@ class DeviceAddress : public mindspore::DeviceSync { std::string format() const { return format_; } TypeId type_id() const { return type_id_; } void set_host_shape(const std::vector &shape) { host_shape_ = shape; } - virtual void UpdateCommunicationAddress() {} virtual void set_status(DeviceAddressStatus status) {} virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } + void *GetMutablePtr() const override { return ptr_; } protected: const void *ptr() const { return ptr_; } @@ -76,6 +76,7 @@ class DeviceAddress : public mindspore::DeviceSync { string format_{"DefaultFormat"}; TypeId type_id_{kNumberTypeFloat16}; bool from_mem_pool_{false}; + uint8_t *communication_ptr_{nullptr}; std::vector host_shape_{}; friend class KernelRuntime; friend class MemoryManager; diff --git a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc index 547c2fbe64..5cb97ac1e7 100644 --- a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc +++ b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc @@ -17,7 +17,7 @@ #include "runtime/device/gpu/blocking_queue.h" #include #include "runtime/device/gpu/gpu_common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace device { diff --git a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h index 77744bce31..f5a33d36ca 100644 --- a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h +++ b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_BLOCKING_QUEUE_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_BLOCKING_QUEUE_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_BLOCKING_QUEUE_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_BLOCKING_QUEUE_H_ #include #include @@ -93,4 +93,4 @@ class BlockingQueue { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_BLOCKING_QUEUE_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_BLOCKING_QUEUE_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_common.h b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h index 2689fdbaca..a85631c8f1 100644 --- a/mindspore/ccsrc/runtime/device/gpu/cuda_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_COMMON_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_COMMON_H_ #include #include "runtime/device/gpu/gpu_device_manager.h" @@ -30,6 +30,7 @@ class CudaCommon { inline int blocks_num(const int total_threads) const { return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); } + size_t share_memory_size() const { return max_share_memory_; } static CudaCommon &GetInstance() { static CudaCommon instance; @@ -44,6 +45,7 @@ class CudaCommon { threads_per_block_ = prop.maxThreadsPerBlock; max_blocks_ = prop.multiProcessorCount; major_sm_ = prop.major; + max_share_memory_ = prop.sharedMemPerBlock; } ~CudaCommon() = default; CudaCommon(const CudaCommon &) = delete; @@ -52,14 +54,16 @@ class CudaCommon { int max_blocks_; int threads_per_block_; int major_sm_; + size_t max_share_memory_; }; #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() +#define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size() #define MINIUM_SM 6 #define RECOMMEND_SM 7 } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_COMMON_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc index 1f5e5e3c22..9eef3d6f10 100644 --- a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc @@ -209,6 +209,16 @@ bool CudaDriver::QueryEvent(const DeviceEvent &event) { } } +bool CudaDriver::ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end) { + auto ret = cudaEventElapsedTime(cost_time, (cudaEvent_t)start, (cudaEvent_t)end); + if (ret == cudaSuccess) { + return true; + } else { + MS_LOG(ERROR) << "cudaEventElapsedTime failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } +} + int CudaDriver::device_count() { int dev_count; auto ret = cudaGetDeviceCount(&dev_count); diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h index fb5d60f6cf..140f1e4c6a 100644 --- a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_CUDA_DRIVER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_CUDA_DRIVER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ #include @@ -57,6 +57,7 @@ class CudaDriver { static bool RecordEvent(DeviceEvent event, DeviceStream stream = 0); static bool SyncEvent(const DeviceEvent &event); static bool QueryEvent(const DeviceEvent &event); + static bool ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end); // Encapsulate the cuda APIs associated with device management. static int device_count(); @@ -76,4 +77,4 @@ class CudaDriver { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_CUDA_DRIVER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h index 5373f21d70..71018e8c78 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h @@ -14,9 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ +#include #include #include "pybind11/pybind11.h" @@ -25,6 +26,12 @@ namespace device { namespace gpu { constexpr int MAX_HOSTNAME_LEN = 1024; constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; +struct NcclGroupInfo { + int size; + int rank; + ncclUniqueId unique_id; + ncclComm_t comm; +}; #define CHECK_RET(expression, result, message) \ { \ auto ret = (expression); \ @@ -39,4 +46,4 @@ constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h index c8405f12f6..c773f8ac7f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_ namespace mindspore { namespace device { @@ -33,4 +33,4 @@ class CollectiveFakeInitializer { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h index 464492d50f..d65ac9045e 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ #include #include @@ -53,4 +53,4 @@ class CollectiveInitializer { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc index f427905afa..d74f1ebea0 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc @@ -14,58 +14,37 @@ * limitations under the License. */ -#include -#include -#include -#include #include -#include #include -#include "runtime/device/gpu/distribution/mpi_wrapper.h" -#include "runtime/device/gpu/distribution/nccl_wrapper.h" +#include "runtime/device/gpu/distribution/collective_wrapper.h" -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif +void InitMPI() { MPIWrapper::instance(); } -using MPIWrapper = mindspore::device::gpu::MPIWrapper; -using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; +int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } -extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } +void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } -extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } - -extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } - -extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector &ranks) { +bool CreateCommGroup(const std::string &group_name, const std::vector &ranks) { return MPIWrapper::instance().CreateCommGroup(group_name, ranks); } -extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) { - return MPIWrapper::instance().GetRankIDByGroup(group_name); -} +int GetRankIDByGroup(const std::string &group_name) { return MPIWrapper::instance().GetRankIDByGroup(group_name); } -extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) { - return MPIWrapper::instance().GetGroupSize(group_name); -} +int GetGroupSize(const std::string &group_name) { return MPIWrapper::instance().GetGroupSize(group_name); } -extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) { - return MPIWrapper::instance().DestroyGroup(group_name); -} +bool DestroyGroup(const std::string &group_name) { return MPIWrapper::instance().DestroyGroup(group_name); } -extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); +ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream, group); } -extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, cudaStream_t stream) { - return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); +ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream, group); } -extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); +ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream, group); } diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h new file mode 100644 index 0000000000..e76ede4d38 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/mpi_wrapper.h" +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +using MPIWrapper = mindspore::device::gpu::MPIWrapper; +using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; + +extern "C" EXPORT_WRAPPER void InitMPI(); +extern "C" EXPORT_WRAPPER int local_rank_id(); +extern "C" EXPORT_WRAPPER void InitNCCLComm(); +extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector &ranks); +extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name); +extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name); +extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name); +extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, + const std::string &group); +extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, cudaStream_t stream, + const std::string &group); +extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, + cudaStream_t stream, const std::string &group); diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc index 08ec320cab..aae35d6c14 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc @@ -58,7 +58,7 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto if (rank_id_ == ranks[0]) { group_unique_id = NCCLWrapper::instance().nccl_unique_id(); } - MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm); + MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, 0, mpi_group_comm); int group_rank[1]; int global_rank[1] = {rank_id_}; @@ -68,9 +68,8 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto return false; } - ncclComm_t nccl_group_comm; - NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]); - NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm); + NcclGroupInfo nccl_group = {static_cast(ranks.size()), group_rank[0], group_unique_id, nullptr}; + NCCLWrapper::instance().AddGroupInfo(group_name, &nccl_group); return true; } @@ -111,7 +110,6 @@ void MPIWrapper::Init() { CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); - NCCLWrapper::instance().set_rank(rank_id_, rank_size_); AssignLocalRankID(); CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD"); @@ -123,7 +121,9 @@ void MPIWrapper::Init() { } CHECK_RET(MPI_Bcast(reinterpret_cast(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), MPI_SUCCESS, "Failed to broadcast nccl unique id."); - NCCLWrapper::instance().set_nccl_unique_id(unique_id); + + NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr}; + NCCLWrapper::instance().AddGroupInfo(NCCL_WORLD_GROUP, &world_group); return; } diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h index 19d06b32d3..dde200aafb 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ #include #include @@ -58,4 +58,4 @@ class MPIWrapper { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc index bcba538309..519a29a597 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc @@ -30,60 +30,58 @@ ncclUniqueId NCCLWrapper::nccl_unique_id() const { return unique_id; } -void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } - -void NCCLWrapper::set_rank(int rank_id, int rank_size) { - rank_id_ = rank_id; - rank_size_ = rank_size; -} - void NCCLWrapper::InitNCCLComm() { - CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, - "Failed to init nccl communicator."); - group_to_comm_map_[NCCL_WORLD_GROUP] = comm_; -} - -void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) { - CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator."); + for (auto group : group_info_) { + std::string group_name = group.first; + NcclGroupInfo group_info = group.second; + CHECK_RET(ncclCommInitRank(&(group_info.comm), group_info.size, group_info.unique_id, group_info.rank), ncclSuccess, + "Failed to init nccl communicator for group " + group_name); + group_info_[group_name].comm = group_info.comm; + } + comm_init_done_ = true; } ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { - CHECK_RET(group_to_comm_map_.count(group_name), 1, + CHECK_RET(group_info_.count(group_name), 1, "Failed to find NCCL communicator for AllReduce by the group name " + group_name); - ncclComm_t group_comm = group_to_comm_map_[group_name]; + ncclComm_t group_comm = group_info_[group_name].comm; return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); } ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, cudaStream_t stream, const std::string &group_name) { - CHECK_RET(group_to_comm_map_.count(group_name), 1, + CHECK_RET(group_info_.count(group_name), 1, "Failed to find NCCL communicator for AllGather by the group name " + group_name); - ncclComm_t group_comm = group_to_comm_map_[group_name]; + ncclComm_t group_comm = group_info_[group_name].comm; return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream); } ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { - CHECK_RET(group_to_comm_map_.count(group_name), 1, + CHECK_RET(group_info_.count(group_name), 1, "Failed to find NCCL communicator for ReduceScatter by the group name " + group_name); - ncclComm_t group_comm = group_to_comm_map_[group_name]; + ncclComm_t group_comm = group_info_[group_name].comm; return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); } -void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) { - group_to_comm_map_[group_name] = comm; +void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) { + if (comm_init_done_) { + CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess, + "Failed to init nccl communicator for group " + group_name); + } + group_info_[group_name] = *group; } void NCCLWrapper::DestroyGroup(const std::string &group_name) { - auto group_iter = group_to_comm_map_.find(group_name); - if (group_iter == group_to_comm_map_.end()) { + auto group_iter = group_info_.find(group_name); + if (group_iter == group_info_.end()) { return; } - group_to_comm_map_.erase(group_iter); - ncclComm_t group_comm = group_iter->second; + ncclComm_t group_comm = group_iter->second.comm; CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name); + group_info_.erase(group_iter); return; } } // namespace gpu diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h index 9cea338c41..94525ebe46 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ #include #include @@ -33,32 +33,26 @@ class NCCLWrapper { NCCLWrapper &operator=(const NCCLWrapper &) = delete; static NCCLWrapper &instance(); ncclUniqueId nccl_unique_id() const; - void set_nccl_unique_id(ncclUniqueId unique_id); - void set_rank(int rank_id, int rank_size); void InitNCCLComm(); - void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank); ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); - void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm); + void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group); void DestroyGroup(const std::string &group_name); private: - NCCLWrapper() : rank_id_(-1), rank_size_(0) {} + NCCLWrapper() : comm_init_done_(false) {} ~NCCLWrapper() = default; private: - int rank_id_; - int rank_size_; - ncclUniqueId unique_id_; - ncclComm_t comm_; - std::map group_to_comm_map_; + bool comm_init_done_; + std::map group_info_; }; } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc index a1b1fa9b79..7ca7878d56 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc @@ -18,7 +18,7 @@ #include #include #include "utils/log_adapter.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace device { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h index 722a36c4ed..610836bcf6 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUFFER_MGR_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUFFER_MGR_H_ #include #include @@ -136,4 +136,4 @@ class GpuBufferMgr { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUFFER_MGR_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h index c1b6146487..ea2b321714 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_COMMON_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_COMMON_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_COMMON_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_COMMON_H_ #include #include @@ -93,6 +93,22 @@ namespace gpu { } \ } +#define CHECK_CUSOLVER_RET_WITH_EXCEPT(expression, message) \ + { \ + cusolverStatus_t status = (expression); \ + if (status != CUSOLVER_STATUS_SUCCESS) { \ + MS_LOG(EXCEPTION) << "cusolver Error: " << message << " | Error Number: " << status; \ + } \ + } + +#define CHECK_CUSOLVER_RET_WITH_ERROR(expression, message) \ + { \ + cusolverStatus_t status = (expression); \ + if (status != CUSOLVER_STATUS_SUCCESS) { \ + MS_LOG(ERROR) << "cusolver Error: " << message << " | Error Number: " << status; \ + } \ + } + #define CHECK_NCCL_RET_WITH_EXCEPT(expression, message) \ { \ int result = (expression); \ @@ -119,4 +135,4 @@ inline bool CheckNullInput(std::vector input_shape) { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_COMMON_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_COMMON_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc index a20a6a9a3c..c7fbda2dad 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc @@ -16,9 +16,16 @@ #include "runtime/device/gpu/gpu_device_address.h" #include +#include #include "runtime/device/gpu/gpu_device_manager.h" #include "utils/log_adapter.h" #include "runtime/device/gpu/gpu_memory_allocator.h" +#include "ir/tensor.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debug_services.h" +#include "debug/tensor_load.h" +#include "debug/debugger/debugger.h" +#endif namespace mindspore { namespace device { @@ -59,6 +66,34 @@ GPUDeviceAddress::~GPUDeviceAddress() { ptr_ = nullptr; } } +#ifdef ENABLE_DEBUGGER +bool GPUDeviceAddress::LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type, size_t slot, + Debugger *debugger, bool keep_prev) const { + bool ret = false; + if (size_ == 0) { + return true; + } + DebugServices *debug_services = debugger->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + + mindspore::tensor::TensorPtr out_tensor = std::make_shared(type_id_, host_shape); + size_t host_size = out_tensor->data().nbytes(); + auto ret_rt_memcpy = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); + if (!ret_rt_memcpy) { + MS_LOG(ERROR) << "Copy device mem to host failed"; + return ret; + } + auto tensor_data = std::make_shared(); + tensor_data->SetName(tensor_name); + tensor_data->SetExecutionOrder(execution_order); + tensor_data->SetTensor(out_tensor); + tensor_data->SetSlot(slot); + ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + return ret; +} +#endif } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h index ade738deed..8a3baccb61 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h @@ -14,14 +14,17 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ #include #include #include "runtime/device/device_address.h" namespace mindspore { +#ifdef ENABLE_DEBUGGER +class Debugger; +#endif namespace device { namespace gpu { class GPUDeviceAddress : public DeviceAddress { @@ -37,6 +40,11 @@ class GPUDeviceAddress : public DeviceAddress { DeviceAddressStatus status() const { return status_; } DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } +#ifdef ENABLE_DEBUGGER + bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, + bool keep_prev) const; +#endif private: DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; }; @@ -44,4 +52,4 @@ class GPUDeviceAddress : public DeviceAddress { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc index 8f17fc20b5..5207bdf1b6 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc @@ -32,6 +32,10 @@ void GPUDeviceManager::InitDevice() { CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), "Failed to set stream for cuBLAS handle."); + CHECK_CUSOLVER_RET_WITH_EXCEPT(cusolverDnCreate(&cusolver_dn_handle_), "Failed to create cusolver dn handle."); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSetStream(cusolver_dn_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cusolver dn handle"); CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") } @@ -47,6 +51,9 @@ void GPUDeviceManager::ReleaseDevice() { if (cublas_handle_ != nullptr) { CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); } + if (cusolver_dn_handle_ != nullptr) { + CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnDestroy(cusolver_dn_handle_), "Failed to destroy cusolver dn handle."); + } CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); } @@ -79,7 +86,7 @@ bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } - +const cusolverDnHandle_t &GPUDeviceManager::GetCusolverDnHandle() const { return cusolver_dn_handle_; } bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h index 002806675c..b2bb618621 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ #include #include +#include #include #include #include "runtime/device/gpu/cuda_driver.h" @@ -43,6 +44,7 @@ class GPUDeviceManager { const cudnnHandle_t &GetCudnnHandle() const; const cublasHandle_t &GetCublasHandle() const; + const cusolverDnHandle_t &GetCusolverDnHandle() const; bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; @@ -73,6 +75,8 @@ class GPUDeviceManager { // handle used for cuBLAS kernels. cublasHandle_t cublas_handle_{nullptr}; + // handle used for cusolver dn kernels; + cusolverDnHandle_t cusolver_dn_handle_{nullptr}; bool dev_id_init_; uint32_t cur_dev_id_; }; @@ -80,4 +84,4 @@ class GPUDeviceManager { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc index 9d88a205bc..e4b054cb5b 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc @@ -21,13 +21,16 @@ #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "frontend/operator/ops.h" #include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_build_client.h" + namespace mindspore { namespace device { namespace gpu { void GpuBuild(const KernelGraphPtr &kernel_graph) { kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); MS_EXCEPTION_IF_NULL(bin_map); - bin_map->Initialize(); + auto pid = mindspore::kernel::GpuKernelBuildClient::Instance().AkgGetPid(); + bin_map->Initialize(pid); MS_EXCEPTION_IF_NULL(kernel_graph); auto kernels = kernel_graph->execution_order(); for (const auto &kernel : kernels) { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h index 831c4e9511..e4b207c1c7 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPUKERNELBUILD_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPUKERNELBUILD_H_ #include #include "backend/session/kernel_graph.h" @@ -25,4 +25,4 @@ void GpuBuild(const std::shared_ptr &kernel_graph); } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPUKERNELBUILD_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index ddf73841b7..676df13aed 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "runtime/device/gpu/gpu_kernel_runtime.h" +#include #include "runtime/device/gpu/gpu_device_address.h" #include "runtime/device/gpu/cuda_driver.h" #include "runtime/device/gpu/gpu_buffer_mgr.h" @@ -22,19 +22,26 @@ #include "runtime/device/gpu/gpu_memory_allocator.h" #include "runtime/device/gpu/distribution/collective_init.h" #include "utils/convert_utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/gpu/gpu_common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "runtime/device/gpu/gpu_memory_manager.h" #include "backend/kernel_compiler/common_utils.h" #include "runtime/device/gpu/gpu_memory_copy_manager.h" +#include "common/trans.h" +#include "ir/dtype.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debug_services.h" +#endif namespace mindspore { namespace device { namespace gpu { +using mindspore::device::memswap::MemSwapInfoSet; using mindspore::device::memswap::MemSwapManager; using mindspore::device::memswap::SwapKind; +static const size_t PARAMETER_OUTPUT_INDEX = 0; bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } bool GPUKernelRuntime::Init() { @@ -42,7 +49,15 @@ bool GPUKernelRuntime::Init() { GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); return true; } - auto ret = InitDevice(); + bool ret = false; +#ifdef ENABLE_DUMP_E2E + ret = SetDumpConf(); + if (!ret) { + MS_LOG(INFO) << "No dump conf to set!"; + } +#endif + + ret = InitDevice(); if (!ret) { MS_LOG(ERROR) << "InitDevice error."; return ret; @@ -62,6 +77,262 @@ bool GPUKernelRuntime::Init() { return ret; } +#ifdef ENABLE_DUMP_E2E +namespace { +void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf, + Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(dump_conf); + bool trans_flag = dump_conf->trans_flag(); + const auto &apply_kernels = graph->execution_order(); + for (const auto &node : apply_kernels) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = AnfAlgo::GetCNodeName(node); + std::string kernel_name = node->fullname_with_scope(); + if (!dump_conf->IsKernelNeedDump(kernel_name)) { + continue; + } + const std::string strsrc = "/"; + const std::string strdst = "--"; + std::string::size_type pos = 0; + std::string::size_type srclen = strsrc.size(); + std::string::size_type dstlen = strdst.size(); + while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) { + kernel_name.replace(pos, srclen, strdst); + pos += dstlen; + } + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t j = 0; j < output_size; ++j) { + auto addr = AnfAlgo::GetOutputAddr(node, j); + TypeId addr_type_id = addr->type_id(); + std::string addr_format = addr->format(); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(node, j); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(node, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + + auto type = AnfAlgo::GetOutputInferDataType(node, j); + + auto format = kOpFormat_DEFAULT; + string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); + + DebugServices *debug_services = debugger->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + std::string original_kernel_name = node->fullname_with_scope(); + size_t slot = j; + auto ret = tensor_loader->DumpTensorToFile(original_kernel_name, trans_flag, filepath, format, int_shapes, type, + addr_type_id, addr_format, slot); + + if (!ret) { + std::string error = "DumpTensorToFile Failed: flag:" + std::to_string(trans_flag) + ", path:" + filepath + + ", host_format:" + format + ".!"; + } + } + } +} + +void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf, + Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(dump_conf); + bool trans_flag = dump_conf->trans_flag(); + const auto ¶meters = graph->inputs(); + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + if (!dump_conf->IsKernelNeedDump(parameter_name)) { + continue; + } + auto addr = AnfAlgo::GetOutputAddr(item, PARAMETER_OUTPUT_INDEX); + TypeId addr_type_id = addr->type_id(); + std::string addr_format = addr->format(); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(item, PARAMETER_OUTPUT_INDEX); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(item, PARAMETER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + + auto type = AnfAlgo::GetOutputInferDataType(item, PARAMETER_OUTPUT_INDEX); + + auto format = kOpFormat_DEFAULT; + string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; + + DebugServices *debug_services = debugger->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + std::string original_kernel_name = parameter_name; + size_t slot = 0; + auto ret = tensor_loader->DumpTensorToFile(original_kernel_name, trans_flag, filepath, format, int_shapes, type, + addr_type_id, addr_format, slot); + + if (!ret) { + std::string error = "DumpTensorToFile Failed: flag:" + std::to_string(trans_flag) + ", path:" + filepath + + ", host_format:" + format + ".!"; + } + } +} +} // namespace + +bool GPUKernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Start dump step"; + DumpConfPtr dump_conf = GetDumpConf(); + MS_EXCEPTION_IF_NULL(dump_conf); + dump_conf->UpdataCurIter(); + bool dump_flag = dump_conf->dump_enable(); + if (!dump_flag) { + MS_LOG(INFO) << "Dump flag is disable, pass dump step"; + return true; + } + uint32_t cur_iter = dump_conf->cur_iter(); + if (dump_conf->dump_iter() != 0) { + if (cur_iter != dump_conf->dump_iter()) { + return true; + } + } + MS_LOG(INFO) << "Cur iter is " << cur_iter; + std::string net_name = dump_conf->dump_net_name(); + std::string iterator = std::to_string(cur_iter); + std::string dump_path = dump_conf->dump_path(); + if (dump_path.back() == '/') { + dump_path = dump_path + net_name + '/' + iterator; + } else { + dump_path = dump_path + '/' + net_name + '/' + iterator; + } + + // dump output + DumpOutput(graph, dump_path, dump_conf, debugger); + // dump parameters + DumpParameters(graph, dump_path, dump_conf, debugger); + + return true; +} +#endif + +#ifdef ENABLE_DEBUGGER +namespace { +void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, + const std::vector &kernel_inputs, + const std::vector &kernel_workspaces, + const std::vector &kernel_outputs, int exec_order, void *stream_ptr, + bool dump_enabled) { + // check if we should read the kernel data + bool read_data = false; + std::string kernel_name = kernel->fullname_with_scope(); + if (debugger) { + debugger->SetCurNode(kernel_name); + if (dump_enabled) { + read_data = true; + } else if (debugger->debugger_enabled()) { + read_data = debugger->ReadNodeDataRequired(); + } + } + + if (!read_data) { + return; + } + + // get inputs + if (!dump_enabled) { + auto input_size = AnfAlgo::GetInputTensorNum(kernel); + for (size_t j = 0; j < input_size; ++j) { + auto input_kernel = kernel->input(j + 1); + std::string input_kernel_name = input_kernel->fullname_with_scope(); + auto addr = kernel_inputs[j]; + auto type = AnfAlgo::GetOutputInferDataType(input_kernel, PARAMETER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + auto gpu_addr = std::make_unique(addr->addr, addr->size, format, type); + string input_tensor_name = input_kernel_name + ':' + "0"; + std::vector int_shapes; + auto shape = AnfAlgo::GetOutputDeviceShape(input_kernel, PARAMETER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + auto ret = gpu_addr->LoadMemToHost(input_tensor_name, exec_order, format, int_shapes, type, 0, debugger, false); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost:" + << ", tensor_name:" << input_tensor_name << ", host_format:" << format << ".!"; + } + } + } + + // get outputs + auto output_size = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t j = 0; j < output_size; ++j) { + auto addr = kernel_outputs[j]; + auto type = AnfAlgo::GetOutputInferDataType(kernel, j); + auto format = kOpFormat_DEFAULT; + auto gpu_addr = std::make_unique(addr->addr, addr->size, format, type); + string tensor_name = kernel_name + ':' + std::to_string(j); + std::vector int_shapes; + auto shape = AnfAlgo::GetOutputDeviceShape(kernel, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + auto ret = gpu_addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, debugger, false); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost:" + << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!"; + } + } + + debugger->PostExecuteNode(); +} + +void UpdateStepNum(Debugger *debugger, bool dump_enabled) { + if (debugger && (debugger->debugger_enabled() || dump_enabled)) { + auto cur_step_num = debugger->step_num(); + cur_step_num = cur_step_num + 1; + debugger->SetStepNum(cur_step_num); + } +} + +void LoadParameters(const session::KernelGraph *graph, Debugger *debugger, bool dump_enabled) { + MS_EXCEPTION_IF_NULL(graph); + if (!(debugger && dump_enabled)) { + return; + } + const auto ¶meters = graph->inputs(); + // for parameters, set its execution order to be 0; + int exec_order = 0; + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + auto addr = AnfAlgo::GetOutputAddr(item, PARAMETER_OUTPUT_INDEX); + auto type = AnfAlgo::GetOutputInferDataType(item, PARAMETER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + string tensor_name = parameter_name + ':' + "0"; + auto gpu_addr = dynamic_cast(addr); + std::vector int_shapes; + auto shape = AnfAlgo::GetOutputDeviceShape(item, PARAMETER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + auto ret = gpu_addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, debugger, true); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost:" + << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!"; + } + } +} + +void ClearCurrentData(Debugger *debugger, bool dump_enabled) { + if (debugger && (debugger->debugger_enabled() || dump_enabled)) { + DebugServices *debug_services = debugger->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + tensor_loader->EmptyCurrentTensor(); + } +} +} // namespace +#endif + DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) { return std::make_shared(device_ptr, device_size, format, type_id); @@ -111,7 +382,7 @@ void GPUKernelRuntime::ReleaseDeviceRes() { auto &mem_swap_manager = item.second; MS_EXCEPTION_IF_NULL(mem_swap_manager); if (mem_swap_manager->trigger_swap()) { - mem_swap_manager->ClearSwapQueue(); + mem_swap_manager->ClearSwapQueue(false); mem_swap_manager->ReleaseHostPinnedMem(); } } @@ -139,12 +410,14 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { InitKernelRefCount(graph); InitMemorySwapInfo(graph); InitKernelOutputAddress(graph); + InitKernelWorkspaceAddress(graph); + SaveGraphOutputNode(graph); } else { AssignDynamicMemory(graph); } } -bool GPUKernelRuntime::Run(session::KernelGraph *graph) { +bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { struct timeval start_time, end_time; (void)gettimeofday(&start_time, nullptr); bool ret = true; @@ -160,12 +433,14 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { } mem_swap_manager_ = iter->second; MS_EXCEPTION_IF_NULL(mem_swap_manager_); - while (!LaunchKernelDynamic(graph)) { - MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; - if (!UpdateMemorySwapInfo(graph)) { - return false; - } + auto mem_reuse_iter = mem_reuse_util_map_.find(graph_id); + if (mem_reuse_iter == mem_reuse_util_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory reuse map failed."; } + mem_reuse_util_ = mem_reuse_iter->second; + MS_EXCEPTION_IF_NULL(mem_reuse_util_); + + ret = RunOneStep(graph, debugger); } else { ret = LaunchKernel(graph); } @@ -177,6 +452,80 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { return ret; } +bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph, Debugger *debugger) { + bool ret = true; + auto graph_id = graph->graph_id(); + if (!is_first_step_map_[graph_id]) { + // Normally run graph + ret = LaunchKernelDynamic(graph, debugger); + } else { + // Mock run first step + ret = LaunchKernelDynamic(graph, debugger, true, false); + if (ret) { + // Normally run graph + ret = LaunchKernelDynamic(graph, debugger); + } else { + // Trigger memory swap + ret = SearchMemSwapScheme(graph, debugger); + } + is_first_step_map_[graph_id] = false; + } + return ret; +} + +bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger) { + MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; + bool ret = false; + ClearKernelOldOutputAndWorkspace(graph); + if (!mem_swap_manager_->mem_swap_init()) { + if (!mem_swap_manager_->Init(graph)) { + return false; + } + } + + while (!ret) { + if (!mem_swap_manager_->RetreatSwapInfo()) { + return false; + } + ret = LaunchKernelDynamic(graph, debugger, true, false); + if (!ret) { + ClearKernelOldOutputAndWorkspace(graph); + } + } + mem_swap_manager_->AssignHostMemory(); + + // Time profiling + ret = LaunchKernelDynamic(graph, debugger, false, true); + if (!ret) { + return ret; + } + return RefineMemSwapScheme(graph, debugger); +} + +bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger) { + MS_LOG(WARNING) << "Refine memory swap scheme, it may take some time, please wait a moment."; + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) { + continue; + } + + size_t swap_in_task_num = mem_swap_manager_->QueryKernelTriggerSwapInTaskNum(kernel); + for (size_t swap_in_task_idx = 0; swap_in_task_idx < swap_in_task_num; swap_in_task_idx++) { + bool ret = false; + while (!ret) { + mem_swap_manager_->AdjustSwapInPos(kernel, swap_in_task_idx); + ret = LaunchKernelDynamic(graph, debugger, true, false); + if (!ret) { + ClearKernelOldOutputAndWorkspace(graph); + ClearSwapInfo(true); + } + } + } + } + return true; +} + void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); @@ -203,6 +552,7 @@ void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(mem_swap_manager); auto graph_id = graph->graph_id(); mem_swap_map_[graph_id] = mem_swap_manager; + is_first_step_map_[graph_id] = true; } void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { @@ -224,10 +574,52 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph } } +void GPUKernelRuntime::InitKernelWorkspaceAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + auto device_address = CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown); + AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); + } + } +} + +void GPUKernelRuntime::SaveGraphOutputNode(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto graph_id = graph->graph_id(); + const auto &output_nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + for (const auto &node : output_nodes) { + graph_output_map_[graph_id].insert(node); + } +} + +bool GPUKernelRuntime::IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(graph); + auto graph_id = graph->graph_id(); + auto iter = graph_output_map_.find(graph_id); + if (iter == graph_output_map_.end()) { + MS_LOG(EXCEPTION) << "Find graph output info failed."; + } + auto &graph_output_set = iter->second; + return (graph_output_set.find(kernel) != graph_output_set.end()); +} + +void GPUKernelRuntime::ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph) { + ClearKernelOutputAddress(graph); + ClearKernelWorkspaceAddress(graph); +} + void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto &kernels = graph->execution_order(); for (const auto &kernel : kernels) { + if (IsGraphOutput(graph, kernel)) { + continue; + } auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); @@ -236,6 +628,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap continue; } auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_) { mem_manager_->FreeMemFromMemPool(device_address); } @@ -244,63 +637,149 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap } } -bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { +void GPUKernelRuntime::ClearKernelWorkspaceAddress(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); - auto graph_id = graph->graph_id(); - auto iter = mem_reuse_util_map_.find(graph_id); - if (iter == mem_reuse_util_map_.end()) { - MS_LOG(EXCEPTION) << "Find memory reuse map failed."; + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); + } + } } - auto mem_reuse_util_ptr = iter->second; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); +} + +bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger, bool mock, + bool profiling) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); // Reset the reference count. - mem_reuse_util_ptr->ResetDynamicUsedRefCount(); + mem_reuse_util_->ResetDynamicUsedRefCount(); // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); +#ifdef ENABLE_DEBUGGER + bool dump_enabled = GPUKernelRuntime::DumpDataEnabledIteration(); + if (!mock) { + UpdateStepNum(debugger, dump_enabled); + } +#endif auto &kernels = graph->execution_order(); + int exec_order = 1; + for (const auto &kernel : kernels) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; - auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs, mock); if (!ret) { +#ifdef ENABLE_DEBUGGER + if (!mock) { + // invalidate current data collected by the debugger + ClearCurrentData(debugger, dump_enabled); + } +#endif return false; } - if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { - MS_LOG(EXCEPTION) << "Launch kernel failed."; + if (!mock) { + if (!profiling) { + CHECK_OP_RET_WITH_EXCEPT(kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_), + "Launch kernel failed."); + } else { + LaunchKernelWithTimeProfiling(kernel, kernel_inputs, kernel_workspaces, kernel_outputs); + } +#ifdef ENABLE_DEBUGGER + // called once per kernel to collect the outputs to the kernel (does a SyncDeviceToHost) + LoadKernelData(debugger, kernel, kernel_inputs, kernel_workspaces, kernel_outputs, exec_order, stream_, + dump_enabled); +#endif + } + exec_order = exec_order + 1; + FreeKernelDynamicRes(kernel); + if (!UpdateMemorySwapTask(kernel, mock, profiling)) { +#ifdef ENABLE_DEBUGGER + if (!mock) { + // invalidate current data collected by the debugger + ClearCurrentData(debugger, dump_enabled); + } +#endif + return false; } - FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); - UpdateMemorySwapTask(kernel); } - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - ClearSwapQueue(); + if (!mock) { +#ifdef ENABLE_DEBUGGER + // collect weights and bias for dump mode + LoadParameters(graph, debugger, dump_enabled); +#endif + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + } + ClearSwapInfo(mock); return true; } -bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { +void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, + const AddressPtrList &workspace, const AddressPtrList &outputs) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + float cost_time = 0; + DeviceEvent start = nullptr; + DeviceEvent end = nullptr; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create event."); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, stream_), "Failed to record event to stream."); + CHECK_OP_RET_WITH_EXCEPT(kernel_mod->Launch(inputs, workspace, outputs, stream_), "Launch kernel failed."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(end, stream_), "Failed to record event to stream."); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(start), "Failed to sync event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(end), "Failed to sync event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ElapsedTime(&cost_time, start, end), "Failed to record elapsed time."); + + mem_swap_manager_->AddKernelExecutionPerform(kernel, cost_time); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(start), "Failed to destroy event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(end), "Failed to destroy event."); +} + +bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); - for (auto &mem_swap_info : mem_swap_info_list) { - auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); - const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; - auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); + const MemSwapInfoSet &mem_swap_info_set = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); + for (auto &mem_swap_info : mem_swap_info_set) { + auto need_swap_kernel = mem_swap_manager_->QueryKernelByTopoOrder(mem_swap_info.topo_order_); + MS_EXCEPTION_IF_NULL(need_swap_kernel); + const HostAddress &host_address = + mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_); + auto device_address = AnfAlgo::GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false); if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { - mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); + if (mem_swap_manager_->QueryKernelHostAddrIsDirty(need_swap_kernel, mem_swap_info.output_idx_)) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address, mock); + mem_swap_manager_->AddKernelHostAddrIsDirty(need_swap_kernel, mem_swap_info.output_idx_, false); + } else { + mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInHost); + } } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { auto status = device_address->status(); if (status == DeviceAddressStatus::kInDeviceToHost) { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); device_address->set_status(DeviceAddressStatus::kInDevice); } else if (status == DeviceAddressStatus::kInHost) { - if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) { + if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_, mock)) { return false; } - if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) { - mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address); + float cost_time = 0; + mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address, mock, profiling, + &cost_time); + if (profiling) { + mem_swap_manager_->AddKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_, + std::make_pair(0, cost_time)); } } } @@ -308,85 +787,81 @@ bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { return true; } -bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - ClearKernelOutputAddress(graph); - if (!mem_swap_manager_->mem_swap_init()) { - mem_swap_manager_->Init(graph); - } - return mem_swap_manager_->RetreatSwapInfo(); -} - -bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) { +bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); if (!mem_swap_manager_->trigger_swap()) { return true; } if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - if (!AddMemorySwapTask(kernel)) { + if (!mock) { + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + } + if (!AddMemorySwapTask(kernel, mock, profiling)) { return false; } + if (!mock) { + CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed."); + } } - CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed."); return true; } -void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) { +void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); if (!mem_swap_manager_->trigger_swap()) { return; } - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice, mock)) { device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); } + auto status = device_address->status(); switch (status) { case DeviceAddressStatus::kInDevice: break; case DeviceAddressStatus::kInDeviceToHost: { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); device_address->set_status(DeviceAddressStatus::kInDevice); break; } case DeviceAddressStatus::kInHostToDevice: { while (device_address->status() != DeviceAddressStatus::kInDevice) { - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice, mock)) { device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); } } break; } case DeviceAddressStatus::kInHost: - MS_LOG(ERROR) << "Invaild device address status:" << status; + MS_LOG(WARNING) << "Unexpected device address status: " << status; break; default: - MS_LOG(EXCEPTION) << "Invaild device address status:" << status; + MS_LOG(EXCEPTION) << "Invaild device address status: " << status; } } -void GPUKernelRuntime::UpdateDeviceSwapQueue() { +void GPUKernelRuntime::UpdateHostSwapOutQueue(bool mock) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); if (!mem_swap_manager_->trigger_swap()) { return; } - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost, mock)) { + if (device_address_swap_out->status() == DeviceAddressStatus::kInDeviceToHost && device_address_swap_out->ptr_) { device_address_swap_out->set_status(DeviceAddressStatus::kInHost); mem_manager_->FreeMemFromMemPool(device_address_swap_out); } } } -void GPUKernelRuntime::ClearSwapQueue() { +void GPUKernelRuntime::ClearSwapInfo(bool mock) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); if (!mem_swap_manager_->trigger_swap()) { return; } - mem_swap_manager_->ClearSwapQueue(); + mem_swap_manager_->ClearSwapQueue(mock); + mem_swap_manager_->ResetHostAddrIsDirty(); } -bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { +bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock) { MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_swap_manager_); auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); @@ -394,13 +869,11 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, if (!mem_swap_manager_->trigger_swap()) { return false; } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } + if (!mock) { + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); } + UpdateHostSwapOutQueue(mock); + ret = mem_manager_->MallocMemFromMemPool(device_address, size); if (!ret) { return false; @@ -409,52 +882,38 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, return true; } -void *GPUKernelRuntime::AttemptMallocMem(size_t size) { - MS_EXCEPTION_IF_NULL(mem_manager_); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - if (!mem_swap_manager_->trigger_swap()) { - return nullptr; - } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - return nullptr; - } - } - return device_ptr; -} - bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, - AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { - if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) { + AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs, + bool mock) { + if (!AllocKernelInputDynamicRes(kernel, kernel_inputs, mock)) { return false; } - if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) { + if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs, mock)) { return false; } - if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) { + if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces, mock)) { return false; } return true; } -bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { +bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, + bool mock) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_inputs); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + DeviceAddressPtr device_address; + if (mem_reuse_util_->is_all_nop_node()) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + } else { + // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); + } MS_EXCEPTION_IF_NULL(device_address); - UpdateHostSwapQueue(device_address); + UpdateHostSwapInQueue(device_address, mock); MS_EXCEPTION_IF_NULL(device_address->ptr_); kernel::AddressPtr input = std::make_shared(); MS_EXCEPTION_IF_NULL(input); @@ -466,16 +925,16 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k } bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_outputs) { + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_outputs, + bool mock) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_outputs); - UpdateDeviceSwapQueue(); + UpdateHostSwapOutQueue(mock); auto output_sizes = kernel_mod.GetOutputSizeList(); for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { + if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i], mock)) { return false; } kernel::AddressPtr output = std::make_shared(); @@ -489,7 +948,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_workspaces) { + AddressPtrList *kernel_workspaces, bool mock) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_workspaces); auto workspace_sizes = kernel_mod.GetWorkspaceSizeList(); @@ -498,13 +957,13 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K kernel_workspaces->emplace_back(nullptr); continue; } - auto device_ptr = AttemptMallocMem(workspace_sizes[i]); - if (!device_ptr) { + auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); + if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, workspace_sizes[i], mock)) { return false; } kernel::AddressPtr workspace = std::make_shared(); MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_ptr; + workspace->addr = device_address->ptr_; workspace->size = workspace_sizes[i]; kernel_workspaces->emplace_back(workspace); } @@ -525,13 +984,21 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); bool is_need_alloc_memory = false; bool is_need_free_memory = false; size_t total_size = 0; std::vector size_list; DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + DeviceAddressPtr device_address; + if (mem_reuse_util_->is_all_nop_node()) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + } else { + // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); + } MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_ == nullptr) { is_need_alloc_memory = true; @@ -592,12 +1059,10 @@ void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, boo } } -void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, - const AddressPtrList &kernel_workspaces, uint32_t graph_id) { +void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + MS_EXCEPTION_IF_NULL(mem_reuse_util_); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::IsCommunicationOp(kernel)) { @@ -605,7 +1070,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } // Free the input of kernel by reference count. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); + auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; } @@ -614,14 +1079,21 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; } if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + DeviceAddressPtr device_address; + if (mem_reuse_util_->is_all_nop_node()) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + } else { + // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true); + } mem_manager_->FreeMemFromMemPool(device_address); device_address->set_status(DeviceAddressStatus::kInDevice); } } // Free the output of kernel, if output has no reference. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); + auto kernel_ref_count_ptr = mem_reuse_util_->GetRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; } @@ -632,12 +1104,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } } // Free the workspace of kernel. - for (size_t i = 0; i < kernel_workspaces.size(); ++i) { - auto workspace = kernel_workspaces[i]; - if (workspace != nullptr) { - MS_EXCEPTION_IF_NULL(workspace->addr); - mem_manager_->FreeMemFromMemPool(workspace->addr); - workspace->addr = nullptr; + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); } } } diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 2b1f8198ce..8f3cb9cb25 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -14,12 +14,13 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ #include #include #include +#include #include #include #include "runtime/device/kernel_runtime.h" @@ -37,7 +38,10 @@ class GPUKernelRuntime : public KernelRuntime { bool Init() override; void ReleaseDeviceRes() override; void AssignMemory(session::KernelGraph *graph) override; - bool Run(session::KernelGraph *graph) override; + bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; +#ifdef ENABLE_DUMP_E2E + bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; +#endif protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, @@ -53,39 +57,52 @@ class GPUKernelRuntime : public KernelRuntime { // The related functions and members for using dynamic memory pool. void InitKernelRefCount(const session::KernelGraph *graph); void InitKernelOutputAddress(const session::KernelGraph *graph); + void InitKernelWorkspaceAddress(const session::KernelGraph *graph); void InitMemorySwapInfo(const session::KernelGraph *graph); + void SaveGraphOutputNode(const session::KernelGraph *graph); + bool IsGraphOutput(const session::KernelGraph *graph, const mindspore::AnfNodePtr &kernel) const; void ClearKernelOutputAddress(const session::KernelGraph *graph); - bool LaunchKernelDynamic(const session::KernelGraph *graph); - bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); - void *AttemptMallocMem(size_t size); + void ClearKernelWorkspaceAddress(const session::KernelGraph *graph); + void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph); + bool RunOneStep(const session::KernelGraph *graph, Debugger *debugger = nullptr); + bool SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); + bool RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); + bool LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger = nullptr, bool mock = false, + bool profiling = false); + void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, + const AddressPtrList &workspace, const AddressPtrList &outputs); + bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, - AddressPtrList *kernel_outputs); - bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs); + AddressPtrList *kernel_outputs, bool mock); + bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, bool mock); bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_outputs); + AddressPtrList *kernel_outputs, bool mock); bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces); + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces, + bool mock); void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, const DeviceAddressPtrList addr_list, size_t total_size, std::vector size_list); - void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, - uint32_t graph_id); - bool AddMemorySwapTask(const AnfNodePtr &kernel); - bool UpdateMemorySwapInfo(const session::KernelGraph *graph); - bool UpdateMemorySwapTask(const AnfNodePtr &kernel); - void UpdateHostSwapQueue(const DeviceAddressPtr device_address); - void UpdateDeviceSwapQueue(); - void ClearSwapQueue(); + void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel); + bool UpdateMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling); + bool AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bool profiling); + void UpdateHostSwapInQueue(const DeviceAddressPtr device_address, bool mock); + void UpdateHostSwapOutQueue(bool mock); + void ClearSwapInfo(bool mock); std::unordered_map mem_reuse_util_map_; std::unordered_map mem_swap_map_; + std::unordered_map is_first_step_map_; + std::unordered_map> graph_output_map_; + + MemReuseUtilPtr mem_reuse_util_{nullptr}; MemSwapManagerPtr mem_swap_manager_{nullptr}; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc index e2395bbaf2..746aeda2cd 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc @@ -18,7 +18,7 @@ #include "runtime/device/gpu/gpu_memory_allocator.h" #include "runtime/device/gpu/cuda_driver.h" #include "utils/log_adapter.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/convert_utils_base.h" namespace mindspore { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h index 4b6eaa4e14..dd66a7d5ee 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ #include #include "runtime/device/gpu/cuda_driver.h" @@ -58,4 +58,4 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc index 0406c0f151..74cae92dea 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc @@ -47,11 +47,20 @@ void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address swap_out_queue_.emplace(device_address, event); } -void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { +void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr, + bool profiling, float *cost_time) { MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(host_addr.addr); - DeviceEvent event = nullptr; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); + DeviceEvent start = nullptr; + DeviceEvent end = nullptr; + if (profiling) { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create CUDA event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create CUDA event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, swap_in_stream_), + "Failed to record CUDA event to swap in stream."); + } else { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end, cudaEventDisableTiming), "Failed to create CUDA event."); + } DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); MS_EXCEPTION_IF_NULL(device_ptr); device_address->set_status(DeviceAddressStatus::kInHostToDevice); @@ -59,9 +68,27 @@ void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, CHECK_OP_RET_WITH_EXCEPT( CudaDriver::CopyHostMemToDeviceAsync(device_ptr, host_addr.addr, host_addr.size, swap_in_stream_), "Failed to copy host memory to device."); - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_in_stream_), + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(end, swap_in_stream_), "Failed to record CUDA event to swap in stream."); - swap_in_queue_.emplace(device_address, event); + if (profiling) { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(start), "Failed to sync event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(end), "Failed to sync event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ElapsedTime(cost_time, start, end), "Failed to record elapsed time."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(start), "Failed to destroy event."); + } + swap_in_queue_.emplace(device_address, end); +} + +void GPUMemCopyManager::AddMemSwapOutTaskMock(const DeviceAddressPtr &device_address) { + MS_EXCEPTION_IF_NULL(device_address); + device_address->set_status(DeviceAddressStatus::kInDeviceToHost); + swap_out_queue_mock_.emplace(device_address); +} + +void GPUMemCopyManager::AddMemSwapInTaskMock(const DeviceAddressPtr &device_address) { + MS_EXCEPTION_IF_NULL(device_address); + device_address->set_status(DeviceAddressStatus::kInHostToDevice); + swap_in_queue_mock_.emplace(device_address); } bool GPUMemCopyManager::SyncMemCopyStream(SwapKind swap_kind) { @@ -104,6 +131,24 @@ DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueue() { return device_address; } +DeviceAddressPtr GPUMemCopyManager::UpdateSwapOutQueueMock() { + if (swap_out_queue_mock_.empty()) { + return nullptr; + } + auto device_address = swap_out_queue_mock_.front(); + swap_out_queue_mock_.pop(); + return device_address; +} + +DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueueMock() { + if (swap_in_queue_mock_.empty()) { + return nullptr; + } + auto device_address = swap_in_queue_mock_.front(); + swap_in_queue_mock_.pop(); + return device_address; +} + bool GPUMemCopyManager::AllocHostPinnedMem(size_t size, void **addr) const { auto alloc_size = CudaDriver::AllocHostPinnedMem(size, addr); return alloc_size == size; @@ -126,6 +171,15 @@ void GPUMemCopyManager::ClearSwapQueue() { swap_in_queue_.pop(); } } + +void GPUMemCopyManager::ClearSwapQueueMock() { + while (!swap_out_queue_mock_.empty()) { + swap_out_queue_mock_.pop(); + } + while (!swap_in_queue_mock_.empty()) { + swap_in_queue_mock_.pop(); + } +} } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h index dc99b7f7d0..067972a38f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ #include #include @@ -40,7 +40,12 @@ class GPUMemCopyManager : public MemCopyManager { void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; - void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; + void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr, bool profiling, + float *cost_time) override; + + void AddMemSwapOutTaskMock(const DeviceAddressPtr &device_address) override; + + void AddMemSwapInTaskMock(const DeviceAddressPtr &device_address) override; bool SyncMemCopyStream(SwapKind swap_kind) override; @@ -48,21 +53,29 @@ class GPUMemCopyManager : public MemCopyManager { DeviceAddressPtr UpdateSwapInQueue() override; + DeviceAddressPtr UpdateSwapOutQueueMock() override; + + DeviceAddressPtr UpdateSwapInQueueMock() override; + bool AllocHostPinnedMem(size_t size, void **addr) const override; void FreeHostPinnedMem(void *addr) const override; void ClearSwapQueue() override; + void ClearSwapQueueMock() override; + private: DeviceStream swap_out_stream_{nullptr}; DeviceStream swap_in_stream_{nullptr}; std::queue> swap_out_queue_; std::queue> swap_in_queue_; + std::queue swap_out_queue_mock_; + std::queue swap_in_queue_mock_; }; using GPUMemCopyManagerPtr = std::shared_ptr; } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc index ffa07eea0d..7702234ccd 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc @@ -16,7 +16,7 @@ #include "runtime/device/gpu/gpu_memory_manager.h" #include "runtime/device/gpu/gpu_memory_allocator.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/convert_utils.h" namespace mindspore { namespace device { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h index 533116cefc..1b8e28f5e9 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ #include #include "runtime/device/memory_manager.h" namespace mindspore { @@ -39,4 +39,4 @@ class GPUMemoryManager : public MemoryManager { } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h index f22ce8fe38..f31bc7aa1e 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ #include #include @@ -70,4 +70,4 @@ CNodePtr CreateStreamSwitchNode(const std::shared_ptr &ker } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 4326987784..00bae0f6c0 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -23,7 +23,7 @@ #include "backend/kernel_compiler/kernel_build_info.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/oplib/opinfo.h" diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h index b351f74fa3..00a2851f02 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_KERNEL_INFO_SETTER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_KERNEL_INFO_SETTER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ #include #include @@ -66,4 +66,4 @@ class KernelAttr { } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_KERNEL_INFO_SETTER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc index 4605a0eb4e..d34bc55302 100644 --- a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc +++ b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc @@ -15,45 +15,24 @@ */ #include "runtime/device/gpu/mpi/mpi_initializer.h" - +#include #include #include #include +#include namespace mindspore { namespace device { namespace gpu { -MPIInitializer::MPIInitializer() { - int init_flag = 0; - if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { - return; - } - if (init_flag == 0) { - auto ret = MPI_Init(nullptr, nullptr); - if (ret != MPI_SUCCESS) { - return; - } - } - MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); - MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); -} - -MPIInitializer::~MPIInitializer() { - int finalized_flag = 0; - (void)MPI_Finalized(&finalized_flag); - if (finalized_flag == 0) { - (void)MPI_Finalize(); - } -} MPIInitializer &MPIInitializer::GetInstance() { static MPIInitializer instance; return instance; } -int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } +int MPIInitializer::get_rank_id(const std::string &group) { return GetRankIDByGroup(group); } -int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } +int MPIInitializer::get_rank_size(const std::string &group) { return GetGroupSize(group); } PYBIND11_MODULE(_ms_mpi, mpi_initializer) { mpi_initializer.doc() = "mindspore mpi python wrapper"; diff --git a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h index bd0a4aa948..20b0a4fba8 100644 --- a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h +++ b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h @@ -14,8 +14,11 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ + +#include +#include "runtime/device/gpu/distribution/collective_wrapper.h" namespace mindspore { namespace device { @@ -25,18 +28,15 @@ class MPIInitializer { MPIInitializer(MPIInitializer const &) = delete; MPIInitializer &operator=(const MPIInitializer &) = delete; static MPIInitializer &GetInstance(); - static int get_rank_id(); - static int get_rank_size(); + static int get_rank_id(const std::string &group); + static int get_rank_size(const std::string &groups); private: - MPIInitializer(); - ~MPIInitializer(); - - int rank_id_; - int rank_size_; + MPIInitializer() = default; + ~MPIInitializer() = default; }; } // namespace gpu } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index bb1f7f723e..513fa68252 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -24,10 +24,10 @@ #include #include "backend/session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "common/trans.h" #include "utils/config_manager.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "runtime/device/ascend/profiling/profiling_manager.h" diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.h b/mindspore/ccsrc/runtime/device/kernel_adjust.h index dbd6f226af..b65dcd0121 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.h +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_ADJUST_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_ADJUST_H_ #include #include @@ -80,4 +80,4 @@ class KernelAdjust { }; } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_ADJUST_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_info.cc b/mindspore/ccsrc/runtime/device/kernel_info.cc index 692532e70b..a7a500ff95 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.cc +++ b/mindspore/ccsrc/runtime/device/kernel_info.cc @@ -73,6 +73,14 @@ DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { return workspace_address_list_[index].get(); } +DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const { + if (index >= workspace_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return workspace_address_list_[index]; +} + bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { if (workspace_address_list_.empty()) { // parameter and valuenode diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h index baded9d9a3..e9d997cb5e 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.h +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -54,6 +54,7 @@ class KernelInfo : public KernelInfoDevice { bool OutputAddrExist(size_t index) const; bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); DeviceAddress *GetWorkspaceAddr(size_t index) const; + DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); kernel::KernelMod *MutableKernelMod() const; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 3de9af8c23..6173aa4faf 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -19,16 +19,17 @@ #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/trans.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "frontend/operator/ops.h" #include "pipeline/jit/parse/python_adapter.h" #include "backend/session/kernel_graph.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/optimizer/common/helper.h" #include "ir/value.h" using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; @@ -41,7 +42,7 @@ KernelRuntime::~KernelRuntime() { #endif } -bool KernelRuntime::Run(session::KernelGraph *graph) { +bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { bool ret = false; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -72,7 +73,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph) { } // for D to impl -bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { +bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { if (graph != nullptr) { return true; } @@ -150,11 +151,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { UpdateRefNodeOutputMem(graph); } -void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, +void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value, + const std::vector &input_tensors, session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); RunOpAssignInputMemory(input_tensors, graph); AssignStaticMemoryValueNode(graph); + RunOpAssignOutputNodeMemory(pre_output_value, graph); for (const auto &cnode : graph->execution_order()) { RunOpAssignOutputMemory(cnode); RunOpAssignWorkSpaceMemory(cnode); @@ -190,6 +193,39 @@ void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { } } +bool KernelRuntime::DumpDataEnabled() { + bool ret = false; +#ifdef ENABLE_DUMP_E2E + DumpConfPtr dump_conf = GetDumpConf(); + MS_EXCEPTION_IF_NULL(dump_conf); + bool dump_flag = dump_conf->dump_enable(); + if (!dump_flag) { + return ret; + } + ret = true; +#endif + return ret; +} + +bool KernelRuntime::DumpDataEnabledIteration() { + bool ret = false; +#ifdef ENABLE_DUMP_E2E + if (!DumpDataEnabled()) { + return ret; + } + DumpConfPtr dump_conf = GetDumpConf(); + MS_EXCEPTION_IF_NULL(dump_conf); + uint32_t cur_iter = dump_conf->cur_iter() + 1; + if (dump_conf->dump_iter() != 0) { + if (cur_iter != dump_conf->dump_iter()) { + return ret; + } + } + ret = true; +#endif + return ret; +} + void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) { AssignStaticMemoryInput(graph); AssignStaticMemoryValueNode(graph); @@ -289,6 +325,45 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { } } +void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph) { + if (pre_output_value == nullptr) { + return; + } + std::vector pre_output_tensors; + TensorValueToTensor(pre_output_value, &pre_output_tensors); + MS_EXCEPTION_IF_NULL(graph); + auto output_nodes = graph->outputs(); + if (pre_output_tensors.size() != output_nodes.size()) { + MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size() + << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]"; + } + // share output address with pre output tensors + for (size_t i = 0; i < output_nodes.size(); ++i) { + auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0); + if (!output_node_with_index.first->isa()) { + MS_LOG(EXCEPTION) << "The output node should be a cnode , but it is " + << output_node_with_index.first->DebugString(); + } + auto real_output_cnode = output_node_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(real_output_cnode); + MS_EXCEPTION_IF_NULL(pre_output_tensors[i]); + if (pre_output_tensors[i]->device_address() == nullptr) { + MS_LOG(EXCEPTION) << "The address of pre output tensor [" << i << "] is a nullptr!"; + } + if (opt::IsNopNode(real_output_cnode)) { + if (real_output_cnode->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString() + << " should large than one!"; + } + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(pre_output_tensors[i]->device_address()), + output_node_with_index.second, real_output_cnode->input(1).get()); + } else { + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(pre_output_tensors[i]->device_address()), + output_node_with_index.second, output_node_with_index.first.get()); + } + } +} + void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); @@ -335,8 +410,10 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { output_type_id = AnfAlgo::GetOutputInferDataType(item, index); } auto tensor_size = CountNodeDeviceMemorySize(item, index); - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { + MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; + } AnfAlgo::SetOutputAddr(address, index, item.get()); } } @@ -353,7 +430,6 @@ void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { continue; } - graph->AddFinalOutputKernel(item_with_index.first); if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { AssignCommunicationNodeMem(kStaticMem, item_with_index.first); } else { @@ -398,12 +474,12 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { } } -void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { - AssignCommunicationNodeInputMem(node); - AssignCommunicationNodeOutputMem(flag, node); +void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { + AssignCommunicationNodeInputMem(type, node); + AssignCommunicationNodeOutputMem(type, node); } -void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { +void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); auto kernel_mod = AnfAlgo::GetKernelMod(node); @@ -429,14 +505,22 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr total_size += mem_size; align_size_list.emplace_back(mem_size); } - uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); + + if (type == kReuseDynamicMem) { + // reuse communication op's all outputs' memory + type = kReuseDynamicCommMem; + } + uint8_t *output_ptr = nullptr; for (size_t j = 0; j < align_size_list.size(); ++j) { std::string output_format = AnfAlgo::GetOutputFormat(node, j); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); - auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); + auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type); MS_EXCEPTION_IF_NULL(address); - if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { - address->UpdateCommunicationAddress(); + if (output_ptr == nullptr) { + output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address); + MS_EXCEPTION_IF_NULL(output_ptr); + } else { + address->set_ptr(output_ptr); } AnfAlgo::SetOutputAddr(address, j, node.get()); output_ptr += align_size_list[j]; @@ -457,13 +541,13 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, return address; } -void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { +void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); size_t total_size = 0; - std::vector> addr_size; + std::vector> addr_size; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); auto input_node = input_node_with_index.first; @@ -476,9 +560,12 @@ void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(address); auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); total_size += mem_size; - addr_size.emplace_back(address.get(), mem_size); + addr_size.emplace_back(address, mem_size); } - uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); + if (addr_size.empty()) { + return; + } + uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first); for (const auto &iter : addr_size) { MS_EXCEPTION_IF_NULL(iter.first); iter.first->set_ptr(input_ptr); @@ -486,14 +573,12 @@ void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { } } -void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); +void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); - if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { + if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) { MS_LOG(INFO) << "GetNext disable mem_reuse"; - flag = kDynamicMem; + type = kDynamicMem; } auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); @@ -510,19 +595,13 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in MS_LOG(INFO) << "Already malloc index:" << i; continue; } - auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); - if (ptr == nullptr) { - // reused ptr, no need alloc, continue; - continue; - } std::string output_format = AnfAlgo::GetOutputFormat(node, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); - auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); MS_EXCEPTION_IF_NULL(device_address); + uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address); + MS_EXCEPTION_IF_NULL(ptr); device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); - if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { - device_address->UpdateCommunicationAddress(); - } AnfAlgo::SetOutputAddr(device_address, i, node.get()); } } @@ -534,36 +613,40 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const MS_EXCEPTION_IF_NULL(mem_manager_); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - auto tensor = node_value->cast(); - if (tensor == nullptr) { - MS_LOG(WARNING) << "Tensor is null"; - return; - } - size_t tensor_size = tensor->data().nbytes(); - auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); - } - auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); - DeviceAddressPtr address = nullptr; - if (ms_context->enable_pynative_infer()) { + std::vector tensors; + TensorValueToTensor(node_value, &tensors); + for (const auto &tensor : tensors) { + if (tensor == nullptr) { + MS_LOG(WARNING) << "Tensor is null"; + return; + } + if (tensor->device_address() != nullptr) { + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), output_idx++, + value_node.get()); + continue; + } + size_t tensor_size = tensor->data().nbytes(); + auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); + } + auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); + DeviceAddressPtr address = nullptr; address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); MS_EXCEPTION_IF_NULL(address); - if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { - MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) { + MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size; + } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) { + MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size; + } + AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); + if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), + tensor->data_c())) { + MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() + << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx) + << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx); } - } else { - auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); - address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); - MS_EXCEPTION_IF_NULL(address); - } - AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); - if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), - tensor->data_c())) { - MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" - << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " - << AnfAlgo::GetOutputInferDataType(value_node, output_idx); } } @@ -580,22 +663,18 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { } auto &node_value = value_node->value(); MS_EXCEPTION_IF_NULL(node_value); - if (node_value->isa()) { + if (node_value->isa() || node_value->isa()) { AssignValueNodeTensor(value_node, node_value, 0); } else if (node_value->isa()) { auto value = GetValue(node_value); size_t tensor_size = value.size(); DeviceAddressPtr address = nullptr; - if (ms_context->enable_pynative_infer()) { - address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); - if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { - MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; - } - } else { - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); + address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { + MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size; + } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { + MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } AnfAlgo::SetOutputAddr(address, 0, value_node.get()); std::vector shape = {1, SizeToInt(tensor_size)}; @@ -612,10 +691,10 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); - auto mem_flag = kDynamicMem; + auto mem_type = kDynamicMem; if (is_enable_mem_reuse) { mem_manager_->MallocReusedDynamicMem(graph); - mem_flag = kReuseDynamicMem; + mem_type = kReuseDynamicMem; } auto &execution_nodes = graph->execution_order(); std::vector compute_nodes; @@ -623,7 +702,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { for (auto &node : execution_nodes) { if (AnfAlgo::IsCommunicationOp(node)) { // skip if the memory is already alocated - AssignCommunicationNodeMem(mem_flag, node); + AssignCommunicationNodeMem(mem_type, node); } else { compute_nodes.emplace_back(node); } @@ -631,19 +710,19 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { // then compute nodes for (auto &node : compute_nodes) { - AssignNodeOutputMem(mem_flag, node, kGetAllOuts); - AssignWorkSpaceMem(mem_flag, node); + AssignNodeOutputMem(mem_type, node, kGetAllOuts); + AssignWorkSpaceMem(mem_type, node); } } -void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { +void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); size_t index = 0; for (auto &size : kernel_mod->GetWorkspaceSizeList()) { - auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); + auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); index++; } @@ -761,6 +840,28 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; } +bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, + const AddressPtrList &kernel_inputs, + const AddressPtrList &kernel_outputs, + const AddressPtrList &kernel_workspaces) const { + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + return false; + } + return true; +} + +DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type) { + auto device_address = CreateDeviceAddress(nullptr, size, format, type); + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto base_ptr = mem_manager_->MallocMem(kStaticMem, size, device_address); + MS_EXCEPTION_IF_NULL(base_ptr); + return device_address; +} + #ifdef ENABLE_DUMP_E2E bool KernelRuntime::SetDumpConf() { dump_conf_ptr_ = std::make_shared(); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 8320355b82..cc33c5646e 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ #include #include #include @@ -23,7 +23,7 @@ #include "runtime/device/device_address.h" #include "ir/tensor.h" -#include "predict/generator/utils/ir_model_util.h" +#include "utils/convert_utils.h" #ifdef ENABLE_DUMP_E2E #include "debug/e2e_dump.h" #endif @@ -33,7 +33,7 @@ #include "backend/session/kernel_graph.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/kernel.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "runtime/device/memory_manager.h" using mindspore::tensor::Tensor; @@ -53,14 +53,20 @@ class KernelRuntime { virtual ~KernelRuntime(); virtual bool Init() = 0; virtual void AssignMemory(session::KernelGraph *graph); - void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); + void RunOpAssignMemory(const ValuePtr &pre_output_value, const std::vector &input_tensors, + session::KernelGraph *graph); void RunOpClearMemory(const session::KernelGraph *graph); - virtual bool Run(session::KernelGraph *graph); - virtual bool DumpData(session::KernelGraph *graph); + bool DumpDataEnabled(); + bool DumpDataEnabledIteration(); + virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr); + virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr); virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); virtual bool RunTask(const session::KernelGraph *graph); virtual bool GenTask(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph); + bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs, + const AddressPtrList &kernel_outputs, + const AddressPtrList &kernel_workspaces) const; virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); virtual void ClearGraphRuntimeResource(uint32_t graph_id); @@ -73,6 +79,7 @@ class KernelRuntime { // for GPU and D to impl virtual void ReleaseDeviceRes() {} void set_device_id(uint32_t device_id) { device_id_ = device_id; } + DeviceAddressPtr AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type); protected: virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, @@ -81,15 +88,15 @@ class KernelRuntime { void AssignStaticMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph); - void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); - void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); + void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index); + void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node); void AssignReuseWorkSpaceMem(const AnfNodePtr &node); void UpdateRefNodeOutputMem(const session::KernelGraph *graph); - void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); - void AssignCommunicationNodeInputMem(const AnfNodePtr &node); - void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); + void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node); + void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node); + void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); #ifdef ENABLE_DUMP_E2E bool SetDumpConf(); #endif @@ -104,6 +111,7 @@ class KernelRuntime { void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); void RunOpAssignOutputMemory(const AnfNodePtr &kernel); void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); + void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); @@ -119,4 +127,4 @@ using KernelRuntimePtr = std::shared_ptr; } // namespace device } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h index 7fcb40ae67..bf88f53087 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_MANAGER_H_ #include #include #include #include #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "runtime/device/kernel_runtime.h" namespace mindspore { namespace device { @@ -62,4 +62,4 @@ class KernelRuntimeRegistrar { DEVICE_NAME, []() { return std::make_shared(); }); } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc index 563d5f0f50..cd3cd620b5 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -16,7 +16,7 @@ #include "runtime/device/memory_manager.h" #include "backend/session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" using mindspore::memreuse::BestFitMemReuse; using mindspore::memreuse::MemReuseUtilPtr; namespace mindspore { @@ -29,7 +29,7 @@ size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; } -void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { +void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); @@ -45,8 +45,10 @@ void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { mem_reuse_util_ptr_->set_mem_base(base_ptr); } -uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { +uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, + const DeviceAddressPtr &address) { MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(address); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); uint8_t *ptr = nullptr; @@ -55,40 +57,53 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in if (context_ptr->enable_hccl()) { communication_mem = true; } - if (flag == kStaticMem) { + if (type == kStaticMem) { ptr = MallocStaticMem(size, communication_mem); + address->from_mem_pool_ = true; + if (communication_mem) { + address->communication_ptr_ = ptr - kMemAlignSize; + } + } else if (type == kReuseDynamicCommMem) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); + ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } else { ptr = MallocDynamicMem(size, communication_mem); } + address->ptr_ = ptr; return ptr; } - if (flag == kStaticMem) { + if (type == kStaticMem) { ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { + address->from_mem_pool_ = true; + } else if (type == kDynamicMem) { ptr = MallocDynamicMem(size, false); - } else if (flag == kReuseDynamicMem) { + } else if (type == kReuseDynamicMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } + address->ptr_ = ptr; return ptr; } -uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { - if (flag == kReuseDynamicMem) { +uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) { + if (type == kReuseDynamicMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); } return MallocDynamicMem(size, false); } -uint8_t *MemoryManager::MallocMem(int flag, size_t size) { +uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { + MS_EXCEPTION_IF_NULL(address); uint8_t *ptr = nullptr; - if (flag == kStaticMem) { + if (type == kStaticMem) { ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { + address->from_mem_pool_ = true; + } else if (type == kDynamicMem) { ptr = MallocDynamicMem(size, false); } + address->ptr_ = ptr; return ptr; } diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h index 3c6fb1b39a..cb045f8d27 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -14,17 +14,16 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ #include #include +#include #include "backend/optimizer/mem_reuse/mem_reuse.h" #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" namespace mindspore { namespace device { -const int kStaticMem = 0; -const int kDynamicMem = 1; -const int kReuseDynamicMem = 2; +enum MemType { kStaticMem, kDynamicMem, kReuseDynamicMem, kReuseDynamicCommMem }; const int kGetAllOuts = -1; const uint64_t kMemAlignSize = 512; using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; @@ -41,10 +40,11 @@ class MemoryManager { dynamic_mem_offset_ = 0; } - void MallocReusedDynamicMem(session::KernelGraph *graph); - uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - virtual uint8_t *MallocMem(int flag, size_t size); + void MallocReusedDynamicMem(const session::KernelGraph *graph); + uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, + const DeviceAddressPtr &address); + uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); + virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address); virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual void *MallocMemFromMemPool(size_t size); @@ -70,4 +70,4 @@ class MemoryManager { }; } // namespace device } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/all_ops.h b/mindspore/ccsrc/transform/graph_ir/all_ops.h index 1206b3c7a8..d815784e7b 100644 --- a/mindspore/ccsrc/transform/graph_ir/all_ops.h +++ b/mindspore/ccsrc/transform/graph_ir/all_ops.h @@ -14,9 +14,9 @@ * limitations under the License. */ -#ifndef TRANSFORM_ALL_OPS_H_ -#define TRANSFORM_ALL_OPS_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_ALL_OPS_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_ALL_OPS_H_ // old #include "ops/all_ops.h" -#endif // TRANSFORM_ALL_OPS_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_ALL_OPS_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 1978181fec..f686e13f2e 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -23,12 +23,12 @@ #include "frontend/operator/ops.h" #include "utils/log_adapter.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/symbolic.h" #include "utils/config_manager.h" #include "utils/convert_utils.h" #include "./common.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace transform { @@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum"; const char kNameIsFinite[] = "isFinite"; const char kNameReciprocal[] = "Reciprocal"; const char kNameRsqrt[] = "Rsqrt"; -const char kNameRsqrtGrad[] = "RsqrtGrad"; const char kNameSqrt[] = "Sqrt"; const char kNameSquare[] = "Square"; const char kNameSquaredDifference[] = "SquaredDifference"; @@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad"; const char kNameConvolution[] = "Convolution"; const char kNameBiasAdd[] = "BiasAdd"; const char kNameMaxPoolGrad[] = "MaxPoolGrad"; +const char kNameRsqrtGrad[] = "RsqrtGrad"; +const char kNameSqrtGrad[] = "SqrtGrad"; +const char kNameReciprocalGrad[] = "ReciprocalGrad"; const char kNameAvgPoolGrad[] = "AvgPoolGrad"; const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; const char kNameApplyMomentum[] = "ApplyMomentum"; @@ -201,12 +203,16 @@ const char kNameBatchToSpace[] = "BatchToSpace"; const char kNameAtan2[] = "Atan2"; const char kNameApplyRMSProp[] = "ApplyRMSProp"; const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; +const char kNameBasicLSTMCell[] = "BasicLSTMCell"; +const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad"; +const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad"; +const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad"; const char kNameL2Loss[] = "L2Loss"; const char kNameCTCLoss[] = "CTCLoss"; const char kNameRange[] = "Range"; const char kNameSquareSumAll[] = "SquareSumAll"; -const char kNameAscendQuant[] = "AscendQuant"; -const char kNameAscendDequant[] = "AscendDequant"; +const char kNameAscendQuant[] = "Quant"; +const char kNameAscendDequant[] = "Dequant"; const char kNameCase[] = "Case"; // -----------------OpAdapter initialization-------------- @@ -229,6 +235,9 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, + {string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)}, + {string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)}, + {string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)}, {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, @@ -409,7 +418,11 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, {string(kNameAtan2), ADPT_DESC(Atan2)}, {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, - {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, + {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)}, + {string(kNameBasicLSTMCell), ADPT_DESC(BasicLSTMCell)}, + {string(kNameBasicLSTMCellInputGrad), ADPT_DESC(BasicLSTMCellInputGrad)}, + {string(kNameBasicLSTMCellWeightGrad), ADPT_DESC(BasicLSTMCellWeightGrad)}, + {string(kNameBasicLSTMCellCStateGrad), ADPT_DESC(BasicLSTMCellCStateGrad)}, {string(kNameL2Loss), ADPT_DESC(L2Loss)}, {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, {string(kNameRange), ADPT_DESC(RangeD)}, @@ -1334,9 +1347,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node } for (size_t i = 1; i < input_size; i++) { - auto pred = inputs[i]; + AnfNodePtr pred = nullptr; if (case_flag != 0) { pred = case_input_handle_cache_[node.get()]->at(i - 1); + } else { + pred = inputs[i]; } while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "Depend") { diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h index 6fa27831bf..24c9fbff67 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.h +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ -#define MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ #define DRAW_GE_GRAPH @@ -255,4 +255,4 @@ class DfGraphConvertor { } // namespace transform } // namespace mindspore -#endif // MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h index 8a574b7a04..0a24845f83 100644 --- a/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h +++ b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h @@ -13,8 +13,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_DF_GRAPH_MANAGER_H_ -#define TRANSFORM_DF_GRAPH_MANAGER_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_DF_GRAPH_MANAGER_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_DF_GRAPH_MANAGER_H_ #include #include @@ -83,4 +83,4 @@ class DfGraphManager { } // namespace transform } // namespace mindspore -#endif // TRANSFORM_DF_GRAPH_MANAGER_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_DF_GRAPH_MANAGER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/graph_builder.h b/mindspore/ccsrc/transform/graph_ir/graph_builder.h index 5162674242..1262625673 100644 --- a/mindspore/ccsrc/transform/graph_ir/graph_builder.h +++ b/mindspore/ccsrc/transform/graph_ir/graph_builder.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_GRAPH_BUILDER_H_ -#define TRANSFORM_GRAPH_BUILDER_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_BUILDER_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_BUILDER_H_ #include #include @@ -31,4 +31,4 @@ Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phas } // namespace transform } // namespace mindspore -#endif // TRANSFORM_GRAPH_BUILDER_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_BUILDER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/graph_runner.h b/mindspore/ccsrc/transform/graph_ir/graph_runner.h index 92db9e1413..55f8beddae 100644 --- a/mindspore/ccsrc/transform/graph_ir/graph_runner.h +++ b/mindspore/ccsrc/transform/graph_ir/graph_runner.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_GRAPH_RUNNER_H_ -#define TRANSFORM_GRAPH_RUNNER_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_ #include #include @@ -60,4 +60,4 @@ class GraphRunner { } // namespace transform } // namespace mindspore -#endif // TRANSFORM_GRAPH_RUNNER_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.h b/mindspore/ccsrc/transform/graph_ir/op_adapter.h index 358cbd20a1..52dac483ed 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_OP_ADAPTER_H_ -#define TRANSFORM_OP_ADAPTER_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ #include #include @@ -910,4 +910,4 @@ std::unordered_map> OpAdapter< } // namespace transform } // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h index 77e28dda94..9d59534fb4 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_OP_ADAPTER_BASE_H_ -#define TRANSFORM_OP_ADAPTER_BASE_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ #include #include @@ -195,4 +195,4 @@ struct AnyTraits { using ExtraAttr = std::unordered_map; } // namespace transform } // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_BASE_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h index 0a0d745ba2..d80aa3b5b3 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_OP_ADAPTER_UTIL_H_ -#define TRANSFORM_OP_ADAPTER_UTIL_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ #include #include @@ -63,4 +63,4 @@ bool IsCustomPrim(const PrimitivePtr &prim); bool IsCustomCNode(const AnfNodePtr &node); } // namespace transform } // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_UTIL_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc index a730093606..bce309f3b3 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -715,6 +715,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits< {"data_format", ATTR_DESC(data_format, AnyTraits())}}; OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; +// RsqrtGrad +INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}}; + +// SqrtGrad +INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}}; + +// ReciprocalGrad +INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}}; + // avgpoolgrad INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, @@ -746,7 +761,7 @@ ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits(), OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; // Conv2D -INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}}; ATTR_MAP(Conv2D) = { {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, @@ -783,7 +798,7 @@ ATTR_MAP(Conv2DBackpropFilterD) = { OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; // DepthwiseConv2D -INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}}; ATTR_MAP(DepthwiseConv2D) = { {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, @@ -815,7 +830,7 @@ ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; // MatMulV2 -INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(bias)}}; ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits())}, {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits())}}; OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; @@ -1273,12 +1288,41 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits())}, ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; -// ApplyCenteredRMSProp -INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, - {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, - {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; -ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; +// ApplyCenteredRMSPropD +INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, + {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; +ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyCenteredRMSPropD) = { + {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; + +// BasicLSTMCell +INPUT_MAP(BasicLSTMCell) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(c)}, {4, INPUT_DESC(w)}, {5, INPUT_DESC(b)}}; +ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}, + {"forget_bias", ATTR_DESC(forget_bias, AnyTraits())}, + {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits())}, + {"activation", ATTR_DESC(activation, AnyTraits())}}; +OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, + {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}}; + +// BasicLSTMCellInputGrad +INPUT_MAP(BasicLSTMCellInputGrad) = {{1, INPUT_DESC(dgate)}, {2, INPUT_DESC(w)}}; +ATTR_MAP(BasicLSTMCellInputGrad) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}}; +OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht)}}; + +// BasicLSTMCellWeightGrad +INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}}; +ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; + +// BasicLSTMCellCStateGrad +INPUT_MAP(BasicLSTMCellCStateGrad) = {{1, INPUT_DESC(c)}, {2, INPUT_DESC(dht)}, {3, INPUT_DESC(dct)}, + {4, INPUT_DESC(it)}, {5, INPUT_DESC(jt)}, {6, INPUT_DESC(ft)}, + {7, INPUT_DESC(ot)}, {8, INPUT_DESC(tanhct)}}; +ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyTraits())}, + {"activation", ATTR_DESC(activation, AnyTraits())}}; +OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}}; // L2Loss INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; @@ -1307,7 +1351,8 @@ OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}}; // AscendDequant INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, - {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}}; + {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}, + {"dtype", ATTR_DESC(dtype, AnyTraits())}}; OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; #ifdef ENABLE_GE // Print diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h index 472dd00275..186a6f43c3 100755 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_OP_DECLARE_H_ -#define TRANSFORM_OP_DECLARE_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_H_ #include #include @@ -434,6 +434,12 @@ DECLARE_OP_ADAPTER(MaxPool) DECLARE_OP_USE_OUTPUT(MaxPool) DECLARE_OP_ADAPTER(MaxPoolGrad) DECLARE_OP_USE_OUTPUT(MaxPoolGrad) +DECLARE_OP_ADAPTER(SqrtGrad) +DECLARE_OP_USE_OUTPUT(SqrtGrad) +DECLARE_OP_ADAPTER(ReciprocalGrad) +DECLARE_OP_USE_OUTPUT(ReciprocalGrad) +DECLARE_OP_ADAPTER(RsqrtGrad) +DECLARE_OP_USE_OUTPUT(RsqrtGrad) DECLARE_OP_ADAPTER(AvgPool) DECLARE_OP_USE_OUTPUT(AvgPool) DECLARE_OP_ADAPTER(AvgPoolGrad) @@ -481,8 +487,16 @@ DECLARE_OP_USE_OUTPUT(Atan2) DECLARE_OP_ADAPTER(ApplyRMSPropD) DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) -DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) -DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) +DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) +DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) +DECLARE_OP_ADAPTER(BasicLSTMCell) +DECLARE_OP_USE_OUTPUT(BasicLSTMCell) +DECLARE_OP_ADAPTER(BasicLSTMCellInputGrad) +DECLARE_OP_USE_OUTPUT(BasicLSTMCellInputGrad) +DECLARE_OP_ADAPTER(BasicLSTMCellWeightGrad) +DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad) +DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad) +DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad) DECLARE_OP_ADAPTER(L2Loss) DECLARE_OP_USE_OUTPUT(L2Loss) DECLARE_OP_ADAPTER(CTCLoss) @@ -497,4 +511,4 @@ DECLARE_OP_USE_DYN_INPUT(Print) #endif } // namespace transform } // namespace mindspore -#endif // TRANSFORM_OP_DECLARE_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/types.h b/mindspore/ccsrc/transform/graph_ir/types.h index d8c57b6255..0b76393231 100644 --- a/mindspore/ccsrc/transform/graph_ir/types.h +++ b/mindspore/ccsrc/transform/graph_ir/types.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_TYPES_H_ -#define TRANSFORM_TYPES_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_TYPES_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_TYPES_H_ #include #include @@ -56,4 +56,4 @@ using TensorMap = std::unordered_map>; } // namespace transform } // namespace mindspore -#endif // TRANSFORM_TYPES_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_TYPES_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/util.h b/mindspore/ccsrc/transform/graph_ir/util.h index 32d4242c4f..14f251a26e 100644 --- a/mindspore/ccsrc/transform/graph_ir/util.h +++ b/mindspore/ccsrc/transform/graph_ir/util.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef TRANSFORM_UTIL_H_ -#define TRANSFORM_UTIL_H_ +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_ #include #include @@ -238,4 +238,4 @@ class TransformUtil { } // namespace transform } // namespace mindspore -#endif // TRANSFORM_UTIL_H_ +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_ diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/onnx/ir_exporter.cc index 78858eea8a..60a1cdab28 100644 --- a/mindspore/ccsrc/transform/onnx/ir_exporter.cc +++ b/mindspore/ccsrc/transform/onnx/ir_exporter.cc @@ -187,7 +187,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); initializer_proto->set_name(param_name); SetParamToTensorProto(param, initializer_proto); - auto tensor = std::dynamic_pointer_cast(param->default_param()->value()); + auto tensor = std::dynamic_pointer_cast(param->default_param()); if (tensor) { initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); } diff --git a/mindspore/ccsrc/transform/onnx/onnx_exporter.cc b/mindspore/ccsrc/transform/onnx/onnx_exporter.cc index f69fb81a7e..8f0b72f78e 100644 --- a/mindspore/ccsrc/transform/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/onnx/onnx_exporter.cc @@ -449,7 +449,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP initializer_proto->set_name(param_ptr->ToString()); SetTensorProtoInfo(param_ptr, initializer_proto); // set value for initializer - auto tensor = std::dynamic_pointer_cast(param_ptr->default_param()->value()); + auto tensor = std::dynamic_pointer_cast(param_ptr->default_param()); if (tensor) { initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); } diff --git a/mindspore/core/ir/anf_py.cc b/mindspore/ccsrc/utils/anf_py.cc similarity index 100% rename from mindspore/core/ir/anf_py.cc rename to mindspore/ccsrc/utils/anf_py.cc diff --git a/mindspore/ccsrc/utils/any.h b/mindspore/ccsrc/utils/any.h deleted file mode 100644 index d5da5b2938..0000000000 --- a/mindspore/ccsrc/utils/any.h +++ /dev/null @@ -1,214 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_UTILS_ANY_H_ -#define MINDSPORE_CCSRC_UTILS_ANY_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "utils/overload.h" -#include "utils/log_adapter.h" -#include "utils/misc.h" - -namespace mindspore { -// usage:AnyPtr sp = std::make_shared(aname); -template -std::string type(const T &t) { - return demangle(typeid(t).name()); -} - -class Any { - public: - // constructors - Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {} - Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} - Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} - - Any &operator=(Any &&other); - // right reference constructor - template ::type, Any>::value, T>::type> - Any(T &&t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT - BasePtr new_val(new Derived::type>(std::forward(t))); - std::swap(m_ptr, new_val); - } - - ~Any() = default; - - // judge whether is empty - bool empty() const { return m_ptr == nullptr; } - - // judge the is relation - template - bool is() const { - return m_tpIndex == std::type_index(typeid(T)); - } - - const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); } - - std::size_t Hash() const { - std::stringstream buffer; - buffer << m_tpIndex.name(); - if (m_ptr != nullptr) { - buffer << m_ptr->GetString(); - } - return std::hash()(buffer.str()); - } - - template - bool Apply(const std::function &fn) { - if (type() == typeid(T)) { - T x = cast(); - fn(x); - return true; - } - return false; - } - - std::string GetString() const { - if (m_ptr != nullptr) { - return m_ptr->GetString(); - } else { - return std::string(""); - } - } - - friend std::ostream &operator<<(std::ostream &os, const Any &any) { - os << any.GetString(); - return os; - } - - // type cast - template - T &cast() const { - if (!is() || !m_ptr) { - // Use MS_LOGFATAL replace throw std::bad_cast() - MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name(); - } - auto ptr = static_cast *>(m_ptr.get()); - return ptr->m_value; - } - - bool operator==(const Any &other) const { - if (m_tpIndex != other.m_tpIndex) { - return false; - } - if (m_ptr == nullptr && other.m_ptr == nullptr) { - return true; - } - if (m_ptr == nullptr || other.m_ptr == nullptr) { - return false; - } - return *m_ptr == *other.m_ptr; - } - - bool operator!=(const Any &other) const { return !(operator==(other)); } - - Any &operator=(const Any &other); - - bool operator<(const Any &other) const; - - std::string ToString() const { - std::ostringstream buffer; - if (m_tpIndex == typeid(float)) { - buffer << " " << cast(); - } else if (m_tpIndex == typeid(double)) { - buffer << " " << cast(); - } else if (m_tpIndex == typeid(int)) { - buffer << " " << cast(); - } else if (m_tpIndex == typeid(bool)) { - buffer << " " << cast(); - } else { - buffer << "<" << demangle(m_tpIndex.name()) << "> " << m_ptr->GetString(); - } - return buffer.str(); - } - __attribute__((used)) void dump() const { std::cout << ToString() << std::endl; } - - private: - struct Base; - using BasePtr = std::unique_ptr; - - // type base definition - struct Base { - virtual const std::type_info &type() const = 0; - virtual BasePtr clone() const = 0; - virtual ~Base() = default; - virtual bool operator==(const Base &other) const = 0; - virtual std::string GetString() = 0; - }; - - template - struct Derived : public Base { - template - explicit Derived(Args &&... args) : m_value(std::forward(args)...), serialize_cache_("") {} - - bool operator==(const Base &other) const override { - if (typeid(*this) != typeid(other)) { - return false; - } - return m_value == static_cast &>(other).m_value; - } - - const std::type_info &type() const override { return typeid(T); } - - BasePtr clone() const override { return BasePtr(new Derived(m_value)); } - - ~Derived() override {} - - std::string GetString() override { - std::stringstream buffer; - buffer << m_value; - return buffer.str(); - } - - T m_value; - std::string serialize_cache_; - }; - - // clone method - BasePtr clone() const { - if (m_ptr != nullptr) { - return m_ptr->clone(); - } - return nullptr; - } - - BasePtr m_ptr; // point to real data - std::type_index m_tpIndex; // type info of data -}; - -using AnyPtr = std::shared_ptr; - -struct AnyHash { - std::size_t operator()(const Any &c) const { return c.Hash(); } -}; - -struct AnyLess { - bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); } -}; - -bool AnyIsLiteral(const Any &any); - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_ANY_H_ diff --git a/mindspore/ccsrc/utils/base_ref.cc b/mindspore/ccsrc/utils/base_ref.cc deleted file mode 100644 index b0d3564c1c..0000000000 --- a/mindspore/ccsrc/utils/base_ref.cc +++ /dev/null @@ -1,188 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/base_ref.h" - -namespace mindspore { -iterator ConstIteratorCast(std::vector *v, const const_iterator iter) { - return std::next(v->begin(), std::distance(v->cbegin(), iter)); -} - -BaseRef::BaseRef(const BaseRef &other) : Base(other), m_ptr(other.m_ptr) { - if (!m_ptr) { - m_ptr = other.copy(); - } -} - -bool BaseRef::operator==(const BaseRef &other) const { - if (m_ptr == other.m_ptr) { - return true; - } - if (m_ptr == nullptr && other.m_ptr == nullptr) { - return *this == other; - } - if (m_ptr == nullptr || other.m_ptr == nullptr) { - return false; - } - if (type() != other.type()) { - MS_LOG(DEBUG) << "Type mismatch"; - return false; - } - if (m_ptr->isa()) { - return *(m_ptr->cast()) == *(other.m_ptr->cast()); - } - - // for noderef equal - if (m_ptr->isa()) { - return *std::static_pointer_cast(m_ptr) == *std::static_pointer_cast(other.m_ptr); - } - - // for node equal - return *m_ptr == *other.m_ptr; -} - -// left reference -BaseRef &BaseRef::operator=(const BaseRef &other) { - if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { - return *this; - } - m_ptr = other.copy(); - return *this; -} - -// right reference -BaseRef &BaseRef::operator=(BaseRef &&other) { - if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { - return *this; - } - m_ptr = other.copy(); - other.m_ptr = nullptr; - return *this; -} - -std::string BaseRef::ToString() const { - if (m_ptr != nullptr) { - return std::string(m_ptr->type_name()) + std::string(" value:") + m_ptr->ToString(); - } - return std::string(); -} - -uint32_t BaseRef::type() const { - if (m_ptr != nullptr) { - return m_ptr->tid(); - } - return tid(); -} - -// left reference -SetRef &SetRef::operator=(const SetRef &other) { - if (elements_ == other.elements_ || this == &other) { - return *this; - } - elements_ = other.elements_; - return *this; -} - -std::string SetRef::ToString() const { - std::ostringstream buffer; - bool begin = true; - buffer << "set["; - for (auto &attr : elements_) { - if (!begin) { - buffer << ", "; - } else { - begin = false; - } - buffer << attr.ToString(); - } - buffer << "]"; - return buffer.str(); -} - -// left reference -VectorRef &VectorRef::operator=(const VectorRef &other) { - if (elements_ == other.elements_ || this == &other) { - return *this; - } - elements_ = other.elements_; - return *this; -} - -std::string VectorRef::ToString() const { - std::ostringstream buffer; - bool begin = true; - buffer << "vector["; - for (auto &attr : elements_) { - if (!begin) { - buffer << ", "; - } else { - begin = false; - } - buffer << attr.ToString(); - } - buffer << "]"; - return buffer.str(); -} - -bool VectorRef::operator==(const BaseRef &other) const { - if (!utils::isa(other)) { - return false; - } - return *this == utils::cast(other); -} - -bool VectorRef::operator==(const VectorRef &other) const { - if (elements_.size() != other.elements_.size()) { - return false; - } - for (size_t i = 0; i < elements_.size(); ++i) { - if (elements_[i] != other.elements_[i]) { - return false; - } - } - return true; -} - -bool SetRef::operator==(const BaseRef &other) const { - if (!utils::isa(other)) { - return false; - } - return *this == utils::cast(other); -} - -bool SetRef::operator==(const SetRef &other) const { - if (elements_.size() != other.elements_.size()) { - return false; - } - auto iter = elements_.begin(); - auto oth_iter = other.elements_.begin(); - for (; iter != elements_.end(); iter++, oth_iter++) { - if (*iter != *oth_iter) { - return false; - } - } - return true; -} - -bool RunFunctionRef::operator==(const BaseRef &other) const { - if (!utils::isa(other)) { - return false; - } - return *this == utils::cast(other); -} - -bool RunFunctionRef::operator==(const RunFunctionRef &other) const { return func_ == other.func_; } -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.h b/mindspore/ccsrc/utils/base_ref.h deleted file mode 100644 index 7c0b4b2f1c..0000000000 --- a/mindspore/ccsrc/utils/base_ref.h +++ /dev/null @@ -1,381 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_H_ -#define MINDSPORE_CCSRC_UTILS_BASE_REF_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/value.h" - -namespace mindspore { -class BaseRef; -class VectorRef; -class SetRef; -class RunFunctionRef; - -using iterator = std::vector::iterator; -using const_iterator = std::vector::const_iterator; -using const_reverse_iterator = std::vector::const_reverse_iterator; - -using RunFunc = std::function; -using RunFuncPtr = std::shared_ptr; - -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_const_t = typename std::remove_const::type; -template -using is_base = std::is_base_of>; -template -using is_value = std::is_base_of>; -template -using is_base_ref = std::is_base_of>; - -iterator ConstIteratorCast(std::vector *v, const_iterator iter); - -inline std::shared_ptr MakeNode(const std::vector &elements) { - return std::make_shared(elements); -} - -inline std::shared_ptr MakeNode(std::initializer_list elements) { - return std::make_shared(elements); -} - -// Anfnode, Funcgraph and some not value node class -template >::value && is_base::value, - int>::type = 0> -inline BasePtr MakeNode(const T &v) { - return v; -} - -template >::value && !is_base_ref::value, int>::type = 0> -inline BasePtr MakeNode(const T &v) { - return MakeValue(v); -} - -inline std::shared_ptr MakeNode(const VectorRef &a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const AnfNodePtrList &a) { - std::vector ret; - (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; }); - return std::make_shared(ret); -} -inline std::shared_ptr MakeNode(const SetRef &a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const RunFuncPtr &a) { return std::make_shared(a); } - -class BaseRef : public Base { - public: - BaseRef() : m_ptr(nullptr) {} - BaseRef(const BaseRef &other); - virtual std::shared_ptr copy() const { return m_ptr; } - - BaseRef(BaseRef &&other) : Base(other) { - m_ptr = other.m_ptr; - other.m_ptr = nullptr; - } - - // right reference constructor - template ::type, BaseRef>::value, T>::type> - BaseRef(T &&t) { // NOLINT - m_ptr = MakeNode(t); - } - - ~BaseRef() override { m_ptr = nullptr; } - - MS_DECLARE_PARENT(BaseRef, Base) - - bool operator!=(const BaseRef &other) const { return !(operator==(other)); } - - virtual bool operator==(const BaseRef &other) const; - - // left reference - virtual BaseRef &operator=(const BaseRef &other); - // right reference - virtual BaseRef &operator=(BaseRef &&other); - - std::size_t hash() const override { - if (m_ptr == nullptr) { - MS_LOG(ERROR) << "Invalid m_ptr"; - return 0; - } - return m_ptr->hash(); - } - - std::string ToString() const override; - - bool is_null() const { return m_ptr == nullptr; } - - virtual uint32_t type() const; - - BasePtr m_ptr; // point to real data -}; -using BaseRefPtr = std::shared_ptr; - -struct BaseRefHash { - std::size_t operator()(const BaseRef &c) const { return c.hash(); } -}; - -struct BaseRefLess { - bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); } -}; - -namespace utils { -// judge isa relation -// examples: isa(handle), isa(handle) -template ::value && !is_base_ref::value, int>::type = 0> -bool isa(const BaseRef &handle) { - if (!handle.m_ptr) { - return false; - } - return handle.m_ptr->isa(); -} - -// noderef isa ptr isa(x) or isa() -template ::value, typename T::element_type>::type, - typename std::enable_if::value || is_base_ref::value, int>::type = 0> -bool isa(const BaseRef &handle) { - if (handle.m_ptr == nullptr) { - return typeid(handle.m_ptr) == typeid(T); - } - - if (handle.m_ptr->isa()) { - return true; - } - - // constptr isa can be true - return std::dynamic_pointer_cast(handle.m_ptr) != nullptr; -} - -// isa(handle) -template ::type::element_type> -bool isa(const BaseRef &handle) { - if (handle.m_ptr == nullptr) { - return false; - } - return handle.m_ptr->isa(); -} - -// isa(handle), judge reference or ptr -template ::value, int>::type = 0> -bool isa(const BaseRef &handle) { - static const uint32_t tid = Base::GetTypeId(typeid(T).name()); - return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa()); -} - -// valueref -> C++ type -// cast(handle) -template ::value && !is_shared_ptr::value, int>::type = 0> -T cast(const BaseRef &handle) { - T ret = GetValue(std::static_pointer_cast(handle.m_ptr)); - return std::move(ret); -} - -// valueref -> valueref type -// cast(handle) -template ::value, int>::type = 0> -const T &cast(const BaseRef &handle) { - if (handle.m_ptr) { - return static_cast(*handle.m_ptr); - } - - return std::move(static_cast(handle)); -} - -// valueref -> nodeptr type -// cast(handle) -template ::value, typename T::element_type>::type, - typename std::enable_if::value && std::is_base_of::value, - int>::type = 0> -T cast(const BaseRef &handle) { - if (!handle.m_ptr) { - MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null"; - } - - auto m = handle.m_ptr->cast(); - if (nullptr != m) { - return m; - } - return std::static_pointer_cast(handle.m_ptr); -} -} // namespace utils - -class VectorRef : public BaseRef { - public: - using value_type = BaseRef; - - VectorRef() {} - explicit VectorRef(const std::vector &elements) : elements_(elements) {} - VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} - - // left reference - virtual VectorRef &operator=(const VectorRef &other); - - ~VectorRef() override = default; - - std::shared_ptr copy() const override { return std::make_shared(elements_); } - - bool empty() const { return (elements_.size() == 0); } - - std::size_t size() const { return elements_.size(); } - MS_DECLARE_PARENT(VectorRef, BaseRef) - - const BaseRef &operator[](const std::size_t &dim) const { - if (dim >= size()) { - MS_LOG(EXCEPTION) << "Out of the size of the tuple."; - } - return elements_[dim]; - } - - BaseRef &operator[](const std::size_t &dim) { - if (dim >= size()) { - MS_LOG(EXCEPTION) << "Out of the size of the tuple."; - } - return elements_[dim]; - } - - uint32_t type() const override { return tid(); } - std::string ToString() const override; - std::vector &elements() { return elements_; } - void clear() { elements_.clear(); } - - bool operator==(const BaseRef &other) const override; - bool operator==(const VectorRef &other) const; - - void push_back(const BaseRef &value) { elements_.push_back(value); } - void push_back(BaseRef &&value) { elements_.push_back(value); } - - void emplace_back(const BaseRef &value) { elements_.emplace_back(value); } - void emplace_back(BaseRef &&value) { elements_.emplace_back(value); } - - template - void insert(const iterator pos, const InputIt first, const InputIt last) { - (void)elements_.insert(pos, first, last); - } - - template - void insert(const const_iterator cpos, const InputIt first, const InputIt last) { - auto pos = ConstIteratorCast(&elements_, cpos); - (void)elements_.insert(pos, first, last); - } - - const_iterator begin() const { return elements_.begin(); } - const_iterator end() const { return elements_.end(); } - - const_reverse_iterator rbegin() const { return elements_.rbegin(); } - const_reverse_iterator rend() const { return elements_.rend(); } - - iterator erase(const const_iterator cpos) { - auto pos = ConstIteratorCast(&elements_, cpos); - return elements_.erase(pos); - } - - iterator erase(const const_iterator cfirst, const const_iterator clast) { - auto first = ConstIteratorCast(&elements_, cfirst); - auto last = ConstIteratorCast(&elements_, clast); - return elements_.erase(first, last); - } - - std::size_t hash() const override { - std::stringstream buffer; - buffer << ToString(); - return std::hash()(buffer.str()); - } - - std::vector elements_; -}; - -using VectorRefPtr = std::shared_ptr; - -using set_iterator = std::set::iterator; -using const_set_iterator = std::set::const_iterator; - -struct VectorRefHash { - std::size_t operator()(const VectorRef &c) const { return c.hash(); } -}; - -class SetRef : public BaseRef { - public: - SetRef() {} - explicit SetRef(const std::set &elements) : elements_(elements) {} - SetRef(const std::initializer_list elements) : elements_(elements.begin(), elements.end()) {} - SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {} - - // left reference - virtual SetRef &operator=(const SetRef &other); - - bool operator==(const BaseRef &other) const override; - bool operator==(const SetRef &other) const; - - ~SetRef() override = default; - - std::shared_ptr copy() const override { return std::make_shared(elements_); } - - bool empty() const { return (elements_.size() == 0); } - - std::size_t size() const { return elements_.size(); } - MS_DECLARE_PARENT(SetRef, BaseRef) - - uint32_t type() const override { return tid(); } - std::string ToString() const override; - std::set &elements() { return elements_; } - void clear() { elements_.clear(); } - - void insert(const BaseRef &elem) { (void)elements_.insert(elem); } - - const_set_iterator begin() const { return elements_.begin(); } - const_set_iterator end() const { return elements_.end(); } - - template - void insert(const InputIt first, const InputIt last) { - (void)elements_.insert(first, last); - } - - std::size_t count(const BaseRef &elem) const { return elements_.count(elem); } - const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); } - - std::set elements_; -}; - -using SetRefPtr = std::shared_ptr; - -class RunFunctionRef : public BaseRef { - public: - RunFunctionRef() {} - explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {} - - ~RunFunctionRef() override = default; - MS_DECLARE_PARENT(RunFunctionRef, BaseRef) - - uint32_t type() const override { return tid(); } - std::string ToString() const override { return std::string("RunFunctionRef"); } - bool operator==(const BaseRef &other) const override; - bool operator==(const RunFunctionRef &other) const; - - RunFuncPtr func_; -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_NODE_REF_H_ diff --git a/mindspore/ccsrc/utils/base_ref_extends.h b/mindspore/ccsrc/utils/base_ref_extends.h index 5aa603bfe9..18d51c2611 100644 --- a/mindspore/ccsrc/utils/base_ref_extends.h +++ b/mindspore/ccsrc/utils/base_ref_extends.h @@ -20,7 +20,7 @@ #include #include "utils/base_ref_py.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" namespace mindspore { class PyObjectRef : public BaseRef { diff --git a/mindspore/ccsrc/utils/base_ref_utils.cc b/mindspore/ccsrc/utils/base_ref_utils.cc deleted file mode 100644 index 87089c6266..0000000000 --- a/mindspore/ccsrc/utils/base_ref_utils.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "utils/base_ref_utils.h" -#include "include/ms_tensor.h" -#include "ir/tensor.h" - -namespace mindspore { -void IterateFindTensor(std::vector> *msTensors, const VectorRef &ref_list) { - for (size_t i = 0; i < ref_list.size(); ++i) { - if (utils::isa(ref_list[i])) { - auto tensor_ptr = utils::cast>(ref_list[i]); - MS_EXCEPTION_IF_NULL(tensor_ptr); - auto tensor = new inference::Tensor(tensor_ptr); - msTensors->emplace_back(std::shared_ptr(tensor)); - } else if (utils::isa(ref_list[i])) { - auto ref_iter = utils::cast(ref_list[i]); - IterateFindTensor(msTensors, ref_iter); - } else { - MS_LOG(EXCEPTION) << "The output is not a tensor"; - } - } -} - -std::vector> TransformVectorRefToMultiTensor(const VectorRef &base_ref) { - std::vector> msTensors; - if (utils::isa(base_ref)) { - auto ref_list = utils::cast(base_ref); - IterateFindTensor(&msTensors, ref_list); - } else if (utils::isa(base_ref)) { - auto tensor_ptr = utils::cast>(base_ref); - MS_EXCEPTION_IF_NULL(tensor_ptr); - auto tensor = new inference::Tensor(tensor_ptr); - msTensors.emplace_back(std::shared_ptr(tensor)); - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } - return msTensors; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref_utils.h b/mindspore/ccsrc/utils/base_ref_utils.h deleted file mode 100644 index 2503eab738..0000000000 --- a/mindspore/ccsrc/utils/base_ref_utils.h +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "utils/base_ref.h" -#include "include/ms_tensor.h" - -#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H -#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H -namespace mindspore { -std::vector> TransformVectorRefToMultiTensor(const VectorRef &base_ref); -} // namespace mindspore -#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 6001b295ad..a73df1e257 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -52,7 +52,7 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, if (param_node->name() == param_name) { TensorPtr tensor; if (param_node->has_default()) { - tensor = std::dynamic_pointer_cast(param_node->default_param()->value()); + tensor = std::dynamic_pointer_cast(param_node->default_param()); } if (tensor == nullptr) { shape->push_back(ONE_SHAPE); diff --git a/mindspore/ccsrc/utils/context/context_extends.cc b/mindspore/ccsrc/utils/context/context_extends.cc new file mode 100644 index 0000000000..2588cfddd7 --- /dev/null +++ b/mindspore/ccsrc/utils/context/context_extends.cc @@ -0,0 +1,372 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/context/context_extends.h" +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace context { +#ifdef ENABLE_GE +using mindspore::transform::DfGraphManager; +#endif + +#ifndef NO_DLIB +// Open tdt dataset +bool OpenTsd(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + + if (ms_context_ptr->is_pynative_ge_init()) { + return true; + } + + if (ms_context_ptr->tsd_ref()) { + MS_LOG(DEBUG) << "TDT Dataset client is already opened."; + ms_context_ptr->set_tsd_ref("++"); + return true; + } + + auto role = common::GetEnv("MS_ROLE"); + if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) { + return true; + } + + unsigned int device_id; + unsigned int rank_size = 1; + + device_id = ms_context_ptr->device_id(); + + auto rank_size_env = common::GetEnv("RANK_SIZE"); + if (rank_size_env.empty()) { + MS_LOG(INFO) << "Should config rank size."; + rank_size = 1; + } else { + int rank_env = std::stoi(rank_size_env); + if (rank_env <= 0) { + MS_LOG(EXCEPTION) << "Error rank size " << rank_env << "."; + } + rank_size = IntToUint(rank_env); + } + + MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; + TDT_StatusT status = TsdOpen(device_id, rank_size); + if (status != TDT_OK) { + MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; + return false; + } + ms_context_ptr->set_tsd_ref("++"); +#ifdef ENABLE_TDTQUE + int32_t initStatus = tdt::TdtHostInit(device_id); + if (initStatus != TDT_OK_CODE) { + MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; + return false; + } + ms_context_ptr->tdt_print_ = std::thread(TensorPrint()); +#endif + MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << ms_context_ptr->tsd_ref() << "."; + return true; +} + +bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + if (ms_context_ptr->tsd_ref() == 0) { + return true; + } + ms_context_ptr->set_tsd_ref("--"); + if (force || ms_context_ptr->tsd_ref() == 0) { + ms_context_ptr->set_tsd_ref(" "); +#ifdef ENABLE_TDTQUE + int32_t stopStatus = tdt::TdtHostStop(KNpuLog); + if (stopStatus != TDT_OK_CODE) { + MS_LOG(EXCEPTION) << "Stop tsd failed, status = " << stopStatus << "."; + return false; + } + py::gil_scoped_release gil_release; + int32_t destroyStatus = tdt::TdtHostDestroy(); + if (destroyStatus != TDT_OK_CODE) { + MS_LOG(EXCEPTION) << "Destroy tsd failed, status = " << destroyStatus << "."; + return false; + } + try { + if (ms_context_ptr->tdt_print_.joinable()) { + MS_LOG(INFO) << "join tdt host receive process"; + ms_context_ptr->tdt_print_.join(); + } + } catch (const std::exception &e) { + MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); + } +#endif + auto device_id = ms_context_ptr->device_id(); + TDT_StatusT status = TsdClose(device_id); + if (status != TDT_OK) { + MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << "."; + return false; + } + ms_context_ptr->set_pynative_ge_init(false); + MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << "."; + } else { + MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->tsd_ref() + << "."; + } + + return true; +} +#else +bool OpenTsd(const std::shared_ptr &ms_context_ptr) { return true; } +bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool) { return true; } +#endif + +void SetDisableReuseMemoryFlag(std::map *ge_options) { + auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY"); + if (!env_disable_reuse_memory.empty()) { + (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory; + } else { + (*ge_options)["ge.exec.disableReuseMemory"] = "0"; + MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0"; + } +} + +void GetGeOptions(const std::shared_ptr &ms_context_ptr, std::map *ge_options) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } +#ifdef ENABLE_GE + (*ge_options)["device_id"] = "0"; + (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->enable_dump()); + (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->save_dump_path(); + (*ge_options)["ge.exec.dumpMode"] = "output"; + MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->enable_dump()) + << " and save dump path is " << ms_context_ptr->save_dump_path() << "."; + (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->enable_profiling()); + if (ms_context_ptr->enable_profiling()) { + (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->profiling_options(); + } + + (*ge_options)["rank_table_file"] = ""; + auto env_ddk_version = common::GetEnv("DDK_VERSION"); + if (!env_ddk_version.empty()) { + (*ge_options)["ge.DDK_version"] = env_ddk_version; + } else { + (*ge_options)["ge.DDK_version"] = "1.60.T17.B830"; + } + (*ge_options)["graphType"] = "1"; + + if (ms_context_ptr->graph_memory_max_size() != "0") { + (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->graph_memory_max_size(); + } + + if (ms_context_ptr->variable_memory_max_size() != "0") { + (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->variable_memory_max_size(); + } + +#if ENABLE_TRAIN == 1 + (*ge_options)["ge.graphRunMode"] = "1"; +#endif + SetDisableReuseMemoryFlag(ge_options); + SetHcclOptions(ms_context_ptr, ge_options); + + auto env_job_id = common::GetEnv("JOB_ID"); + if (!env_job_id.empty()) { + (*ge_options)["ge.exec.jobId"] = env_job_id; + } else { + (*ge_options)["ge.exec.jobId"] = "0"; + MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0"; + } + + auto env_fe_flag = common::GetEnv("FE_FLAG"); + if (!env_fe_flag.empty()) { + (*ge_options)["ge.feFlag"] = env_fe_flag; + MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; + } + + auto env_aicpu_flag = common::GetEnv("AICPU_FLAG"); + if (!env_aicpu_flag.empty()) { + (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag; + MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; + } + + auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH"); + if (!proto_lib_path.empty()) { + char real_path[PATH_MAX] = {0}; + if (realpath(proto_lib_path.c_str(), real_path)) { + proto_lib_path = real_path; + (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path; + } + } else { + MS_LOG(WARNING) << "Set proto lib path failed!"; + } + + // Enable auto mixed precision according to the context options + if (ms_context_ptr->auto_mixed_precision_flag()) { + (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; + } else { + (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; + } + // Disable the global variable acc, only enable it whlie adding training graph in pipeline + (*ge_options)["ge.exec.variable_acc"] = "0"; +#endif +} + +void SetHcclOptions(const std::shared_ptr &ms_context_ptr, std::map *ge_options) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); + auto env_rank_id = common::GetEnv("RANK_ID"); + auto env_device_id = std::to_string(ms_context_ptr->device_id()); + if (!(env_table_file.empty() || env_rank_id.empty())) { + MS_LOG(INFO) << "Initialize Ge for distribute parameter"; + MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; + auto env_hccl_flag = common::GetEnv("HCCL_FLAG"); + if (!env_hccl_flag.empty()) { + (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag; + } + (*ge_options)["ge.exec.isUseHcom"] = "1"; + (*ge_options)["ge.exec.deviceId"] = env_device_id; + (*ge_options)["ge.exec.rankId"] = env_rank_id; + (*ge_options)["ge.exec.podName"] = env_rank_id; + (*ge_options)["ge.exec.rankTableFile"] = env_table_file; + (*ge_options)["ge.graphRunMode"] = "1"; + } else { + // device id is still needed for non-distribute case + (*ge_options)["ge.exec.deviceId"] = env_device_id; + MS_LOG(INFO) << "No hccl mode. " + "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV."; + } + + auto env_deploy_mode = common::GetEnv("DEPLOY_MODE"); + if (!env_deploy_mode.empty()) { + (*ge_options)["ge.exec.deployMode"] = env_deploy_mode; + } else { + (*ge_options)["ge.exec.deployMode"] = "0"; + MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0"; + } +} + +bool InitGe(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } +#ifdef ENABLE_GE + if (ms_context_ptr->is_pynative_ge_init()) { + return true; + } + + if (ms_context_ptr->ge_ref()) { + ms_context_ptr->set_ge_ref("++"); + return true; + } + + std::map ge_options; + GetGeOptions(ms_context_ptr, &ge_options); + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) { + MS_LOG(EXCEPTION) << "Initialize GE failed!"; + } + } + ms_context_ptr->set_ge_ref("++"); + MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->ge_ref() << "."; +#endif + return true; +} + +bool PynativeInitGe(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + if (ms_context_ptr->is_pynative_ge_init() || ms_context_ptr->ge_ref() || ms_context_ptr->tsd_ref()) { + return true; + } + (void)OpenTsd(ms_context_ptr); + (void)InitGe(ms_context_ptr); + ms_context_ptr->set_pynative_ge_init(true); + return true; +} + +bool FinalizeGe(const std::shared_ptr &ms_context_ptr, bool force) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } +#ifdef ENABLE_GE + if (ms_context_ptr->ge_ref() == 0) { + return true; + } + ms_context_ptr->set_ge_ref("--"); + if (force || ms_context_ptr->ge_ref() == 0) { + ms_context_ptr->set_ge_ref(" "); + try { + DfGraphManager::GetInstance().DeleteGraphRunner(); + DfGraphManager::GetInstance().DeleteGeSession(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what(); + } catch (...) { + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName; + } + if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { + MS_LOG(WARNING) << "Finalize GE failed!"; + } + ms_context_ptr->set_pynative_ge_init(true); + } else { + MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->ge_ref() << "."; + } +#endif + return true; +} + +bool IsTsdOpened(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + return ms_context_ptr->tsd_ref(); +} + +bool IsGeInited(const std::shared_ptr &ms_context_ptr) { + if (ms_context_ptr == nullptr) { + MS_LOG(EXCEPTION) << "nullptr"; + } + return ms_context_ptr->ge_ref(); +} + +// Register for device type. +struct DeviceTypeSetRegister { + DeviceTypeSetRegister() { + MsContext::device_type_seter([](std::shared_ptr &device_type_seter) { +#ifdef ENABLE_GE + device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice)); +#elif defined(ENABLE_D) + device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice)); +#elif defined(ENABLE_GPU) + device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice)); +#else + device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice)); +#endif + }); + } + ~DeviceTypeSetRegister() = default; +} device_type_set_regsiter; +} // namespace context +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/context/context_extends.h b/mindspore/ccsrc/utils/context/context_extends.h new file mode 100644 index 0000000000..f425042bec --- /dev/null +++ b/mindspore/ccsrc/utils/context/context_extends.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H +#define MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H + +#include +#include +#include +#include "utils/ms_context.h" +#include "utils/tensorprint_utils.h" +#include "utils/convert_utils.h" + +#ifndef NO_DLIB +#include "tdt/tsd_client.h" +#include "tdt/tdt_host_interface.h" +#include "tdt/data_common.h" +#endif +#ifdef ENABLE_GE +#include "transform/graph_ir/df_graph_manager.h" +#endif + +namespace mindspore { +namespace context { +bool OpenTsd(const std::shared_ptr &inst_context); +bool CloseTsd(const std::shared_ptr &inst_context, bool force = false); +void SetHcclOptions(const std::shared_ptr &inst_context, std::map *ge_options); +void GetGeOptions(const std::shared_ptr &inst_context, std::map *ge_options); +void SetDisableReuseMemoryFlag(std::map *ge_options); +bool InitGe(const std::shared_ptr &inst_context); +bool FinalizeGe(const std::shared_ptr &inst_context, bool force = false); +bool PynativeInitGe(const std::shared_ptr &inst_context); +bool IsTsdOpened(const std::shared_ptr &inst_context); +bool IsGeInited(const std::shared_ptr &inst_context); +} // namespace context +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_UTILS_CONTEXT_CONTEXT_EXTENDS_H diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc deleted file mode 100644 index 79f5b147ac..0000000000 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ /dev/null @@ -1,460 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/context/ms_context.h" -#include -#include -#include -#include "./common.h" -#include "utils/convert_utils.h" -#include "utils/tensorprint_utils.h" -#ifndef NO_DLIB -#include "tdt/tsd_client.h" -#include "tdt/tdt_host_interface.h" -#include "tdt/data_common.h" -#endif -#ifdef ENABLE_GE -#include "transform/graph_ir/df_graph_manager.h" -#endif -#include "ir/tensor.h" -#include "common/utils.h" - -namespace mindspore { -#ifdef ENABLE_GE -using mindspore::transform::DfGraphManager; -#endif - -std::atomic thread_1_must_end(false); - -std::shared_ptr MsContext::inst_context_ = nullptr; -std::map MsContext::policy_map_ = {{"ge", kMsBackendGePrior}, - {"vm", kMsBackendVmOnly}, - {"ms", kMsBackendMsPrior}, - {"ge_only", kMsBackendGeOnly}, - {"vm_prior", kMsBackendVmPrior}}; - -MsContext::MsContext(const std::string &policy, const std::string &target) { - save_graphs_flag_ = false; - save_graphs_path_ = "."; - save_ms_model_flag_ = false; - save_ms_model_path_ = "./model.ms"; - enable_dump_ = false; - save_dump_path_ = "."; - tsd_ref_ = 0; - ge_ref_ = 0; - is_multi_graph_sink_ = false; - is_pynative_ge_init_ = false; - enable_reduce_precision_ = true; - auto env_device = common::GetEnv("DEVICE_ID"); - if (!env_device.empty()) { - device_id_ = UlongToUint(std::stoul(env_device.c_str())); - } else { - device_id_ = 0; - } - backend_policy_ = policy_map_[policy]; - device_target_ = target; - execution_mode_ = kPynativeMode; - enable_task_sink_ = true; - ir_fusion_flag_ = true; - enable_hccl_ = false; -#ifdef ENABLE_DEBUGGER - enable_mem_reuse_ = false; -#else - enable_mem_reuse_ = true; -#endif - enable_gpu_summary_ = true; - precompile_only_ = false; - auto_mixed_precision_flag_ = false; - enable_pynative_infer_ = false; - enable_pynative_hook_ = false; - enable_dynamic_mem_pool_ = true; - graph_memory_max_size_ = "0"; - variable_memory_max_size_ = "0"; - enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice; - profiling_mode_ = false; - profiling_options_ = "training_trace"; - check_bprop_flag_ = false; - max_device_memory_ = kDefaultMaxDeviceMemory; - print_file_path_ = ""; - enable_graph_kernel_ = false; - enable_sparse_ = false; -} - -std::shared_ptr MsContext::GetInstance() { - if (inst_context_ == nullptr) { - MS_LOG(DEBUG) << "Create new mindspore context"; -#ifdef ENABLE_GE - inst_context_.reset(new (std::nothrow) MsContext("ge", kAscendDevice)); -#elif defined(ENABLE_D) - inst_context_.reset(new (std::nothrow) MsContext("ms", kAscendDevice)); -#elif defined(ENABLE_GPU) - inst_context_.reset(new (std::nothrow) MsContext("ms", kGPUDevice)); -#else - inst_context_.reset(new (std::nothrow) MsContext("vm", kCPUDevice)); -#endif - } - return inst_context_; -} - -bool MsContext::set_backend_policy(const std::string &policy) { - if (policy_map_.find(policy) == policy_map_.end()) { - MS_LOG(ERROR) << "invalid backend policy name: " << policy; - return false; - } - backend_policy_ = policy_map_[policy]; - MS_LOG(INFO) << "ms set context backend policy:" << policy; - return true; -} - -std::string MsContext::backend_policy() const { - auto res = std::find_if( - policy_map_.begin(), policy_map_.end(), - [&, this](const std::pair &item) { return item.second == backend_policy_; }); - if (res != policy_map_.end()) { - return res->first; - } - return "unknown"; -} - -void MsContext::set_execution_mode(int execution_mode) { - if (execution_mode != kGraphMode && execution_mode != kPynativeMode) { - MS_LOG(EXCEPTION) << "The execution mode is invalid!"; - } - execution_mode_ = execution_mode; -} - -bool MsContext::set_device_target(const std::string &target) { - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(ERROR) << "invalid device target name: " << target; - return false; - } - if (target == kDavinciDevice) { - device_target_ = kAscendDevice; - } else { - device_target_ = target; - } - MS_LOG(INFO) << "ms set context device target:" << target; - return true; -} - -bool MsContext::set_device_id(uint32_t device_id) { - device_id_ = device_id; - MS_LOG(INFO) << "ms set context device id:" << device_id; - return true; -} - -#ifndef NO_DLIB -// Open tdt dataset -bool MsContext::OpenTsd() { - if (is_pynative_ge_init_) { - return true; - } - - if (tsd_ref_) { - MS_LOG(DEBUG) << "TDT Dataset client is already opened."; - tsd_ref_++; - return true; - } - - auto role = common::GetEnv("MS_ROLE"); - if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) { - return true; - } - - unsigned int device_id; - unsigned int rank_size = 1; - - device_id = device_id_; - - auto rank_size_env = common::GetEnv("RANK_SIZE"); - if (rank_size_env.empty()) { - MS_LOG(INFO) << "Should config rank size."; - rank_size = 1; - } else { - int rank_env = std::stoi(rank_size_env); - if (rank_env <= 0) { - MS_LOG(EXCEPTION) << "Error rank size " << rank_env << "."; - } - rank_size = IntToUint(rank_env); - } - - MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; - TDT_StatusT status = TsdOpen(device_id, rank_size); - if (status != TDT_OK) { - MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; - return false; - } - tsd_ref_++; -#ifdef ENABLE_TDTQUE - int32_t initStatus = tdt::TdtHostInit(device_id); - if (initStatus != TDT_OK_CODE) { - MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; - return false; - } - tdt_print_ = std::thread(TensorPrint()); -#endif - MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << tsd_ref_ << "."; - return true; -} - -bool MsContext::CloseTsd(bool force) { - if (tsd_ref_ == 0) { - return true; - } - tsd_ref_--; - if (force || tsd_ref_ == 0) { - tsd_ref_ = 0; -#ifdef ENABLE_TDTQUE - int32_t stopStatus = tdt::TdtHostStop(KNpuLog); - if (stopStatus != TDT_OK_CODE) { - MS_LOG(EXCEPTION) << "Stop tsd failed, status = " << stopStatus << "."; - return false; - } - py::gil_scoped_release gil_release; - int32_t destroyStatus = tdt::TdtHostDestroy(); - if (destroyStatus != TDT_OK_CODE) { - MS_LOG(EXCEPTION) << "Destroy tsd failed, status = " << destroyStatus << "."; - return false; - } - try { - if (tdt_print_.joinable()) { - MS_LOG(INFO) << "join tdt host receive process"; - tdt_print_.join(); - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); - } -#endif - unsigned int device_id; - device_id = device_id_; - TDT_StatusT status = TsdClose(device_id); - if (status != TDT_OK) { - MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << "."; - return false; - } - is_pynative_ge_init_ = false; - MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << "."; - } else { - MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << tsd_ref_ << "."; - } - - return true; -} -#else -bool MsContext::OpenTsd() { return true; } - -bool MsContext::CloseTsd(bool) { return true; } -#endif - -void MsContext::SetHcclOptions(std::map *ge_options) const { - auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); - auto env_rank_id = common::GetEnv("RANK_ID"); - auto env_device_id = std::to_string(device_id_); - if (!(env_table_file.empty() || env_rank_id.empty())) { - MS_LOG(INFO) << "Initialize Ge for distribute parameter"; - MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; - auto env_hccl_flag = common::GetEnv("HCCL_FLAG"); - if (!env_hccl_flag.empty()) { - (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag; - } - (*ge_options)["ge.exec.isUseHcom"] = "1"; - (*ge_options)["ge.exec.deviceId"] = env_device_id; - (*ge_options)["ge.exec.rankId"] = env_rank_id; - (*ge_options)["ge.exec.podName"] = env_rank_id; - (*ge_options)["ge.exec.rankTableFile"] = env_table_file; - (*ge_options)["ge.graphRunMode"] = "1"; - } else { - // device id is still needed for non-distribute case - (*ge_options)["ge.exec.deviceId"] = env_device_id; - MS_LOG(INFO) << "No hccl mode. " - "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV."; - } - - auto env_deploy_mode = common::GetEnv("DEPLOY_MODE"); - if (!env_deploy_mode.empty()) { - (*ge_options)["ge.exec.deployMode"] = env_deploy_mode; - } else { - (*ge_options)["ge.exec.deployMode"] = "0"; - MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0"; - } -} - -void MsContext::GetGeOptions(std::map *ge_options) const { -#ifdef ENABLE_GE - (*ge_options)["device_id"] = "0"; - (*ge_options)["ge.exec.enableDump"] = std::to_string(enable_dump_); - (*ge_options)["ge.exec.dumpPath"] = save_dump_path_; - (*ge_options)["ge.exec.dumpMode"] = "output"; - MS_LOG(INFO) << "The enable dump state is " << std::to_string(enable_dump_) << " and save dump path is " - << save_dump_path_ << "."; - (*ge_options)["ge.exec.profilingMode"] = std::to_string(profiling_mode_); - if (profiling_mode_) { - (*ge_options)["ge.exec.profilingOptions"] = profiling_options_; - } - - (*ge_options)["rank_table_file"] = ""; - auto env_ddk_version = common::GetEnv("DDK_VERSION"); - if (!env_ddk_version.empty()) { - (*ge_options)["ge.DDK_version"] = env_ddk_version; - } else { - (*ge_options)["ge.DDK_version"] = "1.60.T17.B830"; - } - (*ge_options)["graphType"] = "1"; - - if (graph_memory_max_size_ != "0") { - (*ge_options)["ge.graphMemoryMaxSize"] = graph_memory_max_size_; - } - - if (variable_memory_max_size_ != "0") { - (*ge_options)["ge.variableMemoryMaxSize"] = variable_memory_max_size_; - } - -#if ENABLE_TRAIN == 1 - (*ge_options)["ge.graphRunMode"] = "1"; -#endif - SetDisableReuseMemoryFlag(ge_options); - SetHcclOptions(ge_options); - - auto env_job_id = common::GetEnv("JOB_ID"); - if (!env_job_id.empty()) { - (*ge_options)["ge.exec.jobId"] = env_job_id; - } else { - (*ge_options)["ge.exec.jobId"] = "0"; - MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0"; - } - - auto env_fe_flag = common::GetEnv("FE_FLAG"); - if (!env_fe_flag.empty()) { - (*ge_options)["ge.feFlag"] = env_fe_flag; - MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; - } - - auto env_aicpu_flag = common::GetEnv("AICPU_FLAG"); - if (!env_aicpu_flag.empty()) { - (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag; - MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; - } - - auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH"); - if (!proto_lib_path.empty()) { - char real_path[PATH_MAX] = {0}; - if (realpath(proto_lib_path.c_str(), real_path)) { - proto_lib_path = real_path; - (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path; - } - } else { - MS_LOG(WARNING) << "Set proto lib path failed!"; - } - - // Enable auto mixed precision according to the context options - if (auto_mixed_precision_flag_) { - (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; - } else { - (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; - } - // Disable the global variable acc, only enable it whlie adding training graph in pipeline - (*ge_options)["ge.exec.variable_acc"] = "0"; -#endif -} - -void MsContext::SetDisableReuseMemoryFlag(std::map *ge_options) const { - auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY"); - if (!env_disable_reuse_memory.empty()) { - (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory; - } else { - (*ge_options)["ge.exec.disableReuseMemory"] = "0"; - MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0"; - } -} - -bool MsContext::InitGe() { -#ifdef ENABLE_GE - if (is_pynative_ge_init_) { - return true; - } - - if (ge_ref_) { - ge_ref_++; - return true; - } - - std::map ge_options; - GetGeOptions(&ge_options); - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) { - MS_LOG(EXCEPTION) << "Initialize GE failed!"; - } - } - ge_ref_++; - MS_LOG(INFO) << "Init ge successful, ge reference = " << ge_ref_ << "."; -#endif - return true; -} - -bool MsContext::FinalizeGe(bool force) { -#ifdef ENABLE_GE - if (ge_ref_ == 0) { - return true; - } - ge_ref_--; - if (force || ge_ref_ == 0) { - ge_ref_ = 0; - try { - DfGraphManager::GetInstance().DeleteGraphRunner(); - DfGraphManager::GetInstance().DeleteGeSession(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what(); - } catch (...) { - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName; - } - if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { - MS_LOG(WARNING) << "Finalize GE failed!"; - } - is_pynative_ge_init_ = false; - } else { - MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ge_ref_ << "."; - } -#endif - return true; -} - -bool MsContext::PynativeInitGe() { - if (is_pynative_ge_init_ || ge_ref_ || tsd_ref_) { - return true; - } - (void)OpenTsd(); - (void)InitGe(); - is_pynative_ge_init_ = true; - return true; -} - -bool MsContext::IsTsdOpened() { - if (tsd_ref_ > 0) { - return true; - } - return false; -} - -bool MsContext::IsGeInited() { - if (ge_ref_ > 0) { - return true; - } - return false; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h deleted file mode 100644 index 19205cccb8..0000000000 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ /dev/null @@ -1,215 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_CONTEXT_MS_CONTEXT_H_ -#define MINDSPORE_CCSRC_UTILS_CONTEXT_MS_CONTEXT_H_ -#include -#include -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" - -namespace mindspore { - -enum MsBackendPolicy { - kMsBackendGeOnly = 0, - kMsBackendVmOnly = 1, - kMsBackendGePrior = 2, - kMsBackendVmPrior = 3, - kMsBackendMsPrior = 4, - kMsBackendUnknown = 5, -}; - -const int kGraphMode = 0; -const int kPynativeMode = 1; -const char kCPUDevice[] = "CPU"; -const char kGPUDevice[] = "GPU"; -const char kAscendDevice[] = "Ascend"; -const char kDavinciInferenceDevice[] = "AscendInference"; -const char kDavinciDevice[] = "Davinci"; -const char KNpuLog[] = "_npu_log"; -const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; -// The default max available device memory is 1024GB. -const float kDefaultMaxDeviceMemory = 1024; - -class MsContext { - public: - ~MsContext() = default; - MsContext(const MsContext &) = delete; - MsContext &operator=(const MsContext &) = delete; - - static std::shared_ptr GetInstance(); - - std::string backend_policy() const; - bool set_backend_policy(const std::string &policy); - - int execution_mode() const { return execution_mode_; } - void set_execution_mode(int execution_mode); - - bool enable_pynative_infer() const { return enable_pynative_infer_; } - void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } - - bool enable_pynative_hook() const { return enable_pynative_hook_; } - void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; } - - bool enable_task_sink() const { return enable_task_sink_; } - - void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } - bool precompile_only() const { return precompile_only_; } - - std::string device_target() const { return device_target_; } - bool set_device_target(const std::string &target); - - uint32_t device_id() const { return device_id_; } - bool set_device_id(uint32_t device_id); - - bool save_graphs_flag() const { return save_graphs_flag_; } - void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } - - std::string save_graphs_path() const { return save_graphs_path_; } - void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } - - bool OpenTsd(); - bool CloseTsd(bool force = false); - bool IsTsdOpened(); - bool InitGe(); - bool FinalizeGe(bool force = false); - bool IsGeInited(); - void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; } - bool enable_hccl() const { return enable_hccl_; } - bool PynativeInitGe(); - - bool ir_fusion_flag() const { return ir_fusion_flag_; } - - bool loop_sink_flag() const { return enable_loop_sink_; } - void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } - void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } - bool enable_mem_reuse() const { return enable_mem_reuse_; } - - bool save_ms_model_flag() const { return save_ms_model_flag_; } - void set_save_ms_model_flag(bool save_ms_model_flag) { save_ms_model_flag_ = save_ms_model_flag; } - - std::string save_ms_model_path() const { return save_ms_model_path_; } - void set_save_ms_model_path(const std::string &save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } - - void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } - bool enable_gpu_summary() const { return enable_gpu_summary_; } - - void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) { - auto_mixed_precision_flag_ = auto_mixed_precision_flag; - } - bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; } - - void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; } - bool enable_reduce_precision() const { return enable_reduce_precision_; } - - void set_enable_dump(bool flag) { enable_dump_ = flag; } - bool enable_dump() const { return enable_dump_; } - - void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } - std::string save_dump_path() const { return save_dump_path_; } - - bool IsTsdOpened() const { return tsd_ref_ > 0; } - - bool is_multi_graph_sink() const { return is_multi_graph_sink_; } - void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } - - void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } - bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } - - void set_graph_memory_max_size(const std::string &graph_memory_max_size) { - graph_memory_max_size_ = graph_memory_max_size; - } - - void set_variable_memory_max_size(const std::string &variable_memory_max_size) { - variable_memory_max_size_ = variable_memory_max_size; - } - - const std::string &variable_memory_max_size() const { return variable_memory_max_size_; } - - const std::string &graph_memory_max_size() const { return graph_memory_max_size_; } - - void set_enable_profiling(bool flag) { profiling_mode_ = flag; } - bool enable_profiling() const { return profiling_mode_; } - - void set_profiling_options(const std::string &options) { profiling_options_ = options; } - std::string profiling_options() const { return profiling_options_; } - bool check_bprop_flag() const { return check_bprop_flag_; } - void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } - void set_print_file_path(const std::string &file) { print_file_path_ = file; } - const std::string &print_file_path() const { return print_file_path_; } - - float max_device_memory() const { return max_device_memory_; } - void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } - - void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } - bool enable_graph_kernel() const { return enable_graph_kernel_; } - - bool enable_sparse() const { return enable_sparse_; } - void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } - - private: - MsContext(const std::string &backend_policy, const std::string &target); - void GetGeOptions(std::map *ge_options) const; - void SetDisableReuseMemoryFlag(std::map *ge_options) const; - void SetHcclOptions(std::map *ge_options) const; - - static std::shared_ptr inst_context_; - static std::map policy_map_; - MsBackendPolicy backend_policy_; - std::string device_target_; - uint32_t device_id_; - int execution_mode_; - bool enable_pynative_infer_; - bool enable_pynative_hook_; - bool save_graphs_flag_; - std::string save_graphs_path_; - uint32_t tsd_ref_; - uint32_t ge_ref_; - bool enable_task_sink_; - bool enable_hccl_; - bool precompile_only_; - bool ir_fusion_flag_; - bool auto_mixed_precision_flag_; - bool enable_reduce_precision_; - bool enable_loop_sink_; - bool enable_mem_reuse_; - std::string save_ms_model_path_; - bool save_ms_model_flag_; - bool enable_gpu_summary_; - bool enable_dump_; - std::string save_dump_path_; - bool is_multi_graph_sink_; - bool is_pynative_ge_init_; - bool enable_dynamic_mem_pool_; - std::string graph_memory_max_size_; - std::string variable_memory_max_size_; - std::thread tdt_print_; - bool profiling_mode_; - std::string profiling_options_; - bool check_bprop_flag_; - float max_device_memory_; - std::string print_file_path_; - bool enable_graph_kernel_; - bool enable_sparse_; -}; - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_CONTEXT_MS_CONTEXT_H_ diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index b1847d1df5..70590da753 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -32,6 +32,7 @@ #include "ir/tensor.h" #include "ir/param_value.h" #include "utils/base_ref_extends.h" +#include "utils/ms_context.h" namespace mindspore { py::object BuiltinsToPyData(const Any &value); @@ -48,6 +49,10 @@ py::object ValuePtrToPyData(const ValuePtr &value) { MS_LOG(DEBUG) << "int"; py::int_ v = value->cast()->value(); ret = v; + } else if (value->isa()) { + MS_LOG(DEBUG) << "int64"; + py::int_ v = value->cast()->value(); + ret = v; } else if (value->isa()) { MS_LOG(DEBUG) << "uint64"; py::int_ v = value->cast()->value(); @@ -366,9 +371,9 @@ py::object VectorRefToPyData(const VectorRef &value_list) { return ret; } -AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) { - if ((py::isinstance(shape_obj) || py::isinstance(shape_obj)) && - py::hasattr(type_obj, PYTHON_DTYPE_FLAG)) { +AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, + const py::object &min_shape, const py::object &max_shape) { + if ((py::isinstance(shape_obj) || py::isinstance(shape_obj)) && py::isinstance(type_obj)) { auto ret_vec = shape_obj.cast>(); auto ret_dtype = type_obj.cast(); MS_EXCEPTION_IF_NULL(ret_dtype); @@ -378,12 +383,23 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py return abs_scalar; } AbstractBasePtr tensor = nullptr; + std::vector min_shape_vec; + std::vector max_shape_vec; + if (!min_shape.is_none()) { + min_shape_vec = min_shape.cast>(); + } + if (!max_shape.is_none()) { + max_shape_vec = max_shape.cast>(); + } + auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); if (ret_dtype->isa()) { auto tensor_type = type_obj.cast(); MS_EXCEPTION_IF_NULL(tensor_type); - tensor = std::make_shared(tensor_type->element(), ret_vec); + auto element = std::make_shared(kAnyValue, tensor_type->element()); + tensor = std::make_shared(element, ret_shape); } else { - tensor = std::make_shared(ret_dtype, ret_vec); + auto element = std::make_shared(kAnyValue, ret_dtype); + tensor = std::make_shared(element, ret_shape); } return tensor; } else if (py::isinstance(shape_obj) && py::isinstance(type_obj)) { @@ -401,6 +417,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py auto abstract_none = std::make_shared(); return abstract_none; } else { + // When sparse enabled, the undetermined might be raised and eliminated in opt passes + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (enable_sparse) { + return std::make_shared(); + } MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj); } } @@ -449,7 +472,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple if (!param->has_default()) { MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")"; } - auto tensor = param->default_param()->value(); + auto tensor = param->default_param(); *ret_val = py::cast(tensor); } return true; @@ -607,4 +630,25 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { MS_EXCEPTION_IF_NULL(tensor); return tensor; } + +void TensorValueToTensor(const ValuePtr &value, std::vector *tensors) { + MS_EXCEPTION_IF_NULL(value); + MS_EXCEPTION_IF_NULL(tensors); + if (value->isa()) { + auto value_tuple = value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + for (size_t i = 0; i < value_tuple->size(); ++i) { + ValuePtr element = value_tuple->value()[i]; + if (element->isa()) { + auto tensor = element->cast(); + MS_EXCEPTION_IF_NULL(tensor); + tensors->push_back(tensor); + } + } + } else if (value->isa()) { + tensor::TensorPtr tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + tensors->push_back(tensor); + } +} } // namespace mindspore diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index d4ecbf4408..5597ae4d5e 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -21,13 +21,14 @@ #include #include #include +#include #include #include #include "pybind11/pybind11.h" #include "utils/convert_utils_base.h" #include "utils/any.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" #include "base/base.h" #include "ir/anf.h" @@ -46,7 +47,9 @@ bool BaseRefToInt(const ValuePtr &v, int *value); bool ValueToBool(const ValuePtr &in, bool *out); py::object ValuePtrToPyData(const ValuePtr &value); -AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj); +AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, + const py::object &min_shape = py::none(), + const py::object &max_shape = py::none()); bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, const std::shared_ptr &ret_val); @@ -69,6 +72,8 @@ using NodeMapEquiv = std::unordered_map; bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); + +void TensorValueToTensor(const ValuePtr &value, std::vector *tensors); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_ diff --git a/mindspore/ccsrc/utils/convert_utils_base.h b/mindspore/ccsrc/utils/convert_utils_base.h deleted file mode 100644 index b9a38f997f..0000000000 --- a/mindspore/ccsrc/utils/convert_utils_base.h +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_BASE_H_ -#define MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_BASE_H_ - -#include -#include - -#include "utils/log_adapter.h" - -namespace mindspore { -inline int SizeToInt(size_t u) { - if (u > static_cast((std::numeric_limits::max)())) { - MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int."; - } - return static_cast(u); -} - -inline uint32_t SizeToUint(size_t u) { - if (u > static_cast((std::numeric_limits::max)())) { - MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of uint32_t."; - } - return static_cast(u); -} - -inline int64_t SizeToLong(size_t u) { - if (u > static_cast((std::numeric_limits::max)())) { - MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int64_t."; - } - return static_cast(u); -} - -inline size_t IntToSize(int u) { - if (u < 0) { - MS_LOG(EXCEPTION) << "The int value(" << u << ") is less than 0."; - } - return static_cast(u); -} - -inline size_t LongToSize(int64_t u) { - if (u < 0) { - MS_LOG(EXCEPTION) << "The int64_t value(" << u << ") is less than 0."; - } - return static_cast(u); -} - -inline size_t FloatToSize(float u) { - if (u < 0) { - MS_LOG(EXCEPTION) << "The float value(" << u << ") is less than 0."; - } - - if (u > static_cast((std::numeric_limits::max)())) { - MS_LOG(EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of size_t."; - } - return static_cast(u); -} -inline float IntToFloat(int32_t v) { return static_cast(v); } - -inline uint32_t IntToUint(int32_t u) { - if (u < 0) { - MS_LOG(EXCEPTION) << "The int32_t value(" << u << ") is less than 0."; - } - return static_cast(u); -} - -inline int32_t UintToInt(uint32_t u) { - if (u > static_cast((std::numeric_limits::max)())) { - MS_LOG(EXCEPTION) << "The uint32_t value(" << u << ") exceeds the maximum value of int32_t."; - } - return static_cast(u); -} - -inline unsigned int UlongToUint(size_t u) { - if (u > static_cast((std::numeric_limits::max)())) { - MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of unsigned int."; - } - return static_cast(u); -} - -inline int IntMulWithOverflowCheck(int a, int b) { - int out = a * b; - if (a != 0) { - bool overflow = ((out / a) != b); - if (overflow) { - MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; - } - } - return out; -} - -inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) { - int64_t out = a * b; - if (a != 0) { - bool overflow = ((out / a) != b); - if (overflow) { - MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; - } - } - return out; -} - -inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { - size_t out = a * b; - if (a != 0) { - if ((out / a) != b) { - MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; - } - } - return out; -} - -inline uint8_t *AddressOffset(void *address, size_t offset) { - MS_EXCEPTION_IF_NULL(address); - return static_cast(address) + offset; -} -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_BASE_H_ diff --git a/mindspore/ccsrc/utils/counter.h b/mindspore/ccsrc/utils/counter.h deleted file mode 100644 index ead0ad84f2..0000000000 --- a/mindspore/ccsrc/utils/counter.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_COUNTER_H_ -#define MINDSPORE_CCSRC_UTILS_COUNTER_H_ -#include -#include "utils/ordered_map.h" - -namespace mindspore { - -template , class Equal = std::equal_to> -class Counter { - using counter_type = Counter; - - public: - Counter() = default; - ~Counter() = default; - - Counter(const Counter &other) { data = other.data; } - Counter &operator=(const Counter &other) { - if (this != &other) { - data = other.data; - } - return *this; - } - - int &operator[](const T &t) { return data[t]; } - - counter_type operator-(const counter_type &other) { - counter_type new_counter; - for (auto iter = begin(); iter != end(); ++iter) { - auto key = iter->first; - int value = iter->second; - auto item = other.data.find(key); - if (item != other.data.end()) { - int o_value = item->second; - if (value - o_value > 0) { - new_counter[key] = value - o_value; - } - } else { - new_counter[key] = value; - } - } - - return new_counter; - } - - counter_type operator+(const counter_type &other) { - counter_type new_counter; - for (auto iter = begin(); iter != end(); ++iter) { - auto key = iter->first; - int value = iter->second; - auto item = other.data.find(key); - if (item != other.data.end()) { - new_counter[key] = iter->second + item->second; - } else { - new_counter[key] = value; - } - } - - for (auto iter = other.cbegin(); iter != other.cend(); ++iter) { - auto key = iter->first; - int value = iter->second; - if (!new_counter.contains(key)) { - new_counter[key] = value; - } - } - - return new_counter; - } - - std::size_t size() const { return data.size(); } - - bool contains(const T &t) const { return data.find(t) != data.end(); } - - typename OrderedMap::iterator begin() { return data.begin(); } - - typename OrderedMap::iterator end() { return data.end(); } - - typename OrderedMap::const_iterator cbegin() const { return data.cbegin(); } - - typename OrderedMap::const_iterator cend() const { return data.cend(); } - - private: - OrderedMap data; -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_COUNTER_H_ diff --git a/mindspore/ccsrc/utils/dtype_py.cc b/mindspore/ccsrc/utils/dtype_py.cc new file mode 100644 index 0000000000..2b62bc0b84 --- /dev/null +++ b/mindspore/ccsrc/utils/dtype_py.cc @@ -0,0 +1,159 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +// Define python wrapper to handle data types. +REGISTER_PYBIND_DEFINE( + typing, ([](py::module *const m) { + auto m_sub = m->def_submodule("typing", "submodule for dtype"); + py::enum_(m_sub, "TypeId"); + (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); + (void)m_sub.def("load_type", &TypeIdToType, "load type"); + (void)m_sub.def( + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); + (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); + (void)py::class_>(m_sub, "Type") + .def("__eq__", + [](const TypePtr &t1, const py::object &other) { + if (!py::isinstance(other)) { + return false; + } + auto t2 = py::cast(other); + if (t1 != nullptr && t2 != nullptr) { + return *t1 == *t2; + } + return false; + }) + .def("__hash__", &Type::hash) + .def("__str__", &Type::ToString) + .def("__repr__", &Type::ReprString) + .def("__deepcopy__", [](const TypePtr &t, py::dict) { + if (t == nullptr) { + return static_cast(nullptr); + } + return t->DeepCopy(); + }); + (void)py::class_>(m_sub, "Number").def(py::init()); + (void)py::class_>(m_sub, "Bool") + .def(py::init()) + .def(py::pickle( + [](const Bool &) { // __getstate__ + return py::make_tuple(); + }, + [](const py::tuple &) { // __setstate__ + return std::make_shared(); + })); + (void)py::class_>(m_sub, "Int") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Int &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Int data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "UInt") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const UInt &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + UInt data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "Float") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Float &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Float data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "List") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "Tuple") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "TensorType") + .def(py::init()) + .def(py::init(), py::arg("element")) + .def("element_type", &TensorType::element) + .def(py::pickle( + [](const TensorType &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); + return data; + })); + (void)py::class_>(m_sub, "RowTensorType").def(py::init()); + (void)py::class_>(m_sub, "SparseTensorType") + .def(py::init()); + (void)py::class_>(m_sub, "UndeterminedType") + .def(py::init()); + (void)py::class_>(m_sub, "Function") + .def(py::init()) + .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); + (void)py::class_>(m_sub, "Class").def(py::init()); + (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); + (void)py::class_>(m_sub, "EnvType").def(py::init()); + (void)py::class_>(m_sub, "TypeNone").def(py::init()); + (void)py::class_>(m_sub, "TypeType").def(py::init()); + (void)py::class_>(m_sub, "String").def(py::init()); + (void)py::class_>(m_sub, "RefKeyType").def(py::init()); + (void)py::class_>(m_sub, "RefType").def(py::init()); + (void)py::class_>(m_sub, "TypeAnything").def(py::init()); + (void)py::class_>(m_sub, "Slice").def(py::init()); + (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/func_graph_py.cc b/mindspore/ccsrc/utils/func_graph_py.cc new file mode 100644 index 0000000000..cdddb7b08d --- /dev/null +++ b/mindspore/ccsrc/utils/func_graph_py.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "ir/meta_func_graph.h" +#include "ir/func_graph.h" + +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { + // Define python "MetaFuncGraph_" class + (void)py::class_>(*m, "MetaFuncGraph_") + .def(py::init()); + // Define python "FuncGraph" class + (void)py::class_(*m, "FuncGraph") + .def(py::init()) + .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") + .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc deleted file mode 100644 index 03ac14573d..0000000000 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ /dev/null @@ -1,272 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/graph_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "debug/label.h" -#include "ir/func_graph.h" -#include "utils/log_adapter.h" - -namespace mindspore { -std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { - size_t seen = NewSeenGeneration(); - std::deque todo(1024); - std::unordered_map rank; - std::vector res; - todo.clear(); - todo.push_back(root); - - while (!todo.empty()) { - AnfNodePtr node = todo.back(); - if (node == nullptr || node->seen_ == seen) { - todo.pop_back(); - continue; - } - if (rank.find(node) != rank.end() && rank[node] != todo.size()) { - MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(); - } - rank[node] = todo.size(); - bool cont = false; - auto incl = include(node); - if (incl == FOLLOW) { - auto succs = succ(node); - for (const auto i : succs) { - if ((i != nullptr && i->seen_ != seen) - // Handle the case for 2 subgraphs calls each other. - // If the ValueNodeGraph's return is already in the todo list, do not follow it. - && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && - (i->func_graph()->get_return() == i))) { - todo.push_back(i); - cont = true; - } - } - } else if (incl == NOFOLLOW) { - // do nothing - } else if (incl == EXCLUDE) { - node->seen_ = seen; - todo.pop_back(); - continue; - } else { - MS_LOG(EXCEPTION) << "include(node) must return one of: \"follow\", \"nofollow\", \"exclude\""; - } - if (cont) { - continue; - } - node->seen_ = seen; - res.push_back(node); - todo.pop_back(); - } - return res; -} - -// search the cnodes inside this graph only -std::vector BroadFirstSearchGraphCNodes(CNodePtr ret) { - std::queue todo; - todo.push(ret); - std::vector sorted_nodes; - auto seen = NewSeenGeneration(); - while (!todo.empty()) { - CNodePtr top = todo.front(); - todo.pop(); - sorted_nodes.push_back(top); - auto inputs = top->inputs(); - for (auto &item : inputs) { - if (item->seen_ == seen) { - continue; - } - - if (item->isa()) { - todo.push(item->cast()); - } - item->seen_ = seen; - } - } - return sorted_nodes; -} - -std::vector SuccDeeper(const AnfNodePtr &node) { - std::vector vecs; - if (node == nullptr) { - return vecs; - } - - if (IsValueNode(node)) { - auto graph = GetValueNode(node); - auto ret = graph->get_return(); - if (ret != nullptr) { - vecs.push_back(ret); - } - return vecs; - } else if (node->func_graph() != nullptr) { - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); - } - auto graph = node->func_graph(); - if (graph->get_return() != nullptr) { - vecs.push_back(graph->get_return()); - } - return vecs; - } - - return vecs; -} - -std::vector SuccDeeperSimple(const AnfNodePtr &node) { - std::vector vecs; - if (node == nullptr) { - return vecs; - } - - if (IsValueNode(node)) { - auto graph = GetValueNode(node); - auto ret = graph->get_return(); - if (ret != nullptr) { - vecs.push_back(ret); - } - return vecs; - } else { - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); - } - return vecs; - } -} - -std::vector SuccIncoming(const AnfNodePtr &node) { - std::vector vecs; - if (node == nullptr) { - return vecs; - } - - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); - } - return vecs; -} - -std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) { - std::vector vecs; - if (node == nullptr) { - return vecs; - } - if (node->isa()) { - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - // Check if free variables used. - for (const auto &input : inputs) { - auto input_fg = GetValueNode(input); - if (input_fg) { - for (auto &fv : input_fg->free_variables_nodes()) { - if (fv->func_graph() == fg && fg->nodes().contains(fv)) { - vecs.push_back(fv); - } - } - } - } - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); - } - return vecs; -} - -IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } - -IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { - if (node->func_graph() == fg) { - return FOLLOW; - } else { - return EXCLUDE; - } -} - -FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { - MS_EXCEPTION_IF_NULL(fg); - Acquire(fg); - - auto vec = search(fg->get_return(), include); - for (auto &node : vec) { - MS_EXCEPTION_IF_NULL(node); - Acquire(node); - if (node->func_graph() != nullptr) { - Acquire(node->func_graph()); - } - } -} - -std::set FuncGraphIndex::GetFuncGraphs(const std::string &key) { - std::set func_graphs; - if (index_func_graph_.find(key) != index_func_graph_.end()) { - func_graphs = index_func_graph_[key]; - } - return func_graphs; -} - -std::set FuncGraphIndex::GetNodes(const std::string &key) { - if (index_node_.find(key) != index_node_.end()) { - return index_node_[key]; - } - - return std::set(); -} - -FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { - if (GetFuncGraphs(key).empty()) { - return nullptr; - } - - auto fg = *GetFuncGraphs(key).begin(); - return fg; -} - -AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { - if (GetNodes(key).empty()) { - return nullptr; - } - - auto node = *GetNodes(key).begin(); - return node; -} - -void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { - std::string name = label_manage::Label(key->debug_info()); - if (!name.empty()) { - (void)index_func_graph_[name].insert(key); - } -} - -void FuncGraphIndex::Acquire(const AnfNodePtr &key) { - std::string name = label_manage::Label(key->debug_info()); - if (!name.empty()) { - (void)index_node_[name].insert(key); - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h deleted file mode 100644 index 2a9240ac84..0000000000 --- a/mindspore/ccsrc/utils/graph_utils.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ -#define MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/primitive.h" -#include "ir/scalar.h" -#include "ir/tensor.h" -#include "debug/label.h" - -namespace mindspore { - -enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; - -using IncludeFunc = std::function; -using FilterFunc = std::function; -using SuccFunc = std::function(AnfNodePtr)>; -using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; - -std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); -std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); -std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); - -std::vector SuccDeeper(const AnfNodePtr &node); -std::vector SuccDeeperSimple(const AnfNodePtr &node); -std::vector SuccIncoming(const AnfNodePtr &node); -std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); - -IncludeType AlwaysInclude(const AnfNodePtr &node); -IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); - -std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); -std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); -std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); - -std::vector DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, - const FilterFunc &filter); - -class FuncGraphManager; -using FuncGraphManagerPtr = std::shared_ptr; -std::vector DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, - const FuncGraphManagerPtr &mng); -std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, - const IncludeFunc &include = AlwaysInclude); - -std::vector BroadFirstSearchGraphCNodes(CNodePtr ret); -class FuncGraphIndex { - public: - explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, - const IncludeFunc &include = AlwaysInclude); - FuncGraphIndex(const FuncGraphIndex &) = delete; - FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; - - virtual ~FuncGraphIndex() {} - - std::set GetFuncGraphs(const std::string &key); - std::set GetNodes(const std::string &key); - FuncGraphPtr GetFirstFuncGraph(const std::string &key); - AnfNodePtr GetFirstNode(const std::string &key); - - private: - void Acquire(const FuncGraphPtr &key); - void Acquire(const AnfNodePtr &key); - - std::map> index_func_graph_; - std::map> index_node_; -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ diff --git a/mindspore/ccsrc/utils/graph_utils_extends.cc b/mindspore/ccsrc/utils/graph_utils_extends.cc deleted file mode 100644 index 852dd0e3f2..0000000000 --- a/mindspore/ccsrc/utils/graph_utils_extends.cc +++ /dev/null @@ -1,207 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/graph_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/visitor.h" -#include "ir/manager.h" -#include "ir/func_graph.h" -#include "debug/label.h" -#include "utils/log_adapter.h" -#include "common/utils.h" -#include "pipeline/jit/parse/function_block.h" -#include "pipeline/jit/parse/python_adapter.h" - -namespace mindspore { -namespace { -class DeepFirstSearcher : public AnfVisitor { - public: - explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr) - : include_(include), filter_(filter) {} - ~DeepFirstSearcher() override = default; - - std::vector Search(const AnfNodePtr &root) { - if (root == nullptr) { - return res_; - } - seen_ = NewSeenGeneration(); - Visit(root); - return res_; - } - - void Visit(const AnfNodePtr &node) override { - MS_EXCEPTION_IF_NULL(node); - if (node->seen_ == seen_) { - return; - } - - node->seen_ = seen_; - - auto incl = include_(node); - if (incl == EXCLUDE) { - return; - } - if (filter_ == nullptr || !filter_(node)) { - res_.push_back(node); - } - if (incl == FOLLOW) { - AnfVisitor::Visit(node); - } - } - - private: - size_t seen_{0}; - IncludeFunc include_; - FilterFunc filter_; - std::vector res_{}; -}; - -class DeepScopedGraphSearcher : public DeepFirstSearcher { - public: - explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} - ~DeepScopedGraphSearcher() override = default; - - void Visit(const CNodePtr &cnode) override { - if (cnode->func_graph() == nullptr) { - return; - } - - AnfNodePtr ret = cnode->func_graph()->get_return(); - if (ret != nullptr) { - DeepFirstSearcher::Visit(ret); - } - - auto &inputs = cnode->inputs(); - for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { - DeepFirstSearcher::Visit(*iter); - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (!IsValueNode(vnode)) { - return; - } - - auto graph = GetValueNode(vnode); - AnfNodePtr ret = graph->get_return(); - if (ret != nullptr) { - DeepFirstSearcher::Visit(ret); - } - } - - void Visit(const ParameterPtr ¶m) override { - if (param->func_graph() == nullptr) { - return; - } - - AnfNodePtr ret = param->func_graph()->get_return(); - if (ret != nullptr) { - DeepFirstSearcher::Visit(ret); - } - } -}; - -class DeepUsedGraphSearcher : public DeepFirstSearcher { - public: - explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} - ~DeepUsedGraphSearcher() override = default; - - void Visit(const CNodePtr &cnode) override { - auto &inputs = cnode->inputs(); - for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { - DeepFirstSearcher::Visit(*iter); - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (!IsValueNode(vnode)) { - return; - } - - auto graph = GetValueNode(vnode); - AnfNodePtr ret = graph->get_return(); - if (ret != nullptr) { - DeepFirstSearcher::Visit(ret); - } - } -}; - -class DeepLinkedGraphSearcher : public DeepFirstSearcher { - public: - explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} - ~DeepLinkedGraphSearcher() override = default; - - void Visit(const CNodePtr &cnode) override { - auto &inputs = cnode->inputs(); - for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { - DeepFirstSearcher::Visit(*iter); - } - } - - void Visit(const ValueNodePtr &) override {} -}; - -class DeepUsersSearcher : public DeepFirstSearcher { - public: - explicit DeepUsersSearcher(const IncludeFunc &include, const FuncGraphManagerPtr &mng) - : DeepFirstSearcher(include), mng_(mng) {} - ~DeepUsersSearcher() override = default; - - void Visit(const CNodePtr &cnode) override { - auto &users = mng_->node_users()[cnode]; - for (auto iter = users.begin(); iter != users.end(); ++iter) { - DeepFirstSearcher::Visit(iter->first); - } - } - void Visit(const ValueNodePtr &) override {} - - private: - FuncGraphManagerPtr mng_; -}; -} // namespace - -// include for if expand the node the search, filter for if put the node to results. -std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { - return DeepScopedGraphSearcher(include).Search(root); -} - -std::vector DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, - const FilterFunc &filter) { - return DeepFirstSearcher(include, filter).Search(root); -} - -std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { - return DeepUsedGraphSearcher(include).Search(root); -} - -std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { - return DeepLinkedGraphSearcher(include).Search(root); -} - -std::vector DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, - const FuncGraphManagerPtr &mng) { - return DeepUsersSearcher(include, mng).Search(root); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/hashing.h b/mindspore/ccsrc/utils/hashing.h deleted file mode 100644 index cc8cc5b991..0000000000 --- a/mindspore/ccsrc/utils/hashing.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_HASHING_H_ -#define MINDSPORE_CCSRC_UTILS_HASHING_H_ - -#include - -namespace mindspore { -inline std::size_t hash_combine(std::size_t hash_sum, std::size_t hash_val) { - // Reference from http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0814r0.pdf - return ((hash_sum << 6) + (hash_sum >> 2) + 0x9e3779b9 + hash_val) ^ hash_sum; -} - -inline std::size_t hash_combine(const std::initializer_list &hash_vals) { - std::size_t hash_sum = 0; - for (auto hash_val : hash_vals) { - hash_sum = hash_combine(hash_sum, hash_val); - } - return hash_sum; -} -} // namespace mindspore -#endif // MINDSPORE_CCSRC_UTILS_HASHING_H_ diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc index fa1137e3f6..b592164df9 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc @@ -124,10 +124,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; } - auto param_value = std::make_shared(); - MS_EXCEPTION_IF_NULL(param_value); - param_value->set_value(tensor_info); - node->set_default_param(param_value); + node->set_default_param(tensor_info); } anfnode_build_map_[value_proto.name()] = node; return true; diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc deleted file mode 100644 index 702deefcb4..0000000000 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ /dev/null @@ -1,533 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/log_adapter.h" - -#include -#include -#include "debug/trace.h" - -// namespace to support utils module definition -namespace mindspore { -#ifdef USE_GLOG -static std::string GetTime() { -#define BUFLEN 80 - static char buf[BUFLEN]; -#if defined(_WIN32) || defined(_WIN64) - time_t time_seconds = time(0); - struct tm now_time; - localtime_s(&now_time, &time_seconds); - sprintf_s(buf, BUFLEN, "%d-%d-%d %d:%d:%d", now_time.tm_year + 1900, now_time.tm_mon + 1, now_time.tm_mday, - now_time.tm_hour, now_time.tm_min, now_time.tm_sec); -#else - struct timeval cur_time; - (void)gettimeofday(&cur_time, nullptr); - - struct tm now; - (void)localtime_r(&cur_time.tv_sec, &now); - (void)strftime(buf, BUFLEN, "%Y-%m-%d-%H:%M:%S", &now); // format date and time - // set micro-second - buf[27] = '\0'; - int idx = 26; - auto num = cur_time.tv_usec; - for (int i = 5; i >= 0; i--) { - buf[idx--] = static_cast(num % 10 + '0'); - num /= 10; - if (i % 3 == 0) { - buf[idx--] = '.'; - } - } -#endif - return std::string(buf); -} - -static std::string GetProcName() { -#if defined(__APPLE__) || defined(__FreeBSD__) - const char *appname = getprogname(); -#elif defined(_GNU_SOURCE) - const char *appname = program_invocation_name; -#else - const char *appname = "?"; -#endif - // some times, the appname is an absolute path, its too long - std::string app_name(appname); - std::size_t pos = app_name.rfind("/"); - if (pos == std::string::npos) { - return app_name; - } - if (pos + 1 >= app_name.size()) { - return app_name; - } - return app_name.substr(pos + 1); -} - -static std::string GetLogLevel(MsLogLevel level) { -#define _TO_STRING(x) #x - static const char *const level_names[] = { - _TO_STRING(DEBUG), - _TO_STRING(INFO), - _TO_STRING(WARNING), - _TO_STRING(ERROR), - }; -#undef _TO_STRING - if (level > ERROR) { - level = ERROR; - } - return std::string(level_names[level]); -} - -// convert MsLogLevel to corresponding glog level -static int GetGlogLevel(MsLogLevel level) { - switch (level) { - case DEBUG: - case INFO: - return google::GLOG_INFO; - case WARNING: - return google::GLOG_WARNING; - case ERROR: - default: - return google::GLOG_ERROR; - } -} -#else - -#undef Dlog -#define Dlog(module_id, level, format, ...) \ - do { \ - DlogInner((module_id), (level), (format), ##__VA_ARGS__); \ - } while (0) - -// convert MsLogLevel to corresponding slog level -static int GetSlogLevel(MsLogLevel level) { - switch (level) { - case DEBUG: - return DLOG_DEBUG; - case INFO: - return DLOG_INFO; - case WARNING: - return DLOG_WARN; - case ERROR: - default: - return DLOG_ERROR; - } -} -#endif - -static std::string ExceptionTypeToString(ExceptionType type) { -#define _TO_STRING(x) #x - // clang-format off - static const char *const type_names[] = { - _TO_STRING(NoExceptionType), - _TO_STRING(UnknownError), - _TO_STRING(ArgumentError), - _TO_STRING(NotSupportError), - _TO_STRING(NotExistsError), - _TO_STRING(AlreadyExistsError), - _TO_STRING(UnavailableError), - _TO_STRING(DeviceProcessError), - _TO_STRING(AbortedError), - _TO_STRING(TimeOutError), - _TO_STRING(ResourceUnavailable), - _TO_STRING(NoPermissionError), - _TO_STRING(IndexError), - _TO_STRING(ValueError), - _TO_STRING(TypeError), - }; - // clang-format on -#undef _TO_STRING - if (type < UnknownError || type > TypeError) { - type = UnknownError; - } - return std::string(type_names[type]); -} - -static const char *GetSubModuleName(SubModuleId module_id) { - static const char *sub_module_names[NUM_SUBMODUES] = { - "UNKNOWN", // SM_UNKNOWN - "BASE", // SM_BASE - "ANALYZER", // SM_ANALYZER - "COMMON", // SM_COMMON - "DEBUG", // SM_DEBUG - "DEVICE", // SM_DEVICE - "GE_ADPT", // SM_GE_ADPT - "IR", // SM_IR - "KERNEL", // SM_KERNEL - "MD", // SM_MD - "ME", // SM_ME - "ONNX", // SM_ONNX - "OPTIMIZER", // SM_OPTIMIZER - "PARALLEL", // SM_PARALLEL - "PARSER", // SM_PARSER - "PIPELINE", // SM_PIPELINE - "PRE_ACT", // SM_PRE_ACT - "PYNATIVE", // SM_PYNATIVE - "SESSION", // SM_SESSION - "UTILS", // SM_UTILS - "VM", // SM_VM - "ABSTRACT" // SM_ABSTRACT - }; - - return sub_module_names[module_id % NUM_SUBMODUES]; -} - -void LogWriter::OutputLog(const std::ostringstream &msg) const { -#ifdef USE_GLOG - auto submodule_name = GetSubModuleName(submodule_); - google::LogMessage("", 0, GetGlogLevel(log_level_)).stream() - << "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << GetProcName() - << "):" << GetTime() << " " - << "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl; -#else - auto str_msg = msg.str(); - auto slog_module_id = (submodule_ == SM_MD ? MD : ME); - Dlog(static_cast(slog_module_id), GetSlogLevel(log_level_), "[%s:%d] %s] %s", location_.file_, location_.line_, - location_.func_, str_msg.c_str()); -#endif -} - -void LogWriter::operator<(const LogStream &stream) const noexcept { - std::ostringstream msg; - msg << stream.sstream_->rdbuf(); - OutputLog(msg); -} - -void LogWriter::operator^(const LogStream &stream) const { - std::ostringstream msg; - msg << stream.sstream_->rdbuf(); - OutputLog(msg); - - std::ostringstream oss; - oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; - if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError && - exception_type_ != ValueError) { - oss << ExceptionTypeToString(exception_type_) << " "; - } - oss << msg.str(); - - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - - if (exception_handler_ != nullptr) { - exception_handler_(exception_type_, oss.str()); - } - throw std::runtime_error(oss.str()); -} - -static std::string GetEnv(const std::string &envvar) { - const char *value = ::getenv(envvar.c_str()); - - if (value == nullptr) { - return std::string(); - } - - return std::string(value); -} - -enum LogConfigToken { - INVALID, // indicate invalid token - LEFT_BRACE, // '{' - RIGHT_BRACE, // '}' - VARIABLE, // '[A-Za-z][A-Za-z0-9_]*' - NUMBER, // [0-9]+ - COMMA, // ',' - COLON, // ':' - EOS, // End Of String, '\0' - NUM_LOG_CFG_TOKENS -}; - -static const char *g_tok_names[NUM_LOG_CFG_TOKENS] = { - "invalid", // indicate invalid token - "{", // '{' - "}", // '}' - "variable", // '[A-Za-z][A-Za-z0-9_]*' - "number", // [0-9]+ - ",", // ',' - ":", // ':' - "end-of-string", // End Of String, '\0' -}; - -static inline bool IsAlpha(char ch) { return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); } - -static inline bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; } - -class LogConfigLexer { - public: - explicit LogConfigLexer(const std::string &text) : buffer_(text) { - cur_idx_ = 0; - cur_token_ = LogConfigToken::INVALID; - } - ~LogConfigLexer() = default; - - // skip white space, and return the first char after white space - char SkipWhiteSpace() { - while (cur_idx_ < buffer_.size()) { - char ch = buffer_[cur_idx_]; - if (ch == ' ' || ch == '\t') { - ++cur_idx_; - continue; - } - return ch; - } - return '\0'; - } - - LogConfigToken GetNext(std::string *const ptr) { -#ifdef DEBUG - std::string text; - auto tok = GetNextInner(&text); - MS_LOG(DEBUG) << "Got token " << tok << " with value [" << text << "]"; - if (ptr != nullptr) { - *ptr = text; - } - return tok; - } - - LogConfigToken GetNextInner(std::string *ptr) { -#endif - char ch = SkipWhiteSpace(); - // clang-format off - static const std::map single_char_map = { - {'{', LogConfigToken::LEFT_BRACE}, - {'}', LogConfigToken::RIGHT_BRACE}, - {',', LogConfigToken::COMMA}, - {':', LogConfigToken::COLON}, - {'\0', LogConfigToken::EOS}, - }; - // clang-format on - - auto iter = single_char_map.find(ch); - if (iter != single_char_map.end()) { - if (ptr != nullptr) { - *ptr = std::string() + ch; - } - ++cur_idx_; - return iter->second; - } else if (IsAlpha(ch)) { - std::ostringstream oss; - do { - oss << ch; - ch = buffer_[++cur_idx_]; - } while (cur_idx_ < buffer_.size() && (IsAlpha(ch) || IsDigit(ch) || ch == '_')); - if (ptr != nullptr) { - *ptr = std::string(oss.str()); - } - return LogConfigToken::VARIABLE; - } else if (IsDigit(ch)) { - std::ostringstream oss; - do { - oss << ch; - ch = buffer_[++cur_idx_]; - } while (cur_idx_ < buffer_.size() && IsDigit(ch)); - if (ptr != nullptr) { - *ptr = std::string(oss.str()); - } - return LogConfigToken::NUMBER; - } - return LogConfigToken::INVALID; - } - - private: - std::string buffer_; - size_t cur_idx_; - - LogConfigToken cur_token_; - std::string cur_text_; -}; - -class LogConfigParser { - public: - explicit LogConfigParser(const std::string &cfg) : lexer(cfg) {} - ~LogConfigParser() = default; - - bool Expect(LogConfigToken expected, LogConfigToken tok) { - if (expected != tok) { - MS_LOG(WARNING) << "Parse submodule log configuration text error, expect `" << g_tok_names[expected] - << "`, but got `" << g_tok_names[tok] << "`. The whole configuration will be ignored."; - return false; - } - return true; - } - - // The text of config MS_SUBMODULE_LOG_v is in the form {submodule1:log_level1,submodule2:log_level2,...}. - // Valid values of log levels are: 0 - debug, 1 - info, 2 - warning, 3 - error - // e.g. MS_SUBMODULE_LOG_v={PARSER:0, ANALYZER:2, PIPELINE:1} - std::map Parse() { - std::map log_levels; - - bool flag_error = false; - std::string text; - auto tok = lexer.GetNext(&text); - - // empty string - if (tok == LogConfigToken::EOS) { - return log_levels; - } - - if (!Expect(LogConfigToken::LEFT_BRACE, tok)) { - return log_levels; - } - - do { - std::string key, val; - tok = lexer.GetNext(&key); - if (!Expect(LogConfigToken::VARIABLE, tok)) { - flag_error = true; - break; - } - - tok = lexer.GetNext(&text); - if (!Expect(LogConfigToken::COLON, tok)) { - flag_error = true; - break; - } - - tok = lexer.GetNext(&val); - if (!Expect(LogConfigToken::NUMBER, tok)) { - flag_error = true; - break; - } - - log_levels[key] = val; - tok = lexer.GetNext(&text); - } while (tok == LogConfigToken::COMMA); - - if (!flag_error && !Expect(LogConfigToken::RIGHT_BRACE, tok)) { - flag_error = true; - } - - if (flag_error) { - log_levels.clear(); - } - return log_levels; - } - - private: - LogConfigLexer lexer; -}; - -bool ParseLogLevel(const std::string &str_level, MsLogLevel *ptr_level) { - if (str_level.size() == 1) { - int ch = str_level.c_str()[0]; - ch = ch - '0'; // substract ASCII code of '0', which is 48 - if (ch >= DEBUG && ch <= ERROR) { - if (ptr_level != nullptr) { - *ptr_level = static_cast(ch); - } - return true; - } - } - return false; -} - -static MsLogLevel GetGlobalLogLevel() { -#ifdef USE_GLOG - return static_cast(FLAGS_v); -#else - int log_level = WARNING; // set default log level to WARNING - auto str_level = GetEnv("GLOG_v"); - if (str_level.size() == 1) { - int ch = str_level.c_str()[0]; - ch = ch - '0'; // substract ASCII code of '0', which is 48 - if (ch >= DEBUG && ch <= ERROR) { - log_level = ch; - } - } - return static_cast(log_level); -#endif -} - -void InitSubModulesLogLevel() { - // initialize submodule's log level using global - auto global_log_level = GetGlobalLogLevel(); - for (int i = 0; i < NUM_SUBMODUES; ++i) { - g_ms_submodule_log_levels[i] = global_log_level; - } - - // set submodule's log level - auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); - MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; - LogConfigParser parser(submodule); - auto configs = parser.Parse(); - for (const auto &cfg : configs) { - int mod_idx = -1; - for (int i = 0; i < NUM_SUBMODUES; ++i) { - if (cfg.first == GetSubModuleName(static_cast(i))) { - mod_idx = i; - break; - } - } - if (mod_idx < 0) { - MS_LOG(WARNING) << "Undefined module name " << cfg.first << ", ignore it"; - continue; - } - MsLogLevel submodule_log_level; - if (!ParseLogLevel(cfg.second, &submodule_log_level)) { - MS_LOG(WARNING) << "Illegal log level value " << cfg.second << " for " << cfg.first << ", ignore it."; - continue; - } - g_ms_submodule_log_levels[mod_idx] = submodule_log_level; - } -} -} // namespace mindspore - -extern "C" { -#if defined(_WIN32) || defined(_WIN64) -__attribute__((constructor)) void common_log_init(void) { -#else -void common_log_init(void) { -#endif -#ifdef USE_GLOG - // do not use glog predefined log prefix - FLAGS_log_prefix = false; - // set default log level to WARNING - if (mindspore::GetEnv("GLOG_v").empty()) { - FLAGS_v = mindspore::WARNING; - } - - // set default log file mode to 0640 - if (mindspore::GetEnv("GLOG_logfile_mode").empty()) { - FLAGS_logfile_mode = 0640; - } - std::string logtostderr = mindspore::GetEnv("GLOG_logtostderr"); - // default print log to screen - if (logtostderr.empty()) { - FLAGS_logtostderr = true; - } else if (logtostderr == "0" && mindspore::GetEnv("GLOG_log_dir").empty()) { - FLAGS_logtostderr = true; - MS_LOG(WARNING) << "`GLOG_log_dir` is not set, output log to screen."; - } -#endif - mindspore::InitSubModulesLogLevel(); -} - -// shared lib init hook -#if defined(_WIN32) || defined(_WIN64) -__attribute__((constructor)) void mindspore_log_init(void) { -#else -void mindspore_log_init(void) { -#endif -#ifdef USE_GLOG - static bool is_glog_initialzed = false; - if (!is_glog_initialzed) { -#if !defined(_WIN32) && !defined(_WIN64) - google::InitGoogleLogging("mindspore"); -#endif - is_glog_initialzed = true; - } -#endif - common_log_init(); -} -} diff --git a/mindspore/ccsrc/utils/log_adapter.h b/mindspore/ccsrc/utils/log_adapter.h deleted file mode 100644 index a0e9bfc6d6..0000000000 --- a/mindspore/ccsrc/utils/log_adapter.h +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_LOG_ADAPTER_H_ -#define MINDSPORE_CCSRC_UTILS_LOG_ADAPTER_H_ - -#include -#include -#include -#include -#include -#include -#include "./overload.h" -#include "./securec.h" -#ifdef USE_GLOG -#include "glog/logging.h" -#else -#include "toolchain/slog.h" -#endif -// NOTICE: when relative path of 'log_adapter.h' changed, macro 'LOG_HDR_FILE_REL_PATH' must be changed -#define LOG_HDR_FILE_REL_PATH "mindspore/ccsrc/utils/log_adapter.h" - -// Get start index of file relative path in __FILE__ -static constexpr int GetRelPathPos() noexcept { - return sizeof(__FILE__) > sizeof(LOG_HDR_FILE_REL_PATH) ? sizeof(__FILE__) - sizeof(LOG_HDR_FILE_REL_PATH) : 0; -} - -namespace mindspore { -#define FILE_NAME \ - (sizeof(__FILE__) > GetRelPathPos() ? static_cast(__FILE__) + GetRelPathPos() \ - : static_cast(__FILE__)) -enum ExceptionType { - NoExceptionType = 0, - UnknownError, - ArgumentError, - NotSupportError, - NotExistsError, - AlreadyExistsError, - UnavailableError, - DeviceProcessError, - AbortedError, - TimeOutError, - ResourceUnavailable, - NoPermissionError, - IndexError, - ValueError, - TypeError, -}; - -struct LocationInfo { - LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {} - ~LocationInfo() = default; - - const char *file_; - int line_; - const char *func_; -}; - -class LogStream { - public: - LogStream() { sstream_ = std::make_shared(); } - ~LogStream() = default; - - template - LogStream &operator<<(const T &val) noexcept { - (*sstream_) << val; - return *this; - } - - LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept { - (*sstream_) << func; - return *this; - } - - friend class LogWriter; - - private: - std::shared_ptr sstream_; -}; - -template ::value, int>::type = 0> -constexpr std::ostream &operator<<(std::ostream &stream, const T &value) { - return stream << static_cast::type>(value); -} - -enum MsLogLevel : int { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION }; - -enum SubModuleId : int { - SM_UNKNOWN = 0, // unknown submodule - SM_BASE, // base - SM_ANALYZER, // static analyzer - SM_COMMON, // common - SM_DEBUG, // debug - SM_DEVICE, // device - SM_GE_ADPT, // ge adapter - SM_IR, // IR - SM_KERNEL, // kernel - SM_MD, // MindData - SM_ME, // MindExpression - SM_ONNX, // ONNX - SM_OPTIMIZER, // optimzer - SM_PARALLEL, // parallel - SM_PARSER, // parser - SM_PIPELINE, // ME pipeline - SM_PRE_ACT, // pre-activate - SM_PYNATIVE, // PyNative - SM_SESSION, // session - SM_UTILS, // utils - SM_VM, // VM - SM_ABSTRACT, // abstract - NUM_SUBMODUES // number of submodules -}; - -#ifndef SUBMODULE_ID -#define SUBMODULE_ID mindspore::SubModuleId::SM_ME -#endif - -#if defined(_WIN32) || defined(_WIN64) -extern int g_ms_submodule_log_levels[] __attribute__((dllexport)); -#else -extern int g_ms_submodule_log_levels[] __attribute__((visibility("default"))); -#endif - -class LogWriter { - public: - using ExceptionHandler = std::function; - - LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, - ExceptionType excp_type = NoExceptionType) - : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} - ~LogWriter() = default; - - void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); - void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); - - static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } - - private: - void OutputLog(const std::ostringstream &msg) const; - - LocationInfo location_; - MsLogLevel log_level_; - SubModuleId submodule_; - ExceptionType exception_type_; - - inline static ExceptionHandler exception_handler_ = nullptr; -}; - -#define MSLOG_IF(level, condition, excp_type) \ - static_cast(0), !(condition) \ - ? void(0) \ - : mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), level, \ - SUBMODULE_ID, excp_type) < mindspore::LogStream() -#define MSLOG_THROW(excp_type) \ - mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), mindspore::EXCEPTION, SUBMODULE_ID, \ - excp_type) ^ \ - mindspore::LogStream() - -#define IS_OUTPUT_ON(level) (level) >= mindspore::g_ms_submodule_log_levels[SUBMODULE_ID] - -#define MS_LOG(level) MS_LOG_##level - -#define MS_LOG_DEBUG MSLOG_IF(mindspore::DEBUG, IS_OUTPUT_ON(mindspore::DEBUG), mindspore::NoExceptionType) -#define MS_LOG_INFO MSLOG_IF(mindspore::INFO, IS_OUTPUT_ON(mindspore::INFO), mindspore::NoExceptionType) -#define MS_LOG_WARNING MSLOG_IF(mindspore::WARNING, IS_OUTPUT_ON(mindspore::WARNING), mindspore::NoExceptionType) -#define MS_LOG_ERROR MSLOG_IF(mindspore::ERROR, IS_OUTPUT_ON(mindspore::ERROR), mindspore::NoExceptionType) - -#define MS_LOG_EXCEPTION MSLOG_THROW(mindspore::NoExceptionType) -#define MS_EXCEPTION(type) MSLOG_THROW(type) -} // namespace mindspore - -#define MS_EXCEPTION_IF_NULL(ptr) \ - do { \ - if ((ptr) == nullptr) { \ - MS_LOG(EXCEPTION) << ": The pointer[" << #ptr << "] is null."; \ - } \ - } while (0) - -#ifdef DEBUG -#include -#define MS_ASSERT(f) assert(f) -#else -#define MS_ASSERT(f) ((void)0) -#endif - -#endif // MINDSPORE_CCSRC_UTILS_LOG_ADAPTER_H_ diff --git a/mindspore/ccsrc/utils/log_adapter_py.cc b/mindspore/ccsrc/utils/log_adapter_py.cc index c4793b960b..db086f37a7 100644 --- a/mindspore/ccsrc/utils/log_adapter_py.cc +++ b/mindspore/ccsrc/utils/log_adapter_py.cc @@ -18,6 +18,7 @@ #include #include "pybind11/pybind11.h" +#include "pybind_api/pybind_patch.h" namespace py = pybind11; namespace mindspore { @@ -38,6 +39,9 @@ class PyExceptionInitializer { if (exception_type == TypeError) { throw py::type_error(str); } + if (exception_type == AttributeError) { + throw py::attribute_error(str); + } py::pybind11_fail(str); } }; diff --git a/mindspore/ccsrc/utils/misc.cc b/mindspore/ccsrc/utils/misc.cc deleted file mode 100644 index a9eb8071ef..0000000000 --- a/mindspore/ccsrc/utils/misc.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/misc.h" - -namespace mindspore { - -const int RET_SUCCESS = 0; -const int RET_FAILED = 1; -const int RET_CONTINUE = 2; -const int RET_BREAK = 3; - -std::string demangle(const char *name) { - int status = -1; - std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; - return (status == 0) ? res.get() : name; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/misc.h b/mindspore/ccsrc/utils/misc.h deleted file mode 100644 index e2cdebe98a..0000000000 --- a/mindspore/ccsrc/utils/misc.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_MISC_H_ -#define MINDSPORE_CCSRC_UTILS_MISC_H_ - -#include -#include -#include -#include -#include - -#include "utils/log_adapter.h" - -namespace mindspore { - -extern const int RET_SUCCESS; -extern const int RET_FAILED; -extern const int RET_CONTINUE; -extern const int RET_BREAK; - -// demangle the name to make it human reablable. -extern std::string demangle(const char *name); - -} // namespace mindspore -#endif // MINDSPORE_CCSRC_UTILS_MISC_H_ diff --git a/mindspore/ccsrc/utils/node_strategy.proto b/mindspore/ccsrc/utils/node_strategy.proto index 8ec25f21a6..dc9d65407d 100644 --- a/mindspore/ccsrc/utils/node_strategy.proto +++ b/mindspore/ccsrc/utils/node_strategy.proto @@ -32,7 +32,36 @@ message ParallelStrategyItem { required ParallelStrategys parallel_strategys = 2; } +message DevMatrix { + repeated uint32 dim = 1; +} + +message TensorMap { + repeated int32 dim = 1; +} + +message ParamSplitShape { + repeated int64 dim = 1; +} + +message IndicesOffset { + repeated int64 dim = 1; +} + +message ParallelLayouts { + repeated DevMatrix dev_matrix = 1; + repeated TensorMap tensor_map = 2; + repeated ParamSplitShape param_split_shape = 3; + repeated IndicesOffset indices_offset = 4; +} + +message ParallelLayoutItem { + required string param_name = 1; + required ParallelLayouts parallel_layouts = 2; +} + message ParallelStrategyMap { required uint32 current_stage = 1; repeated ParallelStrategyItem parallel_strategy_item = 2; + repeated ParallelLayoutItem parallel_layout_item = 3; } \ No newline at end of file diff --git a/mindspore/ccsrc/utils/ordered_map.h b/mindspore/ccsrc/utils/ordered_map.h deleted file mode 100644 index 48aa36df31..0000000000 --- a/mindspore/ccsrc/utils/ordered_map.h +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_ORDERED_MAP_H_ -#define MINDSPORE_CCSRC_UTILS_ORDERED_MAP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" - -namespace mindspore { -// Implementation of OrderedMap that keeps insertion order -// using unordered_map to improve the performance of find/erase, and use list to keep insertion order -template , class Equal = std::equal_to> -class OrderedMap { - public: - using key_t = KeyT; - using value_t = ValueT; - using hasher = Hash; - using equal = Equal; - using pair_type = std::pair; - using sequential_type = std::list; - using iterator = typename sequential_type::iterator; - using const_iterator = typename sequential_type::const_iterator; - using reverse_iterator = typename sequential_type::reverse_iterator; - using const_reverse_iterator = typename sequential_type::const_reverse_iterator; - using map_type = std::unordered_map; - using value_type = typename sequential_type::value_type; - using size_type = typename sequential_type::size_type; - - iterator begin() { return sequential_data_.begin(); } - iterator end() { return sequential_data_.end(); } - const_iterator begin() const { return sequential_data_.cbegin(); } - const_iterator end() const { return sequential_data_.cend(); } - const_iterator cbegin() const { return sequential_data_.cbegin(); } - const_iterator cend() const { return sequential_data_.cend(); } - - reverse_iterator rbegin() { return sequential_data_.rbegin(); } - reverse_iterator rend() { return sequential_data_.rend(); } - const_reverse_iterator rbegin() const { return sequential_data_.rbegin(); } - const_reverse_iterator rend() const { return sequential_data_.rend(); } - - pair_type &front() { return sequential_data_.front(); } - const pair_type &front() const { return sequential_data_.front(); } - pair_type &back() { return sequential_data_.back(); } - const pair_type &back() const { return sequential_data_.back(); } - - OrderedMap() = default; - ~OrderedMap() = default; - OrderedMap(const OrderedMap &os) { - for (auto &item : os.sequential_data_) { - (void)insert(pair_type(item.first, item.second)); - } - } - - // Explicitly construct OrderedMap use sequential_type - explicit OrderedMap(const sequential_type &other) { - for (auto &item : other) { - (void)insert(pair_type(item.first, item.second)); - } - } - - OrderedMap &operator=(const OrderedMap &os) { - if (this != &os) { - for (auto &item : os.sequential_data_) { - (void)insert(pair_type(item.first, item.second)); - } - } - return *this; - } - - void clear() { - map_data_.clear(); - sequential_data_.clear(); - } - - void swap(OrderedMap &rhs) { - std::swap(map_data_, rhs.map_data_); - std::swap(sequential_data_, rhs.sequential_data_); - } - - void reserve(size_type num_entries) { - map_data_.reserve(num_entries); - sequential_data_.reserve(num_entries); - } - - std::pair add(const key_t &key) { - iterator empty_itr; - std::pair map_pair = std::make_pair(key, empty_itr); - std::pair result = map_data_.insert(map_pair); - auto &seq_itr = result.first->second; - if (result.second) { - auto it = sequential_data_.insert(sequential_data_.end(), std::make_pair(key, ValueT())); - seq_itr = it; - } - return std::pair(seq_itr, result.second); - } - - ValueT &operator[](const key_t &key) { - auto result = add(key); - return (*result.first).second; - } - - std::pair insert(const pair_type &kv) { - auto result = add(kv.first); - if (result.second) { - *(result.first) = kv.second; - return std::make_pair(std::prev(end()), true); - } - return std::make_pair(result.first, false); - } - - std::pair insert(pair_type &&kv) { - iterator empty_itr; - std::pair map_pair = std::make_pair(kv.first, empty_itr); - std::pair result = map_data_.insert(map_pair); - auto &seq_itr = result.first->second; - if (result.second) { - auto it = sequential_data_.insert(sequential_data_.end(), std::move(kv)); - seq_itr = it; - return std::make_pair(std::prev(end()), true); - } - return std::make_pair(seq_itr, false); - } - - bool empty() const { return sequential_data_.empty(); } - - size_type size() const { return sequential_data_.size(); } - - size_type count(const key_t &key) const { - auto pos = map_data_.find(key); - return pos == map_data_.end() ? 0 : 1; - } - - iterator find(const key_t &key) { - typename map_type::const_iterator pos = map_data_.find(key); - return pos == map_data_.end() ? sequential_data_.end() : (pos->second); - } - - const_iterator find(const key_t &key) const { - auto pos = map_data_.find(key); - return pos == map_data_.end() ? sequential_data_.end() : (pos->second); - } - - // Remove the last element from the sequential_data_. - void pop_back() { - typename map_type::iterator pos = map_data_.find(sequential_data_.back().first); - map_data_.erase(pos); - sequential_data_.pop_back(); - } - - // Remove the first element from the sequential_data_. - void pop_front() { - typename map_type::iterator pos = map_data_.find(sequential_data_.first().first); - map_data_.erase(pos); - sequential_data_.pop_front(); - } - - // Remove the element given by Iterator. - typename sequential_type::iterator erase(const typename sequential_type::iterator &itr) { - (void)map_data_.erase(itr->first); - auto next = sequential_data_.erase(itr); - if (next == sequential_data_.end()) return next; - return next; - } - - // Remove the element with the given key - size_type erase(const key_t &key) { - auto itr = find(key); - if (itr == end()) return 0; - (void)erase(itr); - return 1; - } - - private: - map_type map_data_; - sequential_type sequential_data_; -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_ORDERED_MAP_H_ diff --git a/mindspore/ccsrc/utils/ordered_set.h b/mindspore/ccsrc/utils/ordered_set.h deleted file mode 100644 index f393ce74f2..0000000000 --- a/mindspore/ccsrc/utils/ordered_set.h +++ /dev/null @@ -1,281 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_ORDERED_SET_H_ -#define MINDSPORE_CCSRC_UTILS_ORDERED_SET_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" - -namespace mindspore { - -// Implementation of OrderedSet that keeps insertion order -// using map as set, and use list as a sequential container to record elements to keep insertion order -template , class KeyEqual = std::equal_to> -class OrderedSet { - public: - using element_type = T; - using hasher = Hash; - using equal = KeyEqual; - using sequential_type = std::list; - using vector_type = std::vector; - using iterator = typename sequential_type::iterator; - using const_iterator = typename sequential_type::const_iterator; - using reverse_iterator = typename sequential_type::reverse_iterator; - using const_reverse_iterator = typename sequential_type::const_reverse_iterator; - using map_type = std::unordered_map; - using ordered_set_type = OrderedSet; - - OrderedSet() = default; - ~OrderedSet() = default; - // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, - // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use - // traversal to build elements. - OrderedSet(const OrderedSet &os) { - for (auto &item : os.ordered_data_) { - add(item); - } - } - - explicit OrderedSet(const sequential_type &other) { - for (auto &item : other) { - add(item); - } - } - - // Explicitly construct an OrderedSet use vector - explicit OrderedSet(const vector_type &other) { - for (auto &item : other) { - add(item); - } - } - - OrderedSet &operator=(const OrderedSet &os) { - if (this != &os) { - for (auto &item : os.ordered_data_) { - add(item); - } - } - return *this; - } - - // Add an element to the OrderedSet, without judging return value - void add(const element_type &e) { (void)insert(e); } - - // insert an element to the OrderedSet - std::pair insert(const element_type &e) { - iterator empty_itr; - std::pair map_pair = std::make_pair(e, empty_itr); - auto result = mapped_data_.insert(map_pair); - auto &seq_idx = result.first->second; - // if insert success; - if (result.second) { - auto it = ordered_data_.insert(ordered_data_.end(), e); - seq_idx = it; - } - return std::pair(seq_idx, result.second); - } - - // Remove an element, if removed return true, otherwise return false - bool erase(const element_type &e) { - auto pos = mapped_data_.find(e); - if (pos == mapped_data_.end()) { - return false; - } - // erase the sequential data first - (void)ordered_data_.erase(pos->second); - (void)mapped_data_.erase(pos); - return true; - } - - // Return the container size - std::size_t size() const { return mapped_data_.size(); } - - bool empty() const { return mapped_data_.size() == 0; } - - // Return the string contents in orderset, using ordered_data - std::string toString() { - std::ostringstream res; - res << "orderset content:\n"; - for (auto &item : ordered_data_) { - res << std::to_string(reinterpret_cast(item.get())) << " "; - } - return res.str(); - } - - // Clear the elements - void clear() { - mapped_data_.clear(); - ordered_data_.clear(); - } - - // Compare two orderedset, if the order is not equal shall return false - bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } - - // Remove and return the first element in the OrderedSet - T pop() { - if (ordered_data_.size() != 0) { - T res = ordered_data_.front(); - (void)mapped_data_.erase(res); - (void)ordered_data_.erase(ordered_data_.begin()); - return res; - } - MS_LOG(EXCEPTION) << "pop() on empty OrderedSet"; - } - - T back() { - if (ordered_data_.size() != 0) { - return ordered_data_.back(); - } - MS_LOG(EXCEPTION) << "back() on empty OrderedSet"; - } - - // Return true if there are no common elements - bool is_disjoint(const OrderedSet &other) { - for (auto &item : other.ordered_data_) { - if (mapped_data_.find(item) != mapped_data_.end()) { - return false; - } - } - return true; - } - - // Test whether this is subset of other - bool is_subset(const OrderedSet &other) { - for (auto &item : ordered_data_) { - if (other.mapped_data_.find(item) == other.mapped_data_.end()) { - return false; - } - } - return true; - } - - // Add elements in other to this orderedset - void update(const OrderedSet &other) { - for (auto &item : other.ordered_data_) { - add(item); - } - } - - void update(const std::shared_ptr &other) { update(*other); } - - void update(const sequential_type &other) { - for (auto &item : other) { - add(item); - } - } - - void update(const vector_type &other) { - for (auto &item : other) { - add(item); - } - } - - ordered_set_type get_union(const OrderedSet &other) { - ordered_set_type res(ordered_data_); - res.update(other); - return res; - } - - // Get the union with other set, this operator may cost time because of copy - ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } - - // Return the intersection of two sets - ordered_set_type intersection(const OrderedSet &other) { - ordered_set_type res(ordered_data_); - for (auto &item : ordered_data_) { - if (other.mapped_data_.find(item) == other.mapped_data_.end()) { - (void)res.erase(item); - } - } - return res; - } - ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } - - // Return the symmetric difference of two sets - ordered_set_type symmetric_difference(const OrderedSet &other) { - ordered_set_type res(ordered_data_); - for (auto &item : other.ordered_data_) { - if (mapped_data_.find(item) != mapped_data_.end()) { - (void)res.erase(item); - } else { - res.add(item); - } - } - return res; - } - - ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } - - // Remove elements which is also in others. - void difference_update(const OrderedSet &other) { - // use vector traversal, to keep ordrer - for (auto &item : other.ordered_data_) { - (void)erase(item); - } - } - - void difference_update(const sequential_type &other) { - for (auto &item : other) { - (void)erase(item); - } - } - - void difference_update(const vector_type &other) { - for (auto &item : other) { - (void)erase(item); - } - } - - // Return the set with elements that are not in the others - ordered_set_type difference(const OrderedSet &other) { - ordered_set_type res(ordered_data_); - res.difference_update(other); - return res; - } - ordered_set_type operator-(const OrderedSet &other) { return difference(other); } - - bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } - - // Return the count of an element in set - std::size_t count(const element_type &e) const { return mapped_data_.count(e); } - - iterator begin() { return ordered_data_.begin(); } - iterator end() { return ordered_data_.end(); } - - const_iterator begin() const { return ordered_data_.cbegin(); } - const_iterator end() const { return ordered_data_.cend(); } - - const_iterator cbegin() const { return ordered_data_.cbegin(); } - const_iterator cend() const { return ordered_data_.cend(); } - - private: - map_type mapped_data_; - sequential_type ordered_data_; -}; - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_ORDERED_SET_H_ diff --git a/mindspore/ccsrc/utils/overload.h b/mindspore/ccsrc/utils/overload.h deleted file mode 100644 index a95e285fc7..0000000000 --- a/mindspore/ccsrc/utils/overload.h +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_OVERLOAD_H_ -#define MINDSPORE_CCSRC_UTILS_OVERLOAD_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace mindspore { - -template -std::ostream &operator<<(std::ostream &out, const std::vector &v) { - out << "[const vector]["; - size_t last = v.size() - 1; - for (size_t i = 0; i < v.size(); ++i) { - out << v[i]; - if (i != last) out << ", "; - } - out << "]"; - return out; -} - -template -std::ostream &operator<<(std::ostream &os, const std::list &vec) { - bool begin = true; - os << "[const list]["; - for (auto &item : vec) { - if (!begin) { - os << ", "; - } else { - begin = false; - } - os << item; - } - os << "]"; - - return os; -} - -template -std::ostream &operator<<(std::ostream &os, const std::initializer_list &vec) { - bool begin = true; - os << "["; - for (auto &item : vec) { - if (!begin) { - os << ", "; - } else { - begin = false; - } - os << item; - } - os << "]"; - - return os; -} - -template -bool operator==(const std::initializer_list &lhs, const std::initializer_list &rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - auto lit = lhs.begin(); - auto rit = rhs.begin(); - while (lit != lhs.end()) { - if (!(*lit == *rit)) { - return false; - } - lit++; - rit++; - } - return true; -} - -template -std::ostream &operator<<(std::ostream &os, const std::pair &pair) { - os << "[const pair]"; - - return os; -} - -template -std::ostream &operator<<(std::ostream &os, const std::unordered_map &map) { - os << "[const unordered_map]"; - return os; -} - -template -std::ostream &operator<<(std::ostream &os, const std::map &map) { - os << "[const map]"; - return os; -} - -template -std::string ToString(const std::vector &vec) { - std::ostringstream buffer; - - buffer << vec; - return buffer.str(); -} - -template -std::string ToString(const std::unordered_map &map) { - std::ostringstream buffer; - - buffer << map; - return buffer.str(); -} - -template -std::string ToString(const std::map &map) { - std::ostringstream buffer; - - buffer << map; - return buffer.str(); -} - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_OVERLOAD_H_ diff --git a/mindspore/ccsrc/utils/param_value_py.cc b/mindspore/ccsrc/utils/param_value_py.cc new file mode 100644 index 0000000000..edf2061dc0 --- /dev/null +++ b/mindspore/ccsrc/utils/param_value_py.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ir/param_value.h" +#include "pybind11/pybind11.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +namespace py = pybind11; + +REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { + (void)py::class_(*m, "ParamValue") + .def(py::init()) + .def("clone", &ParamValue::Clone) + .def_property("name", &ParamValue::name, &ParamValue::set_name) + .def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) + .def_property("layerwise_parallel", &ParamValue::layerwise_parallel, + &ParamValue::set_layerwise_parallel) + .def(py::pickle( + [](const ParamValue &p) { // __getstate__ + return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 6) { + std::runtime_error("Invalid state for ParamValue!"); + } + ParamValuePtr p = std::make_shared(); + p->set_name(t[1].cast()); + p->set_requires_grad(t[2].cast()); + p->set_layerwise_parallel(t[3].cast()); + return p; + })); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/primitive_py.cc b/mindspore/ccsrc/utils/primitive_py.cc new file mode 100644 index 0000000000..ea15abdce9 --- /dev/null +++ b/mindspore/ccsrc/utils/primitive_py.cc @@ -0,0 +1,228 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/primitive_py.h" +#include +#include "ir/signature.h" +#include "./common.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pybind11/pytypes.h" +#include "utils/convert_utils_base.h" +#include "utils/primitive_utils.h" +#include "utils/base_ref_py.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace { +constexpr auto kBpropAttrName = "bprop"; +constexpr auto kCellHookAttrName = "cell_hook"; +constexpr auto kCellIDAttrName = "cell_id"; +void SyncData(const py::object &arg) { + if (py::isinstance(arg)) { + py::tuple arg_list = py::cast(arg); + for (size_t i = 0; i < arg_list.size(); i++) { + SyncData(arg_list[i]); + } + } + if (py::isinstance(arg)) { + auto tensor = py::cast(arg); + (void)tensor->data_sync(); + } +} +} // namespace +std::map PrimitivePy::hook_grad_; +static ValuePtr PyArgToValue(const py::object &arg) { + if (py::isinstance(arg) && + py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { + return nullptr; + } + return parse::data_converter::PyDataToValue(arg); +} + +void PrimitivePy::set_signatures( + std::vector> signatures) { + signatures_.clear(); + for (auto &signature : signatures) { + auto [name, rw, kind, arg_default, dtype] = signature; + auto default_value = PyArgToValue(arg_default); + signatures_.emplace_back(name, rw, kind, default_value, dtype); + } + set_has_signature(true); +} + +py::function PrimitivePy::GetBpropFunction() { + static const char *const get_bprop_func_name = "get_bprop"; + if (py::hasattr(python_obj_, get_bprop_func_name)) { + py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); + return fn; + } else { + auto fn = GetBpropFunctionByObj(python_obj_); + return fn; + } +} + +BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { + auto py_args = ConvertDatatoPyTuple(args); + py::object obj; + bool is_bprop = this->HasAttr(kBpropAttrName); + if (is_bprop) { + SyncData(py_args); + obj = hook_(*py_args); + return std::make_shared(obj); + } + SyncData(py_args[2]); + bool is_cell = this->HasAttr(kCellHookAttrName); + if (is_cell) { + auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); + auto iter = hook_grad_.find(cell_id); + if (iter != hook_grad_.end()) { + auto hook_args = py::tuple(3); + hook_args[0] = cell_id; + hook_args[1] = py::make_tuple(iter->second); + hook_args[2] = py::make_tuple(py_args[2]); + obj = hook_(*hook_args); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + hook_grad_.erase(cell_id); + } else { + hook_grad_[cell_id] = py_args[2]; + obj = py_args[2]; + } + } else { + // Hook operator for execute variable hook function + obj = hook_(py::make_tuple(py_args[2])); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + } + obj = py::make_tuple(obj); + return std::make_shared(obj); +} + +py::function PrimitivePy::GetComputeFunction() const { + static const char *const compute_func_name = "vm_impl"; + + if (py::hasattr(python_obj_, compute_func_name)) { + MS_LOG(INFO) << name() << " compute_func_name"; + py::function fn = python_obj_.attr(compute_func_name).cast(); + return fn; + } + + static const std::string vm_module = "mindspore.ops.vm_impl_registry"; + static const std::string get_vm_impl_fn = "get_vm_impl_fn"; + MS_LOG(INFO) << name() << ": get_vm_impl_fn"; + py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); + py::function vm_fn = get_fn(python_obj_); + if (py::isinstance(vm_fn)) { + MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); + vm_fn = mindspore::GetComputeFunction(Primitive::name()); + } + return vm_fn; +} + +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { + std::string attr_name = name; + ValuePtr converted_ret = nullptr; + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; + } + bool converted = parse::ConvertData(obj, &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); + } + (void)this->AddAttr(attr_name, converted_ret); +} + +py::dict PrimitivePy::GetAttrDict() { + py::dict attr_dict; + for (auto &attr : attrs_) { + attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); + } + return attr_dict; +} + +void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->isa()) { + MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; + } + auto primitive_py = primitive->cast(); + MS_EXCEPTION_IF_NULL(primitive_py); + this->set_hook(primitive_py->hook()); +} + +BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const { + auto py_args = ConvertDatatoPyTuple(args); + auto result = this->RunPyComputeFunction(py_args); + if (py::isinstance(result)) { + return std::make_shared(nullptr); + } + return std::make_shared(result); +} + +py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { + auto func = this->GetComputeFunction(); + if (py::isinstance(func)) { + return py::none(); + } + auto result = func(*py_args); + return result; +} + +bool PrimitivePy::HasComputeFunction() const { + auto func = GetComputeFunction(); + if (py::isinstance(func)) { + return false; + } + return true; +} + +PrimitivePtr PrimitivePy::Clone() { + auto clone_fn = python_obj_.attr("_clone"); + py::object new_obj = clone_fn(); + auto cloned_prim = new_obj.cast(); + return cloned_prim; +} + +py::dict PrimitivePy::RunInfer(const py::tuple &args) { + if (!HasPyObj()) { + MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; + } + auto infer_fuc = python_obj_.attr("__infer__"); + return infer_fuc(*args); +} + +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { + (void)py::enum_(*m, "prim_type", py::arithmetic()) + .value("unknown", PrimType::kPrimTypeUnknown) + .value("builtin", PrimType::kPrimTypeBuiltIn) + .value("py_infer_shape", PrimType::kPrimTypePyInferShape) + .value("user_custom", PrimType::kPrimTypeUserCustom); + (void)py::class_>(*m, "Primitive_") + .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) + .def(py::init()) + .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") + .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") + .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") + .def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.") + .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") + .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") + .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/primitive_py.h b/mindspore/ccsrc/utils/primitive_py.h new file mode 100644 index 0000000000..f519b1a080 --- /dev/null +++ b/mindspore/ccsrc/utils/primitive_py.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ +#define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ + +#include +#include +#include +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "utils/misc.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" +#include "ir/primitive.h" +#include "ir/signature.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace py = pybind11; +namespace mindspore { +class PrimitivePy : public Primitive { + public: + PrimitivePy(const py::str &name, const py::object &python_obj) + : Primitive(name, false), python_obj_(python_obj), signatures_() {} + ~PrimitivePy() override = default; + MS_DECLARE_PARENT(PrimitivePy, Primitive); + py::function GetBpropFunction(); + + void set_signatures( + std::vector> + signatures); + + const std::vector &signatures() const { return signatures_; } + + void CopyHookFunction(const PrimitivePtr &primitive) override; + + void AddPyAttr(const py::str &name, const py::object &obj); + + py::dict GetAttrDict(); + void set_hook(const py::function &hook) { hook_ = hook; } + py::function hook() const { return hook_; } + BaseRef RunHookFunction(const VectorRef &args) const override; + BaseRef RunComputeFunction(const VectorRef &args) const override; + py::object RunPyComputeFunction(const py::tuple &py_args) const; + bool HasComputeFunction() const; + const bool parse_info_ = true; + const py::object &GetPyObj() const { return python_obj_; } + py::dict RunInfer(const py::tuple &args); + bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } + bool HasPyObj() { return python_obj_ != nullptr; } + PrimitivePtr Clone() override; + bool is_tuple_input_ = false; + + private: + py::function GetComputeFunction() const; + py::object python_obj_; + py::function hook_; + std::vector signatures_; + static std::map hook_grad_; +}; + +using PrimitivePyPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc index 490e2517a9..956041d8fb 100644 --- a/mindspore/ccsrc/utils/primitive_utils.cc +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -15,9 +15,12 @@ */ #include "utils/primitive_utils.h" + +#include + #include "pipeline/jit/parse/python_adapter.h" #include "utils/log_adapter.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { py::function GetBpropFunctionByObj(py::object obj) { @@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) { py::object fn = mod.attr(common::SafeCStr(name)); return fn; } + +py::tuple ConvertDatatoPyTuple(const VectorRef &args) { + auto py_args = py::tuple(args.size()); + size_t i = 0; + for (auto &arg : args) { + py_args[i] = BaseRefToPyData(arg); + MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString(); + i++; + } + return py_args; +} + +BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) { + auto func = GetComputeFunction(prim->name()); + if (py::isinstance(func)) { + MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented"; + } + auto py_args = ConvertDatatoPyTuple(args); + py::object obj = func(*py_args); + return std::make_shared(obj); +} } // namespace mindspore diff --git a/mindspore/ccsrc/utils/primitive_utils.h b/mindspore/ccsrc/utils/primitive_utils.h index b7e2515aea..cb23e535b0 100644 --- a/mindspore/ccsrc/utils/primitive_utils.h +++ b/mindspore/ccsrc/utils/primitive_utils.h @@ -19,6 +19,8 @@ #include #include "pybind11/pybind11.h" +#include "base/base_ref.h" +#include "utils/convert_utils.h" namespace py = pybind11; @@ -28,6 +30,10 @@ py::function GetBpropFunctionByObj(py::object obj); py::function GetBpropFunction(std::string name); py::function GetComputeFunction(std::string name); + +BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args); + +py::tuple ConvertDatatoPyTuple(const VectorRef &args); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ diff --git a/mindspore/ccsrc/utils/profile.cc b/mindspore/ccsrc/utils/profile.cc deleted file mode 100644 index 9fb9dc9f1a..0000000000 --- a/mindspore/ccsrc/utils/profile.cc +++ /dev/null @@ -1,366 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/profile.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" - -namespace mindspore { - -namespace { -constexpr size_t TIME_INFO_PREFIX_NUM_LEN = 4; -const char KEY_PROF_TOTAL[] = "__total__"; - -void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent = 0, - std::map *sums = nullptr, const std::string &prefix = ""); - -void PrintTimeInfoMap(std::ostringstream &oss, const TimeInfoMap &dict, int indent = 0, - std::map *sums = nullptr, const std::string &prefix = "") { - size_t count = 0; - for (const auto &iter : dict) { - count++; - if (iter.second == nullptr) { - continue; - } - // indent by multiples of 4 spaces. - if (iter.first.size() < TIME_INFO_PREFIX_NUM_LEN) { - MS_LOG(EXCEPTION) << "In TimeInfoMap, the " << count << "th string key is " << iter.first - << ", but the length is less than " << TIME_INFO_PREFIX_NUM_LEN; - } - auto name = iter.first.substr(TIME_INFO_PREFIX_NUM_LEN); - oss << std::setw(indent * 4) << "" - << "[" << name << "]: " << iter.second->time_; - if (iter.second->dict_ != nullptr) { - oss << ", [" << iter.second->dict_->size() << "]"; - } - oss << "\n"; - - std::string newPrefix = prefix; - if (iter.first.find("Cycle ") == std::string::npos) { - newPrefix = prefix.empty() ? iter.first : prefix + "." + iter.first; - } - PrintProfile(oss, *iter.second, indent + 1, sums, newPrefix); - if (iter.second->dict_ == nullptr) { - (*sums)[newPrefix] += iter.second->time_; - } - } -} - -void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent, std::map *sums, - const std::string &prefix) { - bool need_free = false; - if (sums == nullptr) { - sums = new (std::nothrow) std::map(); - if (sums == nullptr) { - MS_LOG(ERROR) << "memory allocation failed"; - return; - } - need_free = true; - } - - // indent by multiples of 4 spaces. - if (indent == 0) { - oss << "TotalTime = " << time_info.time_; - if (time_info.dict_ != nullptr) { - oss << ", [" << time_info.dict_->size() << "]"; - } - oss << "\n"; - } - - if (time_info.dict_ != nullptr) { - PrintTimeInfoMap(oss, *time_info.dict_, indent, sums, prefix); - } - - // print time percentage info - if (need_free) { - double total = 0.0; - for (auto iter = sums->begin(); iter != sums->end(); ++iter) { - total += iter->second; - } - oss << "Sums\n"; - if (total >= 0.0 + DBL_EPSILON) { - for (auto &iter : *sums) { - std::string name = iter.first; - name.erase(0, TIME_INFO_PREFIX_NUM_LEN); - std::size_t pos = 0; - while ((pos = name.find('.', pos)) != std::string::npos) { - pos++; - name.erase(pos, TIME_INFO_PREFIX_NUM_LEN); - } - oss << " " << std::left << std::setw(36) << name << " : " << std::right << std::setw(12) << std::fixed - << std::setprecision(6) << iter.second << "s : " << std::right << std::setw(5) << std::fixed - << std::setprecision(2) << iter.second / total * 100 << "%\n"; - } - } - delete sums; - } -} -} // namespace - -double GetTime(void) { - struct timeval tv = {0, 0}; - (void)gettimeofday(&tv, nullptr); - return tv.tv_sec + tv.tv_usec * 1.0e-6; -} - -TimeInfo::~TimeInfo() { - if (dict_ == nullptr) { - return; - } - for (auto iter = dict_->begin(); iter != dict_->end(); ++iter) { - delete iter->second; - iter->second = nullptr; - } - delete dict_; - dict_ = nullptr; -} - -ProfileBase::ProfileBase() : context_("", this) { - ctx_ptr_ = &context_; - context_.parent_ = nullptr; -} - -ProfileBase::~ProfileBase() { - context_.parent_ = nullptr; - if (context_.time_info_ != nullptr) { - delete context_.time_info_; - context_.time_info_ = nullptr; - } - ctx_ptr_ = nullptr; -} - -void Profile::Print(void) { - if (ctx_ptr_ == nullptr || ctx_ptr_->time_info_ == nullptr) { - return; - } - std::ostringstream oss; - PrintProfile(oss, *ctx_ptr_->time_info_); - std::string text = oss.str(); - // here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace - (void)printf("%s", text.c_str()); - (void)fflush(stdout); -} - -// Start a step in the current context with the given name. -// Nomes must be unique otherwise the previous record will be overwritten. -ProfContext *Profile::Step(const std::string &name) { - ctx_ptr_ = new (std::nothrow) ProfContext(name, this); - if (ctx_ptr_ == nullptr) { - MS_LOG(ERROR) << "memory allocation failed"; - return nullptr; - } - return ctx_ptr_; -} - -// Creates subcontext for a repeated action. -// Count should be monotonically increasing. -ProfContext *Profile::Lap(int count) { - std::ostringstream oss; - oss << "Cycle " << count; - ctx_ptr_ = new (std::nothrow) ProfContext(oss.str(), this); - if (ctx_ptr_ == nullptr) { - MS_LOG(ERROR) << "memory allocation failed"; - return nullptr; - } - return ctx_ptr_; -} - -void Profile::Pop(void) noexcept { - if (ctx_ptr_ == nullptr) { - return; - } - ctx_ptr_ = ctx_ptr_->parent_; -} - -ProfContext::ProfContext(const std::string &name, ProfileBase *const prof) : name_(name), prof_(prof) { - // Initialize a subcontext. - time_info_ = nullptr; - if (prof == nullptr || IsTopContext()) { - parent_ = nullptr; - } else { - parent_ = prof->ctx_ptr_; - } -} - -ProfContext::~ProfContext() { - // top level context - if (parent_ == nullptr || IsTopContext()) { - if (time_info_ != nullptr) { - delete time_info_; - } - } else { - parent_->Insert(name_, time_info_); - if (prof_ != nullptr) { - prof_->Pop(); - } - } - - time_info_ = nullptr; - prof_ = nullptr; - parent_ = nullptr; -} - -void ProfContext::SetTime(double time) noexcept { - if (time_info_ == nullptr) { - time_info_ = new (std::nothrow) TimeInfo(time); - if (time_info_ == nullptr) { - MS_LOG(ERROR) << "memory allocation failed"; - return; - } - } - time_info_->time_ = time; -} - -void ProfContext::Insert(const std::string &name, const TimeInfo *time) noexcept { - if (time_info_ == nullptr) { - time_info_ = new (std::nothrow) TimeInfo(); - if (time_info_ == nullptr) { - MS_LOG(ERROR) << "memory allocation failed"; - delete time; - time = nullptr; - return; - } - } - - if (time_info_->dict_ == nullptr) { - time_info_->dict_ = new (std::nothrow) TimeInfoMap(); - if (time_info_->dict_ == nullptr) { - MS_LOG(ERROR) << "memory allocation failed"; - delete time; - time = nullptr; - delete time_info_; - time_info_ = nullptr; - return; - } - } - - std::stringstream ss; - ss << std::setw(TIME_INFO_PREFIX_NUM_LEN) << std::setfill('0') << time_info_->actionNum_; - std::string sorted_name(ss.str() + name); - time_info_->actionNum_++; - auto iter = time_info_->dict_->find(sorted_name); - // if contains item with same name, delete it - if (iter != time_info_->dict_->end()) { - delete iter->second; - iter->second = nullptr; - (void)time_info_->dict_->erase(iter); - } - (*time_info_->dict_)[sorted_name] = time; -} - -bool ProfContext::IsTopContext() const noexcept { return (prof_ != nullptr) && (this == &prof_->context_); } - -ProfTransaction::ProfTransaction(const ProfileBase *prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } - -ProfTransaction::~ProfTransaction() { - if (ctx_ != nullptr && !ctx_->IsTopContext()) { - delete ctx_; - } - ctx_ = nullptr; -} - -void DumpTime::Record(const std::string &step_name, const double time, const bool is_start) { - file_ss_ << " {" << std::endl; - file_ss_ << " \"name\": " - << "\"" << step_name << "\"," << std::endl; - file_ss_ << " \"cat\": " - << "\"FUNCTION\"," << std::endl; - if (is_start) { - file_ss_ << " \"ph\": " - << "\"B\"," << std::endl; - } else { - file_ss_ << " \"ph\": " - << "\"E\"," << std::endl; - } - file_ss_ << " \"ts\": " << std::setprecision(16) << time * 1000000 << "," << std::endl; - file_ss_ << " \"pid\": " - << "1" << std::endl; - file_ss_ << " }" << std::endl; - file_ss_ << " ," << std::endl; -} - -void DumpTime::Save() { - try { - file_out_.open(file_path_, std::ios::trunc | std::ios::out); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Cannot open file in " << (file_path_); - } - file_out_ << "{\n"; - file_out_ << " \"traceEvents\": [" << std::endl; - file_ss_ >> file_out_.rdbuf(); - (void)file_out_.seekp(-7, std::ios::end); - file_out_ << " ]" << std::endl << " ,\n"; - file_out_ << " \"displayTimeUnit\": \"ms\"" << std::endl; - file_out_ << "}"; - file_out_.close(); -} - -struct TimeInfoGroup { - double total_time = 0.0; - int total_count = 0; - std::list::const_iterator> items; -}; - -static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, const std::string &prefix) { - oss << "------[" << prefix << "] " << std::setw(10) << std::fixed << std::setprecision(6) << group.total_time - << std::setw(6) << group.total_count << "\n"; - for (const auto &iter : group.items) { - oss << std::setw(5) << std::fixed << std::setprecision(2) << iter->second.time_ / group.total_time * 100 - << "% : " << std::setw(12) << std::fixed << std::setprecision(6) << iter->second.time_ << "s : " << std::setw(6) - << iter->second.count_ << ": " << iter->first << "\n"; - } -} - -void MsProfile::Print() { - GetProfile()->Print(); - std::vector items = {"substitution.", "renormalize.", "replace.", "match.", - "func_graph_cloner_run.", "meta_graph.", "manager."}; - std::vector groups(items.size() + 1); - const auto &stat = GetSingleton().time_stat_; - // group all time infos - for (auto iter = stat.cbegin(); iter != stat.cend(); ++iter) { - auto matched_idx = items.size(); - for (size_t i = 0; i < items.size(); ++i) { - if (iter->first.find(items[i]) != std::string::npos) { - matched_idx = i; - break; - } - } - groups[matched_idx].total_time += iter->second.time_; - groups[matched_idx].total_count += iter->second.count_; - groups[matched_idx].items.push_back(iter); - } - std::ostringstream oss; - for (size_t i = 0; i < groups.size(); ++i) { - std::string prefix = (i < items.size() ? items[i] : std::string("others.")); - PrintTimeStat(oss, groups[i], prefix); - } - std::string text = oss.str(); - // here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace - (void)printf("\nTime group info:\n%s", text.c_str()); - (void)fflush(stdout); -} - -} // namespace mindspore diff --git a/mindspore/ccsrc/utils/profile.h b/mindspore/ccsrc/utils/profile.h deleted file mode 100644 index bd3723d5bb..0000000000 --- a/mindspore/ccsrc/utils/profile.h +++ /dev/null @@ -1,234 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_PROFILE_H_ -#define MINDSPORE_CCSRC_UTILS_PROFILE_H_ - -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" - -namespace mindspore { - -struct TimeInfo; -using TimeInfoMap = std::map; - -extern double GetTime(); - -class ProfileBase; - -struct TimeInfo { - explicit TimeInfo(double time = -1.0) : time_(time), dict_(nullptr), actionNum_(0) {} - TimeInfo(const TimeInfo &) = delete; - ~TimeInfo(); - - double time_; - TimeInfoMap *dict_; - size_t actionNum_; -}; - -// Utility class for Profile. -class ProfContext { - friend class Profile; - friend class ProfileBase; - friend class ProfTransaction; - - public: - ProfContext(const std::string &name, ProfileBase *prof); - ~ProfContext(); - - ProfContext(const ProfContext &) = delete; - ProfContext &operator=(const ProfContext &) = delete; - - void SetTime(double time) noexcept; - void Insert(const std::string &name, const TimeInfo *time) noexcept; - bool IsTopContext() const noexcept; - - private: - std::string name_; - ProfileBase *prof_; - ProfContext *parent_; - TimeInfo *time_info_; -}; - -class ProfileBase { - friend class ProfContext; - friend class ProfTransaction; - - public: - ProfileBase(); - virtual ~ProfileBase(); - - virtual void Print(void) {} - virtual ProfContext *Step(const std::string &) { return nullptr; } - virtual ProfContext *Lap(int) { return nullptr; } - virtual void Pop(void) {} - - // top level profile context - ProfContext context_; - // profile context pointer, act as a stack pointer - ProfContext *ctx_ptr_ = nullptr; -}; - -class Profile : public ProfileBase { - public: - Profile() = default; - ~Profile() override = default; - Profile(const Profile &) = delete; - Profile &operator=(const Profile &) = delete; - - void Print(void) override; - ProfContext *Step(const std::string &name) override; - ProfContext *Lap(int count) override; - void Pop(void) noexcept override; -}; - -class ProfTransaction { - public: - explicit ProfTransaction(const ProfileBase *prof); - explicit ProfTransaction(ProfContext *const ctx) : ctx_(ctx) {} - ProfTransaction(const ProfTransaction &) = delete; - ~ProfTransaction(); - - template - void operator-(const Function &func) { - double start_time = GetTime(); - func(); - double end_time = GetTime(); - if (ctx_ != nullptr) { - ctx_->SetTime(end_time - start_time); - } - } - - private: - ProfContext *ctx_ = nullptr; -}; - -class NoProfTransaction { - public: - explicit NoProfTransaction(ProfileBase *prof) {} - explicit NoProfTransaction(ProfContext *ctx) {} - ~NoProfTransaction() = default; - - template - void operator-(const Function &func) { - func(); - } -}; - -class DumpTime { - public: - ~DumpTime() { - try { - Save(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Cannot save file by profile::DumpTime::save"; - } catch (...) { - MS_LOG(ERROR) << "Uncaught exception"; - } - } - DumpTime(const DumpTime &) = delete; - DumpTime &operator=(const DumpTime &) = delete; - static DumpTime &GetInstance() { - static DumpTime instance; - return instance; - } - void set_file_path(const std::string &save_path) { file_path_ = save_path; } - void Record(const std::string &name, const double time, const bool is_start); - void Save(); - - private: - DumpTime() = default; - std::stringstream file_ss_; - std::ofstream file_out_; - std::string file_path_ = "./timeline.json"; -}; - -struct TimeStat { - TimeStat() { - time_ = 0.0; - count_ = 0; - } - ~TimeStat() = default; - - void operator+=(double t) { - time_ += t; - count_ += 1; - } - - TimeStat operator+(double t) { - TimeStat ts = *this; - ts += t; - return ts; - } - - double time_; - int count_; -}; - -class MsProfile { - public: - ~MsProfile() { Clear(); } - - static void Reset() { GetSingleton().Clear(); } - - static ProfileBase *GetProfile() { - MsProfile &ms_prof = GetSingleton(); - if (ms_prof.profile_ == nullptr) { -#ifdef ENABLE_PROFILE - ms_prof.profile_ = new Profile(); -#else - ms_prof.profile_ = new ProfileBase(); -#endif - } - return ms_prof.profile_; - } - static void StatTime(const std::string &id, double time) { GetSingleton().time_stat_[id] += time; } - - static void Print(); - - private: - MsProfile() = default; - - static MsProfile &GetSingleton() { - static MsProfile profile; - return profile; - } - - void Clear() { - time_stat_.clear(); - if (profile_ != nullptr) { - delete profile_; - profile_ = nullptr; - } - } - - std::map time_stat_; // record time and count info from some activity - ProfileBase *profile_ = nullptr; // record hierarchical profile info -}; - -} // namespace mindspore - -#ifdef ENABLE_PROFILE -#define WITH(x) ProfTransaction(x) - -#else -#define WITH(x) NoProfTransaction(x) - -#endif - -#endif // MINDSPORE_CCSRC_UTILS_PROFILE_H_ diff --git a/mindspore/ccsrc/utils/signal.h b/mindspore/ccsrc/utils/signal.h deleted file mode 100644 index 9a43e23814..0000000000 --- a/mindspore/ccsrc/utils/signal.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_SIGNAL_H_ -#define MINDSPORE_CCSRC_UTILS_SIGNAL_H_ - -#include -#include -#include -#include - -namespace mindspore { -template -std::function bind_member(Type *instance, Return (Type::*method)(Args...)) { - return [=](Args &&... args) -> Return { return (instance->*method)(std::forward(args)...); }; -} - -template -class Slot { - public: - explicit Slot(const std::function &callback) : callback(callback) {} - - ~Slot() {} - - std::function callback = nullptr; -}; - -template -class Signal { - public: - template - void operator()(Args &&... args) { - for (auto &slot : slots_) { - if (slot->callback != nullptr) { - slot->callback(std::forward(args)...); - } - } - } - - void add_slot(const std::function &func) { - auto slot = std::make_shared>(func); - slots_.push_back(slot); - } - - // signal connect to a class member func - template - void connect(InstanceType instance, MemberFuncType func) { - add_slot(bind_member(instance, func)); - } - - private: - std::vector>> slots_; -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_EVENT_H_ diff --git a/mindspore/core/ir/signature_py.cc b/mindspore/ccsrc/utils/signature_py.cc similarity index 100% rename from mindspore/core/ir/signature_py.cc rename to mindspore/ccsrc/utils/signature_py.cc diff --git a/mindspore/ccsrc/utils/symbolic.h b/mindspore/ccsrc/utils/symbolic.h deleted file mode 100644 index ca68b2c877..0000000000 --- a/mindspore/ccsrc/utils/symbolic.h +++ /dev/null @@ -1,177 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_ -#define MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_ - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "abstract/abstract_value.h" -#include "utils/any.h" - -namespace mindspore { - -class SymbolicKeyInstance : public Value { - public: - SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) - : node_(node), abstract_(abstract) {} - ~SymbolicKeyInstance() override = default; - MS_DECLARE_PARENT(SymbolicKeyInstance, Value); - AnfNodePtr node() const { return node_; } - abstract::AbstractBasePtr abstract() const { return abstract_; } - bool operator==(const SymbolicKeyInstance &other) const { - return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); - } - - std::size_t hash() const override { return std::hash{}(node_); } - friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { - if (inst == nullptr) { - os << "[Key][" - << "Invalid symbolic key instance" - << "]"; - } else { - os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString(); - } - return os; - } - std::string ToString() const override { - return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); - } - bool operator==(const Value &other) const override { - if (other.isa()) { - auto other_ = static_cast(other); - return *this == other_; - } else { - return false; - } - } - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), - std::make_shared()); - } - - private: - AnfNodePtr node_; - abstract::AbstractBasePtr abstract_; -}; - -using SymbolicKeyInstancePtr = std::shared_ptr; - -struct SymbolicKeyInstanceHash { - std::size_t operator()(const SymbolicKeyInstancePtr s) const { - if (s == nullptr) { - return 0; - } - return s->abstract()->hash(); - } -}; - -struct SymbolicKeyInstanceEqual { - bool operator()(const SymbolicKeyInstancePtr lhs, const SymbolicKeyInstancePtr rhs) const { - if (lhs == nullptr || rhs == nullptr) { - return false; - } - MS_EXCEPTION_IF_NULL(lhs->node()); - MS_EXCEPTION_IF_NULL(rhs->node()); - MS_EXCEPTION_IF_NULL(lhs->abstract()); - MS_EXCEPTION_IF_NULL(rhs->abstract()); - return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract()); - } -}; - -using EnvInstanceContentsMap = - std::unordered_map; - -// Environment mapping keys to values. -// Keys are SymbolicKeyInstances, which represent nodes in the graph along -// with inferred properties. -class EnvInstance : public Value { - public: - friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); - - explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} - ~EnvInstance() override = default; - MS_DECLARE_PARENT(EnvInstance, Value); - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } - bool operator==(const EnvInstance &other) const; - bool operator==(const Value &other) const override; - EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} - EnvInstance(EnvInstance &&v) = default; - EnvInstance &operator=(EnvInstance &&src) noexcept { - if (&src != this) { - contents_ = src.contents_; - } - return *this; - }; - - // Get the sensitivity list for the given key - const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { - auto iterator = contents_.find(key); - if (iterator != contents_.end()) { - return iterator->second; - } - return def; - } - - // Set a value for the given key. - EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { - EnvInstance rval(contents_); - rval.contents_[key] = value; - return rval; - } - - // Add two EnvInstances. - EnvInstance Add(const EnvInstance &other) const { - EnvInstance rval(contents_); - for (auto iter_other : other.contents_) { - auto item_self = contents_.find(iter_other.first); - if (item_self != contents_.end()) { - MS_LOG(DEBUG) << "Need to use add"; - } else { - rval.contents_[iter_other.first] = iter_other.second; - } - } - return rval; - } - - size_t Len() const { return contents_.size(); } - std::size_t hash() const override { - // deterministic characteristic of member variables. - return Len(); - } - - const bool parse_info_ = true; - - private: - EnvInstanceContentsMap contents_; -}; - -using EnvInstancePtr = std::shared_ptr; - -extern std::shared_ptr newenv; - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_ diff --git a/mindspore/ccsrc/utils/tensor_py.cc b/mindspore/ccsrc/utils/tensor_py.cc new file mode 100644 index 0000000000..606b95527d --- /dev/null +++ b/mindspore/ccsrc/utils/tensor_py.cc @@ -0,0 +1,386 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tensor_py.h" + +#include +#include +#include +#include +#include + +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace tensor { +static TypeId GetDataType(const py::buffer_info &buf) { + if (buf.format.size() == 1) { + switch (buf.format.front()) { + case 'e': + case 'f': + case 'd': + switch (buf.itemsize) { + case 2: + return TypeId::kNumberTypeFloat16; + case 4: + return TypeId::kNumberTypeFloat32; + case 8: + return TypeId::kNumberTypeFloat64; + } + break; + case 'b': + case 'h': + case 'i': + case 'l': + case 'q': + switch (buf.itemsize) { + case 1: + return TypeId::kNumberTypeInt8; + case 2: + return TypeId::kNumberTypeInt16; + case 4: + return TypeId::kNumberTypeInt32; + case 8: + return TypeId::kNumberTypeInt64; + } + break; + case 'B': + case 'H': + case 'I': + case 'L': + case 'Q': + switch (buf.itemsize) { + case 1: + return TypeId::kNumberTypeUInt8; + case 2: + return TypeId::kNumberTypeUInt16; + case 4: + return TypeId::kNumberTypeUInt32; + case 8: + return TypeId::kNumberTypeUInt64; + } + break; + case '?': + return TypeId::kNumberTypeBool; + } + } + MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; + return TypeId::kTypeUnknown; +} + +static std::string GetPyTypeFormat(TypeId data_type) { + switch (data_type) { + case TypeId::kNumberTypeFloat16: + return "e"; + case TypeId::kNumberTypeFloat32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeFloat64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt8: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt16: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt8: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt16: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeBool: + return py::format_descriptor::format(); + default: + MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; + return ""; + } +} + +static bool IsCContiguous(const py::array &input) { + auto flags = static_cast(input.flags()); + return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; +} + +TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { + // Get input buffer info. + py::buffer_info buf = input.request(); + // Check data types. + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; + auto buf_type = GetDataType(buf); + if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { + MS_LOG(EXCEPTION) << "Unsupported tensor type!"; + } + // Use buf type as data type if type_ptr not set. + if (data_type == TypeId::kTypeUnknown) { + data_type = buf_type; + } + // Convert input array to C contiguous if need. + std::unique_ptr tmp_buf; + if (!IsCContiguous(input)) { + Py_buffer pybuf; + if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { + MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; + } + tmp_buf = std::make_unique(pybuf.len); + if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { + MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; + } + PyBuffer_Release(&pybuf); + buf.ptr = tmp_buf.get(); + } + // Get tensor shape. + std::vector shape(buf.shape.begin(), buf.shape.end()); + if (data_type == buf_type) { + // Use memory copy if input data type is same as the required type. + return std::make_shared(data_type, shape, buf.ptr, buf.size * buf.itemsize); + } + // Create tensor with data type converted. + return std::make_shared(data_type, shape, buf.ptr, buf_type); +} + +static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { + std::vector strides; + strides.reserve(shape.size()); + const auto ndim = shape.size(); + for (size_t i = 0; i < ndim; ++i) { + auto stride = item_size; + for (size_t j = i + 1; j < ndim; ++j) { + stride *= shape[j]; + } + strides.push_back(stride); + } + return strides; +} + +static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { + std::vector shape(tensor.shape().begin(), tensor.shape().end()); + std::vector strides = GetStrides(shape, tensor.data().itemsize()); + return py::buffer_info{ + tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; +} + +py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { + auto &shape = tensor.shape(); + py::tuple dims(shape.size()); + for (size_t i = 0; i < dims.size(); ++i) { + dims[i] = py::int_(shape[i]); + } + return dims; +} + +py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { + tensor.data_sync(); + auto info = GetPyBufferInfo(tensor); + py::object self = py::cast(&tensor); + return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); +} + +py::array TensorPy::AsNumpy(const Tensor &tensor) { + auto info = GetPyBufferInfo(tensor); + py::object self = py::cast(&tensor); + return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); +} + +static std::vector GetShapeFromTuple(const py::tuple &tuple) { + std::vector shape; + const size_t size = tuple.size(); + shape.reserve(tuple.size()); + for (size_t i = 0; i < size; ++i) { + shape.push_back(py::int_(tuple[i])); + } + return shape; +} + +REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { + // Define python MetaTensor class. + (void)py::class_>(*m, "MetaTensor") + .def(py::init>(), py::arg("dtype"), py::arg("shape")) + .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") + .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") + .def(py::pickle( + [](const MetaTensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(static_cast(t.data_type()), t.shape()); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 2) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); + return tensor; + })); + // Define python Tensor class. + // dtype should define before Tensor, because Tensor init depend dtype + (void)py::class_>(*m, "Tensor") + .def(py::init([](const Tensor &tensor) { return std::make_shared(tensor); }), + py::arg("input")) + .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { + TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; + if (data_type == kTypeUnknown || tensor.data_type() == data_type) { + return std::make_shared(tensor); + } + return std::make_shared(tensor, data_type); + }), + py::arg("input"), py::arg("dtype")) + .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; + return std::make_shared(data_type, GetShapeFromTuple(shape)); + }), + py::arg("dtype"), py::arg("shape")) + .def(py::init([](const py::array &input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(input, type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::float_ input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::int_ input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::list input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::tuple input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) + .def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter( + Get the tensor's data type. + + Returns: + type, the data type of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) + >>> data.dtype + Int32 + )mydelimiter") + .def_property_readonly("_shape", TensorPy::GetPyTupleShape, R"mydelimiter( + Get the tensor's shape. + + Returns: + tuple[int], the shape of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((3, 3))) + >>> data.shape() + (3, 3) + )mydelimiter") + .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( + Convert tensor to numpy.ndarray. + + Returns: + numpy.ndarray. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> array = data.asnumpy() + >>> array + array([[1., 1., 1.], + [1., 1., 1.]]) + )mydelimiter") + .def("size", &Tensor::DataSize, R"mydelimiter( + Get tensor's data size. + + Returns: + int, the size of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.size() + 6 + )mydelimiter") + .def("is_init", &Tensor::is_init, R"mydelimiter( + Get tensor init_flag. + + Returns: + bool, whether the tensor init. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.is_init() + False + )mydelimiter") + .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( + Set tensor init_flag. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.set_init_flag(True) + )mydelimiter") + .def("dim", &Tensor::DataDim, R"mydelimiter( + Get tensor's data dimension. + + Returns: + int, the dimension of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.dim() + 2 + )mydelimiter") + .def("assign_value", &Tensor::AssignValue, R"mydelimiter( + Assign another tensor value to this. + + Arg: + value (:class:`mindspore.tensor`): The value tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) + >>> data.assign_value(data2) + >>> data.shape + (2, 2) + )mydelimiter") + .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( + Set the tensor's data type. + + Arg: + dtype (:class:`mindspore.dtype`): The type of output tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data.set_dtype(mindspore.int32) + mindspore.int32 + )mydelimiter") + .def("__str__", &Tensor::ToString) + .def("__repr__", &Tensor::ToStringRepr) + .def(py::pickle( + [](const Tensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(TensorPy::SyncAsNumpy(t)); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + return TensorPy::MakeTensor(t[0].cast()); + })); + })); +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/tensor_py.h b/mindspore/ccsrc/utils/tensor_py.h new file mode 100644 index 0000000000..0a54530487 --- /dev/null +++ b/mindspore/ccsrc/utils/tensor_py.h @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_TENSOR_PY_H_ +#define MINDSPORE_CCSRC_UTILS_TENSOR_PY_H_ + +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +#include "ir/tensor.h" + +namespace py = pybind11; + +namespace pybind11 { +namespace detail { +// Similar to enums in `pybind11/numpy.h`. Determined by doing: +// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' +constexpr int NPY_FLOAT16 = 23; + +template +struct npy_scalar_caster { + PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); + using Array = array_t; + + bool load(handle src, bool convert) { + // Taken from Eigen casters. Permits either scalar dtype or scalar array. + handle type = dtype::of().attr("type"); + if (!convert && !isinstance(src) && !isinstance(src, type)) return false; + + Array tmp = Array::ensure(src); + if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { + this->value = *tmp.data(); + return true; + } + + return false; + } + + static handle cast(T src, return_value_policy, handle) { + Array tmp({1}); + tmp.mutable_at(0) = src; + tmp.resize({}); + + // You could also just return the array if you want a scalar array. + object scalar = tmp[tuple()]; + return scalar.release(); + } +}; + +template <> +struct npy_format_descriptor { + static constexpr auto name = "float16"; + static pybind11::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); + return reinterpret_borrow(ptr); + } + virtual ~npy_format_descriptor() {} +}; + +template <> +struct type_caster : public npy_scalar_caster { + static constexpr auto name = "float16"; +}; +} // namespace detail +} // namespace pybind11 + +// brief mindspore namespace. +// +// mindspore namespace is the top level namespace of Mindsporeession project. +// Other namespace should be a sub namespace of mindspore namespace in the ME project. +namespace mindspore { +// brief mindspore::tensor namespace +// +// A sub namespace in ME to support tensor related definition. +namespace tensor { +// Tensor python wrapper and adapter class. +class TensorPy { + public: + // brief Create Tensor from a numpy array object. + // + // param input [py::array] Data value of the tensor. + // param data_type [TypeId] Data type of the tensor. + static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr); + + static py::array SyncAsNumpy(const Tensor &tensor); + + static py::array AsNumpy(const Tensor &tensor); + + static py::tuple GetPyTupleShape(const Tensor &tensor); +}; +} // namespace tensor +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_UTILS_TENSOR_PY_H_ diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index 08cd4e4291..d6dbe970c4 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -21,6 +21,8 @@ #include #include #include "ir/tensor.h" +#include "pybind11/pybind11.h" +#include "utils/ms_utils.h" #include "runtime/device/convert_tensor_utils.h" #include "./securec.h" #ifndef NO_DLIB @@ -29,6 +31,7 @@ #include "tdt/data_common.h" #endif +namespace py = pybind11; namespace mindspore { const char kShapeSeperator[] = ","; const char kShapeScalar[] = "[0]"; diff --git a/mindspore/ccsrc/utils/tensorprint_utils.h b/mindspore/ccsrc/utils/tensorprint_utils.h index 4a40862ea3..b150368f71 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.h +++ b/mindspore/ccsrc/utils/tensorprint_utils.h @@ -24,7 +24,7 @@ #include "tdt/tdt_host_interface.h" #include "tdt/data_common.h" #include "proto/print.pb.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #endif namespace mindspore { class TensorPrint { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 3e82aaff2d..6be32a3df5 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -14,12 +14,13 @@ * limitations under the License. */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_UTILS_UTILS_H_ -#define MINDSPORE_MINDSPORE_CCSRC_UTILS_UTILS_H_ +#ifndef MINDSPORE_CCSRC_UTILS_UTILS_H_ +#define MINDSPORE_CCSRC_UTILS_UTILS_H_ #include #include #include +#include #include #include #include @@ -228,6 +229,7 @@ constexpr auto kAttrLabelSwitchList = "label_switch_list"; constexpr auto kAttrNewAxisMask = "new_axis_mask"; constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names"; +constexpr auto kAttrDatadumpIsMultiop = "_datadump_is_multiop"; constexpr auto kAttrStreamId = "stream_id"; constexpr auto kAttrRecordEvent = "record_event"; constexpr auto kAttrWaitEvent = "wait_event"; @@ -246,6 +248,10 @@ constexpr auto kAttrOffset = "offset"; constexpr auto kAttrPsKey = "ps_key"; constexpr auto kAttrOptimizerType = "optim_type"; constexpr auto kAttrChildGraph = "child_graph"; +constexpr auto kAttrInputNums = "inputNums"; +constexpr auto kAttrT = "T"; +constexpr auto kAttrNum = "num"; +constexpr auto kAttrRankSize = "rank_size"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; @@ -314,7 +320,6 @@ const std::set kOptOperatorSet = { kApplyProximalAdagradOpName, kApplyProximalGradientDescentOpName, kApplyRMSPropOpName, - kPushOpName, kPullOpName, }; @@ -333,5 +338,44 @@ static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { MS_LOG(DEBUG) << "File `" << file_name << "` change mode failed! May be not exist."; } } + +static inline uint64_t GetCurrentUSec() { + struct timeval tv; + int ret = gettimeofday(&tv, nullptr); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Fail gettimeofday, ret = " << ret; + } + return static_cast(tv.tv_usec + tv.tv_sec * 1000000); +} + +#define PROF_START(stage) uint64_t start_usec_##stage = mindspore::GetCurrentUSec() +#define PROF_END(stage) \ + do { \ + uint64_t end_usec_##stage = mindspore::GetCurrentUSec(); \ + MS_LOG(INFO) << #stage << " costs " << (end_usec_##stage - start_usec_##stage) << " usec."; \ + } while (0) + +#define PROF_MULTI_DEFINE(stage) \ + static uint64_t total_##stage = 0; \ + static uint64_t count_##stage = 0; + +#define PROF_LOCAL_DEFINE(stage) \ + uint64_t total_##stage = 0; \ + uint64_t count_##stage = 0; + +#define PROF_MULTI_START(stage) uint64_t start_usec_##stage = mindspore::GetCurrentUSec() + +#define PROF_MULTI_END(stage) \ + do { \ + ++count_##stage; \ + uint64_t end_usec_##stage = mindspore::GetCurrentUSec(); \ + total_##stage += (end_usec_##stage - start_usec_##stage); \ + } while (0) + +#define PROF_MULTI_PRINT(stage) \ + do { \ + MS_LOG(INFO) << #stage << " called " << count_##stage << " times, costs " << total_##stage << " usec."; \ + } while (0) + } // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_UTILS_UTILS_H_ +#endif // MINDSPORE_CCSRC_UTILS_UTILS_H_ diff --git a/mindspore/core/ir/value_py.cc b/mindspore/ccsrc/utils/value_py.cc similarity index 100% rename from mindspore/core/ir/value_py.cc rename to mindspore/ccsrc/utils/value_py.cc diff --git a/mindspore/ccsrc/utils/visible.h b/mindspore/ccsrc/utils/visible.h deleted file mode 100644 index 96395230f9..0000000000 --- a/mindspore/ccsrc/utils/visible.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_VISIBLE_H_ -#define MINDSPORE_CCSRC_UTILS_VISIBLE_H_ - -namespace mindspore { -// refer to https://gcc.gnu.org/wiki/Visibility -#if defined _WIN32 || defined __CYGWIN__ -#ifdef BUILDING_DLL -#ifdef __GNUC__ -#define MS_EXPORT __attribute__((dllexport)) -#else -#define MS_EXPORT __declspec(dllexport) // Note: actually gcc seems to also supports this syntax. -#endif -#else -#ifdef __GNUC__ -#define MS_EXPORT __attribute__((dllimport)) -#else -#define MS_EXPORT __declspec(dllimport) // Note: actually gcc seems to also supports this syntax. -#endif -#endif -#define MS_LOCAL -#else -#define MS_EXPORT __attribute__((visibility("default"))) -#define MS_LOCAL __attribute__((visibility("hidden"))) -#endif - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_UTILS_VISIBLE_H_ diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 0290ee57fc..42af5541f2 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -21,10 +21,9 @@ #include "utils/log_adapter.h" #include "ir/anf.h" #include "utils/callbacks.h" -#include "utils/graph_utils.h" #include "utils/base_ref_extends.h" #include "backend/session/session_factory.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" #endif @@ -34,19 +33,6 @@ namespace compile { bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast(c), value); } -LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { - // multi_graph merge to one, big graph have paramters in begin and only have one output - MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size(); - multi_result_.inputs = g->parameters(); - final_output_ = NewValueNode("fake_output"); - multi_result_.outputs = {final_output_}; - GraphId final_g = target_sess_->GetFinalRunGraph(); - - multi_result_.run = std::make_shared( - [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); }); - return multi_result_; -} - LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { MS_LOG(DEBUG) << "MsConvert"; MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); @@ -96,149 +82,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri return result; } -void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { - GraphId active_g = simu_cond_map_[c].cond_graph_map[cond]; - - GraphId cond_g = kInvalidGraphId; - if (utils::isa(c)) { - cond_g = target_sess_->GetGraphIdByNode(utils::cast(c)); - } else { - MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); - } - auto before_cond = curr_switch_; - if (curr_switch_.hash() != c.hash()) { - // invoke while false->before true call - if (simu_cond_map_[before_cond].cond_graph_map.count(false)) { - active_g = simu_cond_map_[before_cond].cond_graph_map[false]; - } else { - active_g = kInvalidGraphId; - } - // while x < y: - // z = y + 1 - // while z < c2: - // out = out + 1 - // z = z + 1 - if (active_g == cond_g) { - active_g = kInvalidGraphId; - simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId; - } - MS_LOG(DEBUG) << "invoke set active:" << active_g; - } - MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; - target_sess_->SetActive(active_g, cond_g); -} - -void MsBackend::SetSwitchGraph() { - MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString(); - - if (is_switch_call_) { - GraphId false_g = kInvalidGraphId; - GraphId true_g = kInvalidGraphId; - MS_LOG(DEBUG) << "start SetSwitchGraph"; - true_g = simu_cond_map_[curr_switch_].cond_graph_map[true]; - bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; - if (!curr_cond) { - if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) { - // has false branch - false_g = simu_cond_map_[curr_switch_].cond_graph_map[false]; - } - GraphId cond_g = kInvalidGraphId; - if (utils::isa(curr_switch_)) { - cond_g = target_sess_->GetGraphIdByNode(utils::cast(curr_switch_)); - } else { - MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); - } - MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; - target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast(curr_switch_)); - } - is_switch_call_ = false; - MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; - } -} - -// convert node from formal parameter to actual parameter, -// and actual parameter is graph user's formal parameter. -// get top while graph's parameter in recall while. -AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - std::unordered_map params_index; - auto result = node; - auto graph = result->func_graph(); - while (func_graph != graph) { - auto iter = graph_user_inputs_.find(graph); - if (iter == graph_user_inputs_.end()) { - break; - } - - params_index.clear(); - auto ¶ms = graph->parameters(); - for (size_t i = 0; i < params.size(); ++i) { - params_index[params[i]] = i; - } - - graph = iter->second.first; - auto &inputs = iter->second.second; - result = inputs[params_index[result]]; - } - return result; -} - -void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user, - const AnfNodePtrList &inputs) { - if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) { - return; - } - graph_user_inputs_[func_graph] = {user, inputs}; -} - -void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) { - std::unordered_map params_index; - auto ¶ms = func_graph->parameters(); - for (size_t i = 0; i < params.size(); ++i) { - params_index[params[i]] = i; - } - - // recall all child graphs in this while - auto &graph_inputs = graph_inputs_[c]; - for (auto &iter : graph_inputs) { - auto &graph = iter.first; - auto &old_args = iter.second; - auto &result = graph_id_map_[graph]; - auto &inputs = result.inputs; - for (size_t i = 0; i < inputs.size(); ++i) { - auto input = ConvertGraphInput(func_graph, inputs[i]); - auto it = params_index.find(input); - if (it != params_index.end()) { - old_args[i] = args[it->second]; - } - } - target_sess_->SetChildGraphInput(graph, old_args); - } - graph_inputs_.erase(c); -} - // compile set input output VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { MS_LOG(DEBUG) << "set graph input:" << g; - // switch maybe twice - target_sess_->SetChildGraphInput(g, args); - - if (is_switch_call_) { - if (!curr_switch_.is_null()) { - // push this {g, args} to all user while graph_inputs for nest while, - // when current condition recall over delete this cond in graph_inputs. - for (auto &iter : graph_inputs_) { - iter.second.push_back({g, args}); - } - if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) { - graph_inputs_[curr_switch_].push_back({g, args}); - } - } - bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; - MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g; - simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g; - SetSwitchGraph(); - } - std::vector outputs; (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs), [](const AnfNodePtr &v) { return v; }); @@ -290,36 +136,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s return outputs; } -SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) { - MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size(); - - CondGraph cond_graph; - cond_graph.curr_cond = value; - if (simu_cond_map_.find(c) == simu_cond_map_.end()) { - simu_cond_map_[c] = cond_graph; - } - - if (simu_cond_map_[c].cond_graph_map.count(value)) { - return kCondAlreadyRun; - } - simu_cond_map_[c].curr_cond = value; - MS_LOG(DEBUG) << "end set cond "; - return kCondOk; -} - -void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) { - MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size(); - std::vector args; - auto parameters = root->parameters(); - (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), - [](const AnfNodePtr &v) { return v; }); - MS_LOG(DEBUG) << "Simulate start"; - (void)target_sess_->SetFinalGraphInput(parameters); - BaseRef output = rt->Eval(VectorRef(args)); - target_sess_->SetFinalGraphOutput(output); - MS_LOG(DEBUG) << "Simulate Eval end"; -} - void MsBackend::Link(GraphId graph_id) { if (graph_id == kInvalidGraphId) { graph_id = target_sess_->GetFinalRunGraph(); @@ -330,9 +146,6 @@ void MsBackend::Link(GraphId graph_id) { Backend::Backend(const std::string &name) : name_(name) { MS_LOG(DEBUG) << "select backend:" << name; convert_fn_ = backends[name_]; - is_switch_call_ = false; - is_multi_graph_sink_ = false; - simu_flag_ = false; } MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 208c4010fb..1bb7c2e406 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -43,50 +43,19 @@ class Backend { LinkFuncType convert_fn() { return convert_fn_; } std::string name() { return name_; } - virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} - virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } virtual bool GetCond(const BaseRef &c, bool *value); virtual bool GetIndex(const BaseRef &c, int *value); - virtual void SetSwitchGraph() {} - virtual void SetSwitchActive(const BaseRef &, bool) {} - virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} - virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} virtual GraphId CompileGraph(NotNull fg) { return kInvalidGraphId; } - void set_curr_switch(const BaseRef &value) { - curr_switch_ = value; - is_switch_call_ = true; - } - - BaseRef curr_switch() { return curr_switch_; } virtual void Link(GraphId) {} - virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); } + virtual void SetDebugger() {} - LinConvertResult multi_result() { return multi_result_; } - void set_multi_result(const LinConvertResult &value) { multi_result_ = value; } - AnfNodePtr final_output() const { return final_output_; } bool is_multi_graph_sink() const { return is_multi_graph_sink_; } void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } - bool simu_flag() const { return simu_flag_; } - bool is_switch_call() const { return is_switch_call_; } - void set_simu_flag(bool simu) { simu_flag_ = simu; } - - virtual void SetDebugger() {} protected: std::string name_; LinkFuncType convert_fn_; - BaseRef curr_switch_; // curr switch node bool is_multi_graph_sink_; - bool is_switch_call_; - bool simu_flag_; - LinConvertResult multi_result_; - AnfNodePtr final_output_; - std::unordered_map> graph_user_inputs_; -}; - -struct CondGraph { - bool curr_cond; - std::unordered_map cond_graph_map; }; class MsBackend : public Backend { @@ -98,16 +67,7 @@ class MsBackend : public Backend { VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); - void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; - SwitchCondStatus SetSimuCond(const BaseRef &c, bool value) override; - - void SetSwitchGraph() override; - void SetSwitchActive(const BaseRef &c, bool cond) override; - void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override; - void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override; void Link(GraphId) override; - AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); - LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; GraphId CompileGraph(NotNull fg) override; VectorRef RunGraph(GraphId graph_id, const VectorRef &args); void CreateOtherSession(const std::string &target); @@ -121,9 +81,7 @@ class MsBackend : public Backend { session::SessionPtr other_sess_; std::string target_device_; std::string other_device_; - std::unordered_map simu_cond_map_; std::unordered_map graph_id_map_; - std::unordered_map>, BaseRefHash> graph_inputs_; }; } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 540b77bcaf..141adc1bff 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -79,6 +79,42 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, c return output; } +namespace { +AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr, + AnfNodePtrToAnfNodePtrMap *eqv_ptr) { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(inputs_ptr); + MS_EXCEPTION_IF_NULL(eqv_ptr); + MS_EXCEPTION_IF_NULL(node); + auto &inputs = *inputs_ptr; + auto &eqv = *eqv_ptr; + if (node->isa() && !IsValueNode(node)) { + eqv[node] = node; + } else if (eqv.find(node) == eqv.end()) { + bool ignore_make_tuple = false; + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + ignore_make_tuple = true; + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &node_inputs = cnode->inputs(); + for (size_t i = 1; i < node_inputs.size(); ++i) { + if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) { + ignore_make_tuple = false; + break; + } + } + } + if (!ignore_make_tuple) { + inputs.push_back(node); + } + eqv[node] = fg->add_parameter(); + eqv[node]->set_abstract(node->abstract()); + eqv[node]->set_kernel_info(node->kernel_info_ptr()); + } + return eqv[node]; +} +} // namespace + std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { auto fg = std::make_shared(); AnfNodePtrList inputs; @@ -86,17 +122,6 @@ std::tuple TransformSegmentToAnfGr if (lst.empty()) { MS_LOG(EXCEPTION) << "Input anf node list is empty"; } - auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { - if (a->isa() && !IsValueNode(a)) { - eqv[a] = a; - } else if (eqv.find(a) == eqv.end()) { - inputs.push_back(a); - eqv[a] = fg->add_parameter(); - eqv[a]->set_abstract(a->abstract()); - eqv[a]->set_kernel_info(a->kernel_info_ptr()); - } - return eqv[a]; - }; // Merge CNodes into a AnfGraph that represents a linear instruction segment for (auto n : lst) { if (!n->isa()) { @@ -117,8 +142,17 @@ std::tuple TransformSegmentToAnfGr eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { args.emplace_back(inps[kRealInputIndexInDepend]); args.emplace_back(inps[kRealInputIndexInDepend]); + } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { + for (size_t i = 1; i < inps.size(); ++i) { + if (inps[i]->isa() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { + args.emplace_back(NewValueNode(MakeValue(i))); + } else { + args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv)); + } + } } else { - (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); + (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), + [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); } eqv[n] = fg->NewCNode(args); eqv[n]->set_abstract(n->abstract()); diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 2cf6ead813..61d96944f7 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -30,8 +30,8 @@ #ifdef ENABLE_GE #include "transform/graph_ir/convert.h" #endif -#include "utils/graph_utils.h" -#include "utils/context/ms_context.h" +#include "ir/graph_utils.h" +#include "utils/ms_context.h" #include "debug/trace.h" #include "debug/anf_ir_dump.h" @@ -69,7 +69,91 @@ bool ContainMultiTarget(const std::vector &nodes) { return false; } -void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref) { +bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, + std::vector *prior_nodes, std::vector *depend_nodes) { + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(behind_node); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + if (prior_node->isa()) { + for (auto &user : node_users[prior_node]) { + auto cnode = user.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + prior_nodes->emplace_back(cnode); + } + } + } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { + prior_nodes->emplace_back(prior_node); + } else { + return false; + } + if (behind_node->isa()) { + for (auto &user : node_users[behind_node]) { + auto cnode = user.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + depend_nodes->emplace_back(cnode); + } + } + } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { + depend_nodes->emplace_back(behind_node); + } else { + return false; + } + return true; +} + +void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, + std::map> *control_edges, + std::map *nodes_ref) { + MS_EXCEPTION_IF_NULL(node); + auto input_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto prior_node = input_cnode->input(kControlDependPriorIndex); + auto depend_node = input_cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(depend_node); + PrimitivePtr prim_ptr = GetValueNode(input_cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim_ptr); + ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); + int depend_mode = 0; + if (mode_ptr != nullptr) { + depend_mode = GetValue(mode_ptr); + } + if ((prior_node->isa() || depend_node->isa()) && depend_mode == 0) { + return; + } + std::vector prior_nodes; + std::vector behind_nodes; + if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { + return; + } + for (auto &first_node : prior_nodes) { + for (auto &second_node : behind_nodes) { + MS_EXCEPTION_IF_NULL(first_node); + MS_EXCEPTION_IF_NULL(second_node); + auto iter = control_edges->find(second_node); + if (iter == control_edges->end()) { + (void)control_edges->insert( + std::pair>(second_node, std::vector{first_node})); + } else { + iter->second.emplace_back(first_node); + } + auto ref_iter = nodes_ref->find(first_node); + if (ref_iter != nodes_ref->end()) { + ref_iter->second++; + } else { + (void)nodes_ref->insert(std::pair(first_node, 1)); + } + } + } +} + +void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref, + std::map> *control_edges) { std::queue queue; queue.push(graph->get_return()); std::set visited; @@ -83,6 +167,9 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *n auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); for (auto &input : cnode->inputs()) { + if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { + AddControlEdge(graph, input, control_edges, nodes_ref); + } auto iter = nodes_ref->find(input); if (iter != nodes_ref->end()) { iter->second++; @@ -142,7 +229,8 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & std::stack to_visit; std::stack next_to_visit; std::map nodes_ref; - CalcNodeRefCount(graph, &nodes_ref); + std::map> control_edges; + CalcNodeRefCount(graph, &nodes_ref, &control_edges); std::string handle_target = default_target; std::string next_target = ""; to_visit.push(graph->get_return()); @@ -162,6 +250,10 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & MS_EXCEPTION_IF_NULL(cnode); auto node_inputs = cnode->inputs(); std::reverse(node_inputs.begin(), node_inputs.end()); + auto ctrl_inputs = control_edges.find(node); + if (ctrl_inputs != control_edges.end()) { + node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); + } for (auto &input : node_inputs) { auto iter = nodes_ref.find(input); if (iter != nodes_ref.end()) { @@ -423,11 +515,7 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no MS_LOG(DEBUG) << "LinConvert start"; LinConvertResult result; - if (backend_->simu_flag()) { - result = backend_->GetMultiGraphRun(graph); - } else { - result = lin_convert_(node_list, target); - } + result = lin_convert_(node_list, target); if (result.run == nullptr) { MS_LOG(ERROR) << "LinConvert failed"; @@ -454,27 +542,6 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no return RET_SUCCESS; } -void CompileGraph::AddSinkSwitch(const CNodePtr &node) { - MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString(); - if (backend_->is_multi_graph_sink()) { - VectorRef args; - args.emplace_back(-1); - MS_LOG(DEBUG) << "call::" << height_; - AddInst(Instruction::kCall, args); - - args.clear(); - args.emplace_back(node->input(1)); - AddInst(Instruction::kSwitchReturn, args); - - args.clear(); - args.emplace_back(false); - args.emplace_back(Ref(node->input(1))); - args.emplace_back(Ref(node->input(2))); - args.emplace_back(Ref(node->input(3))); - AddInst(Instruction::kSwitch, args); - } -} - int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true); @@ -497,7 +564,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) AddPartial(node); } else if (IsPrimitive(fn, prim::kPrimSwitch)) { AddSwitch(node); - AddSinkSwitch(node); } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) { AddSwitchLayer(node); } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { @@ -515,14 +581,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) return RET_SUCCESS; } -void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) { - auto ret = LinConvert(graph, {}); - if (ret == RET_FAILED) { - MS_LOG(EXCEPTION) << "MultiGraphRun failed."; - } - AddReturn(nullptr); -} - bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start split graph"; MS_EXCEPTION_IF_NULL(graph); @@ -567,11 +625,6 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { return true; } -InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) { - InstSet inst = Run(graph); - return inst; -} - InstSet CompileGraph::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); @@ -580,12 +633,8 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) { int param_height = height_; MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); - if (backend_->simu_flag()) { - GenMultiGraphsRun(graph); - } else { - if (!SplitGraph(graph)) { - return inst_; - } + if (!SplitGraph(graph)) { + return inst_; } AddPadStack(param_height); @@ -620,12 +669,6 @@ void CompileGraph::AddPartial(const CNodePtr &node) { if (!IsValueNode(fn)) { MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph"; } - if (backend_->is_multi_graph_sink()) { - auto func_graph = GetValueNode(fn); - args.emplace_back(func_graph); - AnfNodePtrList outs(inputs.begin() + 2, inputs.end()); - backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); - } for (size_t i = 1; i < inputs.size(); i++) { args.emplace_back(Ref(inputs[i])); } @@ -647,9 +690,6 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4"; } VectorRef args; - if (backend_->is_multi_graph_sink()) { - args.emplace_back(true); - } args.emplace_back(Ref(inputs[1])); args.emplace_back(Ref(inputs[2])); args.emplace_back(Ref(inputs[3])); @@ -669,11 +709,7 @@ void CompileGraph::AddSwitchLayer(const CNodePtr &node) { void CompileGraph::AddReturn(const CNodePtr &node) { VectorRef args; - if (backend_->simu_flag()) { - args.emplace_back(Ref(backend_->final_output())); - } else { - args.emplace_back(Ref(node->input(1))); - } + args.emplace_back(Ref(node->input(1))); args.emplace_back(height_); AddInst(Instruction::kReturn, args); } @@ -691,11 +727,6 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { auto inputs = node->inputs(); AnfNodePtr fn = inputs[0]; - if (backend_->is_multi_graph_sink() && IsValueNode(fn)) { - auto func_graph = GetValueNode(fn); - AnfNodePtrList outs(inputs.begin() + 1, inputs.end()); - backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); - } (void)Ref(fn); size_t size = inputs.size(); for (size_t i = size - 1; i > 0; i--) { @@ -837,17 +868,6 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) { } FinalVMPtr rt = std::make_shared(insts_, backend_); - if (backend_->is_multi_graph_sink()) { - backend_->set_simu_flag(true); - MS_LOG(DEBUG) << "Start simulate"; - backend_->SimulateRun(rt, graph); - MS_LOG(DEBUG) << "Link graphs"; - insts_ = transform_->GenMultiGraphsSinkInst(graph); - rt->set_insts(insts_); - backend_->set_simu_flag(false); - MS_LOG(DEBUG) << "End start simulate"; - backend_->Link(kInvalidGraphId); - } MS_LOG(DEBUG) << "End"; return rt; } diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index d08a24d188..819ee07eb7 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -54,12 +54,10 @@ class CompileGraph { ~CompileGraph() = default; InstSet Run(const FuncGraphPtr &func_graph); - InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph); bool IsCut(const AnfNodePtr &node); void Push(const AnfNodePtr &node); void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Ret(int nargs); - void GenMultiGraphsRun(const FuncGraphPtr &graph); int Ref(const AnfNodePtr &node); VectorRef SplitNodes(const FuncGraphPtr &func_graph); @@ -84,7 +82,6 @@ class CompileGraph { int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); - void AddSinkSwitch(const CNodePtr &node); void AddPadStack(int param_height); void AddTailCall(const AnfNodePtr &fn, size_t size); void AddPartial(const CNodePtr &node); diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index baa5b0ea11..091e7af7bc 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -17,12 +17,9 @@ */ #include "vm/vm.h" - #include - #include "vm/vmimpl.h" #include "vm/backend.h" -#include "vm/transform.h" #include "pipeline/jit/parse/data_converter.h" #include "utils/base_ref_extends.h" @@ -142,33 +139,10 @@ void FinalVM::Popsp() { } } -void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); } - -bool FinalVM::PopStatus() { - if (ret_status_.empty()) { - return false; - } - bool status = ret_status_.top(); - ret_status_.pop(); - return status; -} - void FinalVM::DoJmp(const BaseRef &jmp_orig) { MS_LOG(DEBUG) << "Start"; BaseRef jmp = jmp_orig; - if (backend_->simu_flag()) { - bool is_switch_call = false; - if (utils::isa(jmp)) { // need to inherit from Base - MS_LOG(DEBUG) << "Start jump StructSwitch"; - auto simu_value = utils::cast>(jmp); - jmp = simu_value->fn_; - backend_->set_curr_switch(simu_value->value_); - is_switch_call = true; - } - PushStatus(is_switch_call); - } - if (utils::isa(jmp)) { // need to inherit from Base MS_LOG(DEBUG) << "Start jump StructPartial"; auto new_jmp = utils::cast>(jmp); @@ -270,13 +244,6 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) { MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; return; } - - auto rv = Ref(-1); - if (utils::isa(rv) || utils::isa(rv)) { - auto &c = args[0]; - cond_out_[c] = rv; - } - Pop(1); Popsp(); } @@ -294,51 +261,12 @@ void FinalVM::InstReturn(const VectorRef &args) { int height = utils::cast(args[1]); auto rv = Ref(rpos); - if (backend_->simu_flag()) { - auto c = backend_->curr_switch(); - auto status = PopStatus(); - if (status) { - auto iter = cond_out_.find(c); - if (iter != cond_out_.end()) { - rv = MergeArgs(rv, iter->second); - cond_out_.erase(iter); - } - } - - if (backend_->is_switch_call()) { - backend_->SetSwitchGraph(); - } - } - Pop(height); Push(rv); Popp(); MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSimuPartial(const VectorRef &args) { - const size_t args_size = 2; - if (args.size() < args_size) { - MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is " - << args.size() << "."; - return; - } - - auto &node = args[0]; - if (!utils::isa(node)) { - MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph"; - return; - } - auto fg = utils::cast(node); - int fn_ = utils::cast(args[1]); - auto fn = utils::cast(Ref(fn_)); - MS_LOG(DEBUG) << "Partial argssize:" << args.size(); - std::vector outs(args.size() - 2); - (void)std::transform(args.begin() + 2, args.end(), outs.begin(), - [&, this](const BaseRef &a) { return Ref(utils::cast(a)); }); - Push(std::make_shared(fn, VectorRef(outs), fg)); -} - void FinalVM::InstRealPartial(const VectorRef &args) { const size_t args_size = 1; if (args.size() < args_size) { @@ -358,91 +286,10 @@ void FinalVM::InstRealPartial(const VectorRef &args) { void FinalVM::InstPartial(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; - if (backend_->is_multi_graph_sink()) { - InstSimuPartial(args); - } else { - InstRealPartial(args); - } + InstRealPartial(args); MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSimuSwitch(const VectorRef &args) { - const size_t args_size = 4; - if (args.size() != args_size) { - MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size() - << "."; - return; - } - bool cond = utils::cast(args[0]); - int cond_node = utils::cast(args[1]); - int vtrue = utils::cast(args[2]); - int vfalse = utils::cast(args[3]); - - MS_LOG(DEBUG) << "Simu switch cond:" << cond; - BaseRef c = Ref(cond_node); - bool bool_value = cond; - SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value); - - if (cond_stat == kCondAlreadyRun) { - MS_LOG(DEBUG) << "switch alreay run bool while true jmp"; - BaseRef jmp = Ref(vtrue); - if (utils::isa(jmp)) { - auto new_jmp = utils::cast>(jmp); - backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c); - } - cond_jmp_[c] = Ref(vfalse); - Push(static_cast(cond_stat)); - Popp(); - backend_->SetSwitchActive(c, bool_value); - return; - } - if (bool_value) { - Push(std::make_shared(Ref(vtrue), c)); - Pushsp(); - } else { - MergeJmpArgs(Ref(vfalse), c); - Push(std::make_shared(Ref(vfalse), c)); - } -} - -void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) { - auto iter = cond_jmp_.find(c); - if (iter == cond_jmp_.end()) { - return; - } - auto old_jmp = utils::cast>(iter->second); - auto new_jmp = utils::cast>(jmp); - auto &old_args = old_jmp->args_; - auto &new_args = new_jmp->args_; - for (size_t i = 0; i < new_args.size(); ++i) { - auto &old_arg = old_args[i]; - auto &new_arg = new_args[i]; - new_arg = MergeArgs(old_arg, new_arg); - } -} - -BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) { - MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString(); - if (utils::isa(first)) { - auto old_vec_ref = utils::cast(first); - if (utils::isa(second)) { - auto new_vec_ref = utils::cast(second); - std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); - } else { - old_vec_ref.push_back(second); - } - return old_vec_ref; - } - - if (utils::isa(second)) { - auto new_vec_ref = utils::cast(second); - new_vec_ref.push_back(first); - return new_vec_ref; - } - - return VectorRef({first, second}); -} - void FinalVM::InstRealSwitch(const VectorRef &args) { const size_t args_size = 3; if (args.size() != args_size) { @@ -472,11 +319,7 @@ void FinalVM::InstRealSwitch(const VectorRef &args) { void FinalVM::InstSwitch(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; - if (backend_->is_multi_graph_sink()) { - InstSimuSwitch(args); - } else { - InstRealSwitch(args); - } + InstRealSwitch(args); MS_LOG(DEBUG) << "End"; } @@ -580,14 +423,6 @@ void FinalVM::InstExternal(const VectorRef &args) { VectorRef tuple; RunFunctionRef run_ref = utils::cast(args[0]); compile::RunFuncPtr fn = run_ref.func_; - if (backend_->simu_flag()) { - MS_LOG(DEBUG) << "Simu run"; - if (args.size() == 1) { - MS_LOG(EXCEPTION) << "The number of args should be greater than 1, but got 1"; - } - auto simu_run_ref = utils::cast(args[1]); - fn = simu_run_ref.func_; - } for (size_t i = 2; i < args.size(); ++i) { auto index = utils::cast(args[i]); tuple.push_back(Ref(index)); diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 02a1ad4ddb..9986a3a34f 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -96,7 +96,6 @@ class FinalVM { public: // Create a VM with the specified instructions and backend. explicit FinalVM(const InstSet &insts, const BackendPtr &backend); - virtual ~FinalVM() = default; BaseRef Eval(const VectorRef &args); @@ -104,10 +103,8 @@ class FinalVM { void InstTailCall(const VectorRef &args); void InstReturn(const VectorRef &args); void InstPartial(const VectorRef &args); - void InstSimuPartial(const VectorRef &args); void InstRealPartial(const VectorRef &args); void InstSwitch(const VectorRef &args); - void InstSimuSwitch(const VectorRef &args); void InstRealSwitch(const VectorRef &args); void InstTuple(const VectorRef &args); void InstPush(const VectorRef &args); @@ -129,23 +126,16 @@ class FinalVM { void Popp(); void Pushsp(); void Popsp(); - void PushStatus(bool is_switch_call); - bool PopStatus(); void DoJmp(const BaseRef &jmp); void SyncData(const py::object &args); - void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); - BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); private: InstSet insts_; std::deque insts_stack_; std::stack retp_; std::stack retsp_; - std::stack ret_status_; int pc_; int sp_; - std::unordered_map cond_jmp_; - std::unordered_map cond_out_; BackendPtr backend_; const InstFunctionMap inst_function_map = { {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index 2aebf8ad0d..530e6ba040 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -30,7 +30,7 @@ #include "frontend/operator/ops.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "ir/primitive_py.h" +#include "utils/primitive_py.h" #include "utils/convert_utils.h" #include "utils/primitive_utils.h" #include "debug/draw.h" @@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { } BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { - PrimitivePyPtr operation = dyn_cast(prim); - MS_LOG(DEBUG) << "operation start " << prim->name(); - auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name()); - if (py::isinstance(func)) { - MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented"; - } - - py::tuple py_args = py::tuple(args.size()); - MS_LOG(DEBUG) << "input for operation:"; - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg: " << i << ":"; - i++; - } - py::object obj = func(*py_args); - MS_LOG(DEBUG) << "result:" << py::str(obj); - return obj; + MS_EXCEPTION_IF_NULL(prim); + auto result = prim->RunComputeFunction(args); + if (result.is_null()) { + return RunComputeFunction(prim, args); + } + return result; } } // namespace compile diff --git a/mindspore/common/__init__.py b/mindspore/common/__init__.py index c896805d75..e3f8396c4d 100644 --- a/mindspore/common/__init__.py +++ b/mindspore/common/__init__.py @@ -17,10 +17,10 @@ from . import dtype from .api import ms_function from .dtype import * from .parameter import Parameter, ParameterTuple -from .tensor import MetaTensor, Tensor, IndexedSlices +from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor __all__ = [ - "MetaTensor", "Tensor", "IndexedSlices", # tensor + "MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter "dtype" diff --git a/mindspore/common/_register_for_tensor.py b/mindspore/common/_register_for_tensor.py index 8ba2ff7cc4..effd21dd19 100644 --- a/mindspore/common/_register_for_tensor.py +++ b/mindspore/common/_register_for_tensor.py @@ -35,9 +35,11 @@ class Registry(UserDict): new_args = list(args) new_args.append(obj_str) return self["vm_compare"](*new_args) + obj = wrap else: obj = self[obj_str] return obj + tensor_operator_registry = Registry() diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 050baf9f79..b0e2ff5a3f 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -62,6 +62,7 @@ def _wrap_func(fn): Returns: Function, a new function with return suitable format data. """ + @wraps(fn) def wrapper(*arg, **kwargs): results = fn(*arg, **kwargs) @@ -74,6 +75,7 @@ def _wrap_func(fn): if isinstance(data, list): return list(_convert_data(x) for x in data) return data + return _convert_data(results) return wrapper @@ -106,6 +108,7 @@ class _MindSporeFunction: obj (Object): If function is a method, obj is the owner of function, else, obj is none. """ + def __init__(self, fn, input_signature=None, obj=None): self.fn = fn self.save_graphs = context.get_context("save_graphs") @@ -203,23 +206,24 @@ class _MindSporeFunction: def ms_function(fn=None, obj=None, input_signature=None): """ - Creates a callable MindSpore graph from a python function. + Create a callable MindSpore graph from a python function. This allows the MindSpore runtime to apply optimizations based on graph. Args: fn (Function): The Python function that will be run as a graph. Default: None. - obj (Object): The Python Object that provide information for identify compiled function. Default: None. - input_signature (MetaTensor): The MetaTensor to describe the input arguments. The MetaTensor specifies + obj (Object): The Python Object that provides the information for identifying the compiled function.Default: + None. + input_signature (MetaTensor): The MetaTensor which describes the input arguments. The MetaTensor specifies the shape and dtype of the Tensor and they will be supplied to this function. If input_signature - is specified, every input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept - `**kwargs`. The shape and dtype of actual inputs should keep same with input_signature, or TypeError - will be raised. Default: None. + is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept + `**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise, + TypeError will be raised. Default: None. Returns: - Function, if `fn` is not None, returns a callable that will execute the compiled function; If `fn` is None, - returns a decorator and when this decorator invokes with a single `fn` argument, the callable is equal to the - case when `fn` is not None. + Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is + None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is + equal to the case when `fn` is not None. Examples: >>> def tensor_add(x, y): @@ -245,13 +249,13 @@ def ms_function(fn=None, obj=None, input_signature=None): >>> out = tensor_add_with_dec(x, y) >>> out = tensor_add_with_sig(x, y) """ + def wrap_mindspore(func): @wraps(func) def staging_specialize(*args): process_obj = obj if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__): process_obj = args[0] - args = (x.default_input if hasattr(x, 'default_input') else x for x in args) return _MindSporeFunction(func, input_signature, process_obj)(*args) return staging_specialize @@ -275,6 +279,7 @@ def _generate_pip_args(obj, *args, method="construct"): obj.__parse_method__ = parse_method return args_names, args_list + class _PynativeExecutor: """ An pynative executor used to compile/manage/run graph. @@ -304,6 +309,7 @@ class _PynativeExecutor: def __call__(self, *args): return self._executor(args, "") + class _Executor: """ An executor used to compile/manage/run graph. @@ -348,28 +354,8 @@ class _Executor: raise RuntimeError("Failure to init and dataset subgraph!") return True - def _build_data_graph(self, obj, params, phase): - if params is None: - self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) - elif isinstance(params, OrderedDict): - self._executor.build_data_graph(params, phase) - else: - raise TypeError('Parameters need OrderedDict type, but got {}'. - format(type(params))) - - def _params_init_data(self, obj, params, auto_parallel_mode=False): - """Init parameters' data.""" - if params is not None: - for key, param in params.items(): - if not auto_parallel_mode: - param.init_data() - elif key not in obj.parameter_layout_dict: - logger.debug("Layout dict does not contain the key %s.", key) - param.init_data(set_sliced=True) - else: - layout = obj.parameter_layout_dict[key] - param.init_data(layout, set_sliced=True) - obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) + def _build_data_graph(self, obj, phase): + self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) def _set_dataset_mode(self, args_list): """set dataset mode.""" @@ -380,7 +366,7 @@ class _Executor: else: _set_dataset_mode_config('normal') - def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): + def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False): """ Compiles graph. @@ -388,7 +374,6 @@ class _Executor: obj (Function/Cell): The function or cell instance need compile. args (tuple): Function or cell input arguments. phase (str): The name of compile phase. Default: 'predict'. - params (OrderedDict): The parameters dictionary used for init data graph. Default: None. do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph. auto_parallel_mode: When set to True, use auto parallel mode to compile graph. @@ -429,10 +414,12 @@ class _Executor: if auto_parallel_mode: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) - self._params_init_data(obj, params, auto_parallel_mode) + replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) if not enable_debug_runtime or enable_ge: if auto_parallel_mode: - obj.load_parameter_slice(params) + obj.load_parameter_slice(None) + + self._updata_param_node_default_input(phase, replace) # set parallel inputs in sink mode if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag): @@ -440,16 +427,20 @@ class _Executor: # the following GE init process is not needed when use vm or ms backend if enable_ge: - self._build_data_graph(obj, params, phase) + self._build_data_graph(obj, phase) if "export" not in phase: init_phase = "init_subgraph" + "." + str(obj.create_time) _exec_init_graph(obj, init_phase) elif not enable_ge and "export" in phase: - self._build_data_graph(obj, params, phase) + self._build_data_graph(obj, phase) return phase, True + def _updata_param_node_default_input(self, phase, replace): + new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} + return self._executor.updata_param_node_default_input(phase, new_param) + def _get_strategy(self, obj): real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time) return self._executor.get_strategy(real_phase) @@ -515,18 +506,16 @@ class _Executor: return None return self._executor.get_func_graph_proto(exec_id, ir_type) - def export(self, net, file_name, file_format='GEIR'): + def export(self, file_name, graph_id): """ Export graph. Args: - net (Cell): MindSpore network file_name (str): File name of model to export - file_format (str): MindSpore currently support 'GEIR' and 'ONNX' format for exported model + graph_id (str): id of graph to be exported """ from .._c_expression import export_graph - phase = 'export' + '.' + self.phase_prefix + '.' + str(net.create_time) - export_graph(file_name, file_format, phase) + export_graph(file_name, 'GEIR', graph_id) def fetch_info_for_quant_export(self, exec_id): """Get graph proto from pipeline.""" @@ -534,6 +523,7 @@ class _Executor: return None return self._executor.fetch_info_for_quant_export(exec_id) + _executor = _Executor() _pynative_exec = _PynativeExecutor() diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 85bb1c52d6..2b1c692ac1 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -99,6 +99,9 @@ slice_type = typing.Slice ellipsis_type = typing.TypeEllipsis list_type = typing.List tuple_type = typing.Tuple +index_slices = typing.RowTensorType() +sparse_tensor = typing.SparseTensorType() +undetermined = typing.UndeterminedType() number_type = (int8, int16, @@ -163,7 +166,7 @@ def pytype_to_dtype(obj): def get_py_obj_dtype(obj): """ - Get the corresponding MindSpore data type by python type or variable. + Get the MindSpore data type which corresponds to python type or variable. Args: obj: An object of python type, or a variable in python type. @@ -183,7 +186,7 @@ def get_py_obj_dtype(obj): def dtype_to_nptype(type_): """ - Get numpy data type corresponding to MindSpore dtype. + Convert MindSpore dtype to numpy data type. Args: type_ (:class:`mindspore.dtype`): MindSpore's dtype. @@ -210,7 +213,7 @@ def dtype_to_nptype(type_): def dtype_to_pytype(type_): """ - Get python type corresponding to MindSpore dtype. + Convert MindSpore dtype to python data type. Args: type_ (:class:`mindspore.dtype`): MindSpore's dtype. diff --git a/mindspore/common/graph_pattern.py b/mindspore/common/graph_pattern.py new file mode 100644 index 0000000000..487db572f6 --- /dev/null +++ b/mindspore/common/graph_pattern.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Patterns for describing graphs""" +from mindspore.ops import Primitive +from mindspore.common.tensor import Tensor +from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_ + +__all__ = [ + "IsIn", + "IsPrimTypeOf", + "CallWith", + "IsNot", + "AnyPattern", + "NewTensor", +] + +class IsIn(IsIn_): + """ + Express a pattern which allows a list of patterns. + """ + def __init__(self, patterns=None, should_replace=True): + r""" + Args: + patterns(list/tuple): list of allowed patterns + should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. + """ + if not should_replace: + raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \ + its sub-pattern instead.") + self.patterns = patterns + if patterns is None: + IsIn_.__init__(self, ()) + elif isinstance(patterns, Pattern): + IsIn_.__init__(self, [patterns]) + elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): + IsIn_.__init__(self, patterns) + else: + raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") + +class IsPrimTypeOf(IsPrimTypeOf_): + r""" + Express a pattern of certain primitive type(s). + NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed, + please refer to CallWith pattern. + """ + def __init__(self, types, name=None, should_replace=True): + r""" + Args: + types (str/(list/tuple of Primitives)): Specify allowed types. + If it is a string, the form could be + 1) a single primitive type, e.g. 'Conv2D' + 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' + It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)] + name (str): name of the pattern, optional + should_replace + """ + if name is not None and not isinstance(name, str): + raise TypeError(f"Expect string, got : {name}") + self.name = name + if isinstance(types, str): + if self.name is None: + self.name = types + self.types = types.split('|') + elif isinstance(types, Primitive): + if self.name is None: + self.name = types.name + self.types = [types] + elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): + if self.name is None: + self.name = "" + for prim in types: + self.name += prim.name + self.types = types + else: + raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") + IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace) + +class CallWith(CallWith_): + r""" + Express a primitive CNode. + """ + def __init__(self, prim_pattern, inputs=None, should_replace=False): + r""" + Args: + prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode. + inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; + if specified, input patterns should be of right order. + """ + if not isinstance(prim_pattern, (Pattern, str, Primitive)): + raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") + self.prim_pattern = prim_pattern + self.inputs = [] + if inputs is None: + pass + elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): + self.inputs = inputs + else: + raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") + CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) + + +class IsNot(IsNot_): + r""" + Express a pattern which forbids a list of patterns. + NOTE: IsNot pattern should not be the root pattern. + """ + def __init__(self, patterns=None, should_replace=True): + r""" + Args: + patterns(list/tuple): list of forbiden patterns + should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. + """ + if not should_replace: + raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \ + its sub-pattern instead.") + self.patterns = patterns + if patterns is None: + IsNot_.__init__(self, ()) + elif isinstance(patterns, Pattern): + IsNot_.__init__(self, [patterns]) + elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): + IsNot_.__init__(self, patterns) + else: + raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") + +class NewTensor(NewTensor_): + r""" + New Tensor to be used in the target. + """ + def __init__(self, input_tensor, should_replace=False): + r""" + Args: + input_tensor(Tensor): new tensor to be used in the target + should_replace(bool): added this for interface consistency. NewTensor should only appear in the target. + """ + if should_replace: + raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.") + self.input_tensor = input_tensor + if isinstance(input_tensor, Tensor): + NewTensor_.__init__(self, input_tensor) + else: + raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 48a142a23f..f7bf3fef3d 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -163,7 +163,7 @@ def _calculate_in_and_out(arr): """ dim = len(arr.shape) if dim < 2: - raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") + raise ValueError("If initialize data with xavier uniform, the dimension of data must be greater than 1.") n_in = arr.shape[1] n_out = arr.shape[0] diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 1605ee4bc5..091f4bc967 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -14,33 +14,36 @@ # ============================================================================ """Parameter for cell.""" -import numbers from copy import copy -from mindspore import context from .._c_expression import ParamValue from . import dtype as mstype from .initializer import initializer, Initializer from .tensor import Tensor, MetaTensor from .._checkparam import _check_str_by_regular from ..parallel._tensor import _get_slice_index +from ..parallel._auto_parallel_context import auto_parallel_context __all__ = ['Parameter', 'ParameterTuple'] PARAMETER_NAME_DEFAULT = "Parameter" PARAMETER_NAME_PREFIX_MAX_LEN = 1024 +def _is_in_parallel_mode(): + """Get parallel mode.""" + return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] -def _check_type(x): - """Check input data type""" - if not isinstance(x, Parameter): - raise ValueError("Should be `Parameter` collection.") - return True - -class Parameter: +class Parameter(MetaTensor): """ Parameter types of cell models. + After initialized `Parameter` is a subtype of `Tensor`. + + In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by + a `Initializer`, the type of Parameter will be a `MetaTensor` not a `Tensor`. `MetaTensor` + only save the shape type info of a tensor with no memory usage. The shape can be change while + compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. + Note: Each parameter of Cell is represented by Parameter class. @@ -50,29 +53,100 @@ class Parameter: name (str): Name of the child parameter. requires_grad (bool): True if the parameter requires gradient. Default: True. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, - broadcast and gradients communication would not be applied on parameters. Default: False. + broadcast and gradients communication would not be applied to parameters. Default: False. """ + __base_type__ = {} + + def __new__(cls, default_input, name, *args, **kwargs): + input_class, *class_init_args = Parameter._get_parameter_new_args(default_input) + new_type = Parameter._get_base_class(input_class) + obj = input_class.__new__(new_type) + input_class.__init__(obj, *class_init_args) + # it's better to make the Initializer a kind of metatensor. + obj.init_mode = None + if not isinstance(obj, Tensor): + obj.init_mode = default_input + return obj + + def __reduce_ex__(self, _): + data = self + if self.init_mode is not None: + data = self.init_mode + else: + # cast to break deep infinit loop while deepcopy + data = Tensor(self) + return ( + Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) + def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): self._value = ParamValue() - self.set_parameter_data(default_input) self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel + # this flag for tensor copy data. + self.init_flag = False + # this flag is for ge variable copy data. self._is_init = False + self._inited_param = None self._sliced = False self.is_param_ps = False - if context.get_context("mode") == context.PYNATIVE_MODE: - self.init_data() + self._cast_type = None + self.init_in_server = False + + @staticmethod + def _get_base_class(input_class): + input_class_name = f'Parameter{input_class.__name__}' + if input_class_name in Parameter.__base_type__: + new_type = Parameter.__base_type__[input_class_name] + else: + new_type = type(input_class_name, (Parameter, input_class), {}) + Parameter.__base_type__[input_class_name] = new_type + return new_type + + @staticmethod + def _get_parameter_new_args(data): + """Set `default_input` of current `Parameter`.""" + if isinstance(data, bool): + raise ValueError('Parameter data can not be `bool`') + if isinstance(data, Initializer): + if _is_in_parallel_mode(): + # do not init data while in auto parallel. + return (MetaTensor, data.dtype, data.shape) + data = data.to_tensor() + if isinstance(data, Tensor): + # make a copy of Tensor to init the parameter + return (Tensor, data.asnumpy(),) + if isinstance(data, int): + return (Tensor, data, mstype.int32) + if isinstance(data, float): + return (Tensor, data, mstype.float32) + return (Tensor, data) + + def __str__(self): + value_str = MetaTensor.__str__(self) + if isinstance(self, Tensor): + value_str = Tensor.__str__(self) + return f'Parameter (name={self._value.name}, value={value_str})' def __repr__(self): - format_str = 'Parameter (name={name})' - return format_str.format(name=self._value.name) + value_str = MetaTensor.__repr__(self) + if isinstance(self, Tensor): + value_str = Tensor.__repr__(self) + return f'Parameter (name={self._value.name}, value={value_str})' def __parameter__(self): """For parse check.""" - def set_param_ps(self): + def set_param_ps(self, init_in_server=False): self.is_param_ps = True + self.init_in_server = init_in_server + + + @property + def inited_param(self): + """Get the new parameter after call the init_data.""" + return self._inited_param + @property def name(self): @@ -101,6 +175,16 @@ class Parameter: raise ValueError("The type of the name should be `str` or `None`.") self._value.name = name_ + @property + def cast_type(self): + return self._cast_type + + @cast_type.setter + def cast_type(self, dst_type): + if dst_type not in (mstype.float16, mstype.float32, None): + raise ValueError("The type of the name should be type of [float32, float16] or `None`.") + self._cast_type = dst_type + @property def sliced(self): """Get slice status of the parameter.""" @@ -112,7 +196,7 @@ class Parameter: @property def is_init(self): - """Get init status of the parameter.""" + """Get the initialization status of the parameter.""" return self._is_init @is_init.setter @@ -144,15 +228,9 @@ class Parameter: x._value.name = prefix + '.' + self._value.name x.is_init = False if init != 'same': - shape = self.default_input.shape - dtype = self.default_input.dtype - if isinstance(init, (str, Initializer, numbers.Number)): - x.init_mode = initializer(init, shape=shape, dtype=dtype) - x.default_input = MetaTensor(dtype, shape) - if context.get_context("mode") == context.PYNATIVE_MODE: - x.init_data() - else: - x.default_input = initializer(init, shape=shape, dtype=dtype) + shape = self.shape + dtype = self.dtype + x.default_input = initializer(init, shape=shape, dtype=dtype) return x @property @@ -182,55 +260,70 @@ class Parameter: @property def default_input(self): - return self._data + return self @default_input.setter def default_input(self, data): - self._data = data - self._value.data = data - - def __add__(self, other): - return self.default_input + other - - def __sub__(self, other): - return self.default_input - other - - def __mul__(self, other): - return self.default_input * other - - def __truediv__(self, other): - return self.default_input / other + self.set_parameter_data(data) + + def _update_tensor_data(self, data): + "Update the parameter by a Tensor." + if isinstance(self, Tensor): + # for Tensor same shape: + self.init_flag = False + return self.assign_value(data) + # create a new tensor + return Parameter(data, self.name, self.requires_grad) + + def set_parameter_data(self, data, slice_shape=False): + """ + Set `default_input` of current `Parameter`. - def __setitem__(self, index, value): - default_input = self.default_input - default_input[index] = value - return self + Args: + data (Union[Tensor, Initializer]): new data. + slice_shape (bool): If slice the Parameter. Default: False. - def set_parameter_data(self, data): - """Set `default_input` of current `Parameter`.""" - self.init_mode = None - if isinstance(data, bool): - raise ValueError('Parameter data can not be `bool`') - if isinstance(data, Tensor): - # make a copy of Tensor to init the parameter - data = Tensor(data.asnumpy()) - data.init_flag = False - elif isinstance(data, Initializer): - self.init_mode = data - data = MetaTensor(self.init_mode.dtype, self.init_mode.shape) - elif isinstance(data, int): - data = Tensor(data, dtype=mstype.int32) - elif isinstance(data, float): - data = Tensor(data, dtype=mstype.float32) + Retruns: + Parameter, the parameter after set data. + """ + if not isinstance(data, (MetaTensor, Initializer)): + raise ValueError(f"Parameter data must be `Initializer` or a kind of `MetaTensor` " + f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") + # both not init. + is_incoming_tensor = isinstance(data, Tensor) + is_current_tensor = isinstance(self, Tensor) + + if is_incoming_tensor and not is_current_tensor: + raise TypeError("Parameter is a `MetaTensor` and not initializered, `data` for `set_parameter_data`" + "should be a Initializer. If you want to update it by Tensor, call method" + "`init_parameters_data` of `Cell` to init and replace all the Parameter of" + "network, then call this method.") + if tuple(self.shape) != tuple(data.shape): + # If Slice create Parameter shape can be change. + if slice_shape: + self._update_tensor_data(data) + self.sliced = True + else: + raise ValueError(f"Can not change the shape of Parameter which has been initialized." + f" Current shape is {self.shape}, and incoming is {data.shape}.") + if self.dtype != data.dtype: + raise ValueError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" + f", and incoming is {data.dtype}. Use .set_dtype(xxx) to change the dtype.") + if isinstance(data, Initializer): + # The parameter has been initializered, directly update by the data + if is_current_tensor: + self._update_tensor_data(data.to_tensor()) + else: + self.init_mode = data + elif is_incoming_tensor or is_current_tensor: + self._update_tensor_data(data) else: - data = Tensor(data) - data.init_flag = False - - self.default_input = data + raise ValueError(f"Not support to update the Parameter by {data}") + return self def init_data(self, layout=None, set_sliced=False): """ - Init data of the parameter. + Initialize the parameter data. Args: layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape]. @@ -238,26 +331,40 @@ class Parameter: - dev_mat (list[int]): Device matrix. - tensor_map (list[int]): Tensor map. - slice_shape (list[int]): Shape of slice. - set_sliced (bool): True if should set parameter sliced after init the data of initializer. + set_sliced (bool): True if the parameter is set sliced after initializing the data. Default: False. + + Returns: + Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, + returns the same initialized `Parameter`. """ if self.init_mode is None: - return + return self + if self.inited_param is not None: + return self.inited_param if layout is not None: if not isinstance(layout, list): - raise TypeError("The layout should be list! layout is {}." - .format(layout)) - if len(layout) != 3: - raise ValueError("The length of layout must be 3! layout is {}." - .format(layout)) + raise TypeError("The layout should be list! layout is {}.".format(layout)) + if len(layout) < 3: + raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) slice_index = int(_get_slice_index(layout[0], layout[1])) - self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) + if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): + data = self.init_mode.to_tensor(0, [1]) + else: + data = self.init_mode.to_tensor(slice_index, layout[2]) else: - self.default_input = self.init_mode.to_tensor() + if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): + data = self.init_mode.to_tensor(0, [1]) + else: + data = self.init_mode.to_tensor() - self.init_mode = None + obj = self._update_tensor_data(data) + if id(obj) != id(self): + self._inited_param = obj + obj.init_mode = None if set_sliced: - self.sliced = True + obj.sliced = True + return obj class ParameterTuple(tuple): @@ -265,12 +372,16 @@ class ParameterTuple(tuple): Class for storing tuple of parameters. Note: - Used to store the parameters of the network into the parameter tuple collection. + It is used to store the parameters of the network into the parameter tuple collection. """ def __new__(cls, iterable): """Create instance object of ParameterTuple.""" - g = (x for x in iterable if _check_type(x)) - return tuple.__new__(ParameterTuple, g) + data = tuple(iterable) + for x in data: + if not isinstance(x, Parameter): + raise TypeError(f"ParameterTuple input should be `Parameter` collection." + f"But got a {type(iterable)}, {iterable}") + return tuple.__new__(ParameterTuple, tuple(data)) def clone(self, prefix, init='same'): """ diff --git a/mindspore/common/python_pass_register.py b/mindspore/common/python_pass_register.py index 36eb37adc7..ee4f0f0bc8 100644 --- a/mindspore/common/python_pass_register.py +++ b/mindspore/common/python_pass_register.py @@ -14,6 +14,7 @@ # ============================================================================ """Python pass register""" from inspect import isfunction +from mindspore.common.graph_pattern import Pattern from mindspore._c_expression import PyPassManager_ from mindspore._c_expression import phase @@ -46,10 +47,10 @@ class PyPassManager(PyPassManager_): raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") pattern, target = py_pass() pass_name = py_pass.__name__ - if not isfunction(pattern): - raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}") - if not isfunction(target): - raise TypeError(f"Expecting function target, got : ({type(target)}){target}") + if not isinstance(pattern, Pattern): + raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}") + if not isinstance(target, Pattern): + raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}") super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) def unregiste(self, py_pass, pipeline_phase=phase.opt): diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 64a8eb4637..5730c9f8c8 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -21,23 +21,22 @@ from .._checkparam import check_type, check_typename from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry -__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices'] +__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor'] np_types = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_) - class Tensor(Tensor_): """ - Tensor for data storage. + Tensor is used for data storage. - Tensor inherits tensor object in C++ side, some functions are implemented - in C++ side and some functions are implemented in Python layer. + Tensor inherits tensor object in C++. + Some functions are implemented in C++ and some functions are implemented in Python. Args: input_data (Tensor, float, int, bool, tuple, list, numpy.ndarray): Input data of the tensor. - dtype (:class:`mindspore.dtype`): Should be None, bool or numeric type defined in `mindspore.dtype`. + dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`. The argument is used to define the data type of the output tensor. If it is None, the data type of the output tensor will be as same as the `input_data`. Default: None. @@ -45,13 +44,13 @@ class Tensor(Tensor_): Tensor, with the same shape as `input_data`. Examples: - >>> # init a tensor with input data + >>> # initialize a tensor with input data >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32) >>> assert isinstance(t1, Tensor) >>> assert t1.shape == (1, 2, 3) >>> assert t1.dtype == mindspore.float32 >>> - >>> # init a tensor with a float scalar + >>> # initialize a tensor with a float scalar >>> t2 = Tensor(0.1) >>> assert isinstance(t2, Tensor) >>> assert t2.dtype == mindspore.float64 @@ -75,7 +74,7 @@ class Tensor(Tensor_): self._virtual_flag = False def __repr__(self): - return str(self.__str__()) + return str(Tensor_.__str__(self)) def __add__(self, other): out = tensor_operator_registry.get('__add__')(self, other) @@ -108,6 +107,14 @@ class Tensor(Tensor_): out = tensor_operator_registry.get('__neg__')(self) return out + def __bool__(self): + data = self.asnumpy() + if data.shape == (): + return bool(data) + if data.shape == (1,): + return bool(data[0]) + raise ValueError("The truth value of an array with several elements is ambiguous.") + def __pos__(self): return self @@ -181,6 +188,9 @@ class Tensor(Tensor_): def __imod__(self, other): return self.__mod__(other) + def __rmod__(self, other): + return tensor_operator_registry.get('__mod__')(other, self) + def __pow__(self, other): return tensor_operator_registry.get('__pow__')(self, other) @@ -190,11 +200,24 @@ class Tensor(Tensor_): def __ifloordiv__(self, other): return self.__floordiv__(other) + def __rfloordiv__(self, other): + return tensor_operator_registry.get('__floordiv__')(other, self) + def __str__(self): if self.dtype == mstype.type_none: return "Unknown Tensor type!" return str(self.asnumpy()) + @property + def shape(self): + """The shape of tensor.""" + return self._shape + + @property + def dtype(self): + """The dtype of tensor.""" + return self._dtype + @property def virtual_flag(self): """Mark tensor is virtual.""" @@ -207,7 +230,153 @@ class Tensor(Tensor_): raise TypeError("virtual_flag must be bool.") self._virtual_flag = value + def asnumpy(self): + """Convert tensor to numpy array.""" + return Tensor_.asnumpy(self) + + def all(self, axis=(), keep_dims=False): + """ + Check all array elements along a given axis evaluate to True. + + Args: + axis (Union[None, int, tuple(int)): Dimensions of reduction. + Default: (), reduce all dimensions. + keep_dims (bool): Whether to keep the reduced dimensions. + Default : False, don't keep these reduced dimensions. + + Returns: + Tensor, has the same data type as x. + """ + + return tensor_operator_registry.get('all')(keep_dims)(self, axis) + + def any(self, axis=(), keep_dims=False): + """ + Check any array element along a given axis evaluate to True. + + Args: + axis (Union[None, int, tuple(int)): Dimensions of reduction. + Default: (), reduce all dimensions. + keep_dims (bool): Whether to keep the reduced dimensions. + Default : False, don't keep these reduced dimensions. + + Returns: + Tensor, has the same data type as x. + """ + + return tensor_operator_registry.get('any')(keep_dims)(self, axis) + + +class RowTensor: + """ + A sparse representation of a set of tensor slices at given indices. + + An RowTensor is typically used to represent a subset of a larger + tensor dense of shape [L0, D1, .. , DN] where L0 >> D0. + + The values in indices are the indices in the first dimension of the slices + that have been extracted from the larger tensor. + + The dense tensor dense represented by an RowTensor slices has + `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. + + RowTensor can only be used in the `Cell`'s contruct method. + + It is not supported in pynative mode at the moment. + + Args: + indices (Tensor): A 1-D integer Tensor of shape [D0]. + values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn]. + dense_shape (tuple): An integer tuple which contains the shape + of the corresponding dense tensor. + + Returns: + RowTensor, composed of `indices`, `values`, and `dense_shape`. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self, dense_shape): + >>> super(Net, self).__init__() + >>> self.dense_shape = dense_shape + >>> def construct(self, indices, values): + >>> x = RowTensor(indices, values, self.dense_shape) + >>> return x.values, x.indices, x.dense_shape + >>> + >>> indices = Tensor([0]) + >>> values = Tensor([[1, 2]], dtype=ms.float32) + >>> Net((3, 2))(indices, values) + """ + + def __init__(self, indices, values, dense_shape): + "Init RowTensor" + self.__indices = indices + self.__values = values + self.__dense_shape = dense_shape + + @property + def indices(self): + return self.__indices + + @property + def values(self): + return self.__values + + @property + def dense_shape(self): + return self.__dense_shape + + +class SparseTensor: + """ + A sparse representation of a set of nonzero elememts from a tensor at given indices. + + SparseTensor can only be used in the `Cell`'s construct method. + + Pynative mode not supported at the moment. + + For a tensor dense, its SparseTensor(indices, values, dense_shape) has + `dense[indices[i]] = values[i]`. + + Args: + indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`, + where N and ndims are the number of values and number of dimensions in + the SparseTensor, respectively. + values (Tensor): A 1-D tensor of any type and shape `[N]`, which + supplies the values for each element in indices. + dense_shape (tuple): A integer tuple of size `ndims`, + which specifies the dense_shape of the sparse tensor. + + Returns: + SparseTensor, composed of `indices`, `values`, and `dense_shape`. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self, dense_shape): + >>> super(Net, self).__init__() + >>> self.dense_shape = dense_shape + >>> def construct(self, indices, values): + >>> x = SparseTensor(indices, values, self.dense_shape) + >>> return x.values, x.indices, x.dense_shape + >>> + >>> indices = Tensor([[0, 1], [1, 2]]) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> Net((3, 4))(indices, values) + """ -class IndexedSlices: def __init__(self, indices, values, dense_shape): - raise NotImplementedError + "Init SparseTensor" + self.__indices = indices + self.__values = values + self.__dense_shape = dense_shape + + @property + def indices(self): + return self.__indices + + @property + def values(self): + return self.__values + + @property + def dense_shape(self): + return self.__dense_shape diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 5e1f7d06e7..920488cee4 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -113,6 +113,8 @@ def check_parameter_available(func): Wrapper. If not available, raise Error. """ def wrapper(*args, **kargs): + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + return func(*args, **kargs) group = None if "group" in kargs.keys(): group = kargs.get("group") @@ -161,10 +163,7 @@ def _get_rank_helper(group, backend): else: rank_id = hccl.get_rank_id(group) elif backend == Backend.NCCL: - if group == NCCL_WORLD_COMM_GROUP: - rank_id = mpi.get_rank_id() - else: - raise RuntimeError("Nccl doesn't support get_rank_id by user group now.") + rank_id = mpi.get_rank_id(group) else: raise ValueError("Invalid backend: '{}'".format(backend)) return rank_id @@ -223,10 +222,7 @@ def _get_size_helper(group, backend): else: size = hccl.get_rank_size(group) elif backend == Backend.NCCL: - if group == NCCL_WORLD_COMM_GROUP: - size = mpi.get_rank_size() - else: - raise RuntimeError("Nccl doesn't support get_rank_size by user group now.") + size = mpi.get_rank_size(group) else: raise ValueError("Invalid backend: '{}'".format(backend)) return size diff --git a/mindspore/context.py b/mindspore/context.py index 0de6084caf..eecdc291bf 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -234,22 +234,6 @@ class _Context: if not success: raise RuntimeError("Device id set failed!!!") - @property - def save_ms_model(self): - return self._context_handle.get_save_ms_model_flag() - - @save_ms_model.setter - def save_ms_model(self, save_ms_model_flag): - self._context_handle.set_save_ms_model_flag(save_ms_model_flag) - - @property - def save_ms_model_path(self): - return self._context_handle.get_save_ms_model_path() - - @save_ms_model_path.setter - def save_ms_model_path(self, save_ms_model_path): - self._context_handle.set_save_ms_model_path(save_ms_model_path) - @property def enable_auto_mixed_precision(self): return self._context_handle.get_auto_mixed_precision_flag() @@ -443,7 +427,7 @@ def _context(): @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, - strategy_ckpt_save_file=str, full_batch=bool) + strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -487,6 +471,9 @@ def set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. + enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in + data parallel training in the benefit of time and memory saving. + Raises: ValueError: If input key is not attribute in auto parallel context. @@ -532,12 +519,13 @@ def reset_auto_parallel_context(): - parameter_broadcast: False. - strategy_ckpt_load_file: "". - strategy_ckpt_save_file: "". + - enable_parallel_optimizer: False. """ _reset_auto_parallel_context() @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, - save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, + save_graphs_path=str, enable_dump=bool, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, @@ -565,8 +553,6 @@ def set_context(**kwargs): device_id (int): Id of target device, the value must be in [0, device_num_per_host-1], while device_num_per_host should no more than 4096. Default: 0. save_graphs (bool): Whether to save graphs. Default: False. - save_ms_model (bool): Whether to save lite model converted by graph. Default: False. - save_ms_model_path (str): Path to save converted lite model. Default: "." save_graphs_path (str): Path to save graphs. Default: "." enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True. enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be @@ -611,7 +597,6 @@ def set_context(**kwargs): >>> context.set_context(device_id=0) >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms") >>> context.set_context(enable_reduce_precision=True) - >>> context.set_context(save_ms_model=True, save_ms_model_path=".") >>> context.set_context(enable_dump=True, save_dump_path=".") >>> context.set_context(reserve_class_name_in_scope=True) >>> context.set_context(variable_memory_max_size="6GB") diff --git a/mindspore/core/abstract/abstract_function.cc b/mindspore/core/abstract/abstract_function.cc new file mode 100644 index 0000000000..2d46862af1 --- /dev/null +++ b/mindspore/core/abstract/abstract_function.cc @@ -0,0 +1,337 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/abstract_function.h" + +#include + +namespace mindspore { +namespace abstract { +class Evaluator; +class AnalysisEngine; + +AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { + if (func_list.size() == 1) { + return func_list[0]; + } + return std::make_shared(func_list); +} + +AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { + auto this_func = shared_from_base(); + if (other->isa()) { + if (*this_func == *other) { + return this_func; + } + return std::make_shared(this_func, other); + } + auto other_union = dyn_cast(other); + if (other_union->IsSuperSet(this_func)) { + return other; + } + return std::make_shared(this_func, other); +} + +void AbstractFuncAtom::Visit(std::function visit_func) const { + visit_func(const_cast(this)->shared_from_base()); +} + +bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; } + +AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; } + +AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) { + AbstractFuncAtomPtrList new_func_list; + auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); }; + + first->Visit(build_func_list); + second->Visit(build_func_list); + func_list_ = new_func_list; +} + +std::string AbstractFuncUnion::ToString() const { + std::ostringstream buffer; + buffer << "AbstractFuncUnion({"; + int i = 0; + for (const auto &func : func_list_) { + MS_EXCEPTION_IF_NULL(func); + buffer << "[" << i << "]: " << func->ToString() << ", "; + i++; + } + buffer << "})"; + return buffer.str(); +} + +bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) { + MS_EXCEPTION_IF_NULL(other); + std::vector is_in_list; + auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) { + auto iter = find(func_list_.begin(), func_list_.end(), func); + if (iter == func_list_.end()) { + is_in_list.push_back(false); + } + return true; + }; + other->Visit(build_in_list); + return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; }); +} + +AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { + auto this_func = shared_from_base(); + if (other->isa()) { + if (IsSuperSet(other)) { + return this_func; + } + return std::make_shared(this_func, other); + } + auto other_union = dyn_cast(other); + if (other_union->IsSuperSet(this_func)) { + return other; + } + return std::make_shared(this_func, other); +} + +void AbstractFuncUnion::Visit(std::function visit_func) const { + for (AbstractFuncAtomPtr poss : func_list_) { + visit_func(poss); + } +} + +bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_union = static_cast(&other); + if (func_list_.size() != other_union->func_list_.size()) { + return false; + } + if (func_list_ == other_union->func_list_) { + return true; + } + return false; +} + +std::size_t AbstractFuncUnion::hash() const { + std::size_t hash_sum = 0; + for (auto f : func_list_) { + hash_sum = hash_combine(hash_sum, f->hash()); + } + return hash_sum; +} + +bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_prim = static_cast(&other); + if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { + return true; + } + return false; +} + +std::size_t PrimitiveAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), prim_->hash()); + // Keep in sync with operator==() which compares the prim_ pointer; + hash_value = hash_combine(hash_value, std::hash{}(prim_.get())); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } + return hash_value; +} + +bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_fg = static_cast(&other); + if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && + tracking_id() == other_fg->tracking_id()) { + return true; + } + return false; +} + +std::size_t FuncGraphAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), func_graph_->hash()); + hash_value = hash_combine(hash_value, context_->hash()); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } + return hash_value; +} +std::string FuncGraphAbstractClosure::ToString() const { + std::stringstream ss; + ss << "FuncGraphAbstractClosure: " + << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); + return ss.str(); +} + +bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_meta_fg = static_cast(&other); + if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) { + return true; + } + return false; +} + +std::size_t MetaFuncGraphAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); + if (tracking_id() != nullptr) { + hash_value = hash_combine(hash_value, tracking_id()->hash()); + } + return hash_value; +} + +std::string MetaFuncGraphAbstractClosure::ToString() const { + return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name(); +} + +bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_partial = static_cast(&other); + if (fn_ != other_partial->fn_) { + return false; + } + if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_partial->args_spec_list_) { + return true; + } + return false; +} + +std::size_t PartialAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), fn_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string PartialAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; + for (auto arg : args_spec_list_) { + buffer << arg->ToString() << ", "; + } + buffer << "))"; + return buffer.str(); +} + +bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_transformed = static_cast(&other); + if (fn_ == other_transformed->fn_) { + return true; + } + return false; +} + +std::size_t JTransformedAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), fn_->hash()); + return hash_value; +} + +bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_virtual = static_cast(&other); + if (output_ != other_virtual->output_) { + return false; + } + if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_virtual->args_spec_list_) { + return true; + } + return false; +} + +std::size_t VirtualAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), output_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string VirtualAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "VirtualAbstractClosure(args: {"; + int i = 0; + for (const auto &arg : args_spec_list_) { + MS_EXCEPTION_IF_NULL(arg); + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + buffer << "}, output: " << output_->ToString() << ")"; + return buffer.str(); +} + +bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_typed = static_cast(&other); + if (output_ != other_typed->output_) { + return false; + } + if (prim_ != other_typed->prim_) { + return false; + } + if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_typed->args_spec_list_) { + return true; + } + return false; +} + +std::size_t TypedPrimitiveAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), prim_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string TypedPrimitiveAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {"; + int i = 0; + for (const auto &arg : args_spec_list_) { + MS_EXCEPTION_IF_NULL(arg); + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + buffer << "}, output: " << output_->ToString() << ")"; + return buffer.str(); +} + +bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + return true; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h new file mode 100644 index 0000000000..79a1d6c1d7 --- /dev/null +++ b/mindspore/core/abstract/abstract_function.h @@ -0,0 +1,301 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ +#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ + +#include +#include + +#include "abstract/abstract_value.h" +#include "abstract/analysis_context.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +namespace abstract { +class AbstractFuncAtom : public AbstractFunction { + public: + AbstractFuncAtom() = default; + ~AbstractFuncAtom() override = default; + MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) + + AbstractFunctionPtr GetUnique() override { return shared_from_base(); } + AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; + void Visit(std::function) const final; + bool operator==(const AbstractFunction &other) const override; + + std::size_t hash() const override { return tid(); } +}; + +class AbstractFuncUnion : public AbstractFunction { + public: + explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list); + AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second); + ~AbstractFuncUnion() override = default; + MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction) + + std::string ToString() const override; + + AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } + bool IsSuperSet(const AbstractFunctionPtr &other); + AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; + void Visit(std::function) const final; + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } + + private: + AbstractFuncAtomPtrList func_list_; +}; + +class PrimitiveAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Primitive. + // prim: The primitive + // tracking_id: Identifies different uses of the same primitive. + explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr) + : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {} + ~PrimitiveAbstractClosure() override = default; + MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) + + PrimitivePtr prim() { return prim_; } + + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + + void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } + + AbstractFunctionPtr Copy() const override { return std::make_shared(prim_, tracking_id()); } + + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override { return "Prim: " + prim_->name(); } + + private: + PrimitivePtr prim_; + // store it as weak_ptr to break reference cycle. + // one reference cycle example is Graph::set_output() input0 local variable. + AnfNodeWeakPtr tracking_id_; +}; +using PrimitiveAbstractClosurePtr = std::shared_ptr; + +class FuncGraphAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Graph in a certain Context. + // context: The context, or Context.empty() + FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const AnfNodePtr &tracking_id = nullptr) + : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(context); + } + ~FuncGraphAbstractClosure() override = default; + MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) + + FuncGraphPtr func_graph() { return func_graph_; } + + AnalysisContextPtr context() const override { return context_; } + + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + + AbstractFunctionPtr Copy() const override { + return std::make_shared(func_graph_, context_, tracking_id()); + } + + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + FuncGraphPtr func_graph_; + AnalysisContextPtr context_; + // To discriminate different usage of same graph by using this tracking_id, + // so different tracking_id will produce different FuncGraphAbstractClosure, + // different FuncGraphEvaluator. + // Espcecially usefull for recursive func graph call, so it will not mess up + // the graph_context_ in FuncGraphEvaluator. + // Notes: Be careful to use nullptr for this variable. + // store it as weak_ptr to break reference cycle. + AnfNodeWeakPtr tracking_id_; +}; +using FuncGraphAbstractClosurePtr = std::shared_ptr; + +class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { + public: + explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, + const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope) + : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {} + ~MetaFuncGraphAbstractClosure() override = default; + MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) + + MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; } + + AnalysisContextPtr context() const override { return kDummyAnalysisContext; } + + ScopePtr GetScope() { return scope_; } + + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + + AbstractFunctionPtr Copy() const override { + return std::make_shared(meta_func_graph_, tracking_id()); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + MetaFuncGraphPtr meta_func_graph_; + // refer the comment in FuncGraphAbstractClosure; + // store it as weak_ptr to break reference cycle. + AnfNodeWeakPtr tracking_id_; + ScopePtr scope_; +}; +using MetaFuncGraphAbstractClosurePtr = std::shared_ptr; + +class PartialAbstractClosure : public AbstractFuncAtom { + public: + // Represents a partial application. + // args_spec_list: The first few arguments of that function + PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, + const AnfNodePtr &node = nullptr) + : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} + ~PartialAbstractClosure() override = default; + MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) + + AbstractFunctionPtr fn() { return fn_; } + AbstractBasePtrList args() { return args_spec_list_; } + AnfNodePtr node() { return node_.lock(); } + void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } + AbstractFunctionPtr Copy() const override { + return std::make_shared(fn_, args_spec_list_, node_.lock()); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + AbstractFuncAtomPtr fn_; + AbstractBasePtrList args_spec_list_; + // The CNode which this PartialAbstractClosure evaluated from. + AnfNodeWeakPtr node_; +}; + +class JTransformedAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Function transformed through the application of J. + explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} + ~JTransformedAbstractClosure() override = default; + MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) + + AbstractFuncAtomPtr fn() { return fn_; } + AbstractFunctionPtr Copy() const override { return std::make_shared(fn_); } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override { return "J(" + fn_->ToString() + ")"; } + + private: + AbstractFuncAtomPtr fn_; +}; + +class VirtualAbstractClosure : public AbstractFuncAtom { + public: + // Represents some function with an explicitly fixed type signature. + // args_spec_list: The arguments as abstract value given to the function + // output: The output which is abstract value. + VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec) + : args_spec_list_(args_spec_list), output_(output_spec) {} + VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec) + : args_spec_list_({args_spec}), output_(output_spec) {} + ~VirtualAbstractClosure() override = default; + MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) + + AbstractBasePtrList args_spec_list() { return args_spec_list_; } + + AbstractBasePtr output() { return output_; } + AbstractFunctionPtr Copy() const override { + return std::make_shared(args_spec_list_, output_); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + AbstractBasePtrList args_spec_list_; + AbstractBasePtr output_; +}; +using VirtualAbstractClosurePtr = std::shared_ptr; + +class TypedPrimitiveAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Primitive with an explicitly fixed type signature. + // args_spec_list: The arguments as abstract value given to the Primitive + // output: The output which is abstract value. + TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list, + const AbstractBasePtr &output_spec) + : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {} + ~TypedPrimitiveAbstractClosure() override = default; + MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) + + PrimitivePtr prim() { return prim_; } + AbstractBasePtrList args_spec_list() { return args_spec_list_; } + AbstractBasePtr output() { return output_; } + AbstractFunctionPtr Copy() const override { + return std::make_shared(prim_, args_spec_list_, output_); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + PrimitivePtr prim_; + AbstractBasePtrList args_spec_list_; + AbstractBasePtr output_; +}; + +// Represents a function that can't be called. +class DummyAbstractClosure : public AbstractFuncAtom { + public: + DummyAbstractClosure() = default; + ~DummyAbstractClosure() override = default; + MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) + + AbstractFunctionPtr Copy() const override { return std::make_shared(); } + bool operator==(const AbstractFunction &other) const override; + + std::string ToString() const override { return "DummyAbstractClosure()"; } +}; + +struct AbstractFunctionHasher { + std::size_t operator()(const AbstractFunctionPtr &t) const { + std::size_t hash = t->hash(); + return hash; + } +}; + +struct AbstractFunctionEqual { + bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; } +}; +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 7bef3829a6..efdf12452b 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -38,9 +38,24 @@ bool AbstractBase::operator==(const AbstractBase &other) const { << this->ToString() << ", other: " << other.ToString(); } - bool value_equal = *value_ == *other.value_; - bool type_equal = *type_ == *other.type_; - bool shape_equal = *shape_ == *other.shape_; + bool value_equal = false; + if (value_ == other.value_) { + value_equal = true; + } else if (*value_ == *other.value_) { + value_equal = true; + } + bool type_equal = false; + if (type_ == other.type_) { + type_equal = true; + } else if (*type_ == *other.type_) { + type_equal = true; + } + bool shape_equal = false; + if (shape_ == other.shape_) { + shape_equal = true; + } else if (*shape_ == *other.shape_) { + shape_equal = true; + } return value_equal && type_equal && shape_equal; } @@ -1035,16 +1050,75 @@ bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const return AbstractBasePtrListDeepEqual(lhs, rhs); } -// IndexedSlices -TypePtr AbstractIndexedSlices::BuildType() const { +// RowTensor +TypePtr AbstractRowTensor::BuildType() const { + MS_EXCEPTION_IF_NULL(element()); + TypePtr element_type = element()->BuildType(); + return std::make_shared(element_type); +} + +AbstractBasePtr AbstractRowTensor::Clone() const { + MS_EXCEPTION_IF_NULL(element()); + auto clone = std::make_shared(element()->Clone()); + ShapePtr shp = shape(); + clone->set_shape(shp->Clone()); + clone->set_value(GetValueTrack()); + clone->set_indices(indices_->Clone()->cast()); + clone->set_values(values_->Clone()->cast()); + clone->set_dense_shape(dense_shape_->Clone()->cast()); + return clone; +} + +AbstractBasePtr AbstractRowTensor::Broaden() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape(); + broaden->set_shape(shp->Clone()); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +AbstractBasePtr AbstractRowTensor::BroadenWithShape() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape()->Clone(); + shp->Broaden(); + broaden->set_shape(shp); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +std::string AbstractRowTensor::ToString() const { + std::ostringstream buffer; + BaseShapePtr shape_track = GetShapeTrack(); + MS_EXCEPTION_IF_NULL(shape_track); + MS_EXCEPTION_IF_NULL(element()); + auto value_track = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + buffer << type_name() << "(" + << "shape: " << shape_track->ToString() << ", element: " << element()->ToString() + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")" + << ", indices: " << indices_->ToString() << ", values" << values_->ToString() + << ", dense_shape: " << dense_shape_->ToString(); + return buffer.str(); +} + +// SparseTensor +TypePtr AbstractSparseTensor::BuildType() const { MS_EXCEPTION_IF_NULL(element()); TypePtr element_type = element()->BuildType(); - return std::make_shared(element_type); + return std::make_shared(element_type); } -AbstractBasePtr AbstractIndexedSlices::Clone() const { +AbstractBasePtr AbstractSparseTensor::Clone() const { MS_EXCEPTION_IF_NULL(element()); - auto clone = std::make_shared(element()->Clone()); + auto clone = std::make_shared(element()->Clone()); ShapePtr shp = shape(); clone->set_shape(shp->Clone()); clone->set_value(GetValueTrack()); @@ -1054,9 +1128,9 @@ AbstractBasePtr AbstractIndexedSlices::Clone() const { return clone; } -AbstractBasePtr AbstractIndexedSlices::Broaden() const { +AbstractBasePtr AbstractSparseTensor::Broaden() const { MS_EXCEPTION_IF_NULL(element()); - auto broaden = std::make_shared(element()->Broaden()); + auto broaden = std::make_shared(element()->Broaden()); auto shp = shape(); broaden->set_shape(shp->Clone()); broaden->set_value(kAnyValue); @@ -1066,9 +1140,9 @@ AbstractBasePtr AbstractIndexedSlices::Broaden() const { return broaden; } -AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { +AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const { MS_EXCEPTION_IF_NULL(element()); - auto broaden = std::make_shared(element()->Broaden()); + auto broaden = std::make_shared(element()->Broaden()); auto shp = shape()->Clone(); shp->Broaden(); broaden->set_shape(shp); @@ -1079,7 +1153,7 @@ AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { return broaden; } -std::string AbstractIndexedSlices::ToString() const { +std::string AbstractSparseTensor::ToString() const { std::ostringstream buffer; BaseShapePtr shape_track = GetShapeTrack(); MS_EXCEPTION_IF_NULL(shape_track); diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index d922f93e70..faf80c639b 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_ABSTRACT_ABSTRACT_VALUE_H_ -#define MINDSPORE_CCSRC_ABSTRACT_ABSTRACT_VALUE_H_ +#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ +#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ #include #include @@ -27,6 +27,7 @@ #include "utils/log_adapter.h" #include "utils/hashing.h" +#include "utils/any.h" #include "base/base.h" #include "ir/dtype.h" #include "ir/value.h" @@ -193,7 +194,6 @@ class AbstractFunction : public AbstractBase { static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); - virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0; virtual AnfNodePtr tracking_id() const { return nullptr; } virtual void set_tracking_id(AnfNodePtr) {} virtual AnalysisContextPtr context() const { return nullptr; } @@ -593,21 +593,50 @@ struct AbstractBasePtrListEqual { std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); -// IndexedSlices -class AbstractIndexedSlices : public AbstractUndetermined { +// RowTensor +class AbstractRowTensor : public AbstractUndetermined { public: - explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) : AbstractUndetermined(element, shape) {} - AbstractIndexedSlices(const TypePtr &element_type, const std::vector &shape) + AbstractRowTensor(const TypePtr &element_type, const std::vector &shape) : AbstractUndetermined(element_type, shape) {} - ~AbstractIndexedSlices() override = default; - MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) + ~AbstractRowTensor() override = default; + MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined) const AbstractTensorPtr indices() const { return indices_; } + void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } const AbstractTensorPtr values() const { return values_; } + void set_values(const AbstractTensorPtr &values) { values_ = values; } const AbstractTuplePtr dense_shape() const { return dense_shape_; } + void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } + TypePtr BuildType() const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + AbstractBasePtr BroadenWithShape() const; + + std::string ToString() const override; + + private: + AbstractTensorPtr indices_; + AbstractTensorPtr values_; + AbstractTuplePtr dense_shape_; +}; + +// SparseTensor +class AbstractSparseTensor : public AbstractUndetermined { + public: + explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractUndetermined(element, shape) {} + AbstractSparseTensor(const TypePtr &element_type, const std::vector &shape) + : AbstractUndetermined(element_type, shape) {} + ~AbstractSparseTensor() override = default; + MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined) + + const AbstractTensorPtr indices() const { return indices_; } void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } + const AbstractTensorPtr values() const { return values_; } void set_values(const AbstractTensorPtr &values) { values_ = values; } + const AbstractTuplePtr dense_shape() const { return dense_shape_; } void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } TypePtr BuildType() const override; AbstractBasePtr Clone() const override; @@ -623,4 +652,4 @@ class AbstractIndexedSlices : public AbstractUndetermined { }; } // namespace abstract } // namespace mindspore -#endif // MINDSPORE_CCSRC_ABSTRACT_ABSTRACT_VALUE_H_ +#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ diff --git a/mindspore/core/abstract/analysis_context.cc b/mindspore/core/abstract/analysis_context.cc index 1ae6125838..228ddf0f54 100644 --- a/mindspore/core/abstract/analysis_context.cc +++ b/mindspore/core/abstract/analysis_context.cc @@ -19,7 +19,7 @@ #include #include "utils/symbolic.h" -#include "debug/trace.h" +#include "utils/trace_base.h" namespace mindspore { namespace abstract { diff --git a/mindspore/core/abstract/analysis_context.h b/mindspore/core/abstract/analysis_context.h index c0293d7e91..40161d7df1 100644 --- a/mindspore/core/abstract/analysis_context.h +++ b/mindspore/core/abstract/analysis_context.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_ABSTRACT_ANALYSIS_CONTEXT_H_ -#define MINDSPORE_CCSRC_ABSTRACT_ANALYSIS_CONTEXT_H_ +#ifndef MINDSPORE_CORE_ABSTRACT_ANALYSIS_CONTEXT_H_ +#define MINDSPORE_CORE_ABSTRACT_ANALYSIS_CONTEXT_H_ #include #include @@ -85,4 +85,4 @@ struct ContextEqual { extern const AnalysisContextPtr kDummyAnalysisContext; } // namespace abstract } // namespace mindspore -#endif // MINDSPORE_CCSRC_ABSTRACT_ANALYSIS_CONTEXT_H_ +#endif // MINDSPORE_CORE_ABSTRACT_ANALYSIS_CONTEXT_H_ diff --git a/mindspore/core/abstract/dshape.cc b/mindspore/core/abstract/dshape.cc index 74ea1ff7bf..a2cbe0fe62 100644 --- a/mindspore/core/abstract/dshape.cc +++ b/mindspore/core/abstract/dshape.cc @@ -67,6 +67,9 @@ std::string Shape::DumpText() const { buffer << "["; for (size_t i = 0; i < shape_.size(); i++) { buffer << (i > 0 ? ", " : "") << shape_[i]; + if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) { + buffer << "_" << min_shape_[i] << "^" << max_shape_[i]; + } } buffer << "]"; return buffer.str(); diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h index b9b8e93292..4197f73ac0 100644 --- a/mindspore/core/abstract/dshape.h +++ b/mindspore/core/abstract/dshape.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_ABSTRACT_DSHAPE_H_ -#define MINDSPORE_CCSRC_ABSTRACT_DSHAPE_H_ +#ifndef MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ +#define MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ #include #include @@ -25,6 +25,7 @@ #include #include #include +#include #include "utils/log_adapter.h" #include "base/base.h" @@ -63,17 +64,32 @@ class Shape : public BaseShape { static const int SHP_ANY = -1; Shape() : shape_() {} Shape(const std::initializer_list &list) : shape_(list) {} + Shape(const std::initializer_list &list) { + std::vector list_in(list); + (void)std::transform(list_in.begin(), list_in.end(), std::back_inserter(shape_), + [](const int64_t &value) { return static_cast(value); }); + } explicit Shape(const std::vector &list) : shape_(list) {} + explicit Shape(const std::vector &list) { + (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_), + [](const int64_t &value) { return static_cast(value); }); + } + Shape(const std::vector &list, const std::vector &min_shape, const std::vector &max_shape) + : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {} ~Shape() override = default; MS_DECLARE_PARENT(Shape, BaseShape) std::string ToString() const override; std::string DumpText() const override; bool operator==(const BaseShape &other) const override; - BaseShapePtr Clone() const override { return std::make_shared(shape_); } + BaseShapePtr Clone() const override { return std::make_shared(shape_, min_shape_, max_shape_); } void Broaden() override; std::vector &shape() { return shape_; } + std::vector &min_shape() { return min_shape_; } + std::vector &max_shape() { return max_shape_; } - std::vector shape_; // use SHP_ANY to implement the any shape in python + std::vector shape_; // use SHP_ANY to implement the any shape in python + std::vector min_shape_; // record mininum length for each dynamic dimention + std::vector max_shape_; // record maximum length for each dynamic dimention }; using ShapePtr = std::shared_ptr; using ShapePtrList = std::vector; @@ -132,4 +148,4 @@ using ListShapePtr = std::shared_ptr; } // namespace abstract } // namespace mindspore -#endif // MINDSPORE_CCSRC_ABSTRACT_DSHAPE_H_ +#endif // MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h index 434235abda..0b49a7c30f 100644 --- a/mindspore/core/abstract/param_validator.h +++ b/mindspore/core/abstract/param_validator.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_ABSTRACT_PARAM_VALIDATOR_H_ -#define MINDSPORE_CCSRC_ABSTRACT_PARAM_VALIDATOR_H_ +#ifndef MINDSPORE_CORE_ABSTRACT_PARAM_VALIDATOR_H_ +#define MINDSPORE_CORE_ABSTRACT_PARAM_VALIDATOR_H_ #include #include @@ -66,7 +66,8 @@ ABSTRACT_REPORT_NAME_TRAITS(Function) ABSTRACT_REPORT_NAME_TRAITS(Type) ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) ABSTRACT_REPORT_NAME_TRAITS(Class) -ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) +ABSTRACT_REPORT_NAME_TRAITS(RowTensor) +ABSTRACT_REPORT_NAME_TRAITS(SparseTensor) ABSTRACT_REPORT_NAME_TRAITS(Sequeue) template @@ -97,4 +98,4 @@ void CheckArgsSpec(const AbstractBasePtrList &args_list) { } // namespace abstract } // namespace mindspore -#endif // MINDSPORE_CCSRC_ABSTRACT_PARAM_VALIDATOR_H_ +#endif // MINDSPORE_CORE_ABSTRACT_PARAM_VALIDATOR_H_ diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 16497c74a9..20eeab0de5 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { return shape1; } std::vector dims; + bool has_dynamic_shape = false; dims.resize(shape1->shape().size()); for (std::size_t i = 0; i < shape1->shape().size(); i++) { if (shape1->shape()[i] == shape2->shape()[i]) { dims[i] = shape1->shape()[i]; + if (shape1->shape()[i] == Shape::SHP_ANY) { + has_dynamic_shape = true; + } } else { dims[i] = Shape::SHP_ANY; + has_dynamic_shape = true; } } - return std::make_shared(dims); + if (!has_dynamic_shape) { + return std::make_shared(dims); + } + // calculate dynamic shape + std::vector min_dims(dims.size()); + std::vector max_dims(dims.size()); + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] != Shape::SHP_ANY) { + min_dims[i] = max_dims[i] = dims[i]; + continue; + } + if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { + min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]); + max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]); + continue; + } + if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { + if (shape1->min_shape().empty() || shape1->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]); + max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]); + continue; + } + if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) { + if (shape2->min_shape().empty() || shape2->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]); + max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]); + continue; + } + // both shapes contains dynamic shape + if (shape1->min_shape().empty() || shape1->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + if (shape2->min_shape().empty() || shape2->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]); + max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]); + } + return std::make_shared(dims, min_dims, max_dims); } AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h index be38ae860d..75ba63aa0b 100644 --- a/mindspore/core/abstract/utils.h +++ b/mindspore/core/abstract/utils.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_ABSTRACT_UTILS_H_ -#define MINDSPORE_CCSRC_ABSTRACT_UTILS_H_ +#ifndef MINDSPORE_CORE_ABSTRACT_UTILS_H_ +#define MINDSPORE_CORE_ABSTRACT_UTILS_H_ #include #include @@ -26,7 +26,6 @@ #include "abstract/abstract_value.h" #include "utils/any.h" #include "utils/misc.h" -#include "utils/convert_utils.h" namespace mindspore { namespace abstract { @@ -53,4 +52,4 @@ int GetPositiveAxis(int axis_value, size_t increment); ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); } // namespace abstract } // namespace mindspore -#endif // MINDSPORE_CCSRC_ABSTRACT_UTILS_H_ +#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_ diff --git a/mindspore/core/base/base.h b/mindspore/core/base/base.h index 8e1a447c0d..e43b042cfa 100644 --- a/mindspore/core/base/base.h +++ b/mindspore/core/base/base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BASE_BASE_H_ -#define MINDSPORE_CCSRC_BASE_BASE_H_ +#ifndef MINDSPORE_CORE_BASE_BASE_H_ +#define MINDSPORE_CORE_BASE_BASE_H_ #include #include @@ -131,6 +131,11 @@ class AnfNode; using AnfNodePtr = std::shared_ptr; using AnfNodePtrList = std::vector; using AnfNodeSet = OrderedSet; +using AnfNodeWeakPtr = std::weak_ptr; + +class FuncGraph; +using FuncGraphPtr = std::shared_ptr; +using FuncGraphWeakPtr = std::weak_ptr; namespace abstract { class AbstractBase; @@ -149,4 +154,4 @@ struct MS_EXPORT TypeIdManager { }; } // namespace mindspore -#endif // MINDSPORE_CCSRC_BASE_BASE_H_ +#endif // MINDSPORE_CORE_BASE_BASE_H_ diff --git a/mindspore/core/base/base_ref.cc b/mindspore/core/base/base_ref.cc new file mode 100644 index 0000000000..d7d6ea747c --- /dev/null +++ b/mindspore/core/base/base_ref.cc @@ -0,0 +1,188 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "base/base_ref.h" + +namespace mindspore { +iterator ConstIteratorCast(std::vector *v, const const_iterator iter) { + return std::next(v->begin(), std::distance(v->cbegin(), iter)); +} + +BaseRef::BaseRef(const BaseRef &other) : Base(other), m_ptr(other.m_ptr) { + if (!m_ptr) { + m_ptr = other.copy(); + } +} + +bool BaseRef::operator==(const BaseRef &other) const { + if (m_ptr == other.m_ptr) { + return true; + } + if (m_ptr == nullptr && other.m_ptr == nullptr) { + return *this == other; + } + if (m_ptr == nullptr || other.m_ptr == nullptr) { + return false; + } + if (type() != other.type()) { + MS_LOG(DEBUG) << "Type mismatch"; + return false; + } + if (m_ptr->isa()) { + return *(m_ptr->cast()) == *(other.m_ptr->cast()); + } + + // for noderef equal + if (m_ptr->isa()) { + return *std::static_pointer_cast(m_ptr) == *std::static_pointer_cast(other.m_ptr); + } + + // for node equal + return *m_ptr == *other.m_ptr; +} + +// left reference +BaseRef &BaseRef::operator=(const BaseRef &other) { + if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { + return *this; + } + m_ptr = other.copy(); + return *this; +} + +// right reference +BaseRef &BaseRef::operator=(BaseRef &&other) { + if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { + return *this; + } + m_ptr = other.copy(); + other.m_ptr = nullptr; + return *this; +} + +std::string BaseRef::ToString() const { + if (m_ptr != nullptr) { + return std::string(m_ptr->type_name()) + std::string(" value:") + m_ptr->ToString(); + } + return std::string(); +} + +uint32_t BaseRef::type() const { + if (m_ptr != nullptr) { + return m_ptr->tid(); + } + return tid(); +} + +// left reference +SetRef &SetRef::operator=(const SetRef &other) { + if (elements_ == other.elements_ || this == &other) { + return *this; + } + elements_ = other.elements_; + return *this; +} + +std::string SetRef::ToString() const { + std::ostringstream buffer; + bool begin = true; + buffer << "set["; + for (auto &attr : elements_) { + if (!begin) { + buffer << ", "; + } else { + begin = false; + } + buffer << attr.ToString(); + } + buffer << "]"; + return buffer.str(); +} + +// left reference +VectorRef &VectorRef::operator=(const VectorRef &other) { + if (elements_ == other.elements_ || this == &other) { + return *this; + } + elements_ = other.elements_; + return *this; +} + +std::string VectorRef::ToString() const { + std::ostringstream buffer; + bool begin = true; + buffer << "vector["; + for (auto &attr : elements_) { + if (!begin) { + buffer << ", "; + } else { + begin = false; + } + buffer << attr.ToString(); + } + buffer << "]"; + return buffer.str(); +} + +bool VectorRef::operator==(const BaseRef &other) const { + if (!utils::isa(other)) { + return false; + } + return *this == utils::cast(other); +} + +bool VectorRef::operator==(const VectorRef &other) const { + if (elements_.size() != other.elements_.size()) { + return false; + } + for (size_t i = 0; i < elements_.size(); ++i) { + if (elements_[i] != other.elements_[i]) { + return false; + } + } + return true; +} + +bool SetRef::operator==(const BaseRef &other) const { + if (!utils::isa(other)) { + return false; + } + return *this == utils::cast(other); +} + +bool SetRef::operator==(const SetRef &other) const { + if (elements_.size() != other.elements_.size()) { + return false; + } + auto iter = elements_.begin(); + auto oth_iter = other.elements_.begin(); + for (; iter != elements_.end(); iter++, oth_iter++) { + if (*iter != *oth_iter) { + return false; + } + } + return true; +} + +bool RunFunctionRef::operator==(const BaseRef &other) const { + if (!utils::isa(other)) { + return false; + } + return *this == utils::cast(other); +} + +bool RunFunctionRef::operator==(const RunFunctionRef &other) const { return func_ == other.func_; } +} // namespace mindspore diff --git a/mindspore/core/base/base_ref.h b/mindspore/core/base/base_ref.h new file mode 100644 index 0000000000..05157a4020 --- /dev/null +++ b/mindspore/core/base/base_ref.h @@ -0,0 +1,381 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_UTILS_BASE_REF_H_ +#define MINDSPORE_CORE_UTILS_BASE_REF_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/value.h" + +namespace mindspore { +class BaseRef; +class VectorRef; +class SetRef; +class RunFunctionRef; + +using iterator = std::vector::iterator; +using const_iterator = std::vector::const_iterator; +using const_reverse_iterator = std::vector::const_reverse_iterator; + +using RunFunc = std::function; +using RunFuncPtr = std::shared_ptr; + +template +using remove_reference_t = typename std::remove_reference::type; +template +using remove_const_t = typename std::remove_const::type; +template +using is_base = std::is_base_of>; +template +using is_value = std::is_base_of>; +template +using is_base_ref = std::is_base_of>; + +iterator ConstIteratorCast(std::vector *v, const_iterator iter); + +inline std::shared_ptr MakeNode(const std::vector &elements) { + return std::make_shared(elements); +} + +inline std::shared_ptr MakeNode(std::initializer_list elements) { + return std::make_shared(elements); +} + +// Anfnode, Funcgraph and some not value node class +template >::value && is_base::value, + int>::type = 0> +inline BasePtr MakeNode(const T &v) { + return v; +} + +template >::value && !is_base_ref::value, int>::type = 0> +inline BasePtr MakeNode(const T &v) { + return MakeValue(v); +} + +inline std::shared_ptr MakeNode(const VectorRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const AnfNodePtrList &a) { + std::vector ret; + (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; }); + return std::make_shared(ret); +} +inline std::shared_ptr MakeNode(const SetRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const RunFuncPtr &a) { return std::make_shared(a); } + +class BaseRef : public Base { + public: + BaseRef() : m_ptr(nullptr) {} + BaseRef(const BaseRef &other); + virtual std::shared_ptr copy() const { return m_ptr; } + + BaseRef(BaseRef &&other) : Base(other) { + m_ptr = other.m_ptr; + other.m_ptr = nullptr; + } + + // right reference constructor + template ::type, BaseRef>::value, T>::type> + BaseRef(T &&t) { // NOLINT + m_ptr = MakeNode(t); + } + + ~BaseRef() override { m_ptr = nullptr; } + + MS_DECLARE_PARENT(BaseRef, Base) + + bool operator!=(const BaseRef &other) const { return !(operator==(other)); } + + virtual bool operator==(const BaseRef &other) const; + + // left reference + virtual BaseRef &operator=(const BaseRef &other); + // right reference + virtual BaseRef &operator=(BaseRef &&other); + + std::size_t hash() const override { + if (m_ptr == nullptr) { + MS_LOG(ERROR) << "Invalid m_ptr"; + return 0; + } + return m_ptr->hash(); + } + + std::string ToString() const override; + + bool is_null() const { return m_ptr == nullptr; } + + virtual uint32_t type() const; + + BasePtr m_ptr; // point to real data +}; +using BaseRefPtr = std::shared_ptr; + +struct BaseRefHash { + std::size_t operator()(const BaseRef &c) const { return c.hash(); } +}; + +struct BaseRefLess { + bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); } +}; + +namespace utils { +// judge isa relation +// examples: isa(handle), isa(handle) +template ::value && !is_base_ref::value, int>::type = 0> +bool isa(const BaseRef &handle) { + if (!handle.m_ptr) { + return false; + } + return handle.m_ptr->isa(); +} + +// noderef isa ptr isa(x) or isa() +template ::value, typename T::element_type>::type, + typename std::enable_if::value || is_base_ref::value, int>::type = 0> +bool isa(const BaseRef &handle) { + if (handle.m_ptr == nullptr) { + return typeid(handle.m_ptr) == typeid(T); + } + + if (handle.m_ptr->isa()) { + return true; + } + + // constptr isa can be true + return std::dynamic_pointer_cast(handle.m_ptr) != nullptr; +} + +// isa(handle) +template ::type::element_type> +bool isa(const BaseRef &handle) { + if (handle.m_ptr == nullptr) { + return false; + } + return handle.m_ptr->isa(); +} + +// isa(handle), judge reference or ptr +template ::value, int>::type = 0> +bool isa(const BaseRef &handle) { + static const uint32_t tid = Base::GetTypeId(typeid(T).name()); + return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa()); +} + +// valueref -> C++ type +// cast(handle) +template ::value && !is_shared_ptr::value, int>::type = 0> +T cast(const BaseRef &handle) { + T ret = GetValue(std::static_pointer_cast(handle.m_ptr)); + return std::move(ret); +} + +// valueref -> valueref type +// cast(handle) +template ::value, int>::type = 0> +const T &cast(const BaseRef &handle) { + if (handle.m_ptr) { + return static_cast(*handle.m_ptr); + } + + return std::move(static_cast(handle)); +} + +// valueref -> nodeptr type +// cast(handle) +template ::value, typename T::element_type>::type, + typename std::enable_if::value && std::is_base_of::value, + int>::type = 0> +T cast(const BaseRef &handle) { + if (!handle.m_ptr) { + MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null"; + } + + auto m = handle.m_ptr->cast(); + if (nullptr != m) { + return m; + } + return std::static_pointer_cast(handle.m_ptr); +} +} // namespace utils + +class VectorRef : public BaseRef { + public: + using value_type = BaseRef; + + VectorRef() {} + explicit VectorRef(const std::vector &elements) : elements_(elements) {} + VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} + + // left reference + virtual VectorRef &operator=(const VectorRef &other); + + ~VectorRef() override = default; + + std::shared_ptr copy() const override { return std::make_shared(elements_); } + + bool empty() const { return (elements_.size() == 0); } + + std::size_t size() const { return elements_.size(); } + MS_DECLARE_PARENT(VectorRef, BaseRef) + + const BaseRef &operator[](const std::size_t &dim) const { + if (dim >= size()) { + MS_LOG(EXCEPTION) << "Out of the size of the tuple."; + } + return elements_[dim]; + } + + BaseRef &operator[](const std::size_t &dim) { + if (dim >= size()) { + MS_LOG(EXCEPTION) << "Out of the size of the tuple."; + } + return elements_[dim]; + } + + uint32_t type() const override { return tid(); } + std::string ToString() const override; + std::vector &elements() { return elements_; } + void clear() { elements_.clear(); } + + bool operator==(const BaseRef &other) const override; + bool operator==(const VectorRef &other) const; + + void push_back(const BaseRef &value) { elements_.push_back(value); } + void push_back(BaseRef &&value) { elements_.push_back(value); } + + void emplace_back(const BaseRef &value) { elements_.emplace_back(value); } + void emplace_back(BaseRef &&value) { elements_.emplace_back(value); } + + template + void insert(const iterator pos, const InputIt first, const InputIt last) { + (void)elements_.insert(pos, first, last); + } + + template + void insert(const const_iterator cpos, const InputIt first, const InputIt last) { + auto pos = ConstIteratorCast(&elements_, cpos); + (void)elements_.insert(pos, first, last); + } + + const_iterator begin() const { return elements_.begin(); } + const_iterator end() const { return elements_.end(); } + + const_reverse_iterator rbegin() const { return elements_.rbegin(); } + const_reverse_iterator rend() const { return elements_.rend(); } + + iterator erase(const const_iterator cpos) { + auto pos = ConstIteratorCast(&elements_, cpos); + return elements_.erase(pos); + } + + iterator erase(const const_iterator cfirst, const const_iterator clast) { + auto first = ConstIteratorCast(&elements_, cfirst); + auto last = ConstIteratorCast(&elements_, clast); + return elements_.erase(first, last); + } + + std::size_t hash() const override { + std::stringstream buffer; + buffer << ToString(); + return std::hash()(buffer.str()); + } + + std::vector elements_; +}; + +using VectorRefPtr = std::shared_ptr; + +using set_iterator = std::set::iterator; +using const_set_iterator = std::set::const_iterator; + +struct VectorRefHash { + std::size_t operator()(const VectorRef &c) const { return c.hash(); } +}; + +class SetRef : public BaseRef { + public: + SetRef() {} + explicit SetRef(const std::set &elements) : elements_(elements) {} + SetRef(const std::initializer_list elements) : elements_(elements.begin(), elements.end()) {} + SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {} + + // left reference + virtual SetRef &operator=(const SetRef &other); + + bool operator==(const BaseRef &other) const override; + bool operator==(const SetRef &other) const; + + ~SetRef() override = default; + + std::shared_ptr copy() const override { return std::make_shared(elements_); } + + bool empty() const { return (elements_.size() == 0); } + + std::size_t size() const { return elements_.size(); } + MS_DECLARE_PARENT(SetRef, BaseRef) + + uint32_t type() const override { return tid(); } + std::string ToString() const override; + std::set &elements() { return elements_; } + void clear() { elements_.clear(); } + + void insert(const BaseRef &elem) { (void)elements_.insert(elem); } + + const_set_iterator begin() const { return elements_.begin(); } + const_set_iterator end() const { return elements_.end(); } + + template + void insert(const InputIt first, const InputIt last) { + (void)elements_.insert(first, last); + } + + std::size_t count(const BaseRef &elem) const { return elements_.count(elem); } + const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); } + + std::set elements_; +}; + +using SetRefPtr = std::shared_ptr; + +class RunFunctionRef : public BaseRef { + public: + RunFunctionRef() {} + explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {} + + ~RunFunctionRef() override = default; + MS_DECLARE_PARENT(RunFunctionRef, BaseRef) + + uint32_t type() const override { return tid(); } + std::string ToString() const override { return std::string("RunFunctionRef"); } + bool operator==(const BaseRef &other) const override; + bool operator==(const RunFunctionRef &other) const; + + RunFuncPtr func_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_BASE_REF_H_ diff --git a/mindspore/core/base/base_ref_utils.cc b/mindspore/core/base/base_ref_utils.cc new file mode 100644 index 0000000000..69051fa9fd --- /dev/null +++ b/mindspore/core/base/base_ref_utils.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "base/base_ref_utils.h" +#include +#include + +#include "include/infer_tensor.h" +#include "ir/tensor.h" + +namespace mindspore { +void IterateFindTensor(std::vector *msTensors, const VectorRef &ref_list) { + for (size_t i = 0; i < ref_list.size(); ++i) { + if (utils::isa(ref_list[i])) { + auto tensor_ptr = utils::cast>(ref_list[i]); + MS_EXCEPTION_IF_NULL(tensor_ptr); + msTensors->emplace_back(tensor_ptr); + } else if (utils::isa(ref_list[i])) { + auto ref_iter = utils::cast(ref_list[i]); + IterateFindTensor(msTensors, ref_iter); + } else { + MS_LOG(EXCEPTION) << "The output is not a tensor"; + } + } +} + +std::vector TransformVectorRefToMultiTensor(const VectorRef &base_ref) { + std::vector msTensors; + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + IterateFindTensor(&msTensors, ref_list); + } else if (utils::isa(base_ref)) { + auto tensor_ptr = utils::cast>(base_ref); + MS_EXCEPTION_IF_NULL(tensor_ptr); + msTensors.emplace_back(tensor_ptr); + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + return msTensors; +} +} // namespace mindspore diff --git a/mindspore/core/base/base_ref_utils.h b/mindspore/core/base/base_ref_utils.h new file mode 100644 index 0000000000..8ced9134b6 --- /dev/null +++ b/mindspore/core/base/base_ref_utils.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_BASE_BASE_REF_UTILS_H +#define MINDSPORE_CORE_BASE_BASE_REF_UTILS_H +#include +#include + +#include "include/infer_tensor.h" +#include "ir/tensor.h" +#include "base/base_ref.h" + +namespace mindspore { +std::vector TransformVectorRefToMultiTensor(const VectorRef &base_ref); +} // namespace mindspore +#endif // MINDSPORE_CORE_BASE_BASE_REF_UTILS_H diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h new file mode 100755 index 0000000000..a04b983a2d --- /dev/null +++ b/mindspore/core/base/core_ops.h @@ -0,0 +1,161 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_ +#define MINDSPORE_CORE_OPERATOR_OPS_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace prim { +// Maths +inline const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); +inline const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); +inline const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); +inline const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); +inline const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); +inline const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); +inline const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); +inline const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); +inline const PrimitivePtr kPrimReduceAny = std::make_shared("ReduceAny"); +inline const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); +inline const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); +inline const PrimitivePtr kPrimNeg = std::make_shared("Neg"); +inline const PrimitivePtr kPrimSub = std::make_shared("Sub"); +inline const PrimitivePtr kPrimMul = std::make_shared("Mul"); +inline const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); +inline const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); +inline const PrimitivePtr kPrimSquare = std::make_shared("Square"); +inline const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); +inline const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); +inline const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); +inline const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); +inline const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); +inline const PrimitivePtr kPrimPow = std::make_shared("Pow"); +inline const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); +inline const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); +inline const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); +inline const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); + +// Statements +inline const PrimitivePtr kPrimReturn = std::make_shared("return"); +inline const PrimitivePtr kPrimSwitch = std::make_shared("switch"); +inline const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); +inline const PrimitivePtr kPrimAssign = std::make_shared("Assign"); +inline const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); +inline const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); +inline const PrimitivePtr kPrimSelect = std::make_shared("Select"); +inline const PrimitivePtr kPrimCall = std::make_shared("call"); + +// Structures +inline const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); +inline const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); +inline const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); +inline const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); +inline const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); +inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); +inline const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); +inline const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); +inline const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); +inline const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); +inline const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); +inline const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); +inline const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); +inline const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); +inline const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); +inline const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); +inline const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); +inline const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); +inline const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); +inline const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); +inline const PrimitivePtr kPrimListLen = std::make_shared("list_len"); +inline const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); +inline const PrimitivePtr kPrimListMap = std::make_shared("list_map"); +inline const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); +inline const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); +inline const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); +inline const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); +inline const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); +inline const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); +inline const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); +inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); +inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); +inline const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); +inline const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); +inline const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); +inline const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); +inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); + +// Debug ops +inline const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); +inline const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); +inline const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); +inline const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); +inline const PrimitivePtr kPrimDebug = std::make_shared("Debug"); + +// Other miscellaneous +inline const PrimitivePtr kPrimJ = std::make_shared("J"); +inline const PrimitivePtr kPrimDepend = std::make_shared("Depend"); +inline const PrimitivePtr kPrimPartial = std::make_shared("Partial"); +inline const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +inline const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); +inline const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); +inline const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); +inline const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); +inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); +inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); +inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); +inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); +inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); +inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); +inline const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); +inline const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); +inline const PrimitivePtr kPrimPrint = std::make_shared("Print"); +inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); +inline const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); +inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); +inline const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); +inline const PrimitivePtr kPrimIs_ = std::make_shared("is_"); +inline const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); +inline const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); +inline const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); +inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); +inline const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); +inline const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); + +class DoSignaturePrimitive : public Primitive { + public: + explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) + : Primitive("S-Prim-" + name), function_(function) {} + + ~DoSignaturePrimitive() override = default; + + MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) + + const ValuePtr function() const { return function_; } + + private: + ValuePtr function_; +}; +using DoSignaturePrimitivePtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPERATOR_OPS_H_ diff --git a/mindspore/core/base/user_data.h b/mindspore/core/base/user_data.h new file mode 100644 index 0000000000..6912d0767d --- /dev/null +++ b/mindspore/core/base/user_data.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_USER_DATA_H_ +#define MINDSPORE_CORE_USER_DATA_H_ + +#include +#include +#include + +namespace mindspore { +class UserData { + public: + template + void set(const std::string &key, const std::shared_ptr &value) { + if (value == nullptr) { + data_.erase(key); + } else { + data_.insert_or_assign(key, value); + } + } + + template + std::shared_ptr get(const std::string &key) const { + auto iter = data_.find(key); + if (iter == data_.end()) { + return nullptr; + } + return std::static_pointer_cast(iter->second); + } + + bool has(const std::string &key) const { return data_.find(key) != data_.end(); } + + private: + std::map> data_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_USER_DATA_H_ diff --git a/mindspore/core/ir/CMakeLists.txt b/mindspore/core/ir/CMakeLists.txt index 2a0b81ae04..77bc1b7661 100644 --- a/mindspore/core/ir/CMakeLists.txt +++ b/mindspore/core/ir/CMakeLists.txt @@ -1,7 +1,3 @@ file(GLOB_RECURSE _IR_SRC_LIST ./*.cc dtype/*.cc) -file(GLOB_RECURSE _IR_LITE_SRC_FILES - ./lite/tensor.cc - ) -list(REMOVE_ITEM _IR_SRC_LIST ${_IR_LITE_SRC_FILES}) set_property(SOURCE ${_IR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_IR) add_library(_mindspore_ir_obj OBJECT ${_IR_SRC_LIST}) diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 0d96ddf263..a681644b3d 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -23,10 +23,10 @@ #include #include +#include "base/core_ops.h" #include "ir/func_graph.h" #include "ir/primitive.h" -#include "utils/context/ms_context.h" -#include "frontend/operator/ops.h" +#include "utils/ms_context.h" namespace mindspore { // namespace to support intermediate representation definition @@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const { return buffer.str(); } +std::string Parameter::DebugString(int recursive_level) const { + std::ostringstream buffer; + if (recursive_level > 0) { + if (func_graph() != nullptr) { + buffer << func_graph()->ToString() << ":"; + } + } + buffer << ToString(); + return buffer.str(); +} + std::string ValueNode::ToString() const { MS_EXCEPTION_IF_NULL(value_); if (value_->isa()) { @@ -180,6 +191,41 @@ std::string get_id(const AnfNodePtr &node) { void reset_id() { node_ids.clear(); } } // namespace id_generator +namespace { +std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto users = manager->node_users()[cnode]; + std::string first_user_target = GetCNodeTarget(users.back().first); + bool is_used_by_different_target = + std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair &u) -> bool { + return GetCNodeTarget(u.first) != first_user_target; + }); + if (!is_used_by_different_target) { + return first_user_target; + } + + auto inputs = cnode->inputs(); + std::vector real_inputs; + std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); + std::string first_input_target = GetCNodeTarget(real_inputs[0]); + bool is_from_different_target = + std::any_of(std::begin(real_inputs), std::end(real_inputs), + [&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; }); + if (!is_from_different_target) { + return first_input_target; + } + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + return default_target; +} +} // namespace + std::string GetCNodeTarget(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -207,15 +253,26 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { auto primitive = value->cast(); auto att_target = primitive->GetAttr("primitive_target"); if (att_target != nullptr) { + if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || + IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || + IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || + IsPrimitive(attr_input, prim::kPrimTupleGetItem) || IsPrimitive(attr_input, prim::kPrimControlDepend) || + IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { + primitive->EraseAttr("primitive_target"); + return default_target; + } if (!att_target->isa()) { MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; } auto target = GetValue(att_target); if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; } return target; } + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + return GetMaketupleNodeTarget(cnode); + } return default_target; } } // namespace mindspore diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index c1a28d57f1..1595bed6f2 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_ANF_H_ -#define MINDSPORE_CCSRC_IR_ANF_H_ +#ifndef MINDSPORE_CORE_IR_ANF_H_ +#define MINDSPORE_CORE_IR_ANF_H_ #include #include @@ -27,9 +27,10 @@ #include #include "base/base.h" +#include "base/user_data.h" #include "ir/kernel_info_dev.h" #include "ir/scope.h" -#include "debug/info.h" +#include "utils/info.h" // A MindSpore ANF IR defined here. // with BNF followed: @@ -41,12 +42,6 @@ // ANode: Atomic Node // CNode: Complex Node namespace mindspore { -namespace parallel { -class TensorLayout; -class OperatorInfo; -} // namespace parallel -using OperatorInfoPtr = std::shared_ptr; - namespace abstract { class BaseShape; class AbstractBase; @@ -55,8 +50,13 @@ using BaseShapePtr = std::shared_ptr; using AbstractBasePtr = std::shared_ptr; using AbstractBasePtrList = std::vector; +class Value; +using ValuePtr = std::shared_ptr; +using ValuePtrList = std::vector; + class ValueNode; using ValueNodePtr = std::shared_ptr; + class CNode; using CNodePtr = std::shared_ptr; @@ -72,7 +72,7 @@ class BaseRef; class Var; using VarPtr = std::shared_ptr; -class AnfVisitor; +class AnfIrVisitor; class ParamValue; using ParamValuePtr = std::shared_ptr; @@ -105,7 +105,7 @@ class AnfNode : public Base { ~AnfNode() override = default; MS_DECLARE_PARENT(AnfNode, Base); - virtual void accept(AnfVisitor *) {} + virtual void accept(AnfIrVisitor *) {} FuncGraphPtr func_graph() const { return func_graph_.lock(); } void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } @@ -157,6 +157,33 @@ class AnfNode : public Base { } size_t seen_{0}; + template + void set_user_data(const std::string &key, const std::shared_ptr &value) { + user_data_.set(key, value); + } + + template + void set_user_data(const std::shared_ptr &value) { + user_data_.set(T::key, value); + } + + template + std::shared_ptr user_data(const std::string &key) const { + return user_data_.get(key); + } + + template + std::shared_ptr user_data() const { + return user_data_.get(T::key); + } + + bool has_user_data(const std::string &key) const { return user_data_.has(key); } + + template + bool has_user_data() const { + return user_data_.has(T::key); + } + protected: // Hold a weak ref to Graph as Graph also hold ref to AnfNode. // Otherwise, func_graph_ and AnfNode will make a reference cycle. @@ -170,6 +197,7 @@ class AnfNode : public Base { std::hash hash_; ScopePtr scope_; KernelInfoDevicePtr kernel_info_; + UserData user_data_; }; // CNode represents the complex node with a set of arguments. @@ -193,7 +221,7 @@ class CNode : public AnfNode { ~CNode() override = default; MS_DECLARE_PARENT(CNode, AnfNode); - void accept(AnfVisitor *v) override; + void accept(AnfIrVisitor *v) override; // check whether this cnode has some primitive value as the first input. bool IsApply(const PrimitivePtr &) const; @@ -204,6 +232,9 @@ class CNode : public AnfNode { void set_input(size_t i, const AnfNodePtr &input); void set_inputs(const std::vector &inputs) { inputs_ = inputs; } + void set_forward(const ValuePtr &forward) { forward_ = forward; } + const ValuePtr &forward() const { return forward_; } + bool stop_gradient() const { return stop_gradient_; } void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } @@ -212,9 +243,6 @@ class CNode : public AnfNode { std::string DebugString(int recursive_level = 1) const override; std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } - OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info); - OperatorInfoPtr operator_info() { return operator_info_; } - void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } bool in_forward_flag() const { return in_forward_flag_; } @@ -224,8 +252,8 @@ class CNode : public AnfNode { std::vector inputs_; VarPtr func_graph_as_var_; bool stop_gradient_; - OperatorInfoPtr operator_info_ = nullptr; bool in_forward_flag_ = false; + ValuePtr forward_ = nullptr; }; // ANode represents the atomic node. It's derived Parameter and ValueNode. @@ -244,27 +272,22 @@ class ANode : public AnfNode { class Parameter : public ANode { public: explicit Parameter(const FuncGraphPtr &func_graph) - : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} + : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} ~Parameter() override = default; MS_DECLARE_PARENT(Parameter, ANode); - void accept(AnfVisitor *v) override; - + void accept(AnfIrVisitor *v) override; + std::string DebugString(int recursive_level = 1) const override; std::string name() const { return name_; } void set_name(const std::string &name) { name_ = name; } std::string fullname_with_scope() override { return name(); }; bool has_default() const { return has_default_; } - void set_default_param(ParamValuePtr param) { + void set_default_param(ValuePtr param) { default_param_ = param; has_default_ = true; } - ParamValuePtr default_param() const { return default_param_; } - - std::shared_ptr tensor_layout() const { return tensor_layout_; } - void set_tensor_layout(const std::shared_ptr &tensor_layout) { - tensor_layout_ = tensor_layout; - } + ValuePtr default_param() const { return default_param_; } bool operator==(const AnfNode &other) const override { if (!other.isa()) { @@ -280,8 +303,7 @@ class Parameter : public ANode { private: std::string name_; bool has_default_; - ParamValuePtr default_param_; - std::shared_ptr tensor_layout_; + ValuePtr default_param_; }; using ParameterPtr = std::shared_ptr; @@ -310,8 +332,6 @@ class Value : public Base { protected: TypePtr type_{nullptr}; }; -using ValuePtr = std::shared_ptr; -using ValuePtrList = std::vector; // ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode // does not belong to any particular function graph. @@ -321,10 +341,14 @@ class ValueNode : public ANode { ~ValueNode() override = default; MS_DECLARE_PARENT(ValueNode, ANode); - void accept(AnfVisitor *v) override; + void accept(AnfIrVisitor *v) override; + void set_value(const ValuePtr &value) { value_ = value; } const ValuePtr &value() const { return value_; } std::string fullname_with_scope() override; + void set_has_new_value(bool flag) { has_new_value_ = flag; } + bool has_new_value() const { return has_new_value_; } + std::string ToString() const override; std::string DebugString(int recursive_level = 1) const override; std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } @@ -344,6 +368,7 @@ class ValueNode : public ANode { private: ValuePtr value_; + bool has_new_value_ = false; }; template @@ -442,4 +467,4 @@ using TaggedGraph = std::pair; std::string GetCNodeTarget(const AnfNodePtr &node); } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_ANF_H_ +#endif // MINDSPORE_CORE_IR_ANF_H_ diff --git a/mindspore/core/ir/anf_extends.cc b/mindspore/core/ir/anf_extends.cc index b70a660aae..08e4a3194c 100644 --- a/mindspore/core/ir/anf_extends.cc +++ b/mindspore/core/ir/anf_extends.cc @@ -23,9 +23,7 @@ #include "ir/visitor.h" #include "ir/func_graph.h" -#include "frontend/operator/ops.h" -#include "frontend/parallel/ops_info/ops_utils.h" -#include "debug/label.h" +#include "base/core_ops.h" namespace mindspore { // namespace to support intermediate representation definition @@ -37,18 +35,6 @@ std::string AnfNode::ToString() const { return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { - if (operator_info_ != nullptr) { - MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() - << ", using the new one: " << operator_info->name(); - auto old_ptr = operator_info_; - operator_info_ = operator_info; - return old_ptr; - } - operator_info_ = operator_info; - return nullptr; -} - std::string CNode::fullname_with_scope() { // if full name is set, return its name immediately if (!fullname_with_scope_.empty()) { @@ -106,7 +92,7 @@ std::string CNode::fullname_with_scope() { return fullname_with_scope_; } -void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void CNode::accept(AnfIrVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfIrVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfIrVisitor *v) { v->Visit(shared_from_base()); } } // namespace mindspore diff --git a/mindspore/core/ir/device_sync.h b/mindspore/core/ir/device_sync.h index a6bbe92233..d8a0079814 100644 --- a/mindspore/core/ir/device_sync.h +++ b/mindspore/core/ir/device_sync.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ -#define MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ +#ifndef MINDSPORE_CORE_IR_DEVICE_SYNC_H_ +#define MINDSPORE_CORE_IR_DEVICE_SYNC_H_ #include #include @@ -32,7 +32,8 @@ class DeviceSync { virtual bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const = 0; virtual bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const = 0; + virtual void *GetMutablePtr() const = 0; }; using DeviceSyncPtr = std::shared_ptr; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ +#endif // MINDSPORE_CORE_IR_DEVICE_SYNC_H_ diff --git a/mindspore/core/ir/dtype.cc b/mindspore/core/ir/dtype.cc index 71a78bdcf6..b01d12d9ff 100644 --- a/mindspore/core/ir/dtype.cc +++ b/mindspore/core/ir/dtype.cc @@ -179,40 +179,82 @@ bool TensorType::operator==(const Type &other) const { return *element_type_ == *other_elem_type; } -TypePtr IndexedSlicesType::DeepCopy() const { +TypePtr RowTensorType::DeepCopy() const { MS_EXCEPTION_IF_NULL(element_type_); if (IsGeneric()) { - return std::make_shared(); + return std::make_shared(); } - return std::make_shared(element_type_->DeepCopy()); + return std::make_shared(element_type_->DeepCopy()); } -std::string IndexedSlicesType::ToReprString() const { +std::string RowTensorType::ToReprString() const { if (element_type_ == nullptr) { - return "IndexedSlices"; + return "RowTensor"; } - return "IndexedSlices[" + element_type_->ToReprString() + "]"; + return "RowTensor[" + element_type_->ToReprString() + "]"; } -std::string IndexedSlicesType::ToString() const { +std::string RowTensorType::ToString() const { if (element_type_ == nullptr) { - return "IndexedSlices"; + return "RowTensor"; } - return "IndexedSlices[" + element_type_->ToString() + "]"; + return "RowTensor[" + element_type_->ToString() + "]"; } -std::string IndexedSlicesType::DumpText() const { +std::string RowTensorType::DumpText() const { if (element_type_ == nullptr) { - return "IndexedSlices"; + return "RowTensor"; } - return "IndexedSlices[" + element_type_->DumpText() + "]"; + return "RowTensor[" + element_type_->DumpText() + "]"; } -bool IndexedSlicesType::operator==(const Type &other) const { +bool RowTensorType::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_elem_type = static_cast(other).element_type_; + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + +TypePtr SparseTensorType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string SparseTensorType::ToReprString() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->ToReprString() + "]"; +} + +std::string SparseTensorType::ToString() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->ToString() + "]"; +} + +std::string SparseTensorType::DumpText() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->DumpText() + "]"; +} + +bool SparseTensorType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; if (element_type_ == nullptr && other_elem_type == nullptr) { return true; } else if (element_type_ == nullptr || other_elem_type == nullptr) { diff --git a/mindspore/core/ir/dtype.h b/mindspore/core/ir/dtype.h index dc277c031c..969fb4f190 100644 --- a/mindspore/core/ir/dtype.h +++ b/mindspore/core/ir/dtype.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_H_ +#define MINDSPORE_CORE_IR_DTYPE_H_ #include #include @@ -154,15 +154,15 @@ class TensorType : public Object { }; using TensorTypePtr = std::shared_ptr; -class IndexedSlicesType : public Object { +class RowTensorType : public Object { public: - IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} - explicit IndexedSlicesType(const TypePtr &ele) - : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} - ~IndexedSlicesType() override = default; - MS_DECLARE_PARENT(IndexedSlicesType, Object) + RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} + explicit RowTensorType(const TypePtr &ele) + : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~RowTensorType() override = default; + MS_DECLARE_PARENT(RowTensorType, Object) - TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } + TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } const TypePtr element() const { return element_type_; } void set_element(const TypePtr &element_type) { element_type_ = element_type; } @@ -175,7 +175,30 @@ class IndexedSlicesType : public Object { private: TypePtr element_type_; }; -using IndexedSlicesTypePtr = std::shared_ptr; +using RowTensorTypePtr = std::shared_ptr; + +class SparseTensorType : public Object { + public: + SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} + explicit SparseTensorType(const TypePtr &ele) + : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~SparseTensorType() override = default; + MS_DECLARE_PARENT(SparseTensorType, Object) + + TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using SparseTensorTypePtr = std::shared_ptr; class Function : public Object { public: @@ -332,4 +355,4 @@ extern const TypePtr kKeyword; extern const TypePtr kTensorType; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_H_ diff --git a/mindspore/core/ir/dtype/container.h b/mindspore/core/ir/dtype/container.h index 29579fe73c..a6aa07e6f7 100644 --- a/mindspore/core/ir/dtype/container.h +++ b/mindspore/core/ir/dtype/container.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_CONTAINER_H_ +#define MINDSPORE_CORE_IR_DTYPE_CONTAINER_H_ #include #include @@ -147,4 +147,4 @@ class Dictionary : public Object { using DictionaryPtr = std::shared_ptr; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_CONTAINER_H_ diff --git a/mindspore/core/ir/dtype/empty.h b/mindspore/core/ir/dtype/empty.h index e6149a1fce..d2422f8fc3 100644 --- a/mindspore/core/ir/dtype/empty.h +++ b/mindspore/core/ir/dtype/empty.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_EMPTY_H_ +#define MINDSPORE_CORE_IR_DTYPE_EMPTY_H_ #include #include @@ -90,4 +90,4 @@ extern const TypePtr kTypeEllipsis; extern const TypePtr kAnyType; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_EMPTY_H_ diff --git a/mindspore/core/ir/dtype/number.h b/mindspore/core/ir/dtype/number.h index 8997ddc4df..673957c825 100644 --- a/mindspore/core/ir/dtype/number.h +++ b/mindspore/core/ir/dtype/number.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ +#define MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ #include #include @@ -151,4 +151,4 @@ extern const TypePtr kFloat; extern const TypePtr kNumber; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ diff --git a/mindspore/core/ir/dtype/ref.h b/mindspore/core/ir/dtype/ref.h index e798d72af5..79a596a90e 100644 --- a/mindspore/core/ir/dtype/ref.h +++ b/mindspore/core/ir/dtype/ref.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_REF_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_REF_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_ +#define MINDSPORE_CORE_IR_DTYPE_REF_H_ #include #include @@ -72,4 +72,4 @@ extern const TypePtr kRefKeyType; extern const TypePtr kRefType; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_REF_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_REF_H_ diff --git a/mindspore/core/ir/dtype/type.cc b/mindspore/core/ir/dtype/type.cc index 754876a366..ab8d4941f1 100644 --- a/mindspore/core/ir/dtype/type.cc +++ b/mindspore/core/ir/dtype/type.cc @@ -115,8 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) { return "kObjectTypeKeyword"; case kObjectTypeTensorType: return "kObjectTypeTensorType"; - case kObjectTypeIndexedSlicesType: - return "kObjectTypeIndexedSlicesType"; + case kObjectTypeRowTensorType: + return "kObjectTypeRowTensorType"; + case kObjectTypeSparseTensorType: + return "kObjectTypeSparseTensorType"; case kObjectTypeUndeterminedType: return "kObjectTypeUndeterminedType"; case kObjectTypeDictionary: diff --git a/mindspore/core/ir/dtype/type.h b/mindspore/core/ir/dtype/type.h index 2e38e8ffb6..4e6ff01c1d 100644 --- a/mindspore/core/ir/dtype/type.h +++ b/mindspore/core/ir/dtype/type.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_H_ +#define MINDSPORE_CORE_IR_DTYPE_TYPE_H_ #include #include @@ -84,8 +84,6 @@ class Type : public Value { friend std::ostream &operator<<(std::ostream &os, const Type &type); friend std::ostream &operator<<(std::ostream &os, const TypePtr type); - const bool parse_info_ = true; - private: TypeId meta_type_; bool is_generic_; @@ -121,7 +119,15 @@ class Object : public Type { const TypeId parent_type_; }; +// +// TypeId name map +// +const std::unordered_map type_name_map = { + {kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"}, + {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"}, + {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}}; + std::ostream &operator<<(std::ostream &os, const TypePtrList &types); } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_H_ diff --git a/mindspore/core/ir/dtype/type_id.h b/mindspore/core/ir/dtype/type_id.h index 6fb2a354c1..9ff57cb46c 100644 --- a/mindspore/core/ir/dtype/type_id.h +++ b/mindspore/core/ir/dtype/type_id.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ +#ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ +#define MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ #include #include @@ -50,7 +50,8 @@ enum TypeId : int { kObjectTypeSlice, kObjectTypeKeyword, kObjectTypeTensorType, - kObjectTypeIndexedSlicesType, + kObjectTypeRowTensorType, + kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, kObjectTypeClass, kObjectTypeDictionary, @@ -82,12 +83,5 @@ enum TypeId : int { kNumberTypeFloat64, kNumberTypeEnd }; -// -// TypeId name map -// -const std::unordered_map type_name_map = { - {kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"}, - {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"}, - {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}}; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ +#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc index 099748217e..26bdaf00ab 100644 --- a/mindspore/core/ir/dtype_extends.cc +++ b/mindspore/core/ir/dtype_extends.cc @@ -190,9 +190,9 @@ TypePtr TensorStrToType(const std::string &type_name) { return type; } -TypePtr IndexedSlicesStrToType(const std::string &type_name) { - if (type_name == "IndexedSlices") { - return std::make_shared(); +TypePtr RowTensorStrToType(const std::string &type_name) { + if (type_name == "RowTensor") { + return std::make_shared(); } auto start = type_name.find_first_of('[') + 1; auto end = type_name.find_last_of(']'); @@ -204,7 +204,24 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) { if (element_type == nullptr) { return nullptr; } - return std::make_shared(element_type); + return std::make_shared(element_type); +} + +TypePtr SparseTensorStrToType(const std::string &type_name) { + if (type_name == "SparseTensor") { + return std::make_shared(); + } + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + return std::make_shared(element_type); } TypePtr UndeterminedStrToType(const std::string &type_name) { @@ -347,8 +364,10 @@ TypePtr StringToType(const std::string &type_name) { type = TensorStrToType(type_name); } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { type = UndeterminedStrToType(type_name); - } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { - type = IndexedSlicesStrToType(type_name); + } else if (type_name.compare(0, strlen("RowTensor"), "RowTensor") == 0) { + type = RowTensorStrToType(type_name); + } else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) { + type = SparseTensorStrToType(type_name); } else if (type_name.compare(0, strlen("List"), "List") == 0) { type = ListStrToType(type_name); } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { @@ -427,7 +446,8 @@ const TypePtr kTypeExternal = std::make_shared(); const TypePtr kTypeEnv = std::make_shared(); const TypePtr kTypeType = std::make_shared(); const TypePtr kTensorType = std::make_shared(); -const TypePtr kIndexedSlicesType = std::make_shared(); +const TypePtr kRowTensorType = std::make_shared(); +const TypePtr kSparseTensorType = std::make_shared(); const TypePtr kUndeterminedType = std::make_shared(); const TypePtr kString = std::make_shared(); const TypePtr kList = std::make_shared(); diff --git a/mindspore/core/ir/dtype_py.cc b/mindspore/core/ir/dtype_py.cc deleted file mode 100644 index 66bd8ba5f6..0000000000 --- a/mindspore/core/ir/dtype_py.cc +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/dtype.h" -#include -#include -#include -#include "utils/log_adapter.h" -#include "abstract/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" - -namespace mindspore { -// Define python wrapper to handle data types. -REGISTER_PYBIND_DEFINE( - typing, ([](py::module *const m) { - auto m_sub = m->def_submodule("typing", "submodule for dtype"); - py::enum_(m_sub, "TypeId"); - (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); - (void)m_sub.def("load_type", &TypeIdToType, "load type"); - (void)m_sub.def( - "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); - (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); - (void)py::class_>(m_sub, "Type") - .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) - .def("__eq__", - [](const TypePtr &t1, const TypePtr &t2) { - if (t1 != nullptr && t2 != nullptr) { - return *t1 == *t2; - } - return false; - }) - .def("__hash__", &Type::hash) - .def("__str__", &Type::ToString) - .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr &t, py::dict) { - if (t == nullptr) { - return static_cast(nullptr); - } - return t->DeepCopy(); - }); - (void)py::class_>(m_sub, "Number").def(py::init()); - (void)py::class_>(m_sub, "Bool") - .def(py::init()) - .def(py::pickle( - [](const Bool &) { // __getstate__ - return py::make_tuple(); - }, - [](const py::tuple &) { // __setstate__ - return std::make_shared(); - })); - (void)py::class_>(m_sub, "Int") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Int &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Int data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "UInt") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const UInt &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - UInt data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "Float") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Float &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Float data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "List") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "Tuple") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "TensorType") - .def(py::init()) - .def(py::init(), py::arg("element")) - .def("element_type", &TensorType::element) - .def(py::pickle( - [](const TensorType &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); - return data; - })); - (void)py::class_>(m_sub, "IndexedSlicesType") - .def(py::init()); - (void)py::class_>(m_sub, "UndeterminedType") - .def(py::init()); - (void)py::class_>(m_sub, "Function") - .def(py::init()) - .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); - (void)py::class_>(m_sub, "Class").def(py::init()); - (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); - (void)py::class_>(m_sub, "EnvType").def(py::init()); - (void)py::class_>(m_sub, "TypeNone").def(py::init()); - (void)py::class_>(m_sub, "TypeType").def(py::init()); - (void)py::class_>(m_sub, "String").def(py::init()); - (void)py::class_>(m_sub, "RefKeyType").def(py::init()); - (void)py::class_>(m_sub, "RefType").def(py::init()); - (void)py::class_>(m_sub, "TypeAnything").def(py::init()); - (void)py::class_>(m_sub, "Slice").def(py::init()); - (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); - })); -} // namespace mindspore diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index fabdd3e7d3..e8992fda8a 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -22,11 +22,12 @@ #include #include -#include "debug/trace.h" +#include "utils/trace_base.h" #include "ir/manager.h" -#include "frontend/operator/ops.h" +#include "utils/flags.h" #include "utils/ordered_set.h" #include "utils/convert_utils_base.h" +#include "abstract/abstract_function.h" namespace mindspore { /* @@ -49,6 +50,11 @@ FuncGraph::FuncGraph() debug_info_ = std::make_shared(); } +abstract::AbstractBasePtr FuncGraph::ToAbstract() { + auto temp_context = abstract::AnalysisContext::DummyContext(); + return std::make_shared(shared_from_base(), temp_context); +} + AnfNodePtr FuncGraph::output() const { // If return value is set, return should have two inputs. if (return_ != nullptr && return_->inputs().size() == 2) { @@ -417,6 +423,15 @@ std::shared_ptr> FuncGraph::recursive_graphs() { return mng->recursive_graphs(shared_from_base()); } +void FuncGraph::ClearAllManagerInfo() { + ClearNodes(); + ClearValueNodes(); + ClearFuncGraphCNodesIndex(); + ClearFreeVariables(); + ClearFuncGraphsUsed(); + ClearJFuncGraphs(); +} + AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { auto itr = this->parameter_default_value_.find(name); if (itr == parameter_default_value_.end()) { diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 712c75b431..3ce74cfb5b 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ -#define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ +#ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_H_ +#define MINDSPORE_CORE_IR_FUNC_GRAPH_H_ #include #include @@ -32,7 +32,8 @@ #include "ir/manager.h" #include "utils/ordered_set.h" #include "utils/ordered_map.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" +#include "ir/func_graph_cloner.h" namespace mindspore { using BaseRefCounterMap = OrderedMap; @@ -143,12 +144,14 @@ extern const char kFuncGraphFlagUndetermined[]; class FuncGraph : public FuncGraphBase { public: FuncGraph(); + using Drawer = std::function; ~FuncGraph() override = default; MS_DECLARE_PARENT(FuncGraph, FuncGraphBase); // get the graph's abstract abstract::AbstractFunctionPtr abstract(); + abstract::AbstractBasePtr ToAbstract() override; // return the graph's output, or nullptr if not yet deduced AnfNodePtr output() const; @@ -229,7 +232,8 @@ class FuncGraph : public FuncGraphBase { } this->debug_info_ = info; } - + // clear all info from manager + void ClearAllManagerInfo(); // get all nodes belonging to this func graph const AnfNodeSet &nodes(); void CopyNodes(const FuncGraphPtr &source); @@ -328,6 +332,7 @@ class FuncGraph : public FuncGraphBase { std::unordered_map &make_ref_params() { return make_ref_params_; } std::unordered_map attrs_; + std::vector joined_shapes_; std::unordered_map transforms_; // parameter default value std::map parameter_default_value_; @@ -345,6 +350,7 @@ class FuncGraph : public FuncGraphBase { bool stub() const { return stub_; } void set_stub(bool stub) { stub_ = stub; } + static void set_drawer(Drawer drawer) { drawer_ = drawer; } private: // graph is manipulated by manager and others @@ -405,6 +411,7 @@ class FuncGraph : public FuncGraphBase { // CNode order which relates to origin code order std::list order_; bool stub_; + inline static Drawer drawer_ = nullptr; }; inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { @@ -420,4 +427,4 @@ std::shared_ptr> FindRoots(const std::vector &seg std::shared_ptr> FindLeaves(const std::vector &segment); } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ +#endif // MINDSPORE_CORE_IR_FUNC_GRAPH_H_ diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 0857770cad..0e6b73201b 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -20,11 +20,12 @@ #include "ir/manager.h" #include "ir/param_value.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" #include "utils/convert_utils_base.h" #include "utils/log_adapter.h" #include "utils/profile.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" +#include "ir/graph_utils.h" // namespace to support intermediate representation definition namespace mindspore { @@ -87,6 +88,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { CNodePtr new_node = std::make_shared(AnfNodePtrList{}, target); auto old_node = node->cast(); new_node->set_abstract(old_node->abstract()); + new_node->set_forward(old_node->forward()); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_node->set_scope(scope); new_node->set_kernel_info(old_node->kernel_info_ptr()); @@ -102,6 +104,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) { ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_const->set_scope(scope); new_const->set_abstract(node->abstract()); + new_const->set_has_new_value(node->cast()->has_new_value()); repl_node_[node] = new_const; TraceManager::EndTrace(); } @@ -114,6 +117,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_const->set_scope(scope); new_const->set_abstract(node->abstract()); + new_const->set_has_new_value(node->cast()->has_new_value()); repl_node_[node] = new_const; TraceManager::EndTrace(); } @@ -180,11 +184,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); - auto return_node = repl_node_[func_graph->get_return()]->cast(); - if (return_node == nullptr) { - MS_LOG(EXCEPTION) << "Can't find replicate node for return."; + + auto old_return = func_graph->get_return(); + if (old_return != nullptr) { + auto return_node = repl_node_[old_return]->cast(); + if (return_node == nullptr) { + MS_LOG(EXCEPTION) << "Can't find replicate node for return."; + } + target_func_graph->set_return(return_node); } - target_func_graph->set_return(return_node); auto &cnodes = func_graph->func_graph_cnodes_index(); for (auto &cnode : cnodes) { @@ -212,6 +220,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); *target_func_graph = std::make_shared(); (*target_func_graph)->set_attrs(func_graph->attrs()); + (*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_; (*target_func_graph)->set_transforms(func_graph->transforms()); (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); @@ -400,11 +409,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph } void Cloner::Lift() { - for (auto &func_graph_params : repl_func_graph_params_) { - auto &func_graph = func_graph_params.first; - auto ¶ms = func_graph_params.second; - for (auto &cnode : func_graph->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph, params); + // lift inner graph first + auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin())); + for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) { + auto func_graph = *r_iter; + auto iter = repl_func_graph_params_.find(func_graph); + if (iter != repl_func_graph_params_.end()) { + auto ¶ms = iter->second; + for (auto &cnode : func_graph->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph, params); + } } } } diff --git a/mindspore/core/ir/func_graph_cloner.h b/mindspore/core/ir/func_graph_cloner.h index 4279ddfa12..6d75c8d13c 100644 --- a/mindspore/core/ir/func_graph_cloner.h +++ b/mindspore/core/ir/func_graph_cloner.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ -#define MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ +#ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ +#define MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ #include #include @@ -119,8 +119,6 @@ class Cloner { std::unordered_map repl_func_graph_params_; }; -FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); - AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); @@ -130,6 +128,7 @@ ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &r FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation = std::make_shared()); +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ +#endif // MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 579409b05e..6e31c8150f 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -18,17 +18,13 @@ #include #include -#include #include "ir/manager.h" -#include "ir/func_graph_cloner.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" #include "utils/ordered_set.h" #include "abstract/abstract_value.h" -#include "debug/anf_ir_dump.h" -#include "debug/trace.h" -#include "debug/draw.h" -#include "debug/label.h" +#include "abstract/abstract_function.h" +#include "utils/flags.h" namespace mindspore { using mindspore::abstract::AbstractFunction; @@ -79,7 +75,12 @@ void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { input0->set_abstract(f); } -void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } +void FuncGraph::DumpFuncGraph(const std::string &path) { + // draw::Draw(path + ".dot", shared_from_base()); + if (drawer_) { + drawer_(path + ".dot", shared_from_base()); + } +} void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector *specialized_parameter_list, diff --git a/mindspore/core/ir/func_graph_py.cc b/mindspore/core/ir/func_graph_py.cc deleted file mode 100644 index cff25b5aa1..0000000000 --- a/mindspore/core/ir/func_graph_py.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "ir/meta_func_graph.h" -#include "ir/func_graph.h" - -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" - -namespace mindspore { -REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { - // Define python "MetaFuncGraph_" class - (void)py::class_>(*m, "MetaFuncGraph_") - .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) - .def(py::init()); - // Define python "FuncGraph" class - (void)py::class_(*m, "FuncGraph") - .def(py::init()) - .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") - .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); - })); -} // namespace mindspore diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc new file mode 100644 index 0000000000..cde5eaafba --- /dev/null +++ b/mindspore/core/ir/graph_utils.cc @@ -0,0 +1,292 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/graph_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "utils/log_adapter.h" +#include "utils/ms_context.h" + +namespace mindspore { +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { + size_t seen = NewSeenGeneration(); + std::deque todo(1024); + std::unordered_map rank; + std::vector res; + todo.clear(); + todo.push_back(root); + + while (!todo.empty()) { + AnfNodePtr node = todo.back(); + if (node == nullptr || node->seen_ == seen) { + todo.pop_back(); + continue; + } + if (rank.find(node) != rank.end() && rank[node] != todo.size()) { + MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(); + } + rank[node] = todo.size(); + bool cont = false; + auto incl = include(node); + if (incl == FOLLOW) { + auto succs = succ(node); + for (const auto i : succs) { + if ((i != nullptr && i->seen_ != seen) + // Handle the case for 2 subgraphs calls each other. + // If the ValueNodeGraph's return is already in the todo list, do not follow it. + && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && + (i->func_graph()->get_return() == i))) { + todo.push_back(i); + cont = true; + } + } + } else if (incl == NOFOLLOW) { + // do nothing + } else if (incl == EXCLUDE) { + node->seen_ = seen; + todo.pop_back(); + continue; + } else { + MS_LOG(EXCEPTION) << "include(node) must return one of: \"follow\", \"nofollow\", \"exclude\""; + } + if (cont) { + continue; + } + node->seen_ = seen; + res.push_back(node); + todo.pop_back(); + } + return res; +} + +// search the cnodes inside this graph only +std::vector BroadFirstSearchGraphCNodes(CNodePtr ret) { + std::deque todo(1024); + todo.clear(); + todo.push_back(ret); + std::vector sorted_nodes; + auto seen = NewSeenGeneration(); + while (!todo.empty()) { + CNodePtr top = todo.front(); + todo.pop_front(); + sorted_nodes.push_back(top); + auto inputs = top->inputs(); + for (auto &item : inputs) { + if (item->seen_ == seen) { + continue; + } + + if (item->isa()) { + todo.push_back(item->cast()); + } + item->seen_ = seen; + } + } + return sorted_nodes; +} + +std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root) { + std::deque todo; + todo.push_back(root); + std::vector sorted; + auto seen = NewSeenGeneration(); + while (!todo.empty()) { + FuncGraphPtr top = todo.front(); + todo.pop_front(); + sorted.push_back(top); + auto used = top->func_graphs_used(); + for (auto &item : used) { + if (item.first->seen_ == seen) { + continue; + } + todo.push_back(item.first); + item.first->seen_ = seen; + } + } + return sorted; +} + +std::vector SuccDeeper(const AnfNodePtr &node) { + std::vector vecs; + if (node == nullptr) { + return vecs; + } + + if (IsValueNode(node)) { + auto graph = GetValueNode(node); + auto ret = graph->get_return(); + if (ret != nullptr) { + vecs.push_back(ret); + } + return vecs; + } else if (node->func_graph() != nullptr) { + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + } + auto graph = node->func_graph(); + if (graph->get_return() != nullptr) { + vecs.push_back(graph->get_return()); + } + return vecs; + } + + return vecs; +} + +std::vector SuccDeeperSimple(const AnfNodePtr &node) { + std::vector vecs; + if (node == nullptr) { + return vecs; + } + + if (IsValueNode(node)) { + auto graph = GetValueNode(node); + auto ret = graph->get_return(); + if (ret != nullptr) { + vecs.push_back(ret); + } + return vecs; + } else { + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + } + return vecs; + } +} + +std::vector SuccIncoming(const AnfNodePtr &node) { + std::vector vecs; + if (node == nullptr) { + return vecs; + } + + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + } + return vecs; +} + +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) { + std::vector vecs; + if (node == nullptr) { + return vecs; + } + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + // Check if free variables used. + for (const auto &input : inputs) { + auto input_fg = GetValueNode(input); + if (input_fg) { + for (auto &fv : input_fg->free_variables_nodes()) { + if (fv->func_graph() == fg && fg->nodes().contains(fv)) { + vecs.push_back(fv); + } + } + } + } + (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + } + return vecs; +} + +IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } + +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { + if (node->func_graph() == fg) { + return FOLLOW; + } else { + return EXCLUDE; + } +} + +FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { + MS_EXCEPTION_IF_NULL(fg); + Acquire(fg); + + auto vec = search(fg->get_return(), include); + for (auto &node : vec) { + MS_EXCEPTION_IF_NULL(node); + Acquire(node); + if (node->func_graph() != nullptr) { + Acquire(node->func_graph()); + } + } +} + +std::set FuncGraphIndex::GetFuncGraphs(const std::string &key) { + std::set func_graphs; + if (index_func_graph_.find(key) != index_func_graph_.end()) { + func_graphs = index_func_graph_[key]; + } + return func_graphs; +} + +std::set FuncGraphIndex::GetNodes(const std::string &key) { + if (index_node_.find(key) != index_node_.end()) { + return index_node_[key]; + } + + return std::set(); +} + +FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { + if (GetFuncGraphs(key).empty()) { + return nullptr; + } + + auto fg = *GetFuncGraphs(key).begin(); + return fg; +} + +AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { + if (GetNodes(key).empty()) { + return nullptr; + } + + auto node = *GetNodes(key).begin(); + return node; +} + +void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { + std::string name = label_manage::Label(key->debug_info()); + if (!name.empty()) { + (void)index_func_graph_[name].insert(key); + } +} + +void FuncGraphIndex::Acquire(const AnfNodePtr &key) { + std::string name = label_manage::Label(key->debug_info()); + if (!name.empty()) { + (void)index_node_[name].insert(key); + } +} +} // namespace mindspore diff --git a/mindspore/core/ir/graph_utils.h b/mindspore/core/ir/graph_utils.h new file mode 100644 index 0000000000..1e915db47c --- /dev/null +++ b/mindspore/core/ir/graph_utils.h @@ -0,0 +1,97 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_IR_GRAPH_UTILS_H_ +#define MINDSPORE_CORE_IR_GRAPH_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/primitive.h" +#include "ir/scalar.h" +#include "ir/tensor.h" +#include "utils/label.h" + +namespace mindspore { + +enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; + +using IncludeFunc = std::function; +using FilterFunc = std::function; +using SuccFunc = std::function(AnfNodePtr)>; +using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; + +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); + +std::vector SuccDeeper(const AnfNodePtr &node); +std::vector SuccDeeperSimple(const AnfNodePtr &node); +std::vector SuccIncoming(const AnfNodePtr &node); +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); + +IncludeType AlwaysInclude(const AnfNodePtr &node); +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); + +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); + +std::vector DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, + const FilterFunc &filter); + +class FuncGraphManager; +using FuncGraphManagerPtr = std::shared_ptr; +std::vector DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, + const FuncGraphManagerPtr &mng); +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, + const IncludeFunc &include = AlwaysInclude); + +std::vector BroadFirstSearchGraphCNodes(CNodePtr ret); +std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root); +class FuncGraphIndex { + public: + explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, + const IncludeFunc &include = AlwaysInclude); + FuncGraphIndex(const FuncGraphIndex &) = delete; + FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; + + virtual ~FuncGraphIndex() {} + + std::set GetFuncGraphs(const std::string &key); + std::set GetNodes(const std::string &key); + FuncGraphPtr GetFirstFuncGraph(const std::string &key); + AnfNodePtr GetFirstNode(const std::string &key); + + private: + void Acquire(const FuncGraphPtr &key); + void Acquire(const AnfNodePtr &key); + + std::map> index_func_graph_; + std::map> index_node_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_IR_GRAPH_UTILS_H_ diff --git a/mindspore/core/ir/graph_utils_extends.cc b/mindspore/core/ir/graph_utils_extends.cc new file mode 100644 index 0000000000..1662e6111f --- /dev/null +++ b/mindspore/core/ir/graph_utils_extends.cc @@ -0,0 +1,205 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/graph_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/visitor.h" +#include "ir/manager.h" +#include "ir/func_graph.h" +#include "utils/label.h" +#include "utils/log_adapter.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace { +class DeepFirstSearcher : public AnfIrVisitor { + public: + explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr) + : include_(include), filter_(filter) {} + ~DeepFirstSearcher() override = default; + + std::vector Search(const AnfNodePtr &root) { + if (root == nullptr) { + return res_; + } + seen_ = NewSeenGeneration(); + Visit(root); + return res_; + } + + void Visit(const AnfNodePtr &node) override { + MS_EXCEPTION_IF_NULL(node); + if (node->seen_ == seen_) { + return; + } + + node->seen_ = seen_; + + auto incl = include_(node); + if (incl == EXCLUDE) { + return; + } + if (filter_ == nullptr || !filter_(node)) { + res_.push_back(node); + } + if (incl == FOLLOW) { + AnfIrVisitor::Visit(node); + } + } + + private: + size_t seen_{0}; + IncludeFunc include_; + FilterFunc filter_; + std::vector res_{}; +}; + +class DeepScopedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepScopedGraphSearcher() override = default; + + void Visit(const CNodePtr &cnode) override { + if (cnode->func_graph() == nullptr) { + return; + } + + AnfNodePtr ret = cnode->func_graph()->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + + auto &inputs = cnode->inputs(); + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + DeepFirstSearcher::Visit(*iter); + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (!IsValueNode(vnode)) { + return; + } + + auto graph = GetValueNode(vnode); + AnfNodePtr ret = graph->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } + + void Visit(const ParameterPtr ¶m) override { + if (param->func_graph() == nullptr) { + return; + } + + AnfNodePtr ret = param->func_graph()->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } +}; + +class DeepUsedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepUsedGraphSearcher() override = default; + + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + DeepFirstSearcher::Visit(*iter); + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (!IsValueNode(vnode)) { + return; + } + + auto graph = GetValueNode(vnode); + AnfNodePtr ret = graph->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } +}; + +class DeepLinkedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepLinkedGraphSearcher() override = default; + + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + DeepFirstSearcher::Visit(*iter); + } + } + + void Visit(const ValueNodePtr &) override {} +}; + +class DeepUsersSearcher : public DeepFirstSearcher { + public: + explicit DeepUsersSearcher(const IncludeFunc &include, const FuncGraphManagerPtr &mng) + : DeepFirstSearcher(include), mng_(mng) {} + ~DeepUsersSearcher() override = default; + + void Visit(const CNodePtr &cnode) override { + auto &users = mng_->node_users()[cnode]; + for (auto iter = users.begin(); iter != users.end(); ++iter) { + DeepFirstSearcher::Visit(iter->first); + } + } + void Visit(const ValueNodePtr &) override {} + + private: + FuncGraphManagerPtr mng_; +}; +} // namespace + +// include for if expand the node the search, filter for if put the node to results. +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepScopedGraphSearcher(include).Search(root); +} + +std::vector DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, + const FilterFunc &filter) { + return DeepFirstSearcher(include, filter).Search(root); +} + +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepUsedGraphSearcher(include).Search(root); +} + +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepLinkedGraphSearcher(include).Search(root); +} + +std::vector DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, + const FuncGraphManagerPtr &mng) { + return DeepUsersSearcher(include, mng).Search(root); +} +} // namespace mindspore diff --git a/mindspore/core/ir/kernel_info_dev.h b/mindspore/core/ir/kernel_info_dev.h index 87c717bdcb..70665a1471 100644 --- a/mindspore/core/ir/kernel_info_dev.h +++ b/mindspore/core/ir/kernel_info_dev.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ -#define MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ +#ifndef MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ +#define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ #include @@ -29,4 +29,4 @@ class KernelInfoDevice { using KernelInfoDevicePtr = std::shared_ptr; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ +#endif // MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ diff --git a/mindspore/core/ir/lite/param_value_lite.h b/mindspore/core/ir/lite/param_value_lite.h deleted file mode 100644 index 1da9b915c2..0000000000 --- a/mindspore/core/ir/lite/param_value_lite.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ -#define MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ - -#include - -#include "ir/param_value.h" - -namespace mindspore { -class ParamValueLite : public ParamValue { - public: - ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} - virtual ~ParamValueLite() = default; - - size_t tensor_size() const { return tensor_size_; } - void set_tensor_size(size_t size) { tensor_size_ = size; } - - void *tensor_addr() const { return tensor_addr_; } - void set_tensor_addr(void *addr) { tensor_addr_ = addr; } - - private: - void *tensor_addr_; - size_t tensor_size_; -}; - -using ParamValueLitePtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ diff --git a/mindspore/core/ir/lite/tensor.cc b/mindspore/core/ir/lite/tensor.cc deleted file mode 100644 index 2957495aa4..0000000000 --- a/mindspore/core/ir/lite/tensor.cc +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "ir/lite/tensor.h" -#include "securec/include/securec.h" - -namespace mindspore { -namespace tensor { -#define kMaxMallocSize 1024 * 1024 * 100 -Tensor::Tensor(const TypeId data_type, const std::vector &shape) : MetaTensor(data_type, shape) {} - -Tensor::Tensor(const TypePtr &type_ptr, const std::vector &shape) : MetaTensor(type_ptr, shape) {} - -Tensor::Tensor(const Tensor &tensor) : MetaTensor(tensor) { - this->data_type_ = tensor.data_type_; - this->shape_ = tensor.shape_; - auto ret = CopyTensorData(tensor); - if (0 != ret) { - MS_LOG(EXCEPTION) << "CopyTensorData error"; - } -} - -int Tensor::CopyTensorData(const Tensor &srcTensor) { - if (srcTensor.data_ == nullptr) { - MS_LOG(ERROR) << "data of srcTensor is nullptr"; - return -1; - } - size_t data_size = this->Size(); - MS_ASSERT(data_size == tensor.Size()); - if (this->data_ == nullptr) { - if (data_size > kMaxMallocSize) { - MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes"; - return -1; - } - this->data_ = malloc(data_size); - } - memcpy_s(this->data_, data_size, tensor.data_, tensor.Size()); - return 0; -} - -Tensor::~Tensor() { - if (nullptr != this->data_) { - free(this->data_); - } -} - -Tensor &Tensor::operator=(const Tensor &tensor) { - if (&tensor == this) { - return *this; - } - this->shape_ = tensor.shape_; - this->data_type_ = tensor.data_type_; - auto ret = CopyTensorData(tensor); - if (0 != ret) { - MS_LOG(EXCEPTION) << "CopyTensorData error"; - } - return *this; -} - -bool Tensor::operator==(const Tensor &tensor) { - return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_; -} - -bool Tensor::operator==(const Value &other) const { - if (other.isa()) { - auto other_ = static_cast(other); - return *this == other_; - } else { - return false; - } -} -} // namespace tensor - -namespace inference { -MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { - return new Tensor(data_type, shape); -} - -Tensor::Tensor() { this->tensor_impl_ = std::make_shared(); } - -Tensor::Tensor(TypeId data_type, const std::vector &shape) { - this->tensor_impl_ = std::make_shared(data_type, shape); -} - -Tensor::Tensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } - -TypeId Tensor::data_type() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_type(); -} - -TypeId Tensor::set_data_type(TypeId data_type) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_data_type(data_type); -} - -std::vector Tensor::shape() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->shape(); -} - -size_t Tensor::set_shape(const std::vector &shape) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_shape(shape); -} - -int Tensor::DimensionSize(size_t index) const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->DimensionSize(index); -} - -int Tensor::ElementsNum() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->ElementsNum(); -} - -std::size_t Tensor::hash() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->hash(); -} - -std::shared_ptr Tensor::tensor() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_; -} - -size_t Tensor::Size() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->Size(); -} - -void *Tensor::MutableData() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data(); -} -} // namespace inference -} // namespace mindspore diff --git a/mindspore/core/ir/lite/tensor.h b/mindspore/core/ir/lite/tensor.h deleted file mode 100644 index 0dcf5cc0ee..0000000000 --- a/mindspore/core/ir/lite/tensor.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_LITE_TENSOR_H_ -#define MINDSPORE_CCSRC_IR_LITE_TENSOR_H_ - -#include -#include -#include "ir/meta_tensor.h" -#include "ir/dtype/type.h" - -namespace mindspore { -namespace tensor { -class Tensor : public MetaTensor { - public: - Tensor() : MetaTensor() {} - - Tensor(const TypeId data_type, const std::vector &shape); - - Tensor(const TypePtr &type_ptr, const std::vector &shape); - - Tensor(const Tensor &tensor); - - ~Tensor(); - - int CopyTensorData(const Tensor &srcTensor); - - MS_DECLARE_PARENT(Tensor, MetaTensor) - - virtual Tensor &operator=(const Tensor &tensor); - - virtual bool operator==(const Tensor &tensor); - - bool operator==(const Value &other) const override; - - size_t Size() const { return MetaTensor::ElementsNum() * GetTypeByte(TypeIdToType(this->data_type_)); } - - void *Data() const { return data_; } - - protected: - void *data_; -}; - -using TensorPtr = std::shared_ptr; -} // namespace tensor - -namespace inference { -class Tensor : public MSTensor { - public: - Tensor(); - - Tensor(TypeId data_type, const std::vector &shape); - - explicit Tensor(std::shared_ptr tensor_ptr); - - ~Tensor() = default; - - TypeId data_type() const override; - - TypeId set_data_type(const TypeId data_type) override; - - std::vector shape() const override; - - size_t set_shape(const std::vector &shape) override; - - int DimensionSize(size_t index) const override; - - int ElementsNum() const override; - - std::size_t hash() const override; - - std::shared_ptr tensor() const; - - size_t Size() const override; - - void *MutableData() const override; - - protected: - std::shared_ptr tensor_impl_; -}; -} // namespace inference -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_LITE_TENSOR_H_ diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 00c39679cd..2970d22ee2 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -19,14 +19,11 @@ #include "ir/manager.h" #include -#include #include -#include "debug/trace_base.h" #include "ir/func_graph.h" -#include "utils/profile.h" #include "utils/convert_utils_base.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" namespace mindspore { @@ -387,12 +384,17 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector & continue; } AnfNodeIndexSet &users = node_users_[node]; - - std::vector parameters; - if (!users.empty() || - (node->isa() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) { + if (!users.empty()) { continue; } + + if (node->isa() && node->func_graph() != nullptr) { + auto ¶meters = node->func_graph()->parameters(); + if (std::find(parameters.begin(), parameters.end(), node) != parameters.end()) { + continue; + } + } + if (IsValueNode(node)) { auto fg = GetValueNode(node); func_graphs_to_check->add(fg); @@ -520,12 +522,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { target->CopyFuncGraphsUsed(source); target->CopyJFuncGraphs(source); signals_->InvalidateComputer(); - source->ClearNodes(); - source->ClearValueNodes(); - source->ClearFuncGraphCNodesIndex(); - source->ClearFreeVariables(); - source->ClearFuncGraphsUsed(); - source->ClearJFuncGraphs(); + source->ClearAllManagerInfo(); } FuncGraphTransaction FuncGraphManager::Transact() { diff --git a/mindspore/core/ir/manager.h b/mindspore/core/ir/manager.h index a80302d0ac..d961e94ae5 100644 --- a/mindspore/core/ir/manager.h +++ b/mindspore/core/ir/manager.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_MANAGER_H_ -#define MINDSPORE_CCSRC_IR_MANAGER_H_ +#ifndef MINDSPORE_CORE_IR_MANAGER_H_ +#define MINDSPORE_CORE_IR_MANAGER_H_ #include #include @@ -34,10 +34,10 @@ #include "utils/signal.h" #include "utils/ordered_set.h" #include "utils/ordered_map.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/counter.h" #include "utils/hashing.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" #include "ir/anf.h" namespace mindspore { @@ -463,4 +463,4 @@ struct Change { } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_MANAGER_H_ +#endif // MINDSPORE_CORE_IR_MANAGER_H_ diff --git a/mindspore/core/ir/meta_func_graph.cc b/mindspore/core/ir/meta_func_graph.cc index c0cf9d4d2f..141b495a37 100644 --- a/mindspore/core/ir/meta_func_graph.cc +++ b/mindspore/core/ir/meta_func_graph.cc @@ -17,9 +17,53 @@ */ #include "ir/meta_func_graph.h" +#include "base/core_ops.h" +#include "utils/ms_context.h" +#include "abstract/abstract_function.h" // namespace to support intermediate representation definition namespace mindspore { + +abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() { + return std::make_shared(shared_from_base()); +} + +FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + std::vector parameters; + ParameterPtr undetermined_param = nullptr; + auto stub = std::make_shared(); + for (size_t i = 0; i < types.size(); ++i) { + auto param = stub->add_parameter(); + parameters.push_back(param); + if (types[i]->type_id() == kObjectTypeUndeterminedType) { + undetermined_param = param; + } + } + if (undetermined_param != nullptr) { + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i]->type_id() == kObjectTypeFunction) { + std::vector call_prim{parameters[i], undetermined_param}; + inputs.push_back(stub->NewCNode(call_prim)); + } else { + inputs.push_back(parameters[i]); + } + } + auto stub_output = stub->NewCNode(inputs); + stub->set_output(stub_output); + stub->set_stub(true); + return stub; + } + return nullptr; +} + FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { TypePtrList types; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), diff --git a/mindspore/core/ir/meta_func_graph.h b/mindspore/core/ir/meta_func_graph.h index 933c3f700d..05743bd6a4 100644 --- a/mindspore/core/ir/meta_func_graph.h +++ b/mindspore/core/ir/meta_func_graph.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ -#define MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ +#ifndef MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ +#define MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ #include #include @@ -49,7 +49,7 @@ class MetaFuncGraph : public FuncGraphBase { virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { return args_spec_list; } - + abstract::AbstractBasePtr ToAbstract() override; const std::vector &signatures() const { return signatures_; } void set_signatures(const std::vector &signatures) { signatures_ = signatures; } // Generate a Graph for the given abstract arguments. @@ -72,13 +72,13 @@ class MetaFuncGraph : public FuncGraphBase { return false; } } - const bool parse_info_ = true; protected: template std::shared_ptr shared_from_base() { return std::static_pointer_cast(shared_from_this()); } + FuncGraphPtr GenerateStubFunc(const TypePtrList &types); std::string name_; std::vector signatures_; std::unordered_map cache_; @@ -87,4 +87,4 @@ class MetaFuncGraph : public FuncGraphBase { using MetaFuncGraphPtr = std::shared_ptr; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ +#endif // MINDSPORE_CORE_IR_META_FUNC_GRAPH_H_ diff --git a/mindspore/core/ir/meta_tensor.cc b/mindspore/core/ir/meta_tensor.cc index c0b6b79a64..41b069b770 100644 --- a/mindspore/core/ir/meta_tensor.cc +++ b/mindspore/core/ir/meta_tensor.cc @@ -75,8 +75,6 @@ int MetaTensor::ElementsNum() const { return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies()); } -TypePtr MetaTensor::Dtype() const { return TypeIdToType(data_type_); } - TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { if (type_ptr == nullptr) { MS_LOG(ERROR) << "Dtype to be set is nullptr."; diff --git a/mindspore/core/ir/meta_tensor.h b/mindspore/core/ir/meta_tensor.h index 00106215e8..100c3cc59e 100644 --- a/mindspore/core/ir/meta_tensor.h +++ b/mindspore/core/ir/meta_tensor.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_META_TENSOR_H_ -#define MINDSPORE_CCSRC_IR_META_TENSOR_H_ +#ifndef MINDSPORE_CORE_IR_META_TENSOR_H_ +#define MINDSPORE_CORE_IR_META_TENSOR_H_ #include #include @@ -24,7 +24,7 @@ #include "base/base.h" #include "ir/dtype.h" -#include "utils/convert_utils.h" +#include "utils/convert_utils_base.h" #include "utils/hashing.h" // brief mindspore namespace. @@ -163,7 +163,6 @@ class MetaTensor : public Value { return false; } } - const bool parse_info_ = true; protected: // brief Data type of the tensor. @@ -192,4 +191,4 @@ using MetaTensorPtr = std::shared_ptr; } // namespace tensor } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_META_TENSOR_H_ +#endif // MINDSPORE_CORE_IR_META_TENSOR_H_ diff --git a/mindspore/core/ir/meta_tensor_extends.cc b/mindspore/core/ir/meta_tensor_extends.cc index d73aa19374..53fc58eb78 100644 --- a/mindspore/core/ir/meta_tensor_extends.cc +++ b/mindspore/core/ir/meta_tensor_extends.cc @@ -37,5 +37,7 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { abs_tensor->set_value(shared_from_base()); return abs_tensor; } + +TypePtr MetaTensor::Dtype() const { return TypeIdToType(data_type_); } } // namespace tensor } // namespace mindspore diff --git a/mindspore/core/ir/named.h b/mindspore/core/ir/named.h index 40e544c129..74fbf005a7 100644 --- a/mindspore/core/ir/named.h +++ b/mindspore/core/ir/named.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_NAMED_H_ -#define MINDSPORE_CCSRC_IR_NAMED_H_ +#ifndef MINDSPORE_CORE_IR_NAMED_H_ +#define MINDSPORE_CORE_IR_NAMED_H_ #include #include @@ -89,4 +89,4 @@ class Ellipsis : public Named { }; extern const NamedPtr kEllipsis; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_NAMED_H_ +#endif // MINDSPORE_CORE_IR_NAMED_H_ diff --git a/mindspore/core/ir/optimizer_caller.h b/mindspore/core/ir/optimizer_caller.h deleted file mode 100644 index 036f4ab510..0000000000 --- a/mindspore/core/ir/optimizer_caller.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ -#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ - -#include - -#include "ir/anf.h" - -namespace mindspore { -namespace opt { -class Optimizer; -using OptimizerPtr = std::shared_ptr; -using OptimizerWeakPtr = std::weak_ptr; - -using PredicateFuncType = std::function; -} // namespace opt - -class OptimizerCaller { - public: - virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } -}; -using OptimizerCallerPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ diff --git a/mindspore/core/ir/param_value.h b/mindspore/core/ir/param_value.h index 00b79ae91c..b1ec94248b 100644 --- a/mindspore/core/ir/param_value.h +++ b/mindspore/core/ir/param_value.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_H_ -#define MINDSPORE_CCSRC_IR_PARAM_VALUE_H_ +#ifndef MINDSPORE_CORE_IR_PARAM_VALUE_H_ +#define MINDSPORE_CORE_IR_PARAM_VALUE_H_ #include #include @@ -25,33 +25,23 @@ #include "ir/tensor.h" namespace mindspore { - class ParamValue { public: ParamValue() {} ParamValue(const ParamValue &other) = default; - ~ParamValue() = default; - - tensor::MetaTensorPtr value() const { return value_; } - void set_value(const tensor::MetaTensorPtr &value) { value_ = value; } + virtual ~ParamValue() = default; const std::string &name() const { return name_; } void set_name(const std::string &name) { name_ = name; } - const std::string &sparse_grad() const { return sparse_grad_; } - void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } - bool requires_grad() const { return requires_grad_; } void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } bool layerwise_parallel() const { return layerwise_parallel_; } void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; } - bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; } - void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; } - // Whether the parameter clone from other parameter. bool cloned() const { return cloned_; } @@ -79,17 +69,13 @@ class ParamValue { } private: - tensor::MetaTensorPtr value_; std::string name_{"Parameter"}; - std::string sparse_grad_; bool requires_grad_{true}; bool layerwise_parallel_{false}; - bool has_indexed_slices_grad_{false}; bool be_cloned_{false}; bool cloned_{false}; std::vector be_cloned_index_; int32_t cloned_index_{0}; }; - } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_H_ +#endif // MINDSPORE_CORE_IR_PARAM_VALUE_H_ diff --git a/mindspore/core/ir/param_value_py.cc b/mindspore/core/ir/param_value_py.cc deleted file mode 100644 index fb4b313c22..0000000000 --- a/mindspore/core/ir/param_value_py.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ir/param_value.h" -#include "pybind11/pybind11.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -namespace py = pybind11; - -REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { - (void)py::class_(*m, "ParamValue") - .def(py::init()) - .def("clone", &ParamValue::Clone) - .def_property("data", &ParamValue::value, &ParamValue::set_value) - .def_property("name", &ParamValue::name, &ParamValue::set_name) - .def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) - .def_property("layerwise_parallel", &ParamValue::layerwise_parallel, - &ParamValue::set_layerwise_parallel) - .def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad, - &ParamValue::set_has_indexed_slices_grad) - .def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad) - .def(py::pickle( - [](const ParamValue &p) { // __getstate__ - return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(), - p.layerwise_parallel(), p.has_indexed_slices_grad(), - p.sparse_grad()); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 6) { - std::runtime_error("Invalid state for ParamValue!"); - } - ParamValuePtr p = std::make_shared(); - p->set_value(t[0].cast()); - p->set_name(t[1].cast()); - p->set_requires_grad(t[2].cast()); - p->set_layerwise_parallel(t[3].cast()); - p->set_has_indexed_slices_grad(t[4].cast()); - p->set_sparse_grad(t[5].cast()); - return p; - })); - })); -} // namespace mindspore diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 94ba4a381a..7c1a856df6 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -14,17 +14,18 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ -#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ +#ifndef MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ +#define MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ +#include +#include #include #include -#include "ir/anf.h" -#include "frontend/operator/ops.h" +#include "ir/visitor.h" +#include "base/core_ops.h" namespace mindspore { - /// /// Base class for all recognizable patterns. /// We implement an Expression Template approach using static polymorphism based on @@ -39,9 +40,7 @@ namespace mindspore { template class PBase { public: - bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) { - return func(get_object().GetNode(node)); - } + bool CheckFunc(const PredicateFuncType &func, const AnfNodePtr &node) { return func(get_object().GetNode(node)); } const T &get_object() const { return *static_cast(this); } @@ -60,10 +59,10 @@ class PIsEqual { bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } }; -template +template class PatternNode : public PBase > { public: - T GetNode(const AnfNodePtr &node) const { + T GetNode(const AnfNodePtr &) const { if (!captured_) { MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; } @@ -90,12 +89,14 @@ class PatternNode : public PBase > { template class PBinOperation : public PBase > { public: - PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} + PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false) + : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {} + ~PBinOperation() = default; AnfNodePtr GetNode(const AnfNodePtr &node) const { AnfNodePtr lhs = x_.GetNode(node->func_graph()); AnfNodePtr rhs = y_.GetNode(node->func_graph()); - AnfNodePtrList list = {prim_->cast(), lhs, rhs}; + AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; return NewCNode(list, node->func_graph()); } @@ -105,7 +106,15 @@ class PBinOperation : public PBase > { auto inputs = cnode->inputs(); if (inputs.size() == 3) { // Binary Prim assumes only two inputs - if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { + if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) { + // If the operation is commutative, then check with inversed operands + if (is_commutative_) { + Reset(); + if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { + return false; + } + return true; + } return false; } return true; @@ -113,7 +122,6 @@ class PBinOperation : public PBase > { } return false; } - void Reset() const { x_.Reset(); y_.Reset(); @@ -123,6 +131,7 @@ class PBinOperation : public PBase > { const PrimitivePtr prim_; typename T::Internal x_; typename T2::Internal y_; + bool is_commutative_{false}; }; /// @@ -197,43 +206,95 @@ class PCNode : public PBase > { AnfNodePtr GetNode(const AnfNodePtr &node) const { tuple_utils::PTupleGetNode get_node(node); tuple_utils::apply_func_tuple(&get_node, args_); - return NewCNode(get_node.args_, node->func_graph()); + auto prim_cnode = get_node.args_; + // In case this PCNode has captured extra nodes + if (extra_nodes_.size() > 0) { + prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); + } + return NewCNode(prim_cnode, node->func_graph()); } bool TryCapture_(const AnfNodePtr &node) const { if (node->isa()) { auto cnode = node->cast(); auto inputs = cnode->inputs(); - if (inputs.size() != sizeof...(TArgs)) { + + auto pattern_arg_len = sizeof...(TArgs); + // There aren't enough inputs in Node to fill up the Pattern + if (inputs.size() < pattern_arg_len) { return false; } - tuple_utils::PTupleCapture capture_func(inputs); - tuple_utils::apply_func_tuple(&capture_func, args_); - return capture_func.captured_; - } + // Pattern must exactly match the number of Node inputs. + if (!has_min_extra_nodes_) { + // Inputs in Node perfectly match number of tokens in Pattern. + if (inputs.size() == pattern_arg_len) { + AnfNodePtrList tokens(inputs.begin(), inputs.end()); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + return false; + } + + // Pattern may accept extra (non specified) nodes at the end of the CNode + // There must be at least `min_extra_nodes` additional nodes in the inputs. + if (inputs.size() >= pattern_arg_len + min_extra_nodes_) { + AnfNodePtrList tokens(inputs.begin(), inputs.begin() + pattern_arg_len); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + // If it could capture the initial set of nodes specified in the Pattern + // and there are enough extra inputs to add + if (capture_func.captured_ && inputs.size() > pattern_arg_len) { + extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + pattern_arg_len, inputs.end()); + return true; + } + return capture_func.captured_; + } + return false; + } return false; } + /// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one + /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or + /// more nodes after the last one specified when building the PCNode. + const PCNode &MinExtraNodes(const size_t &min_extra_nodes = 0) const { + has_min_extra_nodes_ = true; + min_extra_nodes_ = min_extra_nodes; + return *this; + } + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); + has_min_extra_nodes_ = false; + extra_nodes_.clear(); } private: std::tuple args_; + mutable AnfNodePtrList extra_nodes_; + mutable bool has_min_extra_nodes_{false}; + mutable size_t min_extra_nodes_{0}; }; template class PPrimitive : public PBase > { public: explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} + ~PPrimitive() = default; AnfNodePtr GetNode(const AnfNodePtr &node) const { tuple_utils::PTupleGetNode get_node(node); tuple_utils::apply_func_tuple(&get_node, args_); auto prim_cnode = get_node.args_; prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); + + // In case this PPrimitive has captured extra nodes + if (extra_nodes_.size() > 0) { + prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); + } return NewCNode(prim_cnode, node->func_graph()); } @@ -241,70 +302,529 @@ class PPrimitive : public PBase > { if (IsPrimitiveCNode(node, prim_)) { auto cnode = node->cast(); auto inputs = cnode->inputs(); - if ((inputs.size() - 1) != sizeof...(TArgs)) { + // Number of arguments in Primitive Pattern (not including the Primitive node) + auto pattern_arg_len = sizeof...(TArgs); + // There aren't enough inputs in Node to fill up the Pattern + if ((inputs.size() - 1) < pattern_arg_len) { return false; } - AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); - tuple_utils::PTupleCapture capture_func(rest); - tuple_utils::apply_func_tuple(&capture_func, args_); + // Pattern must exactly match the number of Node inputs. + if (!has_min_extra_nodes_) { + // Inputs in Node perfectly match number of tokens in Pattern. + if ((inputs.size() - 1) == pattern_arg_len) { + AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + return false; + } - return capture_func.captured_; + // Pattern may accept extra (non specified) nodes at the end of the Primitive + // There must be at least `min_extra_nodes` additional nodes in the inputs. + if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { + AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); + tuple_utils::PTupleCapture capture_func(tokens); + tuple_utils::apply_func_tuple(&capture_func, args_); + // If it could capture the initial set of nodes specified in the Pattern + // and there are enough extra inputs to add + if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { + extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); + return true; + } + return capture_func.captured_; + } + return false; } - return false; } + /// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one + /// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or + /// more nodes after the last one specified when building the PPrimitive. + const PPrimitive &MinExtraNodes(const size_t &min_extra_nodes = 0) const { + has_min_extra_nodes_ = true; + min_extra_nodes_ = min_extra_nodes; + return *this; + } + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); + has_min_extra_nodes_ = false; + extra_nodes_.clear(); } private: const PrimitivePtr prim_; std::tuple args_; + mutable AnfNodePtrList extra_nodes_; + mutable bool has_min_extra_nodes_{false}; + mutable size_t min_extra_nodes_{0}; +}; + +/// +/// PConstant class can capture a value node of a specified value (check_value_) +/// or a non-specified one (any_value = true). +/// It can be configured to capture a scalar constant as well (is_scalar_ = true) +/// +template +class PConstant : public PBase > { + public: + explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int check_value = 0, + const bool is_scalar = false) + : as_node_(as_node), + captured_node_(as_node), + any_value_(any_value), + check_value_(check_value), + is_scalar_(is_scalar) {} + + ~PConstant() = default; + // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode + const PConstant &WithShapeAs(const AnfNodePtr &node) const { + if (node == nullptr) { + MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a nullptr node."; + } + as_node_ = node; + changed_shape_ = true; + return *this; + } + + // Sets as_node_ as the node caputred by the received Pattern token to produce a same-shape node with GetNode + const PConstant &WithShapeAs(const PatternNode &pnode) const { + if (captured_node_ == nullptr) { + MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a Pattern token without previously capturing a node."; + } + as_node_ = pnode.GetNode(captured_node_); + changed_shape_ = true; + return *this; + } + + /// Sets captured_node_ as the node captured by the Pattern received as argument + /// to produce a new node with its contents when calling GetNode. + const PConstant &WithValueOf(const PatternNode &pnode) const { + if (!any_value_) { + MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node."; + } + if (captured_node_ == nullptr) { + MS_EXCEPTION(ValueError) << "WithValueOf is trying to use a Pattern token without previously capturing a node."; + } + captured_node_ = pnode.GetNode(captured_node_); + changed_shape_ = true; + return *this; + } + + /// Create a new Value Node filled up with check_value. + /// This function must be used immediately before GetNode to avoid replacing the expected result. + /// Only valid for scalar constants. For tensors use WithShapeAs or WithValueOf. + const PConstant &NewValue() const { + if (!is_scalar_) { + MS_EXCEPTION(ValueError) << "NewValue is valid only for scalar PConstants."; + } + auto value_node_ = MakeValue(check_value_); + captured_node_ = NewValueNode(value_node_); + is_new_value_node_ = true; + return *this; + } + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + // If a NewValueNode was requested (using NewValue function) then return that created node. + if (is_new_value_node_) { + return captured_node_; + } + /// Return a NewTensorFilledWithData if the node was initialized to have a specific value + /// even if it wasn't captured. Usually for zero constants (x - x => zero). + /// If the shape was changed, use the new shape. + if (changed_shape_ || !captured_) { + if (!any_value_) { + return NewTensorFilledWithData(as_node_, check_value_); + } + return NewTensorFilledWithData(as_node_, captured_node_); + } + return captured_node_; + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (node->isa()) { + // If any_value_ is set don't check for the node's value. Just capture it. + if (any_value_) { + captured_node_ = node; + captured_ = true; + return true; + } + + auto value = node->cast()->value(); + if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) { + captured_node_ = node; + captured_ = true; + return true; + } + + auto value_node_ = MakeValue(check_value_); + if (*GetValueNode(node) == *value_node_) { + captured_node_ = node; + captured_ = true; + return true; + } + } + return false; + } + + void Reset() const { + captured_ = false; + changed_shape_ = false; + is_new_value_node_ = false; + } + + // Support function used for checking if all values of a Tensor are equal to `check_value_` + // Supported data types: double, float/float32, int/int32 + bool IsTensorConstant(const ValuePtr &value) const { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + auto threshold = FLT_MIN; + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > threshold) { + return false; + } + } + return true; + } else if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + auto threshold = DBL_MIN; + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > threshold) { + return false; + } + } + return true; + } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] != check_value_) { + return false; + } + } + return true; + } + // Input Data Type is not supported + return false; + } + + bool IsTensorScalarConstant(const ValuePtr &value) const { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { + return false; + } + return IsTensorConstant(value); + } + + void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const { + if (!node->isa()) { + return nullptr; + } + + auto value = node->cast()->value(); + + if (!value->isa()) { + return nullptr; + } + + tensor::TensorPtr tensor_ptr = dyn_cast(value); + return tensor_ptr->data_c(); + } + + // Make a new tensor (when possible) with the same shape as of `node` + // If x is nullptr then fill new tensor will "0" + // If x is a tensor with empty shape then fill new tensor with the single value of x + // If x is a tensor with same shape as `node` then return x as result + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (x == nullptr) { + if (memset_s(data, mem_size, 0, mem_size) != 0) { + return nullptr; + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + // x is not nullptr + if (x->isa()) { + if ((x->abstract() == nullptr) || !x->abstract()->isa()) { + return nullptr; + } + auto x_abstract = x->abstract()->cast(); + std::vector x_shape = x_abstract->shape()->shape(); + + if (x_shape != tensor_shape) { + return nullptr; + } + return x; + } + + if (!x->isa()) { + return nullptr; + } + auto x_value = x->cast()->value(); + if (!x_value->isa()) { + return nullptr; + } + + auto x_tensor_ptr = dyn_cast(x_value); + + if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { + return nullptr; + } + int ret = 0; + char *source_data = reinterpret_cast(GetPointerToTensorData(x)); + if (x_tensor_ptr->DataSize() == 1) { + for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { + ret = memcpy_s(data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr), source_data, + GetTypeByte(tensor_type_ptr)); + } + } else { + ret = memcpy_s(data, mem_size, source_data, mem_size); + } + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size" + << new_tensor_ptr->DataSize(); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (memset_s(data, mem_size, value, mem_size) != 0) { + return nullptr; + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + + // Support function to multiply two constant tensors: partially support broadcasting shapes + template + void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, + int out_data_size) const { + TM *data_1 = reinterpret_cast(in_data_1); + TM *data_2 = reinterpret_cast(in_data_2); + TM *data_out = new TM[out_data_size]; + + if (in_data_1_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[i]; + } + } + if (in_data_2_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[0]; + } + } else { + if (in_data_2_size < out_data_size) { + MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size."; + } + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[i]; + } + } + *out_data = reinterpret_cast(data_out); + return; + } + + AnfNodePtr MulByPatternConst(const PConstant &vpnode_2, const AnfNodePtr &node_3) const { + AnfNodePtr vnode_1 = this->GetNode(captured_node_); + AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); + return MulConstantTensors(vnode_1, vnode_2, node_3); + } + + AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const { + if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || + (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { + return nullptr; + } + + auto value_1 = GetValueNode(vnode_1); + auto value_2 = GetValueNode(vnode_2); + + if (!value_1->isa() || !value_2->isa()) { + return nullptr; + } + + auto tensor_ptr_1 = dyn_cast(value_1); + auto tensor_ptr_2 = dyn_cast(value_2); + + auto tensor_1_abstract = vnode_1->abstract()->cast(); + auto tensor_2_abstract = vnode_1->abstract()->cast(); + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); + TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } + + std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + + int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + + auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); + size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + int ret = 0; + void *data_out = nullptr; + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + ret = memcpy_s(data, mem_size, data_out, mem_size); + delete[] reinterpret_cast(data_out); + } else { + if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + ret = memcpy_s(data, mem_size, data_out, mem_size); + delete[] reinterpret_cast(data_out); + } else { + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + ret = memcpy_s(data, mem_size, data_out, mem_size); + delete[] reinterpret_cast(data_out); + } else { + // Unsupported data types + return nullptr; + } + } + } + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size" + << new_tensor_ptr->DataSize(); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + + using Internal = const PConstant &; + + protected: + mutable AnfNodePtr as_node_; + mutable AnfNodePtr captured_node_; + bool any_value_{true}; + int check_value_{0}; + bool is_scalar_{false}; + mutable bool is_new_value_node_{false}; + mutable bool captured_{false}; + mutable bool changed_shape_{false}; }; // Macro for binary operation functions -#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ - template \ - inline PBinOperation Operator(const PBase &x, const PBase &y) { \ - return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ +#define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \ + template \ + inline PBinOperation Operator(const PBase &x, const PBase &y) { \ + return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \ } // Arithmetic operations -BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); -BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); +BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); +BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); // Macros for match and replace #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ if ((CaptureNode).TryCapture(OrigNode)) { \ - return (ReplaceWith).GetNode(OrigNode); \ + auto rep = (ReplaceWith).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return rep; \ + } \ } #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ + auto rep = (ReplaceWith).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return rep; \ + } \ } #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ if ((CaptureNode).TryCapture(OrigNode)) { \ if ((Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ + auto rep = (ReplaceWith).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } \ + } else { \ + auto rep = (ElseNode).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return (ElseNode).GetNode(OrigNode); \ + } \ } \ - return (ElseNode).GetNode(OrigNode); \ } #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ if ((CaptureNode).TryCapture(OrigNode)) { \ - return (Lambda)(); \ + auto rep = (Lambda)(); \ + if (rep != nullptr) { \ + return rep; \ + } \ } #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (Lambda)(); \ + auto rep = (Lambda)(); \ + if (rep != nullptr) { \ + return rep; \ + } \ } } // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ +#endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ diff --git a/mindspore/core/ir/primitive.cc b/mindspore/core/ir/primitive.cc index 352c0f31ae..07e43cf54e 100644 --- a/mindspore/core/ir/primitive.cc +++ b/mindspore/core/ir/primitive.cc @@ -17,8 +17,39 @@ #include "ir/primitive.h" #include +#include "abstract/abstract_function.h" namespace mindspore { + +static std::string MakeId() { + // Use atomic to make id generator thread safe. + static std::atomic last_id{1}; + return "P" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); +} + +Primitive::Primitive(const std::string &name, const bool is_base, const PrimType prim_type) + : Named(name), + is_base_(is_base), + has_signature_(false), + prim_type_(prim_type), + record_evaluate_add_attr_(false), + is_const_value_(false), + id_(MakeId()) {} + +Primitive::Primitive(const Primitive &prim) + : Named(prim), + attrs_(prim.attrs_), + instance_name_(prim.instance_name_), + is_base_(prim.is_base_), + has_signature_(prim.has_signature_), + prim_type_(prim.prim_type_), + record_evaluate_add_attr_(false), + id_(prim.id_) {} + +abstract::AbstractBasePtr Primitive::ToAbstract() { + return std::make_shared(shared_from_base(), nullptr); +} + bool Primitive::operator==(const Value &other) const { if (other.isa()) { auto other_prim = static_cast(other); diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index 5471b58063..2d52872eef 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_H_ +#ifndef MINDSPORE_CORE_IR_PRIMITIVE_H_ +#define MINDSPORE_CORE_IR_PRIMITIVE_H_ #include #include @@ -25,7 +25,6 @@ #include "ir/dtype/type.h" #include "abstract/abstract_value.h" -#include "frontend/parallel/ops_info/operator_info.h" #include "utils/base_ref_extends.h" namespace mindspore { @@ -41,24 +40,10 @@ enum PrimType { class Primitive : public Named { public: - explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) - : Named(name), - is_base_(is_base), - has_signature_(false), - prim_type_(prim_type), - record_evaluate_add_attr_(false) {} - - Primitive(const Primitive &prim) - : Named(prim), - attrs_(prim.attrs_), - instance_name_(prim.instance_name_), - is_base_(prim.is_base_), - has_signature_(prim.has_signature_), - prim_type_(prim.prim_type_), - record_evaluate_add_attr_(false) {} - + explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn); + Primitive(const Primitive &prim); MS_DECLARE_PARENT(Primitive, Named); - + abstract::AbstractBasePtr ToAbstract(); abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); std::string ToString() const override { return name(); } void BeginRecordAddAttr() { @@ -83,6 +68,7 @@ class Primitive : public Named { void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } + virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; } ValuePtr GetAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); @@ -91,6 +77,12 @@ class Primitive : public Named { const std::unordered_map &attrs() const { return attrs_; } const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } + void set_evaluate_added_attrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + MS_LOG(INFO) << " set evalu attrl " << name() << attr.first; + attrs_[attr.first] = attr.second; + } + } // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. bool HasAttr() const { return !attrs_.empty(); } @@ -99,6 +91,7 @@ class Primitive : public Named { return !(iter == attrs_.cend()); } void set_prim_type(const PrimType t) { prim_type_ = t; } + virtual PrimitivePtr Clone() { return std::make_shared(*this); } void set_instance_name(const std::string s) { instance_name_ = s; } bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } @@ -116,6 +109,9 @@ class Primitive : public Named { bool is_base() const { return is_base_; } virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } + void set_is_const_value(bool value) { is_const_value_ = value; } + bool is_const_value() const { return is_const_value_; } + std::string id() const { return id_; } protected: std::unordered_map attrs_; @@ -127,6 +123,8 @@ class Primitive : public Named { bool has_signature_; PrimType prim_type_; bool record_evaluate_add_attr_; + bool is_const_value_; + std::string id_{""}; }; inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { @@ -149,4 +147,4 @@ struct PrimitiveHasher { } }; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ +#endif // MINDSPORE_CORE_IR_PRIMITIVE_H_ diff --git a/mindspore/core/ir/primitive_py.cc b/mindspore/core/ir/primitive_py.cc deleted file mode 100644 index 1a97487ddc..0000000000 --- a/mindspore/core/ir/primitive_py.cc +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive_py.h" -#include -#include -#include "ir/signature.h" -#include "frontend/operator/ops.h" -#include "./common.h" -#include "pipeline/jit/parse/python_adapter.h" -#include "pipeline/jit/parse/data_converter.h" -#include "pybind11/pytypes.h" -#include "utils/convert_utils_base.h" -#include "utils/primitive_utils.h" -#include "utils/base_ref_py.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" - -namespace mindspore { -namespace { -constexpr auto kBpropAttrName = "bprop"; -constexpr auto kCellHookAttrName = "cell_hook"; -constexpr auto kCellIDAttrName = "cell_id"; -void SyncData(const py::object &arg) { - if (py::isinstance(arg)) { - py::tuple arg_list = py::cast(arg); - for (size_t i = 0; i < arg_list.size(); i++) { - SyncData(arg_list[i]); - } - } - if (py::isinstance(arg)) { - auto tensor = py::cast(arg); - (void)tensor->data_sync(); - } -} -} // namespace -std::map PrimitivePy::hook_grad_; -static ValuePtr PyArgToValue(const py::object &arg) { - if (py::isinstance(arg) && - py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { - return nullptr; - } - return parse::data_converter::PyDataToValue(arg); -} - -void PrimitivePy::set_signatures( - std::vector> signatures) { - signatures_.clear(); - for (auto &signature : signatures) { - auto [name, rw, kind, arg_default, dtype] = signature; - auto default_value = PyArgToValue(arg_default); - signatures_.emplace_back(name, rw, kind, default_value, dtype); - } - set_has_signature(true); -} - -py::function PrimitivePy::GetBpropFunction() { - static const char *const get_bprop_func_name = "get_bprop"; - if (py::hasattr(python_obj_, get_bprop_func_name)) { - py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); - return fn; - } else { - auto fn = GetBpropFunctionByObj(python_obj_); - return fn; - } -} - -BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { - auto py_args = py::tuple(args.size()); - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg:" << i << ":"; - i++; - } - py::object obj; - bool is_bprop = this->HasAttr(kBpropAttrName); - if (is_bprop) { - SyncData(py_args); - obj = hook_(*py_args); - return std::make_shared(obj); - } - SyncData(py_args[2]); - bool is_cell = this->HasAttr(kCellHookAttrName); - if (is_cell) { - auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); - auto iter = hook_grad_.find(cell_id); - if (iter != hook_grad_.end()) { - auto hook_args = py::tuple(3); - hook_args[0] = cell_id; - hook_args[1] = py::make_tuple(iter->second); - hook_args[2] = py::make_tuple(py_args[2]); - obj = hook_(*hook_args); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - hook_grad_.erase(cell_id); - } else { - hook_grad_[cell_id] = py_args[2]; - obj = py_args[2]; - } - } else { - // Hook operator for execute variable hook function - obj = hook_(py::make_tuple(py_args[2])); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - } - obj = py::make_tuple(obj); - return std::make_shared(obj); -} - -py::function PrimitivePy::GetComputeFunction() { - static const char *const compute_func_name = "vm_impl"; - - if (py::hasattr(python_obj_, compute_func_name)) { - MS_LOG(INFO) << name() << " compute_func_name"; - py::function fn = python_obj_.attr(compute_func_name).cast(); - return fn; - } - - static const std::string vm_module = "mindspore.ops.vm_impl_registry"; - static const std::string get_vm_impl_fn = "get_vm_impl_fn"; - MS_LOG(INFO) << name() << ": get_vm_impl_fn"; - py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); - py::function vm_fn = get_fn(python_obj_); - - if (py::isinstance(vm_fn)) { - MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); - vm_fn = mindspore::GetComputeFunction(Primitive::name()); - } - return vm_fn; -} - -void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { - std::string attr_name = name; - ValuePtr converted_ret = nullptr; - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; - } - bool converted = parse::ConvertData(obj, &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); - } - (void)this->AddAttr(attr_name, converted_ret); -} - -py::dict PrimitivePy::GetAttrDict() { - py::dict attr_dict; - for (auto &attr : attrs_) { - attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); - } - return attr_dict; -} - -void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { - MS_EXCEPTION_IF_NULL(primitive); - if (!primitive->isa()) { - MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; - } - auto primitive_py = primitive->cast(); - MS_EXCEPTION_IF_NULL(primitive_py); - this->set_hook(primitive_py->hook()); -} - -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { - (void)py::enum_(*m, "prim_type", py::arithmetic()) - .value("unknown", PrimType::kPrimTypeUnknown) - .value("builtin", PrimType::kPrimTypeBuiltIn) - .value("py_infer_shape", PrimType::kPrimTypePyInferShape) - .value("user_custom", PrimType::kPrimTypeUserCustom); - (void)py::class_>(*m, "Primitive_") - .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) - .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") - .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") - .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") - .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") - .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") - .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); - })); -} // namespace mindspore diff --git a/mindspore/core/ir/primitive_py.h b/mindspore/core/ir/primitive_py.h deleted file mode 100644 index 2dc45ac341..0000000000 --- a/mindspore/core/ir/primitive_py.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ - -#include -#include -#include -#include -#include -#include - -#include "abstract/abstract_value.h" -#include "utils/misc.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" -#include "ir/primitive.h" -#include "ir/signature.h" -#include "frontend/parallel/ops_info/operator_info.h" - -namespace py = pybind11; -namespace mindspore { -class PrimitivePy : public Primitive { - public: - PrimitivePy(const py::str &name, const py::object &python_obj) - : Primitive(name, false), python_obj_(python_obj), signatures_() {} - ~PrimitivePy() override = default; - MS_DECLARE_PARENT(PrimitivePy, Primitive); - py::function GetBpropFunction(); - py::function GetComputeFunction(); - - void set_signatures( - std::vector> - signatures); - - const std::vector &signatures() const { return signatures_; } - - void CopyHookFunction(const PrimitivePtr &primitive) override; - - void AddPyAttr(const py::str &name, const py::object &obj); - - py::dict GetAttrDict(); - void set_hook(const py::function &hook) { hook_ = hook; } - py::function hook() const { return hook_; } - BaseRef RunHookFunction(const VectorRef &args) const override; - const bool parse_info_ = true; - const py::object &GetPyObj() const { return python_obj_; } - bool is_tuple_input_ = false; - - private: - py::object python_obj_; - py::function hook_; - std::vector signatures_; - static std::map hook_grad_; -}; - -using PrimitivePyPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ diff --git a/mindspore/core/ir/scalar.h b/mindspore/core/ir/scalar.h index adae8c65f9..b814a4781d 100644 --- a/mindspore/core/ir/scalar.h +++ b/mindspore/core/ir/scalar.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_SCALAR_H_ -#define MINDSPORE_CCSRC_IR_SCALAR_H_ +#ifndef MINDSPORE_CORE_IR_SCALAR_H_ +#define MINDSPORE_CORE_IR_SCALAR_H_ #include #include @@ -359,4 +359,4 @@ IMM_TRAITS(FP64ImmPtr, double) } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_SCALAR_H_ +#endif // MINDSPORE_CORE_IR_SCALAR_H_ diff --git a/mindspore/core/ir/signature.h b/mindspore/core/ir/signature.h index e9a5a2e1ca..418546e93b 100644 --- a/mindspore/core/ir/signature.h +++ b/mindspore/core/ir/signature.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_ -#define MINDSPORE_CCSRC_IR_SIGNATURE_H_ +#ifndef MINDSPORE_CORE_IR_SIGNATURE_H_ +#define MINDSPORE_CORE_IR_SIGNATURE_H_ #include #include @@ -66,4 +66,4 @@ struct Signature { }; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_SIGNATURE_H_ +#endif // MINDSPORE_CORE_IR_SIGNATURE_H_ diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index c04c2cca96..b0af4411d6 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -28,7 +29,6 @@ #include #include -#include "runtime/device/device_address.h" #include "abstract/abstract_value.h" namespace mindspore { @@ -54,54 +54,80 @@ static size_t SizeOf(const std::vector &shape) { return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); } +template +std::unique_ptr NewData(const U *input, size_t size) { + if (input == nullptr || size == 0) { + return nullptr; + } + auto data = std::make_unique(size); + if constexpr (!std::is_same::value && (std::is_same::value || std::is_same::value)) { + // Because float16 do not support implicit cast from/to other types, + // We can not use std::copy() on array of float16, use a loop here. + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(input[i]); + } + } else { + // otherwise, use std::copy for better performance. + std::copy(input, input + size, data.get()); + } + return data; +} + +template +std::unique_ptr NewData(Scalar scalar) { + auto data = std::make_unique(1); + data[0] = static_cast(scalar); + return data; +} + template -std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { - const size_t count = SizeOf(shape); +std::unique_ptr CopyData(const std::vector &shape, void *data, TypeId data_type) { + const size_t size = SizeOf(shape); switch (data_type) { case kNumberTypeBool: case kNumberTypeUInt8: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeInt8: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeInt16: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeInt32: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeInt64: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeUInt16: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeUInt32: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeUInt64: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeFloat16: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } case kNumberTypeFloat32: { - const float *buf = static_cast(data); - return std::vector(buf, buf + count); + auto buf = static_cast(data); + return NewData(buf, size); } case kNumberTypeFloat64: { auto buf = static_cast(data); - return std::vector(buf, buf + count); + return NewData(buf, size); } default: break; @@ -110,14 +136,14 @@ std::vector CopyData(const std::vector &shape, void *data, TypeId data_t } template -std::vector CopyData(const std::vector &shape, void *data, size_t data_len) { +std::unique_ptr CopyData(const std::vector &shape, void *data, size_t data_len) { size_t size = SizeOf(shape); if (size * sizeof(T) != data_len) { MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) << " item size " << sizeof(T); } auto buf = static_cast(data); - return {buf, buf + size}; + return NewData(buf, size); } // Tensor data implementation. @@ -133,13 +159,13 @@ class TensorDataImpl : public TensorData { TensorDataImpl(const std::vector &shape, void *data, TypeId data_type) : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_type)) {} - template - TensorDataImpl(const std::vector &shape, InputIt first, InputIt last) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {} + template + TensorDataImpl(const std::vector &shape, const U *input, size_t size) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData(input, size)) {} template TensorDataImpl(const std::vector &shape, Scalar scalar) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast(scalar)}) {} + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData(scalar)) {} ssize_t size() const override { return static_cast(data_size_); } @@ -150,24 +176,26 @@ class TensorDataImpl : public TensorData { ssize_t ndim() const override { return static_cast(ndim_); } void *data() override { - static std::vector empty_data(1); - if (data_size_ == 0) { - // Prevent null pointer for empty shape. - return empty_data.data(); + if (data_ == nullptr) { + // Lazy allocation. + data_ = std::make_unique(data_size_); } - // Lazy allocation. - if (data_.empty()) { - data_.resize(data_size_); - } - return data_.data(); + return data_.get(); } bool equals(const TensorData &other) const override { auto ptr = dynamic_cast *>(&other); - if (ptr) { - return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && (data_ == ptr->data_)); + if (ptr == nullptr) { + return false; + } + if (ptr == this) { + return true; } - return false; + if (data_ == nullptr || ptr->data_ == nullptr) { + return false; + } + return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && + std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get()); } std::string ToString(const TypeId type, const std::vector &shape) const override { @@ -180,7 +208,7 @@ class TensorDataImpl : public TensorData { if (data_size_ == 0) { return ""; } - if (data_.empty()) { + if (data_ == nullptr) { return ""; } @@ -206,8 +234,13 @@ class TensorDataImpl : public TensorData { if (isScalar) { ss << value; } else { - ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right) - << value; + if (std::is_same::value) { + ss << std::setw(11) << std::setprecision(4) << std::setiosflags(std::ios::scientific | std::ios::right) + << value; + } else { + ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right) + << value; + } } linefeedThreshold = kThreshold1DFloat; } else if (type == kNumberTypeBool) { @@ -225,10 +258,21 @@ class TensorDataImpl : public TensorData { ss << ' '; } } + + // Set width and indent for different int type. + // + // int8/uint8 width: 3 + // int16/uint16 width: 5 + // int32/uint32 width: 10 + // int64/uint64 width: NOT SET if constexpr (std::is_same::value) { - ss << static_cast(value); + ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast(value); } else if constexpr (std::is_same::value) { - ss << static_cast(value); + ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast(value); + } else if constexpr (std::is_same::value || std::is_same::value) { + ss << std::setw(5) << std::setiosflags(std::ios::right) << value; + } else if constexpr (std::is_same::value || std::is_same::value) { + ss << std::setw(10) << std::setiosflags(std::ios::right) << value; } else { ss << value; } @@ -299,7 +343,7 @@ class TensorDataImpl : public TensorData { size_t ndim_{0}; size_t data_size_{0}; - std::vector data_; + std::unique_ptr data_; }; template @@ -364,12 +408,12 @@ Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, Type Tensor::Tensor(const std::vector &input, const TypePtr &data_type) : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast(input.size())}), - data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), + data_(MakeTensorData(data_type_, shape_, input.data(), input.size())), id_(MakeId()) {} Tensor::Tensor(const std::vector &input, const TypePtr &data_type) : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast(input.size())}), - data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), + data_(MakeTensorData(data_type_, shape_, input.data(), input.size())), id_(MakeId()) {} Tensor::Tensor(int64_t input, const TypePtr &data_type) @@ -454,67 +498,4 @@ TypeId Tensor::set_data_type(const TypeId data_type) { return data_type; } } // namespace tensor - -namespace inference { -MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { - return new Tensor(data_type, shape); -} - -Tensor::Tensor(TypeId data_type, const std::vector &shape) { - this->tensor_impl_ = std::make_shared(data_type, shape); -} - -Tensor::Tensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } - -TypeId Tensor::data_type() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_type(); -} - -TypeId Tensor::set_data_type(TypeId data_type) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_data_type(data_type); -} - -std::vector Tensor::shape() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->shape(); -} - -size_t Tensor::set_shape(const std::vector &shape) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_shape(shape); -} - -int Tensor::DimensionSize(size_t index) const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->DimensionSize(index); -} - -int Tensor::ElementsNum() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->ElementsNum(); -} - -std::size_t Tensor::hash() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->hash(); -} - -std::shared_ptr Tensor::tensor() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_; -} - -size_t Tensor::Size() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data().nbytes(); -} - -void *Tensor::MutableData() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_c(); -} - -} // namespace inference } // namespace mindspore diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 727fb0fdd8..c61add5a23 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_TENSOR_H_ -#define MINDSPORE_CCSRC_IR_TENSOR_H_ +#ifndef MINDSPORE_CORE_IR_TENSOR_H_ +#define MINDSPORE_CORE_IR_TENSOR_H_ #include #include @@ -25,7 +25,6 @@ #include "Eigen/Core" #include "ir/device_sync.h" #include "ir/meta_tensor.h" -#include "include/ms_tensor.h" #include "utils/log_adapter.h" using float16 = Eigen::half; @@ -83,7 +82,7 @@ class Tensor : public MetaTensor { // param data The shared tensor data. Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data); - // brief Create an all zero tensor. + // brief Create a lazy allocated tensor. // // param data_type [TypeId] Data type of the tensor. // param shape The shape represented by std::vector of the tensor. @@ -225,8 +224,6 @@ class Tensor : public MetaTensor { std::string id() const { return id_; } - const bool parse_info_ = true; - private: bool init_flag_{false}; TensorDataPtr data_{nullptr}; @@ -237,40 +234,6 @@ class Tensor : public MetaTensor { using TensorPtr = std::shared_ptr; using TensorPtrList = std::vector>; } // namespace tensor - -namespace inference { -class Tensor : public MSTensor { - public: - Tensor(TypeId data_type, const std::vector &shape); - - explicit Tensor(std::shared_ptr tensor_ptr); - - ~Tensor() = default; - - TypeId data_type() const override; - - TypeId set_data_type(const TypeId data_type) override; - - std::vector shape() const override; - - size_t set_shape(const std::vector &shape) override; - - int DimensionSize(size_t index) const override; - - int ElementsNum() const override; - - std::size_t hash() const override; - - std::shared_ptr tensor() const; - - size_t Size() const override; - - void *MutableData() const override; - - protected: - std::shared_ptr tensor_impl_; -}; -} // namespace inference } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_TENSOR_H_ +#endif // MINDSPORE_CORE_IR_TENSOR_H_ diff --git a/mindspore/core/ir/tensor_py.cc b/mindspore/core/ir/tensor_py.cc deleted file mode 100644 index ef78d2720e..0000000000 --- a/mindspore/core/ir/tensor_py.cc +++ /dev/null @@ -1,389 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/tensor_py.h" - -#include -#include -#include -#include -#include - -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" -#include "abstract/abstract_value.h" - -namespace mindspore { -namespace tensor { - -static TypeId GetDataType(const py::buffer_info &buf) { - if (buf.format.size() == 1) { - switch (buf.format.front()) { - case 'e': - case 'f': - case 'd': - switch (buf.itemsize) { - case 2: - return TypeId::kNumberTypeFloat16; - case 4: - return TypeId::kNumberTypeFloat32; - case 8: - return TypeId::kNumberTypeFloat64; - } - break; - case 'b': - case 'h': - case 'i': - case 'l': - case 'q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeInt8; - case 2: - return TypeId::kNumberTypeInt16; - case 4: - return TypeId::kNumberTypeInt32; - case 8: - return TypeId::kNumberTypeInt64; - } - break; - case 'B': - case 'H': - case 'I': - case 'L': - case 'Q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeUInt8; - case 2: - return TypeId::kNumberTypeUInt16; - case 4: - return TypeId::kNumberTypeUInt32; - case 8: - return TypeId::kNumberTypeUInt64; - } - break; - case '?': - return TypeId::kNumberTypeBool; - } - } - MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; - return TypeId::kTypeUnknown; -} - -static std::string GetPyTypeFormat(TypeId data_type) { - switch (data_type) { - case TypeId::kNumberTypeFloat16: - return "e"; - case TypeId::kNumberTypeFloat32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeFloat64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt8: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt16: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt8: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt16: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeBool: - return py::format_descriptor::format(); - default: - MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; - return ""; - } -} - -static bool IsCContiguous(const py::array &input) { - auto flags = static_cast(input.flags()); - return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; -} - -TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { - // Get input buffer info. - py::buffer_info buf = input.request(); - // Check data types. - auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; - auto buf_type = GetDataType(buf); - if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { - MS_LOG(EXCEPTION) << "Unsupported tensor type!"; - } - // Use buf type as data type if type_ptr not set. - if (data_type == TypeId::kTypeUnknown) { - data_type = buf_type; - } - // Convert input array to C contiguous if need. - std::unique_ptr tmp_buf; - if (!IsCContiguous(input)) { - Py_buffer pybuf; - if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { - MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; - } - tmp_buf = std::make_unique(pybuf.len); - if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { - MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; - } - PyBuffer_Release(&pybuf); - buf.ptr = tmp_buf.get(); - } - // Get tensor shape. - std::vector shape(buf.shape.begin(), buf.shape.end()); - if (data_type == buf_type) { - // Use memory copy if input data type is same as the required type. - return std::make_shared(data_type, shape, buf.ptr, buf.size * buf.itemsize); - } - // Create tensor with data type converted. - return std::make_shared(data_type, shape, buf.ptr, buf_type); -} - -static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { - std::vector strides; - strides.reserve(shape.size()); - const auto ndim = shape.size(); - for (size_t i = 0; i < ndim; ++i) { - auto stride = item_size; - for (size_t j = i + 1; j < ndim; ++j) { - stride *= shape[j]; - } - strides.push_back(stride); - } - return strides; -} - -static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { - std::vector shape(tensor.shape().begin(), tensor.shape().end()); - std::vector strides = GetStrides(shape, tensor.data().itemsize()); - return py::buffer_info{ - tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; -} - -py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { - auto &shape = tensor.shape(); - py::tuple dims(shape.size()); - for (size_t i = 0; i < dims.size(); ++i) { - dims[i] = py::int_(shape[i]); - } - return dims; -} - -py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { - tensor.data_sync(); - auto info = GetPyBufferInfo(tensor); - py::object self = py::cast(&tensor); - return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); -} - -py::array TensorPy::AsNumpy(const Tensor &tensor) { - auto info = GetPyBufferInfo(tensor); - py::object self = py::cast(&tensor); - return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); -} - -static std::vector GetShapeFromTuple(const py::tuple &tuple) { - std::vector shape; - const size_t size = tuple.size(); - shape.reserve(tuple.size()); - for (size_t i = 0; i < size; ++i) { - shape.push_back(py::int_(tuple[i])); - } - return shape; -} - -REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { - // Define python MetaTensor class. - (void)py::class_>(*m, "MetaTensor") - .def(py::init>(), py::arg("dtype"), py::arg("shape")) - .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) - .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") - .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") - .def(py::pickle( - [](const MetaTensor &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(static_cast(t.data_type()), t.shape()); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 2) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); - return tensor; - })); - // Define python Tensor class. - // dtype should define before Tensor, because Tensor init depend dtype - (void)py::class_>(*m, "Tensor") - .def(py::init([](const Tensor &tensor) { return std::make_shared(tensor); }), - py::arg("input")) - .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { - TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; - if (data_type == kTypeUnknown || tensor.data_type() == data_type) { - return std::make_shared(tensor); - } - return std::make_shared(tensor, data_type); - }), - py::arg("input"), py::arg("dtype")) - .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { - auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; - return std::make_shared(data_type, GetShapeFromTuple(shape)); - }), - py::arg("dtype"), py::arg("shape")) - .def(py::init([](const py::array &input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(input, type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::float_ input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::int_ input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::list input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::tuple input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) - .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) - .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( - Get the tensor's data type. - - Returns: - type, the data type of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) - >>> data.dtype - Int32 - )mydelimiter") - .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( - Get the tensor's shape. - - Returns: - tuple[int], the shape of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((3, 3))) - >>> data.shape() - (3, 3) - )mydelimiter") - .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( - Convert tensor to numpy.ndarray. - - Returns: - numpy.ndarray. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> array = data.asnumpy() - >>> array - array([[1., 1., 1.], - [1., 1., 1.]]) - )mydelimiter") - .def("size", &Tensor::DataSize, R"mydelimiter( - Get tensor's data size. - - Returns: - int, the size of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.size() - 6 - )mydelimiter") - .def("is_init", &Tensor::is_init, R"mydelimiter( - Get tensor init_flag. - - Returns: - bool, whether the tensor init. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.is_init() - False - )mydelimiter") - .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( - Set tensor init_flag. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.set_init_flag(True) - )mydelimiter") - .def("dim", &Tensor::DataDim, R"mydelimiter( - Get tensor's data dimension. - - Returns: - int, the dimension of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.dim() - 2 - )mydelimiter") - .def("assign_value", &Tensor::AssignValue, R"mydelimiter( - Assign another tensor value to this. - - Arg: - value (:class:`mindspore.tensor`): The value tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) - >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) - >>> data.assign_value(data2) - >>> data.shape - (2, 2) - )mydelimiter") - .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( - Set the tensor's data type. - - Arg: - dtype (:class:`mindspore.dtype`): The type of output tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) - >>> data.set_dtype(mindspore.int32) - mindspore.int32 - )mydelimiter") - .def("__str__", &Tensor::ToString) - .def("__repr__", &Tensor::ToStringRepr) - .def(py::pickle( - [](const Tensor &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(TensorPy::AsNumpy(t)); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - return TensorPy::MakeTensor(t[0].cast()); - })); - })); -} // namespace tensor -} // namespace mindspore diff --git a/mindspore/core/ir/tensor_py.h b/mindspore/core/ir/tensor_py.h deleted file mode 100644 index f917584977..0000000000 --- a/mindspore/core/ir/tensor_py.h +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_TENSOR_PY_H_ -#define MINDSPORE_CCSRC_IR_TENSOR_PY_H_ - -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pybind11/numpy.h" - -#include "ir/tensor.h" - -namespace py = pybind11; - -namespace pybind11 { -namespace detail { -// Similar to enums in `pybind11/numpy.h`. Determined by doing: -// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' -constexpr int NPY_FLOAT16 = 23; - -template -struct npy_scalar_caster { - PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); - using Array = array_t; - - bool load(handle src, bool convert) { - // Taken from Eigen casters. Permits either scalar dtype or scalar array. - handle type = dtype::of().attr("type"); - if (!convert && !isinstance(src) && !isinstance(src, type)) return false; - - Array tmp = Array::ensure(src); - if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { - this->value = *tmp.data(); - return true; - } - - return false; - } - - static handle cast(T src, return_value_policy, handle) { - Array tmp({1}); - tmp.mutable_at(0) = src; - tmp.resize({}); - - // You could also just return the array if you want a scalar array. - object scalar = tmp[tuple()]; - return scalar.release(); - } -}; - -template <> -struct npy_format_descriptor { - static constexpr auto name = "float16"; - static pybind11::dtype dtype() { - handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); - return reinterpret_borrow(ptr); - } - virtual ~npy_format_descriptor() {} -}; - -template <> -struct type_caster : public npy_scalar_caster { - static constexpr auto name = "float16"; -}; -} // namespace detail -} // namespace pybind11 - -// brief mindspore namespace. -// -// mindspore namespace is the top level namespace of Mindsporeession project. -// Other namespace should be a sub namespace of mindspore namespace in the ME project. -namespace mindspore { -// brief mindspore::tensor namespace -// -// A sub namespace in ME to support tensor related definition. -namespace tensor { -// Tensor python wrapper and adapter class. -class TensorPy { - public: - // brief Create Tensor from a numpy array object. - // - // param input [py::array] Data value of the tensor. - // param data_type [TypeId] Data type of the tensor. - static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr); - - static py::array SyncAsNumpy(const Tensor &tensor); - - static py::array AsNumpy(const Tensor &tensor); - - static py::tuple GetPyTupleShape(const Tensor &tensor); -}; - -} // namespace tensor -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_TENSOR_PY_H_ diff --git a/mindspore/core/ir/value.cc b/mindspore/core/ir/value.cc index 92535bc2e9..560247b8ce 100644 --- a/mindspore/core/ir/value.cc +++ b/mindspore/core/ir/value.cc @@ -130,7 +130,12 @@ bool FP32Imm::operator==(const Value &other) const { return false; } } -bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } +bool FP32Imm::operator==(const FP32Imm &other) const { + if (std::isinf(v_) && std::isinf(other.v_)) { + return true; + } + return fabs(v_ - other.v_) < FLT_EPSILON; +} bool FP64Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); @@ -179,7 +184,12 @@ std::string ValueSequeue::DumpText() const { return oss.str(); } -bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } +bool FP64Imm::operator==(const FP64Imm &other) const { + if (std::isinf(v_) && std::isinf(other.v_)) { + return true; + } + return fabs(v_ - other.v_) < DBL_EPSILON; +} bool StringImm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h index 535de81adf..6288aa6c67 100644 --- a/mindspore/core/ir/value.h +++ b/mindspore/core/ir/value.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_VALUE_H_ -#define MINDSPORE_CCSRC_IR_VALUE_H_ +#ifndef MINDSPORE_CORE_IR_VALUE_H_ +#define MINDSPORE_CORE_IR_VALUE_H_ #include #include @@ -31,7 +31,7 @@ #include "ir/scalar.h" #include "ir/dtype/ref.h" #include "utils/hashing.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { class ValueSequeue : public Value { @@ -303,4 +303,4 @@ inline ValueNodePtr NewValueNode(const T &x) { } } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_VALUE_H_ +#endif // MINDSPORE_CORE_IR_VALUE_H_ diff --git a/mindspore/core/ir/visitor.cc b/mindspore/core/ir/visitor.cc index 9e63f4f9c1..3866052e97 100644 --- a/mindspore/core/ir/visitor.cc +++ b/mindspore/core/ir/visitor.cc @@ -18,24 +18,24 @@ #include "ir/visitor.h" namespace mindspore { -void AnfVisitor::Visit(const AnfNodePtr &node) { node->accept(this); } +void AnfIrVisitor::Visit(const AnfNodePtr &node) { node->accept(this); } -void AnfVisitor::Visit(const CNodePtr &cnode) { +void AnfIrVisitor::Visit(const CNodePtr &cnode) { for (auto &input : cnode->inputs()) { Visit(input); } } -void AnfVisitor::Visit(const ValueNodePtr &vnode) { +void AnfIrVisitor::Visit(const ValueNodePtr &vnode) { if (IsValueNode(vnode)) { auto func_graph = GetValueNode(vnode); Visit(func_graph->output()); } } -void AnfVisitor::Visit(const ParameterPtr &) {} +void AnfIrVisitor::Visit(const ParameterPtr &) {} -VisitFuncType AnfVisitor::Match(const PrimitivePtr &prim, const std::vector &funcs) { +VisitFuncType AnfIrVisitor::Match(const PrimitivePtr &prim, const std::vector &funcs) { auto fn = [prim, funcs, this](const AnfNodePtr &node) { if (!IsPrimitiveCNode(node, prim)) { return; diff --git a/mindspore/core/ir/visitor.h b/mindspore/core/ir/visitor.h index 6dcf28249a..c6accbefbb 100644 --- a/mindspore/core/ir/visitor.h +++ b/mindspore/core/ir/visitor.h @@ -14,22 +14,23 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_IR_VISITOR_H_ -#define MINDSPORE_CCSRC_IR_VISITOR_H_ +#ifndef MINDSPORE_CORE_IR_VISITOR_H_ +#define MINDSPORE_CORE_IR_VISITOR_H_ #include -#include "ir/optimizer_caller.h" +#include "ir/anf.h" namespace mindspore { using VisitFuncType = std::function; -class AnfVisitor : public OptimizerCaller { +using PredicateFuncType = std::function; +class AnfIrVisitor { public: virtual void Visit(const AnfNodePtr &); virtual void Visit(const CNodePtr &); virtual void Visit(const ValueNodePtr &); virtual void Visit(const ParameterPtr &); - VisitFuncType Match(const PrimitivePtr &, const std::vector & = {}); - virtual ~AnfVisitor() = default; + VisitFuncType Match(const PrimitivePtr &, const std::vector & = {}); + virtual ~AnfIrVisitor() = default; }; } // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_VISITOR_H_ +#endif // MINDSPORE_CORE_IR_VISITOR_H_ diff --git a/mindspore/core/utils/CMakeLists.txt b/mindspore/core/utils/CMakeLists.txt new file mode 100644 index 0000000000..f90d1b426a --- /dev/null +++ b/mindspore/core/utils/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _UTIL_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_UTIL_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS) +add_library(_mindspore_core_utils_obj OBJECT ${_UTIL_ALL_SRC_FILES}) diff --git a/mindspore/ccsrc/utils/any.cc b/mindspore/core/utils/any.cc similarity index 100% rename from mindspore/ccsrc/utils/any.cc rename to mindspore/core/utils/any.cc diff --git a/mindspore/core/utils/any.h b/mindspore/core/utils/any.h new file mode 100644 index 0000000000..20a20ed006 --- /dev/null +++ b/mindspore/core/utils/any.h @@ -0,0 +1,214 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_UTILS_ANY_H_ +#define MINDSPORE_CORE_UTILS_ANY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils/overload.h" +#include "utils/log_adapter.h" +#include "utils/misc.h" + +namespace mindspore { +// usage:AnyPtr sp = std::make_shared(aname); +template +std::string type(const T &t) { + return demangle(typeid(t).name()); +} + +class Any { + public: + // constructors + Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {} + Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} + Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} + + Any &operator=(Any &&other); + // right reference constructor + template ::type, Any>::value, T>::type> + Any(T &&t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT + BasePtr new_val(new Derived::type>(std::forward(t))); + std::swap(m_ptr, new_val); + } + + ~Any() = default; + + // judge whether is empty + bool empty() const { return m_ptr == nullptr; } + + // judge the is relation + template + bool is() const { + return m_tpIndex == std::type_index(typeid(T)); + } + + const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); } + + std::size_t Hash() const { + std::stringstream buffer; + buffer << m_tpIndex.name(); + if (m_ptr != nullptr) { + buffer << m_ptr->GetString(); + } + return std::hash()(buffer.str()); + } + + template + bool Apply(const std::function &fn) { + if (type() == typeid(T)) { + T x = cast(); + fn(x); + return true; + } + return false; + } + + std::string GetString() const { + if (m_ptr != nullptr) { + return m_ptr->GetString(); + } else { + return std::string(""); + } + } + + friend std::ostream &operator<<(std::ostream &os, const Any &any) { + os << any.GetString(); + return os; + } + + // type cast + template + T &cast() const { + if (!is() || !m_ptr) { + // Use MS_LOGFATAL replace throw std::bad_cast() + MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name(); + } + auto ptr = static_cast *>(m_ptr.get()); + return ptr->m_value; + } + + bool operator==(const Any &other) const { + if (m_tpIndex != other.m_tpIndex) { + return false; + } + if (m_ptr == nullptr && other.m_ptr == nullptr) { + return true; + } + if (m_ptr == nullptr || other.m_ptr == nullptr) { + return false; + } + return *m_ptr == *other.m_ptr; + } + + bool operator!=(const Any &other) const { return !(operator==(other)); } + + Any &operator=(const Any &other); + + bool operator<(const Any &other) const; + + std::string ToString() const { + std::ostringstream buffer; + if (m_tpIndex == typeid(float)) { + buffer << " " << cast(); + } else if (m_tpIndex == typeid(double)) { + buffer << " " << cast(); + } else if (m_tpIndex == typeid(int)) { + buffer << " " << cast(); + } else if (m_tpIndex == typeid(bool)) { + buffer << " " << cast(); + } else { + buffer << "<" << demangle(m_tpIndex.name()) << "> " << m_ptr->GetString(); + } + return buffer.str(); + } + __attribute__((used)) void dump() const { std::cout << ToString() << std::endl; } + + private: + struct Base; + using BasePtr = std::unique_ptr; + + // type base definition + struct Base { + virtual const std::type_info &type() const = 0; + virtual BasePtr clone() const = 0; + virtual ~Base() = default; + virtual bool operator==(const Base &other) const = 0; + virtual std::string GetString() = 0; + }; + + template + struct Derived : public Base { + template + explicit Derived(Args &&... args) : m_value(std::forward(args)...), serialize_cache_("") {} + + bool operator==(const Base &other) const override { + if (typeid(*this) != typeid(other)) { + return false; + } + return m_value == static_cast &>(other).m_value; + } + + const std::type_info &type() const override { return typeid(T); } + + BasePtr clone() const override { return BasePtr(new Derived(m_value)); } + + ~Derived() override {} + + std::string GetString() override { + std::stringstream buffer; + buffer << m_value; + return buffer.str(); + } + + T m_value; + std::string serialize_cache_; + }; + + // clone method + BasePtr clone() const { + if (m_ptr != nullptr) { + return m_ptr->clone(); + } + return nullptr; + } + + BasePtr m_ptr; // point to real data + std::type_index m_tpIndex; // type info of data +}; + +using AnyPtr = std::shared_ptr; + +struct AnyHash { + std::size_t operator()(const Any &c) const { return c.Hash(); } +}; + +struct AnyLess { + bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); } +}; + +bool AnyIsLiteral(const Any &any); + +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_ANY_H_ diff --git a/mindspore/core/utils/convert_utils_base.h b/mindspore/core/utils/convert_utils_base.h new file mode 100644 index 0000000000..ade7c3a967 --- /dev/null +++ b/mindspore/core/utils/convert_utils_base.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_ +#define MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_ + +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +inline int SizeToInt(size_t u) { + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int."; + } + return static_cast(u); +} + +inline uint32_t SizeToUint(size_t u) { + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of uint32_t."; + } + return static_cast(u); +} + +inline int64_t SizeToLong(size_t u) { + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int64_t."; + } + return static_cast(u); +} + +inline size_t IntToSize(int u) { + if (u < 0) { + MS_LOG(EXCEPTION) << "The int value(" << u << ") is less than 0."; + } + return static_cast(u); +} + +inline size_t LongToSize(int64_t u) { + if (u < 0) { + MS_LOG(EXCEPTION) << "The int64_t value(" << u << ") is less than 0."; + } + return static_cast(u); +} + +inline size_t FloatToSize(float u) { + if (u < 0) { + MS_LOG(EXCEPTION) << "The float value(" << u << ") is less than 0."; + } + + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of size_t."; + } + return static_cast(u); +} +inline float IntToFloat(int32_t v) { return static_cast(v); } + +inline uint32_t IntToUint(int32_t u) { + if (u < 0) { + MS_LOG(EXCEPTION) << "The int32_t value(" << u << ") is less than 0."; + } + return static_cast(u); +} + +inline int32_t UintToInt(uint32_t u) { + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The uint32_t value(" << u << ") exceeds the maximum value of int32_t."; + } + return static_cast(u); +} + +inline unsigned int UlongToUint(size_t u) { + if (u > static_cast((std::numeric_limits::max)())) { + MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of unsigned int."; + } + return static_cast(u); +} + +inline int IntMulWithOverflowCheck(int a, int b) { + int out = a * b; + if (a != 0) { + bool overflow = ((out / a) != b); + if (overflow) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + return out; +} + +inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) { + int64_t out = a * b; + if (a != 0) { + bool overflow = ((out / a) != b); + if (overflow) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + return out; +} + +inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { + size_t out = a * b; + if (a != 0) { + if ((out / a) != b) { + MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + return out; +} + +inline uint8_t *AddressOffset(void *address, size_t offset) { + MS_EXCEPTION_IF_NULL(address); + return static_cast(address) + offset; +} +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_ diff --git a/mindspore/core/utils/counter.h b/mindspore/core/utils/counter.h new file mode 100644 index 0000000000..f70271d37f --- /dev/null +++ b/mindspore/core/utils/counter.h @@ -0,0 +1,102 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_COUNTER_H_ +#define MINDSPORE_CORE_UTILS_COUNTER_H_ +#include +#include "utils/ordered_map.h" + +namespace mindspore { + +template , class Equal = std::equal_to> +class Counter { + using counter_type = Counter; + + public: + Counter() = default; + ~Counter() = default; + + Counter(const Counter &other) { data = other.data; } + Counter &operator=(const Counter &other) { + if (this != &other) { + data = other.data; + } + return *this; + } + + int &operator[](const T &t) { return data[t]; } + + counter_type operator-(const counter_type &other) { + counter_type new_counter; + for (auto iter = begin(); iter != end(); ++iter) { + auto key = iter->first; + int value = iter->second; + auto item = other.data.find(key); + if (item != other.data.end()) { + int o_value = item->second; + if (value - o_value > 0) { + new_counter[key] = value - o_value; + } + } else { + new_counter[key] = value; + } + } + + return new_counter; + } + + counter_type operator+(const counter_type &other) { + counter_type new_counter; + for (auto iter = begin(); iter != end(); ++iter) { + auto key = iter->first; + int value = iter->second; + auto item = other.data.find(key); + if (item != other.data.end()) { + new_counter[key] = iter->second + item->second; + } else { + new_counter[key] = value; + } + } + + for (auto iter = other.cbegin(); iter != other.cend(); ++iter) { + auto key = iter->first; + int value = iter->second; + if (!new_counter.contains(key)) { + new_counter[key] = value; + } + } + + return new_counter; + } + + std::size_t size() const { return data.size(); } + + bool contains(const T &t) const { return data.find(t) != data.end(); } + + typename OrderedMap::iterator begin() { return data.begin(); } + + typename OrderedMap::iterator end() { return data.end(); } + + typename OrderedMap::const_iterator cbegin() const { return data.cbegin(); } + + typename OrderedMap::const_iterator cend() const { return data.cend(); } + + private: + OrderedMap data; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_COUNTER_H_ diff --git a/mindspore/core/utils/flags.cc b/mindspore/core/utils/flags.cc new file mode 100644 index 0000000000..a36d0367d6 --- /dev/null +++ b/mindspore/core/utils/flags.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/flags.h" +namespace mindspore { +// flag names +const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; +const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; +const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; +const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; +const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; +const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; +} // namespace mindspore diff --git a/mindspore/core/utils/flags.h b/mindspore/core/utils/flags.h new file mode 100644 index 0000000000..89268fbaed --- /dev/null +++ b/mindspore/core/utils/flags.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_FLAGS_H +#define MINDSPORE_CORE_UTILS_FLAGS_H +namespace mindspore { +extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; +extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; +extern const char GRAPH_FLAG_HAS_EFFECT[]; +extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; +extern const char GRAPH_FLAG_RANDOM_EFFECT[]; +extern const char GRAPH_FLAG_SIDE_EFFECT[]; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_FLAGS_H diff --git a/mindspore/core/utils/hashing.h b/mindspore/core/utils/hashing.h new file mode 100644 index 0000000000..c0eb753136 --- /dev/null +++ b/mindspore/core/utils/hashing.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_HASHING_H_ +#define MINDSPORE_CORE_UTILS_HASHING_H_ + +#include + +namespace mindspore { +inline std::size_t hash_combine(std::size_t hash_sum, std::size_t hash_val) { + // Reference from http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0814r0.pdf + return ((hash_sum << 6) + (hash_sum >> 2) + 0x9e3779b9 + hash_val) ^ hash_sum; +} + +inline std::size_t hash_combine(const std::initializer_list &hash_vals) { + std::size_t hash_sum = 0; + for (auto hash_val : hash_vals) { + hash_sum = hash_combine(hash_sum, hash_val); + } + return hash_sum; +} +} // namespace mindspore +#endif // MINDSPORE_CORE_UTILS_HASHING_H_ diff --git a/mindspore/core/utils/info.cc b/mindspore/core/utils/info.cc new file mode 100644 index 0000000000..06aa1e2de8 --- /dev/null +++ b/mindspore/core/utils/info.cc @@ -0,0 +1,222 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/info.h" +#include +#include +#include +#include +#include "ir/anf.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { + std::string temp_line = line; + if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && + tip != kSourceLineTipDiscard) { + std::string start = temp_line.substr(0, IntToSize(col_begin)); + std::string trimmed = temp_line.substr(IntToSize(col_begin), IntToSize(col_end - col_begin)); + std::string end = temp_line.substr(IntToSize(col_end), IntToSize(SizeToInt(temp_line.length()) - col_end)); + std::stringstream oss; + std::stringstream tip_ss; + std::string start_spaces(start.length(), ' '); + if (tip == kSourceLineTipInLine) { + temp_line = start + "<" + trimmed + ">" + end; + } else if (tip == kSourceLineTipNextLine) { + tip_ss << start_spaces << "^"; + } + oss << temp_line << "\n" << tip_ss.str(); + return oss.str(); + } + return temp_line; +} +// Generate debug information for the location node . +// print the file name, line no and column no, and part of the content +std::string Location::ToString(SourceLineTip tip) { + std::stringstream debug_info_ss; + debug_info_ss << " In file " << file_name_ << "(" << line_ << ")" << std::endl; + if (line_ <= 0) { + return debug_info_ss.str(); + } + + char path[PATH_MAX + 1] = {0x00}; +#if defined(_WIN32) || defined(_WIN64) + if (file_name_.size() > PATH_MAX || _fullpath(path, file_name_.c_str(), PATH_MAX) == nullptr) { + return debug_info_ss.str(); + } +#else + if (file_name_.size() > PATH_MAX || realpath(file_name_.c_str(), path) == nullptr) { + return debug_info_ss.str(); + } +#endif + auto src_path = std::string(path); + std::ifstream file(src_path); + if (!file.is_open()) { + return debug_info_ss.str(); + } + + int line_num = 0; + std::string line; + (void)getline(file, line); + while (line_num != line_ - 1) { + (void)getline(file, line); + line_num++; + } + file.close(); + + debug_info_ss << HighLightLine(line, column_, column_end_, tip) << std::endl; + return debug_info_ss.str(); +} + +void TraceContext::ProcessAttributeFromContext() { + trace_info_ = nullptr; + location_ = nullptr; + func_name_ = ""; + // if there is trace context, get info from previous context + if (!TraceManager::trace_context_stack_.empty()) { + TraceContextPtr top = TraceManager::trace_context_stack_.top(); + trace_info_ = top->trace_info_; + location_ = top->location_; + func_name_ = top->func_name_; + } +} + +DebugInfo::DebugInfo() { + InitValueFromContext(); + unique_id_ = gen_unique_id(); + debug_id_ = -1; + name_ = ""; +} + +DebugInfo::DebugInfo(const std::string &name) { + InitValueFromContext(); + unique_id_ = gen_unique_id(); + debug_id_ = -1; + name_ = name; +} + +DebugInfo::DebugInfo(const LocationPtr &loc) { + InitValueFromContext(); + unique_id_ = gen_unique_id(); + debug_id_ = -1; + location_ = loc; +} + +int64_t DebugInfo::debug_id() { + // cppcheck-suppress variableScope + static int64_t cur_debug_id = 0; + if (debug_id_ == -1) { + debug_id_ = cur_debug_id; + cur_debug_id++; + } + return debug_id_; +} + +int64_t DebugInfo::unique_id_through_copy() const { + auto info = trace_info(); + if (info != nullptr) { + if (info->isa() && info->debug_info() != nullptr) { + return info->debug_info()->unique_id_through_copy(); + } + } + return unique_id(); +} + +std::string DebugInfo::debug_name() { + if (!name_.empty()) { + return name_; + } + std::string debug_name = std::to_string(debug_id()); + name_ = debug_name; + return debug_name; +} + +std::string NodeDebugInfo::debug_name() { + if (!name_.empty()) { + return name_; + } + std::string prefix = ""; + if (node_.lock() != nullptr) { + std::ostringstream oss; + oss << "[" << node_.lock()->type_name() << "]"; + prefix = oss.str(); + } + name_ = prefix + DebugInfo::debug_name(); + return name_; +} + +std::string GraphDebugInfo::debug_name() { + std::string prefix = ""; + return prefix + DebugInfo::debug_name(); +} + +LocationPtr GraphDebugInfo::location() { + // function may have decorator which is included in its location + if (deco_loc_ != nullptr) { + LocationPtr loc = std::make_shared(*DebugInfo::location()); + loc->set_line(loc->line() + (deco_loc_->line_end() - deco_loc_->line() + 1)); + return loc; + } + return DebugInfo::location(); +} +void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } + +TraceContextPtr TraceManager::CurrentContextInfo() { + if (!TraceManager::trace_context_stack_.empty()) { + return TraceManager::trace_context_stack_.top(); + } + return nullptr; +} + +void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { + TraceContextPtr context = std::make_shared(location); + context->set_func_name(func_name); + TraceManager::trace_context_stack_.push(context); +} + +void TraceManager::DebugTrace(const LocationPtr &location) { + TraceContextPtr context = std::make_shared(location); + TraceManager::trace_context_stack_.push(context); +} + +void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { + if (trace_info == nullptr) { + MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; + } + TraceContextPtr context = std::make_shared(trace_info); + if (trace_info->debug_info() == nullptr) { + MS_LOG(EXCEPTION) << "Trace debug info is null"; + } + TraceManager::trace_context_stack_.push(context); +} + +void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { + if (trace_info == nullptr) { + MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; + } + auto cloned_info = trace_info->clone(); + cloned_info->set_debug_info(debug_info); + if (cloned_info->debug_info() == nullptr) { + MS_LOG(EXCEPTION) << "Trace debug info is null with cloned trace"; + } + TraceContextPtr context = std::make_shared(cloned_info); + TraceManager::trace_context_stack_.push(context); +} + +void TraceManager::EndTrace() { TraceManager::trace_context_stack_.pop(); } + +std::stack TraceManager::trace_context_stack_; +} // namespace mindspore diff --git a/mindspore/core/utils/info.h b/mindspore/core/utils/info.h new file mode 100644 index 0000000000..4ab0ad8a28 --- /dev/null +++ b/mindspore/core/utils/info.h @@ -0,0 +1,235 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_INFO_H_ +#define MINDSPORE_CORE_UTILS_INFO_H_ + +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "utils/trace_info.h" + +namespace mindspore { +// namespace to support intermediate representation definition +enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 }; + +// Location class record the location in source code. +class Location { + public: + Location(const std::string &file_name, int line, int column, int line_end, int column_end) + : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} + Location(const Location &loc) + : file_name_(loc.file_name_), + line_(loc.line_), + column_(loc.column_), + line_end_(loc.line_end_), + column_end_(loc.column_end_) {} + std::string ToString(SourceLineTip tip = kSourceLineTipNextLine); + std::string file_name() { return file_name_; } + int line() const { return line_; } + void set_line(int line) { line_ = line; } + int line_end() const { return line_end_; } + void set_line_end(int line) { line_end_ = line; } + int column() const { return column_; } + void set_column(int column) { column_ = column; } + int column_end() const { return column_end_; } + void set_column_end(int column) { column_end_ = column; } + ~Location() = default; + + private: + std::string file_name_; + int line_; + int column_; + int line_end_; + int column_end_; +}; +class TraceContext; +using TraceContextPtr = std::shared_ptr; + +class TraceManager { + public: + TraceManager() = default; + ~TraceManager() = default; + static TraceContextPtr CurrentContextInfo(); + static void DebugTrace(const std::string &func_name, const LocationPtr &location); + static void DebugTrace(const LocationPtr &location); + static void DebugTrace(const TraceInfoPtr &trace_info); + // debug trace with a cloned trace info with debug_info + static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); + static void EndTrace(); + static std::stack trace_context_stack_; +}; + +class TraceGuard { + public: + explicit TraceGuard(const std::string func_name, const LocationPtr &location) { + TraceManager::DebugTrace(func_name, location); + } + explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } + ~TraceGuard() { TraceManager::EndTrace(); } +}; + +class TraceContext { + public: + LocationPtr location_; + TraceInfoPtr trace_info_; + std::string func_name_; + + protected: + void ProcessAttributeFromContext(); + + public: + ~TraceContext() = default; + explicit TraceContext(const LocationPtr &loc) { + ProcessAttributeFromContext(); + location_ = loc; + } + explicit TraceContext(const std::string &func_name) { + ProcessAttributeFromContext(); + func_name_ = func_name; + } + explicit TraceContext(const TraceInfoPtr &trace_info) { + ProcessAttributeFromContext(); + trace_info_ = trace_info; + } + void set_location(const LocationPtr &loc) { location_ = loc; } + LocationPtr location() { return location_; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } + TraceInfoPtr trace_info() const { return trace_info_; } + void set_func_name(const std::string &func_name) { func_name_ = func_name; } + std::string func_name() { return func_name_; } +}; + +class DebugInfo : public Base { + public: + DebugInfo(); + + explicit DebugInfo(const std::string &name); + + explicit DebugInfo(const LocationPtr &loc); + + ~DebugInfo() override = default; + MS_DECLARE_PARENT(DebugInfo, Base); + int64_t debug_id(); + int64_t unique_id() const { return unique_id_; } + int64_t unique_id_through_copy() const; + std::string get_id() { return std::to_string(debug_id()); } + + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } + TraceInfoPtr trace_info() const { return trace_info_; } + void set_location(const LocationPtr &loc) { location_ = loc; } + virtual LocationPtr location() { return location_; } + std::string name() { return name_; } + void set_name(const std::string &name) { name_ = name; } + virtual std::string debug_name(); + + virtual std::string get_python_func_belonged() { return ""; } + + protected: + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } + + private: + void InitValueFromContext() { + if (TraceManager::CurrentContextInfo() != nullptr) { + auto context_info = TraceManager::CurrentContextInfo(); + trace_info_ = context_info->trace_info(); + location_ = context_info->location(); + } + } + static int64_t gen_unique_id() { + static int64_t cur_unique_id = 0; + return cur_unique_id++; + } + + protected: + int64_t unique_id_; + int64_t debug_id_; + TraceInfoPtr trace_info_; + LocationPtr location_; + std::string name_; +}; + +class NodeDebugInfo : public DebugInfo { + public: + NodeDebugInfo() { + if (TraceManager::CurrentContextInfo() != nullptr) { + auto context_info = TraceManager::CurrentContextInfo(); + py_func_belonged_ = context_info->func_name(); + } + } + explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { + if (TraceManager::CurrentContextInfo() != nullptr) { + auto context_info = TraceManager::CurrentContextInfo(); + py_func_belonged_ = context_info->func_name(); + } + } + ~NodeDebugInfo() override = default; + + std::string debug_name() override; + void set_node(const std::shared_ptr &node) { node_ = AnfNodeWeakPtr(node); } + std::shared_ptr get_node() const { return node_.lock(); } + void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } + std::string get_python_func_belonged() override { return py_func_belonged_; } + AnfNodeWeakPtr node_; + std::string py_func_belonged_; +}; +using NodeDebugInfoPtr = std::shared_ptr; + +class GraphDebugInfo : public DebugInfo { + public: + GraphDebugInfo() { + if (TraceManager::CurrentContextInfo() != nullptr) { + auto context_info = TraceManager::CurrentContextInfo(); + py_func_name_ = context_info->func_name(); + deco_loc_ = nullptr; + } + } + + explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { + if (TraceManager::CurrentContextInfo() != nullptr) { + auto context_info = TraceManager::CurrentContextInfo(); + py_func_name_ = context_info->func_name(); + deco_loc_ = nullptr; + } + } + ~GraphDebugInfo() override = default; + std::string debug_name() override; + LocationPtr location() override; + LocationPtr deco_location() { return deco_loc_; } + void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } + FuncGraphPtr get_graph() const { return func_graph_.lock(); } + void set_full_name(const std::string &name) { full_name_ = name; } + std::string get_full_name() { return full_name_; } + void set_deco_location(const LocationPtr &deco_list_loc); + std::string get_python_func_belonged() override { return py_func_name_; } + FuncGraphWeakPtr func_graph_; + LocationPtr deco_loc_; + std::string py_func_name_; + std::string full_name_; +}; + +using GraphDebugInfoPtr = std::shared_ptr; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_INFO_H_ diff --git a/mindspore/core/utils/label.cc b/mindspore/core/utils/label.cc new file mode 100644 index 0000000000..ef4ce9ee3c --- /dev/null +++ b/mindspore/core/utils/label.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/label.h" +#include +#include +#include + +#include "utils/info.h" +#include "ir/func_graph.h" + +namespace mindspore { +namespace label_manage { +static TraceLabelType trace_type = TraceLabelType::kShortSymbol; +TraceLabelType GetGlobalTraceLabelType() { return trace_type; } +void SetGlobalTraceLabelType(TraceLabelType label_type) { trace_type = label_type; } +struct NameWithTrace { + std::string name; + std::vector trace_labels; +}; +static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { + switch (trace_label) { + case TraceLabelType::kShortSymbol: + return trace_info->symbol(); + case TraceLabelType::kFullName: + return "_" + trace_info->full_name() + "_"; + default: + return ""; + } +} + +NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { + NameWithTrace trace_name; + // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node + auto temp_info = debug_info; + while (temp_info != nullptr) { + if (temp_info->trace_info() != nullptr) { + if (temp_info->trace_info()->isa() || temp_info->trace_info()->isa() || + temp_info->trace_info()->isa()) { + break; + } + trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); + temp_info = temp_info->trace_info()->debug_info(); + } else { + break; + } + } + if (!temp_info->name().empty()) { + trace_name.name = temp_info->name(); + } else { + trace_name.name = temp_info->debug_name(); + } + return trace_name; +} + +std::string CombineTraceTypes(const std::string &root_name, const std::vector &trace_labels) { + std::string tags = ""; + for (auto &itr : trace_labels) { + std::string symbol = itr; + tags = tags + symbol; + } + return tags + root_name; +} + +// get the label name of the node debug info +std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { + NameWithTrace trace_name = RootName(debug_info, trace_label); + return CombineTraceTypes(trace_name.name, trace_name.trace_labels); +} + +std::string CombineUniqueID(const DebugInfoPtr &debug_info) { + auto temp_info = debug_info; + std::string label = ""; + while (temp_info != nullptr) { + if (!temp_info->name().empty()) { + label = label + temp_info->name(); + } else { + // the symbol 'U' is for identification of number + label = label + "U" + std::to_string(temp_info->unique_id()); + } + + if (temp_info->trace_info() != nullptr) { + label = label + "_" + temp_info->trace_info()->full_name() + "_"; + temp_info = temp_info->trace_info()->debug_info(); + } else { + temp_info = nullptr; + } + } + return label; +} + +// get trace with unique id chain +std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); } + +std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { + if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { + return LabelStringUnique(debug_info); + } + return LabelString(debug_info, trace_label); +} +} // namespace label_manage +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/label.h b/mindspore/core/utils/label.h similarity index 100% rename from mindspore/ccsrc/debug/label.h rename to mindspore/core/utils/label.h diff --git a/mindspore/core/utils/log_adapter.cc b/mindspore/core/utils/log_adapter.cc new file mode 100644 index 0000000000..8d91a6f776 --- /dev/null +++ b/mindspore/core/utils/log_adapter.cc @@ -0,0 +1,558 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/log_adapter.h" + +#include +#include +#include + +// namespace to support utils module definition +namespace mindspore { +#ifndef USE_ANDROID_LOG +#ifdef USE_GLOG +static std::string GetTime() { +#define BUFLEN 80 + static char buf[BUFLEN]; +#if defined(_WIN32) || defined(_WIN64) + time_t time_seconds = time(0); + struct tm now_time; + localtime_s(&now_time, &time_seconds); + sprintf_s(buf, BUFLEN, "%d-%d-%d %d:%d:%d", now_time.tm_year + 1900, now_time.tm_mon + 1, now_time.tm_mday, + now_time.tm_hour, now_time.tm_min, now_time.tm_sec); +#else + struct timeval cur_time; + (void)gettimeofday(&cur_time, nullptr); + + struct tm now; + (void)localtime_r(&cur_time.tv_sec, &now); + (void)strftime(buf, BUFLEN, "%Y-%m-%d-%H:%M:%S", &now); // format date and time + // set micro-second + buf[27] = '\0'; + int idx = 26; + auto num = cur_time.tv_usec; + for (int i = 5; i >= 0; i--) { + buf[idx--] = static_cast(num % 10 + '0'); + num /= 10; + if (i % 3 == 0) { + buf[idx--] = '.'; + } + } +#endif + return std::string(buf); +} + +static std::string GetProcName() { +#if defined(__APPLE__) || defined(__FreeBSD__) + const char *appname = getprogname(); +#elif defined(_GNU_SOURCE) + const char *appname = program_invocation_name; +#else + const char *appname = "?"; +#endif + // some times, the appname is an absolute path, its too long + std::string app_name(appname); + std::size_t pos = app_name.rfind("/"); + if (pos == std::string::npos) { + return app_name; + } + if (pos + 1 >= app_name.size()) { + return app_name; + } + return app_name.substr(pos + 1); +} + +static std::string GetLogLevel(MsLogLevel level) { +#define _TO_STRING(x) #x + static const char *const level_names[] = { + _TO_STRING(DEBUG), + _TO_STRING(INFO), + _TO_STRING(WARNING), + _TO_STRING(ERROR), + }; +#undef _TO_STRING + if (level > ERROR) { + level = ERROR; + } + return std::string(level_names[level]); +} + +// convert MsLogLevel to corresponding glog level +static int GetGlogLevel(MsLogLevel level) { + switch (level) { + case DEBUG: + case INFO: + return google::GLOG_INFO; + case WARNING: + return google::GLOG_WARNING; + case ERROR: + default: + return google::GLOG_ERROR; + } +} +#else + +#undef Dlog +#define Dlog(module_id, level, format, ...) \ + do { \ + DlogInner((module_id), (level), (format), ##__VA_ARGS__); \ + } while (0) + +// convert MsLogLevel to corresponding slog level +static int GetSlogLevel(MsLogLevel level) { + switch (level) { + case DEBUG: + return DLOG_DEBUG; + case INFO: + return DLOG_INFO; + case WARNING: + return DLOG_WARN; + case ERROR: + default: + return DLOG_ERROR; + } +} +#endif +#endif + +static std::string ExceptionTypeToString(ExceptionType type) { +#define _TO_STRING(x) #x + // clang-format off + static const char *const type_names[] = { + _TO_STRING(NoExceptionType), + _TO_STRING(UnknownError), + _TO_STRING(ArgumentError), + _TO_STRING(NotSupportError), + _TO_STRING(NotExistsError), + _TO_STRING(AlreadyExistsError), + _TO_STRING(UnavailableError), + _TO_STRING(DeviceProcessError), + _TO_STRING(AbortedError), + _TO_STRING(TimeOutError), + _TO_STRING(ResourceUnavailable), + _TO_STRING(NoPermissionError), + _TO_STRING(IndexError), + _TO_STRING(ValueError), + _TO_STRING(TypeError), + _TO_STRING(AttributeError), + }; + // clang-format on +#undef _TO_STRING + if (type < UnknownError || type > AttributeError) { + type = UnknownError; + } + return std::string(type_names[type]); +} + +static const char *GetSubModuleName(SubModuleId module_id) { + static const char *sub_module_names[NUM_SUBMODUES] = { + "UNKNOWN", // SM_UNKNOWN + "BASE", // SM_BASE + "ANALYZER", // SM_ANALYZER + "COMMON", // SM_COMMON + "DEBUG", // SM_DEBUG + "DEVICE", // SM_DEVICE + "GE_ADPT", // SM_GE_ADPT + "IR", // SM_IR + "KERNEL", // SM_KERNEL + "MD", // SM_MD + "ME", // SM_ME + "ONNX", // SM_ONNX + "OPTIMIZER", // SM_OPTIMIZER + "PARALLEL", // SM_PARALLEL + "PARSER", // SM_PARSER + "PIPELINE", // SM_PIPELINE + "PRE_ACT", // SM_PRE_ACT + "PYNATIVE", // SM_PYNATIVE + "SESSION", // SM_SESSION + "UTILS", // SM_UTILS + "VM", // SM_VM + "ABSTRACT" // SM_ABSTRACT + }; + + return sub_module_names[module_id % NUM_SUBMODUES]; +} + +const char *EnumStrForMsLogLevel(MsLogLevel level) { + if (level == DEBUG) { + return "DEBUG"; + } else if (level == INFO) { + return "INFO"; + } else if (level == WARNING) { + return "WARNING"; + } else if (level == ERROR) { + return "ERROR"; + } else if (level == EXCEPTION) { + return "EXCEPTION"; + } else { + return "NO_LEVEL"; + } +} + +void LogWriter::OutputLog(const std::ostringstream &msg) const { +#ifndef USE_ANDROID_LOG +#ifdef USE_GLOG + auto submodule_name = GetSubModuleName(submodule_); + google::LogMessage("", 0, GetGlogLevel(log_level_)).stream() + << "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << GetProcName() + << "):" << GetTime() << " " + << "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl; +#else + auto str_msg = msg.str(); + auto slog_module_id = (submodule_ == SM_MD ? MD : ME); + Dlog(static_cast(slog_module_id), GetSlogLevel(log_level_), "[%s:%d] %s] %s", location_.file_, location_.line_, + location_.func_, str_msg.c_str()); +#endif +#else + printf("%s [%s:%d] %s] %s\n:", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_, + msg.str().c_str()); +#endif +} + +void LogWriter::operator<(const LogStream &stream) const noexcept { + std::ostringstream msg; + msg << stream.sstream_->rdbuf(); + OutputLog(msg); +} + +void LogWriter::operator^(const LogStream &stream) const { + std::ostringstream msg; + msg << stream.sstream_->rdbuf(); + OutputLog(msg); + + std::ostringstream oss; + oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; + if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError && + exception_type_ != ValueError && exception_type_ != AttributeError) { + oss << ExceptionTypeToString(exception_type_) << " "; + } + oss << msg.str(); + + if (trace_provider_ != nullptr) { + trace_provider_(oss); + } + + if (exception_handler_ != nullptr) { + exception_handler_(exception_type_, oss.str()); + } + throw std::runtime_error(oss.str()); +} + +static std::string GetEnv(const std::string &envvar) { + const char *value = ::getenv(envvar.c_str()); + + if (value == nullptr) { + return std::string(); + } + + return std::string(value); +} + +enum LogConfigToken { + INVALID, // indicate invalid token + LEFT_BRACE, // '{' + RIGHT_BRACE, // '}' + VARIABLE, // '[A-Za-z][A-Za-z0-9_]*' + NUMBER, // [0-9]+ + COMMA, // ',' + COLON, // ':' + EOS, // End Of String, '\0' + NUM_LOG_CFG_TOKENS +}; + +static const char *g_tok_names[NUM_LOG_CFG_TOKENS] = { + "invalid", // indicate invalid token + "{", // '{' + "}", // '}' + "variable", // '[A-Za-z][A-Za-z0-9_]*' + "number", // [0-9]+ + ",", // ',' + ":", // ':' + "end-of-string", // End Of String, '\0' +}; + +static inline bool IsAlpha(char ch) { return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); } + +static inline bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; } + +class LogConfigLexer { + public: + explicit LogConfigLexer(const std::string &text) : buffer_(text) { + cur_idx_ = 0; + cur_token_ = LogConfigToken::INVALID; + } + ~LogConfigLexer() = default; + + // skip white space, and return the first char after white space + char SkipWhiteSpace() { + while (cur_idx_ < buffer_.size()) { + char ch = buffer_[cur_idx_]; + if (ch == ' ' || ch == '\t') { + ++cur_idx_; + continue; + } + return ch; + } + return '\0'; + } + + LogConfigToken GetNext(std::string *const ptr) { +#ifdef DEBUG + std::string text; + auto tok = GetNextInner(&text); + MS_LOG(DEBUG) << "Got token " << tok << " with value [" << text << "]"; + if (ptr != nullptr) { + *ptr = text; + } + return tok; + } + + LogConfigToken GetNextInner(std::string *ptr) { +#endif + char ch = SkipWhiteSpace(); + // clang-format off + static const std::map single_char_map = { + {'{', LogConfigToken::LEFT_BRACE}, + {'}', LogConfigToken::RIGHT_BRACE}, + {',', LogConfigToken::COMMA}, + {':', LogConfigToken::COLON}, + {'\0', LogConfigToken::EOS}, + }; + // clang-format on + + auto iter = single_char_map.find(ch); + if (iter != single_char_map.end()) { + if (ptr != nullptr) { + *ptr = std::string() + ch; + } + ++cur_idx_; + return iter->second; + } else if (IsAlpha(ch)) { + std::ostringstream oss; + do { + oss << ch; + ch = buffer_[++cur_idx_]; + } while (cur_idx_ < buffer_.size() && (IsAlpha(ch) || IsDigit(ch) || ch == '_')); + if (ptr != nullptr) { + *ptr = std::string(oss.str()); + } + return LogConfigToken::VARIABLE; + } else if (IsDigit(ch)) { + std::ostringstream oss; + do { + oss << ch; + ch = buffer_[++cur_idx_]; + } while (cur_idx_ < buffer_.size() && IsDigit(ch)); + if (ptr != nullptr) { + *ptr = std::string(oss.str()); + } + return LogConfigToken::NUMBER; + } + return LogConfigToken::INVALID; + } + + private: + std::string buffer_; + size_t cur_idx_; + + LogConfigToken cur_token_; + std::string cur_text_; +}; + +class LogConfigParser { + public: + explicit LogConfigParser(const std::string &cfg) : lexer(cfg) {} + ~LogConfigParser() = default; + + bool Expect(LogConfigToken expected, LogConfigToken tok) { + if (expected != tok) { + MS_LOG(WARNING) << "Parse submodule log configuration text error, expect `" << g_tok_names[expected] + << "`, but got `" << g_tok_names[tok] << "`. The whole configuration will be ignored."; + return false; + } + return true; + } + + // The text of config MS_SUBMODULE_LOG_v is in the form {submodule1:log_level1,submodule2:log_level2,...}. + // Valid values of log levels are: 0 - debug, 1 - info, 2 - warning, 3 - error + // e.g. MS_SUBMODULE_LOG_v={PARSER:0, ANALYZER:2, PIPELINE:1} + std::map Parse() { + std::map log_levels; + + bool flag_error = false; + std::string text; + auto tok = lexer.GetNext(&text); + + // empty string + if (tok == LogConfigToken::EOS) { + return log_levels; + } + + if (!Expect(LogConfigToken::LEFT_BRACE, tok)) { + return log_levels; + } + + do { + std::string key, val; + tok = lexer.GetNext(&key); + if (!Expect(LogConfigToken::VARIABLE, tok)) { + flag_error = true; + break; + } + + tok = lexer.GetNext(&text); + if (!Expect(LogConfigToken::COLON, tok)) { + flag_error = true; + break; + } + + tok = lexer.GetNext(&val); + if (!Expect(LogConfigToken::NUMBER, tok)) { + flag_error = true; + break; + } + + log_levels[key] = val; + tok = lexer.GetNext(&text); + } while (tok == LogConfigToken::COMMA); + + if (!flag_error && !Expect(LogConfigToken::RIGHT_BRACE, tok)) { + flag_error = true; + } + + if (flag_error) { + log_levels.clear(); + } + return log_levels; + } + + private: + LogConfigLexer lexer; +}; + +bool ParseLogLevel(const std::string &str_level, MsLogLevel *ptr_level) { + if (str_level.size() == 1) { + int ch = str_level.c_str()[0]; + ch = ch - '0'; // substract ASCII code of '0', which is 48 + if (ch >= DEBUG && ch <= ERROR) { + if (ptr_level != nullptr) { + *ptr_level = static_cast(ch); + } + return true; + } + } + return false; +} + +static MsLogLevel GetGlobalLogLevel() { +#ifdef USE_GLOG + return static_cast(FLAGS_v); +#else + int log_level = WARNING; // set default log level to WARNING + auto str_level = GetEnv("GLOG_v"); + if (str_level.size() == 1) { + int ch = str_level.c_str()[0]; + ch = ch - '0'; // substract ASCII code of '0', which is 48 + if (ch >= DEBUG && ch <= ERROR) { + log_level = ch; + } + } + return static_cast(log_level); +#endif +} + +void InitSubModulesLogLevel() { + // initialize submodule's log level using global + auto global_log_level = GetGlobalLogLevel(); + for (int i = 0; i < NUM_SUBMODUES; ++i) { + g_ms_submodule_log_levels[i] = global_log_level; + } + + // set submodule's log level + auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); + MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; + LogConfigParser parser(submodule); + auto configs = parser.Parse(); + for (const auto &cfg : configs) { + int mod_idx = -1; + for (int i = 0; i < NUM_SUBMODUES; ++i) { + if (cfg.first == GetSubModuleName(static_cast(i))) { + mod_idx = i; + break; + } + } + if (mod_idx < 0) { + MS_LOG(WARNING) << "Undefined module name " << cfg.first << ", ignore it"; + continue; + } + MsLogLevel submodule_log_level; + if (!ParseLogLevel(cfg.second, &submodule_log_level)) { + MS_LOG(WARNING) << "Illegal log level value " << cfg.second << " for " << cfg.first << ", ignore it."; + continue; + } + g_ms_submodule_log_levels[mod_idx] = submodule_log_level; + } +} +} // namespace mindspore + +extern "C" { +#if defined(_WIN32) || defined(_WIN64) +__attribute__((constructor)) void common_log_init(void) { +#else +void common_log_init(void) { +#endif +#ifdef USE_GLOG + // do not use glog predefined log prefix + FLAGS_log_prefix = false; + // set default log level to WARNING + if (mindspore::GetEnv("GLOG_v").empty()) { + FLAGS_v = mindspore::WARNING; + } + + // set default log file mode to 0640 + if (mindspore::GetEnv("GLOG_logfile_mode").empty()) { + FLAGS_logfile_mode = 0640; + } + std::string logtostderr = mindspore::GetEnv("GLOG_logtostderr"); + // default print log to screen + if (logtostderr.empty()) { + FLAGS_logtostderr = true; + } else if (logtostderr == "0" && mindspore::GetEnv("GLOG_log_dir").empty()) { + FLAGS_logtostderr = true; + MS_LOG(WARNING) << "`GLOG_log_dir` is not set, output log to screen."; + } +#endif + mindspore::InitSubModulesLogLevel(); +} + +// shared lib init hook +#if defined(_WIN32) || defined(_WIN64) +__attribute__((constructor)) void mindspore_log_init(void) { +#else +void mindspore_log_init(void) { +#endif +#ifdef USE_GLOG + static bool is_glog_initialzed = false; + if (!is_glog_initialzed) { +#if !defined(_WIN32) && !defined(_WIN64) + google::InitGoogleLogging("mindspore"); +#endif + is_glog_initialzed = true; + } +#endif + common_log_init(); +} +} diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h new file mode 100644 index 0000000000..ce31ce9ab8 --- /dev/null +++ b/mindspore/core/utils/log_adapter.h @@ -0,0 +1,207 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_LOG_ADAPTER_H_ +#define MINDSPORE_CORE_UTILS_LOG_ADAPTER_H_ + +#include +#include +#include +#include +#include +#include +#include "utils/overload.h" +#include "./securec.h" +#ifndef USE_ANDROID_LOG +#ifdef USE_GLOG +#include "glog/logging.h" +#else +#include "toolchain/slog.h" +#endif +#endif +// NOTICE: when relative path of 'log_adapter.h' changed, macro 'LOG_HDR_FILE_REL_PATH' must be changed +#define LOG_HDR_FILE_REL_PATH "mindspore/core/utils/log_adapter.h" + +// Get start index of file relative path in __FILE__ +static constexpr int GetRelPathPos() noexcept { + return sizeof(__FILE__) > sizeof(LOG_HDR_FILE_REL_PATH) ? sizeof(__FILE__) - sizeof(LOG_HDR_FILE_REL_PATH) : 0; +} + +namespace mindspore { +#define FILE_NAME \ + (sizeof(__FILE__) > GetRelPathPos() ? static_cast(__FILE__) + GetRelPathPos() \ + : static_cast(__FILE__)) +enum ExceptionType { + NoExceptionType = 0, + UnknownError, + ArgumentError, + NotSupportError, + NotExistsError, + AlreadyExistsError, + UnavailableError, + DeviceProcessError, + AbortedError, + TimeOutError, + ResourceUnavailable, + NoPermissionError, + IndexError, + ValueError, + TypeError, + AttributeError, +}; + +struct LocationInfo { + LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {} + ~LocationInfo() = default; + + const char *file_; + int line_; + const char *func_; +}; + +class LogStream { + public: + LogStream() { sstream_ = std::make_shared(); } + ~LogStream() = default; + + template + LogStream &operator<<(const T &val) noexcept { + (*sstream_) << val; + return *this; + } + + LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept { + (*sstream_) << func; + return *this; + } + + friend class LogWriter; + + private: + std::shared_ptr sstream_; +}; + +template ::value, int>::type = 0> +constexpr std::ostream &operator<<(std::ostream &stream, const T &value) { + return stream << static_cast::type>(value); +} + +enum MsLogLevel : int { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION }; + +enum SubModuleId : int { + SM_UNKNOWN = 0, // unknown submodule + SM_BASE, // base + SM_ANALYZER, // static analyzer + SM_COMMON, // common + SM_DEBUG, // debug + SM_DEVICE, // device + SM_GE_ADPT, // ge adapter + SM_IR, // IR + SM_KERNEL, // kernel + SM_MD, // MindData + SM_ME, // MindExpression + SM_ONNX, // ONNX + SM_OPTIMIZER, // optimzer + SM_PARALLEL, // parallel + SM_PARSER, // parser + SM_PIPELINE, // ME pipeline + SM_PRE_ACT, // pre-activate + SM_PYNATIVE, // PyNative + SM_SESSION, // session + SM_UTILS, // utils + SM_VM, // VM + SM_ABSTRACT, // abstract + NUM_SUBMODUES // number of submodules +}; + +#ifndef SUBMODULE_ID +#define SUBMODULE_ID mindspore::SubModuleId::SM_ME +#endif + +const char *EnumStrForMsLogLevel(MsLogLevel level); + +#if defined(_WIN32) || defined(_WIN64) +extern int g_ms_submodule_log_levels[] __attribute__((dllexport)); +#else +extern int g_ms_submodule_log_levels[] __attribute__((visibility("default"))); +#endif + +class LogWriter { + public: + using ExceptionHandler = std::function; + using TraceProvider = std::function; + + LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, + ExceptionType excp_type = NoExceptionType) + : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} + ~LogWriter() = default; + + void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); + void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); + + static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } + static void set_trace_provider(TraceProvider trace_provider) { trace_provider_ = trace_provider; } + + private: + void OutputLog(const std::ostringstream &msg) const; + + LocationInfo location_; + MsLogLevel log_level_; + SubModuleId submodule_; + ExceptionType exception_type_; + + inline static ExceptionHandler exception_handler_ = nullptr; + inline static TraceProvider trace_provider_ = nullptr; +}; + +#define MSLOG_IF(level, condition, excp_type) \ + static_cast(0), !(condition) \ + ? void(0) \ + : mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), level, \ + SUBMODULE_ID, excp_type) < mindspore::LogStream() +#define MSLOG_THROW(excp_type) \ + mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), mindspore::EXCEPTION, SUBMODULE_ID, \ + excp_type) ^ \ + mindspore::LogStream() + +#define IS_OUTPUT_ON(level) (level) >= mindspore::g_ms_submodule_log_levels[SUBMODULE_ID] + +#define MS_LOG(level) MS_LOG_##level + +#define MS_LOG_DEBUG MSLOG_IF(mindspore::DEBUG, IS_OUTPUT_ON(mindspore::DEBUG), mindspore::NoExceptionType) +#define MS_LOG_INFO MSLOG_IF(mindspore::INFO, IS_OUTPUT_ON(mindspore::INFO), mindspore::NoExceptionType) +#define MS_LOG_WARNING MSLOG_IF(mindspore::WARNING, IS_OUTPUT_ON(mindspore::WARNING), mindspore::NoExceptionType) +#define MS_LOG_ERROR MSLOG_IF(mindspore::ERROR, IS_OUTPUT_ON(mindspore::ERROR), mindspore::NoExceptionType) + +#define MS_LOG_EXCEPTION MSLOG_THROW(mindspore::NoExceptionType) +#define MS_EXCEPTION(type) MSLOG_THROW(type) +} // namespace mindspore + +#define MS_EXCEPTION_IF_NULL(ptr) \ + do { \ + if ((ptr) == nullptr) { \ + MS_LOG(EXCEPTION) << ": The pointer[" << #ptr << "] is null."; \ + } \ + } while (0) + +#ifdef DEBUG +#include +#define MS_ASSERT(f) assert(f) +#else +#define MS_ASSERT(f) ((void)0) +#endif + +#endif // MINDSPORE_CORE_UTILS_LOG_ADAPTER_H_ diff --git a/mindspore/core/utils/misc.cc b/mindspore/core/utils/misc.cc new file mode 100644 index 0000000000..ae93c750ee --- /dev/null +++ b/mindspore/core/utils/misc.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/misc.h" + +namespace mindspore { +const int RET_SUCCESS = 0; +const int RET_FAILED = 1; +const int RET_CONTINUE = 2; +const int RET_BREAK = 3; + +std::string demangle(const char *name) { + int status = -1; + std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; + return (status == 0) ? res.get() : name; +} +} // namespace mindspore diff --git a/mindspore/core/utils/misc.h b/mindspore/core/utils/misc.h new file mode 100644 index 0000000000..9fae5abbcf --- /dev/null +++ b/mindspore/core/utils/misc.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_MISC_H_ +#define MINDSPORE_CORE_UTILS_MISC_H_ + +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +extern const int RET_SUCCESS; +extern const int RET_FAILED; +extern const int RET_CONTINUE; +extern const int RET_BREAK; + +// demangle the name to make it human reablable. +extern std::string demangle(const char *name); + +} // namespace mindspore +#endif // MINDSPORE_CORE_UTILS_MISC_H_ diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc new file mode 100644 index 0000000000..a1ad034c00 --- /dev/null +++ b/mindspore/core/utils/ms_context.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/ms_context.h" +#include +#include +#include +#include "ir/tensor.h" +#include "utils/ms_utils.h" + +namespace mindspore { +std::atomic thread_1_must_end(false); + +std::shared_ptr MsContext::inst_context_ = nullptr; +std::map MsContext::policy_map_ = {{"ge", kMsBackendGePrior}, + {"vm", kMsBackendVmOnly}, + {"ms", kMsBackendMsPrior}, + {"ge_only", kMsBackendGeOnly}, + {"vm_prior", kMsBackendVmPrior}}; + +MsContext::MsContext(const std::string &policy, const std::string &target) { + save_graphs_flag_ = false; + save_graphs_path_ = "."; + enable_dump_ = false; + save_dump_path_ = "."; + tsd_ref_ = 0; + ge_ref_ = 0; + is_multi_graph_sink_ = false; + is_pynative_ge_init_ = false; + enable_reduce_precision_ = true; + auto env_device = common::GetEnv("DEVICE_ID"); + if (!env_device.empty()) { + device_id_ = UlongToUint(std::stoul(env_device.c_str())); + } else { + device_id_ = 0; + } + backend_policy_ = policy_map_[policy]; + device_target_ = target; + execution_mode_ = kPynativeMode; + enable_task_sink_ = true; + ir_fusion_flag_ = true; + enable_hccl_ = false; +#ifdef ENABLE_DEBUGGER + enable_mem_reuse_ = false; +#else + enable_mem_reuse_ = true; +#endif + enable_gpu_summary_ = true; + precompile_only_ = false; + auto_mixed_precision_flag_ = false; + enable_pynative_infer_ = false; + enable_pynative_hook_ = false; + enable_dynamic_mem_pool_ = true; + graph_memory_max_size_ = "0"; + variable_memory_max_size_ = "0"; + enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice; + profiling_mode_ = false; + profiling_options_ = "training_trace"; + check_bprop_flag_ = false; + max_device_memory_ = kDefaultMaxDeviceMemory; + print_file_path_ = ""; + enable_graph_kernel_ = false; + enable_sparse_ = false; +} + +std::shared_ptr MsContext::GetInstance() { + if (inst_context_ == nullptr) { + MS_LOG(DEBUG) << "Create new mindspore context"; + if (device_type_seter_) { + device_type_seter_(inst_context_); + } + } + return inst_context_; +} + +bool MsContext::set_backend_policy(const std::string &policy) { + if (policy_map_.find(policy) == policy_map_.end()) { + MS_LOG(ERROR) << "invalid backend policy name: " << policy; + return false; + } + backend_policy_ = policy_map_[policy]; + MS_LOG(INFO) << "ms set context backend policy:" << policy; + return true; +} + +std::string MsContext::backend_policy() const { + auto res = std::find_if( + policy_map_.begin(), policy_map_.end(), + [&, this](const std::pair &item) { return item.second == backend_policy_; }); + if (res != policy_map_.end()) { + return res->first; + } + return "unknown"; +} + +void MsContext::set_execution_mode(int execution_mode) { + if (execution_mode != kGraphMode && execution_mode != kPynativeMode) { + MS_LOG(EXCEPTION) << "The execution mode is invalid!"; + } + execution_mode_ = execution_mode; +} + +bool MsContext::set_device_target(const std::string &target) { + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(ERROR) << "invalid device target name: " << target; + return false; + } + if (target == kDavinciDevice) { + device_target_ = kAscendDevice; + } else { + device_target_ = target; + } + if (seter_) { + seter_(device_target_); + } + MS_LOG(INFO) << "ms set context device target:" << target; + return true; +} + +bool MsContext::set_device_id(uint32_t device_id) { + device_id_ = device_id; + MS_LOG(INFO) << "ms set context device id:" << device_id; + return true; +} + +void MsContext::set_tsd_ref(const std::string &op) { + if (op == "--") { + tsd_ref_--; + } else if (op == "++") { + tsd_ref_++; + } else { + tsd_ref_ = 0; + } +} + +void MsContext::set_ge_ref(const std::string &op) { + if (op == "--") { + ge_ref_--; + } else if (op == "++") { + ge_ref_++; + } else { + ge_ref_ = 0; + } +} +} // namespace mindspore diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h new file mode 100644 index 0000000000..9ad3259b24 --- /dev/null +++ b/mindspore/core/utils/ms_context.h @@ -0,0 +1,207 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ +#define MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +enum MsBackendPolicy { + kMsBackendGeOnly = 0, + kMsBackendVmOnly = 1, + kMsBackendGePrior = 2, + kMsBackendVmPrior = 3, + kMsBackendMsPrior = 4, + kMsBackendUnknown = 5, +}; + +const int kGraphMode = 0; +const int kPynativeMode = 1; +const char kCPUDevice[] = "CPU"; +const char kGPUDevice[] = "GPU"; +const char kAscendDevice[] = "Ascend"; +const char kDavinciInferenceDevice[] = "AscendInference"; +const char kDavinciDevice[] = "Davinci"; +const char KNpuLog[] = "_npu_log"; +const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; +// The default max available device memory is 1024GB. +const float kDefaultMaxDeviceMemory = 1024; + +class MsContext { + public: + MsContext(const std::string &backend_policy, const std::string &target); + ~MsContext() = default; + MsContext(const MsContext &) = delete; + MsContext &operator=(const MsContext &) = delete; + using DeviceSeter = std::function; + using DeviceTypeSeter = std::function &)>; + static std::shared_ptr GetInstance(); + + std::string backend_policy() const; + bool set_backend_policy(const std::string &policy); + + int execution_mode() const { return execution_mode_; } + void set_execution_mode(int execution_mode); + + bool enable_pynative_infer() const { return enable_pynative_infer_; } + void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } + + bool enable_pynative_hook() const { return enable_pynative_hook_; } + void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; } + + bool enable_task_sink() const { return enable_task_sink_; } + + void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } + bool precompile_only() const { return precompile_only_; } + + std::string device_target() const { return device_target_; } + bool set_device_target(const std::string &target); + + uint32_t device_id() const { return device_id_; } + bool set_device_id(uint32_t device_id); + + bool save_graphs_flag() const { return save_graphs_flag_; } + void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } + + std::string save_graphs_path() const { return save_graphs_path_; } + void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } + + bool IsGeInited() { return ge_ref_ > 0; } + void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; } + bool enable_hccl() const { return enable_hccl_; } + bool ir_fusion_flag() const { return ir_fusion_flag_; } + bool loop_sink_flag() const { return enable_loop_sink_; } + void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } + void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } + bool enable_mem_reuse() const { return enable_mem_reuse_; } + + void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } + bool enable_gpu_summary() const { return enable_gpu_summary_; } + + void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) { + auto_mixed_precision_flag_ = auto_mixed_precision_flag; + } + bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; } + + void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; } + bool enable_reduce_precision() const { return enable_reduce_precision_; } + + void set_enable_dump(bool flag) { enable_dump_ = flag; } + bool enable_dump() const { return enable_dump_; } + + void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } + std::string save_dump_path() const { return save_dump_path_; } + + bool IsTsdOpened() const { return tsd_ref_ > 0; } + void set_tsd_ref(const std::string &op); + uint32_t tsd_ref() const { return tsd_ref_; } + + void set_ge_ref(const std::string &op); + uint32_t ge_ref() const { return ge_ref_; } + + bool is_pynative_ge_init() { return is_pynative_ge_init_; } + void set_pynative_ge_init(bool flag) { is_pynative_ge_init_ = flag; } + + bool is_multi_graph_sink() const { return is_multi_graph_sink_; } + void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } + + void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } + bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } + + void set_graph_memory_max_size(const std::string &graph_memory_max_size) { + graph_memory_max_size_ = graph_memory_max_size; + } + + void set_variable_memory_max_size(const std::string &variable_memory_max_size) { + variable_memory_max_size_ = variable_memory_max_size; + } + + const std::string &variable_memory_max_size() const { return variable_memory_max_size_; } + + const std::string &graph_memory_max_size() const { return graph_memory_max_size_; } + + void set_enable_profiling(bool flag) { profiling_mode_ = flag; } + bool enable_profiling() const { return profiling_mode_; } + + void set_profiling_options(const std::string &options) { profiling_options_ = options; } + std::string profiling_options() const { return profiling_options_; } + bool check_bprop_flag() const { return check_bprop_flag_; } + void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } + void set_print_file_path(const std::string &file) { print_file_path_ = file; } + const std::string &print_file_path() const { return print_file_path_; } + + float max_device_memory() const { return max_device_memory_; } + void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } + + void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } + bool enable_graph_kernel() const { return enable_graph_kernel_; } + + bool enable_sparse() const { return enable_sparse_; } + void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } + static void device_seter(DeviceSeter device) { seter_ = device; } + static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } + + std::thread tdt_print_; + + private: + inline static DeviceSeter seter_ = nullptr; + inline static DeviceTypeSeter device_type_seter_ = nullptr; + static std::shared_ptr inst_context_; + static std::map policy_map_; + MsBackendPolicy backend_policy_; + std::string device_target_; + uint32_t device_id_; + int execution_mode_; + bool enable_pynative_infer_; + bool enable_pynative_hook_; + bool save_graphs_flag_; + std::string save_graphs_path_; + uint32_t tsd_ref_; + uint32_t ge_ref_; + bool enable_task_sink_; + bool enable_hccl_; + bool precompile_only_; + bool ir_fusion_flag_; + bool auto_mixed_precision_flag_; + bool enable_reduce_precision_; + bool enable_loop_sink_; + bool enable_mem_reuse_; + bool enable_gpu_summary_; + bool enable_dump_; + std::string save_dump_path_; + bool is_multi_graph_sink_; + bool is_pynative_ge_init_; + bool enable_dynamic_mem_pool_; + std::string graph_memory_max_size_; + std::string variable_memory_max_size_; + bool profiling_mode_; + std::string profiling_options_; + bool check_bprop_flag_; + float max_device_memory_; + std::string print_file_path_; + bool enable_graph_kernel_; + bool enable_sparse_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ diff --git a/mindspore/core/utils/ms_utils.cc b/mindspore/core/utils/ms_utils.cc new file mode 100644 index 0000000000..f6a1567a7b --- /dev/null +++ b/mindspore/core/utils/ms_utils.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "utils/ms_utils.h" +#include +#include +#include + +namespace mindspore { +namespace common { +const int CACHED_STR_NUM = 1 << 8; +const int CACHED_STR_MASK = CACHED_STR_NUM - 1; +std::vector STR_HOLDER(CACHED_STR_NUM); +const char *SafeCStr(const std::string &&str) { + static std::atomic index{0}; + uint32_t cur_index = index++; + cur_index = cur_index & CACHED_STR_MASK; + STR_HOLDER[cur_index] = str; + return STR_HOLDER[cur_index].c_str(); +} +} // namespace common +} // namespace mindspore diff --git a/mindspore/ccsrc/common/utils.h b/mindspore/core/utils/ms_utils.h similarity index 100% rename from mindspore/ccsrc/common/utils.h rename to mindspore/core/utils/ms_utils.h diff --git a/mindspore/core/utils/ordered_map.h b/mindspore/core/utils/ordered_map.h new file mode 100644 index 0000000000..1152990e46 --- /dev/null +++ b/mindspore/core/utils/ordered_map.h @@ -0,0 +1,201 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ +#define MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +// Implementation of OrderedMap that keeps insertion order +// using unordered_map to improve the performance of find/erase, and use list to keep insertion order +template , class Equal = std::equal_to> +class OrderedMap { + public: + using key_t = KeyT; + using value_t = ValueT; + using hasher = Hash; + using equal = Equal; + using pair_type = std::pair; + using sequential_type = std::list; + using iterator = typename sequential_type::iterator; + using const_iterator = typename sequential_type::const_iterator; + using reverse_iterator = typename sequential_type::reverse_iterator; + using const_reverse_iterator = typename sequential_type::const_reverse_iterator; + using map_type = std::unordered_map; + using value_type = typename sequential_type::value_type; + using size_type = typename sequential_type::size_type; + + iterator begin() { return sequential_data_.begin(); } + iterator end() { return sequential_data_.end(); } + const_iterator begin() const { return sequential_data_.cbegin(); } + const_iterator end() const { return sequential_data_.cend(); } + const_iterator cbegin() const { return sequential_data_.cbegin(); } + const_iterator cend() const { return sequential_data_.cend(); } + + reverse_iterator rbegin() { return sequential_data_.rbegin(); } + reverse_iterator rend() { return sequential_data_.rend(); } + const_reverse_iterator rbegin() const { return sequential_data_.rbegin(); } + const_reverse_iterator rend() const { return sequential_data_.rend(); } + + pair_type &front() { return sequential_data_.front(); } + const pair_type &front() const { return sequential_data_.front(); } + pair_type &back() { return sequential_data_.back(); } + const pair_type &back() const { return sequential_data_.back(); } + + OrderedMap() = default; + ~OrderedMap() = default; + OrderedMap(const OrderedMap &os) { + for (auto &item : os.sequential_data_) { + (void)insert(pair_type(item.first, item.second)); + } + } + + // Explicitly construct OrderedMap use sequential_type + explicit OrderedMap(const sequential_type &other) { + for (auto &item : other) { + (void)insert(pair_type(item.first, item.second)); + } + } + + OrderedMap &operator=(const OrderedMap &os) { + if (this != &os) { + for (auto &item : os.sequential_data_) { + (void)insert(pair_type(item.first, item.second)); + } + } + return *this; + } + + void clear() { + if (!map_data_.empty()) { + map_data_.clear(); + } + sequential_data_.clear(); + } + + void swap(OrderedMap &rhs) { + std::swap(map_data_, rhs.map_data_); + std::swap(sequential_data_, rhs.sequential_data_); + } + + void reserve(size_type num_entries) { + map_data_.reserve(num_entries); + sequential_data_.reserve(num_entries); + } + + std::pair add(const key_t &key) { + iterator empty_itr; + std::pair map_pair = std::make_pair(key, empty_itr); + std::pair result = map_data_.insert(map_pair); + auto &seq_itr = result.first->second; + if (result.second) { + auto it = sequential_data_.insert(sequential_data_.end(), std::make_pair(key, ValueT())); + seq_itr = it; + } + return std::pair(seq_itr, result.second); + } + + ValueT &operator[](const key_t &key) { + auto result = add(key); + return (*result.first).second; + } + + std::pair insert(const pair_type &kv) { + auto result = add(kv.first); + if (result.second) { + *(result.first) = kv; + return std::make_pair(std::prev(end()), true); + } + return std::make_pair(result.first, false); + } + + std::pair insert(pair_type &&kv) { + iterator empty_itr; + std::pair map_pair = std::make_pair(kv.first, empty_itr); + std::pair result = map_data_.insert(map_pair); + auto &seq_itr = result.first->second; + if (result.second) { + auto it = sequential_data_.insert(sequential_data_.end(), std::move(kv)); + seq_itr = it; + return std::make_pair(std::prev(end()), true); + } + return std::make_pair(seq_itr, false); + } + + bool empty() const { return sequential_data_.empty(); } + + size_type size() const { return sequential_data_.size(); } + + size_type count(const key_t &key) const { + auto pos = map_data_.find(key); + return pos == map_data_.end() ? 0 : 1; + } + + iterator find(const key_t &key) { + typename map_type::const_iterator pos = map_data_.find(key); + return pos == map_data_.end() ? sequential_data_.end() : (pos->second); + } + + const_iterator find(const key_t &key) const { + auto pos = map_data_.find(key); + return pos == map_data_.end() ? sequential_data_.end() : (pos->second); + } + + // Remove the last element from the sequential_data_. + void pop_back() { + typename map_type::iterator pos = map_data_.find(sequential_data_.back().first); + map_data_.erase(pos); + sequential_data_.pop_back(); + } + + // Remove the first element from the sequential_data_. + void pop_front() { + typename map_type::iterator pos = map_data_.find(sequential_data_.first().first); + map_data_.erase(pos); + sequential_data_.pop_front(); + } + + // Remove the element given by Iterator. + typename sequential_type::iterator erase(const typename sequential_type::iterator &itr) { + (void)map_data_.erase(itr->first); + auto next = sequential_data_.erase(itr); + if (next == sequential_data_.end()) return next; + return next; + } + + // Remove the element with the given key + size_type erase(const key_t &key) { + auto itr = find(key); + if (itr == end()) return 0; + (void)erase(itr); + return 1; + } + + private: + map_type map_data_; + sequential_type sequential_data_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_ORDERED_MAP_H_ diff --git a/mindspore/core/utils/ordered_set.h b/mindspore/core/utils/ordered_set.h new file mode 100644 index 0000000000..ee3aef46f3 --- /dev/null +++ b/mindspore/core/utils/ordered_set.h @@ -0,0 +1,283 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_ORDERED_SET_H_ +#define MINDSPORE_CORE_UTILS_ORDERED_SET_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { + +// Implementation of OrderedSet that keeps insertion order +// using map as set, and use list as a sequential container to record elements to keep insertion order +template , class KeyEqual = std::equal_to> +class OrderedSet { + public: + using element_type = T; + using hasher = Hash; + using equal = KeyEqual; + using sequential_type = std::list; + using vector_type = std::vector; + using iterator = typename sequential_type::iterator; + using const_iterator = typename sequential_type::const_iterator; + using reverse_iterator = typename sequential_type::reverse_iterator; + using const_reverse_iterator = typename sequential_type::const_reverse_iterator; + using map_type = std::unordered_map; + using ordered_set_type = OrderedSet; + + OrderedSet() = default; + ~OrderedSet() = default; + // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, + // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use + // traversal to build elements. + OrderedSet(const OrderedSet &os) { + for (auto &item : os.ordered_data_) { + add(item); + } + } + + explicit OrderedSet(const sequential_type &other) { + for (auto &item : other) { + add(item); + } + } + + // Explicitly construct an OrderedSet use vector + explicit OrderedSet(const vector_type &other) { + for (auto &item : other) { + add(item); + } + } + + OrderedSet &operator=(const OrderedSet &os) { + if (this != &os) { + for (auto &item : os.ordered_data_) { + add(item); + } + } + return *this; + } + + // Add an element to the OrderedSet, without judging return value + void add(const element_type &e) { (void)insert(e); } + + // insert an element to the OrderedSet + std::pair insert(const element_type &e) { + iterator empty_itr; + std::pair map_pair = std::make_pair(e, empty_itr); + auto result = mapped_data_.insert(map_pair); + auto &seq_idx = result.first->second; + // if insert success; + if (result.second) { + auto it = ordered_data_.insert(ordered_data_.end(), e); + seq_idx = it; + } + return std::pair(seq_idx, result.second); + } + + // Remove an element, if removed return true, otherwise return false + bool erase(const element_type &e) { + auto pos = mapped_data_.find(e); + if (pos == mapped_data_.end()) { + return false; + } + // erase the sequential data first + (void)ordered_data_.erase(pos->second); + (void)mapped_data_.erase(pos); + return true; + } + + // Return the container size + std::size_t size() const { return mapped_data_.size(); } + + bool empty() const { return mapped_data_.size() == 0; } + + // Return the string contents in orderset, using ordered_data + std::string toString() { + std::ostringstream res; + res << "orderset content:\n"; + for (auto &item : ordered_data_) { + res << std::to_string(reinterpret_cast(item.get())) << " "; + } + return res.str(); + } + + // Clear the elements + void clear() { + if (!mapped_data_.empty()) { + mapped_data_.clear(); + } + ordered_data_.clear(); + } + + // Compare two orderedset, if the order is not equal shall return false + bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } + + // Remove and return the first element in the OrderedSet + T pop() { + if (ordered_data_.size() != 0) { + T res = ordered_data_.front(); + (void)mapped_data_.erase(res); + (void)ordered_data_.erase(ordered_data_.begin()); + return res; + } + MS_LOG(EXCEPTION) << "pop() on empty OrderedSet"; + } + + T back() { + if (ordered_data_.size() != 0) { + return ordered_data_.back(); + } + MS_LOG(EXCEPTION) << "back() on empty OrderedSet"; + } + + // Return true if there are no common elements + bool is_disjoint(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { + if (mapped_data_.find(item) != mapped_data_.end()) { + return false; + } + } + return true; + } + + // Test whether this is subset of other + bool is_subset(const OrderedSet &other) { + for (auto &item : ordered_data_) { + if (other.mapped_data_.find(item) == other.mapped_data_.end()) { + return false; + } + } + return true; + } + + // Add elements in other to this orderedset + void update(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { + add(item); + } + } + + void update(const std::shared_ptr &other) { update(*other); } + + void update(const sequential_type &other) { + for (auto &item : other) { + add(item); + } + } + + void update(const vector_type &other) { + for (auto &item : other) { + add(item); + } + } + + ordered_set_type get_union(const OrderedSet &other) { + ordered_set_type res(ordered_data_); + res.update(other); + return res; + } + + // Get the union with other set, this operator may cost time because of copy + ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } + + // Return the intersection of two sets + ordered_set_type intersection(const OrderedSet &other) { + ordered_set_type res(ordered_data_); + for (auto &item : ordered_data_) { + if (other.mapped_data_.find(item) == other.mapped_data_.end()) { + (void)res.erase(item); + } + } + return res; + } + ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } + + // Return the symmetric difference of two sets + ordered_set_type symmetric_difference(const OrderedSet &other) { + ordered_set_type res(ordered_data_); + for (auto &item : other.ordered_data_) { + if (mapped_data_.find(item) != mapped_data_.end()) { + (void)res.erase(item); + } else { + res.add(item); + } + } + return res; + } + + ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } + + // Remove elements which is also in others. + void difference_update(const OrderedSet &other) { + // use vector traversal, to keep ordrer + for (auto &item : other.ordered_data_) { + (void)erase(item); + } + } + + void difference_update(const sequential_type &other) { + for (auto &item : other) { + (void)erase(item); + } + } + + void difference_update(const vector_type &other) { + for (auto &item : other) { + (void)erase(item); + } + } + + // Return the set with elements that are not in the others + ordered_set_type difference(const OrderedSet &other) { + ordered_set_type res(ordered_data_); + res.difference_update(other); + return res; + } + ordered_set_type operator-(const OrderedSet &other) { return difference(other); } + + bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } + + // Return the count of an element in set + std::size_t count(const element_type &e) const { return mapped_data_.count(e); } + + iterator begin() { return ordered_data_.begin(); } + iterator end() { return ordered_data_.end(); } + + const_iterator begin() const { return ordered_data_.cbegin(); } + const_iterator end() const { return ordered_data_.cend(); } + + const_iterator cbegin() const { return ordered_data_.cbegin(); } + const_iterator cend() const { return ordered_data_.cend(); } + + private: + map_type mapped_data_; + sequential_type ordered_data_; +}; + +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_ORDERED_SET_H_ diff --git a/mindspore/core/utils/overload.h b/mindspore/core/utils/overload.h new file mode 100644 index 0000000000..baeb134697 --- /dev/null +++ b/mindspore/core/utils/overload.h @@ -0,0 +1,138 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_OVERLOAD_H_ +#define MINDSPORE_CORE_UTILS_OVERLOAD_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore { +template +std::ostream &operator<<(std::ostream &out, const std::vector &v) { + out << "[const vector]["; + size_t last = v.size() - 1; + for (size_t i = 0; i < v.size(); ++i) { + out << v[i]; + if (i != last) out << ", "; + } + out << "]"; + return out; +} + +template +std::ostream &operator<<(std::ostream &os, const std::list &vec) { + bool begin = true; + os << "[const list]["; + for (auto &item : vec) { + if (!begin) { + os << ", "; + } else { + begin = false; + } + os << item; + } + os << "]"; + + return os; +} + +template +std::ostream &operator<<(std::ostream &os, const std::initializer_list &vec) { + bool begin = true; + os << "["; + for (auto &item : vec) { + if (!begin) { + os << ", "; + } else { + begin = false; + } + os << item; + } + os << "]"; + + return os; +} + +template +bool operator==(const std::initializer_list &lhs, const std::initializer_list &rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + auto lit = lhs.begin(); + auto rit = rhs.begin(); + while (lit != lhs.end()) { + if (!(*lit == *rit)) { + return false; + } + lit++; + rit++; + } + return true; +} + +template +std::ostream &operator<<(std::ostream &os, const std::pair &pair) { + os << "[const pair]"; + + return os; +} + +template +std::ostream &operator<<(std::ostream &os, const std::unordered_map &map) { + os << "[const unordered_map]"; + return os; +} + +template +std::ostream &operator<<(std::ostream &os, const std::map &map) { + os << "[const map]"; + return os; +} + +template +std::string ToString(const std::vector &vec) { + std::ostringstream buffer; + + buffer << vec; + return buffer.str(); +} + +template +std::string ToString(const std::unordered_map &map) { + std::ostringstream buffer; + + buffer << map; + return buffer.str(); +} + +template +std::string ToString(const std::map &map) { + std::ostringstream buffer; + + buffer << map; + return buffer.str(); +} +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_OVERLOAD_H_ diff --git a/mindspore/core/utils/profile.cc b/mindspore/core/utils/profile.cc new file mode 100644 index 0000000000..384c2511a3 --- /dev/null +++ b/mindspore/core/utils/profile.cc @@ -0,0 +1,365 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/profile.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace { +constexpr size_t TIME_INFO_PREFIX_NUM_LEN = 4; +const char KEY_PROF_TOTAL[] = "__total__"; + +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = ""); + +void PrintTimeInfoMap(std::ostringstream &oss, const TimeInfoMap &dict, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = "") { + size_t count = 0; + for (const auto &iter : dict) { + count++; + if (iter.second == nullptr) { + continue; + } + // indent by multiples of 4 spaces. + if (iter.first.size() < TIME_INFO_PREFIX_NUM_LEN) { + MS_LOG(EXCEPTION) << "In TimeInfoMap, the " << count << "th string key is " << iter.first + << ", but the length is less than " << TIME_INFO_PREFIX_NUM_LEN; + } + auto name = iter.first.substr(TIME_INFO_PREFIX_NUM_LEN); + oss << std::setw(indent * 4) << "" + << "[" << name << "]: " << iter.second->time_; + if (iter.second->dict_ != nullptr) { + oss << ", [" << iter.second->dict_->size() << "]"; + } + oss << "\n"; + + std::string newPrefix = prefix; + if (iter.first.find("Cycle ") == std::string::npos) { + newPrefix = prefix.empty() ? iter.first : prefix + "." + iter.first; + } + PrintProfile(oss, *iter.second, indent + 1, sums, newPrefix); + if (iter.second->dict_ == nullptr) { + (*sums)[newPrefix] += iter.second->time_; + } + } +} + +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent, std::map *sums, + const std::string &prefix) { + bool need_free = false; + if (sums == nullptr) { + sums = new (std::nothrow) std::map(); + if (sums == nullptr) { + MS_LOG(ERROR) << "memory allocation failed"; + return; + } + need_free = true; + } + + // indent by multiples of 4 spaces. + if (indent == 0) { + oss << "TotalTime = " << time_info.time_; + if (time_info.dict_ != nullptr) { + oss << ", [" << time_info.dict_->size() << "]"; + } + oss << "\n"; + } + + if (time_info.dict_ != nullptr) { + PrintTimeInfoMap(oss, *time_info.dict_, indent, sums, prefix); + } + + // print time percentage info + if (need_free) { + double total = 0.0; + for (auto iter = sums->begin(); iter != sums->end(); ++iter) { + total += iter->second; + } + oss << "Sums\n"; + if (total >= 0.0 + DBL_EPSILON) { + for (auto &iter : *sums) { + std::string name = iter.first; + name.erase(0, TIME_INFO_PREFIX_NUM_LEN); + std::size_t pos = 0; + while ((pos = name.find('.', pos)) != std::string::npos) { + pos++; + name.erase(pos, TIME_INFO_PREFIX_NUM_LEN); + } + oss << " " << std::left << std::setw(36) << name << " : " << std::right << std::setw(12) << std::fixed + << std::setprecision(6) << iter.second << "s : " << std::right << std::setw(5) << std::fixed + << std::setprecision(2) << iter.second / total * 100 << "%\n"; + } + } + delete sums; + } +} +} // namespace + +double GetTime(void) { + struct timeval tv = {0, 0}; + (void)gettimeofday(&tv, nullptr); + return tv.tv_sec + tv.tv_usec * 1.0e-6; +} + +TimeInfo::~TimeInfo() { + if (dict_ == nullptr) { + return; + } + for (auto iter = dict_->begin(); iter != dict_->end(); ++iter) { + delete iter->second; + iter->second = nullptr; + } + delete dict_; + dict_ = nullptr; +} + +ProfileBase::ProfileBase() : context_("", this) { + ctx_ptr_ = &context_; + context_.parent_ = nullptr; +} + +ProfileBase::~ProfileBase() { + context_.parent_ = nullptr; + if (context_.time_info_ != nullptr) { + delete context_.time_info_; + context_.time_info_ = nullptr; + } + ctx_ptr_ = nullptr; +} + +void Profile::Print(void) { + if (ctx_ptr_ == nullptr || ctx_ptr_->time_info_ == nullptr) { + return; + } + std::ostringstream oss; + PrintProfile(oss, *ctx_ptr_->time_info_); + std::string text = oss.str(); + // here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace + (void)printf("%s", text.c_str()); + (void)fflush(stdout); +} + +// Start a step in the current context with the given name. +// Nomes must be unique otherwise the previous record will be overwritten. +ProfContext *Profile::Step(const std::string &name) { + ctx_ptr_ = new (std::nothrow) ProfContext(name, this); + if (ctx_ptr_ == nullptr) { + MS_LOG(ERROR) << "memory allocation failed"; + return nullptr; + } + return ctx_ptr_; +} + +// Creates subcontext for a repeated action. +// Count should be monotonically increasing. +ProfContext *Profile::Lap(int count) { + std::ostringstream oss; + oss << "Cycle " << count; + ctx_ptr_ = new (std::nothrow) ProfContext(oss.str(), this); + if (ctx_ptr_ == nullptr) { + MS_LOG(ERROR) << "memory allocation failed"; + return nullptr; + } + return ctx_ptr_; +} + +void Profile::Pop(void) noexcept { + if (ctx_ptr_ == nullptr) { + return; + } + ctx_ptr_ = ctx_ptr_->parent_; +} + +ProfContext::ProfContext(const std::string &name, ProfileBase *const prof) : name_(name), prof_(prof) { + // Initialize a subcontext. + time_info_ = nullptr; + if (prof == nullptr || IsTopContext()) { + parent_ = nullptr; + } else { + parent_ = prof->ctx_ptr_; + } +} + +ProfContext::~ProfContext() { + // top level context + if (parent_ == nullptr || IsTopContext()) { + if (time_info_ != nullptr) { + delete time_info_; + } + } else { + parent_->Insert(name_, time_info_); + if (prof_ != nullptr) { + prof_->Pop(); + } + } + + time_info_ = nullptr; + prof_ = nullptr; + parent_ = nullptr; +} + +void ProfContext::SetTime(double time) noexcept { + if (time_info_ == nullptr) { + time_info_ = new (std::nothrow) TimeInfo(time); + if (time_info_ == nullptr) { + MS_LOG(ERROR) << "memory allocation failed"; + return; + } + } + time_info_->time_ = time; +} + +void ProfContext::Insert(const std::string &name, const TimeInfo *time) noexcept { + if (time_info_ == nullptr) { + time_info_ = new (std::nothrow) TimeInfo(); + if (time_info_ == nullptr) { + MS_LOG(ERROR) << "memory allocation failed"; + delete time; + time = nullptr; + return; + } + } + + if (time_info_->dict_ == nullptr) { + time_info_->dict_ = new (std::nothrow) TimeInfoMap(); + if (time_info_->dict_ == nullptr) { + MS_LOG(ERROR) << "memory allocation failed"; + delete time; + time = nullptr; + delete time_info_; + time_info_ = nullptr; + return; + } + } + + std::stringstream ss; + ss << std::setw(TIME_INFO_PREFIX_NUM_LEN) << std::setfill('0') << time_info_->actionNum_; + std::string sorted_name(ss.str() + name); + time_info_->actionNum_++; + auto iter = time_info_->dict_->find(sorted_name); + // if contains item with same name, delete it + if (iter != time_info_->dict_->end()) { + delete iter->second; + iter->second = nullptr; + (void)time_info_->dict_->erase(iter); + } + (*time_info_->dict_)[sorted_name] = time; +} + +bool ProfContext::IsTopContext() const noexcept { return (prof_ != nullptr) && (this == &prof_->context_); } + +ProfTransaction::ProfTransaction(const ProfileBase *prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } + +ProfTransaction::~ProfTransaction() { + if (ctx_ != nullptr && !ctx_->IsTopContext()) { + delete ctx_; + } + ctx_ = nullptr; +} + +void DumpTime::Record(const std::string &step_name, const double time, const bool is_start) { + file_ss_ << " {" << std::endl; + file_ss_ << " \"name\": " + << "\"" << step_name << "\"," << std::endl; + file_ss_ << " \"cat\": " + << "\"FUNCTION\"," << std::endl; + if (is_start) { + file_ss_ << " \"ph\": " + << "\"B\"," << std::endl; + } else { + file_ss_ << " \"ph\": " + << "\"E\"," << std::endl; + } + file_ss_ << " \"ts\": " << std::setprecision(16) << time * 1000000 << "," << std::endl; + file_ss_ << " \"pid\": " + << "1" << std::endl; + file_ss_ << " }" << std::endl; + file_ss_ << " ," << std::endl; +} + +void DumpTime::Save() { + try { + file_out_.open(file_path_, std::ios::trunc | std::ios::out); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Cannot open file in " << (file_path_); + } + file_out_ << "{\n"; + file_out_ << " \"traceEvents\": [" << std::endl; + file_ss_ >> file_out_.rdbuf(); + (void)file_out_.seekp(-7, std::ios::end); + file_out_ << " ]" << std::endl << " ,\n"; + file_out_ << " \"displayTimeUnit\": \"ms\"" << std::endl; + file_out_ << "}"; + file_out_.close(); +} + +struct TimeInfoGroup { + double total_time = 0.0; + int total_count = 0; + std::list::const_iterator> items; +}; + +static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, const std::string &prefix) { + oss << "------[" << prefix << "] " << std::setw(10) << std::fixed << std::setprecision(6) << group.total_time + << std::setw(6) << group.total_count << "\n"; + for (const auto &iter : group.items) { + oss << std::setw(5) << std::fixed << std::setprecision(2) << iter->second.time_ / group.total_time * 100 + << "% : " << std::setw(12) << std::fixed << std::setprecision(6) << iter->second.time_ << "s : " << std::setw(6) + << iter->second.count_ << ": " << iter->first << "\n"; + } +} + +void MsProfile::Print() { + GetProfile()->Print(); + std::vector items = {"substitution.", "renormalize.", "replace.", "match.", + "func_graph_cloner_run.", "meta_graph.", "manager.", "pynative"}; + std::vector groups(items.size() + 1); + const auto &stat = GetSingleton().time_stat_; + // group all time infos + for (auto iter = stat.cbegin(); iter != stat.cend(); ++iter) { + auto matched_idx = items.size(); + for (size_t i = 0; i < items.size(); ++i) { + if (iter->first.find(items[i]) != std::string::npos) { + matched_idx = i; + break; + } + } + groups[matched_idx].total_time += iter->second.time_; + groups[matched_idx].total_count += iter->second.count_; + groups[matched_idx].items.push_back(iter); + } + std::ostringstream oss; + for (size_t i = 0; i < groups.size(); ++i) { + std::string prefix = (i < items.size() ? items[i] : std::string("others.")); + PrintTimeStat(oss, groups[i], prefix); + } + std::string text = oss.str(); + // here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace + (void)printf("\nTime group info:\n%s", text.c_str()); + (void)fflush(stdout); +} + +} // namespace mindspore diff --git a/mindspore/core/utils/profile.h b/mindspore/core/utils/profile.h new file mode 100644 index 0000000000..2e236b7f5e --- /dev/null +++ b/mindspore/core/utils/profile.h @@ -0,0 +1,233 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_PROFILE_H_ +#define MINDSPORE_CCSRC_UTILS_PROFILE_H_ + +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +struct TimeInfo; +using TimeInfoMap = std::map; + +extern double GetTime(); + +class ProfileBase; + +struct TimeInfo { + explicit TimeInfo(double time = -1.0) : time_(time), dict_(nullptr), actionNum_(0) {} + TimeInfo(const TimeInfo &) = delete; + ~TimeInfo(); + + double time_; + TimeInfoMap *dict_; + size_t actionNum_; +}; + +// Utility class for Profile. +class ProfContext { + friend class Profile; + friend class ProfileBase; + friend class ProfTransaction; + + public: + ProfContext(const std::string &name, ProfileBase *prof); + ~ProfContext(); + + ProfContext(const ProfContext &) = delete; + ProfContext &operator=(const ProfContext &) = delete; + + void SetTime(double time) noexcept; + void Insert(const std::string &name, const TimeInfo *time) noexcept; + bool IsTopContext() const noexcept; + + private: + std::string name_; + ProfileBase *prof_; + ProfContext *parent_; + TimeInfo *time_info_; +}; + +class ProfileBase { + friend class ProfContext; + friend class ProfTransaction; + + public: + ProfileBase(); + virtual ~ProfileBase(); + + virtual void Print(void) {} + virtual ProfContext *Step(const std::string &) { return nullptr; } + virtual ProfContext *Lap(int) { return nullptr; } + virtual void Pop(void) {} + + // top level profile context + ProfContext context_; + // profile context pointer, act as a stack pointer + ProfContext *ctx_ptr_ = nullptr; +}; + +class Profile : public ProfileBase { + public: + Profile() = default; + ~Profile() override = default; + Profile(const Profile &) = delete; + Profile &operator=(const Profile &) = delete; + + void Print(void) override; + ProfContext *Step(const std::string &name) override; + ProfContext *Lap(int count) override; + void Pop(void) noexcept override; +}; + +class ProfTransaction { + public: + explicit ProfTransaction(const ProfileBase *prof); + explicit ProfTransaction(ProfContext *const ctx) : ctx_(ctx) {} + ProfTransaction(const ProfTransaction &) = delete; + ~ProfTransaction(); + + template + void operator-(const Function &func) { + double start_time = GetTime(); + func(); + double end_time = GetTime(); + if (ctx_ != nullptr) { + ctx_->SetTime(end_time - start_time); + } + } + + private: + ProfContext *ctx_ = nullptr; +}; + +class NoProfTransaction { + public: + explicit NoProfTransaction(ProfileBase *prof) {} + explicit NoProfTransaction(ProfContext *ctx) {} + ~NoProfTransaction() = default; + + template + void operator-(const Function &func) { + func(); + } +}; + +class DumpTime { + public: + ~DumpTime() { + try { + Save(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Cannot save file by profile::DumpTime::save"; + } catch (...) { + MS_LOG(ERROR) << "Uncaught exception"; + } + } + DumpTime(const DumpTime &) = delete; + DumpTime &operator=(const DumpTime &) = delete; + static DumpTime &GetInstance() { + static DumpTime instance; + return instance; + } + void set_file_path(const std::string &save_path) { file_path_ = save_path; } + void Record(const std::string &name, const double time, const bool is_start); + void Save(); + + private: + DumpTime() = default; + std::stringstream file_ss_; + std::ofstream file_out_; + std::string file_path_ = "./timeline.json"; +}; + +struct TimeStat { + TimeStat() { + time_ = 0.0; + count_ = 0; + } + ~TimeStat() = default; + + void operator+=(double t) { + time_ += t; + count_ += 1; + } + + TimeStat operator+(double t) { + TimeStat ts = *this; + ts += t; + return ts; + } + + double time_; + int count_; +}; + +class MsProfile { + public: + ~MsProfile() { Clear(); } + + static void Reset() { GetSingleton().Clear(); } + + static ProfileBase *GetProfile() { + MsProfile &ms_prof = GetSingleton(); + if (ms_prof.profile_ == nullptr) { +#ifdef ENABLE_PROFILE + ms_prof.profile_ = new Profile(); +#else + ms_prof.profile_ = new ProfileBase(); +#endif + } + return ms_prof.profile_; + } + static void StatTime(const std::string &id, double time) { GetSingleton().time_stat_[id] += time; } + + static void Print(); + + private: + MsProfile() = default; + + static MsProfile &GetSingleton() { + static MsProfile profile; + return profile; + } + + void Clear() { + time_stat_.clear(); + if (profile_ != nullptr) { + delete profile_; + profile_ = nullptr; + } + } + + std::map time_stat_; // record time and count info from some activity + ProfileBase *profile_ = nullptr; // record hierarchical profile info +}; + +} // namespace mindspore + +#ifdef ENABLE_PROFILE +#define WITH(x) ProfTransaction(x) - +#else +#define WITH(x) NoProfTransaction(x) - +#endif + +#endif // MINDSPORE_CCSRC_UTILS_PROFILE_H_ diff --git a/mindspore/core/utils/signal.h b/mindspore/core/utils/signal.h new file mode 100644 index 0000000000..a47e54b117 --- /dev/null +++ b/mindspore/core/utils/signal.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_SIGNAL_H_ +#define MINDSPORE_CORE_UTILS_SIGNAL_H_ + +#include +#include +#include +#include + +namespace mindspore { +template +std::function bind_member(Type *instance, Return (Type::*method)(Args...)) { + return [=](Args &&... args) -> Return { return (instance->*method)(std::forward(args)...); }; +} + +template +class Slot { + public: + explicit Slot(const std::function &callback) : callback(callback) {} + + ~Slot() {} + + std::function callback = nullptr; +}; + +template +class Signal { + public: + template + void operator()(Args &&... args) { + for (auto &slot : slots_) { + if (slot->callback != nullptr) { + slot->callback(std::forward(args)...); + } + } + } + + void add_slot(const std::function &func) { + auto slot = std::make_shared>(func); + slots_.push_back(slot); + } + + // signal connect to a class member func + template + void connect(InstanceType instance, MemberFuncType func) { + add_slot(bind_member(instance, func)); + } + + private: + std::vector>> slots_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_EVENT_H_ diff --git a/mindspore/ccsrc/utils/symbolic.cc b/mindspore/core/utils/symbolic.cc similarity index 100% rename from mindspore/ccsrc/utils/symbolic.cc rename to mindspore/core/utils/symbolic.cc diff --git a/mindspore/core/utils/symbolic.h b/mindspore/core/utils/symbolic.h new file mode 100644 index 0000000000..903d25794e --- /dev/null +++ b/mindspore/core/utils/symbolic.h @@ -0,0 +1,174 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_SYMBOLIC_H_ +#define MINDSPORE_CORE_UTILS_SYMBOLIC_H_ + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "abstract/abstract_value.h" + +namespace mindspore { + +class SymbolicKeyInstance : public Value { + public: + SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) + : node_(node), abstract_(abstract) {} + ~SymbolicKeyInstance() override = default; + MS_DECLARE_PARENT(SymbolicKeyInstance, Value); + AnfNodePtr node() const { return node_; } + abstract::AbstractBasePtr abstract() const { return abstract_; } + bool operator==(const SymbolicKeyInstance &other) const { + return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); + } + + std::size_t hash() const override { return std::hash{}(node_); } + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { + if (inst == nullptr) { + os << "[Key][" + << "Invalid symbolic key instance" + << "]"; + } else { + os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString(); + } + return os; + } + std::string ToString() const override { + return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); + } + bool operator==(const Value &other) const override { + if (other.isa()) { + auto other_ = static_cast(other); + return *this == other_; + } else { + return false; + } + } + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), + std::make_shared()); + } + + private: + AnfNodePtr node_; + abstract::AbstractBasePtr abstract_; +}; + +using SymbolicKeyInstancePtr = std::shared_ptr; + +struct SymbolicKeyInstanceHash { + std::size_t operator()(const SymbolicKeyInstancePtr s) const { + if (s == nullptr) { + return 0; + } + return s->abstract()->hash(); + } +}; + +struct SymbolicKeyInstanceEqual { + bool operator()(const SymbolicKeyInstancePtr lhs, const SymbolicKeyInstancePtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + MS_EXCEPTION_IF_NULL(lhs->node()); + MS_EXCEPTION_IF_NULL(rhs->node()); + MS_EXCEPTION_IF_NULL(lhs->abstract()); + MS_EXCEPTION_IF_NULL(rhs->abstract()); + return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract()); + } +}; + +using EnvInstanceContentsMap = + std::unordered_map; + +// Environment mapping keys to values. +// Keys are SymbolicKeyInstances, which represent nodes in the graph along +// with inferred properties. +class EnvInstance : public Value { + public: + friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); + + explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} + ~EnvInstance() override = default; + MS_DECLARE_PARENT(EnvInstance, Value); + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), std::make_shared()); + } + bool operator==(const EnvInstance &other) const; + bool operator==(const Value &other) const override; + EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} + EnvInstance(EnvInstance &&v) = default; + EnvInstance &operator=(EnvInstance &&src) noexcept { + if (&src != this) { + contents_ = src.contents_; + } + return *this; + }; + + // Get the sensitivity list for the given key + const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { + auto iterator = contents_.find(key); + if (iterator != contents_.end()) { + return iterator->second; + } + return def; + } + + // Set a value for the given key. + EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { + EnvInstance rval(contents_); + rval.contents_[key] = value; + return rval; + } + + // Add two EnvInstances. + EnvInstance Add(const EnvInstance &other) const { + EnvInstance rval(contents_); + for (auto iter_other : other.contents_) { + auto item_self = contents_.find(iter_other.first); + if (item_self != contents_.end()) { + MS_LOG(DEBUG) << "Need to use add"; + } else { + rval.contents_[iter_other.first] = iter_other.second; + } + } + return rval; + } + + size_t Len() const { return contents_.size(); } + std::size_t hash() const override { + // deterministic characteristic of member variables. + return Len(); + } + + private: + EnvInstanceContentsMap contents_; +}; + +using EnvInstancePtr = std::shared_ptr; + +extern std::shared_ptr newenv; + +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_SYMBOLIC_H_ diff --git a/mindspore/core/utils/trace_base.cc b/mindspore/core/utils/trace_base.cc new file mode 100644 index 0000000000..aa9fde6f5b --- /dev/null +++ b/mindspore/core/utils/trace_base.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/trace_base.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/graph_utils.h" + +namespace mindspore { +// namespace to support debug trace infomation +namespace trace { +std::vector GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { + std::vector debug_with_loc_vec; + while (debug_info != nullptr) { + if (debug_info->location() != nullptr) { + debug_with_loc_vec.push_back(debug_info); + } + if (debug_info->trace_info() != nullptr) { + debug_info = debug_info->trace_info()->debug_info(); + } else { + break; + } + } + return debug_with_loc_vec; +} + +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { + auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); + if (debug_with_loc_vec.size() > 0) { + return debug_with_loc_vec[0]; + } else { + return info; + } +} + +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { + if (info == nullptr) { + return ""; + } + auto src_info = GetSourceCodeDebugInfo(info); + if (src_info->location() != nullptr) { + return src_info->location()->ToString(tip); + } + return ""; +} + +// a trace info identifies a node transform, so we can trace the node transform through +// a link of trace info and debug info +std::string GetInfoWithAction(const std::vector &info_vec, SourceLineTip tip) { + if (info_vec.size() < 1) { + return ""; + } + if (info_vec.size() == 1) { + return info_vec[0]->location()->ToString(tip); + } + std::string traced_info = info_vec[0]->location()->ToString(tip); + for (size_t i = 1; i < info_vec.size(); i++) { + auto action_name = info_vec[i - 1]->trace_info()->GetActionBetweenNode(info_vec[i]); + if (action_name == "") { + break; + } + traced_info = traced_info + action_name + info_vec[i]->location()->ToString(tip); + } + return traced_info; +} + +std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { + if (info == nullptr) { + return ""; + } + auto info_vec = GetSourceCodeDebugInfoVec(info); + if (info_vec.size() == 0) { + return ""; + } else if (info_vec.size() == 1) { + return info_vec[0]->location()->ToString(tip); + } else if (info_vec.size() > 1) { + return GetInfoWithAction(info_vec, tip); + } + return ""; +} + +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { + std::ostringstream oss; + if (info == nullptr) { + return ""; + } + + auto debug_info = GetTracedDebugInfo(info, tip); + if (tip == kSourceLineTipDiscard) { + std::replace(debug_info.begin(), debug_info.end(), '\r', '/'); + std::replace(debug_info.begin(), debug_info.end(), '\n', '/'); + } + oss << prefix << debug_info; + return oss.str(); +} +} // namespace trace +} // namespace mindspore diff --git a/mindspore/core/utils/trace_base.h b/mindspore/core/utils/trace_base.h new file mode 100644 index 0000000000..0807789faa --- /dev/null +++ b/mindspore/core/utils/trace_base.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_UTILS_TRACE_BASE_H_ +#define MINDSPORE_CORE_UTILS_TRACE_BASE_H_ + +#include +#include +#include +#include +#include +#include + +#include "utils/info.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "utils/any.h" + +namespace mindspore { +namespace trace { +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, + SourceLineTip tip = kSourceLineTipNextLine); +} // namespace trace +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_TRACE_BASE_H_ diff --git a/mindspore/core/utils/trace_info.cc b/mindspore/core/utils/trace_info.cc new file mode 100644 index 0000000000..26bcb8e7ce --- /dev/null +++ b/mindspore/core/utils/trace_info.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/trace_info.h" +#include +#include +#include +#include "ir/anf.h" + +namespace mindspore { +std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { + if (info == nullptr) { + return ""; + } + std::string act_name = action_name(); + if (debug_info() == nullptr) { + MS_LOG(EXCEPTION) << "Traced debug info is null"; + } + if (debug_info() == info) { + return act_name; + } else if (debug_info()->trace_info() != nullptr) { + return act_name + debug_info()->trace_info()->GetActionBetweenNode(info); + } + return "Not in the traced info"; +} +} // namespace mindspore diff --git a/mindspore/core/utils/trace_info.h b/mindspore/core/utils/trace_info.h new file mode 100644 index 0000000000..fea2cb3ea8 --- /dev/null +++ b/mindspore/core/utils/trace_info.h @@ -0,0 +1,417 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_TRACE_INFO_H_ +#define MINDSPORE_CORE_UTILS_TRACE_INFO_H_ + +#include +#include +#include +#include +#include +#include + +#include "base/base.h" + +namespace mindspore { +class TraceInfo; +using TraceInfoPtr = std::shared_ptr; +class Location; +using LocationPtr = std::shared_ptr; +class DebugInfo; +using DebugInfoPtr = std::shared_ptr; + +// namespace to support intermediate representation definition +class TraceInfo : public Base { + public: + TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { + symbol_ = symbol; + full_name_ = full_name; + name_ = full_name_; + debug_info_ = info; + } + TraceInfo(const TraceInfo &info) + : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} + ~TraceInfo() override = default; + MS_DECLARE_PARENT(TraceInfo, Base); + virtual std::string name() { return name_; } + virtual std::string symbol() { return symbol_; } + virtual std::string full_name() { return full_name_; } + virtual TraceInfoPtr clone() { return shared_from_base(); } + virtual std::string action_name() { return ""; } + virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); + void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } + DebugInfoPtr debug_info() { return debug_info_; } + DebugInfoPtr DebugInfoHasLoc(); + std::vector> GetSourceCodeDebugInfo(); + + protected: + DebugInfoPtr debug_info_; + std::string symbol_; + std::string full_name_; + std::string name_; +}; + +class TracePhi : public TraceInfo { + public: + explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} + MS_DECLARE_PARENT(TracePhi, TraceInfo); + ~TracePhi() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceIfStmtTrueBranch : public TraceInfo { + public: + TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; + explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} + MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); + ~TraceIfStmtTrueBranch() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceIfStmtFalseBranch : public TraceInfo { + public: + TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; + explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} + MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); + ~TraceIfStmtFalseBranch() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceIfStmtAfterBranch : public TraceInfo { + public: + explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} + MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); + ~TraceIfStmtAfterBranch() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceIfExpTrueBranch : public TraceInfo { + public: + explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} + MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); + ~TraceIfExpTrueBranch() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceIfExpFalseBranch : public TraceInfo { + public: + explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} + MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); + ~TraceIfExpFalseBranch() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceCopy : public TraceInfo { + public: + TraceCopy() : TraceInfo(nullptr, "copy", "") {} + explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} + MS_DECLARE_PARENT(TraceCopy, TraceInfo); + ~TraceCopy() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceIterator : public TraceInfo { + public: + explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} + MS_DECLARE_PARENT(TraceIterator, TraceInfo); + ~TraceIterator() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceWhileHeader : public TraceInfo { + public: + explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} + MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); + ~TraceWhileHeader() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceWhileBody : public TraceInfo { + public: + explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} + MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); + ~TraceWhileBody() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceWhileAfter : public TraceInfo { + public: + explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} + MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); + ~TraceWhileAfter() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceForHeader : public TraceInfo { + public: + explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} + MS_DECLARE_PARENT(TraceForHeader, TraceInfo); + ~TraceForHeader() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceForBody : public TraceInfo { + public: + explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} + MS_DECLARE_PARENT(TraceForBody, TraceInfo); + ~TraceForBody() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceForAfter : public TraceInfo { + public: + explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} + MS_DECLARE_PARENT(TraceForAfter, TraceInfo); + ~TraceForAfter() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceLoopEnd : public TraceInfo { + public: + explicit TraceLoopEnd(const DebugInfoPtr &info) : TraceInfo(info, "loop_end", "↓↓") {} + MS_DECLARE_PARENT(TraceLoopEnd, TraceInfo); + ~TraceLoopEnd() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceEquiv : public TraceInfo { + public: + explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} + MS_DECLARE_PARENT(TraceEquiv, TraceInfo); + ~TraceEquiv() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceGradFpropApp : public TraceInfo { + public: + TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} + explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} + MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); + ~TraceGradFpropApp() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceGradBpropApp : public TraceInfo { + public: + TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} + explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} + MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); + ~TraceGradBpropApp() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceGradFprop : public TraceInfo { + public: + TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} + explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} + MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); + ~TraceGradFprop() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceGradBprop : public TraceInfo { + public: + TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} + explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} + MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); + ~TraceGradBprop() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceGradSens : public TraceInfo { + public: + TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} + explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} + MS_DECLARE_PARENT(TraceGradSens, TraceInfo); + ~TraceGradSens() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceSpecialize : public TraceInfo { + public: + explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } + MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); + std::string name() override { return full_name_ + counter_; } + std::string symbol() override { return counter_ + "_"; } + std::string full_name() override { return full_name_ + counter_ + "_"; } + ~TraceSpecialize() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } + std::string counter_; +}; + +class TraceGradOperation : public TraceInfo { + public: + explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} + MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); + ~TraceGradOperation() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceForceBool : public TraceInfo { + public: + explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} + MS_DECLARE_PARENT(TraceForceBool, TraceInfo); + ~TraceForceBool() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceForceWhileCond : public TraceInfo { + public: + explicit TraceForceWhileCond(const DebugInfoPtr &info) : TraceInfo(info, "force_while_cond", "") {} + MS_DECLARE_PARENT(TraceForceWhileCond, TraceInfo); + ~TraceForceWhileCond() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceExpandJ : public TraceInfo { + public: + explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} + MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); + ~TraceExpandJ() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceGenMetaFuncGraph : public TraceInfo { + public: + explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} + MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); + ~TraceGenMetaFuncGraph() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceEvaluatorGenGraph : public TraceInfo { + public: + explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} + MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); + ~TraceEvaluatorGenGraph() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceResolve : public TraceInfo { + public: + explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} + MS_DECLARE_PARENT(TraceResolve, TraceInfo); + ~TraceResolve() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceTransform : public TraceInfo { + public: + TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } + explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { + transform_name_ = transform_name; + } + + std::string full_name() override { return full_name_ + transform_name_; } + MS_DECLARE_PARENT(TraceTransform, TraceInfo); + std::string symbol() override { + if (transform_name_.empty()) { + return ""; + } + return transform_name_ + "_"; + } + + ~TraceTransform() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } + std::string transform_name_; +}; + +class TraceGenerateVarArg : public TraceInfo { + public: + explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} + MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); + ~TraceGenerateVarArg() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceGenerateKwArg : public TraceInfo { + public: + explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} + MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); + ~TraceGenerateKwArg() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceTrasformK : public TraceInfo { + public: + explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} + MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); + ~TraceTrasformK() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TracePartialTransform : public TraceInfo { + public: + explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} + MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); + ~TracePartialTransform() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; + +class TraceGetEnv : public TraceInfo { + public: + explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} + MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); + ~TraceGetEnv() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceDoSignature : public TraceInfo { + public: + explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} + MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); + ~TraceDoSignature() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + +class TraceCombileLikeGraphs : public TraceInfo { + public: + TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} + explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} + MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); + ~TraceCombileLikeGraphs() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_ diff --git a/mindspore/core/utils/visible.h b/mindspore/core/utils/visible.h new file mode 100644 index 0000000000..b206876d3e --- /dev/null +++ b/mindspore/core/utils/visible.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_VISIBLE_H_ +#define MINDSPORE_CORE_UTILS_VISIBLE_H_ + +namespace mindspore { +// refer to https://gcc.gnu.org/wiki/Visibility +#if defined _WIN32 || defined __CYGWIN__ +#ifdef BUILDING_DLL +#ifdef __GNUC__ +#define MS_EXPORT __attribute__((dllexport)) +#else +#define MS_EXPORT __declspec(dllexport) // Note: actually gcc seems to also supports this syntax. +#endif +#else +#ifdef __GNUC__ +#define MS_EXPORT __attribute__((dllimport)) +#else +#define MS_EXPORT __declspec(dllimport) // Note: actually gcc seems to also supports this syntax. +#endif +#endif +#define MS_LOCAL +#else +#define MS_EXPORT __attribute__((visibility("default"))) +#define MS_LOCAL __attribute__((visibility("hidden"))) +#endif + +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_VISIBLE_H_ diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index b2d26b41ee..eb9444a05a 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -21,7 +21,7 @@ can also create samplers with this module to sample data. from .core import config from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ - TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset + TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler, Sampler from .engine.cache_client import DatasetCache @@ -31,5 +31,5 @@ from .engine.graphdata import GraphData __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", - "CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", + "CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] diff --git a/mindspore/dataset/core/datatypes.py b/mindspore/dataset/core/datatypes.py index 292af67e8a..62f59c6f15 100644 --- a/mindspore/dataset/core/datatypes.py +++ b/mindspore/dataset/core/datatypes.py @@ -27,7 +27,7 @@ def mstype_to_detype(type_): Get de data type corresponding to mindspore dtype. Args: - type_ (:class:`mindspore.dtype`): MindSpore's dtype. + type_ (mindspore.dtype): MindSpore's dtype. Returns: The data type of de. @@ -57,7 +57,7 @@ def mstypelist_to_detypelist(type_list): Get list[de type] corresponding to list[mindspore.dtype]. Args: - type_list (:list[mindspore.dtype]): a list of MindSpore's dtype. + type_list (list[mindspore.dtype]): a list of MindSpore's dtype. Returns: The list of de data type. diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 8806babd63..d9f3267244 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -19,6 +19,8 @@ import inspect from multiprocessing import cpu_count import os import numpy as np + +import mindspore._c_dataengine as cde from ..engine import samplers # POS_INT_MIN is used to limit values from starting from 0 @@ -153,8 +155,8 @@ def parse_user_args(method, *args, **kwargs): Args: method (method): a callable function. - *args: user passed args. - **kwargs: user passed kwargs. + args: user passed args. + kwargs: user passed kwargs. Returns: user_filled_args (list): values of what the user passed in for the arguments. @@ -179,16 +181,18 @@ def type_check_list(args, types, arg_names): Check the type of each parameter in the list. Args: - args (list, tuple): a list or tuple of any variable. + args (Union[list, tuple]): a list or tuple of any variable. types (tuple): tuple of all valid types for arg. - arg_names (list, tuple of str): the names of args. + arg_names (Union[list, tuple of str]): the names of args. Returns: Exception: when the type is not correct, otherwise nothing. """ type_check(args, (list, tuple,), arg_names) - if len(args) != len(arg_names): + if len(args) != len(arg_names) and not isinstance(arg_names, str): raise ValueError("List of arguments is not the same length as argument_names.") + if isinstance(arg_names, str): + arg_names = ["{0}[{1}]".format(arg_names, i) for i in range(len(args))] for arg, arg_name in zip(args, arg_names): type_check(arg, types, arg_name) @@ -198,7 +202,7 @@ def type_check(arg, types, arg_name): Check the type of the parameter. Args: - arg : any variable. + arg (Any) : any variable. types (tuple): tuple of all valid types for arg. arg_name (str): the name of arg. @@ -296,7 +300,6 @@ def check_padding_options(param_dict): """ columns_list = param_dict.get('columns_list') - block_reader = param_dict.get('block_reader') padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') if padded_sample is not None: if num_padded is None: @@ -308,9 +311,6 @@ def check_padding_options(param_dict): for column in columns_list: if column not in padded_sample: raise ValueError("padded_sample cannot match columns_list.") - if block_reader: - raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") - if padded_sample is None and num_padded is not None: raise RuntimeError("num_padded is specified but padded_sample is not.") @@ -342,7 +342,7 @@ def check_gnn_list_or_ndarray(param, param_name): Check if the input parameter is list or numpy.ndarray. Args: - param (list, nd.ndarray): param. + param (Union[list, nd.ndarray]): param. param_name (str): param_name. Returns: @@ -358,3 +358,9 @@ def check_gnn_list_or_ndarray(param, param_name): if not param.dtype == np.int32: raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( param_name, param.dtype)) + + +def check_tensor_op(param, param_name): + """check whether param is a tensor op or a callable python function""" + if not isinstance(param, cde.TensorOp) and not callable(param): + raise TypeError("{0} is not a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name)) diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index b3624e1ca3..15eb66e54e 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -29,7 +29,7 @@ from .samplers import * from ..core import config __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", - "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", + "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index 800c0dab1d..d140a0cb55 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -18,21 +18,20 @@ import copy from mindspore._c_dataengine import CacheClient +from ..core.validator_helpers import type_check, check_uint32, check_uint64 + class DatasetCache: """ A client to interface with tensor caching service """ - def __init__(self, session_id=None, size=None, spilling=False): - if session_id is None: - raise RuntimeError("Session generation is not implemented yet. session id required") - self.size = size if size is not None else 0 - if size < 0: - raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size)) - if not isinstance(spilling, bool): - raise ValueError( - "spilling argument for cache should be a boolean value but got: spilling={}".format(spilling)) + def __init__(self, session_id=None, size=0, spilling=False): + check_uint32(session_id, "session_id") + check_uint64(size, "size") + type_check(spilling, (bool,), "spilling") + self.session_id = session_id + self.size = size self.spilling = spilling self.cache_client = CacheClient(session_id, size, spilling) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 846e7e0a56..5491dc4558 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -33,19 +33,20 @@ import copy import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ - MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo + MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo from mindspore._c_expression import typing from mindspore import log as logger from . import samplers -from .iterators import DictIterator, TupleIterator, DummyIterator +from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ - check_rename, check_numpyslicesdataset, \ + check_rename, check_numpyslicesdataset, check_device_send, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 + check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist +from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE try: context = import_module("mindspore.context") @@ -134,6 +135,7 @@ class Dataset: """ def __init__(self, num_parallel_workers=None): + # Note: children and parent are internal variables, not recommand for external using. self.children = [] self.parent = [] self.num_parallel_workers = num_parallel_workers @@ -186,13 +188,13 @@ class Dataset: except for maybe the last batch for each bucket. Args: - column_names (list of string): Columns passed to element_length_function. - bucket_boundaries (list of int): A list consisting of the upper boundaries + column_names (list[str]): Columns passed to element_length_function. + bucket_boundaries (list[int]): A list consisting of the upper boundaries of the buckets. Must be strictly increasing. If there are n boundaries, n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each 0 1: - self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) - else: - self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) - else: - if num_parallel_workers > 1: - self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) - else: - self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) - else: - try: - iter(source) - except TypeError: - # Use generator function if input callable - self.source = (lambda: _generator_fn(source, num_samples)) - else: - # Use iterator function if input is iterable - # Random accessible input is also iterable - self.source = (lambda: _iter_fn(source, num_samples)) + self.num_samples = num_samples if column_names is not None and not isinstance(column_names, list): column_names = [column_names] @@ -3228,12 +3268,39 @@ class GeneratorDataset(MappableDataset): memodict[id(self)] = new_op new_op.children = copy.deepcopy(self.children, memodict) new_op.parent = copy.deepcopy(self.parent, memodict) + new_op.ms_role = copy.deepcopy(self.ms_role, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict) - - new_op.source = self.source - new_op.sampler = self.sampler + new_op.num_samples = copy.deepcopy(self.num_samples, memodict) + + new_op.sampler = copy.deepcopy(self.sampler) + if new_op.sampler is not None and hasattr(self.source, "__getitem__"): + if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, + samplers.RandomSampler, samplers.SubsetRandomSampler, + samplers.WeightedRandomSampler, samplers.Sampler)): + sampler_instance = new_op.sampler.create() + sampler_instance.set_num_rows(len(self.source)) + sampler_instance.initialize() + if new_op.num_parallel_workers > 1: + new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) + else: + new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) + else: + if new_op.num_parallel_workers > 1: + new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) + else: + new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) + else: + try: + iter(self.source) + except TypeError: + # Use generator function if input callable + new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples)) + else: + # Use iterator function if input is iterable + # Random accessible input is also iterable + new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples)) return new_op @@ -3249,9 +3316,9 @@ class TFRecordDataset(SourceDataset): A source dataset that reads and parses datasets stored on disk in TFData format. Args: - dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of - files. The list will be sorted in a lexicographical order. - schema (str or Schema, optional): Path to the json schema file or schema object (default=None). + dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a + pattern of files. The list will be sorted in a lexicographical order. + schema (Union[str, Schema], optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the meta data from the TFData file is considered the schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) num_samples (int, optional): number of samples(rows) to read (default=None). @@ -3260,7 +3327,8 @@ class TFRecordDataset(SourceDataset): If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). - shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch + (default=Shuffle.GLOBAL). If shuffle is False, no shuffling will be performed; If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL Otherwise, there are two levels of shuffling: @@ -3275,7 +3343,8 @@ class TFRecordDataset(SourceDataset): argument should be specified only when num_shards is also specified. shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number of rows of each shard may be not equal. - cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Examples: >>> import mindspore.dataset as ds >>> import mindspore.common.dtype as mstype @@ -3840,13 +3909,14 @@ class RandomDataset(SourceDataset): Args: total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random) - schema (str or Schema, optional): Path to the json schema file or schema object (default=None). + schema (Union[str, Schema], optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the random dataset generates a random schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) num_samples (int): number of samples to draw from the total. (default=None, which means all rows) num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). - cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. shuffle (bool, optional): Whether or not to perform shuffle on the dataset (default=None, expected order behavior shown in the table). num_shards (int, optional): Number of shards that the dataset should be divided @@ -4016,7 +4086,7 @@ class Schema: Parse the columns and add it to self. Args: - columns (dict or list[dict]): dataset attribution information, decoded from schema file. + columns (Union[dict, list[dict]]): dataset attribution information, decoded from schema file. - list[dict], 'name' and 'type' must be in keys, 'shape' optional. @@ -4107,13 +4177,11 @@ class VOCDataset(MappableDataset): """ A source dataset for reading and parsing VOC dataset. - The generated dataset has two columns : - task='Detection' : ['image', 'annotation']; - task='Segmentation' : ['image', 'target']. - The shape of both column 'image' and 'target' is [image_size] if decode flag is False, or [H, W, C] - otherwise. - The type of both tensor 'image' and 'target' is uint8. - The type of tensor 'annotation' is uint32. + The generated dataset has multi-columns : + + - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], + ['difficult', dtype=uint32], ['truncate', dtype=uint32]]. + - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table below shows what input args are allowed and their expected behavior. @@ -4631,15 +4699,16 @@ class CLUEDataset(SourceDataset): } Args: - dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of - files. The list will be sorted in a lexicographical order. + dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for + a pattern of files. The list will be sorted in a lexicographical order. task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. (default=AFQMC). usage (str, optional): Need train, test or eval data (default="train"). num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). - shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch + (default=Shuffle.GLOBAL). If shuffle is False, no shuffling will be performed; If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL Otherwise, there are two levels of shuffling: @@ -4839,18 +4908,122 @@ class CLUEDataset(SourceDataset): return False +class CSVDataset(SourceDataset): + """ + A source dataset that reads and parses CSV datasets. + + Args: + dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search + for a pattern of files. The list will be sorted in a lexicographical order. + field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). + column_defaults (list, optional): List of default values for the CSV field (default=None). Each item + in the list is either a valid type (float, int, or string). If this is not provided, treats all + columns as string type. + column_names (list[str], optional): List of column names of the dataset (default=None). If this + is not provided, infers the column_names from the first row of CSV file. + num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset). + num_parallel_workers (int, optional): number of workers to read the data + (default=None, number set in the config). + shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch + (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. + + Examples: + >>> import mindspore.dataset as ds + >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files + >>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4']) + """ + + @check_csvdataset + def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1, + num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.field_delim = field_delim + self.column_defaults = column_defaults + self.column_names = column_names + self.num_samples = num_samples + + if not isinstance(shuffle, (bool, Shuffle)): + raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") + if not isinstance(shuffle, Shuffle): + if shuffle: + self.shuffle_level = Shuffle.GLOBAL + self.shuffle_files = True + else: + self.shuffle_level = None + self.shuffle_files = False + else: + self.shuffle_level = shuffle + self.shuffle_files = True + + self.num_shards = num_shards + self.shard_id = shard_id + + def get_args(self): + args = super().get_args() + args["dataset_files"] = self.dataset_files + args['field_delim'] = self.field_delim + args['column_defaults'] = self.column_defaults + args['column_names'] = self.column_names + args["num_samples"] = self.num_samples + if self.shuffle_files is not None: + args["shuffle_files"] = self.shuffle_files + args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) + args["shuffle"] = self.shuffle_level + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples == -1: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size + + def is_shuffled(self): + return self.shuffle_files + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return False + + class TextFileDataset(SourceDataset): """ A source dataset that reads and parses datasets stored on disk in text format. The generated dataset has one columns ['text']. Args: - dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of - files. The list will be sorted in a lexicographical order. + dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a + pattern of files. The list will be sorted in a lexicographical order. num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). - shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch + (default=Shuffle.GLOBAL). If shuffle is False, no shuffling will be performed; If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL Otherwise, there are two levels of shuffling: @@ -5031,17 +5204,17 @@ class NumpySlicesDataset(GeneratorDataset): - not allowed Args: - data (list, tuple or dict) Input of Given data, supported data type includes list, tuple, dict and other numpy - format. Input data will be sliced in first dimension and generate many rows, large data is not recommend to - load in this way as data is loading into memory. + data (Union[list, tuple, dict]) Input of Given data, supported data type includes list, tuple, dict and other + numpy format. Input data will be sliced in first dimension and generate many rows, large data is not + recommend to load in this way as data is loading into memory. column_names (list[str], optional): List of column names of the dataset (default=None). If column_names not provided, when data is dict, column_names will be its key, otherwise it will be like column_1, column_2 ... num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. (default=None, expected order behavior shown in the table). - sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is - required (default=None, expected order behavior shown in the table). + sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible + input is required (default=None, expected order behavior shown in the table). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). When this argument is specified, 'num_samples' will not effect. Random accessible input is required. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only @@ -5082,8 +5255,8 @@ class BuildVocabDataset(DatasetOp): Args: vocab(Vocab): text.vocab object. - columns(str or list, optional): column names to get words from. It can be a list of column names (Default is - None, all columns are used, return error if any column isn't string). + columns(Union[str, list], optional): column names to get words from. It can be a list of column names (Default + is None, all columns are used, return error if any column isn't string). freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency can be None, which corresponds to 0/total_words separately (default=None, all words are included). @@ -5139,3 +5312,59 @@ class BuildVocabDataset(DatasetOp): new_op.special_first = copy.deepcopy(self.special_first) return new_op + + +class BuildSentencePieceVocabDataset(DatasetOp): + """ + Build a SentencePieceVocab from a dataset. + This function is not meant to be called directly by user. To build vocab, please use the function + text.SentencePieceVocab.from_dataset() + + Args: + vocab(SentencePieceVocab): text.SentencePieceVocab object. + col_names(list): The list of the col name. + vocab_size(int): Vocabulary size, the type of uint32_t. + charater_coverage(float): Amount of characters covered by the model, good defaults are: 0.9995 for languages + with rich character set like Japanse or Chinese and 1.0 for other languages with small character set. + model_type(SentencePieceModel): Model type.Choose from unigram (default), bpe, char, or word. + The input sentence must be pretokenized when using word type. + params(dict): A dictionary with no incoming parameters. + """ + + def __init__(self, input_dataset, vocab, col_names, vocab_size, character_coverage, model_type, params): + super().__init__() + self.vocab = vocab + self.col_names = col_names + self.vocab_size = vocab_size + self.children.append(input_dataset) + self.character_coverage = character_coverage + self.model_type = DE_C_INTER_SENTENCEPIECE_MODE[model_type] + self.params = params + input_dataset.parent.append(self) + + def get_args(self): + args = super().get_args() + args["vocab"] = self.vocab + args["col_names"] = self.col_names + args["vocab_size"] = self.vocab_size + args["character_coverage"] = self.character_coverage + args["model_type"] = self.model_type + args["params"] = self.params + return args + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + new_op.children = copy.deepcopy(self.children, memodict) + new_op.col_names = copy.deepcopy(self.col_names, memodict) + new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) + new_op.vocab_size = copy.deepcopy(self.vocab_size, memodict) + new_op.parent = copy.deepcopy(self.parent, memodict) + new_op.character_coverage = copy.deepcopy(self.character_coverage, memodict) + new_op.params = copy.deepcopy(self.params, memodict) + new_op.vocab = self.vocab + new_op.model_type = copy.deepcopy(self.model_type) + return new_op diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 81314b4373..8641761daa 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -91,7 +91,7 @@ class GraphData: Get nodes from the edges. Args: - edge_list (list or numpy.ndarray): The given list of edges. + edge_list (Union[list, numpy.ndarray]): The given list of edges. Returns: numpy.ndarray: array of nodes. @@ -107,7 +107,7 @@ class GraphData: Get `neighbor_type` neighbors of the nodes in `node_list`. Args: - node_list (list or numpy.ndarray): The given list of nodes. + node_list (Union[list, numpy.ndarray]): The given list of nodes. neighbor_type (int): Specify the type of neighbor. Returns: @@ -137,9 +137,9 @@ class GraphData: 2-hop samling result ...] Args: - node_list (list or numpy.ndarray): The given list of nodes. - neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop. - neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop. + node_list (Union[list, numpy.ndarray]): The given list of nodes. + neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop. + neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. Returns: numpy.ndarray: array of nodes. @@ -164,7 +164,7 @@ class GraphData: Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`. Args: - node_list (list or numpy.ndarray): The given list of nodes. + node_list (Union[list, numpy.ndarray]): The given list of nodes. neg_neighbor_num (int): Number of neighbors sampled. neg_neighbor_type (int): Specify the type of negative neighbor. @@ -191,8 +191,8 @@ class GraphData: Get `feature_types` feature of the nodes in `node_list`. Args: - node_list (list or numpy.ndarray): The given list of nodes. - feature_types (list or ndarray): The given list of feature types. + node_list (Union[list, numpy.ndarray]): The given list of nodes. + feature_types (Union[list, numpy.ndarray]): The given list of feature types. Returns: numpy.ndarray: array of features. @@ -220,8 +220,8 @@ class GraphData: Get `feature_types` feature of the edges in `edge_list`. Args: - edge_list (list or numpy.ndarray): The given list of edges. - feature_types (list or ndarray): The given list of feature types. + edge_list (Union[list, numpy.ndarray]): The given list of edges. + feature_types (Union[list, numpy.ndarray]): The given list of feature types. Returns: numpy.ndarray: array of features. @@ -249,7 +249,7 @@ class GraphData: the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. Returns: - Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, + dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, node_feature_type and edge_feature_type. """ return self._graph.graph_info() diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index a2a23cbb44..e2d810b29a 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -29,7 +29,6 @@ from . import datasets as de ITERATORS_LIST = list() - def _cleanup(): """Release all the Iterator.""" for itr_ref in ITERATORS_LIST: @@ -60,7 +59,6 @@ def _alter_node(node): node.iterator_bootstrap() return node - class Iterator: """ General Iterator over a dataset. @@ -69,10 +67,21 @@ class Iterator: dataset: Dataset to be iterated over """ - def __init__(self, dataset): + def __init__(self, dataset, num_epochs=-1): + self.num_epochs = num_epochs ITERATORS_LIST.append(weakref.ref(self)) # create a copy of tree and work on it. self.dataset = copy.deepcopy(dataset) + self.parent_subtree = [] + + # The dataset passed into the iterator is not the root of the tree. + # Trim the tree by saving the parent subtree into self.parent_subtree and + # restore it after launching our c++ pipeline. + if self.dataset.parent: + logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.") + self.parent_subtree = self.dataset.parent + self.dataset.parent = [] + self.dataset = alter_tree(self.dataset) if not self.__is_tree(): raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") @@ -83,9 +92,17 @@ class Iterator: root = self.__convert_node_postorder(self.dataset) self.depipeline.AssignRootNode(root) - self.depipeline.LaunchTreeExec() + self.depipeline.LaunchTreeExec(self.num_epochs) self._index = 0 + def stop(self): + """ + Manually terminate python iterator instead of relying on out of scope destruction. + """ + logger.info("terminating python iterator. This will also terminate c++ pipeline.") + if hasattr(self, 'depipeline') and self.depipeline: + del self.depipeline + def __is_tree_node(self, node): """Check if a node is tree node.""" if not node.children: @@ -164,8 +181,12 @@ class Iterator: op_type = OpName.TEXTFILE elif isinstance(dataset, de.BuildVocabDataset): op_type = OpName.BUILDVOCAB + elif isinstance(dataset, de.BuildSentencePieceVocabDataset): + op_type = OpName.SENTENCEPIECEVOCAB elif isinstance(dataset, de.CLUEDataset): op_type = OpName.CLUE + elif isinstance(dataset, de.CSVDataset): + op_type = OpName.CSV else: raise ValueError("Unsupported DatasetOp") @@ -173,6 +194,7 @@ class Iterator: # Convert python node into C node and add to C layer execution tree in postorder traversal. def __convert_node_postorder(self, node): + self.check_node_type(node) op_type = self.__get_dataset_type(node) c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args()) @@ -213,9 +235,14 @@ class Iterator: @abstractmethod def get_next(self): - pass + raise RuntimeError("Calling base class Iterator's get_next is invalid.") def __next__(self): + if not self.depipeline: + logger.warning("Iterator does not have a running c++ pipeline." + + "It can be because Iterator stop() had been called, or c++ pipeline crashed silently.") + raise RuntimeError("Iterator does not have a running c++ pipeline.") + data = self.get_next() if not data: if self._index == 0: @@ -224,6 +251,10 @@ class Iterator: self._index += 1 return data + @abstractmethod + def check_node_type(self, node): + pass + def get_output_shapes(self): return [t for t in self.depipeline.GetOutputShapes()] @@ -245,11 +276,27 @@ class Iterator: def __deepcopy__(self, memo): return self +class SaveOp(Iterator): + """ + The derived class of Iterator with dict type. + """ + def get_next(self): + pass + + def check_node_type(self, node): + if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)): + logger.warning("Used shuffle, repeat, batch before save operator.") + + def save(self, file_names, file_type): + return self.depipeline.SaveDataset(file_names, file_type) + class DictIterator(Iterator): """ The derived class of Iterator with dict type. """ + def check_node_type(self, node): + pass def __iter__(self): return self @@ -269,13 +316,15 @@ class TupleIterator(Iterator): """ The derived class of Iterator with list type. """ + def check_node_type(self, node): + pass - def __init__(self, dataset, columns=None): + def __init__(self, dataset, columns=None, num_epochs=-1): if columns is not None: if not isinstance(columns, list): columns = [columns] dataset = dataset.project(columns) - super().__init__(dataset) + super().__init__(dataset, num_epochs) def __iter__(self): return self diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index b74874f9cf..22c0e44d0d 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler): return c_sampler def create_for_minddataset(self): - c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, + self.seed, num_samples) c_child_sampler = self.create_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 8fd3a2bb9b..5c5704497e 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -30,7 +30,7 @@ def serialize(dataset, json_filepath=None): Args: dataset (Dataset): the starting node. - json_filepath (string): a filepath where a serialized json file will be generated. + json_filepath (str): a filepath where a serialized json file will be generated. Returns: dict containing the serialized dataset graph. @@ -63,7 +63,7 @@ def deserialize(input_dict=None, json_filepath=None): Args: input_dict (dict): a python dictionary containing a serialized dataset graph - json_filepath (string): a path to the json file. + json_filepath (str): a path to the json file. Returns: de.Dataset or None if error occurs. @@ -279,7 +279,7 @@ def create_node(node): sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_file'], node.get('columns_list'), node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'), - node.get('shard_id'), node.get('block_reader'), sampler) + node.get('shard_id'), sampler) elif dataset_op == 'TFRecordDataset': pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 29904f1a9e..2c9b97654f 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -246,7 +246,24 @@ def check_celebadataset(method): return new_method +def check_save(method): + """A wrapper that wrap a parameter checker to the save op.""" + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_files'] + nreq_param_str = ['file_name', 'file_type'] + validate_dataset_param_value(nreq_param_int, param_dict, int) + if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): + raise ValueError("num_files should between {} and {}.".format(1, 1000)) + validate_dataset_param_value(nreq_param_str, param_dict, str) + if param_dict.get('file_type') != 'mindrecord': + raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type'))) + return method(self, *args, **kwargs) + + return new_method def check_minddataset(method): """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" @@ -256,11 +273,12 @@ def check_minddataset(method): nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded'] nreq_param_list = ['columns_list'] - nreq_param_bool = ['block_reader'] nreq_param_dict = ['padded_sample'] dataset_file = param_dict.get('dataset_file') if isinstance(dataset_file, list): + if len(dataset_file) > 4096: + raise ValueError("length of dataset_file should less than or equal to {}.".format(4096)) for f in dataset_file: check_file(f) else: @@ -268,7 +286,6 @@ def check_minddataset(method): validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_list, param_dict, list) - validate_dataset_param_value(nreq_param_bool, param_dict, bool) validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) @@ -635,6 +652,25 @@ def check_positive_int32(method): return new_method +def check_device_send(method): + """check the input argument for to_device and device_que.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + param, param_dict = parse_user_args(method, *args, **kwargs) + para_list = list(param_dict.keys()) + if "prefetch_size" in para_list: + if param[0] is not None: + check_pos_int32(param[0], "prefetch_size") + type_check(param[1], (bool,), "send_epoch_end") + else: + type_check(param[0], (bool,), "send_epoch_end") + + return method(self, *args, **kwargs) + + return new_method + + def check_zip(method): """check the input arguments of zip.""" @@ -669,8 +705,7 @@ def check_concat(method): [ds], _ = parse_user_args(method, *args, **kwargs) type_check(ds, (list, datasets.Dataset), "datasets") if isinstance(ds, list): - dataset_names = ["dataset[{0}]".format(i) for i in range(len(ds)) if isinstance(ds, list)] - type_check_list(ds, (datasets.Dataset,), dataset_names) + type_check_list(ds, (datasets.Dataset,), "dataset") return method(self, *args, **kwargs) return new_method @@ -734,8 +769,7 @@ def check_add_column(method): if shape is not None: type_check(shape, (list,), "shape") - shape_names = ["shape[{0}]".format(i) for i in range(len(shape))] - type_check_list(shape, (int,), shape_names) + type_check_list(shape, (int,), "shape") return method(self, *args, **kwargs) @@ -772,6 +806,53 @@ def check_cluedataset(method): return new_method +def check_csvdataset(method): + """A wrapper that wrap a parameter checker to the original Dataset(CSVDataset).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_files; required argument + dataset_files = param_dict.get('dataset_files') + type_check(dataset_files, (str, list), "dataset files") + + # check num_samples + num_samples = param_dict.get('num_samples') + check_value(num_samples, [-1, INT32_MAX], "num_samples") + + # check field_delim + field_delim = param_dict.get('field_delim') + type_check(field_delim, (str,), 'field delim') + if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: + raise ValueError("field_delim is not legal.") + + # check column_defaults + column_defaults = param_dict.get('column_defaults') + if column_defaults is not None: + if not isinstance(column_defaults, list): + raise TypeError("column_defaults should be type of list.") + for item in column_defaults: + if not isinstance(item, (str, int, float)): + raise TypeError("column type is not legal.") + + # check column_names: must be list of string. + column_names = param_dict.get("column_names") + if column_names is not None: + all_string = all(isinstance(item, str) for item in column_names) + if not all_string: + raise TypeError("column_names should be a list of str.") + + validate_dataset_param_value(nreq_param_int, param_dict, int) + check_sampler_shuffle_shard_options(param_dict) + + return method(self, *args, **kwargs) + + return new_method + + def check_textfiledataset(method): """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" @@ -955,6 +1036,7 @@ def check_gnn_random_walk(method): type_check(step_home_param, (float,), "step_home_param") type_check(step_away_param, (float,), "step_away_param") type_check(default_node, (int,), "default_node") + check_value(default_node, (-1, INT32_MAX), "default_node") return method(self, *args, **kwargs) diff --git a/mindspore/dataset/text/__init__.py b/mindspore/dataset/text/__init__.py index 04eb90a0b6..22e426b4db 100644 --- a/mindspore/dataset/text/__init__.py +++ b/mindspore/dataset/text/__init__.py @@ -19,13 +19,16 @@ utils provides some general methods for nlp text processing. """ import platform from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ - ToNumber -from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm + ToNumber, SlidingWindow, SentencePieceTokenizer +from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm, SentencePieceVocab, SentencePieceModel, \ + SPieceTokenizerOutType, SPieceTokenizerLoadType + __all__ = [ "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", "to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", - "PythonTokenizer" + "PythonTokenizer", "SlidingWindow", "SentencePieceVocab", "SentencePieceTokenizer", "SPieceTokenizerOutType", + "SentencePieceModel", "SPieceTokenizerLoadType" ] if platform.system().lower() != 'windows': diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 30fa2b8f42..ec1120e4d3 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -50,11 +50,11 @@ import numpy as np import mindspore._c_dataengine as cde -from .utils import JiebaMode, NormalizeForm, to_str +from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType from .validators import check_lookup, check_jieba_add_dict, \ check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ - check_to_number, check_bert_tokenizer, check_python_tokenizer + check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow from ..core.datatypes import mstype_to_detype @@ -72,6 +72,34 @@ class Lookup(cde.LookupOp): def __init__(self, vocab, unknown_token=None): super().__init__(vocab, unknown_token) +class SlidingWindow(cde.SlidingWindowOp): + """ + TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis + is a slice of data starting at the corresponding position, with a specified width. + + Args: + width (int): The width of the window. Must be an integer and greater than zero. + axis (int, optional): The axis along which sliding window is computed (default=0). + + Examples: + >>> # Data before + >>> # | col1 | + >>> # +-------------+ + >>> # | [1,2,3,4,5] | + >>> # +-------------+ + >>> data = data.map(operations=SlidingWindow(3, 0)) + >>> # Data after + >>> # | col1 | + >>> # +-------------+ + >>> # | [[1,2,3], | + >>> # | [2,3,4], | + >>> # | [3,4,5]] | + >>> # +--------------+ + """ + + @check_slidingwindow + def __init__(self, width, axis=0): + super().__init__(width=width, axis=axis) class Ngram(cde.NgramOp): """ @@ -80,7 +108,7 @@ class Ngram(cde.NgramOp): Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an overview of what n-gram is and how it works. Args: - n (list of int): n in n-gram, n >= 1. n is a list of positive integers, for e.g. n=[4,3], The result + n (list[int]): n in n-gram, n >= 1. n is a list of positive integers, for e.g. n=[4,3], The result would be a 4-gram followed by a 3-gram in the same tensor. If number of words is not enough to make up for a n-gram, an empty string would be returned. For e.g. 3 grams on ["mindspore","best"] would result in an empty string be produced. @@ -171,7 +199,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): Add user defined word to JiebaTokenizer's dictionary. Args: - user_dict (str or dict): Dictionary to be added, file path or Python dictionary, + user_dict (Union[str, dict]): Dictionary to be added, file path or Python dictionary, Python Dict format: {word1:freq1, word2:freq2,...}. Jieba dictionary format : word(required), freq(optional), such as: @@ -296,6 +324,36 @@ class WordpieceTokenizer(cde.WordpieceTokenizerOp): super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token, self.unknown_token, self.with_offsets) +DE_C_INTER_SENTENCEPIECE_LOADTYPE = { + SPieceTokenizerLoadType.FILE: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KFILE, + SPieceTokenizerLoadType.MODEL: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KMODEL +} + +DE_C_INTER_SENTENCEPIECE_OUTTYPE = { + SPieceTokenizerOutType.STRING: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KString, + SPieceTokenizerOutType.INT: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KINT +} + +class SentencePieceTokenizer(cde.SentencePieceTokenizerOp): + """ + Tokenize scalar token or 1-D tokens to tokens by sentencepiece. + + Args: + mode(Union[str, SentencePieceVocab]): If the input parameter is a file, then it is of type string, + if the input parameter is a SentencePieceVocab object, then it is of type SentencePieceVocab. + out_type(Union[str, int]): The type of output. + """ + + def __init__(self, mode, out_type): + self.out_type = out_type + if isinstance(mode, str): + model_path, model_filename = os.path.split(mode) + super().__init__(model_path, model_filename, + DE_C_INTER_SENTENCEPIECE_LOADTYPE[SPieceTokenizerLoadType.FILE], + DE_C_INTER_SENTENCEPIECE_OUTTYPE[out_type]) + elif isinstance(mode, cde.SentencePieceVocab): + super().__init__(mode, DE_C_INTER_SENTENCEPIECE_LOADTYPE[SPieceTokenizerLoadType.MODEL], + DE_C_INTER_SENTENCEPIECE_OUTTYPE[out_type]) if platform.system().lower() != 'windows': class WhitespaceTokenizer(cde.WhitespaceTokenizerOp): diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index ef1d0e6fc5..c6bfe4f14c 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -22,10 +22,11 @@ import copy import numpy as np import mindspore._c_dataengine as cde -from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset +from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ + check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model __all__ = [ - "Vocab", "to_str", "to_bytes" + "Vocab", "SentencePieceVocab", "to_str", "to_bytes" ] @@ -50,7 +51,7 @@ class Vocab(cde.Vocab): Args: dataset(Dataset): dataset to build vocab from. - columns(list of str, optional): column names to get words from. It can be a list of column names. + columns(list[str], optional): column names to get words from. It can be a list of column names. (default=None, where all columns will be used. If any column isn't string type, will return error). freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency=0 is the same as @@ -137,6 +138,77 @@ class Vocab(cde.Vocab): return super().from_dict(word_dict) +class SentencePieceVocab(cde.SentencePieceVocab): + """ + SentencePiece obiect that is used to segmentate words + """ + @classmethod + @check_from_dataset_sentencepiece + def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params): + """ + Build a sentencepiece from a dataset + + Args: + dataset(Dataset): Dataset to build sentencepiece. + col_names(list): The list of the col name. + vocab_size(int): Vocabulary size, the type of uint32_t. + character_coverage(float): Amount of characters covered by the model, good defaults are: 0.9995 for + languages. with rich character set like Japanse or Chinese and 1.0 for other languages with small + character set. + model_type(SentencePieceModel): Choose from unigram (default), bpe, char, or word. The input sentence + must be pretokenized when using word type. + params(dict): A dictionary with no incoming parameters. + + Returns: + SentencePiece, SentencePiece object from dataset. + """ + + vocab = SentencePieceVocab() + root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage, + model_type, params) + for d in root.create_dict_iterator(): + if d is None: + raise ValueError("from_dataset should receive data other than None.") + return vocab + + @classmethod + @check_from_file_sentencepiece + def from_file(cls, file_path, vocab_size, character_coverage, model_type, params): + """ + Build a SentencePiece object from a list of word. + + Args: + file_path(list): Path to the file which contains the sentencepiece list. + vocab_size(int): Vocabulary size, the type of uint32_t. + character_coverage(float): Amount of characters covered by the model, good defaults are: 0.9995 for + languages. with rich character set like Japanse or Chinese and 1.0 for other languages with small + character set. + model_type(SentencePieceModel): Choose from unigram (default), bpe, char, or word. The input sentence + must be pretokenized when using word type. + params(dict): A dictionary with no incoming parameters(The parameters are derived from SentencePiece + library). + + .. code-block:: + + input_sentence_size 0 + max_sentencepiece_length 16 + """ + return super().from_file(file_path, vocab_size, character_coverage, + DE_C_INTER_SENTENCEPIECE_MODE[model_type], params) + + @classmethod + @check_save_model + def save_model(cls, vocab, path, filename): + """ + Save model to filepath + + Args: + vocab(SentencePieceVocab): A sentencepiece object. + path(str): Path to store model. + filename(str): The name of the file. + """ + return super().save_model(vocab, path, filename) + def to_str(array, encoding='utf8'): """ @@ -144,7 +216,7 @@ def to_str(array, encoding='utf8'): Args: array (numpy.ndarray): Array of type `bytes` representing strings. - encoding (string): Indicating the charset for decoding. + encoding (str): Indicating the charset for decoding. Returns: numpy.ndarray, numpy array of `str`. @@ -188,3 +260,27 @@ class NormalizeForm(IntEnum): NFKC = 2 NFD = 3 NFKD = 4 + +class SentencePieceModel(IntEnum): + """An enumeration for SentencePieceModel, effective enumeration types are UNIGRAM, BPE, CHAR, WORD.""" + UNIGRAM = 0 + BPE = 1 + CHAR = 2 + WORD = 3 + +DE_C_INTER_SENTENCEPIECE_MODE = { + SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM, + SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE, + SentencePieceModel.CHAR: cde.SentencePieceModel.DE_SENTENCE_PIECE_CHAR, + SentencePieceModel.WORD: cde.SentencePieceModel.DE_SENTENCE_PIECE_WORD +} + +class SPieceTokenizerOutType(IntEnum): + """An enumeration for SPieceTokenizerOutType, effective enumeration types are STRING, INT.""" + STRING = 0 + INT = 1 + +class SPieceTokenizerLoadType(IntEnum): + """An enumeration for SPieceTokenizerLoadType, effective enumeration types are FILE, MODEL.""" + FILE = 0 + MODEL = 1 diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index b0327f5609..ef04ce7f2c 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ - INT32_MAX, check_value, check_positive + INT32_MAX, check_value, check_positive, check_pos_int32 def check_unique_list_of_words(words, arg_name): @@ -67,7 +67,7 @@ def check_from_file(method): check_unique_list_of_words(special_tokens, "special_tokens") type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) if vocab_size is not None: - check_value(vocab_size, (-1, INT32_MAX), "vocab_size") + check_positive(vocab_size, "vocab_size") type_check(special_first, (bool,), special_first) return method(self, *args, **kwargs) @@ -297,8 +297,7 @@ def check_from_dataset(method): if columns is not None: if not isinstance(columns, list): columns = [columns] - col_names = ["col_{0}".format(i) for i in range(len(columns))] - type_check_list(columns, (str,), col_names) + type_check_list(columns, (str,), "col") if freq_range is not None: type_check(freq_range, (tuple,), "freq_range") @@ -328,6 +327,17 @@ def check_from_dataset(method): return new_method +def check_slidingwindow(method): + """A wrapper that wrap a parameter checker to the original function(sliding window operation).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [width, axis], _ = parse_user_args(method, *args, **kwargs) + check_pos_int32(width, "width") + type_check(axis, (int,), "axis") + return method(self, *args, **kwargs) + + return new_method def check_ngram(method): """A wrapper that wraps a parameter checker to the original function.""" @@ -409,3 +419,81 @@ def check_python_tokenizer(method): return method(self, *args, **kwargs) return new_method + + +def check_from_dataset_sentencepiece(method): + """A wrapper that wraps a parameter checker to the original function (from_dataset).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) + + if col_names is not None: + type_check(col_names, (list,), "col_names") + + if vocab_size is not None: + check_uint32(vocab_size, "vocab_size") + + if character_coverage is not None: + type_check(character_coverage, (float,), "character_coverage") + + if model_type is not None: + from .utils import SentencePieceModel + type_check(model_type, (str, SentencePieceModel), "model_type") + + if params is not None: + type_check(params, (dict,), "params") + + return method(self, *args, **kwargs) + + return new_method + + +def check_from_file_sentencepiece(method): + """A wrapper that wraps a parameter checker to the original function (from_file).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) + + if file_path is not None: + type_check(file_path, (list,), "file_path") + + if vocab_size is not None: + check_uint32(vocab_size, "vocab_size") + + if character_coverage is not None: + type_check(character_coverage, (float,), "character_coverage") + + if model_type is not None: + from .utils import SentencePieceModel + type_check(model_type, (str, SentencePieceModel), "model_type") + + if params is not None: + type_check(params, (dict,), "params") + + return method(self, *args, **kwargs) + + return new_method + + +def check_save_model(method): + """A wrapper that wraps a parameter checker to the original function (save_model).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs) + + if vocab is not None: + type_check(vocab, (cde.SentencePieceVocab,), "vocab") + + if path is not None: + type_check(path, (str,), "path") + + if filename is not None: + type_check(filename, (str,), "filename") + + return method(self, *args, **kwargs) + + return new_method + \ No newline at end of file diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 62496822e5..41a31e7c50 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype import mindspore._c_dataengine as cde from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \ - check_pad_end, check_concat_type + check_pad_end, check_concat_type, check_random_transform_ops from ..core.datatypes import mstype_to_detype @@ -46,7 +46,7 @@ class Fill(cde.FillOp): The output tensor will have the same shape and type as the input tensor. Args: - fill_value (python types (str, bytes, int, float, or bool)) : scalar value + fill_value (Union[str, bytes, int, float, bool])) : scalar value to fill created tensor with. """ @@ -78,11 +78,11 @@ class Slice(cde.SliceOp): (Currently only rank-1 tensors are supported). Args: - *slices(Variable length argument list, supported types are, int, list(int), slice, None or Ellipses): + slices(Union[int, list(int), slice, None, Ellipses]): Maximum `n` number of arguments to slice a tensor of rank `n`. One object in slices can be one of: 1. :py:obj:`int`: Slice this index only. Negative index is supported. - 2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supdeported. + 2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supported. 3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`. 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in python indexing. 5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in python indexing. @@ -139,9 +139,9 @@ class Mask(cde.MaskOp): Args: operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE - constant (python types (str, int, float, or bool): constant to be compared to. + constant (Union[str, int, float, bool]): constant to be compared to. Constant will be casted to the type of the input tensor - dtype (optional, mindspore.dtype): type of the generated mask. Default to bool + dtype (mindspore.dtype, optional): type of the generated mask. Default to bool Examples: >>> # Data before @@ -171,7 +171,7 @@ class PadEnd(cde.PadEndOp): Args: pad_shape (list(int)): list on integers representing the shape needed. Dimensions that set to `None` will not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values. - pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty + pad_value (Union[str, bytes, int, float, bool]), optional): value used to pad. Default to 0 or empty string in case of Tensors of strings. Examples: @@ -201,8 +201,8 @@ class Concatenate(cde.ConcatenateOp): Args: axis (int, optional): axis to concatenate the tensors along (Default=0). - prepend (np.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None). - append (np.array, optional): numpy array to be appended to the already concatenated tensors (Default=None). + prepend (numpy.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None). + append (numpy.array, optional): numpy array to be appended to the already concatenated tensors (Default=None). """ @check_concat_type @@ -232,3 +232,55 @@ class Duplicate(cde.DuplicateOp): >>> # | [1,2,3] | [1,2,3] | >>> # +---------+---------+ """ + + +class Compose(cde.ComposeOp): + """ + Compose a list of transforms into a single transform. + + Args: + transforms (list): List of transformations to be applied. + + Examples: + >>> compose = Compose([vision.Decode(), vision.RandomCrop()]) + >>> dataset = ds.map(operations=compose) + """ + + @check_random_transform_ops + def __init__(self, transforms): + super().__init__(transforms) + + +class RandomApply(cde.RandomApplyOp): + """ + Randomly performs a series of transforms with a given probability. + + Args: + transforms (list): List of transformations to be applied. + prob (float, optional): The probability to apply the transformation list (default=0.5) + + Examples: + >>> rand_apply = RandomApply([vision.RandomCrop()]) + >>> dataset = ds.map(operations=rand_apply) + """ + + @check_random_transform_ops + def __init__(self, transforms, prob=0.5): + super().__init__(prob, transforms) + + +class RandomChoice(cde.RandomChoiceOp): + """ + Randomly selects one transform from a list of transforms to perform operation. + + Args: + transforms (list): List of transformations to be chosen from to apply. + + Examples: + >>> rand_choice = RandomChoice([vision.CenterCrop(), vision.RandomCrop()]) + >>> dataset = ds.map(operations=rand_choice) + """ + + @check_random_transform_ops + def __init__(self, transforms): + super().__init__(transforms) diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 9fe0fa5f10..f44fd918ee 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -18,7 +18,8 @@ from functools import wraps import numpy as np from mindspore._c_expression import typing -from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive +from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \ + check_tensor_op # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 @@ -180,3 +181,22 @@ def check_concat_type(method): return method(self, *args, **kwargs) return new_method + + +def check_random_transform_ops(method): + """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + arg_list, _ = parse_user_args(method, *args, **kwargs) + type_check(arg_list[0], (list,), "op_list") + if not arg_list[0]: + raise ValueError("op_list can not be empty.") + for ind, op in enumerate(arg_list[0]): + check_tensor_op(op, "op_list[{0}]".format(ind)) + if len(arg_list) == 2: # random apply takes an additional arg + type_check(arg_list[1], (float, int), "prob") + check_value(arg_list[1], (0, 1), "prob") + return method(self, *args, **kwargs) + + return new_method diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 8e3b7c7214..9a07c58a19 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -47,7 +47,7 @@ from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ - FLOAT_MAX_INTEGER + check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -71,6 +71,38 @@ def parse_padding(padding): return padding +class AutoContrast(cde.AutoContrastOp): + """ + Apply auto contrast on input image. + + Args: + cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). + ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). + """ + + @check_auto_contrast + def __init__(self, cutoff=0.0, ignore=None): + if ignore is None: + ignore = [] + if isinstance(ignore, int): + ignore = [ignore] + super().__init__(cutoff, ignore) + + +class Equalize(cde.EqualizeOp): + """ + Apply histogram equalization on input image. + does not have input arguments. + """ + + +class Invert(cde.InvertOp): + """ + Apply invert on input image in RGB mode. + does not have input arguments. + """ + + class Decode(cde.DecodeOp): """ Decode the input image in RGB mode. @@ -119,10 +151,10 @@ class RandomCrop(cde.RandomCropOp): Crop the input image at a random location. Args: - size (int or sequence): The output size of the cropped image. + size (Union[int, sequence]): The output size of the cropped image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). - padding (int or sequence, optional): The number of pixels to pad the image (default=None). + padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None). If padding is not None, pad image firstly with padding values. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) @@ -131,7 +163,7 @@ class RandomCrop(cde.RandomCropOp): it pads the left, top, right and bottom respectively. pad_if_needed (bool, optional): Pad the image if either side is smaller than the given output size (default=False). - fill_value (int or tuple, optional): The pixel intensity of the borders if + fill_value (Union[int, tuple], optional): The pixel intensity of the borders if the padding_mode is Border.CONSTANT (default=0). If it is a 3-tuple, it is used to fill R, G, B channels respectively. padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT). Can be any of @@ -174,10 +206,10 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp): Crop the input image at a random location and adjust bounding boxes accordingly. Args: - size (int or sequence): The output size of the cropped image. + size (Union[int, sequence]): The output size of the cropped image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). - padding (int or sequence, optional): The number of pixels to pad the image (default=None). + padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None). If padding is not None, pad image firstly with padding values. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) @@ -185,7 +217,7 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp): If 4 values are provided as a list or tuple,it pads the left, top, right and bottom respectively. pad_if_needed (bool, optional): Pad the image if either side is smaller than the given output size (default=False). - fill_value (int or tuple, optional): The pixel intensity of the borders if + fill_value (Union[int, tuple], optional): The pixel intensity of the borders if the padding_mode is Border.CONSTANT (default=0). If it is a 3-tuple, it is used to fill R, G, B channels respectively. padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT). Can be any of @@ -303,7 +335,7 @@ class Resize(cde.ResizeOp): Resize the input image to the given size. Args: - size (int or sequence): The output size of the resized image. + size (Union[int, sequence]): The output size of the resized image. If size is an int, smaller edge of the image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2, it should be (height, width). @@ -319,8 +351,6 @@ class Resize(cde.ResizeOp): @check_resize_interpolation def __init__(self, size, interpolation=Inter.LINEAR): - if isinstance(size, int): - size = (size, size) self.size = size self.interpolation = interpolation interpoltn = DE_C_INTER_MODE[interpolation] @@ -335,7 +365,7 @@ class ResizeWithBBox(cde.ResizeWithBBoxOp): Resize the input image to the given size and adjust bounding boxes accordingly. Args: - size (int or sequence): The output size of the resized image. + size (Union[int, sequence]): The output size of the resized image. If size is an int, smaller edge of the image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2, it should be (height, width). @@ -365,7 +395,7 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): Crop the input image to a random size and aspect ratio and adjust bounding boxes accordingly. Args: - size (int or sequence): The size of the output image. + size (Union[int, sequence]): The size of the output image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). scale (tuple, optional): Range (min, max) of respective size of the original @@ -404,7 +434,7 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp): Crop the input image to a random size and aspect ratio. Args: - size (int or sequence): The size of the output image. + size (Union[int, sequence]): The size of the output image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). scale (tuple, optional): Range (min, max) of respective size of the original @@ -443,7 +473,7 @@ class CenterCrop(cde.CenterCropOp): Crops the input image at the center to the given size. Args: - size (int or sequence): The output size of the cropped image. + size (Union[int, sequence]): The output size of the cropped image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). """ @@ -461,16 +491,16 @@ class RandomColorAdjust(cde.RandomColorAdjustOp): Randomly adjust the brightness, contrast, saturation, and hue of the input image. Args: - brightness (float or tuple, optional): Brightness adjustment factor (default=(1, 1)). Cannot be negative. + brightness (Union[float, tuple], optional): Brightness adjustment factor (default=(1, 1)). Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness]. If it is a sequence, it should be [min, max] for the range. - contrast (float or tuple, optional): Contrast adjustment factor (default=(1, 1)). Cannot be negative. + contrast (Union[float, tuple], optional): Contrast adjustment factor (default=(1, 1)). Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast]. If it is a sequence, it should be [min, max] for the range. - saturation (float or tuple, optional): Saturation adjustment factor (default=(1, 1)). Cannot be negative. + saturation (Union[float, tuple], optional): Saturation adjustment factor (default=(1, 1)). Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation]. If it is a sequence, it should be [min, max] for the range. - hue (float or tuple, optional): Hue adjustment factor (default=(0, 0)). + hue (Union[float, tuple], optional): Hue adjustment factor (default=(0, 0)). If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5. If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5. """ @@ -503,7 +533,7 @@ class RandomRotation(cde.RandomRotationOp): Rotate the input image by a random angle. Args: - degrees (int or float or sequence): Range of random rotation degrees. + degrees (Union[int, float, sequence): Range of random rotation degrees. If degrees is a number, the range will be converted to (-degrees, degrees). If degrees is a sequence, it should be (min, max). resample (Inter mode, optional): An optional resampling filter (default=Inter.NEAREST). @@ -522,7 +552,8 @@ class RandomRotation(cde.RandomRotationOp): Note that the expand flag assumes rotation around the center and no translation. center (tuple, optional): Optional center of rotation (a 2-tuple) (default=None). Origin is the top left corner. None sets to the center of the image. - fill_value (int or tuple, optional): Optional fill color for the area outside the rotated image (default=0). + fill_value (Union[int, tuple], optional): Optional fill color for the area outside the rotated image + (default=0). If it is a 3-tuple, it is used for R, G, B channels respectively. If it is an int, it is used for all RGB channels. """ @@ -565,7 +596,7 @@ class RandomResize(cde.RandomResizeOp): Tensor operation to resize the input image using a randomly selected interpolation mode. Args: - size (int or sequence): The output size of the resized image. + size (Union[int, sequence]): The output size of the resized image. If size is an int, smaller edge of the image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2, it should be (height, width). @@ -586,7 +617,7 @@ class RandomResizeWithBBox(cde.RandomResizeWithBBoxOp): bounding boxes accordingly. Args: - size (int or sequence): The output size of the resized image. + size (Union[int, sequence]): The output size of the resized image. If size is an int, smaller edge of the image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2, it should be (height, width). @@ -612,7 +643,7 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): Equivalent to RandomResizedCrop, but crops before decodes. Args: - size (int or sequence, optional): The size of the output image. + size (Union[int, sequence], optional): The size of the output image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). scale (tuple, optional): Range (min, max) of respective size of the @@ -651,13 +682,13 @@ class Pad(cde.PadOp): Pads the image according to padding parameters. Args: - padding (int or sequence): The number of pixels to pad the image. + padding (Union[int, sequence]): The number of pixels to pad the image. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) with the first value and (right and bottom) with the second value. If 4 values are provided as a list or tuple, it pads the left, top, right and bottom respectively. - fill_value (int or tuple, optional): The pixel intensity of the borders if + fill_value (Union[int, tuple], optional): The pixel intensity of the borders if the padding_mode is Border.CONSTANT (default=0). If it is a 3-tuple, it is used to fill R, G, B channels respectively. padding_mode (Border mode): The method of padding (default=Border.CONSTANT). Can be any of @@ -692,7 +723,7 @@ class UniformAugment(cde.UniformAugOp): Tensor operation to perform randomly selected augmentation. Args: - operations: list of C++ operations (python OPs are not accepted). + transforms: list of C++ operations (python OPs are not accepted). num_ops (int, optional): number of OPs to be selected and applied (default=2). Examples: @@ -700,7 +731,7 @@ class UniformAugment(cde.UniformAugOp): >>> c_transforms.RandomVerticalFlip(), >>> c_transforms.RandomColorAdjust(), >>> c_transforms.RandomRotation(degrees=45)] - >>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2) + >>> uni_aug = c_transforms.UniformAugment(transforms=transforms_list, num_ops=2) >>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]), >>> uni_aug, F.ToTensor()] >>> ds_ua = ds.map(input_columns="image", @@ -708,7 +739,28 @@ class UniformAugment(cde.UniformAugOp): """ @check_uniform_augment_cpp - def __init__(self, operations, num_ops=2): - self.operations = operations + def __init__(self, transforms, num_ops=2): + self.transforms = transforms self.num_ops = num_ops - super().__init__(operations, num_ops) + super().__init__(transforms, num_ops) + + +class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp): + """ + Choose a random sub-policy from a list to be applied on the input image. A sub-policy is a list of tuples + (op, prob), where op is a TensorOp operation and prob is the probability that this op will be applied. Once + a sub-policy is selected, each op within the subpolicy with be applied in sequence according to its probability + + Args: + policy (list(list(tuple(TensorOp,float))): List of sub-policies to choose from. + + Examples: + >>> policy = [[(c_vision.RandomRotation((45, 45)), 0.5), (c_transforms.RandomVerticalFlip(), 1), + >>> (c_transforms.RandomColorAdjust(), 0.8)], + >>> [(c_vision.RandomRotation((90, 90)), 1), (c_transforms.RandomColorAdjust(), 0.2)]] + >>> ds_policy = ds.map(input_columns=["image"], operations=visions.RandomSelectSubpolicy(policy)) + """ + + @check_random_select_subpolicy_op + def __init__(self, policy): + super().__init__(policy) diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index 3bfd6b0644..9e4be238c2 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -33,7 +33,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \ - check_mix_up, check_positive_degrees, check_uniform_augment_py, check_compose_list + check_mix_up, check_positive_degrees, check_uniform_augment_py, check_compose_list, check_auto_contrast from .utils import Inter, Border DE_PY_INTER_MODE = {Inter.NEAREST: Image.NEAREST, @@ -129,7 +129,7 @@ class ToType: Convert the input Numpy image array to desired numpy dtype. Args: - output_type (numpy datatype): The datatype of the numpy output. e.g. np.float32. + output_type (numpy datatype): The datatype of the numpy output, e.g. numpy.float32. Examples: >>> import numpy as np @@ -260,10 +260,10 @@ class RandomCrop: Crop the input PIL Image at a random location. Args: - size (int or sequence): The output size of the cropped image. + size (Union[int, sequence]): The output size of the cropped image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). - padding (int or sequence, optional): The number of pixels to pad the image (default=None). + padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None). If padding is not None, pad image firstly with padding values. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) @@ -385,7 +385,7 @@ class Resize: Resize the input PIL Image to the given size. Args: - size (int or sequence): The output size of the resized image. + size (Union[int, sequence]): The output size of the resized image. If size is an int, smaller edge of the image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2, it should be (height, width). @@ -427,7 +427,7 @@ class RandomResizedCrop: Extract crop from the input image and resize it to a random size and aspect ratio. Args: - size (int or sequence): The size of the output image. + size (Union[int, sequence]): The size of the output image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). scale (tuple, optional): Range (min, max) of respective size of the original size @@ -479,7 +479,7 @@ class CenterCrop: Crop the central reigion of the input PIL Image to the given size. Args: - size (int or sequence): The output size of the cropped image. + size (Union[int, sequence]): The output size of the cropped image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). @@ -511,16 +511,16 @@ class RandomColorAdjust: Perform a random brightness, contrast, saturation, and hue adjustment on the input PIL image. Args: - brightness (float or tuple, optional): Brightness adjustment factor (default=(1, 1)). Cannot be negative. + brightness (Union[float, tuple], optional): Brightness adjustment factor (default=(1, 1)). Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness]. If it is a sequence, it should be [min, max] for the range. - contrast (float or tuple, optional): Contrast adjustment factor (default=(1, 1)). Cannot be negative. + contrast (Union[float, tuple], optional): Contrast adjustment factor (default=(1, 1)). Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast]. If it is a sequence, it should be [min, max] for the range. - saturation (float or tuple, optional): Saturation adjustment factor (default=(1, 1)). Cannot be negative. + saturation (Union[float, tuple], optional): Saturation adjustment factor (default=(1, 1)). Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation]. If it is a sequence, it should be [min, max] for the range. - hue (float or tuple, optional): Hue adjustment factor (default=(0, 0)). + hue (Union[float, tuple], optional): Hue adjustment factor (default=(0, 0)). If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5. If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5. @@ -558,7 +558,7 @@ class RandomRotation: See https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate. Args: - degrees (int or float or sequence): Range of random rotation degrees. + degrees (Union[int, float, sequence]): Range of random rotation degrees. If degrees is a number, the range will be converted to (-degrees, degrees). If degrees is a sequence, it should be (min, max). resample (Inter mode, optional): An optional resampling filter (default=Inter.NEAREST). @@ -743,7 +743,7 @@ class TenCrop: Generate 10 cropped images (first 5 from FiveCrop, second 5 from their flipped version). Args: - size (int or sequence): The output size of the crop. + size (Union[int, sequence]): The output size of the crop. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). use_vertical_flip (bool, optional): Flip the image vertically instead of horizontally @@ -853,13 +853,13 @@ class Pad: Pad the input PIL image according to padding parameters. Args: - padding (int or sequence): The number of pixels to pad the image. + padding (Union[int, sequence]): The number of pixels to pad the image. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) with the first value and (right and bottom) with the second value. If 4 values are provided as a list or tuple, it pads the left, top, right and bottom respectively. - fill_value (int or tuple, optional): Filling value (default=0). The pixel intensity + fill_value (Union[int, tuple], optional): Filling value (default=0). The pixel intensity of the borders if the padding_mode is Border.CONSTANT. If it is a 3-tuple, it is used to fill R, G, B channels respectively. padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT). @@ -961,7 +961,7 @@ class RandomErasing: original image (default=(0.02, 0.33)). ratio (sequence of floats, optional): Range of the aspect ratio of the erase area (default=(0.3, 3.3)). - value (int or sequence): Erasing value (default=0). + value (Union[int, sequence]): Erasing value (default=0). If value is a single int, it is applied to all pixels to be erases. If value is a sequence of length 3, it is applied to R, G, B channels respectively. If value is a str 'random', the erase value will be obtained from a standard normal distribution. @@ -1088,7 +1088,7 @@ class RandomAffine: Apply Random affine transformation to the input PIL image. Args: - degrees (int or float or sequence): Range of the rotation degrees. + degrees (Union[int, float, sequence]): Range of the rotation degrees. If degrees is a number, the range will be (-degrees, degrees). If degrees is a sequence, it should be (min, max). translate (sequence, optional): Sequence (tx, ty) of maximum translation in @@ -1097,7 +1097,7 @@ class RandomAffine: (-tx*width, tx*width) and (-ty*height, ty*height), respectively. If None, no translations gets applied. scale (sequence, optional): Scaling factor interval (default=None, original scale is used). - shear (int or float or sequence, optional): Range of shear factor (default=None). + shear (Union[int, float, sequence], optional): Range of shear factor (default=None). If a number 'shear', then a shear parallel to the x axis in the range of (-shear, +shear) is applied. If a tuple or list of size 2, then a shear parallel to the x axis in the range of (shear[0], shear[1]) is applied. @@ -1114,7 +1114,7 @@ class RandomAffine: - Inter.BICUBIC, means resample method is bicubic interpolation. - fill_value (tuple or int, optional): Optional fill_value to fill the area outside the transform + fill_value (Union[tuple, int], optional): Optional fill_value to fill the area outside the transform in the output image. Used only in Pillow versions > 5.0.0 (default=0, filling is performed). Raises: @@ -1361,6 +1361,10 @@ class AutoContrast: """ Automatically maximize the contrast of the input PIL image. + Args: + cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). + ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). + Examples: >>> py_transforms.ComposeOp([py_transforms.Decode(), >>> py_transforms.AutoContrast(), @@ -1368,6 +1372,11 @@ class AutoContrast: """ + @check_auto_contrast + def __init__(self, cutoff=0.0, ignore=None): + self.cutoff = cutoff + self.ignore = ignore + def __call__(self, img): """ Call method. @@ -1379,7 +1388,7 @@ class AutoContrast: img (PIL Image), Augmented image. """ - return util.auto_contrast(img) + return util.auto_contrast(img, self.cutoff, self.ignore) class Invert: diff --git a/mindspore/dataset/transforms/vision/py_transforms_util.py b/mindspore/dataset/transforms/vision/py_transforms_util.py index d076109ff4..89aa230039 100644 --- a/mindspore/dataset/transforms/vision/py_transforms_util.py +++ b/mindspore/dataset/transforms/vision/py_transforms_util.py @@ -148,7 +148,7 @@ def to_tensor(img, output_type): Change the input image (PIL Image or Numpy image array) to numpy format. Args: - img (PIL Image or numpy.ndarray): Image to be converted. + img (Union[PIL Image, numpy.ndarray]): Image to be converted. output_type: The datatype of the numpy output. e.g. np.float32 Returns: @@ -284,7 +284,7 @@ def resize(img, size, interpolation=Inter.BILINEAR): Args: img (PIL Image): Image to be resized. - size (int or sequence): The output size of the resized image. + size (Union[int, sequence]): The output size of the resized image. If size is an int, smaller edge of the image will be resized to this value with the same image aspect ratio. If size is a sequence of (height, width), this will be the desired output size. @@ -321,7 +321,7 @@ def center_crop(img, size): Args: img (PIL Image): Image to be cropped. - size (int or tuple): The size of the crop box. + size (Union[int, tuple]): The size of the crop box. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). @@ -346,7 +346,7 @@ def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, ma Args: img (PIL Image): Image to be randomly cropped and resized. - size (int or sequence): The size of the output image. + size (Union[int, sequence]): The size of the output image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). scale (tuple): Range (min, max) of respective size of the original size to be cropped. @@ -416,10 +416,10 @@ def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode): Args: img (PIL Image): Image to be randomly cropped. - size (int or sequence): The output size of the cropped image. + size (Union[int, sequence]): The output size of the cropped image. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). - padding (int or sequence, optional): The number of pixels to pad the image. + padding (Union[int, sequence], optional): The number of pixels to pad the image. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) with the first value and (right and bottom) with the second value. @@ -428,7 +428,7 @@ def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode): Default is None. pad_if_needed (bool): Pad the image if either side is smaller than the given output size. Default is False. - fill_value (int or tuple): The pixel intensity of the borders if + fill_value (Union[int, tuple]): The pixel intensity of the borders if the padding_mode is 'constant'. If it is a 3-tuple, it is used to fill R, G, B channels respectively. padding_mode (str): The method of padding. Can be any of @@ -602,7 +602,7 @@ def rotate(img, angle, resample, expand, center, fill_value): Args: img (PIL Image): Image to be rotated. angle (int or float): Rotation angle in degrees, counter-clockwise. - resample (Inter.NEAREST, or Inter.BILINEAR, Inter.BICUBIC, optional): An optional resampling filter. + resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST. expand (bool, optional): Optional expansion flag. If set to True, expand the output image to make it large enough to hold the entire rotated image. @@ -610,7 +610,7 @@ def rotate(img, angle, resample, expand, center, fill_value): Note that the expand flag assumes rotation around the center and no translation. center (tuple, optional): Optional center of rotation (a 2-tuple). Origin is the top left corner. - fill_value (int or tuple): Optional fill color for the area outside the rotated image. + fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image. If it is a 3-tuple, it is used for R, G, B channels respectively. If it is an int, it is used for all RGB channels. @@ -634,16 +634,16 @@ def random_color_adjust(img, brightness, contrast, saturation, hue): Args: img (PIL Image): Image to have its color adjusted randomly. - brightness (float or tuple): Brightness adjustment factor. Cannot be negative. + brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness]. If it is a sequence, it should be [min, max] for the range. - contrast (float or tuple): Contrast adjustment factor. Cannot be negative. + contrast (Union[float, tuple]): Contrast adjustment factor. Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast]. If it is a sequence, it should be [min, max] for the range. - saturation (float or tuple): Saturation adjustment factor. Cannot be negative. + saturation (Union[float, tuple]): Saturation adjustment factor. Cannot be negative. If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation]. If it is a sequence, it should be [min, max] for the range. - hue (float or tuple): Hue adjustment factor. + hue (Union[float, tuple]): Hue adjustment factor. If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5. If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5. @@ -696,10 +696,10 @@ def random_rotation(img, degrees, resample, expand, center, fill_value): Args: img (PIL Image): Image to be rotated. - degrees (int or float or sequence): Range of random rotation degrees. + degrees (Union[int, float, sequence]): Range of random rotation degrees. If degrees is a number, the range will be converted to (-degrees, degrees). If degrees is a sequence, it should be (min, max). - resample (Inter.NEAREST, or Inter.BILINEAR, Inter.BICUBIC, optional): An optional resampling filter. + resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST. expand (bool, optional): Optional expansion flag. If set to True, expand the output image to make it large enough to hold the entire rotated image. @@ -707,7 +707,7 @@ def random_rotation(img, degrees, resample, expand, center, fill_value): Note that the expand flag assumes rotation around the center and no translation. center (tuple, optional): Optional center of rotation (a 2-tuple). Origin is the top left corner. - fill_value (int or tuple): Optional fill color for the area outside the rotated image. + fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image. If it is a 3-tuple, it is used for R, G, B channels respectively. If it is an int, it is used for all RGB channels. @@ -789,7 +789,7 @@ def five_crop(img, size): Args: img (PIL Image): PIL Image to be cropped. - size (int or sequence): The output size of the crop. + size (Union[int, sequence]): The output size of the crop. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). @@ -829,7 +829,7 @@ def ten_crop(img, size, use_vertical_flip=False): Args: img (PIL Image): PIL Image to be cropped. - size (int or sequence): The output size of the crop. + size (Union[int, sequence]): The output size of the crop. If size is an int, a square crop of size (size, size) is returned. If size is a sequence of length 2, it should be (height, width). use_vertical_flip (bool): Flip the image vertically instead of horizontally if set to True. @@ -895,14 +895,14 @@ def pad(img, padding, fill_value, padding_mode): Args: img (PIL Image): Image to be padded. - padding (int or sequence, optional): The number of pixels to pad the image. + padding (Union[int, sequence], optional): The number of pixels to pad the image. If a single number is provided, it pads all borders with this value. If a tuple or list of 2 values are provided, it pads the (left and top) with the first value and (right and bottom) with the second value. If 4 values are provided as a list or tuple, it pads the left, top, right and bottom respectively. Default is None. - fill_value (int or tuple): The pixel intensity of the borders if + fill_value (Union[int, tuple]): The pixel intensity of the borders if the padding_mode is "constant". If it is a 3-tuple, it is used to fill R, G, B channels respectively. padding_mode (str): The method of padding. Can be any of @@ -1137,12 +1137,12 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0 Args: img (PIL Image): Image to be applied affine transformation. - angle (int or float): Rotation angle in degrees, clockwise. + angle (Union[int, float]): Rotation angle in degrees, clockwise. translations (sequence): Translations in horizontal and vertical axis. scale (float): Scale parameter, a single number. - shear (float or sequence): Shear amount parallel to x and y axis. - resample (Inter.NEAREST, or Inter.BILINEAR, Inter.BICUBIC, optional): An optional resampling filter. - fill_value (tuple or int, optional): Optional fill_value to fill the area outside the transform + shear (Union[float, sequence]): Shear amount parallel to x and y axis. + resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter. + fill_value (Union[tuple int], optional): Optional fill_value to fill the area outside the transform in the output image. Used only in Pillow versions > 5.0.0. If None, no filling is performed. @@ -1457,13 +1457,15 @@ def random_sharpness(img, degrees): return ImageEnhance.Sharpness(img).enhance(v) -def auto_contrast(img): +def auto_contrast(img, cutoff, ignore): """ Automatically maximize the contrast of the input PIL image. Args: img (PIL Image): Image to be augmented with AutoContrast. + cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). + ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). Returns: img (PIL Image), Augmented image. @@ -1473,7 +1475,7 @@ def auto_contrast(img): if not is_pil(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return ImageOps.autocontrast(img) + return ImageOps.autocontrast(img, cutoff, ignore) def invert_color(img): diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index 4cb6613359..f140673f31 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp from .utils import Inter, Border from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ - check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list + check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, check_tensor_op def check_crop_size(size): @@ -78,6 +78,8 @@ def check_fill_value(fill_value): def check_padding(padding): """Parsing the padding arguments and check if it is legal.""" type_check(padding, (tuple, list, numbers.Number), "padding") + if isinstance(padding, numbers.Number): + check_value(padding, (0, INT32_MAX), "padding") if isinstance(padding, (tuple, list)): if len(padding) not in (2, 4): raise ValueError("The size of the padding list or tuple should be 2 or 4.") @@ -92,7 +94,11 @@ def check_degrees(degrees): if isinstance(degrees, numbers.Number): check_value(degrees, (0, float("inf")), "degrees") elif isinstance(degrees, (list, tuple)): - if len(degrees) != 2: + if len(degrees) == 2: + type_check_list(degrees, (numbers.Number,), "degrees") + if degrees[0] > degrees[1]: + raise ValueError("degrees should be in (min,max) format. Got (max,min).") + else: raise TypeError("If degrees is a sequence, the length must be 2.") @@ -104,6 +110,8 @@ def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT raise ValueError("The input value of {} cannot be negative.".format(input_name)) elif isinstance(value, (list, tuple)) and len(value) == 2: check_range(value, bound) + if value[0] > value[1]: + raise ValueError("value should be in (min,max) format. Got (max,min).") def check_erasing_value(value): @@ -163,10 +171,17 @@ def check_random_resize_crop(method): check_crop_size(size) if scale is not None: + type_check(scale, (tuple,), "scale") + type_check_list(scale, (float, int), "scale") check_range(scale, [0, FLOAT_MAX_INTEGER]) + if scale[0] > scale[1]: + raise ValueError("scale should be in (min,max) format. Got (max,min).") if ratio is not None: + type_check(ratio, (tuple,), "ratio") + type_check_list(ratio, (float, int), "ratio") check_range(ratio, [0, FLOAT_MAX_INTEGER]) - check_positive(ratio[0], "ratio[0]") + if ratio[0] > ratio[1]: + raise ValueError("ratio should be in (min,max) format. Got (max,min).") if interpolation is not None: type_check(interpolation, (Inter,), "interpolation") if max_attempts is not None: @@ -450,8 +465,7 @@ def check_random_affine(method): if translate is not None: if type_check(translate, (list, tuple), "translate"): - translate_names = ["translate_{0}".format(i) for i in range(len(translate))] - type_check_list(translate, (int, float), translate_names) + type_check_list(translate, (int, float), "translate") if len(translate) != 2: raise TypeError("translate should be a list or tuple of length 2.") for i, t in enumerate(translate): @@ -473,7 +487,7 @@ def check_random_affine(method): if len(shear) not in (2, 4): raise TypeError("shear must be of length 2 or 4.") - type_check(resample, (Inter,), "resample") + type_check(resample, (Inter,), "resample") if fill_value is not None: check_fill_value(fill_value) @@ -502,14 +516,13 @@ def check_uniform_augment_cpp(method): @wraps(method) def new_method(self, *args, **kwargs): - [operations, num_ops], _ = parse_user_args(method, *args, **kwargs) + [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) type_check(num_ops, (int,), "num_ops") check_positive(num_ops, "num_ops") - if num_ops > len(operations): - raise ValueError("num_ops is greater than operations list size") - tensor_ops = ["tensor_op_{0}".format(i) for i in range(len(operations))] - type_check_list(operations, (TensorOp,), tensor_ops) + if num_ops > len(transforms): + raise ValueError("num_ops is greater than transforms list size") + type_check_list(transforms, (TensorOp,), "tensor_ops") return method(self, *args, **kwargs) @@ -530,6 +543,27 @@ def check_bounding_box_augment_cpp(method): return new_method +def check_auto_contrast(method): + """Wrapper method to check the parameters of AutoContrast ops (python and cpp).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs) + type_check(cutoff, (int, float), "cutoff") + check_value(cutoff, [0, 100], "cutoff") + if ignore is not None: + type_check(ignore, (list, tuple, int), "ignore") + if isinstance(ignore, int): + check_value(ignore, [0, 255], "ignore") + if isinstance(ignore, (list, tuple)): + for item in ignore: + type_check(item, (int,), "item") + check_value(item, [0, 255], "ignore") + return method(self, *args, **kwargs) + + return new_method + + def check_uniform_augment_py(method): """Wrapper method to check the parameters of python UniformAugment op.""" @@ -588,3 +622,26 @@ def check_compose_list(method): return method(self, *args, **kwargs) return new_method + + +def check_random_select_subpolicy_op(method): + """Wrapper method to check the parameters of RandomSelectSubpolicyOp.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [policy], _ = parse_user_args(method, *args, **kwargs) + type_check(policy, (list,), "policy") + if not policy: + raise ValueError("policy can not be empty.") + for sub_ind, sub in enumerate(policy): + type_check(sub, (list,), "policy[{0}]".format([sub_ind])) + if not sub: + raise ValueError("policy[{0}] can not be empty.".format(sub_ind)) + for op_ind, tp in enumerate(sub): + check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind)) + check_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind)) + check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind)) + + return method(self, *args, **kwargs) + + return new_method diff --git a/mindspore/hub.py b/mindspore/hub.py new file mode 100644 index 0000000000..52a1e7754f --- /dev/null +++ b/mindspore/hub.py @@ -0,0 +1,214 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +hub for loading models: +Users can load pre-trained models using mindspore.hub.load() API. +""" +import os +import re +import shutil +import tarfile +import hashlib +from urllib.request import urlretrieve +import requests +from bs4 import BeautifulSoup + +import mindspore +import mindspore.nn as nn +from mindspore import log as logger +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo" +OFFICIAL_NAME = "official" +DEFAULT_CACHE_DIR = '.cache' +MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo'] +MODEL_TARGET_NLP = ['bert', 'mass', 'transformer'] + + +def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR): + """ + Packing the input filename to filename.tar.gz in source dir. + """ + try: + with tarfile.open(output_filename, "w:gz") as tar: + tar.add(savepath, arcname=os.path.basename(savepath)) + except Exception as e: + raise OSError("Cannot tar file {} for - {}".format(output_filename, e)) + + +def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR): + """ + Unpacking the input filename to dirs. + """ + try: + t = tarfile.open(input_filename) + t.extractall(path=savepath) + except Exception as e: + raise OSError("Cannot untar file {} for - {}".format(input_filename, e)) + + +def _remove_path_if_exists(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + + +def _create_path_if_not_exists(path): + if not os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + os.mkdir(path) + + +def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR): + """ + get checkpoint weight from giving url. + + Args: + url(string): checkpoint tar.gz url path. + hash_md5(string): checkpoint file md5. + savepath(string): checkpoint download save path. + + Returns: + string. + """ + + def reporthook(a, b, c): + percent = a * b * 100.0 / c + show_str = ('[%%-%ds]' % 70) % (int(percent * 80) * '#') + print("\rDownloading:", show_str, " %5.1f%%" % (percent), end="") + + def md5sum(file_name, hash_md5): + fp = open(file_name, 'rb') + content = fp.read() + fp.close() + m = hashlib.md5() + m.update(content.encode('utf-8')) + download_md5 = m.hexdigest() + return download_md5 == hash_md5 + + _remove_path_if_exists(os.path.realpath(savepath)) + _create_path_if_not_exists(os.path.realpath(savepath)) + ckpt_name = os.path.basename(url.split("/")[-1]) + # identify file exist or not + file_path = os.path.join(savepath, ckpt_name) + if os.path.isfile(file_path): + if hash_md5 and md5sum(file_path, hash_md5): + print('File already exists!') + return file_path + + file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path + _remove_path_if_exists(file_path_) + + # download the checkpoint file + print('Downloading data from url {}'.format(url)) + try: + urlretrieve(url, file_path, reporthook=reporthook) + except HTTPError as e: + raise Exception(e.code, e.msg, url) + except URLError as e: + raise Exception(e.errno, e.reason, url) + print('\nDownload finished!') + + # untar file_path + _unpacking_targz(file_path, os.path.realpath(savepath)) + + filesize = os.path.getsize(file_path) + # turn the file size to Mb format + print('File size = %.2f Mb' % (filesize / 1024 / 1024)) + return file_path_ + + +def _get_url_paths(url, ext='.tar.gz'): + response = requests.get(url) + if response.ok: + response_text = response.text + else: + return response.raise_for_status() + soup = BeautifulSoup(response_text, 'html.parser') + parent = [url + node.get('href') for node in soup.find_all('a') + if node.get('href').endswith(ext)] + return parent + + +def _get_file_from_url(base_url, base_name): + idx = 0 + urls = _get_url_paths(base_url + "/") + files = [url.split('/')[-1] for url in urls] + for i, name in enumerate(files): + if re.match(base_name + '*', name) is not None: + idx = i + break + return urls[idx] + + +def load_weights(network, network_name=None, force_reload=True, **kwargs): + r""" + Load a model from mindspore, with pretrained weights. + + Args: + network (Cell): Cell network. + network_name (string, optional): Cell network name get from network. Default: None. + force_reload (bool, optional): Whether to force a fresh download unconditionally. Default: False. + kwargs (dict, optional): The corresponding kwargs for download for model. + + - device_target (str, optional): Runtime device target. Default: 'ascend'. + - dataset (str, optional): Dataset to train the network. Default: 'cifar10'. + - version (str, optional): MindSpore version to save the checkpoint. Default: Latest version. + + Example: + >>> hub.load(network, network_name='lenet', + **{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'}) + """ + if not isinstance(network, nn.Cell): + logger.error("Failed to combine the net and the parameters.") + msg = ("Argument net should be a Cell, but got {}.".format(type(network))) + raise TypeError(msg) + + if network_name is None: + if hasattr(network, network_name): + network_name = network.network_name + else: + msg = "Should input network name, but got None." + raise TypeError(msg) + + device_target = kwargs['device_target'] if kwargs['device_target'] else 'ascend' + dataset = kwargs['dataset'] if kwargs['dataset'] else 'imagenet' + version = kwargs['version'] if kwargs['version'] else mindspore.version.__version__ + + if network_name.split("_")[0] in MODEL_TARGET_CV: + model_type = "cv" + elif network_name.split("_")[0] in MODEL_TARGET_NLP: + model_type = "nlp" + else: + raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0])) + + download_base_url = "/".join([DOWNLOAD_BASIC_URL, + OFFICIAL_NAME, model_type, network_name]) + download_file_name = "_".join( + [network_name, device_target, version, dataset, OFFICIAL_NAME]) + download_url = _get_file_from_url(download_base_url, download_file_name) + + if force_reload: + ckpt_path = _get_weights_file(download_url, None, DEFAULT_CACHE_DIR) + else: + raise ValueError("Unsupported not force reload.") + + ckpt_file = os.path.join(ckpt_path, network_name + ".ckpt") + param_dict = load_checkpoint(ckpt_file) + load_param_into_net(network, param_dict) diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt new file mode 100644 index 0000000000..76808d7506 --- /dev/null +++ b/mindspore/lite/CMakeLists.txt @@ -0,0 +1,158 @@ +cmake_minimum_required(VERSION 3.14) +project (Lite) + +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) + message(FATAL_ERROR "GCC vesion ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") +endif () + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") +set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../..) +set(CORE_DIR ${TOP_DIR}/mindspore/core) +set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc) +include_directories(${TOP_DIR}) +include_directories(${CORE_DIR}) +include_directories(${CCSRC_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${TOP_DIR}/third_party) +include_directories(${TOP_DIR}/third_party/flatbuffers/include) + +include(${TOP_DIR}/cmake/utils.cmake) +include(${TOP_DIR}/cmake/dependency_utils.cmake) +include(${TOP_DIR}/cmake/dependency_securec.cmake) + +option(CMAKE_BUILD_TYPE "build type" Release) +option(BUILD_DEVICE "if build device" on) +option(SUPPORT_TRAIN "if build for on-device train" off) +option(PLATFORM_ARM64 "if build device for arm64" off) +option(PLATFORM_ARM32 "if build device for arm32" off) +option(BUILD_CONVERTER "if build converter" on) +option(ENABLE_FP16 "if build fp16 ops" off) +option(SUPPORT_GPU "if support gpu" off) +option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) +option(BUILD_MINDDATA "" off) + +set(CMAKE_VERBOSE_MAKEFILE on) +add_compile_definitions(USE_ANDROID_LOG) +add_compile_definitions(NO_DLIB) +add_compile_options(-fPIC) +if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") +endif() + +if (BUILD_DEVICE) + add_compile_definitions(BUILD_DEVICE) +endif() +if (SUPPORT_TRAIN) + add_compile_definitions(SUPPORT_TRAIN) +endif() +if (ENABLE_NEON) + add_compile_definitions(ENABLE_NEON) +endif () +if (ENABLE_FP16) + add_compile_definitions(ENABLE_FP16) +endif () +if (SUPPORT_GPU) + add_definitions(-DUSE_OPENCL_WRAPPER) + add_definitions(-DMS_OPENCL_PROFILE=false) + add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) + add_compile_definitions(SUPPORT_GPU) + if(OFFLINE_COMPILE) + add_compile_definitions(PROGRAM_WITH_IL) + endif() + include_directories(${TOP_DIR}/third_party/OpenCL-Headers) + include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include) +endif() + +set(ANF_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../core/ir/meta_tensor.cc + ${CCSRC_DIR}/gvar/logging_level.cc + ${CCSRC_DIR}/gvar/typeid_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../core/base/base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../core/utils/log_adapter.cc + ) +if (BUILD_CONVERTER) + if (PLATFORM_ARM64 OR PLATFORM_ARM32) + MESSAGE(FATAL_ERROR "Cannot build converter in arm platform") + endif() + find_package(Python3 3.7 COMPONENTS Interpreter Development) + if(Python3_FOUND) + set(PYTHON_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}") + set(PYTHON_LIBRARIES "${Python3_LIBRARIES}") + if (WIN32) + if (Python3_DIR) + message("Python3_DIR set already: " ${Python3_DIR}) + else() + string(LENGTH ${PYTHON_LIBRARIES} PYTHON_LIBRARIES_LEN) + string(LENGTH "libpythonxx.a" Python3_NAME_LEN) + math(EXPR Python3_DIR_LEN ${PYTHON_LIBRARIES_LEN}-${Python3_NAME_LEN}) + string(SUBSTRING ${Python3_LIBRARIES} 0 ${Python3_DIR_LEN} Python3_DIR) + message("Python3_DIR: " ${Python3_DIR}) + endif() + link_directories(${Python3_DIR}) + endif() + else() + find_python_package(py_inc py_lib) + set(PYTHON_INCLUDE_DIRS "${py_inc}") + set(PYTHON_LIBRARIES "${py_lib}") + endif() + include_directories(${PYTHON_INCLUDE_DIRS}) +# include(${TOP_DIR}/cmake/utils.cmake) +# include(${TOP_DIR}/cmake/dependency_utils.cmake) + include(${TOP_DIR}/cmake/external_libs/json.cmake) +# include(${TOP_DIR}/cmake/dependency_securec.cmake) + include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) + include(${TOP_DIR}/cmake/external_libs/eigen.cmake) + include_directories(${TOP_DIR}/third_party/protobuf/build/include) + link_directories(${TOP_DIR}/third_party/protobuf/build/lib) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) + add_subdirectory(src/common/anf_exporter) +endif() + +if (BUILD_DEVICE) + if (PLATFORM_ARM32 OR PLATFORM_ARM64) + if (NOT DEFINED ENV{ANDROID_NDK}) + message(FATAL_ERROR "env ANDROID_NDK should be setted for ARM compile") + endif() + add_compile_definitions(ENABLE_ARM) + endif() + if (PLATFORM_ARM32) + add_definitions(-mfloat-abi=softfp -mfpu=neon) + add_compile_definitions(ENABLE_ARM32) + endif() + if (PLATFORM_ARM64) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + add_compile_definitions(ENABLE_ARM64) + if (ENABLE_FP16) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + endif () + endif() +endif() + +if (BUILD_MINDDATA) + # opencv + set(OpenCV_DIR ${TOP_DIR}/third_party/opencv/build) + find_package(OpenCV REQUIRED) + include_directories(${OpenCV_INCLUDE_DIRS}) + # eigen + include_directories(${TOP_DIR}/third_party/eigen/) + # jpeg-turbo + add_library(jpeg-turbo SHARED IMPORTED) + set_target_properties(jpeg-turbo PROPERTIES + IMPORTED_LOCATION ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so + ) + add_library(jpeg SHARED IMPORTED) + set_target_properties(jpeg PROPERTIES + IMPORTED_LOCATION ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so + ) + include_directories(${TOP_DIR}/third_party/libjpeg-turbo/include) + + add_compile_definitions(ENABLE_ANDROID) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/minddata) +endif() + +if (BUILD_DEVICE) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/test) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/time_profile) +endif() diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h new file mode 100644 index 0000000000..02b6cd04e2 --- /dev/null +++ b/mindspore/lite/include/context.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_CONTEXT_H_ +#define MINDSPORE_LITE_INCLUDE_CONTEXT_H_ + +#include +#include +#include "include/ms_tensor.h" + +namespace mindspore::lite { +/// \brief Allocator defined by MindSpore Lite. +/// +/// \note List public class and interface for reference. +class Allocator; + +/// \brief CpuBindMode defined by MindSpore Lite. +enum CpuBindMode { + MID_CPU = -1, /**< bind mid cpu first */ + HIGHER_CPU = 1, /**< bind higher cpu first */ + NO_BIND = 0 /**< no bind */ +}; + +/// \brief DeviceType defined by MindSpore Lite. +typedef enum { + DT_CPU, /**< CPU device type */ + DT_GPU, /**< GPU device type */ + DT_NPU /**< NPU device type */ +} DeviceType; + +/// \brief DeviceContext defined by MindSpore Lite. +typedef struct { + DeviceType type; /**< device type */ +} DeviceContext; + +/// \brief Context defined by MindSpore Lite +class MS_API Context { + public: + /// \brief Constructor of MindSpore Lite context using default value for parameters. + /// + /// \return Instance of MindSpore Lite Context. + Context(); + + /// \brief Constructor of MindSpore Lite Context using input value for parameters. + /// + /// \param[in] thread_num Define the threadNum during the runtime. + /// \param[in] allocator Define the allocator for malloc. + /// \param[in] device_ctx Define device information during the runtime. + Context(int thread_num, std::shared_ptr allocator, DeviceContext device_ctx); + + /// \brief Destructor of MindSpore Lite Context. + virtual ~Context(); + + public: + DeviceContext device_ctx_{DT_CPU}; + int thread_num_ = 2; /**< thread number config for thread pool */ + std::shared_ptr allocator = nullptr; + CpuBindMode cpu_bind_mode_ = MID_CPU; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_ diff --git a/mindspore/lite/include/errorcode.h b/mindspore/lite/include/errorcode.h new file mode 100644 index 0000000000..2cdd4659de --- /dev/null +++ b/mindspore/lite/include/errorcode.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ +#define MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ + +namespace mindspore { +namespace lite { +using STATUS = int; + +/* Success */ +constexpr int RET_OK = 0; /**< No error occurs. */ + +/* Common error code, range: [-1, -100]*/ +constexpr int RET_ERROR = -1; /**< Common error code. */ +constexpr int RET_NULL_PTR = -2; /**< NULL pointer returned.*/ +constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/ +constexpr int RET_NO_CHANGE = -4; /**< No change. */ +constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */ +constexpr int RET_MEMORY_FAILED = -6; /**< Create memory failed. */ + +/* Executor error code, range: [-101,-200] */ +constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to checking range. */ +constexpr int RET_INPUT_TENSOR_ERROR = -102; /**< Failed to checking input tensor. */ +constexpr int RET_REENTRANT_ERROR = -103; /**< Exist executor running. */ + +/* Graph error code, range: [-201,-300] */ +constexpr int RET_GRAPH_FILE_ERR = -201; /**< Failed to verify graph file. */ + +/* Node error code, range: [-301,-400] */ +constexpr int RET_NOT_FIND_OP = -301; /**< Failed to find operator. */ +constexpr int RET_INVALID_OP_NAME = -302; /**< Invalid operator name. */ +constexpr int RET_INVALID_OP_ATTR = -303; /**< Invalid operator attr. */ +constexpr int RET_OP_EXECUTE_FAILURE = -304; /**< Failed to execution operator. */ + +/* Tensor error code, range: [-401,-500] */ +constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */ +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ + diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h new file mode 100644 index 0000000000..ea762f0f60 --- /dev/null +++ b/mindspore/lite/include/lite_session.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H +#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H + +#include +#include +#include +#include "include/ms_tensor.h" +#include "include/model.h" +#include "include/context.h" + +namespace mindspore { +namespace session { +struct CallBackParam { + std::string name_callback_param; + std::string type_callback_param; +}; + +using KernelCallBack = std::function inputs, + std::vector outputs, const CallBackParam &opInfo)>; + +/// \brief LiteSession defined by MindSpore Lite. +class MS_API LiteSession { + public: + /// \brief Static method to create a LiteSession pointer. + /// + /// \param[in] context Define the context of session to be created. + /// + /// \return Pointer of MindSpore Lite LiteSession. + static LiteSession *CreateSession(lite::Context *context); + + /// \brief Destructor of MindSpore Lite LiteSession. + virtual ~LiteSession() = default; + + /// \brief Try to bind or unbind threads in the thread pool to specified cpu core. + /// + /// \param[in] if_bind Define weather to bind or unbind threads. + virtual void BindThread(bool if_bind) = 0; + + /// \brief Compile MindSpore lite model. + /// + /// \note CompileGraph should called before RunGraph. + /// + /// \param[in] model Define the model to be compiled. + /// + /// \return ErrorCode of compile graph. + virtual int CompileGraph(lite::Model *model) = 0; + + /// \brief Get input MindSpore Lite MSTensors of model. + /// + /// \return A vector of MindSpore Lite MSTensor. + virtual std::vector GetInputs() const = 0; + + /// \brief Get input MindSpore Lite MSTensors of model by node name. + /// + /// \param[in] node_name Define node name. + /// + /// \return A vector of MindSpore Lite MSTensor. + virtual std::vector GetInputsByName(const std::string &node_name) const = 0; + + /// \brief Run session with callback. + /// + /// \param[in] before Define a call_back_function called before running each node + /// \param[in] after Define a call_back_function called after running each node + /// + /// \note RunGraph should called after CompileGraph. + /// + /// \return ErrorCode of run graph. + virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0; + + /// \brief Get output MindSpore Lite MSTensors of model. + /// + /// \return A vector of MindSpore Lite MSTensor. + virtual std::vector GetOutputs() const = 0; + + /// \brief Get output MindSpore Lite MSTensors of model by node name. + /// + /// \param[in] node_name Define node name. + /// + /// \return A vector of MindSpore Lite MSTensor. + virtual std::vector GetOutputsByName(const std::string &node_name) const = 0; +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h new file mode 100644 index 0000000000..5000d8a1d5 --- /dev/null +++ b/mindspore/lite/include/model.h @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_MODEL_H +#define MINDSPORE_LITE_INCLUDE_MODEL_H + +#include +#include +#include +#include "schema/model_generated.h" + +namespace mindspore { +#define MS_API __attribute__((visibility("default"))) + +/// \brief ModelImpl defined by MindSpore Lite. +/// +/// \note List public class and interface for reference. +class ModelImpl; + +namespace lite { +/// \brief Primitive defined by MindSpore Lite. +/// +/// \note List public class and interface for reference. +class Primitive; + +/// \brief Model defined by MindSpore Lite. +class MS_API Model { + public: + /// \brief Static method to create a Model pointer. + /// + /// \param[in] model_buf Define the buffer read from a model file. + /// \param[in] size Define bytes numbers of model buffer. + /// + /// \return Pointer of MindSpore Lite Model. + static std::shared_ptr Import(const char *model_buf, size_t size); + + /// \brief Constructor of MindSpore Lite Model using default value for parameters. + /// + /// \return Instance of MindSpore Lite Model. + Model() = default; + + /// \brief Destructor of MindSpore Lite Model. + virtual ~Model() = default; + + /// \brief Get MindSpore Lite Primitive by name. + /// + /// \param[in] name Define name of primitive to be returned. + /// + /// \return A pointer of MindSpore Lite Primitive. + lite::Primitive *GetOp(const std::string &name) const; + + /// \brief Get MindSpore Lite MetaGraph. + /// + /// \return A pointer of MindSpore Lite MetaGraph. + const schema::MetaGraph *GetMetaGraph() const; + + /// \brief Get MindSpore Lite ModelImpl. + /// + /// \return A pointer of MindSpore Lite ModelImpl. + std::shared_ptr model_impl(); + + /// \brief Free MetaGraph in MindSpore Lite Model. + void FreeMetaGraph(); + + protected: + std::shared_ptr model_impl_ = nullptr; +}; + +/// \brief ModelBuilder defined by MindSpore Lite. +class MS_API ModelBuilder { + public: + /// \brief OutEdge defined by MindSpore Lite. + struct OutEdge { + std::string nodeId; /**< Id of a node linked by this edge */ + size_t outEdgeIndex; /**< Index of this edge */ + }; + + /// \brief Constructor of MindSpore Lite Model using default value for parameters. + /// + /// \return Instance of MindSpore Lite ModelBuilder. + ModelBuilder() = default; + + /// \brief Destructor of MindSpore Lite ModelBuilder. + virtual ~ModelBuilder() = default; + + /// \brief Add primitive into model builder for model building. + /// + /// \param[in] op Define the primitive to be added. + /// \param[in] inputs Define input edge of primitive to be added. + /// + /// \return Id of the primitive added. + virtual std::string AddOp(const lite::Primitive &op, const std::vector &inputs) = 0; + + /// \brief Finish constructing the model. + /// + /// \return A pointer of MindSpore Lite Model. + virtual Model *Construct(); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_INCLUDE_MODEL_H diff --git a/mindspore/lite/include/ms_tensor.h b/mindspore/lite/include/ms_tensor.h new file mode 100644 index 0000000000..911d4faa1a --- /dev/null +++ b/mindspore/lite/include/ms_tensor.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_ +#define MINDSPORE_INCLUDE_MS_TENSOR_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" + +namespace mindspore { +#define MS_API __attribute__((visibility("default"))) +namespace tensor { +/// \brief MSTensor defined by MindSpore Lite. +class MS_API MSTensor { + public: + /// \brief Constructor of MindSpore Lite MSTensor. + /// + /// \return Instance of MindSpore Lite MSTensor. + MSTensor() = default; + + /// \brief Static method to create a MSTensor pointer. + /// + /// \param[in] data_type Define data type of tensor to be created. + /// \param[in] shape Define Shape of tensor to be created. + /// + /// \note TypeId is defined in mindspore/mindspore/core/ir/dtype/type_id.h. Only number types in TypeId enum is + /// suitable for MSTensor. + /// + /// \return A pointer of MSTensor. + static MSTensor *CreateTensor(TypeId data_type, const std::vector &shape); + + /// \brief Destructor of MindSpore Lite Model. + virtual ~MSTensor() = default; + + /// \brief Get data type of the MindSpore Lite MSTensor. + /// + /// \note TypeId is defined in mindspore/mindspore/core/ir/dtype/type_id.h. Only number types in TypeId enum is + /// suitable for MSTensor. + /// + /// \return MindSpore Lite TypeId of the MindSpore Lite MSTensor. + virtual TypeId data_type() const = 0; + + /// \brief Set data type for the MindSpore Lite MSTensor. + /// + /// \param[in] data_type Define MindSpore Lite TypeId to be set into the MindSpore Lite MSTensor. + /// + /// \return MindSpore Lite TypeId of the MindSpore Lite MSTensor after set. + virtual TypeId set_data_type(TypeId data_type) = 0; + + /// \brief Get shape of the MindSpore Lite MSTensor. + /// + /// \return A vector of int as the shape of the MindSpore Lite MSTensor. + virtual std::vector shape() const = 0; + + /// \brief Set shape for the MindSpore Lite MSTensor. + /// + /// \param[in] shape Define A vector of int as shape to be set into the MindSpore Lite MSTensor. + /// + /// \return size of shape of the MindSpore Lite MSTensor after set. + virtual size_t set_shape(const std::vector &shape) = 0; + + /// \brief Get size of the dimension of the MindSpore Lite MSTensor index by the parameter index. + /// + /// \param[in] index Define index of dimension returned. + /// + /// \return Size of dimension of the MindSpore Lite MSTensor. + virtual int DimensionSize(size_t index) const = 0; + + /// \brief Get number of element in MSTensor. + /// + /// \return Number of element in MSTensor. + virtual int ElementsNum() const = 0; + + /// \brief Get hash of the MindSpore Lite MSTensor. + /// + /// \return Hash of the MindSpore Lite MSTensor. + virtual std::size_t hash() const = 0; + + /// \brief Get byte size of data in MSTensor. + /// + /// \return Byte size of data in MSTensor. + virtual size_t Size() const = 0; + + /// \brief Get pointer of data in MSTensor. + /// + /// \note The data pointer can be used to both write or read data in MSTensor. + /// + /// \return A pointer points to data in MSTensor. + virtual void *MutableData() const = 0; +}; + +using MultiTensor = std::vector>>; +} // namespace tensor +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_ diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt new file mode 100644 index 0000000000..fcf6d721a9 --- /dev/null +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -0,0 +1,47 @@ +set(MINDDATA_DIR ${CCSRC_DIR}/minddata/dataset) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -Wno-deprecated-declarations") +set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb") +if (CMAKE_BUILD_TYPE EQUAL "DEBUG") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s") +endif() + +AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/core MINDDATA_CORE_SRC_FILES) +list(REMOVE_ITEM MINDDATA_CORE_SRC_FILES "${MINDDATA_DIR}/core/client.cc") + +AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels MINDDATA_KERNELS_SRC_FILES) +list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") + +AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/image MINDDATA_KERNELS_IMAGE_SRC_FILES) + +AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/data MINDDATA_KERNELS_DATA_SRC_FILES) + +add_library(minddata-eager OBJECT + ${MINDDATA_DIR}/api/de_tensor.cc + ${MINDDATA_DIR}/api/execute.cc + ) + +add_library(minddata-lite SHARED + ${MINDDATA_CORE_SRC_FILES} + ${MINDDATA_KERNELS_SRC_FILES} + ${MINDDATA_KERNELS_IMAGE_SRC_FILES} + ${MINDDATA_KERNELS_DATA_SRC_FILES} + ${MINDDATA_DIR}/util/status.cc + ${MINDDATA_DIR}/util/memory_pool.cc + ${MINDDATA_DIR}/util/path.cc + ${MINDDATA_DIR}/api/transforms.cc + ${CORE_DIR}/utils/log_adapter.cc + ${CCSRC_DIR}/gvar/logging_level.cc + ) + +target_link_libraries(minddata-lite + securec + jpeg-turbo + jpeg + opencv_core + opencv_imgcodecs + opencv_imgproc + mindspore::json + ) \ No newline at end of file diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs new file mode 100644 index 0000000000..cb1bad7031 --- /dev/null +++ b/mindspore/lite/schema/model.fbs @@ -0,0 +1,225 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +include "ops.fbs"; + +namespace mindspore.schema; + +enum NodeType: int { + ValueNode, // const + Parameter, // var + CNode // op +} + +table QuantParam { + scale: double; + zeroPoint: int; + min: double = 0; + max: double = 0; + narrowRange: bool = true; + numBits: int = 8; + inited: bool = false; +} + +table Tensor { + nodeType: NodeType; + // data type + dataType: int; + // shape + dims: [int]; + format: Format; + refCount: int; + offset: int; + data: [ubyte]; + quantParams: [QuantParam]; +} + +union PrimitiveType { + Concat, + SoftMax, + Activation, + Conv2D, + FusedBatchNorm, + BatchNorm, + BiasAdd, + Pooling, + DepthwiseConv2D, + DeDepthwiseConv2D, + Resize, + DetectionPostProcess, + FullConnection, + Mean, + DeConv2D, + Scale, + Reshape, + Eltwise, + NetOutput, + Add, + Sub, + MatMul, + StridedSlice, + Power, + Slice, + Stack, + Mul, + RealDiv, + Pad, + Maximum, + Minimum, + CaffePReLU, + LeakyReLU, + ArgMax, + ArgMin, + Exp, + Crop, + Range, + Rsqrt, + ExpandDims, + Tile, + Cast, + Shape, + Nchw2Nhwc, + Nhwc2Nchw, + QuantDTypeCast, + Split, + Permute, + FakeQuantWithMinMaxVars, + Equal, + Less, + Greater, + NotEqual, + LessEqual, + GreaterEqual, + Min, + Floor, + Abs, + Neg, + Cos, + Sin, + Sqrt, + Square, + Constant, + Log, + Tan, + Atan, + Asin, + Clip, + Transpose, + Squeeze, + Unsqueeze, + Upsample, + Dropout, + Broadcast, + BroadcastTo, + Lrn, + Prelu, + ZerosLike, + TopK, + SpaceToDepth, + SpaceToBatch, + SparseToDense, + ReverseSequence, + Rank, + Gather, + GatherNd, + Fill, + Elu, + DepthToSpace, + BatchToSpace, + AddN, + Ceil, + EmbeddingLookup, + EmbeddingLookupSparse, + FloorDiv, + FloorMod, + L2Norm, + LocalResponseNormalization, + MatrixDiag, + Reduce, + Reverse, + Round, + Select, + Scatter, + ScatterND, + Unique, + Unstack, + LogicalAnd, + LogicalOr, + LogicalXor, + LogicalNot, + OnnxInt8Quantize, + OnnxInt8Dequantize, + FakeQuantWithMinMax, + FakeQuantWithMinMaxPerChannel, + BatchNormFold, + MulFold, + AddFold, + SquaredDifference, + Flatten, + TupleGetItem, + Div, + Where, + OneHot, + Lstm, + Conv2DGradFilter, + Conv2DGradInput, + PoolingGrad, + BNGradInput, + OptMomentum, + BiasGrad, + SoftmaxCrossEntropy, + AddGrad, + SubGrad, + MulGrad, + DivGrad, + PowerGrad, + ActivationGrad, + PriorBox, + SpaceToBatchND, + TopKV2 +} + +enum QuantType: int { + QUANT_NONE, + AwareTrainning, + WeightQuant, + PostTraining +} + +table Primitive { + value: PrimitiveType; +} + +table CNode { + name: string; + nodeType: NodeType = CNode; + primitive: Primitive; + inputIndex: [uint]; + outputIndex: [uint]; + quantType: QuantType = QUANT_NONE; +} + +table MetaGraph { + name: string; + fmkType: int; // 0:tf,1:caffe + inputIndex: [uint]; + outputIndex: [uint]; + mempoolSize: uint; + nodes: [CNode]; + allTensors: [Tensor]; // weight + input + output +} + +root_type MetaGraph; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs new file mode 100644 index 0000000000..fe97b0fb9e --- /dev/null +++ b/mindspore/lite/schema/ops.fbs @@ -0,0 +1,865 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace mindspore.schema; + +enum ResizeMethod: byte { + UNKNOW = -1, + BILINEAR = 0, + NEAREST_NEIGHBOR = 1 +} + +enum Format : int { + NCHW = 0, + NHWC, + NHWC4, + HWKC, + HWCK, + KCHW, + CKHW, + KHWC, + CHWK, + NC4HW4 = 100, + NUM_OF_FORMAT +} + +enum ActivationType : byte { + NO_ACTIVATION = 0, + RELU = 1, + SIGMOID = 2, + RELU6 = 3, + ELU = 4, + LEAKY_RELU = 5, + ABS = 6, + RELU1 = 7, + SOFTSIGN = 8, + SOFTPLUS = 9, + TANH = 10, + SELU = 11, + HSWISH = 12, + HSIGMOID = 13, + THRESHOLDRELU = 14, + LINEAR = 15, + UNKNOW = 16 +} +enum ActivationGradType : byte { + NO_ACTIVATION = 0, + RELU = 1, + SIGMOID = 2, + RELU6 = 3, + ELU = 4, + LEAKY_RELU = 5, + ABS = 6, + RELU1 = 7, + SOFTSIGN = 8, + SOFTPLUS = 9, + TANH = 10, + SELU = 11, + HSWISH = 12, + HSIGMOID = 13, + THRESHOLDRELU = 14, + LINEAR = 15, + UNKNOW = 16 +} +enum ReduceType : byte { + REDUCE_MAX = 0, + REDUCE_MEAN = 1, + REDUCE_ALL = 2, + REDUCE_ANY = 3, + REDUCE_LOG_SUM_EXP = 4, + REDUCE_PROD = 5, + REDUCE_SUM = 6, + UNKNOW = 7 +} + +enum PoolMode : byte { + MAX_POOLING = 0, + MEAN_POOLING = 1, +} + +enum EltwiseMode : byte { + PROD = 0, + SUM = 1, + MAXIMUM = 2, + UNKNOW = 3 +} + +enum PadMode : byte { + NOTSET = 0, + SAME = 1, + VALID = 2, + CAFFE = 4 +} + +enum RoundMode : byte { + FLOOR = 0, + CEIL = 1 +} + +enum PaddingMode : byte { + CONSTANT = 0, + REFLECT = 1, + SYMMETRIC = 2, + MODE_RESERVED = 3 +} + +table Pad { + paddings: [int]; + paddingMode: PaddingMode; + constantValue: float; +} + +table Maximum { +} + +table Minimum { +} + +table Flatten { +} + +table Concat { + axis: int; + n: int; +} + +table SoftMax { + axis: int; +} + +table Activation { + type: ActivationType = 0; +} +table ActivationGrad { + type: ActivationGradType = 0; +} + + +table Conv2D { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table Conv2DGradFilter { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table Conv2DGradInput { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +}table FusedBatchNorm { + epsilon: float = 0.00001; // eg. epsilon=0.001 + momentum: float = 0.9; + spatial: int = 1; +} + +table BatchNorm { + epsilon: float = 0.00001; // eg. epsilon=0.001 +} + +table BiasGrad { + axis: [int]; +} + + +table SoftmaxCrossEntropy { + axis: [int]; +} + + +table PoolingGrad { + format: Format = 0; + poolingMode: PoolMode; + global: bool = false; + windowW: int; + windowH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + roundMode: RoundMode; +} +table Shape { +} + +table Nchw2Nhwc { + +} + +table Nhwc2Nchw { + +} + +table FakeQuantWithMinMaxVars { + narrowRange: bool; + numBits: int; +} + +table BiasAdd { + axis: [int]; +} + +table Pooling { + format: Format = 0; + poolingMode: PoolMode; + global: bool = false; + windowW: int; + windowH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + roundMode: RoundMode; +} + +table DepthwiseConv2D { + format: Format = 0; + channelIn: int; + channelMultiplier: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table DeDepthwiseConv2D { + format: Format = 0; + channelIn: int; + channelMultiplier: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + + +table Resize { + format: Format = 0; + method: ResizeMethod; + newHeight: long; + newWidth: long; + alignCorners: bool = false; + preserveAspectRatio: bool = false; +} + +table DetectionPostProcess { + format: Format = 0; + inputSize: int; + hScale: float; + wScale: float; + xScale: float; + yScale: float; + NmsIouThreshold: float; + NmsScoreThreshold: float; + MaxDetections: long; + DetectionsPreClass: long; + MaxClassesPreDetection: long; + NumClasses: long; + UseRegularNms: bool; +} + +table FullConnection { + hasBias: bool; + axis: int; +} + +// Mean(input_tensor, axis, keep_dims) +table Mean { + axis: [int]; + keepDims: bool = false; +} + +table DeConv2D { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} +table BNGradInput { + eps : float; + channels: int; +} +table Scale { + axis: int; +} + +table Eltwise { + mode: EltwiseMode; +} + +table Add { + activationType: ActivationType = 0; +} + +table Sub { + activationType: ActivationType = 0; +} + +table Mul { + activationType: ActivationType = 0; +} + +table Div { + activationType: ActivationType = 0; +} + +table AddGrad { +} + +table SubGrad { +} + +table MulGrad { +} + +table DivGrad { +} +table RealDiv { +} + +table Rsqrt { +} + +table Equal { +} + +table Less { +} + +table Greater { +} + +table NotEqual { +} + +table LessEqual { +} + +table GreaterEqual { +} + +table Min { +} + +table Slice { + format: Format = 0; + begin: [int]; + size: [int]; +} + +table Floor { +} + +table Abs { +} + +table Neg { +} + +table Exp { +} + +table Cos { +} + +table Sin { +} + +table Sqrt { +} + +table Square { +} + +table Ceil { +} + +table Log { +} + +table Tan { +} + +table Atan { +} + +table Asin { +} + +table Reshape { + format: Format = 0; + shape: [long]; +} + +table Power { + power: float; + scale: float; + shift: float; +} +table PowerGrad { + power: float; + scale: float; + shift: float; +} +table ArgMax { + axis: int; + outMaxValue: bool; + topK: int = 1; + keepDims: bool; + axisType: int; +} + +table ArgMin { + axis: int; + outMaxValue: bool; + topK: int = 1; + keepDims: bool; + axisType: int; +} + +table NetOutput { +} + +table MatMul { + transposeA : bool = false; + transposeB : bool = false; +} + +table CaffePReLU { + channelShared : bool = false; +} + +table LeakyReLU { + negativeSlope: float; +} + +table StridedSlice { + beginMask: int; + endMask: int; + ellipsisMask: int; + newAxisMask: int; + shrinkAxisMask: int; + begin: [int]; + end: [int]; + stride: [int]; + isScale: [int]; +} + +table Stack { + axis: int; + n: int; + isScale: [int]; +} + +table Range { + dType: int; + start: int; + limit: int; + delta: int; +} + +table ExpandDims { + dim: int; +} + +table Tile { + multiples: [int]; +} + +table Cast { + srcT: int; + dstT: int; +} + +table QuantDTypeCast { + srcT: int; + dstT: int; +} + +table Split { + numberSplit: int; + sizeSplits: [int]; + splitDim: int; +} + +table Crop { + axis : long; + offsets : [long]; +} + +table Permute { + order: [long]; +} + +table Clip { + max: float; + min: float; +} + +table Constant { +} + + +table Elu { + alpha: float = 1.0; +} + +table Broadcast { +} + +table BroadcastTo { + dst_shape: [int]; +} + +table Lrn { + alpha: float = 0.0001; + beta: float = 0.75; + bias: float = 1.0; + size: int; +} + +enum ReduceMode : byte { + ReduceMean = 0, + ReduceMax = 1, + ReduceMin = 2, + ReduceProd = 3, + ReduceSum = 4, + ReduceSumSquare = 5 +} + +table Reduce { + axes: [int]; + keepDims: int; + mode: ReduceMode; +} + +table Prelu { + slope: [float]; +} + +table Transpose { + perm: [int]; + conjugate: bool = false; +} + +table Squeeze { + axis: [int]; +} + +table Unsqueeze { + axis: [int]; +} + +table Upsample { + mode: string; + scales: [float]; +} + +table Dropout { + ratio : float = 0.5; +} + +table LocalResponseNormalization { + depth_radius: int; + bias: float; + alpha: float; + beta: float; +} + +table ZerosLike { +} + +table TopK { + k : int; + sorted : bool = true; +} + +table SpaceToDepth { + blockSize : int; + format: Format = 0; +} + +table SpaceToBatch { + blockShape : [int]; + paddings : [int]; +} + +table SparseToDense { + outputShape: [int]; + sparseValue: [int]; + defaultValue: [int]; + validateIndices: bool; +} + +table ReverseSequence { + seqAxis: int; + batchAxis: int; + seqLengths: [int]; +} + +table Rank { +} + + +table Gather { + axis: int; + batchDims: int; +} + +table GatherNd { + batchDims: int; +} + +table Fill { + dims: [int]; +} + +table DepthToSpace { + blockSize: int; + format: Format = 0; +} + + +table BatchToSpace { + blockShape: [int]; + crops: [int]; +} + +table AddN { + N: int; +} + + +table EmbeddingLookup { + maxNorm: float = 0.0; +} + +table EmbeddingLookupSparse { + spIds: [int]; + spWeights: [float]; + //combiner: Combiner=0; + maxNortm: float; +} + +table FloorDiv { +} + +table FloorMod { +} + +table L2Norm { + axis: [int]; + epsilon: float; +} + +table LogicalAnd { +} + +table LogicalOr { +} + +table LogicalXor { +} + +table LogicalNot { +} + +table MatrixDiag { + k: int; + numRows: int; + numCols: int; + paddingValue: float; +} + +table Select { +} + +table TfReduce { + type: ReduceType = 7; +} + +table Reverse { + axis: [int]; +} + +table Round { +} + +table Scatter { +} + +table ScatterND { +} + +table Unique { + outType: int; +} + +table Unstack { + num: int; + axis: int; +} + +table OnnxInt8Quantize { +} + +table OnnxInt8Dequantize { +} + +table FakeQuantWithMinMax { +} + +table FakeQuantWithMinMaxPerChannel { +} + +table BatchNormFold { +} + +table MulFold { +} + +table AddFold { +} + +table SquaredDifference { +} + +table TupleGetItem { +} + +table OptMomentum { +} + + +table Where{ + condition: [bool]; +} + +table OneHot { + axis: int; +} + +table Lstm{ + bidirection: bool = false; +} + +table PriorBox { + min_sizes: [int]; + max_sizes: [int]; + aspect_ratios: [float]; + variances: [float]; + image_size_w: int; + image_size_h: int; + step_w: float; + step_h: float; + clip: bool = true; + flip: bool = true; + offset: float; +} + +table SpaceToBatchND { + blockShape : [int]; + paddings : [int]; +} + +table TopKV2 { + k : [int]; + sorted : bool = true; +} + diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt new file mode 100644 index 0000000000..8b96e731ab --- /dev/null +++ b/mindspore/lite/src/CMakeLists.txt @@ -0,0 +1,95 @@ +set(LITE_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/ms_tensor_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/workspace_pool.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_factory.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model.cc + ${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc + ) + +if (SUPPORT_GPU) + list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc) + list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc) +endif() + +if (SUPPORT_TRAIN) + set(ANF_SRC +# ${CCSRC_DIR}/common/trans.cc +# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc +# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc +# ${CCSRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc +# ${CCSRC_DIR}/session/lite/session_basic_extends.cc +# ${CCSRC_DIR}/session/anf_runtime_algorithm.cc +# ${CCSRC_DIR}/session/session_basic.cc +# ${CCSRC_DIR}/session/kernel_graph.cc +# ${CCSRC_DIR}/session/session_factory.cc +# ${CCSRC_DIR}/device/kernel_info.cc +# ${CCSRC_DIR}/device/kernel_runtime.cc +# ${CCSRC_DIR}/device/lite/kernel_runtime_extends.cc + ) + set(PASS_SRC) + set(LITE_SRC + ${LITE_SRC} + ${ANF_SRC} + ${PASS_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc + ) +else () + set(LITE_SRC + ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc + ) +endif () + +if (SUPPORT_GPU) + set(LITE_SRC + ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc + ) +endif () + +set(ANF_SRC + ${ANF_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc + ) + +add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC}) +target_link_libraries(mindspore-lite + cpu_kernel_mid_ + ops_mid_ + ) + +add_subdirectory(runtime/kernel/arm) +if (BUILD_MINDDATA) + target_link_libraries(mindspore-lite minddata-eager minddata-lite) + if (PLATFORM_ARM32 OR PLATFORM_ARM64) + target_link_libraries(mindspore-lite log) + endif() +endif () + +add_subdirectory(ops) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64 OR PLATFORM_ARM32)) +add_custom_command(TARGET mindspore-lite POST_BUILD + COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip + ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so) +endif() + diff --git a/mindspore/lite/src/common/anf_exporter/CMakeLists.txt b/mindspore/lite/src/common/anf_exporter/CMakeLists.txt new file mode 100644 index 0000000000..352f59947a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +add_library(anf_exporter_mid OBJECT + ${ANF_SRC_LIST} + ) + diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc new file mode 100644 index 0000000000..ec708411a2 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -0,0 +1,396 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_exporter.h" + +#include +#include +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "base/core_ops.h" +#include "mindspore/core/ir/primitive.h" +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/ir/primitive_t_value.h" +#include "src/ir/tensor.h" +#include "src/param_value_lite.h" + +namespace mindspore::lite { +std::set RemoveNodeInAnfExporter{"tuple_getitem", "make_tuple"}; + +void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { + bool hasMakeTuple = false; + std::vector inputs; + inputs.clear(); + + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto makeTupleNode = utils::cast(inputNode); + if (IsPrimitiveCNode(makeTupleNode, prim::kPrimMakeTuple)) { + hasMakeTuple = true; + for (size_t j = 1; j < makeTupleNode->inputs().size(); ++j) { + inputs.emplace_back(makeTupleNode->input(j)); + } + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (hasMakeTuple) { + cnode->set_inputs(inputs); + } +} + +bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { + bool hasTupleGetItem = false; + std::vector inputs; + inputs.clear(); + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto tupleGetItemNode = utils::cast(inputNode); + if (IsPrimitiveCNode(tupleGetItemNode, prim::kPrimTupleGetItem)) { + hasTupleGetItem = true; + inputs.emplace_back(tupleGetItemNode->input(1)); + AnfNodePtr indexNode = tupleGetItemNode->input(2); + if (!utils::isa(indexNode)) { + MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; + return false; + } + ValueNodePtr valueNode = utils::cast(indexNode); + mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = + GetValue(valueNode->value()); + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (hasTupleGetItem) { + cnode->set_inputs(inputs); + } + return true; +} + +bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr &metaGraphT, const CNodePtr &cnode) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto inputNode = cnode->input(i); + if (!inputNode->isa()) { + MS_LOG(ERROR) << "Node of Return's input is not CNode"; + return false; + } + auto inputCNode = utils::cast(inputNode); + auto inputPrimitive = GetValueNode(inputCNode->input(0)); + std::string inputName = inputNode->fullname_with_scope(); + auto graphOutput = nodeIdMap[inputName]; + metaGraphT->outputIndex.emplace_back(graphOutput); + } + return true; +} + +schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { + auto cnodes = funcGraph->GetOrderedCnodes(); + auto metaGraphT = std::make_unique(); + for (const auto &cnode : cnodes) { + auto primitive = GetValueNode(cnode->input(0)); + if (primitive != nullptr && + RemoveNodeInAnfExporter.count(primitive->name()) != 0) { + continue; + } + mapRemoveGetItem_.clear(); + RemoveIfMakeTuple(cnode); + RemoveIfTupleGetItem(cnode); + if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) { + AddOutPutIfReturn(metaGraphT, cnode); + continue; + } + + auto node = std::make_unique(); + node->name = cnode->fullname_with_scope(); + node->nodeType = schema::NodeType_CNode; + // populate primitive + if (primitive != nullptr) { + primitive = GetValueNode(cnode->input(0)); + MS_ASSERT(primitive != nullptr); + std::string opType = primitive->name(); + auto nodeParser = + AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + if (nodeParser == nullptr) { + MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; + return nullptr; + } + std::vector outputs; + if (utils::isa(cnode->abstract())) { + auto abstract_cnode = + utils::cast(cnode->abstract()); + outputs.resize(abstract_cnode->size()); + } + + nodeParser->Parse(cnode, node.get(), &outputs); + SetOpInputNode(cnode, metaGraphT.get(), node.get()); + SetOpOutputNode(outputs, metaGraphT.get(), node.get()); + metaGraphT->nodes.emplace_back(std::move(node)); + continue; + } + auto primitiveT_value = + GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + return nullptr; + } + + auto *lite_primitive = primitiveT_value->GetPrimitiveT(); + if (lite_primitive == nullptr) { + MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr"; + return nullptr; + } + + node->primitive = + std::unique_ptr(primitiveT_value->GetPrimitiveT()); + std::vector outputs; + SetOpInputNode(cnode, metaGraphT.get(), node.get()); + SetOpOutputNode(outputs, metaGraphT.get(), node.get()); + + // add quant param + node->quantType = primitiveT_value->GetQuantType(); + if (node->quantType == schema::QuantType_PostTraining) { + MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; + // activation + auto activate_index = node->inputIndex[0]; + auto tensor_input = metaGraphT->allTensors[activate_index].get(); + auto input_quant_params = primitiveT_value->GetInputQuantParams(); + if (input_quant_params.empty()) { + MS_LOG(WARNING) << "node: " << node->name + << " input quant params is empty"; + } else { + std::unique_ptr input_quant_param = + std::make_unique(input_quant_params[0]); + tensor_input->quantParams.emplace_back(std::move(input_quant_param)); + } + tensor_input->dataType = kNumberTypeInt8; + // output + auto output_index = node->outputIndex[0]; + auto tensor_output = metaGraphT->allTensors[output_index].get(); + auto output_quant_params = primitiveT_value->GetOutputQuantParams(); + if (output_quant_params.empty()) { + MS_LOG(WARNING) << "node: " << node->name + << " output quant params is empty"; + } else { + std::unique_ptr output_quant_param = + std::make_unique(output_quant_params[0]); + tensor_output->quantParams.emplace_back(std::move(output_quant_param)); + } + tensor_output->dataType = kNumberTypeInt8; + // // TensorType + // valuePtr = primitive->GetAttr(kInputTensorDataType); + // if (valuePtr != nullptr) { + // MS_LOG(INFO) << "node: " << node->name << " input tensor data + // type: " << GetValue(valuePtr); for (auto input : + // node->inputIndex) { + // auto tensor = subGraph->allTensors[input].get(); + // tensor->dataType = kNumberTypeUInt8; + // } + // } + } + + metaGraphT->nodes.emplace_back(std::move(node)); + } + // set graph input tensors + for (auto node : graphInputNodes) { + for (auto input : node->inputIndex) { + auto tensor = metaGraphT->allTensors[input].get(); + if (tensor->data.empty()) { + tensor->nodeType = schema::NodeType_ValueNode; + tensor->format = schema::Format_NHWC; + // tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT; + metaGraphT->inputIndex.emplace_back(input); + } + } + } + return metaGraphT.release(); +} + +void AnfExporter::SetOpInputNode(const CNodePtr &cnode, + schema::MetaGraphT *meta_graph, + schema::CNodeT *fbNode) { + MS_ASSERT(nullptr != meta_graph); + MS_ASSERT(nullptr != fbNode); + if (cnode->inputs().size() <= 1) { + return; + } + std::string cNodeName = cnode->fullname_with_scope(); + bool isGraphInput = true; + for (int i = 1; i < static_cast(cnode->inputs().size()); i++) { + auto inputNode = cnode->input(i); + if (inputNode->isa()) { + isGraphInput = false; + std::string inputName = inputNode->fullname_with_scope(); + if (!mapRemoveGetItem_.empty()) { + for (auto name : mapRemoveGetItem_) { + if (name.first == inputName) { + inputName = inputName + "_o:" + std::to_string(name.second); + } + } + } + if (nodeIdMap.find(inputName) != nodeIdMap.end()) { + fbNode->inputIndex.emplace_back(nodeIdMap[inputName]); + } + } else if (inputNode->isa()) { + auto paramNode = inputNode->cast(); + if (paramNode->name().empty()) { + paramNode->set_name(cNodeName + "_i:" + std::to_string(i - 1)); + } + if (nodeIdMap.find(paramNode->name()) != nodeIdMap.end()) { + fbNode->inputIndex.emplace_back(nodeIdMap[paramNode->name()]); + continue; + } + auto paramTensor = std::make_unique(); + auto abstractBase = paramNode->abstract(); + if (abstractBase == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " + << paramNode->name(); + MS_ASSERT(false); + return; + } + if (!utils::isa(abstractBase)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " + << paramNode->name(); + MS_ASSERT(false); + return; + } + auto abstractTensor = + utils::cast(abstractBase); + auto typePtr = abstractTensor->element()->GetTypeTrack(); + MS_ASSERT(typePtr != nullptr); + paramTensor->dataType = typePtr->type_id(); + if (!utils::isa(abstractTensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " + << paramNode->name(); + MS_ASSERT(false); + return; + } + paramTensor->dims = + utils::cast(abstractTensor->BuildShape()) + ->shape(); + auto paramValue = + std::dynamic_pointer_cast(paramNode->default_param()); + if (paramValue != nullptr) { + paramTensor->nodeType = schema::NodeType_ValueNode; + paramTensor->data.resize(paramValue->tensor_size()); + memcpy(paramTensor->data.data(), paramValue->tensor_addr(), + paramValue->tensor_size()); + for (auto &ite : paramValue->quant_param()) { + auto quantPar = std::make_unique(); + quantPar->scale = ite->scale; + quantPar->zeroPoint = ite->zeroPoint; + quantPar->min = ite->min; + quantPar->max = ite->max; + quantPar->narrowRange = ite->narrowRange; + quantPar->inited = ite->inited; + quantPar->numBits = ite->numBits; + paramTensor->quantParams.emplace_back(std::move(quantPar)); + paramTensor->dataType = paramValue->tensor_type(); + } + } + nodeIdMap[paramNode->fullname_with_scope()] = + meta_graph->allTensors.size(); + fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(paramTensor)); + } else if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + auto paramTensor = std::make_unique(); + auto value = valueNode->value(); + if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractTensor = + utils::cast(valueAbstract); + auto typePtr = abstractTensor->element()->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = + utils::cast(abstractTensor->BuildShape()) + ->shape(); + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.resize(data->Size()); + memcpy(paramTensor->data.data(), data->Data(), data->Size()); + nodeIdMap[valueNode->fullname_with_scope()] = + meta_graph->allTensors.size(); + fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractScalar = utils::cast(valueAbstract); + auto typePtr = abstractScalar->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = {1}; + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.emplace_back(data->value()); + nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); + fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + MS_LOG(INFO) << "Value type is ValueSequence."; + break; + } else { + MS_LOG(ERROR) << "Not support value type , need add support."; + } + } + } + if (isGraphInput) { + graphInputNodes.emplace_back(fbNode); + } +} + +void AnfExporter::SetOpOutputNode( + const std::vector &outputTensors, + schema::MetaGraphT *graph, schema::CNodeT *cnode) { + MS_ASSERT(nullptr != graph); + MS_ASSERT(nullptr != cnode); + std::string cnodeName = cnode->name; + if (!outputTensors.empty()) { + int i = 0; + for (auto outputTensor : outputTensors) { + std::string name = cnodeName + "_o:" + std::to_string(i); + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; + nodeIdMap[name] = graph->allTensors.size(); + cnode->outputIndex.emplace_back(graph->allTensors.size()); + graph->allTensors.emplace_back(msTensor); + i++; + } + return; + } + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; + cnode->outputIndex.emplace_back(graph->allTensors.size()); + nodeIdMap[cnodeName] = graph->allTensors.size(); + graph->allTensors.emplace_back(msTensor); +} + +schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) { + AnfExporter anfExporter; + return anfExporter.Export(funcGraph); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.h b/mindspore/lite/src/common/anf_exporter/anf_exporter.h new file mode 100644 index 0000000000..8cb04e9d72 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.h @@ -0,0 +1,50 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ +#define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ + +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +class AnfExporter { + public: + AnfExporter() = default; + virtual ~AnfExporter() = default; + schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); + void SetOpOutputNode(const std::vector &outputTensors, schema::MetaGraphT *graph, + schema::CNodeT *cnode); + void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); + void RemoveIfMakeTuple(const CNodePtr &cnode); + bool RemoveIfTupleGetItem(const CNodePtr &cnode); + bool AddOutPutIfReturn(const std::unique_ptr &metaGraphT, const CNodePtr &cnode); + private: + std::map nodeIdMap; + std::vector graphInputNodes; + std::map mapRemoveGetItem_; +}; + +schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ + diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc new file mode 100644 index 0000000000..bf6a66e57d --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + if (p->name() == "ReLU") { + attr->type = schema::ActivationType_RELU; + } else if (p->name() == "Sigmoid") { + attr->type = schema::ActivationType_SIGMOID; + } else if (p->name() == "ReLU6") { + attr->type = schema::ActivationType_RELU6; + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Activation; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h new file mode 100644 index 0000000000..daa4add19c --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfActivationPopulater : public AnfNodePopulater { + public: + AnfActivationPopulater() = default; + ~AnfActivationPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc new file mode 100644 index 0000000000..d8013aed14 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + attr->epsilon = GetValue(p->GetAttr("epsilon")); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h new file mode 100644 index 0000000000..1df83a87ac --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H +#define MINDSPORE_ANF_BATCHNORM_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfBatchnormParser : public AnfNodePopulater { + public: + AnfBatchnormParser() = default; + ~AnfBatchnormParser() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_BATCHNORM_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc new file mode 100644 index 0000000000..ad59e89936 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + attr->axis = {0}; + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_BiasAdd; + node->primitive->value.value = attr.release(); + return 0; +} + +AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h new file mode 100644 index 0000000000..6256e20567 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_BIASADD_PARSER_H +#define MINDSPORE_ANF_BIASADD_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfBiasAddPopulater : public AnfNodePopulater { + public: + AnfBiasAddPopulater() = default; + ~AnfBiasAddPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_BIASADD_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc new file mode 100644 index 0000000000..1b4596205a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc @@ -0,0 +1,45 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_concat_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + + auto prim_axis = GetValue(p->GetAttr("axis")); + attr->axis = prim_axis; + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Concat; + node->primitive->value.value = attr.release(); + + return 0; +} + +AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h new file mode 100644 index 0000000000..9a9915dcb5 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h @@ -0,0 +1,32 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_CONCAT_PARSER_H +#define MINDSPORE_ANF_CONCAT_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfConcatPopulater : public AnfNodePopulater { + public: + AnfConcatPopulater() = default; + ~AnfConcatPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_CONCAT_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc new file mode 100644 index 0000000000..c2dcee77ec --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc @@ -0,0 +1,121 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + int group = GetValue(p->GetAttr("group")); + + if (group > 1) { + auto attr = std::make_unique(); + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(p->GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(p->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(p->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(p->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + node->primitive->value.value = attr.release(); + } else { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(p->GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(p->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(p->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = GetValue(p->GetAttr("out_channel")); + + auto pad_mode = GetValue(p->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Conv2D; + node->primitive->value.value = attr.release(); + } + return 0; +} + +AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h new file mode 100644 index 0000000000..88edda0951 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h @@ -0,0 +1,32 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_CONV_PARSER_H +#define MINDSPORE_ANF_CONV_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfConvPopulater : public AnfNodePopulater { + public: + AnfConvPopulater() = default; + ~AnfConvPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_CONV_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc new file mode 100644 index 0000000000..8bb4c79771 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(p->GetAttr("pads")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(p->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(p->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(p->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + auto channel_multiplier = GetValue(p->GetAttr("channel_multiplier")); + attr->channelMultiplier = channel_multiplier; + + MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); + auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto paramNode = inputNode->cast(); + auto abstractBase = paramNode->abstract(); + MS_ASSERT(abstractBase != nullptr); + if (utils::isa(abstractBase)) { + auto abstractTensor = utils::cast(abstractBase); + MS_ASSERT(abstractTensor != nullptr); + if (utils::isa(abstractTensor->BuildShape())) { + auto dims = utils::cast(abstractTensor->BuildShape())->shape(); + attr->channelIn = dims[kAnfPopulaterOne]; + } + } + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h new file mode 100644 index 0000000000..de96776d6f --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H +#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfDepwiseconv2DPopulater : public AnfNodePopulater { + public: + AnfDepwiseconv2DPopulater() = default; + ~AnfDepwiseconv2DPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc new file mode 100644 index 0000000000..a08bf67d68 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h new file mode 100644 index 0000000000..12017ad60b --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_dequant_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H +#define MINDSPORE_ANF_DEQUANT_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfDequantPopulater : public AnfNodePopulater { + public: + AnfDequantPopulater() = default; + ~AnfDequantPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_DEQUANT_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc new file mode 100644 index 0000000000..8ba27f99a7 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Flatten; + node->primitive->value.value = attr.release(); + return 0; +} + +AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h new file mode 100644 index 0000000000..f2cf48ab02 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_FLATTEN_PARSER_H +#define MINDSPORE_ANF_FLATTEN_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfFlattenPopulater : public AnfNodePopulater { + public: + AnfFlattenPopulater() = default; + ~AnfFlattenPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_FLATTEN_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc new file mode 100644 index 0000000000..909ceec01a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + attr->transposeA = GetValue(p->GetAttr("transpose_a")); + attr->transposeB = GetValue(p->GetAttr("transpose_b")); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_MatMul; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h new file mode 100644 index 0000000000..752e8eff31 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_MATMUL_PARSER_H +#define MINDSPORE_ANF_MATMUL_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfMatmulPopulater : public AnfNodePopulater { + public: + AnfMatmulPopulater() = default; + ~AnfMatmulPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_MATMUL_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc new file mode 100644 index 0000000000..4f5c3beec8 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_mul_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Mul; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h new file mode 100644 index 0000000000..87f526cf7a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfMulPopulater : public AnfNodePopulater { + public: + AnfMulPopulater() = default; + ~AnfMulPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc new file mode 100644 index 0000000000..4045e0e043 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc @@ -0,0 +1,19 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" + +namespace mindspore::lite {} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h new file mode 100644 index 0000000000..3d9accb75e --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_NODE_PARSER_H +#define MINDSPORE_ANF_NODE_PARSER_H + +#include +#include "ir/anf.h" +#include "schema/inner/model_generated.h" +namespace mindspore::lite { +constexpr int kAnfPopulaterOne = 1; +constexpr int kAnfPopulaterTwo = 2; +constexpr int kAnfPopulaterThree = 3; +class AnfNodePopulater { + public: + AnfNodePopulater() = default; + virtual ~AnfNodePopulater() = default; + virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) = 0; +}; + +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_NODE_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc new file mode 100644 index 0000000000..1ac99bd3f1 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include +#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" +namespace mindspore { +namespace lite { +AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { + static AnfNodePopulaterRegistry instance; + return &instance; +} +AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { + if (parsers.find(name) == parsers.end()) { + return nullptr; + } + return parsers[name]; +} +void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *parser) { + parsers[name] = parser; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h new file mode 100644 index 0000000000..321d4b5fb3 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H +#define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +#include +namespace mindspore::lite { +class AnfNodePopulaterRegistry { + public: + AnfNodePopulaterRegistry() = default; + virtual ~AnfNodePopulaterRegistry() = default; + static AnfNodePopulaterRegistry *GetInstance(); + AnfNodePopulater *GetNodePopulater(const std::string &name); + void SetNodePopulater(const std::string &name, AnfNodePopulater *parser); + + private: + std::unordered_map parsers; +}; + +class AnfNodePopulaterRegistrar { + public: + AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *parser) { + AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, parser); + } +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_NODE_PARSER_REGISTRY_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc new file mode 100644 index 0000000000..8c70bb46ae --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + if (p->instance_name() == "MaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } else if (p->instance_name() == "MeanPool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } + + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + + auto pad_mode = GetValue(p->GetAttr("padding")); + if (pad_mode == "VALID") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "SAME") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + auto kernel_size = GetValue>(p->GetAttr("ksize")); + attr->windowH = kernel_size[2]; + attr->windowW = kernel_size[3]; + + auto stride = GetValue>(p->GetAttr("strides")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Pooling; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfMaxPoolParser("MaxPool", new AnfPoolPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h new file mode 100644 index 0000000000..a677e7baca --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_POOL_PARSER_H +#define MINDSPORE_ANF_POOL_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfPoolPopulater : public AnfNodePopulater { + public: + AnfPoolPopulater() = default; + ~AnfPoolPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_POOL_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc new file mode 100644 index 0000000000..964f00c2a5 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_quant_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h new file mode 100644 index 0000000000..87a593b459 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_quant_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_QUANT_PARSER_H +#define MINDSPORE_ANF_QUANT_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfQuantPopulater : public AnfNodePopulater { + public: + AnfQuantPopulater() = default; + ~AnfQuantPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_QUANT_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc new file mode 100644 index 0000000000..8ec0a93cfe --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_reducemean_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +namespace { + constexpr int kReduceInputNum = 3; + constexpr int kReduceInputIndex = 2; +} +int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + attr->mode = schema::ReduceMode_ReduceMean; + + attr->keepDims = GetValue(p->GetAttr("keep_dims")); + if (cnodePtr->inputs().size() == kReduceInputNum) { + auto inputNode = cnodePtr->input(kReduceInputIndex); + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto value = valueNode->value(); + MS_ASSERT(value != nullptr); + if (value->isa()) { + auto valTuplPtr = dyn_cast(value); + MS_ASSERT(valTuplPtr != nullptr); + for (size_t i = 0; i < valTuplPtr->size(); i++) { + auto elem = dyn_cast((*valTuplPtr)[i]); + MS_ASSERT(elem != nullptr); + attr->axes.emplace_back(elem->value()); + } + } + } + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Reduce; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfReduceMeanParser("ReduceMean", new AnfReduceMeanPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h new file mode 100644 index 0000000000..16ac3b0c7e --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfReduceMeanPopulater : public AnfNodePopulater { + public: + AnfReduceMeanPopulater() = default; + ~AnfReduceMeanPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc new file mode 100644 index 0000000000..de5007ca5f --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Flatten; + node->primitive->value.value = attr.release(); + return 0; +} + +AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h new file mode 100644 index 0000000000..776aab0f94 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H +#define MINDSPORE_ANF_RESHAPE_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfReshapePopulater : public AnfNodePopulater { + public: + AnfReshapePopulater() = default; + ~AnfReshapePopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_RESHAPE_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc new file mode 100644 index 0000000000..e220e45b41 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfTensorAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Add; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfTensorAddParser("TensorAdd", new AnfTensorAddPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h new file mode 100644 index 0000000000..d8ff59bba7 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfTensorAddPopulater : public AnfNodePopulater { + public: + AnfTensorAddPopulater() = default; + ~AnfTensorAddPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc new file mode 100644 index 0000000000..76eafbec64 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_transpose_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + + MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); + auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + if (inputNode->isa()) { + auto valNode = inputNode->cast(); + MS_ASSERT(valNode != nullptr); + auto val = valNode->value(); + MS_ASSERT(val != nullptr); + if (val->isa()) { + auto tuple = val->cast(); + MS_ASSERT(tuple != nullptr); + for (size_t i = 0; i < tuple->size(); i++) { + auto elem = tuple->value()[i]->cast(); + MS_ASSERT(elem != nullptr); + attr->perm.emplace_back(static_cast(elem->value())); + } + } + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Transpose; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfTransposeParser("Transpose", new AnfTransposePopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h new file mode 100644 index 0000000000..eecdbb7593 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H +#define MINDSPORE_ANF_TRANSPOSE_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfTransposePopulater : public AnfNodePopulater { + public: + AnfTransposePopulater() = default; + ~AnfTransposePopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_TRANSPOSE_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc new file mode 100644 index 0000000000..9f6092f4ae --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfTupleGetItemPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_TupleGetItem; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfTupleGetItemParser("tuple_getitem", new AnfTupleGetItemPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h new file mode 100644 index 0000000000..3acf2638c3 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H +#define MINDSPORE_ANF_BATCHNORM_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfTupleGetItemPopulater : public AnfNodePopulater { + public: + AnfTupleGetItemPopulater() = default; + ~AnfTupleGetItemPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_BATCHNORM_PARSER_H diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.cc b/mindspore/lite/src/common/anf_importer/anf_importer.cc new file mode 100644 index 0000000000..eb9f84eca3 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/anf_importer.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "src/common/anf_importer/anf_importer.h" +#include "schema/model_generated.h" +#include "ir/dtype.h" +#include "ir/primitive.h" +#include "src/param_value_lite.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "src/ir/primitive_value.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +#if 0 +PrimitivePtr SetConv2DAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_Conv2D(); + PrimitivePtr prim; + if (attrs->group() > 1) { + prim = std::make_shared("DepthwiseConv2D"); + prim->set_instance_name("DepthwiseConv2D"); + } else { + prim = std::make_shared("Conv2D"); + prim->set_instance_name("Conv2D"); + } + + prim->set_attr("group", MakeValue(attrs->group())); + prim->set_attr("format", MakeValue(attrs->format())); + prim->set_attr("pad_mode", MakeValue(attrs->padMode())); + std::vector pad_list = {attrs->padUp(), attrs->padDown(), attrs->padLeft(), attrs->padRight()}; + prim->set_attr("pad_list", MakeValue>(pad_list)); + std::vector dilate = {attrs->dilateH(), attrs->dilateW()}; + prim->set_attr("dilation", MakeValue>(dilate)); + std::vector kernel_size = {attrs->kernelH(), attrs->kernelW()}; + prim->set_attr("kernel_size", MakeValue>(kernel_size)); + std::vector stride = {1, 1, attrs->strideH(), attrs->strideW()}; + prim->set_attr("stride", MakeValue>(stride)); + prim->set_attr("out_channel", MakeValue(attrs->channelOut())); + prim->set_attr("group", MakeValue(attrs->group())); + return prim; +} + +PrimitivePtr SetActivationAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_Activation(); + PrimitivePtr prim; + if (attrs->type() == schema::ActivationType_RELU) { + prim = std::make_shared("ReLU"); + prim->set_instance_name("ReLU"); + } + return prim; +} + +PrimitivePtr SetPoolingAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_Pooling(); + PrimitivePtr prim; + if (attrs->poolingMode() == schema::PoolMode_MAX_POOLING) { + prim = std::make_shared("MaxPool"); + prim->set_instance_name("MaxPool"); + } else if (attrs->poolingMode() == schema::PoolMode_MEAN_POOLING) { + prim = std::make_shared("MeanPool"); + prim->set_instance_name("MeanPool"); + } + + prim->set_attr("format", MakeValue(attrs->format())); + prim->set_attr("pad_mode", MakeValue(attrs->padMode())); + prim->set_attr("ksize", MakeValue>(std::vector({1, 1, attrs->windowH(), attrs->windowW()}))); + prim->set_attr("strides", MakeValue>(std::vector({1, 1, attrs->strideH(), attrs->strideW()}))); + return prim; +} + +PrimitivePtr SetFlattenAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("Flatten"); + prim->set_instance_name("Flatten"); + return prim; +} + +PrimitivePtr SetMatmulAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_MatMul(); + auto prim = std::make_shared("Matmul"); + prim->set_instance_name("Matmul"); + prim->set_attr("transpose_a", MakeValue(attrs->transposeA())); + prim->set_attr("transpose_b", MakeValue(attrs->transposeB())); + return prim; +} + +PrimitivePtr SetMulAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + // auto attrs = nodedef->attr_as_Mul(); + auto prim = std::make_shared("Mul"); + prim->set_instance_name("Mul"); + return prim; +} + +PrimitivePtr SetSigmoidAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("Sigmoid"); + prim->set_instance_name("Sigmoid"); + return prim; +} + +PrimitivePtr SetReduceAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("ReduceMean"); + prim->set_instance_name("ReduceMean"); + return prim; +} + +PrimitivePtr SetBatchNormAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive_as_BatchNorm(); + auto prim = std::make_shared("BatchNorm"); + prim->set_attr("is_training", MakeValue(attrs->is_training())); + prim->set_instance_name("BatchNorm"); + return prim; +} + +PrimitivePtr SetBiasAddAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("BiasAdd"); + prim->set_instance_name("BiasAdd"); + return prim; +} + +PrimitivePtr SetAddAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("Add"); + prim->set_instance_name("Add"); + return prim; +} + +void MinnieBuildGraph::FbTest(const GraphDef *graph_def) { + auto node_def = graph_def->subgraphs()->begin()->nodes()->GetAs(3); + PrimitivePtr prim = ConverterOperatorAttr(node_def); + if (prim->GetAttr("format")) MS_LOG(INFO) << "find format"; + if (prim->GetAttr("group")) MS_LOG(INFO) << "find group"; +} +#endif + +int AnfImporter::Import() { + ConverterConstTensor(); + auto ret = ConverterCNode(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "ConverterCNode failed " << ret; + return ret; + } + AddReturnCNode(); + return RET_OK; +} + +AnfNodePtr AnfImporter::GetNode(int tensor_id) { + auto n = nodes_.find(tensor_id); + if (n == nodes_.end()) { + return nullptr; + } + return n->second; +} + +void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.h b/mindspore/lite/src/common/anf_importer/anf_importer.h new file mode 100644 index 0000000000..3281294f40 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/anf_importer.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ + +#include +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "base/base.h" + +namespace mindspore::lite { +class AnfImporter { + public: + AnfImporter() = default; + + virtual ~AnfImporter() = default; + + virtual int Import(); + + virtual FuncGraphPtr GetResult() = 0; + + protected: + // convert const tensor into parameter and save in nodes_ + virtual void ConverterConstTensor() = 0; + // convert other node into cnode and save in nodes_ + virtual int ConverterCNode() = 0; + + virtual void AddReturnCNode() = 0; + + AnfNodePtr GetNode(int tensor_id); + + void AddNode(int tensor_id, AnfNodePtr node); + + protected: + std::unordered_map nodes_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc new file mode 100644 index 0000000000..6ec0ba8c54 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_importer/import_from_meta_graph.h" +#include +#include +#include +#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" +#include "src/ir/primitive_value.h" +#include "include/errorcode.h" + +namespace mindspore::lite { +void AnfImporterFromMetaGraph::ConverterConstTensor() { + MS_EXCEPTION_IF_NULL(model); + auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + for (size_t i = 0; i < meta_graph->allTensors()->size(); i++) { + auto *tensor = meta_graph->allTensors()->GetAs(i); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor->nodeType() != schema::NodeType_ValueNode) { + continue; + } + MS_ASSERT(tensor->dims() != nullptr); + auto parameter = model->add_parameter(); + std::vector shape; + for (size_t j = 0; j < tensor->dims()->size(); ++j) { + shape.push_back(tensor->dims()->data()[j]); + } + auto type_id = static_cast(tensor->dataType()); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, shape); + parameter->set_abstract(abstract_tensor); + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + if (tensor->data() != nullptr) { + auto size = tensor->data()->size(); + char *tensor_data = new char[size](); + std::memcpy(tensor_data, tensor->data()->data(), size); + MS_EXCEPTION_IF_NULL(tensor_data); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + } + parameter->set_default_param(param_value); + AddNode(i, parameter); + } +} + +int AnfImporterFromMetaGraph::ConverterCNode() { + MS_EXCEPTION_IF_NULL(model); + auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + auto cNodes = meta_graph->nodes(); + for (size_t i = 0; i < cNodes->size(); i++) { + auto cNode = cNodes->GetAs(i); + MS_EXCEPTION_IF_NULL(cNode); + auto tensor_id = cNode->outputIndex()->data()[0]; + if (GetNode(tensor_id)) { + continue; + } + + auto prim = std::make_shared(model->GetOp(cNode->name()->str())); + if (prim == nullptr) { + MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; + return RET_ERROR; + } + auto value_node = NewValueNode(prim); + AddNode(tensor_id, value_node); + + std::vector op_inputs = {value_node}; + MS_EXCEPTION_IF_NULL(cNode->inputIndex()); + for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { + auto node = GetNode(*(cNode->inputIndex()->GetAs(j))); + if (nullptr == node) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_ERROR; + } + // todo: CheckInputNodeType, the first node should be op; + op_inputs.push_back(node); + } + auto cnode = model->NewCNode(op_inputs); + auto node_name = std::string(cNode->name()->c_str()); + cnode->set_fullname_with_scope(node_name); + AddNode(tensor_id, cnode); + } + return RET_OK; +} + +void AnfImporterFromMetaGraph::AddReturnCNode() { + MS_EXCEPTION_IF_NULL(model); + auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + std::vector op_inputs; + auto value_node = NewValueNode(prim::kPrimReturn); + op_inputs.push_back(value_node); + auto tensor_id = meta_graph->outputIndex()->data()[0]; + op_inputs.push_back(GetNode(tensor_id)); + auto cnode = model->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + model->set_return(cnode); +} +FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model; } +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h new file mode 100644 index 0000000000..fd34930f1c --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ + +#include +#include "src/train/model_impl.h" +#include "schema/model_generated.h" +#include "src/common/anf_importer/anf_importer.h" + +namespace mindspore::lite { +class AnfImporterFromMetaGraph : public AnfImporter { + public: + explicit AnfImporterFromMetaGraph(std::shared_ptr model) : model(model) {} + + ~AnfImporterFromMetaGraph() override = default; + + FuncGraphPtr GetResult() override; + + private: + void ConverterConstTensor() override; + + int ConverterCNode() override; + + void AddReturnCNode() override; + + private: + std::shared_ptr model = nullptr; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc new file mode 100644 index 0000000000..c470d6a6e3 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" +#include "import_from_meta_graphT.h" +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" +#include "src/ir/primitive_value.h" +#include "src/ir/primitive_t_value.h" +#include "include/errorcode.h" +#include "src/ops/ops.h" + +namespace mindspore::lite { +void AnfImporterFromMetaGraphT::ConverterConstTensor() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { + auto &tensor = meta_graph_->allTensors.at(i); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor->nodeType != schema::NodeType_ValueNode) { + continue; + } + MS_ASSERT(tensor->dims() != nullptr); + auto parameter = func_graph_->add_parameter(); + std::vector shape; + for (int &dim : tensor->dims) { + shape.push_back(dim); + } + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, shape); + parameter->set_abstract(abstract_tensor); + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + if (!tensor->data.empty()) { + auto size = tensor->data.size(); + char *tensor_data = new char[size]; + std::memcpy(tensor_data, tensor->data.data(), size); + MS_EXCEPTION_IF_NULL(tensor_data); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + } + parameter->set_default_param(param_value); + AddNode(i, parameter); + } +} + +int AnfImporterFromMetaGraphT::ConverterCNode() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + for (size_t i = 0; i < meta_graph_->nodes.size(); i++) { + auto &cNode = meta_graph_->nodes.at(i); + MS_EXCEPTION_IF_NULL(cNode); + auto tensor_id = cNode->outputIndex.front(); + if (nullptr != GetNode(tensor_id)) { + continue; + } + + auto primTValue = std::make_shared(cNode->primitive.release()); + cNode->primitive = nullptr; + auto value_node = NewValueNode(primTValue); + + std::vector op_inputs = {value_node}; + for (size_t j = 0; j < cNode->inputIndex.size(); j++) { + auto node = GetNode(cNode->inputIndex.at(j)); + if (nullptr == node) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_ERROR; + } + // todo: CheckInputNodeType, the first node should be op; + op_inputs.push_back(node); + } + auto cnode = func_graph_->NewCNode(op_inputs); + cnode->set_fullname_with_scope(cNode->name); + AddNode(tensor_id, cnode); + } + return RET_OK; +} + +void AnfImporterFromMetaGraphT::AddReturnCNode() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + std::vector make_tuple_inputs; + auto make_tuple_value_node = NewValueNode(prim::kPrimMakeTuple); + make_tuple_inputs.emplace_back(make_tuple_value_node); + for (auto tensor_id : meta_graph_->outputIndex) { + make_tuple_inputs.emplace_back(GetNode(tensor_id)); + } + auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope("return tuple"); + + std::vector op_inputs; + auto value_node = NewValueNode(prim::kPrimReturn); + op_inputs.emplace_back(value_node); + op_inputs.emplace_back(make_tuple_cnode); + auto cnode = func_graph_->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + func_graph_->set_return(cnode); +} + +FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h new file mode 100644 index 0000000000..5b3799a256 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ + +#include + +#include "schema/inner/model_generated.h" +#include "src/common/anf_importer/anf_importer.h" + +namespace mindspore::lite { +class AnfImporterFromMetaGraphT : public AnfImporter { + public: + explicit AnfImporterFromMetaGraphT(schema::MetaGraphT *meta_graph, FuncGraphPtr func_graph) + : meta_graph_(meta_graph), func_graph_(std::move(func_graph)) {} + + ~AnfImporterFromMetaGraphT() override = default; + + FuncGraphPtr GetResult() override; + + private: + void ConverterConstTensor() override; + + int ConverterCNode() override; + + void AddReturnCNode() override; + + private: + schema::MetaGraphT *meta_graph_; + FuncGraphPtr func_graph_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ + diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc new file mode 100644 index 0000000000..904e0bd0ad --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -0,0 +1,1188 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_importer/import_from_protobuf.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/ops.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "include/errorcode.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "src/ir/tensor.h" +#include "src/param_value_lite.h" +#include "tools/converter/parser/onnx/onnx.pb.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" + +using string = std::string; +using int32 = int32_t; +using int64 = int64_t; +using uint64 = uint64_t; + +namespace mindspore::lite { + +static constexpr char kConstantValueNode[] = "Constant"; +static constexpr char kCNodeShapeAttr[] = "shape"; +static constexpr char kCNodeShape1Attr[] = "shape1"; +static constexpr char kCNodeShape2Attr[] = "shape2"; + +enum ParseForm : int { + FORM_PARSE_TYPE = 0, + FORM_PARSE_SCALAR = 1, + FORM_PARSE_TENSOR = 2, +}; + +static std::map kParseTypeSwitchMap{ + {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; + +static std::unordered_map kDefaultValueSwitchMap{ + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, +}; + +#if 0 +std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, + const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("scalar:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + std::stack rules; + std::stack value; + int num = 0, count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + num++; + } + } + } + return {}; +} + +std::shared_ptr +ParserAttrShape(const std::string &attr_name, const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("shape:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + std::stack rules; + std::stack value; + int num = 0, count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + num++; + } + } + } + return {}; +} + +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ + if (attr_tensor.type##_data_size() == 1) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ + } else { \ + MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ + } \ + return {}; \ + } + +PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) +PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) +PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) +PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) +PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) + +bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto) { + MS_EXCEPTION_IF_NULL(node); + if (!value_proto.has_type() || !value_proto.has_name()) { + MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; + return false; + } + node->set_name(value_proto.name()); + const auto &type_proto = value_proto.type(); + if (!type_proto.has_tensor_type()) { + MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; + return false; + } + const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); + if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; + return false; + } + const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); + std::vector shape; + for (int i = 0; i < tensor_shape.dim_size(); ++i) { + shape.push_back(tensor_shape.dim(i).dim_value()); + } + + if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; + return false; + } + + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape); + node->set_abstract(abstract_tensor); + + if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { + tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); + MS_EXCEPTION_IF_NULL(tensor_info); + tensor_info->MallocData(); + const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; + std::string initial_data = initialize_proto.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + MS_EXCEPTION_IF_NULL(tensor_data_buf); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); + if (EOK != ret) { + MS_LOG(ERROR) << "memcpy_s error"; + return false; + } + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_addr(tensor_data_buf); + param_value->set_tensor_size(tensor_info->Size()); + node->set_default_param(param_value); + } + anfnode_build_map_[value_proto.name()] = node; + return true; +} + +bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); + + for (int i = 0; i < importProto.initializer_size(); ++i) { + const onnx::TensorProto &initializer_proto = importProto.initializer(i); + if (!initializer_proto.has_name()) { + MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; + return false; + } + default_para_map_[initializer_proto.name()] = initializer_proto; + } + + MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); + for (int i = 0; i < importProto.input_size(); ++i) { + const onnx::ValueInfoProto &input_proto = importProto.input(i); + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { + MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; + return false; + } + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; + return false; + } + prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + return true; +} + +ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + switch (attr_tensor_type) { + case onnx::TensorProto_DataType_STRING: { + return ParseAttrInScalar_string_string(attr_tensor); + } + case onnx::TensorProto_DataType_INT32: { + return ParseAttrInScalar_int32_int32(attr_tensor); + } + case onnx::TensorProto_DataType_INT64: { + return ParseAttrInScalar_int64_int64(attr_tensor); + } + case onnx::TensorProto_DataType_UINT64: { + return ParseAttrInScalar_uint64_uint64(attr_tensor); + } + case onnx::TensorProto_DataType_FLOAT: { + return ParseAttrInScalar_float_float(attr_tensor); + } + case onnx::TensorProto_DataType_DOUBLE: { + return ParseAttrInScalar_double_double(attr_tensor); + } + case onnx::TensorProto_DataType_BOOL: { + return ParseAttrInScalar_int32_bool(attr_tensor); + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + return {}; + } + return {}; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; + return false; +} + +bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; + } + } + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + std::unordered_map::iterator iter = kv.begin(); + prim->AddAttr(attr_name, iter->second); + } else { + auto res = ParserScalarAttrValue(ref_attr_name, kv); + prim->AddAttr(attr_name, res); + } + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + std::vector shape; + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor_info->MallocData(); + const std::string &tensor_buf = attr_tensor.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); + if (EOK != ret) { + MS_LOG(ERROR) << "memcpy_s error"; + return false; + } + auto new_value_node = NewValueNode(MakeValue(tensor_info)); + MS_EXCEPTION_IF_NULL(new_value_node); + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); + auto abstract_tensor = std::make_shared(type_ptr, shape); + new_value_node->set_abstract(abstract_tensor); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; + return false; + } + auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); + new_value_node->set_abstract(abs_type); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_proto) { + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; + } + } + + ValueNodePtr new_value_node; + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + std::unordered_map::iterator iter = kv.begin(); + new_value_node = NewValueNode(iter->second); + new_value_node->set_abstract(iter->second->ToAbstract()); + } else { + auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); + new_value_node = NewValueNode(value_ptr); + new_value_node->set_abstract(value_ptr->ToAbstract()); + } + anfnode_build_map_[value_node_name] = new_value_node; + } + return true; +} + +bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { + const std::string &value_node_name = node_proto.output(0); + const onnx::AttributeProto &attr_proto = node_proto.attribute(0); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; + return false; + } + return GetAttrValueForValueNode(value_node_name, attr_proto); +} + +std::unordered_map +AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + std::vector shape_vec; + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + for (int j = 0; j < attr_tensor.dims_size(); ++j) { + shape_vec.push_back(attr_tensor.dims(j)); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape_vec); + kv.insert(std::pair(attr_tensor.name(), abstract_tensor)); + } + return kv; +} + +CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + if (!node_proto.has_op_type()) { + MS_LOG(ERROR) << "Get CNode op_type failed!"; + return nullptr; + } + const std::string &node_name = node_proto.output(0); + const std::string &fullname_with_scope = node_proto.domain(); + const std::string &node_type = node_proto.op_type(); + PrimitivePtr prim = std::make_shared(node_type); + MS_EXCEPTION_IF_NULL(prim); + prim->set_instance_name(node_type); + std::unordered_map kv; + string shape_ref_attr_name; + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const onnx::AttributeProto &attr_proto = node_proto.attribute(i); + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + shape_ref_attr_name = attr_proto.ref_attr_name(); + kv = GetAbstractForCNode(attr_proto); + continue; + } + if (!GetAttrValueForCNode(prim, attr_proto)) { + MS_LOG(ERROR) << "Get CNode attr failed!"; + return nullptr; + } + } + + std::vector inputs; + inputs.clear(); + inputs.push_back(NewValueNode(prim)); + for (int i = 0; i < node_proto.input_size(); ++i) { + const std::string &input_name = node_proto.input(i); + if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; + return nullptr; + } + inputs.push_back(anfnode_build_map_[input_name]); + } + CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + if (0 == kv.size()) { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (1 == kv.size()) { + std::unordered_map::iterator iter = kv.begin(); + cnode_ptr->set_abstract(iter->second); + } else { + auto abstract = ParserAttrShape(shape_ref_attr_name, kv); + cnode_ptr->set_abstract(abstract); + } + + cnode_ptr->set_fullname_with_scope(fullname_with_scope); + anfnode_build_map_[node_name] = cnode_ptr; + return cnode_ptr; +} + +bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_EXCEPTION_IF_NULL(cnode_ptr); + std::vector inputs; + if (importProto.output_size() > 1) { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + AbstractBasePtrList elem; + for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { + const onnx::ValueInfoProto &output_node = importProto.output(out_size); + const std::string &out_tuple = output_node.name(); + inputs.push_back(anfnode_build_map_[out_tuple]); + elem.push_back(anfnode_build_map_[out_tuple]->abstract()); + } + auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); + maketuple_ptr->set_abstract(std::make_shared(elem)); + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(maketuple_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success."; + } else { + const onnx::ValueInfoProto &output_node = importProto.output(0); + const onnx::TypeProto &output_typeproto = output_node.type(); + int output_type = output_typeproto.tensor_type().elem_type(); + std::vector output_shape; + for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { + output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); + auto abstract_tensor = std::make_shared(type_ptr, output_shape); + + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(cnode_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + return_node->set_abstract(abstract_tensor); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success!"; + } + return true; +} + +bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); + CNodePtr cnode_ptr = nullptr; + for (int i = 0; i < importProto.node_size(); ++i) { + const onnx::NodeProto &node_proto = importProto.node(i); + const std::string &node_type = node_proto.op_type(); + if (node_type == kConstantValueNode) { + if (!BuildValueNodeForFuncGraph(node_proto)) { + MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; + return false; + } + continue; + } + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; + return false; + } + } + + BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); + return true; +} +#endif + +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ + const onnx::TensorProto &attr_tensor) { \ + MS_EXCEPTION_IF_NULL(prim); \ + std::vector attr_value_vec; \ + for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ + auto value = static_cast(attr_tensor.type##_data(i)); \ + attr_value_vec.push_back(MakeValue(value)); \ + } \ + if (attr_value_vec.size() == 1) { \ + prim->AddAttr(attr_name, attr_value_vec[0]); \ + } else { \ + prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ + } \ + } + +PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) +PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) +PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) +PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) +PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) + +bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto) { + MS_EXCEPTION_IF_NULL(node); + if (!value_proto.has_type() || !value_proto.has_name()) { + MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; + return false; + } + node->set_name(value_proto.name()); + const auto &type_proto = value_proto.type(); + if (!type_proto.has_tensor_type()) { + MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; + return false; + } + const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); + if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; + return false; + } + const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); + std::vector shape; + for (int i = 0; i < tensor_shape.dim_size(); ++i) { + shape.push_back(tensor_shape.dim(i).dim_value()); + } + + if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; + return false; + } + + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape); + node->set_abstract(abstract_tensor); + + if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { + tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); + MS_EXCEPTION_IF_NULL(tensor_info); + tensor_info->MallocData(); + const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; + std::string initial_data = initialize_proto.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + MS_EXCEPTION_IF_NULL(tensor_data_buf); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); + if (EOK != ret) { + MS_LOG(ERROR) << "memcpy_s error"; + return false; + } + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_addr(tensor_data_buf); + param_value->set_tensor_size(tensor_info->Size()); + node->set_default_param(param_value); + } + anfnode_build_map_[value_proto.name()] = node; + return true; +} + +bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); + + for (int i = 0; i < importProto.initializer_size(); ++i) { + const onnx::TensorProto &initializer_proto = importProto.initializer(i); + if (!initializer_proto.has_name()) { + MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; + return false; + } + default_para_map_[initializer_proto.name()] = initializer_proto; + } + + MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); + for (int i = 0; i < importProto.input_size(); ++i) { + const onnx::ValueInfoProto &input_proto = importProto.input(i); + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { + MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; + return false; + } + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; + return false; + } + prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_tensor.data_type(); + switch (attr_tensor_type) { + case onnx::TensorProto_DataType_STRING: { + ParseAttrInScalar_string_string(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_INT32: { + ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_INT64: { + ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_UINT64: { + ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_FLOAT: { + ParseAttrInScalar_float_float(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_DOUBLE: { + ParseAttrInScalar_double_double(prim, attr_name, attr_tensor); + break; + } + case onnx::TensorProto_DataType_BOOL: { + ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor); + auto value = prim->GetAttr(attr_name); + break; + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + return false; + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; + return false; +} + +bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + const onnx::TensorProto &attr_tensor = attr_proto.t(); + switch (kParseTypeSwitchMap[ref_attr_name]) { + case FORM_PARSE_TYPE: { + return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_TENSOR: { + return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; + } +} +bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + std::vector shape; + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor_info->MallocData(); + const std::string &tensor_buf = attr_tensor.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); + if (EOK != ret) { + MS_LOG(ERROR) << "memcpy_s error"; + return false; + } + auto new_value_node = NewValueNode(MakeValue(tensor_info)); + MS_EXCEPTION_IF_NULL(new_value_node); + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); + auto abstract_tensor = std::make_shared(type_ptr, shape); + new_value_node->set_abstract(abstract_tensor); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + ValuePtr value_ptr = nullptr; + switch (attr_tensor_type) { + case onnx::TensorProto_DataType_INT32: { + std::vector add_data; + for (int i = 0; i < attr_tensor.int32_data_size(); ++i) { + add_data.push_back(attr_tensor.int32_data(i)); + } + if (add_data.size() == 1) { + value_ptr = MakeValue(add_data[0]); + } else if (!add_data.empty()) { + value_ptr = MakeValue>(add_data); + } + break; + } + case onnx::TensorProto_DataType_FLOAT: { + std::vector add_data; + for (int i = 0; i < attr_tensor.float_data_size(); ++i) { + add_data.push_back(attr_tensor.float_data(i)); + } + + if (add_data.size() == 1) { + value_ptr = MakeValue(add_data[0]); + } else if (!add_data.empty()) { + value_ptr = MakeValue>(add_data); + } + break; + } + case onnx::TensorProto_DataType_UNDEFINED: { + std::vector elems; + value_ptr = std::make_shared(elems); + break; + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + return false; + } + auto new_value_node = NewValueNode(value_ptr); + MS_EXCEPTION_IF_NULL(new_value_node); + new_value_node->set_abstract(value_ptr->ToAbstract()); + anfnode_build_map_[value_node_name] = new_value_node; + + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; + return false; + } + auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); + new_value_node->set_abstract(abs_type); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, + const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + switch (kParseTypeSwitchMap[ref_attr_name]) { + case FORM_PARSE_SCALAR: { + return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); + } + case FORM_PARSE_TENSOR: { + return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); + } + case FORM_PARSE_TYPE: { + return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; + return false; + } +} + +bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { + const std::string &value_node_name = node_proto.output(0); + const onnx::AttributeProto &attr_proto = node_proto.attribute(0); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + const onnx::TensorProto &attr_tensor = attr_proto.t(); + + return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); +} + +abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { + std::vector shape_vec; + const onnx::TensorProto &attr_tensor = attr_proto.t(); + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape_vec.push_back(attr_tensor.dims(i)); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape_vec); + MS_EXCEPTION_IF_NULL(abstract_tensor); + return abstract_tensor; +} + +CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + if (!node_proto.has_op_type()) { + MS_LOG(ERROR) << "Get CNode op_type failed!"; + return nullptr; + } + const std::string &node_name = node_proto.output(0); + const std::string &fullname_with_scope = node_proto.domain(); + const std::string &node_type = node_proto.op_type(); + PrimitivePtr prim = std::make_shared(node_type); + MS_EXCEPTION_IF_NULL(prim); + prim->set_instance_name(node_type); + + abstract::AbstractTensorPtr abstract = nullptr; + abstract::AbstractTensorPtr abstract_first = nullptr; + abstract::AbstractTensorPtr abstract_second = nullptr; + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const onnx::AttributeProto &attr_proto = node_proto.attribute(i); + if (attr_proto.name() == kCNodeShapeAttr) { + abstract = GetAbstractForCNode(attr_proto); + continue; + } + if (attr_proto.name() == kCNodeShape1Attr) { + abstract_first = GetAbstractForCNode(attr_proto); + continue; + } + if (attr_proto.name() == kCNodeShape2Attr) { + abstract_second = GetAbstractForCNode(attr_proto); + continue; + } + if (!GetAttrValueForCNode(prim, attr_proto)) { + MS_LOG(ERROR) << "Get CNode attr failed!"; + return nullptr; + } + } + + std::vector inputs; + inputs.clear(); + inputs.push_back(NewValueNode(prim)); + for (int i = 0; i < node_proto.input_size(); ++i) { + const std::string &input_name = node_proto.input(i); + if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; + return nullptr; + } + inputs.push_back(anfnode_build_map_[input_name]); + } + CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + if (node_type == "LayerNorm") { + AbstractBasePtrList elem; + elem.push_back(abstract); + elem.push_back(abstract_first); + elem.push_back(abstract_second); + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (node_type == "ArgMaxWithValue") { + AbstractBasePtrList elem; + elem.push_back(abstract); + elem.push_back(abstract_first); + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (nullptr == abstract) { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); + } else { + cnode_ptr->set_abstract(abstract); + } + cnode_ptr->set_fullname_with_scope(fullname_with_scope); + anfnode_build_map_[node_name] = cnode_ptr; + return cnode_ptr; +} + +bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_EXCEPTION_IF_NULL(cnode_ptr); + std::vector inputs; + if (importProto.output_size() > 1) { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + AbstractBasePtrList elem; + for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { + const onnx::ValueInfoProto &output_node = importProto.output(out_size); + const std::string &out_tuple = output_node.name(); + inputs.push_back(anfnode_build_map_[out_tuple]); + elem.push_back(anfnode_build_map_[out_tuple]->abstract()); + } + auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); + maketuple_ptr->set_abstract(std::make_shared(elem)); + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(maketuple_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success."; + } else { + const onnx::ValueInfoProto &output_node = importProto.output(0); + const onnx::TypeProto &output_typeproto = output_node.type(); + int output_type = output_typeproto.tensor_type().elem_type(); + std::vector output_shape; + for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { + output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); + auto abstract_tensor = std::make_shared(type_ptr, output_shape); + + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(cnode_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + return_node->set_abstract(abstract_tensor); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success!"; + } + return true; +} + +bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); + CNodePtr cnode_ptr = nullptr; + for (int i = 0; i < importProto.node_size(); ++i) { + const onnx::NodeProto &node_proto = importProto.node(i); + const std::string &node_type = node_proto.op_type(); + if (node_type == kConstantValueNode) { + if (!BuildValueNodeForFuncGraph(node_proto)) { + MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; + return false; + } + continue; + } + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; + return false; + } + } + + BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); + return true; +} + +bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); + MS_EXCEPTION_IF_NULL(debug_info_ptr); + if (importProto.has_name()) { + debug_info_ptr->set_name(importProto.name()); + } else { + MS_LOG(ERROR) << "FuncGraph under converting has not name!"; + } + + if (!ImportParametersForGraph(outputFuncGraph, importProto)) { + return false; + } + return ImportNodesForGraph(outputFuncGraph, importProto); +} + +bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { + if (!model_proto.has_producer_name()) { + MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; + return false; + } + producer_name_ = model_proto.producer_name(); + + if (!model_proto.has_model_version()) { + MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; + return false; + } + model_version_ = model_proto.model_version(); + + if (!model_proto.has_ir_version()) { + MS_LOG(ERROR) << "Parse model version from pb file failed!"; + return false; + } + ir_version_ = model_proto.ir_version(); + return true; +} + +int AnfImporterFromProtobuf::Import() { + FuncGraphPtr dstGraph = std::make_shared(); + MS_EXCEPTION_IF_NULL(dstGraph); + if (!ParseModelConfigureInfo(*onnx_model_)) { + MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; + } + const onnx::GraphProto &graphBuild = onnx_model_->graph(); + if (!BuildFuncGraph(dstGraph, graphBuild)) { + MS_LOG(ERROR) << "Build funcgraph failed!"; + return RET_ERROR; + } + func_graph_ = dstGraph; + MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; + return RET_OK; +} + +onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { + std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); + if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { + MS_LOG(ERROR) << "open file failed."; + return nullptr; + } + int fd = open(onnx_file.get(), O_RDONLY); + google::protobuf::io::FileInputStream input(fd); + google::protobuf::io::CodedInputStream code_input(&input); + code_input.SetTotalBytesLimit(INT_MAX, 536870912); + auto onnx_model = new onnx::ModelProto; + bool ret = onnx_model->ParseFromCodedStream(&code_input); + if (!ret) { + MS_LOG(ERROR) << "load onnx file failed"; + delete onnx_model; + return nullptr; + } + (void)close(fd); + MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; + return onnx_model; +} + +FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; } +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h new file mode 100644 index 0000000000..4513c79f17 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ + +#include +#include +#include +#include + +#include "tools/converter/parser/onnx/onnx.pb.h" +#include "src/common/anf_importer/anf_importer.h" +#include "abstract/abstract_value.h" + +namespace mindspore::lite { +class AnfImporterFromProtobuf : public AnfImporter { + public: + explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) + : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} + + ~AnfImporterFromProtobuf() override = default; + + static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); + + FuncGraphPtr GetResult() override; + + int Import() override; + + private: + void ConverterConstTensor() override {}; + int ConverterCNode() override {}; + void AddReturnCNode() override {}; + bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); + bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); +#if 0 + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); + bool BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr); + bool GetAttrValueForCNode(const PrimitivePtr &prim, + const onnx::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, + const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, + const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); + bool ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_tensor); + bool ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor); + std::unordered_map + GetAbstractForCNode(const onnx::AttributeProto &attr_proto); +#endif + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); + bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr); + bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); + bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + + bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, + const onnx::TensorProto &attr_tensor); + bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); + + + + private: + std::string producer_name_; + int model_version_{}; + int ir_version_{}; + std::unordered_map anfnode_build_map_; + std::map default_para_map_; + onnx::ModelProto *onnx_model_; + FuncGraphPtr func_graph_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ + diff --git a/mindspore/lite/src/common/common.h b/mindspore/lite/src/common/common.h new file mode 100755 index 0000000000..ed12c49686 --- /dev/null +++ b/mindspore/lite/src/common/common.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_COMMON_H_ +#define MINDSPORE_LITE_COMMON_COMMON_H_ + +#include +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +enum NCHW_SHAPE { NCHW_N = 0, NCHW_C = 1, NCHW_H = 2, NCHW_W = 3 }; +enum NHWC_SHAPE { NHWC_N = 0, NHWC_H = 1, NHWC_W = 2, NHWC_C = 3 }; +enum HWCK_SHAPE { HWCK_H = 0, HWCK_W = 1, HWCK_C = 2, HWCK_K = 3 }; +enum HWKC_SHAPE { HWKC_H = 0, HWKC_W = 1, HWKC_K = 2, HWKC_C = 3 }; +enum KCHW_SHAPE { KCHW_K = 0, KCHW_C = 1, KCHW_H = 2, KCHW_W = 3 }; +enum CKHW_SHAPE { CKHW_C = 0, CKHW_K = 1, CKHW_H = 2, CKHW_W = 3 }; +enum CHWK_SHAPE { CHWK_C = 0, CHWK_H = 1, CHWK_W = 2, CHWK_K = 3 }; +enum KHWC_SHAPE { KHWC_K = 0, KHWC_H = 1, KHWC_W = 2, KHWC_C = 3 }; +enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 }; +enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 }; +static constexpr int kNCHWDimNumber = 4; +static constexpr int kNHWCDimNumber = 4; + +static constexpr int TENSOR_MAX_REFCOUNT = 999; + +static const char *DELIM_COLON = ":"; +static const char *DELIM_COMMA = ","; +static const char *DELIM_SLASH = "/"; +static const char *DELIM_DOUBLE_BACKSLASH = "\\"; + +// quantization relative +static const char QUANTIZED_UINT8[] = "QUANTIZED_UINT8"; +static const char QUANTIZED_INT8[] = "QUANTIZED_INT8"; +static const char QUANTIZED_INT16[] = "QUANTIZED_INT16"; +static const char QUANTIZED_UINT16[] = "QUANTIZED_UINT16"; +static const char QUANTIZED_FLOAT16[] = "FLOAT16"; +static const char QUANTIZED_FLOAT32[] = "FLOAT32"; +static const char QUANTIZATION_TYPE_DYNAMIC[] = "DYNAMIC"; +static const char QUANTIZATION_TYPE_STATIC[] = "STATIC"; +static const char CALIB_NORM[] = "NORM"; + +// dims +static const int32_t DIM_DEFAULT_SIZE = 4; + +static const schema::Format DEFAULT_FORMAT = schema::Format_NCHW; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_COMMON_H_ + diff --git a/mindspore/lite/src/common/file_utils.cc b/mindspore/lite/src/common/file_utils.cc new file mode 100644 index 0000000000..7babb87b47 --- /dev/null +++ b/mindspore/lite/src/common/file_utils.cc @@ -0,0 +1,167 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "src/common/file_utils.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +#define MAX_FILENAME_LEN 1024 +char *ReadFile(const char *file, size_t *size) { + if (file == nullptr) { + MS_LOG(ERROR) << "file is nullptr"; + return nullptr; + } + MS_ASSERT(size != nullptr); + std::string realPath = RealPath(file); + std::ifstream ifs(realPath); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << realPath << " is not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << realPath << " open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + std::unique_ptr buf(new (std::nothrow) char[*size]); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), *size); + ifs.close(); + + return buf.release(); +} + +std::string RealPath(const char *path) { + if (path == nullptr) { + MS_LOG(ERROR) << "path is nullptr"; + return ""; + } + if ((strlen(path)) >= PATH_MAX) { + MS_LOG(ERROR) << "path is too long"; + return ""; + } + std::shared_ptr resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); + if (resolvedPath == nullptr) { + MS_LOG(ERROR) << "new resolvedPath failed"; + return ""; + } + std::string realPath = realpath(path, resolvedPath.get()); + if (realPath.empty()) { + MS_LOG(ERROR) << "Proto file path is not valid"; + return ""; + } + std::string res = resolvedPath.get(); + + return res; +} + +int WriteToBin(const std::string &file_path, void *data, size_t size) { + std::ofstream out_file; + + out_file.open(file_path.c_str(), std::ios::binary); + if (!out_file.good()) { + return -1; + } + + if (!out_file.is_open()) { + out_file.close(); + return -1; + } + out_file.write(reinterpret_cast(data), size); + return 0; +} + +int CompareOutputData(float *output_data, float *correct_data, int data_size) { + float error = 0; + for (size_t i = 0; i < data_size; i++) { + float abs = fabs(output_data[i] - correct_data[i]); + if (abs > 0.00001) { + error += abs; + } + } + error /= data_size; + if (error > 0.0001) { + printf("has accuracy error!\n"); + printf("%f\n", error); + return 1; + } + return 0; +} + +void CompareOutput(float *output_data, std::string file_path) { + size_t output_size; + auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); + size_t output_num = output_size / sizeof(float); + printf("output num : %zu\n", output_num); + CompareOutputData(output_data, ground_truth, output_num); +} + +// std::string GetAndroidPackageName() { +// static std::string packageName; +// +// if (!packageName.empty()) { +// return packageName; +// } +// +// char cmdline[MAX_FILENAME_LEN] = {0}; +// int fd = open("/proc/self/cmdline", O_RDONLY); +// +// if (fd >= 0) { +// char ch; +// int i = 0; +// while (read(fd, &ch, sizeof(ch)) > 0 && !isspace(ch)) { +// if (':' == ch) { +// break; +// } +// +// if (('/' == ch) || ('\\' == ch)) { +// (void)memset(cmdline, 0, sizeof(cmdline)); +// i = 0; +// } else { +// cmdline[i] = ch; +// i++; +// } +// } +// close(fd); +// } +// packageName = std::string(cmdline); +// return packageName; +//} + +// std::string GetAndroidPackagePath() { +// std::string packageName = GetAndroidPackageName(); +// if (packageName.empty()) { +// return "./"; +// } +// return "/data/data/" + packageName + '/'; +//} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/file_utils.h b/mindspore/lite/src/common/file_utils.h new file mode 100644 index 0000000000..ff1ec03e64 --- /dev/null +++ b/mindspore/lite/src/common/file_utils.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_FILE_UTILS_H_ +#define MINDSPORE_LITE_COMMON_FILE_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "src/common/utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +char *ReadFile(const char *file, size_t *size); + +std::string RealPath(const char *path); + +template +void WriteToTxt(const std::string& file_path, void *data, size_t element_size) { + std::ofstream out_file; + out_file.open(file_path, std::ios::out); + auto real_data = reinterpret_cast(data); + for (size_t i = 0; i < element_size; i++) { + out_file << real_data[i] << " "; + } + out_file.close(); +} + +int WriteToBin(const std::string& file_path, void *data, size_t size); + +int CompareOutputData(float *output_data, float *correct_data, int data_size); +void CompareOutput(float *output_data, std::string file_path); + +std::string GetAndroidPackageName(); +std::string GetAndroidPackagePath(); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_FILE_UTILS_H_ + diff --git a/mindspore/lite/src/common/file_utils_ext.cc b/mindspore/lite/src/common/file_utils_ext.cc new file mode 100644 index 0000000000..cdaa337e23 --- /dev/null +++ b/mindspore/lite/src/common/file_utils_ext.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" + +namespace mindspore { +namespace lite { +static int CompareOutputRelativeData(float *output_data, float *correct_data, int data_size) { + float error = 0; + + // relative error + float diffSum = 0.0f; + float sum = 0.0f; + for (int i = 0; i < data_size; i++) { + sum += std::abs(correct_data[i]); + } + for (int i = 0; i < data_size; i++) { + float diff = std::abs(output_data[i] - correct_data[i]); + diffSum += diff; + } + error = diffSum / sum; + if (error > 1e-4) { + std::cout << "has accuracy error!\n" << error << "\n"; + return 1; + } + return 0; +} + +int CompareRelativeOutput(float *output_data, std::string file_path) { + size_t output_size; + auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); + size_t output_num = output_size / sizeof(float); + std::cout << "output num : " << output_num << "\n"; + return CompareOutputRelativeData(output_data, ground_truth, output_num); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/file_utils_ext.h b/mindspore/lite/src/common/file_utils_ext.h new file mode 100644 index 0000000000..28eea02e41 --- /dev/null +++ b/mindspore/lite/src/common/file_utils_ext.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_ +#define MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_ +#include + + +namespace mindspore { +namespace lite { +int CompareRelativeOutput(float *output_data, std::string file_path); + +} +} // namespace mindspore +#endif // MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_ diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc new file mode 100755 index 0000000000..e094c744ab --- /dev/null +++ b/mindspore/lite/src/common/graph_util.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/common/graph_util.h" +#include "src/common/utils.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +std::vector GetGraphInputNodes(const schema::MetaGraph *meta_graph) { + MS_ASSERT(nullptr != meta_graph); + std::vector ret; + for (size_t i = 0; i < meta_graph->inputIndex()->size(); i++) { + auto input_index = meta_graph->inputIndex()->GetAs(i); + for (size_t j = 0; j < meta_graph->nodes()->size(); j++) { + auto *cNode = meta_graph->nodes()->GetAs(j); + MS_ASSERT(nullptr != cNode); + for (size_t k = 0; k < cNode->inputIndex()->size(); k++) { + if (cNode->inputIndex()->GetAs(k) == input_index) { + if (!IsContain(ret, j)) { + ret.emplace_back(j); + } + break; + } + } + } + } + return std::move(ret); +} + +std::vector GetGraphOutputNodes(const schema::MetaGraph *meta_graph) { + MS_ASSERT(nullptr != meta_graph); + std::vector ret; + for (size_t i = 0; i < meta_graph->outputIndex()->size(); i++) { + auto output_index = meta_graph->outputIndex()->GetAs(i); + for (size_t j = 0; j < meta_graph->nodes()->size(); j++) { + auto *cNode = meta_graph->nodes()->GetAs(j); + MS_ASSERT(nullptr != cNode); + for (size_t k = 0; k < cNode->outputIndex()->size(); k++) { + if (cNode->outputIndex()->GetAs(k) == output_index) { + if (!IsContain(ret, j)) { + ret.emplace_back(j); + } + break; + } + } + } + } + return std::move(ret); +} + +// NODE_ID OpNode::ID() { return id; } +// +// void OpNode::AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); } +// +// void OpNode::AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); } +// +// std::unordered_set OpNode::GetAllInEdges() { return inEdges; } +// +// std::unordered_set OpNode::GetAllOutEdges() { return outEdges; } + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/graph_util.h b/mindspore/lite/src/common/graph_util.h new file mode 100755 index 0000000000..e9a9e994fe --- /dev/null +++ b/mindspore/lite/src/common/graph_util.h @@ -0,0 +1,250 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ +#define MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ + +#include +#include +#include +#include +#include +#include "schema/model_generated.h" +#include "utils//log_adapter.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +using NODE_ID = std::string; + +std::vector GetGraphInputNodes(const schema::MetaGraph *meta_graph); + +std::vector GetGraphOutputNodes(const schema::MetaGraph *meta_graph); + +class OpNode { + public: + explicit OpNode(const NODE_ID &nodeId) : id(nodeId) {} + NODE_ID ID() { return id; }; + void AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); } + void AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); } + std::unordered_set GetAllInEdges() { return inEdges; } + std::unordered_set GetAllOutEdges() { return outEdges; } + + protected: + NODE_ID id; + std::unordered_set inEdges; + std::unordered_set outEdges; +}; + + +template +class OpGraph { + public: + OpGraph() {} + + ~OpGraph(); + + int Build(const schema::MetaGraph *subGraphDef); + NODE_T *GetNode(NODE_ID nodeId); + NODE_T *AddNode(NODE_ID nodeId); + std::unordered_set GetInputNode(); + std::unordered_set GetOutputNode(); + + void AddNodes(std::vector addNodes); + void DeleteNodes(std::vector deleteNodes); + + void AddEdge(NODE_ID nodeId); + int AddEdge(NODE_ID srcId, NODE_ID dstId); + int AddEdge(const schema::CNode *srcNodeDef, const flatbuffers::Vector> *opDefs); + std::unordered_map> GetDepends(); + + protected: + std::unordered_map nodes; +}; + +template +int OpGraph::Build(const schema::MetaGraph *subGraphDef) { + if (subGraphDef == nullptr) { + // MS_LOGE("subGraphDef is nullptr"); + return RET_ERROR; + } + + + auto opDefs = subGraphDef->nodes(); + + uint32_t opCount = opDefs->size(); + for (uint32_t i = 0; i < opCount; i++) { + auto opDef = opDefs->GetAs(i); + auto node = AddNode(std::string(opDef->name()->c_str())); + if (node == nullptr) { + // MS_LOGE("add srcNode failed,name %s", opDef->name()->c_str()); + return RET_ERROR; + } + auto ret = AddEdge(opDef, opDefs); + if (ret != RET_OK) { + // MS_LOGE("%s add edge failed. ret:%d", opDef->name()->c_str(), ret); + return RET_ERROR; + } + } + + return RET_OK; +} +template +int OpGraph::AddEdge(const schema::CNode *srcNodeDef, + const flatbuffers::Vector> *nodeDefs) { + MS_ASSERT(srcNodeDef != nullptr); + MS_ASSERT(nodeDefs != nullptr); + NODE_ID srcId = std::string(srcNodeDef->name()->c_str()); + uint32_t opCount = nodeDefs->size(); + // for single op condition + AddNode(srcId); + for (auto index : *(srcNodeDef->outputIndex())) { + for (uint32_t i = 0; i < opCount; i++) { + auto dstNodeDef = nodeDefs->GetAs(i); + bool find = false; + auto inputIndex = dstNodeDef->inputIndex(); + if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) { + find = true; + } + + if (!find) { + continue; + } + NODE_ID dstId = std::string(dstNodeDef->name()->c_str()); + auto ret = AddEdge(srcId, dstId); + if (ret != RET_OK) { + return ret; + } + } + } + + return RET_OK; +} + +template +int OpGraph::AddEdge(NODE_ID srcId, NODE_ID dstId) { + auto srcNode = AddNode(srcId); + if (srcNode == nullptr) { + // MS_LOGE("add srcNode failed"); + return RET_ERROR; + } + auto dstNode = AddNode(dstId); + if (dstNode == nullptr) { + // MS_LOGE("add dstNode failed"); + return RET_ERROR; + } + + srcNode->AddOutEdge(dstNode); + + dstNode->AddInEdge(srcNode); + return RET_OK; +} + +template +NODE_T *OpGraph::GetNode(NODE_ID nodeId) { + auto node = nodes.find(nodeId); + if (node == nodes.end()) { + return nullptr; + } + return node->second; +} + +template +NODE_T *OpGraph::AddNode(NODE_ID nodeId) { + auto node = GetNode(nodeId); + if (node != nullptr) { + return node; + } + node = new (std::nothrow) NODE_T(nodeId); + if (node == nullptr) { + // MS_LOGE("new node failed"); + return nullptr; + } + nodes[nodeId] = node; + return node; +} + +template +void OpGraph::AddNodes(std::vector addNodes) { + for (auto node : addNodes) { + if (node == nullptr) { + return; + } + + nodes[node->ID()] = node; + } +} + +template +void OpGraph::DeleteNodes(std::vector deleteNodes) { + for (auto deletenode : deleteNodes) { + if (deletenode == nullptr) { + continue; + } + auto node = GetNode(deletenode->ID()); + if (node == nullptr) { + continue; + } + nodes.erase(deletenode->ID()); + } +} + +template +std::unordered_set OpGraph::GetInputNode() { + std::unordered_set inputNodes; + for (const auto &iter : nodes) { + auto node = iter.second; + if (node->GetAllInEdges().empty()) { + inputNodes.insert(node); + } + } + return inputNodes; +} + +template +std::unordered_set OpGraph::GetOutputNode() { + std::unordered_set outputNodes; + for (const auto &iter : nodes) { + auto node = iter.second; + if (node->GetAllOutEdges().empty()) { + outputNodes.insert(node); + } + } + return outputNodes; +} + +template +std::unordered_map> OpGraph::GetDepends() { + std::unordered_map> depends; + for (auto nodeIter : nodes) { + depends[nodeIter.second] = nodeIter.second->GetAllInEdges(); + } + return depends; +} + +template +OpGraph::~OpGraph() { + for (auto iter : nodes) { + delete iter.second; + } + nodes.clear(); +} + +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ + diff --git a/mindspore/lite/src/common/graph_utils_extends.cc b/mindspore/lite/src/common/graph_utils_extends.cc new file mode 100644 index 0000000000..7bf2993de9 --- /dev/null +++ b/mindspore/lite/src/common/graph_utils_extends.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/graph_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/visitor.h" +#include "ir/func_graph.h" + #include "utils/label.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace { +class DeepFirstSearcher { + public: + explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} + ~DeepFirstSearcher() = default; + + std::vector Search(const AnfNodePtr &root) { + if (root == nullptr) { + return res_; + } + seen_ = NewSeenGeneration(); + Visit(root); + return res_; + } + + void Visit(const AnfNodePtr &node) { + if (node == nullptr) { + return; + } + if (node->seen_ == seen_) { + return; + } + + node->seen_ = seen_; + + auto incl = include_(node); + if (incl == EXCLUDE) { + return; + } + if (filter_ == nullptr || !filter_(node)) { + res_.push_back(node); + } + if (incl == FOLLOW) { + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + Visit(*iter); + } + return; + } + } + } + + private: + size_t seen_{0}; + IncludeFunc include_; + FilterFunc filter_; + std::vector res_{}; +}; + +class DeepScopedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepScopedGraphSearcher() = default; + + void Visit(const CNodePtr &cnode) { return; } + + void Visit(const ValueNodePtr &vnode) { + if (!IsValueNode(vnode)) { + return; + } + + auto graph = GetValueNode(vnode); + AnfNodePtr ret = graph->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } + + void Visit(const ParameterPtr ¶m) { + if (param->func_graph() == nullptr) { + return; + } + + AnfNodePtr ret = param->func_graph()->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } +}; + +class DeepUsedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepUsedGraphSearcher() = default; + + void Visit(const CNodePtr &cnode) { return; } + + void Visit(const ValueNodePtr &vnode) { return; } +}; + +class DeepLinkedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepLinkedGraphSearcher() = default; + + void Visit(const CNodePtr &cnode) { return; } + + void Visit(const ValueNodePtr &) {} +}; +} // namespace + +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepScopedGraphSearcher(include).Search(root); +} + +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepUsedGraphSearcher(include).Search(root); +} + +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepLinkedGraphSearcher(include).Search(root); +} + +} // namespace mindspore + diff --git a/mindspore/lite/src/common/ms_tensor_utils.cc b/mindspore/lite/src/common/ms_tensor_utils.cc new file mode 100644 index 0000000000..44d04afbc8 --- /dev/null +++ b/mindspore/lite/src/common/ms_tensor_utils.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/ms_tensor_utils.h" + +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace tensor { +using mindspore::lite::tensor::LiteTensor; +using mindspore::lite::tensor::Tensor; + +std::vector PackToMSTensors(const std::vector &in_tensors) { + std::vector ret; + for (auto *lite_tensor : in_tensors) { + MS_ASSERT(lite_tensor != nullptr); + auto *ms_tensor = new (std::nothrow) LiteTensor(lite_tensor); + if (ms_tensor == nullptr) { + MS_LOG(ERROR) << "new LiteTensor failed"; + return ret; + } + ret.emplace_back(ms_tensor); + } + return ret; +} +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/lite/src/common/ms_tensor_utils.h b/mindspore/lite/src/common/ms_tensor_utils.h new file mode 100644 index 0000000000..fc68d0e951 --- /dev/null +++ b/mindspore/lite/src/common/ms_tensor_utils.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MS_TENSOR_UTILS_H +#define LITE_MS_TENSOR_UTILS_H + +#include +#include "include/ms_tensor.h" +#include "src/ir/tensor.h" + +namespace mindspore { +namespace tensor { +std::vector PackToMSTensors(const std::vector &in_tensors); +} +} // namespace mindspore + +#endif // LITE_MS_TENSOR_UTILS_H diff --git a/mindspore/lite/src/common/op_utils.h b/mindspore/lite/src/common/op_utils.h new file mode 100755 index 0000000000..68a4217114 --- /dev/null +++ b/mindspore/lite/src/common/op_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_OP_UTILS_H_ +#define MINDSPORE_LITE_COMMON_OP_UTILS_H_ + +#include +#include +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); } +inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); } +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_OP_UTILS_H_ + diff --git a/mindspore/lite/src/common/utils.cc b/mindspore/lite/src/common/utils.cc new file mode 100644 index 0000000000..bb2e1e9c2b --- /dev/null +++ b/mindspore/lite/src/common/utils.cc @@ -0,0 +1,262 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __ANDROID__ +#include +#endif +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +std::vector StringSplit(std::string str, const std::string& pattern) { + std::vector result; + if (str.empty()) { + return result; + } + std::string::size_type pos; + str += pattern; + auto size = str.size(); + + for (size_t i = 0; i < size; i++) { + pos = str.find(pattern, i); + if (pos < size) { + std::string s = str.substr(i, pos - i); + result.push_back(s); + i = pos + pattern.size() - 1; + } + } + return result; +} + +uint64_t GetTimeUs() { + struct timespec ts = {0, 0}; + if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { + return 0; + } + // USECS_IN_SEC *NSECS_IN_USEC; + uint64_t retval = static_cast((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); + return retval; +} + +static const unsigned int FP32_BIT_SIZE = 32; +static const unsigned int FP32_EXPONENT_BIAS = 127; +static const unsigned int FP32_SIGNIFICAND = 23; + +static const unsigned int FP32_EXPONENT_MAX = 255; + +static const unsigned int FP16_BIT_SIZE = 16; +static const unsigned int FP16_EXPONENT_BIAS = 15; +static const unsigned int FP16_SIGNIFICAND = 10; + +static const int FP16_EXPONENT_MAX = 30; +static const int FP16_EXPONENT_MIN = -10; + +// fp16.c +float ShortToFloat32(int16_t srcValue) { + uint16_t expHalf16 = srcValue & 0x7C00; + int exp1 = static_cast(expHalf16); + uint16_t mantissa16 = srcValue & 0x03FF; + int mantissa1 = static_cast(mantissa16); + int sign = static_cast(srcValue & 0x8000); + sign = sign << FP16_BIT_SIZE; + + // nan or inf + if (expHalf16 == 0x7C00) { + // nan + if (mantissa16 > 0) { + int res = (0x7FC00000 | sign); + int *iRes = &res; + auto fres = static_cast(*iRes); + return fres; + } + // inf + int res = (0x7F800000 | sign); + int *iRes = &res; + auto fres = static_cast(*iRes); + return fres; + } + if (expHalf16 != 0) { + exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias + int res = (exp1 | mantissa1); + res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND); + res = (res | sign); + int *iRes = &res; + auto fres = static_cast(*iRes); + return fres; + } + + int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND); + xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND) + << FP32_SIGNIFICAND); // add the bias difference to xmm1 + xmm1 = xmm1 | sign; // Combine with the sign mask + + auto res = static_cast(mantissa1); // Convert mantissa to float + int *ixmm1 = nullptr; + ixmm1 = &xmm1; + res *= static_cast(*ixmm1); + + return res; +} + +// __gnu_f2h_ieee +int16_t Float32ToShort(float srcValue) { + float *psrcValue = nullptr; + psrcValue = &srcValue; + auto srcValueBit = static_cast(*psrcValue); + int sign = srcValueBit >> (FP32_BIT_SIZE - 1); + int mantissa = srcValueBit & 0x007FFFFF; + // exponent + int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; + int16_t res; + if (exp > 0 && exp < FP16_EXPONENT_MAX) { + // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. + res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | + ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } else if (srcValueBit == 0) { + res = 0; + } else { + if (exp <= 0) { + if (exp < FP16_EXPONENT_MIN) { + // value is less than min half float point + res = 0; + } else { + // normalized single, magnitude is less than min normal half float point. + mantissa = (mantissa | 0x00800000) >> (1 - exp); + // round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + } + // combine sign & mantissa (exp is zero to get denormalized number) + res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { + if (mantissa == 0) { + // input float is infinity, return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // input float is NaN, return half NaN + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else { + // exp > 0, normalized single, round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + if ((mantissa & 0x00800000) > 0) { + mantissa = 0; + exp = exp + 1; + } + } + if (exp > FP16_EXPONENT_MAX) { + // exponent overflow - return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // combine sign, exp and mantissa into normalized half + res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | + (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } + } + return res; +} +std::string Remove(const std::string &from, const std::string &subStr, Mode mode) { + std::string result = from; + if (mode == PREFIX) { + if (from.substr(0, subStr.length()) == subStr) { + result = from.substr(subStr.size()); + } + } else if (mode == SUFFIX) { + if (from.rfind(subStr) == from.size() - subStr.size()) { + result = from.substr(0, from.size() - subStr.size()); + } + } else { + size_t index; + while ((index = result.find(subStr)) != std::string::npos) { + result = result.erase(index, subStr.size()); + } + } + + return result; +} + +std::vector StrSplit(const std::string &str, const std::string &pattern) { + std::string::size_type pos; + std::vector result; + std::string tmpStr(str + pattern); + std::string::size_type size = tmpStr.size(); + + for (std::string::size_type i = 0; i < size; i++) { + pos = tmpStr.find(pattern, i); + if (pos < size) { + std::string s = tmpStr.substr(i, pos - i); + result.push_back(s); + i = pos + pattern.size() - 1; + } + } + return result; +} + +std::vector Tokenize(const std::string &src, const std::string &delimiters, + const Option &maxTokenNum) { + if (maxTokenNum.IsSome() && maxTokenNum.Get() == 0) { + return {}; + } + + std::vector tokens; + size_t offset = 0; + + while (true) { + size_t nonDelimiter = src.find_first_not_of(delimiters, offset); + if (nonDelimiter == std::string::npos) { + break; + } + size_t delimiter = src.find_first_of(delimiters, nonDelimiter); + if (delimiter == std::string::npos || (maxTokenNum.IsSome() && tokens.size() == maxTokenNum.Get() - 1)) { + tokens.push_back(src.substr(nonDelimiter)); + break; + } + + tokens.push_back(src.substr(nonDelimiter, delimiter - nonDelimiter)); + offset = delimiter; + } + return tokens; +} + +void ShortToFloat32(const int16_t *srcdata, float *dstdata, size_t elementSize) { + MS_ASSERT(srcdata != nullptr); + MS_ASSERT(dstdata != nullptr); + for (size_t i = 0; i < elementSize; i++) { + dstdata[i] = ShortToFloat32(srcdata[i]); + } +} + +void Float32ToShort(const float *srcdata, int16_t *dstdata, size_t elementSize) { + MS_ASSERT(srcdata != nullptr); + MS_ASSERT(dstdata != nullptr); + for (size_t i = 0; i < elementSize; i++) { + dstdata[i] = Float32ToShort(srcdata[i]); + } +} + +#if defined(__ANDROID__) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h new file mode 100644 index 0000000000..b6d28d8992 --- /dev/null +++ b/mindspore/lite/src/common/utils.h @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_UTILS_H_ +#define MINDSPORE_LITE_COMMON_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "tools/common/option.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +const int USEC = 1000000; +const int MSEC = 1000; +std::vector StringSplit(std::string str, const std::string& pattern); + +uint64_t GetTimeUs(void); + +int16_t Float32ToShort(float srcValue); + +float ShortToFloat32(int16_t srcValue); + +void ShortToFloat32(const int16_t *srcdata, float *dstdata, size_t elementSize); + +void Float32ToShort(const float *srcdata, int16_t *dstdata, size_t elementSize); + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +template +bool IsContain(const std::vector &vec, T element) { + for (auto iter = vec.begin(); iter != vec.end(); iter++) { + if (*iter == element) { + return true; + } + } + return false; +} + +template +bool VectorErase(std::vector *vec, T element) { + bool ret = false; + for (auto iter = vec->begin(); iter != vec->end();) { + if (*iter == element) { + iter = vec->erase(iter); + ret = true; + } else { + iter++; + } + } + return ret; +} + +template +bool VectorReplace(std::vector *vec, T srcElement, T dstElement) { + bool ret = false; + for (auto iter = vec->begin(); iter != vec->end(); iter++) { + if (*iter == srcElement) { + if (!IsContain(*vec, dstElement)) { + *iter = std::move(dstElement); + } else { + vec->erase(iter); + } + ret = true; + break; + } + } + return ret; +} + +const char WHITESPACE[] = "\t\n\v\f\r "; +const char STR_TRUE[] = "true"; +const char STR_FALSE[] = "false"; + +template +Option ToString(T t) { + std::ostringstream out; + out << t; + if (!out.good()) { + return Option(None()); + } + + return Option(out.str()); +} + +template <> +inline Option ToString(bool value) { + return value ? Option(STR_TRUE) : Option(STR_FALSE); +} + +// get the file name from a given path +// for example: "/usr/bin", we will get "bin" +inline std::string GetFileName(const std::string &path) { + char delim = '/'; + + size_t i = path.rfind(delim, path.length()); + if (i != std::string::npos) { + return (path.substr(i + 1, path.length() - i)); + } + + return ""; +} + +// trim the white space character in a string +// see also: macro WHITESPACE defined above +inline void Trim(std::string *input) { + if (input == nullptr) { + return; + } + if (input->empty()) { + return; + } + + input->erase(0, input->find_first_not_of(WHITESPACE)); + input->erase(input->find_last_not_of(WHITESPACE) + 1); +} + +// to judge whether a string is starting with prefix +// for example: "hello world" is starting with "hello" +inline bool StartsWithPrefix(const std::string &source, const std::string &prefix) { + if (source.length() < prefix.length()) { + return false; + } + + return (source.compare(0, prefix.length(), prefix) == 0); +} + +// split string +std::vector StrSplit(const std::string &str, const std::string &pattern); + +// tokenize string +std::vector Tokenize(const std::string &src, const std::string &delimiters, + const Option &maxTokenNum = Option(None())); + +enum Mode { PREFIX, SUFFIX, ANY }; + +// remove redundant charactor +std::string Remove(const std::string &from, const std::string &subStr, Mode mode = ANY); + +template +inline Option GenericParseValue(const std::string &value) { + T ret; + std::istringstream input(value); + input >> ret; + + if (input && input.eof()) { + return Option(ret); + } + + return Option(None()); +} + +template <> +inline Option GenericParseValue(const std::string &value) { + return Option(value); +} + +template <> +inline Option GenericParseValue(const std::string &value) { + if (value == "true") { + return Option(true); + } else if (value == "false") { + return Option(false); + } + + return Option(None()); +} +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_UTILS_H_ + diff --git a/mindspore/lite/src/context.cc b/mindspore/lite/src/context.cc new file mode 100644 index 0000000000..5145a738b7 --- /dev/null +++ b/mindspore/lite/src/context.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/context.h" +#include "src/runtime/allocator.h" + +namespace mindspore::lite { +Context::Context() { allocator = Allocator::Create(); } + +Context::~Context() = default; + +Context::Context(int thread_num, std::shared_ptr allocator, DeviceContext device_ctx) { + this->allocator = std::move(allocator); + this->thread_num_ = thread_num; + this->device_ctx_ = device_ctx; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc new file mode 100644 index 0000000000..a82d7cd1ed --- /dev/null +++ b/mindspore/lite/src/executor.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/executor.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "include/errorcode.h" +#include "src/common/ms_tensor_utils.h" + +namespace mindspore::lite { +int Executor::Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator, + const session::KernelCallBack &before, const session::KernelCallBack &after) { + MS_ASSERT(nullptr != allocator); + for (auto &inTensor : inputs) { + if (inTensor == nullptr) { + MS_LOG(ERROR) << "Graph input tensor is nullptr"; + return RET_ERROR; + } + if (inTensor->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "Model input tensor should be NHWC"; + return RET_ERROR; + } + } + kernel::LiteKernelUtil::InitTensorRefCount(kernels); + for (auto *kernel : kernels) { + MS_ASSERT(nullptr != kernel); + auto &outputs = kernel->GetOutputs(); + for (auto *output : outputs) { + MS_ASSERT(nullptr != output); + output->MallocData(); + } + session::CallBackParam callbackParam; + callbackParam.name_callback_param = kernel->Name(); + callbackParam.type_callback_param = kernel->type_str(); + + if (before != nullptr) { + if (!before(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()), callbackParam)) { + MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->Name(); + } + } + auto ret = kernel->Run(); + if (0 != ret) { + MS_LOG(ERROR) << "run kernel failed, name: " << kernel->Name(); + return ret; + } + + if (after != nullptr) { + if (!after(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()), callbackParam)) { + MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->Name(); + } + } + for (auto input_kernel : kernel->GetInKernels()) { + MS_EXCEPTION_IF_NULL(input_kernel); + ret = input_kernel->DecOutTensorRefCount(); + if (0 != ret) { + MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed"; + } + } + } + return RET_OK; +} + +int Executor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator); + MS_ASSERT(4 == tensor->shape().size()); + auto data_type = tensor->data_type(); + switch (data_type) { + case kNumberTypeInt8: + return TransformTensorLayoutUint8(tensor, dst_format, allocator); + case kNumberTypeFloat32: + return TransformTensorLayoutFp32(tensor, dst_format, allocator); + default: + return RET_ERROR; + } + return RET_OK; +} + +int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator); + MS_ASSERT(4 == tensor->shape().size()); + auto src_format = tensor->GetFormat(); + if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC) { + auto *src_data = tensor->Data(); + auto *dst_data = allocator->Malloc(tensor->Size()); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + PackNC4HW4ToNHWCFp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); + tensor->SetData(dst_data); + tensor->SetFormat(dst_format); + allocator->Free(src_data); + return RET_OK; + } else { + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in float32"; + return RET_ERROR; + } +} + +int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator); + MS_ASSERT(4 == tensor->shape().size()); + // auto src_format = tensor->GetFormat(); + // todo + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in uint8"; + return RET_ERROR; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/executor.h b/mindspore/lite/src/executor.h new file mode 100644 index 0000000000..5ee084fa5f --- /dev/null +++ b/mindspore/lite/src/executor.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXECUTOR_H_ +#define MINDSPORE_LITE_SRC_EXECUTOR_H_ + +#include +#include "src/runtime/allocator.h" +#include "src/lite_kernel.h" +#include "include/lite_session.h" + +namespace mindspore::lite { +class Executor { + public: + Executor() = default; + + int Prepare(std::vector &kernels) { return 0; } + + int Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator = nullptr, + const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr); + + protected: + int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr); + + int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr); + + int TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr); + + protected: + Context *context = nullptr; +}; + +} // namespace mindspore::lite +#endif diff --git a/mindspore/lite/src/ir/meta_tensor_extends.cc b/mindspore/lite/src/ir/meta_tensor_extends.cc new file mode 100644 index 0000000000..3e5851ba33 --- /dev/null +++ b/mindspore/lite/src/ir/meta_tensor_extends.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/meta_tensor.h" + +namespace mindspore { +namespace tensor { +abstract::AbstractBasePtr MetaTensor::ToAbstract() { + MS_LOG(ERROR) << "MetaTensor ToAbstract is not implemented"; + return nullptr; +} +TypePtr MetaTensor::Dtype() const { return nullptr; } +} // namespace tensor +} // namespace mindspore + diff --git a/mindspore/lite/src/ir/primitive_t_value.cc b/mindspore/lite/src/ir/primitive_t_value.cc new file mode 100644 index 0000000000..9c27cc66fd --- /dev/null +++ b/mindspore/lite/src/ir/primitive_t_value.cc @@ -0,0 +1,17 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ir/primitive_t_value.h" diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h new file mode 100644 index 0000000000..b13f4606eb --- /dev/null +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ + +#include +#include "ir/value.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore::lite { + +class PrimitiveTValue : public Value { + public: + explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} + // not responsible to free primitive, the one created the dynamic memory is responsible to free it. + ~PrimitiveTValue() override = default; + + MS_DECLARE_PARENT(PrimitiveTValue, Value) + + schema::PrimitiveT *GetPrimitiveT() const { return this->primitive; } + + void SetPrimitiveT(schema::PrimitiveT *primIn) { this->primitive = primIn; } + + bool operator==(const Value &rhs) const override { + if (rhs.isa()) { + auto other_prim = static_cast(rhs); + auto a = this->primitive->value.type; + auto b = other_prim.primitive->value.type; + return a == b; + } else { + return false; + } + } + + void AddInputQuantParam(schema::QuantParamT quant_param) { + this->input_quant_param_.emplace_back(quant_param); + } + std::vector GetInputQuantParams() const { + return input_quant_param_; + } + + void AddOutputQuantParam(schema::QuantParamT quant_param) { + this->output_quant_param_.emplace_back(quant_param); + } + std::vector GetOutputQuantParams() const { + return output_quant_param_; + } + + void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } + + schema::QuantType GetQuantType() const { return quant_type_; } + + protected: + schema::PrimitiveT *primitive = nullptr; + std::vector input_quant_param_; + std::vector output_quant_param_; + schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ + diff --git a/mindspore/lite/src/ir/primitive_value.cc b/mindspore/lite/src/ir/primitive_value.cc new file mode 100644 index 0000000000..ebd5d4d615 --- /dev/null +++ b/mindspore/lite/src/ir/primitive_value.cc @@ -0,0 +1,19 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ir/primitive_value.h" + + diff --git a/mindspore/lite/src/ir/primitive_value.h b/mindspore/lite/src/ir/primitive_value.h new file mode 100644 index 0000000000..66202d15e6 --- /dev/null +++ b/mindspore/lite/src/ir/primitive_value.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_ + +#include "ir/value.h" +#include "src/ops/ops.h" + +namespace mindspore::lite { +class PrimitiveValue : public Value { + public: + explicit PrimitiveValue(const lite::Primitive *prim) : primitive(prim) {} + + const lite::Primitive *GetPrimitive() const { + return this->primitive; + } + MS_DECLARE_PARENT(PrimitiveValue, Value) + bool operator==(const Value &rhs) const override { + if (rhs.isa()) { + auto other_prim = static_cast(rhs); + return *this == other_prim; + } else { + return false; + } + } + + protected: + const lite::Primitive *primitive = nullptr; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_ + diff --git a/mindspore/lite/src/ir/tensor.cc b/mindspore/lite/src/ir/tensor.cc new file mode 100644 index 0000000000..cd9fb766d6 --- /dev/null +++ b/mindspore/lite/src/ir/tensor.cc @@ -0,0 +1,322 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ir/tensor.h" +#include "securec/include/securec.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +namespace tensor { +#define kMaxMallocSize 1024 * 1024 * 100 +Tensor::Tensor(const TypeId data_type, const std::vector &shape, const schema::Format &format, + schema::NodeType tensorType) + : MetaTensor(data_type, shape), format_(format), tensorType(tensorType) {} + +Tensor::Tensor(const Tensor &tensor) : MetaTensor(tensor) { + auto ret = CopyTensor(tensor, true); + if (0 != ret) { + MS_LOG(EXCEPTION) << "CopyTensorData error"; + } +} + +int Tensor::CopyTensorData(const Tensor &srcTensor) { + if (srcTensor.data_ == nullptr) { + MS_LOG(ERROR) << "data of srcTensor is nullptr"; + return mindspore::lite::RET_PARAM_INVALID; + } + size_t data_size = this->Size(); + MS_ASSERT(data_size == srcTensor.Size()); + if (this->data_ == nullptr) { + if (data_size > kMaxMallocSize) { + MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes"; + return mindspore::lite::RET_ERROR; + } + this->data_ = malloc(data_size); + } + memcpy(this->data_, srcTensor.data_, data_size); + return 0; +} + +int Tensor::CopyTensor(const Tensor &srcTensor, bool copyData) { + this->data_type_ = srcTensor.data_type_; + this->shape_ = srcTensor.shape_; + this->tensorType = srcTensor.tensorType; + if (copyData) { + auto ret = CopyTensorData(srcTensor); + if (0 != ret) { + MS_LOG(ERROR) << "CopyTensorData error"; + return mindspore::lite::RET_ERROR; + } + } + return 0; +} + +Tensor::~Tensor() { + if (nullptr != this->data_) { + if (this->allocator_ != nullptr) { + this->allocator_->Free(this->data_); + } else { + free(this->data_); + } + } +} + +Tensor &Tensor::operator=(const Tensor &tensor) { + if (&tensor == this) { + return *this; + } + auto ret = CopyTensor(tensor, true); + if (0 != ret) { + MS_LOG(ERROR) << "CopyTensorData error"; + MS_ASSERT(false); + } + return *this; +} + +bool Tensor::operator==(const Tensor &tensor) { + return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_; +} + +bool Tensor::operator==(const Value &other) const { + if (other.isa()) { + auto other_ = static_cast(other); + return *this == other_; + } else { + return false; + } +} + +int32_t Tensor::Batch() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NCHW: + case schema::Format_NC4HW4: + case schema::Format_KCHW: + case schema::Format_KHWC: + return this->shape_[0]; + case schema::Format_HWCK: + case schema::Format_CHWK: + return this->shape_[3]; + case schema::Format_HWKC: + return this->shape_[2]; + case schema::Format_CKHW: + return this->shape_[1]; + default: + MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); + return -1; + } +} + +int32_t Tensor::Channel() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NCHW: + case schema::Format_KCHW: + return this->shape_[1]; + case schema::Format_HWCK: + return this->shape_[2]; + case schema::Format_HWKC: + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NC4HW4: + case schema::Format_KHWC: + return this->shape_[3]; + case schema::Format_CKHW: + case schema::Format_CHWK: + return this->shape_[0]; + default: + return -1; + } +} + +int32_t Tensor::Height() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NCHW: + case schema::Format_KCHW: + case schema::Format_CKHW: + return this->shape_[2]; + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NC4HW4: + case schema::Format_KHWC: + case schema::Format_CHWK: + return this->shape_[1]; + case schema::Format_HWCK: + case schema::Format_HWKC: + return this->shape_[0]; + default: + MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); + return -1; + } +} + +int32_t Tensor::Width() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NCHW: + case schema::Format_KCHW: + case schema::Format_CKHW: + return this->shape_[3]; + case schema::Format_KHWC: + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NC4HW4: + case schema::Format_CHWK: + return this->shape_[2]; + case schema::Format_HWCK: + case schema::Format_HWKC: + return this->shape_[1]; + default: + return -1; + } +} + +std::string Tensor::ToString() const { + std::ostringstream oss; + oss << "Format: " << schema::EnumNameFormat(this->format_); + oss << " DataType: " << this->data_type_; + oss << " NodeType: " << schema::EnumNameNodeType(this->tensorType); + oss << " Shape:"; + for (auto &dim : this->shape()) { + oss << " " << dim; + } + oss << std::endl << "Data:"; + switch (this->data_type_) { + case kNumberTypeFloat32: { + auto data = static_cast(this->data_); + if (data == nullptr) { + return "Data of tensor is nullptr"; + } else { + for (size_t i = 0; i < 40 && i < this->ElementsNum(); i++) { + oss << " " << data[i]; + } + } + } break; + case kNumberTypeInt32: { + auto data = static_cast(this->data_); + if (data == nullptr) { + return "Data of tensor is nullptr"; + } else { + for (size_t i = 0; i < 40 && i < this->ElementsNum(); i++) { + oss << " " << data[i]; + } + } + } break; + default: + oss << "Unsupport data type to print"; + break; + } + return oss.str(); +} + +void Tensor::AddQuantParam(const tensor::QuantArg &quant_arg) { this->quant_params_.push_back(quant_arg); } + +std::vector Tensor::GetQuantParams() const { return this->quant_params_; } + +LiteTensor::LiteTensor() { this->tensor_impl_ = new tensor::Tensor(); } + +LiteTensor::LiteTensor(TypeId data_type, const std::vector &shape) { + this->tensor_impl_ = new tensor::Tensor(data_type, shape); +} + +LiteTensor::LiteTensor(tensor::Tensor *tensor_ptr) { this->tensor_impl_ = tensor_ptr; } + +TypeId LiteTensor::data_type() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data_type(); +} + +TypeId LiteTensor::set_data_type(TypeId data_type) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_data_type(data_type); +} + +std::vector LiteTensor::shape() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->shape(); +} + +size_t LiteTensor::set_shape(const std::vector &shape) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_shape(shape); +} + +int LiteTensor::DimensionSize(size_t index) const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->DimensionSize(index); +} + +int LiteTensor::ElementsNum() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->ElementsNum(); +} + +std::size_t LiteTensor::hash() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->hash(); +} + +tensor::Tensor *LiteTensor::tensor() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_; +} + +size_t LiteTensor::Size() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->Size(); +} + +void *LiteTensor::MutableData() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + auto data = this->tensor_impl_->Data(); + if (nullptr == data) { + auto ret = tensor_impl_->MallocData(); + if (0 != ret) { + return nullptr; + } + } + return this->tensor_impl_->Data(); +} +LiteTensor::~LiteTensor() { delete this->tensor_impl_; } + +void LiteTensor::SetTensorImpl(tensor::Tensor *tensor) { this->tensor_impl_ = tensor; } +} // namespace tensor +} // namespace lite +namespace tensor { +MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { + return new mindspore::lite::tensor::LiteTensor(data_type, shape); +} +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/lite/src/ir/tensor.h b/mindspore/lite/src/ir/tensor.h new file mode 100644 index 0000000000..3585633c0a --- /dev/null +++ b/mindspore/lite/src/ir/tensor.h @@ -0,0 +1,229 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_IR_TENSOR_H_ +#define MINDSPORE_LITE_SRC_IR_TENSOR_H_ + +#include +#include +#include +#include "ir/meta_tensor.h" +#include "include/ms_tensor.h" +#include "ir/dtype/type_id.h" +#include "src/runtime/allocator.h" +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +namespace tensor { + +struct QuantArg { + double scale; + int32_t zeroPoint; +}; + +class Tensor : public mindspore::tensor::MetaTensor { + public: + Tensor() : MetaTensor() {} + + Tensor(const TypeId data_type, const std::vector &shape, const schema::Format &format = schema::Format_NHWC, + schema::NodeType tensorType = schema::NodeType_Parameter); + + Tensor(const Tensor &tensor); + + ~Tensor() override; + + int CopyTensorData(const Tensor &srcTensor); + + int CopyTensor(const Tensor &srcTensor, bool copyData = false); + + MS_DECLARE_PARENT(Tensor, MetaTensor) + + virtual Tensor &operator=(const Tensor &tensor); + + virtual bool operator==(const Tensor &tensor); + + bool operator==(const Value &other) const override; + + int32_t Batch() const; + + int32_t Channel() const; + + int32_t Height() const; + + int32_t Width() const; + + int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); } + + int DataSize() const { return this->ElementsNum(); } + + size_t Size() const { + size_t size = 0; + switch (this->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + size = sizeof(float); + break; + case kNumberTypeInt8: + size = sizeof(int8_t); + break; + case kNumberTypeUInt8: + size = sizeof(uint8_t); + break; + case kNumberTypeFloat16: + size = sizeof(int16_t); + break; + case kNumberTypeInt16: + size = sizeof(int16_t); + break; + case kNumberTypeInt32: + size = sizeof(int32_t); + break; + case kNumberTypeInt64: + size = sizeof(int64_t); + break; + case kNumberTypeUInt16: + size = sizeof(uint16_t); + break; + case kNumberTypeUInt32: + size = sizeof(uint32_t); + break; + case kNumberTypeUInt64: + size = sizeof(uint64_t); + break; + case kNumberTypeBool: + size = sizeof(bool); + break; + default: + MS_LOG(ERROR) << "Not support the type: " << this->data_type_; + return 0; + } + size *= (format_ == schema::Format_NC4HW4 || format_ == schema::Format_NHWC4) ? ElementsC4Num() + : MetaTensor::ElementsNum(); + + return size; + } + + void set_allocator(mindspore::lite::Allocator *allocator) { allocator_ = allocator; } + + int MallocData(mindspore::lite::Allocator *allocator = nullptr) { + if (nullptr != this->data_) { + return 0; + } + if (allocator != nullptr) { + allocator_ = allocator; + } + if (allocator_ == nullptr) { + this->data_ = malloc(this->Size()); + } else { + this->data_ = allocator_->Malloc(this->Size()); + } + if (nullptr == this->data_) { + MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size(); + return -1; + } + + return 0; + } + + int FreeData() { + if (nullptr == this->data_) { + return 0; + } + if (nullptr == allocator_) { + free(this->data_); + } else { + allocator_->Free(this->data_); + this->data_ = nullptr; + } + + return 0; + } + + void *Data() { return data_; } + + void SetData(void *data) { this->data_ = data; } + + schema::NodeType TensorType() { return this->tensorType; } + + void SetFormat(schema::Format format) { this->format_ = format; } + + schema::Format GetFormat() { return this->format_; } + + size_t RefCount() { return this->refCount; } + + void SetRefCount(size_t refCount) { this->refCount = refCount; } + + void decRefCount() { this->refCount--; } + + std::string ToString() const override; + + void AddQuantParam(const tensor::QuantArg &quant_arg); + + std::vector GetQuantParams() const; + + protected: + void *data_ = nullptr; + void *device_data_ = nullptr; + schema::NodeType tensorType; + schema::Format format_; + size_t refCount = 0; + std::vector quant_params_; + mindspore::lite::Allocator *allocator_ = nullptr; +}; + +class LiteTensor : public mindspore::tensor::MSTensor { + public: + LiteTensor(); + + LiteTensor(TypeId data_type, const std::vector &shape); + + explicit LiteTensor(tensor::Tensor *tensor_ptr); + + ~LiteTensor() override; + + TypeId data_type() const override; + + TypeId set_data_type(TypeId data_type) override; + + std::vector shape() const override; + + size_t set_shape(const std::vector &shape) override; + + int DimensionSize(size_t index) const override; + + int ElementsNum() const override; + + std::size_t hash() const override; + + tensor::Tensor *tensor() const; + + size_t Size() const override; + + void *MutableData() const override; + + void SetTensorImpl(tensor::Tensor *tensor); + + protected: + tensor::Tensor *tensor_impl_; +}; + +using TensorPtr = std::shared_ptr; +} // namespace tensor +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_IR_TENSOR_H_ diff --git a/mindspore/lite/src/kernel_factory.cc b/mindspore/lite/src/kernel_factory.cc new file mode 100644 index 0000000000..b835506bec --- /dev/null +++ b/mindspore/lite/src/kernel_factory.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/kernel_factory.h" +#include "utils/log_adapter.h" +#include "src/populate_parameter.h" +#include "schema/model_generated.h" + +using mindspore::kernel::KERNEL_ARCH; +using mindspore::kernel::KernelKey; +using mindspore::kernel::LiteKernel; + +namespace mindspore::lite { +KernelFactory::KernelFactory() = default; + +KernelFactory::~KernelFactory() = default; + +KernelFactory *KernelFactory::GetInstance() { + static KernelFactory instance; + return &instance; +} + +LiteKernel *KernelFactory::GetKernel(const std::vector &inputs, + const std::vector &outputs, const lite::Primitive *primitive, + const Context *ctx, const kernel::KernelKey &key) { + MS_EXCEPTION_IF_NULL(primitive); + MS_EXCEPTION_IF_NULL(ctx); + auto parameter = kernel::PopulateParameter(primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); + return nullptr; + } + auto creator = KernelRegistry::GetInstance()->GetCreator(key); + if (creator != nullptr) { + auto kernel = creator(inputs, outputs, parameter, ctx, key); + return kernel; + } + return nullptr; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/kernel_factory.h b/mindspore/lite/src/kernel_factory.h new file mode 100644 index 0000000000..2065e208d0 --- /dev/null +++ b/mindspore/lite/src/kernel_factory.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ +#define MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ + +#include +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "schema/model_generated.h" + +namespace mindspore::lite { +class KernelFactory { + public: + KernelFactory(); + virtual ~KernelFactory(); + + static KernelFactory *GetInstance(); + kernel::LiteKernel *GetKernel(const std::vector &inputs, + const std::vector &outputs, const lite::Primitive *primitive, + const Context *ctx, const kernel::KernelKey &key); +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc new file mode 100644 index 0000000000..4917d93dcc --- /dev/null +++ b/mindspore/lite/src/kernel_registry.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "ir/dtype/type_id.h" +#ifdef ENABLE_ARM64 +#include +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" +#endif + +using mindspore::kernel::kCPU; +using mindspore::kernel::KERNEL_ARCH; +using mindspore::kernel::KernelCreator; +using mindspore::kernel::KernelKey; +using mindspore::kernel::kKernelArch_MAX; +using mindspore::kernel::kKernelArch_MIN; +using mindspore::schema::PrimitiveType_MAX; +using mindspore::schema::PrimitiveType_MIN; + +namespace mindspore::lite { +KernelRegistry::KernelRegistry() { + device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1; + data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1; + op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1; + // malloc an array contain creator functions of kernel + auto total_len = device_type_length_ * data_type_length_ * op_type_length_; + creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator)); + if (creator_arrays_ == nullptr) { + MS_LOG(ERROR) << "malloc creator_arrays_ failed."; + } else { + for (int i = 0; i < total_len; ++i) { + creator_arrays_[i] = nullptr; + } + } +} + +KernelRegistry::~KernelRegistry() { FreeCreatorArray(); } + +KernelRegistry *KernelRegistry::GetInstance() { + static KernelRegistry instance; + return &instance; +} + +int KernelRegistry::Init() { +#ifdef ENABLE_ARM64 + void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_; + if (optimized_lib_handler != nullptr) { + MS_LOG(INFO) << "load optimize lib success."; + } else { + MS_LOG(INFO) << "load optimize lib failed."; + } +#endif + return RET_OK; +} + +void KernelRegistry::FreeCreatorArray() { + if (creator_arrays_ != nullptr) { + free(creator_arrays_); + creator_arrays_ = nullptr; + } +} + +kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { + if (creator_arrays_ == nullptr) { + MS_LOG(ERROR) << "Creator func array is null."; + return nullptr; + } + int index = GetCreatorFuncIndex(desc); + auto it = creator_arrays_[index]; + if (it != nullptr) { + return it; + } + return nullptr; +} + +int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { + int index; + int device_index = static_cast(desc.arch); + int dType_index = static_cast(desc.data_type); + int op_index = static_cast(desc.type); + index = device_index * data_type_length_ * op_type_length_ + dType_index * op_type_length_ + op_index; + return index; +} + +void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) { + if (creator_arrays_ == nullptr) { + MS_LOG(ERROR) << "Creator func array is null."; + return; + } + int index = GetCreatorFuncIndex(desc); + creator_arrays_[index] = creator; +} + +void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, + kernel::KernelCreator creator) { + if (creator_arrays_ == nullptr) { + MS_LOG(ERROR) << "Creator func array is null."; + return; + } + KernelKey desc = {arch, data_type, op_type}; + int index = GetCreatorFuncIndex(desc); + creator_arrays_[index] = creator; +} + +bool KernelRegistry::Merge(const std::unordered_map &newCreators) { return false; } + +const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; } +} // namespace mindspore::lite diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h new file mode 100644 index 0000000000..eab7d03a53 --- /dev/null +++ b/mindspore/lite/src/kernel_registry.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ +#define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ + +#include +#include +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + +namespace mindspore::lite { +class KernelRegistry { + public: + KernelRegistry(); + virtual ~KernelRegistry(); + + static KernelRegistry *GetInstance(); + int Init(); + void FreeCreatorArray(); + virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); + const kernel::KernelCreator *GetCreatorArrays(); + int GetCreatorFuncIndex(const kernel::KernelKey desc); + void RegKernel(const kernel::KernelKey desc, kernel::KernelCreator creator); + void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type, + kernel::KernelCreator creator); + bool Merge(const std::unordered_map &newCreators); + + protected: + kernel::KernelCreator *creator_arrays_ = nullptr; + int device_type_length_; + int data_type_length_; + int op_type_length_; + std::mutex lock_; +}; + +class KernelRegistrar { + public: + KernelRegistrar(const kernel::KernelKey &desc, kernel::KernelCreator creator) { + KernelRegistry::GetInstance()->RegKernel(desc, creator); + } + + KernelRegistrar(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, + kernel::KernelCreator creator) { + KernelRegistry::GetInstance()->RegKernel(arch, data_type, op_type, creator); + } +}; + +#define REG_KERNEL(arch, data_type, op_type, kernelCreater) \ + static KernelRegistrar g_##arch##data_type##op_type##kernelReg(arch, data_type, op_type, kernelCreater); +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc new file mode 100644 index 0000000000..86e172a315 --- /dev/null +++ b/mindspore/lite/src/lite_kernel.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/lite_kernel.h" +#include +#include "src/common/utils.h" + +namespace mindspore::kernel { +void LiteKernel::InitOutTensorRefCount() { + for (auto *tensor : this->outputs_) { + tensor->SetRefCount(this->out_kernel_.size()); + } +} + +int LiteKernel::DecOutTensorRefCount() { + for (auto *tensor : this->outputs_) { + tensor->decRefCount(); + if (0 >= tensor->RefCount()) { + auto ret = tensor->FreeData(); + if (0 != ret) { + MS_LOG(ERROR) << "Free tensor data failed"; + return ret; + } + } + } + return 0; +} + +std::vector LiteKernelUtil::SubgraphInputKernels( + const std::vector &kernels) { + std::vector input_kernels; + for (const auto kernel : kernels) { + for (auto input : kernel->GetInKernels()) { + auto iter = std::find(kernels.begin(), kernels.end(), input); + if (iter == kernels.end()) { + input_kernels.emplace_back(input); + } + } + } + return input_kernels; +} + +std::vector LiteKernelUtil::SubgraphOutputKernels( + const std::vector &kernels) { + std::vector output_kernels; + for (const auto kernel : kernels) { + for (const auto output : kernel->GetOutKernels()) { + auto iter = std::find(kernels.begin(), kernels.end(), output); + if (iter == kernels.end()) { + output_kernels.emplace_back(output); + } + } + } + return output_kernels; +} + +std::vector LiteKernelUtil::SubgraphInputTensors( + const std::vector &kernels) { + std::vector input_tensors; + std::vector all_output_tensors; + for (const auto &kernel : kernels) { + all_output_tensors.insert(all_output_tensors.end(), kernel->GetOutputs().begin(), kernel->GetOutputs().end()); + } + std::sort(all_output_tensors.begin(), all_output_tensors.end()); + auto end_iter = std::unique(all_output_tensors.begin(), all_output_tensors.end()); + all_output_tensors.erase(end_iter, all_output_tensors.end()); + + std::vector input_kernels = SubgraphInputKernels(kernels); + for (const auto &kernel : input_kernels) { + for (const auto &tensor : kernel->GetInputs()) { + auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); + if (iter == all_output_tensors.end() && tensor->Data() == nullptr) { + input_tensors.emplace_back(tensor); + } + } + } + return input_tensors; +} + +std::vector LiteKernelUtil::SubgraphOutputTensors( + const std::vector &kernels) { + std::vector output_tensors; + std::vector all_input_tensors; + for (const auto &kernel : kernels) { + all_input_tensors.insert(all_input_tensors.end(), kernel->GetInputs().begin(), kernel->GetInputs().end()); + } + std::sort(all_input_tensors.begin(), all_input_tensors.end()); + auto end_iter = std::unique(all_input_tensors.begin(), all_input_tensors.end()); + all_input_tensors.erase(end_iter, all_input_tensors.end()); + + std::vector output_kernels = SubgraphOutputKernels(kernels); + for (const auto &kernel : output_kernels) { + for (const auto &tensor : kernel->GetOutputs()) { + auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor); + if (iter == all_input_tensors.end()) { + output_tensors.emplace_back(tensor); + } + } + } + return output_tensors; +} + +void LiteKernelUtil::TopologicalSortKernels(std::vector &kernels) { + for (auto *kernel : kernels) { + for (auto *search_kernel : kernels) { + if (search_kernel == kernel) { + continue; + } + for (auto *tensor : kernel->GetInputs()) { + if (lite::IsContain(search_kernel->GetOutputs(), tensor)) { + kernel->AddInKernel(search_kernel); + } + } + for (auto *tensor : kernel->GetOutputs()) { + if (lite::IsContain(search_kernel->GetInputs(), tensor)) { + kernel->AddOutKernel(search_kernel); + } + } + } + } +} + +void LiteKernelUtil::InitTensorRefCount(std::vector &kernels) { + for (auto *kernel : kernels) { + kernel->InitOutTensorRefCount(); + } +} + +int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector inputs) { return -1; } +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h new file mode 100644 index 0000000000..1dba44c862 --- /dev/null +++ b/mindspore/lite/src/lite_kernel.h @@ -0,0 +1,167 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_LITE_KERNEL_H_ +#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_ +#include +#include +#ifdef ENABLE_ARM +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "include/context.h" +#include "src/ir/tensor.h" +#include "src/ops/ops.h" + +#ifdef ENABLE_FP16 +using FLOAT_t = float16_t; +#else +using FLOAT_t = float; +#endif + +// using mindspore::kernel::AddressPtr; +namespace mindspore::kernel { +enum KERNEL_ARCH { kCPU, kGPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; +struct KernelKey { + KERNEL_ARCH arch; + TypeId data_type; + schema::PrimitiveType type; + + bool operator<(const KernelKey &dst) const { + if (arch != dst.arch) { + return arch < dst.arch; + } else if (data_type != dst.data_type) { + return data_type < dst.data_type; + } else { + return type < dst.type; + } + } +}; + +class LiteKernel { + public: + LiteKernel() = default; + explicit LiteKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false) { + this->in_kernel_.clear(); + this->out_kernel_.clear(); + } + + virtual ~LiteKernel() { delete opParameter; } + + virtual int Prepare() { return -1; } + virtual int Init() { return -1; } + virtual int ReSize() { return -1; } + virtual int Run() { return -1; } + + std::string Name() { return this->name; } + virtual void train() { train_mode = true; } + virtual bool is_train() { return train_mode == true; } + virtual void eval() { train_mode = false; } + virtual bool is_eval() { return train_mode == false; } + void set_name(const std::string &name) { this->name = name; } + + schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } + + std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); } + + void SetInputs(const std::vector &inputs) { this->inputs_ = inputs; } + + void SetOutputs(const std::vector &outputs) { this->outputs_ = outputs; } + + std::vector &GetInputs() { return this->inputs_; } + + std::vector &GetOutputs() { return this->outputs_; } + + void AddInKernel(LiteKernel *kernel) { this->in_kernel_.emplace_back(kernel); } + + void AddOutKernel(LiteKernel *kernel) { this->out_kernel_.emplace_back(kernel); } + + std::vector &GetInKernels() { return this->in_kernel_; } + + std::vector &GetOutKernels() { return this->out_kernel_; } + + void InitOutTensorRefCount(); + + int DecOutTensorRefCount(); + + const KernelKey Desc() const { return desc; } + + void set_desc(const KernelKey kernel_key) { desc = kernel_key; } + + protected: + KernelKey desc; + std::string name; + OpParameter *opParameter = nullptr; + // tensor will free in ~lite_session() + std::vector inputs_; + std::vector outputs_; + std::vector in_kernel_; + std::vector out_kernel_; + bool train_mode; +}; + +class SubGraphKernel : public LiteKernel { + public: + explicit SubGraphKernel(const std::vector &inputs, + const std::vector &outputs, + const std::vector &inKernels, + const std::vector &outKernels, + const std::vector &nodes) + : LiteKernel(nullptr, inputs, outputs), + inputs_(inputs), + outputs_(outputs), + inkernels_(inKernels), + outkernels_(outKernels), + nodes_(nodes) {} + + virtual int Init() { return -1; } + virtual int InferShape() { return -1; } + virtual int ReSize() { return -1; } + virtual int Run() { return -1; } + + protected: + std::vector inputs_; + std::vector outputs_; + std::vector inkernels_; + std::vector outkernels_; + std::vector nodes_; +}; + +typedef LiteKernel *(*KernelCreator)(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc); + +class LiteKernelUtil { + public: + static void TopologicalSortKernels(std::vector &kernels); + + static std::vector SubgraphInputKernels(const std::vector &kernels); + + static std::vector SubgraphOutputKernels(const std::vector &kernels); + + static std::vector SubgraphInputTensors(const std::vector &kernels); + + static std::vector SubgraphOutputTensors(const std::vector &kernels); + + static void InitTensorRefCount(std::vector &kernels); + + static int SetInput(LiteKernel &kernelMod, std::vector inputs); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_LITE_KERNEL_H_ diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc new file mode 100644 index 0000000000..ae529836ab --- /dev/null +++ b/mindspore/lite/src/lite_session.cc @@ -0,0 +1,291 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "include/errorcode.h" +#include "src/lite_session.h" +#include "utils/log_adapter.h" +#include "src/scheduler.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/allocator.h" +#include "src/executor.h" +#include "src/common/utils.h" +#include "src/common/graph_util.h" +#include "src/kernel_registry.h" +#if SUPPORT_GPU +#include "src/runtime/opencl/opencl_runtime.h" +#endif + +namespace mindspore { +namespace lite { +int LiteSession::ConvertTensors(const lite::Model *model) { + MS_EXCEPTION_IF_NULL(model); + auto meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + uint32_t tensorCount = meta_graph->allTensors()->size(); + for (uint32_t i = 0; i < tensorCount; i++) { + auto *srcTensor = meta_graph->allTensors()->GetAs(i); + if (srcTensor == nullptr) { + MS_LOG(ERROR) << i << "th tensor in meta_graph is nullptr"; + return RET_NULL_PTR; + } + std::vector shape; + if (srcTensor->dims() == nullptr) { + MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr"; + } else { + if (srcTensor->nodeType() == schema::NodeType_ValueNode) { + for (size_t j = 0; j < srcTensor->dims()->size(); j++) { + shape.push_back(srcTensor->dims()->data()[j]); + } + } + } + int dataType = srcTensor->dataType(); + auto *dstTensor = new tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType()); + if (srcTensor->nodeType() == schema::NodeType_ValueNode && srcTensor->data() != nullptr && + srcTensor->data()->size() > 0) { + if (shape.empty()) { + shape.push_back(1); + } + MS_ASSERT(dstTensor != nullptr); + MS_ASSERT(dstTensor->Size() == srcTensor->data()->size()); + // no copy data, do copy when call LiteKernel::Init + dstTensor->SetData(const_cast(srcTensor->data()->data())); + } + this->tensors.emplace_back(dstTensor); + } + return RET_OK; +} + +void LiteSession::InitGraphInOutTensor(const lite::Model *model) { + auto meta_graph = model->GetMetaGraph(); + MS_ASSERT(this->input_map.empty()); + MS_ASSERT(meta_graph != nullptr); + auto graph_input_node_indexes = GetGraphInputNodes(meta_graph); + for (auto in_node_index : graph_input_node_indexes) { + auto *in_node = meta_graph->nodes()->GetAs(in_node_index); + MS_ASSERT(nullptr != in_node); + MS_ASSERT(this->input_map.find(in_node->name()->str()) == this->input_map.end()); + for (size_t i = 0; i < in_node->inputIndex()->size(); i++) { + auto in_tensor_index = size_t(in_node->inputIndex()->GetAs(i)); + bool is_graph_input = false; + for (size_t j = 0; j < meta_graph->inputIndex()->size(); j++) { + if (in_tensor_index == size_t(meta_graph->inputIndex()->GetAs(j))) { + is_graph_input = true; + break; + } + } + if (!is_graph_input) { + continue; + } + MS_ASSERT(in_tensor_index < this->tensors.size()); + auto *in_tensor = this->tensors.at(in_tensor_index); + MS_ASSERT(in_tensor != nullptr); + auto *ms_tensor = new tensor::LiteTensor(in_tensor); + MS_ASSERT(nullptr != ms_tensor); + this->input_map[in_node->name()->str()].emplace_back(ms_tensor); + } + } + + auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph); + for (auto out_node_index : graph_output_node_indexes) { + auto *out_node = meta_graph->nodes()->GetAs(out_node_index); + MS_ASSERT(nullptr != out_node); + MS_ASSERT(this->output_map.find(out_node->name()->str()) == this->output_map.end()); + for (size_t i = 0; i < out_node->outputIndex()->size(); i++) { + auto out_tensor_index = size_t(out_node->outputIndex()->GetAs(i)); + bool is_graph_output = false; + for (size_t j = 0; j < meta_graph->outputIndex()->size(); j++) { + if (out_tensor_index == size_t(meta_graph->outputIndex()->GetAs(j))) { + is_graph_output = true; + break; + } + } + if (!is_graph_output) { + continue; + } + MS_ASSERT(out_tensor_index < this->tensors.size()); + auto *out_tensor = this->tensors.at(out_tensor_index); + MS_ASSERT(out_tensor != nullptr); + auto *ms_tensor = new tensor::LiteTensor(out_tensor); + MS_ASSERT(nullptr != ms_tensor); + this->output_map[out_node->name()->str()].emplace_back(ms_tensor); + } + } +} + +int LiteSession::CompileGraph(Model *model) { + // model.MetaGraph ==> kernels + if (model == nullptr) { + MS_LOG(ERROR) << "The input model is nullptr."; + return RET_PARAM_INVALID; + } + + auto ret = ConvertTensors(model); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvertTensors failed: " << ret; + return ret; + } + + InitGraphInOutTensor(model); + + // scheduler kernels + Scheduler scheduler(context_); + ret = scheduler.Schedule(model, &tensors, &kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Schedule kernels failed: " << ret; + return ret; + } + + return RET_OK; +} + +std::vector LiteSession::GetInputs() const { + std::vector ret; + for (auto &iter : this->input_map) { + auto &node_input_tensors = iter.second; + for (auto tensor : node_input_tensors) { + if (!IsContain(ret, tensor)) { + ret.emplace_back(tensor); + } + } + } + return ret; +} + +int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { + MS_EXCEPTION_IF_NULL(this->context_); + SetMaxWokerNum(context_->thread_num_); + Executor executor; + if (before == nullptr && after == nullptr) { + return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get()); + } else { + return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get(), before, after); + } +} + +std::vector LiteSession::GetOutputs() const { + std::vector ret; + for (auto &iter : this->output_map) { + auto &node_output_tensors = iter.second; + for (auto tensor : node_output_tensors) { + if (!IsContain(ret, tensor)) { + ret.emplace_back(tensor); + } + } + } + return ret; +} + +int LiteSession::Init(Context *context) { + MS_EXCEPTION_IF_NULL(context); + this->context_ = new (std::nothrow) Context(context->thread_num_, context->allocator, context->device_ctx_); + if (this->context_ == nullptr) { + MS_LOG(ERROR) << "new context failed"; + return RET_MEMORY_FAILED; + } + this->context_->cpu_bind_mode_ = context->cpu_bind_mode_; + ConfigThreadPool(context->cpu_bind_mode_, context->thread_num_); + auto ret = KernelRegistry::GetInstance()->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "KernelRegistry Init Failed."; + return ret; + } +#if SUPPORT_GPU + if (context_->device_ctx_.type == DT_GPU) { + auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + opencl_runtime->Init(); + } +#endif + return RET_OK; +} + +void LiteSession::BindThread(bool ifBind) { + if (this->context_->cpu_bind_mode_ != NO_BIND) { + DoAllThreadBind(ifBind, static_cast(this->context_->cpu_bind_mode_)); + } +} + +LiteSession::~LiteSession() { + for (auto *tensor : tensors) { + // weight data can not be to free, we will free weight data when freeing meta_graph + if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs, tensor)) { + tensor->SetData(nullptr); + } + delete tensor; + } + // inputs outputs input_map output_map are freed in tensors + for (auto *input : inputs) { + ((tensor::LiteTensor *)input)->SetTensorImpl(nullptr); + delete input; + } + for (auto *output : outputs) { + ((tensor::LiteTensor *)output)->SetTensorImpl(nullptr); + delete output; + } + for (auto iter : this->input_map) { + for (auto *ms_tensor : iter.second) { + ((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr); + delete ms_tensor; + } + iter.second.clear(); + } + input_map.clear(); + for (auto iter : this->output_map) { + for (auto *ms_tensor : iter.second) { + ((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr); + delete ms_tensor; + } + iter.second.clear(); + } + output_map.clear(); + for (auto *kernel : kernels) { + delete kernel; + } + delete this->context_; +} + +std::vector LiteSession::GetInputsByName(const std::string &name) const { + auto ret = input_map.find(name); + if (ret == input_map.end()) { + MS_LOG(WARNING) << "Node " << name << " is not an input node"; + std::vector empty_ret; + return empty_ret; + } + return ret->second; +} + +std::vector LiteSession::GetOutputsByName(const std::string &name) const { + auto ret = output_map.find(name); + if (ret == output_map.end()) { + MS_LOG(WARNING) << "Node " << name << " is not an output node"; + std::vector empty_ret; + return empty_ret; + } + return ret->second; +} +} // namespace lite + +session::LiteSession *session::LiteSession::CreateSession(lite::Context *context) { + auto session = new lite::LiteSession(); + auto ret = session->Init(context); + if (ret != mindspore::lite::RET_OK) { + MS_LOG(ERROR) << "init sesssion failed"; + delete session; + return nullptr; + } + return session; +} +} // namespace mindspore diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h new file mode 100644 index 0000000000..4c45809596 --- /dev/null +++ b/mindspore/lite/src/lite_session.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_LITE_SESSION_H_ +#define MINDSPORE_LITE_SRC_LITE_SESSION_H_ + +#include +#include +#include +#include +#include "include/ms_tensor.h" +#include "include/lite_session.h" +#include "include/model.h" +#include "include/context.h" +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +class LiteSession : public session::LiteSession { + public: + LiteSession() = default; + + ~LiteSession() override; + + int Init(Context *context); + + void BindThread(bool ifBind) override; + + int CompileGraph(Model *model) override; + + std::vector GetInputs() const override; + + std::vector GetInputsByName(const std::string &name) const override; + + int RunGraph(const session::KernelCallBack &before = nullptr, + const session::KernelCallBack &after = nullptr) override; + + std::vector GetOutputs() const override; + + std::vector GetOutputsByName(const std::string &name) const override; + + protected: + int ConvertTensors(const lite::Model *model); + + void InitGraphInOutTensor(const lite::Model *model); + + protected: + Context *context_ = nullptr; + std::vector kernels; + std::vector tensors; + // graph input tensors + std::vector inputs; + // graph output tensors + std::vector outputs; + // graph input node name -- input tensors + std::unordered_map> input_map; + // graph output node name -- output tensors + std::unordered_map> output_map; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_LITE_SESSION_H_ diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc new file mode 100644 index 0000000000..55d10278f7 --- /dev/null +++ b/mindspore/lite/src/model.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef SUPPORT_TRAIN +#include "src/train/model_impl.h" +#else +#include "src/model_impl.h" +#endif +#include "include/model.h" +#include "utils/log_adapter.h" + +namespace mindspore::lite { + +std::shared_ptr Model::Import(const char *model_buf, size_t size) { + auto model = std::make_shared(); + model->model_impl_ = ModelImpl::Import(model_buf, size); + return model; +} + +lite::Primitive *Model::GetOp(const std::string &name) const { + MS_EXCEPTION_IF_NULL(model_impl_); + return const_cast(model_impl_->GetOp(name)); +} + +void Model::FreeMetaGraph() { + MS_EXCEPTION_IF_NULL(model_impl_); + return model_impl_->FreeMetaGraph(); +} + +const schema::MetaGraph *Model::GetMetaGraph() const { + MS_EXCEPTION_IF_NULL(model_impl_); + return model_impl_->GetMetaGraph(); +} + +std::shared_ptr Model::model_impl() { + MS_EXCEPTION_IF_NULL(model_impl_); + return this->model_impl_; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc new file mode 100644 index 0000000000..abead7cba1 --- /dev/null +++ b/mindspore/lite/src/model_impl.cc @@ -0,0 +1,270 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/model_impl.h" +#include "utils/log_adapter.h" + +namespace mindspore::lite { +std::shared_ptr ModelImpl::Import(const char *model_buf, size_t size) { + MS_EXCEPTION_IF_NULL(model_buf); + flatbuffers::Verifier verify((const uint8_t *)model_buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return nullptr; + } + auto *inner_model_buf = new (std::nothrow) char[size]; + if (inner_model_buf == nullptr) { + MS_LOG(ERROR) << "new model buf fail."; + return nullptr; + } + memcpy(inner_model_buf, model_buf, size); + auto model = std::make_shared(inner_model_buf, size); + if (model == nullptr) { + MS_LOG(ERROR) << "Create modelImpl failed"; + return nullptr; + } + auto ret = model->BuildOps(); + if (0 != ret) { + MS_LOG(ERROR) << "BuildOps failed"; + return nullptr; + } + return model; +} + +lite::Primitive *ModelImpl::GetOp(const std::string &name) const { + auto iter = ops.find(name); + if (iter == ops.end()) { + return nullptr; + } else { + return iter->second; + } +} + +ModelImpl::~ModelImpl() { + delete[](this->model_buf_); + for (auto iter : ops) { + delete (iter.second); + } + ops.clear(); +} + +void ModelImpl::FreeMetaGraph() { + delete[](this->model_buf_); + model_buf_ = nullptr; +} + +const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; } + +lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { + MS_EXCEPTION_IF_NULL(srcPrim); + auto op_type = srcPrim->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(srcPrim)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(srcPrim)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(srcPrim)); + case schema::PrimitiveType_DeConv2D: + return new lite::DeConv2D(const_cast(srcPrim)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(srcPrim)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(srcPrim)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(srcPrim)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_BatchNorm: + return new lite::BatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(srcPrim)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(srcPrim)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(srcPrim)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(srcPrim)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(srcPrim)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(srcPrim)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(srcPrim)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(srcPrim)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(srcPrim)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(srcPrim)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(srcPrim)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(srcPrim)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(srcPrim)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(srcPrim)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(srcPrim)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(srcPrim)); + case schema::PrimitiveType_Slice: + return new lite::Slice(const_cast(srcPrim)); + case schema::PrimitiveType_Squeeze: + return new lite::Squeeze(const_cast(srcPrim)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(srcPrim)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(srcPrim)); + case schema::PrimitiveType_Flatten: + return new lite::Flatten(const_cast(srcPrim)); + case schema::PrimitiveType_Mean: + return new lite::Mean(const_cast(srcPrim)); + case schema::PrimitiveType_Stack: + return new lite::Stack(const_cast(srcPrim)); + case schema::PrimitiveType_Crop: + return new lite::Crop(const_cast(srcPrim)); + case schema::PrimitiveType_SquaredDifference: + return new lite::SquaredDifference(const_cast(srcPrim)); + case schema::PrimitiveType_AddN: + return new lite::AddN(const_cast(srcPrim)); + case schema::PrimitiveType_Abs: + return new lite::Abs(const_cast(srcPrim)); + case schema::PrimitiveType_Sin: + return new lite::Sin(const_cast(srcPrim)); + case schema::PrimitiveType_Cos: + return new lite::Cos(const_cast(srcPrim)); + case schema::PrimitiveType_Log: + return new lite::Log(const_cast(srcPrim)); + case schema::PrimitiveType_Sqrt: + return new lite::Sqrt(const_cast(srcPrim)); + case schema::PrimitiveType_Rsqrt: + return new lite::Rsqrt(const_cast(srcPrim)); + case schema::PrimitiveType_Square: + return new lite::Square(const_cast(srcPrim)); + case schema::PrimitiveType_Exp: + return new lite::Exp(const_cast(srcPrim)); + case schema::PrimitiveType_Gather: + return new lite::Gather(const_cast(srcPrim)); + case schema::PrimitiveType_LocalResponseNormalization: + return new lite::LocalResponseNormalization(const_cast(srcPrim)); + case schema::PrimitiveType_Maximum: + return new lite::Maximum(const_cast(srcPrim)); + case schema::PrimitiveType_Minimum: + return new lite::Minimum(const_cast(srcPrim)); + case schema::PrimitiveType_Pad: + return new lite::Pad(const_cast(srcPrim)); + case schema::PrimitiveType_StridedSlice: + return new lite::StridedSlice(const_cast(srcPrim)); + case schema::PrimitiveType_Prelu: + return new lite::Prelu(const_cast(srcPrim)); + case schema::PrimitiveType_Round: + return new lite::Round(const_cast(srcPrim)); + case schema::PrimitiveType_ReverseSequence: + return new lite::ReverseSequence(const_cast(srcPrim)); + case schema::PrimitiveType_LogicalAnd: + return new lite::LogicalAnd(const_cast(srcPrim)); + case schema::PrimitiveType_LogicalOr: + return new lite::LogicalOr(const_cast(srcPrim)); + case schema::PrimitiveType_LogicalNot: + return new lite::LogicalNot(const_cast(srcPrim)); + case schema::PrimitiveType_FloorDiv: + return new lite::FloorDiv(const_cast(srcPrim)); + case schema::PrimitiveType_FloorMod: + return new lite::FloorMod(const_cast(srcPrim)); + case schema::PrimitiveType_Equal: + return new lite::Equal(const_cast(srcPrim)); + case schema::PrimitiveType_NotEqual: + return new lite::NotEqual(const_cast(srcPrim)); + case schema::PrimitiveType_Less: + return new lite::Less(const_cast(srcPrim)); + case schema::PrimitiveType_LessEqual: + return new lite::LessEqual(const_cast(srcPrim)); + case schema::PrimitiveType_Greater: + return new lite::Greater(const_cast(srcPrim)); + case schema::PrimitiveType_GreaterEqual: + return new lite::GreaterEqual(const_cast(srcPrim)); + case schema::PrimitiveType_Floor: + return new lite::Floor(const_cast(srcPrim)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(srcPrim)); + case schema::PrimitiveType_Split: + return new lite::Split(const_cast(srcPrim)); + case schema::PrimitiveType_OneHot: + return new lite::OneHot(const_cast(srcPrim)); + case schema::PrimitiveType_MatMul: + return new lite::MatMul(const_cast(srcPrim)); + case schema::PrimitiveType_QuantDTypeCast: + return new lite::QuantDTypeCast(const_cast(srcPrim)); + case schema::PrimitiveType_EmbeddingLookup: + return new lite::EmbeddingLookup(const_cast(srcPrim)); + default: + break; + } + return nullptr; +} + +int ModelImpl::BuildOps() { + if (this->meta_graph == nullptr) { + MS_LOG(ERROR) << "mete_graph is nullptr"; + return -1; + } + MS_EXCEPTION_IF_NULL(meta_graph->nodes()); + for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + auto name = cNode->name()->str(); + auto srcPrim = cNode->primitive(); + + this->ops[name] = CopyPrimitive(srcPrim); + // flatbuffers::FlatBufferBuilder fbb(1024); + // schema::Conv2DBuilder conv2DBuilder(fbb); + // conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode()); + // conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut()); + // conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn()); + // conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH()); + // conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW()); + // conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH()); + // conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW()); + // conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH()); + // conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW()); + // conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp()); + // conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown()); + // conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft()); + // conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight()); + // conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format()); + // conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group()); + // conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType()); + // schema::PrimitiveBuilder primBuilder(fbb); + // primBuilder.add_value_type(srcPrim->value_type()); + // primBuilder.add_value(conv2DBuilder.Finish()); + // + // fbb.Finish(conv2DBuilder.Finish()); + // auto buf = fbb.GetBufferPointer(); + // auto conv2D = flatbuffers::GetRoot(buf); + // fbb.Clear(); + // + // return const_cast(opDef); + } + return 0; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/model_impl.h b/mindspore/lite/src/model_impl.h new file mode 100644 index 0000000000..14e0a1ccb9 --- /dev/null +++ b/mindspore/lite/src/model_impl.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_MODEL_IMPL_H_ +#define MINDSPORE_LITE_SRC_MODEL_IMPL_H_ + +#include +#include +#include +#include "schema/model_generated.h" +#include "src/ops/ops.h" + +namespace mindspore { +namespace lite { +class ModelImpl { + public: + static std::shared_ptr Import(const char *model_buf, size_t size); + ModelImpl() = default; + explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { + meta_graph = schema::GetMetaGraph(model_buf); + } + virtual ~ModelImpl(); + lite::Primitive *GetOp(const std::string &name) const; + const schema::MetaGraph *GetMetaGraph() const; + void FreeMetaGraph(); + int BuildOps(); + + protected: + lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); + + protected: + const char *model_buf_; + size_t buf_size_; + const schema::MetaGraph *meta_graph = nullptr; + std::map ops; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_INCLUDE_MODEL_H_ + diff --git a/mindspore/lite/src/ops/CMakeLists.txt b/mindspore/lite/src/ops/CMakeLists.txt new file mode 100644 index 0000000000..c468336fca --- /dev/null +++ b/mindspore/lite/src/ops/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) + +add_library(ops_mid_ OBJECT ${OPS_SRC}) \ No newline at end of file diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc new file mode 100644 index 0000000000..a6fab39cfc --- /dev/null +++ b/mindspore/lite/src/ops/addn.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kLeastInputNum = 2; +} +int AddN::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs.front(); + MS_ASSERT(input != nullptr); + auto output = outputs.front(); + MS_ASSERT(output != nullptr); + if (inputs.size() < kLeastInputNum) { + MS_LOG(ERROR) << "input size" << inputs.size() << " is error!"; + return RET_INPUT_TENSOR_ERROR; + } + for (int i = 1; i < inputs.size(); ++i) { + if (inputs.at(i)->shape() != inputs.at(0)->shape()) { + MS_LOG(ERROR) << "AddN inputs shape is not equal!"; + return RET_INPUT_TENSOR_ERROR; + } + if (inputs.at(i)->data_type() != inputs.at(0)->data_type()) { + MS_LOG(ERROR) << "AddN all input data type should be the same!"; + return RET_INPUT_TENSOR_ERROR; + } + } + output->SetFormat(input->GetFormat()); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc new file mode 100644 index 0000000000..d409ba1f0d --- /dev/null +++ b/mindspore/lite/src/ops/argmax.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ArgMax::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "tensor number is error."; + } + auto argmax_prim = this->primitive->value_as_ArgMax(); + + std::vector output_shape(input->shape()); + auto input_shape_size = input->shape().size(); + int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis(); + if (axis >= input_shape_size || axis < 0) { + MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size; + return RET_PARAM_INVALID; + } + if (argmax_prim->topK() == 1) { + output_shape.erase(output_shape.begin() + axis); + } else if (argmax_prim->axisType() == 1) { + output_shape[axis] = argmax_prim->topK(); + } + + output->SetFormat(input->GetFormat()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc new file mode 100644 index 0000000000..4fbd5a3861 --- /dev/null +++ b/mindspore/lite/src/ops/argmin.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ArgMin::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "tensor number is error."; + } + auto argmin_prim = this->primitive->value_as_ArgMin(); + auto input_shape_size = input->shape().size(); + int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis(); + if (axis >= input_shape_size || axis < 0) { + MS_LOG(ERROR) << "Invalid axis " << argmin_prim->axis() << ", input shape size: " << input_shape_size; + return RET_PARAM_INVALID; + } + std::vector output_shape(input->shape()); + if (argmin_prim->topK() == 1) { + output_shape.erase(output_shape.begin() + axis); + } else if (argmin_prim->axisType() == 1) { + output_shape[axis] = argmin_prim->topK(); + } + + output->SetFormat(input->GetFormat()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/arithmetic.cc b/mindspore/lite/src/ops/arithmetic.cc new file mode 100644 index 0000000000..2b15e22608 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Arithmetic::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "The number of input must be " << kDoubleNum; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "The number of output must be " << kSingleNum; + return RET_INPUT_TENSOR_ERROR; + } + auto input0 = inputs_[0]; + MS_ASSERT(input0 != nullptr); + auto input1 = inputs_[1]; + MS_ASSERT(input1 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto input_shape0 = input0->shape(); + auto input_shape1 = input1->shape(); + auto format = input0->GetFormat(); + in_shape0_.resize(5); + in_shape1_.resize(5); + out_shape_.resize(5); + + ndim_ = input_shape0.size(); + if (input_shape0.size() < input_shape1.size()) { + ndim_ = input_shape1.size(); + auto fill_dim_num = input_shape1.size() - input_shape0.size(); + int j = 0; + for (int i = 0; i < input_shape1.size(); i++) { + if (i < fill_dim_num) { + in_shape0_[i] = 1; + } else { + in_shape0_[i] = input_shape0[j++]; + } + in_shape1_[i] = input_shape1[i]; + } + format = input0->GetFormat(); + } else if (input_shape0.size() > input_shape1.size()) { + ndim_ = input_shape0.size(); + auto fill_dim_num = input_shape0.size() - input_shape1.size(); + int j = 0; + for (int i = 0; i < input_shape0.size(); i++) { + if (i < fill_dim_num) { + in_shape1_[i] = 1; + } else { + in_shape1_[i] = input_shape1[j++]; + } + in_shape0_[i] = input_shape0[i]; + } + } else { + for (int i = 0; i < input_shape0.size(); i++) { + in_shape1_[i] = input_shape1[i]; + in_shape0_[i] = input_shape0[i]; + } + } + + std::vector output_shape; + for (size_t i = 0; i < ndim_; i++) { + if (in_shape0_[i] != in_shape1_[i]) { + if (in_shape0_[i] == 1) { + out_shape_[i] = in_shape1_[i]; + } else if (in_shape1_[i] == 1) { + out_shape_[i] = in_shape0_[i]; + } else { + MS_LOG(ERROR) << "shapes of input tensors can not be broadCasted"; + return -1; + } + broadcasting_ = true; + } else { + out_shape_[i] = in_shape0_[i]; + } + output_shape.push_back(out_shape_[i]); + } + output->SetFormat(format); + output->set_shape(output_shape); + output->set_data_type(input0->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/arithmetic_self.cc b/mindspore/lite/src/ops/arithmetic_self.cc new file mode 100644 index 0000000000..3d2210e746 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic_self.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ArithmeticSelf::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + output->SetFormat(input->GetFormat()); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc new file mode 100644 index 0000000000..41412c58c3 --- /dev/null +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kBatchToSpaceOutputNum = 1; +constexpr int kBatchToSpaceInputNum = 1; +constexpr int kBlockShapeSize = 2; +constexpr int kCropsSize = 4; +} // namespace + +int BatchToSpace::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + return RET_PARAM_INVALID; + } + auto prim = this->primitive->value_as_BatchToSpace(); + auto block_shape = prim->blockShape(); + if (block_shape->size() != kBlockShapeSize) { + MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize; + return RET_PARAM_INVALID; + } + auto crops = prim->crops(); + if (crops->size() != kCropsSize) { + MS_LOG(ERROR) << "Crops size should be " << kCropsSize; + return RET_PARAM_INVALID; + } + size_t mul_block_shape = 1; + + for (size_t i = 0; i < kBlockShapeSize; ++i) { + if (block_shape->Get(i) <= 0) { + MS_LOG(ERROR) << "Input block_shape should > 0!"; + return RET_PARAM_INVALID; + } + if (input_shape[kNHWC_n_index] % block_shape->Get(i)) { + MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " can not divide block_shape[" << i << "] " + << block_shape->Get(i); + return RET_PARAM_INVALID; + } + mul_block_shape *= block_shape->Get(i); + } + + if (input_shape[kNHWC_n_index] < mul_block_shape) { + MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " < product of block shape!"; + return RET_PARAM_INVALID; + } + for (size_t i = 0; i < kCropsSize; ++i) { + if (crops->Get(i) < 0) { + MS_LOG(ERROR) << "Input crops should >= 0"; + return RET_PARAM_INVALID; + } + } + std::vector output_shape(input_shape.size()); + output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index] / mul_block_shape; + output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_shape->Get(0) - crops->Get(0) - crops->Get(1); + output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3); + output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index]; + + outputs[0]->SetFormat(input->GetFormat()); + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc new file mode 100644 index 0000000000..51e5914677 --- /dev/null +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kBroadcastToInputNum = 1; +constexpr int kBroadcastToOutputNum = 1; +} // namespace + +int BroadcastTo::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { + MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + std::vector dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(), + this->primitive->value_as_BroadcastTo()->dst_shape()->end()); + auto input_shape = input->shape(); + std::vector shape(dst_shape.size()); + int input_shape_index = input_shape.size() - 1; + if (input_shape.size() > dst_shape.size()) { + MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size " + << dst_shape.size() << "!"; + return RET_PARAM_INVALID; + } + + for (int i = dst_shape.size() - 1; i >= 0; --i) { + if (dst_shape[i] < 0) { + MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!"; + return RET_PARAM_INVALID; + } + if (input_shape_index >= 0) { + auto dim = input_shape[input_shape_index]; + if (dim != dst_shape[i] && dim != 1) { + MS_LOG(ERROR) << "Invalid broadcast shape!"; + return RET_PARAM_INVALID; + } + } + shape[i] = dst_shape[i]; + --input_shape_index; + } + outputs[0]->SetFormat(input->GetFormat()); + outputs[0]->set_shape(shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc new file mode 100644 index 0000000000..13de84ff5e --- /dev/null +++ b/mindspore/lite/src/ops/cast.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Cast::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "tensor number is error."; + return RET_INPUT_TENSOR_ERROR; + } + auto cast_prim = this->primitive->value_as_Cast(); + MS_ASSERT(cast_prim != nullptr); + if (input->data_type() != cast_prim->srcT()) { + MS_LOG(ERROR) << "input dataType is error"; + return RET_INPUT_TENSOR_ERROR; + } + if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) { + MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); + return RET_INPUT_TENSOR_ERROR; + } + if (cast_prim->dstT() != kNumberTypeFloat || cast_prim->dstT() != kNumberTypeFloat32) { + MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT(); + return RET_INPUT_TENSOR_ERROR; + } + output->SetFormat(input->GetFormat()); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc new file mode 100644 index 0000000000..e69e2707a3 --- /dev/null +++ b/mindspore/lite/src/ops/concat.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kConcatOutputNum = 1; +} +int Concat::InferShape(std::vector inputs_, std::vector outputs_) { + if (this->primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr!"; + return RET_PARAM_INVALID; + } + auto input0 = inputs_.front(); + auto output = outputs_.front(); + if (outputs_.size() != kConcatOutputNum) { + MS_LOG(ERROR) << "output size is error"; + return RET_PARAM_INVALID; + } + auto concat_prim = this->primitive->value_as_Concat(); + MS_ASSERT(concat_prim != nullptr); + auto input0_shape = inputs_.at(0)->shape(); + int axis = concat_prim->axis() < 0 ? concat_prim->axis() + input0_shape.size() : concat_prim->axis(); + if (axis < 0 || axis >= input0_shape.size()) { + MS_LOG(ERROR) << "Invalid axis: " << axis; + return RET_PARAM_INVALID; + } + + auto input0_shape_without_axis = input0_shape; + input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis); + auto input0_data_type = inputs_.at(0)->data_type(); + int output_axis_dim = input0_shape.at(axis); + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_.at(i)->data_type() != input0_data_type) { + MS_LOG(ERROR) << "All inputs should have the same data type!"; + return RET_PARAM_INVALID; + } + + auto shape_tmp = inputs_.at(i)->shape(); + if (shape_tmp.size() != input0_shape.size()) { + MS_LOG(ERROR) << "All inputs should have the same dim num!"; + return RET_PARAM_INVALID; + } + auto axis_tmp = shape_tmp[axis]; + shape_tmp.erase(shape_tmp.begin() + axis); + if (input0_shape_without_axis != shape_tmp) { + MS_LOG(ERROR) << "Inputs should have the same dim except axis!"; + return RET_PARAM_INVALID; + } + output_axis_dim += axis_tmp; + } + auto output_shape = input0_shape; + output_shape[axis] = output_axis_dim; + outputs_[0]->set_shape(output_shape); + output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/conv.cc b/mindspore/lite/src/ops/conv.cc new file mode 100644 index 0000000000..9597e851fc --- /dev/null +++ b/mindspore/lite/src/ops/conv.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { + MS_ASSERT(this->primitive != nullptr); + auto conv2DPrim = this->primitive->value_as_Conv2D(); + int kernel_w = conv2DPrim->kernelW(); + int kernel_h = conv2DPrim->kernelH(); + int stride_w = conv2DPrim->strideW(); + int stride_h = conv2DPrim->strideH(); + int dilate_w = conv2DPrim->dilateW(); + int dilate_h = conv2DPrim->dilateH(); + pad_l_ = conv2DPrim->padLeft(); + pad_u_ = conv2DPrim->padUp(); + pad_d_ = conv2DPrim->padDown(); + pad_r_ = conv2DPrim->padRight(); + + if (conv2DPrim->padMode() == schema::PadMode_SAME) { + *output_w = std::ceil(static_cast(input_w) / static_cast(stride_w)); + *output_h = std::ceil(static_cast(input_h) / static_cast(stride_h)); + auto pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h); + auto pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w); + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } else { + *output_w = std::ceil((static_cast(input_w) + pad_l_ + pad_r_ - + (static_cast(kernel_w) - 1) * static_cast(dilate_w)) / static_cast(stride_w)); + *output_h = std::ceil((static_cast(input_h) + pad_u_ + pad_d_ - + (static_cast(kernel_h) - 1) * static_cast(dilate_h)) / static_cast(stride_h)); + } +} + +int Conv2D::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != 2 && inputs_.size() != 3) { + MS_LOG(ERROR) << "Add should has two or three inputs"; + return RET_ERROR; + } + if (outputs_.size() != 1) { + MS_LOG(ERROR) << "Add should has one outputs"; + return RET_ERROR; + } + auto *input_tensor = inputs_.front(); + auto *weight_tensor = inputs_.at(1); + auto *out_tensor = outputs_.front(); + MS_ASSERT(input_tensor != nullptr); + MS_ASSERT(out_tensor != nullptr); + + auto in_shape = input_tensor->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + int output_w = 0, output_h = 0; + + this->ConvInferShape(input_h, input_w, &output_h, &output_w); + + std::vector out_shape{input_tensor->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; + out_shape.at(3) = weight_tensor->shape()[0]; + out_tensor->set_shape(out_shape); + out_tensor->SetFormat(input_tensor->GetFormat()); + out_tensor->set_data_type(input_tensor->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/convolution_depthwise.cc b/mindspore/lite/src/ops/convolution_depthwise.cc new file mode 100644 index 0000000000..c26bda7b6f --- /dev/null +++ b/mindspore/lite/src/ops/convolution_depthwise.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int DepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + MS_LOG(ERROR) << "inputs number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight = inputs_.at(1); + MS_ASSERT(weight != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto in_shape = input->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + int output_w = 0, output_h = 0; + + auto conv_prim = this->primitive->value_as_DepthwiseConv2D(); + pad_l_ = conv_prim->padLeft(); + pad_u_ = conv_prim->padUp(); + pad_d_ = conv_prim->padDown(); + pad_r_ = conv_prim->padRight(); + if (conv_prim->padMode() == schema::PadMode_SAME) { + output_h = std::ceil(static_cast(input_h) / static_cast(conv_prim->strideH())); + output_w = std::ceil(static_cast(input_w) / static_cast(conv_prim->strideW())); + auto pad_h_all = + ((output_h - 1) * conv_prim->strideH() + (conv_prim->kernelH() - 1) * conv_prim->dilateH() + 1 - input_h); + auto pad_w_all = + ((output_w - 1) * conv_prim->strideW() + (conv_prim->kernelW() - 1) * conv_prim->dilateW() + 1 - input_w); + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } else { + output_h = + std::ceil((static_cast(input_h) + pad_u_ + pad_d_ - (static_cast(conv_prim->kernelH()) - 1) * + static_cast(conv_prim->dilateH())) / static_cast(conv_prim->strideH())); + output_w = + std::ceil((static_cast(input_w) + pad_l_ + pad_r_ - (static_cast(conv_prim->kernelW()) - 1) * + static_cast(conv_prim->dilateW())) / static_cast(conv_prim->strideW())); + } + std::vector out_shape{input->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; + out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel + + output->set_shape(out_shape); + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc new file mode 100644 index 0000000000..dceab29f9b --- /dev/null +++ b/mindspore/lite/src/ops/crop.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kCropOutputNum = 1; +constexpr int kCropInputNum = 2; +} // namespace + +int Crop::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + outputs[0]->set_shape(inputs[1]->shape()); + outputs[0]->SetFormat(inputs[1]->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/deconvolution.cc b/mindspore/lite/src/ops/deconvolution.cc new file mode 100644 index 0000000000..effcef756a --- /dev/null +++ b/mindspore/lite/src/ops/deconvolution.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int DeConv2D::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight = inputs_.at(1); + MS_ASSERT(weight != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + int32_t input_h = input->Height(); + int32_t input_w = input->Width(); + + int32_t output_n = input->Batch(); + int32_t output_h = 0; + int32_t output_w = 0; + int32_t output_c = weight->Channel(); + + auto deconv = GetAttribute(); + int kernel_w = deconv->kernelW(); + int kernel_h = deconv->kernelH(); + int stride_w = deconv->strideW(); + int stride_h = deconv->strideH(); + int dilate_w = deconv->dilateW(); + int dilate_h = deconv->dilateH(); + pad_l_ = deconv->padLeft(); + pad_u_ = deconv->padUp(); + pad_d_ = deconv->padDown(); + pad_r_ = deconv->padRight(); + schema::PadMode pad_mode = deconv->padMode(); + + if (pad_mode == schema::PadMode_CAFFE) { + output_h = (input_h - 1) * stride_h + ((kernel_h - 1) * dilate_h + 1) - pad_u_ - pad_d_; + output_w = (input_w - 1) * stride_w + ((kernel_w - 1) * dilate_w + 1) - pad_l_ - pad_r_; + } else if (pad_mode == schema::PadMode_SAME) { + output_h = input_h * stride_h; + output_w = input_w * stride_w; + } else if (pad_mode == schema::PadMode_VALID) { + output_h = (input_h - 1) * stride_h + kernel_h; + output_w = (input_w - 1) * stride_w + kernel_w; + } else { + MS_LOG(ERROR) << "unsupported pad mode for deconv"; + } + + std::vector out_shape = {output_n, output_h, output_w, output_c}; + output->set_shape(out_shape); + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/deconvolution_depthwise.cc b/mindspore/lite/src/ops/deconvolution_depthwise.cc new file mode 100644 index 0000000000..01e2542482 --- /dev/null +++ b/mindspore/lite/src/ops/deconvolution_depthwise.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int DeconvDepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + MS_LOG(ERROR) << "inputs number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight = inputs_.at(1); + MS_ASSERT(weight != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto in_shape = input->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + int output_w = 0, output_h = 0; + + auto conv_prim = this->primitive->value_as_DeDepthwiseConv2D(); + pad_l_ = conv_prim->padLeft(); + pad_u_ = conv_prim->padUp(); + pad_d_ = conv_prim->padDown(); + pad_r_ = conv_prim->padRight(); + output_h = conv_prim->strideH() * (input_h - 1) * conv_prim->kernelH() - pad_u_ - pad_d_; + output_w = conv_prim->strideW() * (input_w - 1) * conv_prim->kernelW() - pad_l_ - pad_r_; + if ((output_h + conv_prim->padUp() + conv_prim->padDown() - conv_prim->kernelH()) % conv_prim->strideH() != 0) { + output_h += (output_h + conv_prim->padLeft() + conv_prim->padRight() - conv_prim->kernelH()) % conv_prim->strideH(); + } + if ((output_w + conv_prim->padLeft() + conv_prim->padRight() - conv_prim->kernelW()) % conv_prim->strideW() != 0) { + output_w += (output_w + conv_prim->padLeft() + conv_prim->padRight() - conv_prim->kernelW()) % conv_prim->strideW(); + } + std::vector out_shape{input->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; + out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel + + output->set_shape(out_shape); + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc new file mode 100644 index 0000000000..f09fddfb58 --- /dev/null +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kDepthToSpaceOutputNum = 1; +constexpr int kDepthToSpaceInputNum = 1; +} // namespace + +int DepthToSpace::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + return RET_PARAM_INVALID; + } + auto prim = this->primitive->value_as_DepthToSpace(); + int32_t block_size = prim->blockSize(); + if (input_shape[kNHWC_c_index] % (block_size * block_size) != 0 || input_shape[kNHWC_c_index] == 0) { + MS_LOG(ERROR) << "input dimension c size " << input_shape[kNHWC_c_index] << " should be mulitple of block_size(" + << block_size << ") * block_size)!"; + return RET_PARAM_INVALID; + } + std::vector output_shape(input_shape.size()); + output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index]; + output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_size; + output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_size; + output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index] / (block_size * block_size); + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/embedding_lookup.cc b/mindspore/lite/src/ops/embedding_lookup.cc new file mode 100644 index 0000000000..3a25197611 --- /dev/null +++ b/mindspore/lite/src/ops/embedding_lookup.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "src/ir/tensor.h" +#include "utils/log_adapter.h" + +namespace mindspore::lite { +int EmbeddingLookup::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() < kDoubleNum) { + MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs"; + return RET_INPUT_TENSOR_ERROR; + } + + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "Embedding Lookup should have one outputs"; + return RET_INPUT_TENSOR_ERROR; + } + + auto params_ = inputs_.front(); + MS_ASSERT(params_ != nullptr); + auto ids = inputs_.back(); + MS_ASSERT(ids != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto embedding_shape = params_->shape(); + embedding_shape.erase(embedding_shape.begin()); + + std::vector output_shape(ids->shape()); + for (size_t i = 0; i < embedding_shape.size(); ++i) { + output_shape.push_back(embedding_shape.at(i)); + } + + for (int i = 1; i < inputs_.size() - 1; ++i) { + auto embedding_shape_t = inputs_.at(i)->shape(); + embedding_shape_t.erase(embedding_shape_t.begin()); + if (embedding_shape_t != embedding_shape) { + MS_LOG(ERROR) << "The embedded layers should have the same shape"; + return RET_INPUT_TENSOR_ERROR; + } + } + + output->set_shape(output_shape); + output->set_data_type(params_->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc new file mode 100644 index 0000000000..5b0391d654 --- /dev/null +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ExpandDims::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size is invalid"; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output size is invalid"; + } + auto expand_dims_prim = this->primitive->value_as_ExpandDims(); + int dim = expand_dims_prim->dim(); + if (dim < 0) { + dim += input->shape().size() + 1; + } + if (dim > input->shape().size()) { + MS_LOG(ERROR) << "attribute dim out of range"; + return RET_INPUT_TENSOR_ERROR; + } + auto out_shape = input->shape(); + out_shape.insert(out_shape.begin() + dim, 1, 1); + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc new file mode 100644 index 0000000000..f4bd0c1952 --- /dev/null +++ b/mindspore/lite/src/ops/fill.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Fill::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "Fill input or output is null!"; + return RET_ERROR; + } + + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto fill_prim = this->primitive->value_as_Fill(); + if (fill_prim == nullptr) { + MS_LOG(ERROR) << "Fill primitive is null!"; + return RET_ERROR; + } + std::vector output_shape; + (void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc new file mode 100644 index 0000000000..bde0cd16c5 --- /dev/null +++ b/mindspore/lite/src/ops/flatten.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Flatten::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "Flatten input or output is null!"; + return RET_ERROR; + } + + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + + auto input_shape = input->shape(); + std::vector output_shape(2); + output_shape[0] = input_shape[0]; + output_shape[1] = 1; + for (int i = 1; i < input_shape.size(); i++) { + output_shape[1] *= input_shape[i]; + } + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/fullconnection.cc b/mindspore/lite/src/ops/fullconnection.cc new file mode 100644 index 0000000000..7b4b1e051f --- /dev/null +++ b/mindspore/lite/src/ops/fullconnection.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int FullConnection::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input0 = inputs_.front(); + MS_ASSERT(input0 != nullptr); + auto input1 = inputs_.at(1); + MS_ASSERT(input1 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto fc_prim = this->primitive->value_as_FullConnection(); + if ((fc_prim->hasBias() && inputs_.size() != kMultiNum) || (!fc_prim->hasBias() && inputs_.size() != kDoubleNum)) { + MS_LOG(ERROR) << "Input tensors num error"; + return RET_INPUT_TENSOR_ERROR; + } + if (fc_prim->axis() < 1 || fc_prim->axis() > input0->shape().size()) { + MS_LOG(ERROR) << "FullConnection axis invalid"; + return RET_INPUT_TENSOR_ERROR; + } + int new_k = 1; + for (size_t i = fc_prim->axis(); i < input0->shape().size(); ++i) { + new_k *= input0->shape().at(i); + } + if (new_k != input1->shape().at(1)) { + MS_LOG(ERROR) << "Input1 size invalid"; + return RET_PARAM_INVALID; + } + if (fc_prim->hasBias()) { + if (inputs_.at(2)->shape()[0] != input1->shape()[0]) { + MS_LOG(ERROR) << "bias size invalid"; + return RET_PARAM_INVALID; + } + } + std::vector out_shape{inputs_[0]->shape()}; + out_shape.resize(fc_prim->axis() + 1); + out_shape[fc_prim->axis()] = input1->shape()[0]; + output->set_shape(out_shape); + output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc new file mode 100644 index 0000000000..328de9ba2f --- /dev/null +++ b/mindspore/lite/src/ops/gather.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Gather::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "Gather should have two inputs"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "Gather should have one outputs"; + return RET_INPUT_TENSOR_ERROR; + } + + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto indices = inputs_.at(1); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(input != nullptr); + + auto gather_prim = this->primitive->value_as_Gather(); + MS_ASSERT(gather_prim != nullptr); + + int axis = gather_prim->axis(); + int batch_dims = gather_prim->batchDims(); + if (axis < 0) { + axis += input->shape().size(); + } + auto indices_shape = indices->shape(); + int indices_rank = indices_shape.size(); + if (indices_rank < batch_dims + 1) { + MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1"; + return RET_ERROR; + } + if (batch_dims != 0) { + MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support"; + return RET_ERROR; + } + auto in_shape = input->shape(); + int in_rank = in_shape.size(); + if (in_rank < axis + 1) { + MS_LOG(ERROR) << "input[0]'s rank is less than axis + 1"; + return RET_ERROR; + } + + std::vector out_shape{in_shape}; + out_shape.erase(out_shape.begin() + axis); + for (size_t i = 0; i < indices_rank; i++) { + out_shape.insert(out_shape.begin() + axis, indices_shape[i]); + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc new file mode 100644 index 0000000000..681e2d207b --- /dev/null +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int GatherNd::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "GatherNd should have two inputs"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "GatherNd should have one outputs"; + return RET_INPUT_TENSOR_ERROR; + } + + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto indices = inputs_.at(1); + MS_ASSERT(indices != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto in_shape = input->shape(); + int in_rank = in_shape.size(); + auto indices_shape = indices->shape(); + int indices_rank = indices_shape.size(); + + if (indices_shape[indices_rank - 1] > in_rank) { + MS_LOG(ERROR) << "Input of indices data is error!"; + return RET_ERROR; + } + + std::vector out_shape; + int i = 0; + for (i = 0; i < indices_rank - 1; ++i) { + out_shape.emplace_back(indices_shape[i]); + } + for (i = indices_shape[indices_rank - 1]; i < in_rank; ++i) { + out_shape.emplace_back(in_shape[i]); + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc new file mode 100644 index 0000000000..22b52ad8a7 --- /dev/null +++ b/mindspore/lite/src/ops/lstm.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +const int kLstmInputNum = 6; +const int kLstmOutputNum = 3; +int Lstm::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kLstmInputNum || outputs_.size() != kLstmOutputNum) { + MS_LOG(ERROR) << "OpLstm inputs or outputs size error."; + return RET_INPUT_TENSOR_ERROR; + } + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight_i = inputs_.front(); + MS_ASSERT(input0 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + std::vector in_shape = input->shape(); + std::vector w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size + if (in_shape.size() != 3 || w_shape.size() != 3) { + MS_LOG(ERROR) << "OpLstm input dims should be 3."; + return RET_ERROR; + } + + auto lstm_prim = this->primitive->value_as_Lstm(); + int hidden_size = w_shape[1] / 4; + + // set output + std::vector out_shape(in_shape); + out_shape[2] = hidden_size; + if (lstm_prim->bidirection()) { + out_shape.insert(out_shape.begin() + 1, 2); + } + output->set_shape(out_shape); + + // set hidden state, cell state + std::vector state_shape(in_shape); + state_shape[0] = lstm_prim->bidirection() ? 2 : 1; + state_shape[2] = hidden_size; + outputs_[1]->set_shape(state_shape); + outputs_[2]->set_shape(state_shape); + + for (int i = 0; i < kLstmOutputNum; i++) { + outputs_[i]->set_data_type(input->data_type()); + outputs_[i]->SetFormat(input->GetFormat()); + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc new file mode 100644 index 0000000000..619fcade8c --- /dev/null +++ b/mindspore/lite/src/ops/matmul.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int MatMul::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "OpMatMul inputs size: " << inputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto input0 = inputs_.front(); + MS_ASSERT(input0 != nullptr); + auto input1 = inputs_.at(1); + MS_ASSERT(input1 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + std::vector a_shape = input0->shape(); + std::vector b_shape = input1->shape(); + if (a_shape.size() < 3 || b_shape.size() < 3) { + MS_LOG(ERROR) << "inputs shape is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + + for (int i = 0; i < a_shape.size() - 2; ++i) { + if (a_shape[i] != b_shape[i]) { + MS_LOG(ERROR) << "Op MatMul's dimensions must be equal"; + return RET_INPUT_TENSOR_ERROR; + } + } + + auto matmul_prim = this->primitive->value_as_MatMul(); + if (matmul_prim->transposeA()) { + std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]); + } + if (matmul_prim->transposeB()) { + std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]); + } + std::vector c_shape(a_shape); + c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; + output->set_shape(c_shape); + output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/mean.cc b/mindspore/lite/src/ops/mean.cc new file mode 100644 index 0000000000..331198221f --- /dev/null +++ b/mindspore/lite/src/ops/mean.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr size_t kInputSize = 1; +constexpr size_t kOutputSize = 1; +} // namespace +int Mean::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) { + return RET_ERROR; + } + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + return RET_NULL_PTR; + } + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto mean_prim = this->primitive->value_as_Mean(); + bool keep_dims = static_cast(mean_prim->keepDims()); + std::vector in_shape = input->shape(); + std::vector out_shape; + const auto &axes = mean_prim->axis(); + auto num_axes = axes->size(); + // reduce on all axes + if (num_axes == 0) { + if (keep_dims) { + for (auto i = 0; i < in_shape.size(); i++) { + out_shape.push_back(1); + } + } + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; + } + + // reduce on selected axes + for (size_t i = 0; i < in_shape.size(); i++) { + bool reduce_axis = false; + for (int idx = 0; idx < num_axes; ++idx) { + if (static_cast((*axes)[idx]) == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + out_shape.push_back(1); + } + } else { + out_shape.push_back(in_shape[i]); + } + } + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/nchw2nhwc.cc b/mindspore/lite/src/ops/nchw2nhwc.cc new file mode 100644 index 0000000000..5a420ceba8 --- /dev/null +++ b/mindspore/lite/src/ops/nchw2nhwc.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" +#include "src/common/common.h" + +namespace mindspore::lite { +int Nchw2Nhwc::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + std::vector nchw_shape = input->shape(); + if (nchw_shape.size() != 4) { + output->set_shape(nchw_shape); + } else { + std::vector nhwc_shape{nchw_shape}; + nhwc_shape[NHWC_N] = nchw_shape[NCHW_N]; + nhwc_shape[NHWC_H] = nchw_shape[NCHW_H]; + nhwc_shape[NHWC_W] = nchw_shape[NCHW_W]; + nhwc_shape[NHWC_C] = nchw_shape[NCHW_C]; + output->set_shape(nhwc_shape); + } + output->SetFormat(schema::Format_NHWC); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/nhwc2nchw.cc b/mindspore/lite/src/ops/nhwc2nchw.cc new file mode 100644 index 0000000000..579ce71be2 --- /dev/null +++ b/mindspore/lite/src/ops/nhwc2nchw.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" +#include "src/common/common.h" + +namespace mindspore::lite { +int Nhwc2Nchw::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + std::vector nhwc_shape = input->shape(); + if (nhwc_shape.size() != 4) { + output->set_shape(nhwc_shape); + } else { + std::vector nchw_shape{nhwc_shape}; + nchw_shape[NCHW_N] = nhwc_shape[NHWC_N]; + nchw_shape[NCHW_C] = nhwc_shape[NHWC_C]; + nchw_shape[NCHW_H] = nhwc_shape[NHWC_H]; + nchw_shape[NCHW_W] = nhwc_shape[NHWC_W]; + output->set_shape(nchw_shape); + } + output->SetFormat(schema::Format_NCHW); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc new file mode 100644 index 0000000000..878813c995 --- /dev/null +++ b/mindspore/lite/src/ops/one_hot.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr size_t kOneHotInputNum = 4; +} +int OneHot::InferShape(std::vector inputs, std::vector outputs) { + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto one_hot_prim = this->primitive->value_as_OneHot(); + if (one_hot_prim == nullptr) { + return RET_NULL_PTR; + } + int axis = one_hot_prim->axis(); + + // indices, depth, on_value, off_value + if (inputs.size() != kOneHotInputNum) { + MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; + return RET_ERROR; + } + auto depth_tensor = inputs.at(1); + if (depth_tensor == nullptr) { + return RET_NULL_PTR; + } + const int *depth = static_cast(depth_tensor->Data()); + + auto input = inputs.front(); + if (input == nullptr) { + return RET_NULL_PTR; + } + const auto input_shape = input->shape(); + int input_rank = static_cast(input_shape.size()); + if (axis < 0) { + axis += input_rank + 1; + } + std::vector output_shape(input_shape); + output_shape.insert(output_shape.cbegin() + axis, *depth); + + auto output = outputs.front(); + if (output == nullptr) { + return RET_NULL_PTR; + } + output->set_shape(output_shape); + + auto on_value = inputs.at(2); + if (on_value == nullptr) { + return RET_NULL_PTR; + } + output->set_data_type(on_value->data_type()); + output->SetFormat(on_value->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.cc b/mindspore/lite/src/ops/ops.cc new file mode 100644 index 0000000000..9f771cbc3f --- /dev/null +++ b/mindspore/lite/src/ops/ops.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { + MS_ASSERT(primitive != nullptr); + auto op_type = primitive->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(primitive)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(primitive)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(primitive)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(primitive)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(primitive)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(primitive)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(primitive)); + case schema::PrimitiveType_BatchNorm: + return new lite::BatchNorm(const_cast(primitive)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(primitive)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(primitive)); + case schema::PrimitiveType_Pad: + return new lite::Pad(const_cast(primitive)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(primitive)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(primitive)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(primitive)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(primitive)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(primitive)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(primitive)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(primitive)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(primitive)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(primitive)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(primitive)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(primitive)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(primitive)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(primitive)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(primitive)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(primitive)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(primitive)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(primitive)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(primitive)); + case schema::PrimitiveType_Squeeze: + return new lite::Squeeze(const_cast(primitive)); + case schema::PrimitiveType_SquaredDifference: + return new lite::SquaredDifference(const_cast(primitive)); + case schema::PrimitiveType_Split: + return new lite::Split(const_cast(primitive)); + case schema::PrimitiveType_FloorDiv: + return new lite::FloorDiv(const_cast(primitive)); + case schema::PrimitiveType_FloorMod: + return new lite::FloorMod(const_cast(primitive)); + case schema::PrimitiveType_Reverse: + return new lite::Reverse(const_cast(primitive)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(primitive)); + case schema::PrimitiveType_GatherNd: + return new lite::GatherNd(const_cast(primitive)); + case schema::PrimitiveType_Tile: + return new lite::Tile(const_cast(primitive)); + case schema::PrimitiveType_TopK: + return new lite::TopK(const_cast(primitive)); + case schema::PrimitiveType_Unique: + return new lite::Unique(const_cast(primitive)); + case schema::PrimitiveType_Unstack: + return new lite::Unstack(const_cast(primitive)); + case schema::PrimitiveType_ReverseSequence: + return new lite::ReverseSequence(const_cast(primitive)); + case schema::PrimitiveType_Round: + return new lite::Round(const_cast(primitive)); + case schema::PrimitiveType_ZerosLike: + return new lite::ZerosLike(const_cast(primitive)); + case schema::PrimitiveType_Where: + return new lite::Where(const_cast(primitive)); + case schema::PrimitiveType_Floor: + return new lite::Floor(const_cast(primitive)); + case schema::PrimitiveType_Shape: + return new lite::Shape(const_cast(primitive)); + case schema::PrimitiveType_ScatterND: + return new lite::ScatterND(const_cast(primitive)); + case schema::PrimitiveType_Unsqueeze: + return new lite::Unsqueeze(const_cast(primitive)); + case schema::PrimitiveType_Flatten: + return new lite::Flatten(const_cast(primitive)); + case schema::PrimitiveType_StridedSlice: + return new lite::StridedSlice(const_cast(primitive)); + case schema::PrimitiveType_Resize: + return new lite::Resize(const_cast(primitive)); + case schema::PrimitiveType_OneHot: + return new lite::OneHot(const_cast(primitive)); + case schema::PrimitiveType_PriorBox: + return new lite::PriorBox(const_cast(primitive)); + case schema::PrimitiveType_SpaceToDepth: + return new lite::SpaceToDepth(const_cast(primitive)); + case schema::PrimitiveType_SpaceToBatch: + return new lite::SpaceToBatch(const_cast(primitive)); + case schema::PrimitiveType_QuantDTypeCast: + return new lite::QuantDTypeCast(const_cast(primitive)); + case schema::PrimitiveType_MatMul: + return new lite::MatMul(const_cast(primitive)); + case schema::PrimitiveType_EmbeddingLookup: + return new lite::EmbeddingLookup(const_cast(primitive)); + default: + break; + } + return nullptr; +} + +int Primitive::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h new file mode 100644 index 0000000000..302f085b9d --- /dev/null +++ b/mindspore/lite/src/ops/ops.h @@ -0,0 +1,790 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_OPS_H_ +#define MINDSPORE_LITE_SRC_OPS_OPS_H_ + +#include +#include +#include +#include "schema/model_generated.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace lite::tensor { +class Tensor; +} +namespace lite { +constexpr uint32_t kSingleNum = 1; +constexpr uint32_t kDoubleNum = 2; +constexpr uint32_t kMultiNum = 3; +constexpr uint32_t kNHWC_n_index = 0; +constexpr uint32_t kNHWC_h_index = 1; +constexpr uint32_t kNHWC_w_index = 2; +constexpr uint32_t kNHWC_c_index = 3; +constexpr uint32_t kDimension_4d = 4; + +const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32}; + +class Primitive { + public: + explicit Primitive(schema::Primitive *primitive) : primitive(primitive) {} + static Primitive *CreatePrimitive(schema::Primitive *primitive); + virtual ~Primitive() {} + const schema::Primitive *Value() const { return this->primitive; } + schema::PrimitiveType Type() const { return this->primitive->value_type(); } + const void *Attribute() const { return this->primitive->value(); } + virtual int InferShape(std::vector inputs_, std::vector outputs_); + + protected: + schema::Primitive *primitive; +}; + +class Conv2D : public Primitive { + public: + explicit Conv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Conv2D *GetAttribute() const { return this->primitive->value_as_Conv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w); + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class Pooling : public Primitive { + public: + explicit Pooling(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Pooling *GetAttribute() const { return this->primitive->value_as_Pooling(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class BatchNorm : public Primitive { + public: + explicit BatchNorm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BatchNorm *GetAttribute() const { return this->primitive->value_as_BatchNorm(); } +}; + +class FusedBatchNorm : public Primitive { + public: + explicit FusedBatchNorm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::FusedBatchNorm *GetAttribute() const { return this->primitive->value_as_FusedBatchNorm(); } +}; + +class Activation : public Primitive { + public: + explicit Activation(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Activation *GetAttribute() const { return this->primitive->value_as_Activation(); } +}; + +class Prelu : public Activation { + public: + explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {} + const schema::Prelu *GetAttribute() const { return this->primitive->value_as_Prelu(); } +}; + +class Split : public Primitive { + public: + explicit Split(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Split *GetAttribute() const { return this->primitive->value_as_Split(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Reshape : public Primitive { + public: + explicit Reshape(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Reshape *GetAttribute() const { return this->primitive->value_as_Reshape(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + + private: + int CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_shape) const; +}; + +class FullConnection : public Primitive { + public: + explicit FullConnection(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::FullConnection *GetAttribute() const { return this->primitive->value_as_FullConnection(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class SoftMax : public Primitive { + public: + explicit SoftMax(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::SoftMax *GetAttribute() const { return this->primitive->value_as_SoftMax(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Reduce : public Primitive { + public: + explicit Reduce(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Reduce *GetAttribute() const { return this->primitive->value_as_Reduce(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class DepthwiseConv2D : public Primitive { + public: + explicit DepthwiseConv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DepthwiseConv2D *GetAttribute() const { return this->primitive->value_as_DepthwiseConv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class DeConv2D : public Primitive { + public: + explicit DeConv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DeConv2D *GetAttribute() const { return this->primitive->value_as_DeConv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class DeconvDepthwiseConv2D : public Primitive { + public: + explicit DeconvDepthwiseConv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DeDepthwiseConv2D *GetAttribute() const { return this->primitive->value_as_DeDepthwiseConv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class Power : public Primitive { + public: + explicit Power(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Power *GetAttribute() const { return this->primitive->value_as_Power(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Range : public Primitive { + public: + explicit Range(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Range *GetAttribute() const { return this->primitive->value_as_Range(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class AddN : public Primitive { + public: + explicit AddN(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::AddN *GetAttribute() const { return this->primitive->value_as_AddN(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Arithmetic : public Primitive { + public: + explicit Arithmetic(schema::Primitive *primitive) : Primitive(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; + bool Broadcasting() { return this->broadcasting_; } + int NDims() { return this->ndim_; } + std::vector InShape0() { return this->in_shape0_; } + std::vector InShape1() { return this->in_shape1_; } + std::vector OutputShape() { return this->out_shape_; } + + protected: + bool broadcasting_ = false; + int ndim_; + std::vector in_shape0_; + std::vector in_shape1_; + std::vector out_shape_; +}; + +class Add : public Arithmetic { + public: + explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Add *GetAttribute() const { return this->primitive->value_as_Add(); } +}; + +class Mul : public Arithmetic { + public: + explicit Mul(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Mul *GetAttribute() const { return this->primitive->value_as_Mul(); } +}; + +class Sub : public Arithmetic { + public: + explicit Sub(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); } +}; + +class Div : public Arithmetic { + public: + explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Div *GetAttribute() const { return this->primitive->value_as_Div(); } +}; + +class LogicalAnd : public Arithmetic { + public: + explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::LogicalAnd *GetAttribute() const { return this->primitive->value_as_LogicalAnd(); } +}; + +class LogicalOr : public Arithmetic { + public: + explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::LogicalOr *GetAttribute() const { return this->primitive->value_as_LogicalOr(); } +}; + +class Maximum : public Arithmetic { + public: + explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Maximum *GetAttribute() const { return this->primitive->value_as_Maximum(); } +}; + +class Minimum : public Arithmetic { + public: + explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Minimum *GetAttribute() const { return this->primitive->value_as_Minimum(); } +}; + +class FloorDiv : public Arithmetic { + public: + explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::FloorDiv *GetAttribute() const { return this->primitive->value_as_FloorDiv(); } +}; + +class FloorMod : public Arithmetic { + public: + explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::FloorMod *GetAttribute() const { return this->primitive->value_as_FloorMod(); } +}; + +class SquaredDifference : public Arithmetic { + public: + explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::SquaredDifference *GetAttribute() const { return this->primitive->value_as_SquaredDifference(); } +}; + +class Equal : public Arithmetic { + public: + explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Equal *GetAttribute() const { return this->primitive->value_as_Equal(); } +}; + +class NotEqual : public Arithmetic { + public: + explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::NotEqual *GetAttribute() const { return this->primitive->value_as_NotEqual(); } +}; + +class Less : public Arithmetic { + public: + explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Less *GetAttribute() const { return this->primitive->value_as_Less(); } +}; + +class LessEqual : public Arithmetic { + public: + explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::LessEqual *GetAttribute() const { return this->primitive->value_as_LessEqual(); } +}; + +class Greater : public Arithmetic { + public: + explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Greater *GetAttribute() const { return this->primitive->value_as_Greater(); } +}; + +class GreaterEqual : public Arithmetic { + public: + explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::GreaterEqual *GetAttribute() const { return this->primitive->value_as_GreaterEqual(); } +}; + +class Eltwise : public Arithmetic { + public: + explicit Eltwise(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Eltwise *GetAttribute() const { return this->primitive->value_as_Eltwise(); } +}; + +class ArithmeticSelf : public Primitive { + public: + explicit ArithmeticSelf(schema::Primitive *primitive) : Primitive(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Abs : public ArithmeticSelf { + public: + explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Abs *GetAttribute() const { return this->primitive->value_as_Abs(); } +}; + +class Cos : public ArithmeticSelf { + public: + explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Cos *GetAttribute() const { return this->primitive->value_as_Cos(); } +}; + +class Exp : public ArithmeticSelf { + public: + explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Exp *GetAttribute() const { return this->primitive->value_as_Exp(); } +}; + +class Log : public ArithmeticSelf { + public: + explicit Log(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Log *GetAttribute() const { return this->primitive->value_as_Log(); } +}; + +class Square : public ArithmeticSelf { + public: + explicit Square(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Square *GetAttribute() const { return this->primitive->value_as_Square(); } +}; + +class Sqrt : public ArithmeticSelf { + public: + explicit Sqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Sqrt *GetAttribute() const { return this->primitive->value_as_Sqrt(); } +}; + +class Rsqrt : public ArithmeticSelf { + public: + explicit Rsqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Rsqrt *GetAttribute() const { return this->primitive->value_as_Rsqrt(); } +}; + +class Sin : public ArithmeticSelf { + public: + explicit Sin(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Sin *GetAttribute() const { return this->primitive->value_as_Sin(); } +}; + +class LogicalNot : public ArithmeticSelf { + public: + explicit LogicalNot(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::LogicalNot *GetAttribute() const { return this->primitive->value_as_LogicalNot(); } +}; + +class Floor : public ArithmeticSelf { + public: + explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Floor *GetAttribute() const { return this->primitive->value_as_Floor(); } +}; + +class Ceil : public ArithmeticSelf { + public: + explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Ceil *GetAttribute() const { return this->primitive->value_as_Ceil(); } +}; + +class RealDiv : public Arithmetic { + public: + explicit RealDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::RealDiv *GetAttribute() const { return this->primitive->value_as_RealDiv(); } +}; + +class BiasAdd : public Primitive { + public: + explicit BiasAdd(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BiasAdd *GetAttribute() const { return this->primitive->value_as_BiasAdd(); } +}; + +class ExpandDims : public Primitive { + public: + explicit ExpandDims(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ExpandDims *GetAttribute() const { return this->primitive->value_as_ExpandDims(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Unsqueeze : public Primitive { + public: + explicit Unsqueeze(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Unsqueeze *GetAttribute() const { return this->primitive->value_as_Unsqueeze(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Cast : public Primitive { + public: + explicit Cast(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Cast *GetAttribute() const { return this->primitive->value_as_Cast(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Concat : public Primitive { + public: + explicit Concat(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Concat *GetAttribute() const { return this->primitive->value_as_Concat(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Fill : public Primitive { + public: + explicit Fill(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Fill *GetAttribute() const { return this->primitive->value_as_Fill(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Mean : public Primitive { + public: + explicit Mean(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Mean *GetAttribute() const { return this->primitive->value_as_Mean(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class ArgMax : public Primitive { + public: + explicit ArgMax(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ArgMax *GetAttribute() const { return this->primitive->value_as_ArgMax(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class ArgMin : public Primitive { + public: + explicit ArgMin(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ArgMin *GetAttribute() const { return this->primitive->value_as_ArgMin(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class MatMul : public Primitive { + public: + explicit MatMul(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::MatMul *GetAttribute() const { return this->primitive->value_as_MatMul(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Nchw2Nhwc : public Primitive { + public: + explicit Nchw2Nhwc(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Nchw2Nhwc *GetAttribute() const { return this->primitive->value_as_Nchw2Nhwc(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Nhwc2Nchw : public Primitive { + public: + explicit Nhwc2Nchw(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Nhwc2Nchw *GetAttribute() const { return this->primitive->value_as_Nhwc2Nchw(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Rank : public Primitive { + public: + explicit Rank(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Rank *GetAttribute() const { return this->primitive->value_as_Rank(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Pad : public Primitive { + public: + explicit Pad(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Pad *GetAttribute() const { return this->primitive->value_as_Pad(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Gather : public Primitive { + public: + explicit Gather(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Gather *GatherAttribute() const { return this->primitive->value_as_Gather(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class GatherNd : public Primitive { + public: + explicit GatherNd(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::GatherNd *GetAttribute() const { return this->primitive->value_as_GatherNd(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Slice : public Primitive { + public: + explicit Slice(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Slice *GetAttribute() const { return this->primitive->value_as_Slice(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class BroadcastTo : public Primitive { + public: + explicit BroadcastTo(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BroadcastTo *GetAttribute() const { return this->primitive->value_as_BroadcastTo(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Squeeze : public Primitive { + public: + explicit Squeeze(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Squeeze *SqueezeAttribute() const { return this->primitive->value_as_Squeeze(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Transpose : public Primitive { + public: + explicit Transpose(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Transpose *GetAttribute() const { return this->primitive->value_as_Transpose(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class LocalResponseNormalization : public Primitive { + public: + explicit LocalResponseNormalization(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::LocalResponseNormalization *GetAttribute() const { + return this->primitive->value_as_LocalResponseNormalization(); + } +}; + +class Tile : public Primitive { + public: + explicit Tile(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Tile *GetAttribute() const { return this->primitive->value_as_Tile(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Reverse : public Primitive { + public: + explicit Reverse(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Reverse *GetAttribute() const { return this->primitive->value_as_Reverse(); } +}; + +class TopK : public Primitive { + public: + explicit TopK(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::TopK *GetAttribute() const { return this->primitive->value_as_TopK(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Scale : public Primitive { + public: + explicit Scale(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Scale *GetAttribute() const { return this->primitive->value_as_Scale(); } +}; + +class Stack : public Primitive { + public: + explicit Stack(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Stack *GetAttribute() const { return this->primitive->value_as_Stack(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Unstack : public Primitive { + public: + explicit Unstack(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Unstack *GetAttribute() const { return this->primitive->value_as_Unstack(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Unique : public Primitive { + public: + explicit Unique(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Unique *GetAttribute() const { return this->primitive->value_as_Unique(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class ReverseSequence : public Primitive { + public: + explicit ReverseSequence(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ReverseSequence *GetAttribute() const { return this->primitive->value_as_ReverseSequence(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class DepthToSpace : public Primitive { + public: + explicit DepthToSpace(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DepthToSpace *GetAttribute() const { return this->primitive->value_as_DepthToSpace(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Resize : public Primitive { + public: + explicit Resize(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Resize *GetAttrbute() const { return this->primitive->value_as_Resize(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Round : public ArithmeticSelf { + public: + explicit Round(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Round *GetAttribute() const { return this->primitive->value_as_Round(); } +}; + +class ZerosLike : public Primitive { + public: + explicit ZerosLike(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ZerosLike *GetAttribute() const { return this->primitive->value_as_ZerosLike(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Where : public Primitive { + public: + explicit Where(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Where *GetAttribute() const { return this->primitive->value_as_Where(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class BatchToSpace : public Primitive { + public: + explicit BatchToSpace(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BatchToSpace *GetAttribute() const { return this->primitive->value_as_BatchToSpace(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class SpaceToBatch : public Primitive { + public: + explicit SpaceToBatch(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::SpaceToBatch *GetAttribute() const { return this->primitive->value_as_SpaceToBatch(); } + int InferShape(std::vector inputs, std::vector outputs) override; + std::vector BlockSizes() { return block_sizes_; } + std::vector Paddings() { return block_sizes_; } + std::vector InShape() { return block_sizes_; } + std::vector PaddedInShape() { return block_sizes_; } + + private: + std::vector block_sizes_; + std::vector paddings_; + std::vector in_shape_; + std::vector padded_in_shape_; +}; + +class Crop : public Primitive { + public: + explicit Crop(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Crop *GetAttribute() const { return this->primitive->value_as_Crop(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Shape : public Primitive { + public: + explicit Shape(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Shape *GetAttribute() const { return this->primitive->value_as_Shape(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class ScatterND : public Primitive { + public: + explicit ScatterND(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ScatterND *GetAttribute() const { return this->primitive->value_as_ScatterND(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Flatten : public Primitive { + public: + explicit Flatten(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Flatten *GetAttribute() const { return this->primitive->value_as_Flatten(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class OneHot : public Primitive { + public: + explicit OneHot(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::OneHot *GetAttribute() const { return this->primitive->value_as_OneHot(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class StridedSlice : public Primitive { + public: + explicit StridedSlice(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::StridedSlice *GetAttribute() const { return this->primitive->value_as_StridedSlice(); } + int InferShape(std::vector inputs, std::vector outputs) override; + int NDims() { return this->ndim_; } + void ApplyNewAxisMask(); + std::vector ApplyShrinkMask(std::vector out_shape); + void ApplyBeginMask(); + void ApplyEndMask(); + void ApplyEllipsisMask(); + std::vector GetInShape() { return this->in_shape_; } + std::vector GetBegins() { return this->begins_; } + std::vector GetEnds() { return this->ends_; } + std::vector GetStrides() { return this->strides_; } + + protected: + int ndim_; + std::vector in_shape_; + std::vector begins_; + std::vector ends_; + std::vector strides_; + std::vector begins_mask_; + std::vector ends_mask_; + std::vector ellipsis_mask_; + std::vector new_axis_mask_; + std::vector shrink_axis_mask_; +}; + +class PriorBox : public Primitive { + public: + explicit PriorBox(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::PriorBox *GetAttrbute() const { return this->primitive->value_as_PriorBox(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class SpaceToDepth : public Primitive { + public: + explicit SpaceToDepth(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::SpaceToDepth *GetAttribute() const { return this->primitive->value_as_SpaceToDepth(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class QuantDTypeCast : public Primitive { + public: + explicit QuantDTypeCast(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::QuantDTypeCast *GetAttribute() const { return this->primitive->value_as_QuantDTypeCast(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Lstm : public Primitive { + public: + explicit Lstm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class EmbeddingLookup : public Primitive { + public: + explicit EmbeddingLookup(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::EmbeddingLookup *GetAttribute() const { return this->primitive->value_as_EmbeddingLookup(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_OPS_OPS_H_ diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc new file mode 100644 index 0000000000..e697b53df8 --- /dev/null +++ b/mindspore/lite/src/ops/pad.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +const size_t kPaddingsSize = 8; +const size_t kInputRank = 4; +} // namespace +int Pad::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto pad_prim = this->primitive->value_as_Pad(); + if (pad_prim == nullptr) { + return RET_NULL_PTR; + } + auto paddings = pad_prim->paddings(); + if (paddings == nullptr) { + return RET_NULL_PTR; + } + + auto input = inputs.front(); + if (input == nullptr) { + return RET_NULL_PTR; + } + auto input_shape = input->shape(); + std::vector output_shape; + for (size_t i = 0; i < input_shape.size(); i++) { + auto shape = input_shape[i] + (*paddings)[2 * i] + (*paddings)[2 * i + 1]; + output_shape.push_back(shape); + } + + auto output = outputs.front(); + if (output == nullptr) { + return RET_NULL_PTR; + } + output->SetFormat(input->GetFormat()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc new file mode 100644 index 0000000000..20745e7cd0 --- /dev/null +++ b/mindspore/lite/src/ops/pooling.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Pooling::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + int input_h = input->shape().at(1); + int input_w = input->shape().at(2); + + auto pooling_prim = this->primitive->value_as_Pooling(); + MS_ASSERT(pooling_prim != nullptr); + auto window_h = pooling_prim->windowH(); + auto window_w = pooling_prim->windowW(); + if (pooling_prim->global()) { + window_h = input_h; + window_w = input_w; + } + + int output_h = 0; + int output_w = 0; + pad_l_ = pooling_prim->padLeft(); + pad_u_ = pooling_prim->padUp(); + pad_d_ = pooling_prim->padDown(); + pad_r_ = pooling_prim->padRight(); + if (pooling_prim->padMode() == schema::PadMode_SAME) { + output_w = std::ceil(static_cast(input_w) / static_cast(pooling_prim->strideW())); + output_h = std::ceil(static_cast(input_h) / static_cast(pooling_prim->strideH())); + auto pad_h_all = ((output_h - 1) * pooling_prim->strideH() + (window_h - 1) + 1 - input_h); + auto pad_w_all = ((output_w - 1) * pooling_prim->strideW() + (window_w - 1) + 1 - input_w); + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } else { + auto round_mode = pooling_prim->roundMode(); + if (round_mode == schema::RoundMode_FLOOR) { + output_h = std::floor((input_h + pad_u_ + pad_d_ - window_h) / pooling_prim->strideH() + 1); + output_w = std::floor((input_w + pad_l_ + pad_r_ - window_w) / pooling_prim->strideW() + 1); + } else if (round_mode == schema::RoundMode_CEIL) { + output_h = + std::ceil((input_h + pooling_prim->padUp() + pooling_prim->padDown() - window_h) / pooling_prim->strideH() + 1); + output_w = std::ceil( + (input_w + pooling_prim->padLeft() + pooling_prim->padRight() - window_w) / pooling_prim->strideW() + 1); + } else { + MS_LOG(ERROR) << "unsupported round mode."; + } + } + + // todo: fmk type + auto input_shape = input->shape(); + input_shape.at(1) = output_h; + input_shape.at(2) = output_w; + output->set_shape(input_shape); + output->set_data_type(input->data_type()); + + // todo: temp fix + output->SetFormat(schema::Format_NHWC); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc new file mode 100644 index 0000000000..9bc41f02e1 --- /dev/null +++ b/mindspore/lite/src/ops/power.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Power::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + auto x_tensor = inputs[0]; + MS_ASSERT(x_tensor != nullptr); + auto exp_tensor = inputs[1]; + MS_ASSERT(exp_tensor != nullptr); + auto output_tensor = outputs[0]; + MS_ASSERT(output_tensor != nullptr); + if (inputs.size() < 2) { + MS_LOG(ERROR) << "input size" << inputs.size() << " is error!"; + return RET_INPUT_TENSOR_ERROR; + } + if (exp_tensor->shape() != x_tensor->shape() && exp_tensor->shape().size() != 1) { + MS_LOG(ERROR) << "Power inputs shape is not equal!"; + return RET_INPUT_TENSOR_ERROR; + } + + int exp_size = std::accumulate(exp_tensor->shape().begin(), exp_tensor->shape().end(), 1, std::multiplies()); + if (x_tensor->data_type() != exp_tensor->data_type() && exp_size != 1) { + MS_LOG(ERROR) << "Exponent tensor's shape is wrong"; + return RET_INPUT_TENSOR_ERROR; + } + output_tensor->SetFormat(x_tensor->GetFormat()); + output_tensor->set_shape(x_tensor->shape()); + output_tensor->set_data_type(x_tensor->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/prior_box.cc b/mindspore/lite/src/ops/prior_box.cc new file mode 100644 index 0000000000..96e5b65413 --- /dev/null +++ b/mindspore/lite/src/ops/prior_box.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kPriorBoxPoints = 4; +constexpr int kPriorBoxN = 1; +constexpr int kPriorBoxW = 1; +constexpr int kPriorBoxC = 2; +} // namespace + +int PriorBox::InferShape(std::vector inputs_, std::vector outputs_) { + auto param = GetAttrbute(); + MS_ASSERT(param != nullptr); + std::vector different_aspect_ratios{1.0f}; + auto aspect_ratios = param->aspect_ratios(); + MS_ASSERT(aspect_ratios != nullptr); + for (auto i = 0; i < aspect_ratios->size(); i++) { + float ratio = (*aspect_ratios)[i]; + bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), [&](float v) { + return abs(ratio - v) < 1e-6; + }); + if (!exist) { + different_aspect_ratios.emplace_back(ratio); + if (param->flip()) { + different_aspect_ratios.emplace_back(1.0f / ratio); + } + } + } + int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size(); + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints; + + std::vector output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC}; + auto output = outputs_.at(0); + MS_ASSERT(output != nullptr); + + output->set_shape(output_shape); + output->set_data_type(kNumberTypeFloat32); + output->SetFormat(input->GetFormat()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc new file mode 100644 index 0000000000..93855a89c4 --- /dev/null +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int QuantDTypeCast::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + auto param = primitive->value_as_QuantDTypeCast(); + MS_ASSERT(input->data_type() == param->srcT); + output->set_data_type(static_cast(param->dstT())); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc new file mode 100644 index 0000000000..53180a8d51 --- /dev/null +++ b/mindspore/lite/src/ops/range.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Range::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto range_prim = this->primitive->value_as_Range(); + MS_ASSERT(range_prim != nullptr); + + int shape_size = std::ceil(static_cast(range_prim->limit() - range_prim->start()) / range_prim->delta()); + std::vector in_shape(1); + in_shape.push_back(shape_size); + output->set_shape(in_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc new file mode 100644 index 0000000000..5939396d16 --- /dev/null +++ b/mindspore/lite/src/ops/rank.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Rank::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + std::vector in_shape(1, 1); + output->set_shape(in_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc new file mode 100644 index 0000000000..76ce819977 --- /dev/null +++ b/mindspore/lite/src/ops/reduce.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr size_t kInputSize = 1; +constexpr size_t kOutputSize = 1; +} // namespace +int Reduce::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) { + return RET_ERROR; + } + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + return RET_NULL_PTR; + } + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto reduce_prim = this->primitive->value_as_Reduce(); + bool keep_dims = static_cast(reduce_prim->keepDims()); + std::vector in_shape = input->shape(); + std::vector out_shape; + const auto &axes = reduce_prim->axes(); + auto num_axes = axes->size(); + // reduce on all axes + if (num_axes == 0) { + if (keep_dims) { + for (auto i = 0; i < in_shape.size(); i++) { + out_shape.push_back(1); + } + } + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; + } + + // reduce on selected axes + for (size_t i = 0; i < in_shape.size(); i++) { + bool reduce_axis = false; + for (int idx = 0; idx < num_axes; ++idx) { + if (static_cast((*axes)[idx]) == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + out_shape.push_back(1); + } + } else { + out_shape.push_back(in_shape[i]); + } + } + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc new file mode 100644 index 0000000000..683e3f404e --- /dev/null +++ b/mindspore/lite/src/ops/reshape.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_shape) const { + size_t in_shape_size = 1; + for (size_t i = 0; i < in_tensor->shape().size(); i++) { + in_shape_size *= in_tensor->shape()[i]; + } + + int64_t inferIndex = -1; + size_t out_shapeSize = 1; + for (size_t i = 0; i < out_shape->size(); i++) { + if (out_shape->at(i) == -1) { + if (inferIndex == -1) { + inferIndex = i; + } else { + MS_LOG(ERROR) << "output shape should has no more than one dim which need infer"; + return RET_ERROR; + } + } else if (out_shape->at(i) < 0) { + MS_LOG(ERROR) << "output shape dim should be non-negative"; + return RET_ERROR; + } else if (out_shape->at(i) == 0) { + out_shape->at(i) = in_tensor->shape().at(i); + out_shapeSize *= out_shape->at(i); + } else { + out_shapeSize *= out_shape->at(i); + } + } + + if (inferIndex == -1 && out_shapeSize != in_shape_size) { + MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size; + return RET_ERROR; + } + if (inferIndex != -1) { + out_shape->at(inferIndex) = in_shape_size / out_shapeSize; + } + return RET_OK; +} + +template +void CalShape(const T *data, const std::vector &inputs, std::vector *out_shape, int shape_size) { + int input_count = inputs[0]->ElementsNum(); + + int index = 0; + int size = 1; + for (size_t i = 0; i < shape_size; i++) { + if (data[i] == -1) { + index = i; + } else { + size *= data[i]; + } + out_shape->push_back(data[i]); + } + if (data[index] == -1) { + (*out_shape)[index] = input_count / size; + } +} + +int Reshape::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto reshape_prim = this->primitive->value_as_Reshape(); + MS_ASSERT(reshape_prim != nullptr); + + std::vector out_shape; + if (inputs_.size() == kDoubleNum) { + auto shape_tensor = inputs_.at(1); + size_t shape_size = shape_tensor->ElementsNum(); + switch (shape_tensor->data_type()) { + case kNumberTypeInt8: { + auto data = reinterpret_cast(shape_tensor->Data()); + CalShape(data, inputs_, &out_shape, shape_size); + } break; + case kNumberTypeInt32: { + auto data = reinterpret_cast(shape_tensor->Data()); + CalShape(data, inputs_, &out_shape, shape_size); + } break; + case kNumberTypeFloat: { + auto data = reinterpret_cast(shape_tensor->Data()); + CalShape(data, inputs_, &out_shape, shape_size); + } break; + case kNumberTypeUInt32: { + auto data = reinterpret_cast(shape_tensor->Data()); + CalShape(data, inputs_, &out_shape, shape_size); + } break; + default: { + MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); + return RET_ERROR; + } + } + } else if (inputs_.size() == kSingleNum) { + std::copy(reshape_prim->shape()->begin(), reshape_prim->shape()->end(), std::back_inserter(out_shape)); + } else { + MS_LOG(ERROR) << "inputs tensor size invalid."; + } + + auto ret = CalNewShape(inputs_.front(), &out_shape); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CalNewShape error"; + return ret; + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc new file mode 100644 index 0000000000..7ef277b6fb --- /dev/null +++ b/mindspore/lite/src/ops/resize.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +namespace mindspore::lite { +namespace { +constexpr int kInputRank = 4; +} // namespace +int Resize::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + if (input == nullptr) { + return RET_NULL_PTR; + } + auto output = outputs_.front(); + if (output == nullptr) { + return RET_NULL_PTR; + } + auto resize = GetAttrbute(); + auto new_height = resize->newHeight(); + auto new_width = resize->newWidth(); + + std::vector output_shape; + output_shape.push_back(input->Batch()); + output_shape.push_back(new_height); + output_shape.push_back(new_width); + output_shape.push_back(input->Channel()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/reverse_sequence.cc b/mindspore/lite/src/ops/reverse_sequence.cc new file mode 100644 index 0000000000..e7ff1e7e62 --- /dev/null +++ b/mindspore/lite/src/ops/reverse_sequence.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ReverseSequence::InferShape(std::vector inputs, std::vector outputs) { + auto input = inputs.front(); + auto output = outputs.front(); + MS_ASSERT(input != nullptr); + MS_ASSERT(output != nullptr); + + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc new file mode 100644 index 0000000000..5384edbf91 --- /dev/null +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kScatterNDInputNum = 3; +constexpr int kScatterNDOutputNum = 1; +constexpr int kScatterShapeIndex = 0; +constexpr int kScatterIndicesIndex = 1; +constexpr int kScatterUpdateIndex = 2; +} // namespace + +int ScatterND::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kScatterNDInputNum) { + MS_LOG(ERROR) << "inputs number is not equal to " << kScatterNDInputNum; + return RET_ERROR; + } + if (outputs_.size() != kScatterNDOutputNum) { + MS_LOG(ERROR) << "outputs number is not equal to " << kScatterNDInputNum; + return RET_ERROR; + } + auto shape = inputs_.at(kScatterShapeIndex); + if (shape == nullptr) { + MS_LOG(ERROR) << "shape null pointer dereferencing."; + return RET_ERROR; + } + auto indices = inputs_.at(kScatterIndicesIndex); + if (indices == nullptr) { + MS_LOG(ERROR) << "indices null pointer dereferencing."; + return RET_ERROR; + } + auto update = inputs_.at(kScatterUpdateIndex); + if (update == nullptr) { + MS_LOG(ERROR) << "update null pointer dereferencing."; + return RET_ERROR; + } + auto output = outputs_.front(); + auto shape_data = reinterpret_cast(shape->Data()); + std::vector out_shape(shape_data, shape_data + shape->DataSize()); + output->set_shape(out_shape); + output->set_data_type(update->data_type()); + output->SetFormat(update->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/shape.cc b/mindspore/lite/src/ops/shape.cc new file mode 100644 index 0000000000..6c4c239354 --- /dev/null +++ b/mindspore/lite/src/ops/shape.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kShapeInputNum = 1; +constexpr int kShapeOutputNum = 1; + +} // namespace +int Shape::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kShapeInputNum) { + MS_LOG(ERROR) << "inputs to Shape operator should be 1, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + if (outputs_.size() != kShapeOutputNum) { + MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given."; + return RET_ERROR; + } + + auto in_tensor = inputs_.front(); + auto out_tensor = outputs_.front(); + std::vector out_shape; + out_shape.push_back(static_cast(in_tensor->shape().size())); + + auto ret_shape = out_tensor->set_shape(out_shape); + if (ret_shape != 1 || size_t(out_tensor->shape()[0]) != in_tensor->shape().size()) { + MS_LOG(ERROR) << "Set shape fails."; + return RET_ERROR; + } + auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type()); + if (ret_dtype != in_tensor->data_type()) { + MS_LOG(ERROR) << "Set datatype fails."; + return RET_ERROR; + } + + // todo + // auto ret_data = out_tensor->MallocData(); + // if (ret_data != 0) { + // MS_LOG(ERROR) << "Allocate memory fails."; + // return RET_ERROR; + // } + + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc new file mode 100644 index 0000000000..c1fe1d396d --- /dev/null +++ b/mindspore/lite/src/ops/slice.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSliceInputNum = 1; +constexpr int kSliceOutputNum = 1; +} // namespace + +int Slice::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { + MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + auto input_shape = input->shape(); + auto slice_prim = this->primitive->value_as_Slice(); + std::vector slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end()); + std::vector slice_size(slice_prim->size()->begin(), slice_prim->size()->end()); + std::vector output_shape(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) { + if (slice_size[i] < 0 && slice_size[i] != -1) { + MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i]; + return RET_PARAM_INVALID; + } + if (slice_begin[i] < 0) { + MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0"; + return RET_PARAM_INVALID; + } + if (input_shape[i] <= slice_begin[i]) { + MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i] + << " which should be <= " << input_shape[i]; + return RET_PARAM_INVALID; + } + if (slice_size[i] > (input_shape[i] - slice_begin[i])) { + MS_LOG(ERROR) << "Invalid size input " << slice_size[i] + << " which should be <= " << input_shape[i] - slice_begin[i]; + return RET_PARAM_INVALID; + } + + output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i]; + } + + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc new file mode 100644 index 0000000000..ebce31098c --- /dev/null +++ b/mindspore/lite/src/ops/softmax.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int SoftMax::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc new file mode 100644 index 0000000000..c3cecc75d9 --- /dev/null +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSpaceToBatchNDOutputNum = 1; +constexpr int kSpaceToBatchNDInputNum = 1; +constexpr int kBlockSizesSize = 2; +constexpr int kPaddingsSize = 4; +} // namespace + +int SpaceToBatch::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; + return RET_FORMAT_ERR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + return RET_PARAM_INVALID; + } + + auto prim = this->primitive->value_as_SpaceToBatch(); + if (prim->blockShape()->size() != kBlockSizesSize) { + MS_LOG(ERROR) << "Block shape size should be " << kBlockSizesSize; + return RET_PARAM_INVALID; + } + if (prim->paddings()->size() != kPaddingsSize) { + MS_LOG(ERROR) << "Crops size should be " << kPaddingsSize; + return RET_PARAM_INVALID; + } + + for (auto iter = prim->blockShape()->begin(); iter != prim->blockShape()->end(); ++iter) { + block_sizes_.emplace_back(*iter); + } + + in_shape_.clear(); + padded_in_shape_.clear(); + paddings_.clear(); + in_shape_.emplace_back(input_shape.at(kNHWC_n_index)); + padded_in_shape_.emplace_back(input_shape.at(kNHWC_n_index)); + for (int i = 0; i < kBlockSizesSize; i++) { + in_shape_.emplace_back(input_shape.at(i + 1)); + padded_in_shape_.emplace_back(input_shape.at(i + 1) + (paddings_.at(2 * i) + paddings_.at(2 * i + 1))); + paddings_.emplace_back(paddings_.at(2 * i)); + paddings_.emplace_back(paddings_.at(2 * i + 1)); + if (paddings_.back() % block_sizes_.at(i)) { + MS_LOG(ERROR) << "Padded shape does not divide block size " << block_sizes_.at(i); + return RET_PARAM_INVALID; + } + } + in_shape_.emplace_back(input_shape.at(kNHWC_c_index)); + padded_in_shape_.emplace_back(input_shape.at(kNHWC_c_index)); + + std::vector output_shape(input_shape.size()); + output_shape[kNHWC_n_index] = + input_shape[kNHWC_n_index] * (block_sizes_[kNHWC_n_index] * block_sizes_[kNHWC_h_index]); + output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] / block_sizes_[kNHWC_n_index]; + output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] / block_sizes_[kNHWC_h_index]; + output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index]; + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc new file mode 100644 index 0000000000..647c64c909 --- /dev/null +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSpaceToDepthOutputNum = 1; +constexpr int kSpaceToDepthInputNum = 1; +} + +int SpaceToDepth::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kSpaceToDepthOutputNum || inputs.size() != kSpaceToDepthInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "space_to_depth only support NHWC now!"; + return RET_FORMAT_ERR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + return RET_PARAM_INVALID; + } + auto prim = this->primitive->value_as_SpaceToDepth(); + int32_t block_size = prim->blockSize(); + if (input_shape[kNHWC_c_index] % (block_size * block_size) != 0 || input_shape[kNHWC_c_index] == 0) { + MS_LOG(ERROR) << "input dimension c size " << input_shape[kNHWC_c_index] << " should be mulitple of block_size(" + << block_size << ") * block_size)!"; + return RET_PARAM_INVALID; + } + std::vector output_shape(input_shape.size()); + output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index]; + output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] / block_size; + output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] / block_size; + output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index] * (block_size * block_size); + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc new file mode 100644 index 0000000000..f8b175fa96 --- /dev/null +++ b/mindspore/lite/src/ops/split.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSplitInputNum = 1; +} // namespace +int Split::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto spilt_prim = this->primitive->value_as_Split(); + MS_ASSERT(spilt_prim != nullptr); + if (inputs_.size() != kSplitInputNum) { + MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum; + return RET_ERROR; + } + auto output = outputs_.front(); + if (output == nullptr) { + MS_LOG(ERROR) << "output null pointer dereferencing."; + return RET_ERROR; + } + int number_split = spilt_prim->numberSplit(); + if (outputs_.size() != number_split) { + MS_LOG(ERROR) << "outputs number is not equal to " << number_split; + return RET_ERROR; + } + int split_dim = spilt_prim->splitDim(); + std::vector input_shape = input->shape(); + std::vector size_split; + size_split.insert(size_split.begin(), spilt_prim->sizeSplits()->begin(), spilt_prim->sizeSplits()->end()); + + for (int i = 0; i < number_split; ++i) { + std::vector output_shape; + output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); + auto split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i]; + output_shape[split_dim] = split_dim_i; + outputs_[i]->set_shape(output_shape); + outputs_[i]->set_data_type(input->data_type()); + outputs_[i]->SetFormat(input->GetFormat()); + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc new file mode 100644 index 0000000000..4446f3ead5 --- /dev/null +++ b/mindspore/lite/src/ops/squeeze.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSqueezeInputNum = 1; +constexpr int kSqueezeOutputNum = 1; +} // namespace +int Squeeze::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (kSqueezeInputNum != inputs_.size()) { + MS_LOG(ERROR) << "Add should has " << kSqueezeInputNum << " inputs"; + return -1; + } + if (kSqueezeOutputNum != outputs_.size()) { + MS_LOG(ERROR) << "Add should has " << kSqueezeOutputNum << " outputs"; + return -1; + } + auto *in_tensor = inputs_.front(); + auto in_shape = in_tensor->shape(); + std::vector out_shape; + + // todo: getAxis + auto squeeze_prim = this->primitive->value_as_Squeeze(); + MS_EXCEPTION_IF_NULL(squeeze_prim); + auto axis = squeeze_prim->axis(); + std::vector axes_; + for (auto iter = axis->begin(); iter != axis->end(); iter++) { + axes_.push_back(*iter); + } + + if (axes_.size() == 0) { + for (int i = 0; i < in_shape.size(); i++) { + if (in_shape[i] != 1) { + out_shape.push_back(in_shape[i]); + } + } + } else { + int axisIdx = 0; + for (int i = 0; i < in_shape.size(); i++) { + if (axisIdx < axes_.size() && axes_[axisIdx] == i) { + MS_ASSERT(in_shape[i] == 1); + axisIdx++; + continue; + } else { + out_shape.push_back(in_shape[i]); + } + } + } + + outputs_.front()->set_shape(out_shape); + outputs_.front()->set_data_type(in_tensor->data_type()); + outputs_.front()->SetFormat(in_tensor->GetFormat()); + + return 0; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc new file mode 100644 index 0000000000..4b727c62b6 --- /dev/null +++ b/mindspore/lite/src/ops/stack.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kStackOutputNum = 1; +constexpr int kStackMinInputNum = 2; +} // namespace + +int Stack::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kStackOutputNum) { + MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + if (inputs.size() < kStackMinInputNum) { + MS_LOG(ERROR) << "Invalid input size " << inputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + auto input_shape = input->shape(); + auto stack_prim = this->primitive->value_as_Stack(); + std::vector output_shape = input_shape; + int axis = stack_prim->axis() < 0 ? stack_prim->axis() + input_shape.size() : stack_prim->axis(); + if (axis < 0 || axis > input_shape.size()) { + MS_LOG(ERROR) << "Invalid axis " << stack_prim->axis(); + return RET_PARAM_INVALID; + } + for (size_t i = 1; i < inputs.size(); ++i) { + auto input_shape_tmp = inputs[i]->shape(); + if (input_shape_tmp.size() != input_shape.size()) { + MS_LOG(ERROR) << "All input shape size should be the same!"; + return RET_PARAM_INVALID; + } + for (size_t j = 0; j < input_shape.size(); ++j) { + if (input_shape_tmp[j] != input_shape[j]) { + MS_LOG(ERROR) << "All input shape should be the same!"; + return RET_PARAM_INVALID; + } + } + } + + output_shape.insert(output_shape.begin() + axis, inputs.size()); + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc new file mode 100644 index 0000000000..6127c8a601 --- /dev/null +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "include/errorcode.h" +#include "src/ops/ops.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kStridedSliceOutputNum = 1; +constexpr int kStridedSliceInputNum = 1; +} // namespace + +void StridedSlice::ApplyNewAxisMask() { + for (int i = 0; i < new_axis_mask_.size(); i++) { + if (new_axis_mask_.at(i)) { + ndim_ += 1; + in_shape_.insert(in_shape_.begin() + i, 1); + begins_.at(i) = 0; + ends_.at(i) = 1; + strides_.at(i) = 1; + + begins_.emplace_back(0); + ends_.emplace_back(in_shape_.at(ndim_ - 1)); + strides_.emplace_back(1); + + begins_mask_.at(i) = false; + ends_mask_.at(i) = false; + ellipsis_mask_.at(i) = false; + shrink_axis_mask_.at(i) = false; + } + } +} + +std::vector StridedSlice::ApplyShrinkMask(std::vector out_shape) { + auto old_out_shape = out_shape; + out_shape.clear(); + for (int i = 0; i < shrink_axis_mask_.size(); i++) { + if (shrink_axis_mask_.at(i)) { + ends_.at(i) = begins_.at(i) + 1; + strides_.at(i) = 1; + } else { + out_shape.emplace_back(old_out_shape.at(i)); + } + } + for (int i = shrink_axis_mask_.size(); i < old_out_shape.size(); i++) { + out_shape.emplace_back(old_out_shape.at(i)); + } + return out_shape; +} + +/*only one bit will be used if multiple bits are true.*/ +void StridedSlice::ApplyEllipsisMask() { + for (int i = 0; i < ellipsis_mask_.size(); i++) { + if (ellipsis_mask_.at(i)) { + begins_.at(i) = 0; + ends_.at(i) = in_shape_.at(i); + break; + } + } +} + +void StridedSlice::ApplyBeginMask() { + for (int i = 0; i < ndim_; i++) { + if (begins_mask_.at(i)) { + begins_.at(i) = 0; + } + } +} + +void StridedSlice::ApplyEndMask() { + for (int i = 0; i < ndim_; i++) { + if (ends_mask_.at(i)) { + ends_.at(i) = in_shape_.at(i); + } + } +} + +int StridedSlice::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kStridedSliceOutputNum) { + MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + if (inputs.size() != kStridedSliceInputNum) { + MS_LOG(ERROR) << "Invalid input size " << inputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + MS_ASSERT(input != nullptr); + auto input_shape = input->shape(); + std::vector output_shape; + auto strided_slice_prim = this->primitive->value_as_StridedSlice(); + ndim_ = static_cast(strided_slice_prim->begin()->size()); + + MS_ASSERT(ndim_ == static_cast(strided_slice_prim->end()->size())); + MS_ASSERT(ndim_ == static_cast(strided_slice_prim->stride()->size())); + MS_ASSERT(ndim_ == static_cast(input_shape.size())); + + for (int i = 0; i < ndim_; i++) { + in_shape_.emplace_back(input_shape.at(i)); + begins_.emplace_back((*(strided_slice_prim->begin()))[i]); + ends_.emplace_back((*(strided_slice_prim->end()))[i]); + strides_.emplace_back((*(strided_slice_prim->stride()))[i]); + } + + // set all mask to original input shape + begins_mask_.resize(ndim_); + ends_mask_.resize(ndim_); + ellipsis_mask_.resize(ndim_); + new_axis_mask_.resize(ndim_); + shrink_axis_mask_.resize(ndim_); + + // convert bit to vector + for (int i = 0; i < ndim_; i++) { + begins_mask_.at(i) = static_cast(strided_slice_prim->beginMask()) & (1 << i); + ends_mask_.at(i) = static_cast(strided_slice_prim->endMask()) & (1 << i); + ellipsis_mask_.at(i) = static_cast(strided_slice_prim->ellipsisMask()) & (1 << i); + new_axis_mask_.at(i) = static_cast(strided_slice_prim->newAxisMask()) & (1 << i); + shrink_axis_mask_.at(i) = static_cast(strided_slice_prim->shrinkAxisMask()) & (1 << i); + } + + ApplyNewAxisMask(); + ApplyBeginMask(); + ApplyEndMask(); + ApplyEllipsisMask(); + + output_shape.clear(); + output_shape.resize(in_shape_.size()); + for (int i = 0; i < in_shape_.size(); i++) { + if (i < ndim_ && new_axis_mask_.at(i)) { + output_shape.at(i) = 1; + } else { + output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); + } + } + + output_shape = ApplyShrinkMask(output_shape); + + outputs.front()->set_shape(output_shape); + outputs.front()->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc new file mode 100644 index 0000000000..42f6bd2071 --- /dev/null +++ b/mindspore/lite/src/ops/tile.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Tile::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto tile_prim = this->primitive->value_as_Tile(); + MS_ASSERT(tile_prim != nullptr); + + std::vector out_shape; + std::vector multiples; + std::copy(tile_prim->multiples()->begin(), tile_prim->multiples()->end(), std::back_inserter(multiples)); + for (size_t i = 0; i < input->shape().size(); ++i) { + int tmp = input->shape()[i] * multiples[i]; + out_shape.push_back(tmp); + } + + output->SetFormat(input->GetFormat()); + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc new file mode 100644 index 0000000000..e3dbee034e --- /dev/null +++ b/mindspore/lite/src/ops/topk.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int TopK::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output0 = outputs_.front(); + MS_ASSERT(output0 != nullptr); + auto output1 = outputs_.at(1); + MS_ASSERT(output1 != nullptr); + auto topk_prim = this->primitive->value_as_TopK(); + MS_ASSERT(topk_prim != nullptr); + + auto out_shape = input->shape(); + out_shape[out_shape.size() - 1] = topk_prim->k(); + + output0->set_shape(out_shape); + output0->set_data_type(input->data_type()); + output0->SetFormat(input->GetFormat()); + + output1->set_shape(out_shape); + output1->set_data_type(kNumberTypeInt32); + output1->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc new file mode 100644 index 0000000000..ba9366c233 --- /dev/null +++ b/mindspore/lite/src/ops/transpose.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Transpose::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + MS_ASSERT(inputs_.size() == kSingleNum); + MS_ASSERT(outputs_.size() == kSingleNum); + auto transpore_prim = this->primitive->value_as_Transpose(); + int conjugate = transpore_prim->conjugate(); + if (conjugate) { + MS_LOG(ERROR) << "Transpose conjugate is not support currently"; + return RET_ERROR; + } + std::vector perm; + perm.insert(perm.begin(), transpore_prim->perm()->begin(), transpore_prim->perm()->end()); + + std::vector in_shape = input->shape(); + std::vector out_shape; + out_shape.resize(perm.size()); + for (int i = 0; i < perm.size(); ++i) { + out_shape[i] = in_shape[perm[i]]; + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/unique.cc b/mindspore/lite/src/ops/unique.cc new file mode 100644 index 0000000000..139e0e7113 --- /dev/null +++ b/mindspore/lite/src/ops/unique.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Unique::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto &input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto &output0 = outputs_.at(0); + MS_ASSERT(output0 != nullptr); + auto &output1 = outputs_.at(1); + MS_ASSERT(output1 != nullptr); + output0->set_shape(input->shape()); + output0->set_data_type(input->data_type()); + output1->set_shape(input->shape()); + output1->set_data_type(kNumberTypeInt32); + output1->SetFormat(input->GetFormat()); + output0->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc new file mode 100644 index 0000000000..586defcbd5 --- /dev/null +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Unsqueeze::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size is invalid"; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output size is invalid"; + } + auto unsqueeze_prim = this->primitive->value_as_Unsqueeze(); + auto dims = unsqueeze_prim->axis()->data(); + auto in_shape = input->shape(); + auto in_rank = in_shape.size(); + auto dim_rank = unsqueeze_prim->axis()->size(); + std::vector out_shape; + + if (dim_rank == 0) { + for (auto d : in_shape) { + if (d != 1) { + out_shape.push_back(d); + } + } + } else { + auto sz = in_rank + dim_rank; + int in_itr = 0; + int ax_itr = 0; + for (int i = 0; i < sz; i++) { + if (ax_itr < dim_rank && dims[ax_itr] == i) { + out_shape.emplace_back(1); + ax_itr++; + } else if (ax_itr < dim_rank && dims[ax_itr] + sz == i) { + out_shape.emplace_back(1); + ax_itr++; + } else { + if (in_shape[in_itr] > 1) { + out_shape.emplace_back(in_shape[in_itr]); + } + in_itr++; + } + } + } + + output->SetFormat(input->GetFormat()); + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc new file mode 100644 index 0000000000..e11f0d71e1 --- /dev/null +++ b/mindspore/lite/src/ops/unstack.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Unstack::InferShape(std::vector inputs, std::vector outputs) { + auto input = inputs.at(0); + MS_ASSERT(input != nullptr); + auto input_shape = input->shape(); + auto prim = this->primitive->value_as_Unstack(); + int axis = prim->axis() < 0 ? prim->axis() + input_shape.size() : prim->axis(); + if (axis < 0 || axis >= input_shape.size()) { + MS_LOG(ERROR) << "Invalid axis " << prim->axis(); + return RET_PARAM_INVALID; + } + + std::vector output_shape; + for (size_t i = 0; i < input_shape.size(); ++i) { + if (i != axis) { + output_shape.push_back(input_shape.at(i)); + } + } + for (auto &out : outputs) { + MS_ASSERT(out != nullptr); + out->set_shape(output_shape); + out->set_data_type(input->data_type()); + out->SetFormat(input->GetFormat()); + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc new file mode 100644 index 0000000000..ef3d3d0d03 --- /dev/null +++ b/mindspore/lite/src/ops/where.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Where::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "where input or output number invalid, Input size:" << inputs_.size() + << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + if (inputs_.size() < 3) { + MS_LOG(ERROR) << "Input shape tensors should b"; + return RET_INPUT_TENSOR_ERROR; + } + auto input0 = inputs_.at(0); + auto input1 = inputs_.at(1); + auto input2 = inputs_.at(2); + int num = input0->ElementsNum(); + int num1 = input1->ElementsNum(); + int num2 = input2->ElementsNum(); + int nummax = num > num1 ? num : (num1 > num2 ? num1 : num2); + + auto shape_tmp = inputs_.at(0)->shape(); + auto shape_tmp1 = inputs_.at(1)->shape(); + auto shape_tmp2 = inputs_.at(2)->shape(); + int axisout = 0; + int temp = 0; + for (int j = 0; j < shape_tmp.size(); j++) { + if (shape_tmp[j] == shape_tmp1[j] && shape_tmp[j] != shape_tmp2[j]) { + axisout = j; + break; + } + if (shape_tmp[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) { + axisout = j; + break; + } + if (shape_tmp1[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) { + axisout = j; + break; + } + temp += 1; + if (temp == shape_tmp.size()) { + outputs_[0]->set_shape(shape_tmp); + output->set_data_type(input->data_type()); + return RET_OK; + } + } + + auto output_shape = shape_tmp; + output_shape[axisout] = nummax; + outputs_[0]->set_shape(output_shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/zeroslike.cc b/mindspore/lite/src/ops/zeroslike.cc new file mode 100644 index 0000000000..d14b96b385 --- /dev/null +++ b/mindspore/lite/src/ops/zeroslike.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ZerosLike::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "zeroslike input or output number invalid, Input size:" << inputs_.size() + << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/param_value_lite.h b/mindspore/lite/src/param_value_lite.h new file mode 100644 index 0000000000..8f806a4f9c --- /dev/null +++ b/mindspore/lite/src/param_value_lite.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ +#define MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ + +#include +#include +#include +#include + +#include "ir/param_value.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +struct AnfQuantParam { + double scale; + int32_t zeroPoint; + double min; + double max; + bool narrowRange; + bool inited; + int32_t numBits; + AnfQuantParam() : scale(1.0), zeroPoint(0), min(0.0), max(0.0), narrowRange(false), numBits(8), inited(false) {} +}; +class ParamValueLite : public Value { + public: + ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} + virtual ~ParamValueLite() = default; + + size_t tensor_size() const { return tensor_size_; } + void set_tensor_size(size_t size) { tensor_size_ = size; } + // todo + void *tensor_addr() const { return tensor_addr_; } + void set_tensor_addr(void *addr) { tensor_addr_ = addr; } + + std::vector tensor_shape() const { return tensor_shape_; } + void set_tensor_shape(std::vector tensor_shape) { tensor_shape_ = std::move(tensor_shape); } + + TypeId tensor_type() const { return type_id_; } + void set_tensor_type(TypeId type_id) { type_id_ = type_id; } + + int tensor_shape_size() const { + int size = 1; + for (auto val : tensor_shape_) { + size *= val; + } + return size; + } + std::vector> &quant_param() { return quant_params_; } + void set_quant_param(std::unique_ptr &quant_param) { + quant_params_.emplace_back(std::move(quant_param)); + } + + bool operator==(const Value &other) const override { + this == &other; + } + + private: + void *tensor_addr_; + size_t tensor_size_; + std::vector tensor_shape_; + TypeId type_id_; + std::vector> quant_params_; +}; + +using ParamValueLitePtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ + diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc new file mode 100644 index 0000000000..22e5b42deb --- /dev/null +++ b/mindspore/lite/src/populate_parameter.cc @@ -0,0 +1,1339 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/populate_parameter.h" +#include +#include "src/ops/ops.h" +#include "utils/log_adapter.h" +#include "schema/ops_generated.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h" +#include "src/runtime/kernel/arm/nnacl/fp32/cast.h" +#include "src/runtime/kernel/arm/nnacl/concat_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/slice.h" +#include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" +#include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/stack.h" +#include "src/runtime/kernel/arm/nnacl/unstack.h" +#include "src/runtime/kernel/arm/nnacl/depth_to_space.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" +#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" +#include "src/runtime/kernel/arm/nnacl/tile.h" +#include "src/runtime/kernel/arm/nnacl/fp32/topk.h" +#include "src/runtime/kernel/arm/nnacl/fp32/reduce.h" +#include "src/runtime/kernel/arm/nnacl/fp32/activation.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/nnacl/fused_batchnorm.h" +#include "src/runtime/kernel/arm/nnacl/fp32/batchnorm.h" +#include "src/runtime/kernel/arm/nnacl/power.h" +#include "src/runtime/kernel/arm/nnacl/fp32/range.h" +#include "src/runtime/kernel/arm/nnacl/fp32/local_response_norm.h" +#include "src/runtime/kernel/arm/nnacl/fp32/expandDims.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h" +#include "src/runtime/kernel/arm/nnacl/pad_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/fill.h" +#include "src/runtime/kernel/arm/nnacl/transpose.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" +#include "src/runtime/kernel/arm/nnacl/squeeze.h" +#include "src/runtime/kernel/arm/nnacl/fp32/gather.h" +#include "src/runtime/kernel/arm/nnacl/fp32/reverse.h" +#include "src/runtime/kernel/arm/nnacl/reverse_sequence.h" +#include "src/runtime/kernel/arm/nnacl/unique.h" +#include "src/runtime/kernel/arm/nnacl/scale.h" +#include "src/runtime/kernel/arm/nnacl/fp32/gatherNd.h" +#include "src/runtime/kernel/arm/nnacl/resize.h" +#include "src/runtime/kernel/arm/nnacl/scatter_nd.h" +#include "src/runtime/kernel/arm/nnacl/batch_to_space.h" +#include "src/runtime/kernel/arm/nnacl/fp32/crop.h" +#include "src/runtime/kernel/arm/fp32/flatten.h" +#include "src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h" +#include "src/runtime/kernel/arm/nnacl/fp32/one_hot.h" +#include "src/runtime/kernel/arm/nnacl/strided_slice.h" +#include "src/runtime/kernel/arm/base/prior_box.h" +#include "src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h" +#include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" +#include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" +#include "src/runtime/kernel/arm/nnacl/fp32/lstm.h" +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" + +namespace mindspore::kernel { +OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) { + BatchNormParameter *batch_norm_param = new (std::nothrow) BatchNormParameter(); + if (batch_norm_param == nullptr) { + MS_LOG(ERROR) << "new BatchNormParameter failed."; + return nullptr; + } + batch_norm_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_BatchNorm(); + batch_norm_param->epsilon_ = param->epsilon(); + return reinterpret_cast(batch_norm_param); +} + +OpParameter *PopulateFillParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_Fill(); + FillParameter *fill_param = new (std::nothrow) FillParameter(); + if (fill_param == nullptr) { + MS_LOG(ERROR) << "new FillParameter failed."; + return nullptr; + } + fill_param->op_parameter_.type_ = primitive->Type(); + auto flatDims = param->dims(); + fill_param->num_dims_ = flatDims->size(); + int i = 0; + for (auto iter = flatDims->begin(); iter != flatDims->end(); iter++) { + fill_param->dims_[i++] = *iter; + } + return reinterpret_cast(fill_param); +} + +OpParameter *PopulateExpandDimsParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_ExpandDims(); + ExpandDimsParameter *expand_dims_param = new (std::nothrow) ExpandDimsParameter(); + if (expand_dims_param == nullptr) { + MS_LOG(ERROR) << "new ExpandDimsParameter failed."; + return nullptr; + } + expand_dims_param->op_parameter_.type_ = primitive->Type(); + expand_dims_param->dim_ = param->dim(); + return reinterpret_cast(expand_dims_param); +} + +OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) { + auto pooling_primitive = primitive->Value()->value_as_Pooling(); + // todo use malloc instead + PoolingParameter *pooling_param = new (std::nothrow) PoolingParameter(); + if (pooling_param == nullptr) { + MS_LOG(ERROR) << "new PoolingParameter failed."; + return nullptr; + } + pooling_param->op_parameter_.type_ = primitive->Type(); + pooling_param->global_ = pooling_primitive->global(); + pooling_param->window_w_ = pooling_primitive->windowW(); + pooling_param->window_h_ = pooling_primitive->windowH(); + // todo format + auto pooling_lite_primitive = (lite::Pooling *)primitive; + MS_ASSERT(nullptr != pooling_lite_primitive); + pooling_param->pad_u_ = pooling_lite_primitive->PadUp(); + pooling_param->pad_d_ = pooling_lite_primitive->PadDown(); + pooling_param->pad_l_ = pooling_lite_primitive->PadLeft(); + pooling_param->pad_r_ = pooling_lite_primitive->PadRight(); + pooling_param->stride_w_ = pooling_primitive->strideW(); + pooling_param->stride_h_ = pooling_primitive->strideH(); + + auto pool_mode = pooling_primitive->poolingMode(); + switch (pool_mode) { + case schema::PoolMode_MAX_POOLING: + pooling_param->max_pooling_ = true; + pooling_param->avg_pooling_ = false; + break; + case schema::PoolMode_MEAN_POOLING: + pooling_param->max_pooling_ = false; + pooling_param->avg_pooling_ = true; + break; + default: + pooling_param->max_pooling_ = false; + pooling_param->avg_pooling_ = false; + break; + } + + auto round_mode = pooling_primitive->roundMode(); + switch (round_mode) { + case schema::RoundMode_FLOOR: + pooling_param->round_floor_ = true; + pooling_param->round_ceil_ = false; + break; + case schema::RoundMode_CEIL: + pooling_param->round_floor_ = false; + pooling_param->round_ceil_ = true; + break; + default: + pooling_param->round_floor_ = false; + pooling_param->round_ceil_ = false; + break; + } + return reinterpret_cast(pooling_param); +} + +OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_FullConnection(); + MatMulParameter *matmul_param = new (std::nothrow) MatMulParameter(); + if (matmul_param == nullptr) { + MS_LOG(ERROR) << "new FullconnectionParameter failed."; + return nullptr; + } + matmul_param->op_parameter_.type_ = primitive->Type(); + matmul_param->b_transpose_ = true; + matmul_param->a_transpose_ = false; + matmul_param->has_bias_ = param->hasBias(); + matmul_param->act_type_ = ActType_No; + return reinterpret_cast(matmul_param); +} + +OpParameter *PopulateMatMulParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_MatMul(); + MatMulParameter *matmul_param = new (std::nothrow) MatMulParameter(); + if (matmul_param == nullptr) { + MS_LOG(ERROR) << "new FullconnectionParameter failed."; + return nullptr; + } + matmul_param->op_parameter_.type_ = primitive->Type(); + matmul_param->b_transpose_ = param->transposeB(); + matmul_param->a_transpose_ = param->transposeA(); + matmul_param->has_bias_ = false; + matmul_param->act_type_ = ActType_No; + return reinterpret_cast(matmul_param); +} + +OpParameter *PopulateConvParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new (std::nothrow) ConvParameter(); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = primitive->Value()->value_as_Conv2D(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); + // todo format + conv_param->group_ = conv_primitive->group(); + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); + + auto conv2d_lite_primitive = (lite::Conv2D *)primitive; + MS_ASSERT(nullptr != conv2d_lite_primitive); + conv_param->pad_u_ = conv2d_lite_primitive->PadUp(); + conv_param->pad_d_ = conv2d_lite_primitive->PadDown(); + conv_param->pad_l_ = conv2d_lite_primitive->PadLeft(); + conv_param->pad_r_ = conv2d_lite_primitive->PadRight(); + conv_param->pad_h_ = conv2d_lite_primitive->PadUp(); + conv_param->pad_w_ = conv2d_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); + conv_param->input_channel_ = conv_primitive->channelIn(); + conv_param->output_channel_ = conv_primitive->channelOut(); + conv_param->group_ = conv_primitive->group(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; + break; + default: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; + break; + } + return reinterpret_cast(conv_param); +} + +OpParameter *PopulateConvDwParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new (std::nothrow) ConvParameter(); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = primitive->Value()->value_as_DepthwiseConv2D(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); + // todo format, group + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); + + auto pad_mode = conv_primitive->padMode(); + auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive; + MS_ASSERT(nullptr != convdw_lite_primitive); + conv_param->pad_u_ = convdw_lite_primitive->PadUp(); + conv_param->pad_d_ = convdw_lite_primitive->PadDown(); + conv_param->pad_l_ = convdw_lite_primitive->PadLeft(); + conv_param->pad_r_ = convdw_lite_primitive->PadRight(); + conv_param->pad_h_ = convdw_lite_primitive->PadUp(); + conv_param->pad_w_ = convdw_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; + break; + default: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; + break; + } + return reinterpret_cast(conv_param); +} + +OpParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new ConvParameter(); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = primitive->Value()->value_as_DeDepthwiseConv2D(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); + // todo format, group + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); + + auto deconvdw_lite_primitive = (lite::DeconvDepthwiseConv2D *)primitive; + MS_ASSERT(nullptr != deconvdw_lite_primitive); + conv_param->pad_u_ = deconvdw_lite_primitive->PadUp(); + conv_param->pad_d_ = deconvdw_lite_primitive->PadDown(); + conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft(); + conv_param->pad_r_ = deconvdw_lite_primitive->PadRight(); + conv_param->pad_h_ = deconvdw_lite_primitive->PadUp(); + conv_param->pad_w_ = deconvdw_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; + break; + default: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; + break; + } + return reinterpret_cast(conv_param); +} + +OpParameter *PopulateDeconvParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new ConvParameter(); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = primitive->Value()->value_as_DeConv2D(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); + + auto deconv_lite_primitive = (lite::DeConv2D *)primitive; + MS_ASSERT(nullptr != deconvdw_lite_primitive); + conv_param->pad_u_ = deconv_lite_primitive->PadUp(); + conv_param->pad_d_ = deconv_lite_primitive->PadDown(); + conv_param->pad_l_ = deconv_lite_primitive->PadLeft(); + conv_param->pad_r_ = deconv_lite_primitive->PadRight(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; + break; + default: + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; + break; + } + + auto pad_mode = conv_primitive->padMode(); + switch (pad_mode) { + case schema::PadMode_SAME: + conv_param->pad_h_ = (conv_param->kernel_h_ - 1) / 2; + conv_param->pad_w_ = (conv_param->kernel_w_ - 1) / 2; + break; + case schema::PadMode_VALID: + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + break; + case schema::PadMode_CAFFE: + conv_param->pad_h_ = conv_param->pad_u_; + conv_param->pad_w_ = conv_param->pad_l_; + break; + default: + MS_LOG(ERROR) << "invalid pad mode!"; + return nullptr; + } + + return reinterpret_cast(conv_param); +} + +OpParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) { + auto softmax_primitive = primitive->Value()->value_as_SoftMax(); + SoftmaxParameter *softmax_param = new (std::nothrow) SoftmaxParameter(); + if (softmax_param == nullptr) { + MS_LOG(ERROR) << "new SoftmaxParameter failed."; + return nullptr; + } + softmax_param->op_parameter_.type_ = primitive->Type(); + softmax_param->axis_ = softmax_primitive->axis(); + return reinterpret_cast(softmax_param); +} + +OpParameter *PopulateReduceParameter(const lite::Primitive *primitive) { + ReduceParameter *reduce_param = new (std::nothrow) ReduceParameter(); + if (reduce_param == nullptr) { + MS_LOG(ERROR) << "new ReduceParameter failed."; + return nullptr; + } + reduce_param->op_parameter_.type_ = primitive->Type(); + auto reduce = primitive->Value()->value_as_Reduce(); + reduce_param->keep_dims_ = reduce->keepDims(); + auto axisVector = reduce->axes(); + if (axisVector->size() > REDUCE_MAX_AXES_NUM) { + MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM; + delete (reduce_param); + return nullptr; + } + reduce_param->num_axes_ = static_cast(axisVector->size()); + int i = 0; + for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) { + reduce_param->axes_[i++] = *iter; + } + reduce_param->mode_ = static_cast(reduce->mode()); + return reinterpret_cast(reduce_param); +} + +OpParameter *PopulateMeanParameter(const lite::Primitive *primitive) { + ReduceParameter *mean_param = new (std::nothrow) ReduceParameter(); + if (mean_param == nullptr) { + MS_LOG(ERROR) << "new ReduceParameter failed."; + return nullptr; + } + mean_param->op_parameter_.type_ = primitive->Type(); + auto mean = primitive->Value()->value_as_Mean(); + mean_param->keep_dims_ = mean->keepDims(); + auto axisVector = mean->axis(); + if (axisVector->size() > REDUCE_MAX_AXES_NUM) { + MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM; + delete (mean_param); + return nullptr; + } + mean_param->num_axes_ = static_cast(axisVector->size()); + int i = 0; + for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) { + mean_param->axes_[i++] = *iter; + } + mean_param->mode_ = static_cast(schema::ReduceMode_ReduceMean); + return reinterpret_cast(mean_param); +} + +OpParameter *PopulatePadParameter(const lite::Primitive *primitive) { + PadParameter *pad_param = new (std::nothrow) PadParameter(); + if (pad_param == nullptr) { + MS_LOG(ERROR) << "new PadParameter failed."; + return nullptr; + } + pad_param->op_parameter_.type_ = primitive->Type(); + auto pad_node = primitive->Value()->value_as_Pad(); + pad_param->pad_mode_ = pad_node->paddingMode(); + if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) { + pad_param->constant_value_ = pad_node->constantValue(); + } else { + MS_LOG(ERROR) << "Invalid padding mode: " << pad_param->pad_mode_; + delete (pad_param); + return nullptr; + } + + auto size = pad_node->paddings()->size(); + if (size > MAX_PAD_SIZE) { + MS_LOG(ERROR) << "Invalid padding size: " << size; + delete (pad_param); + return nullptr; + } + + for (size_t i = 0; i < size; i++) { + pad_param->paddings_[MAX_PAD_SIZE - size + i] = (*(pad_node->paddings()))[i]; + } + return reinterpret_cast(pad_param); +} + +OpParameter *PopulateActivationParameter(const lite::Primitive *primitive) { + ActivationParameter *act_param = new (std::nothrow) ActivationParameter(); + if (act_param == nullptr) { + MS_LOG(ERROR) << "new ActivationParameter failed."; + return nullptr; + } + auto activation = primitive->Value()->value_as_Activation(); + act_param->type_ = static_cast(activation->type()); + return reinterpret_cast(act_param); +} + +OpParameter *PopulateFusedBatchNorm(const lite::Primitive *primitive) { + FusedBatchNormParameter *fuse_batch_norm_param = new (std::nothrow) FusedBatchNormParameter(); + if (fuse_batch_norm_param == nullptr) { + MS_LOG(ERROR) << "new FusedBatchNormParameter failed."; + return nullptr; + } + fuse_batch_norm_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_FusedBatchNorm(); + fuse_batch_norm_param->epsilon_ = param->epsilon(); + return reinterpret_cast(fuse_batch_norm_param); +} + +OpParameter *PopulateArithmetic(const lite::Primitive *primitive) { + ArithmeticParameter *arithmetic_param = new (std::nothrow) ArithmeticParameter(); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "new ArithmeticParameter failed."; + return nullptr; + } + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + switch (primitive->Type()) { + case schema::PrimitiveType_Add: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Add()->activationType(); + break; + case schema::PrimitiveType_Sub: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Sub()->activationType(); + break; + case schema::PrimitiveType_Mul: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Mul()->activationType(); + break; + case schema::PrimitiveType_Div: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Div()->activationType(); + break; + default: + arithmetic_param->activation_type_ = 0; + break; + } + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + (void)memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + (void)memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + (void)memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); +} + +OpParameter *PopulateEltwiseParameter(const lite::Primitive *primitive) { + ArithmeticParameter *arithmetic_param = new (std::nothrow) ArithmeticParameter(); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "new ArithmeticParameter failed."; + return nullptr; + } + auto eltwise = primitive->Value()->value_as_Eltwise(); + switch (eltwise->mode()) { + case schema::EltwiseMode_PROD: + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul; + break; + case schema::EltwiseMode_SUM: + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add; + break; + case schema::EltwiseMode_MAXIMUM: + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum; + break; + default: + delete arithmetic_param; + return nullptr; + } + return reinterpret_cast(arithmetic_param); +} + +OpParameter *PopulateArithmeticSelf(const lite::Primitive *primitive) { + ArithmeticSelfParameter *arithmetic_self_param = new (std::nothrow) ArithmeticSelfParameter(); + if (arithmetic_self_param == nullptr) { + MS_LOG(ERROR) << "new ArithmeticParameter failed."; + return nullptr; + } + arithmetic_self_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(arithmetic_self_param); +} + +OpParameter *PopulatePowerParameter(const lite::Primitive *primitive) { + PowerParameter *power_param = new (std::nothrow) PowerParameter(); + if (power_param == nullptr) { + MS_LOG(ERROR) << "new PowerParameter failed."; + return nullptr; + } + power_param->op_parameter_.type_ = primitive->Type(); + auto power = primitive->Value()->value_as_Power(); + power_param->power_ = power->power(); + power_param->scale_ = power->scale(); + power_param->shift_ = power->shift(); + return reinterpret_cast(power_param); +} + +OpParameter *PopulateArgMaxParameter(const lite::Primitive *primitive) { + ArgMinMaxParameter *arg_param = new (std::nothrow) ArgMinMaxParameter(); + if (arg_param == nullptr) { + MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; + return nullptr; + } + arg_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_ArgMax(); + arg_param->axis_ = param->axis(); + arg_param->topk_ = param->topK(); + arg_param->axis_type_ = param->axisType(); + arg_param->out_value_ = param->outMaxValue(); + arg_param->keep_dims_ = param->keepDims(); + return reinterpret_cast(arg_param); +} + +OpParameter *PopulateArgMinParameter(const lite::Primitive *primitive) { + ArgMinMaxParameter *arg_param = new (std::nothrow) ArgMinMaxParameter(); + if (arg_param == nullptr) { + MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; + return nullptr; + } + arg_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_ArgMin(); + arg_param->axis_ = param->axis(); + arg_param->topk_ = param->topK(); + arg_param->axis_type_ = param->axisType(); + arg_param->out_value_ = param->outMaxValue(); + arg_param->keep_dims_ = param->keepDims(); + return reinterpret_cast(arg_param); +} + +OpParameter *PopulateCastParameter(const lite::Primitive *primitive) { + CastParameter *cast_param = new (std::nothrow) CastParameter(); + if (cast_param == nullptr) { + MS_LOG(ERROR) << "new CastParameter failed."; + return nullptr; + } + cast_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Cast(); + cast_param->src_type_ = param->srcT(); + cast_param->dst_type_ = param->dstT(); + return reinterpret_cast(cast_param); +} + +OpParameter *PopulateLocalResponseNormParameter(const lite::Primitive *primitive) { + auto local_response_norm_attr = primitive->Value()->value_as_LocalResponseNormalization(); + LocalResponseNormParameter *lrn_param = new (std::nothrow) LocalResponseNormParameter(); + if (lrn_param == nullptr) { + MS_LOG(ERROR) << "new LocalResponseNormParameter failed."; + return nullptr; + } + lrn_param->op_parameter_.type_ = primitive->Type(); + lrn_param->depth_radius_ = local_response_norm_attr->depth_radius(); + lrn_param->bias_ = local_response_norm_attr->bias(); + lrn_param->alpha_ = local_response_norm_attr->alpha(); + lrn_param->beta_ = local_response_norm_attr->beta(); + return reinterpret_cast(lrn_param); +} + +OpParameter *PopulateRangeParameter(const lite::Primitive *primitive) { + auto range_attr = primitive->Value()->value_as_Range(); + RangeParameter *range_param = new (std::nothrow) RangeParameter(); + if (range_param == nullptr) { + MS_LOG(ERROR) << "new RangeParameter failed."; + return nullptr; + } + range_param->op_parameter_.type_ = primitive->Type(); + range_param->start_ = range_attr->start(); + range_param->limit_ = range_attr->limit(); + range_param->delta_ = range_attr->delta(); + range_param->dType_ = range_attr->dType(); + return reinterpret_cast(range_param); +} + +OpParameter *PopulateConcatParameter(const lite::Primitive *primitive) { + ConcatParameter *concat_param = new (std::nothrow) ConcatParameter(); + if (concat_param == nullptr) { + MS_LOG(ERROR) << "new ConcatParameter failed."; + return nullptr; + } + concat_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Concat(); + concat_param->axis_ = param->axis(); + return reinterpret_cast(concat_param); +} + +OpParameter *PopulateTileParameter(const lite::Primitive *primitive) { + TileParameter *tile_param = new (std::nothrow) TileParameter(); + if (tile_param == nullptr) { + MS_LOG(ERROR) << "new TileParameter failed."; + return nullptr; + } + tile_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Tile(); + auto multiples = param->multiples(); + tile_param->in_dim_ = multiples->size(); + for (size_t i = 0; i < tile_param->in_dim_; ++i) { + tile_param->multiples_[i] = multiples->Get(i); + } + return reinterpret_cast(tile_param); +} + +OpParameter *PopulateTopKParameter(const lite::Primitive *primitive) { + TopkParameter *topk_param = new (std::nothrow) TopkParameter(); + if (topk_param == nullptr) { + MS_LOG(ERROR) << "new TopkParameter failed."; + return nullptr; + } + topk_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_TopK(); + topk_param->k_ = param->k(); + topk_param->sorted_ = param->sorted(); + return reinterpret_cast(topk_param); +} + +OpParameter *PopulateNhwc2NchwParameter(const lite::Primitive *primitive) { + OpParameter *parameter = new (std::nothrow) OpParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new Nhwc2NchwParameter failed."; + return nullptr; + } + parameter->type_ = primitive->Type(); + return parameter; +} + +OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) { + OpParameter *parameter = new (std::nothrow) OpParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new Nchw2NhwcParameter failed."; + return nullptr; + } + parameter->type_ = primitive->Type(); + return parameter; +} + +OpParameter *PopulateTransposeParameter(const lite::Primitive *primitive) { + TransposeParameter *transpose_param = new (std::nothrow) TransposeParameter(); + if (transpose_param == nullptr) { + MS_LOG(ERROR) << "new TransposeParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Transpose(); + transpose_param->op_parameter_.type_ = primitive->Type(); + auto perm_vector_ = param->perm(); + int i = 0; + for (auto iter = perm_vector_->begin(); iter != perm_vector_->end(); iter++) { + transpose_param->perm_[i++] = *iter; + } + transpose_param->num_axes_ = i; + transpose_param->conjugate_ = param->conjugate(); + return reinterpret_cast(transpose_param); +} + +OpParameter *PopulateSplitParameter(const lite::Primitive *primitive) { + SplitParameter *split_param = new (std::nothrow) SplitParameter(); + if (split_param == nullptr) { + MS_LOG(ERROR) << "new SplitParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Split(); + split_param->op_parameter_.type_ = primitive->Type(); + split_param->num_split_ = param->numberSplit(); + auto split_sizes_vector_ = param->sizeSplits(); + int i = 0; + for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); iter++) { + split_param->split_sizes_[i++] = *iter; + } + split_param->split_dim_ = param->splitDim(); + split_param->num_split_ = param->numberSplit(); + return reinterpret_cast(split_param); +} + +OpParameter *PopulateSqueezeParameter(const lite::Primitive *primitive) { + SqueezeParameter *squeeze_param = new (std::nothrow) SqueezeParameter(); + if (squeeze_param == nullptr) { + MS_LOG(ERROR) << "new SqueezeParameter failed."; + return nullptr; + } + squeeze_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(squeeze_param); +} + +OpParameter *PopulateScaleParameter(const lite::Primitive *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "input primitive is nullptr"; + return nullptr; + } + ScaleParameter *scale_param = new (std::nothrow) ScaleParameter(); + if (scale_param == nullptr) { + MS_LOG(ERROR) << "new ScaleParameter failed."; + return nullptr; + } + scale_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Scale(); + if (param == nullptr) { + MS_LOG(ERROR) << "value_as_Scale return nullptr"; + return nullptr; + } + scale_param->axis_ = param->axis(); + return reinterpret_cast(scale_param); +} + +OpParameter *PopulateGatherParameter(const lite::Primitive *primitive) { + auto gather_attr = primitive->Value()->value_as_Gather(); + GatherParameter *gather_param = new (std::nothrow) GatherParameter(); + if (gather_param == nullptr) { + MS_LOG(ERROR) << "new GatherParameter failed."; + return nullptr; + } + gather_param->op_parameter_.type_ = primitive->Type(); + gather_param->axis_ = gather_attr->axis(); + gather_param->batchDims_ = gather_attr->batchDims(); + return reinterpret_cast(gather_param); +} + +OpParameter *PopulateGatherNdParameter(const lite::Primitive *primitive) { + GatherNdParameter *gather_nd_param = new (std::nothrow) GatherNdParameter(); + if (gather_nd_param == nullptr) { + MS_LOG(ERROR) << "new GatherNDParameter failed."; + return nullptr; + } + gather_nd_param->op_parameter_.type_ = primitive->Type(); + auto gatherNd_attr = primitive->Value()->value_as_GatherNd(); + gather_nd_param->batchDims_ = gatherNd_attr->batchDims(); + return reinterpret_cast(gather_nd_param); +} + +OpParameter *PopulateScatterNDParameter(const lite::Primitive *primitive) { + ScatterNDParameter *scatter_nd_param = new (std::nothrow) ScatterNDParameter(); + if (scatter_nd_param == nullptr) { + MS_LOG(ERROR) << "new ScatterNDParameter failed."; + return nullptr; + } + scatter_nd_param->op_parameter_.type_ = primitive->Type(); + MS_ASSERT(paramter != nullptr); + return reinterpret_cast(scatter_nd_param); +} + +OpParameter *PopulateSliceParameter(const lite::Primitive *primitive) { + SliceParameter *slice_param = new (std::nothrow) SliceParameter(); + if (slice_param == nullptr) { + MS_LOG(ERROR) << "new SliceParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Slice(); + slice_param->op_parameter_.type_ = primitive->Type(); + auto param_begin = param->begin(); + auto param_size = param->size(); + if (param_begin->size() != param_size->size()) { + delete slice_param; + return nullptr; + } + slice_param->param_length_ = static_cast(param_begin->size()); + for (int32_t i = 0; i < slice_param->param_length_; ++i) { + slice_param->begin_[i] = param_begin->Get(i); + slice_param->size_[i] = param_size->Get(i); + } + return reinterpret_cast(slice_param); +} + +OpParameter *PopulateBroadcastToParameter(const lite::Primitive *primitive) { + BroadcastToParameter *broadcast_param = new (std::nothrow) BroadcastToParameter(); + if (broadcast_param == nullptr) { + MS_LOG(ERROR) << "new BroadcastToParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_BroadcastTo(); + broadcast_param->op_parameter_.type_ = primitive->Type(); + auto dst_shape = param->dst_shape(); + broadcast_param->shape_size_ = dst_shape->size(); + for (size_t i = 0; i < broadcast_param->shape_size_; ++i) { + broadcast_param->shape_[i] = dst_shape->Get(i); + } + return reinterpret_cast(broadcast_param); +} + +OpParameter *PopulateReshapeParameter(const lite::Primitive *primitive) { + ReshapeParameter *reshape_param = new (std::nothrow) ReshapeParameter(); + if (reshape_param == nullptr) { + MS_LOG(ERROR) << "new ReshapeParameter failed."; + return nullptr; + } + reshape_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(reshape_param); +} + +OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) { + auto reverse_attr = primitive->Value()->value_as_Reverse(); + ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); + if (reverse_param == nullptr) { + MS_LOG(ERROR) << "new ReverseParameter failed."; + return nullptr; + } + reverse_param->op_parameter_.type_ = primitive->Type(); + auto flatAxis = reverse_attr->axis(); + reverse_param->num_axis_ = flatAxis->size(); + int i = 0; + for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { + reverse_param->axis_[i++] = *iter; + } + return reinterpret_cast(reverse_param); +} + +OpParameter *PopulateUnsqueezeParameter(const lite::Primitive *primitive) { + auto unsqueeze_attr = primitive->Value()->value_as_Unsqueeze(); + UnsqueezeParameter *unsqueeze_param = new (std::nothrow) UnsqueezeParameter(); + if (unsqueeze_param == nullptr) { + MS_LOG(ERROR) << "new ReverseParameter failed."; + return nullptr; + } + unsqueeze_param->op_parameter_.type_ = primitive->Type(); + auto flatAxis = unsqueeze_attr->axis(); + unsqueeze_param->num_dim_ = flatAxis->size(); + int i = 0; + for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { + unsqueeze_param->dims_[i++] = *iter; + } + return reinterpret_cast(unsqueeze_param); +} + +OpParameter *PopulateStackParameter(const lite::Primitive *primitive) { + StackParameter *stack_param = new (std::nothrow) StackParameter(); + if (stack_param == nullptr) { + MS_LOG(ERROR) << "new StackParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Stack(); + stack_param->op_parameter_.type_ = primitive->Type(); + stack_param->axis_ = param->axis(); + return reinterpret_cast(stack_param); +} + +OpParameter *PopulateUnstackParameter(const lite::Primitive *primitive) { + UnstackParameter *unstack_param = new (std::nothrow) UnstackParameter(); + if (unstack_param == nullptr) { + MS_LOG(ERROR) << "new UnstackParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Unstack(); + unstack_param->op_parameter_.type_ = primitive->Type(); + unstack_param->num_ = param->num(); + unstack_param->axis_ = param->axis(); + return reinterpret_cast(unstack_param); +} + +OpParameter *PopulateReverseSequenceParameter(const lite::Primitive *primitive) { + ReverseSequenceParameter *reverse_sequence_param = new (std::nothrow) ReverseSequenceParameter(); + if (reverse_sequence_param == nullptr) { + MS_LOG(ERROR) << "new ReverseSequenceParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_ReverseSequence(); + reverse_sequence_param->op_parameter_.type_ = primitive->Type(); + reverse_sequence_param->seq_axis_ = param->seqAxis(); + reverse_sequence_param->batch_axis_ = param->batchAxis(); + return reinterpret_cast(reverse_sequence_param); +} + +OpParameter *PopulateUniqueParameter(const lite::Primitive *primitive) { + UniqueParameter *unique_param = new (std::nothrow) UniqueParameter(); + if (unique_param == nullptr) { + MS_LOG(ERROR) << "new PopulateUniqueParam failed."; + return nullptr; + } + unique_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(unique_param); +} + +OpParameter *PopulateDepthToSpaceParameter(const lite::Primitive *primitive) { + DepthToSpaceParameter *depth_space_param = new (std::nothrow) DepthToSpaceParameter(); + if (depth_space_param == nullptr) { + MS_LOG(ERROR) << "new DepthToSpaceParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_DepthToSpace(); + depth_space_param->op_parameter_.type_ = primitive->Type(); + depth_space_param->block_size_ = param->blockSize(); + return reinterpret_cast(depth_space_param); +} + +OpParameter *PopulateSpaceToDepthParameter(const lite::Primitive *primitive) { + SpaceToDepthParameter *space_depth_param = new (std::nothrow) SpaceToDepthParameter(); + if (space_depth_param == nullptr) { + MS_LOG(ERROR) << "new SpaceToDepthspace_depth_param failed."; + return nullptr; + } + space_depth_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_DepthToSpace(); + space_depth_param->op_parameter_.type_ = primitive->Type(); + space_depth_param->block_size_ = param->blockSize(); + if (param->format() != schema::Format_NHWC) { + MS_LOG(ERROR) << "Currently only NHWC format is supported."; + return nullptr; + } + return reinterpret_cast(space_depth_param); +} + +OpParameter *PopulateSpaceToBatchParameter(const lite::Primitive *primitive) { + SpaceToBatchParameter *space_batch_param = new (std::nothrow) SpaceToBatchParameter(); + if (space_batch_param == nullptr) { + MS_LOG(ERROR) << "new SpaceToBatchParameter failed."; + return nullptr; + } + space_batch_param->op_parameter_.type_ = primitive->Type(); + space_batch_param->op_parameter_.type_ = primitive->Type(); + auto block_sizes = ((lite::SpaceToBatch *)primitive)->BlockSizes(); + (void)memcpy(space_batch_param->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); + auto paddings = ((lite::SpaceToBatch *)primitive)->Paddings(); + (void)memcpy(space_batch_param->paddings_, (paddings.data()), paddings.size() * sizeof(int)); + auto in_shape = ((lite::SpaceToBatch *)primitive)->InShape(); + (void)memcpy(space_batch_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + auto padded_in_shape = ((lite::SpaceToBatch *)primitive)->PaddedInShape(); + (void)memcpy(space_batch_param->padded_in_shape_, (padded_in_shape.data()), padded_in_shape.size() * sizeof(int)); + return reinterpret_cast(space_batch_param); +} + +OpParameter *PopulateResizeParameter(const lite::Primitive *primitive) { + ResizeParameter *resize_param = new (std::nothrow) ResizeParameter(); + if (resize_param == nullptr) { + MS_LOG(ERROR) << "new ResizeParameter failed."; + return nullptr; + } + resize_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Resize(); + resize_param->method_ = param->method(); + resize_param->new_height_ = param->newHeight(); + resize_param->new_width_ = param->newWidth(); + resize_param->align_corners_ = param->alignCorners(); + resize_param->preserve_aspect_ratio_ = param->preserveAspectRatio(); + return reinterpret_cast(resize_param); +} + +OpParameter *PopulateBatchToSpaceParameter(const lite::Primitive *primitive) { + BatchToSpaceParameter *batch_space_param = new (std::nothrow) BatchToSpaceParameter(); + if (batch_space_param == nullptr) { + MS_LOG(ERROR) << "New BatchToSpaceParameter fail!"; + return nullptr; + } + batch_space_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_BatchToSpace(); + auto block_shape = param->blockShape(); + if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { + MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; + return nullptr; + } + + auto crops = param->crops(); + if (crops->size() != BATCH_TO_SPACE_CROPS_SIZE) { + MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE; + return nullptr; + } + + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + batch_space_param->block_shape_[i] = block_shape->Get(i); + } + + for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { + batch_space_param->crops_[i] = crops->Get(i); + } + return reinterpret_cast(batch_space_param); +} + +OpParameter *PopulateCropParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_Crop(); + auto param_offset = param->offsets(); + if (param_offset->size() > CROP_OFFSET_MAX_SIZE) { + MS_LOG(ERROR) << "crop_param offset size(" << param_offset->size() << ") should <= " << CROP_OFFSET_MAX_SIZE; + return nullptr; + } + CropParameter *crop_param = new (std::nothrow) CropParameter(); + if (crop_param == nullptr) { + MS_LOG(ERROR) << "new CropParameter fail!"; + return nullptr; + } + crop_param->op_parameter_.type_ = primitive->Type(); + crop_param->axis_ = param->axis(); + crop_param->offset_size_ = param_offset->size(); + for (int i = 0; i < param_offset->size(); ++i) { + crop_param->offset_[i] = param_offset->Get(i); + } + return reinterpret_cast(crop_param); +} + +OpParameter *PopulateOneHotParameter(const lite::Primitive *primitive) { + OneHotParameter *one_hot_param = new (std::nothrow) OneHotParameter(); + if (one_hot_param == nullptr) { + MS_LOG(ERROR) << "new OneHotParameter fail!"; + return nullptr; + } + one_hot_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_OneHot(); + if (param == nullptr) { + delete (one_hot_param); + MS_LOG(ERROR) << "get OneHot param nullptr."; + return nullptr; + } + one_hot_param->axis_ = param->axis(); + return reinterpret_cast(one_hot_param); +} + +OpParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { + FlattenParameter *flatten_param = new (std::nothrow) FlattenParameter(); + if (flatten_param == nullptr) { + MS_LOG(ERROR) << "new FlattenParameter fail!"; + return nullptr; + } + flatten_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(flatten_param); +} + +OpParameter *PopulateQuantDTypeCastParameter(const lite::Primitive *primitive) { + QuantDTypeCastParameter *parameter = new (std::nothrow) QuantDTypeCastParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new QuantDTypeCastParameter fail!"; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + auto quant_dtype_cast_param = primitive->Value()->value_as_QuantDTypeCast(); + parameter->srcT = quant_dtype_cast_param->srcT(); + parameter->dstT = quant_dtype_cast_param->dstT(); + return reinterpret_cast(parameter); +} + +OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) { + StridedSliceParameter *strided_slice_param = new (std::nothrow) StridedSliceParameter(); + if (strided_slice_param == nullptr) { + MS_LOG(ERROR) << "new StridedSliceParameter failed."; + return nullptr; + } + strided_slice_param->op_parameter_.type_ = primitive->Type(); + auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); + strided_slice_param->num_axes_ = n_dims; + auto begin = ((lite::StridedSlice *)primitive)->GetBegins(); + (void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int)); + auto end = ((lite::StridedSlice *)primitive)->GetEnds(); + (void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int)); + auto stride = ((lite::StridedSlice *)primitive)->GetStrides(); + (void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); + auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape(); + (void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + return reinterpret_cast(strided_slice_param); +} + +OpParameter *PopulateAddNParameter(const lite::Primitive *primitive) { + auto addn_param = new (std::nothrow) OpParameter(); + if (addn_param == nullptr) { + MS_LOG(ERROR) << "new OpParameter fail!"; + return nullptr; + } + addn_param->type_ = primitive->Type(); + return reinterpret_cast(addn_param); +} + +OpParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) { + PriorBoxParameter *prior_box_param = new (std::nothrow) PriorBoxParameter(); + if (prior_box_param == nullptr) { + MS_LOG(ERROR) << "new PriorBoxParameter failed."; + return nullptr; + } + prior_box_param->op_parameter_.type_ = primitive->Type(); + auto prior_box_attr = primitive->Value()->value_as_PriorBox(); + + if (prior_box_attr->min_sizes()->size() > PRIOR_BOX_MAX_NUM) { + MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " + << prior_box_attr->min_sizes(); + delete (prior_box_param); + return nullptr; + } + prior_box_param->min_sizes_size = prior_box_attr->min_sizes()->size(); + if (prior_box_attr->max_sizes()->size() > PRIOR_BOX_MAX_NUM) { + MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " + << prior_box_attr->max_sizes(); + delete (prior_box_param); + return nullptr; + } + prior_box_param->max_sizes_size = prior_box_attr->max_sizes()->size(); + (void)memcpy(prior_box_param->max_sizes, prior_box_attr->max_sizes()->data(), + prior_box_attr->max_sizes()->size() * sizeof(int32_t)); + (void)memcpy(prior_box_param->min_sizes, prior_box_attr->min_sizes()->data(), + prior_box_attr->min_sizes()->size() * sizeof(int32_t)); + + if (prior_box_attr->aspect_ratios()->size() > PRIOR_BOX_MAX_NUM) { + MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " + << prior_box_attr->aspect_ratios(); + delete (prior_box_param); + return nullptr; + } + prior_box_param->aspect_ratios_size = prior_box_attr->aspect_ratios()->size(); + (void)memcpy(prior_box_param->aspect_ratios, prior_box_attr->aspect_ratios()->data(), + prior_box_attr->aspect_ratios()->size() * sizeof(float)); + if (prior_box_attr->variances()->size() != PRIOR_BOX_VAR_NUM) { + MS_LOG(ERROR) << "PriorBox variances size should be " << PRIOR_BOX_VAR_NUM << ", got " + << prior_box_attr->variances()->size(); + delete (prior_box_param); + return nullptr; + } + (void)memcpy(prior_box_param->variances, prior_box_attr->variances()->data(), PRIOR_BOX_VAR_NUM * sizeof(float)); + prior_box_param->flip = prior_box_attr->flip(); + prior_box_param->clip = prior_box_attr->clip(); + prior_box_param->offset = prior_box_attr->offset(); + prior_box_param->image_size_h = prior_box_attr->image_size_h(); + prior_box_param->image_size_w = prior_box_attr->image_size_w(); + prior_box_param->step_h = prior_box_attr->step_h(); + prior_box_param->step_w = prior_box_attr->step_w(); + return reinterpret_cast(prior_box_param); +} + +OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) { + LstmParameter *lstm_param = new (std::nothrow) LstmParameter(); + if (lstm_param == nullptr) { + MS_LOG(ERROR) << "new LstmParameter fail!"; + return nullptr; + } + lstm_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Lstm(); + if (param == nullptr) { + delete (lstm_param); + MS_LOG(ERROR) << "get Lstm param nullptr."; + return nullptr; + } + lstm_param->bidirectional_ = param->bidirection(); + return reinterpret_cast(lstm_param); +} + +OpParameter *PopulateEmbeddingLookupParameter(const lite::Primitive *primitive) { + EmbeddingLookupParameter *embedding_lookup_parameter = new (std::nothrow) EmbeddingLookupParameter(); + if (embedding_lookup_parameter == nullptr) { + MS_LOG(ERROR) << "new EmbeddingLookupParameter failed"; + return nullptr; + } + embedding_lookup_parameter->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_EmbeddingLookup(); + embedding_lookup_parameter->max_norm_ = param->maxNorm(); + if (embedding_lookup_parameter->max_norm_ < 0) { + MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got " + << embedding_lookup_parameter->max_norm_; + return nullptr; + } + return reinterpret_cast(embedding_lookup_parameter); +} + +PopulateParameterRegistry::PopulateParameterRegistry() { + populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; + populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; + populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter; + populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter; + populate_parameter_funcs_[schema::PrimitiveType_Mean] = PopulateMeanParameter; + populate_parameter_funcs_[schema::PrimitiveType_Pooling] = PopulatePoolingParameter; + populate_parameter_funcs_[schema::PrimitiveType_DepthwiseConv2D] = PopulateConvDwParameter; + populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter; + populate_parameter_funcs_[schema::PrimitiveType_DeConv2D] = PopulateDeconvParameter; + populate_parameter_funcs_[schema::PrimitiveType_FusedBatchNorm] = PopulateFusedBatchNorm; + populate_parameter_funcs_[schema::PrimitiveType_BatchNorm] = PopulateBatchNorm; + populate_parameter_funcs_[schema::PrimitiveType_FullConnection] = PopulateFullconnectionParameter; + populate_parameter_funcs_[schema::PrimitiveType_Power] = PopulatePowerParameter; + populate_parameter_funcs_[schema::PrimitiveType_LocalResponseNormalization] = PopulateLocalResponseNormParameter; + populate_parameter_funcs_[schema::PrimitiveType_Range] = PopulateRangeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Transpose] = PopulateTransposeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Mul] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_FloorDiv] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_FloorMod] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_SquaredDifference] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_BiasAdd] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Eltwise] = PopulateEltwiseParameter; + populate_parameter_funcs_[schema::PrimitiveType_ExpandDims] = PopulateExpandDimsParameter; + populate_parameter_funcs_[schema::PrimitiveType_Abs] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Cos] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Sin] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Exp] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Log] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Sqrt] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Rsqrt] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_LogicalNot] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Floor] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Ceil] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_ArgMax] = PopulateArgMaxParameter; + populate_parameter_funcs_[schema::PrimitiveType_ArgMin] = PopulateArgMinParameter; + populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter; + populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter; + populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter; + populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter; + populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter; + populate_parameter_funcs_[schema::PrimitiveType_Fill] = PopulateFillParameter; + populate_parameter_funcs_[schema::PrimitiveType_Gather] = PopulateGatherParameter; + populate_parameter_funcs_[schema::PrimitiveType_GatherNd] = PopulateGatherNdParameter; + populate_parameter_funcs_[schema::PrimitiveType_Slice] = PopulateSliceParameter; + populate_parameter_funcs_[schema::PrimitiveType_BroadcastTo] = PopulateBroadcastToParameter; + populate_parameter_funcs_[schema::PrimitiveType_Reverse] = PopulateReverseParameter; + populate_parameter_funcs_[schema::PrimitiveType_Stack] = PopulateStackParameter; + populate_parameter_funcs_[schema::PrimitiveType_Unstack] = PopulateUnstackParameter; + populate_parameter_funcs_[schema::PrimitiveType_ReverseSequence] = PopulateReverseSequenceParameter; + populate_parameter_funcs_[schema::PrimitiveType_Unique] = PopulateUniqueParameter; + populate_parameter_funcs_[schema::PrimitiveType_DepthToSpace] = PopulateDepthToSpaceParameter; + populate_parameter_funcs_[schema::PrimitiveType_Nchw2Nhwc] = PopulateNchw2NhwcParameter; + populate_parameter_funcs_[schema::PrimitiveType_Nhwc2Nchw] = PopulateNhwc2NchwParameter; + populate_parameter_funcs_[schema::PrimitiveType_Pad] = PopulatePadParameter; + populate_parameter_funcs_[schema::PrimitiveType_Resize] = PopulateResizeParameter; + populate_parameter_funcs_[schema::PrimitiveType_BatchToSpace] = PopulateBatchToSpaceParameter; + populate_parameter_funcs_[schema::PrimitiveType_SpaceToDepth] = PopulateSpaceToDepthParameter; + populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatch] = PopulateSpaceToBatchParameter; + populate_parameter_funcs_[schema::PrimitiveType_Crop] = PopulateCropParameter; + populate_parameter_funcs_[schema::PrimitiveType_Unsqueeze] = PopulateUnsqueezeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Flatten] = PopulateFlattenParameter; + populate_parameter_funcs_[schema::PrimitiveType_MatMul] = PopulateMatMulParameter; + populate_parameter_funcs_[schema::PrimitiveType_OneHot] = PopulateOneHotParameter; + populate_parameter_funcs_[schema::PrimitiveType_AddN] = PopulateAddNParameter; + populate_parameter_funcs_[schema::PrimitiveType_StridedSlice] = PopulateStridedSliceParameter; + populate_parameter_funcs_[schema::PrimitiveType_ScatterND] = PopulateScatterNDParameter; + populate_parameter_funcs_[schema::PrimitiveType_Squeeze] = PopulateSqueezeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter; + populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; + populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter; + populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter; + populate_parameter_funcs_[schema::PrimitiveType_EmbeddingLookup] = PopulateEmbeddingLookupParameter; +} + +PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { + static PopulateParameterRegistry populate_parameter_instance; + return &populate_parameter_instance; +} + +PopulateParameterFunc PopulateParameterRegistry::GetParameterFunc(const schema::PrimitiveType &type) { + return populate_parameter_funcs_[type]; +} + +OpParameter *PopulateParameter(const lite::Primitive *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + auto op_type = primitive->Type(); + auto func = PopulateParameterRegistry::GetInstance()->GetParameterFunc(op_type); + if (func == nullptr) { + MS_LOG(ERROR) << "Get nullptr for Op Parameter Func."; + return nullptr; + } + + auto *parameter = func(primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "Get nullptr for Op Parameter."; + return nullptr; + } + return parameter; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/populate_parameter.h b/mindspore/lite/src/populate_parameter.h new file mode 100644 index 0000000000..4bfde44d8d --- /dev/null +++ b/mindspore/lite/src/populate_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ + +#include "schema/model_generated.h" +#include "src/ops/ops.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +namespace mindspore::kernel { +typedef OpParameter *(*PopulateParameterFunc)(const lite::Primitive *); + +class PopulateParameterRegistry { + public: + PopulateParameterRegistry(); + ~PopulateParameterRegistry() = default; + + static PopulateParameterRegistry *GetInstance(); + PopulateParameterFunc GetParameterFunc(const schema::PrimitiveType &type); + + protected: + PopulateParameterFunc populate_parameter_funcs_[schema::PrimitiveType_MAX + 1]; +}; + +OpParameter *PopulateParameter(const lite::Primitive *primitive); +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/allocator.cc b/mindspore/lite/src/runtime/allocator.cc new file mode 100644 index 0000000000..3b047e6690 --- /dev/null +++ b/mindspore/lite/src/runtime/allocator.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/allocator.h" +#include +#include "utils/log_adapter.h" + +namespace mindspore::lite { +std::shared_ptr Allocator::Create() { return std::shared_ptr(new DefaultAllocator()); } + +DefaultAllocator::DefaultAllocator() {} + +DefaultAllocator::~DefaultAllocator() { Clear(); } + +void DefaultAllocator::SetContext(const AllocatorContext &ctx) { + lockFlag = ctx.lockFlag; + shiftFactor = ctx.shiftFactor; +} + +void DefaultAllocator::Lock() { + if (lockFlag) { + lock.lock(); + } +} + +void DefaultAllocator::UnLock() { + if (lockFlag) { + lock.unlock(); + } +} + +void *DefaultAllocator::Malloc(size_t size) { + if (size > MAX_MALLOC_SIZE) { + MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; + return nullptr; + } + Lock(); + auto iter = freeList.lower_bound(size); + if (iter != freeList.end() && (iter->second->size >= size) && (iter->second->size < (size << shiftFactor))) { + auto membuf = iter->second; + freeList.erase(iter); + allocatedList[membuf->buf] = membuf; + UnLock(); + return membuf->buf; + } + + std::unique_ptr membuf(reinterpret_cast(malloc(sizeof(MemBuf) + size))); + if (membuf == nullptr) { + MS_LOG(ERROR) << "malloc membuf return nullptr"; + UnLock(); + return nullptr; + } + membuf->size = size; + membuf->buf = reinterpret_cast(membuf.get()) + sizeof(MemBuf); + auto bufPtr = membuf->buf; + allocatedList[bufPtr] = membuf.release(); + UnLock(); + return bufPtr; +} + +void DefaultAllocator::Free(void *buf) { + if (buf == nullptr) { + return; + } + Lock(); + auto iter = allocatedList.find(buf); + if (iter != allocatedList.end()) { + auto membuf = iter->second; + allocatedList.erase(iter); + freeList.insert(std::make_pair(membuf->size, membuf)); + UnLock(); + return; + } + UnLock(); + free(buf); +} + +size_t DefaultAllocator::GetTotalSize() { + Lock(); + size_t totalSize = 0; + + for (auto it = allocatedList.begin(); it != allocatedList.end(); it++) { + auto membuf = it->second; + totalSize += membuf->size; + } + + for (auto it = freeList.begin(); it != freeList.end(); it++) { + auto membuf = it->second; + totalSize += membuf->size; + } + UnLock(); + return totalSize; +} + +void DefaultAllocator::Clear() { + Lock(); + + for (auto it = allocatedList.begin(); it != allocatedList.end(); it++) { + free(it->second); + } + allocatedList.clear(); + + for (auto it = freeList.begin(); it != freeList.end(); it++) { + free(it->second); + } + freeList.clear(); + UnLock(); +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/runtime/allocator.h b/mindspore/lite/src/runtime/allocator.h new file mode 100644 index 0000000000..9d44bd6f89 --- /dev/null +++ b/mindspore/lite/src/runtime/allocator.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore::lite { +struct AllocatorContext { + int shiftFactor; + bool lockFlag; +}; + +class Allocator { + public: + Allocator() : name("default") {} + virtual ~Allocator() {} + virtual void *Malloc(size_t size) = 0; + virtual void Free(void *ptr) = 0; + virtual void SetContext(const AllocatorContext &ctx) {} + virtual size_t GetTotalSize() { return 0; } + virtual void Clear() {} + static std::shared_ptr Create(); + std::string name; +}; + +class DefaultAllocator : public Allocator { + public: + DefaultAllocator(); + ~DefaultAllocator() override; + void SetContext(const AllocatorContext &ctx) override; + void *Malloc(size_t size) override; + void Free(void *ptr) override; + size_t GetTotalSize() override; + void Clear() override; + + private: + void Lock(); + void UnLock(); + struct MemBuf { + size_t size; + void *buf; + }; + + std::mutex lock; + // buf, membuf> + std::unordered_map allocatedList; + std::multimap freeList; + // 6 is empirical value + int shiftFactor = 6; + bool lockFlag = false; +}; + +#define MAX_MALLOC_SIZE 500 * 1024 * 1024 + +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt new file mode 100644 index 0000000000..2a66748ecf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -0,0 +1,29 @@ +file(GLOB KERNEL_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc + nnacl/*.cc + nnacl/fp32/*.cc + nnacl/int8/*.cc + nnacl/quantization/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc + ) + +if (PLATFORM_ARM64) + # assembly + file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s + nnacl/assembly/arm64/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) +endif() + +if (PLATFORM_ARM32) + # assembly + file(GLOB ASSEMBLY_SRC nnacl/assembly/arm32/*.s + nnacl/assembly/arm32/*.S + ) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) +endif() + +add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC}) +add_subdirectory(nnacl) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc new file mode 100644 index 0000000000..0e356a5315 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/arg_min_max_base.h" +#include "src/runtime/kernel/arm/nnacl/arg_min_max.h" +#include "src/runtime/kernel/arm/fp32/argminmax.h" +#include "src/runtime/kernel/arm/int8/argminmax_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ArgMax; +using mindspore::schema::PrimitiveType_ArgMin; + +namespace mindspore::kernel { +int ArgMinMaxBaseCPUKernel::Init() { + auto param = reinterpret_cast(opParameter); + switch (opParameter->type_) { + case PrimitiveType_ArgMax: + param->get_max_ = true; + break; + case PrimitiveType_ArgMin: + param->get_max_ = false; + break; + default: + MS_LOG(ERROR) << "Unexpected type " << opParameter->type_; + return RET_ERROR; + } + auto in_shape = inputs_.at(0)->shape(); + auto dims_size = in_shape.size(); + int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_; + param->axis_ = axis; + param->dims_size_ = dims_size; + if (param->topk_ <= 0) { + MS_LOG(ERROR) << "Invalid topk " << param->topk_; + return RET_PARAM_INVALID; + } + param->topk_ = MSMIN(param->topk_, in_shape[axis]); + if (param->topk_ > 1) { + if (context_ != nullptr && context_->allocator != nullptr) { + param->arg_elements_ + = reinterpret_cast(context_->allocator->Malloc(sizeof(ArgElement) * in_shape[axis])); + data_from_allocator_ = true; + } else { + param->arg_elements_ = reinterpret_cast(malloc(sizeof(ArgElement) * in_shape[axis])); + } + if (param->arg_elements_ == nullptr) { + MS_LOG(ERROR) << "malloc memroy fail!"; + return RET_ERROR; + } + } + return RET_OK; +} + +int ArgMinMaxBaseCPUKernel::Run() { + auto input = inputs_.at(0); + + auto input_data = reinterpret_cast(inputs_.at(0)->Data()); + auto output_data = outputs_.at(0)->Data(); + + auto shape = input->shape().data(); + auto param = reinterpret_cast(opParameter); + ArgMinMax(input_data, output_data, reinterpret_cast(shape), param); + return RET_OK; +} + +void ArgMinMaxBaseCPUKernel::FreeTmpMemory() { + auto param = reinterpret_cast(opParameter); + if (param->arg_elements_ == nullptr) { + return; + } + if (data_from_allocator_) { + context_->allocator->Free(param->arg_elements_); + } else { + free(param->arg_elements_); + } + param->arg_elements_ = nullptr; +} + +kernel::LiteKernel *CpuArgMinMaxInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + auto kernel = new (std::nothrow) ArgMinMaxInt8CPUKernel(op_parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ArgMinMaxInt8CPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuArgMinMaxFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + auto kernel = new (std::nothrow) ArgMinMaxCPUKernel(op_parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ArgMinMaxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, CpuArgMinMaxFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, CpuArgMinMaxFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ArgMax, CpuArgMinMaxInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ArgMin, CpuArgMinMaxInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h new file mode 100644 index 0000000000..d5ad81f29d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARG_MIN_MAX_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARG_MIN_MAX_BASE_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class ArgMinMaxBaseCPUKernel : public LiteKernel { + public: + ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx), data_from_allocator_(false) { + opParameter->thread_num_ = ctx->thread_num_; + } + + virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); } + + int Init() override; + + int ReSize() override { return 0; } + + int Run() override; + + void FreeTmpMemory(); + + private: + const lite::Context *context_; + bool data_from_allocator_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARG_MIN_MAX_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc new file mode 100644 index 0000000000..e816153411 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/batch_to_space_base.h" +#include "src/runtime/kernel/arm/nnacl/batch_to_space.h" +#include "src/runtime/kernel/arm/fp32/batch_to_space.h" +#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BatchToSpace; + +namespace mindspore::kernel { +int BatchToSpaceBaseCPUKernel::Init() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + BatchToSpaceParameter *param = reinterpret_cast(this->opParameter); + for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { + if (param->crops_[i] != 0) { + no_crop_ = false; + } + } + return RET_OK; +} + +kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace); + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) BatchToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchToSpaceInt8CPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace); + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(op_parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchToSpaceCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BatchToSpace, CpuBatchToSpaceInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h new file mode 100644 index 0000000000..131b512e76 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_BATCH_TO_SPACE_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_BATCH_TO_SPACE_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/concat_parameter.h" + +namespace mindspore::kernel { +class BatchToSpaceBaseCPUKernel : public LiteKernel { + public: + BatchToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + opParameter->thread_num_ = ctx->thread_num_; + } + + virtual ~BatchToSpaceBaseCPUKernel() = default; + + int Init() override; + + int ReSize() override { return 0; } + + int Run() override { return 0; } + + bool IsNoCrop() const { return no_crop_; } + + private: + bool no_crop_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_BATCH_TO_SPACE_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc new file mode 100644 index 0000000000..ed1e5416b9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/concat_base.h" +#include +#include "src/runtime/kernel/arm/int8/concat_int8.h" +#include "src/runtime/kernel/arm/fp32/concat.h" +#include "src/runtime/kernel/arm/nnacl/fp32/concat.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Concat; + +namespace mindspore::kernel { +int ConcatBaseCPUKernel::Init() { + axis_ = concat_param_->axis_ >= 0 ? concat_param_->axis_ : inputs_.front()->shape().size() + concat_param_->axis_; + return RET_OK; +} + +kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConcatInt32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConcatFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Concat, CpuConcatInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Concat, CpuConcatInt32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Concat, CpuConcatFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h new file mode 100644 index 0000000000..9c7f558083 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/concat_parameter.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ConcatBaseCPUKernel : public LiteKernel { + public: + ConcatBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + opParameter->thread_num_ = ctx->thread_num_; + concat_param_ = reinterpret_cast(opParameter); + } + + virtual ~ConcatBaseCPUKernel() = default; + + int Init() override; + + int ReSize() override { return 0; } + + int Run() override { return 0; } + protected: + int thread_count_; + int axis_; + const Context *ctx_; + ConcatParameter *concat_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc new file mode 100644 index 0000000000..43cc5605c6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType; +using mindspore::schema::PadMode; + +namespace mindspore::kernel { +ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { + if (bias_data_ != nullptr) { + free(bias_data_); + bias_data_ = nullptr; + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + nhwc4_input_ = nullptr; + } +} + +void ConvolutionBaseCPUKernel::FreeQuantParam() { + ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_; + if (conv_quant_arg_ == nullptr) { + return; + } + if (conv_quant_arg_->real_multiplier_ != nullptr) { + free(conv_quant_arg_->real_multiplier_); + conv_quant_arg_->real_multiplier_ = nullptr; + } + if (conv_quant_arg_->left_shift_ != nullptr) { + free(conv_quant_arg_->left_shift_); + conv_quant_arg_->left_shift_ = nullptr; + } + if (conv_quant_arg_->right_shift_ != nullptr) { + free(conv_quant_arg_->right_shift_); + conv_quant_arg_->right_shift_ = nullptr; + } + if (conv_quant_arg_->quant_multiplier_ != nullptr) { + free(conv_quant_arg_->quant_multiplier_); + conv_quant_arg_->quant_multiplier_ = nullptr; + } + if (conv_quant_arg_->out_act_min_ != nullptr) { + free(conv_quant_arg_->out_act_min_); + conv_quant_arg_->out_act_min_ = nullptr; + } + if (conv_quant_arg_->out_act_max_ != nullptr) { + free(conv_quant_arg_->out_act_max_); + conv_quant_arg_->out_act_max_ = nullptr; + } + + if (conv_quant_arg_->quant_args_ != nullptr) { + for (int i = 0; i < 3; ++i) { + if (*(conv_quant_arg_->quant_args_ + i) != nullptr) { + free(*(conv_quant_arg_->quant_args_ + i)); + } + } + } +} + +int ConvolutionBaseCPUKernel::Init() { + auto input = this->inputs_.front(); + auto output = this->outputs_.front(); + conv_param_->input_batch_ = input->Batch(); + conv_param_->input_h_ = input->Height(); + conv_param_->input_w_ = input->Width(); + conv_param_->input_channel_ = input->Channel(); + conv_param_->output_batch_ = output->Batch(); + conv_param_->output_h_ = output->Height(); + conv_param_->output_w_ = output->Width(); + conv_param_->output_channel_ = output->Channel(); + conv_param_->thread_num_ = ctx_->thread_num_; + return RET_OK; +} + +int ConvolutionBaseCPUKernel::CheckLayout(lite::tensor::Tensor *input_tensor) { + auto data_type = input_tensor->data_type(); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + convert_func_ = LayoutTransform(data_type, input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetQuantParam() { + ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_; + conv_quant_arg_->quant_args_ = reinterpret_cast(malloc(3 * sizeof(QuantArg *))); + if (conv_quant_arg_->quant_args_ == nullptr) { + MS_LOG(ERROR) << "malloc quant_args_ failed."; + return RET_ERROR; + } + // per-tensor init + for (int j = 0; j < 3; ++j) { + conv_quant_arg_->quant_args_[j] = reinterpret_cast(malloc(sizeof(QuantArg))); + if (conv_quant_arg_->quant_args_[j] == nullptr) { + MS_LOG(ERROR) << "malloc quant_args_ failed."; + return RET_ERROR; + } + } + auto input_tensor = inputs_.at(kInputIndex); + auto weight_tensor = inputs_.at(kWeightIndex); + auto output_tensor = outputs_.at(kOutputIndex); + auto input_quant_arg = input_tensor->GetQuantParams().front(); + auto weight_quant_arg = weight_tensor->GetQuantParams().front(); + auto output_quant_arg = output_tensor->GetQuantParams().front(); + // input + conv_quant_arg_->quant_args_[0][0].zp_ = input_quant_arg.zeroPoint; + conv_quant_arg_->quant_args_[0][0].scale_ = input_quant_arg.scale; + // weight + conv_quant_arg_->quant_args_[1][0].zp_ = weight_quant_arg.zeroPoint; + conv_quant_arg_->quant_args_[1][0].scale_ = weight_quant_arg.scale; + // output + conv_quant_arg_->quant_args_[2][0].zp_ = output_quant_arg.zeroPoint; + conv_quant_arg_->quant_args_[2][0].scale_ = output_quant_arg.scale; + + conv_quant_arg_->real_multiplier_ = reinterpret_cast(malloc(sizeof(double))); + conv_quant_arg_->left_shift_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->right_shift_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->quant_multiplier_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->out_act_min_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->out_act_max_ = reinterpret_cast(malloc(sizeof(int32_t))); + + double real_multiplier = weight_quant_arg.scale * input_quant_arg.scale / output_quant_arg.scale; + conv_quant_arg_->real_multiplier_[0] = real_multiplier; + QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[0], &conv_quant_arg_->left_shift_[0], + &conv_quant_arg_->right_shift_[0]); + + CalculateActivationRangeQuantized( + conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param_->conv_quant_arg_.quant_args_[2][0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0], + &conv_param_->conv_quant_arg_.out_act_max_[0]); + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h new file mode 100644 index 0000000000..89b53dfcad --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ + +#include +#include +#include +#include +#ifdef ENABLE_ARM +#include +#include +#endif +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; +using mindspore::schema::PadMode; +using mindspore::schema::QuantType; + +namespace mindspore::kernel { +class ConvolutionBaseCPUKernel : public LiteKernel { + public: + ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + opParameter->thread_num_ = ctx->thread_num_; + conv_param_ = reinterpret_cast(opParameter); + } + ~ConvolutionBaseCPUKernel() override; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + virtual int CheckLayout(lite::tensor::Tensor *input_tensor); + int SetQuantParam(); + void FreeQuantParam(); + + protected: + int thread_count_; + int tile_num_; + void *bias_data_ = nullptr; + void *nhwc4_input_ = nullptr; + const Context *ctx_; + ConvParameter *conv_param_; + LayoutConvertor convert_func_; +}; +bool CheckSupportFP16(); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc new file mode 100644 index 0000000000..9f66feb208 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/crop_base.h" +#include +#include "src/runtime/kernel/arm/int8/crop_int8.h" +#include "src/runtime/kernel/arm/fp32/crop.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Crop; + +namespace mindspore::kernel { +int CropBaseCPUKernel::Init() { return RET_OK; } + +kernel::LiteKernel *CpuCropInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Crop); + auto *kernel = new (std::nothrow) CropInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new CropCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuCropInt32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Crop); + auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new CropCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuCropFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Crop); + auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new CropCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Crop, CpuCropInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Crop, CpuCropInt32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Crop, CpuCropFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h new file mode 100644 index 0000000000..f4ad763b5f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CROP_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CROP_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/crop_parameter.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class CropBaseCPUKernel : public LiteKernel { + public: + CropBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + opParameter->thread_num_ = ctx->thread_num_; + } + ~CropBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + int thread_count_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CROP_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc new file mode 100644 index 0000000000..b18f7dc9bb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/depth_to_space_base.h" +#include "src/runtime/kernel/arm/nnacl/depth_to_space.h" +#include "src/runtime/kernel/arm/fp32/depth_to_space.h" +#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthToSpace; + +namespace mindspore::kernel { +int DepthToSpaceBaseCPUKernel::Init() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + if (param->block_size_ <= 0) { + MS_LOG(ERROR) << "Input block_size should > 0!"; + return RET_PARAM_INVALID; + } + auto shape_size = inputs_[0]->shape().size(); + if (shape_size != DIMENSION_4D) { + MS_LOG(ERROR) << "Input shape size should be " << DIMENSION_4D; + return RET_PARAM_INVALID; + } + int32_t in_strides[DIMENSION_4D]; + ComputeStrides(const_cast(inputs_[0]->shape().data()), in_strides, shape_size); + param->in_stride_dim0_ = in_strides[0]; + param->in_stride_dim1_ = in_strides[1]; + param->in_stride_dim2_ = in_strides[2]; + int32_t out_strides[DIMENSION_4D]; + ComputeStrides(const_cast(outputs_[0]->shape().data()), out_strides, shape_size); + param->out_stride_dim0_ = out_strides[0]; + param->out_stride_dim1_ = out_strides[1]; + param->out_stride_dim2_ = out_strides[2]; + return RET_OK; +} + +kernel::LiteKernel *CpuDepthToSpaceInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace); + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) DepthToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchToSpaceInt8CPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace); + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(op_parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new DepthToSpaceCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthToSpace, CpuDepthToSpaceFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DepthToSpace, CpuDepthToSpaceInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h new file mode 100644 index 0000000000..32974271aa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/depth_to_space.h" + +namespace mindspore::kernel { +class DepthToSpaceBaseCPUKernel : public LiteKernel { + public: + DepthToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + opParameter->thread_num_ = ctx->thread_num_; + } + + virtual ~DepthToSpaceBaseCPUKernel() = default; + + int Init() override; + + int ReSize() override { return 0; } + + int Run() override { return 0; } +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc new file mode 100644 index 0000000000..4f74e94360 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/fullconnection_base.h" +#include "src/runtime/kernel/arm/int8/fullconnection_int8.h" +#include "src/runtime/kernel/arm/fp32/fullconnection.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_FullConnection; + +namespace mindspore::kernel { +int FullconnectionBaseCPUKernel::Init() { + fc_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_FullConnection, CpuFullConnectionInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FullConnection, CpuFullConnectionFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h new file mode 100644 index 0000000000..3d29519370 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_FULLCONNECTION_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_FULLCONNECTION_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FullconnectionBaseCPUKernel : public LiteKernel { + public: + FullconnectionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + fc_param_ = reinterpret_cast(opParameter); + } + ~FullconnectionBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + MatMulParameter *fc_param_; + int thread_count_; + int thread_stride_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_FULLCONNECTION_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc new file mode 100644 index 0000000000..a97c392bf5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "mindspore/core/utils/log_adapter.h" + +using mindspore::schema::Format; +namespace mindspore::kernel { +LayoutConvertor LayoutTransformFp32(schema::Format src_format, schema::Format dst_format) { + // todo + if (src_format == schema::Format_NHWC && dst_format == schema::Format_NC4HW4) { + return PackNHWCToNC4HW4Fp32; + } else if (src_format == schema::Format_NHWC && dst_format == schema::Format_NHWC4) { + return PackNHWCToNHWC4Fp32; + } else if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC4) { + return PackNC4HW4ToNHWC4Fp32; + } else if (src_format == schema::Format_NCHW && dst_format == schema::Format_NC4HW4) { + return PackNCHWToNC4HW4Fp32; + } else if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC) { + return PackNC4HW4ToNHWCFp32; + } else { + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(src_format) << " to " + << schema::EnumNameFormat(dst_format); + return nullptr; + } +} + +LayoutConvertor LayoutTransformInt8(schema::Format src_format, schema::Format dst_format) { + // todo + if (src_format == schema::Format_NHWC && dst_format == schema::Format_NHWC4) { + return PackNHWCToNHWC4Int8; + } else { + return nullptr; + } +} + +LayoutConvertor LayoutTransform(TypeId data_type, schema::Format src_format, schema::Format dst_format) { + // todo + switch (data_type) { + case kNumberTypeInt8: + return LayoutTransformInt8(src_format, dst_format); + case kNumberTypeFloat32: + return LayoutTransformFp32(src_format, dst_format); + default: + return nullptr; + } +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h new file mode 100644 index 0000000000..99f37d923f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_LAYOUT_TRANSFORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_LAYOUT_TRANSFORM_H_ + +#ifdef ENABLE_FP16 +#include +#endif +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "ir/dtype/type_id.h" +#include "schema/ops_generated.h" + +namespace mindspore::kernel { +typedef void (*LayoutConvertor)(const void *src, void *dst, int batch, int plane, int channel); +#ifdef ENABLE_FP16 +LayoutConvertor LayoutTransformFp16(schema::Format src_format, schema::Format dst_format); +#endif + +LayoutConvertor LayoutTransformFp32(schema::Format src_format, schema::Format dst_format); + +LayoutConvertor LayoutTransformInt8(schema::Format src_format, schema::Format dst_format); + +LayoutConvertor LayoutTransform(TypeId data_type, schema::Format src_format, schema::Format dst_format); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_LAYOUT_TRANSFORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc new file mode 100644 index 0000000000..eb88cfb4b3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/matmul_base.h" +#include "src/runtime/kernel/arm/fp32/matmul.h" +#include "src/runtime/kernel/arm/int8/matmul_int8.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_MatMul; + +namespace mindspore::kernel { +kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + case kNumberTypeUInt8: { + kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + break; + } + + case kNumberTypeFloat32: { + kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + break; + } + + default: + break; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h new file mode 100644 index 0000000000..d8bc0624a2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATMUL_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATMUL_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class MatmulBaseCPUKernel : public LiteKernel { + public: + MatmulBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + params_ = reinterpret_cast(opParameter); + } + ~MatmulBaseCPUKernel() = default; + + int Init() override { return 0; } + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + MatMulParameter *params_; + int thread_count_; + int thread_stride_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATMUL_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc b/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc new file mode 100644 index 0000000000..a979c70d67 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/matrix.h" +#include "utils/log_adapter.h" + +namespace mindspore::kernel { +Matrix *TransformMatrixGenerator(int m, int k) { + auto matrix = new Matrix; + auto aa = malloc(m * k * sizeof(float)); + matrix->SetData(aa); + matrix->SetNum(m, k); +// matrix->data_ = malloc(m * k * sizeof(float)); +// matrix->m_ = m; +// matrix->k_ = k; +// matrix->row_major_ = true; + return matrix; +} + +void ChooseMatrixG(Matrix *matrix_g, Matrix *matrix_gt) { + int m = matrix_g->GetM(); + int k = matrix_g->GetK(); + auto matrix_g_data = reinterpret_cast(matrix_g->GetData()); + auto matrix_gt_data = reinterpret_cast(matrix_gt->GetData()); + // m represents input unit, only 4 or 8 can be accepted for input unit. + // k represents kernel unit, varies from 2 to 7. + if (m == 4 && k == 2) { + MatrixG4x2(matrix_g_data); + MatrixGT2x4(matrix_gt_data); + } else if (m == 8 && k == 2) { + MatrixG8x2(matrix_g_data); + MatrixGT2x8(matrix_gt_data); + } else if (m == 8 && k == 3) { + MatrixG8x3(matrix_g_data); + MatrixGT3x8(matrix_gt_data); + } else if (m == 8 && k == 4) { + MatrixG8x4(matrix_g_data); + MatrixGT4x8(matrix_gt_data); + } else if (m == 8 && k == 5) { + MatrixG8x5(matrix_g_data); + MatrixGT5x8(matrix_gt_data); + } else if (m == 8 && k == 6) { + MatrixG8x6(matrix_g_data); + MatrixGT6x8(matrix_gt_data); + } else if (m == 8 && k == 7) { + MatrixG8x7(matrix_g_data); + MatrixGT7x8(matrix_gt_data); + } else { + MS_LOG(ERROR) << "Unsupported input unit or kernel unit."; + return; + } +} + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, bool row) { + // row-major implementation + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matrix.h b/mindspore/lite/src/runtime/kernel/arm/base/matrix.h new file mode 100644 index 0000000000..eb0c5e4cb2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/matrix.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_ + +#include +#include +#include "src/runtime/kernel/arm/nnacl/winograd_utils.h" + +namespace mindspore::kernel { +class Matrix { + public: + Matrix() = default; + ~Matrix() { + if (data_ != nullptr) { + free(data_); + } + } + + void SetData(void *data) { this->data_ = data; } + + void *GetData() { return this->data_; } + + void SetNDim(int dim) { this->n_dim_ = dim; } + + int GetNDim() { return this->n_dim_; } + + void SetShape(std::vector shape) { this->shape_ = shape; } + + std::vector GetShape() { return this->shape_; } + + void SetStride(std::vector stride) { this->stride_ = stride; } + + std::vector GetStride() { return this->stride_; } + + void SetNum(int m, int k) { + this->m_ = m; + this->k_ = k; + } + + int GetM() { return this->m_; } + + int GetK() { return this->k_; } + + protected: + void *data_; + std::vector shape_; + std::vector stride_; + int m_; + int k_; + int n_dim_; + bool row_major_; +}; +// struct Matrix { +// void *data_; +// int *shape_; +// int *stride_; +// int m_; +// int k_; +// int n_dim_; +// bool row_major_; +// ~Matrix() { +// if (data_ != nullptr) { +// free(data_); +// } +// if (shape_ != nullptr) { +// free(shape_); +// } +// if (shape_ != nullptr) { +// free(stride_); +// } +// } +//}; + +Matrix *TransformMatrixGenerator(int m, int k); + +// Chinese Remainder Theorem interp: 0.5 +void ChooseMatrixG(Matrix *matrix_g, Matrix *matrix_gt); + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, bool row); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc new file mode 100644 index 0000000000..723657b603 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "src/runtime/kernel/arm/fp32/pad.h" +#include "src/runtime/kernel/arm/int8/pad_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pad; + +namespace mindspore::kernel { + +kernel::LiteKernel *CpuPadInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Pad); + auto *kernel = new (std::nothrow) PadInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PadCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Pad); + auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PadCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Pad, CpuPadInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Pad, CpuPadFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc new file mode 100644 index 0000000000..68dedfe351 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/pooling_base.h" +#include +#include "src/runtime/kernel/arm/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/fp32/pooling.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pooling; + +namespace mindspore::kernel { +int PoolingBaseCPUKernel::SetQuantParam() { + // per tensor init + pooling_quant_arg_ = reinterpret_cast(malloc(2 * sizeof(QuantArg *))); + pooling_quant_arg_[0] = reinterpret_cast(malloc(sizeof(QuantArg))); + pooling_quant_arg_[1] = reinterpret_cast(malloc(sizeof(QuantArg))); + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_arg = input_tensor->GetQuantParams(); + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_arg = out_tensor->GetQuantParams(); + pooling_quant_arg_[0][0].scale_ = in_quant_arg.front().scale; + pooling_quant_arg_[0][0].zp_ = in_quant_arg.front().zeroPoint; + pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale; + pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint; + pooling_param_->quant_args_ = pooling_quant_arg_; + return RET_OK; +} + +void PoolingBaseCPUKernel::FreeQuantParam() { + if (pooling_quant_arg_ != nullptr) { + for (int i = 0; i < 2; ++i) { + if (*(pooling_quant_arg_ + i) != nullptr) { + free(*(pooling_quant_arg_ + i)); + } + } + } +} + +int PoolingBaseCPUKernel::Init() { + MS_ASSERT(inputs_.size() == 1); + MS_ASSERT(outputs_.size() == 1); + pooling_param_->thread_num_ = thread_count_; + MS_ASSERT(this->opParameter != nullptr); + auto in_tensor = this->inputs_.front(); + auto out_tensor = this->outputs_.front(); + MS_ASSERT(in_tensor != nullptr); + MS_ASSERT(out_tensor != nullptr); + pooling_param_->input_batch_ = in_tensor->Batch(); + pooling_param_->input_channel_ = in_tensor->Channel(); + pooling_param_->input_h_ = in_tensor->Height(); + pooling_param_->input_w_ = in_tensor->Width(); + pooling_param_->output_batch_ = out_tensor->Batch(); + pooling_param_->output_channel_ = out_tensor->Channel(); + pooling_param_->output_h_ = out_tensor->Height(); + pooling_param_->output_w_ = out_tensor->Width(); + return RET_OK; +} + +kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); + auto *kernel = new (std::nothrow) PoolingInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PoolingInt8CPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuPoolingFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); + auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PoolingCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Pooling, CpuPoolingInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Pooling, CpuPoolingFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h new file mode 100644 index 0000000000..a601db0bb3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POOLING_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POOLING_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" +#include "include/errorcode.h" + +using mindspore::lite::Context; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +namespace mindspore::kernel { +class PoolingBaseCPUKernel : public LiteKernel { + public: + PoolingBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + pooling_param_ = reinterpret_cast(opParameter); + } + ~PoolingBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return RET_OK; } + int Run() override { return RET_OK; } + int SetQuantParam(); + void FreeQuantParam(); + + protected: + int thread_count_; + const Context *ctx_; + PoolingParameter *pooling_param_; + QuantArg **pooling_quant_arg_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POOLING_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc new file mode 100644 index 0000000000..e09ec8889c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/kernel/arm/base/prior_box.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_PriorBox; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 2; +constexpr int kOutputNum = 1; +} // namespace +int PriorBoxCPUKernel::Init() { + if (prior_box_param_ == nullptr) { + MS_LOG(ERROR) << "PriorBoxParameter nullptr"; + return RET_NULL_PTR; + } + MS_ASSERT(inputs_.size() == kInputNum); + MS_ASSERT(outputs_.size() == kOutputNum); + + auto ret = GeneratePriorBox(); + + return ret; +} + +int PriorBoxCPUKernel::GeneratePriorBox() { + const int fmap_w = inputs_[0]->Width(); + const int fmap_h = inputs_[0]->Height(); + + const int image_w = prior_box_param_->image_size_w > 0 ? prior_box_param_->image_size_w : inputs_[1]->Width(); + const int image_h = prior_box_param_->image_size_h > 0 ? prior_box_param_->image_size_h : inputs_[1]->Height(); + + const float step_w = + prior_box_param_->step_w > 0.0f ? prior_box_param_->step_w : static_cast(image_w) / fmap_w; + const float step_h = + prior_box_param_->step_h > 0.0f ? prior_box_param_->step_h : static_cast(image_h) / fmap_h; + + std::vector different_aspect_ratios{1.0f}; + auto aspect_ratios = prior_box_param_->aspect_ratios; + MS_ASSERT(aspect_ratios != nullptr); + for (auto i = 0; i < prior_box_param_->aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), + [&](float v) { return abs(ratio - v) < 1e-6; }); + if (!exist) { + different_aspect_ratios.emplace_back(ratio); + if (prior_box_param_->flip) { + different_aspect_ratios.emplace_back(1.0f / ratio); + } + } + } + + for (int i = 0; i < fmap_h; i++) { + float cy = i + prior_box_param_->offset; + for (int j = 0; j < fmap_w; j++) { + float cx = j + prior_box_param_->offset; + for (auto k = 0; k < prior_box_param_->min_sizes_size; k++) { + float min_size = prior_box_param_->min_sizes[k]; + output_.emplace_back((cx - min_size / step_w * 0.5f) / fmap_w); + output_.emplace_back((cy - min_size / step_h * 0.5f) / fmap_h); + output_.emplace_back((cx + min_size / step_w * 0.5f) / fmap_w); + output_.emplace_back((cy + min_size / step_h * 0.5f) / fmap_h); + + if (prior_box_param_->max_sizes_size > 0) { + float max_size = prior_box_param_->max_sizes[k]; + float prime = sqrt(min_size * max_size); + output_.emplace_back((cx - prime / step_w * 0.5f) / fmap_w); + output_.emplace_back((cy - prime / step_h * 0.5f) / fmap_h); + output_.emplace_back((cx + prime / step_w * 0.5f) / fmap_w); + output_.emplace_back((cy + prime / step_h * 0.5f) / fmap_h); + } + + for (auto v : different_aspect_ratios) { + if (abs(v - 1.0f) < 1e-6) { + continue; + } + float as_square_root = sqrt(v); + float box_w = min_size * as_square_root; + float box_h = min_size / as_square_root; + output_.emplace_back((cx - box_w / step_w * 0.5f) / fmap_w); + output_.emplace_back((cy - box_h / step_h * 0.5f) / fmap_h); + output_.emplace_back((cx + box_w / step_w * 0.5f) / fmap_w); + output_.emplace_back((cy + box_h / step_h * 0.5f) / fmap_h); + } + } + } + } + + // do clip + if (prior_box_param_->clip) { + for (auto item : output_) { + if (item > 1.0f) { + item = 1.0f; + } + if (item < 0.0f) { + item = 0.0f; + } + } + } + + // variance + for (auto i = 0; i < outputs_[0]->Height() / PRIOR_BOX_VAR_NUM; i++) { + for (auto j = 0; j < PRIOR_BOX_VAR_NUM; j++) { + output_.emplace_back(prior_box_param_->variances[j]); + } + } + return RET_OK; +} + +int PriorBoxCPUKernel::PriorBoxImpl(int task_id) { + auto src = output_.data(); + auto output = outputs_.at(0); + if (output == nullptr) { + return RET_NULL_PTR; + } + auto ret = PriorBox(src, reinterpret_cast(output->Data()), output_.size(), task_id, thread_count_); + return ret; +} + +int RunPriorBox(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto prior_box = reinterpret_cast(cdata); + + auto error_code = prior_box->PriorBoxImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Resize Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PriorBoxCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(RunPriorBox, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "PriorBox run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + if (desc.type != schema::PrimitiveType_PriorBox) { + MS_LOG(ERROR) << "PriorBox invalid desc type " << desc.type; + return nullptr; + } + auto *kernel = new (std::nothrow) PriorBoxCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PriorBoxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PriorBox, CpuPriorBoxKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_PriorBox, CpuPriorBoxKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h new file mode 100644 index 0000000000..d4867c620f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PRIOR_BOX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PRIOR_BOX_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" +#include "src/runtime/kernel/arm/nnacl/prior_box.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class PriorBoxCPUKernel : public LiteKernel { + public: + PriorBoxCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + prior_box_param_ = reinterpret_cast(opParameter); + } + ~PriorBoxCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int PriorBoxImpl(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + + private: + std::vector output_; + PriorBoxParameter *prior_box_param_; + int GeneratePriorBox(); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PRIOR_BOX_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc new file mode 100644 index 0000000000..ef5516be2d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/quant_dtype_cast.h" +#include +#include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_QuantDTypeCast; + +namespace mindspore::kernel { +namespace { +constexpr int kQuantDTypeCastInputNum = 1; +constexpr int kQuantDTypeCastOutputNum = 1; +} // namespace + +int QuantDTypeCastCPUKernel::Init() { + if (inputs_.size() != 1) { + MS_LOG(ERROR) << "inputs number should be 1, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + if (outputs_.size() != 1) { + MS_LOG(ERROR) << "outputs number should be 1, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + auto in_tensor = inputs_.front(); + auto out_tensor = outputs_.front(); + auto param = reinterpret_cast(opParameter); + if (param->srcT == kNumberTypeFloat32 && param->dstT == kNumberTypeInt8) { + if (in_tensor->data_type() != kNumberTypeFloat32 || out_tensor->data_type() != kNumberTypeInt8) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + inverse_ = false; + } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeFloat32) { + if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat32) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + inverse_ = true; + } else { + MS_LOG(ERROR) << "param data type not supported."; + return RET_ERROR; + } + + num_unit_ = static_cast(in_tensor->DataSize()); + thread_n_num_ = MSMIN(thread_num_, num_unit_); + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + + return RET_OK; +} + +int QuantDTypeCastCPUKernel::ReSize() { return RET_OK; } + +int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_n_stride_; + auto quant_arg = inputs_.front()->GetQuantParams().front(); + int ret; + if (inverse_) { + ret = DequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, + num_unit_thread); + } else { + ret = QuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int QuantDTypeCastRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->QuantDTypeCast(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int QuantDTypeCastCPUKernel::Run() { + if (inverse_) { + int8_ptr_ = reinterpret_cast(inputs_[0]->Data()); + float32_ptr_ = reinterpret_cast(outputs_[0]->Data()); + } else { + float32_ptr_ = reinterpret_cast(inputs_[0]->Data()); + int8_ptr_ = reinterpret_cast(outputs_[0]->Data()); + } + + int ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) QuantDTypeCastCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new QuantDTypeCastCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed! name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h new file mode 100644 index 0000000000..0ea72b2ddc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_QUANTDTYPECAST_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_QUANTDTYPECAST_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class QuantDTypeCastCPUKernel : public LiteKernel { + public: + QuantDTypeCastCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + ~QuantDTypeCastCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int QuantDTypeCast(int task_id); + + private: + int thread_num_; + int thread_n_num_; + int thread_n_stride_; + int num_unit_; + int8_t *int8_ptr_; + float *float32_ptr_; + bool inverse_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_QUANTDTYPECAST_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc new file mode 100644 index 0000000000..479bfce181 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/reshape_base.h" +#include +#include "src/runtime/kernel/arm/int8/reshape_int8.h" +#include "src/runtime/kernel/arm/fp32/reshape.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reshape; + +namespace mindspore::kernel { +int ReshapeBaseCPUKernel::Init() { + reshape_param_->thread_count_ = thread_count_; + return RET_OK; +} + +kernel::LiteKernel *CpuReshapeInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); + auto *kernel = new (std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuReshapeInt32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); + auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuReshapeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); + auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ReshapeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reshape, CpuReshapeInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Reshape, CpuReshapeInt32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reshape, CpuReshapeFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h new file mode 100644 index 0000000000..2c0ca7ea30 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeBaseCPUKernel : public LiteKernel { + public: + ReshapeBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + reshape_param_ = reinterpret_cast(opParameter); + } + ~ReshapeBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + int thread_count_; + const Context *ctx_; + ReshapeParameter *reshape_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc new file mode 100644 index 0000000000..f5f348e9ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/softmax_base.h" +#include +#include "src/runtime/kernel/arm/int8/softmax_int8.h" +#include "src/runtime/kernel/arm/fp32/softmax.h" +#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_NULL_PTR; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore::kernel { + +int SoftmaxBaseCPUKernel::Init() { + if (softmax_param_ == nullptr) { + MS_LOG(ERROR) << "SoftmaxParameter nullptr"; + return RET_NULL_PTR; + } + + auto input_tensor = inputs_.front(); + auto in_shape = input_tensor->shape(); + auto in_dims = in_shape.size(); + int ele_size = 1; + softmax_param_->n_dim_ = in_dims; + for (size_t i = 0; i < in_dims; i++) { + softmax_param_->input_shape_[i] = in_shape[i]; + ele_size *= in_shape[i]; + } + softmax_param_->element_size_ = ele_size; + return RET_OK; +} + +kernel::LiteKernel *CpuSoftmaxInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); + auto *kernel = new (std::nothrow) SoftmaxInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); + auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SoftMax, CpuSoftmaxInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h new file mode 100644 index 0000000000..4e8873aec0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" + +namespace mindspore::kernel { +class SoftmaxBaseCPUKernel : public LiteKernel { + public: + SoftmaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + opParameter->thread_num_ = ctx->thread_num_; + softmax_param_ = reinterpret_cast(opParameter); + } + ~SoftmaxBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + int thread_count_; + const lite::Context *ctx_; + SoftmaxParameter *softmax_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SOFTMAX_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc new file mode 100644 index 0000000000..7e04a8ab2d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/split_base.h" +#include +#include "src/runtime/kernel/arm/int8/split_int8.h" +#include "src/runtime/kernel/arm/fp32/split.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Split; + +namespace mindspore::kernel { +int SplitBaseCPUKernel::Init() { + auto in_tensor = inputs_.front(); + auto input_shape = in_tensor->shape(); + + param->strides_[input_shape.size() - 1] = 1; + for (int i = input_shape.size() - 2; i >= 0; i--) { + param->strides_[i] = param->strides_[i + 1] * input_shape[i + 1]; + } + + param->split_count_ = + param->strides_[0] * input_shape[0] / (input_shape[param->split_dim_] * param->strides_[param->split_dim_]); + param->n_dims_ = input_shape.size(); + + if (param->split_sizes_[0] == 0) { + if (input_shape[param->split_dim_] % param->num_split_ != 0) { + MS_LOG(ERROR) << "Default split size is not usable."; + return RET_ERROR; + } + int split_size = input_shape[param->split_dim_] / param->num_split_; + for (int i = 0; i < param->num_split_; i++) { + param->split_sizes_[i] = split_size; + } + } + + num_unit_ = param->split_count_ * param->num_split_; + thread_n_num_ = MSMIN(thread_count_, num_unit_); + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + return RET_OK; +} + +kernel::LiteKernel *CpuSplitInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Split); + auto *kernel = new (std::nothrow) SplitInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SplitCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuSplitInt32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Split); + auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SplitCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuSplitFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Split); + auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SplitCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Split, CpuSplitInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Split, CpuSplitInt32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Split, CpuSplitFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.h b/mindspore/lite/src/runtime/kernel/arm/base/split_base.h new file mode 100644 index 0000000000..0f90604cfc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class SplitBaseCPUKernel : public LiteKernel { + public: + SplitBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + param = reinterpret_cast(opParameter); + } + ~SplitBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + int thread_count_; + const Context *ctx_; + int thread_n_stride_; + int thread_n_num_; + int num_unit_; + SplitParameter *param; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SPLIT_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc new file mode 100644 index 0000000000..886bec2e68 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/strided_slice.h" +#include +#include "src/runtime/kernel/arm/nnacl/strided_slice.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_StridedSlice; + +namespace mindspore::kernel { + +int StridedSliceCPUKernel::Init() { + auto input = inputs_.at(0); + auto parameter = reinterpret_cast(opParameter); + MS_ASSERT(input); + MS_ASSERT(parameter); + parameter->data_type = input->data_type() == kNumberTypeInt8 ? kDataTypeInt8 : kDataTypeFloat; + return RET_OK; +} + +int StridedSliceCPUKernel::ReSize() { return 0; } + +int StridedSliceCPUKernel::Run() { + auto input = inputs_.at(0); + auto output = outputs_.at(0); + MS_ASSERT(input); + MS_ASSERT(output); + + auto ret = DoStridedSlice(input->Data(), output->Data(), reinterpret_cast(opParameter)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuStridedSliceKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_StridedSlice); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter null pointer dereferencing."; + return nullptr; + } + auto *kernel = new (std::nothrow) StridedSliceCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSlice, CpuStridedSliceKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_StridedSlice, CpuStridedSliceKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h new file mode 100644 index 0000000000..f6d8845ad1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_ + +#include +#include "ir/anf.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class StridedSliceCPUKernel : public LiteKernel { + public: + StridedSliceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + ~StridedSliceCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + int thread_num_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc new file mode 100644 index 0000000000..0f2e18111c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -0,0 +1,286 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param) { + auto input_channel = conv_param->input_channel_; + auto output_channel = conv_param->output_channel_; + auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oC8 = UP_DIV(output_channel, C8NUM); + + size_t tmp_size = oC8 * C8NUM * iC4 * C4NUM * kernel_plane * sizeof(float16_t); + auto tmp_addr = reinterpret_cast(malloc(tmp_size)); + memset(tmp_addr, 0, tmp_size); + + PackWeightToC4Fp16(origin_weight, tmp_addr, conv_param); + Conv3x3Fp16FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane); + + free(tmp_addr); +} + +int Convolution3x3FP16CPUKernel::InitWeightBias() { + auto input_channel = conv_param_->input_channel_; + int output_channel = conv_param_->output_channel_; + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oC8 = UP_DIV(output_channel, C8NUM); + // init weight + size_t transformed_size = iC4 * C4NUM * oC8 * C8NUM * 36 * sizeof(float16_t); + transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); + if (transformed_filter_addr_ == nullptr) { + MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; + return RET_ERROR; + } + memset(transformed_filter_addr_, 0, transformed_size); + float *origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t); + fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); + if (fp16_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_weight_ failed."; + return RET_ERROR; + } + memset(fp16_weight_, 0, fp16_weight_size); + for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { + fp16_weight_[i] = (float16_t)origin_weight[i]; + } + ProcessFilterFp16(fp16_weight_, transformed_filter_addr_, conv_param_); + + // init bias + size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); + bias_data_ = malloc(new_bias_size); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, new_bias_size); + auto fp16_bias_data = reinterpret_cast(bias_data_); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + for (int i = 0; i < output_channel; ++i) { + fp16_bias_data[i] = (float16_t)ori_bias_addr[i]; + } + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::InitTmpBuffer() { + int tile_num = 16; + int k_plane = 36; + int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM); + + /*=============================tile_buffer_============================*/ + size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC4 * C4NUM * sizeof(float16_t); + tile_buffer_ = reinterpret_cast(malloc(tile_buffer_size)); + if (tile_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tile_buffer_ failed."; + return RET_ERROR; + } + memset(tile_buffer_, 0, tile_buffer_size); + + /*=============================block_unit_buffer_============================*/ + size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float16_t); + block_unit_buffer_ = reinterpret_cast(malloc(block_unit_buffer_size)); + if (block_unit_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; + return RET_ERROR; + } + memset(block_unit_buffer_, 0, block_unit_buffer_size); + + /*=============================tmp_dst_buffer_============================*/ + size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float16_t); + tmp_dst_buffer_ = reinterpret_cast(malloc(tmp_dst_buffer_size)); + if (tmp_dst_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; + return RET_ERROR; + } + memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); + + /*=============================tmp_out_============================*/ + int new_out_plane = UP_DIV(conv_param_->output_h_, C4NUM) * UP_DIV(conv_param_->output_w_, C4NUM) * C4NUM * C4NUM; + size_t tmp_out_size = + oC8 * C8NUM * conv_param_->output_batch_ * new_out_plane * sizeof(float16_t); + tmp_out_ = reinterpret_cast(malloc(tmp_out_size)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + memset(tmp_out_, 0, tmp_out_size); + + /*=============================fp16_input_============================*/ + size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * + conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = reinterpret_cast(malloc(fp16_input_size)); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_input_ failed."; + return RET_ERROR; + } + memset(fp16_input_, 0, fp16_input_size); + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = + iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + /*=============================fp16_out_============================*/ + size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ * + conv_param_->output_w_ * sizeof(float16_t); + fp16_out_ = reinterpret_cast(malloc(fp16_output_size)); + if (fp16_out_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_out_ failed."; + return RET_ERROR; + } + return RET_OK; +} + +void Convolution3x3FP16CPUKernel::ConfigInputOutput() { + auto input_tensor = inputs_.at(kInputIndex); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + convert_func_ = LayoutTransformFp16(input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return; + } +} + +int Convolution3x3FP16CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::ReSize() { + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + if (fp16_out_ != nullptr) { + free(fp16_out_); + } + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { + Conv3x3Fp16(reinterpret_cast(nhwc4_input_), transformed_filter_addr_, + reinterpret_cast(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, + tmp_out_, task_id, conv_param_); + return RET_OK; +} + +int Convolution3x3Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution3x3 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::Run() { + // cast fp32 input data to fp16 + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = reinterpret_cast(input_tensor->Data()); + for (int i = 0; i < input_tensor->ElementsNum(); ++i) { + fp16_input_[i] = (float16_t)ori_input_data[i]; + } + + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(reinterpret_cast(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv3x3 fp16 error error_code[" << error_code << "]"; + return RET_ERROR; + } + + // cast fp16 out to fp32 data + auto out_tensor = outputs_.at(kOutputIndex); + auto output_addr = reinterpret_cast(out_tensor->Data()); + for (int j = 0; j < out_tensor->ElementsNum(); ++j) { + output_addr[j] = static_cast(fp16_out_[j]); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h new file mode 100644 index 0000000000..e6f8862e9b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" + +namespace mindspore::kernel { +class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution3x3FP16CPUKernel() override { + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + } + if (fp16_out_ != nullptr) { + free(fp16_out_); + } + if (transformed_filter_addr_ != nullptr) { + free(transformed_filter_addr_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float16_t *fp16_input_; + float16_t *fp16_weight_; + float16_t *fp16_out_; + float16_t *transformed_filter_addr_; + float16_t *tile_buffer_; + float16_t *block_unit_buffer_; + float16_t *tmp_dst_buffer_; + float16_t *tmp_out_; +}; +void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc new file mode 100644 index 0000000000..ec027deacd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int ConvolutionDepthwiseFp16CPUKernel::InitBuffer() { + // malloc pack input buffer + int C8 = UP_DIV(conv_param_->input_channel_, C8NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM * C8; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float16_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_input_, 0, pack_input_size * sizeof(float16_t)); + + // malloc pack output buffer + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM * C8; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float16_t))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() { + // init weight: o, h, w, i; o == group, i == 1 + int OC8 = UP_DIV(conv_param_->output_channel_, C8NUM); + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + int pack_weight_size = C8NUM * OC8 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); + PackNCHWFp32ToNC8HW8Fp16(origin_weight, packed_weight_, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // init bias + bias_data_ = reinterpret_cast(malloc(C8NUM * OC8 * sizeof(float16_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C8NUM * OC8 * sizeof(float16_t)); + auto bias_fp16 = reinterpret_cast(bias_data_); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + for (int i = 0; i < conv_param_->output_channel_; i++) { + bias_fp16[i] = (float16_t)ori_bias[i]; + } + } + + conv_param_->thread_num_ = MSMIN(thread_count_, OC8); + return RET_OK; +} + +int ConvolutionDepthwiseFp16CPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding_ window param + sliding_ = new SlidingWindowParam; + InitSlidingParam(sliding_, conv_param_, C8NUM); + + auto ret = InitWeightBias(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp16 InitWeightBias failed."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp16 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseFp16CPUKernel::ReSize() { + free(packed_input_); + free(packed_output_); + + ConvolutionBaseCPUKernel::Init(); + InitSlidingParam(sliding_, conv_param_, C8NUM); + + auto ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp16 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseFp16CPUKernel::Execute(int task_id) { + ConvDwC8Fp16(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding_, task_id); + return RET_OK; +} + +int ConvDwFp16Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw_fp16 = reinterpret_cast(cdata); + auto ret = conv_dw_fp16->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwiseFp16Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseFp16CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + // pack input: to nhwc8 + PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_, + conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); + + auto ret = LiteBackendParallelLaunch(ConvDwFp16Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDwFp16Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + PackNHWC8Fp16ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + + return RET_OK; +} + +kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DepthwiseConv2D, CpuConvDwFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h new file mode 100644 index 0000000000..5b81404e3e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DEPTHWISE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DEPTHWISE_FP16_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.h" + +namespace mindspore::kernel { +class ConvolutionDepthwiseFp16CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwiseFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionDepthwiseFp16CPUKernel() override { + delete sliding_; + free(packed_weight_); + free(packed_input_); + free(packed_output_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + int InitBuffer(); + int InitWeightBias(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding_; + float16_t *packed_weight_; + float16_t *packed_input_; + float16_t *packed_output_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DEPTHWISE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc new file mode 100644 index 0000000000..92277efdf6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -0,0 +1,287 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +int ConvolutionFP16CPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int oc8 = UP_DIV(out_channel, C8NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; + + // init weight + float *origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); + fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); + if (fp16_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_weight_ failed."; + return RET_ERROR; + } + for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { + fp16_weight_[i] = (float16_t)origin_weight[i]; + } + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); + PackWeightFp16(fp16_weight_, conv_param_, packed_weight_); + + // init bias + bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc8 * C8NUM * sizeof(float16_t)); + auto fp16_bias_data = reinterpret_cast(bias_data_); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + for (int i = 0; i < out_channel; ++i) { + fp16_bias_data[i] = (float16_t)ori_bias[i]; + } + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionFP16CPUKernel::InitTmpBuffer() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_batch = conv_param_->input_batch_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int channel_block = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + + // malloc packed_inputs + int cal_num = 16; + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, cal_num); + int unit_size = kernel_plane * channel_block * C4NUM; + int packed_input_size = output_tile_count * cal_num * unit_size; + + /*=============================packed_input_============================*/ + packed_input_ = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_input_ failed."; + return RET_ERROR; + } + memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float16_t)); + + /*=============================fp16_input_============================*/ + size_t fp16_input_size = + in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = reinterpret_cast(malloc(fp16_input_size)); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_input_ failed."; + return RET_ERROR; + } + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * + conv_param_->input_w_ * sizeof(float16_t); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + /*=============================tmp_output_block_============================*/ + tmp_output_block_ = reinterpret_cast(malloc(cal_num * out_channel * sizeof(float16_t))); + if (tmp_output_block_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; + return RET_ERROR; + } + + /*=============================fp16_out_============================*/ + size_t fp16_output_size = + out_channel * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float16_t); + fp16_out_ = reinterpret_cast(malloc(fp16_output_size)); + if (fp16_out_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_out_ failed."; + return RET_ERROR; + } + return RET_OK; +} + +void ConvolutionFP16CPUKernel::ConfigInputOutput() { + auto input_tensor = inputs_.at(kInputIndex); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + convert_func_ = LayoutTransformFp16(input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return; + } +} + +int ConvolutionFP16CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + return RET_OK; +} + +int ConvolutionFP16CPUKernel::ReSize() { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_out_ != nullptr) { + free(fp16_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionFP16CPUKernel::RunImpl(int task_id) { + ConvFp16(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_output_block_, fp16_out_, task_id, conv_param_); + return RET_OK; +} + +int ConvolutionFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionFP16CPUKernel::Run() { + // cast fp32 input data to fp16 + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = reinterpret_cast(input_tensor->Data()); + for (int i = 0; i < input_tensor->ElementsNum(); ++i) { + fp16_input_[i] = (float16_t)ori_input_data[i]; + } + + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(reinterpret_cast(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionFp16Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; + return RET_ERROR; + } + + // cast fp16 out to fp32 data + auto out_tensor = outputs_.at(kOutputIndex); + auto output_addr = reinterpret_cast(out_tensor->Data()); + for (int j = 0; j < out_tensor->ElementsNum(); ++j) { + output_addr[j] = static_cast(fp16_out_[j]); + } + return RET_OK; +} + +kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + kernel::LiteKernel *kernel; + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); + } else { + kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create conv fp16 kernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h new file mode 100644 index 0000000000..8be48714da --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" + +namespace mindspore::kernel { +class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionFP16CPUKernel() override { + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + } + if (fp16_out_ != nullptr) { + free(fp16_out_); + } + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float16_t *fp16_input_; + float16_t *fp16_weight_; + float16_t *fp16_out_; + float16_t *packed_input_; + float16_t *packed_weight_; + float16_t *tmp_output_block_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc new file mode 100644 index 0000000000..4a0097b75c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; + +namespace mindspore::kernel { +int DeconvolutionDepthwiseFp16CPUKernel::InitSlideParam() { + conv_param_->input_batch_ = outputs_.front()->shape().at(kNHWC_N); + conv_param_->input_h_ = outputs_.front()->shape().at(kNHWC_H); + conv_param_->input_w_ = outputs_.front()->shape().at(kNHWC_W); + conv_param_->input_channel_ = outputs_.front()->shape().at(kNHWC_C); + conv_param_->output_batch_ = inputs_.front()->shape().at(kNHWC_N); + conv_param_->output_h_ = inputs_.front()->shape().at(kNHWC_H); + conv_param_->output_w_ = inputs_.front()->shape().at(kNHWC_W); + conv_param_->output_channel_ = inputs_.front()->shape().at(kNHWC_C); + + // init sliding_ window param + InitSlidingParam(sliding_, conv_param_, C8NUM); + return RET_OK; +} + +int DeconvolutionDepthwiseFp16CPUKernel::InitBuffer() { + // malloc pack input buffer + int C8 = UP_DIV(conv_param_->input_channel_, C8NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM * C8; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float16_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_input_, 0, pack_input_size * sizeof(float16_t)); + + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM * C8; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float16_t))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_output_, 0, pack_output_size * sizeof(float16_t)); + return RET_OK; +} + +int DeconvolutionDepthwiseFp16CPUKernel::InitWeightBias() { + // init weight: o, h, w, i; o == group, i == 1 + int OC8 = UP_DIV(conv_param_->output_channel_, C8NUM); + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + int pack_weight_size = C8NUM * OC8 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); + PackNCHWFp32ToNC8HW8Fp16(origin_weight, packed_weight_, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // init bias + bias_data_ = reinterpret_cast(malloc(C8NUM * OC8 * sizeof(float16_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C8NUM * OC8 * sizeof(float16_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + for (int i = 0; i < conv_param_->output_channel_; i++) { + reinterpret_cast(bias_data_)[i] = (float16_t)ori_bias[i]; + } + } + + conv_param_->thread_num_ = MSMIN(thread_count_, OC8); + return RET_OK; +} + +int DeconvolutionDepthwiseFp16CPUKernel::Init() { + sliding_ = new SlidingWindowParam; + InitSlideParam(); + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitWeightBias(); + if (ret != 0) { + MS_LOG(ERROR) << "Deconvolution depthwise fp16 InitWeightBias failed."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Deconvolution depthwise fp16 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseFp16CPUKernel::ReSize() { + free(packed_input_); + free(packed_output_); + + InitSlideParam(); + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp16 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseFp16CPUKernel::Execute(int task_id) { + DeconvDwC8Fp16(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding_, task_id); + return RET_OK; +} + +int DeconvDwFp16Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv_dw_fp16 = reinterpret_cast(cdata); + auto ret = deconv_dw_fp16->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvolutionDepthwiseFp16Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseFp16CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + // pack input: to nhwc8 + PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_, + conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); + + auto ret = LiteBackendParallelLaunch(DeconvDwFp16Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvDwFp16Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + PackNHWC8Fp16ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + return RET_OK; +} + +kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); + auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h new file mode 100644 index 0000000000..0e3a31682c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_DEPTHWISE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_DEPTHWISE_FP16_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.h" + +namespace mindspore::kernel { +class DeconvolutionDepthwiseFp16CPUKernel : public ConvolutionBaseCPUKernel { + public: + DeconvolutionDepthwiseFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeconvolutionDepthwiseFp16CPUKernel() override { + delete sliding_; + free(packed_weight_); + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitBuffer(); + int InitWeightBias(); + int InitSlideParam(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding_; + float16_t *packed_weight_; + float16_t *packed_input_; + float16_t *packed_output_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_DEPTHWISE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/layout_transform_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/layout_transform_fp16.cc new file mode 100644 index 0000000000..419c10b06e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/layout_transform_fp16.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "schema/ops_generated.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore::kernel { +LayoutConvertor LayoutTransformFp16(schema::Format src_format, schema::Format dst_format) { + if (src_format == schema::Format_NHWC && dst_format == schema::Format_NC4HW4) { + return PackNHWCToNC4HW4Fp16; + } else if (src_format == schema::Format_NHWC && dst_format == schema::Format_NHWC4) { + return PackNHWCToNHWC4Fp16; + } else if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC4) { + return PackNC4HW4ToNHWC4Fp16; + } else if (src_format == schema::Format_NCHW && dst_format == schema::Format_NC4HW4) { + return PackNCHWToNC4HW4Fp16; + } else if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC) { + return PackNC4HW4ToNHWCFp16; + } else { + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(src_format) << " to " + << schema::EnumNameFormat(dst_format); + return nullptr; + } +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/layout_transform_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/layout_transform_fp16.h new file mode 100644 index 0000000000..37e11da649 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/layout_transform_fp16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LAYOUT_TRANSFORM_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LAYOUT_TRANSFORM_FP16_H_ + +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/ops_generated.h" + +namespace mindspore::kernel { +LayoutConvertor LayoutTransformFp16(schema::Format src_format, schema::Format dst_format); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_LAYOUT_TRANSFORM_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc new file mode 100644 index 0000000000..e4fda400a4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/activation.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_HSWISH; +using mindspore::schema::ActivationType_LEAKY_RELU; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::PrimitiveType_Activation; + +namespace mindspore::kernel { +int ActivationCPUKernel::Init() { return RET_OK; } + +int ActivationCPUKernel::ReSize() { return RET_OK; } + +int ActivationCPUKernel::DoActivation(int task_id) { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto length = inputs_.at(0)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + auto error_code = RET_OK; + + if (type_ == schema::ActivationType_RELU) { + error_code = Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_RELU6) { + error_code = Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_LEAKY_RELU) { + error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); + } else if (type_ == schema::ActivationType_SIGMOID) { + error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_TANH) { + error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_HSWISH) { + error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else { + MS_LOG(ERROR) << "Activation type error"; + return RET_ERROR; + } + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int ActivationRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activation_kernel = reinterpret_cast(cdata); + auto error_code = activation_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ActivationRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ActivationRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Activation); + auto *kernel = new (std::nothrow) ActivationCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Activation, CpuActivationFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h new file mode 100644 index 0000000000..2e21629d79 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/activation.h" + +namespace mindspore::kernel { +class ActivationCPUKernel : public LiteKernel { + public: + ActivationCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(param, inputs, outputs), thread_count_(ctx->thread_num_) { + type_ = (reinterpret_cast(param))->type_; + alpha_ = (reinterpret_cast(param))->alpha_; + } + ~ActivationCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + + private: + int thread_count_; + int type_; + float alpha_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc new file mode 100644 index 0000000000..279832aca2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/activation_grad.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationGradType_HSWISH; +using mindspore::schema::ActivationGradType_LEAKY_RELU; +using mindspore::schema::ActivationGradType_RELU; +using mindspore::schema::ActivationGradType_RELU6; +using mindspore::schema::PrimitiveType_ActivationGrad; + +namespace mindspore::kernel { +int ActivationGradCPUKernel::Init() { + outputs_[0]->set_shape(inputs_[0]->shape()); + return RET_OK; +} + +int ActivationGradCPUKernel::ReSize() { return RET_OK; } + +int ActivationGradCPUKernel::DoActivation(int task_id) { + auto yt_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto input_addr = reinterpret_cast(inputs_.at(1)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto length = inputs_.at(0)->ElementsNum(); + + auto error_code = RET_OK; + + if (type_ == schema::ActivationGradType_RELU) { + error_code = ReluGrad(yt_addr, input_addr, length, output_addr); + } else if (type_ == schema::ActivationGradType_RELU6) { + error_code = Relu6Grad(yt_addr, input_addr, length, output_addr); + } else if (type_ == schema::ActivationGradType_LEAKY_RELU) { + error_code = LReluGrad(yt_addr, input_addr, length, output_addr, alpha_); + } else if (type_ == schema::ActivationGradType_SIGMOID) { + error_code = SigmoidGrad(yt_addr, input_addr, length, output_addr); + } else if (type_ == schema::ActivationGradType_TANH) { + error_code = TanhGrad(yt_addr, input_addr, length, output_addr); + } else if (type_ == schema::ActivationGradType_HSWISH) { + error_code = HSwishGrad(yt_addr, input_addr, length, output_addr); + } else if (type_ == schema::ActivationGradType_HSIGMOID) { + error_code = HSigmoidGrad(yt_addr, input_addr, length, output_addr); + } else { + MS_LOG(ERROR) << "Activation type error"; + return RET_ERROR; + } + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int ActivationGradRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activationGrad_kernel = reinterpret_cast(cdata); + auto error_code = activationGrad_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ActivationGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationGradCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ActivationGradRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_ActivationGrad); + auto *kernel = new (std::nothrow) ActivationGradCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ + << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ActivationGrad, CpuActivationGradFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h new file mode 100644 index 0000000000..c3de590123 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +#include "src/runtime/kernel/arm/nnacl/activation_grad.h" + +namespace mindspore::kernel { +class ActivationGradCPUKernel : public LiteKernel { + public: + explicit ActivationGradCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(param, inputs, outputs) { + ActivationGradParameter *param_act_grad = reinterpret_cast(param); + type_ = param_act_grad->type_; + alpha_ = param_act_grad->alpha_; + } + ~ActivationGradCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + + private: + int thread_count_; + int type_; + float alpha_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc new file mode 100644 index 0000000000..226843c805 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/addn.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/arithmetic.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_AddN; + +namespace mindspore::kernel { +namespace { +int AddNLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { + if (cdata == nullptr) { + MS_LOG(ERROR) << "Input cdata is nullptr!"; + return RET_NULL_PTR; + } + auto kernel = reinterpret_cast(cdata); + return kernel->AddNParallelRun(thread_id); +} +} + +int AddNCPUKernel::Init() { + elements_num_ = inputs_[0]->ElementsNum(); + return RET_OK; +} + +int AddNCPUKernel::ReSize() { return RET_OK; } + +int AddNCPUKernel::AddNParallelRun(int thread_id) { + int count_per_thread = UP_DIV(elements_num_, opParameter->thread_num_); + int count = MSMIN(count_per_thread, elements_num_ - thread_id * count_per_thread); + auto stride = count_per_thread * thread_id; + auto ret = ElementAdd(in1_addr_ + stride, in2_addr_ + stride, out_addr_ + stride, count); + if (ret != NNACL_OK) { + MS_LOG(ERROR) << "ElementAdd fail! ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +int AddNCPUKernel::Run() { + auto input0_data = reinterpret_cast(inputs_[0]->Data()); + auto input1_data = reinterpret_cast(inputs_[1]->Data()); + auto output_data = reinterpret_cast(outputs_[0]->Data()); + if (elements_num_ < opParameter->thread_num_) { + ElementAdd(input0_data, input1_data, output_data, elements_num_); + for (int i = 2; i < inputs_.size(); ++i) { + ElementAdd(reinterpret_cast(inputs_[i]->Data()), output_data, output_data, elements_num_); + } + return RET_OK; + } + in1_addr_ = input0_data; + in2_addr_ = input1_data; + out_addr_ = output_data; + int ret = LiteBackendParallelLaunch(AddNLaunch, this, opParameter->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "addn launch fail!ret: " << ret; + return RET_ERROR; + } + for (size_t i = 2; i < inputs_.size(); ++i) { + in1_addr_ = reinterpret_cast(inputs_[i]->Data()); + in2_addr_ = output_data; + ret = LiteBackendParallelLaunch(AddNLaunch, this, opParameter->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "addn launch fail!ret: " << ret << ", input index: " << i; + return RET_ERROR; + } + } + return RET_OK; +} + +kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "Input context is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_AddN); + op_parameter->thread_num_ = ctx->thread_num_; + auto *kernel = new (std::nothrow) AddNCPUKernel(op_parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new AddNCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed! name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddN, CpuAddNFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h new file mode 100644 index 0000000000..43d27fad02 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDN_H_ + +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + + +namespace mindspore::kernel { +class AddNCPUKernel : public LiteKernel { + public: + AddNCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~AddNCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int AddNParallelRun(int thread_id); + private: + float *in1_addr_; + float *in2_addr_; + float *out_addr_; + size_t elements_num_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDN_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc new file mode 100644 index 0000000000..a89eb715fd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/argminmax.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/arg_min_max.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ArgMax; +using mindspore::schema::PrimitiveType_ArgMin; + +namespace mindspore::kernel { +int ArgMinMaxCPUKernel::Init() { + auto ret = ArgMinMaxBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + auto param = reinterpret_cast(opParameter); + param->data_type_ = kNumberTypeFloat32; + return RET_OK; +} + +int ArgMinMaxCPUKernel::Run() { + auto ret = ArgMinMaxBaseCPUKernel::Run(); + ArgMinMaxBaseCPUKernel::FreeTmpMemory(); + return ret; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h new file mode 100644 index 0000000000..c6c4fb9be7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ + +#include +#include "src/runtime/kernel/arm/base/arg_min_max_base.h" + +namespace mindspore::kernel { +class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel { + public: + ArgMinMaxCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~ArgMinMaxCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc new file mode 100644 index 0000000000..3d3d55f894 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/int8/add_int8.h" +#include "src/runtime/kernel/arm/int8/mul_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Eltwise; + +namespace mindspore::kernel { + +ArithmeticCPUKernel::~ArithmeticCPUKernel() { + if (tile_data0_ != nullptr) { + delete[](tile_data0_); + tile_data0_ = nullptr; + } + if (tile_data1_ != nullptr) { + delete[](tile_data1_); + tile_data1_ = nullptr; + } +} +int ArithmeticCPUKernel::Init() { + auto element_num = outputs_[0]->ElementsNum(); + + tile_data0_ = new float[element_num]; + tile_data1_ = new float[element_num]; + + return RET_OK; +} + +int ArithmeticCPUKernel::ReSize() { return RET_OK; } + +int ArithmeticCPUKernel::DoArithmetic(int task_id) { + auto input0_data = reinterpret_cast(inputs_[0]->Data()); + auto input1_data1 = reinterpret_cast(inputs_[1]->Data()); + auto output_data = reinterpret_cast(outputs_[0]->Data()); + auto element_num = outputs_[0]->ElementsNum(); + + MS_ASSERT(thread_count_ != 0); + int stride = UP_DIV(element_num, thread_count_); + int count = MSMIN(stride, element_num - stride * task_id); + + if (arithmetic_run_ == nullptr) { + MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; + return RET_ERROR; + } + + int error_code = RET_OK; + if (arithmeticParameter_->broadcasting_) { + error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id, + output_data + stride * task_id, count); + + } else { + error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id, + output_data + stride * task_id, count); + } + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto arithmetic_kernel = reinterpret_cast(cdata); + auto error_code = arithmetic_kernel->DoArithmetic(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticCPUKernel::Run() { + if (arithmeticParameter_->broadcasting_) { + auto input_data0 = reinterpret_cast(inputs_[0]->Data()); + auto input_data1 = reinterpret_cast(inputs_[1]->Data()); + TileDimensions(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_); + } + int error_code = LiteBackendParallelLaunch(ArithmeticsRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + auto kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator) + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h new file mode 100644 index 0000000000..9fae6696c8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" +#include "schema/model_generated.h" + +using mindspore::schema::PrimitiveType_Add; +using mindspore::schema::PrimitiveType_Div; +using mindspore::schema::PrimitiveType_Equal; +using mindspore::schema::PrimitiveType_FloorDiv; +using mindspore::schema::PrimitiveType_FloorMod; +using mindspore::schema::PrimitiveType_Greater; +using mindspore::schema::PrimitiveType_GreaterEqual; +using mindspore::schema::PrimitiveType_Less; +using mindspore::schema::PrimitiveType_LessEqual; +using mindspore::schema::PrimitiveType_LogicalAnd; +using mindspore::schema::PrimitiveType_LogicalOr; +using mindspore::schema::PrimitiveType_Maximum; +using mindspore::schema::PrimitiveType_Minimum; +using mindspore::schema::PrimitiveType_Mul; +using mindspore::schema::PrimitiveType_NotEqual; +using mindspore::schema::PrimitiveType_SquaredDifference; +using mindspore::schema::PrimitiveType_Sub; + +namespace mindspore::kernel { +class ArithmeticCPUKernel : public LiteKernel { + typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size); + typedef int (*ArithmeticBroadcastRun)(float *input0, float *input1, float *tile_input0, float *tile_input1, + float *output, int element_size, ArithmeticParameter *param); + + public: + ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) { + arithmeticParameter_ = reinterpret_cast(parameter); + switch (parameter->type_) { + case PrimitiveType_Mul: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementMulRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementMulRelu6; + break; + default: + arithmetic_run_ = ElementMul; + break; + } + break; + case PrimitiveType_Add: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementAddRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementAddRelu6; + break; + default: + arithmetic_run_ = ElementAdd; + break; + } + break; + case PrimitiveType_Sub: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementSubRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementSubRelu6; + break; + default: + arithmetic_run_ = ElementSub; + break; + } + break; + case PrimitiveType_Div: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementDivRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementDivRelu6; + break; + default: + arithmetic_run_ = ElementDiv; + break; + } + break; + case PrimitiveType_LogicalAnd: + arithmetic_run_ = ElementLogicalAnd; + arithmetic_broadcast_run_ = BroadcastLogicalAnd; + break; + case PrimitiveType_LogicalOr: + arithmetic_run_ = ElementLogicalOr; + arithmetic_broadcast_run_ = BroadcastLogicalOr; + break; + case PrimitiveType_Maximum: + arithmetic_run_ = ElementMaximum; + arithmetic_broadcast_run_ = BroadcastMaximum; + break; + case PrimitiveType_Minimum: + arithmetic_run_ = ElementMinimum; + arithmetic_broadcast_run_ = BroadcastMinimum; + break; + case PrimitiveType_FloorDiv: + arithmetic_run_ = ElementFloorDiv; + arithmetic_broadcast_run_ = BroadcastFloorDiv; + break; + case PrimitiveType_FloorMod: + arithmetic_run_ = ElementFloorMod; + arithmetic_broadcast_run_ = BroadcastFloorMod; + break; + case PrimitiveType_Equal: + arithmetic_run_ = ElementEqual; + arithmetic_broadcast_run_ = BroadcastEqual; + break; + case PrimitiveType_NotEqual: + arithmetic_run_ = ElementNotEqual; + arithmetic_broadcast_run_ = BroadcastNotEqual; + break; + case PrimitiveType_Less: + arithmetic_run_ = ElementLess; + arithmetic_broadcast_run_ = BroadcastLess; + break; + case PrimitiveType_LessEqual: + arithmetic_run_ = ElementLessEqual; + arithmetic_broadcast_run_ = BroadcastLessEqual; + break; + case PrimitiveType_Greater: + arithmetic_run_ = ElementGreater; + arithmetic_broadcast_run_ = BroadcastGreater; + break; + case PrimitiveType_GreaterEqual: + arithmetic_run_ = ElementGreaterEqual; + arithmetic_broadcast_run_ = BroadcastGreaterEqual; + break; + case PrimitiveType_SquaredDifference: + arithmetic_run_ = ElementSquaredDifference; + arithmetic_broadcast_run_ = BroadcastSquaredDifference; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << parameter->type_; + arithmetic_run_ = nullptr; + arithmetic_broadcast_run_ = nullptr; + break; + } + } + ~ArithmeticCPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmetic(int task_id); + + private: + int thread_count_; + float *tile_data0_ = nullptr; + float *tile_data1_ = nullptr; + ArithmeticParameter *arithmeticParameter_; + ArithmeticRun arithmetic_run_; + ArithmeticBroadcastRun arithmetic_broadcast_run_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc new file mode 100644 index 0000000000..b3f9075bef --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc @@ -0,0 +1,285 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" +#include "src/runtime/kernel/arm/fp32/arithmetic_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +namespace { +constexpr int kArithGradOpInputNum = 3; +constexpr int kArithGradOpOutputNum = 2; +} // namespace + +int ArithmeticGradCPUKernel::Init() { + auto ret = InferShape(); + return ret; +} + +int ArithmeticGradCPUKernel::InferShape() { + if (inputs_.size() != kArithGradOpInputNum) { + MS_LOG(ERROR) << "The number of input must be " << kArithGradOpInputNum; + return RET_ERROR; + } + if (outputs_.size() != kArithGradOpOutputNum) { + MS_LOG(ERROR) << "The number of output must be " << kArithGradOpOutputNum; + return RET_ERROR; + } + auto dy = inputs_[0]; + auto x1 = inputs_[1]; + auto x2 = inputs_[2]; + auto dx1 = outputs_[0]; + auto dx2 = outputs_[1]; + + MS_ASSERT(dy != nullptr); + MS_ASSERT(x1 != nullptr); + MS_ASSERT(x2 != nullptr); + MS_ASSERT(dx1 != nullptr); + MS_ASSERT(dx2 != nullptr); + + auto inShape0 = x1->shape(); + auto inShape1 = x2->shape(); + auto outShape = dy->shape(); + + if ((type() == PrimitiveType_AddGrad) || (type() == PrimitiveType_SubGrad)) { + arithmeticParameter_->ndim_ = outShape.size(); + auto fillDimNum0 = outShape.size() - inShape0.size(); + auto fillDimNum1 = outShape.size() - inShape1.size(); + int j0 = 0; + int j1 = 0; + for (unsigned int i = 0; i < outShape.size(); i++) { + arithmeticParameter_->in_shape0_[i] = (i < fillDimNum0) ? 1 : inShape0[j0++]; + arithmeticParameter_->in_shape1_[i] = (i < fillDimNum1) ? 1 : inShape1[j1++]; + arithmeticParameter_->out_shape_[i] = outShape[i]; + } + } else { + // if (inShape0.size() < inShape1.size()) + if (dx1->ElementsNum() < dx2->ElementsNum()) { + arithmeticParameter_->ndim_ = inShape1.size(); + if (type() == PrimitiveType_MulGrad) + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul2L; + else if (type() == PrimitiveType_DivGrad) + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv2L; + + auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch! + int j = 0; + for (unsigned int i = 0; i < inShape1.size(); i++) { + if (i < fillDimNum) { + arithmeticParameter_->in_shape1_[i] = 1; + } else { + arithmeticParameter_->in_shape1_[i] = inShape0[j++]; + } + arithmeticParameter_->in_shape0_[i] = inShape1[i]; + arithmeticParameter_->out_shape_[i] = outShape[i]; + } + } else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size()) + arithmeticParameter_->ndim_ = inShape0.size(); + if (type() == PrimitiveType_MulGrad) + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul1L; + else if (type() == PrimitiveType_DivGrad) + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv1L; + arithmeticParameter_->broadcasting_ = true; + arithmeticParameter_->ndim_ = inShape0.size(); + int j = 0; + auto fillDimNum = inShape0.size() - inShape1.size(); + for (unsigned int i = 0; i < inShape0.size(); i++) { + if (i < fillDimNum) { + arithmeticParameter_->in_shape1_[i] = 1; + } else { + arithmeticParameter_->in_shape1_[i] = inShape1[j++]; + } + arithmeticParameter_->in_shape0_[i] = inShape0[i]; + arithmeticParameter_->out_shape_[i] = outShape[i]; + } + } else { + arithmeticParameter_->broadcasting_ = false; + for (unsigned int i = 0; i < inShape0.size(); i++) { + arithmeticParameter_->in_shape1_[i] = inShape1[i]; + arithmeticParameter_->in_shape0_[i] = inShape0[i]; + arithmeticParameter_->out_shape_[i] = outShape[i]; + } + } + tile_data0 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; + MS_ASSERT(tile_data0 != nullptr); + tile_data1 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; + MS_ASSERT(tile_data1 != nullptr); + if (type() == PrimitiveType_DivGrad) { + tile_data2 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; + MS_ASSERT(tile_data2 != nullptr); + } + } + + dx1->set_shape(x1->shape()); + dx2->set_shape(x2->shape()); + // outTensor->set_shape(out_shape); + dx1->set_data_type(dy->data_type()); + dx2->set_data_type(dy->data_type()); + return RET_OK; +} + +void ArithmeticGradCPUKernel::ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + if (dx1_size == dy_size) + memcpy(dx1, dy, dy_size * sizeof(float)); + else + ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_, + arithmeticParameter_->ndim_); + if (dx2_size == dy_size) + memcpy(dx2, dy, dy_size * sizeof(float)); + else + ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx2, arithmeticParameter_->in_shape1_, + arithmeticParameter_->ndim_); +} + +void ArithmeticGradCPUKernel::ArithmeticGradSub(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + if (dx1_size == dy_size) + memcpy(dx1, dy, dy_size * sizeof(float)); + else + ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_, + arithmeticParameter_->ndim_); + if (dx2_size == dy_size) { + for (int i = 0; i < dx2_size; i++) { + dx2[i] = -dy[i]; + } + } else { + ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx2, arithmeticParameter_->in_shape1_, + arithmeticParameter_->ndim_); + for (int i = 0; i < dx2_size; i++) { + dx2[i] = -dx2[i]; + } + } +} + +void ArithmeticGradCPUKernel::ArithmeticGradMul(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + auto x1_data = reinterpret_cast(inputs_[1]->Data()); + auto x2_data = reinterpret_cast(inputs_[2]->Data()); + ElementMul(dy, x1_data, dx2, dy_size); + ElementMul(dy, x2_data, dx1, dy_size); +} + +void ArithmeticGradCPUKernel::ArithmeticGradMul1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + auto x1_data = reinterpret_cast(inputs_[1]->Data()); + auto x2_data = reinterpret_cast(inputs_[2]->Data()); + ElementMul(dy, x1_data, tile_data0, dy_size); + ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_, + arithmeticParameter_->ndim_); + + BroadcastMul(dy, x2_data, tile_data0, tile_data1, dx1, dy_size, arithmeticParameter_); // broadcast directly to dx1 +} + +void ArithmeticGradCPUKernel::ArithmeticGradMul2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + auto x1_data = reinterpret_cast(inputs_[1]->Data()); + auto x2_data = reinterpret_cast(inputs_[2]->Data()); + ElementMul(dy, x2_data, tile_data0, dy_size); + ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx1, arithmeticParameter_->in_shape1_, + arithmeticParameter_->ndim_); + + BroadcastMul(dy, x1_data, tile_data0, tile_data1, dx2, dy_size, arithmeticParameter_); // broadcast directly to dx2 +} + +void ArithmeticGradCPUKernel::ArithmeticGradDiv(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + auto x1 = reinterpret_cast(inputs_[1]->Data()); + auto x2 = reinterpret_cast(inputs_[2]->Data()); + ElementDiv(dy, x2, dx1, dy_size); + ElementMulAndDivNegSquare(dy, x1, x2, dx2, dy_size); +} + +void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + auto x1_data = reinterpret_cast(inputs_[1]->Data()); + auto x2_data = reinterpret_cast(inputs_[2]->Data()); + + ElementMul(x2_data, x2_data, dx2, dx2_size); + ElementMul(x1_data, dy, dx1, dy_size); // use dx1 buffer + BroadcastDiv(dx1, dx2, tile_data0, tile_data1, tile_data2, dy_size, + arithmeticParameter_); // broadcast directly to dx1 + ReduceSumByAxes(tile_data2, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_, + arithmeticParameter_->ndim_); + for (int i = 0; i < dx2_size; i++) dx2[i] = -dx2[i]; + // ReduceNegSumPrefix(tile_data2, dy_size, dx2, dx2_size); //then reduce into dx2 + + // broadcasting x2 + BroadcastDiv(dy, x2_data, tile_data0, tile_data1, dx1, dy_size, arithmeticParameter_); // broadcast directly to dx1 +} + +void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, + int dx2_size) { + auto x1_data = reinterpret_cast(inputs_[1]->Data()); + auto x2_data = reinterpret_cast(inputs_[2]->Data()); + + // dx1 = dy/x2 + ElementDiv(dy, x2_data, tile_data0, dy_size); // first multiply into temp + ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx1, arithmeticParameter_->in_shape1_, + arithmeticParameter_->ndim_); + + // dx2 = -dy*x1/(x2*x2) + BroadcastMul(dy, x1_data, tile_data0, tile_data1, tile_data2, dy_size, arithmeticParameter_); // broadcast numerator + ElementDivNegSquare(tile_data2, x2_data, dx2, dy_size); +} + +int ArithmeticGradCPUKernel::ReSize() { return RET_OK; } + +int ArithmeticGradCPUKernel::Run() { + auto dy = reinterpret_cast(inputs_[0]->Data()); + // auto input1_data1 = reinterpret_cast(inputs_[1]->Data()); + auto dx1 = reinterpret_cast(outputs_[0]->Data()); + auto dx2 = reinterpret_cast(outputs_[1]->Data()); + + size_t dy_size = inputs_.at(0)->ElementsNum(); + size_t dx1_size = outputs_.at(0)->ElementsNum(); + size_t dx2_size = outputs_[1]->ElementsNum(); + (this->*arithmetic_grad_)(dy, dy_size, dx1, dx1_size, dx2, dx2_size); + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_EXCEPTION_IF_NULL(opParameter); + if (opParameter == nullptr) { + return nullptr; + } + auto *kernel = new (std::nothrow) ArithmeticGradCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulGrad, CpuArithmeticGradFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddGrad, CpuArithmeticGradFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SubGrad, CpuArithmeticGradFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DivGrad, CpuArithmeticGradFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h new file mode 100644 index 0000000000..cea0c0e659 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" +#include "schema/model_generated.h" +#include "ir/anf.h" + +using mindspore::schema::PrimitiveType_AddGrad; +using mindspore::schema::PrimitiveType_DivGrad; +using mindspore::schema::PrimitiveType_MulGrad; +using mindspore::schema::PrimitiveType_SubGrad; + +namespace mindspore::kernel { + +class ArithmeticGradCPUKernel; + +class ArithmeticGradCPUKernel : public LiteKernel { + typedef void (ArithmeticGradCPUKernel::*ArithmeticGradOperation)(float *, int, float *, int, float *, int); + + public: + explicit ArithmeticGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) { + switch (type()) { + case PrimitiveType_MulGrad: + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul; // this will be adjusted in InferShape + break; + case PrimitiveType_AddGrad: + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradAdd; + break; + case PrimitiveType_SubGrad: + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradSub; + break; + case PrimitiveType_DivGrad: + arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv; // this will be adjusted in InferShape + break; + default: + MS_LOG(ERROR) << "Error Operator type " << parameter->type_; + break; + } + arithmeticParameter_ = reinterpret_cast(parameter); + } + ~ArithmeticGradCPUKernel() override { + if (tile_data0) delete[] tile_data0; + if (tile_data1) delete[] tile_data1; + if (tile_data2) delete[] tile_data2; + } + void InitKernel(const CNodePtr &kernel_node); + + int Init() override; + int InferShape(); + int ReSize() override; + int Run() override; + + private: + void ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradSub(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradMul(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradMul1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradMul2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradDiv(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradDiv1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + void ArithmeticGradDiv2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); + ArithmeticParameter *arithmeticParameter_; + ArithmeticGradOperation arithmetic_grad_; + float *tile_data0; + float *tile_data1; + float *tile_data2; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc new file mode 100644 index 0000000000..d65ef0d1c9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/arithmetic_self.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int ArithmeticSelfCPUKernel::Init() { + int ret = ReSize(); + return ret; +} + +int ArithmeticSelfCPUKernel::ReSize() { + data_size_ = inputs_[0]->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int ArithmeticSelfRuns(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoArithmeticSelf(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ArithmeticSelfCPUKernel::DoArithmeticSelf(int task_id) { + int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + if (arithmeticSelf_run_) { + auto ret = arithmeticSelf_run_(in_ptr_ + offset, out_ptr_ + offset, size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run failed, illegal input! "; + return ret; + } + } else { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticSelfCPUKernel::Run() { + auto input_tensor = inputs_.at(0); + auto out_tensor = outputs_.at(0); + in_ptr_ = reinterpret_cast(input_tensor->Data()); + out_ptr_ = reinterpret_cast(out_tensor->Data()); + int ret = LiteBackendParallelLaunch(ArithmeticSelfRuns, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticSelfFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Creator failed, opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) ArithmeticSelfCPUKernel(opParameter, inputs, outputs, ctx); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Abs, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cos, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Exp, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Log, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Square, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h new file mode 100644 index 0000000000..bcc56820db --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h" +#include "schema/model_generated.h" +#include "include/context.h" + + +using mindspore::lite::Context; +using mindspore::schema::PrimitiveType_Abs; +using mindspore::schema::PrimitiveType_Cos; +using mindspore::schema::PrimitiveType_Exp; +using mindspore::schema::PrimitiveType_Floor; +using mindspore::schema::PrimitiveType_Log; +using mindspore::schema::PrimitiveType_LogicalNot; +using mindspore::schema::PrimitiveType_Rsqrt; +using mindspore::schema::PrimitiveType_Sin; +using mindspore::schema::PrimitiveType_Sqrt; +using mindspore::schema::PrimitiveType_Square; +using mindspore::schema::PrimitiveType_Ceil; + +namespace mindspore::kernel { +class ArithmeticSelfCPUKernel : public LiteKernel { + typedef int (*ArithmeticSelfRun)(float *input, float *output, int element_size); + + public: + explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + switch (parameter->type_) { + case PrimitiveType_Abs: + arithmeticSelf_run_ = ElementAbs; + break; + case PrimitiveType_Cos: + arithmeticSelf_run_ = ElementCos; + break; + case PrimitiveType_Exp: + arithmeticSelf_run_ = ElementExp; + break; + case PrimitiveType_Log: + arithmeticSelf_run_ = ElementLog; + break; + case PrimitiveType_Square: + arithmeticSelf_run_ = ElementSquare; + break; + case PrimitiveType_Sqrt: + arithmeticSelf_run_ = ElementSqrt; + break; + case PrimitiveType_Rsqrt: + arithmeticSelf_run_ = ElementRsqrt; + break; + case PrimitiveType_Sin: + arithmeticSelf_run_ = ElementSin; + break; + case PrimitiveType_LogicalNot: + arithmeticSelf_run_ = ElementLogicalNot; + break; + case PrimitiveType_Floor: + arithmeticSelf_run_ = ElementFloor; + break; + case PrimitiveType_Ceil: + arithmeticSelf_run_ = ElementCeil; + break; + default: + break; + } + arithmeticSelfParameter_ = reinterpret_cast(parameter); + } + ~ArithmeticSelfCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmeticSelf(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + size_t data_size_; + ArithmeticSelfParameter *arithmeticSelfParameter_; + ArithmeticSelfRun arithmeticSelf_run_; + const Context *ctx_; + float *in_ptr_; + float *out_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc new file mode 100644 index 0000000000..a24cae52de --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/batch_to_space.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/batch_to_space.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int BatchToSpaceCPUKernel::Init() { + return BatchToSpaceBaseCPUKernel::Init(); +} + +int BatchToSpaceCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + const float *input_data = reinterpret_cast(input->Data()); + float *output_data = reinterpret_cast(output->Data()); + auto in_shape = input->shape(); + auto out_shape = output->shape(); + BatchToSpaceParameter *param = reinterpret_cast(this->opParameter); + + if (IsNoCrop()) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, + sizeof(float)); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_, + sizeof(float)); + } + + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h new file mode 100644 index 0000000000..2933bee478 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ +#include +#include "src/runtime/kernel/arm/base/batch_to_space_base.h" + +namespace mindspore::kernel { +class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel { + public: + BatchToSpaceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~BatchToSpaceCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc new file mode 100644 index 0000000000..063fd4f6aa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/batchnorm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BatchNorm; + +namespace mindspore::kernel { +int BatchnormCPUKernel::Init() { return RET_OK; } + +int BatchnormCPUKernel::ReSize() { return RET_OK; } + +int BatchnormCPUKernel::DoExecute(int tid) { + int count = MSMIN(thread_unit_, units_ - tid * thread_unit_); + if (count <= 0) { + return RET_OK; + } + int offset = tid * thread_unit_ * channel_; + BatchNorm(in_addr_ + offset, mean_addr_, var_addr_, count, channel_, batchnorm_param_->epsilon_, out_addr_ + offset); + return RET_OK; +} + +int BatchNormRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BatchnormRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int BatchnormCPUKernel::Run() { + in_addr_ = reinterpret_cast(inputs_.at(0)->Data()); + mean_addr_ = reinterpret_cast(inputs_.at(1)->Data()); + var_addr_ = reinterpret_cast(inputs_.at(2)->Data()); + out_addr_ = reinterpret_cast(outputs_.at(0)->Data()); + auto input_shapes = inputs_[0]->shape(); + channel_ = input_shapes[3]; + units_ = 1; + for (int i = 0; i < 3; i++) { + units_ *= input_shapes[i]; + } + thread_count_ = MSMIN(thread_count_, units_); + thread_unit_ = UP_DIV(units_, thread_count_); + int ret = LiteBackendParallelLaunch(BatchNormRun, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuBatchnormKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_BatchNorm); + auto *kernel = new (std::nothrow) BatchnormCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchNormCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchNorm, CpuBatchnormKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h new file mode 100644 index 0000000000..e29aa20a2f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCHNORM_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/fp32/batchnorm.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class BatchnormCPUKernel : public LiteKernel { + public: + BatchnormCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + batchnorm_param_ = reinterpret_cast(parameter); + } + ~BatchnormCPUKernel() override { delete batchnorm_param_; } + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tid); + + private: + int thread_count_; + int thread_unit_; + int units_; + int channel_; + float *in_addr_; + float *mean_addr_; + float *var_addr_; + float *out_addr_; + const Context *ctx_; + BatchNormParameter *batchnorm_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCHNORM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc new file mode 100644 index 0000000000..9464b342a2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/bias.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/int8/bias_add_int8.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BiasAdd; + +namespace mindspore::kernel { +int BiasCPUKernel::ReSize() { return RET_OK; } + +int BiasCPUKernel::Run() { + auto in = reinterpret_cast(inputs_.at(0)->Data()); + auto bias = reinterpret_cast(inputs_.at(1)->Data()); + auto out = reinterpret_cast(outputs_.at(0)->Data()); + size_t data_size = inputs_.at(0)->ElementsNum(); + auto tile_in = new float[data_size]; + auto tile_bias = new float[data_size]; + BroadcastAdd(in, bias, tile_in, tile_bias, out, data_size, bias_param_); + delete[] tile_in; + delete[] tile_bias; + return RET_OK; +} + +int BiasCPUKernel::Init() { + auto dims = inputs_[0]->shape(); + MS_ASSERT(dims.size() <= 5); + bias_param_->ndim_ = dims.size(); + for (int i = 0; i < bias_param_->ndim_; i++) { + bias_param_->in_shape0_[i] = dims[i]; + bias_param_->in_shape1_[i] = 1; + bias_param_->out_shape_[i] = dims[i]; + } + bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; + return RET_OK; +} + +kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_BiasAdd); + auto kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, CpuBiasFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h new file mode 100644 index 0000000000..a4d88378fd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +namespace mindspore::kernel { +class BiasCPUKernel : public LiteKernel { + public: + BiasCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + bias_param_ = reinterpret_cast(parameter); + } + ~BiasCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + ArithmeticParameter *bias_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc new file mode 100644 index 0000000000..e57fe298ab --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32/bias_grad.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_BiasGrad; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int BiasGradCPUKernel::InferShape() { + if (1 != this->inputs_.size()) { + MS_LOG(ERROR) << "BiasGrad should have one input"; + return RET_ERROR; + } + if (1 != this->outputs_.size()) { + MS_LOG(ERROR) << "BiasGrad should have one output"; + return RET_ERROR; + } + auto *in0 = inputs_.front(); + auto *out = outputs_.front(); + MS_ASSERT(in0 != nullptr); + MS_ASSERT(out != nullptr); + auto inshape = in0->shape(); + int ndim = inshape.size(); + for (int i = 0; i < ndim - 1; i++) { + inshape[i] = 1; + } + out->set_shape(inshape); + out->set_data_type(in0->data_type()); + return RET_OK; +} + +int BiasGradCPUKernel::Init() { + MS_ASSERT(InferShape() == RET_OK); + + auto dims = inputs_[0]->shape(); + bias_param->ndim_ = dims.size(); + for (unsigned int i = 0; i < bias_param->ndim_; i++) { + bias_param->in_shape0_[i] = dims[i]; + bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W, + } + bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1]; + for (int i = bias_param->ndim_; i < 4; i++) { + bias_param->in_shape0_[i] = 0; + bias_param->out_shape_[i] = 0; + } + return RET_OK; +} + + +int BiasGradCPUKernel::ReSize() { return 0; } + +int BiasGradCPUKernel::Run() { + auto in = reinterpret_cast(inputs_.at(0)->Data()); + auto out = reinterpret_cast(outputs_.at(0)->Data()); + // size_t data_size = inputs_.at(0)->ElementsNum(); + + size_t nhw_size = 1; + size_t channels = bias_param->in_shape0_[bias_param->ndim_ - 1]; // C in NHWC + for (unsigned int i = 0; i < bias_param->ndim_ - 1; i++) nhw_size *= bias_param->in_shape0_[i]; + + size_t total_size = channels * nhw_size; + for (size_t c = 0; c < channels; ++c) { + out[c] = 0; + for (size_t offset = 0; offset < total_size; offset += channels) { + out[c] += in[offset + c]; + } + } + + return RET_OK; +} + + +kernel::LiteKernel *CpuBiasGradFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_BiasGrad); + auto *kernel = new (std::nothrow) BiasGradCPUKernel(reinterpret_cast(opParameter), inputs, outputs); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasGrad, CpuBiasGradFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h new file mode 100644 index 0000000000..797abfd162 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +namespace mindspore::kernel { +class BiasGradCPUKernel : public LiteKernel { + public: + explicit BiasGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + bias_param = reinterpret_cast(parameter); + } + ~BiasGradCPUKernel() override = default; + + int Init() override; + int InferShape(); + int ReSize() override; + int Run() override; + + private: + ArithmeticParameter *bias_param; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc new file mode 100644 index 0000000000..2a07167058 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "src/runtime/kernel/arm/fp32/bngrad_input.h" +#include "src/runtime//kernel/arm/nnacl/batch_norm.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +// using mindspore::lite::REG_OP; +using mindspore::schema::PrimitiveType_BNGradInput; + +namespace mindspore::kernel { +int BNGradInputCPUKernel::Init() { + auto bn_param = reinterpret_cast(opParameter); + workspace_size = 5 * bn_param->channels; + workspace = new float[workspace_size]; + + if (2 != this->inputs_.size()) { + MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; + return RET_ERROR; + } + if (1 != this->outputs_.size()) { + MS_LOG(ERROR) << "Conv2d Grad should has one output"; + return RET_ERROR; + } + auto *input_tensor = inputs_.at(0); + // auto *weight_tensor = inputs_.at(1); + auto *out_tensor = outputs_.at(0); + auto in_shape = input_tensor->shape(); + out_tensor->set_shape(in_shape); + out_tensor->set_data_type(input_tensor->data_type()); + return RET_OK; +} + +int BNGradInputCPUKernel::ReSize() { return RET_OK; } + +/* +according to https://wiseodd.github.io/techblog/2016/07/04/batchnorm +*/ + +int BNGradInputCPUKernel::Run() { + // std::cout << "run succ" << std::endl; + auto *input_x = inputs_.at(0); + auto *input_yt = inputs_.at(1); + auto *input_scale = inputs_.at(2); + auto *output_grad = outputs_.at(0); + // Tensor *bias = input[5]; + auto bn_param = reinterpret_cast(opParameter); + int batch = bn_param->batch; + int channels = bn_param->channels; + int spatial = bn_param->spatial; + float eps = bn_param->eps; + std::fill(workspace, workspace + workspace_size, 0.f); + + float *mean = workspace; + float *variance = mean + channels; + float *mean_delta = variance + channels; + float *variance_delta = mean_delta + channels; + float *mean_add_delta = variance_delta + channels; + + float *x = reinterpret_cast(input_x->Data()); + float *yt = reinterpret_cast(input_yt->Data()); + float *scale = reinterpret_cast(input_scale->Data()); + float *out = reinterpret_cast(output_grad->Data()); + + std::copy(yt, yt + batch * channels * spatial, out); + meanVar(x, batch, spatial, channels, mean, variance); + scaleBias(scale, batch, channels, spatial, out); + meanDelta(out, spatial, channels, eps, variance, mean_delta); + varianceDelta(x, out, mean, variance, batch, channels, spatial, eps, variance_delta); + meanAdd(x, mean, variance_delta, batch, channels, spatial, mean_add_delta, mean_delta); + NormalizeDelta(x, mean, variance, mean_delta, variance_delta, batch, channels, eps, spatial, out); + return RET_OK; +} + +kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_BNGradInput); + // parameter->name = opDef.name()->str().data(); + // parameter->type = opDef.attr_type(); + auto *kernel = new (std::nothrow) BNGradInputCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BNGradInput, CpuBNGradInputFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h new file mode 100644 index 0000000000..e4e6d6e746 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class BNGradInputCPUKernel : public LiteKernel { + public: + explicit BNGradInputCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~BNGradInputCPUKernel() override { delete workspace; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float *workspace; + int workspace_size; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc new file mode 100644 index 0000000000..ed5ffd9822 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/broadcast_to.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BroadcastTo; + +namespace mindspore::kernel { + +int BroadcastToCPUKernel::Init() { + auto input_shape = inputs_[0]->shape(); + for (size_t i = 0; i < input_shape.size(); ++i) { + shape_info_.input_shape_[i] = input_shape[i]; + } + + shape_info_.input_shape_size_ = static_cast(input_shape.size()); + auto output_shape = outputs_[0]->shape(); + for (size_t i = 0; i < output_shape.size(); ++i) { + shape_info_.output_shape_[i] = output_shape[i]; + } + shape_info_.output_shape_size_ = static_cast(output_shape.size()); + return RET_OK; +} + +int BroadcastToCPUKernel::Run() { + auto input_data = reinterpret_cast(inputs_.at(0)->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + + return BroadcastTo(input_data, &shape_info_, output_data); +} + +kernel::LiteKernel *CpuBroadcastToFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_BroadcastTo); + auto *kernel = new (std::nothrow) BroadcastToCPUKernel(op_parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BroadcastToCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BroadcastTo, CpuBroadcastToFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h new file mode 100644 index 0000000000..cfb8969448 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" + +namespace mindspore::kernel { +class BroadcastToCPUKernel : public LiteKernel { + public: + BroadcastToCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~BroadcastToCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + } + int Run() override; + private: + BroadcastShapeInfo shape_info_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc new file mode 100644 index 0000000000..b17bc4bb7a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/cast.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/cast.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Cast; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +const std::vector kSupportInputDataType = {kNumberTypeUInt8, kNumberTypeInt32}; +int CastRun(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { + if (cdata == nullptr) { + MS_LOG(ERROR) << "input cdata is nullptr!"; + return RET_ERROR; + } + + return reinterpret_cast(cdata)->DoCast(thread_id); +} +} // namespace + +int CastCPUKernel::Init() { + data_num_ = inputs_[0]->ElementsNum(); + if (data_num_ == 0) { + return RET_OK; + } + thread_num_ = MSMIN(thread_num_, data_num_); + stride_ = UP_DIV(data_num_, thread_num_); + return RET_OK; +} + +int CastCPUKernel::DoCast(int thread_id) { + auto input = inputs_.at(0); + int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); + if (data_num <= 0) { + return RET_OK; + } + + auto offset = thread_id * stride_; + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + switch (input->data_type()) { + case kNumberTypeUInt8: + Uint8ToFloat32(reinterpret_cast(input->Data()) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFloat32(reinterpret_cast(input->Data()) + offset, output_data + offset, data_num); + break; + default: + MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); + return RET_ERROR; + } + return RET_OK; +} + +int CastCPUKernel::Run() { + if (data_num_ == 0) { + return RET_OK; + } + return LiteBackendParallelLaunch(CastRun, this, thread_num_); +} + +kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "Input context is nullptr!"; + return nullptr; + } + if (ctx->thread_num_ == 0) { + MS_LOG(ERROR) << "context thread num is 0!"; + return nullptr; + } + auto *kernel = new (std::nothrow) CastCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new CastCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, CpuCastFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.h b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.h new file mode 100644 index 0000000000..3d75715319 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class CastCPUKernel : public LiteKernel { + public: + CastCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + if (ctx != nullptr) { + thread_num_ = ctx->thread_num_; + } + } + + ~CastCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + }; + int Run() override; + int DoCast(int thread_id); + private: + uint32_t thread_num_; + uint32_t stride_; + uint32_t data_num_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc new file mode 100644 index 0000000000..df3ba83c5c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/concat.h" +#include +#include "src/runtime/kernel/arm/nnacl/fp32/concat.h" +#include "src/kernel_registry.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Concat; + +namespace mindspore::kernel { + int ConcatCPUKernel::Init() { + ConcatBaseCPUKernel::Init(); + schema::Format input0_format = inputs_[0]->GetFormat(); + bool need_convert_format = false; + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != input0_format) { + need_convert_format = true; + } + } + if (!need_convert_format) { + outputs_[0]->SetFormat(input0_format); + return RET_OK; + } + MS_LOG(ERROR) << "All input format should be the same!"; + return RET_ERROR; + } + + int ConcatCPUKernel::ReSize() { return RET_OK; } + + int ConcatCPUKernel::Run() { + auto input_num = inputs_.size(); + std::vector inputs_addr(input_num, nullptr); + std::vector inputs_output_shape(input_num + 1, nullptr); + + std::vector > shapes; + for (size_t i = 0; i < input_num; ++i) { + inputs_addr[i] = inputs_[i]->Data(); + shapes.push_back(inputs_[i]->shape()); + inputs_output_shape[i] = shapes[i].data(); + } + auto output_shape = outputs_.at(0)->shape(); + inputs_output_shape[input_num] = output_shape.data(); + auto output_addr = outputs_.at(0)->Data(); + + Concat(reinterpret_cast(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), + output_shape.size(), output_addr); + return RET_OK; + } +} // namespace mindspore::kernel + + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h new file mode 100644 index 0000000000..078921a53d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/base/concat_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ConcatCPUKernel : public ConcatBaseCPUKernel { + public: + ConcatCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~ConcatCPUKernel() = default; + + int Init() override; + + int ReSize() override; + + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc new file mode 100644 index 0000000000..bee5e46e75 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -0,0 +1,296 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution.h" +#include "src/runtime/kernel/arm/fp32/convolution_1x1.h" +#include "src/runtime/kernel/arm/fp32/convolution_3x3.h" +#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +int ConvolutionCPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int oc_block, oc_block_num; +#ifdef ENABLE_ARM32 + oc_block = C4NUM; + oc_block_num = UP_DIV(out_channel, C4NUM); +#else + oc_block = C8NUM; + oc_block_num = UP_DIV(out_channel, C8NUM); +#endif + int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane; + + // init weight + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed weight failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + PackWeightFp32(origin_weight, conv_param_, packed_weight_, oc_block, oc_block_num); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc_block_num * oc_block * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionCPUKernel::InitTmpBuffer() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_batch = conv_param_->input_batch_; + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_channel = conv_param_->output_channel_; + int kernel_plane = kernel_h * kernel_w; + + // malloc packed_inputs + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * TILE_NUM * unit_size; + /*=============================packed_input============================*/ + packed_input_ = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed input failed."; + return RET_ERROR; + } + memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float)); + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = + ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4 input failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + /*=============================tmp_output_block_============================*/ + tmp_output_block_ = reinterpret_cast(malloc(TILE_NUM * out_channel * sizeof(float))); + if (tmp_output_block_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp output block failed."; + return RET_ERROR; + } + return RET_OK; +} + +void ConvolutionCPUKernel::ConfigInputOutput() { + // set output format + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + + // select trans func for input + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +#ifdef ENABLE_ARM32 + gemm_func_ = IndirectGemmFp32_8x4; +#else + gemm_func_ = IndirectGemmFp32_8x8; +#endif +} + +int ConvolutionCPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + // config input output + ConfigInputOutput(); + return RET_OK; +} + +int ConvolutionCPUKernel::ReSize() { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionCPUKernel::RunImpl(int task_id) { + if (gemm_func_ == nullptr) { + MS_LOG(ERROR) << "gemm_func is nullptr."; + return RET_ERROR; + } + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + ConvFp32(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_output_block_, output_addr, task_id, conv_param_, gemm_func_); + return RET_OK; +} + +int ConvolutionImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionCPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + int input_unit = conv_param->kernel_h_ + *output_unit - 1; + input_trans_func = GetInputTransFunc(input_unit); + if (input_trans_func == nullptr) { + MS_LOG(INFO) << "No matching input trans func. Turn back to common conv."; + *use_winograd = false; + } + output_trans_func = GetOutputTransFunc(input_unit, *output_unit); + if (output_trans_func == nullptr) { + MS_LOG(INFO) << "No matching output trans func. Turn back to common conv."; + *use_winograd = false; + } + } else { + *use_winograd = false; + } + } else { + *use_winograd = false; + } +} + +kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + bool use_winograd; + int out_unit; + InputTransformUnitFunc input_trans_func = nullptr; + OutputTransformUnitFunc output_trans_func = nullptr; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); + kernel::LiteKernel *kernel; + if (kernel_h == 1 && kernel_w == 1) { + kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); + } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); + } else if (use_winograd) { + kernel = new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit); + } else { + kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2D, CpuConvFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h new file mode 100644 index 0000000000..9e79609280 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" + +namespace mindspore::kernel { +class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionCPUKernel() override { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float *packed_input_; + float *packed_weight_; + float *tmp_output_block_; + GEMM_FUNC_FP32 gemm_func_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc new file mode 100644 index 0000000000..e2456bfdf4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_1x1.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + weight_ptr_ = nullptr; + } + if (pack_input_ != nullptr) { + free(pack_input_); + pack_input_ = nullptr; + } + if (pack_output_ != nullptr) { + free(pack_output_); + pack_output_ = nullptr; + } + if (pre_trans_input_ && input_ptr_ != nullptr) { + free(input_ptr_); + input_ptr_ = nullptr; + } + delete matmul_param_; +} + +int Convolution1x1CPUKernel::ReSize() { + if (pack_input_ != nullptr) { + free(pack_input_); + pack_input_ = nullptr; + } + if (pre_trans_input_ && input_ptr_ != nullptr) { + free(input_ptr_); + input_ptr_ = nullptr; + } + InitConv1x1MatmulParam(); + InitConv1x1Param(); + return RET_OK; +} + +void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { + matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; + matmul_param_->col_ = conv_param_->output_channel_; + matmul_param_->deep_ = conv_param_->input_channel_; + matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); + matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); + matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No; + matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_; + return; +} + +int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { + if (inputs_.size() == 3) { + bias_data_ = malloc(matmul_param_->col_8_ * sizeof(float)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, matmul_param_->col_8_ * sizeof(float)); + memcpy(bias_data_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); + } else { + bias_data_ = nullptr; + } + + weight_ptr_ = reinterpret_cast(malloc(matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float))); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float)); + RowMajor2Col8Major(reinterpret_cast(inputs_[1]->Data()), weight_ptr_, matmul_param_->col_, + matmul_param_->deep_); + return RET_OK; +} + +int Convolution1x1CPUKernel::InitConv1x1Param() { + pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 || + conv_param_->stride_w_ != 1); + if (pre_trans_input_) { + input_ptr_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float))); + if (input_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!"; + return RET_MEMORY_FAILED; + } + memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)); + } + + thread_count_ = MSMIN(opParameter->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; + + pack_input_ = reinterpret_cast(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float))); + if (pack_input_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; + return RET_MEMORY_FAILED; + } + memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); + + pack_output_ = reinterpret_cast(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float))); + if (pack_output_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!"; + return RET_MEMORY_FAILED; + } + memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)); + return RET_OK; +} + +void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { + output_ptr_ = src_output; + + if (pre_trans_input_) { + Conv1x1InputPackFp32(src_input, input_ptr_, conv_param_); + } else { + input_ptr_ = src_input; + } + + RowMajor2Col8Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); + return; +} + +int Convolution1x1CPUKernel::Init() { + ConvolutionBaseCPUKernel::Init(); + InitConv1x1MatmulParam(); + + int error_code = InitConv1x1BiasWeight(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution base init failed."; + return error_code; + } + error_code = InitConv1x1Param(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution base init failed."; + return error_code; + } + return RET_OK; +} + +int Convolution1x1CPUKernel::DoConv1x1(int task_id) { + int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast(bias_data_) + thread_stride_ * task_id; + + MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, + pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_, + matmul_param_->deep_, matmul_param_->row_8_, cur_oc); + + return RET_OK; +} + +int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv1x1 = reinterpret_cast(cdata); + auto error_code = conv1x1->DoConv1x1(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution1x1Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution1x1CPUKernel::Run() { + auto src_in = reinterpret_cast(inputs_[0]->Data()); + auto src_out = reinterpret_cast(outputs_[0]->Data()); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, + src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); + + int error_code = LiteBackendParallelLaunch(Convolution1x1Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; + return RET_ERROR; + } + + Row8x8Major2RowMajor(pack_output_, output_ptr_, matmul_param_->row_, matmul_param_->col_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h new file mode 100644 index 0000000000..6d5840e017 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" +#include "src/runtime/kernel/arm/nnacl/fp32/common_func.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" + +namespace mindspore::kernel { +class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution1x1CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) { + matmul_param_ = new MatMulParameter(); + } + ~Convolution1x1CPUKernel(); + int Init() override; + int Run() override; + int ReSize() override; + + public: + int DoConv1x1(int task_id); + + private: + int InitConv1x1Param(); + int InitConv1x1BiasWeight(); + void InitConv1x1MatmulParam(); + void Pre1x1Trans(float *src_input, float *src_output); + + private: + MatMulParameter *matmul_param_ = nullptr; + bool pre_trans_input_ = false; + int thread_count_ = 0; + int thread_stride_ = 0; + float *weight_ptr_ = nullptr; + float *pack_input_ = nullptr; + float *pack_output_ = nullptr; + float *input_ptr_ = nullptr; + float *output_ptr_ = nullptr; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc new file mode 100644 index 0000000000..aa1f363010 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc @@ -0,0 +1,255 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_3x3.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num) { + auto input_channel = conv_param->input_channel_; + auto output_channel = conv_param->output_channel_; + auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int iC4 = UP_DIV(input_channel, C4NUM); + + size_t tmp_size = oc_block_num * oc_block * iC4 * C4NUM * kernel_plane * sizeof(float); + auto tmp_addr = reinterpret_cast(malloc(tmp_size)); + if (tmp_addr == nullptr) { + MS_LOG(ERROR) << "malloc tmp_addr failed."; + return; + } + memset(tmp_addr, 0, tmp_size); + + PackNHWCToNC4HW4Fp32(origin_weight, tmp_addr, output_channel, kernel_plane, input_channel); + Conv3x3Fp32FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane, oc_block); + free(tmp_addr); +} + +int Convolution3x3CPUKernel::InitWeightBias() { + auto input_channel = conv_param_->input_channel_; + auto output_channel = conv_param_->output_channel_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oC4 = UP_DIV(output_channel, C4NUM); + int oc_block, oc_block_num; +#ifdef ENABLE_ARM32 + oc_block = C4NUM; + oc_block_num = UP_DIV(output_channel, C4NUM); +#else + oc_block = C8NUM; + oc_block_num = UP_DIV(output_channel, C8NUM); +#endif + int k_plane = 16; + // init weight + size_t transformed_size = iC4 * C4NUM * oc_block_num * oc_block * k_plane * sizeof(float); + transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); + if (transformed_filter_addr_ == nullptr) { + MS_LOG(ERROR) << "malloc transformed filter addr failed."; + return RET_ERROR; + } + memset(transformed_filter_addr_, 0, transformed_size); + auto weight_data = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + ProcessFilter(weight_data, transformed_filter_addr_, conv_param_, oc_block, oc_block_num); + + // init bias + size_t new_bias_size = oC4 * C4NUM * sizeof(float); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias data failed."; + return RET_ERROR; + } + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int Convolution3x3CPUKernel::InitTmpBuffer() { + int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int k_plane = 16; + + /*=============================tile_buffer_============================*/ + size_t tile_buffer_size = thread_count_ * TILE_NUM * k_plane * iC4 * C4NUM * sizeof(float); + tile_buffer_ = reinterpret_cast(malloc(tile_buffer_size)); + if (tile_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tile buffer failed."; + return RET_ERROR; + } + memset(tile_buffer_, 0, tile_buffer_size); + + /*=============================block_unit_buffer_============================*/ + size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float); + block_unit_buffer_ = reinterpret_cast(malloc(block_unit_buffer_size)); + if (block_unit_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; + return RET_ERROR; + } + memset(block_unit_buffer_, 0, block_unit_buffer_size); + + /*=============================tmp_dst_buffer_============================*/ + size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * k_plane * oC4 * C4NUM * sizeof(float); + tmp_dst_buffer_ = reinterpret_cast(malloc(tmp_dst_buffer_size)); + if (tmp_dst_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; + return RET_ERROR; + } + memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = + iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + /*=============================nc4hw4_out_============================*/ + size_t nc4hw4_out_size = + oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float); + nc4hw4_out_ = reinterpret_cast(malloc(nc4hw4_out_size)); + if (nc4hw4_out_ == nullptr) { + MS_LOG(ERROR) << "malloc nc4hw4_out_ failed."; + return RET_ERROR; + } + memset(nc4hw4_out_, 0, nc4hw4_out_size); + tmp_buffer_address_list_[0] = tile_buffer_; + tmp_buffer_address_list_[1] = block_unit_buffer_; + tmp_buffer_address_list_[2] = tmp_dst_buffer_; + tmp_buffer_address_list_[3] = nc4hw4_out_; + return RET_OK; +} + +void Convolution3x3CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +#ifdef ENABLE_ARM32 + gemm_func_ = IndirectGemmFp32_8x4; +#else + gemm_func_ = IndirectGemmFp32_8x8; +#endif +} + +int Convolution3x3CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + return RET_OK; +} + +int Convolution3x3CPUKernel::ReSize() { + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + if (nc4hw4_out_ != nullptr) { + free(nc4hw4_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3CPUKernel::RunImpl(int task_id) { + if (gemm_func_ == nullptr) { + MS_LOG(ERROR) << "gemm_func is nullptr."; + return RET_ERROR; + } + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Conv3x3Fp32(reinterpret_cast(nhwc4_input_), transformed_filter_addr_, reinterpret_cast(bias_data_), + output_addr, tmp_buffer_address_list_, task_id, conv_param_, gemm_func_); + return RET_OK; +} + +int Convolution3x3Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv3x3 = reinterpret_cast(cdata); + auto error_code = conv3x3->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution3x3 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3CPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(Convolution3x3Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h new file mode 100644 index 0000000000..90dbac5a6f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" + +namespace mindspore::kernel { +class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution3x3CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution3x3CPUKernel() override { + if (transformed_filter_addr_ != nullptr) { + free(transformed_filter_addr_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (nc4hw4_out_ != nullptr) { + free(nc4hw4_out_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float *transformed_filter_addr_; + float *tile_buffer_; + float *block_unit_buffer_; + float *tmp_dst_buffer_; + float *nc4hw4_out_; + TmpBufferAddress tmp_buffer_address_list_[4]; + GEMM_FUNC_FP32 gemm_func_ = nullptr; +}; +void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc new file mode 100644 index 0000000000..6eec3f91aa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -0,0 +1,210 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" +#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int ConvolutionDepthwiseCPUKernel::InitWeightBias() { + // init weight: o, h, w, i; o == group, i == 1 + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // init bias + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C4NUM * OC4 * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float)); + } + + // init threadNum; + conv_param_->thread_num_ = MSMIN(thread_count_, OC4); + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::InitBuffer() { + // malloc pack input and output buffer + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_input_, 0, pack_input_size * sizeof(float)); + + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + } + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding window param + sliding_ = new SlidingWindowParam; + InitSlidingParam(sliding_, conv_param_, C4NUM); + + auto ret = InitWeightBias(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp32 InitWeightBias failed."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp32 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::ReSize() { + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding window param + sliding_ = new SlidingWindowParam; + InitSlidingParam(sliding_, conv_param_, C4NUM); + + auto ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Convolution depthwise fp32 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::Execute(int task_id) { + ConvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding_, task_id); + return RET_OK; +} + +int ConvDwRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw = reinterpret_cast(cdata); + auto ret = conv_dw->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwiseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + + // pack input: to nhwc4 + if (need_align_) { + PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_, + conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); + } else { + packed_input_ = input_addr; + } + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(ConvDwRun, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} + +kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + kernel::LiteKernel *kernel; + kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + // auto param = reinterpret_cast(opParameter); + // if (param->kernel_h_ == 3 && param->kernel_w_ == 3 && param->stride_h_ == 1 && param->stride_w_ == 1 && + // param->dilation_h_ == 1 && param->dilation_w_ == 1) { + // kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx); + // } else { + // kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + // } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, CpuConvDwFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h new file mode 100644 index 0000000000..e0d742c6ae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwiseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionDepthwiseCPUKernel() override { + delete sliding_; + free(packed_weight_); + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitBuffer(); + int InitWeightBias(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding_; + float *packed_weight_; + float *packed_input_; + float *packed_output_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc new file mode 100644 index 0000000000..6f2322255b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { + // init weight: o, h, w, i; o == group, i == 1 + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + // o h w 1 -> o/4 h w 1 4 + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int weight_c4_size = OC4 * C4NUM * 9; + auto tmp_weight = reinterpret_cast(malloc(weight_c4_size * sizeof(float))); + if (tmp_weight == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(tmp_weight, 0, weight_c4_size * sizeof(float)); + PackNCHWToNC4HW4Fp32(origin_weight, tmp_weight, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // weight transform + int packed_weight_size = OC4 * C4NUM * 16; + packed_weight_ = reinterpret_cast(malloc(packed_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, packed_weight_size * sizeof(float)); + ConvDw3x3Fp32FilterTrans(packed_weight_, tmp_weight, OC4); + + // init bias + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C4NUM * OC4 * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float)); + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::InitBuffer() { + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_input_, 0, pack_input_size * sizeof(float)); + + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + } + + // malloc transform buffer + trans_size_ = UP_DIV(conv_param_->output_w_, 2) * UP_DIV(conv_param_->output_h_, 2) * 16 * C4NUM; + size_t trans_buffer_size = thread_count_ * trans_size_ * sizeof(float); + trans_buffer_ = reinterpret_cast(malloc(trans_buffer_size)); + if (trans_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc trans buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise3x3 fp32 initWeightBias error!"; + return ret; + } + + // init threadNum; + conv_param_->thread_num_ = MSMIN(thread_count_, UP_DIV(conv_param_->output_channel_, C4NUM)); + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise3x3 fp32 initBuffer error!"; + return ret; + } + + // malloc one block buffer + block_buffer_ = reinterpret_cast(malloc(thread_count_ * 16 * C4NUM * sizeof(float))); + if (block_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::ReSize() { + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + free(trans_buffer_); + + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise3x3 fp32 initBuffer error!"; + return ret; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) { + auto trans_buf = trans_buffer_ + task_id * trans_size_; + auto block_buf = block_buffer_ + task_id * 16 * C4NUM; + ConvDw3x3Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), trans_buf, + block_buf, conv_param_, task_id); + return RET_OK; +} + +int ConvDw3x3Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw_3x3 = reinterpret_cast(cdata); + auto ret = conv_dw_3x3->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwise3x3Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwise3x3CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + + // pack input: to nhwc4 + if (need_align_) { + PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_, + conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); + } else { + packed_input_ = input_addr; + } + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(ConvDw3x3Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h new file mode 100644 index 0000000000..63f3d35cd2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwise3x3CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~ConvolutionDepthwise3x3CPUKernel() override { + free(packed_weight_); + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + free(block_buffer_); + free(trans_buffer_); + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitWeightBias(); + int InitBuffer(); + int Execute(int task_id); + + private: + float *packed_weight_; + float *packed_input_; + float *packed_output_; + float *block_buffer_; + float *trans_buffer_; + int trans_size_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc new file mode 100644 index 0000000000..3deb6a2017 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_grad_filter.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/pack_ext.h" +#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2DGradFilter; + +namespace mindspore::kernel { +int ConvolutionGradFilterCPUKernel::Init() { + // dy is in input 0 + // x is in input 1 + // dw is output 0 + + if (2 != this->inputs_.size()) { + MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; + return RET_ERROR; + } + if (1 != this->outputs_.size()) { + MS_LOG(ERROR) << "Conv2d Grad should has one output"; + return RET_ERROR; + } + + auto *input_tensor = inputs_.at(1); + MS_ASSERT(input_tensor != nullptr); + auto *dy = inputs_.at(0); + MS_ASSERT(dy != nullptr); + auto *weight_tensor = outputs_.at(0); + MS_ASSERT(weight_tensor != nullptr); + + auto conv_param = reinterpret_cast(opParameter); + conv_param->output_batch_ = this->inputs_.at(0)->shape().at(kNHWC_N); + conv_param->input_batch_ = this->inputs_.at(1)->shape().at(kNHWC_N); + conv_param->input_h_ = this->inputs_.at(1)->shape().at(kNHWC_H); + conv_param->input_w_ = this->inputs_.at(1)->shape().at(kNHWC_W); + // assume OutCh|kh|kw|In + conv_param->input_channel_ = this->inputs_.at(1)->shape().at(kNHWC_C); + conv_param->output_channel_ = this->outputs_.at(0)->shape().at(kNHWC_N); + + int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * + conv_param->input_channel_ / conv_param->group_; + + workspace = new float[ws_size]; + + int output_w = 0; + int output_h = 0; + output_h = dy->shape()[kNHWC_H]; + output_w = dy->shape()[kNHWC_W]; + + std::vector out_shape(4); + out_shape.at(0) = conv_param->output_channel_; + out_shape.at(1) = conv_param->kernel_h_; + out_shape.at(2) = conv_param->kernel_w_; + out_shape.at(3) = conv_param->input_channel_ / conv_param->group_; + + // weight is output + weight_tensor->set_shape(out_shape); + weight_tensor->set_data_type(input_tensor->data_type()); + + conv_param->output_h_ = output_h; + conv_param->output_w_ = output_w; + + return RET_OK; +} + +int ConvolutionGradFilterCPUKernel::ReSize() { return 0; } + +int ConvolutionGradFilterCPUKernel::Run() { + auto conv_param = reinterpret_cast(opParameter); + auto *input_dy = inputs_.at(0); + auto *input_x = inputs_.at(1); + auto *out_dw = outputs_.at(0); + + auto x_addr = reinterpret_cast(input_x->Data()); + auto dy_addr = reinterpret_cast(input_dy->Data()); + auto dw_addr = reinterpret_cast(out_dw->Data()); + + int i, j; + int nweights = out_dw->ElementsNum(); + int in_ch = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; // out_dw->shape()[1]; + int k_w = conv_param->kernel_w_; // out_dw->shape()[2]; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int groups = conv_param->group_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int n = k_h * k_w * in_ch / groups; + int k = out_ch / groups; + + // zero out pointer + memset(dw_addr, 0, out_dw->Size()); + + for (i = 0; i < batch; ++i) { + for (j = 0; j < groups; ++j) { + float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups); + float *mat_b = workspace; + float *mat_c = dw_addr + j * nweights / groups; + float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups); + + im2row_hwc(im, mat_b, conv_param); + gemm(1, 1, k, n, m, 1, mat_a, out_ch, mat_b, m, 1, mat_c, n); + } + } + + // std::cout << "run succ" << std::endl; + return RET_OK; +} + +kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradFilter); + + auto *kernel = new (std::nothrow) ConvolutionGradFilterCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradFilter, CpuConvGradFilterFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h new file mode 100644 index 0000000000..c32a798eaf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class ConvolutionGradFilterCPUKernel : public LiteKernel { + public: + explicit ConvolutionGradFilterCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ConvolutionGradFilterCPUKernel() override { delete workspace; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float *workspace; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc new file mode 100644 index 0000000000..6e0683b301 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_grad_input.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/pack_ext.h" +#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Conv2DGradInput; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int ConvolutionGradInputCPUKernel::Init() { + if (2 != this->inputs_.size()) { + MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; + return RET_ERROR; + } + if (1 != this->outputs_.size()) { + MS_LOG(ERROR) << "Conv2d Grad should has one output"; + return RET_ERROR; + } + + auto *dy_tensor = inputs_.at(kInputIndex); + MS_ASSERT(dy_tensor != nullptr); + auto *weight_tensor = inputs_.at(kWeightIndex); + MS_ASSERT(weight_tensor != nullptr); + auto *dx_tensor = outputs_.at(kOutputIndex); + MS_ASSERT(dx_tensor != nullptr); + + auto conv_param = reinterpret_cast(opParameter); + conv_param->output_batch_ = dx_tensor->shape()[(kNHWC_N)]; + conv_param->input_batch_ = dy_tensor->shape()[(kNHWC_N)]; + + conv_param->input_h_ = dx_tensor->shape()[(kNHWC_H)]; + conv_param->input_w_ = dx_tensor->shape()[(kNHWC_W)]; + + // assume OutCh|kh|kw|In + conv_param->input_channel_ = dx_tensor->shape()[(kNHWC_C)]; + conv_param->output_channel_ = weight_tensor->shape()[(kNHWC_N)]; + + // TBD + conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; + conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; + + int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * + conv_param->input_channel_ / conv_param->group_; + + workspace = new float[ws_size]; + return 0; +} + +int ConvolutionGradInputCPUKernel::ReSize() { return 0; } + +int ConvolutionGradInputCPUKernel::Run() { + auto conv_param = reinterpret_cast(opParameter); + auto *input_dy = inputs_.at(0); + auto *input_w = inputs_.at(1); + auto *out_dx = outputs_.at(0); + + auto dy_addr = reinterpret_cast(input_dy->Data()); + auto w_addr = reinterpret_cast(input_w->Data()); + auto dx_addr = reinterpret_cast(out_dx->Data()); + + int i, j; + int nweights = input_w->ElementsNum(); + int in_ch = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; // out_dw->shape()[1]; + int k_w = conv_param->kernel_w_; // out_dw->shape()[2]; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int groups = conv_param->group_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int n = k_w * k_h * in_ch / groups; + int k = out_ch / groups; + + memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); + + for (i = 0; i < batch; ++i) { + for (j = 0; j < groups; ++j) { + float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups); + float *mat_b = w_addr + j * nweights / groups; + float *mat_c = workspace; + gemm(0, 0, m, n, k, 1, mat_a, out_ch, mat_b, n, 0, mat_c, n); + col2im_hwc(mat_c, dx_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups), conv_param); + } + } + + // std::cout << "run succ" << std::endl; + return 0; +} + +kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput); + + auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradInput, CpuConvGradInputFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h new file mode 100644 index 0000000000..86901b37ba --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class ConvolutionGradInputCPUKernel : public LiteKernel { + public: + explicit ConvolutionGradInputCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ConvolutionGradInputCPUKernel() override { delete workspace; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float *workspace; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc new file mode 100644 index 0000000000..3204ee8e05 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -0,0 +1,357 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit, + ConvParameter *conv_param, int oc_block) { + // original weight format : ohwi + auto channel_in = conv_param->input_channel_; + auto channel_out = conv_param->output_channel_; + int input_unit_square = input_unit * input_unit; + + // generate matrix_G && matrix_GT + auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit); + auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit); + ChooseMatrixG(matrix_g, matrix_gt); + auto matrix_g_data = reinterpret_cast(matrix_g->GetData()); + auto matrix_gt_data = reinterpret_cast(matrix_gt->GetData()); + + // trans_filter = G*g*GT (g represents weight_data) + // separate into two steps ===> tmp = G*g ===> out = tmp * GT + auto tmp_weight_data = reinterpret_cast(malloc(kernel_unit * kernel_unit * sizeof(float))); + auto tmp_data = reinterpret_cast(malloc(input_unit * kernel_unit * sizeof(float))); + auto trans_out_data = reinterpret_cast(malloc(input_unit * input_unit * sizeof(float))); + bool row = true; + auto trans_weight_data = reinterpret_cast(trans_weight->GetData()); + std::vector strides = trans_weight->GetStride(); + + int kernel_plane_stride = channel_in; + for (int i = 0; i < channel_out; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int input_oz_offset = i * kernel_unit * kernel_unit * channel_in; + int output_oz_offset = out_c_block * strides[1] * input_unit * input_unit + out_c_res; + for (int j = 0; j < channel_in; j++) { + int ic4_block = j / C4NUM; + int ic4_res = j % C4NUM; + int input_iz_offset = input_oz_offset + j; + int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; + for (int k = 0; k < kernel_unit * kernel_unit; k++) { + int input_xy_offset = input_iz_offset + k * kernel_plane_stride; + tmp_weight_data[k] = *(weight_data + input_xy_offset); + } + // now we only support row-major matrix-multiply + // tmp = G * g + MatrixMultiply(matrix_g_data, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row); + // out = tmp * GT + MatrixMultiply(tmp_data, matrix_gt_data, trans_out_data, input_unit, kernel_unit, input_unit, row); + + for (int z = 0; z < input_unit_square; z++) { + int output_xy_offset = output_iz_offset + z * strides[1]; + *(trans_weight_data + output_xy_offset) = trans_out_data[z]; + } + } + } + free(tmp_weight_data); + free(tmp_data); + free(trans_out_data); + delete matrix_g; + delete matrix_gt; +} + +int ConvolutionWinogradCPUKernel::InitWeightBias() { + int output_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int oc_block, oc_block_num; +#ifdef ENABLE_ARM32 + oc_block = C4NUM; + oc_block_num = UP_DIV(output_channel, C4NUM); +#else + oc_block = C8NUM; + oc_block_num = UP_DIV(output_channel, C8NUM); +#endif + + // init weight + auto ret = MallocFilterMatrix(oc_block, oc_block_num); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Malloc filter matrix failed."; + return RET_ERROR; + } + auto weight_tensor = inputs_.at(kWeightIndex); + auto weight_data = reinterpret_cast(weight_tensor->Data()); + WinogradFilterTransform(weight_data, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block); + + // init bias + size_t new_bias_size = oc4 * C4NUM * sizeof(float); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) { + int channel_in = conv_param_->input_channel_; + int ic4 = UP_DIV(channel_in, BLOCK); + + // set data + auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float); + auto matrix_buffer = malloc(trans_matrix_data_size); + if (matrix_buffer == nullptr) { + MS_LOG(ERROR) << "malloc matrix_buffer failed."; + return RET_ERROR; + } + memset(matrix_buffer, 0, trans_matrix_data_size); + trans_weight_ = new Matrix(); + trans_weight_->SetData(matrix_buffer); + trans_weight_->SetNDim(5); + + std::vector shapes; + std::vector strides; + // set shape + shapes.push_back(input_unit_ * input_unit_); + shapes.push_back(oc_block_num); + shapes.push_back(ic4); + shapes.push_back(C4NUM); + shapes.push_back(oc_block); + // set stride + for (int i = 0; i < 4; i++) { + int stride = 1; + for (int j = i + 1; j < 5; j++) { + stride *= shapes[j]; + } + strides.push_back(stride); + } + trans_weight_->SetShape(shapes); + trans_weight_->SetStride(strides); + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::InitTmpBuffer() { + int channel_in = conv_param_->input_channel_; + int channel_out = conv_param_->output_channel_; + int output_h = conv_param_->output_h_; + int output_w = conv_param_->output_w_; + int ic4 = UP_DIV(channel_in, C4NUM); + int oc4 = UP_DIV(channel_out, C4NUM); + + /*=============================trans_input_============================*/ + size_t tile_buffer_size = thread_count_ * TILE_NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); + trans_input_ = reinterpret_cast(malloc(tile_buffer_size)); + if (trans_input_ == nullptr) { + MS_LOG(ERROR) << "malloc trans_input_ failed."; + return RET_ERROR; + } + memset(trans_input_, 0, tile_buffer_size); + + /*=============================gemm_out_============================*/ + gemm_out_ = reinterpret_cast( + malloc(thread_count_ * TILE_NUM * input_unit_ * input_unit_ * oc4 * C4NUM * sizeof(float))); + if (gemm_out_ == nullptr) { + MS_LOG(ERROR) << "malloc gemm_out_ failed."; + return RET_ERROR; + } + + /*=============================tmp_out_data_============================*/ + int out_w_block = UP_DIV(output_w, output_unit_); + int out_h_block = UP_DIV(output_h, output_unit_); + tmp_out_data_ = reinterpret_cast( + malloc(out_w_block * out_h_block * output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); + if (tmp_out_data_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; + return RET_ERROR; + } + + /*=============================tmp_data_============================*/ + tmp_data_ = reinterpret_cast(malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float))); + if (tmp_data_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_data_ failed."; + return RET_ERROR; + } + memset(tmp_data_, 0, C4NUM * input_unit_ * input_unit_ * sizeof(float)); + + tmp_buffer_address_list_[0] = trans_input_; + tmp_buffer_address_list_[1] = gemm_out_; + tmp_buffer_address_list_[2] = tmp_out_data_; + tmp_buffer_address_list_[3] = tmp_data_; + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = + ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::ConfigInputOutput() { + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return RET_ERROR; + } + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + + // choose input transformer function (4x4 unit or 8x8 unit) + input_trans_func_ = GetInputTransFunc(input_unit_); + if (input_trans_func_ == nullptr) { + MS_LOG(ERROR) << "Get input_trans_func failed."; + return RET_ERROR; + } + output_trans_func_ = GetOutputTransFunc(input_unit_, output_unit_); + if (output_trans_func_ == nullptr) { + MS_LOG(ERROR) << "Get output_trans_func_ failed."; + return RET_ERROR; + } +#ifdef ENABLE_ARM32 + gemm_func_ = IndirectGemmFp32_8x4; +#else + gemm_func_ = IndirectGemmFp32_8x8; +#endif + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + kernel_unit_ = conv_param_->kernel_h_; + input_unit_ = output_unit_ + kernel_unit_ - 1; + conv_param_->input_unit_ = input_unit_; + conv_param_->output_unit_ = output_unit_; + + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // malloc tmp buffer + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ret = ConfigInputOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConfigInputOutput failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::ReSize() { + if (tmp_data_ != nullptr) { + free(tmp_data_); + } + if (trans_input_ != nullptr) { + free(trans_input_); + } + if (gemm_out_ != nullptr) { + free(gemm_out_); + } + if (tmp_out_data_ != nullptr) { + free(tmp_out_data_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + kernel_unit_ = conv_param_->kernel_h_; + input_unit_ = output_unit_ + kernel_unit_ - 1; + conv_param_->input_unit_ = input_unit_; + conv_param_->output_unit_ = output_unit_; + + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ret = ConfigInputOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConfigInputOutput failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::RunImpl(int task_id) { + if (gemm_func_ == nullptr) { + MS_LOG(ERROR) << "gemm_func is nullptr."; + return RET_ERROR; + } + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + ConvWinogardFp32(reinterpret_cast(nhwc4_input_), reinterpret_cast(trans_weight_->GetData()), + reinterpret_cast(bias_data_), output_addr, tmp_buffer_address_list_, task_id, + conv_param_, input_trans_func_, output_trans_func_, gemm_func_); + return RET_OK; +} + +int ConvolutionWinogradImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ConvolutionWinograd Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionWinogradImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h new file mode 100644 index 0000000000..d11cc8ae4b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/base/matrix.h" + +namespace mindspore::kernel { +class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx, int output_unit) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx), output_unit_(output_unit) {} + ~ConvolutionWinogradCPUKernel() override { + if (tmp_data_ != nullptr) { + free(tmp_data_); + } + if (trans_input_ != nullptr) { + free(trans_input_); + } + if (gemm_out_ != nullptr) { + free(gemm_out_); + } + if (tmp_out_data_ != nullptr) { + free(tmp_out_data_); + } + delete trans_weight_; + }; + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int MallocFilterMatrix(int oc_block, int oc_block_num); + int InitTmpBuffer(); + int ConfigInputOutput(); + + private: + int kernel_unit_; + int input_unit_; + int output_unit_; + float *tmp_data_; + float *trans_input_; + float *gemm_out_; + float *tmp_out_data_; + Matrix *trans_weight_; + InputTransformUnitFunc input_trans_func_; + OutputTransformUnitFunc output_trans_func_; + TmpBufferAddress tmp_buffer_address_list_[5]; + GEMM_FUNC_FP32 gemm_func_ = nullptr; +}; +void WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit, + ConvParameter *conv_param, int oc_block); +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc new file mode 100644 index 0000000000..5282c6cada --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/crop.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/crop.h" +#include "src/runtime/kernel/arm/nnacl/crop_parameter.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Crop; + +namespace mindspore::kernel { +namespace { +int CropLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { + if (cdata == nullptr) { + MS_LOG(ERROR) << "Input cdata is nullptr!"; + return RET_NULL_PTR; + } + auto kernel = reinterpret_cast(cdata); + return kernel->CropParallelRun(thread_id); +} +} // namespace + +int CropCPUKernel::Init() { + schema::Format input0_format = inputs_[0]->GetFormat(); + if (input0_format != schema::Format_NCHW && input0_format != schema::Format_NHWC) { + MS_LOG(ERROR) << "Unsupport format " << input0_format; + return RET_FORMAT_ERR; + } + outputs_[0]->SetFormat(input0_format); + return RET_OK; +} + +int CropCPUKernel::CropParallelRun(int thread_id) { + auto input = inputs_[0]; + auto output = outputs_[0]; + float *input_data = reinterpret_cast(input->Data()); + float *output_data = reinterpret_cast(output->Data()); + Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), + reinterpret_cast(opParameter)); + return RET_OK; +} + +int CropCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + auto param = reinterpret_cast(opParameter); + if (output->shape()[1] < param->op_parameter_.thread_num_) { + float *input_data = reinterpret_cast(input->Data()); + float *output_data = reinterpret_cast(output->Data()); + Crop4DNoParallel(input_data, output_data, input->shape().data(), output->shape().data(), param); + return RET_OK; + } + + int ret = LiteBackendParallelLaunch(CropLaunch, this, param->op_parameter_.thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Crop launch fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h new file mode 100644 index 0000000000..f9656b2355 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CROP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CROP_H_ +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/base/crop_base.h" + +namespace mindspore::kernel { +class CropCPUKernel : public CropBaseCPUKernel { + public: + CropCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : CropBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~CropCPUKernel() = default; + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int CropParallelRun(int thread_id); +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CROP_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc new file mode 100644 index 0000000000..be571478fe --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -0,0 +1,240 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/deconvolution.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { +DeConvolutionCPUKernel::~DeConvolutionCPUKernel() { + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + weight_ptr_ = nullptr; + } + if (tmp_buffer_ != nullptr) { + free(tmp_buffer_); + tmp_buffer_ = nullptr; + } + if (pack_input_ != nullptr) { + free(pack_input_); + pack_input_ = nullptr; + } + if (pack_output_ != nullptr) { + free(pack_output_); + pack_output_ = nullptr; + } + return; +} + +int DeConvolutionCPUKernel::ReSize() { + if (tmp_buffer_ != nullptr) { + free(tmp_buffer_); + tmp_buffer_ = nullptr; + } + if (pack_input_ != nullptr) { + free(pack_input_); + pack_input_ = nullptr; + } + if (pack_output_ != nullptr) { + free(pack_output_); + pack_output_ = nullptr; + } + InitParam(); + + return RET_OK; +} + +int DeConvolutionCPUKernel::InitWeightBias() { + if (inputs_.size() == 3) { + bias_data_ = malloc(UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); + memcpy(bias_data_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); + } else { + bias_data_ = nullptr; + } + + size_t weight_pack_size = conv_param_->input_channel_ * conv_param_->kernel_w_ * conv_param_->kernel_h_ * + UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(float); + weight_ptr_ = reinterpret_cast(malloc(weight_pack_size)); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, weight_pack_size); + PackNHWCToC8HWN8Fp32(reinterpret_cast(inputs_[1]->Data()), weight_ptr_, conv_param_->input_channel_, + kernel_plane_, conv_param_->output_channel_); + return RET_OK; +} + +int DeConvolutionCPUKernel::InitParam() { + input_plane_ = conv_param_->input_h_ * conv_param_->input_w_; + kernel_plane_ = conv_param_->kernel_w_ * conv_param_->kernel_h_; + output_plane_ = conv_param_->output_h_ * conv_param_->output_w_; + + matmul_param_->row_ = input_plane_; + matmul_param_->deep_ = conv_param_->input_channel_; + matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; + matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); + matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_; + + thread_count_ = MSMIN(opParameter->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); + thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_); + + pack_input_ = reinterpret_cast(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float))); + if (pack_input_ == nullptr) { + MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; + return RET_ERROR; + } + + pack_output_ = + reinterpret_cast(malloc(UP_ROUND(conv_param_->output_channel_, C8NUM) * output_plane_ * sizeof(float))); + if (pack_output_ == nullptr) { + MS_LOG(ERROR) << "deconv Malloc pack_output_ error!"; + return RET_NULL_PTR; + } + + tmp_buffer_ = reinterpret_cast(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float))); + if (tmp_buffer_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvFp32Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoDeconv(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvFp32PostRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoPostFunc(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvFp32PostRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvolutionCPUKernel::DoDeconv(int task_id) { + int oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); + if (oc <= 0) { + return RET_OK; + } + + MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, + tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No, + matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_); + + return RET_OK; +} + +int DeConvolutionCPUKernel::DoPostFunc(int task_id) { + int oc = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); + if (oc <= 0) { + return RET_OK; + } + + float *bias = + (bias_data_ == nullptr) ? nullptr : reinterpret_cast(bias_data_) + thread_stride_ * task_id * C8NUM; + + DeConvPostFp32C8x8(tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, + pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, bias, + output_ptr_ + task_id * thread_stride_ * C8NUM, oc, conv_param_); + return RET_OK; +} + +int DeConvolutionCPUKernel::Init() { + ConvolutionBaseCPUKernel::Init(); + + int error_code = InitParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv InitParam error!"; + return error_code; + } + + error_code = InitWeightBias(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv InitWeightBias error!"; + return error_code; + } + return RET_OK; +} + +int DeConvolutionCPUKernel::Run() { + float *src_in = reinterpret_cast(inputs_[0]->Data()); + float *src_out = reinterpret_cast(outputs_[0]->Data()); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; + output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; + + RowMajor2Col8Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_); + + int error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv fp32 run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + + error_code = LiteBackendParallelLaunch(DeConvFp32PostRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv fp32 postrun error! error_code[" << error_code << "]"; + return RET_ERROR; + } + } + return RET_OK; +} + + +kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, CpuDeConvFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h new file mode 100644 index 0000000000..165ebf5d28 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/deconv.h" +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" + +namespace mindspore::kernel { +class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel { + public: + DeConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) { + matmul_param_ = new MatMulParameter(); + } + ~DeConvolutionCPUKernel() override; + int Init() override; + int Run() override; + int ReSize() override; + + public: + int DoDeconv(int task_id); + int DoPostFunc(int task_id); + + private: + int InitParam(); + int InitWeightBias(); + + private: + MatMulParameter *matmul_param_; + int input_plane_; + int kernel_plane_; + int output_plane_; + int thread_count_; + int thread_stride_; + float *weight_ptr_; + float *pack_input_; + float *pack_output_; + float *tmp_buffer_; + float *input_ptr_; + float *output_ptr_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc new file mode 100644 index 0000000000..07e5747847 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/deconvolution_depthwise.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; + +namespace mindspore::kernel { +int DeconvolutionDepthwiseCPUKernel::InitSlideParam() { + conv_param_->input_batch_ = outputs_.front()->shape().at(kNHWC_N); + conv_param_->input_h_ = outputs_.front()->shape().at(kNHWC_H); + conv_param_->input_w_ = outputs_.front()->shape().at(kNHWC_W); + conv_param_->input_channel_ = outputs_.front()->shape().at(kNHWC_C); + conv_param_->output_batch_ = inputs_.front()->shape().at(kNHWC_N); + conv_param_->output_h_ = inputs_.front()->shape().at(kNHWC_H); + conv_param_->output_w_ = inputs_.front()->shape().at(kNHWC_W); + conv_param_->output_channel_ = inputs_.front()->shape().at(kNHWC_C); + + // init sliding window param + sliding_ = new SlidingWindowParam; + InitSlidingParam(sliding_, conv_param_, C4NUM); + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::InitWeightBias() { + // init weight: o, h, w, i; o == group, i == 1 + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // init bias + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C4NUM * OC4 * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float)); + } + + // init threadNum; + conv_param_->thread_num_ = MSMIN(conv_param_->thread_num_, OC4); + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::InitBuffer() { + // malloc pack input and output buffer + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_input_, 0, pack_input_size * sizeof(float)); + + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_output_, 0, pack_output_size * sizeof(float)); + } + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::Init() { + InitSlideParam(); + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitWeightBias(); + if (ret != 0) { + MS_LOG(ERROR) << "Deconvolution depthwise fp32 InitWeightBias failed."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Deconvolution depthwise fp32 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::ReSize() { + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + InitSlideParam(); + + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitBuffer(); + if (ret != 0) { + MS_LOG(ERROR) << "Deconvolution depthwise fp32 InitBuffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::Execute(int task_id) { + DeconvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding_, task_id); + return RET_OK; +} + +int DeconvDwRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv_dw = reinterpret_cast(cdata); + auto ret = deconv_dw->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvolutionDepthwiseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + + // pack input: to nhwc4 + if (need_align_) { + PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_, + conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); + } else { + packed_input_ = input_addr; + } + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (!need_align_) { + memset(output_addr, 0, outputs_.at(kOutputIndex)->ElementsNum() * sizeof(float)); + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(DeconvDwRun, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} + +kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h new file mode 100644 index 0000000000..f2ed4b5d95 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_DEPTHWISE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class DeconvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { + public: + DeconvolutionDepthwiseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeconvolutionDepthwiseCPUKernel() override { + delete sliding_; + free(packed_weight_); + if (need_align_) { + free(packed_input_); + free(packed_output_); + } + }; + + int Init() override; + int InitSlideParam(); + int ReSize() override; + int Run() override; + + int InitBuffer(); + int InitWeightBias(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding_; + float *packed_weight_; + float *packed_input_; + float *packed_output_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_DEPTHWISE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc new file mode 100644 index 0000000000..c28a9cd4cd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/depth_to_space.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/kernel/arm/nnacl/depth_to_space.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthToSpace; + +namespace mindspore::kernel { + +int DepthToSpaceCPUKernel::Init() { + auto ret = DepthToSpaceBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + param->data_type_size_ = sizeof(float); + return RET_OK; +} + +int DepthToSpaceCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + const float *input_data = reinterpret_cast(input->Data()); + float *output_data = reinterpret_cast(output->Data()); + auto in_shape = input->shape(); + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + if (input->GetFormat() == schema::Format_NHWC) { + DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), param); + return RET_OK; + } else { + MS_LOG(ERROR) << "Depth_to_space only support NHWC now!"; + return RET_ERROR; + } +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h new file mode 100644 index 0000000000..e1676ccb8e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ + +#include +#include "src/runtime/kernel/arm/base/depth_to_space_base.h" + +namespace mindspore::kernel { +class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel { + public: + DepthToSpaceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DepthToSpaceCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc new file mode 100644 index 0000000000..0904f90b5d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/embedding_lookup.h" +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_EmbeddingLookup; + +namespace mindspore::kernel { +int EmbeddingLookupCPUKernel::Init() { + embedding_lookup_parameter_ = reinterpret_cast(opParameter); + embedding_lookup_parameter_->thread_num = thread_count_; + embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum(); + + embedding_lookup_parameter_->layer_size_ = 1; + auto in_shape = inputs_.front()->shape(); + for (int i = 1; i < in_shape.size(); ++i) { + embedding_lookup_parameter_->layer_size_ *= in_shape[i]; + } + + embedding_lookup_parameter_->layer_num_ = 0; + for (int i = 0; i < inputs_.size() - 1; ++i) { + embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0]; + } + + input_addr_ = reinterpret_cast( + std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_)); + if (input_addr_ == nullptr) { + MS_LOG(ERROR) << "Create memory failed"; + return mindspore::lite::RET_MEMORY_FAILED; + } + + embedding_lookup_parameter_->is_regulated_ = + reinterpret_cast(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_)); + if (embedding_lookup_parameter_->is_regulated_ == nullptr) { + MS_LOG(ERROR) << "Create memory failed"; + return mindspore::lite::RET_MEMORY_FAILED; + } + + for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) { + embedding_lookup_parameter_->is_regulated_[i] = embedding_lookup_parameter_->max_norm_ == 0; + } + + return RET_OK; +} + +int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; } + +int EmbeddingLookupCPUKernel::DoExcute(int task_id) { + int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "embedding lookup error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int EmbeddingLookupRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto EmbeddingLookupData = reinterpret_cast(cdata); + auto ret = EmbeddingLookupData->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "EmbeddingLookupRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int EmbeddingLookupCPUKernel::Run() { + int dest_loc = 0; + for (int i = 0; i < inputs_.size() - 1; i++) { + auto input_t = reinterpret_cast(inputs_.at(i)->Data()); + memcpy(input_addr_ + dest_loc, input_t, sizeof(float) * inputs_.at(i)->ElementsNum()); + dest_loc += inputs_.at(i)->ElementsNum(); + } + output_addr_ = reinterpret_cast(outputs_.front()->Data()); + ids_addr_ = reinterpret_cast(inputs_.back()->Data()); + + auto ret = LiteBackendParallelLaunch(EmbeddingLookupRun, this, embedding_lookup_parameter_->thread_num); + if (ret != RET_OK) { + MS_LOG(ERROR) << "EmbeddingLookup error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuEmbeddingLookupFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_EmbeddingLookup); + auto *kernel = new (std::nothrow) EmbeddingLookupCPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init Kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_EmbeddingLookup, CpuEmbeddingLookupFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h new file mode 100644 index 0000000000..6afa0d5620 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" + +namespace mindspore::kernel { +class EmbeddingLookupCPUKernel : public LiteKernel { + public: + explicit EmbeddingLookupCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~EmbeddingLookupCPUKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const lite::Context *ctx_; + EmbeddingLookupParameter *embedding_lookup_parameter_; + + private: + float *input_addr_; + float *output_addr_; + int *ids_addr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc new file mode 100644 index 0000000000..8f37cb30bb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/expandDims.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ExpandDims; + +namespace mindspore::kernel { +int ExpandDimsCPUKernel::Init() { + int ret = ReSize(); + return ret; +} + +int ExpandDimsCPUKernel::ReSize() { + data_size_ = inputs_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int ExpandDimsCPUKernel::DoExpandDims(int task_id) { + size_t size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size == 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + int ret = ExpandDims(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExpandDimsRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ExpandDimsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoExpandDims(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExpandDimsRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ExpandDimsCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); + out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); + int ret = LiteBackendParallelLaunch(ExpandDimsRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExpandDimsRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_ExpandDims); + auto *kernel = new (std::nothrow) ExpandDimsCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ExpandDimsCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, CpuExpandsDimsFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h new file mode 100644 index 0000000000..3ed5b12eae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EXPANDDIMS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EXPANDDIMS_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/expandDims.h" +#include "schema/model_generated.h" + +#include "include/context.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ExpandDimsCPUKernel : public LiteKernel { + public: + ExpandDimsCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~ExpandDimsCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExpandDims(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + size_t data_size_; + float *in_ptr_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_EXPANDDIMS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc new file mode 100644 index 0000000000..96a583decc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/fill.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Fill; + +namespace mindspore::kernel { + +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +} // namespace + +int FillCPUKernel::Init() { + data_size_ = outputs_.front()->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int FillCPUKernel::ReSize() { return RET_OK; } + +int FillCPUKernel::DoFill(int task_id) { + int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + int ret = Fill(out_ptr_ + offset, size, src_data_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int FillRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoFill(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int FillCPUKernel::Run() { + auto fillData = inputs_.at(inputs_.size() - 1); + auto output = outputs_.front(); + auto fill_data = reinterpret_cast(fillData->Data()); + src_data_ = fill_data[0]; + out_ptr_ = reinterpret_cast(output->Data()); + int ret = LiteBackendParallelLaunch(FillRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_Fill. "; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Fill); + auto *kernel = new (std::nothrow) FillCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new FillCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Fill, CpuFillFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.h new file mode 100644 index 0000000000..2010681f56 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/fp32/fill.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FillCPUKernel : public LiteKernel { + public: + FillCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~FillCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoFill(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + float src_data_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc new file mode 100644 index 0000000000..3adddbf91f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/flatten.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/flatten.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Flatten; + +namespace mindspore::kernel { +int FlattenCPUKernel::Init() { + auto output_shape = outputs_[0]->shape(); + flatten_param_->size = sizeof(float); + for (int i = 0; i < output_shape.size(); i++) { + flatten_param_->size *= output_shape[i]; + } + return RET_OK; +} + +int FlattenCPUKernel::ReSize() { return RET_OK; } + +int FlattenCPUKernel::Run() { + auto input = reinterpret_cast(inputs_[0]->Data()); + auto output = reinterpret_cast(outputs_[0]->Data()); + Flatten(input, output, flatten_param_); + return RET_OK; +} + +kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_Flatten. "; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Flatten); + auto *kernel = new (std::nothrow) FlattenCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new FlattenCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Flatten, CpuFlattenFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h new file mode 100644 index 0000000000..e84104f50f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/flatten.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FlattenCPUKernel : public LiteKernel { + public: + FlattenCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + flatten_param_ = reinterpret_cast(parameter); + } + ~FlattenCPUKernel() override { delete flatten_param_; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + FlattenParameter *flatten_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc new file mode 100644 index 0000000000..366167b5f3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/fullconnection.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +FullconnectionCPUKernel::~FullconnectionCPUKernel() { + if (a_c8_ptr_ != nullptr) { + free(a_c8_ptr_); + a_c8_ptr_ = nullptr; + } + if (b_r8_ptr_ != nullptr) { + free(b_r8_ptr_); + b_r8_ptr_ = nullptr; + } + if (c_r8x8_ptr_ != nullptr) { + free(c_r8x8_ptr_); + c_r8x8_ptr_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } +} + +int FullconnectionCPUKernel::ReSize() { return RET_OK; } + +int FullconnectionCPUKernel::Init() { + fc_param_->row_ = (inputs_[0]->shape())[0]; + fc_param_->col_ = (inputs_[1]->shape())[0]; + fc_param_->deep_ = (inputs_[1]->shape())[1]; + + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); + fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); + + thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); + + bias_ptr_ = reinterpret_cast(malloc(fc_param_->col_8_ * sizeof(float))); + memset(bias_ptr_, 0, fc_param_->col_8_ * sizeof(float)); + if (inputs_.size() == 3) { + memcpy(bias_ptr_, inputs_[2]->Data(), fc_param_->col_ * sizeof(float)); + } + + a_c8_ptr_ = reinterpret_cast(malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(float))); + if (a_c8_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(float)); + + b_r8_ptr_ = reinterpret_cast(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float))); + if (b_r8_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)); + RowMajor2Col8Major(reinterpret_cast(inputs_[1]->Data()), b_r8_ptr_, fc_param_->col_, fc_param_->deep_); + + c_r8x8_ptr_ = reinterpret_cast(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float))); + if (c_r8x8_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)); + return RET_OK; +} + +int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto fc = reinterpret_cast(cdata); + auto error_code = fc->DoMatmul(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "FcFp32MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int FullconnectionCPUKernel::DoMatmul(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, + c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, + bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_, + cur_oc * 8); + return RET_OK; +} + +int FullconnectionCPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + + RowMajor2Col8Major(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + + LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_); + + Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_); + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h new file mode 100644 index 0000000000..be4c1f72b5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ + +#include +#include "include/errorcode.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" +#include "src/runtime/kernel/arm/base/fullconnection_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { + public: + FullconnectionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~FullconnectionCPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + + public: + int DoMatmul(int task_id); + + private: + float *a_c8_ptr_; + float *b_r8_ptr_; + float *c_r8x8_ptr_; + float *bias_ptr_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc new file mode 100644 index 0000000000..4452655bbc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/fused_batchnorm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_FusedBatchNorm; + +namespace mindspore::kernel { +int FusedBatchnormCPUKernel::Init() { + input_shape_ = reinterpret_cast(malloc(sizeof(int) * inputs_[0]->shape().size())); + memcpy(input_shape_, inputs_[0]->shape().data(), inputs_[0]->shape().size() * sizeof(int)); + return RET_OK; +} + +int FusedBatchnormCPUKernel::ReSize() { return RET_OK; } + +int FusedBatchnormCPUKernel::Run() { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto scale_addr = reinterpret_cast(inputs_.at(1)->Data()); + auto offest_addr = reinterpret_cast(inputs_.at(2)->Data()); + auto mean_addr = reinterpret_cast(inputs_.at(3)->Data()); + auto variance_addr = reinterpret_cast(inputs_.at(4)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + + FusedBatchNorm(input_addr, scale_addr, offest_addr, mean_addr, variance_addr, input_shape_, + fused_batchnorm_param_->epsilon_, output_addr); + return RET_OK; +} + +kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_FusedBatchNorm); + auto *kernel = new (std::nothrow) FusedBatchnormCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new FusedBatchnormCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h new file mode 100644 index 0000000000..bce0660be9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fused_batchnorm.h" + +namespace mindspore::kernel { +class FusedBatchnormCPUKernel : public LiteKernel { + public: + FusedBatchnormCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + fused_batchnorm_param_ = reinterpret_cast(parameter); + } + ~FusedBatchnormCPUKernel() override { delete fused_batchnorm_param_; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + int *input_shape_{}; + FusedBatchNormParameter *fused_batchnorm_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc new file mode 100644 index 0000000000..fd073a9f00 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32/gather.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Gather; + +namespace mindspore::kernel { + +int GatherCPUKernel::Init() { + axis_ = (reinterpret_cast(opParameter))->axis_; + batchDims_ = (reinterpret_cast(opParameter))->batchDims_; + return RET_OK; +} + +int GatherCPUKernel::ReSize() { return RET_OK; } + +int GatherCPUKernel::DoGather(int task_id) { + auto input_tensor = inputs_.at(0); + auto indices_tensor = inputs_.at(1); + auto out_tensor = outputs_.at(0); + + auto input_ptr = reinterpret_cast(input_tensor->Data()); + auto indices_ptr = reinterpret_cast(indices_tensor->Data()); + auto output_ptr = reinterpret_cast(out_tensor->Data()); + + auto in_shape = input_tensor->shape(); + int in_rank = in_shape.size(); + int indices_element_size = indices_tensor->ElementsNum(); + + const int limit = in_shape[axis_]; + for (size_t i = 0; i < indices_element_size; ++i) { + if (indices_ptr[i] >= limit) { + MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]"; + return RET_ERROR; + } + } + + int outer_size = 1; + for (int i = 0; i < axis_; ++i) { + outer_size *= in_shape[i]; + } + + int inner_size = 1; + for (int i = axis_ + 1; i < in_rank; ++i) { + inner_size *= in_shape[i]; + } + + int stride = UP_DIV(outer_size, thread_count_); + int count = MSMIN(stride, outer_size - stride * task_id); + + input_ptr += stride * task_id * limit; + output_ptr += stride * task_id * indices_element_size; + + auto error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr); + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto gather_kernel = reinterpret_cast(cdata); + auto error_code = gather_kernel->DoGather(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int GatherCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(GatherRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Gather); + + auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, CpuGatherFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h new file mode 100644 index 0000000000..3204144c52 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/fp32/gather.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class GatherCPUKernel : public LiteKernel { + public: + GatherCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + ~GatherCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoGather(int task_id); + + private: + int thread_count_; + int batchDims_; + int axis_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc new file mode 100644 index 0000000000..109ce338f8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/gatherNd.h" +#include +#include +#include "schema/model_generated.h" +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_GatherNd; + +namespace mindspore::kernel { + +GatherNdCPUKernel::~GatherNdCPUKernel() { + if (in_offset_ != nullptr) { + free(in_offset_); + in_offset_ = nullptr; + } +} + +int GatherNdCPUKernel::Init() { + auto indices_tensor = inputs_.at(1); + auto indices_shape = indices_tensor->shape(); + int indices_rank = indices_shape.size(); + count_ = 1; + for (int i = 0; i < indices_rank - 1; ++i) { + count_ *= indices_shape[i]; + } + + in_offset_ = reinterpret_cast(malloc(count_ * sizeof(int))); + if (in_offset_ == nullptr) { + MS_LOG(ERROR) << "GatherNd Malloc in_offset_ error!"; + return RET_ERROR; + } + (void)memset(in_offset_, 0, count_ * sizeof(int)); + + thread_sz_count_ = MSMIN(thread_count_, count_); + thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); + int ret = ReSize(); + return ret; +} + +int GatherNdCPUKernel::ReSize() { + auto in_shape = inputs_.front()->shape(); + int in_rank = in_shape.size(); + auto indices_tensor = inputs_.at(1); + auto indices_shape = indices_tensor->shape(); + int indices_rank = indices_shape.size(); + int idx_lastshape = indices_shape[indices_rank - 1]; + auto indices_ptr = reinterpret_cast(indices_tensor->Data()); + area_ = 1; + for (int i = idx_lastshape; i < in_rank; ++i) { + area_ *= in_shape[i]; + } + std::vector in_stride(in_rank); + in_stride[in_rank - 1] = 1; + for (int i = in_rank - 2; i >= 0; --i) { + in_stride[i] = in_shape[i + 1] * in_stride[i + 1]; + } + + int idx_stride = idx_lastshape; + for (int j = 0; j < count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + + return RET_OK; +} + +int GatherNdCPUKernel::DoGatherNd(int task_id) { + int count = MSMIN(thread_sz_stride_, count_ - task_id * thread_sz_stride_); + if (count <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + auto ret = GatherNd(in_ptr_, out_ptr_ + offset * area_, in_offset_ + offset, area_, count); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int GatherNdRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoGatherNd(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int GatherNdCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_.front()->Data()); + out_ptr_ = reinterpret_cast(outputs_.front()->Data()); + int ret = LiteBackendParallelLaunch(GatherNdRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_GatherNd); + + auto *kernel = new (std::nothrow) GatherNdCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h new file mode 100644 index 0000000000..6c8a6064a6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/fp32/gatherNd.h" +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class GatherNdCPUKernel : public LiteKernel { + public: + GatherNdCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~GatherNdCPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + int DoGatherNd(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int count_; + int area_; + int *in_offset_ = nullptr; + float *in_ptr_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc new file mode 100644 index 0000000000..bb1286eebd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/local_response_norm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_LocalResponseNormalization; + +namespace mindspore::kernel { + +int LocalResponseNormCPUKernel::Init() { return RET_OK; } + +int LocalResponseNormCPUKernel::ReSize() { return RET_OK; } + +int LocalResponseNormCPUKernel::DoLocalResponseNorm(int task_id) { + auto input_tensor = inputs_.front(); + auto out_tensor = outputs_.front(); + auto input_ptr = reinterpret_cast(input_tensor->Data()); + auto output_ptr = reinterpret_cast(out_tensor->Data()); + + auto in_shape = input_tensor->shape(); + MS_ASSERT(in_shape.size() == 4); + + int batch = in_shape[0]; + int height = in_shape[1]; + int width = in_shape[2]; + int channel = in_shape[3]; + + int outer_size = batch * width * height; + int stride = UP_DIV(outer_size, thread_count_); + int count = MSMIN(stride, outer_size - stride * task_id); + + input_ptr += stride * task_id * channel; + output_ptr += stride * task_id * channel; + + auto error_code = LocalResponseNorm(input_ptr, count, channel, output_ptr, + reinterpret_cast(opParameter)); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DoLocalResponseNorm error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int LocalResponseNormRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto lrn = reinterpret_cast(cdata); + auto error_code = lrn->DoLocalResponseNorm(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "LocalResponseNormRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int LocalResponseNormCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(LocalResponseNormRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "LocalResponseNorm function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuLocalResponseNormFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_LocalResponseNormalization); + + auto *kernel = new (std::nothrow) LocalResponseNormCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new LocalResponseNormCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LocalResponseNormalization, CpuLocalResponseNormFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h new file mode 100644 index 0000000000..ea65a3e923 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/fp32/local_response_norm.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class LocalResponseNormCPUKernel : public LiteKernel { + public: + LocalResponseNormCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + ~LocalResponseNormCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoLocalResponseNorm(int task_id); + + private: + int thread_count_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc new file mode 100644 index 0000000000..cc779de6e5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/lstm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Lstm; + +namespace mindspore::kernel { +int LstmCPUKernel::InitParam() { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + std::vector in_shape = input->shape(); + lstm_parm_->seq_len_ = in_shape[0]; + lstm_parm_->batch_ = in_shape[1]; + lstm_parm_->input_size_ = in_shape[2]; + + auto weight_i = inputs_[1]; + MS_ASSERT(weight_i != nullptr); + std::vector w_shape = weight_i->shape(); + lstm_parm_->hidden_size_ = w_shape[1] / 4; + + lstm_parm_->input_step_ = lstm_parm_->batch_ * lstm_parm_->input_size_; + lstm_parm_->output_step_ = lstm_parm_->bidirectional_ ? 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ + : lstm_parm_->batch_ * lstm_parm_->hidden_size_; + return RET_OK; +} + +int LstmCPUKernel::InitBuffer() { + gate_buffer_ = reinterpret_cast(malloc(4 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmCPUKernel::InitWeightBias() { + // copy weight_i and weight_h + auto weight_i = inputs_.at(1); + MS_ASSERT(weight_i != nullptr); + weight_i_ptr_ = reinterpret_cast(malloc(weight_i->ElementsNum() * sizeof(float))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; + return RET_ERROR; + } + memcpy(weight_i_ptr_, weight_i->Data(), weight_i->ElementsNum() * sizeof(float)); + + auto weight_h = inputs_.at(2); + MS_ASSERT(weight_h != nullptr); + weight_h_ptr_ = reinterpret_cast(malloc(weight_h->ElementsNum() * sizeof(float))); + if (weight_h_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error."; + return RET_ERROR; + } + memcpy(weight_h_ptr_, weight_h->Data(), weight_h->ElementsNum() * sizeof(float)); + + // init bias + int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * lstm_parm_->hidden_size_ : 4 * lstm_parm_->hidden_size_; + bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + + auto bias_data = reinterpret_cast(inputs_.at(3)->Data()); + int state_bias_offset = 4 * lstm_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; + } + if (lstm_parm_->bidirectional_) { + bias_data += 4 * lstm_parm_->hidden_size_ * 2; + auto backward_bias = bias_ptr_ + 4 * lstm_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; + } + } + return RET_OK; +} + +int LstmCPUKernel::Init() { + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; + return RET_ERROR; + } + + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmCPUKernel::ReSize() { + free(gate_buffer_); + + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmCPUKernel::Run() { + auto input = inputs_.at(kInputIndex); + MS_ASSERT(input != nullptr); + auto hidden_state = inputs_.at(4); + MS_ASSERT(hidden_state != nullptr); + auto cell_state = inputs_.at(5); + MS_ASSERT(cell_state != nullptr); + auto output = outputs_.at(0); + MS_ASSERT(output != nullptr); + + auto input_ptr = reinterpret_cast(input->Data()); + auto output_ptr = reinterpret_cast(output->Data()); + + auto output_hidden_state = outputs_[1]; + memcpy(output_hidden_state->Data(), hidden_state->Data(), hidden_state->ElementsNum() * sizeof(float)); + auto output_cell_state = outputs_[2]; + memcpy(output_cell_state->Data(), cell_state->Data(), cell_state->ElementsNum() * sizeof(float)); + + Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, + reinterpret_cast(output_hidden_state->Data()), reinterpret_cast(output_cell_state->Data()), + gate_buffer_, lstm_parm_); + return RET_OK; +} + +kernel::LiteKernel *CpuLstmKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Lstm); + + auto *kernel = new (std::nothrow) LstmCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Lstm, CpuLstmKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h new file mode 100644 index 0000000000..61488ca2c4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/lstm.h" + +namespace mindspore::kernel { +class LstmCPUKernel : public LiteKernel { + public: + LstmCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + lstm_parm_ = reinterpret_cast(opParameter); + } + + ~LstmCPUKernel() override { + free(gate_buffer_); + free(weight_i_ptr_); + free(weight_h_ptr_); + free(bias_ptr_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + int InitParam(); + int InitBuffer(); + int InitWeightBias(); + + private: + float *gate_buffer_; + float *weight_i_ptr_; + float *weight_h_ptr_; + float *bias_ptr_; + LstmParameter *lstm_parm_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc new file mode 100644 index 0000000000..ca20fb76b6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/matmul.h" +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +MatmulCPUKernel::~MatmulCPUKernel() { + ctx_->allocator->Free(a_c8_ptr_); + ctx_->allocator->Free(b_r8_ptr_); + ctx_->allocator->Free(c_r8x8_ptr_); +} + +int MatmulCPUKernel::ReSize() { return RET_OK; } + +int MatmulCPUKernel::Init() { + int batch = 1; + auto x_shape = inputs_[0]->shape(); + auto o_shape = outputs_[0]->shape(); + for (int i = 0; i < x_shape.size() - 2; ++i) { + batch *= x_shape[i]; + } + params_->batch = batch; + params_->row_ = o_shape[o_shape.size() - 2]; + params_->col_ = o_shape[o_shape.size() - 1]; + params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; + params_->row_8_ = UP_ROUND(params_->row_, 8); + params_->col_8_ = UP_ROUND(params_->col_, 8); + thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); + + a_c8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(float))); + if (!a_c8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(float)); + b_r8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(float))); + if (!b_r8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(float)); + c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(float))); + if (!c_r8x8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float)); + return RET_OK; +} + +int MatmulCPUKernel::RunImpl(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; + auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; + MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8); + return RET_OK; +} + +int MatmulFloatRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto op = reinterpret_cast(cdata); + auto error_code = op->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "MatmulFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int MatmulCPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_[0]->Data()); + auto b_ptr = reinterpret_cast(inputs_[1]->Data()); + auto c_ptr = reinterpret_cast(outputs_[0]->Data()); + auto a_stride = params_->row_ * params_->deep_; + auto b_stride = params_->deep_ * params_->col_; + auto c_stride = params_->row_ * params_->col_; + for (int i = 0; i < params_->batch; ++i) { + auto cur_a_ptr = a_ptr + i * a_stride; + auto cur_b_ptr = b_ptr + i * b_stride; + auto cur_c_ptr = c_ptr + i * c_stride; + if (params_->a_transpose_) { + RowMajor2Row8Major(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); + } else { + RowMajor2Col8Major(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_); + } + if (params_->b_transpose_) { + RowMajor2Col8Major(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_); + } else { + RowMajor2Row8Major(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_); + } + LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_); + Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h new file mode 100644 index 0000000000..0f617d0179 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" +#include "src/runtime/kernel/arm/base/matmul_base.h" + +namespace mindspore::kernel { +class MatmulCPUKernel : public MatmulBaseCPUKernel { + public: + explicit MatmulCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : MatmulBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~MatmulCPUKernel() override; + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: + float *a_c8_ptr_; + float *b_r8_ptr_; + float *c_r8x8_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc new file mode 100644 index 0000000000..9ee4f8a577 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Nchw2Nhwc; + +namespace mindspore::kernel { +int Nchw2NhwcCPUKernel::Init() { return RET_OK; } + +int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; } + +int Nchw2NhwcCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + + PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + return RET_OK; +} + +kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc); + auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h new file mode 100644 index 0000000000..df45cdd2d6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ + +#include +#include "src/lite_kernel.h" + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" + +namespace mindspore::kernel { +class Nchw2NhwcCPUKernel : public LiteKernel { + public: + Nchw2NhwcCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~Nchw2NhwcCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc new file mode 100644 index 0000000000..480d58eaa2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/nhwc2nchw.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Nhwc2Nchw; + +namespace mindspore::kernel { +int Nhwc2NchwCPUKernel::Init() { return RET_OK; } + +int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; } + +int Nhwc2NchwCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + + PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + return RET_OK; +} + +kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw); + auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h new file mode 100644 index 0000000000..8165283ecb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ + +#include +#include "src/lite_kernel.h" + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" + +namespace mindspore::kernel { +class Nhwc2NchwCPUKernel : public LiteKernel { + public: + Nhwc2NchwCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~Nhwc2NchwCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc new file mode 100644 index 0000000000..afff5d42de --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/one_hot.h" +#include "src/runtime/kernel/arm/nnacl/fp32/one_hot.h" +#include "schema/model_generated.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_OneHot; + +namespace mindspore::kernel { +namespace { +constexpr size_t kInputNum = 4; +constexpr size_t kOutputNum = 1; +} // namespace + +int OneHotCPUKernel::Init() { + // indices depth on_value off_value + if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << inputs_.size() + << ", output size should be" << kOutputNum << ", got " << outputs_.size(); + return RET_ERROR; + } + + auto indices = inputs_.at(0); + if (indices == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr"; + return RET_NULL_PTR; + } + auto indices_shape = indices->shape(); + outer_size_ = 1; + for (size_t i = 0; i < static_cast(axis_); i++) { + outer_size_ *= indices_shape[i]; + } + inner_size_ = indices->ElementsNum() / outer_size_; + + if (context_ == nullptr) { + MS_LOG(ERROR) << "OneHot context nullptr"; + return RET_NULL_PTR; + } + thread_num_ = context_->thread_num_; + + const int indices_rank = static_cast(inputs_.at(0)->shape().size()); + if (axis_ < 0) { + axis_ += indices_rank + 1; + } + + return RET_OK; +} + +int RunOneHot(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto onehot_kernel = reinterpret_cast(cdata); + if (onehot_kernel == nullptr) { + MS_LOG(ERROR) << "cast OneHotCPUKernel failed"; + return RET_ERROR; + } + auto error_code = onehot_kernel->OneHotImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "RunOneHot error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int OneHotCPUKernel::OneHotImpl(int task_id) { + auto indices_data = static_cast(inputs_.at(0)->Data()); + auto output = outputs_.at(0); + if (output == nullptr) { + MS_LOG(ERROR) << "OneHot output nullptr"; + return RET_NULL_PTR; + } + auto output_data = static_cast(output->Data()); + + auto ret = GetParams(); + if (ret != RET_OK) { + return ret; + } + auto one_hot_param = reinterpret_cast(opParameter); + + ret = OneHot(indices_data, output_data, one_hot_param, task_id, thread_num_); + return ret; +} + +int OneHotCPUKernel::GetParams() { + auto one_hot_param = reinterpret_cast(opParameter); + if (one_hot_param == nullptr) { + MS_LOG(ERROR) << "cast OneHotParameter nullptr"; + return RET_NULL_PTR; + } + + auto depth_tensor = inputs_.at(1); + if (depth_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[1] depth nullptr"; + return RET_NULL_PTR; + } + const int *depth = static_cast(depth_tensor->Data()); + if (depth == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->depth_ = *depth; + + auto on_value_tensor = inputs_.at(2); + if (on_value_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; + return RET_NULL_PTR; + } + const float *on_value = static_cast(on_value_tensor->Data()); + if (on_value == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->on_value_ = *on_value; + + auto off_value_tensor = inputs_.at(3); + if (off_value_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr"; + return RET_NULL_PTR; + } + const float *off_value = static_cast(off_value_tensor->Data()); + if (off_value == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->off_value_ = *off_value; + + one_hot_param->outer_size_ = outer_size_; + one_hot_param->inner_size_ = inner_size_; + + return RET_OK; +} + +int OneHotCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(RunOneHot, this, context_->thread_num_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "OneHot function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuOneHotFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter != nullptr) { + MS_LOG(ERROR) << "OneHot opParameter nullptr."; + return nullptr; + } + if (desc.type != schema::PrimitiveType_OneHot) { + MS_LOG(ERROR) << "OneHot desc type should be " << schema::PrimitiveType_OneHot << " got " << desc.type; + return nullptr; + } + auto *kernel = new (std::nothrow) OneHotCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "OneHot new kernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OneHot, CpuOneHotFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h new file mode 100644 index 0000000000..dd823f4460 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ONE_HOT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ONE_HOT_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class OneHotCPUKernel : public LiteKernel { + public: + OneHotCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + + ~OneHotCPUKernel() override = default; + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int OneHotImpl(int task_id); + + private: + int GetParams(); + + private: + const lite::Context *context_; + int thread_num_; + int axis_; + int outer_size_; + int inner_size_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ONE_HOT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc new file mode 100644 index 0000000000..84c51509ba --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc @@ -0,0 +1,78 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/opt_momentum.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_OptMomentum; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int OptMomentumCPUKernel::ReSize() { return 0; } + +int OptMomentumCPUKernel::Run() { + if (inputs_.size() != 5 || !outputs_.empty()) { + MS_LOG(ERROR) << "OptMomentumCPUKernel error input output size!"; + return RET_ERROR; + } + + if (inputs_[0]->ElementsNum() != inputs_[1]->ElementsNum() || + inputs_[0]->ElementsNum() != inputs_[3]->ElementsNum()) { + MS_LOG(ERROR) << "error input data size!"; + return RET_ERROR; + } + auto weight = reinterpret_cast(inputs_[0]->Data()); + auto accumulate = reinterpret_cast(inputs_[1]->Data()); + float learning_rate = reinterpret_cast(inputs_[2]->Data())[0]; + auto gradient = reinterpret_cast(inputs_[3]->Data()); + float moment = reinterpret_cast(inputs_[4]->Data())[0]; + size_t elem_num = inputs_[0]->ElementsNum(); + for (size_t i = 0; i < elem_num; ++i) { + accumulate[i] = accumulate[i] * moment + gradient[i]; + weight[i] -= accumulate[i] * learning_rate; + } + return RET_OK; +} + +int OptMomentumCPUKernel::Init() { return 0; } + +kernel::LiteKernel *CpuOptMomentumFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_OptMomentum); + auto *kernel = new (std::nothrow) OptMomentumCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OptMomentum, CpuOptMomentumFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h new file mode 100644 index 0000000000..ccc2871779 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class OptMomentumCPUKernel : public LiteKernel { + public: + explicit OptMomentumCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~OptMomentumCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc new file mode 100644 index 0000000000..511c7f9d87 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/pad.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pad; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +} // namespace + +int PadCPUKernel::Init() { + if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << inputs_.size() << ", output size should be" + << kOutputNum << ", got " << outputs_.size(); + return RET_ERROR; + } + + auto input = inputs_.at(0); + auto output = outputs_.at(0); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "Pad input or output nullptr"; + return RET_NULL_PTR; + } + + auto rank = input->shape().size(); + if (rank > DEFAULT_PAD_NDIMS) { + MS_LOG(ERROR) << "Pad input rank should <= " << DEFAULT_PAD_NDIMS << ", got " << rank; + return RET_ERROR; + } + + for (int i = 0; i < rank; i++) { + in_[DEFAULT_PAD_NDIMS - rank + i] = input->shape()[i]; + out_[DEFAULT_PAD_NDIMS - rank + i] = output->shape()[i]; + } + return RET_OK; +} + +int PadImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto padKernel = reinterpret_cast(cdata); + int error_code = padKernel->RunImpl(task_id); + if (error_code != NNACL_OK) { + MS_LOG(ERROR) << "Pad Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PadCPUKernel::RunImpl(int task_id) { + auto input = inputs_.at(0); + auto output = outputs_.at(0); + + auto input_data = reinterpret_cast(input->Data()); + auto output_data = reinterpret_cast(output->Data()); + + Pad(input_data, output_data, in_, out_, pad_param_->paddings_, task_id, context_->thread_num_); + + return RET_OK; +} + +int PadCPUKernel::Run() { + auto output = outputs_.at(0); + int output_size = output->DataSize(); + + auto output_data = reinterpret_cast(output->Data()); + // todo parallel memset to save time + memset(output_data, 0, output_size * sizeof(float)); + + int error_code = LiteBackendParallelLaunch(PadImpl, this, context_->thread_num_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h new file mode 100644 index 0000000000..c48ddf581c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/nnacl/fp32/pad.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +namespace mindspore::kernel { +class PadCPUKernel : public LiteKernel { + public: + PadCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx) { + pad_param_ = reinterpret_cast(parameter); + } + + ~PadCPUKernel() {} + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int RunImpl(int task_id); + + private: + const lite::Context *context_; + const PadParameter *pad_param_; + int in_[4] = {1, 1, 1, 1}; + int out_[4] = {1, 1, 1, 1}; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc new file mode 100644 index 0000000000..960734f994 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/pooling.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pooling; + +namespace mindspore::kernel { +int PoolingCPUKernel::Init() { + auto ret = PoolingBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PoolingBase Init failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingCPUKernel::ReSize() { + auto ret = Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Pooling resize init failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingCPUKernel::RunImpl(int task_id) { + auto input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (pooling_param_->max_pooling_) { + MaxPooling(input_ptr, output_ptr, pooling_param_, task_id); + } else { + AvgPooling(input_ptr, output_ptr, pooling_param_, task_id); + } + return RET_OK; +} + +int PoolingImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto pooling = reinterpret_cast(cdata); + auto error_code = pooling->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(PoolingImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h new file mode 100644 index 0000000000..7edd82a537 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ + +#include +#include "src/runtime/kernel/arm/base/pooling_base.h" +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "include/context.h" + +namespace mindspore::kernel { +using mindspore::lite::Context; +using mindspore::schema::PadMode; +using mindspore::schema::PoolMode; +using mindspore::schema::QuantType; +using mindspore::schema::RoundMode; + +class PoolingCPUKernel : public PoolingBaseCPUKernel { + public: + PoolingCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~PoolingCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc new file mode 100644 index 0000000000..23606e6fa7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/pooling_grad.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_PoolingGrad; + +namespace mindspore::kernel { +#if 0 +int PoolingGradCPUKernel::TfPadding(int input_w, int input_h, int &output_w, int &output_h) { + PoolingParameter *pool_param = reinterpret_cast (opParameter); + + auto stride_w = pool_param->stride_w_; + auto stride_h = pool_param->stride_h_; + auto window_w = pool_param->window_w_; + auto window_h = pool_param->window_h_; + auto pad_up = pool_param->pad_u_; + auto pad_down = pool_param->pad_d_; + auto pad_left = pool_param->pad_l_; + auto pad_right = pool_param->pad_r_; + if (pool_param->pad_mode_ == PADMODE_SAME) { + output_w = ceil(input_w / stride_w); + output_h = ceil(input_h / stride_h); + } else { + output_w = ceil((input_w + pad_left + pad_right - window_w + 1) / stride_w); + output_h = ceil((input_h + pad_up + pad_down - window_h + 1) / stride_h); + } + return RET_OK; +} + +int PoolingGradCPUKernel::CaffePadding(int input_w, int input_h, int &output_w, int &output_h) { + PoolingParameter *pool_param = reinterpret_cast (opParameter); + + auto round_mode = pool_param->round_mode_; + auto stride_w = pool_param->stride_w_; + auto stride_h = pool_param->stride_h_; + auto window_w = pool_param->window_w_; + auto window_h = pool_param->window_h_; + auto pad_up = pool_param->pad_u_; + auto pad_down = pool_param->pad_d_; + auto pad_left = pool_param->pad_l_; + auto pad_right = pool_param->pad_r_; + if (round_mode == ROUNDMODE_FLOOR && false) { + output_w = floor((input_w + pad_left + pad_right - window_w) / stride_w + 1); + output_h = floor((input_h + pad_up + pad_down - window_h) / stride_h + 1); + } else if (round_mode == ROUNDMODE_CEIL || true) { + output_w = ceil((input_w + pad_left + pad_right - window_w) / stride_w + 1); + output_h = ceil((input_h + pad_up + pad_down - window_h) / stride_h + 1); + } else { + MS_LOG(ERROR) << "round mode not support."; + } + + if (pad_left > 0 || pad_up > 0) { + if ((output_w - 1) * stride_w >= input_w + pad_left) { + --output_w; + } + if ((output_h - 1) * stride_h >= input_h + pad_up) { + --output_h; + } + } + return RET_OK; +} + +int PoolingGradCPUKernel::OnnxPadding(int input_w, int input_h, int &output_w, int &output_h) { + PoolingParameter *pool_param = reinterpret_cast (opParameter); + + auto round_mode = pool_param->round_mode_; + auto stride_w = pool_param->stride_w_; + auto stride_h = pool_param->stride_h_; + auto window_w = pool_param->window_w_; + auto window_h = pool_param->window_h_; + auto pad_up = pool_param->pad_u_; + auto pad_down = pool_param->pad_d_; + auto pad_left = pool_param->pad_l_; + auto pad_right = pool_param->pad_r_; + if (round_mode == ROUNDMODE_FLOOR) { + output_w = floor((input_w + pad_left + pad_right - window_w) / stride_w + 1); + output_h = floor((input_h + pad_up + pad_down - window_h) / stride_h + 1); + } else if (round_mode == ROUNDMODE_CEIL) { + MS_LOG(ERROR) << "RoundMode_CEIL mode not support."; + } else { + MS_LOG(ERROR) << "OnnxPadding round mode not support."; + } + return RET_OK; +} +#endif + +int PoolingGradCPUKernel::Init() { + // InferShape(): + // auto *in_tensor = reinterpret_cast(inputs_.at(0)->Data()); + // auto *x_tensor = reinterpret_cast(inputs_.at(1)->Data()); + + PoolingParameter *pool_param = reinterpret_cast(opParameter); + + auto in_shape = inputs_.at(0)->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + + if (pool_param->global_) { + pool_param->window_w_ = input_w; + pool_param->window_h_ = input_h; + } + + // Emir -- here I assume we get the outputshape in the output tensor + auto *out_tensor = outputs_.front(); + auto out_shape = out_tensor->shape(); + +#if 0 + int output_w = 0, output_h = 0; + auto fmk_type = pool_param->fmk_type_; + switch (fmk_type) { + case lite::FmkType_TF: + break; + case lite::FmkType_CAFFE: + CaffePadding(input_w, input_h, output_w, output_h); + break; + case lite::FmkType_ONNX: + OnnxPadding(input_w, input_h, output_w, output_h); + break; + case lite::FmkType_MS: + break; + case lite::FmkType_TFLITE: + TfPadding(input_w, input_h, output_w, output_h); + break; + default: + MS_LOG(ERROR) << "Not support this framework."; + } + std::vector out_shape{in_tensor->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; +#endif + out_tensor->set_shape(out_shape); + out_tensor->set_data_type(inputs_.at(0)->data_type()); + return RET_OK; +} + +int PoolingGradCPUKernel::ReSize() { return RET_OK; } + +int PoolingGradCPUKernel::Run() { + PoolingParameter *pool_param = reinterpret_cast(opParameter); + auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + + if (pool_param->max_pooling_) { + auto ind = reinterpret_cast(inputs_.at(1)->Data()); + MaxPoolingGrad(input_ptr, ind, output_ptr, pool_param); + } else { + AvgPoolingGrad(input_ptr, output_ptr, pool_param); + } + return RET_OK; +} + +kernel::LiteKernel *CpuPoolingGradFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_PoolingGrad); + + auto *kernel = new (std::nothrow) PoolingGradCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PoolingGrad, CpuPoolingGradFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h new file mode 100644 index 0000000000..eec333d860 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +using mindspore::schema::PadMode; +using mindspore::schema::PoolMode; +using mindspore::schema::QuantType; +using mindspore::schema::RoundMode; + +class PoolingGradCPUKernel : public LiteKernel { + public: + explicit PoolingGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~PoolingGradCPUKernel() override = default; + + // int TfPadding(int input_w, int input_h, int &output_w, int &output_h); + // int CaffePadding(int input_w, int input_h, int &output_w, int &output_h); + // int OnnxPadding(int input_w, int input_h, int &output_w, int &output_h); + + int Init() override; + int ReSize() override; + int Run() override; + + private: + uint8_t data_shape_{0}; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc new file mode 100644 index 0000000000..26e9a0c5ea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/power.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Power; + +namespace mindspore::kernel { +int PowerCPUKernel::Init() { return RET_OK; } + +int PowerCPUKernel::ReSize() { return RET_OK; } + +int PowerImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->RunImpl(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PowerImpl error: " << ret; + return ret; + } + return RET_OK; +} + +int PowerCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(PowerImpl, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PowerCPUKernel error: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +int PowerCPUKernel::RunImpl(int task_id) { + auto x_addr = reinterpret_cast(inputs_[0]->Data()); + auto exp_addr = reinterpret_cast(inputs_[1]->Data()); + auto output_addr = reinterpret_cast(outputs_[0]->Data()); + auto size = inputs_[0]->ElementsNum(); + int stride = UP_DIV(size, thread_count_); + int len = MSMIN(stride, size - stride * task_id); + bool broadcast = (inputs_[1]->ElementsNum() == 1) ? true : false; + float *cur_exp; + if (broadcast) { + cur_exp = exp_addr; + } else { + cur_exp = exp_addr + stride * task_id; + } + Power(x_addr + stride * task_id, cur_exp, output_addr + stride * task_id, len, scale_, shift_, broadcast); + return RET_OK; +} + +kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Power); + auto *kernel = + new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PowerCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Power, CpuPowerFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.h b/mindspore/lite/src/runtime/kernel/arm/fp32/power.h new file mode 100644 index 0000000000..b2ee1f6b62 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_ + +#include +#include "include/context.h" +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/power.h" + +namespace mindspore::kernel { +class PowerCPUKernel : public LiteKernel { + public: + PowerCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(param, inputs, outputs), + ctx_(ctx), + thread_count_(ctx->thread_num_), + scale_(reinterpret_cast(opParameter)->scale_), + shift_(reinterpret_cast(opParameter)->shift_) {} + ~PowerCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: + const lite::Context *ctx_; + int thread_count_; + float scale_; + float shift_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc new file mode 100644 index 0000000000..f99bd3e743 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/power_grad.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_PowerGrad; + +namespace mindspore::kernel { +int PowerGradCPUKernel::Init() { return RET_OK; } + +int PowerGradCPUKernel::ReSize() { return RET_OK; } + +int PowerGradCPUKernel::Run() { + auto dy_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto x_addr = reinterpret_cast(inputs_.at(1)->Data()); + auto dx_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto size = inputs_.at(0)->ElementsNum(); + + float exp = power_ - 1; + Power(x_addr, &exp, dx_addr, size, scale_, shift_, true); + ElementMul(dx_addr, dy_addr, dx_addr, size); + float scale = scale_ * power_; + for (int i = 0; i < size; i++) { + dx_addr[i] *= scale; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuPowerGradFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_PowerGrad); + auto *kernel = new (std::nothrow) PowerGradCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PowerGrad, CpuPowerGradFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h new file mode 100644 index 0000000000..00b2e882f7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "src/runtime/kernel/arm/nnacl/power.h" + +namespace mindspore::kernel { +class PowerGradCPUKernel : public LiteKernel { + public: + PowerGradCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(param, inputs, outputs) { + PowerParameter *power_param = reinterpret_cast(param); + power_ = power_param->power_; + scale_ = power_param->scale_; + shift_ = power_param->shift_; + } + ~PowerGradCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float power_; + float scale_; + float shift_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc new file mode 100644 index 0000000000..952392761d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/prelu.h" +#include +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/nnacl/prelu.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Prelu; + +namespace mindspore::kernel { +int PReluCPUKernel::Init() { + prelu_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +int PReluCPUKernel::DoExcute(int task_id) { + PRelu(input_data, output_data, prelu_param_, task_id); + return RET_OK; +} + +int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto PReludata = reinterpret_cast(cdata); + auto ret = PReludata->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PReluRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PReluCPUKernel::Run() { + auto input = inputs_.at(0); + prelu_param_->input_num_ = input->ElementsNum(); + input_data = reinterpret_cast(input->Data()); + output_data = reinterpret_cast(outputs_.at(0)->Data()); + + auto ret = LiteBackendParallelLaunch(PReluRun, this, prelu_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PReluDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Prelu); + auto *kernel = new (std::nothrow) PReluCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PReluCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Prelu, CpuPReluFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h new file mode 100644 index 0000000000..63da37c8f5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/prelu.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class PReluCPUKernel : public LiteKernel { + public: + PReluCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + prelu_param_ = (reinterpret_cast(opParameter)); + } + ~PReluCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + PReluParameter *prelu_param_; + + private: + float *input_data; + float *output_data; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc new file mode 100644 index 0000000000..b965731a35 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/range.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Range; + +namespace mindspore::kernel { + +namespace { +constexpr int kInputNum = 0; +constexpr int kOutputNum = 1; +} // namespace + +int RangeCPUKernel::Init() { return RET_OK; } + +int RangeCPUKernel::ReSize() { return RET_OK; } + +int RangeCPUKernel::Run() { + size_t start = (reinterpret_cast(opParameter))->start_; + size_t limit = (reinterpret_cast(opParameter))->limit_; + size_t delta = (reinterpret_cast(opParameter))->delta_; + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + Range(output_ptr, start, limit, delta); + return RET_OK; +} + +kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Range); + + auto *kernel = new (std::nothrow) RangeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new RangeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Range, CpuRangeFp32KernelCreator) + +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range.h b/mindspore/lite/src/runtime/kernel/arm/fp32/range.h new file mode 100644 index 0000000000..0306440aba --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/range.h" + +namespace mindspore::kernel { +class RangeCPUKernel : public LiteKernel { + public: + explicit RangeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~RangeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc new file mode 100644 index 0000000000..7917f0928d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/rank.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Rank; + +namespace mindspore::kernel { + +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +} // namespace + +int RankCPUKernel::Init() { return RET_OK; } + +int RankCPUKernel::ReSize() { return RET_OK; } + +int RankCPUKernel::Run() { + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + auto in_shape = inputs_[0]->shape(); + auto rank = in_shape.size(); + Rank(output_ptr, rank); + return RET_OK; +} + +kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Rank); + + auto *kernel = new (std::nothrow) RankCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new RankCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rank, CpuRankFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/rank.h b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.h new file mode 100644 index 0000000000..13ef53dc79 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/rank.h" + +namespace mindspore::kernel { +class RankCPUKernel : public LiteKernel { + public: + explicit RankCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~RankCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc new file mode 100644 index 0000000000..43bdfe9351 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc @@ -0,0 +1,282 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/reduce.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/nnacl/fp32/reduce.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Mean; +using mindspore::schema::PrimitiveType_Reduce; +using mindspore::schema::ReduceMode; +using mindspore::schema::ReduceMode_ReduceMax; +using mindspore::schema::ReduceMode_ReduceMean; +using mindspore::schema::ReduceMode_ReduceMin; +using mindspore::schema::ReduceMode_ReduceProd; +using mindspore::schema::ReduceMode_ReduceSum; +using mindspore::schema::ReduceMode_ReduceSumSquare; + +namespace mindspore::kernel { +namespace { +constexpr size_t kInputNum = 1; +constexpr size_t kOutputNum = 1; +} // namespace + +int ReduceCPUKernel::CheckInputsOutputs() { + if (inputs_.size() != kInputNum) { + MS_LOG(ERROR) << "Reduce inputs size should be " << kInputNum << " but got " << inputs_.size(); + return RET_ERROR; + } + if (outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "Reduce outputs size should be " << kOutputNum << " but got " << outputs_.size(); + return RET_ERROR; + } + auto input = inputs_.at(0); + if (input == nullptr) { + MS_LOG(ERROR) << "Reduce input is nullptr"; + return RET_NULL_PTR; + } + auto output = outputs_.at(0); + if (output == nullptr) { + MS_LOG(ERROR) << "Reduce output is nullptr"; + return RET_NULL_PTR; + } + return RET_OK; +} + +int ReduceCPUKernel::CheckParameters() { + size_t input_rank = inputs_.at(0)->shape().size(); + if (static_cast(num_axes_) > input_rank) { + MS_LOG(ERROR) << "Reduce num of reduce axes " << num_axes_ << " larger than input rank " << input_rank; + return RET_ERROR; + } + for (auto i = 0; i < num_axes_; i++) { + if (axes_[i] < -static_cast(input_rank) || static_cast(axes_[i]) >= input_rank) { + MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in [" + << -static_cast(input_rank) << ", " << input_rank - 1 << "]."; + return RET_ERROR; + } + if (axes_[i] < 0) { + axes_[i] += static_cast(input_rank); + } + } + + if (num_axes_ == 0) { + for (int i = 0; i < input_rank; i++) { + axes_[i] = i; + } + } + + return RET_OK; +} + +int ReduceCPUKernel::Init() { + auto ret = CheckInputsOutputs(); + if (ret != RET_OK) { + return ret; + } + ret = CheckParameters(); + if (ret != RET_OK) { + return ret; + } + ret = MallocTmpBuffer(); + if (ret != RET_OK) { + return ret; + } + + switch (mode_) { + case static_cast(ReduceMode_ReduceSum): { + reducer_ = ReduceSum; + break; + } + case static_cast(ReduceMode_ReduceMean): { + reducer_ = ReduceMean; + break; + } + case static_cast(ReduceMode_ReduceMax): { + reducer_ = ReduceMax; + break; + } + case static_cast(ReduceMode_ReduceMin): { + reducer_ = ReduceMin; + break; + } + case static_cast(ReduceMode_ReduceProd): { + reducer_ = ReduceProd; + break; + } + case static_cast(ReduceMode_ReduceSumSquare): { + reducer_ = ReduceSumSquare; + break; + } + default: + MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; + return RET_ERROR; + } + return RET_OK; +} + +int ReduceCPUKernel::CallReduceUnit(int task_id) { + auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id, + context_->thread_num_); + return ret; +} + +int ReduceImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto reduce = reinterpret_cast(cdata); + auto error_code = reduce->CallReduceUnit(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Reduce Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ReduceCPUKernel::Run() { + tmp_shape_ = inputs_.at(0)->shape(); + src_data_ = static_cast(inputs_.at(0)->Data()); + for (int i = 0; i < data_buffers_.size(); ++i) { + dst_data_ = data_buffers_[i]; + int axis = axes_[i]; + outer_size_ = 1; + for (int j = 0; j < axis; j++) { + outer_size_ *= tmp_shape_[j]; + } + inner_size_ = 1; + for (int k = axis + 1; k < static_cast(tmp_shape_.size()); k++) { + inner_size_ *= tmp_shape_[k]; + } + axis_size_ = tmp_shape_[axis]; + auto error_code = LiteBackendParallelLaunch(ReduceImpl, this, context_->thread_num_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + tmp_shape_[axis] = 1; + src_data_ = dst_data_; + } + + int last_reduce_axis = axes_[num_axes_ - 1]; + outer_size_ = 1; + for (int i = 0; i < last_reduce_axis; i++) { + outer_size_ *= tmp_shape_[i]; + } + inner_size_ = 1; + for (int i = last_reduce_axis + 1; i < static_cast(tmp_shape_.size()); i++) { + inner_size_ *= tmp_shape_[i]; + } + axis_size_ = tmp_shape_[last_reduce_axis]; + dst_data_ = reinterpret_cast(outputs_.at(0)->Data()); + auto error_code = LiteBackendParallelLaunch(ReduceImpl, this, context_->thread_num_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +int ReduceCPUKernel::MallocTmpBuffer() { + auto input_shape = inputs_.at(0)->shape(); + for (auto i = 0; i < num_axes_ - 1; i++) { + int axis = axes_[i]; + size_t size = 1; + for (auto j = 0; j < input_shape.size(); j++) { + if (static_cast(axis) != j) { + size *= input_shape[j]; + } + } + float *buffer = reinterpret_cast(malloc(size * sizeof(float))); + if (buffer == nullptr) { + MS_LOG(ERROR) << "Malloc data failed."; + return RET_ERROR; + } + data_buffers_.emplace_back(buffer); + input_shape[axis] = 1; + } + return RET_OK; +} + +kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Reduce); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Reduce opParameter nullptr"; + return nullptr; + } + if (desc.type != schema::PrimitiveType_Reduce) { + MS_LOG(ERROR) << "Reduce op desc.type should be PrimitiveType_Reduce, got " << desc.type; + return nullptr; + } + auto *kernel = + new (std::nothrow) ReduceCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuMeanFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Mean); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Reduce opParameter nullptr"; + return nullptr; + } + if (desc.type != schema::PrimitiveType_Mean) { + MS_LOG(ERROR) << "Reduce op desc.type should be PrimitiveType_Mean, got " << desc.type; + return nullptr; + } + auto *kernel = + new (std::nothrow) ReduceCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reduce, CpuReduceFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mean, CpuMeanFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h new file mode 100644 index 0000000000..2273465c27 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REDUCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REDUCE_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/nnacl/fp32/reduce.h" +#include "ir/anf.h" +using mindspore::schema::ReduceMode; + +namespace mindspore::kernel { +class ReduceCPUKernel : public LiteKernel { + typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); + + public: + ReduceCPUKernel(ReduceParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(reinterpret_cast(param), inputs, outputs), + context_(ctx), + keep_dims_(param->keep_dims_), + num_axes_(param->num_axes_), + mode_(param->mode_) { + memcpy(axes_, param->axes_, sizeof(param->axes_)); + } + ~ReduceCPUKernel() { + for (auto i = 0; i < data_buffers_.size(); i++) { + float *buffer = data_buffers_[i]; + if (buffer != nullptr) { + free(buffer); + buffer = nullptr; + } + } + src_data_ = nullptr; + dst_data_ = nullptr; + } + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int CallReduceUnit(int task_id); + + private: + int CheckInputsOutputs(); + int CheckParameters(); + int MallocTmpBuffer(); + + private: + const lite::Context *context_ = nullptr; + bool keep_dims_; + int axes_[REDUCE_MAX_AXES_NUM]; + int num_axes_; + int mode_; + + private: + std::vector data_buffers_; + int outer_size_; + int inner_size_; + int axis_size_; + std::vector tmp_shape_; + const float *src_data_; + float *dst_data_; + Reducer reducer_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REDUCE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc new file mode 100644 index 0000000000..dc45381b84 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/reshape.h" +#include +#include "src/runtime/kernel/arm/nnacl/reshape.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reshape; + +namespace mindspore::kernel { +int ReshapeCPUKernel::Init() { + ReshapeBaseCPUKernel::Init(); + return RET_OK; +} + +int ReshapeCPUKernel::ReSize() { return RET_OK; } + +int ReshapeCPUKernel::Run() { + auto input_ptr = inputs_.at(kInputIndex)->Data(); + auto output_ptr = outputs_.at(kOutputIndex)->Data(); + size_t data_size = inputs_.at(kInputIndex)->Size(); + Reshape(input_ptr, output_ptr, data_size); + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h new file mode 100644 index 0000000000..f366e739d9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/base/reshape_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeCPUKernel : public ReshapeBaseCPUKernel { + public: + ReshapeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ReshapeCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc new file mode 100644 index 0000000000..ea89861d32 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc @@ -0,0 +1,243 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/resize.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/resize.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_INVALID_OP_ATTR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +constexpr int kRank = 4; +} // namespace + +int ResizeCPUKernel::CheckParameters() { + auto parameter = reinterpret_cast(opParameter); + if (parameter == nullptr) { + MS_LOG(ERROR) << "cast ResizeParameter failed."; + return RET_NULL_PTR; + } + method_ = parameter->method_; + if (method_ != schema::ResizeMethod_BILINEAR && method_ != schema::ResizeMethod_NEAREST_NEIGHBOR) { + MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_; + return RET_INVALID_OP_ATTR; + } + new_height_ = parameter->new_height_; + if (new_height_ < 1) { + MS_LOG(ERROR) << "Resize new_height should >= 1, but got " << new_height_; + return RET_INVALID_OP_ATTR; + } + new_width_ = parameter->new_width_; + if (new_width_ < 1) { + MS_LOG(ERROR) << "Resize new_width should >= 1, but got " << new_width_; + return RET_INVALID_OP_ATTR; + } + align_corners_ = parameter->align_corners_; + preserve_aspect_ratio = parameter->preserve_aspect_ratio_; + if (preserve_aspect_ratio) { + MS_LOG(ERROR) << "Resize currently not support preserve_aspect_ratio true"; + return RET_ERROR; + } + return RET_OK; +} + +int ResizeCPUKernel::CheckInputsOuputs() { + if (inputs_.size() != kInputNum) { + MS_LOG(ERROR) << "Resize input num should be " << kInputNum << ", but got " << inputs_.size(); + return RET_ERROR; + } + auto input = inputs_.at(0); + if (input == nullptr) { + return RET_NULL_PTR; + } + if (outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "Resize output num should be " << kOutputNum << ", but got " << outputs_.size(); + return RET_ERROR; + } + auto output = outputs_.at(0); + if (output == nullptr) { + return RET_NULL_PTR; + } + return RET_OK; +} + +int ResizeCPUKernel::Init() { + auto ret = CheckParameters(); + if (ret != RET_OK) { + return ret; + } + ret = CheckInputsOuputs(); + if (ret != RET_OK) { + return ret; + } + + auto output = outputs_.at(0); + auto input = inputs_.at(0); + auto input_shape = input->shape(); + if (input_shape.size() != kRank) { + return RET_ERROR; + } + schema::Format execute_format; + size_t exec_input_size; + switch (method_) { + case schema::ResizeMethod_BILINEAR: { + execute_format = schema::Format_NC4HW4; + output->SetFormat(schema::Format_NC4HW4); + exec_input_size = input->ElementsC4Num(); + break; + } + case schema::ResizeMethod_NEAREST_NEIGHBOR: { + execute_format = schema::Format_NHWC; + output->SetFormat(schema::Format_NHWC); + exec_input_size = input->ElementsNum(); + break; + } + default: { + MS_LOG(ERROR) << "Resize unknown method " << method_; + return RET_ERROR; + } + } + + auto input_format = input->GetFormat(); + if (input_format != execute_format) { + auto input_type = input->data_type(); + layout_convertor_ = LayoutTransform(input_type, input_format, execute_format); + exec_input_data_ = reinterpret_cast(malloc(exec_input_size * sizeof(float))); + if (exec_input_data_ == nullptr) { + return RET_NULL_PTR; + } + } + + return RET_OK; +} + +int ResizeImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto resize = reinterpret_cast(cdata); + auto error_code = resize->RunImpl(task_id); + if (error_code != NNACL_OK) { + MS_LOG(ERROR) << "Resize Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ResizeCPUKernel::RunImpl(int task_id) { + auto input = inputs_.at(0); + auto input_data = reinterpret_cast(input->Data()); + if (input_data == nullptr) { + return RET_NULL_PTR; + } + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + if (output_data == nullptr) { + return RET_NULL_PTR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kRank) { + return RET_ERROR; + } + if (context_ == nullptr) { + return RET_NULL_PTR; + } + + int ret = 0; + switch (method_) { + case schema::ResizeMethod_BILINEAR: { + if (layout_convertor_ != nullptr) { + layout_convertor_(input_data, exec_input_data_, input->Batch(), input->Height() * input->Width(), + input->Channel()); + ret = ResizeBilinear(exec_input_data_, output_data, inputs_[0]->shape().data(), outputs_[0]->shape().data(), + align_corners_, task_id, context_->thread_num_); + } else { + ret = ResizeBilinear(input_data, output_data, inputs_[0]->shape().data(), outputs_[0]->shape().data(), + align_corners_, task_id, context_->thread_num_); + } + break; + } + case schema::ResizeMethod_NEAREST_NEIGHBOR: { + if (align_corners_) { + MS_LOG(ERROR) << "ResizeNearestNeighbor not support align_corners."; + return RET_ERROR; + } + if (layout_convertor_ != nullptr) { + layout_convertor_(input_data, exec_input_data_, input->Batch(), input->Height() * input->Width(), + input->Channel()); + ret = ResizeNearestNeighbor(exec_input_data_, output_data, input_shape.data(), outputs_[0]->shape().data(), + task_id, context_->thread_num_); + } else { + ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), outputs_[0]->shape().data(), task_id, + context_->thread_num_); + } + break; + } + case schema::ResizeMethod_UNKNOW: + default: { + MS_LOG(ERROR) << "Resize unknown method " << method_; + ret = NNACL_ERR; + } + } + return ret; +} + +int ResizeCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ResizeImpl, this, context_->thread_num_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuResizeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Resize); + auto *kernel = new (std::nothrow) ResizeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ResizeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Resize, CpuResizeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.h b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.h new file mode 100644 index 0000000000..6d4f681be2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/resize.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::schema::PrimitiveType_Resize; +using mindspore::schema::ResizeMethod; + +namespace mindspore::kernel { +class ResizeCPUKernel : public LiteKernel { + public: + ResizeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + + ~ResizeCPUKernel() { + if (exec_input_data_ != nullptr) { + free(exec_input_data_); + exec_input_data_ = nullptr; + } + } + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int RunImpl(int task_id); + + protected: + const lite::Context *context_; + + private: + int CheckParameters(); + int CheckInputsOuputs(); + + private: + ResizeMethod method_; + int64_t new_height_; + int64_t new_width_; + bool align_corners_; + bool preserve_aspect_ratio; + LayoutConvertor layout_convertor_ = nullptr; + float *exec_input_data_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc new file mode 100644 index 0000000000..9e3b6bb557 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/reverse.h" +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/reverse.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reverse; + +namespace mindspore::kernel { + +int ReverseCPUKernel::Stride(int index) { + int i, stride = 1; + for (i = index + 1; i < inputs_[0]->shape().size(); ++i) { + stride *= inputs_[0]->shape()[i]; + } + return stride; +} + +int ReverseCPUKernel::ReSize() { + auto *param = reinterpret_cast(opParameter); + auto input_shape = inputs_[0]->shape(); + if (param->num_axis_ > input_shape.size()) { + MS_LOG(ERROR) << "Reverse dims : " << param->num_axis_ + << "is greater than input shape size :" << input_shape.size(); + return RET_ERROR; + } + if (input_shape.size() > REVERSE_SHAPE_MAX_SIZE) { + MS_LOG(ERROR) << "input dimension num should <= " << REVERSE_SHAPE_MAX_SIZE; + return RET_ERROR; + } + + if (tmp_ != nullptr) { + free(tmp_); + tmp_ = nullptr; + } + tmp_ = reinterpret_cast(malloc(data_size_ * sizeof(int))); + if (tmp_ == nullptr) { + MS_LOG(ERROR) << "Reverse Malloc tmp_ error!"; + return RET_ERROR; + } + (void)memset(tmp_, 0, data_size_ * sizeof(int)); + + for (int i = 0; i < param->num_axis_; i++) { + int axis = param->axis_[i]; + int stride = Stride(axis); + strides_[i] = stride; + inCount_[i] = input_shape[axis]; + outCount_[i] = 1; + for (int j = 0; j < axis; j++) { + outCount_[i] *= input_shape[j]; + } + } + + int out, in, C, m; + for (int i = 0; i < data_size_; ++i) { + int tmp = i; + for (int j = 0; j < param->num_axis_; ++j) { + C = inCount_[j]; + out = tmp / (C * strides_[j]); + in = tmp / strides_[j] - out * C; + m = tmp % strides_[j]; + tmp = out * C * strides_[j] + strides_[j] * (C - 1 - in) + m; + } + tmp_[i] = tmp; + } + + return RET_OK; +} + +int ReverseCPUKernel::Init() { + data_size_ = inputs_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + int ret = ReSize(); + return ret; +} + +int ReverseRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoReverse(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "reverseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ReverseCPUKernel::DoReverse(int task_id) { + int count = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (count <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + auto ret = Reverse(in_ptr_ + offset, out_ptr_, thread_sz_stride_, tmp_ + offset); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReverseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ReverseCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_[0]->Data()); + out_ptr_ = reinterpret_cast(outputs_[0]->Data()); + int ret = LiteBackendParallelLaunch(ReverseRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Reverse run error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter is NULL! "; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Reverse); + auto *kernel = new (std::nothrow) ReverseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Kernel is NULL! name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reverse, CpuReverseFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h new file mode 100644 index 0000000000..3d5b6d66e3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" + +#define REVERSE_STRIDE_MAX_SIZE 4 + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReverseCPUKernel : public LiteKernel { + public: + ReverseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~ReverseCPUKernel() { + if (tmp_ != nullptr) { + free(tmp_); + tmp_ = nullptr; + } + } + + int Init() override; + int ReSize() override; + int Run() override; + int Stride(int index); + int DoReverse(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + int strides_[REVERSE_STRIDE_MAX_SIZE]; + int inCount_[REVERSE_STRIDE_MAX_SIZE]; + int outCount_[REVERSE_STRIDE_MAX_SIZE]; + const Context *ctx_; + int *tmp_ = nullptr; + float *in_ptr_; + float *out_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc new file mode 100644 index 0000000000..8684adaeeb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/reverse_sequence.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ReverseSequence; + +namespace mindspore::kernel { +int ReverseSequenceCPUKernel::Init() { + auto input0 = inputs_.at(0); + auto input1 = inputs_.at(1); + auto output = outputs_.at(0); + MS_ASSERT(input0 != nullptr); + MS_ASSERT(input1 != nullptr); + MS_ASSERT(output != nullptr); + + auto para = reinterpret_cast(opParameter); + + ConvertAxisToPositive(input0->shape(), &(para->batch_axis_)); + ConvertAxisToPositive(input0->shape(), &(para->seq_axis_)); + + para->ndim_ = input0->shape().size(); + for (int i = 0; i < para->ndim_; i++) { + para->input_shape0_[i] = input0->DimensionSize(i); + para->output_shape_[i] = output->DimensionSize(i); + } + + int less_axis = MSMIN(para->batch_axis_, para->seq_axis_); + int greater_axis = MSMAX(para->batch_axis_, para->seq_axis_); + + para->outer_count_ = CalcCountPreAxis(input0->shape(), less_axis); + para->outer_stride_ = input0->DimensionSize(less_axis) * CalcCountAfterAxis(input0->shape(), less_axis); + + para->inner_count_ = 1; + for (int i = less_axis + 1; i < greater_axis; ++i) { + para->inner_count_ *= input0->DimensionSize(i); + } + + para->inner_stride_ = input0->DimensionSize(greater_axis) * CalcCountAfterAxis(input0->shape(), greater_axis); + + para->copy_byte_size_ = sizeof(float) * CalcCountAfterAxis(input0->shape(), greater_axis); + para->total_data_size_ = input0->Size(); + return RET_OK; +} + +void ReverseSequenceCPUKernel::ConvertAxisToPositive(const std::vector shape, int *axis) { + if (axis != nullptr && *axis < 0) { + *axis += shape.size(); + } +} + +int ReverseSequenceCPUKernel::CalcCountPreAxis(const std::vector shape, int axis) { + int count = 1; + for (int i = 0; i < axis; ++i) { + count *= shape[i]; + } + return count; +} +int ReverseSequenceCPUKernel::CalcCountAfterAxis(const std::vector shape, int axis) { + int count = 1; + for (int i = axis + 1; i < shape.size(); ++i) { + count *= shape[i]; + } + return count; +} + +int ReverseSequenceCPUKernel::ReSize() { return RET_OK; } + +int ReverseSequenceCPUKernel::Run() { + float *input0 = reinterpret_cast(inputs_.at(0)->Data()); + int *input1 = reinterpret_cast(inputs_.at(1)->Data()); + float *output = reinterpret_cast(outputs_.at(0)->Data()); + ReverseSequence(input0, input1, output, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_ReverseSequence); + auto *kernel = new (std::nothrow) ReverseSequenceCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseSequence, CpuReverseSequenceFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h new file mode 100644 index 0000000000..4723745afd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/reverse_sequence.h" + +namespace mindspore::kernel { +class ReverseSequenceCPUKernel : public LiteKernel { + public: + ReverseSequenceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ReverseSequenceCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void ConvertAxisToPositive(const std::vector shape, int *axis); + int CalcCountPreAxis(const std::vector shape, int axis); + int CalcCountAfterAxis(const std::vector shape, int axis); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc new file mode 100644 index 0000000000..2323e4ce95 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/scale.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/scale.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Scale; + +namespace mindspore::kernel { +int ScaleCPUKernel::InitScaleOffset() { + auto param = reinterpret_cast(opParameter); + auto scale_tensor = inputs_.at(1); + float *scale_ptr = reinterpret_cast(inputs_.at(1)->Data()); + if (scale_ptr != nullptr) { + scale_ = reinterpret_cast(malloc(scale_tensor->ElementsNum() * sizeof(float))); + if (scale_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(float)); + } else { + scale_ = nullptr; + } + + if (inputs_.size() == 3) { + auto offset_tensor = inputs_.at(1); + offset_ = reinterpret_cast(malloc(offset_tensor->ElementsNum() * sizeof(float))); + if (offset_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + param->has_offset_ = true; + } else { + offset_ = nullptr; + param->has_offset_ = false; + } + return RET_OK; +} + +int ScaleCPUKernel::InitParameter() { + auto param = reinterpret_cast(opParameter); + auto in_tensor = inputs_.at(0); + auto in_shape = in_tensor->shape(); + auto scale_tensor = inputs_.at(1); + auto scale_shape = scale_tensor->shape(); + + if (scale_shape.size() + param->axis_ > in_shape.size()) { + MS_LOG(ERROR) << "Scale tensor shape is incorrect."; + return RET_ERROR; + } + param->outer_size_ = 1; + param->axis_size_ = 1; + param->inner_size_ = 1; + for (int i = 0; i < param->axis_; i++) { + param->outer_size_ *= in_shape[i]; + } + for (int i = 0; i < scale_shape.size(); i++) { + if (in_shape[i + param->axis_] != scale_shape[i]) { + MS_LOG(ERROR) << "Scale tensor shape is incorrect."; + return RET_ERROR; + } + param->axis_size_ *= in_shape[i + param->axis_]; + } + for (int i = param->axis_ + scale_shape.size(); i < in_shape.size(); i++) { + param->inner_size_ *= in_shape[i]; + } + return RET_OK; +} + +int ScaleCPUKernel::Init() { + if (inputs_.size() < 2 || inputs_.size() > 3) { + MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + + auto ret = InitParameter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale fp32 InitParameter failed."; + return RET_ERROR; + } + + ret = InitScaleOffset(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ScaleCPUKernel::ReSize() { return RET_OK; } + +int ScaleCPUKernel::Scale(int task_id) { + auto ret = + DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, reinterpret_cast(opParameter)); + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScaleRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto scale = reinterpret_cast(cdata); + auto ret = scale->Scale(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScaleRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScaleCPUKernel::Run() { + auto in_tensor = inputs_.front(); + input_ptr_ = reinterpret_cast(in_tensor->Data()); + if (scale_ == nullptr) { + auto scale_tensor = inputs_[1]; + scale_ = reinterpret_cast(scale_tensor->Data()); + } + auto out_tensor = outputs_.front(); + output_ptr_ = reinterpret_cast(out_tensor->Data()); + + int ret = LiteBackendParallelLaunch(ScaleRun, this, opParameter->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Scale); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter is nullptr"; + return nullptr; + } + auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Scale, CpuScaleFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h new file mode 100644 index 0000000000..32417bcc26 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { + +class ScaleCPUKernel : public LiteKernel { + public: + explicit ScaleCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + opParameter->thread_num_ = ctx->thread_num_; + } + ~ScaleCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int InitParameter(); + int InitScaleOffset(); + int Scale(int task_id); + + private: + float *input_ptr_; + float *scale_; + float *offset_; + float *output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc new file mode 100644 index 0000000000..1a9ab72866 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/scatter_nd.h" +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ScatterND; + +namespace mindspore::kernel { +namespace { + constexpr int kScatterNDInputNum = 3; + constexpr int kScatterNDOutputNum = 1; + constexpr int kScatterShapeIndex = 0; + constexpr int kScatterIndicesIndex = 1; + constexpr int kScatterUpdateIndex = 2; +} // namespace +int ScatterNDCPUKernel::Init() { + auto shape = inputs_.at(kScatterShapeIndex); + auto indices = inputs_.at(kScatterIndicesIndex); + auto update = inputs_.at(kScatterUpdateIndex); + + update_ptr_ = reinterpret_cast(update->Data()); + output_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); + + // check indices shape + auto shape_rank = shape->ElementsNum(); + auto shape_data = reinterpret_cast(shape->Data()); + auto indice_unit_rank = indices->shape().back(); + if (indice_unit_rank > shape_rank) { + MS_LOG(ERROR) << "Value of last dimension of indices is greater than shape rank."; + return RET_ERROR; + } + + if (indices->shape().size() < 2) { + MS_LOG(ERROR) << "Indices dimension smaller than 2."; + return RET_ERROR; + } + + // check consistency of the shape indices and shape + auto update_rank = static_cast(update->shape().size()); + auto indices_shape = indices->shape(); + if (update_rank != indices->shape().size() - 1 + shape_rank - indice_unit_rank) { + MS_LOG(ERROR) << "Update, shape rank and indices rank inconsistent."; + return RET_ERROR; + } + // check update shape + auto update_shape = update->shape(); + for (size_t i = 0; i < indices_shape.size() - 1; i++) { + if (update_shape[i] != indices_shape[i]) { + MS_LOG(ERROR) << "Value of " << i << " th dimension of indices is not equal to that of update."; + return RET_ERROR; + } + } + for (size_t i = 0; i < shape->ElementsNum() - (indices_shape.size() - 1); i++) { + if (update_shape[i + indices_shape.size() - 1] != shape_data[i + indices_shape.size() - 1]) { + MS_LOG(ERROR) << "Value of " << i + indices_shape.size() - 1 + << " th dimension of indices is not equal to the corresbonding dimension of shape."; + return RET_ERROR; + } + } + // todo check indeices out of range + // for (size_t i = 0; i < static_cast(indice_unit_rank); i++) {} + + // calculate unit_size_ + unit_size_ = 1; + for (int i = indices_shape.size() - 1; i < update_rank; i++) { + unit_size_ *= update_shape[i]; + } + + // calculate offsets + int out_stride = 1; + out_strides_.push_back(1); + for (int i = indice_unit_rank - 2; i >= 0; i--) { + out_stride *= shape_data[i + 1]; + out_strides_.push_back(out_stride); + } + + num_unit_ = 1; + num_unit_ *= update_shape[indices_shape.size() - 2]; + for (int i = indices_shape.size() - 3; i >= 0; i--) { + num_unit_ *= update_shape[i]; + } + + int *indices_ptr = reinterpret_cast(indices->Data()); + for (int i = 0; i < num_unit_; i++) { + int tmp_stride = 0; + for (int j = 0; j < indice_unit_rank; j++) { + tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_[j] * unit_size_; + } + output_unit_offsets_.push_back(tmp_stride); + } + + thread_n_num_ = MSMIN(thread_num_, num_unit_); + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + return RET_OK; +} + +int ScatterNDCPUKernel::ReSize() { return 0; } + +int ScatterNDCPUKernel::ScatterND(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int offset = task_id * thread_n_stride_; + MS_LOG(ERROR) << "offset " << offset << std::endl; + auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset, + unit_size_, num_unit_thread); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScatterND error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScatterNDRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->ScatterND(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScatterNDRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScatterNDCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(ScatterNDRun, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuScatterNDFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_ScatterND); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not scatterND"; + return nullptr; + } + auto *kernel = new (std::nothrow) ScatterNDCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != 0) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterND, CpuScatterNDFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h new file mode 100644 index 0000000000..6877fe8704 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/scatter_nd.h" + +namespace mindspore::kernel { + +class ScatterNDCPUKernel : public LiteKernel { + public: + explicit ScatterNDCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + ~ScatterNDCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int ScatterND(int task_id); + + private: + int thread_num_; + int thread_n_num_; + int thread_n_stride_; + int num_unit_; + int unit_size_; + float *output_ptr_; + float *update_ptr_; + std::vector out_strides_; + std::vector output_unit_offsets_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc new file mode 100644 index 0000000000..077e174fcf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/shape.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Shape; + +namespace mindspore::kernel { +namespace { + constexpr int kShapeInputNum = 1; + constexpr int kShapeOutputNum = 1; +} // namespace +int ShapeCPUKernel::Init() { return RET_OK; } + +int ShapeCPUKernel::ReSize() { return RET_OK; } + +int ShapeCPUKernel::Run() { + auto out_tensor = outputs_.front(); + auto in_tensor = inputs_.front(); + if (in_tensor == nullptr || out_tensor == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_ERROR; + } + if (in_tensor->Data() == nullptr || out_tensor->Data() == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_ERROR; + } + + for (int i = 0; i < in_tensor->shape().size(); i++) { + reinterpret_cast(out_tensor->Data())[i] = in_tensor->shape()[i]; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuShapeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Shape); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Shape"; + return nullptr; + } + auto *kernel = new (std::nothrow) ShapeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Shape, CpuShapeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/shape.h b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.h new file mode 100644 index 0000000000..25634db0b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/shape.h" + +namespace mindspore::kernel { + +class ShapeCPUKernel : public LiteKernel { + public: + explicit ShapeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ShapeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc new file mode 100644 index 0000000000..91d02b5623 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/slice.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/slice.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Slice; + +namespace mindspore::kernel { +namespace { +int SliceLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { + if (cdata == nullptr) { + MS_LOG(ERROR) << "Input cdata is nullptr!"; + return RET_NULL_PTR; + } + auto kernel = reinterpret_cast(cdata); + return kernel->SliceParallelRun(thread_id); +} +} // namespace + +int SliceCPUKernel::Init() { + auto *param = reinterpret_cast(opParameter); + auto input_shape = inputs_[0]->shape(); + if (input_shape.size() != param->param_length_) { + MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " + << input_shape.size(); + return RET_ERROR; + } + if (input_shape.size() > DIMENSION_4D) { + MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_4D; + return RET_ERROR; + } + + for (size_t i = 0; i < input_shape.size(); ++i) { + param->shape_[i] = input_shape[i]; + } + outputs_[0]->SetFormat(inputs_[0]->GetFormat()); + return RET_OK; +} + +int SliceCPUKernel::SliceParallelRun(int thread_id) { + const float *input_data = reinterpret_cast(inputs_[0]->Data()); + float *output_data = reinterpret_cast(outputs_[0]->Data()); + SliceParameter *param = reinterpret_cast(opParameter); + DoSlice(input_data, output_data, param); + return RET_OK; +} + +int SliceCPUKernel::Run() { + SliceParameter *param = reinterpret_cast(opParameter); + for (int i = 0; i < param->param_length_; ++i) { + if (param->size_[i] < 0) { + param->size_[i] = param->shape_[i] - param->begin_[i]; + } + param->end_[i] = param->begin_[i] + param->size_[i]; + } + + if (param->param_length_ < DIMENSION_4D) { + PadSliceParameterTo4D(param); + } + + const float *input_data = reinterpret_cast(inputs_[0]->Data()); + float *output_data = reinterpret_cast(outputs_[0]->Data()); + if (param->size_[1] < param->op_parameter_.thread_num_) { + DoSliceNoParallel(input_data, output_data, param); + return RET_OK; + } + int ret = LiteBackendParallelLaunch(SliceLaunch, this, param->op_parameter_.thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "slice launch fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "Input context is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Slice); + op_parameter->thread_num_ = ctx->thread_num_; + auto *kernel = new (std::nothrow) SliceCPUKernel(op_parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SliceCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h new file mode 100644 index 0000000000..a02baf4918 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class SliceCPUKernel : public LiteKernel { + public: + SliceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~SliceCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + } + int Run() override; + int SliceParallelRun(int thread_id); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc new file mode 100644 index 0000000000..ffdb5080e0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/softmax.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore::kernel { +int SoftmaxCPUKernel::Init() { + SoftmaxBaseCPUKernel::Init(); + + // malloc tmp buffer + auto axis = softmax_param_->axis_; + sum_data = reinterpret_cast(malloc(softmax_param_->input_shape_[axis] * sizeof(float))); + memset(sum_data, 0, softmax_param_->input_shape_[axis] * sizeof(float)); + return RET_OK; +} + +int SoftmaxCPUKernel::ReSize() { return RET_OK; } + +int SoftmaxCPUKernel::Run() { + auto input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Softmax(input_ptr, output_ptr, sum_data, softmax_param_); + return RET_OK; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h new file mode 100644 index 0000000000..0e9c0a7ebf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/softmax_base.h" + +namespace mindspore::kernel { +class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { + public: + SoftmaxCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~SoftmaxCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float *sum_data; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc new file mode 100644 index 0000000000..5ac8d4db97 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/space_to_batch.h" +#include +#include "schema/ops_generated.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_OP_EXECUTE_FAILURE; +using mindspore::schema::PrimitiveType_SpaceToBatch; + +namespace mindspore::kernel { + +int SpaceToBatchCPUKernel::Init() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; + return RET_FORMAT_ERR; + } + SpaceToBatchParameter *param = reinterpret_cast(this->opParameter); + for (int i = 0; i < SPACE_TO_BATCH_PADDINGS_SIZE; ++i) { + if (param->paddings_[i] != 0) { + param->need_paddings_ = true; + break; + } + } + param->n_dims_ = DIMENSION_4D; + param->n_space_dims_ = SPACE_TO_BATCH_BLOCK_SIZES_SIZE; + param->num_elements_ = EnumElement(param->in_shape_, param->n_dims_); + param->num_elements_padded_ = EnumElement(param->padded_in_shape_, param->n_dims_); + return RET_OK; +} + +int SpaceToBatchCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + input_ptr_ = reinterpret_cast(input->Data()); + output_ptr_ = reinterpret_cast(output->Data()); + SpaceToBatchParameter *param = reinterpret_cast(this->opParameter); + + int ret; + float *tmp_space[3] = {nullptr, nullptr, nullptr}; + if (param->need_paddings_) { + tmp_space[0] = reinterpret_cast(malloc(param->num_elements_padded_ * sizeof(float))); + (void)memset(tmp_space[0], 0, param->num_elements_padded_); + tmp_space[1] = reinterpret_cast(malloc(param->num_elements_padded_ * sizeof(float))); + (void)memset(tmp_space[1], 0, param->num_elements_padded_); + tmp_space[2] = reinterpret_cast(malloc(param->num_elements_padded_ * sizeof(float))); + (void)memset(tmp_space[2], 0, param->num_elements_padded_); + + ret = SpaceToBatch(input_ptr_, output_ptr_, *param, tmp_space); + } else { + ret = SpaceToBatch(input_ptr_, output_ptr_, *param, tmp_space); + } + if (ret != NNACL_OK) { + MS_LOG(ERROR) << "Do space to batch fails!"; + return RET_OP_EXECUTE_FAILURE; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuSpaceToBatchFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) SpaceToBatchCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SpaceToBatchCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatch, CpuSpaceToBatchFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h new file mode 100644 index 0000000000..510649f2c0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPACE_TO_BATCH_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPACE_TO_BATCH_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class SpaceToBatchCPUKernel : public LiteKernel { + public: + SpaceToBatchCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + + ~SpaceToBatchCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + + private: + const float *input_ptr_; + float *output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPACE_TO_BATCH_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc new file mode 100644 index 0000000000..13aa702567 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/space_to_depth.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::schema::PrimitiveType_SpaceToDepth; + +namespace mindspore::kernel { + +int SpaceToDepthCPUKernel::Init() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "space_to_depth only support NHWC now!"; + return RET_FORMAT_ERR; + } + SpaceToDepthParameter *param = reinterpret_cast(opParameter); + if (param->block_size_ <= 0) { + MS_LOG(ERROR) << "Input block_size should > 0!"; + return RET_PARAM_INVALID; + } + + num_unit_ = static_cast(inputs_[0]->shape().at(kNHWC_H)); + thread_h_num_ = MSMIN(thread_num_, num_unit_); + thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); + return RET_OK; +} + +int SpaceToDepthCPUKernel::SpaceToDepth(int task_id) { + int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_h_stride_; + auto in_shape = inputs_[0]->shape(); + auto out_shape = outputs_[0]->shape(); + SpaceToDepthParameter *param = reinterpret_cast(opParameter); + auto ret = SpaceToDepthForNHWC(input_ptr_, output_ptr_, in_shape.data(), out_shape.data(), in_shape.size(), + param->block_size_, thread_offset, thread_offset + num_unit_thread); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SpaceToDepth error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SpaceToDepthRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->SpaceToDepth(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SpaceToDepthRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SpaceToDepthCPUKernel::Run() { + input_ptr_ = reinterpret_cast(inputs_[0]->Data()); + output_ptr_ = reinterpret_cast(outputs_[0]->Data()); + if (inputs_[0]->GetFormat() == schema::Format_NHWC) { + int ret = LiteBackendParallelLaunch(SpaceToDepthRun, this, thread_h_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SpaceToDepth error error_code[" << ret << "]"; + return ret; + } + return RET_OK; + } else { + MS_LOG(ERROR) << "Only support NHWC now!"; + return RET_ERROR; + } +} +kernel::LiteKernel *CpuSpaceToDepthFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) SpaceToDepthCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SpaceToDepthCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SpaceToDepth, CpuSpaceToDepthFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h new file mode 100644 index 0000000000..749de65503 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_SPACE_TO_DEPTH_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_SPACE_TO_DEPTH_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class SpaceToDepthCPUKernel : public LiteKernel { + public: + SpaceToDepthCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + ~SpaceToDepthCPUKernel() = default; + + int SpaceToDepth(int task_id); + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + + private: + int thread_num_; + int thread_h_stride_; + int thread_h_num_; + int num_unit_; + float *input_ptr_; + float *output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_SPACE_TO_DEPTH_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc new file mode 100644 index 0000000000..fec15c1e00 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h" +#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SoftmaxCrossEntropy; + +namespace mindspore::kernel { + +int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, + float *output) const { + float total_loss = 0; + for (int i = 0; i < param->batch_size_; ++i) { + if (labels[i] < 0) { + MS_LOG(EXCEPTION) << "label value must >= 0"; + } + size_t label = labels[i]; + if (label > param->number_of_classes_) { + MS_LOG(EXCEPTION) << "error label input!"; + } else { + total_loss -= logf(losses[i * param->number_of_classes_ + label]); + } + } + output[0] = total_loss / param->batch_size_; +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, + float *output) const { + size_t row_start = 0; + for (int i = 0; i < param->batch_size_; ++i) { + if (labels[i] < 0) { + MS_LOG(EXCEPTION) << "label value must >= 0"; + } + size_t label = labels[i]; + if (label > param->number_of_classes_) { + MS_LOG(EXCEPTION) << "error label input!"; + } + for (size_t j = 0; j < param->number_of_classes_; ++j) { + size_t index = row_start + j; + if (j == label) { + output[index] = (losses[index] - 1) / param->batch_size_; + } else { + output[index] = losses[index] / param->batch_size_; + } + } + row_start += param->number_of_classes_; + } +} + +int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { + auto ins = reinterpret_cast(inputs_.at(0)->Data()); + auto labels = reinterpret_cast(inputs_.at(1)->Data()); + auto out = reinterpret_cast(outputs_.at(0)->Data()); + float *grads = NULL; + if (is_train()) { // outputs_.size() > 1) + grads = reinterpret_cast(outputs_.at(0)->Data()); + } + size_t data_size = inputs_.at(0)->ElementsNum(); + float *losses = new (std::nothrow) float[data_size]; + MS_ASSERT(losses != nullptr); + std::fill(losses, losses + data_size, 0); + + MS_ASSERT(out != nullptr); + MS_ASSERT(labels != nullptr); + MS_ASSERT(ins != nullptr); + + SoftmaxParameter sm_params; + sm_params.n_dim_ = param->n_dim_; + sm_params.element_size_ = data_size; + sm_params.axis_ = 1; + for (int i = 0; i < 4; i++) // softmax has only 4 params in shape + sm_params.input_shape_[i] = param->input_shape_[i]; + float sum_data[sm_params.input_shape_[sm_params.axis_]]; + Softmax(ins, losses, sum_data, &sm_params); + + if (is_train()) { + GradPostExecute(labels, losses, grads); + } else { + ForwardPostExecute(labels, losses, out); + } + return RET_OK; +} + +int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { + auto dims = inputs_[0]->shape(); + param->n_dim_ = 2; + param->number_of_classes_ = dims[1]; + param->batch_size_ = dims[0]; + for (unsigned int i = 0; i < dims.size(); i++) param->input_shape_[i] = dims[i]; + if (2 != this->inputs_.size()) { + MS_LOG(ERROR) << "softmax entropy loss should have two inputs"; + return RET_ERROR; + } + auto *in0 = inputs_.front(); + if (in0 == nullptr) { + MS_LOG(ERROR) << "softmax etropy loss in0 have no data"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_SoftmaxCrossEntropy); + auto *kernel = new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSoftmaxCrossEntropyFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h new file mode 100644 index 0000000000..a8dd3439cd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +namespace mindspore::kernel { + +class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { + public: + explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, + const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + param = reinterpret_cast(parameter); + } + ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + + void ForwardPostExecute(const int *labels, const float *losses, float *output) const; + void GradPostExecute(const int *labels, const float *losses, float *output) const; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + SoftmaxCrossEntropyParameter *param; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc new file mode 100644 index 0000000000..359a379985 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/sparse_to_dense.h" +#include +#include "schema/model_generated.h" +#include "schema/ops_generated.h" +#include "src/runtime/kernel/arm/nnacl/sparse_to_dense.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SparseToDense; + +namespace mindspore::kernel { +int SparseToDenseCPUKernel::Init() { + s2d_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +int SparseToDenseCPUKernel::DoExcute(int task_id) { + SparseToDense(input_data_, output_shape_, snum_, dnum_, sp_num_, output_data, s2d_param_, task_id); + return RET_OK; +} + +int SparseToDenseRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto s2ddata = reinterpret_cast(cdata); + auto ret = s2ddata->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SparseToDenseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} +int SparseToDenseCPUKernel::Run() { + auto input = inputs_.at(0); + auto input1 = inputs_.at(1); + auto input2 = inputs_.at(2); + auto input3 = inputs_.at(3); + auto output0 = outputs_.at(0); + + input_data_ = reinterpret_cast(input->Data()); + total_number_ = reinterpret_cast(input1->Data()); + snum_ = reinterpret_cast(input2->Data()); + dnum_ = reinterpret_cast(input3->Data()); + sp_num_ = static_cast(input->ElementsNum() / 2); + + output_data = reinterpret_cast(outputs_.at(0)->Data()); + std::vector temp_shape = output0->shape(); + output_shape_ = reinterpret_cast(temp_shape.data()); + + auto ret = LiteBackendParallelLaunch(SparseToDenseRun, this, s2d_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SparseToDenseRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_SparseToDense); + auto *kernel = new (std::nothrow) SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SparseToDenseCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, CpuSparseToDenseFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h new file mode 100644 index 0000000000..813e572979 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/sparse_to_dense.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class SparseToDenseCPUKernel : public LiteKernel { + public: + SparseToDenseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + s2d_param_ = (reinterpret_cast(opParameter)); + } + ~SparseToDenseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + SparseToDenseParameter *s2d_param_; + + private: + int *input_data_; + int *total_number_; + int sp_num_; + float *snum_; + float *dnum_; + float *output_data; + int *output_shape_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc new file mode 100644 index 0000000000..81942f7a37 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/split.h" +#include "src/runtime/kernel/arm/base/split_base.h" +#include "src/runtime/kernel/arm/nnacl/split.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Split; + +namespace mindspore::kernel { + +int SplitCPUKernel::Init() { + SplitBaseCPUKernel::Init(); + auto in_tensor = inputs_.front(); + input_ptr_ = reinterpret_cast(in_tensor->Data()); + for (int i = 0; i < param->num_split_; i++) { + output_ptr_.push_back(reinterpret_cast(outputs_.at(i)->Data())); + } + return RET_OK; +} + +int SplitCPUKernel::ReSize() { return RET_OK; } + +int SplitCPUKernel::Split(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_n_stride_; + auto ret = DoSplit(input_ptr_, output_ptr_.data(), inputs_.front()->shape().data(), thread_offset, num_unit_thread, + param); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Split error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SplitRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->Split(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SplitRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SplitCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(SplitRun, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/split.h b/mindspore/lite/src/runtime/kernel/arm/fp32/split.h new file mode 100644 index 0000000000..5761367abb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/split.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLIT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLIT_H_ + +#include +#include "src/runtime/kernel/arm/base/split_base.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class SplitCPUKernel : public SplitBaseCPUKernel { + public: + SplitCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : SplitBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~SplitCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int Split(int task_id); + + private: + float *input_ptr_; + std::vector output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLIT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc new file mode 100644 index 0000000000..fbc102eec0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/squeeze.h" +#include +#include "src/runtime/kernel/arm/nnacl/squeeze.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Squeeze; + +namespace mindspore::kernel { +namespace { + constexpr int kSqueezeInputNum = 1; + constexpr int kSqueezeOutputNum = 1; +} // namespace + +int SqueezeCPUKernel::Init() { return RET_OK; } + +int SqueezeCPUKernel::ReSize() { return RET_OK; } + +int SqueezeCPUKernel::Run() { + auto input_ptr = reinterpret_cast(inputs_.front()->Data()); + auto output_ptr = reinterpret_cast(outputs_.front()->Data()); + size_t data_size = inputs_.front()->Size(); + auto ret = DoSqueeze(input_ptr, output_ptr, data_size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Do squeeze failed."; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Squeeze"; + return nullptr; + } + auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h new file mode 100644 index 0000000000..e48fce51c9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { + +class SqueezeCPUKernel : public LiteKernel { + public: + explicit SqueezeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~SqueezeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + std::vector axes_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc new file mode 100644 index 0000000000..344e6762ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/stack.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/fp32/stack.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Stack; + +namespace mindspore::kernel { +int StackCPUKernel::Init() { + StackParameter *param = reinterpret_cast(opParameter); + auto input0_shape = inputs_[0]->shape(); + axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() : param->axis_; + schema::Format input0_format = inputs_[0]->GetFormat(); + bool need_convert_format = false; + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != input0_format) { + need_convert_format = true; + } + } + if (!need_convert_format) { + outputs_[0]->SetFormat(input0_format); + return RET_OK; + } + + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != schema::Format_NHWC) { + convert_functions_[i] = LayoutTransform(inputs_[i]->data_type(), inputs_[i]->GetFormat(), schema::Format_NHWC); + if (convert_functions_[i] == nullptr) { + MS_LOG(ERROR) << "Can not convert format " << inputs_[i]->GetFormat() << " to " << schema::Format_NHWC; + return RET_ERROR; + } + size_t packed_input_size = + inputs_[i]->Channel() * inputs_[i]->Batch() * inputs_[i]->Height() * inputs_[i]->Width(); + packed_inputs_[i] = reinterpret_cast(malloc(packed_input_size * sizeof(float))); + if (packed_inputs_[i] == nullptr) { + MS_LOG(ERROR) << "malloc memory fail!"; + return RET_ERROR; + } + memset(packed_inputs_[i], 0, packed_input_size * sizeof(float)); + } else { + convert_functions_[i] = nullptr; + packed_inputs_[i] = nullptr; + } + } + outputs_[0]->SetFormat(schema::Format_NHWC); + return RET_OK; +} + +int StackCPUKernel::Run() { + size_t inputs_num = inputs_.size(); + auto input0_shape = inputs_[0]->shape(); + auto *output_data = reinterpret_cast(outputs_[0]->Data()); + float *inputs[inputs_num]; + for (size_t i = 0; i < inputs_num; ++i) { + inputs[i] = reinterpret_cast(inputs_[i]->Data()); + if (convert_functions_[i] != nullptr) { + convert_functions_[i](inputs[i], packed_inputs_[i], inputs_[i]->Batch(), + inputs_[i]->Height() * inputs_[i]->Width(), inputs_[i]->Channel()); + } else { + packed_inputs_[i] = inputs[i]; + } + } + DoStack(packed_inputs_.data(), inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); + return RET_OK; +} + +kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *op_parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Stack); + auto *kernel = new (std::nothrow) StackCPUKernel(op_parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new StackCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h new file mode 100644 index 0000000000..c1d76ca193 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/base/layout_transform.h" + +namespace mindspore::kernel { +class StackCPUKernel : public LiteKernel { + public: + StackCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs), + convert_functions_(inputs_.size(), nullptr), + packed_inputs_(inputs_.size(), nullptr) {} + + ~StackCPUKernel() { + for (size_t i = 0; i < packed_inputs_.size(); ++i) { + if (packed_inputs_[i] != nullptr) { + free(packed_inputs_[i]); + packed_inputs_[i] = nullptr; + } + } + } + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + + private: + int axis_; + std::vector convert_functions_; + std::vector packed_inputs_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc new file mode 100644 index 0000000000..62a1502774 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/tile.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Tile; + +namespace mindspore::kernel { +int TileCPUKernel::Init() { + auto tile_parameter_ = reinterpret_cast(opParameter); + for (int i = 0; i < tile_parameter_->in_dim_; ++i) { + tile_parameter_->in_shape_[i] = inputs_[0]->shape()[i]; + tile_parameter_->out_shape_[i] = outputs_[0]->shape()[i]; + } + ComputeStrides(tile_parameter_->in_shape_, tile_parameter_->in_strides_, tile_parameter_->in_dim_); + ComputeStrides(tile_parameter_->out_shape_, tile_parameter_->out_strides_, tile_parameter_->in_dim_); + return RET_OK; +} + +void TileCPUKernel::ComputeStrides(int *shape, int *strides, int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +int TileCPUKernel::ReSize() { return RET_OK; } + +int TileCPUKernel::Run() { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + + Tile(input_addr, output_addr, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Tile); + auto *kernel = new (std::nothrow) TileCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile.h b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.h new file mode 100644 index 0000000000..bc05c8b8ed --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/tile.h" + +namespace mindspore::kernel { +class TileCPUKernel : public LiteKernel { + public: + explicit TileCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~TileCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void ComputeStrides(int *shape, int *strides, int ndim); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc new file mode 100644 index 0000000000..954ec041bc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/topk.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_TopK; + +namespace mindspore::kernel { +int TopKCPUKernel::Init() { + TopkParameter *parameter = reinterpret_cast(opParameter); + lite::tensor::Tensor *input = inputs_.at(0); + parameter->last_dim_size_ = input->shape()[input->shape().size() - 1]; + parameter->loop_num_ = 1; + for (int i = 0; i < input->shape().size() - 1; ++i) { + parameter->loop_num_ *= input->shape()[i]; + } + + parameter->topk_node_list_ = malloc(sizeof(TopkNode) * parameter->last_dim_size_); + if (parameter->topk_node_list_ == nullptr) { + MS_LOG(ERROR) << "malloc fail."; + return RET_ERROR; + } + return RET_OK; +} + +int TopKCPUKernel::ReSize() { return RET_OK; } + +int TopKCPUKernel::Run() { + auto input_data = reinterpret_cast(inputs_.at(0)->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + auto output_index = reinterpret_cast(outputs_.at(1)->Data()); + + Topk(input_data, output_data, output_index, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new TopKCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TopK, CpuTopKFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h new file mode 100644 index 0000000000..9ed7af2f39 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/topk.h" + +namespace mindspore::kernel { +class TopKCPUKernel : public LiteKernel { + public: + explicit TopKCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~TopKCPUKernel() override { + TopkParameter *parameter = reinterpret_cast(opParameter); + free(parameter->topk_node_list_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc new file mode 100644 index 0000000000..a6473114fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/transpose.h" +#include +#include "src/runtime/kernel/arm/nnacl/transpose.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Transpose; + +namespace mindspore::kernel { +namespace { + constexpr int kTransposeInputNum = 1; + constexpr int kTransposeOutputNum = 1; +} // namespace +int TransposeCPUKernel::Init() { + auto &inTensor = inputs_.front(); + auto &outTensor = outputs_.front(); + auto param = reinterpret_cast(opParameter); + auto in_shape = inTensor->shape(); + auto out_shape = outTensor->shape(); + param->strides_[param->num_axes_ - 1] = 1; + param->out_strides_[param->num_axes_ - 1] = 1; + param->data_size_ = inTensor->Size(); + for (int i = param->num_axes_ - 2; i >= 0; i--) { + param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1]; + param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; + } + return RET_OK; +} + +int TransposeCPUKernel::ReSize() { return RET_OK; } + +int TransposeCPUKernel::Run() { + MS_ASSERT(inputs_.size() == TransposeInputNum); + MS_ASSERT(outputs_.size() == TransposeOutputNum); + auto &inTensor = inputs_.front(); + auto &outTensor = outputs_.front(); + if (inTensor == nullptr || outTensor == nullptr) { + MS_LOG(ERROR) << "null pointer dreferencing."; + return RET_ERROR; + } + auto *in_data = static_cast(inTensor->Data()); + auto *out_data = static_cast(outTensor->Data()); + auto in_shape = inTensor->shape(); + auto out_shape = outTensor->shape(); + auto *input_shape = &in_shape.front(); + auto *output_shape = &out_shape.front(); + + auto ret = + DoTranspose(in_data, out_data, input_shape, output_shape, reinterpret_cast(opParameter)); + return ret; +} + +kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Transpose"; + return nullptr; + } + auto *kernel = new (std::nothrow) TransposeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h new file mode 100644 index 0000000000..f13ba70015 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/kernel_factory.h" + + +namespace mindspore::kernel { + +class TransposeCPUKernel : public LiteKernel { + public: + explicit TransposeCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(param, inputs, outputs) {} + ~TransposeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc new file mode 100644 index 0000000000..dc93554fa7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/unique.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Unique; + +namespace mindspore::kernel { +int UniqueCPUKernel::Init() { return RET_OK; } + +int UniqueCPUKernel::ReSize() { return RET_OK; } + +int UniqueCPUKernel::Run() { + auto input = reinterpret_cast(inputs_.at(0)->Data()); + auto output0 = reinterpret_cast(outputs_.at(0)->Data()); + auto output1 = reinterpret_cast(outputs_.at(1)->Data()); + + int output0_len = 0; + Unique(input, inputs_.at(0)->ElementsNum(), output0, &output0_len, output1); + + std::vector out_shape = outputs_.at(0)->shape(); + out_shape[out_shape.size() - 1] = output0_len; + outputs_.at(0)->set_shape(out_shape); + return RET_OK; +} + +kernel::LiteKernel *CpuUniqueFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + MS_ASSERT(parameter); + MS_ASSERT(desc.type == PrimitiveType_Unique); + auto *kernel = new (std::nothrow) UniqueCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unique, CpuUniqueFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unique.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.h new file mode 100644 index 0000000000..349b1e63d3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/unique.h" + +namespace mindspore::kernel { +class UniqueCPUKernel : public LiteKernel { + public: + UniqueCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~UniqueCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc new file mode 100644 index 0000000000..2555605caa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/unsqueeze.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Unsqueeze; + +namespace mindspore::kernel { +int UnsqueezeCPUKernel::Init() { + int ret = ReSize(); + return ret; +} + +int UnsqueezeCPUKernel::ReSize() { + data_size_ = inputs_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int UnsqueezeCPUKernel::DoUnsqueeze(int task_id) { + size_t size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size == 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int UnsqueezeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoUnsqueeze(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int UnsqueezeCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); + out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); + int ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); + auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h new file mode 100644 index 0000000000..77e60e0f57 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class UnsqueezeCPUKernel : public LiteKernel { + public: + UnsqueezeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + ~UnsqueezeCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoUnsqueeze(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + float *in_ptr_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc new file mode 100644 index 0000000000..8facbc7a9d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/unstack.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Unstack; + +namespace mindspore::kernel { +int UnstackCPUKernel::Init() { + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + size_t shape_size = input->shape().size(); + + auto para = reinterpret_cast(opParameter); + para->pre_dims_ = 1; + para->axis_dim_ = 1; + para->after_dims_ = 1; + if (para->axis_ < 0) { + para->axis_ += shape_size; + } + for (size_t i = 0; i < shape_size; i++) { + if (i < para->axis_) { + para->pre_dims_ *= input->DimensionSize(i); + } else if (i > para->axis_) { + para->after_dims_ *= input->DimensionSize(i); + } else { + para->axis_dim_ = input->DimensionSize(i); + } + } + + output_addr_array_ = reinterpret_cast(malloc(sizeof(float *) * outputs_.size())); + if (output_addr_array_ == nullptr) { + MS_LOG(ERROR) << "Failed to malloc memory"; + return lite::RET_ERROR; + } + return RET_OK; +} + +int UnstackCPUKernel::ReSize() { return RET_OK; } + +int UnstackCPUKernel::Run() { + float *input = reinterpret_cast(inputs_.at(0)->Data()); + size_t out_num = outputs_.size(); + for (size_t i = 0; i < out_num; i++) { + output_addr_array_[i] = reinterpret_cast(outputs_.at(i)->Data()); + } + Unistack(input, output_addr_array_, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuUnstackFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + MS_ASSERT(desc.type == PrimitiveType_Unstack); + auto *kernel = new (std::nothrow) UnstackCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unstack, CpuUnstackFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h new file mode 100644 index 0000000000..e652ad6adf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/unstack.h" + +namespace mindspore::kernel { +class UnstackCPUKernel : public LiteKernel { + public: + UnstackCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~UnstackCPUKernel() { + free(output_addr_array_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float **output_addr_array_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc new file mode 100644 index 0000000000..f2bf03fc46 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/where.h" +#include +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/nnacl/where.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Where; + +namespace mindspore::kernel { +int WhereCPUKernel::Init() { + where_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +int WhereCPUKernel::DoExcute(int task_id) { + Where(input_data, input_data1, input_data2, output_data, where_param_, task_id); + return RET_OK; +} + +int WhereRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto wheredata = reinterpret_cast(cdata); + auto ret = wheredata->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "WhereRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} +int WhereCPUKernel::Run() { + auto input = inputs_.at(0); + auto input1 = inputs_.at(1); + auto input2 = inputs_.at(2); + int num = input->ElementsNum(); + int num1_ = input1->ElementsNum(); + int num2_ = input2->ElementsNum(); + + input_data = reinterpret_cast(input->Data()); + input_data1 = reinterpret_cast(input1->Data()); + input_data2 = reinterpret_cast(input2->Data()); + output_data = reinterpret_cast(outputs_.at(0)->Data()); + int num_max = num > num1_ ? num : (num1_ > num2_ ? num1_ : num2_); + where_param_->num_ = num; + where_param_->num1_ = num1_; + where_param_->num2_ = num2_; + where_param_->number_ = num_max; + + if (((num != 1) && (num != num_max)) || ((num1_ != 1) && (num1_ != num_max)) || + ((num2_ != 1) && (num2_ != num_max))) { + MS_LOG(ERROR) << "The length of three inputs are not equal to 1 or length of output, which is unacceptable"; + return RET_ERROR; + } + if (num_max <= 0) { + MS_LOG(ERROR) << "Error, inputs' length are zero !!!"; + return RET_ERROR; + } + auto ret = LiteBackendParallelLaunch(WhereRun, this, where_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "WhereDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuWhereFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Where); + auto *kernel = new (std::nothrow) WhereCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new WhereCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Where, CpuWhereFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where.h b/mindspore/lite/src/runtime/kernel/arm/fp32/where.h new file mode 100644 index 0000000000..ad9c73a9fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/where.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class WhereCPUKernel : public LiteKernel { + public: + WhereCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + where_param_ = reinterpret_cast(opParameter); + } + ~WhereCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + WhereParameter *where_param_; + + private: + bool *input_data; + float *input_data1; + float *input_data2; + float *output_data; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc new file mode 100644 index 0000000000..6fcd25dc48 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/zeroslike.h" +#include +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/nnacl/zeroslike.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ZerosLike; + +namespace mindspore::kernel { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; + +int ZerosLikeCPUKernel::Init() { return RET_OK; } + +int ZerosLikeCPUKernel::Run() { + auto input = inputs_.at(0); + auto input_data = reinterpret_cast(input->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + ApproximateZerosLike(input_data, output_data, input->ElementsNum()); + return RET_OK; +} + +kernel::LiteKernel *CpuZerosLikeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_ZerosLike); + auto *kernel = new (std::nothrow) ZerosLikeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ZerosLikeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ZerosLike, CpuZerosLikeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h new file mode 100644 index 0000000000..cd6aad2667 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class ZerosLikeCPUKernel : public LiteKernel { + public: + ZerosLikeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + + ~ZerosLikeCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc b/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc new file mode 100644 index 0000000000..32d933d5e8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/activation.h" +#include "src/runtime/kernel/arm/int8/relu_int8.h" +#include "src/runtime/kernel/arm/int8/hswish_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Activation; + +namespace mindspore::kernel { +kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "parameter is nullptr"; + return nullptr; + } + MS_ASSERT(inputs.at(0)); + auto type = (reinterpret_cast(parameter))->type_; + kernel::LiteKernel *kernel = nullptr; + switch (static_cast(type)) { + case schema::ActivationType_RELU: + kernel = new (std::nothrow) ReluInt8CPUKernel(parameter, inputs, outputs, ctx); + break; + case schema::ActivationType_HSWISH: + kernel = new (std::nothrow) HswishInt8CPUKernel(parameter, inputs, outputs, ctx); + break; + default: + break; + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Activation, CpuActivationInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc new file mode 100644 index 0000000000..18bb5e80c6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/add_int8.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Add; + +namespace mindspore::kernel { +int QuantizedAddCPUKernel::Init() { + lite::tensor::Tensor *input0 = inputs_.at(0); + lite::tensor::Tensor *input1 = inputs_.at(1); + lite::tensor::Tensor *output = outputs_.at(0); + MS_ASSERT(input0); + MS_ASSERT(input1); + MS_ASSERT(output); + + para_.input0_scale_ = input0->GetQuantParams().front().scale; + para_.input0_offset_ = input0->GetQuantParams().front().zeroPoint * -1; + para_.input1_scale_ = input1->GetQuantParams().front().scale; + para_.input1_offset_ = input1->GetQuantParams().front().zeroPoint * -1; + para_.output_scale_ = output->GetQuantParams().front().scale; + para_.output_offset_ = output->GetQuantParams().front().zeroPoint; + + const int left_shift = 20; // 1 << 20, 2/20 + const double twice_max_input_scale = 2 * std::max(para_.input0_scale_, para_.input1_scale_); + const double real_input0_multiplier = para_.input0_scale_ / twice_max_input_scale; + const double real_input1_multiplier = para_.input1_scale_ / twice_max_input_scale; + const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * para_.output_scale_); + + QuantizeMultiplierSmallerThanOne(real_input0_multiplier, ¶_.input0_multiplier_, ¶_.input0_shift_); + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, ¶_.input1_multiplier_, ¶_.input1_shift_); + QuantizeMultiplierSmallerThanOne(real_output_multiplier, ¶_.output_multiplier_, ¶_.output_shift_); + + para_.output_activation_min_ = std::numeric_limits::min(); + para_.output_activation_max_ = std::numeric_limits::max(); + + int left_shift0 = -para_.input0_shift_ > 0 ? -para_.input0_shift_ : 0; + para_.right_shift0_ = -para_.input0_shift_ > 0 ? 0 : para_.input0_shift_; + + int left_shift1 = -para_.input1_shift_ > 0 ? -para_.input1_shift_ : 0; + para_.right_shift1_ = -para_.input1_shift_ > 0 ? 0 : para_.input1_shift_; + + para_.left_shift_out_ = -para_.output_shift_ > 0 ? -para_.output_shift_ : 0; + para_.right_shift_out_ = -para_.output_shift_ > 0 ? 0 : para_.output_shift_; + + para_.left_shift_result0_ = (1 << left_shift) * ((1 << left_shift0)); + para_.left_shift_result1_ = (1 << left_shift) * ((1 << left_shift1)); + + MS_ASSERT(left_shift + left_shift0 == left_shift); + MS_ASSERT(left_shift + left_shift1 == left_shift); + return 0; +} + +int QuantizedAddCPUKernel::ReSize() { return 0; } + +int QuantizedAddCPUKernel::Run() { + input0_data_ = static_cast(inputs_.at(0)->Data()); + input1_data_ = static_cast(inputs_.at(1)->Data()); + output_data_ = static_cast(outputs_.at(0)->Data()); + + elements_num_ = inputs_.at(0)->ElementsNum(); + count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; + + if (inputs_.at(0)->ElementsNum() != inputs_.at(1)->ElementsNum()) { + input0_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + input1_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + + ArithmeticParameter tile_para = {0}; + tile_para.ndim_ = outputs_.at(0)->shape().size(); + for (size_t i = 0; i < tile_para.ndim_; i++) { + tile_para.in_shape0_[i] = inputs_.at(0)->DimensionSize(i); + tile_para.in_shape1_[i] = inputs_.at(1)->DimensionSize(i); + tile_para.out_shape_[i] = outputs_.at(0)->DimensionSize(i); + } + TileDimensionsUint8(static_cast(inputs_.at(0)->Data()), static_cast(inputs_.at(1)->Data()), + reinterpret_cast(input0_data_), reinterpret_cast(input1_data_), + &tile_para); + auto ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); + ctx_->allocator->Free(input0_data_); + ctx_->allocator->Free(input1_data_); + return ret; + } + + auto ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); + return ret; +} + +int AddInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto add = reinterpret_cast(cdata); + add->DoExecute(task_id); + return lite::RET_OK; +} + +int QuantizedAddCPUKernel::DoExecute(int tId) { + int64_t real_dst_count = MSMIN(elements_num_ - tId * count_unit_, count_unit_); + int8_t *cur_input0_data = input0_data_ + tId * count_unit_; + int8_t *cur_input1_data = input1_data_ + tId * count_unit_; + int8_t *cur_output_data = output_data_ + tId * count_unit_; + + AddInt8(cur_input0_data, cur_input1_data, cur_output_data, real_dst_count, ¶_); + return lite::RET_OK; +} + +kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Add); + auto *kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Add, CpuAddInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h new file mode 100644 index 0000000000..b57aab2279 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ADD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ADD_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/add_int8.h" +#include "src/runtime/runtime_api.h" + +namespace mindspore::kernel { +class QuantizedAddCPUKernel : public LiteKernel { + public: + explicit QuantizedAddCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx_->thread_num_) {} + ~QuantizedAddCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tId); + + private: + const lite::Context *ctx_; + AddQuantParameter para_; + int thread_count_; + int64_t elements_num_; + int64_t count_unit_; + int8_t *input0_data_ = nullptr; + int8_t *input1_data_ = nullptr; + int8_t *output_data_ = nullptr; +}; + +int AddInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ADD_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc new file mode 100644 index 0000000000..03cae3aa86 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/int8/argminmax_int8.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_OK; +using mindspore::lite::RET_ERROR; + +namespace mindspore::kernel { +int ArgMinMaxInt8CPUKernel::Init() { + auto ret = ArgMinMaxBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + auto param = reinterpret_cast(opParameter); + param->data_type_ = kNumberTypeInt8; + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + in_quant_arg_.scale_ = in_quant_args.front().scale; + in_quant_arg_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + out_quant_arg_.scale_ = out_quant_args.front().scale; + out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + return RET_OK; +} + +int ArgMinMaxInt8CPUKernel::Run() { + auto input = inputs_.at(0); + + const int8_t *input_data = reinterpret_cast(inputs_.at(0)->Data()); + int8_t *output_data = reinterpret_cast(outputs_.at(0)->Data()); + + auto in_shape = input->shape().data(); + auto param = reinterpret_cast(opParameter); + if (param->topk_ == 1) { + ArgMinMaxQuant(input_data, output_data, in_shape, param, &in_quant_arg_, &out_quant_arg_); + return RET_OK; + } + + switch (param->axis_) { + case 0: + ArgMinMaxDim0(input_data, output_data, in_shape, param, &in_quant_arg_, &out_quant_arg_); + break; + case 1: + ArgMinMaxDim1(input_data, output_data, in_shape, param, &in_quant_arg_, &out_quant_arg_); + break; + case 2: + ArgMinMaxDim2(input_data, output_data, in_shape, param, &in_quant_arg_, &out_quant_arg_); + break; + case 3: + ArgMinMaxDim3(input_data, output_data, in_shape, param, &in_quant_arg_, &out_quant_arg_); + break; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h new file mode 100644 index 0000000000..919acd2037 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_ + +#include +#include "src/runtime/kernel/arm/base/arg_min_max_base.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +namespace mindspore::kernel { +class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel { + public: + ArgMinMaxInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~ArgMinMaxInt8CPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + private: + QuantArg in_quant_arg_; + QuantArg out_quant_arg_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc new file mode 100644 index 0000000000..3c01ca389b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/arithmetic_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +using mindspore::schema::PrimitiveType_Equal; +using mindspore::schema::PrimitiveType_NotEqual; +using mindspore::schema::PrimitiveType_LessEqual; +using mindspore::schema::PrimitiveType_Greater; +using mindspore::schema::PrimitiveType_GreaterEqual; +using mindspore::schema::PrimitiveType_Less; + +namespace mindspore::kernel { +namespace { +int ArithmeticsInt8Launch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { + auto arithmetic_kernel = reinterpret_cast(cdata); + auto error_code = arithmetic_kernel->DoArithmetic(thread_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRun error thread_id[" << thread_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace + +ArithmeticInt8CPUKernel::~ArithmeticInt8CPUKernel() { + auto param = reinterpret_cast(opParameter); + if (!param->broadcasting_) { + return; + } + if (context_->allocator != nullptr) { + if (tile_data0_ != nullptr) { + context_->allocator->Free(tile_data0_); + } + if (tile_data1_ != nullptr) { + context_->allocator->Free(tile_data1_); + } + } else { + if (tile_data0_ != nullptr) { + free(tile_data0_); + } + if (tile_data1_ != nullptr) { + free(tile_data1_); + } + } + tile_data0_ = nullptr; + tile_data1_ = nullptr; +} + +int ArithmeticInt8CPUKernel::Init() { + switch (opParameter->type_) { + case PrimitiveType_Equal: + arithmetic_run_ = ElementEqual; + break; + case PrimitiveType_NotEqual: + arithmetic_run_ = ElementNotEqual; + break; + case PrimitiveType_Less: + arithmetic_run_ = ElementLess; + break; + case PrimitiveType_LessEqual: + arithmetic_run_ = ElementLessEqual; + break; + case PrimitiveType_Greater: + arithmetic_run_ = ElementGreater; + break; + case PrimitiveType_GreaterEqual: + arithmetic_run_ = ElementGreaterEqual; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << opParameter->type_; + arithmetic_run_ = nullptr; + return RET_PARAM_INVALID; + } + auto data_size = outputs_[0]->Size(); + auto param = reinterpret_cast(opParameter); + if (param->broadcasting_) { + if (context_->allocator != nullptr) { + tile_data0_ = reinterpret_cast(context_->allocator->Malloc(data_size)); + tile_data1_ = reinterpret_cast(context_->allocator->Malloc(data_size)); + } else { + tile_data0_ = reinterpret_cast(malloc(data_size)); + tile_data1_ = reinterpret_cast(malloc(data_size)); + } + } else { + tile_data0_ = nullptr; + tile_data1_ = nullptr; + } + return RET_OK; +} + +int ArithmeticInt8CPUKernel::ReSize() { return RET_OK; } + +int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) { + auto input0_data = reinterpret_cast(inputs_[0]->Data()); + auto input1_data1 = reinterpret_cast(inputs_[1]->Data()); + auto output_data = reinterpret_cast(outputs_[0]->Data()); + auto element_num = outputs_[0]->ElementsNum(); + auto param = reinterpret_cast(opParameter); + if (param->broadcasting_ && arithmetic_run_ != nullptr) { + MS_ASSERT(thread_count_ != 0); + int stride = UP_DIV(element_num, thread_count_); + int count = MSMIN(stride, element_num - stride * thread_id); + + int error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id, + output_data + stride * thread_id, count); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Arithmetic run fail! ret: " << error_code; + return RET_ERROR; + } + } else if (arithmetic_run_ != nullptr) { + int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Arithmetic run fail!ret: " << error_code; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticInt8CPUKernel::Run() { + auto param = reinterpret_cast(opParameter); + if (param->broadcasting_) { + auto input_data0 = reinterpret_cast(inputs_[0]->Data()); + auto input_data1 = reinterpret_cast(inputs_[1]->Data()); + TileDimensionsInt8(input_data0, input_data1, tile_data0_, tile_data1_, param); + } + int error_code = LiteBackendParallelLaunch(ArithmeticsInt8Launch, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Arithmetic launch function fail! ret: " << error_code; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "Input parameter is null!"; + return nullptr; + } + auto kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create ArithmeticInt8CPUKernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Equal, CpuArithmeticInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_NotEqual, CpuArithmeticInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Less, CpuArithmeticInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LessEqual, CpuArithmeticInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Greater, CpuArithmeticInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GreaterEqual, CpuArithmeticInt8KernelCreator) + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h new file mode 100644 index 0000000000..56ebcd7e0b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + +namespace mindspore::kernel { +class ArithmeticInt8CPUKernel : public LiteKernel { + typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + + public: + ArithmeticInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_), context_(ctx) {} + ~ArithmeticInt8CPUKernel(); + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmetic(int thread_id); + + private: + int thread_count_; + int8_t *tile_data0_; + int8_t *tile_data1_; + const lite::Context *context_; + ArithmeticRunInt8 arithmetic_run_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc new file mode 100644 index 0000000000..1688f27bb4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/arithmetic_self_int8.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int ArithmeticSelfInt8CPUKernel::Init() { + int ret = ReSize(); + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + para_->quant_arg_.in_args_.scale_ = in_quant_args.front().scale; + para_->quant_arg_.in_args_.zp_ = in_quant_args.front().zeroPoint * (-1); + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + para_->quant_arg_.out_args_.scale_ = out_quant_args.front().scale; + para_->quant_arg_.out_args_.zp_ = out_quant_args.front().zeroPoint; + + para_->quant_arg_.output_activation_max_ = std::numeric_limits::max(); + para_->quant_arg_.output_activation_min_ = std::numeric_limits::min(); + + if (para_->op_parameter_.type_ == PrimitiveType_Square) { + const double real_multiplier = + (para_->quant_arg_.in_args_.scale_ * para_->quant_arg_.in_args_.scale_) / para_->quant_arg_.out_args_.scale_; + + int right_shift = 0; + QuantizeMultiplierSmallerThanOne(real_multiplier, ¶_->quant_arg_.output_multiplier_, &right_shift); + + para_->quant_arg_.shift_left_ = right_shift < 0 ? -right_shift : 0; + para_->quant_arg_.shift_right_ = right_shift > 0 ? right_shift : 0; + } + + return ret; +} + +int ArithmeticSelfInt8CPUKernel::ReSize() { + data_size_ = inputs_[0]->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int ArithmeticSelfInt8Runs(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoArithmeticSelf(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ArithmeticSelfInt8CPUKernel::DoArithmeticSelf(int task_id) { + int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + if (arithmeticSelf_run_) { + auto ret = arithmeticSelf_run_(in_ptr_ + offset, out_ptr_ + offset, size, para_->quant_arg_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run failed, illegal input! "; + return ret; + } + } else { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticSelfInt8CPUKernel::Run() { + auto input_tensor = inputs_.at(0); + auto out_tensor = outputs_.at(0); + in_ptr_ = reinterpret_cast(input_tensor->Data()); + out_ptr_ = reinterpret_cast(out_tensor->Data()); + int ret = LiteBackendParallelLaunch(ArithmeticSelfInt8Runs, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticSelfInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Creator failed, opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) ArithmeticSelfInt8CPUKernel(opParameter, inputs, outputs, ctx); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Round, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Floor, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Ceil, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Abs, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sin, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cos, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Log, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h new file mode 100644 index 0000000000..7dfe7d14ad --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_SELF_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_SELF_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h" +#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.h" +#include "schema/model_generated.h" +#include "include/context.h" + + +using mindspore::lite::Context; +using mindspore::schema::PrimitiveType_Round; +using mindspore::schema::PrimitiveType_Floor; +using mindspore::schema::PrimitiveType_Ceil; +using mindspore::schema::PrimitiveType_Abs; +using mindspore::schema::PrimitiveType_Sin; +using mindspore::schema::PrimitiveType_Cos; +using mindspore::schema::PrimitiveType_Log; +using mindspore::schema::PrimitiveType_Sqrt; +using mindspore::schema::PrimitiveType_Rsqrt; +using mindspore::schema::PrimitiveType_Square; +using mindspore::schema::PrimitiveType_LogicalNot; + +namespace mindspore::kernel { +class ArithmeticSelfInt8CPUKernel : public LiteKernel { + typedef int (*ArithmeticSelfInt8Run)(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + + public: + explicit ArithmeticSelfInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + switch (parameter->type_) { + case PrimitiveType_Round: + arithmeticSelf_run_ = ElementRound; + break; + case PrimitiveType_Floor: + arithmeticSelf_run_ = ElementFloor; + break; + case PrimitiveType_Ceil: + arithmeticSelf_run_ = ElementCeil; + break; + case PrimitiveType_Abs: + arithmeticSelf_run_ = ElementAbs; + break; + case PrimitiveType_Sin: + arithmeticSelf_run_ = ElementSin; + break; + case PrimitiveType_Cos: + arithmeticSelf_run_ = ElementCos; + break; + case PrimitiveType_Log: + arithmeticSelf_run_ = ElementLog; + break; + case PrimitiveType_Sqrt: + arithmeticSelf_run_ = ElementSqrt; + break; + case PrimitiveType_Rsqrt: + arithmeticSelf_run_ = ElementRsqrt; + break; + case PrimitiveType_Square: + arithmeticSelf_run_ = ElementSquare; + break; + case PrimitiveType_LogicalNot: + arithmeticSelf_run_ = ElementLogicalNot; + break; + default: + break; + } + para_ = reinterpret_cast(parameter); + } + ~ArithmeticSelfInt8CPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmeticSelf(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + size_t data_size_; + ArithmeticSelfParameter *para_; + ArithmeticSelfInt8Run arithmeticSelf_run_; + const Context *ctx_; + int8_t *in_ptr_; + int8_t *out_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_SELF_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc new file mode 100644 index 0000000000..bdf49bf14e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/batch_to_space.h" +#include "src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int BatchToSpaceInt8CPUKernel::Init() { + auto ret = BatchToSpaceBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + in_quant_arg_.scale_ = in_quant_args.front().scale; + in_quant_arg_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + out_quant_arg_.scale_ = out_quant_args.front().scale; + out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + return RET_OK; +} + +int BatchToSpaceInt8CPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + const int8_t *input_data = reinterpret_cast(input->Data()); + int8_t *output_data = reinterpret_cast(output->Data()); + auto in_shape = input->shape(); + auto out_shape = output->shape(); + BatchToSpaceParameter *param = reinterpret_cast(this->opParameter); + + if (in_quant_arg_.scale_ == out_quant_arg_.scale_ && in_quant_arg_.zp_ == out_quant_arg_.zp_) { + if (IsNoCrop()) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, + sizeof(int8_t)); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_, + sizeof(int8_t)); + } + } else { + if (IsNoCrop()) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, + &in_quant_arg_, &out_quant_arg_); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_, + &in_quant_arg_, &out_quant_arg_); + } + } + + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h new file mode 100644 index 0000000000..17f30f004f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BATCH_TO_SPACE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BATCH_TO_SPACE_INT8_H_ + +#include +#include "src/runtime/kernel/arm/base/batch_to_space_base.h" + +namespace mindspore::kernel { +class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel { + public: + BatchToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~BatchToSpaceInt8CPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + private: + QuantArg in_quant_arg_; + QuantArg out_quant_arg_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BATCH_TO_SPACE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc new file mode 100644 index 0000000000..f80daade18 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/bias_add_int8.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BiasAdd; + +namespace mindspore::kernel { +int BiasAddInt8CPUKernel::Init() { + auto bias_param = reinterpret_cast(opParameter); + auto dims = inputs_[0]->shape(); + bias_param->ndim_ = dims.size(); + for (int i = 0; i < bias_param->ndim_; i++) { + bias_param->in_shape0_[i] = dims[i]; + bias_param->in_shape1_[i] = 1; + bias_param->out_shape_[i] = dims[i]; + } + bias_param->in_shape1_[3] = dims[3]; + return NNACL_OK; +} + +int BiasAddInt8CPUKernel::ReSize() { return NNACL_OK; } + +int BiasAddInt8CPUKernel::Run() { + auto in = reinterpret_cast(inputs_.at(0)->Data()); + auto bias = reinterpret_cast(inputs_.at(1)->Data()); + auto out = reinterpret_cast(outputs_.at(0)->Data()); + size_t data_size = inputs_.at(0)->ElementsNum(); + auto tile_in = static_cast(ctx_->allocator->Malloc(data_size)); + auto tile_bias = static_cast(ctx_->allocator->Malloc(data_size)); + if (tile_in == nullptr || tile_bias == nullptr) { + MS_LOG(ERROR) << "Failed to malloc momery"; + return NNACL_ERR; + } + BroadcastAddInt8(in, bias, tile_in, tile_bias, out, data_size, reinterpret_cast(opParameter)); + ctx_->allocator->Free(tile_in); + ctx_->allocator->Free(tile_bias); + return NNACL_OK; +} + +kernel::LiteKernel *CpuBiasAddInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or context is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_BiasAdd); + auto *kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BiasAdd, CpuBiasAddInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h new file mode 100644 index 0000000000..3ced965318 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/unique.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +namespace mindspore::kernel { +class BiasAddInt8CPUKernel : public LiteKernel { + public: + BiasAddInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx) {} + ~BiasAddInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + const lite::Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc new file mode 100644 index 0000000000..ea88847f16 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/concat_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/concat_int8.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int ConcatInt8CPUKernel::Init() { + ConcatBaseCPUKernel::Init(); + quant_concat_parm_ = concat_param_->concat_quant_arg_; + quant_concat_parm_ = new (std::nothrow) ConcatQuantArg; + auto input_num = inputs_.size(); + quant_concat_parm_->input_num_ = input_num; + quant_concat_parm_->input_sizes_ = reinterpret_cast(malloc(sizeof(int) * input_num)); + if (quant_concat_parm_->input_sizes_ == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->input_sizes_."; + return RET_ERROR; + } + + for (size_t i = 0; i < input_num; i++) { + quant_concat_parm_->input_sizes_[i] = 1; + } + quant_concat_parm_->input_shapes_ = reinterpret_cast(malloc(sizeof(int *) * input_num)); + if (quant_concat_parm_->input_shapes_ == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->input_shapes_."; + return RET_ERROR; + } + + for (size_t i = 0; i < input_num; i++) { + auto *input_tensor = inputs_.at(i); + MS_ASSERT(input_tensor != nullptr); + auto input_size = input_tensor->shape().size(); + MS_ASSERT(input_size != NULL); + quant_concat_parm_->input_shapes_[i] = reinterpret_cast(malloc(sizeof(int) * input_size)); + if (quant_concat_parm_->input_shapes_[i] == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->input_shapes_[" << i << "]."; + return RET_ERROR; + } + + ::memcpy(quant_concat_parm_->input_shapes_[i], input_tensor->shape().data(), sizeof(int) * input_size); + for (size_t j = 0; j < input_size; j++) { + auto *input_tensor_tmp = inputs_.at(i); + auto input_shape = input_tensor_tmp->shape()[j]; + quant_concat_parm_->input_sizes_[i] *= input_shape; + } + } + + quant_concat_parm_->in_quant_args_ = reinterpret_cast(malloc(sizeof(QuantArg) * input_num)); + if (quant_concat_parm_->in_quant_args_ == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->in_quant_args_."; + return RET_ERROR; + } + + for (size_t i = 0; i < input_num; i++) { + auto *input_tensor = inputs_.at(i); + auto quant_args = input_tensor->GetQuantParams(); + MS_ASSERT(quant_args.size() == 1); + quant_concat_parm_->in_quant_args_[i].scale_ = quant_args.front().scale; + quant_concat_parm_->in_quant_args_[i].zp_ = quant_args.front().zeroPoint; + } + + MS_ASSERT(outputs_.size() == 1); + auto output_tensor = outputs_.at(0); + MS_ASSERT(output_tensor != nullptr); + auto output_shape = output_tensor->shape(); + MS_ASSERT(output_shape != NULL); + auto output_dim = output_shape.size(); + quant_concat_parm_->output_dim_ = output_dim; + int output_size = 1; + for (size_t i = 0; i < output_dim; i++) { + output_size *= output_shape[i]; + } + quant_concat_parm_->output_size_ = output_size; + + quant_concat_parm_->output_shape_ = new int[output_size]; + ::memcpy(quant_concat_parm_->output_shape_, output_shape.data(), sizeof(int) * output_size); + + auto quant_args = output_tensor->GetQuantParams(); + MS_ASSERT(quant_args.size() == 1); + quant_concat_parm_->out_quant_args_.scale_ = quant_args.front().scale; + quant_concat_parm_->out_quant_args_.zp_ = quant_args.front().zeroPoint; + + return RET_OK; +} + +int ConcatInt8CPUKernel::ReSize() { return 0; } + +int ConcatInt8CPUKernel::Run() { + auto input_dim = quant_concat_parm_->input_num_; + int8_t **inputs_array = reinterpret_cast(malloc(sizeof(int8_t *) * input_dim)); + for (size_t i = 0; i < input_dim; i++) { + auto input_size = quant_concat_parm_->input_sizes_[i]; + inputs_array[i] = reinterpret_cast(malloc(sizeof(int8_t) * input_size)); + auto input_type = inputs_[i]->data_type(); + if (input_type == kNumberTypeUInt8) { + uint8_t *input_tmp = reinterpret_cast(inputs_[i]->Data()); + for (size_t j = 0; j < input_size; j++) { + inputs_array[i][j] = (int8_t)(input_tmp[j] - 128); + } + for (size_t j = 0; j < input_dim; j++) { + quant_concat_parm_->in_quant_args_[j].zp_ -= 128; + } + quant_concat_parm_->out_quant_args_.zp_ -= 128; + } else { + ::memcpy(inputs_array[i], inputs_.at(i)->Data(), sizeof(int8_t) * input_size); + } + } + int8_t *output_addr = reinterpret_cast(outputs_.at(0)->Data()); + Concat(inputs_array, output_addr, quant_concat_parm_, axis_); + auto output_type = outputs_[0]->data_type(); + if (output_type == kNumberTypeUInt8) { + auto output_size = quant_concat_parm_->output_size_; + for (size_t i = 0; i < output_size; i++) { + output_addr[i] = (uint8_t)(output_addr[i] + 128); + } + } + + for (int i = 0; i < input_dim; i++) { + free(*(inputs_array + i)); + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h new file mode 100644 index 0000000000..192f50b46c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/concat_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ConcatInt8CPUKernel : public ConcatBaseCPUKernel { + public: + ConcatInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConcatInt8CPUKernel() override { delete quant_concat_parm_; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + ConcatQuantArg *quant_concat_parm_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc new file mode 100644 index 0000000000..e3c9471514 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc @@ -0,0 +1,245 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/conv_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParameter *conv_param) { + auto input_channel = conv_param->input_channel_; + auto output_channel = conv_param->output_channel_; + auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int iC8 = UP_DIV(input_channel, C8NUM); + + size_t tmp_size = output_channel * iC8 * C8NUM * kernel_plane * sizeof(int16_t); + auto tmp_addr = reinterpret_cast(malloc(tmp_size)); + memset(tmp_addr, 0, tmp_size); + PackWeightToC8Int8(origin_weight, tmp_addr, conv_param); + Conv3x3Int8FilterTransform(tmp_addr, dst_weight, iC8, output_channel, kernel_plane); + + free(tmp_addr); +} + +Convolution3x3Int8CPUKernel::~Convolution3x3Int8CPUKernel() { + if (transformed_filter_addr_ != nullptr) { + free(transformed_filter_addr_); + } + if (input_data_ != nullptr) { + free(input_data_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + FreeQuantParam(); +} + +int Convolution3x3Int8CPUKernel::InitWeightBias() { + auto input_channel = conv_param_->input_channel_; + auto output_channel = conv_param_->output_channel_; + int iC8 = UP_DIV(input_channel, C8NUM); + int oC4 = UP_DIV(output_channel, C4NUM); + // init weight + size_t transformed_size = iC8 * C8NUM * oC4 * C4NUM * 16 * sizeof(int16_t); + transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); + if (transformed_filter_addr_ == nullptr) { + MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; + return RET_ERROR; + } + memset(transformed_filter_addr_, 0, transformed_size); + auto weight_data = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + ProcessFilterUint8(weight_data, transformed_filter_addr_, conv_param_); + + // init bias + size_t new_bias_size = oC4 * C4NUM * sizeof(int32_t); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(int32_t)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::InitTmpBuffer() { + int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM); + int oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int in_batch = conv_param_->input_batch_; + int input_w = conv_param_->input_w_; + int input_h = conv_param_->input_h_; + int output_batch = conv_param_->output_batch_; + int output_w = conv_param_->output_w_; + int output_h = conv_param_->output_h_; + + /*=============================tile_buffer_============================*/ + size_t tile_buffer_size = thread_count_ * TILE_NUM * 16 * ic8 * C8NUM * sizeof(int16_t); + tile_buffer_ = reinterpret_cast(malloc(tile_buffer_size)); + if (tile_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tile_buffer_ failed."; + return RET_ERROR; + } + memset(tile_buffer_, 0, tile_buffer_size); + + /*=============================block_unit_buffer_============================*/ + size_t block_unit_buffer_size = thread_count_ * 4 * 4 * C8NUM * sizeof(int16_t); + block_unit_buffer_ = reinterpret_cast(malloc(block_unit_buffer_size)); + if (block_unit_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; + return RET_ERROR; + } + memset(block_unit_buffer_, 0, block_unit_buffer_size); + + /*=============================tmp_dst_buffer_============================*/ + size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * 16 * oc4 * C4NUM * sizeof(int32_t); + tmp_dst_buffer_ = reinterpret_cast(malloc(tmp_dst_buffer_size)); + if (tmp_dst_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; + return RET_ERROR; + } + memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); + + /*=============================tmp_out_============================*/ + size_t tmp_out_size = oc4 * C4NUM * output_batch * output_w * output_h * sizeof(uint8_t); + tmp_out_ = reinterpret_cast(malloc(tmp_out_size)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + memset(tmp_out_, 0, tmp_out_size); + + /*=============================input_data_============================*/ + size_t c8_input_size = in_batch * input_h * input_w * ic8 * C8NUM * sizeof(int16_t); + input_data_ = reinterpret_cast(malloc(c8_input_size)); + if (input_data_ == nullptr) { + MS_LOG(ERROR) << "malloc input_data_ failed."; + return RET_ERROR; + } + memset(input_data_, 0, c8_input_size); + return RET_OK; +} + +void Convolution3x3Int8CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); +} + +int Convolution3x3Int8CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + SetQuantParam(); + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + // config input output + ConfigInputOutput(); + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::ReSize() { + if (input_data_ != nullptr) { + free(input_data_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Conv3x3Int8(input_data_, transformed_filter_addr_, reinterpret_cast(bias_data_), output_addr, tile_buffer_, + block_unit_buffer_, tmp_dst_buffer_, tmp_out_, task_id, conv_param_); + return RET_OK; +} + +int Convolution3x3Int8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution3x3 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::Run() { + auto input_addr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + PackInputToC8Int8(input_addr, input_data_, conv_param_); + + int error_code = LiteBackendParallelLaunch(Convolution3x3Int8Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv3x3 int8 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h new file mode 100644 index 0000000000..ef0a8a5560 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" + +namespace mindspore::kernel { +class Convolution3x3Int8CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution3x3Int8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution3x3Int8CPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + int16_t *transformed_filter_addr_; + int16_t *input_data_; + int16_t *tile_buffer_; + int16_t *block_unit_buffer_; + int32_t *tmp_dst_buffer_; + int8_t *tmp_out_; +}; +void ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParameter *conv_param); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc new file mode 100644 index 0000000000..87fa4e1690 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/convolution_depthwise_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() { + // init weight, int8 -> int16 + // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 + auto origin_weight = reinterpret_cast(inputs_[kWeightIndex]->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(int16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(int16_t)); + PackDepthwiseInt8Weight(origin_weight, packed_weight_, conv_param_); + + // init bias, add output zp + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(int32_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(int32_t)); + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::InitBuffer() { + // malloc packed input buffer + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * + UP_DIV(conv_param_->input_channel_, 4); + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(int16_t))); + memset(packed_input_, 0, pack_input_size * sizeof(int16_t)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * + UP_DIV(conv_param_->output_channel_, C4NUM); + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(int8_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding window param + sliding = new SlidingWindowParam; + InitSlidingParam(sliding, conv_param_, C4NUM); + + // init quant param + ConvolutionBaseCPUKernel::SetQuantParam(); + + // init weight and bias + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; + return ret; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; + return ret; + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::ReSize() { + free(packed_input_); + if (need_align_) { + free(packed_output_); + } + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding window param + InitSlidingParam(sliding, conv_param_, C4NUM); + + // init quant param + ConvolutionBaseCPUKernel::SetQuantParam(); + + auto ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; + return ret; + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::Execute(int task_id) { + ConvDwInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding, task_id); + return RET_OK; +} + +int ConvDwInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw_int8 = reinterpret_cast(cdata); + auto ret = conv_dw_int8->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwiseInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + + // pack input, assume input format: NHWC -> NHWC4 + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(ConvDwInt8Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDwInt8Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} + +kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DepthwiseConv2D, CpuConvDwInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h new file mode 100644 index 0000000000..e0ce121dad --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class ConvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwiseInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionDepthwiseInt8CPUKernel() override { + delete sliding; + free(packed_weight_); + free(packed_input_); + if (need_align_) { + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitWeightBias(); + int InitBuffer(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding; + int16_t *packed_weight_; + int16_t *packed_input_; + int8_t *packed_output_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc new file mode 100644 index 0000000000..00415cc542 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -0,0 +1,432 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/convolution_int8.h" +#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/conv_int8.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ConvolutionInt8CPUKernel::CheckSupportOptimize() { + tile_num_ = 24; +#ifdef ENABLE_ARM32 + tile_num_ = 2; + support_optimize_ = false; +#endif + +#ifdef ENABLE_ARM64 + void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; + if (optimize_op_handler != nullptr) { + dlerror(); + *(reinterpret_cast(&gemm_func_)) = dlsym(optimize_op_handler, "IndirectGemmInt8_optimize_handler"); + auto dlopen_error = dlerror(); + if (dlopen_error != nullptr) { + MS_LOG(ERROR) << "load gemm func failed! " << dlopen_error << "."; + tile_num_ = 4; + support_optimize_ = false; + gemm_func_ = nullptr; + } else { + // do nothing + } + } else { + tile_num_ = 4; + support_optimize_ = false; + } +#endif + conv_param_->tile_num_ = tile_num_; +} + +int ConvolutionInt8CPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * plane_c4 * C4NUM; + int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_; + int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_; + + // init weight + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size)); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size); + auto *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + for (int i = 0; i < out_channel; i++) weight_sum[i] = 0; + PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc4 * C4NUM * sizeof(int32_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc4 * C4NUM * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(int32_t)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + auto *bias_data = reinterpret_cast(bias_data_); + int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } + free(weight_sum); + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitTmpBuffer() { + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, tile_num_); + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + int unit_size = plane_c4 * C4NUM * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_num_ * unit_size; + + /*=============================packed_input_============================*/ + packed_input_ = reinterpret_cast(malloc(conv_param_->input_batch_ * packed_input_size)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_input_ failed."; + return RET_ERROR; + } + memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); + + /*=============================input_sum_============================*/ + input_sum_ = reinterpret_cast(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); + if (input_sum_ == nullptr) { + MS_LOG(ERROR) << "malloc input_sum_ failed."; + return RET_ERROR; + } + memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); + + /*=============================tmp_dst_============================*/ + size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); + tmp_dst_ = reinterpret_cast(malloc(tmp_dst_size)); + if (tmp_dst_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_ failed."; + return RET_ERROR; + } + memset(tmp_dst_, 0, tmp_dst_size); + + /*=============================tmp_out_============================*/ + tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4 input failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; + int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_; + int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_; + + // init weight + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size)); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, filter_zp, pack_weight_size); + auto *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane; + PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc4 * C4NUM * sizeof(int32_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc4 * C4NUM * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(int32_t)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + auto *bias_data = reinterpret_cast(bias_data_); + int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } + free(weight_sum); + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, tile_num_); + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_num_ * unit_size; + + /*=============================packed_input_============================*/ + packed_input_ = reinterpret_cast(malloc(conv_param_->input_batch_ * packed_input_size)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_input_ failed."; + return RET_ERROR; + } + memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); + + /*=============================input_sum_============================*/ + input_sum_ = reinterpret_cast(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); + if (input_sum_ == nullptr) { + MS_LOG(ERROR) << "malloc input_sum_ failed."; + return RET_ERROR; + } + memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); + + /*=============================tmp_dst_============================*/ + size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); + tmp_dst_ = reinterpret_cast(malloc(tmp_dst_size)); + if (tmp_dst_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_ failed."; + return RET_ERROR; + } + memset(tmp_dst_, 0, tmp_dst_size); + + /*=============================tmp_out_============================*/ + tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4 input failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +void ConvolutionInt8CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +} + +int ConvolutionInt8CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + // config input output + ConfigInputOutput(); + CheckSupportOptimize(); + SetQuantParam(); + // init for opt + if (support_optimize_) { + ret = InitOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Initialization for optimized int8 conv failed."; + return RET_ERROR; + } + return RET_OK; + } + + // init for situation that not support sdot + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitOpt() { + auto ret = InitWeightBiasOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBufferOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::ReSize() { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (input_sum_ != nullptr) { + free(input_sum_); + } + if (tmp_dst_ != nullptr) { + free(tmp_dst_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + if (support_optimize_) { + ret = InitTmpBufferOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer for opt failed."; + return RET_ERROR; + } + return RET_OK; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (support_optimize_) { + ConvInt8Opt(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_dst_, tmp_out_, output_addr, input_sum_, task_id, + conv_param_, gemm_func_); + } else { + ConvInt8(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_dst_, tmp_out_, output_addr, input_sum_, task_id, + conv_param_); + } + return RET_OK; +} + +int ConvolutionInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionInt8Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv int8 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + kernel::LiteKernel *kernel; + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx); + } else { + kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx); + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Conv2D, CpuConvInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h new file mode 100644 index 0000000000..3250cfa112 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" +#include "src/runtime/kernel/arm/nnacl/int8/conv_int8.h" + +namespace mindspore::kernel { +class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionInt8CPUKernel() override { + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (input_sum_ != nullptr) { + free(input_sum_); + } + if (tmp_dst_ != nullptr) { + free(tmp_dst_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + FreeQuantParam(); + }; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + void CheckSupportOptimize(); + int InitOpt(); + int InitWeightBiasOpt(); + int InitTmpBufferOpt(); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + bool support_optimize_ = true; + int8_t *packed_weight_ = nullptr; + int8_t *packed_input_ = nullptr; + int32_t *input_sum_ = nullptr; + int32_t *tmp_dst_ = nullptr; + int8_t *tmp_out_ = nullptr; + GEMM_FUNC gemm_func_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc new file mode 100644 index 0000000000..4e73001954 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/crop_int8.h" +#include +#include "src/runtime/kernel/arm/nnacl/int8/crop_int8.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int CropInt8CPUKernel::Init() { + CropBaseCPUKernel::Init(); + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + crop_para_->quant_arg.in_args_.scale_ = in_quant_args.front().scale; + crop_para_->quant_arg.in_args_.zp_ = in_quant_args.front().zeroPoint; + auto input_dim = input_tensor->shape().size(); + MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE); + crop_para_->input_dim_ = input_dim; + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + crop_para_->quant_arg.out_args_.scale_ = out_quant_args.front().scale; + crop_para_->quant_arg.out_args_.zp_ = out_quant_args.front().zeroPoint; + + crop_para_->in_shape_ = input_tensor->shape().data(); + crop_para_->out_shape_ = out_tensor->shape().data(); + + crop_para_->quant_arg.output_activation_max_ = std::numeric_limits::max(); + crop_para_->quant_arg.output_activation_min_ = std::numeric_limits::min(); + + PadOffset(input_dim, crop_para_); + return RET_OK; +} + +int CropInt8CPUKernel::ReSize() { return 0; } + +int CropInt8CPUKernel::Run() { + auto ret = LiteBackendParallelLaunch(CropInt8Run, this, thread_count_); + return ret; +} + +void PadOffset(int input_dim, CropParameter *crop_para) { + auto axis = crop_para->axis_; + auto offsets_size = crop_para->offset_size_; + MS_ASSERT(axis <= input_dim); + if (offsets_size > 1) { + MS_ASSERT(axis + offsets_size == input_dim); + } + for (int i = 0; i < input_dim; i++) { + int crop_offset = 0; + if (i >= axis) { + if (offsets_size == 1) { + crop_offset = crop_para->offset_[0]; + } else if (offsets_size > 1) { + crop_offset = crop_para->offset_[i - axis]; + } + } + crop_para->in_offset_[i] = crop_offset; + } +} + +int CropInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto crop = reinterpret_cast(cdata); + crop->DoExecute(task_id); + return RET_OK; +} + +int CropInt8CPUKernel::DoExecute(int task_id) { + auto input_tensor = inputs_.at(kInputIndex); + auto out_tensor = outputs_.at(kOutputIndex); + int8_t *input_data = reinterpret_cast(input_tensor->Data()); + int8_t *output_data = reinterpret_cast(out_tensor->Data()); + Crop(input_data, output_data, task_id, crop_para_); + return RET_OK; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h new file mode 100644 index 0000000000..ebcfd22a65 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CROP_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CROP_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/crop_base.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class CropInt8CPUKernel : public CropBaseCPUKernel { + public: + CropInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : CropBaseCPUKernel(parameter, inputs, outputs, ctx) { + crop_para_ = reinterpret_cast(opParameter); + crop_para_->thread_count_ = opParameter->thread_num_; + } + ~CropInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tId); + + private: + CropParameter *crop_para_; +}; + +int CropInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); +void PadOffset(int input_dim, CropParameter *crop_para); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CROP_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc new file mode 100644 index 0000000000..6de3c47211 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc @@ -0,0 +1,228 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; + +namespace mindspore::kernel { +int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() { + // init weight: int8 -> int16 + // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 + auto origin_weight = reinterpret_cast(inputs_[kWeightIndex]->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(int16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(int16_t)); + PackDepthwiseInt8Weight(origin_weight, packed_weight_, conv_param_); + + // init bias, add output zp + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(int32_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(int32_t)); + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::InitSlideParam() { + conv_param_->input_batch_ = outputs_.front()->shape().at(kNHWC_N); + conv_param_->input_h_ = outputs_.front()->shape().at(kNHWC_H); + conv_param_->input_w_ = outputs_.front()->shape().at(kNHWC_W); + conv_param_->input_channel_ = C4NUM; + conv_param_->output_batch_ = inputs_.front()->shape().at(kNHWC_N); + conv_param_->output_h_ = inputs_.front()->shape().at(kNHWC_H); + conv_param_->output_w_ = inputs_.front()->shape().at(kNHWC_W); + conv_param_->output_channel_ = inputs_.front()->shape().at(kNHWC_C); + + // init sliding window param + InitSlidingParam(sliding, conv_param_, C4NUM); + + sliding->in_h_step_ = conv_param_->input_w_ * C4NUM; + sliding->in_sh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->stride_h_; // stride H + sliding->in_sw_step_ = C4NUM * conv_param_->stride_h_; // stride W + sliding->in_kh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->dilation_h_; // kernel H + sliding->in_kw_step_ = C4NUM * conv_param_->dilation_w_; // kernel W + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::InitBuffer() { + // malloc packed input buffer + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * + UP_DIV(conv_param_->input_channel_, 4); + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(int16_t))); + memset(packed_input_, 0, pack_input_size * sizeof(int16_t)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * + UP_DIV(conv_param_->output_channel_, C4NUM); + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(int8_t))); + if (packed_output_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_output_, 0, pack_output_size * sizeof(int8_t)); + } + + // malloc tmp buffer for int32 output + output_buffer = + reinterpret_cast(malloc(conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * sizeof(int32_t))); + if (output_buffer == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::Init() { + sliding = new SlidingWindowParam; + InitSlideParam(); + + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init quant param + ConvolutionBaseCPUKernel::SetQuantParam(); + + // init weight and bias + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Deconv Depthwise int8 InitWeightBias error!"; + return ret; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Deconv Depthwise int8 InitBuffer error!"; + return ret; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::ReSize() { + free(packed_input_); + if (need_align_) { + free(packed_output_); + } + InitSlideParam(); + + // conv base init + ConvolutionBaseCPUKernel::Init(); + + auto ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Deconv Depthwise int8 InitBuffer error!"; + return ret; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::Execute(int task_id) { + DeconvDwInt8(packed_output_, output_buffer, packed_input_, packed_weight_, reinterpret_cast(bias_data_), + conv_param_, sliding, task_id); + return RET_OK; +} + +int DeconvDwInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv_dw_int8 = reinterpret_cast(cdata); + auto ret = deconv_dw_int8->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvolutionDepthwiseInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + + // pack input, assume input format: NHWC -> NHWC4 + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (!need_align_) { + memset(output_addr, 0, outputs_.at(kOutputIndex)->ElementsNum() * sizeof(int8_t)); + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(DeconvDwInt8Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvDwInt8Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} + +kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h new file mode 100644 index 0000000000..3e56264aae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_DEPTHWISE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_DEPTHWISE_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class DeconvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + DeconvolutionDepthwiseInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeconvolutionDepthwiseInt8CPUKernel() override { + delete sliding; + free(packed_weight_); + free(packed_input_); + if (need_align_) { + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitSlideParam(); + int InitWeightBias(); + int InitBuffer(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding; + int16_t *packed_weight_; + int16_t *packed_input_; + int8_t *packed_output_; + int32_t *output_buffer; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_DEPTHWISE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc new file mode 100644 index 0000000000..f21c5bc1c4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc @@ -0,0 +1,244 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/deconvolution_int8.h" +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { +DeConvInt8CPUKernel::~DeConvInt8CPUKernel() { + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + weight_ptr_ = nullptr; + } + if (tmp_buffer_ != nullptr) { + free(tmp_buffer_); + tmp_buffer_ = nullptr; + } + if (input_ptr_ != nullptr) { + free(input_ptr_); + input_ptr_ = nullptr; + } + if (tmp_output_ != nullptr) { + free(tmp_output_); + tmp_output_ = nullptr; + } + ConvolutionBaseCPUKernel::FreeQuantParam(); +} + +int DeConvInt8CPUKernel::ReSize() { return RET_OK; } + +int DeConvInt8CPUKernel::InitParam() { + fc_param_ = new MatMulParameter(); + fc_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; + fc_param_->deep_ = conv_param_->input_channel_; + fc_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, C8NUM); + fc_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + size_t oc8 = UP_DIV(conv_param_->output_channel_, C8NUM); + thread_count_ = MSMIN(opParameter->thread_num_, oc8); + thread_stride_ = UP_DIV(oc8, thread_count_) * C8NUM; + return RET_OK; +} + +int DeConvInt8CPUKernel::InitBiasWeight() { + if (inputs_.size() == 3) { + size_t size = UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(int32_t); + bias_data_ = malloc(size); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv int8 malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, size); + memcpy(bias_data_, inputs_[0]->Data(), conv_param_->output_channel_ * sizeof(int32_t)); + } else { + bias_data_ = nullptr; + } + + /* weight: ichwoc(nhwc) -> oc8 * h * w * inc * 8 */ + size_t size = conv_param_->kernel_w_ * conv_param_->kernel_h_ * UP_ROUND(conv_param_->output_channel_, C8NUM) * + conv_param_->input_channel_ * sizeof(int8_t); + weight_ptr_ = reinterpret_cast(malloc(size)); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "deconv int8 malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, size); + PackNHWCToC8HWN8Int8(inputs_[1]->Data(), weight_ptr_, conv_param_->input_channel_, + conv_param_->kernel_h_ * conv_param_->kernel_w_, conv_param_->output_channel_); + return RET_OK; +} + +int DeConvInt8CPUKernel::InitData() { + int size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * conv_param_->input_channel_; + input_ptr_ = reinterpret_cast(malloc(size * sizeof(int8_t))); + if (input_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(input_ptr_, 0, size * sizeof(int8_t)); + + size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * + UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_; + tmp_buffer_ = reinterpret_cast(malloc(size * sizeof(int32_t))); + if (tmp_buffer_ == nullptr) { + return RET_MEMORY_FAILED; + } + + size = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->output_h_ * conv_param_->output_w_; + tmp_output_ = reinterpret_cast(malloc(size * sizeof(int32_t))); + if (tmp_output_ == nullptr) { + return RET_MEMORY_FAILED; + } + return RET_OK; +} + +int DeConvInt8CPUKernel::Init() { + ConvolutionBaseCPUKernel::Init(); + int error_code = ConvolutionBaseCPUKernel::SetQuantParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 SetQuantParam error!"; + return error_code; + } + + error_code = InitParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 InitParam error!"; + return error_code; + } + + error_code = InitBiasWeight(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 InitBiasWeight error!"; + return error_code; + } + + error_code = InitData(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 InitData error!"; + return error_code; + } + return RET_OK; +} + +int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoDeconv(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvInt8PostFuncRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoPostFunc(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvInt8PostFuncRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvInt8CPUKernel::DoDeconv(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_ROUND(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + int input_plane = conv_param_->input_h_ * conv_param_->input_w_; + int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; + + DeConvInt8(input_ptr_, weight_ptr_ + task_id * thread_stride_ * kernel_plane * conv_param_->input_channel_, + tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, fc_param_->row_8_, + cur_oc * kernel_plane, fc_param_->deep_, conv_param_); + + return RET_OK; +} + +int DeConvInt8CPUKernel::DoPostFunc(int task_id) { + int input_plane = conv_param_->input_h_ * conv_param_->input_w_; + int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; + int output_plane = conv_param_->output_h_ * conv_param_->output_w_; + + int cur_oc = MSMIN(thread_stride_, conv_param_->output_channel_ - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + DeConvPostInt8(tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, + reinterpret_cast(bias_data_) + task_id * thread_stride_, + tmp_output_ + task_id * thread_stride_ * output_plane, output_ptr_ + task_id * thread_stride_, cur_oc, + conv_param_); + return RET_OK; +} + +int DeConvInt8CPUKernel::Run() { + int8_t *src_in = reinterpret_cast(inputs_[0]->Data()); + int8_t *src_out = reinterpret_cast(outputs_[0]->Data()); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + RowMajor2Col8MajorInt8(src_in + batch_index * fc_param_->row_ * conv_param_->input_channel_, input_ptr_, + fc_param_->row_, fc_param_->deep_); + output_ptr_ = src_out + batch_index * fc_param_->col_; + + int error_code = LiteBackendParallelLaunch(DeConvInt8Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + error_code = LiteBackendParallelLaunch(DeConvInt8PostFuncRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 post run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + } + + return RET_OK; +} + +kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto kernel = new (std::nothrow) kernel::DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DeConv2D, CpuDeConvInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h new file mode 100644 index 0000000000..b63633f9a5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/int8/deconv.h" +#include "src/runtime/kernel/arm/nnacl/int8/matmul.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +namespace mindspore::kernel { +class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + DeConvInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeConvInt8CPUKernel() override; + + int ReSize() override; + int Init() override; + int Run() override; + + public: + int DoDeconv(int task_id); + int DoPostFunc(int task_id); + + private: + int InitData(); + int InitParam(); + int InitBiasWeight(); + + private: + MatMulParameter *fc_param_; + int8_t *weight_ptr_; + int8_t *input_ptr_; /* record c8 input*/ + int32_t *tmp_buffer_; /* record matmul result */ + int32_t *tmp_output_; /* record post c8 result */ + int8_t *output_ptr_; + size_t thread_count_; + size_t thread_stride_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc new file mode 100644 index 0000000000..b70624aa18 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/depth_to_space.h" +#include "src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_OK; +using mindspore::lite::RET_ERROR; + +namespace mindspore::kernel { +int DepthToSpaceInt8CPUKernel::Init() { + auto ret = DepthToSpaceBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + param->data_type_size_ = sizeof(int8_t); + + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + in_quant_arg_.scale_ = in_quant_args.front().scale; + in_quant_arg_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + out_quant_arg_.scale_ = out_quant_args.front().scale; + out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + return RET_OK; +} + +int DepthToSpaceInt8CPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + const int8_t *input_data = reinterpret_cast(input->Data()); + int8_t *output_data = reinterpret_cast(output->Data()); + auto in_shape = input->shape(); + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + if (in_quant_arg_.scale_ == out_quant_arg_.scale_ && in_quant_arg_.zp_ == out_quant_arg_.zp_) { + DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), param); + } else { + DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), param, &in_quant_arg_, &out_quant_arg_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h new file mode 100644 index 0000000000..427b6d5eb0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEPTH_TO_SPACE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEPTH_TO_SPACE_INT8_H_ + +#include +#include "src/runtime/kernel/arm/base/depth_to_space_base.h" + +namespace mindspore::kernel { +class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel { + public: + DepthToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~DepthToSpaceInt8CPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + private: + QuantArg in_quant_arg_; + QuantArg out_quant_arg_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEPTH_TO_SPACE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc new file mode 100644 index 0000000000..a14438c815 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/fullconnection_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/matmul.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int FullconnectionInt8CPUKernel::Init() { + fc_param_->row_ = (inputs_[0]->shape())[0]; + fc_param_->col_ = (inputs_[1]->shape())[0]; + fc_param_->deep_ = (inputs_[1]->shape())[1]; + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); + fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); + + thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); + + a_c8_ptr_ = + reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t))); + if (!a_c8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)); + b_r8_ptr_ = + reinterpret_cast(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t))); + if (!b_r8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)); + auto weight_data = reinterpret_cast(inputs_[1]->Data()); + RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_); + c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int))); + if (!c_r8x8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)); + auto bias_len = fc_param_->col_8_ * sizeof(int); + bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_len)); + if (!bias_ptr_) { + return RET_MEMORY_FAILED; + } + memset(bias_ptr_, 0, bias_len); + if (inputs_.size() == 3) { + memcpy(bias_ptr_, inputs_[2]->Data(), bias_len); + } + + auto input_tensor = inputs_[0]; + auto params = input_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.input.zp_ = params.front().zeroPoint; + quant_params_.input.scale_ = params.front().scale; + auto weight_tensor = inputs_[1]; + params = weight_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.weight.zp_ = params.front().zeroPoint; + quant_params_.weight.scale_ = params.front().scale; + auto output_tensor = outputs_[0]; + params = output_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.output.zp_ = params.front().zeroPoint; + quant_params_.output.scale_ = params.front().scale; + + double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; + QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, + &quant_params_.right_shift); + CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, + quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_max, + &quant_params_.out_act_min); + return RET_OK; +} + +int FullconnectionInt8CPUKernel::ReSize() { return RET_OK; } + +int FullconnectionInt8CPUKernel::RunImpl(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + auto &p = quant_params_; + auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_; + auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_; + MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_); + return RET_OK; +} + +int FcInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto fc = reinterpret_cast(cdata); + auto ret = fc->RunImpl(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FcInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int FullconnectionInt8CPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_[0]->Data()); + auto output_ptr = reinterpret_cast(outputs_[0]->Data()); + auto &p = quant_params_; + RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + LiteBackendParallelLaunch(FcInt8Run, this, thread_count_); + PostFuncInt8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->row_8_, + p.quant_multiplier, p.left_shift, p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max); + return RET_OK; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h new file mode 100644 index 0000000000..30aef2d45f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_ + +#include +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" +#include "src/runtime/kernel/arm/base/fullconnection_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel { + public: + FullconnectionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~FullconnectionInt8CPUKernel() override { + ctx_->allocator->Free(a_c8_ptr_); + ctx_->allocator->Free(b_r8_ptr_); + ctx_->allocator->Free(c_r8x8_ptr_); + } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: + MatmulQuantArg quant_params_; + int8_t *a_c8_ptr_; + int8_t *b_r8_ptr_; + int *c_r8x8_ptr_; + int *bias_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc new file mode 100644 index 0000000000..7b5f9795eb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/hswish_int8.h" +#include +#include "src/runtime/kernel/arm/nnacl/int8/hswish_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_HSWISH; + +namespace mindspore::kernel { +int HswishInt8CPUKernel::Init() { + lite::tensor::Tensor *input = inputs_.at(0); + lite::tensor::Tensor *output = outputs_.at(0); + MS_ASSERT(input); + MS_ASSERT(output); + + quant_arg_.input_scale = input->GetQuantParams().front().scale; + quant_arg_.input_zp = input->GetQuantParams().front().zeroPoint; + quant_arg_.output_scale = output->GetQuantParams().front().scale; + quant_arg_.output_zp = output->GetQuantParams().front().zeroPoint; + + const float output_multiplier = (1.0f / 128.0f) * quant_arg_.input_scale / quant_arg_.output_scale; + + int32_t output_multiplier_fixedpoint; + QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint, &quant_arg_.output_multiplier_exponent); + MS_ASSERT(quant_arg_.output_multiplier_exponent <= 0); + MultiplierInt32ToInt16(output_multiplier_fixedpoint, &quant_arg_.output_multiplier_fixedpoint_int16); + + const float relu6_multiplier = (1.0f / 128.0f) * quant_arg_.input_scale / (3.0f / 32768.0f); + int32_t relu6_multiplier_fixedpoint; + QuantizeMultiplier(relu6_multiplier, &relu6_multiplier_fixedpoint, &quant_arg_.relu6_multiplier_exponent); + MultiplierInt32ToInt16(relu6_multiplier_fixedpoint, &quant_arg_.relu6_multiplier_fixedpoint_int16); + + return RET_OK; +} + +void HswishInt8CPUKernel::MultiplierInt32ToInt16(int32_t input, int16_t *output) { + MS_ASSERT(input >= 0); + if (input >= std::numeric_limits::max() - (1 << 15)) { + *output = std::numeric_limits::max(); + return; + } + *output = (input + (1 << 15)) >> 16; +} + +int HswishInt8CPUKernel::ReSize() { return RET_OK; } + +int HswishInt8CPUKernel::DoActivation(int task_id) { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto length = inputs_.at(0)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + HSwishInt8(input_addr + stride * task_id, count, output_addr + stride * task_id, &quant_arg_); + return RET_OK; +} + +int HswishInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activation_kernel = reinterpret_cast(cdata); + auto error_code = activation_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "HswishInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int HswishInt8CPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(HswishInt8Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "HswishInt8Run function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h new file mode 100644 index 0000000000..b907f2c338 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_INT8_HSWISH_INT8_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_INT8_HSWISH_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/int8/hswish_int8.h" + +namespace mindspore::kernel { +class HswishInt8CPUKernel : public LiteKernel { + public: + HswishInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + ~HswishInt8CPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + + private: + int thread_count_; + HswishQuantArg quant_arg_; + void MultiplierInt32ToInt16(int32_t input, int16_t *output); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_INT8_HSWISH_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc new file mode 100644 index 0000000000..5015fa0cea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/matmul_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/matmul.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { + ctx_->allocator->Free(a_c8_ptr_); + ctx_->allocator->Free(b_r8_ptr_); + ctx_->allocator->Free(c_r8x8_ptr_); +} + +int MatmulInt8CPUKernel::Init() { + int batch = 1; + auto x_shape = inputs_[0]->shape(); + auto o_shape = outputs_[0]->shape(); + for (int i = 0; i < x_shape.size() - 2; ++i) { + batch *= x_shape[i]; + } + params_->batch = batch; + params_->row_ = o_shape[o_shape.size() - 2]; + params_->col_ = o_shape[o_shape.size() - 1]; + params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; + params_->row_8_ = UP_ROUND(params_->row_, 8); + params_->col_8_ = UP_ROUND(params_->col_, 8); + thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); + + a_c8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t))); + if (!a_c8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(int8_t)); + b_r8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(int8_t))); + if (!b_r8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(int8_t)); + c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(int))); + if (!c_r8x8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int)); + + auto input_tensor = inputs_[0]; + auto params = input_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.input.zp_ = params.front().zeroPoint; + quant_params_.input.scale_ = params.front().scale; + auto weight_tensor = inputs_[1]; + params = weight_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.weight.zp_ = params.front().zeroPoint; + quant_params_.weight.scale_ = params.front().scale; + auto output_tensor = outputs_[0]; + params = output_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.output.zp_ = params.front().zeroPoint; + quant_params_.output.scale_ = params.front().scale; + + double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; + QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, + &quant_params_.right_shift); + return RET_OK; +} + +int MatmulInt8CPUKernel::ReSize() { return RET_OK; } + +int MatmulInt8CPUKernel::RunImpl(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; + auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; + MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_, + quant_params_.weight.zp_); + return RET_OK; +} + +int MatmulInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto op = reinterpret_cast(cdata); + auto ret = op->RunImpl(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int MatmulInt8CPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_[0]->Data()); + auto b_ptr = reinterpret_cast(inputs_[1]->Data()); + auto c_ptr = reinterpret_cast(outputs_[0]->Data()); + auto a_stride = params_->row_ * params_->deep_; + auto b_stride = params_->deep_ * params_->col_; + auto c_stride = params_->row_ * params_->col_; + + for (int i = 0; i < params_->batch; ++i) { + auto cur_a_ptr = a_ptr + i * a_stride; + auto cur_b_ptr = b_ptr + i * b_stride; + auto cur_c_ptr = c_ptr + i * c_stride; + if (params_->a_transpose_) { + RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); + } else { + RowMajor2Col8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_); + } + if (params_->b_transpose_) { + RowMajor2Col8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_); + } else { + RowMajor2Row8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_); + } + LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_); + auto &q = quant_params_; + SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier, + q.left_shift, q.right_shift, q.output.zp_); + } + + return RET_OK; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h new file mode 100644 index 0000000000..dc8f5ec0b6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_INT8_H_ + +#include +#include "include/context.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" +#include "src/runtime/kernel/arm/base/matmul_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { + public: + MatmulInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : MatmulBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~MatmulInt8CPUKernel() override; + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: + MatmulQuantArg quant_params_; + int8_t *a_c8_ptr_; + int8_t *b_r8_ptr_; + int *c_r8x8_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc new file mode 100644 index 0000000000..a9fd7c1378 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/mul_int8.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/kernel/arm/nnacl/int8/mul_int8.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Mul; + +namespace mindspore::kernel { +int MulInt8CPUKernel::Init() { + lite::tensor::Tensor *input0 = inputs_.at(0); + lite::tensor::Tensor *input1 = inputs_.at(1); + lite::tensor::Tensor *output = outputs_.at(0); + MS_ASSERT(input0); + MS_ASSERT(input1); + MS_ASSERT(output); + + para_.mul_quant_arg_.in_quant_args_[0].scale_ = input0->GetQuantParams().front().scale; + para_.mul_quant_arg_.in_quant_args_[0].zp_ = input0->GetQuantParams().front().zeroPoint * -1; + para_.mul_quant_arg_.in_quant_args_[1].scale_ = input1->GetQuantParams().front().scale; + para_.mul_quant_arg_.in_quant_args_[1].zp_ = input1->GetQuantParams().front().zeroPoint * -1; + para_.mul_quant_arg_.out_quant_arg_.scale_ = output->GetQuantParams().front().scale; + para_.mul_quant_arg_.out_quant_arg_.zp_ = output->GetQuantParams().front().zeroPoint; + para_.mul_quant_arg_.output_activation_max_ = std::numeric_limits::max(); + para_.mul_quant_arg_.output_activation_min_ = std::numeric_limits::min(); + + const double real_multiplier = + (para_.mul_quant_arg_.in_quant_args_[0].scale_ * para_.mul_quant_arg_.in_quant_args_[1].scale_) / + para_.mul_quant_arg_.out_quant_arg_.scale_; + + int right_shift = 0; + QuantizeMultiplierSmallerThanOne(real_multiplier, ¶_.mul_quant_arg_.output_multiplier_, &right_shift); + + para_.mul_quant_arg_.shift_left_ = right_shift < 0 ? -right_shift : 0; + para_.mul_quant_arg_.shift_right_ = right_shift > 0 ? right_shift : 0; + + return RET_OK; +} + +int MulInt8CPUKernel::ReSize() { return RET_OK; } + +int MulInt8CPUKernel::Run() { + input0_data_ = static_cast(inputs_.at(0)->Data()); + input1_data_ = static_cast(inputs_.at(1)->Data()); + output_data_ = static_cast(outputs_.at(0)->Data()); + + elements_num_ = inputs_.at(0)->ElementsNum(); + count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; + + if (inputs_.at(0)->ElementsNum() != inputs_.at(1)->ElementsNum()) { + input0_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + input1_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + + ArithmeticParameter tile_para = {0}; + tile_para.ndim_ = outputs_.at(0)->shape().size(); + for (size_t i = 0; i < tile_para.ndim_; i++) { + tile_para.in_shape0_[i] = inputs_.at(0)->DimensionSize(i); + tile_para.in_shape1_[i] = inputs_.at(1)->DimensionSize(i); + tile_para.out_shape_[i] = outputs_.at(0)->DimensionSize(i); + } + TileDimensionsInt8(static_cast(inputs_.at(0)->Data()), static_cast(inputs_.at(1)->Data()), + input0_data_, input1_data_, &tile_para); + auto ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); + ctx_->allocator->Free(input0_data_); + ctx_->allocator->Free(input1_data_); + return ret; + } + + auto ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); + return ret; +} + +int MulInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto mul = reinterpret_cast(cdata); + mul->DoExecute(task_id); + return lite::RET_OK; +} + +int MulInt8CPUKernel::DoExecute(int tId) { + int64_t real_dst_count = MSMIN(elements_num_ - tId * count_unit_, count_unit_); + int8_t *cur_input0_data = input0_data_ + tId * count_unit_; + int8_t *cur_input1_data = input1_data_ + tId * count_unit_; + int8_t *cur_output_data = output_data_ + tId * count_unit_; + + Mul(cur_input0_data, cur_input1_data, cur_output_data, real_dst_count, para_.mul_quant_arg_); + return lite::RET_OK; +} + +kernel::LiteKernel *CpuMulInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, const KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Mul); + auto *kernel = new (std::nothrow) MulInt8CPUKernel(opParameter, inputs, outputs, ctx); + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Mul, CpuMulInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h new file mode 100644 index 0000000000..7725591a06 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/mul_parameter.h" +#include "src/runtime/runtime_api.h" + +namespace mindspore::kernel { +class MulInt8CPUKernel : public LiteKernel { + public: + explicit MulInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx_->thread_num_) {} + ~MulInt8CPUKernel() override {}; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tId); + + private: + const lite::Context *ctx_; + MulParameter para_; + int thread_count_; + int64_t elements_num_; + int64_t count_unit_; + int8_t *input0_data_ = nullptr; + int8_t *input1_data_ = nullptr; + int8_t *output_data_ = nullptr; +}; + +int MulInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc new file mode 100644 index 0000000000..42b3f7c16a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/pad_int8.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +void PadInt8CPUKernel::FreeQuantParam() { + if (pad_param_->pad_quant_arg_.in_quant_args_ != nullptr) { + free(pad_param_->pad_quant_arg_.in_quant_args_); + pad_param_->pad_quant_arg_.in_quant_args_ = nullptr; + } + if (pad_param_->pad_quant_arg_.out_quanr_args_ != nullptr) { + free(pad_param_->pad_quant_arg_.out_quanr_args_); + pad_param_->pad_quant_arg_.out_quanr_args_ = nullptr; + } +} + +int PadInt8CPUKernel::SetQuantParam() { + PadQuantArg *pad_quant_args = &pad_param_->pad_quant_arg_; + pad_quant_args->in_quant_args_ = reinterpret_cast(malloc(sizeof(QuantArg))); + if (pad_quant_args->in_quant_args_ == nullptr) { + return RET_MEMORY_FAILED; + } + pad_quant_args->out_quanr_args_ = reinterpret_cast(malloc(sizeof(QuantArg))); + if (pad_quant_args->out_quanr_args_ == nullptr) { + return RET_MEMORY_FAILED; + } + pad_quant_args->constant_value_ = reinterpret_cast(malloc(sizeof(int8_t))); + if (pad_quant_args->constant_value_ == nullptr) { + return RET_MEMORY_FAILED; + } + + auto *input_tensor = inputs_.at(kInputIndex); + auto *out_tensor = outputs_.at(kOutputIndex); + auto in_quant_arg = input_tensor->GetQuantParams(); + auto out_quant_arg = out_tensor->GetQuantParams(); + + pad_quant_args->in_quant_args_->zp_ = in_quant_arg.front().zeroPoint; + pad_quant_args->in_quant_args_->scale_ = in_quant_arg.front().scale; + pad_quant_args->out_quanr_args_->zp_ = out_quant_arg.front().zeroPoint; + pad_quant_args->out_quanr_args_->scale_ = out_quant_arg.front().scale; + + if (pad_quant_args->in_quant_args_->scale_ != pad_quant_args->out_quanr_args_->scale_ || + pad_quant_args->in_quant_args_->zp_ != pad_quant_args->out_quanr_args_->zp_) { + MS_LOG(ERROR) << "Pad int8 op : scale & zp of output and input must be equal."; + return RET_ERROR; + } + + pad_quant_args->constant_value_[0] = QuantizeToInt8( + pad_param_->constant_value_, pad_quant_args->in_quant_args_->scale_, pad_quant_args->in_quant_args_->zp_); + return RET_OK; +} + +int PadInt8CPUKernel::InitPadParam() { + auto in_dims = inputs_[0]->shape(); + auto out_dims = outputs_[0]->shape(); + int ndims = in_dims.size(); + + int in[] = {1, 1, 1, 1}; + int out[] = {1, 1, 1, 1}; + + for (int i = 0; i < ndims; i++) { + in[DEFAULT_PAD_NDIMS - ndims + i] = in_dims[i]; + out[DEFAULT_PAD_NDIMS - ndims + i] = out_dims[i]; + } + + memcpy(in_dims_, in, DEFAULT_PAD_NDIMS * sizeof(int)); + memcpy(out_dims_, out, DEFAULT_PAD_NDIMS * sizeof(int)); + + return RET_OK; +} + +int PadInt8CPUKernel::ReSize() { + InitPadParam(); + return RET_OK; +} + +int PadInt8CPUKernel::Init() { + int error_code = InitPadParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "InitPadParam failed. errorcode: " << error_code; + return error_code; + } + + error_code = SetQuantParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "SetQuantParam failed. errorcode: " << error_code; + return error_code; + } + return RET_OK; +} + +int PadInt8CPUKernel::Run() { + int8_t *in_data = reinterpret_cast(inputs_[0]->Data()); + int8_t *out_data = reinterpret_cast(outputs_[0]->Data()); + + memset(out_data, pad_param_->pad_quant_arg_.constant_value_[0], outputs_[0]->ElementsNum() * sizeof(int8_t)); + PadConstant4D(in_data, out_data, in_dims_, out_dims_, pad_param_->paddings_); + return RET_OK; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h new file mode 100644 index 0000000000..3d9dda740b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ + +#include +#include "include/errorcode.h" +#include "src/lite_kernel.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/nnacl/pad_parameter.h" +#include "src/runtime/kernel/arm/nnacl/int8/pad.h" + +namespace mindspore::kernel { +class PadInt8CPUKernel : public LiteKernel { + public: + explicit PadInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + opParameter->thread_num_ = ctx->thread_num_; + pad_param_ = reinterpret_cast(opParameter); + } + ~PadInt8CPUKernel() override { FreeQuantParam(); }; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + int SetQuantParam(); + int InitPadParam(); + void FreeQuantParam(); + + private: + PadParameter *pad_param_; + int in_dims_[DEFAULT_PAD_NDIMS]; + int out_dims_[DEFAULT_PAD_NDIMS]; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc new file mode 100644 index 0000000000..9a810d27f8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/nnacl/fp32/cast.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int PoolingInt8CPUKernel::Init() { + auto ret = PoolingBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PoolingBase Init failed."; + return RET_ERROR; + } + ret = SetQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set pooling quant param failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingInt8CPUKernel::ReSize() { + FreeQuantParam(); + auto ret = PoolingBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PoolingBase Init failed."; + return RET_ERROR; + } + SetQuantParam(); + ret = SetQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set pooling quant param failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingInt8CPUKernel::RunImpl(int task_id) { + auto input_data = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + auto output_data = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (pooling_param_->max_pooling_) { + MaxPoolingInt8(input_data, output_data, pooling_param_, task_id); + } else { + AvgPoolingInt8(input_data, output_data, pooling_param_, task_id); + } + return RET_OK; +} + +int PoolingInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto pooling = reinterpret_cast(cdata); + auto error_code = pooling->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "PoolingInt8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingInt8CPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(PoolingInt8Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "poolingInt8 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h new file mode 100644 index 0000000000..367c10f59c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POOLING_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POOLING_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/pooling_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class PoolingInt8CPUKernel : public PoolingBaseCPUKernel { + public: + PoolingInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~PoolingInt8CPUKernel() { FreeQuantParam(); } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POOLING_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.cc new file mode 100644 index 0000000000..62178774b9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/relu_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_RELU; + +namespace mindspore::kernel { +int ReluInt8CPUKernel::Init() { + lite::tensor::Tensor *input = inputs_.at(0); + lite::tensor::Tensor *output = outputs_.at(0); + MS_ASSERT(input); + MS_ASSERT(output); + + quant_arg_.input_arg.scale_ = input->GetQuantParams().front().scale; + quant_arg_.input_arg.zp_ = input->GetQuantParams().front().zeroPoint; + quant_arg_.output_arg.scale_ = output->GetQuantParams().front().scale; + quant_arg_.output_arg.zp_ = output->GetQuantParams().front().zeroPoint; + + const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_; + QuantizeRoundParameter(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, &quant_arg_.right_shift_); + + return RET_OK; +} + +int ReluInt8CPUKernel::ReSize() { return RET_OK; } + +int ReluInt8CPUKernel::DoActivation(int task_id) { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto length = inputs_.at(0)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + ReluInt8(input_addr + stride * task_id, count, output_addr + stride * task_id, &quant_arg_); + return RET_OK; +} + +int ReluInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activation_kernel = reinterpret_cast(cdata); + auto error_code = activation_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ReluInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ReluInt8CPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ReluInt8Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ReluInt8Run function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h new file mode 100644 index 0000000000..5f8f81a3a5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_INT8_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_INT8_ACTIVATION_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/activation.h" +#include "src/runtime/kernel/arm/nnacl/int8/relu_int8.h" + +namespace mindspore::kernel { +class ReluInt8CPUKernel : public LiteKernel { + public: + ReluInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) { + type_ = (reinterpret_cast(parameter))->type_; + } + ~ReluInt8CPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + + private: + int thread_count_; + int type_; + ReluQuantArg quant_arg_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_INT8_ACTIVATION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc new file mode 100644 index 0000000000..7969b89a78 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/reshape_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/reshape_int8.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int ReshapeInt8CPUKernel::Init() { + ReshapeBaseCPUKernel::Init(); + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + in_quant_arg_.scale_ = in_quant_args.front().scale; + in_quant_arg_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + out_quant_arg_.scale_ = out_quant_args.front().scale; + out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + return RET_OK; +} + +int ReshapeInt8CPUKernel::ReSize() { return 0; } + +int ReshapeInt8CPUKernel::Run() { + MS_ASSERT(inputs_.size() == 1); + MS_ASSERT(outputs_.size() == 1); + auto input_type = inputs_[kInputIndex]->data_type(); + auto input_num = inputs_[kInputIndex]->ElementsNum(); + auto output_num = outputs_.at(kOutputIndex)->ElementsNum(); + MS_ASSERT(input_num == output_num); + int8_t *input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + int8_t *output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (input_type == kNumberTypeUInt8) { + auto *input_tmp = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + for (size_t i = 0; i < input_num; i++) { + input_ptr[i] = (int8_t)(input_tmp[i] - 128); + } + in_quant_arg_.zp_ -= 128; + out_quant_arg_.zp_ -= 128; + } + + size_t data_size = inputs_.at(kInputIndex)->Size(); + Reshape(input_ptr, output_ptr, data_size, input_num, in_quant_arg_, out_quant_arg_); + + auto output_type = outputs_[kOutputIndex]->data_type(); + if (output_type == kNumberTypeUInt8) { + for (size_t i = 0; i < output_num; i++) { + output_ptr[i] = (uint8_t)(output_ptr[i] + 128); + } + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h new file mode 100644 index 0000000000..cb1065a4c0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/base/reshape_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeInt8CPUKernel : public ReshapeBaseCPUKernel { + public: + ReshapeInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ReshapeInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + QuantArg in_quant_arg_; + QuantArg out_quant_arg_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc new file mode 100644 index 0000000000..2a3ef9d6ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/softmax_int8.h" +#include "src/runtime/kernel/arm/nnacl/int8/softmax_int8.h" +#include "schema/model_generated.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int SoftmaxInt8CPUKernel::Init() { + SoftmaxBaseCPUKernel::Init(); + + auto *input_tensor = inputs_.at(kInputIndex); + MS_ASSERT(input_tensor); + + auto in_quant_args = input_tensor->GetQuantParams(); + quant_params_.in_quant_args_.scale_ = in_quant_args.front().scale; + quant_params_.in_quant_args_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + MS_ASSERT(out_tensor); + + auto out_quant_args = out_tensor->GetQuantParams(); + quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale; + quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + + int inner_size = 1; + for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) { + inner_size *= softmax_param_->input_shape_[i]; + } + + exp_data_ = reinterpret_cast(malloc(softmax_param_->element_size_ * sizeof(float))); + sum_data_ = reinterpret_cast(malloc(inner_size * sizeof(float))); + return RET_OK; +} + +int SoftmaxInt8CPUKernel::ReSize() { return RET_OK; } + +int SoftmaxInt8CPUKernel::DoSoftmax(int task_id) { + MS_ASSERT(inputs_.size() == 1); + MS_ASSERT(outputs_.size() == 1); + + auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + + int outter_size = 1, inner_size = 1; + for (int i = 0; i < softmax_param_->axis_; i++) { + outter_size *= softmax_param_->input_shape_[i]; + } + for (int i = softmax_param_->axis_; i < softmax_param_->n_dim_; i++) { + inner_size *= softmax_param_->input_shape_[i]; + } + + int stride = UP_DIV(outter_size, thread_count_); + int count = MSMIN(stride, outter_size - stride * task_id); + + input_ptr += stride * task_id * inner_size; + output_ptr += stride * task_id * inner_size; + exp_data_ += stride * task_id * inner_size; + + auto error_code = Softmax(input_ptr, output_ptr, count, exp_data_, sum_data_, quant_params_, softmax_param_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DoSoftmax error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SoftmaxRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto softmax_kernel = reinterpret_cast(cdata); + auto error_code = softmax_kernel->DoSoftmax(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "SoftmaxRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SoftmaxInt8CPUKernel::Run() { + auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); + int ele_size = softmax_param_->element_size_; + for (int i = 0; i < ele_size; i++) { + float input_scaled = ((input_ptr[i] - quant_params_.in_quant_args_.zp_) * quant_params_.in_quant_args_.scale_); + exp_data_[i] = exp(input_scaled); + } + int error_code = LiteBackendParallelLaunch(SoftmaxRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Softmax function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h new file mode 100644 index 0000000000..d29da4094a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ + +#include +#include "src/runtime/kernel/arm/base/softmax_base.h" + +namespace mindspore::kernel { +class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel { + public: + SoftmaxInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~SoftmaxInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoSoftmax(int task_id); + + private: + float *sum_data_; + float *exp_data_; + SoftmaxQuantArg quant_params_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc new file mode 100644 index 0000000000..80f9cdf61c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/split_int8.h" +#include +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" +#include "src/runtime/kernel/arm/nnacl/int8/split_int8.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int SplitInt8CPUKernel::Init() { + SplitBaseCPUKernel::Init(); + auto in_tensor = inputs_.at(kInputIndex); + input_ptr_ = reinterpret_cast(in_tensor->Data()); + for (int i = 0; i < param->num_split_; i++) { + output_ptr_.push_back(reinterpret_cast(outputs_.at(i)->Data())); + } + + auto in_quant_args = in_tensor->GetQuantParams(); + param->quant_arg_.in_args_.scale_ = in_quant_args.front().scale; + param->quant_arg_.in_args_.zp_ = in_quant_args.front().zeroPoint; + + MS_ASSERT(param->num_split_ == outputs_.size()); + for (int i = 0; i < param->num_split_; i++) { + auto *out_tensor = outputs_.at(i); + auto out_quant_args = out_tensor->GetQuantParams(); + param->quant_arg_.out_args_[i].scale_ = out_quant_args.front().scale; + param->quant_arg_.out_args_[i].zp_ = out_quant_args.front().zeroPoint; + } + + param->quant_arg_.output_activation_max_ = std::numeric_limits::max(); + param->quant_arg_.output_activation_min_ = std::numeric_limits::min(); + + return RET_OK; +} + +int SplitInt8CPUKernel::ReSize() { return RET_OK; } + +int SplitInt8CPUKernel::Split(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_n_stride_; + auto ret = + DoSplit(input_ptr_, output_ptr_.data(), inputs_.front()->shape().data(), thread_offset, num_unit_thread, param); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Split error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SplitInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->Split(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SplitRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SplitInt8CPUKernel::Run() { + int ret = LiteBackendParallelLaunch(SplitInt8Run, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h new file mode 100644 index 0000000000..501d595634 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SPLIT_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SPLIT_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/split_base.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class SplitInt8CPUKernel : public SplitBaseCPUKernel { + public: + SplitInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : SplitBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~SplitInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int Split(int tId); + + private: + int8_t *input_ptr_; + std::vector output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SPLIT_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc new file mode 100644 index 0000000000..2280d9e078 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/topk_int8.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_TopK; + +namespace mindspore::kernel { +int TopKInt8CPUKernel::Init() { + TopkParameter *parameter = reinterpret_cast(opParameter); + lite::tensor::Tensor *input = inputs_.at(0); + parameter->last_dim_size_ = input->shape()[input->shape().size() - 1]; + parameter->loop_num_ = 1; + for (int i = 0; i < input->shape().size() - 1; ++i) { + parameter->loop_num_ *= input->shape()[i]; + } + + parameter->topk_node_list_ = malloc(sizeof(TopkNodeInt8) * parameter->last_dim_size_); + if (parameter->topk_node_list_ == nullptr) { + MS_LOG(ERROR) << "malloc fail."; + return RET_ERROR; + } + return RET_OK; +} + +int TopKInt8CPUKernel::ReSize() { return RET_OK; } + +int TopKInt8CPUKernel::Run() { + int8_t *input_data = reinterpret_cast(inputs_.at(0)->Data()); + int8_t *output_data = reinterpret_cast(outputs_.at(0)->Data()); + int32_t *output_index = reinterpret_cast(outputs_.at(1)->Data()); + + TopkInt8(input_data, output_data, output_index, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuTopKInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + auto *kernel = new (std::nothrow) TopKInt8CPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new TopKInt8CPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_TopK, CpuTopKInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h new file mode 100644 index 0000000000..1216033e9c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/int8/topk_int8.h" + +namespace mindspore::kernel { +class TopKInt8CPUKernel : public LiteKernel { + public: + explicit TopKInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~TopKInt8CPUKernel() override { + TopkParameter *parameter = reinterpret_cast(opParameter); + free(parameter->topk_node_list_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/nnacl/CMakeLists.txt new file mode 100644 index 0000000000..9836fd3c8d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/CMakeLists.txt @@ -0,0 +1,43 @@ +project(nnacl) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) +include_directories(NNACL_DIR) + +########################### optimized files ########################### +file(GLOB OPTIMIZED_ASSEMBLY + ${NNACL_DIR}/assembly/opt/*.s + ${NNACL_DIR}/assembly/opt/*.S + ) + +file(GLOB FP16_SRC + ${NNACL_DIR}/fp16/*.cc + ${NNACL_DIR}/../fp16/*.cc + ) + +########################### share library build ######################## +set(OPTIMIZED_OPS ${NNACL_DIR}/opt_op_handler.c) + +set_property(SOURCE ${OPTIMIZED_ASSEMBLY} PROPERTY LANGUAGE C) +list(APPEND OPTIMIZED_OPS ${OPTIMIZED_ASSEMBLY} ${FP16_SRC}) + +if (PLATFORM_ARM64) + string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + add_library(optimize SHARED ${OPTIMIZED_OPS}) + target_link_libraries( + optimize + mindspore-lite + ) + set_target_properties(optimize PROPERTIES CLEAN_DIRECT_OUTPUT 1) + + add_custom_command(TARGET optimize POST_BUILD + COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip + ${TOP_DIR}/build/src/runtime/kernel/arm/nnacl/liboptimize.so) + + add_custom_command(TARGET optimize POST_BUILD + COMMAND rm -rf ${TOP_DIR}/output/lib/liboptimize.so + COMMAND mkdir -pv ${TOP_DIR}/output/lib + COMMAND cp ${TOP_DIR}/build/src/runtime/kernel/arm/nnacl/liboptimize.so ${TOP_DIR}/output/lib) +endif () diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/activation_grad.h new file mode 100644 index 0000000000..c0ebfc368a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/activation_grad.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ACTIVATION_GRAD_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +struct ActivationGradParameter { + OpParameter op_parameter{}; + int type_; + float alpha_{0.01}; +}; + +inline int ReluGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src1[i] > 0 ? 1.0f : 0.0f; + } + ElementMul(src0, dst, dst, length); + return NNACL_OK; +} + +inline int Relu6Grad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + if (src1[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f; + } + } + ElementMul(src0, dst, dst, length); + return NNACL_OK; +} + +inline int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) { + for (int i = 0; i < length; ++i) { + dst[i] = src1[i] > 0.0f ? 1.0f : alpha; + } + ElementMul(src0, dst, dst, length); + return NNACL_OK; +} + +inline int SigmoidGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +inline int TanhGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; + } + return NNACL_OK; +} + +inline int HSwishGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +inline int HSigmoidGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.cc new file mode 100644 index 0000000000..a11e76f77c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/add_int8.h" +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +#ifdef ENABLE_NEON +int16x8_t LoadAndAddOffset(int8_t *data, int index, int offset) { + int8x8_t input_s8 = vld1_s8(data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + return vaddq_s16(input_s16, vdupq_n_s16(offset)); +} + +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec) { + int32x4_t shifted_input = vmulq_s32(input, left_shift_result_vec); + shifted_input = vqrdmulhq_s32(shifted_input, input_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(shifted_input, right_shift_vec), 31); + return vrshlq_s32(vqaddq_s32(shifted_input, fixup), right_shift_vec); +} + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, AddQuantParameter *para) { + int32x4_t raw_sum = vaddq_s32(scaled_input0, scaled_input1); + + raw_sum = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_sum, left_shift_out_vec), output_multiplier_vec), + para->right_shift_out_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para->output_offset_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para->output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void AddInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + AddQuantParameter *para, int *index) { + int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_); + int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_); + int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_); + int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_); + int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32((1 << para->left_shift_out_)); + int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_); + int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->input0_offset_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->input1_offset_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int32x4_t scaled_input0_low = + ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input0_high = + ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input1_low = + ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + int32x4_t scaled_input1_high = + ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + + int16x4_t sum_low = + ClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + ClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data + *index, res_u8_n0); + } +} +#endif + +void AddInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + AddQuantParameter *para) { + int index = 0; +#ifdef ENABLE_NEON + AddInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->input0_offset_ + input0_data[index]; + const int32_t input1_val = para->input1_offset_ + input1_data[index]; + const int32_t shifted_input0_val = input0_val * para->left_shift_result0_; + const int32_t shifted_input1_val = input1_val * para->left_shift_result1_; + const int32_t scaled_input0_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_); + const int32_t scaled_input1_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_); + + const int32_t raw_sum = scaled_input0_val + scaled_input1_val; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_sum * (1 << (unsigned int)para->left_shift_out_), + para->output_multiplier_), + para->right_shift_out_) + + para->output_offset_; + + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h new file mode 100644 index 0000000000..ee667f1713 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/add_int8.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ADD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ADD_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct AddQuantParameter { + int input0_offset_; + int input1_offset_; + int output_offset_; + float input0_scale_; + float input1_scale_; + float output_scale_; + int input0_multiplier_; + int input1_multiplier_; + int output_multiplier_; + int input0_shift_; + int input1_shift_; + int output_shift_; + int output_activation_min_; + int output_activation_max_; + int left_shift_result0_; + int left_shift_result1_; + int right_shift0_; + int right_shift1_; + int left_shift_out_; + int right_shift_out_; +}; + +void AddInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + AddQuantParameter *para); + +#ifdef ENABLE_NEON +#include +int16x8_t LoadAndAddOffset(int8_t *data, int index, int offset); +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ADD_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.cc new file mode 100644 index 0000000000..fd1a6f7c72 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/arg_min_max.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h" + +#define FLOAT_DATA_TYPE 43 + +void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count, + int *after_axis_count) { + *pre_axis_count = 1; + for (int i = 0; i < axis; ++i) { + *pre_axis_count = (*pre_axis_count) * shape[i]; + } + + *axis_count = shape[axis]; + + *after_axis_count = 1; + for (int i = axis + 1; i < dims_number; ++i) { + *after_axis_count = (*after_axis_count) * shape[i]; + } +} + +void ArgMinMaxTopk1(const void *input, void *output, const int *shape, ArgMinMaxParameter *param) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + GetCalcParameter(shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + switch (param->data_type_) { + case FLOAT_DATA_TYPE: { + if (param->get_max_) { + ArgMax(reinterpret_cast(input), reinterpret_cast(output), param, pre_axis_count, + axis_count, after_axis_count); + } else { + ArgMin(reinterpret_cast(input), reinterpret_cast(output), param, pre_axis_count, + axis_count, after_axis_count); + } + break; + } + default: + break; + } +} + +void ArgMinMaxTopknFp32(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->get_max_) { + switch (param->axis_) { + case 0: + ArgMaxDim0(input, output, in_shape, param); + break; + case 1: + ArgMaxDim1(input, output, in_shape, param); + break; + case 2: + ArgMaxDim2(input, output, in_shape, param); + break; + case 3: + ArgMaxDim3(input, output, in_shape, param); + break; + } + } else { + switch (param->axis_) { + case 0: + ArgMinDim0(input, output, in_shape, param); + break; + case 1: + ArgMinDim1(input, output, in_shape, param); + break; + case 2: + ArgMinDim2(input, output, in_shape, param); + break; + case 3: + ArgMinDim3(input, output, in_shape, param); + break; + } + } +} + +void ArgMinMax(const void *input, void *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->topk_ == 1) { + ArgMinMaxTopk1(input, output, in_shape, param); + return; + } + + switch (param->data_type_) { + case FLOAT_DATA_TYPE: { + ArgMinMaxTopknFp32(reinterpret_cast(input), reinterpret_cast(output), in_shape, param); + return; + } + default: + break; + } +} + +#undef FLOAT_DATA_TYPE +#undef INT8_DATA_TYPE diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.h new file mode 100644 index 0000000000..a308f65c51 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARG_MIN_MAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARG_MIN_MAX_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/arg_min_max_parameter.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +void ArgMinMax(const void *input, void *output, const int *in_shape, ArgMinMaxParameter *param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARG_MIN_MAX_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max_parameter.h new file mode 100644 index 0000000000..c403f2a7c5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max_parameter.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARG_MIN_MAX_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARG_MIN_MAX_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ArgElement { + uint32_t index_; + union ArgData { + int8_t i8_data_; + int32_t i_data_; + float f_data_; + } data_; +}; + +struct ArgMinMaxParameter { + OpParameter op_parameter_; + bool out_value_; + bool keep_dims_; + bool get_max_; + int32_t axis_; + int32_t topk_; + int32_t axis_type_; + int32_t dims_size_; + int32_t data_type_; // equals to type_id + int32_t in_strides_[DIMENSION_4D]; + int32_t out_strides_[DIMENSION_4D]; + ArgElement *arg_elements_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARG_MIN_MAX_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.cc new file mode 100644 index 0000000000..b136fd8d42 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +void TileOneDimension(float *inData, float *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimension(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileOneDimensionUint8(uint8_t *inData, uint8_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(uint8_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionUint8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, + ndim, inShape, inStrides, outStrides, multiple); + } + } +} + +void ComputeStrides(int *shape, int *strides, int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +void CalcMultiplesAndStrides(ArithmeticParameter *param) { + for (auto i = 0; i < param->ndim_; i++) { + param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; + param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + } + // cal strides + ComputeStrides(param->in_shape0_, param->in_strides0_, param->ndim_); + ComputeStrides(param->in_shape1_, param->in_strides1_, param->ndim_); + ComputeStrides(param->out_shape_, param->out_strides_, param->ndim_); +} + +void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimension(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimension(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionUint8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionUint8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionUint8((uint8_t *)(data0), (uint8_t *)(tile_data0), 0, param->ndim_, + param->in_shape0_, param->in_strides0_, param->out_strides_, param->multiples0_); + TileOneDimensionUint8((uint8_t *)(data1), (uint8_t *)(tile_data1), 0, param->ndim_, + param->in_shape1_, param->in_strides1_, param->out_strides_, param->multiples1_); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h new file mode 100644 index 0000000000..b0e52b8694 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +struct ArithmeticParameter { + OpParameter op_parameter_; + bool broadcasting_; + size_t ndim_; + int activation_type_; + int in_shape0_[5]; + int in_shape1_[5]; + int out_shape_[5]; + + int in_strides0_[5]; + int in_strides1_[5]; + int out_strides_[5]; + + int multiples0_[5]; + int multiples1_[5]; +}; +void TileOneDimension(float *inData, float *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple); +void ComputeStrides(int *shape, int *strides, int ndim); + +void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param); +void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1, + ArithmeticParameter *param); +void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_parameter.h new file mode 100644 index 0000000000..132ad4a0ab --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARTITHMETIC_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARTITHMETIC_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_attribute.h" + + + + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARTITHMETIC_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h new file mode 100644 index 0000000000..35669c5518 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_SELF_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_SELF_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. +struct ArithmeticSelfParameter { + OpParameter op_parameter_; + ArithSelfQuantArg quant_arg_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_SELF_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/ConvDwFp32Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/ConvDwFp32Center.S new file mode 100644 index 0000000000..c8398ca03d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/ConvDwFp32Center.S @@ -0,0 +1,161 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global ConvDwFp32Center +#ifndef __APPLE__ +.type ConvDwFp32Center, %function +#endif + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// r0: dst, r1: src, r2: weight, r3: bias, #48: height, #52: weight, #56: kernel_h, #60: kernel_w, +// #64: out_h_step, #68: block_channel, #72: in_sh_step, #76: in_sw_step, #80: in_kh_step,#84: in_kw_step +// #88: relu, #92: relu6 +ConvDwFp32Center: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r4, [sp, #48] + + vld1.32 {q13}, [r3] + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + LoopH: + ldr r1, [sp, #4] // src_w + ldr r5, [sp, #52] // width + ldr r0, [sp] // dst_w + cmp r5, #4 + blt LoopW + LoopW4: + ldr r11, [sp, #76] // in_sw_step + mov r8, r1 // src_kh + ldr r2, [sp, #8] // weight_kh + ldr r6, [sp, #56] // kernel_h + vmov q0, q13 + LoopKh4: + ldr r12, [sp, #80] //in_kh_step + ldr r7, [sp, #60] // kernel_w + mov lr, r8 // src_kw + LoopKw4: + mov r10, lr + vld1.32 {q12}, [r2]! + vld1.32 {q4}, [r10] + add r10, r10, r11 + vmla.f32 q0, q4, q12 + vld1.32 {q5}, [r10] + add r10, r10, r11 + vmla.f32 q1, q5, q12 + vld1.32 {q6}, [r10] + add r10, r10, r11 + vmla.f32 q2, q6, q12 + vld1.32 {q7}, [r10] + add r10, r10, r11 + vmla.f32 q3, q7, q12 + subs r7, r7, #1 + add lr, lr, r12 + bne LoopKw4 + ldr r12, [sp, #80] + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh4 + ldr r12, [sp, #92] + cmp r12, #0 + bne Relu64 + ldr r12, [sp, #88] + cmp r12, #0 + bne Relu4 + b Write4 + Relu64: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + Relu4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + Write4: + ldr r12, [sp, #68] + vst1.32 {q0}, [r0] + add r0, r0, r12 + vst1.32 {q1}, [r0] + add r0, r0, r12 + vst1.32 {q2}, [r0] + add r0, r0, r12 + vst1.32 {q3}, [r0] + add r0, r0, r12 + mov r12, #4 + mul r11, r11, r12 + add r1, r1, r11 + sub r5, r5, #4 + cmp r5, #0 + ble LoopWEnd + cmp r5, #4 + bge LoopW + LoopW: + mov r8, r1 // src_kh + ldr r2, [sp, #8] // weight_kh + ldr r6, [sp, #56] // kernel_h + vmov q0, q13 + LoopKh: + ldr r12, [sp, #84] //in_kw_step + ldr r7, [sp, #60] // kernel_w + mov r10, r8 // src_kw + LoopKw: + vld1.32 {q1}, [r10] + add r10, r10, r12 + vld1.32 {q12}, [r2]! + vmla.f32 q0, q1, q12 + subs r7, r7, #1 + bne LoopKw + ldr r12, [sp, #80] + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh + ldr r12, [sp, #92] + cmp r12, #0 + bne Relu6 + ldr r12, [sp, #88] + cmp r12, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q14 + Relu: + vmax.f32 q0, q0, q15 + Write: + ldr r12, [sp, #68] + vst1.32 {q0}, [r0] + add r0, r0, r12 + ldr r12, [sp, #76] + add r1, r1, r12 + subs r5, r5, #1 + bne LoopW + ldr r3, [sp, #64] + ldr r12, [sp] + add r12, r12, r3 + str r12, [sp] + ldr r3, [sp, #72] + ldr r12, [sp, #4] + add r12, r12, r3 + str r12, [sp, #4] + subs r4, r4, #1 + bne LoopH +LoopWEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/ConvDwInt8Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/ConvDwInt8Center.S new file mode 100644 index 0000000000..2f75feaa19 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/ConvDwInt8Center.S @@ -0,0 +1,210 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global ConvDwInt8Center +#ifndef __APPLE__ +.type ConvDwInt8Center, %function +#endif + +// void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift, +// int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: src, r2: weight, r3: bias, #48: height, #52: width, #56: kernel_h, #60: kernel_w, +// #64: out_h_step, #68: block_channel, #72: in_sh_step, #76: in_sw_step, #80: in_kh_step,#84: in_kw_step +// #88: out_multiplier, #92: left_shift, #96: right_shift, #100: out_zp, #104: acc_min, #108: acc_max +ConvDwInt8Center: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r4, [sp, #48] + + ldr r12, [sp, #92] + vdup.32 q9, r12 + + ldr r11, [sp, #88] + vdup.32 q10, r11 + + ldr r10, [sp, #96] + vdup.32 q11, r10 + + ldr r8, [sp, #100] + vdup.32 q12, r8 + + ldr r7, [sp, #104] + vdup.32 q13, r7 + + ldr r6, [sp, #108] + vdup.32 q14, r6 + + vld1.32 {q15}, [r3] + + LoopH: + ldr r1, [sp, #4] // src_w + ldr r5, [sp, #52] // width + ldr r0, [sp] // dst_w + LoopW4: + ldr r11, [sp, #76] // in_sw_step + mov r8, r1 // src_kh + ldr r2, [sp, #8] // weight_kh + ldr r6, [sp, #56] // kernel_h + vmov q0, q15 + LoopKh4: + ldr r12, [sp, #80] //in_kh_step + ldr r7, [sp, #60] // kernel_w + mov r10, r8 // src_kw + LoopKw4: + vld1.16 {d24}, [r2]! + vld1.16 {d8}, [r10] + add r10, r10, r11 + vmlal.s16 q0, d8, d24 + vld1.16 {d10}, [r10] + add r10, r10, r11 + vmlal.s16 q1, d10, d24 + vld1.16 {d12}, [r10] + add r10, r10, r11 + vmlal.s16 q2, d12, d24 + vld1.16 {d14}, [r10] + add r10, r10, r11 + vmlal.s16 q3, d14, d24 + subs r7, r7, #1 + bne LoopKw4 + ldr r12, [sp, #80] + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh4 + + vshl.s32 q0, q0, q9 + vshl.s32 q1, q1, q9 + vshl.s32 q2, q2, q9 + vshl.s32 q3, q3, q9 + vqrdmulh.s32 q0, q0, q10 + vqrdmulh.s32 q1, q1, q10 + vqrdmulh.s32 q2, q2, q10 + vqrdmulh.s32 q3, q3, q10 + vrshl.s32 q0, q0, q11 + vrshl.s32 q1, q1, q11 + vrshl.s32 q2, q2, q11 + vrshl.s32 q3, q3, q11 + vadd.i32 q0, q0, q12 + vadd.i32 q1, q1, q12 + vadd.i32 q2, q2, q12 + vadd.i32 q3, q3, q12 + vmax.s32 q0, q0, q13 + vmax.s32 q1, q1, q13 + vmax.s32 q2, q2, q13 + vmax.s32 q3, q3, q13 + vmin.s32 q0, q0, q14 + vmin.s32 q1, q1, q14 + vmin.s32 q2, q2, q14 + vmin.s32 q3, q3, q14 + + vqmovn.s32 d0, q0 + vqmovn.s32 d2, q1 + vqmovn.s32 d4, q2 + vqmovn.s32 d6, q3 + vqmovn.s16 d0, q0 + vqmovn.s16 d2, q1 + vqmovn.s16 d4, q2 + vqmovn.s16 d6, q3 + + mov r3, r0 + ldr r12, [sp, #68] + vst1.8 {d0[0]}, [r3]! + vst1.8 {d0[1]}, [r3]! + vst1.8 {d0[2]}, [r3]! + vst1.8 {d0[3]}, [r3]! + add r0, r0, r12 + mov r3, r0 + vst1.8 {d2[0]}, [r3]! + vst1.8 {d2[1]}, [r3]! + vst1.8 {d2[2]}, [r3]! + vst1.8 {d2[3]}, [r3]! + add r0, r0, r12 + mov r3, r0 + vst1.8 {d4[0]}, [r3]! + vst1.8 {d4[1]}, [r3]! + vst1.8 {d4[2]}, [r3]! + vst1.8 {d4[3]}, [r3]! + add r0, r0, r12 + mov r3, r0 + vst1.8 {d6[0]}, [r3]! + vst1.8 {d6[1]}, [r3]! + vst1.8 {d6[2]}, [r3]! + vst1.8 {d6[3]}, [r3]! + add r0, r0, r12 + mov r3, r0 + mov r12, #4 + mul r11, r11, r12 + add r1, r1, r11 + sub r5, r5, #4 + cmp r5, #0 + ble LoopWEnd + cmp r5, #4 + bge LoopW4 + LoopW: + mov r8, r1 // src_kh + ldr r2, [sp, #8] // weight_kh + ldr r6, [sp, #56] // kernel_h + vmov q0, q15 + LoopKh: + ldr r12, [sp, #84] //in_kw_step + ldr r7, [sp, #60] // kernel_w + mov r10, r8 // src_kw + LoopKw: + vld1.16 {d2}, [r10] + add r10, r10, r12 + vld1.16 {d24}, [r2]! + vmlal.s16 q0, d2, d24 + subs r7, r7, #1 + bne LoopKw + ldr r12, [sp, #80] + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh + + vshl.s32 q0, q0, q9 + vqrdmulh.s32 q0, q0, q10 + vrshl.s32 q0, q0, q11 + vadd.i32 q0, q0, q12 + vmax.s32 q0, q0, q13 + vmin.s32 q0, q0, q14 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + mov r3, r0 + ldr r12, [sp, #68] + vst1.8 {d0[0]}, [r3]! + vst1.8 {d0[1]}, [r3]! + vst1.8 {d0[2]}, [r3]! + vst1.8 {d0[3]}, [r3]! + add r0, r0, r12 + ldr r12, [sp, #76] + add r1, r1, r12 + subs r5, r5, #1 + bne LoopW + ldr r3, [sp, #64] + ldr r12, [sp] + add r12, r12, r3 + str r12, [sp] + ldr r3, [sp, #72] + ldr r12, [sp, #4] + add r12, r12, r3 + str r12, [sp, #4] + subs r4, r4, #1 + bne LoopH +LoopWEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/DeconvDwFp32Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/DeconvDwFp32Center.S new file mode 100644 index 0000000000..06c38740a5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/DeconvDwFp32Center.S @@ -0,0 +1,69 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global DeconvDwFp32Center +#ifndef __APPLE__ +.type DeconvDwFp32Center, %function +#endif + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +DeconvDwFp32Center: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.32 {q1}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.32 {q2}, [r2]! + vmla.f32 q0, q1, q2 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/DeconvDwInt8Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/DeconvDwInt8Center.S new file mode 100644 index 0000000000..abae39e13a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/DeconvDwInt8Center.S @@ -0,0 +1,69 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global DeconvDwInt8Center +#ifndef __APPLE__ +.type DeconvDwInt8Center, %function +#endif + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +DeconvDwInt8Center: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.16 {d2}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.16 {d24}, [r2]! + vmlal.s16 q0, d2, d24 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S new file mode 100644 index 0000000000..215178d35a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S @@ -0,0 +1,302 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global IndirectGemmFp32_8x4 +#ifndef __APPLE__ +.type IndirectGemmFp32_8x4, %function +#endif + +// void IndirectGemmFp32_8x4(float *output, float *input, float *weight, float *bias, +// size_t kSize, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); +// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset +// r8:mode, r10: writeMode, r10: relu, r10:relu6 +// mode = 0 for general convolution, where one conv unit is a row +// mode = 1 for winograd/common gemm, where the total channels of one input is a row +IndirectGemmFp32_8x4: + + .macro INIT_BIAS + veor q8, q8, q8 + cmp r3, #0 + beq InitBias + vld1.32 {q8}, [r3] + InitBias: + vmov q9, q8 + vmov q10, q8 + vmov q11, q8 + vmov q12, q8 + vmov q13, q8 + vmov q14, q8 + vmov q15, q8 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + cmp r8, #0 + bne LoopOc + // step is one for common convolution, where ic8 should multiply by kernel size + // step is (a+b-1) for F(a,b) in winograd + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + mov r8, r4 + mov r12, r1 + + LoopKsize: + + mov r11, r0 + INIT_BIAS + + // load input for output 1-2 + vld1.32 {q0, q1}, [r12]! + vld1.32 {q2, q3}, [r12]! + // load weight + vld1.32 {q4, q5}, [r2]! + // step for output 1-2 + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + + subs r10, r5, #1 + beq LoopIcEnd + + LoopIc: + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + vld1.s32 {q0, q1}, [r12]! + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + vld1.s32 {q2, q3}, [r12]! + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + vld1.s32 {q4, q5}, [r2]! + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vld1.s32 {q0, q1}, [r12]! + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + vld1.s32 {q6, q7}, [r2]! + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.s32 {q2, q3}, [r12]! + + subs r10, r10, #1 + bne LoopIc + + LoopIcEnd: + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + vld1.s32 {q0, q1}, [r12]! + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + vld1.s32 {q2, q3}, [r12]! + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + + ldr r10, [sp, #28] + cmp r10, #0 + bne Relu6 + ldr r10, [sp, #24] + cmp r10, #0 + bne Relu + b WriteStart + Relu6: + vmov.i32 q7, #6 + vcvt.f32.s32 q7, q7 + vmin.f32 q8, q8, q7 + vmin.f32 q9, q9, q7 + vmin.f32 q10, q10, q7 + vmin.f32 q11, q11, q7 + vmin.f32 q12, q12, q7 + vmin.f32 q13, q13, q7 + vmin.f32 q14, q14, q7 + vmin.f32 q15, q15, q7 + Relu: + veor q7, q7, q7 + vmax.f32 q8, q8, q7 + vmax.f32 q9, q9, q7 + vmax.f32 q10, q10, q7 + vmax.f32 q11, q11, q7 + vmax.f32 q12, q12, q7 + vmax.f32 q13, q13, q7 + vmax.f32 q14, q14, q7 + vmax.f32 q15, q15, q7 + + WriteStart: + ldr r10, [sp, #20] + cmp r10, #0 + bne Write4 + cmp r6, #1 + beq Write1 + cmp r6, #2 + beq Write2 + cmp r6, #3 + beq Write3 + b Write4 + Write1: + vst1.32 d16[0], [r11] + add r11, r11, r7 + vst1.32 d18[0], [r11] + add r11, r11, r7 + vst1.32 d20[0], [r11] + add r11, r11, r7 + vst1.32 d22[0], [r11] + add r11, r11, r7 + vst1.32 d24[0], [r11] + add r11, r11, r7 + vst1.32 d26[0], [r11] + add r11, r11, r7 + vst1.32 d28[0], [r11] + add r11, r11, r7 + vst1.32 d30[0], [r11] + add r11, r11, r7 + add r0, r0, #4 + b WriteEnd + Write2: + vst1.32 d16, [r11] + add r11, r11, r7 + vst1.32 d18, [r11] + add r11, r11, r7 + vst1.32 d20, [r11] + add r11, r11, r7 + vst1.32 d22, [r11] + add r11, r11, r7 + vst1.32 d24, [r11] + add r11, r11, r7 + vst1.32 d26, [r11] + add r11, r11, r7 + vst1.32 d28, [r11] + add r11, r11, r7 + vst1.32 d30, [r11] + add r11, r11, r7 + add r0, r0, #8 + b WriteEnd + Write3: + add lr, r11, #8 + vst1.32 d16, [r11] + add r11, r11, r7 + vst1.32 d17[0], [lr] + add lr, lr, r7 + vst1.32 d18, [r11] + add r11, r11, r7 + vst1.32 d19[0], [lr] + add lr, lr, r7 + vst1.32 d20, [r11] + add r11, r11, r7 + vst1.32 d21[0], [lr] + add lr, lr, r7 + vst1.32 d22, [r11] + add r11, r11, r7 + vst1.32 d23[0], [lr] + add lr, lr, r7 + vst1.32 d24, [r11] + add r11, r11, r7 + vst1.32 d25[0], [lr] + add lr, lr, r7 + vst1.32 d26, [r11] + add r11, r11, r7 + vst1.32 d27[0], [lr] + add lr, lr, r7 + vst1.32 d28, [r11] + add r11, r11, r7 + vst1.32 d29[0], [lr] + add lr, lr, r7 + vst1.32 d30, [r11] + add r11, r11, r7 + vst1.32 d31[0], [lr] + add lr, lr, r7 + add r0, r0, #12 + b WriteEnd + Write4: + // prefetching is not prefered while writing results in spite of cache missings + // you could try pld + // there are almost no benefits observed though + vst1.32 {q8}, [r11], r7 + vst1.32 {q9}, [r11], r7 + vst1.32 {q10}, [r11], r7 + vst1.32 {q11}, [r11], r7 + vst1.32 {q12}, [r11], r7 + vst1.32 {q13}, [r11], r7 + vst1.32 {q14}, [r11], r7 + vst1.32 {q15}, [r11], r7 + add r0, r0, #16 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + cmp r6, #4 + ble LoopOcEnd + sub r6, r6, #4 + cmp r3, #0 + beq NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + b LoopOc + +LoopOcEnd: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmInt16to32_8x4.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmInt16to32_8x4.S new file mode 100644 index 0000000000..eaf11da242 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,238 @@ +#ifdef ENABLE_ARM32 + +.text +.align 5 +.global IndirectGemmInt16to32_8x4 +#ifndef __APPLE__ +.type IndirectGemmInt16to32_8x4, %function +#endif + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t kszie, size_t ic8, size_t oc4, size_t offset); +// r0: output, r1: input, r2: weight, r3: kszie, r4: ic8, r5: oc4, r6: offset +IndirectGemmInt16to32_8x4: + + .macro INIT_ZERO + // we could also use "vmov.s32 q12, #0" to initialize q12 by 0 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, lr} + + ldr r4, [sp, #28] + ldr r5, [sp, #32] + ldr r6, [sp, #36] + + vpush {q4-q7} + + LoopOc: + + mov r7, r3 + mov r8, r1 + + LoopKsize: + mov r10, r0 + INIT_ZERO + + // load input + vld1.16 {q0, q1}, [r8]! + // load weight + vld1.16 {q4}, [r2]! + vmull.s16 q8, d8, d0[0] + vmull.s16 q9, d8, d2[0] + // load weight + vld1.16 {q5}, [r2]! + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + // load weight + vld1.16 {q6, q7}, [r2]! + vmull.s16 q10, d8, d4[0] + vmull.s16 q11, d8, d6[0] + + subs r12, r4, #1 + beq LoopIcEnd + + LoopIc: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + // load weight + vld1.16 {q4, q5}, [r2]! + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d8, d0[0] + vmlal.s16 q9, d8, d2[0] + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load weight + vld1.16 {q6, q7}, [r2]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + vmlal.s16 q10, d8, d4[0] + vmlal.s16 q11, d8, d6[0] + + subs r12, r12, #1 + bne LoopIc + + LoopIcEnd: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vst1.32 {q8}, [r10], r6 + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vst1.32 {q9}, [r10], r6 + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.s16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vst1.32 {q10}, [r10], r6 + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vst1.32 {q11}, [r10], r6 + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + vst1.32 {q12}, [r10], r6 + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vst1.32 {q13}, [r10], r6 + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + + vst1.32 {q14}, [r10], r6 + vst1.32 {q15}, [r10] + + subs r7, r7, #1 + add r0, r0, #16 + bne LoopKsize + + subs r5, r5, #1 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, r10, pc} + +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S new file mode 100644 index 0000000000..49f3e34ff2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S @@ -0,0 +1,243 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt8_2x4 +#ifndef __APPLE__ +.type IndirectGemmInt8_2x4, %function +#endif + +// void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, +// size_t shift_before, size_t shift_after); +// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset +// r8: input_sum, r10: act_min, r11: act_max, r10: out_zp, r11: out_multiplier, r10: shift_before, r11: shift_after +IndirectGemmInt8_2x4: + + .macro INIT_BIAS + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + + mov r8, r4 + mov r12, r1 + + LoopKsize: + INIT_BIAS + mov r11, r0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-2 + vld1.8 {q0, q1}, [r12]! + // load weight for oc 1-2 + vld1.8 {q2, q3}, [r2]! + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vpaddl.s16 q8, q6 + vpaddl.s16 q9, q7 + // load weight for oc 3-4 + vld1.8 {q4, q5}, [r2]! + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1 + vld1.8 {q0}, [r12]! + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vld1.8 {q2, q3}, [r2]! + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vld1.8 {q4, q5}, [r2]! + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vld1.8 {q1}, [r12]! + vpadal.s16 q8, q6 + vpadal.s16 q9, q7 + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r10, #1 + bne LoopIc + + LoopIcEnd: + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + + // load sum + ldr r10, [sp, #16] + vld1.32 q0[], [r10]! + vld1.32 q1[], [r10]! + // pairwise add + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + vpadd.i32 d28, d28, d29 + vpadd.i32 d30, d30, d31 + + vpadd.i32 d16, d16, d18 + vpadd.i32 d17, d20, d22 + vpadd.i32 d24, d24, d26 + vpadd.i32 d25, d28, d30 + vsub.i32 q8, q8, q0 + vsub.i32 q12, q12, q1 + cmp r3, #0 + beq NoBias + vld1.32 q2, [r3] + vadd.i32 q8, q8, q2 + vadd.i32 q12, q12, q2 + + NoBias: + ldr r10, [sp, #36] + vdup.32 q3, r10 + vshl.s32 q8, q8, q3 + vshl.s32 q12, q12, q3 + + ldr r10, [sp, #32] + vdup.32 q4, r10 + vqrdmulh.s32 q8, q8, q4 + vqrdmulh.s32 q12, q12, q4 + + ldr r10, [sp, #40] + vdup.32 q5, r10 + vrshl.s32 q8, q8, q5 + vrshl.s32 q12, q12, q5 + + ldr r10, [sp, #28] + vdup.32 q6, r10 + vadd.i32 q8, q8, q6 + vadd.i32 q12, q12, q6 + + ldr r10, [sp, #20] + vdup.32 q0, r10 + vmax.s32 q8, q8, q0 + vmax.s32 q12, q12, q0 + + ldr r10, [sp, #24] + vdup.32 q1, r10 + vmin.s32 q8, q8, q1 + vmin.s32 q12, q12, q1 + + vqmovn.s32 d30, q8 + vqmovn.s32 d31, q12 + vqmovn.s16 d0, q14 + + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + WriteStart: + cmp r6, #1 + beq Write1 + cmp r6, #2 + beq Write2 + cmp r6, #3 + beq Write3 + b Write4 + Write1: + vst1.8 {d0[0]}, [r11], r7 + vst1.8 {d0[1]}, [r11] + add r0, r0, #1 + b WriteEnd + Write2: + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + add r0, r0, #2 + b WriteEnd + Write3: + add r14, r11, #2 + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + vst1.8 {d0[0]}, [r14], r7 + vst1.8 {d0[1]}, [r14] + add r0, r0, #3 + b WriteEnd + Write4: + vst1.32 {d0[0]}, [r11], r7 + vst1.32 {d0[1]}, [r11] + add r0, r0, #4 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + cmp r6, #4 + ble LoopOcEnd + sub r6, r6, #4 + cmp r3, #0 + beq NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + b LoopOc + +LoopOcEnd: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAdd.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAdd.S new file mode 100644 index 0000000000..d59b0239f1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAdd.S @@ -0,0 +1,131 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global C4BiasAdd +#ifndef __APPLE__ + .type C4BiasAdd, %function +#endif + +//void C4BiasAdd(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride) +//x0: dst, x1: input, x2: bias, x3: oc, x4: plane_size, x5: stride + +C4BiasAdd: + + LoopOc: + ld1 {v4.4s}, [x2], #16 + mov x6, x4 + mov x7, x0 + cmp x6, #4 + blt Loop1 + + Loop4: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v4.4s + fadd v1.4s, v1.4s, v4.4s + fadd v2.4s, v2.4s, v4.4s + fadd v3.4s, v3.4s, v4.4s + + cmp x3, #4 + bge Write4x4 + cmp x3, #3 + beq Write3x4 + cmp x3, #2 + beq Write2x4 + + Write1x4: + str s0, [x7] + add x7, x7, x5 + str s1, [x7] + add x7, x7, x5 + str s2, [x7] + add x7, x7, x5 + str s3, [x7] + add x7, x7, x5 + b WriteEndx4 + Write2x4: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x5 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x5 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x5 + b WriteEndx4 + Write3x4: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + st1 {v0.s}[2], [x8], x5 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x5 + st1 {v1.s}[2], [x8], x5 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x5 + st1 {v2.s}[2], [x8], x5 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x5 + st1 {v3.s}[2], [x8], x5 + b WriteEndx4 + Write4x4: + st1 {v0.4s}, [x7], x5 + st1 {v1.4s}, [x7], x5 + st1 {v2.4s}, [x7], x5 + st1 {v3.4s}, [x7], x5 + + WriteEndx4: + subs x6, x6, #4 + beq LoopOcEnd + cmp x6, #4 + blt Loop1 + b Loop4 + + Loop1: + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v4.4s + + cmp x3, #4 + bge Write4 + cmp x3, #3 + beq Write3 + cmp x3, #2 + beq Write2 + + Write1: + str s0, [x7] + add x7, x7, x5 + b WriteEnd + Write2: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + b WriteEnd + Write3: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + st1 {v0.s}[2], [x8], x5 + b WriteEnd + Write4: + st1 {v0.4s}, [x7], x5 + WriteEnd: + subs x6, x6, #1 + bne Loop1 + LoopOcEnd: + subs x3, x3, #4 + add x0, x0, #16 + bgt LoopOc + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAddRelu.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAddRelu.S new file mode 100644 index 0000000000..6a464b400b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAddRelu.S @@ -0,0 +1,137 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global C4BiasAddRelu +#ifndef __APPLE__ + .type C4BiasAddRelu, %function +#endif + +//void C4BiasAddRelu(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride) +//x0: dst, x1: input, x2: bias, x3: oc, x4: plane_size, x5: stride + +C4BiasAddRelu: + dup v5.4s, wzr + LoopOc: + ld1 {v4.4s}, [x2], #16 + mov x6, x4 + mov x7, x0 + cmp x6, #4 + blt Loop1 + + Loop4: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v4.4s + fadd v1.4s, v1.4s, v4.4s + fadd v2.4s, v2.4s, v4.4s + fadd v3.4s, v3.4s, v4.4s + + fmax v0.4s, v0.4s, v5.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v5.4s + fmax v3.4s, v3.4s, v5.4s + + cmp x3, #4 + bge Write4x4 + cmp x3, #3 + beq Write3x4 + cmp x3, #2 + beq Write2x4 + + Write1x4: + str s0, [x7] + add x7, x7, x5 + str s1, [x7] + add x7, x7, x5 + str s2, [x7] + add x7, x7, x5 + str s3, [x7] + add x7, x7, x5 + b WriteEndx4 + Write2x4: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x5 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x5 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x5 + b WriteEndx4 + Write3x4: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + st1 {v0.s}[2], [x8], x5 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x5 + st1 {v1.s}[2], [x8], x5 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x5 + st1 {v2.s}[2], [x8], x5 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x5 + st1 {v3.s}[2], [x8], x5 + b WriteEndx4 + Write4x4: + st1 {v0.4s}, [x7], x5 + st1 {v1.4s}, [x7], x5 + st1 {v2.4s}, [x7], x5 + st1 {v3.4s}, [x7], x5 + + WriteEndx4: + subs x6, x6, #4 + beq LoopOcEnd + cmp x6, #4 + blt Loop1 + b Loop4 + + Loop1: + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v4.4s + fmax v0.4s, v0.4s, v5.4s + + cmp x3, #4 + bge Write4 + cmp x3, #3 + beq Write3 + cmp x3, #2 + beq Write2 + + Write1: + str s0, [x7] + add x7, x7, x5 + b WriteEnd + Write2: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + b WriteEnd + Write3: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + st1 {v0.s}[2], [x8], x5 + b WriteEnd + Write4: + st1 {v0.4s}, [x7], x5 + WriteEnd: + subs x6, x6, #1 + bne Loop1 + LoopOcEnd: + subs x3, x3, #4 + add x0, x0, #16 + bgt LoopOc + + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAddRelu6.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAddRelu6.S new file mode 100644 index 0000000000..b8c8b84842 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4BiasAddRelu6.S @@ -0,0 +1,146 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global C4BiasAddRelu6 +#ifndef __APPLE__ + .type C4BiasAddRelu6, %function +#endif + +//void C4BiC4BiasAddRelu6asAdd(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride) +//x0: dst, x1: input, x2: bias, x3: oc, x4: plane_size, x5: stride + +C4BiasAddRelu6: + dup v5.4s, wzr + movi v6.4s, #6 + scvtf v6.4s, v6.4s + + LoopOc: + ld1 {v4.4s}, [x2], #16 + mov x6, x4 + mov x7, x0 + cmp x6, #4 + blt Loop1 + + Loop4: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v4.4s + fadd v1.4s, v1.4s, v4.4s + fadd v2.4s, v2.4s, v4.4s + fadd v3.4s, v3.4s, v4.4s + + fmax v0.4s, v0.4s, v5.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v5.4s + fmax v3.4s, v3.4s, v5.4s + + fmin v0.4s, v0.4s, v6.4s + fmin v1.4s, v1.4s, v6.4s + fmin v2.4s, v2.4s, v6.4s + fmin v3.4s, v3.4s, v6.4s + + cmp x3, #4 + bge Write4x4 + cmp x3, #3 + beq Write3x4 + cmp x3, #2 + beq Write2x4 + + Write1x4: + str s0, [x7] + add x7, x7, x5 + str s1, [x7] + add x7, x7, x5 + str s2, [x7] + add x7, x7, x5 + str s3, [x7] + add x7, x7, x5 + b WriteEndx4 + Write2x4: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x5 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x5 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x5 + b WriteEndx4 + Write3x4: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + st1 {v0.s}[2], [x8], x5 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x5 + st1 {v1.s}[2], [x8], x5 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x5 + st1 {v2.s}[2], [x8], x5 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x5 + st1 {v3.s}[2], [x8], x5 + b WriteEndx4 + Write4x4: + st1 {v0.4s}, [x7], x5 + st1 {v1.4s}, [x7], x5 + st1 {v2.4s}, [x7], x5 + st1 {v3.4s}, [x7], x5 + + WriteEndx4: + subs x6, x6, #4 + beq LoopOcEnd + cmp x6, #4 + blt Loop1 + b Loop4 + + Loop1: + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v4.4s + fmax v0.4s, v0.4s, v5.4s + fmin v0.4s, v0.4s, v6.4s + + cmp x3, #4 + bge Write4 + cmp x3, #3 + beq Write3 + cmp x3, #2 + beq Write2 + + Write1: + str s0, [x7] + add x7, x7, x5 + b WriteEnd + Write2: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + b WriteEnd + Write3: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x5 + st1 {v0.s}[2], [x8], x5 + b WriteEnd + Write4: + st1 {v0.4s}, [x7], x5 + WriteEnd: + subs x6, x6, #1 + bne Loop1 + LoopOcEnd: + subs x3, x3, #4 + add x0, x0, #16 + bgt LoopOc + + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4Relu.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4Relu.S new file mode 100644 index 0000000000..70ebb2f295 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4Relu.S @@ -0,0 +1,132 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global C4Relu +#ifndef __APPLE__ + .type C4Relu, %function +#endif + +//void C4Relu(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride) +//x0: dst, x1: input, x2: oc, x3: plane_size, x4: stride + +C4Relu: + dup v5.4s, wzr + LoopOc: + mov x6, x3 + mov x7, x0 + cmp x6, #4 + blt Loop1 + + Loop4: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + + fmax v0.4s, v0.4s, v5.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v5.4s + fmax v3.4s, v3.4s, v5.4s + + cmp x2, #4 + bge Write4x4 + cmp x2, #3 + beq Write3x4 + cmp x2, #2 + beq Write2x4 + + Write1x4: + str s0, [x7] + add x7, x7, x4 + str s1, [x7] + add x7, x7, x4 + str s2, [x7] + add x7, x7, x4 + str s3, [x7] + add x7, x7, x4 + b WriteEndx4 + Write2x4: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x4 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x4 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x4 + b WriteEndx4 + Write3x4: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + st1 {v0.s}[2], [x8], x4 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x4 + st1 {v1.s}[2], [x8], x4 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x4 + st1 {v2.s}[2], [x8], x4 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x4 + st1 {v3.s}[2], [x8], x4 + b WriteEndx4 + Write4x4: + st1 {v0.4s}, [x7], x4 + st1 {v1.4s}, [x7], x4 + st1 {v2.4s}, [x7], x4 + st1 {v3.4s}, [x7], x4 + + WriteEndx4: + subs x6, x6, #4 + beq LoopOcEnd + cmp x6, #4 + blt Loop1 + b Loop4 + + Loop1: + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v4.4s + fmax v0.4s, v0.4s, v5.4s + + cmp x2, #4 + bge Write4 + cmp x2, #3 + beq Write3 + cmp x2, #2 + beq Write2 + + Write1: + str s0, [x7] + add x7, x7, x4 + b WriteEnd + Write2: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + b WriteEnd + Write3: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + st1 {v0.s}[2], [x8], x4 + b WriteEnd + Write4: + st1 {v0.4s}, [x7], x4 + WriteEnd: + subs x6, x6, #1 + bne Loop1 + LoopOcEnd: + subs x2, x2, #4 + add x0, x0, #16 + bgt LoopOc + + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4Relu6.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4Relu6.S new file mode 100644 index 0000000000..d19c22a6d3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/C4Relu6.S @@ -0,0 +1,140 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global C4Relu6 +#ifndef __APPLE__ + .type C4Relu6, %function +#endif + +//void C4Relu6(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride) +//x0: dst, x1: input, x2: oc, x2: plane_size, x3: stride + +C4Relu6: + dup v5.4s, wzr + movi v6.4s, #6 + scvtf v6.4s, v6.4s + + LoopOc: + mov x6, x3 + mov x7, x0 + cmp x6, #4 + blt Loop1 + + Loop4: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fmax v0.4s, v0.4s, v5.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v5.4s + fmax v3.4s, v3.4s, v5.4s + + fmin v0.4s, v0.4s, v6.4s + fmin v1.4s, v1.4s, v6.4s + fmin v2.4s, v2.4s, v6.4s + fmin v3.4s, v3.4s, v6.4s + + cmp x2, #4 + bge Write4x4 + cmp x2, #3 + beq Write3x4 + cmp x2, #2 + beq Write2x4 + + Write1x4: + str s0, [x7] + add x7, x7, x4 + str s1, [x7] + add x7, x7, x4 + str s2, [x7] + add x7, x7, x4 + str s3, [x7] + add x7, x7, x4 + b WriteEndx4 + Write2x4: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x4 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x4 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x4 + b WriteEndx4 + Write3x4: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + st1 {v0.s}[2], [x8], x4 + dup s17, v1.s[1] + stp s1, s17, [x7] + add x7, x7, x4 + st1 {v1.s}[2], [x8], x4 + dup s18, v2.s[1] + stp s2, s18, [x7] + add x7, x7, x4 + st1 {v2.s}[2], [x8], x4 + dup s19, v3.s[1] + stp s3, s19, [x7] + add x7, x7, x4 + st1 {v3.s}[2], [x8], x4 + b WriteEndx4 + Write4x4: + st1 {v0.4s}, [x7], x4 + st1 {v1.4s}, [x7], x4 + st1 {v2.4s}, [x7], x4 + st1 {v3.4s}, [x7], x4 + + WriteEndx4: + subs x6, x6, #4 + beq LoopOcEnd + cmp x6, #4 + blt Loop1 + b Loop4 + + Loop1: + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v4.4s + fmax v0.4s, v0.4s, v5.4s + fmin v0.4s, v0.4s, v6.4s + + cmp x2, #4 + bge Write4 + cmp x2, #3 + beq Write3 + cmp x2, #2 + beq Write2 + + Write1: + str s0, [x7] + add x7, x7, x4 + b WriteEnd + Write2: + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + b WriteEnd + Write3: + add x8, x7, #8 + dup s16, v0.s[1] + stp s0, s16, [x7] + add x7, x7, x4 + st1 {v0.s}[2], [x8], x4 + b WriteEnd + Write4: + st1 {v0.4s}, [x7], x4 + WriteEnd: + subs x6, x6, #1 + bne Loop1 + LoopOcEnd: + subs x2, x2, #4 + add x0, x0, #16 + bgt LoopOc + + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/ConvDwFp32Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/ConvDwFp32Center.S new file mode 100644 index 0000000000..6b51afbe05 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/ConvDwFp32Center.S @@ -0,0 +1,295 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global ConvDwFp32Center +#ifndef __APPLE__ +.type ConvDwFp32Center, %function +#endif + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +ConvDwFp32Center: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + ld1 {v24.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x18, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v8.4s, v16.4s, v25.4s + fmla v9.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v10.4s, v18.4s, v25.4s + fmla v11.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v12.4s, v20.4s, v25.4s + fmla v13.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v14.4s, v22.4s, v25.4s + fmla v15.4s, v23.4s, v25.4s + subs x18, x18, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x18, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + subs x18, x18, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x18, x7 + mov x22, x16 + LoopKw: + ld1 {v16.4s}, [x22], x13 + ld1 {v25.4s}, [x17], #16 + fmla v0.4s, v16.4s, v25.4s + subs x18, x18, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + sub sp, sp, #48 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/ConvDwInt8Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/ConvDwInt8Center.S new file mode 100644 index 0000000000..0381b6bdb0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/ConvDwInt8Center.S @@ -0,0 +1,558 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global ConvDwInt8Center +#ifndef __APPLE__ +.type ConvDwInt8Center, %function +#endif + +// void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift, +// int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: out_multiplier, #56: left_shift, #64: right_shift, #72:out_zp, #80: acc_min, #88: acc_max +ConvDwInt8Center: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + + ldr w14, [sp, #56] + dup v26.4s, w14 + + ldr x15, [sp, #48] + dup v27.4s, w15 + + ldr w16, [sp, #64] + dup v28.4s, w16 + + ldr w17, [sp, #72] + dup v29.4s, w17 + + ldr w18, [sp, #80] + dup v30.4s, w18 + + ldr w19, [sp, #88] + dup v31.4s, w19 + + ld1 {v24.4s}, [x3] + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x18, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.4h}, [x17], #8 + ld1 {v16.4h}, [x22], x13 + ld1 {v17.4h}, [x22], x13 + smlal v0.4s, v16.4h, v25.4h + smlal v1.4s, v17.4h, v25.4h + ld1 {v18.4h}, [x22], x13 + ld1 {v19.4h}, [x22], x13 + smlal v2.4s, v18.4h, v25.4h + smlal v3.4s, v19.4h, v25.4h + ld1 {v20.4h}, [x22], x13 + ld1 {v21.4h}, [x22], x13 + smlal v4.4s, v20.4h, v25.4h + smlal v5.4s, v21.4h, v25.4h + ld1 {v22.4h}, [x22], x13 + ld1 {v23.4h}, [x22], x13 + smlal v6.4s, v22.4h, v25.4h + smlal v7.4s, v23.4h, v25.4h + ld1 {v16.4h}, [x22], x13 + ld1 {v17.4h}, [x22], x13 + smlal v8.4s, v16.4h, v25.4h + smlal v9.4s, v17.4h, v25.4h + ld1 {v18.4h}, [x22], x13 + ld1 {v19.4h}, [x22], x13 + smlal v10.4s, v18.4h, v25.4h + smlal v11.4s, v19.4h, v25.4h + ld1 {v20.4h}, [x22], x13 + ld1 {v21.4h}, [x22], x13 + smlal v12.4s, v20.4h, v25.4h + smlal v13.4s, v21.4h, v25.4h + ld1 {v22.4h}, [x22], x13 + ld1 {v23.4h}, [x22], x13 + smlal v14.4s, v22.4h, v25.4h + smlal v15.4s, v23.4h, v25.4h + subs x18, x18, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + + sqshl v0.4s, v0.4s ,v26.4s + sqshl v1.4s, v1.4s ,v26.4s + sqshl v2.4s, v2.4s ,v26.4s + sqshl v3.4s, v3.4s ,v26.4s + sqshl v4.4s, v4.4s ,v26.4s + sqshl v5.4s, v5.4s ,v26.4s + sqshl v6.4s, v6.4s ,v26.4s + sqshl v7.4s, v7.4s ,v26.4s + sqshl v8.4s, v8.4s ,v26.4s + sqshl v9.4s, v9.4s ,v26.4s + sqshl v10.4s, v10.4s ,v26.4s + sqshl v11.4s, v11.4s ,v26.4s + sqshl v12.4s, v12.4s ,v26.4s + sqshl v13.4s, v13.4s ,v26.4s + sqshl v14.4s, v14.4s ,v26.4s + sqshl v15.4s, v15.4s ,v26.4s + sqrdmulh v0.4s, v0.4s ,v27.4s + sqrdmulh v1.4s, v1.4s ,v27.4s + sqrdmulh v2.4s, v2.4s ,v27.4s + sqrdmulh v3.4s, v3.4s ,v27.4s + sqrdmulh v4.4s, v4.4s ,v27.4s + sqrdmulh v5.4s, v5.4s ,v27.4s + sqrdmulh v6.4s, v6.4s ,v27.4s + sqrdmulh v7.4s, v7.4s ,v27.4s + sqrdmulh v8.4s, v8.4s ,v27.4s + sqrdmulh v9.4s, v9.4s ,v27.4s + sqrdmulh v10.4s, v10.4s ,v27.4s + sqrdmulh v11.4s, v11.4s ,v27.4s + sqrdmulh v12.4s, v12.4s ,v27.4s + sqrdmulh v13.4s, v13.4s ,v27.4s + sqrdmulh v14.4s, v14.4s ,v27.4s + sqrdmulh v15.4s, v15.4s ,v27.4s + sqrshl v0.4s, v0.4s ,v28.4s + sqrshl v1.4s, v1.4s ,v28.4s + sqrshl v2.4s, v2.4s ,v28.4s + sqrshl v3.4s, v3.4s ,v28.4s + sqrshl v4.4s, v4.4s ,v28.4s + sqrshl v5.4s, v5.4s ,v28.4s + sqrshl v6.4s, v6.4s ,v28.4s + sqrshl v7.4s, v7.4s ,v28.4s + sqrshl v8.4s, v8.4s ,v28.4s + sqrshl v9.4s, v9.4s ,v28.4s + sqrshl v10.4s, v10.4s ,v28.4s + sqrshl v11.4s, v11.4s ,v28.4s + sqrshl v12.4s, v12.4s ,v28.4s + sqrshl v13.4s, v13.4s ,v28.4s + sqrshl v14.4s, v14.4s ,v28.4s + sqrshl v15.4s, v15.4s ,v28.4s + add v0.4s, v0.4s ,v29.4s + add v1.4s, v1.4s ,v29.4s + add v2.4s, v2.4s ,v29.4s + add v3.4s, v3.4s ,v29.4s + add v4.4s, v4.4s ,v29.4s + add v5.4s, v5.4s ,v29.4s + add v6.4s, v6.4s ,v29.4s + add v7.4s, v7.4s ,v29.4s + add v8.4s, v8.4s ,v29.4s + add v9.4s, v9.4s ,v29.4s + add v10.4s, v10.4s ,v29.4s + add v11.4s, v11.4s ,v29.4s + add v12.4s, v12.4s ,v29.4s + add v13.4s, v13.4s ,v29.4s + add v14.4s, v14.4s ,v29.4s + add v15.4s, v15.4s ,v29.4s + smax v0.4s, v0.4s ,v30.4s + smax v1.4s, v1.4s ,v30.4s + smax v2.4s, v2.4s ,v30.4s + smax v3.4s, v3.4s ,v30.4s + smax v4.4s, v4.4s ,v30.4s + smax v5.4s, v5.4s ,v30.4s + smax v6.4s, v6.4s ,v30.4s + smax v7.4s, v7.4s ,v30.4s + smax v8.4s, v8.4s ,v30.4s + smax v9.4s, v9.4s ,v30.4s + smax v10.4s, v10.4s ,v30.4s + smax v11.4s, v11.4s ,v30.4s + smax v12.4s, v12.4s ,v30.4s + smax v13.4s, v13.4s ,v30.4s + smax v14.4s, v14.4s ,v30.4s + smax v15.4s, v15.4s ,v30.4s + smin v0.4s, v0.4s ,v31.4s + smin v1.4s, v1.4s ,v31.4s + smin v2.4s, v2.4s ,v31.4s + smin v3.4s, v3.4s ,v31.4s + smin v4.4s, v4.4s ,v31.4s + smin v5.4s, v5.4s ,v31.4s + smin v6.4s, v6.4s ,v31.4s + smin v7.4s, v7.4s ,v31.4s + smin v8.4s, v8.4s ,v31.4s + smin v9.4s, v9.4s ,v31.4s + smin v10.4s, v10.4s ,v31.4s + smin v11.4s, v11.4s ,v31.4s + smin v12.4s, v12.4s ,v31.4s + smin v13.4s, v13.4s ,v31.4s + smin v14.4s, v14.4s ,v31.4s + smin v15.4s, v15.4s ,v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + sqxtn v4.4h, v4.4s + sqxtn v5.4h, v5.4s + sqxtn v6.4h, v6.4s + sqxtn v7.4h, v7.4s + sqxtn v8.4h, v8.4s + sqxtn v9.4h, v9.4s + sqxtn v10.4h, v10.4s + sqxtn v11.4h, v11.4s + sqxtn v12.4h, v12.4s + sqxtn v13.4h, v13.4s + sqxtn v14.4h, v14.4s + sqxtn v15.4h, v15.4s + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + sqxtn v4.8b, v4.8h + sqxtn v5.8b, v5.8h + sqxtn v6.8b, v6.8h + sqxtn v7.8b, v7.8h + sqxtn v8.8b, v8.8h + sqxtn v9.8b, v9.8h + sqxtn v10.8b, v10.8h + sqxtn v11.8b, v11.8h + sqxtn v12.8b, v12.8h + sqxtn v13.8b, v13.8h + sqxtn v14.8b, v14.8h + sqxtn v15.8b, v15.8h + + add x17, x3, #1 + add x18, x3, #2 + add x21, x3, #3 + st1 {v0.b}[0], [x3], x9 + st1 {v0.b}[1], [x17], x9 + st1 {v0.b}[2], [x18], x9 + st1 {v0.b}[3], [x21], x9 + + st1 {v1.b}[0], [x3], x9 + st1 {v1.b}[1], [x17], x9 + st1 {v1.b}[2], [x18], x9 + st1 {v1.b}[3], [x21], x9 + + st1 {v2.b}[0], [x3], x9 + st1 {v2.b}[1], [x17], x9 + st1 {v2.b}[2], [x18], x9 + st1 {v2.b}[3], [x21], x9 + + st1 {v3.b}[0], [x3], x9 + st1 {v3.b}[1], [x17], x9 + st1 {v3.b}[2], [x18], x9 + st1 {v3.b}[3], [x21], x9 + + st1 {v4.b}[0], [x3], x9 + st1 {v4.b}[1], [x17], x9 + st1 {v4.b}[2], [x18], x9 + st1 {v4.b}[3], [x21], x9 + + st1 {v5.b}[0], [x3], x9 + st1 {v5.b}[1], [x17], x9 + st1 {v5.b}[2], [x18], x9 + st1 {v5.b}[3], [x21], x9 + + st1 {v6.b}[0], [x3], x9 + st1 {v6.b}[1], [x17], x9 + st1 {v6.b}[2], [x18], x9 + st1 {v6.b}[3], [x21], x9 + + st1 {v7.b}[0], [x3], x9 + st1 {v7.b}[1], [x17], x9 + st1 {v7.b}[2], [x18], x9 + st1 {v7.b}[3], [x21], x9 + + st1 {v8.b}[0], [x3], x9 + st1 {v8.b}[1], [x17], x9 + st1 {v8.b}[2], [x18], x9 + st1 {v8.b}[3], [x21], x9 + + st1 {v9.b}[0], [x3], x9 + st1 {v9.b}[1], [x17], x9 + st1 {v9.b}[2], [x18], x9 + st1 {v9.b}[3], [x21], x9 + + st1 {v10.b}[0], [x3], x9 + st1 {v10.b}[1], [x17], x9 + st1 {v10.b}[2], [x18], x9 + st1 {v10.b}[3], [x21], x9 + + st1 {v11.b}[0], [x3], x9 + st1 {v11.b}[1], [x17], x9 + st1 {v11.b}[2], [x18], x9 + st1 {v11.b}[3], [x21], x9 + + st1 {v12.b}[0], [x3], x9 + st1 {v12.b}[1], [x17], x9 + st1 {v12.b}[2], [x18], x9 + st1 {v12.b}[3], [x21], x9 + + st1 {v13.b}[0], [x3], x9 + st1 {v13.b}[1], [x17], x9 + st1 {v13.b}[2], [x18], x9 + st1 {v13.b}[3], [x21], x9 + + st1 {v14.b}[0], [x3], x9 + st1 {v14.b}[1], [x17], x9 + st1 {v14.b}[2], [x18], x9 + st1 {v14.b}[3], [x21], x9 + + st1 {v15.b}[0], [x3], x9 + st1 {v15.b}[1], [x17], x9 + st1 {v15.b}[2], [x18], x9 + st1 {v15.b}[3], [x21], x9 + + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x18, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.4h}, [x17], #8 + ld1 {v16.4h}, [x22], x13 + ld1 {v17.4h}, [x22], x13 + smlal v0.4s, v16.4h, v25.4h + smlal v1.4s, v17.4h, v25.4h + ld1 {v18.4h}, [x22], x13 + ld1 {v19.4h}, [x22], x13 + smlal v2.4s, v18.4h, v25.4h + smlal v3.4s, v19.4h, v25.4h + ld1 {v20.4h}, [x22], x13 + ld1 {v21.4h}, [x22], x13 + smlal v4.4s, v20.4h, v25.4h + smlal v5.4s, v21.4h, v25.4h + ld1 {v22.4h}, [x22], x13 + ld1 {v23.4h}, [x22], x13 + smlal v6.4s, v22.4h, v25.4h + smlal v7.4s, v23.4h, v25.4h + subs x18, x18, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + + sqshl v0.4s, v0.4s ,v26.4s + sqshl v1.4s, v1.4s ,v26.4s + sqshl v2.4s, v2.4s ,v26.4s + sqshl v3.4s, v3.4s ,v26.4s + sqshl v4.4s, v4.4s ,v26.4s + sqshl v5.4s, v5.4s ,v26.4s + sqshl v6.4s, v6.4s ,v26.4s + sqshl v7.4s, v7.4s ,v26.4s + sqrdmulh v0.4s, v0.4s ,v27.4s + sqrdmulh v1.4s, v1.4s ,v27.4s + sqrdmulh v2.4s, v2.4s ,v27.4s + sqrdmulh v3.4s, v3.4s ,v27.4s + sqrdmulh v4.4s, v4.4s ,v27.4s + sqrdmulh v5.4s, v5.4s ,v27.4s + sqrdmulh v6.4s, v6.4s ,v27.4s + sqrdmulh v7.4s, v7.4s ,v27.4s + sqrshl v0.4s, v0.4s ,v28.4s + sqrshl v1.4s, v1.4s ,v28.4s + sqrshl v2.4s, v2.4s ,v28.4s + sqrshl v3.4s, v3.4s ,v28.4s + sqrshl v4.4s, v4.4s ,v28.4s + sqrshl v5.4s, v5.4s ,v28.4s + sqrshl v6.4s, v6.4s ,v28.4s + sqrshl v7.4s, v7.4s ,v28.4s + add v0.4s, v0.4s ,v29.4s + add v1.4s, v1.4s ,v29.4s + add v2.4s, v2.4s ,v29.4s + add v3.4s, v3.4s ,v29.4s + add v4.4s, v4.4s ,v29.4s + add v5.4s, v5.4s ,v29.4s + add v6.4s, v6.4s ,v29.4s + add v7.4s, v7.4s ,v29.4s + smax v0.4s, v0.4s ,v30.4s + smax v1.4s, v1.4s ,v30.4s + smax v2.4s, v2.4s ,v30.4s + smax v3.4s, v3.4s ,v30.4s + smax v4.4s, v4.4s ,v30.4s + smax v5.4s, v5.4s ,v30.4s + smax v6.4s, v6.4s ,v30.4s + smax v7.4s, v7.4s ,v30.4s + smin v0.4s, v0.4s ,v31.4s + smin v1.4s, v1.4s ,v31.4s + smin v2.4s, v2.4s ,v31.4s + smin v3.4s, v3.4s ,v31.4s + smin v4.4s, v4.4s ,v31.4s + smin v5.4s, v5.4s ,v31.4s + smin v6.4s, v6.4s ,v31.4s + smin v7.4s, v7.4s ,v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + sqxtn v4.4h, v4.4s + sqxtn v5.4h, v5.4s + sqxtn v6.4h, v6.4s + sqxtn v7.4h, v7.4s + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + sqxtn v4.8b, v4.8h + sqxtn v5.8b, v5.8h + sqxtn v6.8b, v6.8h + sqxtn v7.8b, v7.8h + + add x17, x3, #1 + add x18, x3, #2 + add x21, x3, #3 + st1 {v0.b}[0], [x3], x9 + st1 {v0.b}[1], [x17], x9 + st1 {v0.b}[2], [x18], x9 + st1 {v0.b}[3], [x21], x9 + + st1 {v1.b}[0], [x3], x9 + st1 {v1.b}[1], [x17], x9 + st1 {v1.b}[2], [x18], x9 + st1 {v1.b}[3], [x21], x9 + + st1 {v2.b}[0], [x3], x9 + st1 {v2.b}[1], [x17], x9 + st1 {v2.b}[2], [x18], x9 + st1 {v2.b}[3], [x21], x9 + + st1 {v3.b}[0], [x3], x9 + st1 {v3.b}[1], [x17], x9 + st1 {v3.b}[2], [x18], x9 + st1 {v3.b}[3], [x21], x9 + + st1 {v4.b}[0], [x3], x9 + st1 {v4.b}[1], [x17], x9 + st1 {v4.b}[2], [x18], x9 + st1 {v4.b}[3], [x21], x9 + + st1 {v5.b}[0], [x3], x9 + st1 {v5.b}[1], [x17], x9 + st1 {v5.b}[2], [x18], x9 + st1 {v5.b}[3], [x21], x9 + + st1 {v6.b}[0], [x3], x9 + st1 {v6.b}[1], [x17], x9 + st1 {v6.b}[2], [x18], x9 + st1 {v6.b}[3], [x21], x9 + + st1 {v7.b}[0], [x3], x9 + st1 {v7.b}[1], [x17], x9 + st1 {v7.b}[2], [x18], x9 + st1 {v7.b}[3], [x21], x9 + + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x18, x7 + mov x22, x16 + LoopKw: + ld1 {v16.4h}, [x22], x13 + ld1 {v25.4h}, [x17], #8 + smlal v0.4s, v16.4h, v25.4h + subs x18, x18, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + + sqshl v0.4s, v0.4s ,v26.4s + sqrdmulh v0.4s, v0.4s ,v27.4s + sqrshl v0.4s, v0.4s ,v28.4s + add v0.4s, v0.4s ,v29.4s + smax v0.4s, v0.4s ,v30.4s + smin v0.4s, v0.4s ,v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + mov x17, x3 + st1 {v0.b}[0], [x17], #1 + st1 {v0.b}[1], [x17], #1 + st1 {v0.b}[2], [x17], #1 + st1 {v0.b}[3], [x17], #1 + add x3, x3, x9 + + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + sub sp, sp, #48 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/DeconvDwFp32Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/DeconvDwFp32Center.S new file mode 100644 index 0000000000..07cd1a5cea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/DeconvDwFp32Center.S @@ -0,0 +1,64 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global DeconvDwFp32Center +#ifndef __APPLE__ +.type DeconvDwFp32Center, %function +#endif + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +DeconvDwFp32Center: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x18, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4s}, [x16], x8 + LoopKh: + mov x21, x18 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4s}, [x19], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x18, x18, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + sub sp, sp, #32 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/DeconvDwInt8Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/DeconvDwInt8Center.S new file mode 100644 index 0000000000..25433d7a5f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/DeconvDwInt8Center.S @@ -0,0 +1,65 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global DeconvDwInt8Center +#ifndef __APPLE__ +.type DeconvDwInt8Center, %function +#endif + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +DeconvDwInt8Center: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x18, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4h}, [x16], x8 + LoopKh: + mov x21, x18 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4h}, [x19], #8 + smlal v0.4s, v1.4h, v2.4h + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x18, x18, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + add x16, x16, x8 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + sub sp, sp, #32 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S new file mode 100644 index 0000000000..be649b0e58 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S @@ -0,0 +1,730 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmFp32_8x8 +#ifndef __APPLE__ +.type IndirectGemmFp32_8x8, %function +#endif + +// void IndirectGemmFp32_8x8(float *output, float *input, float *weight, float *bias, +// size_t kSize, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); +// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +// x8:mode, x9: writeMode, x10: relu, x11:relu6 +// mode = 0 for general convolution, where one conv unit is a row +// mode = 1 for winograd/common gemm, where the total channels of one input is a row +IndirectGemmFp32_8x8: + + .macro INIT_BIAS + dup v16.4s, wzr + dup v17.4s, wzr + cbz x3, InitBias + ld1 {v16.4s, v17.4s}, [x3] + InitBias: + mov v18.16b, v16.16b + mov v19.16b, v17.16b + mov v20.16b, v16.16b + mov v21.16b, v17.16b + mov v22.16b, v16.16b + mov v23.16b, v17.16b + mov v24.16b, v16.16b + mov v25.16b, v17.16b + mov v26.16b, v16.16b + mov v27.16b, v17.16b + mov v28.16b, v16.16b + mov v29.16b, v17.16b + mov v30.16b, v16.16b + mov v31.16b, v17.16b + .endm + + .macro INIT_BIAS_HALF + dup v16.4s, wzr + cbz x3, InitBiasHalf + ld1 {v16.4s}, [x3] + InitBiasHalf: + mov v18.16b, v16.16b + mov v20.16b, v16.16b + mov v22.16b, v16.16b + mov v24.16b, v16.16b + mov v26.16b, v16.16b + mov v28.16b, v16.16b + mov v30.16b, v16.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + ldr x8, [sp, #0] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + + cbnz x8, NoStepShuffle + // step is one for common convolution, where ic8 should multiply by kernel size + // step is (a+b-1) for F(a,b) in winograd + mul x5, x4, x5 + mov x4, #1 + +NoStepShuffle: + // x8 is used to store offset now + // only useful for WriteC4 + mov x8, #16 + mul x8, x8, x4 + +IndirectGemmStart: + + cmp x6, #4 + ble LoopOcHalf + + LoopOc: + + mov x14, x4 + mov x12, x1 + + LoopKsize: + + mov x15, x0 + INIT_BIAS + + // load input for output 1-2 + ld1 {v0.4s, v1.4s}, [x12], #32 + // load weight + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + // step for output 1-2 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + fmla v19.4s, v9.4s, v1.s[0] + // load input for output 3-4 + ld1 {v2.4s, v3.4s}, [x12], #32 + // another step for output 1-2 + fmla v16.4s, v10.4s, v0.s[1] + fmla v17.4s, v11.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + fmla v19.4s, v11.4s, v1.s[1] + // load input for output 5-8 + // input cache should be refreshed after loading + // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 3-8 + fmla v20.4s, v8.4s, v2.s[0] + fmla v21.4s, v9.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + fmla v23.4s, v9.4s, v3.s[0] + + subs x13, x5, #1 + beq LoopIcEnd + + LoopIc: + fmla v24.4s, v8.4s, v4.s[0] + fmla v25.4s, v9.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v27.4s, v9.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v29.4s, v9.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + fmla v31.4s, v9.4s, v7.s[0] + // load weight + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v21.4s, v11.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + fmla v23.4s, v11.4s, v3.s[1] + fmla v24.4s, v10.4s, v4.s[1] + fmla v25.4s, v11.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v27.4s, v11.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v29.4s, v11.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + fmla v31.4s, v11.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v17.4s, v13.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v19.4s, v13.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v21.4s, v13.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v23.4s, v13.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v25.4s, v13.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v27.4s, v13.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v29.4s, v13.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + fmla v31.4s, v13.4s, v7.s[2] + // load weight + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v17.4s, v15.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + fmla v19.4s, v15.4s, v1.s[3] + fmla v20.4s, v14.4s, v2.s[3] + fmla v21.4s, v15.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + fmla v23.4s, v15.4s, v3.s[3] + // load input for output 1-4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + fmla v24.4s, v14.4s, v4.s[3] + fmla v25.4s, v15.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v27.4s, v15.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v29.4s, v15.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + fmla v31.4s, v15.4s, v7.s[3] + // load input for output 5-8 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 1-8 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + fmla v19.4s, v9.4s, v1.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v17.4s, v11.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + fmla v19.4s, v11.4s, v1.s[1] + fmla v20.4s, v8.4s, v2.s[0] + fmla v21.4s, v9.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + fmla v23.4s, v9.4s, v3.s[0] + + subs x13, x13, #1 + bne LoopIc + + LoopIcEnd: + fmla v24.4s, v8.4s, v4.s[0] + fmla v25.4s, v9.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v27.4s, v9.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v29.4s, v9.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + fmla v31.4s, v9.4s, v7.s[0] + // load weight + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v21.4s, v11.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + fmla v23.4s, v11.4s, v3.s[1] + fmla v24.4s, v10.4s, v4.s[1] + fmla v25.4s, v11.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v27.4s, v11.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v29.4s, v11.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + fmla v31.4s, v11.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v17.4s, v13.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v19.4s, v13.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v21.4s, v13.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v23.4s, v13.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v25.4s, v13.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v27.4s, v13.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v29.4s, v13.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + fmla v31.4s, v13.4s, v7.s[2] + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v17.4s, v15.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + fmla v19.4s, v15.4s, v1.s[3] + fmla v20.4s, v14.4s, v2.s[3] + fmla v21.4s, v15.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + fmla v23.4s, v15.4s, v3.s[3] + fmla v24.4s, v14.4s, v4.s[3] + fmla v25.4s, v15.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v27.4s, v15.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v29.4s, v15.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + fmla v31.4s, v15.4s, v7.s[3] + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + cbnz x11, Relu6 + cbnz x10, Relu + b WriteStart + Relu6: + movi v1.4s, #6 + scvtf v1.4s, v1.4s + fmin v16.4s, v16.4s ,v1.4s + fmin v17.4s, v17.4s ,v1.4s + fmin v18.4s, v18.4s ,v1.4s + fmin v19.4s, v19.4s ,v1.4s + fmin v20.4s, v20.4s ,v1.4s + fmin v21.4s, v21.4s ,v1.4s + fmin v22.4s, v22.4s ,v1.4s + fmin v23.4s, v23.4s ,v1.4s + fmin v24.4s, v24.4s ,v1.4s + fmin v25.4s, v25.4s ,v1.4s + fmin v26.4s, v26.4s ,v1.4s + fmin v27.4s, v27.4s ,v1.4s + fmin v28.4s, v28.4s ,v1.4s + fmin v29.4s, v29.4s ,v1.4s + fmin v30.4s, v30.4s ,v1.4s + fmin v31.4s, v31.4s ,v1.4s + Relu: + dup v0.4s, wzr + fmax v16.4s, v16.4s ,v0.4s + fmax v17.4s, v17.4s ,v0.4s + fmax v18.4s, v18.4s ,v0.4s + fmax v19.4s, v19.4s ,v0.4s + fmax v20.4s, v20.4s ,v0.4s + fmax v21.4s, v21.4s ,v0.4s + fmax v22.4s, v22.4s ,v0.4s + fmax v23.4s, v23.4s ,v0.4s + fmax v24.4s, v24.4s ,v0.4s + fmax v25.4s, v25.4s ,v0.4s + fmax v26.4s, v26.4s ,v0.4s + fmax v27.4s, v27.4s ,v0.4s + fmax v28.4s, v28.4s ,v0.4s + fmax v29.4s, v29.4s ,v0.4s + fmax v30.4s, v30.4s ,v0.4s + fmax v31.4s, v31.4s ,v0.4s + + WriteStart: + cbnz x9, WriteC4 + cmp x6, #5 + beq Write5 + cmp x6, #6 + beq Write6 + cmp x6, #7 + beq Write7 + b Write8 + Write5: + add x17, x15, #16 + st1 {v16.4s}, [x15], x7 + str s17, [x17] + add x17, x17, x7 + st1 {v18.4s}, [x15], x7 + str s19, [x17] + add x17, x17, x7 + st1 {v20.4s}, [x15], x7 + str s21, [x17] + add x17, x17, x7 + st1 {v22.4s}, [x15], x7 + str s23, [x17] + add x17, x17, x7 + st1 {v24.4s}, [x15], x7 + str s25, [x17] + add x17, x17, x7 + st1 {v26.4s}, [x15], x7 + str s27, [x17] + add x17, x17, x7 + st1 {v28.4s}, [x15], x7 + str s29, [x17] + add x17, x17, x7 + st1 {v30.4s}, [x15] + str s31, [x17] + add x0, x0, #20 + b WriteEnd + Write6: + add x17, x15, #16 + st1 {v16.4s}, [x15], x7 + dup s16, v17.s[1] + stp s17, s16, [x17] + add x17, x17, x7 + st1 {v18.4s}, [x15], x7 + dup s18, v19.s[1] + stp s19, s18, [x17] + add x17, x17, x7 + st1 {v20.4s}, [x15], x7 + dup s20, v21.s[1] + stp s21, s20, [x17] + add x17, x17, x7 + st1 {v22.4s}, [x15], x7 + dup s22, v23.s[1] + stp s23, s22, [x17] + add x17, x17, x7 + st1 {v24.4s}, [x15], x7 + dup s24, v25.s[1] + stp s25, s24, [x17] + add x17, x17, x7 + st1 {v26.4s}, [x15], x7 + dup s26, v27.s[1] + stp s27, s26, [x17] + add x17, x17, x7 + st1 {v28.4s}, [x15], x7 + dup s28, v29.s[1] + stp s29, s28, [x17] + add x17, x17, x7 + st1 {v30.4s}, [x15] + dup s30, v31.s[1] + stp s31, s30, [x17] + add x0, x0, #24 + b WriteEnd + Write7: + add x17, x15, #16 + add x16, x15, #24 + st1 {v16.4s}, [x15], x7 + dup s16, v17.s[1] + stp s17, s16, [x17] + add x17, x17, x7 + st1 {v17.s}[2], [x16], x7 + st1 {v18.4s}, [x15], x7 + dup s18, v19.s[1] + stp s19, s18, [x17] + add x17, x17, x7 + st1 {v19.s}[2], [x16], x7 + st1 {v20.4s}, [x15], x7 + dup s20, v21.s[1] + stp s21, s20, [x17] + add x17, x17, x7 + st1 {v21.s}[2], [x16], x7 + st1 {v22.4s}, [x15], x7 + dup s22, v23.s[1] + stp s23, s22, [x17] + add x17, x17, x7 + st1 {v23.s}[2], [x16], x7 + st1 {v24.4s}, [x15], x7 + dup s24, v25.s[1] + stp s25, s24, [x17] + add x17, x17, x7 + st1 {v25.s}[2], [x16], x7 + st1 {v26.4s}, [x15], x7 + dup s26, v27.s[1] + stp s27, s26, [x17] + add x17, x17, x7 + st1 {v27.s}[2], [x16], x7 + st1 {v28.4s}, [x15], x7 + dup s28, v29.s[1] + stp s29, s28, [x17] + add x17, x17, x7 + st1 {v29.s}[2], [x16], x7 + st1 {v30.4s}, [x15], x7 + dup s30, v31.s[1] + stp s31, s30, [x17] + add x17, x17, x7 + st1 {v31.s}[2], [x16], x7 + add x0, x0, #28 + b WriteEnd + WriteC4: + st1 {v16.4s}, [x15], x7 + st1 {v18.4s}, [x15], x7 + st1 {v20.4s}, [x15], x7 + st1 {v22.4s}, [x15], x7 + st1 {v24.4s}, [x15], x7 + st1 {v26.4s}, [x15], x7 + st1 {v28.4s}, [x15], x7 + st1 {v30.4s}, [x15] + add x15, x8, x0 + st1 {v17.4s}, [x15], x7 + st1 {v19.4s}, [x15], x7 + st1 {v21.4s}, [x15], x7 + st1 {v23.4s}, [x15], x7 + st1 {v25.4s}, [x15], x7 + st1 {v27.4s}, [x15], x7 + st1 {v29.4s}, [x15], x7 + st1 {v31.4s}, [x15] + add x0, x0, #16 + b WriteEnd + Write8: + st1 {v16.4s, v17.4s}, [x15], x7 + st1 {v18.4s, v19.4s}, [x15], x7 + st1 {v20.4s, v21.4s}, [x15], x7 + st1 {v22.4s, v23.4s}, [x15], x7 + st1 {v24.4s, v25.4s}, [x15], x7 + st1 {v26.4s, v27.4s}, [x15], x7 + st1 {v28.4s, v29.4s}, [x15], x7 + st1 {v30.4s, v31.4s}, [x15] + add x0, x0, #32 + + WriteEnd: + + subs x14, x14, #1 + bne LoopKsize + + subs x6, x6, #8 + ble LoopOcEnd + cbz x9, NoStepC4Block + add x0, x0, x8 + NoStepC4Block: + cbz x3, NoStepForward + add x3, x3, #32 + NoStepForward: + cmp x6, #4 + bgt LoopOc + + LoopOcHalf: + mov x18, #32 + + mov x14, x4 + mov x12, x1 + + LoopKsizeHalf: + + mov x15, x0 + INIT_BIAS_HALF + + // load input for output 1-2 + ld1 {v0.4s, v1.4s}, [x12], #32 + // load weight + ld1 {v8.4s}, [x2], x18 + ld1 {v10.4s}, [x2], x18 + // step for output 1-2 + fmla v16.4s, v8.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + // load input for output 3-4 + ld1 {v2.4s, v3.4s}, [x12], #32 + // another step for output 1-2 + fmla v16.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + // load input for output 5-8 + // input cache should be refreshed after loading + // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 3-8 + fmla v20.4s, v8.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + + subs x13, x5, #1 + beq LoopIcEndHalf + + LoopIcHalf: + fmla v24.4s, v8.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + // load weight + ld1 {v12.4s}, [x2], x18 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + // load weight + ld1 {v14.4s}, [x2], x18 + fmla v24.4s, v10.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + // load weight + ld1 {v8.4s}, [x2], x18 + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + // load weight + ld1 {v10.4s}, [x2], x18 + fmla v20.4s, v14.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + // load input for output 1-4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + fmla v24.4s, v14.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + // load input for output 5-8 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 1-8 + fmla v16.4s, v8.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + fmla v20.4s, v8.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + + subs x13, x13, #1 + bne LoopIcHalf + + LoopIcEndHalf: + fmla v24.4s, v8.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + // load weight + ld1 {v12.4s}, [x2], x18 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + // load weight + ld1 {v14.4s}, [x2], x18 + fmla v24.4s, v10.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + fmla v20.4s, v14.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + fmla v24.4s, v14.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + + cbnz x11, Relu6Half + cbnz x10, ReluHalf + b WriteStartHalf + Relu6Half: + movi v1.4s, #6 + scvtf v1.4s, v1.4s + fmin v16.4s, v16.4s ,v1.4s + fmin v18.4s, v18.4s ,v1.4s + fmin v20.4s, v20.4s ,v1.4s + fmin v22.4s, v22.4s ,v1.4s + fmin v24.4s, v24.4s ,v1.4s + fmin v26.4s, v26.4s ,v1.4s + fmin v28.4s, v28.4s ,v1.4s + fmin v30.4s, v30.4s ,v1.4s + ReluHalf: + dup v0.4s, wzr + fmax v16.4s, v16.4s ,v0.4s + fmax v18.4s, v18.4s ,v0.4s + fmax v20.4s, v20.4s ,v0.4s + fmax v22.4s, v22.4s ,v0.4s + fmax v24.4s, v24.4s ,v0.4s + fmax v26.4s, v26.4s ,v0.4s + fmax v28.4s, v28.4s ,v0.4s + fmax v30.4s, v30.4s ,v0.4s + + WriteStartHalf: + cbnz x9, Write4 + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + str s16, [x15] + add x15, x15, x7 + str s18, [x15] + add x15, x15, x7 + str s20, [x15] + add x15, x15, x7 + str s22, [x15] + add x15, x15, x7 + str s24, [x15] + add x15, x15, x7 + str s26, [x15] + add x15, x15, x7 + str s28, [x15] + add x15, x15, x7 + str s30, [x15] + add x0, x0, #4 + b WriteEnd + Write2: + dup s17, v16.s[1] + stp s16, s17, [x15] + add x15, x15, x7 + dup s19, v18.s[1] + stp s18, s19, [x15] + add x15, x15, x7 + dup s21, v20.s[1] + stp s20, s21, [x15] + add x15, x15, x7 + dup s23, v22.s[1] + stp s22, s23, [x15] + add x15, x15, x7 + dup s25, v24.s[1] + stp s24, s25, [x15] + add x15, x15, x7 + dup s27, v26.s[1] + stp s26, s27, [x15] + add x15, x15, x7 + dup s29, v28.s[1] + stp s28, s29, [x15] + add x15, x15, x7 + dup s31, v30.s[1] + stp s30, s31, [x15] + add x0, x0, #8 + b WriteEnd + Write3: + add x17, x15, #8 + dup s17, v16.s[1] + stp s16, s17, [x15] + add x15, x15, x7 + st1 {v16.s}[2], [x17], x7 + dup s19, v18.s[1] + stp s18, s19, [x15] + add x15, x15, x7 + st1 {v18.s}[2], [x17], x7 + dup s21, v20.s[1] + stp s20, s21, [x15] + add x15, x15, x7 + st1 {v20.s}[2], [x17], x7 + dup s23, v22.s[1] + stp s22, s23, [x15] + add x15, x15, x7 + st1 {v22.s}[2], [x17], x7 + dup s25, v24.s[1] + stp s24, s25, [x15] + add x15, x15, x7 + st1 {v24.s}[2], [x17], x7 + dup s27, v26.s[1] + stp s26, s27, [x15] + add x15, x15, x7 + st1 {v26.s}[2], [x17], x7 + dup s29, v28.s[1] + stp s28, s29, [x15] + add x15, x15, x7 + st1 {v28.s}[2], [x17], x7 + dup s31, v30.s[1] + stp s30, s31, [x15] + st1 {v30.s}[2], [x17] + add x0, x0, #12 + b WriteEndHalf + Write4: + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + st1 {v16.4s}, [x15], x7 + st1 {v18.4s}, [x15], x7 + st1 {v20.4s}, [x15], x7 + st1 {v22.4s}, [x15], x7 + st1 {v24.4s}, [x15], x7 + st1 {v26.4s}, [x15], x7 + st1 {v28.4s}, [x15], x7 + st1 {v30.4s}, [x15] + add x0, x0, #16 + + WriteEndHalf: + + subs x14, x14, #1 + bne LoopKsizeHalf + +LoopOcEnd: + + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt16to32_8x4.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt16to32_8x4.S new file mode 100644 index 0000000000..bfad61a362 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,221 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt16to32_8x4 +#ifndef __APPLE__ +.type IndirectGemmInt16to32_8x4, %function +#endif + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t ksize, size_t ic8, size_t oc4, size_t offset); +// x0: output, x1: input, x2: weight, x3: ksize, x4: ic8, x5: oc4, x6: offset +IndirectGemmInt16to32_8x4: + + .macro INIT_ZERO + dup v28.4s, wzr + mov v29.16b, v28.16b + mov v30.16b, v28.16b + mov v31.16b, v28.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + LoopOc: + mov x7, x3 + mov x8, x1 + + LoopKsize: + mov x9, x0 + INIT_ZERO + + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + // load weight + ld1 {v16.8h}, [x2], #16 + smull v24.4s, v16.4h, v0.h[0] + smull v25.4s, v16.4h, v1.h[0] + // load weight + ld1 {v17.8h}, [x2], #16 + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smull v26.4s, v16.4h, v2.h[0] + smull v27.4s, v16.4h, v3.h[0] + + subs x10, x4, #1 + beq LoopIcEnd + + LoopIc: + + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + // load weight + ld1 {v16.8h, v17.8h}, [x2], #32 + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v16.4h, v0.h[0] + smlal v25.4s, v16.4h, v1.h[0] + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + smlal v26.4s, v16.4h, v2.h[0] + smlal v27.4s, v16.4h, v3.h[0] + + subs x10, x10, #1 + bne LoopIc + + LoopIcEnd: + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + st1 {v24.4s}, [x9], x6 + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + st1 {v25.4s}, [x9], x6 + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + st1 {v26.4s}, [x9], x6 + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + st1 {v27.4s}, [x9], x6 + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + st1 {v28.4s}, [x9], x6 + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + st1 {v29.4s}, [x9], x6 + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + + st1 {v30.4s}, [x9], x6 + st1 {v31.4s}, [x9] + + subs x7, x7, #1 + add x0, x0, #16 + bne LoopKsize + + subs x5, x5, #1 + bne LoopOc + + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S new file mode 100644 index 0000000000..f70495e0e2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S @@ -0,0 +1,326 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt8_4x4 +#ifndef __APPLE__ +.type IndirectGemmInt8_4x4, %function +#endif + +// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, +// size_t shift_before, size_t shift_after); +// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +IndirectGemmInt8_4x4: + + .macro INIT_BIAS + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + ldr x15, [sp] + ldr w8, [sp, #8] + ldr w9, [sp, #16] + ldr w16, [sp, #24] + ldr w17, [sp, #32] + ldr w18, [sp, #40] + ldr w19, [sp, #48] + + mul x5, x4, x5 + mov x4, #1 + + LoopOc: + + mov x10, x4 + mov x12, x1 + + LoopKsize: + INIT_BIAS + mov x11, x0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + // load weight + ld1 {v4.16b, v5.16b}, [x2], #32 + // step for output 1-4 + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v0.8b, v5.8b + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v0.16b, v5.16b + // load input for output 9-16 + ld1 {v6.16b, v7.16b}, [x2], #32 + // another step for output 5-8 + smull v12.8h, v1.8b, v4.8b + smull v13.8h, v1.8b, v5.8b + smlal2 v12.8h, v1.16b, v4.16b + smlal2 v13.8h, v1.16b, v5.16b + ld1 {v2.16b, v3.16b}, [x12], #32 + smull v10.8h, v0.8b, v6.8b + smull v11.8h, v0.8b, v7.8b + smlal2 v10.8h, v0.16b, v6.16b + smlal2 v11.8h, v0.16b, v7.16b + saddlp v16.4s, v8.8h + smull v14.8h, v1.8b, v6.8b + smull v15.8h, v1.8b, v7.8b + smlal2 v14.8h, v1.16b, v6.16b + smlal2 v15.8h, v1.16b, v7.16b + saddlp v17.4s, v9.8h + + subs x13, x5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + sadalp v18.4s, v10.8h + smull v8.8h, v2.8b, v4.8b + smull v9.8h, v2.8b, v5.8b + sadalp v19.4s, v11.8h + smlal2 v8.8h, v2.16b, v4.16b + smlal2 v9.8h, v2.16b, v5.16b + sadalp v20.4s, v12.8h + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v2.8b, v7.8b + sadalp v21.4s, v13.8h + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v2.16b, v7.16b + sadalp v22.4s, v14.8h + smull v12.8h, v3.8b, v4.8b + smull v13.8h, v3.8b, v5.8b + sadalp v23.4s, v15.8h + smlal2 v12.8h, v3.16b, v4.16b + smlal2 v13.8h, v3.16b, v5.16b + sadalp v24.4s, v8.8h + ld1 {v4.16b, v5.16b}, [x2], #32 + smull v14.8h, v3.8b, v6.8b + smull v15.8h, v3.8b, v7.8b + sadalp v25.4s, v9.8h + smlal2 v14.8h, v3.16b, v6.16b + smlal2 v15.8h, v3.16b, v7.16b + sadalp v26.4s, v10.8h + ld1 {v6.16b, v7.16b}, [x2], #32 + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v0.8b, v5.8b + sadalp v27.4s, v11.8h + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v0.16b, v5.16b + sadalp v28.4s, v12.8h + ld1 {v2.16b, v3.16b}, [x12], #32 + smull v12.8h, v1.8b, v4.8b + smull v13.8h, v1.8b, v5.8b + sadalp v29.4s, v13.8h + smlal2 v12.8h, v1.16b, v4.16b + smlal2 v13.8h, v1.16b, v5.16b + sadalp v30.4s, v14.8h + smull v10.8h, v0.8b, v6.8b + smull v11.8h, v0.8b, v7.8b + sadalp v31.4s, v15.8h + smlal2 v10.8h, v0.16b, v6.16b + smlal2 v11.8h, v0.16b, v7.16b + sadalp v16.4s, v8.8h + smull v14.8h, v1.8b, v6.8b + smull v15.8h, v1.8b, v7.8b + sadalp v17.4s, v9.8h + smlal2 v14.8h, v1.16b, v6.16b + smlal2 v15.8h, v1.16b, v7.16b + + subs x13, x13, #1 + bne LoopIc + + LoopIcEnd: + sadalp v18.4s, v10.8h + smull v8.8h, v2.8b, v4.8b + smull v9.8h, v2.8b, v5.8b + sadalp v19.4s, v11.8h + smlal2 v8.8h, v2.16b, v4.16b + smlal2 v9.8h, v2.16b, v5.16b + sadalp v20.4s, v12.8h + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v2.8b, v7.8b + sadalp v21.4s, v13.8h + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v2.16b, v7.16b + sadalp v22.4s, v14.8h + smull v12.8h, v3.8b, v4.8b + smull v13.8h, v3.8b, v5.8b + sadalp v23.4s, v15.8h + smlal2 v12.8h, v3.16b, v4.16b + smlal2 v13.8h, v3.16b, v5.16b + sadalp v24.4s, v8.8h + smull v14.8h, v3.8b, v6.8b + smull v15.8h, v3.8b, v7.8b + sadalp v25.4s, v9.8h + smlal2 v14.8h, v3.16b, v6.16b + smlal2 v15.8h, v3.16b, v7.16b + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s ,v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + + // load sum + mov x20, x15 + ld1r {v8.4s}, [x20], #4 + ld1r {v9.4s}, [x20], #4 + ld1r {v10.4s}, [x20], #4 + ld1r {v11.4s}, [x20] + // pairwise add + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + cbz x3, NoReadBias + ld1 {v12.4s}, [x3] + NoReadBias: + addp v16.4s, v16.4s, v18.4s + addp v20.4s, v20.4s, v22.4s + addp v24.4s, v24.4s, v26.4s + addp v28.4s, v28.4s, v30.4s + sub v16.4s, v16.4s, v8.4s + sub v20.4s, v20.4s, v9.4s + sub v24.4s, v24.4s, v10.4s + sub v28.4s, v28.4s, v11.4s + add v16.4s, v16.4s, v12.4s + add v20.4s, v20.4s, v12.4s + add v24.4s, v24.4s, v12.4s + add v28.4s, v28.4s, v12.4s + + dup v2.4s, w18 + sqshl v16.4s, v16.4s ,v2.4s + sqshl v20.4s, v20.4s ,v2.4s + sqshl v24.4s, v24.4s ,v2.4s + sqshl v28.4s, v28.4s ,v2.4s + + dup v3.4s, w17 + sqrdmulh v16.4s, v16.4s ,v3.4s + sqrdmulh v20.4s, v20.4s ,v3.4s + sqrdmulh v24.4s, v24.4s ,v3.4s + sqrdmulh v28.4s, v28.4s ,v3.4s + + dup v4.4s, w19 + sqrshl v16.4s, v16.4s ,v4.4s + sqrshl v20.4s, v20.4s ,v4.4s + sqrshl v24.4s, v24.4s ,v4.4s + sqrshl v28.4s, v28.4s ,v4.4s + + dup v5.4s, w16 + add v16.4s, v16.4s ,v5.4s + add v20.4s, v20.4s ,v5.4s + add v24.4s, v24.4s ,v5.4s + add v28.4s, v28.4s ,v5.4s + + dup v0.4s, w8 + smax v16.4s, v16.4s ,v0.4s + smax v20.4s, v20.4s ,v0.4s + smax v24.4s, v24.4s ,v0.4s + smax v28.4s, v28.4s ,v0.4s + + dup v1.4s, w9 + smin v16.4s, v16.4s ,v1.4s + smin v20.4s, v20.4s ,v1.4s + smin v24.4s, v24.4s ,v1.4s + smin v28.4s, v28.4s ,v1.4s + + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v20.4s + sqxtn v15.8b, v13.8h + sqxtn v14.4h, v24.4s + sqxtn2 v14.8h, v28.4s + sqxtn2 v15.16b, v14.8h + + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + WriteStart: + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + st1 {v15.b}[0], [x11], x7 + st1 {v15.b}[4], [x11], x7 + st1 {v15.b}[8], [x11], x7 + st1 {v15.b}[12], [x11] + add x0, x0, #1 + b WriteEnd + Write2: + st1 {v15.h}[0], [x11], x7 + st1 {v15.h}[2], [x11], x7 + st1 {v15.h}[4], [x11], x7 + st1 {v15.h}[6], [x11] + add x0, x0, #2 + b WriteEnd + Write3: + add x14, x11, #2 + st1 {v15.h}[0], [x11], x7 + st1 {v15.b}[2], [x14], x7 + st1 {v15.h}[2], [x11], x7 + st1 {v15.b}[6], [x14], x7 + st1 {v15.h}[4], [x11], x7 + st1 {v15.b}[10], [x14], x7 + st1 {v15.h}[6], [x11] + st1 {v15.b}[14], [x14] + add x0, x0, #3 + b WriteEnd + Write4: + st1 {v15.s}[0], [x11], x7 + st1 {v15.s}[1], [x11], x7 + st1 {v15.s}[2], [x11], x7 + st1 {v15.s}[3], [x11] + add x0, x0, #4 + + WriteEnd: + + subs x10, x10, #1 + bne LoopKsize + + subs x6, x6, #4 + cbz x3, NoStepFowrard + add x3, x3, #16 + NoStepFowrard: + bgt LoopOc + + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add.S new file mode 100644 index 0000000000..181de0de72 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add.S @@ -0,0 +1,82 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global BiasAdd +#ifndef __APPLE__ + .type BiasAdd, %function +#endif + + + +//void BiasAdd(const float* bias, float* data, size_t oc4, size_t plan_size) + +//Auto: x0:bias, x1: data, x2:oc4,x3: plan_size, + +BiasAdd: +cmp x2, #0 +beq BiasAddEnd + +cmp x3, #0 +beq BiasAddEnd + +LoopOc4: +ld1 {v0.4s}, [x0], #16 +mov x6, x3 +mov x5, x1 + +Loop16LineIn: +cmp x6, #4 +blt L4 +sub x6, x6, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fadd v5.4s, v0.4s, v1.4s +fadd v6.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +cmp x6, #4 +blt Loop16LineOut + +Loop16: +st1 {v5.4s, v6.4s}, [x1], #32 +fadd v7.4s, v0.4s, v3.4s +fadd v8.4s, v0.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +st1 {v7.4s, v8.4s}, [x1], #32 +fadd v5.4s, v0.4s, v1.4s +fadd v6.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +sub x6, x6, #4 +cmp x6, #4 +bge Loop16 + +Loop16LineOut: +st1 {v5.4s, v6.4s}, [x1], #32 +fadd v7.4s, v0.4s, v3.4s +fadd v8.4s, v0.4s, v4.4s + +st1 {v7.4s, v8.4s}, [x1], #32 + +L4: +cmp x6, #0 +beq Loop16LineEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fadd v2.4s, v1.4s, v0.4s +subs x6, x6, #1 +st1 {v2.4s}, [x1], #16 +bne Loop4 + +Loop16LineEnd: +subs x2, x2, #1 +bne LoopOc4 + +BiasAddEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add_relu.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add_relu.S new file mode 100644 index 0000000000..f9e4eccc69 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add_relu.S @@ -0,0 +1,94 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global BiasAddRelu +#ifndef __APPLE__ + .type BiasAddRelu, %function +#endif + + +//void BiasAddRelu(const float* bias, float* data, size_t oc4, size_t plan_size) + +//Auto: x0:bias, x1: data, x2:oc4,x3: plan_size, + +BiasAddRelu: +cmp x2, #0 +beq BiasAddEnd + +cmp x3, #0 +beq BiasAddEnd + +dup v16.4s, wzr + +LoopOc4: +ld1 {v0.4s}, [x0], #16 +mov x6, x3 +mov x5, x1 + +Loop16LineIn: +cmp x6, #4 +blt L4 +sub x6, x6, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s + +cmp x6, #4 +blt Loop16LineOut + +Loop16: +st1 {v23.4s, v24.4s}, [x1], #32 +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s + +st1 {v27.4s, v28.4s}, [x1], #32 +ld1 {v3.4s, v4.4s}, [x5], #32 +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s +sub x6, x6, #4 +cmp x6, #4 +bge Loop16 + +Loop16LineOut: +st1 {v23.4s, v24.4s}, [x1], #32 +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s + +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +st1 {v27.4s, v28.4s}, [x1], #32 + +L4: +cmp x6, #0 +beq Loop16LineEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fadd v1.4s, v1.4s, v0.4s +fmax v1.4s, v1.4s, v16.4s + +subs x6, x6, #1 +st1 {v1.4s}, [x1], #16 +bne Loop4 + +Loop16LineEnd: +subs x2, x2, #1 +bne LoopOc4 + +BiasAddEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add_relu6.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add_relu6.S new file mode 100644 index 0000000000..77c563a812 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/bias_add_relu6.S @@ -0,0 +1,113 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global BiasAddRelu6 +#ifndef __APPLE__ + .type BiasAddRelu6, %function +#endif + + + +//void BiasAddRelu6(const float* bias, float* data, size_t oc4, size_t plan_size) + +//Auto: x0:bias, x1: data, x2:oc4,x3: plan_size, + +BiasAddRelu6: +cmp x2, #0 +beq BiasAddEnd + +cmp x3, #0 +beq BiasAddEnd + +dup v16.4s, wzr +movi v17.4s, #6 +scvtf v17.4s, v17.4s + +LoopOc4: +ld1 {v0.4s}, [x0], #16 +mov x6, x3 +mov x5, x1 + +Loop16LineIn: +cmp x6, #4 +blt L4 +sub x6, x6, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s + + + +cmp x6, #4 +blt Loop16LineOut + +Loop16: +fmin v23.4s, v23.4s, v17.4s +fmin v24.4s, v24.4s, v17.4s +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +st1 {v23.4s, v24.4s}, [x1], #32 +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s + +fmin v27.4s, v27.4s, v17.4s +fmin v28.4s, v28.4s, v17.4s +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +st1 {v27.4s, v28.4s}, [x1], #32 + + +sub x6, x6, #4 +cmp x6, #4 +bge Loop16 + +Loop16LineOut: +fmin v23.4s, v23.4s, v17.4s +fmin v24.4s, v24.4s, v17.4s +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s + +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +st1 {v23.4s, v24.4s}, [x1], #32 + +fmin v27.4s, v27.4s, v17.4s +fmin v28.4s, v28.4s, v17.4s + +st1 {v27.4s, v28.4s}, [x1], #32 + +L4: +cmp x6, #0 +beq Loop16LineEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fadd v1.4s, v1.4s, v0.4s +fmax v1.4s, v1.4s, v16.4s +fmin v1.4s, v1.4s, v17.4s + +subs x6, x6, #1 +st1 {v1.4s}, [x1], #16 +bne Loop4 + +Loop16LineEnd: +subs x2, x2, #1 +bne LoopOc4 + +BiasAddEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s new file mode 100644 index 0000000000..b33c71d34e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s @@ -0,0 +1,295 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatMulFloatNeon64 +#ifndef __APPLE__ + .type MatMulFloatNeon64, %function +#endif + +// A: LM [row_8 * depth] col_8_major +// B: RM [depth * col_8] row_8_major +// C: A*B [row_8 * col_8] col_8x8_major +// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8] +/////////////////////////////////////////////////////////////////////////////// +//CommLoopMul RM 1x8 block +// /-----------------------------------------\ +// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| +// \-----------------------------------------/ +// LM 8x1 block +// /---------------------\ /-----------------------------------------\ +// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]| +// | ... | | ... ... | +// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]| +// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]| +// | ... | | ... ... | +// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]| +// \---------------------/ \-----------------------------------------/ +// accumulators 8x8 block +// +/////////////////////////////////////////////////////////////////////////////// +//OptLoopMul4 RM 1x8 block +// /--------------------------------------------\ +// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | +// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| +// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]| +// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]| +// \--------------------------------------------/ +// LM 8x4 block +// /---------------------------------\ /--------------------------------------------\ +// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] | +// | ... ... ... ... | | ... ... | +// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] | +// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] | +// | ... ... ... ... | | ... ... | +// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] | +// \---------------------------------/ \--------------------------------------------/ +// accumulators 8x8 block +///////////////////////////////////////////////////////////////////////////////// +// +// void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// v0.s[0]: maxf +// v1.s[0]: minf +// w4: depth +// w5: row +// w6: col + +MatMulFloatNeon64: + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + mov w7, v0.s[0] + mov w8, v1.s[0] + mov w9, 0 // rm col offset + mov w10, 0 // lm row offset + mov w18, #32 // sizeof(float)*8 + mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth + +L1: + cmp w9, w6 + beq End1 + + mov w10, 0 // reset lm row offset + mov x12, x0 // reload lm ptr + mov x14, x3 // reload bias ptr +L2: + cmp w10, w6 + beq End2 + + mov w13, w4 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + +OptLoopMul4: + cmp w13, #4 + blt CommLoopMul + + ld1 {v0.4s, v1.4s}, [x12], #32 + ld1 {v8.4s, v9.4s}, [x1], #32 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[0] + fmla v18.4s, v8.4s, v0.s[1] + fmla v19.4s, v9.4s, v0.s[1] + fmla v20.4s, v8.4s, v0.s[2] + fmla v21.4s, v9.4s, v0.s[2] + fmla v22.4s, v8.4s, v0.s[3] + fmla v23.4s, v9.4s, v0.s[3] + ld1 {v10.4s, v11.4s}, [x1], #32 + fmla v24.4s, v8.4s, v1.s[0] + fmla v25.4s, v9.4s, v1.s[0] + fmla v26.4s, v8.4s, v1.s[1] + fmla v27.4s, v9.4s, v1.s[1] + ld1 {v2.4s, v3.4s}, [x12], #32 + fmla v28.4s, v8.4s, v1.s[2] + fmla v29.4s, v9.4s, v1.s[2] + fmla v30.4s, v8.4s, v1.s[3] + fmla v31.4s, v9.4s, v1.s[3] + fmla v16.4s, v10.4s, v2.s[0] + fmla v17.4s, v11.4s, v2.s[0] + fmla v18.4s, v10.4s, v2.s[1] + fmla v19.4s, v11.4s, v2.s[1] + fmla v20.4s, v10.4s, v2.s[2] + fmla v21.4s, v11.4s, v2.s[2] + fmla v22.4s, v10.4s, v2.s[3] + fmla v23.4s, v11.4s, v2.s[3] + ld1 {v12.4s, v13.4s}, [x1], #32 + fmla v24.4s, v10.4s, v3.s[0] + fmla v25.4s, v11.4s, v3.s[0] + fmla v26.4s, v10.4s, v3.s[1] + fmla v27.4s, v11.4s, v3.s[1] + ld1 {v4.4s, v5.4s}, [x12], #32 + fmla v28.4s, v10.4s, v3.s[2] + fmla v29.4s, v11.4s, v3.s[2] + fmla v30.4s, v10.4s, v3.s[3] + fmla v31.4s, v11.4s, v3.s[3] + fmla v16.4s, v12.4s, v4.s[0] + fmla v17.4s, v13.4s, v4.s[0] + fmla v18.4s, v12.4s, v4.s[1] + fmla v19.4s, v13.4s, v4.s[1] + fmla v20.4s, v12.4s, v4.s[2] + fmla v21.4s, v13.4s, v4.s[2] + fmla v22.4s, v12.4s, v4.s[3] + fmla v23.4s, v13.4s, v4.s[3] + ld1 {v6.4s,v7.4s}, [x12], #32 + fmla v24.4s, v12.4s, v5.s[0] + fmla v25.4s, v13.4s, v5.s[0] + fmla v26.4s, v12.4s, v5.s[1] + fmla v27.4s, v13.4s, v5.s[1] + ld1 {v14.4s, v15.4s}, [x1], #32 + fmla v28.4s, v12.4s, v5.s[2] + fmla v29.4s, v13.4s, v5.s[2] + fmla v30.4s, v12.4s, v5.s[3] + fmla v31.4s, v13.4s, v5.s[3] + fmla v16.4s, v14.4s, v6.s[0] + fmla v17.4s, v15.4s, v6.s[0] + fmla v18.4s, v14.4s, v6.s[1] + fmla v19.4s, v15.4s, v6.s[1] + fmla v20.4s, v14.4s, v6.s[2] + fmla v21.4s, v15.4s, v6.s[2] + fmla v22.4s, v14.4s, v6.s[3] + fmla v23.4s, v15.4s, v6.s[3] + fmla v24.4s, v14.4s, v7.s[0] + fmla v25.4s, v15.4s, v7.s[0] + fmla v26.4s, v14.4s, v7.s[1] + fmla v27.4s, v15.4s, v7.s[1] + fmla v28.4s, v14.4s, v7.s[2] + fmla v29.4s, v15.4s, v7.s[2] + fmla v30.4s, v14.4s, v7.s[3] + fmla v31.4s, v15.4s, v7.s[3] + subs w13, w13, #4 + b OptLoopMul4 + +CommLoopMul: + cmp w13, #1 + blt Bias + + ld1 {v0.4s, v1.4s}, [x12], #32 + ld1 {v2.4s, v3.4s}, [x1], #32 + fmla v16.4s, v2.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[0] + fmla v18.4s, v2.4s, v0.s[1] + fmla v19.4s, v3.4s, v0.s[1] + fmla v20.4s, v2.4s, v0.s[2] + fmla v21.4s, v3.4s, v0.s[2] + fmla v22.4s, v2.4s, v0.s[3] + fmla v23.4s, v3.4s, v0.s[3] + fmla v24.4s, v2.4s, v1.s[0] + fmla v25.4s, v3.4s, v1.s[0] + fmla v26.4s, v2.4s, v1.s[1] + fmla v27.4s, v3.4s, v1.s[1] + fmla v28.4s, v2.4s, v1.s[2] + fmla v29.4s, v3.4s, v1.s[2] + fmla v30.4s, v2.4s, v1.s[3] + fmla v31.4s, v3.4s, v1.s[3] + subs w13, w13, #1 + b CommLoopMul + +Bias: + cmp x3, #0 + beq Relu + ld1 {v0.4s}, [x14], #16 + ld1 {v1.4s}, [x14], #16 + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + +Relu: + dup v15.4s, w7 + dup v14.4s, w8 + fmax v16.4s, v16.4s, v14.4s + fmax v17.4s, v17.4s, v14.4s + fmax v18.4s, v18.4s, v14.4s + fmax v19.4s, v19.4s, v14.4s + fmax v20.4s, v20.4s, v14.4s + fmax v21.4s, v21.4s, v14.4s + fmax v22.4s, v22.4s, v14.4s + fmax v23.4s, v23.4s, v14.4s + fmax v24.4s, v24.4s, v14.4s + fmax v25.4s, v25.4s, v14.4s + fmax v26.4s, v26.4s, v14.4s + fmax v27.4s, v27.4s, v14.4s + fmax v28.4s, v28.4s, v14.4s + fmax v29.4s, v29.4s, v14.4s + fmax v30.4s, v30.4s, v14.4s + fmax v31.4s, v31.4s, v14.4s + + fmin v16.4s, v16.4s, v15.4s + fmin v17.4s, v17.4s, v15.4s + fmin v18.4s, v18.4s, v15.4s + fmin v19.4s, v19.4s, v15.4s + fmin v20.4s, v20.4s, v15.4s + fmin v20.4s, v20.4s, v15.4s + fmin v21.4s, v21.4s, v15.4s + fmin v22.4s, v22.4s, v15.4s + fmin v23.4s, v23.4s, v15.4s + fmin v24.4s, v24.4s, v15.4s + fmin v25.4s, v25.4s, v15.4s + fmin v26.4s, v26.4s, v15.4s + fmin v27.4s, v27.4s, v15.4s + fmin v28.4s, v28.4s, v15.4s + fmin v29.4s, v29.4s, v15.4s + fmin v30.4s, v30.4s, v15.4s + fmin v31.4s, v31.4s, v15.4s + +TransToOut: + st1 {v16.4s}, [x2], #16 + st1 {v17.4s}, [x2], #16 + st1 {v18.4s}, [x2], #16 + st1 {v19.4s}, [x2], #16 + st1 {v20.4s}, [x2], #16 + st1 {v21.4s}, [x2], #16 + st1 {v22.4s}, [x2], #16 + st1 {v23.4s}, [x2], #16 + st1 {v24.4s}, [x2], #16 + st1 {v25.4s}, [x2], #16 + st1 {v26.4s}, [x2], #16 + st1 {v27.4s}, [x2], #16 + st1 {v28.4s}, [x2], #16 + st1 {v29.4s}, [x2], #16 + st1 {v30.4s}, [x2], #16 + st1 {v31.4s}, [x2], #16 + + add w10, w10, #8 // lhs row offset + 8 + b L2 + +End2: + add w9, w9, #8 // rhs col offset + 8 + b L1 + +End1: + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matrix_add.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matrix_add.S new file mode 100644 index 0000000000..d361190361 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matrix_add.S @@ -0,0 +1,103 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global MatrixAdd +#ifndef __APPLE__ + .type MatrixAdd, %function +#endif + + + +//void MatrixAdd(const float* matDataA, const float* matDataB, float* matDataC, +// size_t aStride, size_t bStride, size_t cStride, size_t width, size_t height) + +//Auto: x0: matDataA, x1:matDataB, x2:matDatac, +//x3:aStride, x4:bStride, x5:cStride, x6:width, x7:height + +MatrixAdd: +mov x12, #4 //sizeof(float) +mul x3, x12, x3 +mul x4, x12, x4 +mul x5, x12, x5 + +loopH: +mov x8, x0 +mov x9, x1 +mov x10, x2 + +mov x11, x6 + +loop16LineIn: +cmp x11, #4 +blt L8 +sub x11, x11, #4 +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +fadd v4.4s, v0.4s, v2.4s +fadd v5.4s, v1.4s, v3.4s + +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +cmp x11, #4 +blt loop16LineOut + +loop16: +st1 {v4.4s, v5.4s}, [x2], #32 +fadd v10.4s, v6.4s, v8.4s +fadd v11.4s, v7.4s, v9.4s +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +st1 {v10.4s, v11.4s}, [x2], #32 +fadd v4.4s, v0.4s, v2.4s +fadd v5.4s, v1.4s, v3.4s +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +sub x11, x11, #4 +cmp x11, #4 +bge loop16 + +loop16LineOut: +st1 {v4.4s, v5.4s}, [x2], #32 +fadd v10.4s, v6.4s, v8.4s +fadd v11.4s, v7.4s, v9.4s +st1 {v10.4s, v11.4s}, [x2], #32 + + +L8: +cmp x11, #2 +blt L4 +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 +fadd v4.4s, v0.4s, v2.4s +fadd v5.4s, v1.4s, v3.4s +sub x11, x11, #2 +st1 {v4.4s, v5.4s}, [x2], #32 + + +cmp x11, #0 +beq loop16EndLine + +L4: +ld1 {v0.4s}, [x0], #16 +ld1 {v1.4s}, [x1], #16 +fadd v0.4s, v0.4s, v1.4s +sub x11, x11, #1 +st1 {v0.4s}, [x2], #16 +//bne L4 + +loop16EndLine: +add x0, x8, x3 +add x1, x9, x4 +add x2, x10, x5 + +subs x7, x7, #1 +bne loopH + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matrix_sub.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matrix_sub.S new file mode 100644 index 0000000000..7ac5f56a39 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matrix_sub.S @@ -0,0 +1,105 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global MatrixSub +#ifndef __APPLE__ + .type MatrixSub, %function +#endif + + + +//void MatrixSub(const float* matDataA, const float* matDataB, float* matDataC, +// size_t aStride, size_t bStride, size_t cStride, size_t width, size_t height) + +//Auto: x0: matDataA, x1:matDataB, x2:matDatac, +//x3:aStride, x4:bStride, x5:cStride, x6:width, x7:height + +MatrixSub: +mov x12, #4 //sizeof(float) +mul x3, x12, x3 +mul x4, x12, x4 +mul x5, x12, x5 + +loopH: +mov x8, x0 +mov x9, x1 +mov x10, x2 + +mov x11, x6 + +loop16LineIn: +cmp x11, #4 +blt L8 +sub x11, x11, #4 +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +fsub v4.4s, v0.4s, v2.4s +fsub v5.4s, v1.4s, v3.4s + +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +cmp x11, #4 +blt loop16LineOut + +loop16: +st1 {v4.4s, v5.4s}, [x2], #32 +fsub v10.4s, v6.4s, v8.4s +fsub v11.4s, v7.4s, v9.4s +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +st1 {v10.4s, v11.4s}, [x2], #32 +fsub v4.4s, v0.4s, v2.4s +fsub v5.4s, v1.4s, v3.4s +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +sub x11, x11, #4 +cmp x11, #4 +bge loop16 + +loop16LineOut: +st1 {v4.4s, v5.4s}, [x2], #32 +fsub v10.4s, v6.4s, v8.4s +fsub v11.4s, v7.4s, v9.4s +st1 {v10.4s, v11.4s}, [x2], #32 + +L8: +cmp x11, #2 +blt L4 + +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +fsub v4.4s, v0.4s, v2.4s +fsub v5.4s, v1.4s, v3.4s + +sub x11, x11, #2 +st1 {v4.4s, v5.4s}, [x2], #32 + + +cmp x11, #0 +beq loop16EndLine + +L4: +ld1 {v0.4s}, [x0], #16 +ld1 {v1.4s}, [x1], #16 +fsub v0.4s, v0.4s, v1.4s +sub x11, x11, #1 +st1 {v0.4s}, [x2], #16 + + +loop16EndLine: +add x0, x8, x3 +add x1, x9, x4 +add x2, x10, x5 + +subs x7, x7, #1 +bne loopH + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/relu.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/relu.S new file mode 100644 index 0000000000..74c40a135b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/relu.S @@ -0,0 +1,73 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global Relu +#ifndef __APPLE__ + .type Relu, %function +#endif + + +//void Relu(float* data, size_t element4) + +//Auto: x0:data, x1: element4 + +Relu: +cmp x1, #0 +beq ReluEnd + +dup v16.4s, wzr + +mov x5, x0 + +Loop16LineIn: +cmp x1, #4 +blt L4 +sub x1, x1, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmax v5.4s, v16.4s, v1.4s +fmax v6.4s, v16.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +cmp x1, #4 +blt Loop16LineOut + +Loop16: +st1 {v5.4s, v6.4s}, [x0], #32 +fmax v7.4s, v16.4s, v3.4s +fmax v8.4s, v16.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +st1 {v7.4s, v8.4s}, [x0], #32 +fmax v5.4s, v16.4s, v1.4s +fmax v6.4s, v16.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +sub x1, x1, #4 +cmp x1, #4 +bge Loop16 + +Loop16LineOut: +st1 {v5.4s, v6.4s}, [x0], #32 +fmax v7.4s, v16.4s, v3.4s +fmax v8.4s, v16.4s, v4.4s + +st1 {v7.4s, v8.4s}, [x0], #32 + +L4: +cmp x1, #0 +beq ReluEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fmax v2.4s, v16.4s, v0.4s +subs x1, x1, #1 +st1 {v2.4s}, [x0], #16 +bne Loop4 + +ReluEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/relu6.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/relu6.S new file mode 100644 index 0000000000..c1789845ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/relu6.S @@ -0,0 +1,89 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global Relu6 +#ifndef __APPLE__ + .type Relu6, %function +#endif + + +//void Relu6(float* data, size_t element4) + +//Auto: x0:data, x1: element4 + +Relu6: +cmp x1, #0 +beq Relu6End + +dup v16.4s, wzr +movi v17.4s, #6 +scvtf v17.4s, v17.4s + +mov x5, x0 + +Loop16LineIn: +cmp x1, #4 +blt L4 +sub x1, x1, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmax v21.4s, v1.4s, v16.4s +fmax v22.4s, v2.4s, v16.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +fmin v23.4s, v21.4s, v17.4s +fmin v24.4s, v22.4s, v17.4s + + +cmp x1, #4 +blt Loop16LineOut + +Loop16: +st1 {v23.4s, v24.4s}, [x0], #32 +fmax v25.4s, v3.4s, v16.4s +fmax v26.4s, v4.4s, v16.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmin v27.4s, v25.4s, v17.4s +fmin v28.4s, v26.4s, v17.4s +fmax v21.4s, v1.4s, v16.4s +fmax v22.4s, v2.4s, v16.4s + +st1 {v27.4s, v28.4s}, [x0], #32 +ld1 {v3.4s, v4.4s}, [x5], #32 +fmin v23.4s, v21.4s, v17.4s +fmin v24.4s, v22.4s, v17.4s + +sub x1, x1, #4 +cmp x1, #4 +bge Loop16 + +Loop16LineOut: +st1 {v23.4s, v24.4s}, [x0], #32 +fmax v25.4s, v3.4s, v16.4s +fmax v26.4s, v4.4s, v16.4s + +fmin v27.4s, v25.4s, v17.4s +fmin v28.4s, v26.4s, v17.4s +st1 {v27.4s, v28.4s}, [x0], #32 + +L4: +cmp x1, #0 +beq Relu6End +Loop4: +ld1 {v1.4s}, [x5], #16 +fmax v1.4s, v1.4s, v16.4s + +fmin v1.4s, v1.4s, v17.4s + +subs x1, x1, #1 +st1 {v1.4s}, [x0], #16 +bne Loop4 + +Relu6End: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/ConvDwFp16Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/ConvDwFp16Center.S new file mode 100644 index 0000000000..6b27af6a6e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/ConvDwFp16Center.S @@ -0,0 +1,294 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global ConvDwFp16Center +#ifndef __APPLE__ +.type ConvDwFp16Center, %function +#endif + +// void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +ConvDwFp16Center: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + ld1 {v24.8h}, [x3] + movi v26.8h, #0x46, lsl #8 + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x18, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v8.8h, v16.8h, v25.8h + fmla v9.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v10.8h, v18.8h, v25.8h + fmla v11.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v12.8h, v20.8h, v25.8h + fmla v13.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v14.8h, v22.8h, v25.8h + fmla v15.8h, v23.8h, v25.8h + subs x18, x18, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + fmin v8.8h, v8.8h, v26.8h + fmin v9.8h, v9.8h, v26.8h + fmin v10.8h, v10.8h, v26.8h + fmin v11.8h, v11.8h, v26.8h + fmin v12.8h, v12.8h, v26.8h + fmin v13.8h, v13.8h, v26.8h + fmin v14.8h, v14.8h, v26.8h + fmin v15.8h, v15.8h, v26.8h + Relu16: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + fmax v8.8h, v8.8h, v27.8h + fmax v9.8h, v9.8h, v27.8h + fmax v10.8h, v10.8h, v27.8h + fmax v11.8h, v11.8h, v27.8h + fmax v12.8h, v12.8h, v27.8h + fmax v13.8h, v13.8h, v27.8h + fmax v14.8h, v14.8h, v27.8h + fmax v15.8h, v15.8h, v27.8h + Write16: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + st1 {v8.8h}, [x3], x9 + st1 {v9.8h}, [x3], x9 + st1 {v10.8h}, [x3], x9 + st1 {v11.8h}, [x3], x9 + st1 {v12.8h}, [x3], x9 + st1 {v13.8h}, [x3], x9 + st1 {v14.8h}, [x3], x9 + st1 {v15.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x18, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + subs x18, x18, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + Relu8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + Write8: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x18, x7 + mov x22, x16 + LoopKw: + ld1 {v16.8h}, [x22], x13 + ld1 {v25.8h}, [x17], #16 + fmla v0.8h, v16.8h, v25.8h + subs x18, x18, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v26.8h + Relu: + fmax v0.8h, v0.8h, v27.8h + Write: + st1 {v0.8h}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + sub sp, sp, #48 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/DeconvDwFp16Center.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/DeconvDwFp16Center.S new file mode 100644 index 0000000000..1087856cb5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/DeconvDwFp16Center.S @@ -0,0 +1,64 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global DeconvDwFp16Center +#ifndef __APPLE__ +.type DeconvDwFp16Center, %function +#endif + +// void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +DeconvDwFp16Center: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x18, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.8h}, [x16], x8 + LoopKh: + mov x21, x18 + mov x13, x6 + LoopKw: + ld1 {v0.8h}, [x21] + ld1 {v2.8h}, [x19], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x18, x18, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + sub sp, sp, #32 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmFp16_16x8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmFp16_16x8.S new file mode 100644 index 0000000000..3c50aa362c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmFp16_16x8.S @@ -0,0 +1,720 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmFp16_16x8 +#ifndef __APPLE__ +.type IndirectGemmFp16_16x8, %function +#endif + +// void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, +// size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); +// x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset, +// x8:mode, x9: writeC4, x10:relu, x11: relu6 +// compute 8 channel for 16 outputs +IndirectGemmFp16_16x8: + + .macro INIT_BIAS + dup v16.4s, wzr + cbz x3, InitBias + ld1 {v16.8h}, [x3] + InitBias: + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + mov v24.16b, v16.16b + mov v25.16b, v16.16b + mov v26.16b, v16.16b + mov v27.16b, v16.16b + mov v28.16b, v16.16b + mov v29.16b, v16.16b + mov v30.16b, v16.16b + mov v31.16b, v16.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + // performance between storing 4 registers at the same time and seperatly storing them on in-order cores + // is not tested yet + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + ldr x8, [sp, #0] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + + cbnz x8, IndirectGemmStart + // step is one for common convolution, where ic8 should multiply by kernel size + // step is (a+b-1) for F(a,b) in winograd + mul x5, x4, x5 + mov x4, #1 + +IndirectGemmStart: + + LoopOc: + + mov x14, x4 + mov x12, x1 + + LoopKsize: + + mov x15, x0 + INIT_BIAS + // load input for output 1-8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + // load weight + ld1 {v8.8h, v9.8h}, [x2], #32 + // first 2 steps for output 1 and 3 + fmla v16.8h, v8.8h, v0.h[0] + fmla v18.8h, v8.8h, v1.h[0] + fmla v16.8h, v9.8h, v0.h[1] + fmla v18.8h, v9.8h, v1.h[1] + // load weight + ld1 {v10.8h, v11.8h}, [x2], #32 + // first 2 steps for output 2 and 4 + fmla v17.8h, v8.8h, v0.h[4] + fmla v19.8h, v8.8h, v1.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v19.8h, v9.8h, v1.h[5] + // load input for output 9-16 + // input cache should be refreshed after loading + // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + // last 2 steps for output 1 and 3 + fmla v16.8h, v10.8h, v0.h[2] + fmla v18.8h, v10.8h, v1.h[2] + fmla v16.8h, v11.8h, v0.h[3] + fmla v18.8h, v11.8h, v1.h[3] + + // check if ic4=1 + subs x13, x5, #1 + beq LoopIcEnd + + LoopIc: + // last 2 steps for output 2 and 4 + fmla v17.8h, v10.8h, v0.h[6] + fmla v19.8h, v10.8h, v1.h[6] + fmla v17.8h, v11.8h, v0.h[7] + fmla v19.8h, v11.8h, v1.h[7] + // steps for output 5-8 + fmla v20.8h, v8.8h, v2.h[0] + fmla v22.8h, v8.8h, v3.h[0] + fmla v20.8h, v9.8h, v2.h[1] + fmla v22.8h, v9.8h, v3.h[1] + fmla v21.8h, v8.8h, v2.h[4] + fmla v23.8h, v8.8h, v3.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v23.8h, v9.8h, v3.h[5] + fmla v20.8h, v10.8h, v2.h[2] + fmla v22.8h, v10.8h, v3.h[2] + fmla v20.8h, v11.8h, v2.h[3] + fmla v22.8h, v11.8h, v3.h[3] + fmla v21.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v3.h[6] + fmla v21.8h, v11.8h, v2.h[7] + fmla v23.8h, v11.8h, v3.h[7] + // load input for output 1-8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + // steps for output 9-12 + fmla v24.8h, v8.8h, v4.h[0] + fmla v26.8h, v8.8h, v5.h[0] + fmla v24.8h, v9.8h, v4.h[1] + fmla v26.8h, v9.8h, v5.h[1] + fmla v25.8h, v8.8h, v4.h[4] + fmla v27.8h, v8.8h, v5.h[4] + fmla v25.8h, v9.8h, v4.h[5] + fmla v27.8h, v9.8h, v5.h[5] + fmla v24.8h, v10.8h, v4.h[2] + fmla v26.8h, v10.8h, v5.h[2] + fmla v24.8h, v11.8h, v4.h[3] + fmla v26.8h, v11.8h, v5.h[3] + fmla v25.8h, v10.8h, v4.h[6] + fmla v27.8h, v10.8h, v5.h[6] + fmla v25.8h, v11.8h, v4.h[7] + fmla v27.8h, v11.8h, v5.h[7] + // steps for output 13-16 + fmla v28.8h, v8.8h, v6.h[0] + fmla v30.8h, v8.8h, v7.h[0] + fmla v28.8h, v9.8h, v6.h[1] + fmla v30.8h, v9.8h, v7.h[1] + fmla v29.8h, v8.8h, v6.h[4] + fmla v31.8h, v8.8h, v7.h[4] + fmla v29.8h, v9.8h, v6.h[5] + fmla v31.8h, v9.8h, v7.h[5] + // load weight + ld1 {v8.8h, v9.8h}, [x2], #32 + fmla v28.8h, v10.8h, v6.h[2] + fmla v30.8h, v10.8h, v7.h[2] + fmla v28.8h, v11.8h, v6.h[3] + fmla v30.8h, v11.8h, v7.h[3] + fmla v29.8h, v10.8h, v6.h[6] + fmla v31.8h, v10.8h, v7.h[6] + fmla v29.8h, v11.8h, v6.h[7] + fmla v31.8h, v11.8h, v7.h[7] + // load weight + ld1 {v10.8h, v11.8h}, [x2], #32 + // first 2 steps for output 1-4 + fmla v16.8h, v8.8h, v0.h[0] + fmla v18.8h, v8.8h, v1.h[0] + fmla v16.8h, v9.8h, v0.h[1] + fmla v18.8h, v9.8h, v1.h[1] + fmla v17.8h, v8.8h, v0.h[4] + fmla v19.8h, v8.8h, v1.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v19.8h, v9.8h, v1.h[5] + // load input for output 9-16 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + // last 2 steps for output 1 and 3 + fmla v16.8h, v10.8h, v0.h[2] + fmla v18.8h, v10.8h, v1.h[2] + fmla v16.8h, v11.8h, v0.h[3] + fmla v18.8h, v11.8h, v1.h[3] + + subs x13, x13, #1 + bne LoopIc + + LoopIcEnd: + fmla v17.8h, v10.8h, v0.h[6] + fmla v19.8h, v10.8h, v1.h[6] + fmla v17.8h, v11.8h, v0.h[7] + fmla v19.8h, v11.8h, v1.h[7] + // steps for output 5-8 + fmla v20.8h, v8.8h, v2.h[0] + fmla v22.8h, v8.8h, v3.h[0] + fmla v20.8h, v9.8h, v2.h[1] + fmla v22.8h, v9.8h, v3.h[1] + fmla v21.8h, v8.8h, v2.h[4] + fmla v23.8h, v8.8h, v3.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v23.8h, v9.8h, v3.h[5] + fmla v20.8h, v10.8h, v2.h[2] + fmla v22.8h, v10.8h, v3.h[2] + fmla v20.8h, v11.8h, v2.h[3] + fmla v22.8h, v11.8h, v3.h[3] + fmla v21.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v3.h[6] + fmla v21.8h, v11.8h, v2.h[7] + fmla v23.8h, v11.8h, v3.h[7] + // steps for output 9-12 + fmla v24.8h, v8.8h, v4.h[0] + fmla v26.8h, v8.8h, v5.h[0] + fmla v24.8h, v9.8h, v4.h[1] + fmla v26.8h, v9.8h, v5.h[1] + fmla v25.8h, v8.8h, v4.h[4] + fmla v27.8h, v8.8h, v5.h[4] + fmla v25.8h, v9.8h, v4.h[5] + fmla v27.8h, v9.8h, v5.h[5] + fmla v24.8h, v10.8h, v4.h[2] + fmla v26.8h, v10.8h, v5.h[2] + fmla v24.8h, v11.8h, v4.h[3] + fmla v26.8h, v11.8h, v5.h[3] + fmla v25.8h, v10.8h, v4.h[6] + fmla v27.8h, v10.8h, v5.h[6] + fmla v25.8h, v11.8h, v4.h[7] + fmla v27.8h, v11.8h, v5.h[7] + // steps for output 13-16 + fmla v28.8h, v8.8h, v6.h[0] + fmla v30.8h, v8.8h, v7.h[0] + fmla v28.8h, v9.8h, v6.h[1] + fmla v30.8h, v9.8h, v7.h[1] + fmla v29.8h, v8.8h, v6.h[4] + fmla v31.8h, v8.8h, v7.h[4] + fmla v29.8h, v9.8h, v6.h[5] + fmla v31.8h, v9.8h, v7.h[5] + fmla v28.8h, v10.8h, v6.h[2] + fmla v30.8h, v10.8h, v7.h[2] + fmla v28.8h, v11.8h, v6.h[3] + fmla v30.8h, v11.8h, v7.h[3] + fmla v29.8h, v10.8h, v6.h[6] + fmla v31.8h, v10.8h, v7.h[6] + fmla v29.8h, v11.8h, v6.h[7] + fmla v31.8h, v11.8h, v7.h[7] + + cbnz x11, Relu6 + cbnz x10, Relu + b WriteStart + Relu6: + movi v9.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v9.8h + fmin v17.8h, v17.8h, v9.8h + fmin v18.8h, v18.8h, v9.8h + fmin v19.8h, v19.8h, v9.8h + fmin v20.8h, v20.8h, v9.8h + fmin v21.8h, v21.8h, v9.8h + fmin v22.8h, v22.8h, v9.8h + fmin v23.8h, v23.8h, v9.8h + fmin v24.8h, v24.8h, v9.8h + fmin v25.8h, v25.8h, v9.8h + fmin v26.8h, v26.8h, v9.8h + fmin v27.8h, v27.8h, v9.8h + fmin v28.8h, v28.8h, v9.8h + fmin v29.8h, v29.8h, v9.8h + fmin v30.8h, v30.8h, v9.8h + fmin v31.8h, v31.8h, v9.8h + Relu: + dup v8.4s, wzr + fmax v16.8h, v16.8h, v8.8h + fmax v17.8h, v17.8h, v8.8h + fmax v18.8h, v18.8h, v8.8h + fmax v19.8h, v19.8h, v8.8h + fmax v20.8h, v20.8h, v8.8h + fmax v21.8h, v21.8h, v8.8h + fmax v22.8h, v22.8h, v8.8h + fmax v23.8h, v23.8h, v8.8h + fmax v24.8h, v24.8h, v8.8h + fmax v25.8h, v25.8h, v8.8h + fmax v26.8h, v26.8h, v8.8h + fmax v27.8h, v27.8h, v8.8h + fmax v28.8h, v28.8h, v8.8h + fmax v29.8h, v29.8h, v8.8h + fmax v30.8h, v30.8h, v8.8h + fmax v31.8h, v31.8h, v8.8h + + WriteStart: + cbnz x9, Write8 + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + cmp x6, #4 + beq Write4 + cmp x6, #5 + beq Write5 + cmp x6, #6 + beq Write6 + cmp x6, #7 + beq Write7 + b Write8 + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + Write1: + str h16, [x15] + add x15, x15, x7 + str h17, [x15] + add x15, x15, x7 + str h18, [x15] + add x15, x15, x7 + str h19, [x15] + add x15, x15, x7 + str h20, [x15] + add x15, x15, x7 + str h21, [x15] + add x15, x15, x7 + str h22, [x15] + add x15, x15, x7 + str h23, [x15] + add x15, x15, x7 + str h24, [x15] + add x15, x15, x7 + str h25, [x15] + add x15, x15, x7 + str h26, [x15] + add x15, x15, x7 + str h27, [x15] + add x15, x15, x7 + str h28, [x15] + add x15, x15, x7 + str h29, [x15] + add x15, x15, x7 + str h30, [x15] + add x15, x15, x7 + str h31, [x15] + add x0, x0, #2 + b WriteEnd + Write2: + str s16, [x15] + add x15, x15, x7 + str s17, [x15] + add x15, x15, x7 + str s18, [x15] + add x15, x15, x7 + str s19, [x15] + add x15, x15, x7 + str s20, [x15] + add x15, x15, x7 + str s21, [x15] + add x15, x15, x7 + str s22, [x15] + add x15, x15, x7 + str s23, [x15] + add x15, x15, x7 + str s24, [x15] + add x15, x15, x7 + str s25, [x15] + add x15, x15, x7 + str s26, [x15] + add x15, x15, x7 + str s27, [x15] + add x15, x15, x7 + str s28, [x15] + add x15, x15, x7 + str s29, [x15] + add x15, x15, x7 + str s30, [x15] + add x15, x15, x7 + str s31, [x15] + add x0, x0, #4 + b WriteEnd + Write3: + add x17, x15, #4 + str s16, [x15] + add x15, x15, x7 + st1 {v16.h}[2], [x17], x7 + str s17, [x15] + add x15, x15, x7 + st1 {v17.h}[2], [x17], x7 + str s18, [x15] + add x15, x15, x7 + st1 {v18.h}[2], [x17], x7 + str s19, [x15] + add x15, x15, x7 + st1 {v19.h}[2], [x17], x7 + str s20, [x15] + add x15, x15, x7 + st1 {v20.h}[2], [x17], x7 + str s21, [x15] + add x15, x15, x7 + st1 {v21.h}[2], [x17], x7 + str s22, [x15] + add x15, x15, x7 + st1 {v22.h}[2], [x17], x7 + str s23, [x15] + add x15, x15, x7 + st1 {v23.h}[2], [x17], x7 + str s24, [x15] + add x15, x15, x7 + st1 {v24.h}[2], [x17], x7 + str s25, [x15] + add x15, x15, x7 + st1 {v25.h}[2], [x17], x7 + str s26, [x15] + add x15, x15, x7 + st1 {v26.h}[2], [x17], x7 + str s27, [x15] + add x15, x15, x7 + st1 {v27.h}[2], [x17], x7 + str s28, [x15] + add x15, x15, x7 + st1 {v28.h}[2], [x17], x7 + str s29, [x15] + add x15, x15, x7 + st1 {v29.h}[2], [x17], x7 + str s30, [x15] + add x15, x15, x7 + st1 {v30.h}[2], [x17], x7 + str s31, [x15] + st1 {v31.h}[2], [x17] + add x0, x0, #6 + b WriteEnd + Write4: + str d16, [x15] + add x15, x15, x7 + str d17, [x15] + add x15, x15, x7 + str d18, [x15] + add x15, x15, x7 + str d19, [x15] + add x15, x15, x7 + str d20, [x15] + add x15, x15, x7 + str d21, [x15] + add x15, x15, x7 + str d22, [x15] + add x15, x15, x7 + str d23, [x15] + add x15, x15, x7 + str d24, [x15] + add x15, x15, x7 + str d25, [x15] + add x15, x15, x7 + str d26, [x15] + add x15, x15, x7 + str d27, [x15] + add x15, x15, x7 + str d28, [x15] + add x15, x15, x7 + str d29, [x15] + add x15, x15, x7 + str d30, [x15] + add x15, x15, x7 + str d31, [x15] + add x0, x0, #8 + b WriteEnd + Write5: + add x17, x15, #8 + str d16, [x15] + add x15, x15, x7 + st1 {v16.h}[4], [x17], x7 + str d17, [x15] + add x15, x15, x7 + st1 {v17.h}[4], [x17], x7 + str d18, [x15] + add x15, x15, x7 + st1 {v18.h}[4], [x17], x7 + str d19, [x15] + add x15, x15, x7 + st1 {v19.h}[4], [x17], x7 + str d20, [x15] + add x15, x15, x7 + st1 {v20.h}[4], [x17], x7 + str d21, [x15] + add x15, x15, x7 + st1 {v21.h}[4], [x17], x7 + str d22, [x15] + add x15, x15, x7 + st1 {v22.h}[4], [x17], x7 + str d23, [x15] + add x15, x15, x7 + st1 {v23.h}[4], [x17], x7 + str d24, [x15] + add x15, x15, x7 + st1 {v24.h}[4], [x17], x7 + str d25, [x15] + add x15, x15, x7 + st1 {v25.h}[4], [x17], x7 + str d26, [x15] + add x15, x15, x7 + st1 {v26.h}[4], [x17], x7 + str d27, [x15] + add x15, x15, x7 + st1 {v27.h}[4], [x17], x7 + str d28, [x15] + add x15, x15, x7 + st1 {v28.h}[4], [x17], x7 + str d29, [x15] + add x15, x15, x7 + st1 {v29.h}[4], [x17], x7 + str d30, [x15] + add x15, x15, x7 + st1 {v30.h}[4], [x17], x7 + str d31, [x15] + st1 {v31.h}[4], [x17] + add x0, x0, #10 + b WriteEnd + Write6: + add x17, x15, #8 + str d16, [x15] + add x15, x15, x7 + ins v0.s[0], v16.s[2] + str s0, [x17] + add x17, x17, x7 + str d17, [x15] + add x15, x15, x7 + ins v1.s[0], v17.s[2] + str s1, [x17] + add x17, x17, x7 + str d18, [x15] + add x15, x15, x7 + ins v2.s[0], v18.s[2] + str s2, [x17] + add x17, x17, x7 + str d19, [x15] + add x15, x15, x7 + ins v3.s[0], v19.s[2] + str s3, [x17] + add x17, x17, x7 + str d20, [x15] + add x15, x15, x7 + ins v4.s[0], v20.s[2] + str s4, [x17] + add x17, x17, x7 + str d21, [x15] + add x15, x15, x7 + ins v5.s[0], v21.s[2] + str s5, [x17] + add x17, x17, x7 + str d22, [x15] + add x15, x15, x7 + ins v6.s[0], v22.s[2] + str s6, [x17] + add x17, x17, x7 + str d23, [x15] + add x15, x15, x7 + ins v7.s[0], v23.s[2] + str s7, [x17] + add x17, x17, x7 + str d24, [x15] + add x15, x15, x7 + ins v8.s[0], v24.s[2] + str s8, [x17] + add x17, x17, x7 + str d25, [x15] + add x15, x15, x7 + ins v9.s[0], v25.s[2] + str s9, [x17] + add x17, x17, x7 + str d26, [x15] + add x15, x15, x7 + ins v10.s[0], v26.s[2] + str s10, [x17] + add x17, x17, x7 + str d27, [x15] + add x15, x15, x7 + ins v11.s[0], v27.s[2] + str s11, [x17] + add x17, x17, x7 + str d28, [x15] + add x15, x15, x7 + ins v12.s[0], v28.s[2] + str s12, [x17] + add x17, x17, x7 + str d29, [x15] + add x15, x15, x7 + ins v13.s[0], v29.s[2] + str s13, [x17] + add x17, x17, x7 + str d30, [x15] + add x15, x15, x7 + ins v14.s[0], v30.s[2] + str s14, [x17] + add x17, x17, x7 + str d31, [x15] + ins v15.s[0], v31.s[2] + str s15, [x17] + add x0, x0, #12 + b WriteEnd + Write7: + add x17, x15, #8 + add x16, x15, #12 + str d16, [x15] + add x15, x15, x7 + ins v0.s[0], v16.s[2] + str s0, [x17] + add x17, x17, x7 + st1 {v16.h}[6], [x16], x7 + str d17, [x15] + add x15, x15, x7 + ins v1.s[0], v17.s[2] + str s1, [x17] + add x17, x17, x7 + st1 {v17.h}[6], [x16], x7 + str d18, [x15] + add x15, x15, x7 + ins v2.s[0], v18.s[2] + str s2, [x17] + add x17, x17, x7 + st1 {v18.h}[6], [x16], x7 + str d19, [x15] + add x15, x15, x7 + ins v3.s[0], v19.s[2] + str s3, [x17] + add x17, x17, x7 + st1 {v19.h}[6], [x16], x7 + str d20, [x15] + add x15, x15, x7 + ins v4.s[0], v20.s[2] + str s4, [x17] + add x17, x17, x7 + st1 {v20.h}[6], [x16], x7 + str d21, [x15] + add x15, x15, x7 + ins v5.s[0], v21.s[2] + str s5, [x17] + add x17, x17, x7 + st1 {v21.h}[6], [x16], x7 + str d22, [x15] + add x15, x15, x7 + ins v6.s[0], v22.s[2] + str s6, [x17] + add x17, x17, x7 + st1 {v22.h}[6], [x16], x7 + str d23, [x15] + add x15, x15, x7 + ins v7.s[0], v23.s[2] + str s7, [x17] + add x17, x17, x7 + st1 {v23.h}[6], [x16], x7 + str d24, [x15] + add x15, x15, x7 + ins v8.s[0], v24.s[2] + str s8, [x17] + add x17, x17, x7 + st1 {v24.h}[6], [x16], x7 + str d25, [x15] + add x15, x15, x7 + ins v9.s[0], v25.s[2] + str s9, [x17] + add x17, x17, x7 + st1 {v25.h}[6], [x16], x7 + str d26, [x15] + add x15, x15, x7 + ins v10.s[0], v26.s[2] + str s10, [x17] + add x17, x17, x7 + st1 {v26.h}[6], [x16], x7 + str d27, [x15] + add x15, x15, x7 + ins v11.s[0], v27.s[2] + str s11, [x17] + add x17, x17, x7 + st1 {v27.h}[6], [x16], x7 + str d28, [x15] + add x15, x15, x7 + ins v12.s[0], v28.s[2] + str s12, [x17] + add x17, x17, x7 + st1 {v28.h}[6], [x16], x7 + str d29, [x15] + add x15, x15, x7 + ins v13.s[0], v29.s[2] + str s13, [x17] + add x17, x17, x7 + st1 {v29.h}[6], [x16], x7 + str d30, [x15] + add x15, x15, x7 + ins v14.s[0], v30.s[2] + str s14, [x17] + add x17, x17, x7 + st1 {v30.h}[6], [x16], x7 + str d31, [x15] + ins v15.s[0], v31.s[2] + str s15, [x17] + st1 {v31.h}[6], [x16] + add x0, x0, #14 + b WriteEnd + Write8: + st1 {v16.8h}, [x15], x7 + st1 {v17.8h}, [x15], x7 + st1 {v18.8h}, [x15], x7 + st1 {v19.8h}, [x15], x7 + st1 {v20.8h}, [x15], x7 + st1 {v21.8h}, [x15], x7 + st1 {v22.8h}, [x15], x7 + st1 {v23.8h}, [x15], x7 + st1 {v24.8h}, [x15], x7 + st1 {v25.8h}, [x15], x7 + st1 {v26.8h}, [x15], x7 + st1 {v27.8h}, [x15], x7 + st1 {v28.8h}, [x15], x7 + st1 {v29.8h}, [x15], x7 + st1 {v30.8h}, [x15], x7 + st1 {v31.8h}, [x15] + add x0, x0, #16 + + WriteEnd: + subs x14, x14, #1 + bne LoopKsize + + subs x6, x6, #8 + cbz x3, NoStepForward + add x3, x3, #16 + NoStepForward: + bgt LoopOc + + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S new file mode 100644 index 0000000000..278b4376b2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S @@ -0,0 +1,636 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt8_24x4_dp +#ifndef __APPLE__ +.type IndirectGemmInt8_24x4_dp, %function +#endif + +// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, +// size_t shift_before, size_t shift_after); +// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) +// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) +// the 44-48 bits indicates whether dotprod is supported +IndirectGemmInt8_24x4_dp: + + .macro INIT_BIAS + mov x20, x15 + ld1r {v8.4s}, [x20], #4 + ld1r {v9.4s}, [x20], #4 + ld1r {v10.4s}, [x20], #4 + ld1r {v11.4s}, [x20], #4 + ld1r {v12.4s}, [x20], #4 + ld1r {v13.4s}, [x20], #4 + ld1r {v14.4s}, [x20], #4 + ld1r {v15.4s}, [x20], #4 + ld1r {v16.4s}, [x20], #4 + ld1r {v17.4s}, [x20], #4 + ld1r {v18.4s}, [x20], #4 + ld1r {v19.4s}, [x20], #4 + ld1r {v20.4s}, [x20], #4 + ld1r {v21.4s}, [x20], #4 + ld1r {v22.4s}, [x20], #4 + ld1r {v23.4s}, [x20], #4 + ld1r {v24.4s}, [x20], #4 + ld1r {v25.4s}, [x20], #4 + ld1r {v26.4s}, [x20], #4 + ld1r {v27.4s}, [x20], #4 + ld1r {v28.4s}, [x20], #4 + ld1r {v29.4s}, [x20], #4 + ld1r {v30.4s}, [x20], #4 + ld1r {v31.4s}, [x20], #4 + dup v7.4s, wzr + cbz x3, InitBias + ld1 {v7.4s}, [x3] + InitBias: + sub v8.4s, v7.4s, v8.4s + sub v9.4s, v7.4s, v9.4s + sub v10.4s, v7.4s, v10.4s + sub v11.4s, v7.4s, v11.4s + sub v12.4s, v7.4s, v12.4s + sub v13.4s, v7.4s, v13.4s + sub v14.4s, v7.4s, v14.4s + sub v15.4s, v7.4s, v15.4s + sub v16.4s, v7.4s, v16.4s + sub v17.4s, v7.4s, v17.4s + sub v18.4s, v7.4s, v18.4s + sub v19.4s, v7.4s, v19.4s + sub v20.4s, v7.4s, v20.4s + sub v21.4s, v7.4s, v21.4s + sub v22.4s, v7.4s, v22.4s + sub v23.4s, v7.4s, v23.4s + sub v24.4s, v7.4s, v24.4s + sub v25.4s, v7.4s, v25.4s + sub v26.4s, v7.4s, v26.4s + sub v27.4s, v7.4s, v27.4s + sub v28.4s, v7.4s, v28.4s + sub v29.4s, v7.4s, v29.4s + sub v30.4s, v7.4s, v30.4s + sub v31.4s, v7.4s, v31.4s + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + ldr x15, [sp] + ldr w8, [sp, #8] + ldr w9, [sp, #16] + ldr w16, [sp, #24] + ldr w17, [sp, #32] + ldr w18, [sp, #40] + ldr w19, [sp, #48] + + mul x5, x4, x5 + mov x4, #1 + + LoopOc: + + mov x10, x4 + mov x12, x1 + + LoopKsize: + INIT_BIAS + mov x11, x0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + // load weight + ld1 {v6.16b}, [x2], #16 + // step for output 1-4 + .inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0] + .inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1] + .inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2] + .inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3] + // load input for output 9-16 + ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x12], #64 + // another step for output 5-8 + .inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0] + .inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1] + .inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2] + .inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3] + + subs x13, x5, #1 + beq LoopIcEndOne + // load weight + ld1 {v7.16b}, [x2], #16 + cmp x13, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] + .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] + .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] + .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] + .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] + .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] + .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] + .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] + ld1 {v2.16b, v3.16b}, [x12], #32 + .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] + .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] + .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] + .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] + .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] + .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] + .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] + .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] + // load input for output 9-16 + ld1 {v4.4s, v5.4s}, [x12], #32 + .inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0] + .inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1] + .inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2] + .inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3] + // another step for output 5-8 + .inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0] + .inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1] + .inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2] + .inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3] + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + .inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0] + .inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1] + .inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2] + .inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3] + .inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0] + .inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1] + .inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2] + .inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3] + // load weight + ld1 {v6.16b}, [x2], #16 + .inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0] + .inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1] + .inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2] + .inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3] + .inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0] + .inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1] + .inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2] + .inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3] + // load input for output 9-16 + ld1 {v2.4s, v3.4s}, [x12], #32 + .inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0] + .inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1] + .inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2] + .inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3] + // another step for output 5-8 + .inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0] + .inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1] + .inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2] + .inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3] + // load input for output 9-16 + ld1 {v4.4s, v5.4s}, [x12], #32 + + subs x13, x13, #2 + beq LoopIcEndOne + // load weight + ld1 {v7.16b}, [x2], #16 + cmp x13, #1 + beq LoopIcEnd + b LoopIc + + LoopIcEnd: + mov x20, x15 + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] + .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] + .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] + .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] + .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] + .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] + .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] + .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] + ld1 {v2.16b, v3.16b}, [x12], #32 + .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] + .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] + .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] + .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] + .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] + .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] + .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] + .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] + // load input for output 9-16 + ld1 {v4.4s, v5.4s}, [x12], #32 + .inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0] + .inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1] + .inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2] + .inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3] + .inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0] + .inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1] + .inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2] + .inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3] + + .inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0] + .inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1] + .inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2] + .inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3] + .inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0] + .inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1] + .inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2] + .inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3] + + .inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0] + .inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1] + .inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2] + .inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3] + .inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0] + .inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1] + .inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2] + .inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3] + b Quantization + + LoopIcEndOne: + .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] + .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] + .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] + .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] + .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] + .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] + .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] + .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] + + .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] + .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] + .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] + .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] + .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] + .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] + .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] + .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] + + Quantization: + dup v2.4s, w18 + sqshl v8.4s, v8.4s ,v2.4s + sqshl v9.4s, v9.4s ,v2.4s + sqshl v10.4s, v10.4s ,v2.4s + sqshl v11.4s, v11.4s ,v2.4s + sqshl v12.4s, v12.4s ,v2.4s + sqshl v13.4s, v13.4s ,v2.4s + sqshl v14.4s, v14.4s ,v2.4s + sqshl v15.4s, v15.4s ,v2.4s + sqshl v16.4s, v16.4s ,v2.4s + sqshl v17.4s, v17.4s ,v2.4s + sqshl v18.4s, v18.4s ,v2.4s + sqshl v19.4s, v19.4s ,v2.4s + sqshl v20.4s, v20.4s ,v2.4s + sqshl v21.4s, v21.4s ,v2.4s + sqshl v22.4s, v22.4s ,v2.4s + sqshl v23.4s, v23.4s ,v2.4s + sqshl v24.4s, v24.4s ,v2.4s + sqshl v25.4s, v25.4s ,v2.4s + sqshl v26.4s, v26.4s ,v2.4s + sqshl v27.4s, v27.4s ,v2.4s + sqshl v28.4s, v28.4s ,v2.4s + sqshl v29.4s, v29.4s ,v2.4s + sqshl v30.4s, v30.4s ,v2.4s + sqshl v31.4s, v31.4s ,v2.4s + + dup v3.4s, w17 + sqrdmulh v8.4s, v8.4s ,v3.4s + sqrdmulh v9.4s, v9.4s ,v3.4s + sqrdmulh v10.4s, v10.4s ,v3.4s + sqrdmulh v11.4s, v11.4s ,v3.4s + sqrdmulh v12.4s, v12.4s ,v3.4s + sqrdmulh v13.4s, v13.4s ,v3.4s + sqrdmulh v14.4s, v14.4s ,v3.4s + sqrdmulh v15.4s, v15.4s ,v3.4s + sqrdmulh v16.4s, v16.4s ,v3.4s + sqrdmulh v17.4s, v17.4s ,v3.4s + sqrdmulh v18.4s, v18.4s ,v3.4s + sqrdmulh v19.4s, v19.4s ,v3.4s + sqrdmulh v20.4s, v20.4s ,v3.4s + sqrdmulh v21.4s, v21.4s ,v3.4s + sqrdmulh v22.4s, v22.4s ,v3.4s + sqrdmulh v23.4s, v23.4s ,v3.4s + sqrdmulh v24.4s, v24.4s ,v3.4s + sqrdmulh v25.4s, v25.4s ,v3.4s + sqrdmulh v26.4s, v26.4s ,v3.4s + sqrdmulh v27.4s, v27.4s ,v3.4s + sqrdmulh v28.4s, v28.4s ,v3.4s + sqrdmulh v29.4s, v29.4s ,v3.4s + sqrdmulh v30.4s, v30.4s ,v3.4s + sqrdmulh v31.4s, v31.4s ,v3.4s + + dup v4.4s, w19 + sqrshl v8.4s, v8.4s ,v4.4s + sqrshl v9.4s, v9.4s ,v4.4s + sqrshl v10.4s, v10.4s ,v4.4s + sqrshl v11.4s, v11.4s ,v4.4s + sqrshl v12.4s, v12.4s ,v4.4s + sqrshl v13.4s, v13.4s ,v4.4s + sqrshl v14.4s, v14.4s ,v4.4s + sqrshl v15.4s, v15.4s ,v4.4s + sqrshl v16.4s, v16.4s ,v4.4s + sqrshl v17.4s, v17.4s ,v4.4s + sqrshl v18.4s, v18.4s ,v4.4s + sqrshl v19.4s, v19.4s ,v4.4s + sqrshl v20.4s, v20.4s ,v4.4s + sqrshl v21.4s, v21.4s ,v4.4s + sqrshl v22.4s, v22.4s ,v4.4s + sqrshl v23.4s, v23.4s ,v4.4s + sqrshl v24.4s, v24.4s ,v4.4s + sqrshl v25.4s, v25.4s ,v4.4s + sqrshl v26.4s, v26.4s ,v4.4s + sqrshl v27.4s, v27.4s ,v4.4s + sqrshl v28.4s, v28.4s ,v4.4s + sqrshl v29.4s, v29.4s ,v4.4s + sqrshl v30.4s, v30.4s ,v4.4s + sqrshl v31.4s, v31.4s ,v4.4s + + dup v5.4s, w16 + add v8.4s, v8.4s ,v5.4s + add v9.4s, v9.4s ,v5.4s + add v10.4s, v10.4s ,v5.4s + add v11.4s, v11.4s ,v5.4s + add v12.4s, v12.4s ,v5.4s + add v13.4s, v13.4s ,v5.4s + add v14.4s, v14.4s ,v5.4s + add v15.4s, v15.4s ,v5.4s + add v16.4s, v16.4s ,v5.4s + add v17.4s, v17.4s ,v5.4s + add v18.4s, v18.4s ,v5.4s + add v19.4s, v19.4s ,v5.4s + add v20.4s, v20.4s ,v5.4s + add v21.4s, v21.4s ,v5.4s + add v22.4s, v22.4s ,v5.4s + add v23.4s, v23.4s ,v5.4s + add v24.4s, v24.4s ,v5.4s + add v25.4s, v25.4s ,v5.4s + add v26.4s, v26.4s ,v5.4s + add v27.4s, v27.4s ,v5.4s + add v28.4s, v28.4s ,v5.4s + add v29.4s, v29.4s ,v5.4s + add v30.4s, v30.4s ,v5.4s + add v31.4s, v31.4s ,v5.4s + + dup v0.4s, w8 + smax v8.4s, v8.4s ,v0.4s + smax v9.4s, v9.4s ,v0.4s + smax v10.4s, v10.4s ,v0.4s + smax v11.4s, v11.4s ,v0.4s + smax v12.4s, v12.4s ,v0.4s + smax v13.4s, v13.4s ,v0.4s + smax v14.4s, v14.4s ,v0.4s + smax v15.4s, v15.4s ,v0.4s + smax v16.4s, v16.4s ,v0.4s + smax v17.4s, v17.4s ,v0.4s + smax v18.4s, v18.4s ,v0.4s + smax v19.4s, v19.4s ,v0.4s + smax v20.4s, v20.4s ,v0.4s + smax v21.4s, v21.4s ,v0.4s + smax v22.4s, v22.4s ,v0.4s + smax v23.4s, v23.4s ,v0.4s + smax v24.4s, v24.4s ,v0.4s + smax v25.4s, v25.4s ,v0.4s + smax v26.4s, v26.4s ,v0.4s + smax v27.4s, v27.4s ,v0.4s + smax v28.4s, v28.4s ,v0.4s + smax v29.4s, v29.4s ,v0.4s + smax v30.4s, v30.4s ,v0.4s + smax v31.4s, v31.4s ,v0.4s + + dup v1.4s, w9 + smin v8.4s, v8.4s ,v1.4s + smin v9.4s, v9.4s ,v1.4s + smin v10.4s, v10.4s ,v1.4s + smin v11.4s, v11.4s ,v1.4s + smin v12.4s, v12.4s ,v1.4s + smin v13.4s, v13.4s ,v1.4s + smin v14.4s, v14.4s ,v1.4s + smin v15.4s, v15.4s ,v1.4s + smin v16.4s, v16.4s ,v1.4s + smin v17.4s, v17.4s ,v1.4s + smin v18.4s, v18.4s ,v1.4s + smin v19.4s, v19.4s ,v1.4s + smin v20.4s, v20.4s ,v1.4s + smin v21.4s, v21.4s ,v1.4s + smin v22.4s, v22.4s ,v1.4s + smin v23.4s, v23.4s ,v1.4s + smin v24.4s, v24.4s ,v1.4s + smin v25.4s, v25.4s ,v1.4s + smin v26.4s, v26.4s ,v1.4s + smin v27.4s, v27.4s ,v1.4s + smin v28.4s, v28.4s ,v1.4s + smin v29.4s, v29.4s ,v1.4s + smin v30.4s, v30.4s ,v1.4s + smin v31.4s, v31.4s ,v1.4s + + sqxtn v6.4h, v8.4s + sqxtn2 v6.8h, v9.4s + sqxtn v0.8b, v6.8h + sqxtn v7.4h, v10.4s + sqxtn2 v7.8h, v11.4s + sqxtn2 v0.16b, v7.8h + + sqxtn v6.4h, v12.4s + sqxtn2 v6.8h, v13.4s + sqxtn v1.8b, v6.8h + sqxtn v7.4h, v14.4s + sqxtn2 v7.8h, v15.4s + sqxtn2 v1.16b, v7.8h + + sqxtn v6.4h, v16.4s + sqxtn2 v6.8h, v17.4s + sqxtn v2.8b, v6.8h + sqxtn v7.4h, v18.4s + sqxtn2 v7.8h, v19.4s + sqxtn2 v2.16b, v7.8h + + sqxtn v6.4h, v20.4s + sqxtn2 v6.8h, v21.4s + sqxtn v3.8b, v6.8h + sqxtn v7.4h, v22.4s + sqxtn2 v7.8h, v23.4s + sqxtn2 v3.16b, v7.8h + + sqxtn v6.4h, v24.4s + sqxtn2 v6.8h, v25.4s + sqxtn v4.8b, v6.8h + sqxtn v7.4h, v26.4s + sqxtn2 v7.8h, v27.4s + sqxtn2 v4.16b, v7.8h + + sqxtn v6.4h, v28.4s + sqxtn2 v6.8h, v29.4s + sqxtn v5.8b, v6.8h + sqxtn v7.4h, v30.4s + sqxtn2 v7.8h, v31.4s + sqxtn2 v5.16b, v7.8h + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + WriteStart: + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + st1 {v0.b}[0], [x11], x7 + st1 {v0.b}[4], [x11], x7 + st1 {v0.b}[8], [x11], x7 + st1 {v0.b}[12], [x11], x7 + st1 {v1.b}[0], [x11], x7 + st1 {v1.b}[4], [x11], x7 + st1 {v1.b}[8], [x11], x7 + st1 {v1.b}[12], [x11], x7 + st1 {v2.b}[0], [x11], x7 + st1 {v2.b}[4], [x11], x7 + st1 {v2.b}[8], [x11], x7 + st1 {v2.b}[12], [x11], x7 + st1 {v3.b}[0], [x11], x7 + st1 {v3.b}[4], [x11], x7 + st1 {v3.b}[8], [x11], x7 + st1 {v3.b}[12], [x11], x7 + st1 {v4.b}[0], [x11], x7 + st1 {v4.b}[4], [x11], x7 + st1 {v4.b}[8], [x11], x7 + st1 {v4.b}[12], [x11], x7 + st1 {v5.b}[0], [x11], x7 + st1 {v5.b}[4], [x11], x7 + st1 {v5.b}[8], [x11], x7 + st1 {v5.b}[12], [x11] + add x0, x0, #1 + b WriteEnd + Write2: + st1 {v0.h}[0], [x11], x7 + st1 {v0.h}[2], [x11], x7 + st1 {v0.h}[4], [x11], x7 + st1 {v0.h}[6], [x11], x7 + st1 {v1.h}[0], [x11], x7 + st1 {v1.h}[2], [x11], x7 + st1 {v1.h}[4], [x11], x7 + st1 {v1.h}[6], [x11], x7 + st1 {v2.h}[0], [x11], x7 + st1 {v2.h}[2], [x11], x7 + st1 {v2.h}[4], [x11], x7 + st1 {v2.h}[6], [x11], x7 + st1 {v3.h}[0], [x11], x7 + st1 {v3.h}[2], [x11], x7 + st1 {v3.h}[4], [x11], x7 + st1 {v3.h}[6], [x11], x7 + st1 {v4.h}[0], [x11], x7 + st1 {v4.h}[2], [x11], x7 + st1 {v4.h}[4], [x11], x7 + st1 {v4.h}[6], [x11], x7 + st1 {v5.h}[0], [x11], x7 + st1 {v5.h}[2], [x11], x7 + st1 {v5.h}[4], [x11], x7 + st1 {v5.h}[6], [x11] + add x0, x0, #2 + b WriteEnd + Write3: + add x14, x11, #2 + st1 {v0.h}[0], [x11], x7 + st1 {v0.b}[2], [x14], x7 + st1 {v0.h}[2], [x11], x7 + st1 {v0.b}[6], [x14], x7 + st1 {v0.h}[4], [x11], x7 + st1 {v0.b}[10], [x14], x7 + st1 {v0.h}[6], [x11], x7 + st1 {v0.b}[14], [x14], x7 + st1 {v1.h}[0], [x11], x7 + st1 {v1.b}[2], [x14], x7 + st1 {v1.h}[2], [x11], x7 + st1 {v1.b}[6], [x14], x7 + st1 {v1.h}[4], [x11], x7 + st1 {v1.b}[10], [x14], x7 + st1 {v1.h}[6], [x11], x7 + st1 {v1.b}[14], [x14], x7 + st1 {v2.h}[0], [x11], x7 + st1 {v2.b}[2], [x14], x7 + st1 {v2.h}[2], [x11], x7 + st1 {v2.b}[6], [x14], x7 + st1 {v2.h}[4], [x11], x7 + st1 {v2.b}[10], [x14], x7 + st1 {v2.h}[6], [x11], x7 + st1 {v2.b}[14], [x14], x7 + st1 {v3.h}[0], [x11], x7 + st1 {v3.b}[2], [x14], x7 + st1 {v3.h}[2], [x11], x7 + st1 {v3.b}[6], [x14], x7 + st1 {v3.h}[4], [x11], x7 + st1 {v3.b}[10], [x14], x7 + st1 {v3.h}[6], [x11], x7 + st1 {v3.b}[14], [x14], x7 + st1 {v4.h}[0], [x11], x7 + st1 {v4.b}[2], [x14], x7 + st1 {v4.h}[2], [x11], x7 + st1 {v4.b}[6], [x14], x7 + st1 {v4.h}[4], [x11], x7 + st1 {v4.b}[10], [x14], x7 + st1 {v4.h}[6], [x11], x7 + st1 {v4.b}[14], [x14], x7 + st1 {v5.h}[0], [x11], x7 + st1 {v5.b}[2], [x14], x7 + st1 {v5.h}[2], [x11], x7 + st1 {v5.b}[6], [x14], x7 + st1 {v5.h}[4], [x11], x7 + st1 {v5.b}[10], [x14], x7 + st1 {v5.h}[6], [x11], x7 + st1 {v5.b}[14], [x14], x7 + add x0, x0, #3 + b WriteEnd + Write4: + st1 {v0.s}[0], [x11], x7 + st1 {v0.s}[1], [x11], x7 + st1 {v0.s}[2], [x11], x7 + st1 {v0.s}[3], [x11], x7 + st1 {v1.s}[0], [x11], x7 + st1 {v1.s}[1], [x11], x7 + st1 {v1.s}[2], [x11], x7 + st1 {v1.s}[3], [x11], x7 + st1 {v2.s}[0], [x11], x7 + st1 {v2.s}[1], [x11], x7 + st1 {v2.s}[2], [x11], x7 + st1 {v2.s}[3], [x11], x7 + st1 {v3.s}[0], [x11], x7 + st1 {v3.s}[1], [x11], x7 + st1 {v3.s}[2], [x11], x7 + st1 {v3.s}[3], [x11], x7 + st1 {v4.s}[0], [x11], x7 + st1 {v4.s}[1], [x11], x7 + st1 {v4.s}[2], [x11], x7 + st1 {v4.s}[3], [x11], x7 + st1 {v5.s}[0], [x11], x7 + st1 {v5.s}[1], [x11], x7 + st1 {v5.s}[2], [x11], x7 + st1 {v5.s}[3], [x11] + add x0, x0, #4 + + WriteEnd: + + subs x10, x10, #1 + bne LoopKsize + + subs x6, x6, #4 + cbz x3, NoStepFowrard + add x3, x3, #16 + NoStepFowrard: + bgt LoopOc + + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.cc new file mode 100644 index 0000000000..bbe4dc1f3c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/runtime/kernel/arm/nnacl/batch_norm.h" + +static void sumSpatialBatch(const float *in, int size, int ch, float *out) { + std::fill(out, out + ch, 0.f); + for (int i = 0; i < size; i++) { + const float *ptr = in + i * ch; + for (int c = 0; c < ch; c++) { + out[c] += ptr[c]; + } + } +} + +void scaleBias(const float *scales, int batch, int n, int size, float *output) { + for (int i = 0; i < batch * size; i++) + for (int c = 0; c < n; c++) output[i * n + c] *= scales[c]; +} + +void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial, + float *out) { + int b, f, i; + for (b = 0; b < batch; ++b) { + for (i = 0; i < spatial; ++i) { + for (f = 0; f < filters; ++f) { + int index = b * filters * spatial + i * filters + f; + out[index] = (x[index] - mean[f]) / (std::sqrt(variance[f]) + eps); + } + } + } +} + +void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates) { + int i, b, f; + std::fill(scale_updates, scale_updates + n, 0.f); + for (b = 0; b < batch; ++b) { + for (i = 0; i < size; ++i) { + for (f = 0; f < n; ++f) { + int index = (b * size + i) * n + f; + scale_updates[f] += delta[index] * x_norm[index]; + } + } + } +} + +void meanVar(const float *in, int batch, int spatial, int ch, float *mean, float *var) { + float N = batch * spatial; + sumSpatialBatch(in, N, ch, mean); + for (int f = 0; f < ch; ++f) mean[f] /= N; + std::fill(var, var + ch, 0.f); + for (int i = 0; i < N; i++) { + for (int f = 0; f < ch; f++) { + float x = in[i * ch + f]; + var[f] += (x - mean[f]) * (x - mean[f]); + } + } + for (int f = 0; f < ch; f++) var[f] /= N; +} + +void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta) { + sumSpatialBatch(yt, size, ch, mean_delta); + for (int i = 0; i < ch; i++) mean_delta[i] *= -1.f / std::sqrt((variance[i] + eps)); +} + +void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial, + float *mean_add, float *mean_delta) { + int i, k; + std::fill(mean_add, mean_add + filters, 0.f); + for (k = 0; k < spatial * batch; ++k) { + for (i = 0; i < filters; ++i) { + int index = k * filters + i; + mean_add[i] += x[index] - mean[i]; + } + } + for (i = 0; i < filters; ++i) { + mean_add[i] *= variance_delta[i] * (-2.f / (spatial * batch)); + mean_delta[i] += mean_add[i]; + } +} + +void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int filters, + int spatial, float eps, float *variance_delta) { + int i, k; + std::fill(variance_delta, variance_delta + filters, 0.f); + for (k = 0; k < batch * spatial; k++) { + for (i = 0; i < filters; i++) { + int index = k * filters + i; + variance_delta[i] += delta[index] * (x[index] - mean[i]); + } + } + for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * pow(variance[i] + eps, (-3.f / 2.f)); +} + +void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, + const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta) { + int f, k; + for (k = 0; k < batch * spatial; k++) { + for (f = 0; f < filters; f++) { + int index = k * filters + f; + delta[index] = delta[index] * 1. / (std::sqrt(variance[f] + eps)) + + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + + mean_delta[f] / (spatial * batch); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.h new file mode 100644 index 0000000000..0d9e8b74bf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ + +struct bnParameter { + int batch; + int channels; + int spatial; + float eps; +}; +void scaleBias(const float *scales, int batch, int n, int size, float *output); +void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial, + float *out); +void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates); +void meanVar(const float *in, int batch, int size, int ch, float *mean, float *var); +void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta); +void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int ch, + int spatial, float eps, float *variance_delta); +void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial, + float *mean_add, float *mean_delta); +void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, + const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); + +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.cc new file mode 100644 index 0000000000..e114f59cbf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/batch_to_space.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + size_t stride_h = block_w * out_n; + size_t output_offset = 0; + size_t copy_size = in_c * data_size; + size_t in_stride_h = in_w * in_c; + size_t in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + size_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + size_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy(reinterpret_cast(output) + output_offset, + reinterpret_cast(input) + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} + +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size) { + int block_h = block[0]; + int block_w = block[1]; + int in_n = in_shape[0]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + size_t stride_h = block_w * out_n; + size_t output_offset = 0; + size_t copy_size = in_c * data_size; + size_t in_stride_h = in_w * in_c; + size_t in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + size_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + size_t h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + size_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + size_t w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy(reinterpret_cast(output) + output_offset, + reinterpret_cast(input) + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.h new file mode 100644 index 0000000000..0cee0f071b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_BATCH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_BATCH_TO_SPACE_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +#define BATCH_TO_SPACE_BLOCK_SHAPE_SIZE 2 +#define BATCH_TO_SPACE_CROPS_SIZE 4 + +struct BatchToSpaceParameter { + OpParameter op_parameter_; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; + int32_t crops_[BATCH_TO_SPACE_CROPS_SIZE]; +}; + +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size); +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_TO_SPACE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc new file mode 100644 index 0000000000..7e83fded24 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +#ifndef ENABLE_ARM64 +void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, + int output_channel, size_t offset, size_t relu, size_t relu6) { + for (int i = 0; i < TILE_NUM; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * output_channel; + for (int j = 0; j < output_channel; j++) { + int oc8_block = j / C8NUM; + int oc8_res = j % C8NUM; + int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; + int out_oc_offset = output_tile_offset + j; + + float acc = 0; + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; + for (int m = 0; m < C4NUM; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * C8NUM; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + } + acc += bias[j]; + if (relu) { + acc = acc > 0 ? acc : 0; + } else if (relu6) { + if (acc < 0) { + acc = 0; + } else if (acc > 6) { + acc = 6; + } else { + } + } + (output + out_oc_offset)[0] = acc; + } + } +} + +void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6) { + int oc4 = UP_DIV(output_channel, C4NUM); + if (mode && writeC4) { + for (int i = 0; i < TILE_NUM; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * oc4 * C4NUM * step; + for (int j = 0; j < output_channel; j++) { + int oc4_block = j / 4; + int oc4_res = j % 4; + int oc8_block = oc4_block / 2; + int oc8_res = oc4_block % 2; + int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res * C4NUM + oc4_res; + int out_oc_offset = output_tile_offset + oc4_block * step * C4NUM + oc4_res; + + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; + int output_kw_offset = out_oc_offset + n * C4NUM; + float acc = 0; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; + for (int m = 0; m < 4; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * C8NUM; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + (output + output_kw_offset)[0] = acc; + } + } + } + } else if (mode) { + IndirectGemmFp32_Comm(output, input, weight, ic4, C8NUM, output_channel, offset); + } else { + IndirectGemmFp32(output, input, weight, bias, step, ic4, output_channel, offset, relu, relu6); + } +} +#endif +#ifndef ENABLE_ARM32 +void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6) { + for (int i = 0; i < TILE_NUM; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * output_channel; + for (int j = 0; j < output_channel; j++) { + int oc4_block = j / C4NUM; + int oc4_res = j % C4NUM; + int weight_oc_offset = oc4_block * step * ic4 * C4NUM * C4NUM + oc4_res; + int out_oc_offset = output_tile_offset + j; + + float acc = 0; + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C4NUM; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * C4NUM; + for (int m = 0; m < C4NUM; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * C4NUM; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + } + acc += bias[j]; + if (relu) { + acc = acc > 0 ? acc : 0; + } else if (relu6) { + if (acc < 0) { + acc = 0; + } else if (acc > 6) { + acc = 6; + } else { + } + } + (output + out_oc_offset)[0] = acc; + } + } +} +#endif + +int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } + +int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } + +void ReluFp32(float *data, int ele_num) { + int four_block = UP_DIV(ele_num, C4NUM); + for (int i = 0; i < four_block - 1; i++) { + int index = i * C4NUM; +#ifdef ENABLE_NEON + float32x4_t relu_data = vld1q_f32(data + index); + float32x4_t zero_data = vdupq_n_f32(0); + relu_data = vmaxq_f32(relu_data, zero_data); + vst1q_f32(data + index, relu_data); +#else + data[index] = data[index] < 0 ? 0 : data[index]; + data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; + data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; + data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; +#endif + } + for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) { + data[j] = data[j] < 0 ? 0 : data[j]; + } +} + +void Relu6Fp32(float *data, int ele_num) { + int four_block = UP_DIV(ele_num, C4NUM); + for (int i = 0; i < four_block - 1; i++) { + int index = i * C4NUM; +#ifdef ENABLE_NEON + float32x4_t relu6_data = vld1q_f32(data + index); + float32x4_t zero_data = vdupq_n_f32(0); + float32x4_t six_data = vdupq_n_f32(6); + relu6_data = vmaxq_f32(relu6_data, zero_data); + relu6_data = vminq_f32(relu6_data, six_data); + vst1q_f32(data + index, relu6_data); +#else + data[index] = data[index] < 0 ? 0 : data[index]; + data[index] = data[index] > 6 ? 6 : data[index]; + data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; + data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1]; + data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; + data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2]; + data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; + data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3]; +#endif + } + for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) { + data[j] = data[j] < 0 ? 0 : data[j]; + data[j] = data[j] > 6 ? 6 : data[j]; + } +} + +void IndirectGemmFp32_Comm(float *output, const float *input, const float *weight, size_t ic4, size_t hw, size_t oc, + size_t offset) { + for (int r = 0; r < hw; r++) { + for (int c = 0; c < oc; c++) { + float value = 0; + for (int deep = 0; deep < ic4; deep++) { + int d4mod = deep % 4; + int d4div = deep / 4; + int a_index = d4div * 4 * 8 + r * 4 + d4mod; + int b_index = 8 * deep + c; + value += input[a_index] * weight[b_index]; + } + output[r * offset + c] = value; + } + } + return; +} + +void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { + /* (int32_t)row8x8-major * multiplier + bias => (int8)relu => (int8_t)row-major */ + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c8div = c / 8, c8mod = c % 8; + int src_index = c8div * plane8 * 8 + r * 8 + c8mod; + int dst_index = r * oc + c; + int32_t value = in[src_index]; + if (bias != nullptr) { + value = in[src_index] + bias[c]; + } + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + out[dst_index] = (int8_t)value; + } + } + return; +} + +void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp) { + /* (int32_t)row8x8-major * multiplier => (int8_t)row-major */ + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c8div = c / 8, c8mod = c % 8; + int src_index = c8div * plane8 * 8 + r * 8 + c8mod; + int dst_index = r * oc + c; + int32_t value = in[src_index]; + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(CHAR_MAX, value); + value = MSMAX(CHAR_MIN, value); + out[dst_index] = (int8_t)value; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h new file mode 100644 index 0000000000..82daa49f25 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_COMMON_FUNC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_COMMON_FUNC_H_ + +#include +#include +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int8_t MinInt8(int8_t a, int8_t b); +int8_t MaxInt8(int8_t a, int8_t b); +void ReluFp32(float *data, int ele_num); +void Relu6Fp32(float *data, int ele_num); +void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); +void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp); +void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +void IndirectGemmFp32_Comm(float *output, const float *input, const float *weight, size_t ic4, size_t hw, size_t oc, + size_t offset); +void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, + int output_channel, size_t offset, size_t relu, size_t relu6); + +inline int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; +} + +inline int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3]; +} + +inline int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); } + +#ifdef ENABLE_ARM64 +void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size); +void Relu6(float *data, size_t element4); +void Relu(float *data, size_t element4); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_COMMON_FUNC_H_ */ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/concat_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/concat_parameter.h new file mode 100644 index 0000000000..5b998a4d2d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/concat_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONCAT_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONCAT_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +struct ConcatParameter { + OpParameter op_parameter_; + ConcatQuantArg *concat_quant_arg_; + int axis_; + int thread_count_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONCAT_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h new file mode 100644 index 0000000000..8a4d0ef843 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONV_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONV_PARAMETER_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +struct ConvParameter { + OpParameter op_parameter_; + ConvQuantArg conv_quant_arg_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_h_; + int pad_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int group_; + int tile_num_; + int input_batch_; + int input_h_; + int input_w_; + int input_channel_; + int output_batch_; + int output_h_; + int output_w_; + int output_channel_; + int thread_num_; + int input_unit_; + int output_unit_; + bool is_relu_; + bool is_relu6_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONV_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/crop_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/crop_parameter.h new file mode 100644 index 0000000000..93907046d1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/crop_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CROP_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CROP_PARAMETER_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +#define CROP_OFFSET_MAX_SIZE 4 + +struct CropParameter { + OpParameter op_parameter_; + CropQuantArg quant_arg; + int thread_count_; + int thread_id_; + int offset_size_; + int64_t offset_[CROP_OFFSET_MAX_SIZE]; + int64_t in_offset_[CROP_OFFSET_MAX_SIZE]; + int64_t axis_; + const int *in_shape_; + const int *out_shape_; + int input_dim_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CROP_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.cc new file mode 100644 index 0000000000..6ded162136 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/depth_to_space.h" +#include + +void DepthToSpaceForNHWC(const void *input, void *output, int *in_shape, DepthToSpaceParameter *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = block_size * param->out_stride_dim2_ * param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_offset_n = i * param->in_stride_dim0_; + size_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + size_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + size_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + size_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + size_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + size_t out_offset = (out_offset_w + l * param->out_stride_dim1_) * param->data_type_size_; + size_t in_offset = (in_offset_w + l * block_size * param->out_stride_dim2_) * param->data_type_size_; + memcpy(reinterpret_cast(output) + out_offset, reinterpret_cast(input) + in_offset, + copy_size); + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.h new file mode 100644 index 0000000000..88ae21c840 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.h @@ -0,0 +1,21 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_DEPTH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_DEPTH_TO_SPACE_H_ +#include "src/runtime/kernel/arm/nnacl/depth_to_space_parameter.h" + +void DepthToSpaceForNHWC(const void *input, void *output, int *in_shape, DepthToSpaceParameter *param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_DEPTH_TO_SPACE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space_parameter.h new file mode 100644 index 0000000000..0a44d45dba --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct DepthToSpaceParameter { + OpParameter op_parameter_; + int32_t block_size_; + int32_t in_stride_dim0_; + int32_t in_stride_dim1_; + int32_t in_stride_dim2_; + int32_t out_stride_dim0_; + int32_t out_stride_dim1_; + int32_t out_stride_dim2_; + uint8_t data_type_size_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_DEPTH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/errorcode.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/errorcode.h new file mode 100644 index 0000000000..8c74dcbf39 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/errorcode.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ERRORCODE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ERRORCODE_H_ + +enum ErrorCodeCommonEnum { + NNACL_OK = 0, + NNACL_ERR = 1, + NNACL_NULL_PTR, + NNACL_PARAM_INVALID, + OPLIB_COMMON_END = 9999 +}; + +enum ErrorCodeFp32OpEnum { + NNACL_ERRCODE_OP_FP32_START = 10000, + NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, + NNACL_ERRCODE_REVERSE_MALLOC, + NNACL_ERRCODE_SQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_DIVISOR_ZERO, + NNACL_ERRCODE_INDEX_OUT_OF_RANGE, + NNACL_ERRCODE_OP_FP32_END = 19999 +}; + +enum ErrorCodeFp16OpEnum { NNACL_ERRCODE_OP_FP16_START = 20000, NNACL_ERRCODE_OP_FP16_END = 29999 }; + +enum ErrorCodeUint8OpEnum { NNACL_ERRCODE_OP_UINT8_START = 30000, NNACL_ERRCODE_OP_UINT8_END = 39999 }; + +enum ErrorCodeInt8OpEnum { NNACL_ERRCODE_OP_INT8_START = 40000, NNACL_ERRCODE_OP_INT8_END = 49999 }; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ERRORCODE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/flatten.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/flatten.cc new file mode 100644 index 0000000000..2ca6ad285e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/flatten.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/flatten.h" +#include + +void Flatten(const void *input, void *output, FlattenParameter *flatten_param) { + memcpy(output, input, flatten_param->size); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/flatten.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/flatten.h new file mode 100644 index 0000000000..feb9c0641c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/flatten.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FLATTEN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FLATTEN_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct FlattenParameter { + OpParameter op_parameter_; + int size; +}; + +void Flatten(const void *input, void *output, FlattenParameter *flatten_param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FLATTEN_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/common_func.h new file mode 100644 index 0000000000..ba0c7a35b8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/common_func.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_COMMON_FUNC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_COMMON_FUNC_H_ + +#include +#include +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, + size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, + size_t in_kw_step, size_t relu, size_t relu6); +void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.cc new file mode 100644 index 0000000000..254bbaf115 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.cc @@ -0,0 +1,335 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.h" +#include +#include "src/runtime/kernel/arm/nnacl/fp16/common_func.h" + +/*conv depthwise fp16 begin*/ +void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int in_kh_step, int in_kw_step, int kernel_w, bool is_relu, + bool is_relu6) { + for (int c = 0; c < C8NUM; c++) { + dst[c] = 0; + } + const float16_t *src_kh = src; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst, dst_8); + + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + for (int c = 0; c < C8NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + float16_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float16_t *src_h = src + ih * sliding->in_h_step_; + + float16_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float16_t *src_w = src_h + iw * sliding->block_channel_; + + const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + + DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, conv_param->is_relu_, + conv_param->is_relu6_); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, + int in_sh_step, int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float16_t *src_kh = src_w; + const float16_t *weight_kh = weight; + for (int c = 0; c < C8NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_w); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_w, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp16: sliding window +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DepthwiseBorderFp16(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + ConvDwFp16Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_ * sizeof(float16_t), sliding->block_channel_ * sizeof(float16_t), + sliding->in_sh_step_ * sizeof(float16_t), sliding->in_sw_step_ * sizeof(float16_t), + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->is_relu_, conv_param->is_relu6_); +#else + DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_); +#endif + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nchwc8 +} +/*conv depthwise fp16 end*/ + +/*deconv depthwise fp16 begin*/ +void DeconvDepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w) { + float16_t *dst_kh = dst; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); + + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop +} + +void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int top, int bottom, + int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + const float16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float16_t *dst_h = dst + oh * sliding->in_h_step_; + + const float16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float16_t *dst_w = dst_h + ow * sliding->block_channel_; + + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + float16_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float16_t *dst_kh = dst_w; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float16x8_t src_8 = vld1q_f16(src_w); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel, + const ConvParameter *conv_param) { + float16_t *dst_k = dst; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C8NUM; c++) { + dst_k[c] += bias[c]; + dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]); + dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]); + } + dst_k += block_channel; + } +} + +// deconv depthwise fp16: sliding window +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_ * sizeof(float16_t), sliding->block_channel_ * sizeof(float16_t), + sliding->in_sh_step_ * sizeof(float16_t), sliding->in_sw_step_ * sizeof(float16_t), + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t)); +#else + DeconvDepthwiseCenterFp16(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, + sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDepthwisePostFuncFp16(dst_data, bias, sliding->block_channel_, conv_param); + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nchwc8 +} +/*deconv depthwise fp16 end*/ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.h new file mode 100644 index 0000000000..dec8c48eeb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_depthwise_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_CONV_DEPTHWISE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_CONV_DEPTHWISE_FP16_H_ + +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_CONV_DEPTHWISE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc new file mode 100644 index 0000000000..8bd1e2f023 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc @@ -0,0 +1,232 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" +#include +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h" + +extern "C" { +#ifdef ENABLE_ARM64 +void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +#endif +} + +#ifndef ENABLE_NEON +void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6) { + int tile_n = 16; + for (int i = 0; i < out_channel; i++) { + int oc8_block = i / 8; + int oc8_res = i % 8; + int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; + for (int k = 0; k < tile_n; k++) { + int input_tile_offset = k * C4NUM; + int out_tile_offset = i + k * out_channel; + + float16_t tmp_out = 0; + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * tile_n * ic4 * C4NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; + for (int j = 0; j < ic4; j++) { + int input_ic4_offset = input_kw_offset + j * tile_n * C4NUM; + int weight_ic4_offset = weight_kw_offset + j * C4NUM * C8NUM; + for (int m = 0; m < C4NUM; m++) { + int input_c4_offset = input_ic4_offset + m; + int weight_c4_offset = weight_ic4_offset + m * C8NUM; + tmp_out += (input + input_c4_offset)[0] * (weight + weight_c4_offset)[0]; + } + } + } + + (output + out_tile_offset)[0] = tmp_out + bias[i]; + if (relu) { + (output + out_tile_offset)[0] = (output + out_tile_offset)[0] < 0 ? 0 : (output + out_tile_offset)[0]; + } else if (relu6) { + (output + out_tile_offset)[0] = (output + out_tile_offset)[0] < 0 ? 0 : (output + out_tile_offset)[0]; + (output + out_tile_offset)[0] = (output + out_tile_offset)[0] > 6 ? 6 : (output + out_tile_offset)[0]; + } + } + } +} + +void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *weight, const float16_t *bias, + size_t step, size_t ic4, size_t output_channel, size_t offset, size_t mode, + size_t writeC4, size_t relu, size_t relu6) { + int tile_num = 16; + if (mode) { + for (int i = 0; i < tile_num; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * output_channel * 36; + for (int j = 0; j < output_channel; j++) { + int oc8_block = j / 8; + int oc8_res = j % 8; + int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res; + int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res; + + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * 8; + int output_kw_offset = out_oc_offset + n * C8NUM; + float16_t acc = 0; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * 8; + for (int m = 0; m < 4; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * 8; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + + (output + output_kw_offset)[0] = acc; + } + } + } + } else { + } +} +#endif + +// fp16 convolution common (im2col+gemm) +void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, + float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + bool relu = conv_param->is_relu_; + bool relu6 = conv_param->is_relu6_; + // todo + int thread_count = conv_param->thread_num_; + int tile_n = 16; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + + int channel_block = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * channel_block * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + // we accumulate 4 channels per time for input blocks + int ic4 = UP_DIV(in_channel, C4NUM); + int conv_depth = kernel_h * kernel_w; + // bytes from one output's i-th channel to the next output's i-th channel + // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + float16_t *gemm_input = (float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); + Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + if (real_cal_num == tile_n) { + float16_t *gemm_output = output_data + out_offset; + IndirectGemmFp16_16x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, + out_channel * sizeof(float16_t), 0, 0, relu, relu6); + } else { + // res part + IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, + out_channel * sizeof(float16_t), 0, 0, relu, relu6); + memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t)); + } + } + } +} + +// fp16 conv3x3 +void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, + float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, + int task_id, ConvParameter *conv_param) { + // todo + int thread_count = conv_param->thread_num_; + int tile_num = 16; + int output_unit = 4; + int k_plane = 36; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + + int output_batch = conv_param->output_batch_; + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + + int out_w_block = UP_DIV(conv_param->output_w_, C4NUM); + int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, tile_num); + int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM; + int block_unit_buffer_offset = k_plane * C4NUM; + int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM; + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_num; + int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; + + Conv3x3Fp16InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + IndirectGemmFp16_16x8(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, + tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM, + oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); + + Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, + real_cal_num, out_w_block, conv_param); + } + } + + // get real output + // todo + bool relu = conv_param->is_relu_; + bool relu6 = conv_param->is_relu6_; + for (int batch = 0; batch < output_batch; batch++) { + int batch_size = batch * output_channel * output_h * output_w; + for (int h = 0; h < output_h; h++) { + for (int w = 0; w < output_w; w++) { + for (int c = 0; c < output_channel; c++) { + int oc8_block = c / C8NUM; + int oc8_res = c % C8NUM; + int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * tile_num + + C8NUM * (h * out_w_block * output_unit + w) + oc8_res; + int dst_offset = (h * output_w + w) * output_channel + c; + (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; + if (relu) { + (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; + } else if (relu6) { + (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; + (output_data + dst_offset)[0] = (output_data + dst_offset)[0] > 6 ? 6 : (output_data + dst_offset)[0]; + } + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h new file mode 100644 index 0000000000..c28ffbdd30 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_CONV_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_CONV_FP16_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +#ifndef ENABLE_NEON +void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +#endif + +// fp16 convolution common (im2col+gemm) +void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, + float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param); + +// fp16 conv3x3 +void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, + float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, + int task_id, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_CONV_FP16_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.cc new file mode 100644 index 0000000000..2d2f7a93cc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.cc @@ -0,0 +1,346 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include +#include + +void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int channel_block = UP_DIV(in_channel, 4); + int kernel_plane = kernel_h * kernel_w; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * channel_block * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * channel_block * C4NUM; + int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM; + for (int m = 0; m < channel_block; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * 16 * C4NUM; +#ifdef ENABLE_ARM64 + vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride)); +#else + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; +#endif + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + } // tile num loop +} + +void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) { + // original weight format : ohwi + int tile_num = 8; + int inchannel_block = 4; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int kernel_block = UP_DIV(out_channel, tile_num); + int channel_block = UP_DIV(in_channel, inchannel_block); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane; + + int unit_size = tile_num * inchannel_block; + int block_size = pack_weight_size / kernel_block; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * channel_block; + for (int i = 0; i < channel_block; i++) { + int channel_block_stride = kernel_plane_stride + i * inchannel_block; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * inchannel_block; + int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h * tile_num; + for (int j = 0; j < kernel_block; j++) { + int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * tile_num; + int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num; + for (int k = 0; k < real_oc_num; k++) { + float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k; + *packed_data_ptr = *origin_data_ptr; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = UP_DIV(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; + for (int i = 0; i < input_channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; + } + } + } +} + +void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic4 = UP_DIV(input_channel, C4NUM); + int output_channel = conv_param->output_channel_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C4NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic4 * kernel_plane * C4NUM; + for (int i = 0; i < input_channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c4_block_num * kernel_plane * C4NUM + c4_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; + } + } + } +} + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + nhwc4_batch_offset + i * ic4 * C4NUM, (float16_t *)src + batch_offset + i * channel, + channel * sizeof(float16_t)); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; + ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + for (int c = 0; c < channel; c++) { + (dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c]; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } +} + +void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc_batch_unit_offset = channel * plane; + int nhwc_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c8 * C8NUM * plane; + for (int i = 0; i < plane; i++) { + for (int c = 0; c < channel; c++) { + (dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c]; + } + } + nhwc_batch_offset += nhwc_batch_unit_offset; + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h new file mode 100644 index 0000000000..a4ac6201d1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_PACK_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_PACK_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, + int block_index); + +void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight); + +void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); + +void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_PACK_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc new file mode 100644 index 0000000000..e585509825 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc @@ -0,0 +1,532 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h" + +// for fp16 convolution 3x3 filter/input/output transform F(4,3) +void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { + float16x4_t d00 = vld1_f16(tmp_data); + float16x4_t d01 = vld1_f16(tmp_data + 4); + float16x4_t d02 = vld1_f16(tmp_data + 2 * 4); + float16x4_t d03 = vld1_f16(tmp_data + 3 * 4); + float16x4_t d04 = vld1_f16(tmp_data + 4 * 4); + float16x4_t d05 = vld1_f16(tmp_data + 5 * 4); + + float16x4_t d10 = vld1_f16(tmp_data + 6 * 4); + float16x4_t d11 = vld1_f16(tmp_data + 7 * 4); + float16x4_t d12 = vld1_f16(tmp_data + 8 * 4); + float16x4_t d13 = vld1_f16(tmp_data + 9 * 4); + float16x4_t d14 = vld1_f16(tmp_data + 10 * 4); + float16x4_t d15 = vld1_f16(tmp_data + 11 * 4); + + float16x4_t d20 = vld1_f16(tmp_data + 12 * 4); + float16x4_t d21 = vld1_f16(tmp_data + 13 * 4); + float16x4_t d22 = vld1_f16(tmp_data + 14 * 4); + float16x4_t d23 = vld1_f16(tmp_data + 15 * 4); + float16x4_t d24 = vld1_f16(tmp_data + 16 * 4); + float16x4_t d25 = vld1_f16(tmp_data + 17 * 4); + + float16x4_t d30 = vld1_f16(tmp_data + 18 * 4); + float16x4_t d31 = vld1_f16(tmp_data + 19 * 4); + float16x4_t d32 = vld1_f16(tmp_data + 20 * 4); + float16x4_t d33 = vld1_f16(tmp_data + 21 * 4); + float16x4_t d34 = vld1_f16(tmp_data + 22 * 4); + float16x4_t d35 = vld1_f16(tmp_data + 23 * 4); + + float16x4_t d40 = vld1_f16(tmp_data + 24 * 4); + float16x4_t d41 = vld1_f16(tmp_data + 25 * 4); + float16x4_t d42 = vld1_f16(tmp_data + 26 * 4); + float16x4_t d43 = vld1_f16(tmp_data + 27 * 4); + float16x4_t d44 = vld1_f16(tmp_data + 28 * 4); + float16x4_t d45 = vld1_f16(tmp_data + 29 * 4); + + float16x4_t d50 = vld1_f16(tmp_data + 30 * 4); + float16x4_t d51 = vld1_f16(tmp_data + 31 * 4); + float16x4_t d52 = vld1_f16(tmp_data + 32 * 4); + float16x4_t d53 = vld1_f16(tmp_data + 33 * 4); + float16x4_t d54 = vld1_f16(tmp_data + 34 * 4); + float16x4_t d55 = vld1_f16(tmp_data + 35 * 4); + + float16x4_t t00 = vadd_f16(vsub_f16(vmul_n_f16(d00, 4), vmul_n_f16(d20, 5)), d40); + float16x4_t t01 = vadd_f16(vsub_f16(vmul_n_f16(d01, 4), vmul_n_f16(d21, 5)), d41); + float16x4_t t02 = vadd_f16(vsub_f16(vmul_n_f16(d02, 4), vmul_n_f16(d22, 5)), d42); + float16x4_t t03 = vadd_f16(vsub_f16(vmul_n_f16(d03, 4), vmul_n_f16(d23, 5)), d43); + float16x4_t t04 = vadd_f16(vsub_f16(vmul_n_f16(d04, 4), vmul_n_f16(d24, 5)), d44); + float16x4_t t05 = vadd_f16(vsub_f16(vmul_n_f16(d05, 4), vmul_n_f16(d25, 5)), d45); + + float16x4_t t10 = vadd_f16(vadd_f16(d30, d40), vmul_n_f16(vadd_f16(d10, d20), -4)); + float16x4_t t11 = vadd_f16(vadd_f16(d31, d41), vmul_n_f16(vadd_f16(d11, d21), -4)); + float16x4_t t12 = vadd_f16(vadd_f16(d32, d42), vmul_n_f16(vadd_f16(d12, d22), -4)); + float16x4_t t13 = vadd_f16(vadd_f16(d33, d43), vmul_n_f16(vadd_f16(d13, d23), -4)); + float16x4_t t14 = vadd_f16(vadd_f16(d34, d44), vmul_n_f16(vadd_f16(d14, d24), -4)); + float16x4_t t15 = vadd_f16(vadd_f16(d35, d45), vmul_n_f16(vadd_f16(d15, d25), -4)); + + float16x4_t t20 = vadd_f16(vsub_f16(d40, d30), vmul_n_f16(vsub_f16(d10, d20), 4)); + float16x4_t t21 = vadd_f16(vsub_f16(d41, d31), vmul_n_f16(vsub_f16(d11, d21), 4)); + float16x4_t t22 = vadd_f16(vsub_f16(d42, d32), vmul_n_f16(vsub_f16(d12, d22), 4)); + float16x4_t t23 = vadd_f16(vsub_f16(d43, d33), vmul_n_f16(vsub_f16(d13, d23), 4)); + float16x4_t t24 = vadd_f16(vsub_f16(d44, d34), vmul_n_f16(vsub_f16(d14, d24), 4)); + float16x4_t t25 = vadd_f16(vsub_f16(d45, d35), vmul_n_f16(vsub_f16(d15, d25), 4)); + + float16x4_t t30 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d30, d10), 2)); + float16x4_t t31 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d31, d11), 2)); + float16x4_t t32 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d32, d12), 2)); + float16x4_t t33 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d33, d13), 2)); + float16x4_t t34 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d34, d14), 2)); + float16x4_t t35 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d35, d15), 2)); + + float16x4_t t40 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d10, d30), 2)); + float16x4_t t41 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d11, d31), 2)); + float16x4_t t42 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d12, d32), 2)); + float16x4_t t43 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d13, d33), 2)); + float16x4_t t44 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d14, d34), 2)); + float16x4_t t45 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d15, d35), 2)); + + float16x4_t t50 = vadd_f16(vsub_f16(vmul_n_f16(d10, 4), vmul_n_f16(d30, 5)), d50); + float16x4_t t51 = vadd_f16(vsub_f16(vmul_n_f16(d11, 4), vmul_n_f16(d31, 5)), d51); + float16x4_t t52 = vadd_f16(vsub_f16(vmul_n_f16(d12, 4), vmul_n_f16(d32, 5)), d52); + float16x4_t t53 = vadd_f16(vsub_f16(vmul_n_f16(d13, 4), vmul_n_f16(d33, 5)), d53); + float16x4_t t54 = vadd_f16(vsub_f16(vmul_n_f16(d14, 4), vmul_n_f16(d34, 5)), d54); + float16x4_t t55 = vadd_f16(vsub_f16(vmul_n_f16(d15, 4), vmul_n_f16(d35, 5)), d55); + + float16x4_t m00 = vadd_f16(vsub_f16(vmul_n_f16(t00, 4), vmul_n_f16(t02, 5)), t04); + float16x4_t m01 = vadd_f16(vadd_f16(t03, t04), vmul_n_f16(vadd_f16(t01, t02), -4)); + float16x4_t m02 = vadd_f16(vsub_f16(t04, t03), vmul_n_f16(vsub_f16(t01, t02), 4)); + float16x4_t m03 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t03, t01), 2)); + float16x4_t m04 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t01, t03), 2)); + float16x4_t m05 = vadd_f16(vsub_f16(vmul_n_f16(t01, 4), vmul_n_f16(t03, 5)), t05); + + float16x4_t m10 = vadd_f16(vsub_f16(vmul_n_f16(t10, 4), vmul_n_f16(t12, 5)), t14); + float16x4_t m11 = vadd_f16(vadd_f16(t13, t14), vmul_n_f16(vadd_f16(t11, t12), -4)); + float16x4_t m12 = vadd_f16(vsub_f16(t14, t13), vmul_n_f16(vsub_f16(t11, t12), 4)); + float16x4_t m13 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t13, t11), 2)); + float16x4_t m14 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t11, t13), 2)); + float16x4_t m15 = vadd_f16(vsub_f16(vmul_n_f16(t11, 4), vmul_n_f16(t13, 5)), t15); + + float16x4_t m20 = vadd_f16(vsub_f16(vmul_n_f16(t20, 4), vmul_n_f16(t22, 5)), t24); + float16x4_t m21 = vadd_f16(vadd_f16(t23, t24), vmul_n_f16(vadd_f16(t21, t22), -4)); + float16x4_t m22 = vadd_f16(vsub_f16(t24, t23), vmul_n_f16(vsub_f16(t21, t22), 4)); + float16x4_t m23 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t23, t21), 2)); + float16x4_t m24 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t21, t23), 2)); + float16x4_t m25 = vadd_f16(vsub_f16(vmul_n_f16(t21, 4), vmul_n_f16(t23, 5)), t25); + + float16x4_t m30 = vadd_f16(vsub_f16(vmul_n_f16(t30, 4), vmul_n_f16(t32, 5)), t34); + float16x4_t m31 = vadd_f16(vadd_f16(t33, t34), vmul_n_f16(vadd_f16(t31, t32), -4)); + float16x4_t m32 = vadd_f16(vsub_f16(t34, t33), vmul_n_f16(vsub_f16(t31, t32), 4)); + float16x4_t m33 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t33, t31), 2)); + float16x4_t m34 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t31, t33), 2)); + float16x4_t m35 = vadd_f16(vsub_f16(vmul_n_f16(t31, 4), vmul_n_f16(t33, 5)), t35); + + float16x4_t m40 = vadd_f16(vsub_f16(vmul_n_f16(t40, 4), vmul_n_f16(t42, 5)), t44); + float16x4_t m41 = vadd_f16(vadd_f16(t43, t44), vmul_n_f16(vadd_f16(t41, t42), -4)); + float16x4_t m42 = vadd_f16(vsub_f16(t44, t43), vmul_n_f16(vsub_f16(t41, t42), 4)); + float16x4_t m43 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t43, t41), 2)); + float16x4_t m44 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t41, t43), 2)); + float16x4_t m45 = vadd_f16(vsub_f16(vmul_n_f16(t41, 4), vmul_n_f16(t43, 5)), t45); + + float16x4_t m50 = vadd_f16(vsub_f16(vmul_n_f16(t50, 4), vmul_n_f16(t52, 5)), t54); + float16x4_t m51 = vadd_f16(vadd_f16(t53, t54), vmul_n_f16(vadd_f16(t51, t52), -4)); + float16x4_t m52 = vadd_f16(vsub_f16(t54, t53), vmul_n_f16(vsub_f16(t51, t52), 4)); + float16x4_t m53 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t53, t51), 2)); + float16x4_t m54 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t51, t53), 2)); + float16x4_t m55 = vadd_f16(vsub_f16(vmul_n_f16(t51, 4), vmul_n_f16(t53, 5)), t55); + + vst1_f16(trans_input_data, m00); + vst1_f16(trans_input_data + step, m01); + vst1_f16(trans_input_data + 2 * step, m02); + vst1_f16(trans_input_data + 3 * step, m03); + vst1_f16(trans_input_data + 4 * step, m04); + vst1_f16(trans_input_data + 5 * step, m05); + + vst1_f16(trans_input_data + 6 * step, m10); + vst1_f16(trans_input_data + 7 * step, m11); + vst1_f16(trans_input_data + 8 * step, m12); + vst1_f16(trans_input_data + 9 * step, m13); + vst1_f16(trans_input_data + 10 * step, m14); + vst1_f16(trans_input_data + 11 * step, m15); + + vst1_f16(trans_input_data + 12 * step, m20); + vst1_f16(trans_input_data + 13 * step, m21); + vst1_f16(trans_input_data + 14 * step, m22); + vst1_f16(trans_input_data + 15 * step, m23); + vst1_f16(trans_input_data + 16 * step, m24); + vst1_f16(trans_input_data + 17 * step, m25); + + vst1_f16(trans_input_data + 18 * step, m30); + vst1_f16(trans_input_data + 19 * step, m31); + vst1_f16(trans_input_data + 20 * step, m32); + vst1_f16(trans_input_data + 21 * step, m33); + vst1_f16(trans_input_data + 22 * step, m34); + vst1_f16(trans_input_data + 23 * step, m35); + + vst1_f16(trans_input_data + 24 * step, m40); + vst1_f16(trans_input_data + 25 * step, m41); + vst1_f16(trans_input_data + 26 * step, m42); + vst1_f16(trans_input_data + 27 * step, m43); + vst1_f16(trans_input_data + 28 * step, m44); + vst1_f16(trans_input_data + 29 * step, m45); + + vst1_f16(trans_input_data + 30 * step, m50); + vst1_f16(trans_input_data + 31 * step, m51); + vst1_f16(trans_input_data + 32 * step, m52); + vst1_f16(trans_input_data + 33 * step, m53); + vst1_f16(trans_input_data + 34 * step, m54); + vst1_f16(trans_input_data + 35 * step, m55); +} + +void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int output_unit = 4; + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + int ic4 = UP_DIV(input_channel, C4NUM); + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * output_unit - pad_w; + int origin_y = (x_id / out_w_block) * output_unit - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); + + int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, 6 * 6 * C4NUM * sizeof(float16_t)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = real_y_start; interval < real_y_end; interval++) { + int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM; + int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM; + for (int j = 0; j < (real_x_end - real_x_start); j++) { + int src_x_offset = src_y_offset + j * ic4 * C4NUM; + int dst_x_offset = dst_y_offset + j * C4NUM; + float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; + float16_t *dst_addr = tmp_data + dst_x_offset; + dst_addr[0] = src_addr[0]; + dst_addr[1] = src_addr[1]; + dst_addr[2] = src_addr[2]; + dst_addr[3] = src_addr[3]; + } + } + + // todo + // input transform + int dst_ic4_offset = dst_plane_offset + ic * 16 * C4NUM; + size_t dst_step = ic4 * C4NUM * 16; + float16_t *trans_input_ptr = trans_input + dst_ic4_offset; + Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); + } + } +} + +void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, + int kernel_plane) { + int dst_step = iC4 * C4NUM * 8; + for (int o = 0; o < output_channel; o++) { + int oc8_block_num = o / C8NUM; + int oc8_block_rem = o % C8NUM; + int src_oc_offset = o * iC4 * C4NUM * kernel_plane; + int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; + for (int i = 0; i < iC4; i++) { + const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; + float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; + float16x4_t g00 = vld1_f16(src_ic4_ptr); + float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); + float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); + float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); + float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); + float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); + float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); + float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); + float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); + + float16x4_t dst00 = vmul_n_f16(g00, 0.25); + float16x4_t dst01 = vmul_n_f16(g01, 0.25); + float16x4_t dst02 = vmul_n_f16(g02, 0.25); + + float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); + float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); + float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); + + float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); + float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); + float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); + + float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), + vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); + float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), + vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); + float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), + vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); + + float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), + vmul_n_f16(g10, 0.08333333333333)); + float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), + vmul_n_f16(g11, 0.08333333333333)); + float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), + vmul_n_f16(g12, 0.08333333333333)); + + float16x4_t dst50 = g20; + float16x4_t dst51 = g21; + float16x4_t dst52 = g22; + + float16x4_t m00 = vmul_n_f16(dst00, 0.25); + float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); + float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); + float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), + vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); + float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), + vmul_n_f16(dst01, 0.08333333333333)); + float16x4_t m05 = dst02; + + float16x4_t m10 = vmul_n_f16(dst10, 0.25); + float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); + float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); + float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), + vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); + float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), + vmul_n_f16(dst11, 0.08333333333333)); + float16x4_t m15 = dst12; + + float16x4_t m20 = vmul_n_f16(dst20, 0.25); + float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); + float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); + float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), + vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); + float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), + vmul_n_f16(dst21, 0.08333333333333)); + float16x4_t m25 = dst22; + + float16x4_t m30 = vmul_n_f16(dst30, 0.25); + float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); + float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); + float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), + vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); + float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), + vmul_n_f16(dst31, 0.08333333333333)); + float16x4_t m35 = dst32; + + float16x4_t m40 = vmul_n_f16(dst40, 0.25); + float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); + float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); + float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), + vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); + float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), + vmul_n_f16(dst41, 0.08333333333333)); + float16x4_t m45 = dst42; + + float16x4_t m50 = vmul_n_f16(dst50, 0.25); + float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); + float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); + float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), + vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); + float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), + vmul_n_f16(dst51, 0.08333333333333)); + float16x4_t m55 = dst52; + + for (int j = 0; j < 4; j++) { + dst_ic4_ptr[j * 8] = m00[j]; + dst_ic4_ptr[j * 8 + dst_step] = m01[j]; + dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; + dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; + dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; + dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; + dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; + dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; + dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; + dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; + dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; + dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; + dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; + dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; + dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; + dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; + dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; + dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; + dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; + dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; + dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; + dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; + dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; + dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; + dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; + dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; + dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; + dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; + dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; + dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; + dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; + dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; + dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; + dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; + dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; + dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; + } + } + } +} + +void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, + int output_w) { + float16x8_t s00 = vld1q_f16(gemm_out); + float16x8_t s01 = vld1q_f16(gemm_out + 8); + float16x8_t s02 = vld1q_f16(gemm_out + 16); + float16x8_t s03 = vld1q_f16(gemm_out + 24); + float16x8_t s04 = vld1q_f16(gemm_out + 32); + float16x8_t s05 = vld1q_f16(gemm_out + 40); + + float16x8_t s10 = vld1q_f16(gemm_out + 48); + float16x8_t s11 = vld1q_f16(gemm_out + 56); + float16x8_t s12 = vld1q_f16(gemm_out + 64); + float16x8_t s13 = vld1q_f16(gemm_out + 72); + float16x8_t s14 = vld1q_f16(gemm_out + 80); + float16x8_t s15 = vld1q_f16(gemm_out + 88); + + float16x8_t s20 = vld1q_f16(gemm_out + 96); + float16x8_t s21 = vld1q_f16(gemm_out + 104); + float16x8_t s22 = vld1q_f16(gemm_out + 112); + float16x8_t s23 = vld1q_f16(gemm_out + 120); + float16x8_t s24 = vld1q_f16(gemm_out + 128); + float16x8_t s25 = vld1q_f16(gemm_out + 136); + + float16x8_t s30 = vld1q_f16(gemm_out + 144); + float16x8_t s31 = vld1q_f16(gemm_out + 152); + float16x8_t s32 = vld1q_f16(gemm_out + 160); + float16x8_t s33 = vld1q_f16(gemm_out + 168); + float16x8_t s34 = vld1q_f16(gemm_out + 176); + float16x8_t s35 = vld1q_f16(gemm_out + 184); + + float16x8_t s40 = vld1q_f16(gemm_out + 192); + float16x8_t s41 = vld1q_f16(gemm_out + 200); + float16x8_t s42 = vld1q_f16(gemm_out + 208); + float16x8_t s43 = vld1q_f16(gemm_out + 216); + float16x8_t s44 = vld1q_f16(gemm_out + 224); + float16x8_t s45 = vld1q_f16(gemm_out + 232); + + float16x8_t s50 = vld1q_f16(gemm_out + 240); + float16x8_t s51 = vld1q_f16(gemm_out + 248); + float16x8_t s52 = vld1q_f16(gemm_out + 256); + float16x8_t s53 = vld1q_f16(gemm_out + 264); + float16x8_t s54 = vld1q_f16(gemm_out + 272); + float16x8_t s55 = vld1q_f16(gemm_out + 280); + + float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); + float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); + float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); + float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); + float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); + float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); + + float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); + float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); + float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); + float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); + float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); + float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); + + float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); + float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); + float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); + float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); + float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); + float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); + + float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); + float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); + float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); + float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); + float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); + float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); + + float16x8_t bias_ptr = vld1q_f16(bias_data); + float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr); + float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr); + float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr); + float16x8_t d03 = + vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr); + + float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr); + float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr); + float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr); + float16x8_t d13 = + vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr); + + float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr); + float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr); + float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr); + float16x8_t d23 = + vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr); + + float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr); + float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr); + float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr); + float16x8_t d33 = + vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr); + + vst1q_f16(output_data, d00); + vst1q_f16(output_data + 8, d01); + vst1q_f16(output_data + 16, d02); + vst1q_f16(output_data + 24, d03); + + vst1q_f16(output_data + output_w * 8, d10); + vst1q_f16(output_data + output_w * 8 + 8, d11); + vst1q_f16(output_data + output_w * 8 + 16, d12); + vst1q_f16(output_data + output_w * 8 + 24, d13); + + vst1q_f16(output_data + 2 * output_w * 8, d20); + vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); + vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); + vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); + + vst1q_f16(output_data + 3 * output_w * 8, d30); + vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); + vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); + vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); +} + +void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc8 = UP_DIV(output_channel, C8NUM); +// todo outputw --> out_w_block * out_unit + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc8 * C8NUM * 36; + int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w); + + for (int j = 0; j < oc8; j++) { + int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = out_data + dst_oc8_offset; + + // output transform + Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h new file mode 100644 index 0000000000..235b9979c0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ + +#include +#include +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" + +// for fp16 convolution 3x3 filter/input/output transform +void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); + +void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); + +void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h new file mode 100644 index 0000000000..249bfacbce --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ACTIVATION_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +struct ActivationParameter { + OpParameter op_parameter_; + int type_; + float alpha_{0.01}; +}; + +inline int Relu(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +inline int Relu6(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + if (src[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src[i] > 6.0f ? 6.0f : src[i]; + } + } + return NNACL_OK; +} + +inline int LRelu(const float *src, int length, float *dst, float alpha) { + for (int i = 0; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); + } + return NNACL_OK; +} + +inline int Sigmoid(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = 1.0f / (1.0f + exp(-src[i])); + } + return NNACL_OK; +} + +inline int Tanh(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = 1.0f - 2.0f / (exp(2 * src[i]) + 1); + } + return NNACL_OK; +} + +inline int HSwish(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float in = src[i]; + float relu6 = MSMIN(MSMAX(in + 3, 0), 6); + dst[i] = in * relu6 / 6; + } + return NNACL_OK; +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ACTIVATION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.cc new file mode 100644 index 0000000000..2b592dc574 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.cc @@ -0,0 +1,486 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h" +#include +#include + +int ArgCompareAscFp32(const void *a, const void *b) { + return reinterpret_cast(a)->data_.f_data_ + - reinterpret_cast(b)->data_.f_data_; +} + +int ArgCompareDescFp32(const void *a, const void *b) { + return reinterpret_cast(b)->data_.f_data_ + - reinterpret_cast(a)->data_.f_data_; +} + +void ArgMaxDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescFp32); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + output[out_offset] = param->arg_elements_[j].data_.f_data_; + } + } +} + +void ArgMaxDim0OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescFp32); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + output[out_offset] = param->arg_elements_[j].index_; + } + } +} + +void ArgMinDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscFp32); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + output[out_offset] = param->arg_elements_[j].data_.f_data_; + } + } +} + +void ArgMinDim0OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscFp32); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + output[out_offset] = param->arg_elements_[j].index_; + } + } +} + +void ArgMaxDim1OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescFp32); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + output[out_offset] = param->arg_elements_[k].data_.f_data_; + } + } + } +} + +void ArgMaxDim1OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescFp32); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + output[out_offset] = param->arg_elements_[k].index_; + } + } + } +} + +void ArgMinDim1OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscFp32); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + output[out_offset] = param->arg_elements_[k].data_.f_data_; + } + } + } +} + +void ArgMinDim1OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscFp32); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + output[out_offset] = param->arg_elements_[k].index_; + } + } + } +} + +void ArgMaxDim2OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + output[out_offset] = param->arg_elements_[l].data_.f_data_; + } + } + } + } +} + +void ArgMaxDim2OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + output[out_offset] = param->arg_elements_[l].index_; + } + } + } + } +} + +void ArgMinDim2OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + output[out_offset] = param->arg_elements_[l].data_.f_data_; + } + } + } + } +} + +void ArgMinDim2OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + output[out_offset] = param->arg_elements_[l].index_; + } + } + } + } +} + +void ArgMaxDim3OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + output[out_offset] = param->arg_elements_[l].data_.f_data_; + } + } + } + } +} + +void ArgMaxDim3OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + output[out_offset] = param->arg_elements_[l].index_; + } + } + } + } +} + +void ArgMinDim3OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + output[out_offset] = param->arg_elements_[l].data_.f_data_; + } + } + } + } +} + +void ArgMinDim3OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscFp32); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + output[out_offset] = param->arg_elements_[l].index_; + } + } + } + } +} + +void ArgMaxDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMaxDim0OutValue(input, output, in_shape, param); + } else { + ArgMaxDim0OutIndex(input, output, in_shape, param); + } +} + +void ArgMinDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMinDim0OutValue(input, output, in_shape, param); + } else { + ArgMinDim0OutIndex(input, output, in_shape, param); + } +} + +void ArgMaxDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMaxDim1OutValue(input, output, in_shape, param); + } else { + ArgMaxDim1OutIndex(input, output, in_shape, param); + } +} + +void ArgMinDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMinDim1OutValue(input, output, in_shape, param); + } else { + ArgMinDim1OutIndex(input, output, in_shape, param); + } +} + +void ArgMaxDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMaxDim2OutValue(input, output, in_shape, param); + } else { + ArgMaxDim2OutIndex(input, output, in_shape, param); + } +} + +void ArgMinDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMinDim2OutValue(input, output, in_shape, param); + } else { + ArgMinDim2OutIndex(input, output, in_shape, param); + } +} + +void ArgMaxDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMaxDim3OutValue(input, output, in_shape, param); + } else { + ArgMaxDim3OutIndex(input, output, in_shape, param); + } +} + +void ArgMinDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { + if (param->out_value_) { + ArgMinDim3OutValue(input, output, in_shape, param); + } else { + ArgMinDim3OutIndex(input, output, in_shape, param); + } +} + +void ArgMax(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count, + int after_axis_count) { + bool out_value = param->out_value_; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = -FLT_MAX; + float index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } + output[output_offset + j] = out_value ? value : index; + } + } +} + +void ArgMin(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count, + int after_axis_count) { + bool out_value = param->out_value_; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = FLT_MAX; + float index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + output[output_offset + j] = out_value ? value : index; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h new file mode 100644 index 0000000000..8812459de2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARG_MIN_MAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARG_MIN_MAX_H_ + +#include "src/runtime/kernel/arm/nnacl/arg_min_max_parameter.h" + +void ArgMax(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count, + int after_axis_count); +void ArgMin(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count, + int after_axis_count); +void ArgMaxDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMinDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMaxDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMinDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMaxDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMinDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMaxDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +void ArgMinDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARG_MIN_MAX_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc new file mode 100644 index 0000000000..0fc3d805a7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc @@ -0,0 +1,759 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +int ElementMul(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmulq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] * input1[0]; + output[1] = input0[1] * input1[1]; + output[2] = input0[2] * input1[2]; + output[3] = input0[3] * input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] * input1[index]; + } + + return NNACL_OK; +} + +int ElementMulRelu(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmulq_f32(vin0, vin1); + vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); + vst1q_f32(output, vout); +#else + float res = input0[0] * input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] * input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] * input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] * input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] * input1[index]; + output[index] = res > 0 ? res : 0; + } + + return NNACL_OK; +} + +int ElementMulRelu6(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); + } + + return NNACL_OK; +} + +int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementMul(tile_input0, tile_input1, output, element_size); +} + +int ElementAdd(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vaddq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] + input1[0]; + output[1] = input0[1] + input1[1]; + output[2] = input0[2] + input1[2]; + output[3] = input0[3] + input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] + input1[index]; + } + return NNACL_OK; +} + +int ElementAddRelu(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vaddq_f32(vin0, vin1); + vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); + vst1q_f32(output, vout); +#else + float res = input0[0] + input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] + input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] + input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] + input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] + input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementAddRelu6(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); + } + + return NNACL_OK; +} + +int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] + input1[i]; + } + return NNACL_OK; +} + +int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementAdd(tile_input0, tile_input1, output, element_size); +} + +int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, + int element_size, ArithmeticParameter *param) { + TileDimensionsInt8(input0, input1, tile_input0, tile_input1, param); + return ElementAddInt8(tile_input0, tile_input1, output, element_size); +} + +int ElementSub(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vsubq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] - input1[0]; + output[1] = input0[1] - input1[1]; + output[2] = input0[2] - input1[2]; + output[3] = input0[3] - input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] - input1[index]; + } + return NNACL_OK; +} + +int ElementSubRelu(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vsubq_f32(vin0, vin1); + vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); + vst1q_f32(output, vout); +#else + float res = input0[0] - input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] - input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] - input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] - input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] - input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementSubRelu6(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); + } + + return NNACL_OK; +} + +int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementSub(tile_input0, tile_input1, output, element_size); +} + +// todo c=a/b,if(b==0) +int ElementDiv(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] / input1[i]; + } + return NNACL_OK; +} + +int ElementDivRelu(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + float res = input0[i] / input1[i]; + output[i] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementDivRelu6(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = MSMIN(MSMAX(input0[i] / input1[i], 0), 6); + } + return NNACL_OK; +} + +int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementDiv(tile_input0, tile_input1, output, element_size); +} + +int ElementFloorMod(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + return NNACL_OK; +} + +int BroadcastFloorMod(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementFloorMod(tile_input0, tile_input1, output, element_size); +} + +int ElementFloorDiv(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = floorf(input0[i] / input1[i]); + } + return NNACL_OK; +} + +int BroadcastFloorDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementFloorDiv(tile_input0, tile_input1, output, element_size); +} + +int ElementLogicalAnd(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; + uint32x4_t mask = vmovq_n_u32((uint32_t(1u << 31) - 1)); + uint32x4_t zeros = {0, 0, 0, 0}; +#endif + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input0)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input1)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vandq_u32(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f32(output, vout); +#else + output[0] = (float)((bool)(input0[0]) & (bool)(input1[0])); + output[1] = (float)((bool)(input0[1]) & (bool)(input1[1])); + output[2] = (float)((bool)(input0[2]) & (bool)(input1[2])); + output[3] = (float)((bool)(input0[3]) & (bool)(input1[3])); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)((bool)(input0[index]) & (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementSquaredDifference(float *input0, float *input1, float *output, int element_size) { + ElementSub(input0, input1, output, element_size); + return ElementMul(output, output, output, element_size); +} + +int BroadcastSquaredDifference(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + BroadcastSub(input0, input1, tile_input0, tile_input1, output, element_size, param); + return ElementMul(output, output, output, element_size); +} + +int BroadcastLogicalAnd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLogicalAnd(tile_input0, tile_input1, output, element_size); +} + +int ElementLogicalOr(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; + uint32x4_t mask = vmovq_n_u32((uint32_t(1u << 31) - 1)); + uint32x4_t zeros = {0, 0, 0, 0}; +#endif + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input0)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input1)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f32(output, vout); +#else + output[0] = (float)((bool)(input0[0]) | (bool)(input1[0])); + output[1] = (float)((bool)(input0[1]) | (bool)(input1[1])); + output[2] = (float)((bool)(input0[2]) | (bool)(input1[2])); + output[3] = (float)((bool)(input0[3]) | (bool)(input1[3])); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)((bool)(input0[index]) | (bool)(input1[index])); + } + return NNACL_OK; +} + +int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLogicalOr(tile_input0, tile_input1, output, element_size); +} + +int ElementMaximum(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmaxq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] > input1[0] ? input0[0] : input1[0]; + output[1] = input0[1] > input1[1] ? input0[1] : input1[1]; + output[2] = input0[2] > input1[2] ? input0[2] : input1[2]; + output[3] = input0[3] > input1[3] ? input0[3] : input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] > input1[index] ? input0[index] : input1[index]; + } + return NNACL_OK; +} + +int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementMaximum(tile_input0, tile_input1, output, element_size); +} + +int ElementMinimum(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] > input1[0] ? input1[0] : input0[0]; + output[1] = input0[1] > input1[1] ? input1[1] : input0[1]; + output[2] = input0[2] > input1[2] ? input1[2] : input0[2]; + output[3] = input0[3] > input1[3] ? input1[3] : input0[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; + } + return NNACL_OK; +} + +int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementMinimum(tile_input0, tile_input1, output, element_size); +} + +int ElementNotEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] != input1[0]); + output[1] = (float)(input0[1] != input1[1]); + output[2] = (float)(input0[2] != input1[2]); + output[3] = (float)(input0[3] != input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] != input1[index]); + } + return NNACL_OK; +} + +int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementNotEqual(tile_input0, tile_input1, output, element_size); +} + +int ElementEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] == input1[0]); + output[1] = (float)(input0[1] == input1[1]); + output[2] = (float)(input0[2] == input1[2]); + output[3] = (float)(input0[3] == input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] == input1[index]); + } + return NNACL_OK; +} + +int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementEqual(tile_input0, tile_input1, output, element_size); +} + +int ElementLess(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcltq_f32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] < input1[0]); + output[1] = (float)(input0[1] < input1[1]); + output[2] = (float)(input0[2] < input1[2]); + output[3] = (float)(input0[3] < input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] < input1[index]); + } + return NNACL_OK; +} + +int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLess(tile_input0, tile_input1, output, element_size); +} + +int ElementLessEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcleq_f32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] <= input1[0]); + output[1] = (float)(input0[1] <= input1[1]); + output[2] = (float)(input0[2] <= input1[2]); + output[3] = (float)(input0[3] <= input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] <= input1[index]); + } + return NNACL_OK; +} + +int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLessEqual(tile_input0, tile_input1, output, element_size); +} + +int ElementGreater(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcgtq_f32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] > input1[0]); + output[1] = (float)(input0[1] > input1[1]); + output[2] = (float)(input0[2] > input1[2]); + output[3] = (float)(input0[3] > input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] > input1[index]); + } + return NNACL_OK; +} + +int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementGreater(tile_input0, tile_input1, output, element_size); +} + +int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcgeq_f32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] >= input1[0]); + output[1] = (float)(input0[1] >= input1[1]); + output[2] = (float)(input0[2] >= input1[2]); + output[3] = (float)(input0[3] >= input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] >= input1[index]); + } + return NNACL_OK; +} + +int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementGreaterEqual(tile_input0, tile_input1, output, element_size); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h new file mode 100644 index 0000000000..81f388800a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int ElementMul(float *input0, float *input1, float *output, int element_size); +int ElementMulRelu(float *input0, float *input1, float *output, int element_size); +int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); +int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementAdd(float *input0, float *input1, float *output, int element_size); +int ElementAddRelu(float *input0, float *input1, float *output, int element_size); +int ElementAddRelu6(float *input0, float *input1, float *output, int element_size); +int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); +int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, + int element_size, ArithmeticParameter *param); + +int ElementSub(float *input0, float *input1, float *output, int element_size); +int ElementSubRelu(float *input0, float *input1, float *output, int element_size); +int ElementSubRelu6(float *input0, float *input1, float *output, int element_size); +int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementDiv(float *input0, float *input1, float *output, int element_size); +int ElementDivRelu(float *input0, float *input1, float *output, int element_size); +int ElementDivRelu6(float *input0, float *input1, float *output, int element_size); +int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementLogicalAnd(float *input0, float *input1, float *output, int element_size); +int BroadcastLogicalAnd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementLogicalOr(float *input0, float *input1, float *output, int element_size); +int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementMaximum(float *input0, float *input1, float *output, int element_size); +int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementMinimum(float *input0, float *input1, float *output, int element_size); +int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementFloorDiv(float *input0, float *input1, float *output, int element_size); +int BroadcastFloorDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementFloorMod(float *input0, float *input1, float *output, int element_size); +int BroadcastFloorMod(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementSquaredDifference(float *input0, float *input1, float *output, int element_size); +int BroadcastSquaredDifference(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementNotEqual(float *input0, float *input1, float *output, int element_size); + +int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementEqual(float *input0, float *input1, float *output, int element_size); + +int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementLess(float *input0, float *input1, float *output, int element_size); +int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementLessEqual(float *input0, float *input1, float *output, int element_size); +int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementGreater(float *input0, float *input1, float *output, int element_size); +int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size); +int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.cc new file mode 100644 index 0000000000..f13dc824f8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" + +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h new file mode 100644 index 0000000000..9994b4d66d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ + +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.cc new file mode 100644 index 0000000000..6bb22ad021 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.h" + +// abs: +int ElementAbs(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +// cos: +int ElementCos(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +// exp: +int ElementExp(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = expf(input[i]); + } + return NNACL_OK; +} + +// log: +int ElementLog(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +// Square +int ElementSquare(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +// Sqrt +int ElementSqrt(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +// rsqrt +int ElementRsqrt(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +// sin: +int ElementSin(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNot(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +// round: +int ElementRound(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = round(input[i]); + } + return NNACL_OK; +} + +// floor: +int ElementFloor(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeil(float *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ceil(input[i]); + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.h new file mode 100644 index 0000000000..62ad6d3cf9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_self.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_SELF_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_SELF_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int ElementAbs(float *input, float *output, int element_size); + +int ElementCos(float *input, float *output, int element_size); + +int ElementExp(float *input, float *output, int element_size); + +int ElementLog(float *input, float *output, int element_size); + +int ElementSquare(float *input, float *output, int element_size); + +int ElementSqrt(float *input, float *output, int element_size); + +int ElementRsqrt(float *input, float *output, int element_size); + +int ElementSin(float *input, float *output, int element_size); + +int ElementLogicalNot(float *input, float *output, int element_size); + +int ElementRound(float *input, float *output, int element_size); + +int ElementFloor(float *input, float *output, int element_size); + +int ElementCeil(float *input, float *output, int number); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_SELF_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.cc new file mode 100644 index 0000000000..e61d5ea596 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/batchnorm.h" + +void BatchNorm(const float *input_ptr, const float *mean_ptr, const float *variance_ptr, int units, int channel, + float epsilon, float *output_ptr) { + for (int u = 0; u < units; u++) { + for (int c = 0; c < channel; c++) { + auto variance_sqrt = sqrt(variance_ptr[c] + epsilon); + output_ptr[u * channel + c] = (input_ptr[u * channel + c] - mean_ptr[c]) / variance_sqrt; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.h new file mode 100644 index 0000000000..135f7a73e0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCHNORM_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct BatchNormParameter { + OpParameter op_parameter_; + float epsilon_; +}; + +void BatchNorm(const float *input_ptr, const float *mean_ptr, const float *variance_ptr, int count, int channel, + float epsilon, float *output_ptr); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FUSED_BATCHNORM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/broadcast_to.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/broadcast_to.cc new file mode 100644 index 0000000000..b16bb3a9c6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/broadcast_to.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void PadBroadcastShapeInfo(BroadcastShapeInfo *shape_info) { + if (shape_info->input_shape_size_ < DIMENSION_4D) { + int input_shape_tmp[DIMENSION_4D]; + for (int i = 0; i < shape_info->input_shape_size_; ++i) { + input_shape_tmp[i] = shape_info->input_shape_[i]; + } + int input_shape_index = shape_info->input_shape_size_ - 1; + for (int i = DIMENSION_4D - 1; i >= 0; --i) { + if (input_shape_index >= 0) { + shape_info->input_shape_[i] = input_shape_tmp[input_shape_index--]; + } else { + shape_info->input_shape_[i] = 1; + } + } + } + if (shape_info->output_shape_size_ < DIMENSION_4D) { + int output_shape_tmp[DIMENSION_4D]; + for (int i = 0; i < shape_info->output_shape_size_; ++i) { + output_shape_tmp[i] = shape_info->output_shape_[i]; + } + int output_shape_index = shape_info->output_shape_size_ - 1; + for (int i = DIMENSION_4D - 1; i >= 0; --i) { + if (output_shape_index >= 0) { + shape_info->output_shape_[i] = output_shape_tmp[output_shape_index--]; + } else { + shape_info->output_shape_[i] = 1; + } + } + } +} + +int BroadcastTo(const float *input, BroadcastShapeInfo *shape_info, float *output) { + if (shape_info->input_shape_size_ > DIMENSION_4D || shape_info->output_shape_size_ > DIMENSION_4D) { + return -1; + } + PadBroadcastShapeInfo(shape_info); + size_t input_dim_offset[DIMENSION_4D - 1]; + input_dim_offset[2] = shape_info->input_shape_[3] * 4; + input_dim_offset[1] = input_dim_offset[2] * shape_info->input_shape_[2]; + input_dim_offset[0] = input_dim_offset[1] * shape_info->input_shape_[1]; + size_t output_dim_offset[DIMENSION_4D - 1]; + output_dim_offset[2] = shape_info->output_shape_[3] * 4; + output_dim_offset[1] = output_dim_offset[2] * shape_info->output_shape_[2]; + output_dim_offset[0] = output_dim_offset[1] * shape_info->output_shape_[1]; + uint8_t *in_base = (uint8_t *)input; + uint8_t *out_base = (uint8_t *)(output); + for (int32_t dim0 = 0; dim0 < shape_info->input_shape_[0]; ++dim0) { + for (int32_t dim1 = 0; dim1 < shape_info->input_shape_[1]; ++dim1) { + for (int32_t dim2 = 0; dim2 < shape_info->input_shape_[2]; ++dim2) { + if (shape_info->input_shape_[3] == shape_info->output_shape_[3]) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1 + + output_dim_offset[2] * dim2, + in_base + input_dim_offset[0] * dim0 + input_dim_offset[1] * dim1 + + input_dim_offset[2] * dim2, input_dim_offset[2]); + } else { + for (int32_t dim3 = 0; dim3 < shape_info->output_shape_[3]; ++dim3) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1 + + output_dim_offset[2] * dim2 + dim3 * 4, + in_base + input_dim_offset[0] * dim0 + input_dim_offset[1] * dim1 + + input_dim_offset[2] * dim2, 4); + } + } + } + if (shape_info->input_shape_[2] != shape_info->output_shape_[2]) { + for (int32_t dim2 = 0; dim2 < shape_info->output_shape_[2]; ++dim2) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1 + + dim2 * output_dim_offset[2], + out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1, + output_dim_offset[2]); + } + } + } + if (shape_info->input_shape_[1] != shape_info->output_shape_[1]) { + for (int32_t dim1 = 0; dim1 < shape_info->output_shape_[1]; ++dim1) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1, + out_base + output_dim_offset[0] * dim0, output_dim_offset[1]); + } + } + } + if (shape_info->input_shape_[0] != shape_info->output_shape_[0]) { + for (int32_t dim0 = 0; dim0 < shape_info->output_shape_[0]; ++dim0) { + memcpy(out_base + output_dim_offset[0] * dim0, out_base, output_dim_offset[0]); + } + } + return 0; +} + + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h new file mode 100644 index 0000000000..281cfaaa0d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BROADCAST_TO_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BROADCAST_TO_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +#define BROADCAST_TO_SHAPE_MAX_SIZE 4 + +struct BroadcastToParameter { + OpParameter op_parameter_; + int shape_[BROADCAST_TO_SHAPE_MAX_SIZE]; + size_t shape_size_; +}; + +struct BroadcastShapeInfo { + int input_shape_[BROADCAST_TO_SHAPE_MAX_SIZE]; + int input_shape_size_; + int output_shape_[BROADCAST_TO_SHAPE_MAX_SIZE]; + int output_shape_size_; +}; + +int BroadcastTo(const float *input, BroadcastShapeInfo *shape_info, float *output); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BROADCAST_TO_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc new file mode 100644 index 0000000000..4582e31474 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/cast.h" + +void Uint8ToFloat32(const uint8_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Uint8ToInt8(const uint8_t *input, int8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int8_t)(input[i] - 128); + } +} + +void Int8ToUint8(const int8_t *input, uint8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (uint8_t)(input[i] + 128); + } +} + +void Int32ToFloat32(const int32_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +#ifdef ENABLE_FP16 +void Float32ToFloat16(const float *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Float16ToFloat32(const float16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h new file mode 100644 index 0000000000..616e0b89fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +// For cast. +struct CastParameter { + OpParameter op_parameter_; + int src_type_; + int dst_type_; +}; + +void Uint8ToFloat32(const uint8_t *input, float *output, int number); +void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); +void Int8ToUint8(const int8_t *input, uint8_t *output, int number); +void Int32ToFloat32(const int32_t *input, float *output, int number); +#ifdef ENABLE_FP16 +void Float32ToFloat16(const float *input, float16_t *output, int number); +void Float16ToFloat32(const float16_t *input, float *output, int number); +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.cc new file mode 100644 index 0000000000..f50036bf3c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/common_func.h" + +#ifndef __aarch64__ +void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int a_index = c * a_stride + r * C4NUM; + int b_index = c * b_stride + r * C4NUM; + int c_index = c * c_stride + r * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst[c_index + i] = a_ptr[a_index + i] + b_ptr[b_index + i]; + } + } + } + return; +} + +void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int a_index = c * a_stride + r * C4NUM; + int b_index = c * b_stride + r * C4NUM; + int c_index = c * c_stride + r * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst[c_index + i] = a_ptr[a_index + i] - b_ptr[b_index + i]; + } + } + } + return; +} +#endif + +void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, + size_t c_stride, size_t x_stride) { + /* U2 = P1 + P6 */ + MatrixAdd(x_ptr, c12, c12, x_stride, c_stride, c_stride, row, col); + /* U3 = U2 + P7 */ + MatrixAdd(c12, c21, c21, c_stride, c_stride, c_stride, row, col); + /* U4 = U2 + P5 */ + MatrixAdd(c12, c22, c12, c_stride, c_stride, c_stride, row, col); + /* U7 = U3 + P5 */ + MatrixAdd(c21, c22, c22, c_stride, c_stride, c_stride, row, col); + /* U5 = U4 + P3 */ + MatrixAdd(c12, c11, c12, c_stride, c_stride, c_stride, row, col); + return; +} + +void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6, int size) { + for (int oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size, oc_mod = oc % size; + for (int hw = 0; hw < plane_size; hw++) { + int src_index = oc_div * size * plane_size + hw * size + oc_mod; + int dst_index = hw * stride + oc; + float value = src_ptr_[src_index]; + if (bias_ptr != nullptr) { + value = value + bias_ptr[oc]; + } + value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value); + value = (is_relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } + return; +} + +void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { +#ifndef ENABLE_ARM64 + PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C4NUM); +#else + if (bias_ptr != nullptr) { + if (is_relu) { + C4BiasAddRelu(out_ptr, c4_out_ptr, bias_ptr, output_channel, plane_size, stride * sizeof(float)); + } else if (is_relu6) { + C4BiasAddRelu6(out_ptr, c4_out_ptr, bias_ptr, output_channel, plane_size, stride * sizeof(float)); + } else { + C4BiasAdd(out_ptr, c4_out_ptr, bias_ptr, output_channel, plane_size, stride * sizeof(float)); + } + } else { + if (is_relu) { + C4Relu(out_ptr, c4_out_ptr, output_channel, plane_size, stride * sizeof(float)); + } else if (is_relu6) { + C4Relu6(out_ptr, c4_out_ptr, output_channel, plane_size, stride * sizeof(float)); + } else { + // do nothing + } + } +#endif + return; +} + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { + PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM); + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h new file mode 100644 index 0000000000..3802a82cf2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_COMMON_FUNC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_COMMON_FUNC_H_ + +#include +#include +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6); +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6); +void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col); +void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col); +void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, + size_t c_stride, size_t x_stride); + +#ifdef ENABLE_ARM +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +#endif + +#ifdef ENABLE_ARM64 +void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size); +void Relu6(float *data, size_t element4); +void Relu(float *data, size_t element4); +void C4BiasAdd(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride); +void C4BiasAddRelu(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride); +void C4BiasAddRelu6(float *dst, const float *input, const float* bias, size_t oc, size_t plane_size, size_t stride); +void C4Relu(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride); +void C4Relu6(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/concat.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/concat.cc new file mode 100644 index 0000000000..8c50292552 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/concat.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/concat.h" +#include + +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) { + int before_axis_size = 1; + for (int i = 0; i < axis; ++i) { + before_axis_size *= inputs_output_shape[0][i]; + } + // sizeof float/int32 + int after_axis_size = 4; + for (size_t i = axis + 1; i < shape_size; ++i) { + after_axis_size *= inputs_output_shape[0][i]; + } + int axis_offset = 0; + uint8_t *dst_base = reinterpret_cast(output); + size_t output_stride = after_axis_size * inputs_output_shape[input_num][axis]; + for (int i = 0; i < input_num; ++i) { + uint8_t *src_base = reinterpret_cast(input[i]); + size_t input_stride = after_axis_size * inputs_output_shape[i][axis]; + for (int j = 0; j < before_axis_size; ++j) { + uint8_t *src = src_base + j * input_stride; + uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size; + memcpy(dst, src, input_stride); + } + axis_offset += inputs_output_shape[i][axis]; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/concat.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/concat.h new file mode 100644 index 0000000000..a692fdc1cb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/concat.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONCAT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONCAT_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONCAT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc new file mode 100644 index 0000000000..624ffd7565 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc @@ -0,0 +1,228 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" +#include +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" + +// fp32 conv common +void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data, + float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param, + GEMM_FUNC_FP32 gemm_func) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int thread_count = conv_param->thread_num_; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * TILE_NUM * unit_size; + + // we accumulate 4 channels per time for input blocks + int conv_depth = kernel_h * kernel_w; + // bytes from one output's i-th channel to the next output's i-th channel + // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward + size_t output_offset = out_channel * sizeof(float); + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + float *gemm_input = packed_input + thread_id * unit_size * TILE_NUM + gemm_in_batch_offset; + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset; + if (real_cal_num == TILE_NUM) { + float *gemm_output = output_data + out_offset; + gemm_func(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0, + conv_param->is_relu_, conv_param->is_relu6_); + } else { + // res part + gemm_func(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, + 0, conv_param->is_relu_, conv_param->is_relu6_); + memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float)); + } + } + } +} + +// fp32 conv1x1 strassen matmul +int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr, + StrassenMatMulParameter matmul_param) { + return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr); +} + +// fp32 conv winograd +void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func, + GEMM_FUNC_FP32 gemm_func) { + int thread_num = conv_param->thread_num_; + int input_unit = conv_param->input_unit_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_unit = conv_param->output_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, out_unit); + int out_h_block = UP_DIV(conv_param->output_h_, out_unit); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int out_channel = conv_param->output_channel_; + int out_batch = conv_param->output_batch_; + int oc4 = UP_DIV(out_channel, C4NUM); + int input_unit_square = input_unit * input_unit; + size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float); + bool is_relu = conv_param->is_relu_; + bool is_relu6 = conv_param->is_relu6_; + + float *trans_input = buffer_list[0]; + float *gemm_out = buffer_list[1]; + float *tmp_out_data = buffer_list[2]; + float *tmp_data = buffer_list[3]; + int trans_input_offset = TILE_NUM * input_unit_square * ic4 * C4NUM; + int gemm_out_offset = TILE_NUM * input_unit_square * oc4 * C4NUM; + int tmp_data_offset = input_unit_square * C4NUM; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + for (int b = 0; b < in_batch; b++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { + int out_tile_index = thread_id * TILE_NUM; + int cal_num = output_count - thread_id * TILE_NUM; + cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num; + WinogradInputTransform(input_data, trans_input + task_id * trans_input_offset, + tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, + input_trans_func); + // step 3 : gemm + gemm_func(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, nullptr, + input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0); + + // step 4 : output transform + WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, output_trans_func); + } + } + // get real output + UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel, + out_unit); + int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; + if (is_relu) { + ReluFp32(output_data, output_num); + } else if (is_relu6) { + Relu6Fp32(output_data, output_num); + } else { + // do nothing + } +} + +void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, + int output_unit) { + int out_h_block_num = UP_DIV(height, output_unit); + int out_w_block_num = UP_DIV(width, output_unit); + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_batch_offset = b * c4 * C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; + int dst_batch_offset = b * height * width * channel; + for (int h = 0; h < height; h++) { + int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit); + int dst_h_offset = dst_batch_offset + h * width * channel; + for (int w = 0; w < width; w++) { + int src_w_offset = src_h_offset + w * C4NUM; + int dst_w_offset = dst_h_offset + w * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c4_offset = src_w_offset + c * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; + int dst_c4_offset = dst_w_offset + c * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32(dst + dst_c4_offset, vld1q_f32(src + src_c4_offset)); +#else + dst[dst_c4_offset] = src[src_c4_offset]; + dst[dst_c4_offset + 1] = src[src_c4_offset + 1]; + dst[dst_c4_offset + 2] = src[src_c4_offset + 2]; + dst[dst_c4_offset + 3] = src[src_c4_offset + 3]; +#endif + } + int c_res = channel - (c4 - 1) * C4NUM; + int src_c_res_offset = (c4 - 1) * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; + int dst_c_res_offset = (c4 - 1) * C4NUM; + for (int c = 0; c < c_res; c++) { + int src_c4_res_offset = src_w_offset + src_c_res_offset + c; + int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c; + dst[dst_c4_res_offset] = src[src_c4_res_offset]; + } + } + } + } +} + +// fp32 conv3x3 +void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) { + int thread_count = conv_param->thread_num_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int output_channel = conv_param->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int input_unit_square = 4 * 4; + bool is_relu = conv_param->is_relu_; + bool is_relu6 = conv_param->is_relu6_; + float *tile_buffer = buffer_list[0]; + float *block_unit_buffer = buffer_list[1]; + float *tmp_dst_buffer = buffer_list[2]; + float *nc4hw4_out = buffer_list[3]; + int tile_buffer_offset = TILE_NUM * input_unit_square * ic4 * C4NUM; + int block_unit_buffer_offset = input_unit_square * C4NUM; + int tmp_dst_buffer_offset = TILE_NUM * input_unit_square * oc4 * C4NUM; + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + Conv3x3Fp32InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + gemm_func(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, nullptr, input_unit_square, ic4, oc4 * C4NUM, + oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0); + + Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out, bias_data, start_index, + real_cal_num, out_w_block, conv_param); + } + PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); + } + int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; + if (is_relu) { + ReluFp32(output_data, output_num); + } else if (is_relu6) { + Relu6Fp32(output_data, output_num); + } else { + // do nothing + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.h new file mode 100644 index 0000000000..97a2b08e70 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h" +#include "src/runtime/kernel/arm/nnacl/winograd_utils.h" + +using TmpBufferAddress = float *; +typedef void (*GEMM_FUNC_FP32)(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, + size_t relu, size_t relu6); + +// fp32 convolution common (im2col+gemm) +void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data, + float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param, + GEMM_FUNC_FP32 gemm_func); + +// fp32 conv1x1 strassen matmul +int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr, + StrassenMatMulParameter matmul_param); + +// fp32 convolution winograd +void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func, + GEMM_FUNC_FP32 gemm_func); + +void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit); + +// fp32 conv3x3 +void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc new file mode 100644 index 0000000000..0b8123c2e1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc @@ -0,0 +1,728 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" +#include "src/runtime/kernel/arm/nnacl/fp32/common_func.h" +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" +#ifdef ENABLE_ARM64 +#include +#endif + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + int left = 0; + int right = conv_param->output_w_; + int top = 0; + int bottom = conv_param->output_h_; + + for (; left * conv_param->stride_w_ < conv_param->pad_w_; left++) { + } + for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_w_ + conv_param->kernel_w_ * conv_param->dilation_w_ > + conv_param->input_w_ && + right > left; + right--) { + } + for (; top * conv_param->stride_h_ < conv_param->pad_h_; top++) { + } + for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_h_ + conv_param->kernel_h_ * conv_param->dilation_h_ > + conv_param->input_h_ && + bottom > top; + bottom--) { + } + sliding->left_ = left; + sliding->right_ = right; + sliding->top_ = top; + sliding->bottom_ = bottom; + sliding->c_block_ = UP_DIV(conv_param->output_channel_, block); + sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block; + + sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_; + sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_; + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_h_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block; +} + +/*conv depthwise fp32 begin*/ +void DepthwiseBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w, bool is_relu, bool is_relu6) { + const float *src_kh = src; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst[c] = 0; + } + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src_kw); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void DepthwiseBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + float *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sliding->in_h_step_; + + float *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sliding->block_channel_; + + const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + + DepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, conv_param->is_relu_, + conv_param->is_relu6_); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float *src_kh = src_w; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp32: sliding window +void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DepthwiseBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DepthwiseBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_); +#else + DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, + conv_param->is_relu_, conv_param->is_relu6_); +#endif + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*conv depthwise fp32 end*/ + +/*conv depthwise 3x3 fp32 begin*/ +void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) { + for (int c = 0; c < oc4; c++) { + float *src = weight + c * C4NUM * 9; + float *dst = trans_weight + c * C4NUM * 16; +#ifdef ENABLE_ARM + float32x4_t g00 = vld1q_f32(src); + float32x4_t g01 = vld1q_f32(src + 4); + float32x4_t g02 = vld1q_f32(src + 2 * 4); + float32x4_t g10 = vld1q_f32(src + 3 * 4); + float32x4_t g11 = vld1q_f32(src + 4 * 4); + float32x4_t g12 = vld1q_f32(src + 5 * 4); + float32x4_t g20 = vld1q_f32(src + 6 * 4); + float32x4_t g21 = vld1q_f32(src + 7 * 4); + float32x4_t g22 = vld1q_f32(src + 8 * 4); + + float32x4_t dst00 = g00; + float32x4_t dst01 = g01; + float32x4_t dst02 = g02; + + float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5)); + float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5)); + float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5)); + float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5)); + float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst30 = g20; + float32x4_t dst31 = g21; + float32x4_t dst32 = g22; + + float32x4_t m00 = dst00; + float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5)); + float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5)); + float32x4_t m03 = dst02; + + float32x4_t m10 = dst10; + float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5)); + float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5)); + float32x4_t m13 = dst12; + + float32x4_t m20 = dst20; + float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5)); + float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5)); + float32x4_t m23 = dst22; + + float32x4_t m30 = dst30; + float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5)); + float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5)); + float32x4_t m33 = dst32; + + vst1q_f32(dst, m00); + vst1q_f32(dst + 4, m01); + vst1q_f32(dst + 8, m02); + vst1q_f32(dst + 12, m03); + vst1q_f32(dst + 16, m10); + vst1q_f32(dst + 20, m11); + vst1q_f32(dst + 24, m12); + vst1q_f32(dst + 28, m13); + vst1q_f32(dst + 32, m20); + vst1q_f32(dst + 36, m21); + vst1q_f32(dst + 40, m22); + vst1q_f32(dst + 44, m23); + vst1q_f32(dst + 48, m30); + vst1q_f32(dst + 52, m31); + vst1q_f32(dst + 56, m32); + vst1q_f32(dst + 60, m33); +#else + for (int j = 0; j < C4NUM; j++) { + float *local_ptr = src + j; + float dst00 = local_ptr[0]; + float dst01 = (local_ptr + 4)[0]; + float dst02 = (local_ptr + 8)[0]; + + float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst30 = (local_ptr + 24)[0]; + float dst31 = (local_ptr + 28)[0]; + float dst32 = (local_ptr + 32)[0]; + + float m00 = dst00; + float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02; + float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02; + float m03 = dst02; + + float m10 = dst10; + float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12; + float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12; + float m13 = dst12; + + float m20 = dst20; + float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22; + float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22; + float m23 = dst22; + + float m30 = dst30; + float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32; + float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32; + float m33 = dst32; + + *(dst + j) = m00; + *(dst + j + 4) = m01; + *(dst + j + 8) = m02; + *(dst + j + 12) = m03; + + *(dst + j + 16) = m10; + *(dst + j + 20) = m11; + *(dst + j + 24) = m12; + *(dst + j + 28) = m13; + + *(dst + j + 32) = m20; + *(dst + j + 36) = m21; + *(dst + j + 40) = m22; + *(dst + j + 44) = m23; + + *(dst + j + 48) = m30; + *(dst + j + 52) = m31; + *(dst + j + 56) = m32; + *(dst + j + 60) = m33; + } +#endif + } +} + +void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float *block_buffer, int out_h_block, + int out_w_block, const ConvParameter *conv_param) { + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int input_unit = 4; + memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float)); + + for (int oh = 0; oh < out_h_block; oh++) { + int ih = oh * 2 - conv_param->pad_h_; + int real_h_start = ih > 0 ? 0 : -ih; + int real_h_end = (ih + input_unit) < conv_param->input_h_ ? input_unit : (conv_param->input_h_ - ih); + for (int ow = 0; ow < out_w_block; ow++) { + int iw = ow * 2 - conv_param->pad_w_; + int real_w_start = iw > 0 ? 0 : -iw; + int real_w_end = (iw + input_unit) < conv_param->input_w_ ? input_unit : (conv_param->input_w_ - iw); + + memset(block_buffer, 0, 16 * C4NUM * sizeof(float)); + int src_plane_offset = ic4 * C4NUM * (ih * conv_param->input_w_ + iw); + for (int h = real_h_start; h < real_h_end; h++) { + int src_h_offset = src_plane_offset + (h * conv_param->input_w_) * ic4 * C4NUM; + int dst_h_offset = (h * input_unit) * C4NUM; + for (int w = real_w_start; w < real_w_end; w++) { + int src_w_offset = src_h_offset + w * ic4 * C4NUM; + int dst_w_offset = dst_h_offset + w * C4NUM; + float *src_addr = (float *)(input_data) + src_w_offset; + float *dst_addr = block_buffer + dst_w_offset; +#ifdef ENABLE_NEON + vst1q_f32(dst_addr, vld1q_f32(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + (dst_addr + k)[0] = (src_addr + k)[0]; + } +#endif + } + } + int trans_offset = (oh * out_w_block + ow) * 16 * C4NUM; + Conv3x3Fp32InputUnit(block_buffer, trans_input + trans_offset, C4NUM); + } + } +} + +// todo yangruoqi: implement assembly +void ConvDw3x3Fp32Winograd(float *trans_buffer, const float *weight, int out_h_block, int out_w_block) { + int unit = 4; + for (int oh = 0; oh < out_h_block; oh++) { + float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM; + for (int ow = 0; ow < out_w_block; ow++) { + float *buf_ow = buf_oh + ow * 16 * C4NUM; + for (int kh = 0; kh < unit; kh++) { + float *buf_kh = buf_ow + kh * unit * C4NUM; + const float *weight_kh = weight + kh * unit * C4NUM; + for (int kw = 0; kw < unit; kw++) { + float *buf_kw = buf_kh + kw * C4NUM; + const float *weight_kw = weight_kh + kw * C4NUM; + for (int c = 0; c < C4NUM; c++) { + buf_kw[c] = buf_kw[c] * weight_kw[c]; + } + } + } + } + } +} + +void ConvDw3x3Fp32OutputUnit(float *src_buf, float *dst_output, const float *bias, int channel, int output_w, + bool h_in_range, bool w_in_range, bool is_relu, bool is_relu6) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias); + + float32x4_t s00 = vld1q_f32(src_buf); + float32x4_t s01 = vld1q_f32(src_buf + 4); + float32x4_t s02 = vld1q_f32(src_buf + 8); + float32x4_t s03 = vld1q_f32(src_buf + 12); + + float32x4_t s10 = vld1q_f32(src_buf + 16); + float32x4_t s11 = vld1q_f32(src_buf + 20); + float32x4_t s12 = vld1q_f32(src_buf + 24); + float32x4_t s13 = vld1q_f32(src_buf + 28); + + float32x4_t s20 = vld1q_f32(src_buf + 32); + float32x4_t s21 = vld1q_f32(src_buf + 36); + float32x4_t s22 = vld1q_f32(src_buf + 40); + float32x4_t s23 = vld1q_f32(src_buf + 44); + + float32x4_t s30 = vld1q_f32(src_buf + 48); + float32x4_t s31 = vld1q_f32(src_buf + 52); + float32x4_t s32 = vld1q_f32(src_buf + 56); + float32x4_t s33 = vld1q_f32(src_buf + 60); + + float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20); + float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21); + float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22); + float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23); + + float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30); + float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31); + float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32); + float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33); + + float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr); + float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr); + float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr); + float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr); + + vst1q_f32(dst_output, d00); + if (w_in_range) { + vst1q_f32(dst_output + channel, d01); + } + if (h_in_range) { + vst1q_f32(dst_output + output_w * channel, d10); + if (w_in_range) { + vst1q_f32(dst_output + output_w * channel + channel, d11); + } + } +#else + for (int i = 0; i < C4NUM; i++) { + const float *local_ptr = src_buf + i; + const float *bias_ptr = bias + i; + + float s00 = local_ptr[0]; + float s01 = (local_ptr + 4)[0]; + float s02 = (local_ptr + 8)[0]; + float s03 = (local_ptr + 12)[0]; + + float s10 = (local_ptr + 16)[0]; + float s11 = (local_ptr + 20)[0]; + float s12 = (local_ptr + 24)[0]; + float s13 = (local_ptr + 28)[0]; + + float s20 = (local_ptr + 32)[0]; + float s21 = (local_ptr + 36)[0]; + float s22 = (local_ptr + 40)[0]; + float s23 = (local_ptr + 44)[0]; + + float s30 = (local_ptr + 48)[0]; + float s31 = (local_ptr + 52)[0]; + float s32 = (local_ptr + 56)[0]; + float s33 = (local_ptr + 60)[0]; + + float t00 = s00 + s10 + s20; + float t01 = s01 + s11 + s21; + float t02 = s02 + s12 + s22; + float t03 = s03 + s13 + s23; + + float t10 = s10 - s20 - s30; + float t11 = s11 - s21 - s31; + float t12 = s12 - s22 - s32; + float t13 = s13 - s23 - s33; + + float d00 = t00 + t01 + t02 + bias_ptr[0]; + float d01 = t01 - t02 - t03 + bias_ptr[0]; + float d10 = t10 + t11 + t12 + bias_ptr[0]; + float d11 = t11 - t12 - t13 + bias_ptr[0]; + + (dst_output + i)[0] = d00; + if (w_in_range) { + (dst_output + i + channel)[0] = d01; + } + if (h_in_range) { + (dst_output + i + output_w * channel)[0] = d10; + if (w_in_range) { + (dst_output + i + output_w * channel + channel)[0] = d11; + } + } + } +#endif +} + +void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const float *bias, int out_h_block, + int out_w_block, const ConvParameter *conv_param) { + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); + bool h_in_range = true; + for (int oh = 0; oh < out_h_block; oh++) { + int real_oh = 2 * oh; + if ((oh + 1) * 2 > conv_param->output_h_) { + h_in_range = false; + } + bool w_in_range = true; + float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM; + float *output_oh = output_data + real_oh * conv_param->output_w_ * oc4 * C4NUM; + + for (int ow = 0; ow < out_w_block; ow++) { + int real_ow = 2 * ow; + if ((ow + 1) * 2 > conv_param->output_w_) { + w_in_range = false; + } + float *buf_ow = buf_oh + ow * 16 * C4NUM; + float *output_ow = output_oh + real_ow * oc4 * C4NUM; + + ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range, + conv_param->is_relu_, conv_param->is_relu6_); + } + } +} + +void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id) { + int thread_count = conv_param->thread_num_; + int output_channel = conv_param->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int out_h_block = UP_DIV(conv_param->output_h_, 2); + int out_w_block = UP_DIV(conv_param->output_w_, 2); + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + const float *input = input_data + batch * conv_param->input_h_ * conv_param->input_w_ * + UP_DIV(conv_param->input_channel_, C4NUM) * C4NUM; + float *output = output_data + batch * conv_param->output_h_ * conv_param->output_w_ * + UP_DIV(conv_param->output_channel_, C4NUM) * C4NUM; + for (int oc = task_id; oc < oc4; oc += thread_count) { + const float *weight = weight_data + oc * 16 * C4NUM; + const float *bias = bias_data + oc * C4NUM; + + ConvDw3x3Fp32InputTrans(input + oc * C4NUM, trans_buffer, block_buffer, out_h_block, out_w_block, conv_param); + + ConvDw3x3Fp32Winograd(trans_buffer, weight, out_h_block, out_w_block); + + ConvDw3x3Fp32OutputTrans(trans_buffer, output + oc * C4NUM, bias, out_h_block, out_w_block, conv_param); + } + } +} +/*conv depthwise 3x3 fp32 end*/ + +/*deconv depthwise fp32 begin*/ +void DeconvDepthwiseBorderPixel(float *dst, const float *src, const float *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + float *dst_kh = dst; + const float *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_kw); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_kw, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right, + const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const float *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float *dst_h = dst + oh * sliding->in_h_step_; + + const float *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float *dst_w = dst_h + ow * sliding->block_channel_; + + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) { + float *dst_k = dst; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + dst_k[c] += bias[c]; + dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]); + dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]); + } + dst_k += block_channel; + } +} + +// deconv depthwise fp32: sliding window +void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DeconvDepthwiseBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_, + conv_param, sliding); + DeconvDepthwiseBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DeconvDepthwiseBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + +#ifdef ENABLE_ARM64 + DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float)); +#else + DeconvDepthwiseCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, + sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, + sliding->in_kw_step_); +#endif + } + DeconvDepthwisePostFunc(dst_data, bias, sliding->block_channel_, conv_param); + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise fp32 end*/ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h new file mode 100644 index 0000000000..e83b6b6dcf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_DEPTHWISE_H_ + +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +struct SlidingWindowParam { + int left_; + int right_; + int top_; + int bottom_; + int c_block_; + int block_channel_; + int out_step_; + int out_h_step_; + int in_step_; + int in_h_step_; + int in_sh_step_; // stride H + int in_sw_step_; // stride W + int in_kh_step_; // kernel H + int in_kw_step_; // kernel W + int kernel_step_; +}; + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4); + +void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id); + +void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.cc new file mode 100644 index 0000000000..4336b10af8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/fp32/crop.h" +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/crop_parameter.h" + +void Pad4DOffset(CropParameter *crop_param, int64_t *offset) { + int axis = crop_param->axis_; + for (int i = DIMENSION_4D - 1; i >= 0; --i) { + int offset_index = i - axis; + if (offset_index >= 0) { + offset[i] = crop_param->offset_[offset_index]; + } else { + offset[i] = 0; + } + } +} + +void Crop4D(const float *input, float *output, const int *in_shape, const int *out_shape, CropParameter *crop_param) { + int64_t offset_pad[DIMENSION_4D]; + Pad4DOffset(crop_param, offset_pad); + int out_shape1 = out_shape[1]; + int out_shape2 = out_shape[2]; + int out_shape3 = out_shape[3]; + size_t out_stride2 = out_shape3; + size_t out_stride1 = out_stride2 * out_shape2; + size_t out_stride0 = out_stride1 * out_shape1; + size_t in_stride2 = in_shape[3]; + size_t in_stride1 = in_stride2 * in_shape[2]; + size_t in_stride0 = in_stride1 * in_shape[1]; + size_t copy_size = out_shape3 * sizeof(float); + size_t count_per_thread = UP_DIV(out_shape1, crop_param->op_parameter_.thread_num_); + int thread_id = crop_param->thread_id_; + size_t thread_stride = thread_id * count_per_thread; + for (int i = 0; i < out_shape[0]; ++i) { + size_t out_offset0 = i * out_stride0; + size_t in_offset0 = (i + offset_pad[0]) * in_stride0 + offset_pad[3]; + for (size_t j = 0; j < count_per_thread; ++j) { + size_t k = j + thread_stride; + if (k >= out_shape1) { + break; + } + size_t out_offset1 = k * out_stride1 + out_offset0; + size_t in_offset1 = (k + offset_pad[1]) * in_stride1 + in_offset0; + for (int l = 0; l < out_shape2; ++l) { + size_t out_offset = l * out_stride2 + out_offset1; + size_t in_offset = (l + offset_pad[2]) * in_stride2 + in_offset1; + memcpy(output + out_offset, input + in_offset, copy_size); + } + } + } +} + +void Crop4DNoParallel(const float *input, float *output, const int *in_shape, const int *out_shape, + CropParameter *crop_param) { + int64_t offset_pad[DIMENSION_4D]; + Pad4DOffset(crop_param, offset_pad); + size_t in_dim2_stride = in_shape[3]; + size_t in_dim1_stride = in_shape[2] * in_dim2_stride; + size_t in_dim0_stride = in_dim1_stride * in_shape[1]; + size_t offset_3 = offset_pad[3]; + size_t out_offset = 0; + size_t copy_num = out_shape[3]; + size_t copy_size = copy_num * sizeof(float); + size_t in_dim0_end = offset_pad[0] + out_shape[0]; + size_t in_dim1_end = offset_pad[1] + out_shape[1]; + size_t in_dim2_end = offset_pad[2] + out_shape[2]; + for (int i = offset_pad[0]; i < in_dim0_end; ++i) { + size_t dim0_offset = i * in_dim0_stride + offset_3; + for (int j = offset_pad[1]; j < in_dim1_end; ++j) { + size_t dim1_offset = j * in_dim1_stride + dim0_offset; + for (int k = offset_pad[2]; k < in_dim2_end; ++k) { + size_t in_offset = dim1_offset + k * in_dim2_stride; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += copy_num; + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.h new file mode 100644 index 0000000000..c728f8aee9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CROP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CROP_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/crop_parameter.h" + +#define CROP_OFFSET_MAX_SIZE 4 + +void Crop4D(const float *input, float *output, const int *in_shape, const int *out_shape, CropParameter *crop_param); +void Crop4DNoParallel(const float *input, float *output, const int *in_shape, const int *out_shape, + CropParameter *crop_param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_CROP_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.cc new file mode 100644 index 0000000000..16a43f94ec --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/deconv.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane) { + /* ichwoc(nhwc) -> oc4 * h * w * incUP4 * 4 */ + int ic_up4 = UP_ROUND(input_channel, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM; + int oc4mod = oc % C4NUM; + for (int ic = 0; ic < input_channel; ic++) { + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * plane * output_channel + hw * output_channel + oc; + int dst_index = oc4div * ic_up4 * plane * C4NUM + hw * ic_up4 * C4NUM + ic * C4NUM + oc4mod; + dst[dst_index] = weight[src_index]; + } + } + } + return; +} + +int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer, + StrassenMatMulParameter matmul_param) { + return StrassenMatmul(input, weight, output, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_buffer); +} + +int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, + ConvParameter *conv_param) { + /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_DIV(output_channel, C8NUM); + int in_plane8 = UP_ROUND(input_plane, C8NUM); + + for (int c = 0; c < oc8; c++) { + float *dst_ptr = tmp + c * output_plane * C8NUM; + const float *src_ptr = src + c * in_plane8 * kernel_plane * C8NUM; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * conv_param->input_w_ * C8NUM + iw * C8NUM + + kh * in_plane8 * conv_param->kernel_w_ * C8NUM + kw * in_plane8 * C8NUM; + int dst_index = oh * conv_param->output_w_ * C8NUM + ow * C8NUM + + kh * conv_param->dilation_h_ * conv_param->output_w_ * C8NUM + + kw * conv_param->dilation_w_ * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_ptr[dst_index + i] += src_ptr[src_index + i]; + } + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_, + conv_param->is_relu6_); + return NNACL_OK; +} + +int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, + int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) { + int oc4 = UP_DIV(output_channel, C4NUM); + for (int c = 0; c < oc4; c++) { + float *dst_ptr = tmp_c4 + c * output_plane * C4NUM; + const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM; + memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM + + kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM; + int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM + + kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM + + kw * conv_param->dilation_w_ * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst_ptr[dst_index + i] += src_ptr[src_index + i]; + } + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc4*/ + + PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_, + conv_param->is_relu6_); + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h new file mode 100644 index 0000000000..003fb36c1f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_DECONV_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_DECONV_H_ + +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); + +int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer, + StrassenMatMulParameter matmul_param); + +int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, + int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param); +int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, + ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_DECONV_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc new file mode 100644 index 0000000000..964041fa3c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" +#include +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "mindspore/core/utils/log_adapter.h" + +void l2_regulate(float *data, int size, float max_norm) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += data[i]; + } + if (sum != 0) { + for (int i = 0; i < size; ++i) { + data[i] *= max_norm / sum; + } + } + return; +} + +int CopyData(float *input_data, int *ids, float *output_data, int num, EmbeddingLookupParameter *parameter) { + if (ids[num] >= parameter->layer_num_ || ids[num] < 0) { + MS_LOG(ERROR) << "Embedding lookup index out of range"; + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + float *out_data = output_data + num * parameter->layer_size_; + float *in_data = input_data + ids[num] * parameter->layer_size_; + if (!parameter->is_regulated_[ids[num]]) { + l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_); + parameter->is_regulated_[ids[num]] = true; + } + + memcpy(out_data, in_data, sizeof(float) * parameter->layer_size_); + return NNACL_OK; +} + +int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id) { + for (size_t i = task_id; i < parameter->ids_size_; i += parameter->thread_num) { + int ret = CopyData(input_data, ids, output_data, i, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h new file mode 100644 index 0000000000..fa9f0ce5da --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct EmbeddingLookupParameter { + OpParameter op_parameter_; + bool *is_regulated_; + float max_norm_; + int ids_size_; + int layer_size_; + int layer_num_; + int thread_num; +}; + +int EmbeddingLookup(float *input_data, int *ids, float *output_data, EmbeddingLookupParameter *parameter, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/expandDims.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/expandDims.cc new file mode 100644 index 0000000000..65c2a58e90 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/expandDims.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/expandDims.h" +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int ExpandDims(float *input_ptr, float *output_ptr, size_t data_size) { + memcpy(output_ptr, input_ptr, data_size); + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/expandDims.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/expandDims.h new file mode 100644 index 0000000000..21b1fc9789 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/expandDims.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_EXPANDDIMS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_EXPANDDIMS_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ExpandDimsParameter { + OpParameter op_parameter_; + int dim_; +}; + +int ExpandDims(float *input_ptr, float *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_EXPANDDIMS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/fill.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/fill.cc new file mode 100644 index 0000000000..38bc4dbb27 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/fill.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/fill.h" + +int Fill(float *output, int size, float data) { + for (int i = 0; i < size; ++i) { + output[i] = data; + } + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/fill.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/fill.h new file mode 100644 index 0000000000..905afafd3f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/fill.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FILL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FILL_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +#define FILL_DIMS_MAX_SIZE 4 + +struct FillParameter { + OpParameter op_parameter_; + int dims_[FILL_DIMS_MAX_SIZE]; + int num_dims_; +}; + +int Fill(float *output, int size, float data); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FILL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.cc new file mode 100644 index 0000000000..c6bef4c940 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/gather.h" +#include + +inline int Stride(int *shape, int rank, int index) { + int i, stride = 1; + for (i = index + 1; i < rank; ++i) { + stride *= shape[i]; + } + return stride; +} + +int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size, + float *output) { + int i, m; + for (m = 0; m < outer_size; ++m) { + auto inputm = input + inner_size * m * limit; + auto outputm = output + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return -1; + } + memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size); + } + } + return 0; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h new file mode 100644 index 0000000000..95a4046c46 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct GatherParameter { + OpParameter op_parameter_; + int axis_; + int batchDims_; +}; + +int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size, + float *output); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gatherNd.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gatherNd.cc new file mode 100644 index 0000000000..fd62a5ab46 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gatherNd.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/gatherNd.h" +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int GatherNd(float *input, float *output, int *in_offset, int area, int count) { + int i = 0; + for (i = 0; i < count; i++) { + (void)memcpy(output + area * i, input + in_offset[i], area * sizeof(float)); + } + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gatherNd.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gatherNd.h new file mode 100644 index 0000000000..93c97002a2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gatherNd.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHERND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHERND_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct GatherNdParameter { + OpParameter op_parameter_; + int batchDims_; +}; + +int GatherNd(float *input, float *output, int *in_offset, int area, int count); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHERND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.cc new file mode 100644 index 0000000000..83705a8fbd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" + +static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, + int ldc) { + int i, j, k; + for (i = 0; i < M; ++i) { + for (k = 0; k < K; ++k) { + float a = alpha * mat_a[i * lda + k]; + for (j = 0; j < N; ++j) { + mat_c[i * ldc + j] += a * mat_B[k * ldb + j]; + } + } + } +} + +static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, + int ldc) { + int i, j, k; + for (i = 0; i < M; ++i) { + for (j = 0; j < N; ++j) { + float sum = 0; + for (k = 0; k < K; ++k) { + sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k]; + } + mat_c[i * ldc + j] += sum; + } + } +} + +static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, + int ldc) { + int i, j, k; + for (i = 0; i < M; ++i) { + for (k = 0; k < K; ++k) { + float a = alpha * mat_a[k * lda + i]; + for (j = 0; j < N; ++j) { + mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; + } + } + } +} + +static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, + int ldc) { + int i, j, k; + for (i = 0; i < M; ++i) { + for (j = 0; j < N; ++j) { + float sum = 0; + for (k = 0; k < K; ++k) { + sum += alpha * mat_a[i + k * lda] * mat_b[k + j * ldb]; + } + mat_c[i * ldc + j] += sum; + } + } +} + +// mat_c = alpha*op( mat_a )*op( mat_b ) + beta*C +// M - number of rows of matrix a +// N - number of cols of matrix b +// K - number of cols of matrix a + +void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, + int ldb, float beta, float *mat_c, int ldc) { + // printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc); + if (beta >= 0.f && beta <= 0.f) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + mat_c[i * ldc + j] = 0; + } + } + } else if (beta < 1.f || beta > 1.f) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + mat_c[i * ldc + j] *= beta; + } + } + } + + int t; + + for (t = 0; t < M; ++t) { + if (!transpose_a && !transpose_b) { + gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); + } else if (transpose_a && !transpose_b) { + gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); + } else if (!transpose_a && transpose_b) { + gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); + } else { + gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.h new file mode 100644 index 0000000000..b3e30d09da --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.h @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ + +void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, + int ldb, float beta, float *mat_c, int ldc); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/local_response_norm.cc new file mode 100644 index 0000000000..0120e93645 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/local_response_norm.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/local_response_norm.h" + +int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, + LocalResponseNormParameter *param) { + int i, j, k; + int left, right; + + float depth_radius = param->depth_radius_; + float bias = param->bias_; + float alpha = param->alpha_; + float beta = param->beta_; + + for (i = 0; i < out_size; i++) { + float *in_data = input_ptr + i * channel; + float *out_data = output_ptr + i * channel; + + for (j = 0; j < channel; j++) { + left = MSMAX(0, j - depth_radius); + right = MSMIN(channel - 1, j + depth_radius); + + float sum = 0.0; + for (k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(pow((double)(sum * alpha + bias), -beta)); + } + } + return 0; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/local_response_norm.h new file mode 100644 index 0000000000..ad10be79e9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/local_response_norm.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LOCAL_RESPONSE_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LOCAL_RESPONSE_NORM_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct LocalResponseNormParameter { + OpParameter op_parameter_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, + LocalResponseNormParameter *param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_LOCAL_RESPONSE_NORM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.cc new file mode 100644 index 0000000000..7a3a07ed0d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/lstm.h" +#include +#include "src/runtime/kernel/arm/nnacl/fp32/activation.h" +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +void InitGate(float *gate_buffer, const float *bias, LstmParameter *lstm_parm) { + int gate_offest = 0; + for (int l = 0; l < 4; l++) { + int batch_offest = gate_offest; + int bias_offest = l * lstm_parm->hidden_size_; + for (int b = 0; b < lstm_parm->batch_; b++) { + memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float)); + batch_offest += lstm_parm->hidden_size_; + } + gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; + } +} + +// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] +void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + float res = 0; + for (int i = 0; i < inner_size; i++) { + res += input[r * inner_size + i] * weight[c * inner_size + i]; + } + output[r * cols + c] += res; + } + } +} + +void ElementMulAcc(float *input0, float *input1, float *output, int element_size) { + for (int index = 0; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +void UpdataState(float *cell_state, float *forget_gate, float *input_gate, float *cell_gate, int batch, + int hidden_size) { + ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); +} + +void UpdataOutput(float *cell_state, float *output_gate, float *hidden_state, int batch, int hidden_size) { + Tanh(cell_state, batch * hidden_size, hidden_state); + ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); +} + +void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, + const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, + const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, + const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, + LstmParameter *lstm_parm) { + InitGate(gate_buffer, bias, lstm_parm); + + float *input_gate = gate_buffer; + float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; + float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; + float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; + + // input * weight + MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); + MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); + MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + + // state * weight + MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + + // update input_gate + Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate); + + // update forget_gate + Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate); + + // update cell_gate + Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); + // update cell state + UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_); + + // update output_gate + Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); + // update output + UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_); + memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); +} + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, + float *hidden_state, float *cell_state, float *gate_buffer, LstmParameter *lstm_parm) { + // forward + const float *input_input_weight = weight_i; + const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; + const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; + const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; + + const float *state_input_weight = weight_h; + const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; + const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; + const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; + + for (int t = 0; t < lstm_parm->seq_len_; t++) { + const float *input_ptr = input + t * lstm_parm->input_step_; + float *output_ptr = output + t * lstm_parm->output_step_; + LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, + state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, + cell_state, gate_buffer, lstm_parm); + } + + // backward + if (lstm_parm->bidirectional_) { + input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; + input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; + input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; + input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; + + state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; + state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; + state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; + state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; + + float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; + const float *backward_bias = bias + 4 * lstm_parm->hidden_size_; + float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; + float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; + for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { + const float *input_ptr = input + t * lstm_parm->input_step_; + float *output_ptr = backward_output + t * lstm_parm->output_step_; + LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, + input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, + backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.h new file mode 100644 index 0000000000..05fff4f1dc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_LSTM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_LSTM_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct LstmParameter { + OpParameter op_parameter_; + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + int input_step_; + int output_step_; + bool bidirectional_; +}; + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, + float *hidden_state, float *cell_state, float *gate_buffer, LstmParameter *lstm_parm); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_LSTM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc new file mode 100644 index 0000000000..6c841f7f39 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" + +void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + float *src = src_ptr + r * col; + for (int c = 0; c < col; c++) { + int cd8 = c / 8; + int cm8 = c % 8; + dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; + } + } + return; +} + +void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) { + size_t row8 = row / C8NUM * C8NUM; + size_t col4 = col / C4NUM * C4NUM; + float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C8NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + + /* 8x4 row-major to col-major */ +#ifdef ENABLE_NEON + size_t stride = col * 4; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "zip1 v4.4s, v0.4s, v1.4s\n" + "zip2 v5.4s, v0.4s, v1.4s\n" + "zip1 v6.4s, v2.4s, v3.4s\n" + "zip2 v7.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "trn1 v0.2d, v4.2d, v6.2d\n" + "trn2 v1.2d, v4.2d, v6.2d\n" + "trn1 v2.2d, v5.2d, v7.2d\n" + "trn2 v3.2d, v5.2d, v7.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v8.2d, v12.2d, v14.2d\n" + "trn2 v9.2d, v12.2d, v14.2d\n" + "trn1 v10.2d, v13.2d, v15.2d\n" + "trn2 v11.2d, v13.2d, v15.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v8.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v9.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11],#16\n" + "st1 {v10.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11],#16\n" + "st1 {v11.4s}, [x11], #16\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15"); +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + return; +} + +void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col) { + int row8 = UP_ROUND(row, 8); + for (int c = 0; c < col; c++) { + int cd8 = c / 8; + int cm8 = c % 8; + for (int r = 0; r < row; r++) { + dst_ptr[r * col + c] = src_ptr[cd8 * row8 * 8 + r * 8 + cm8]; + } + } + return; +} + +void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, + int col_8_) { + /* col8-major * row8-major => col8x8-major */ + for (int row = 0; row < row_8_; row++) { + for (int col = 0; col < col_8_; col++) { + int r8div = row / 8, r8mod = row % 8; + int c8div = col / 8, c8mod = col % 8; + size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r8div * deep * 8 + d * 8 + r8mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != nullptr) value += bias[col]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + c[ci] = value; + } + } + return; +} + +void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, + int col_8_) { + MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h new file mode 100644 index 0000000000..e92f6004b7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_MATMUL_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" + +void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col); +void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); +void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col); +void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, + int row_8_, int col_8_); +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __aarch64__ +void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, + int row, int col); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/one_hot.cc new file mode 100644 index 0000000000..351e154290 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/one_hot.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/one_hot.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_param, const int tid, + const int thread_num) { + if (indices == nullptr || one_hot_param == nullptr || output == nullptr) { + return NNACL_NULL_PTR; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + float on_value = one_hot_param->on_value_; + float off_value = one_hot_param->off_value_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = indices[i * inner_size + j]; + if (index >= depth) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/one_hot.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/one_hot.h new file mode 100644 index 0000000000..9cefa121ec --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/one_hot.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ONE_HOT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ONE_HOT_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct OneHotParameter { + OpParameter op_parameter_; + int axis_; + int depth_; + float on_value_; + float off_value_; + int outer_size_; + int inner_size_; +}; + +int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_param, const int tid, + const int thread_num); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ONE_HOT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pad.cc new file mode 100644 index 0000000000..1857ed6515 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pad.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/pad.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" + +void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + const int *paddings, const int tid, const int thread_num) { + int in[4], out[4]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + float *dst = output_data + offset(output_shape, out[0], out[1], out[2], paddings[6]); + const float *src = input_data + offset(input_shape, in[0], in[1], in[2], 0); + memcpy(dst, src, input_shape[3] * sizeof(float)); + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pad.h new file mode 100644 index 0000000000..1b60946e29 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_PAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_PAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/pad_parameter.h" + +void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + const int *paddings, const int tid, const int thread_num); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_PAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling.cc new file mode 100644 index 0000000000..da45b8b94d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" +#include + +void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int c4 = UP_DIV(channel, C4NUM); + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + // input channel is equal to output channel + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_avg = vdupq_n_f32(0); +#else + float tmp_avg1 = 0; + float tmp_avg2 = 0; + float tmp_avg3 = 0; + float tmp_avg4 = 0; +#endif + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset)); +#else + tmp_avg1 += *(input_ptr + in_offset); + tmp_avg2 += *(input_ptr + in_offset + 1); + tmp_avg3 += *(input_ptr + in_offset + 2); + tmp_avg4 += *(input_ptr + in_offset + 3); +#endif + ++real_count; + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + vst1q_f32(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f32(real_count)); +#else + *(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count; + *(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count; + *(output_ptr + out_channel_offset + 2) = tmp_avg3 / (float)real_count; + *(output_ptr + out_channel_offset + 3) = tmp_avg4 / (float)real_count; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_avg / (float)real_count; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} + +void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c4 = UP_DIV(channel, C4NUM); + // input channel is equal to output channel + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_max = vdupq_n_f32(FLT_MIN); +#else + float tmp_max1 = FLT_MIN; + float tmp_max2 = FLT_MIN; + float tmp_max3 = FLT_MIN; + float tmp_max4 = FLT_MIN; +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset)); +#else + tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset)); + tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1)); + tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2)); + tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3)); +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + vst1q_f32(output_ptr + out_channel_offset, tmp_max); +#else + *(output_ptr + out_channel_offset) = tmp_max1; + *(output_ptr + out_channel_offset + 1) = tmp_max2; + *(output_ptr + out_channel_offset + 2) = tmp_max3; + *(output_ptr + out_channel_offset + 3) = tmp_max4; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_max = FLT_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling.h new file mode 100644 index 0000000000..b460ea490c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct PoolingParameter { + OpParameter op_parameter_; + QuantArg **quant_args_; + bool global_; + bool max_pooling_; + bool avg_pooling_; + bool round_ceil_; + bool round_floor_; + int window_w_; + int window_h_; + int input_w_; + int input_h_; + int input_batch_; + int input_channel_; + int output_w_; + int output_h_; + int output_batch_; + int output_channel_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int stride_w_; + int stride_h_; + int thread_num_; +}; + +void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.cc new file mode 100644 index 0000000000..7c37fd38bc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" + +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + + const float *inPtr; + for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; + + // int pad_top = padding[2]; + + float kk = static_cast(win_h * win_w); + + for (uint16_t ib = 0; ib < output_batch; ib++) { + // int in_batch_offset = batch * in_h * in_w * channel; + // int out_batch_offset = batch * output_h * output_w * channel; + // out = grads->getData(ib*grads->imgSize()); + // inPtr = in->getData(ib*in->imgSize()); + float *out; + out = &output_ptr[(ib * output_h * output_w)]; + inPtr = reinterpret_cast(&input_ptr[(ib * in_h * in_w)]); + if (1) { // in->layout() == Tensor::nhwc) + // iterate over yt + for (uint16_t yh = 0; yh < in_h; yh++) { + for (uint16_t yw = 0; yw < in_w; yw++) { + for (uint16_t ic = 0; ic < channel; ic++) { + int idx = (yw + yh * in_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; + float delta = inPtr[idx] / kk; + for (int32_t kh = 0; kh < win_h; kh++) { + int xh = yh * stride_h + kh - pad_h; + if ((xh < 0) || (xh >= output_h)) { + continue; + } + for (int32_t kw = 0; kw < win_w; kw++) { + int xw = yw * stride_w + kw - pad_w; + if ((xw < 0) || (xw >= output_w)) { + continue; + } + // out[(ic*output_h*output_w) + (xh*output_w) + xw] += delta; + out[(xw + output_w * xh) * channel + ic] += delta; + } + } + } + } + } + } else { // nchw + for (uint16_t ic = 0; ic < channel; ic++) { + // iterate over yt + for (uint16_t yh = 0; yh < in_h; yh++) { + for (uint16_t yw = 0; yw < in_w; yw++) { + int idx = (ic * in_h * in_w) + (in_w * yh) + yw; + float delta = inPtr[idx] / kk; + for (int32_t kh = 0; kh < win_h; kh++) { + int xh = yh * stride_h + kh - pad_h; + if ((xh < 0) || (xh >= output_h)) { + continue; + } + for (int32_t kw = 0; kw < win_w; kw++) { + int xw = yw * stride_w + kw - pad_w; + if ((xw < 0) || (xw >= output_w)) { + continue; + } + out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta; + } + } + } + } + } + } + } +} + +void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, PoolingParameter *pooling_param) { + // int stride_w = pooling_param->stride_w_; + // int stride_h = pooling_param->stride_h_; + // int pad_w = pooling_param->pad_l_; + // int pad_h = pooling_param->pad_u_; + // int win_w = pooling_param->window_w_; + // int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + + int out_img_size = + output_h * output_w; // Emir -- in original code this varible is calculated according to input size ?? + int ind_img_size = in_h * in_w; + // const int w_pad = (output_w + pad_w + pad_w); + + for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; + + const float *yt = reinterpret_cast(dy); + const int *pos = reinterpret_cast(indices); + float *out; + + if (1) { // grads->layout() == Tensor::nhwc) + for (int ib = 0; ib < output_batch; ib++) { + out = &(output_ptr[ib * output_w * output_w * channel]); + for (int ix = 0; ix < ind_img_size; ix++) { + for (int cix = 0; cix < channel; cix++) { + int idx = (*pos) * channel + cix; + out[idx] += *yt; + pos++; + yt++; + } + } + } + } else { + for (int ib = 0; ib < output_batch; ib++) { + out = &output_ptr[(ib * out_img_size)]; + for (int cix = 0; cix < channel; cix++) { + for (int ix = 0; ix < ind_img_size; ix++) { + int idx = cix * output_h * output_w + *pos; // cord_y*output_w + cord_x; + out[idx] += *yt; + pos++; + yt++; + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h new file mode 100644 index 0000000000..0f6049afd4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ + +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" + +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); +void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/range.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/range.cc new file mode 100644 index 0000000000..acb6c18cd5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/range.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/range.h" + +void Range(float *output_ptr, int start, int limit, int delta) { + size_t index = 0; + for (size_t i = start; i < limit; i += delta) { + output_ptr[index++] = (float)(i); + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/range.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/range.h new file mode 100644 index 0000000000..bdb40c6d25 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/range.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RANGE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RANGE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct RangeParameter { + OpParameter op_parameter_; + int dType_; + int start_; + int limit_; + int delta_; +}; + +void Range(float *output_ptr, int start, int limit, int delta); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RANGE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/rank.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/rank.cc new file mode 100644 index 0000000000..1cf475b625 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/rank.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/rank.h" + +void Rank(float* output, int rank) { + output[0] = (float)(rank); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/rank.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/rank.h new file mode 100644 index 0000000000..e8133d0ac7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/rank.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RANK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RANK_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Rank(float* output, int rank); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RANK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.cc new file mode 100644 index 0000000000..4fe680b8cd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/nnacl/fp32/reduce.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 0.0f; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = tmp / (float)axis_size; + } + } + return NNACL_OK; +} +int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 0.0f; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} +int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = -FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} +int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} +int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 1.0f; + for (i = 0; i < axis_size; i++) { + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} +int ReduceSumSquare(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 0.0f; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size] * inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h new file mode 100644 index 0000000000..9ebdb9638a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_REDUCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_REDUCE_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#define REDUCE_MAX_AXES_NUM 8 + +struct ReduceParameter { + OpParameter op_parameter_; + bool keep_dims_; + int axes_[REDUCE_MAX_AXES_NUM]; + int num_axes_; + int mode_; +}; + +int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceSumSquare(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_REDUCE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.cc new file mode 100644 index 0000000000..40801c3f35 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" + +static inline bool NextIndex(const int num_dims, const int *dims, int *current) { + int carry = 1; + for (int idx = num_dims - 1; idx >= 0; --idx) { + int current_val = current[idx] + carry; + if (dims[idx] == current_val) { + current[idx] = 0; + } else { + current[idx] = current_val; + carry = 0; + break; + } + } + return (carry == 0); +} + +static inline size_t GetInputOffset(const int num_dims, const int *dims, const int *iter) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + + return offset; +} + +static inline size_t GetOutputOffset(const int num_dims, const int *dims, const int *iter, const int num_axis, + const int *axes) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + // if we need to skip this axis + bool is_axis = false; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (idx == axes[axis_idx]) { + is_axis = true; + break; + } + } + + if (!is_axis) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + } + return offset; +} + +void ReduceMeanByAxes(const float *input_data, int *input_iter, const int *input_dims, int input_num_dims, + const int *axes, int num_axes, float *output_data, const int *output_dims, int output_num_dims) { + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = (size_t)(output_dims[idx]); + num_outputs *= current; + } + + // Reset input iterator. + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(input_num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(input_num_dims, input_dims, input_iter, num_axes, axes); + output_data[output_offset] += input_data[input_offset]; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + + // Calculate mean by dividing output_data by num of aggregated element. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_axes; ++idx) { + size_t current = (size_t)(input_dims[axes[idx]]); + num_elements_in_axis *= current; + } + + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = output_data[idx] / static_cast(num_elements_in_axis); + } +} + +float ReduceMeanAll(const float *src, int size) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += src[i]; + } + return sum / size; +} + +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) { + int num_outputs = 1; + int same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_outputs *= output_dims[idx]; + if (output_dims[idx] != input_dims[idx]) same_shape = false; + } + if (same_shape) { + std::copy(input, input + num_outputs * sizeof(float), output); + // memcpy(output, input, num_outputs*sizeof(float)); + return; + } + + for (int idx = 0; idx < num_outputs; ++idx) output[idx] = 0; // zero output + + int input_iter[8] = {0}; + int axes[5] = {0}; + int num_axes = 0; + for (int i = 0; i < num_dims; i++) + if (output_dims[i] == 1) axes[num_axes++] = i; + + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(num_dims, input_dims, input_iter, num_axes, axes); + output[output_offset] += input[input_offset]; + } while (NextIndex(num_dims, input_dims, input_iter)); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h new file mode 100644 index 0000000000..334ae736c4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_REDUCE_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_REDUCE_GRAD_H_ + +float ReduceMeanAll(const float *src, int size); +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_REDUCE_GRAD_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reverse.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reverse.cc new file mode 100644 index 0000000000..6c1867f954 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reverse.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/reverse.h" +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int Reverse(const float *input, float *output, size_t elem_size, int *index) { + for (int i = 0; i < elem_size; i++) { + output[index[i]] = input[i]; + } + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reverse.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reverse.h new file mode 100644 index 0000000000..8e76eb4b94 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reverse.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_REVERSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_REVERSE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#define REVERSE_SHAPE_MAX_SIZE 4 + +// For reverse. +struct ReverseParameter { + OpParameter op_parameter_; + int axis_[REVERSE_SHAPE_MAX_SIZE]; + int num_axis_; +}; + +int Reverse(const float *input, float *output, size_t elem_size, int *index); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_REVERSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/slice.cc new file mode 100644 index 0000000000..02e3b3fcee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/slice.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/slice.h" +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +void PadSliceParameterTo4D(SliceParameter *param) { + int32_t begin[DIMENSION_4D]; + int32_t end[DIMENSION_4D]; + int32_t slice_size[DIMENSION_4D]; + int32_t data_shape[DIMENSION_4D]; + for (int32_t i = 0; i < param->param_length_; ++i) { + begin[i] = param->begin_[i]; + end[i] = param->end_[i]; + slice_size[i] = param->size_[i] < 0 ? param->shape_[i] - begin[i] : param->size_[i]; + data_shape[i] = param->shape_[i]; + } + int32_t real_index = param->param_length_ - 1; + for (int32_t i = DIMENSION_4D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begin_[i] = begin[real_index]; + param->end_[i] = end[real_index]; + param->size_[i] = slice_size[real_index]; + param->shape_[i] = data_shape[real_index--]; + } else { + param->begin_[i] = 0; + param->end_[i] = 1; + param->size_[i] = 1; + param->shape_[i] = 1; + } + } + param->param_length_ = DIMENSION_4D; +} + +void DoSlice(const float *input, float *output, SliceParameter *param) { + int32_t out_dim1 = param->size_[1]; + int32_t out_dim2 = param->size_[2]; + int32_t out_dim3 = param->size_[3]; + size_t out_stride2 = out_dim3; + size_t out_stride1 = out_stride2 * out_dim2; + size_t out_stride0 = out_stride1 * out_dim1; + size_t count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_); + int thread_id = param->thread_id_; + size_t thread_stride = thread_id * count_per_thread; + size_t copy_size = param->size_[3] * sizeof(float); + size_t in_stride2 = param->shape_[3]; + size_t in_stride1 = param->shape_[2] * in_stride2; + size_t in_stride0 = param->shape_[1] * in_stride1; + for (int i = 0; i < param->size_[0]; ++i) { + size_t out_offset0 = i * out_stride0; + size_t in_offset0 = (i + param->begin_[0]) * in_stride0 + param->begin_[3]; + for (size_t j = 0; j < count_per_thread; ++j) { + size_t k = j + thread_stride; + if (k >= out_dim1) { + break; + } + size_t out_offset1 = k * out_stride1 + out_offset0; + size_t in_offset1 = (k + param->begin_[1]) * in_stride1 + in_offset0; + for (int l = 0; l < out_dim2; ++l) { + size_t out_offset = out_offset1 + l * out_stride2; + size_t in_offset = in_offset1 + (l + param->begin_[2]) * in_stride2; + memcpy(output + out_offset, input + in_offset, copy_size); + } + } + } +} + +void DoSliceNoParallel(const float *input, float *output, SliceParameter *param) { + size_t copy_size = param->size_[3] * sizeof(float); + size_t in_stride2 = param->shape_[3]; + size_t in_stride1 = param->shape_[2] * in_stride2; + size_t in_stride0 = param->shape_[1] * in_stride1; + size_t out_offset = 0; + for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) { + size_t in_offset0 = dim0 * in_stride0 + param->begin_[3]; + for (size_t dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) { + size_t in_offset1 = dim1 * in_stride1 + in_offset0; + for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) { + size_t in_offset = in_offset1 + dim2 * in_stride2; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += param->size_[3]; + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/slice.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/slice.h new file mode 100644 index 0000000000..07da4625c2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/slice.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SLICE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SLICE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#define SLICE_SHAPE_MAX_SIZE 4 + +struct SliceParameter { + OpParameter op_parameter_; + int32_t begin_[SLICE_SHAPE_MAX_SIZE]; + int32_t end_[SLICE_SHAPE_MAX_SIZE]; + int32_t size_[SLICE_SHAPE_MAX_SIZE]; + int32_t shape_[SLICE_SHAPE_MAX_SIZE]; + int32_t param_length_; + int32_t thread_id_; +}; + +void PadSliceParameterTo4D(SliceParameter *param); +void DoSlice(const float *input, float *output, SliceParameter *param); +void DoSliceNoParallel(const float *input, float *output, SliceParameter *param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SLICE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax.cc new file mode 100644 index 0000000000..2e046f2651 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" +#include + +// output = exp(input) / reduce_sum(exp(input), axis) +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) { + int32_t axis = parameter->axis_; + int n_dim = parameter->n_dim_; + int ele_size = parameter->element_size_; + int *input_shape = parameter->input_shape_; + + for (int i = 0; i < ele_size; i++) { + output_ptr[i] = exp(input_ptr[i]); + } + int inner_size = 1, outter_size = 1; + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + sum_data[j] += output_ptr[inner_offset]; + } + } + } + + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[j]; + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax.h new file mode 100644 index 0000000000..7324b0d4d5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SOFTMAX_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" + +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SOFTMAX_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h new file mode 100644 index 0000000000..e9b0955d9d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SOFTMAX_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SOFTMAX_GRAD_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct SoftmaxCrossEntropyParameter { + OpParameter op_parameter; + int32_t batch_size_; + unsigned int number_of_classes_; + int n_dim_; + int input_shape_[5]; +}; +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_SOFTMAX_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.cc new file mode 100644 index 0000000000..cbc4e247dd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/fp32/concat.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +int EnumElement(int *shape, int n_dims) { + int total = 1; + for (int i = 0; i < n_dims; i++) { + total *= shape[i]; + } + return total; +} + +void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm, + int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + memcpy(out_data + out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n, + in_data + stride0_i + stride1_j + stride2_k + stride3_m + stride4_n, stride4 * sizeof(float)); + } + } + } + } + } +} + +int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes) { + int trans_in_shape[6] = {in_shape[0], in_shape[1] / block_sizes[0], + block_sizes[0], in_shape[2] / block_sizes[1], + block_sizes[1], in_shape[3]}; + int trans_out_shape[6] = { + in_shape[0], block_sizes[0], block_sizes[1], in_shape[1] / block_sizes[0], in_shape[2] / block_sizes[1], + in_shape[3]}; + int in_strides[C4NUM + 2]; + ComputeStrides(trans_in_shape, in_strides, shape_size + 2); + int out_strides[C4NUM + 2]; + ComputeStrides(trans_out_shape, out_strides, shape_size + 2); + + int perm[6] = {0, 2, 4, 1, 3, 5}; + TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape); + return NNACL_OK; +} + +void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]) { + float *tmp = padded_input; + (void)memcpy(tmp, input, param.num_elements_ * sizeof(float)); + float *target = tmp_space[0]; + float *tmp_zeros = tmp_space[1]; + float *tmp2 = nullptr; + int cur_shape[param.n_dims_], cur_start_shape[param.n_dims_], cur_end_shape[param.n_dims_], + cur_target_shape[param.n_dims_]; + float *concat_inputs[3]; + int *concat_shapes[4]; + + for (int i = 0; i < param.n_dims_; i++) { + cur_shape[i] = param.in_shape_[i]; + cur_start_shape[i] = param.in_shape_[i]; + cur_end_shape[i] = param.in_shape_[i]; + cur_target_shape[i] = param.in_shape_[i]; + } + for (int i = 0; i < param.n_space_dims_; ++i) { + if (param.padded_in_shape_[i + 1] > param.in_shape_[i + 1]) { + int concat_idx = 0; + cur_target_shape[i + 1] = 0; + if (param.paddings_[2 * i] != 0) { + cur_start_shape[i + 1] = param.paddings_[2 * i]; + concat_inputs[concat_idx] = tmp_zeros; + concat_shapes[concat_idx++] = cur_start_shape; + cur_target_shape[i + 1] += cur_start_shape[i + 1]; + } + + concat_inputs[concat_idx] = tmp; + concat_shapes[concat_idx++] = cur_shape; + cur_target_shape[i + 1] += cur_shape[i + 1]; + if (param.paddings_[2 * i + 1] != 0) { + cur_end_shape[i + 1] = param.paddings_[2 * i + 1]; + concat_inputs[concat_idx] = tmp_zeros; + concat_shapes[concat_idx++] = cur_end_shape; + cur_target_shape[i + 1] += cur_end_shape[i + 1]; + } + concat_shapes[concat_idx] = cur_target_shape; + Concat((void **)concat_inputs, concat_idx, i + 1, concat_shapes, param.n_dims_, target); + + tmp2 = tmp; + tmp = target; + target = tmp2; + cur_start_shape[i + 1] = cur_end_shape[i + 1] = cur_shape[i + 1] = concat_shapes[concat_idx][i + 1]; + } + } + if (padded_input != tmp) { + memcpy(padded_input, tmp, param.num_elements_padded_ * sizeof(float)); + } +} + +int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, float *tmp_space[3]) { + float *padded_input; + int ret; + if (param.need_paddings_) { + if (tmp_space[0] == nullptr || tmp_space[1] == nullptr || tmp_space[2] == nullptr) { + return NNACL_NULL_PTR; + } + padded_input = tmp_space[0]; + DoPadding(input, padded_input, param, tmp_space + 1); + } + + if (param.need_paddings_) { + ret = SpaceToBatchForNHWC(padded_input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_); + } else { + ret = SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_); + } + return ret; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h new file mode 100644 index 0000000000..4d4ab437f7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +#define SPACE_TO_BATCH_BLOCK_SIZES_SIZE 2 +#define SPACE_TO_BATCH_PADDINGS_SIZE 4 + +struct SpaceToBatchParameter { + OpParameter op_parameter_; + int block_sizes_[8]; + int paddings_[8]; + int n_dims_; + int num_elements_; + int num_elements_padded_; + int n_space_dims_; + int in_shape_[8]; + int padded_in_shape_[8]; + bool need_paddings_ = false; +}; + +int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, float *tmp_space[3]); +int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size); +void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm, + int *output_shape); +void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param); +int EnumElement(int *shape, int n_dims); +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.cc new file mode 100644 index 0000000000..dab1dd7b8e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size, + int block_size, int h_start, int h_end) { + if (input == nullptr || output == nullptr) { + return NNACL_NULL_PTR; + } + if (shape_size != C4NUM) { + return NNACL_PARAM_INVALID; + } + if (h_start < 0 || h_start >= h_end || h_end > out_shape[1]) { + return NNACL_PARAM_INVALID; + } + int in_strides[C4NUM]; + ComputeStrides(in_shape, in_strides, shape_size); + int out_strides[C4NUM]; + ComputeStrides(out_shape, out_strides, shape_size); + for (int i = 0; i < out_shape[0]; ++i) { + size_t in_offset_n = i * in_strides[0]; + size_t out_offset_n = i * out_strides[0]; + for (int j = h_start; j < h_end; ++j) { + size_t in_offset_h = in_offset_n + j * block_size * in_strides[1]; + size_t out_offset_h = out_offset_n + j * out_strides[1]; + for (int k = 0; k < out_shape[2]; ++k) { + size_t in_offset_w = in_offset_h + k * block_size * in_strides[2]; + size_t out_offset_w = out_offset_h + k * out_strides[2]; + for (int l = 0; l < block_size; ++l) { + memcpy(output + out_offset_w + l * block_size * in_strides[2], input + in_offset_w + l * in_strides[1], + block_size * in_strides[2] * sizeof(float)); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h new file mode 100644 index 0000000000..5bfdacac89 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_DEPTH_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_DEPTH_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct SpaceToDepthParameter { + OpParameter op_parameter_; + int32_t block_size_; +}; + +int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size, + int block_size, int h_start, int h_end); +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_DEPTH_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/stack.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/stack.cc new file mode 100644 index 0000000000..1496ec7d47 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/stack.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/stack.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +void DoStack(const float * const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { + size_t one_input_size = 1; + for (size_t i = 0; i < shape_size; ++i) { + one_input_size *= in_shape[i]; + } + int in_strides[shape_size]; + ComputeStrides(in_shape, in_strides, shape_size); + + size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size; + size_t copy_size = copy_num * sizeof(float); + size_t pre_axis_count = 1; + for (size_t i = 0; i < axis; ++i) { + pre_axis_count *= in_shape[i]; + } + size_t in_offset = 0; + size_t out_offset = 0; + for (size_t i = 0; i < pre_axis_count; ++i) { + for (size_t j = 0; j < input_num; ++j) { + memcpy(output + out_offset, inputs[j] + in_offset, copy_size); + out_offset += copy_num; + } + in_offset += copy_num; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/stack.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/stack.h new file mode 100644 index 0000000000..cf08d4901b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/stack.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_STACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_STACK_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct StackParameter { + OpParameter op_parameter_; + int32_t axis_; +}; + +void DoStack(const float * const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_STACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.cc new file mode 100644 index 0000000000..1baa31f8e7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h" + +bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) { + if (cur_recursion >= max_recursion) { + return false; + } + + if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) { + return false; + } + + int row2 = row / 2; + int col2 = col / 2; + int deep2 = deep / 2; + + float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 - + 7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) - + 4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3); + + return (save_cost > 0.f); +} + +void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride, + int c_stride) { + int row4mod = row % 4; + int row4div = row / 4; + for (int r = 0; r < row; r++) { + int r4mod = r % 4; + int r4div = r / 4; + for (int c = 0; c < col * 4; c++) { + float value = 0; + int ic = c / 4 * c_stride + r * 4 + c % 4; + for (int d = 0; d < deep * 4; d++) { + int d4mod = d % 4; + int d4div = d / 4; + int a_stride = (r < (row4div * 4)) ? 4 : row4mod; + int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod; + int bi = c / 4 * b_stride + d * 4 + c % 4; + value = value + a_ptr[ai] * b_ptr[bi]; + } + dst_ptr[ic] = value; + } + } + return; +} + +void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride, + int c_stride) { + int row4mod = row % 4; + int row4div = row / 4; + + if (row4div > 0) { + GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride); + } + + if (row4mod != 0) { + GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride, + c_stride); + } + return; +} + +int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int cur_recursion, float *tmp_a_ptr) { + size_t row2 = matmul_param->row_ / 2; + size_t deep2 = matmul_param->deep_ / 2; + size_t col2 = matmul_param->col_ / 2; + size_t a_stride = matmul_param->a_stride_; + size_t b_stride = matmul_param->b_stride_; + size_t c_stride = matmul_param->c_stride_; + + StrassenMatMulParameter *rec_matmul = new StrassenMatMulParameter(); + rec_matmul->row_ = row2; + rec_matmul->deep_ = deep2; + rec_matmul->col_ = col2; + + float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float))); + if (x_ptr == nullptr) { + free(rec_matmul); + return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC; + } + float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float))); + if (y_ptr == nullptr) { + free(x_ptr); + free(rec_matmul); + return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC; + } + size_t x_stride = row2 * FP32_STRASSEN_UINT; + size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT; + + const float *a11 = a_ptr; + const float *a12 = a_ptr + deep2 * a_stride; + const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT; + const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT; + const float *b11 = b_ptr; + const float *b12 = b_ptr + col2 * b_stride; + const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT; + const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT; + float *c11 = c_ptr; + float *c12 = c_ptr + col2 * c_stride; + float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT; + float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT; + + /* S3 = A11 - A21 */ + MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2); + + /* T3 = B22 - B12 */ + MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2); + + /* P7 = S3T3 */ + rec_matmul->a_stride_ = x_stride; + rec_matmul->b_stride_ = y_stride; + rec_matmul->c_stride_ = c_stride; + StrassenMatmul(x_ptr, y_ptr, c21, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* S1 = A21 + A22 */ + MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2); + + /* T1 = B12 - B11 */ + MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2); + + /* P5 = S1T1 */ + StrassenMatmul(x_ptr, y_ptr, c22, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* S2 = S1 - A11 */ + MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2); + + /* T2 = B22 - T1 */ + MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2); + + /* P6 = S2T2 */ + StrassenMatmul(x_ptr, y_ptr, c12, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* S4 = A12 - S2 */ + MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2); + + /* P3 = S4B22 */ + rec_matmul->b_stride_ = b_stride; + StrassenMatmul(x_ptr, b22, c11, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* P1 = A11B11 */ + rec_matmul->a_stride_ = a_stride; + rec_matmul->c_stride_ = row2 * FP32_STRASSEN_UINT; + StrassenMatmul(a11, b11, x_ptr, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* U2 = P1 + P6 + U3 = U2 + P7 + U4 = U2 + P5 + U7 = U3 + P5 + U5 = U4 + P3 */ + MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride); + + /* T4 = T2 - B21 */ + MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2); + + /* P4 = A22T4 */ + rec_matmul->b_stride_ = y_stride; + rec_matmul->c_stride_ = c_stride; + StrassenMatmul(a22, y_ptr, c11, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* U6 = U3 - P4 */ + MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2); + + /* P2 = A12B21 */ + rec_matmul->b_stride_ = b_stride; + StrassenMatmul(a12, b21, c11, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* U1 = P1 + P2 */ + MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2); + + free(x_ptr); + free(y_ptr); + free(rec_matmul); + return NNACL_OK; +} + +int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + float *tmp_a_ptr) { + MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_); + GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_, + matmul_param->b_stride_, matmul_param->c_stride_); + return NNACL_OK; +} + +int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int cur_recursion, float *tmp_a_ptr) { + if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) { + return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr); + } + return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h new file mode 100644 index 0000000000..06afd71669 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_STRASSEN_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_STRASSEN_MATMUL_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/strassen_matmul.h" +#include "src/runtime/kernel/arm/nnacl/fp32/common_func.h" + +#define FP32_STRASSEN_UINT C4NUM +#define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM) +#define FP32_STRASSEN_MAX_RECURSION 5 + +int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int, float *tmp_a_ptr); +int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param, + float *tmp_a_ptr); + +int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int cur_recursion, float *tmp_a_ptr); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_STRASSEN_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.cc new file mode 100644 index 0000000000..be0c8b3f13 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/topk.h" + +int DescendCmp(const void *a, const void *b) { + return ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; +} + +int AscendCmp(const void *a, const void *b) { + return ((const TopkNode *)a)->element - ((const TopkNode *)b)->element; +} + +void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter) { + int last_dim_size = parameter->last_dim_size_; + int loop_num = parameter->loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + float *cur_input_data = input_data; + float *cur_output_data = output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < loop_num; i++) { + for (int j = 0; j < last_dim_size; j++) { + top_map[j].element = *(cur_input_data + j); + top_map[j].index = j; + } + if (parameter->sorted_) { + qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp); + } else { + qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmp); + } + for (int m = 0; m < k; m++) { + cur_output_data[m] = top_map[m].element; + cur_output_index[m] = top_map[m].index; + } + cur_input_data += last_dim_size; + cur_output_data += k; + cur_output_index += k; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.h new file mode 100644 index 0000000000..a49e5f8ef4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TOPK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TOPK_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct TopkNode { + float element; + int32_t index; +}; + +struct TopkParameter { + OpParameter op_parameter_; + int last_dim_size_; + int loop_num_; + int k_; + bool sorted_; + void *topk_node_list_; +}; + +void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TOPK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.cc new file mode 100644 index 0000000000..4d7d1600cf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h" +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size) { + memcpy(output_ptr, input_ptr, data_size); + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h new file mode 100644 index 0000000000..59c01944f3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNSQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNSQUEEZE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +#define UNSQUEEZE_DIMS_MAX_SIZE 4 + +struct UnsqueezeParameter { + OpParameter op_parameter_; + int dims_[UNSQUEEZE_DIMS_MAX_SIZE]; + int num_dim_; +}; + +int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNSQUEEZE_H_ + + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.cc new file mode 100644 index 0000000000..9c8aa0fa8c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/fused_batchnorm.h" + +void FusedBatchNorm(const float *input_ptr, const float *scale_ptr, const float *offest_ptr, const float *mean_ptr, + const float *variance_ptr, int *input_shapes, float epsilon, float *output_ptr) { + int channel = input_shapes[3]; + int units = 1; + for (int i = 0; i < 3; i++) { + units *= input_shapes[i]; + } + for (int c = 0; c < input_shapes[3]; c++) { + auto variance_sqrt = sqrt(variance_ptr[c] + epsilon); + for (int u = 0; u < units; u++) { + output_ptr[u * channel + c] = + (input_ptr[u * channel + c] - mean_ptr[c]) / variance_sqrt * scale_ptr[c] + offest_ptr[c]; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.h new file mode 100644 index 0000000000..8aac705e46 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FUSED_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FUSED_BATCHNORM_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct FusedBatchNormParameter { + OpParameter op_parameter_; + float epsilon_; +}; + + +void FusedBatchNorm(const float *input_ptr, const float *scale_ptr, const float *offest_ptr, const float *mean_ptr, + const float *variance_ptr, int *input_shapes, float epsilon, float *output_ptr); + + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FUSED_BATCHNORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.cc new file mode 100644 index 0000000000..31ac63ec5b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.cc @@ -0,0 +1,221 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.h" +#include + +void CalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count, + int *after_axis_count) { + *pre_axis_count = 1; + for (int i = 0; i < axis; ++i) { + *pre_axis_count = (*pre_axis_count) * shape[i]; + } + + *axis_count = shape[axis]; + + *after_axis_count = 1; + for (int i = axis + 1; i < dims_number; ++i) { + *after_axis_count = (*after_axis_count) * shape[i]; + } +} + +void ArgMinMaxQuant(const int8_t *input, int8_t *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count, + int after_axis_count, QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = -FLT_MAX; + if (!param->get_max_) { + value = FLT_MAX; + } + float index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j] * in_quant_arg->scale_ + bias; + if (param->get_max_) { + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } else { + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + } + float real_out = out_value ? value : index; + output[output_offset + j] = real_out * output_inverse_scale + output_zp; + } + } +} + +void ArgMinMaxQuant(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + CalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + ArgMinMaxQuant(input, output, param, pre_axis_count, axis_count, after_axis_count, in_quant_arg, out_quant_arg); + return; +} + +int ArgCompareAscInt8(const void *a, const void *b) { + return reinterpret_cast(a)->data_.f_data_ + - reinterpret_cast(b)->data_.f_data_; +} + +int ArgCompareDescInt8(const void *a, const void *b) { + return reinterpret_cast(b)->data_.f_data_ + - reinterpret_cast(a)->data_.f_data_; +} + +int8_t GetInt8Output(float real_out, float output_inverse_scale, int32_t output_zp) { + return real_out * output_inverse_scale + output_zp; +} + +void ArgMinMaxDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + float real_out = out_value ? param->arg_elements_[j].data_.f_data_ : param->arg_elements_[j].index_; + output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp); + } + } +} + +void ArgMinMaxDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + float real_out = out_value ? param->arg_elements_[k].data_.f_data_ : param->arg_elements_[k].index_; + output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp); + } + } + } +} + +void ArgMinMaxDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + float real_out = out_value ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_; + output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp); + } + } + } + } +} + +void ArgMinMaxDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + float real_out = out_value ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_; + output[out_offset] = GetInt8Output(real_out, output_inverse_scale, output_zp); + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.h new file mode 100644 index 0000000000..ec4f259715 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARG_MIN_MAX_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARG_MIN_MAX_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/arg_min_max_parameter.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +void ArgMinMaxQuant(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant, QuantArg *out_quant); +void ArgMinMaxDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant, QuantArg *out_quant); +void ArgMinMaxDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant, QuantArg *out_quant); +void ArgMinMaxDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant, QuantArg *out_quant); +void ArgMinMaxDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param, + QuantArg *in_quant, QuantArg *out_quant); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARG_MIN_MAX_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc new file mode 100644 index 0000000000..d949d64179 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h" +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int index = 0; index < element_size; ++index) { + output[index] = (int8_t)(input0[index] != input1[index]); + } + return NNACL_OK; +} + +int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int index = 0; index < element_size; ++index) { + output[index] = (int8_t)(input0[index] == input1[index]); + } + return NNACL_OK; +} + +int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int index = 0; index < element_size; ++index) { + output[index] = (int8_t)(input0[index] < input1[index]); + } + return NNACL_OK; +} + +int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int index = 0; index < element_size; ++index) { + output[index] = (int8_t)(input0[index] <= input1[index]); + } + return NNACL_OK; +} + +int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int index = 0; index < element_size; ++index) { + output[index] = (int8_t)(input0[index] > input1[index]); + } + return NNACL_OK; +} + +int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int index = 0; index < element_size; ++index) { + output[index] = (int8_t)(input0[index] >= input1[index]); + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h new file mode 100644 index 0000000000..9229657f51 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + +int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + +int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + +int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + +int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + +int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.cc new file mode 100644 index 0000000000..b6b98d109e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.cc @@ -0,0 +1,279 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.h" +#ifdef ENABLE_NEON +#include +#include "src/runtime/kernel/arm/nnacl/add_int8.h" +#endif +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +int ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(floorf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementRound(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(round(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementCeil(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(ceil(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementAbs(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(fabsf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementSin(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(sinf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementCos(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(cosf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementLog(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(logf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementSqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + int32_t output_tmp = round(sqrtf(input_f32) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementRsqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 <= 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + int32_t output_tmp = round(1.f / (sqrtf(input_f32) * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input, int32x4_t left_shift_out_vec, int32x4_t output_multiplier_vec, + ArithSelfQuantArg para) { + int32x4_t input_scale = vmulq_s32(scaled_input, scaled_input); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + para.shift_right_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_args_.zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void SquareInt8NEON(int8_t *input_data, int8_t *output_data, int64_t element_size, ArithSelfQuantArg para, int *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); + + for (; (*index) <= element_size - 8; (*index) += 8) { + int16x8_t input_val = LoadAndAddOffset(input_data, *index, para.in_args_.zp_); + int32x4_t input_low = vmovl_s16(vget_low_s16(input_val)); + int32x4_t input_high = vmovl_s16(vget_high_s16(input_val)); + + int16x4_t sum_low = ClacSumHalfWord(input_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = ClacSumHalfWord(input_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + } +} +#endif + +int ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + int32_t in_zp = para.in_args_.zp_; + int32_t out_zp = para.out_args_.zp_; + + int index = 0; +#ifdef ENABLE_NEON + SquareInt8NEON(input, output, element_size, para, &index); +#endif + for (; index < element_size; index++) { + const int32_t input_val = input[index] + in_zp; + int32_t output_tmp = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input_val * input_val * (1 << para.shift_left_), para.output_multiplier_), + para.shift_right_); + output_tmp += out_zp; + if (output_tmp > para.output_activation_max_) { + output[index] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[index] = para.output_activation_min_; + } else { + output[index] = static_cast(output_tmp); + } + } + return NNACL_OK; +} + +int ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(((float)(!(bool)(input[i] * in_scale + bias))) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = static_cast(output_tmp); + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.h new file mode 100644 index 0000000000..78c9a42f25 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_self_int8.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_SELF_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_SELF_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int ElementRound(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementCeil(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementAbs(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementSin(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementCos(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementLog(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementSqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementRsqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_SELF_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.cc new file mode 100644 index 0000000000..a10a0f151b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.h" +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +void BatchToSpaceNoCropForNHWC(const int8_t *input, int8_t *output, const int *in_shape, int out_n, const int *block, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + size_t stride_h = block_w * out_n; + size_t output_offset = 0; + size_t in_stride_h = in_w * in_c; + size_t in_stride_n = in_stride_h * in_h; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + size_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + size_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} + +void BatchToSpaceForNHWC(const int8_t *input, int8_t *output, const int *in_shape, int out_n, const int *block, + const int *crops, QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_n = in_shape[0]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + size_t stride_h = block_w * out_n; + size_t output_offset = 0; + size_t in_stride_h = in_w * in_c; + size_t in_stride_n = in_stride_h * in_h; + + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + size_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + size_t h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + size_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + size_t w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.h new file mode 100644 index 0000000000..7846e9a1e6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/batch_to_space_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +void BatchToSpaceNoCropForNHWC(const int8_t *input, int8_t *output, const int *in_shape, int out_n, const int *block, + QuantArg *in_quant_arg, QuantArg *out_quant_arg); +void BatchToSpaceForNHWC(const int8_t *input, int8_t *output, const int *in_shape, int out_n, const int *block, + const int *crops, QuantArg *in_quant_arg, QuantArg *out_quant_arg); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h new file mode 100644 index 0000000000..3439fbe1db --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_COMMON_FUNC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_COMMON_FUNC_H_ + +#include +#include +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM +void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, + size_t oc4, size_t offset); + +#ifdef ENABLE_ARM64 +void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, + size_t shift_after); +#elif defined(ENABLE_ARM32) +void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, + size_t shift_after); +#endif +#endif + +#ifdef ENABLE_ARM +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, + size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier, + int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_COMMON_FUNC_H_ */ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/concat_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/concat_int8.cc new file mode 100644 index 0000000000..804c80bf93 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/concat_int8.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/concat_int8.h" +#include + +void Concat(int8_t **inputs, int8_t *output_ptr, ConcatQuantArg *quant_concat_parm, int axis) { + float output_scale = quant_concat_parm->out_quant_args_.scale_; + float output_inverse_scale = 1.f / output_scale; + int input_num = quant_concat_parm->input_num_; + int *output_shape = quant_concat_parm->output_shape_; + int output_dim = quant_concat_parm->output_dim_; + QuantArg *input_quant = quant_concat_parm->in_quant_args_; + int output_zp = quant_concat_parm->out_quant_args_.zp_; + + int before_axis_size = 1; + for (int i = 0; i < axis; i++) { + before_axis_size *= output_shape[i]; + } + + int after_axis_size = 1; + for (size_t i = axis + 1; i < output_dim; i++) { + after_axis_size *= output_shape[i]; + } + + for (int k = 0; k < before_axis_size; k++) { + for (int i = 0; i < input_num; i++) { + int *input_shape = quant_concat_parm->input_shapes_[i]; + int copy_size = input_shape[axis] * after_axis_size; + int8_t *input_ptr = inputs[i] + k * copy_size; + if (input_quant[i].scale_ == output_scale && input_quant[i].zp_ == output_zp) { + memcpy(output_ptr, input_ptr, copy_size); + } else { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + for (int j = 0; j < copy_size; j++) { + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } + } + output_ptr += copy_size; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/concat_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/concat_int8.h new file mode 100644 index 0000000000..6b7edb43d5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/concat_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONCAT_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONCAT_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Concat(int8_t **inputs, int8_t *output_ptr, ConcatQuantArg *quant_concat_parm, int axis); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONCAT_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc new file mode 100644 index 0000000000..784b754cd8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc @@ -0,0 +1,345 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.h" +#include +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" +#include "src/runtime/kernel/arm/nnacl/int8/common_func.h" + +/*conv depthwise int8 begin*/ +void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w, int out_multiplier, + int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int tmp_buffer[C4NUM]; + for (int i = 0; i < C4NUM; i++) { + tmp_buffer[i] = 0; + } + const int16_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int16_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + dst[c] = static_cast(tmp_buffer[c]); + } +} + +void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + int8_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const int16_t *src_h = src + ih * sliding->in_h_step_; + + int8_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const int16_t *src_w = src_h + iw * sliding->block_channel_; + + const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + + DepthwiseBorderPixelInt8( + dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_, + sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0]); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step, int out_multiplier, int left_shift, + int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int tmp_buffer[C4NUM]; + int8_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int8_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const int16_t *src_kh = src_w; + const int16_t *weight_kh = weight; + + for (int i = 0; i < C4NUM; i++) { + tmp_buffer[i] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const int16_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add bias relu + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), + -right_shift); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + dst_w[c] = static_cast(tmp_buffer[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const int16_t *src_data = src + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * C4NUM; + int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * C4NUM; +#ifdef ENABLE_ARM64 + ConvDwInt8Center( + out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), + sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t), + sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), + sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0]); +#else + DepthwiseCenterInt8( + out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); +#endif + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*conv depthwise int8 end*/ + +/*deconv depthwise int8 begin*/ +void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + int32_t *dst_kh = dst; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDepthwiseBorderInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const int16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int32_t *dst_h = dst + oh * sliding->in_h_step_; + + const int16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + int32_t *dst_w = dst_h + ow * C4NUM; + + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + int32_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDepthwiseCenterInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step) { + int32_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int32_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + int32_t *dst_kh = dst_w; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDepthwisePostFuncInt8(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, + const ConvParameter *conv_param, int out_multiplier, int left_shift, int right_shift, + int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int8_t *dst_k = dst; + int32_t *buffer_k = output_buffer; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + buffer_k[c] += bias[c]; + buffer_k[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer_k[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + buffer_k[c] += out_zp; + buffer_k[c] = MSMAX(buffer_k[c], acc_min); + buffer_k[c] = MSMIN(buffer_k[c], acc_max); + dst_k[c] = static_cast(buffer_k[c]); + } + dst_k += block_channel; + buffer_k += C4NUM; + } +} + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + int buffer_size = conv_param->output_h_ * conv_param->output_w_ * C4NUM; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + memset(output_buffer, 0, buffer_size * sizeof(int32_t)); + const int16_t *src_data = src + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * C4NUM; + const int16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_ * sizeof(int16_t), sliding->block_channel_ * sizeof(int16_t), + sliding->in_sh_step_ * sizeof(int32_t), sliding->in_sw_step_ * sizeof(int32_t), + sliding->in_kh_step_ * sizeof(int32_t), sliding->in_kw_step_ * sizeof(int32_t)); +#else + DeconvDepthwiseCenterInt8(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, + sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDepthwisePostFuncInt8( + dst_data, output_buffer, bias, sliding->block_channel_, conv_param, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise int8 end*/ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.h new file mode 100644 index 0000000000..555c3693ab --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONV_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONV_DEPTHWISE_H_ + +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" + +void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONV_DEPTHWISE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.cc new file mode 100644 index 0000000000..fa02fd99f1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.cc @@ -0,0 +1,320 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/conv_int8.h" +#include +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" +#include "src/runtime/kernel/arm/nnacl/int8/common_func.h" + +void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param) { + int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0]; + int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0]; + int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; + int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; +#ifdef __aarch64__ + IndirectGemmInt8_4x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), + input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); +#elif defined(ENABLE_ARM32) + IndirectGemmInt8_2x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), + input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); +#else + int tile_num = conv_param->tile_num_; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4_block = oc / C4NUM; + int oc4_res = oc % C4NUM; + int weight_oc4_offset = oc4_block * C4NUM * plane_c4 * C4NUM * ic4 * C4NUM + oc4_res * C4NUM * C4NUM; + int dst_oc_offset = oc; + for (int n = 0; n < tile_num; n++) { + int src_tile_offset = n * C4NUM * C4NUM; + int dst_tile_offset = dst_oc_offset + n * output_channel; + + for (int b = 0; b < kernel_plane; b++) { + int plane_c4_block = b / C4NUM; + int plane_c4_res = b % C4NUM; + int src_plane_offset = src_tile_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM; + int weight_plane_offset = + weight_oc4_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM; + for (int i = 0; i < ic4; i++) { + int src_ic4_offset = src_plane_offset + i * tile_num * C4NUM * C4NUM; + int weight_ic4_offset = weight_plane_offset + i * C4NUM * C4NUM * C4NUM; + for (int j = 0; j < C4NUM; j++) { + int weight_ic_offset = weight_ic4_offset + j; + tmp_dst[dst_tile_offset] += weight[weight_ic_offset] * src[src_ic4_offset + j]; + } // in c4num loop + } // ic4 loop + } // kernel_plane loop + tmp_dst[dst_tile_offset] -= input_sum[n]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } // tile_num loop + } // output_channel loop +#endif +} + +void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param, GEMM_FUNC gemm_func) { + int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0]; + int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0]; + int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; + int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; + if (gemm_func != nullptr) { +#ifdef __aarch64__ + gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, + act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); +#endif + } else { + int tile_num = conv_param->tile_num_; + for (int oc = 0; oc < output_channel; oc++) { + int oc4_block = oc / C4NUM; + int oc4_res = oc % C4NUM; + int weight_oc4_offset = oc4_block * C4NUM * kernel_plane * ic4 * C4NUM + oc4_res * C4NUM; + int dst_oc_offset = oc; + for (int n = 0; n < tile_num; n++) { + int src_tile_offset = n * C4NUM; + int dst_tile_offset = dst_oc_offset + n * output_channel; + + for (int b = 0; b < kernel_plane; b++) { + int src_plane_offset = src_tile_offset + b * tile_num * ic4 * C4NUM; + int weight_plane_offset = weight_oc4_offset + b * C4NUM * ic4 * C4NUM; + for (int i = 0; i < ic4; i++) { + int src_ic4_offset = src_plane_offset + i * tile_num * C4NUM; + int weight_ic4_offset = weight_plane_offset + i * C4NUM * C4NUM; + for (int j = 0; j < C4NUM; j++) { + int weight_ic_offset = weight_ic4_offset + j; + tmp_dst[dst_tile_offset] += weight[weight_ic_offset] * src[src_ic4_offset + j]; + } // in c4num loop + } // ic4 loop + } // kernel_plane loop + tmp_dst[dst_tile_offset] -= input_sum[n]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } // tile_num loop + } // output_channel loop + } +} + +void Conv3x3Uint8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { + int oc4 = UP_DIV(oc, C4NUM); + int input_unit_square = 16; +#ifdef ENABLE_ARM + IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); +#else + for (int c = 0; c < oc4; c++) { + int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; + int dst_oc_offset = c * input_unit_square * C4NUM; + for (int n = 0; n < real_cal_num; n++) { + int src_tile_offset = n * C8NUM; + int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; + for (int i = 0; i < 4; i++) { + int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; + int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; + int dst_h_offset = dst_tile_offset + i * 4 * 4; + for (int m = 0; m < 4; m++) { + int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; + int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; + int dst_w_offset = dst_h_offset + m * C4NUM; + + int32_t acc[4] = {0}; + for (int z = 0; z < 4; z++) { + int filter_offset = filter_w_offset + z; + for (int j = 0; j < ic8; j++) { + int filter_c8_offset = filter_offset + j * 4 * 8; + int src_c8_offset = src_w_offset + j * 8 * 8; + + for (int k = 0; k < 8; k++) { + const int16_t *w_ptr = weight + filter_c8_offset + k * 4; + const int16_t *input_ptr = src + src_c8_offset + k; + acc[z] += w_ptr[0] * input_ptr[0]; + } + } + (dst + dst_w_offset + z)[0] = acc[z]; + } + } + } + } + } +#endif +} + +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + + int tile_n = conv_param->tile_num_; + int thread_count = conv_param->thread_num_; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + int32_t *tmp_input_sum = input_sum + thread_id * tile_n; + int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; + // clear tmp buffer before compute + memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + + size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); + int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; + memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); + + Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); + if (real_cal_num == tile_n) { + int8_t *gemm_output = output_data + out_offset; + IndirectGemmInt8(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, + out_channel, input_sum, conv_param); + } else { + // res part + IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, + out_channel, input_sum, conv_param); + memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); + } + } + } +} + +void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param, GEMM_FUNC gemm_func) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + int tile_n = conv_param->tile_num_; + int thread_count = conv_param->thread_num_; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + // todo + int32_t *tmp_input_sum = input_sum + thread_id * tile_n; + int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; + // clear tmp buffer before compute + memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + + size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); + int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; + memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); + + Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); + if (real_cal_num == tile_n) { + int8_t *gemm_output = output_data + out_offset; + IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, + kernel_plane, out_channel, input_sum, conv_param, gemm_func); + } else { + // res part + IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, + out_channel, input_sum, conv_param, gemm_func); + memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); + } + } + } +} + +// int8 convolution 3x3 +void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, + int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, + int task_id, ConvParameter *conv_param) { + int thread_count = conv_param->thread_num_; + int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); + int output_batch = conv_param->output_batch_; + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int oc4 = UP_DIV(output_channel, C4NUM); + int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; + int block_unit_buffer_offset = 16 * C8NUM; + int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + + Conv3x3Uint8InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, output_channel, ic8, real_cal_num); + + Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, + real_cal_num, out_w_block, conv_param); + } + } + + // get real output + PackNC4HW4ToNHWCInt8(tmp_out, output_data, output_batch, output_h * output_w, output_channel); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h new file mode 100644 index 0000000000..96f08c0c6b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_int8.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONV_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONV_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/winograd_utils.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, + size_t shift_after); + +void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param); + +void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param, GEMM_FUNC gemm_func); + +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param); + +void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param, GEMM_FUNC gemm_func); + +// int8 convolution 3x3 +void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, + int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, + int task_id, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CONV_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/crop_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/crop_int8.cc new file mode 100644 index 0000000000..e7708fcc2a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/crop_int8.cc @@ -0,0 +1,222 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/crop_parameter.h" +#include "src/runtime/kernel/arm/nnacl/int8/crop_int8.h" +#include + +void Crop(const int8_t *input, int8_t *output, int task_id, CropParameter *para) { + auto input_dim = para->input_dim_; + switch (input_dim) { + case 1: + Crop1D(input, output, task_id, para); + break; + case 2: + Crop2D(input, output, task_id, para); + break; + case 3: + Crop3D(input, output, task_id, para); + break; + case 4: + Crop4D(input, output, task_id, para); + break; + } +} + +void Crop1D(const int8_t *input, int8_t *output, int task_id, CropParameter *para) { + const int out_batch = para->out_shape_[0]; + const int thread_count = para->thread_count_; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + + float in_scale = para->quant_arg.in_args_.scale_; + int32_t in_zp = para->quant_arg.in_args_.zp_; + float out_scale = para->quant_arg.out_args_.scale_; + int32_t out_zp = para->quant_arg.out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + auto n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const int8_t *in_ptr = input + n + para->in_offset_[0]; + int8_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + if (in_scale == out_scale && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > para->quant_arg.output_activation_max_) { + out_ptr[i] = para->quant_arg.output_activation_max_; + } else if (output_tmp < para->quant_arg.output_activation_min_) { + out_ptr[i] = para->quant_arg.output_activation_min_; + } else { + out_ptr[i] = static_cast(output_tmp); + } + } + } + return; +} + +void Crop2D(const int8_t *input, int8_t *output, int task_id, CropParameter *para) { + const int in_height = para->in_shape_[1]; + const int out_batch = para->out_shape_[0]; + const int out_height = para->out_shape_[1]; + const int thread_count = para->thread_count_; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + + float in_scale = para->quant_arg.in_args_.scale_; + int32_t in_zp = para->quant_arg.in_args_.zp_; + float out_scale = para->quant_arg.out_args_.scale_; + int32_t out_zp = para->quant_arg.out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + auto h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const int8_t *in_ptr = input + (n + para->in_offset_[0]) * in_height + h + para->in_offset_[1]; + int8_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + if (in_scale == out_scale && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > para->quant_arg.output_activation_max_) { + out_ptr[i] = para->quant_arg.output_activation_max_; + } else if (output_tmp < para->quant_arg.output_activation_min_) { + out_ptr[i] = para->quant_arg.output_activation_min_; + } else { + out_ptr[i] = static_cast(output_tmp); + } + } + } + } + return; +} + +void Crop3D(const int8_t *input, int8_t *output, int task_id, CropParameter *para) { + const int in_height = para->in_shape_[1]; + const int in_width = para->in_shape_[2]; + + const int out_batch = para->out_shape_[0]; + const int out_height = para->out_shape_[1]; + const int out_width = para->out_shape_[2]; + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = para->quant_arg.in_args_.scale_; + int32_t in_zp = para->quant_arg.in_args_.zp_; + float out_scale = para->quant_arg.out_args_.scale_; + int32_t out_zp = para->quant_arg.out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + const int thread_count = para->thread_count_; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + auto h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const int8_t *in_ptr = + input + (n + para->in_offset_[0]) * in_stride_n + (h + para->in_offset_[1]) * in_stride_h + para->in_offset_[2]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + if (in_scale == out_scale && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_width); + } else { + for (int i = 0; i < out_width; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > para->quant_arg.output_activation_max_) { + out_ptr[i] = para->quant_arg.output_activation_max_; + } else if (output_tmp < para->quant_arg.output_activation_min_) { + out_ptr[i] = para->quant_arg.output_activation_min_; + } else { + out_ptr[i] = static_cast(output_tmp); + } + } + } + } + } + return; +} + +void Crop4D(const int8_t *input, int8_t *output, int task_id, CropParameter *para) { + const int in_height = para->in_shape_[1]; + const int in_width = para->in_shape_[2]; + const int in_channel = para->in_shape_[3]; + + const int out_batch = para->out_shape_[0]; + const int out_height = para->out_shape_[1]; + const int out_width = para->out_shape_[2]; + const int out_channel = para->out_shape_[3]; + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = para->quant_arg.in_args_.scale_; + int32_t in_zp = para->quant_arg.in_args_.zp_; + float out_scale = para->quant_arg.out_args_.scale_; + int32_t out_zp = para->quant_arg.out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + const int thread_count = para->thread_count_; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + auto h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const int8_t *in_ptr = input + (n + para->in_offset_[0]) * in_stride_n + + (h + para->in_offset_[1]) * in_stride_h + (w + para->in_offset_[2]) * in_stride_w + + para->in_offset_[3]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + if (in_scale == out_scale && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_channel); + } else { + for (int i = 0; i < out_channel; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > para->quant_arg.output_activation_max_) { + out_ptr[i] = para->quant_arg.output_activation_max_; + } else if (output_tmp < para->quant_arg.output_activation_min_) { + out_ptr[i] = para->quant_arg.output_activation_min_; + } else { + out_ptr[i] = static_cast(output_tmp); + } + } + } + } + } + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/crop_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/crop_int8.h new file mode 100644 index 0000000000..d574b0813b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/crop_int8.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CROP_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CROP_INT8_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/crop_parameter.h" + +void Crop(const int8_t *input, int8_t *output, int task_id, CropParameter *para); +void Crop1D(const int8_t *input, int8_t *output, int task_id, CropParameter *para); +void Crop2D(const int8_t *input, int8_t *output, int task_id, CropParameter *para); +void Crop3D(const int8_t *input, int8_t *output, int task_id, CropParameter *para); +void Crop4D(const int8_t *input, int8_t *output, int task_id, CropParameter *para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_CROP_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.cc new file mode 100644 index 0000000000..32586063a3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/deconv.h" + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, + ConvParameter *conv_param) { + MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.quant_args_[0][0].zp_, + conv_param->conv_quant_arg_.quant_args_[1][0].zp_); + return NNACL_OK; +} + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param) { + /* row8x8-major(ih*iw x oc*kh*kw) -> row8x8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_DIV(output_channel, C8NUM); + int in_plane8 = UP_ROUND(input_plane, 8); + + for (int c = 0; c < oc8; c++) { + int32_t *dst_ptr = tmp + c * output_plane * C8NUM; + const int32_t *src_ptr = src + c * in_plane8 * kernel_plane * C8NUM; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * conv_param->input_w_ * C8NUM + iw * C8NUM + + kh * input_plane * conv_param->kernel_w_ * C8NUM + kw * input_plane * C8NUM; + int dst_index = oh * conv_param->output_w_ * C8NUM + ow * C8NUM + + kh * conv_param->dilation_h_ * conv_param->output_w_ * C8NUM + + kw * conv_param->dilation_w_ * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_ptr[dst_index + i] += src_ptr[src_index + i]; + } + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8), + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h new file mode 100644 index 0000000000..145a0c1ff6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DECONV_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DECONV_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" +#include "src/runtime/kernel/arm/nnacl/int8/matmul.h" + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, + ConvParameter *conv_param); + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DECONV_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.cc new file mode 100644 index 0000000000..4ee38a88c4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.h" +#include + +void DepthToSpaceForNHWC(const int8_t *input, int8_t *output, int *in_shape, DepthToSpaceParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = block_size * param->out_stride_dim2_; + float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_offset_n = i * param->in_stride_dim0_; + size_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + size_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + size_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + size_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + size_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + size_t out_offset = out_offset_w + l * param->out_stride_dim1_; + size_t in_offset = in_offset_w + l * block_size * param->out_stride_dim2_; + for (int m = 0; m < copy_size; ++m) { + int32_t output_tmp = round(input[in_offset + m] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[out_offset + m] = output_tmp; + } + } + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.h new file mode 100644 index 0000000000..b2abb1f068 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/depth_to_space_parameter.h" +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +void DepthToSpaceForNHWC(const int8_t *input, int8_t *output, int *in_shape, DepthToSpaceParameter *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/hswish_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/hswish_int8.cc new file mode 100644 index 0000000000..3e048590c8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/hswish_int8.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/hswish_int8.h" + +int16_t SaturatingLeftShift(int16_t value, int shift_num) { + int32_t result = (int32_t)value * (1 << shift_num); + return MSMAX(MSMIN(result, SHRT_MAX), SHRT_MIN); +} + +int HSwishInt8(const int8_t *src, int length, int8_t *dst, HswishQuantArg *arg) { + for (int i = 0; i < length; i++) { + const int16_t input_value = src[i] - arg->input_zp; + const int16_t input_value_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + SaturatingRoundingDoublingHighMulInt16(input_value_scale, arg->output_multiplier_fixedpoint_int16); + int16_t relu6_value = input_value_scale; + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, arg->relu6_multiplier_exponent - 1); + } + relu6_value = SaturatingRoundingDoublingHighMulInt16(relu6_value, arg->relu6_multiplier_fixedpoint_int16); + + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, 1); + } + if (arg->relu6_multiplier_exponent < 0) { + relu6_value = RoundingDivideByPOT(relu6_value, -arg->relu6_multiplier_exponent); + } + relu6_value = (relu6_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + SaturatingRoundingDoublingHighMulInt16(relu6_value, input_value_on_preshift_output_scale); + + int16_t output = RoundingDivideByPOT(preshift_output_value, -arg->output_multiplier_exponent); + output += arg->output_zp; + output = MSMIN(output, 127); + output = MSMAX(output, -128); + dst[i] = (int8_t)output; + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/hswish_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/hswish_int8.h new file mode 100644 index 0000000000..2d32aee083 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/hswish_int8.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_HSWISH_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_HSWISH_INT8_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +struct HswishQuantArg { + double input_scale; + int32_t input_zp; + double output_scale; + int32_t output_zp; + int16_t relu6_multiplier_fixedpoint_int16; + int32_t relu6_multiplier_exponent; + int16_t output_multiplier_fixedpoint_int16; + int32_t output_multiplier_exponent; +}; + +int HSwishInt8(const int8_t *src, int length, int8_t *dst, HswishQuantArg *arg); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_HSWISH_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul.cc new file mode 100644 index 0000000000..972dfd46c5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/matmul.h" +#include +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + int8_t *src = src_ptr + r * col; + for (int c = 0; c < col; c++) { + int cd8 = c / 8; + int cm8 = c % 8; + dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; + } + } +} + +void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + int rd8 = r / 8; + int rm8 = r % 8; + for (int c = 0; c < col; c++) { + dst_ptr[rd8 * col * 8 + c * 8 + rm8] = src_ptr[r * col + c]; + } + } +} + +void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, + const int32_t a_zp, const int32_t b_zp) { + /* col8-major * row8-major => row8x8-major */ + for (int row = 0; row < row8; row++) { + for (int col = 0; col < col8; col++) { + int r8div = row / 8, r8mod = row % 8; + int c8div = col / 8, c8mod = col % 8; + size_t ci = c8div * row8 * 8 + row * 8 + c8mod; + int32_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r8div * deep * 8 + d * 8 + r8mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp); + } + c[ci] = value; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul.h new file mode 100644 index 0000000000..7e2cc43ca9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" + +void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, + const int32_t a_zp, const int32_t b_zp); +void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_INT8_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/mul_int8.cc new file mode 100644 index 0000000000..087e6b0f9d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/mul_int8.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/mul_int8.h" +#include "src/runtime/kernel/arm/nnacl/mul_parameter.h" +#ifdef ENABLE_NEON +#include +#include "src/runtime/kernel/arm/nnacl/add_int8.h" +#endif +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, MulQuantArg para) { + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + para.shift_right_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_quant_arg_.zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void MulInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + MulQuantArg para, int *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para.in_quant_args_[0].zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para.in_quant_args_[1].zp_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int16x4_t sum_low = + ClacSumHalfWord(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + ClacSumHalfWord(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + } +} +#endif + +void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para) { + int index = 0; +#ifdef ENABLE_NEON + MulInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para.in_quant_args_[0].zp_ + input0_data[index]; + const int32_t input1_val = para.in_quant_args_[1].zp_ + input1_data[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << para.shift_left_), + para.output_multiplier_), para.shift_right_); + + mul_result += para.out_quant_arg_.zp_; + + if (mul_result > para.output_activation_max_) { + output_data[index] = para.output_activation_max_; + } else if (mul_result < para.output_activation_min_) { + output_data[index] = para.output_activation_min_; + } else { + output_data[index] = static_cast(mul_result); + } + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/mul_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/mul_int8.h new file mode 100644 index 0000000000..000fa37b8e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/mul_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MUL_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MUL_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/mul_parameter.h" + +void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MUL_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pad.cc new file mode 100644 index 0000000000..55c6d93404 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pad.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/pad.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" + +void PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings) { + int32_t copy_size = in_dims[3]; + for (int n = 0; n < in_dims[0]; n++) { + for (int h = 0; h < in_dims[1]; h++) { + for (int w = 0; w < in_dims[2]; w++) { + const int8_t *in = in_data + offset(in_dims, n, h, w, 0); + int8_t *out = out_data + offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]); + memcpy(out, in, copy_size * sizeof(int8_t)); + } + } + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pad.h new file mode 100644 index 0000000000..336af7f852 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pad.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_PAD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_PAD_INT8_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/pad_parameter.h" + +void PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pooling_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pooling_int8.cc new file mode 100644 index 0000000000..854e7e7fc3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pooling_int8.cc @@ -0,0 +1,378 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" + +void AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + float input_scale = pooling_param->quant_args_[0][0].scale_; + int input_zp = pooling_param->quant_args_[0][0].zp_; + float output_scale = pooling_param->quant_args_[1][0].scale_; + int output_zp = pooling_param->quant_args_[1][0].zp_; + double real_multiplier = input_scale / output_scale; + int8_t out_min = INT8_MIN; + int8_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int8_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = real_out; + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c8 = UP_DIV(channel, C8NUM); + int8_t out_min = INT8_MIN; + int8_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c8 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C8NUM; + int out_channel_offset = out_plane_offset + j * C8NUM; + int16_t tmp_avg1 = 0; + int16_t tmp_avg2 = 0; + int16_t tmp_avg3 = 0; + int16_t tmp_avg4 = 0; + int16_t tmp_avg5 = 0; + int16_t tmp_avg6 = 0; + int16_t tmp_avg7 = 0; + int16_t tmp_avg8 = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg1 += *(input_ptr + in_offset); + tmp_avg2 += *(input_ptr + in_offset + 1); + tmp_avg3 += *(input_ptr + in_offset + 2); + tmp_avg4 += *(input_ptr + in_offset + 3); + tmp_avg5 += *(input_ptr + in_offset + 4); + tmp_avg6 += *(input_ptr + in_offset + 5); + tmp_avg7 += *(input_ptr + in_offset + 6); + tmp_avg8 += *(input_ptr + in_offset + 7); + ++real_count; + } + } // win_w loop + } // win_h loop + int16_t tmp_out1 = round((float)tmp_avg1 / (float)real_count); + int16_t tmp_out2 = round((float)tmp_avg2 / (float)real_count); + int16_t tmp_out3 = round((float)tmp_avg3 / (float)real_count); + int16_t tmp_out4 = round((float)tmp_avg4 / (float)real_count); + int16_t tmp_out5 = round((float)tmp_avg5 / (float)real_count); + int16_t tmp_out6 = round((float)tmp_avg6 / (float)real_count); + int16_t tmp_out7 = round((float)tmp_avg7 / (float)real_count); + int16_t tmp_out8 = round((float)tmp_avg8 / (float)real_count); + int16_t real_out1 = tmp_out1 < out_min ? out_min : tmp_out1; + int16_t real_out2 = tmp_out2 < out_min ? out_min : tmp_out2; + int16_t real_out3 = tmp_out3 < out_min ? out_min : tmp_out3; + int16_t real_out4 = tmp_out4 < out_min ? out_min : tmp_out4; + int16_t real_out5 = tmp_out5 < out_min ? out_min : tmp_out5; + int16_t real_out6 = tmp_out6 < out_min ? out_min : tmp_out6; + int16_t real_out7 = tmp_out7 < out_min ? out_min : tmp_out7; + int16_t real_out8 = tmp_out8 < out_min ? out_min : tmp_out8; + real_out1 = real_out1 > out_max ? out_max : real_out1; + real_out2 = real_out2 > out_max ? out_max : real_out2; + real_out3 = real_out3 > out_max ? out_max : real_out3; + real_out4 = real_out4 > out_max ? out_max : real_out4; + real_out5 = real_out5 > out_max ? out_max : real_out5; + real_out6 = real_out6 > out_max ? out_max : real_out6; + real_out7 = real_out7 > out_max ? out_max : real_out7; + real_out8 = real_out8 > out_max ? out_max : real_out8; + *(output_ptr + out_channel_offset) = (int8_t)real_out1; + *(output_ptr + out_channel_offset + 1) = (int8_t)real_out2; + *(output_ptr + out_channel_offset + 2) = (int8_t)real_out3; + *(output_ptr + out_channel_offset + 3) = (int8_t)real_out4; + *(output_ptr + out_channel_offset + 4) = (int8_t)real_out5; + *(output_ptr + out_channel_offset + 5) = (int8_t)real_out6; + *(output_ptr + out_channel_offset + 6) = (int8_t)real_out7; + *(output_ptr + out_channel_offset + 7) = (int8_t)real_out8; + } // in_channel loop + int channel_s = (c8 - 1) * C8NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + int16_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = (int8_t)real_out; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + // input channel is equal to output channel + float input_scale = pooling_param->quant_args_[0][0].scale_; + int input_zp = pooling_param->quant_args_[0][0].zp_; + float output_scale = pooling_param->quant_args_[1][0].scale_; + int output_zp = pooling_param->quant_args_[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c16 = UP_DIV(channel, 16); + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c16 - 1; j++) { + int in_channel_offset = in_batch_offset + j * 16; + int out_channel_offset = out_plane_offset + j * 16; +#ifdef ENABLE_NEON + int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); +#else + int8_t tmp_max1 = INT8_MIN; + int8_t tmp_max2 = INT8_MIN; + int8_t tmp_max3 = INT8_MIN; + int8_t tmp_max4 = INT8_MIN; + int8_t tmp_max5 = INT8_MIN; + int8_t tmp_max6 = INT8_MIN; + int8_t tmp_max7 = INT8_MIN; + int8_t tmp_max8 = INT8_MIN; + int8_t tmp_max9 = INT8_MIN; + int8_t tmp_max10 = INT8_MIN; + int8_t tmp_max11 = INT8_MIN; + int8_t tmp_max12 = INT8_MIN; + int8_t tmp_max13 = INT8_MIN; + int8_t tmp_max14 = INT8_MIN; + int8_t tmp_max15 = INT8_MIN; + int8_t tmp_max16 = INT8_MIN; +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); +#else + tmp_max1 = MaxInt8(tmp_max1, *(input_ptr + in_offset)); + tmp_max2 = MaxInt8(tmp_max2, *(input_ptr + in_offset + 1)); + tmp_max3 = MaxInt8(tmp_max3, *(input_ptr + in_offset + 2)); + tmp_max4 = MaxInt8(tmp_max4, *(input_ptr + in_offset + 3)); + tmp_max5 = MaxInt8(tmp_max5, *(input_ptr + in_offset + 4)); + tmp_max6 = MaxInt8(tmp_max6, *(input_ptr + in_offset + 5)); + tmp_max7 = MaxInt8(tmp_max7, *(input_ptr + in_offset + 6)); + tmp_max8 = MaxInt8(tmp_max8, *(input_ptr + in_offset + 7)); + tmp_max9 = MaxInt8(tmp_max9, *(input_ptr + in_offset + 8)); + tmp_max10 = MaxInt8(tmp_max10, *(input_ptr + in_offset + 9)); + tmp_max11 = MaxInt8(tmp_max11, *(input_ptr + in_offset + 10)); + tmp_max12 = MaxInt8(tmp_max12, *(input_ptr + in_offset + 11)); + tmp_max13 = MaxInt8(tmp_max13, *(input_ptr + in_offset + 12)); + tmp_max14 = MaxInt8(tmp_max14, *(input_ptr + in_offset + 13)); + tmp_max15 = MaxInt8(tmp_max15, *(input_ptr + in_offset + 14)); + tmp_max16 = MaxInt8(tmp_max16, *(input_ptr + in_offset + 15)); +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + vst1q_s8(output_ptr + out_channel_offset, tmp_max); +#else + *(output_ptr + out_channel_offset) = tmp_max1; + *(output_ptr + out_channel_offset + 1) = tmp_max2; + *(output_ptr + out_channel_offset + 2) = tmp_max3; + *(output_ptr + out_channel_offset + 3) = tmp_max4; + *(output_ptr + out_channel_offset + 4) = tmp_max5; + *(output_ptr + out_channel_offset + 5) = tmp_max6; + *(output_ptr + out_channel_offset + 6) = tmp_max7; + *(output_ptr + out_channel_offset + 7) = tmp_max8; + *(output_ptr + out_channel_offset + 8) = tmp_max9; + *(output_ptr + out_channel_offset + 9) = tmp_max10; + *(output_ptr + out_channel_offset + 10) = tmp_max11; + *(output_ptr + out_channel_offset + 11) = tmp_max12; + *(output_ptr + out_channel_offset + 12) = tmp_max13; + *(output_ptr + out_channel_offset + 13) = tmp_max14; + *(output_ptr + out_channel_offset + 14) = tmp_max15; + *(output_ptr + out_channel_offset + 15) = tmp_max16; +#endif + } // in_channel loop + int channel_s = (c16 - 1) * 16; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_max; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pooling_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pooling_int8.h new file mode 100644 index 0000000000..007b60f32d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/pooling_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_POOLING_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" + +void AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_POOLING_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.cc new file mode 100644 index 0000000000..de24938986 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == nullptr || real_values == nullptr) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] + zp) * scale; + } + return NNACL_OK; +} + +int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == nullptr || real_values == nullptr) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + quant_values[i] = (int8_t)round(real_values[i] / scale + zp); + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h new file mode 100644 index 0000000000..d11ae2c5f6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_QUANTDTYPECAST_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_QUANTDTYPECAST_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct QuantDTypeCastParameter { + OpParameter op_parameter_; + int32_t srcT; + int32_t dstT; +}; + +int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relu_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relu_int8.h new file mode 100644 index 0000000000..88d5adcb13 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/relu_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RELU_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RELU_INT8_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +struct ReluQuantArg { + QuantArg input_arg; + QuantArg output_arg; + int input_multiplier_; + int left_shift_; + int right_shift_; +}; + +inline void ReluInt8(const int8_t *src, int length, int8_t *dst, ReluQuantArg *arg) { + for (int i = 0; i < length; ++i) { + if (src[i] <= arg->input_arg.zp_) { + dst[i] = arg->output_arg.zp_; + continue; + } + const int32_t input_val = src[i] - arg->input_arg.zp_; + const int32_t scaled_input = SaturatingRoundingDoublingHighMul(input_val, arg->input_multiplier_); + const int32_t shifted_input = RoundingDivideByPOT(scaled_input * (1 << arg->left_shift_), -arg->right_shift_); + const int32_t output = shifted_input + arg->output_arg.zp_; + dst[i] = (int8_t)output; + } +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RELU_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/reshape_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/reshape_int8.cc new file mode 100644 index 0000000000..3993e3d3d2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/reshape_int8.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/reshape_int8.h" +#include + +void Reshape(int8_t *input_ptr, int8_t *output_ptr, size_t data_size, int input_num, QuantArg in_quant_arg, + QuantArg out_quant_arg) { + if (in_quant_arg.scale_ == out_quant_arg.scale_ && in_quant_arg.zp_ == out_quant_arg.zp_) { + memcpy(output_ptr, input_ptr, data_size); + } else { + float output_inverse_scale = 1.f / out_quant_arg.scale_; + float scale = in_quant_arg.scale_ * output_inverse_scale; + float bias = -in_quant_arg.zp_ * scale; + int32_t output_zp = out_quant_arg.zp_; + for (int i = 0; i < input_num; i++) { + int32_t output_tmp = round(input_ptr[i] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[i] = 127; + } else if (output_tmp < -128) { + output_ptr[i] = -128; + } else { + output_ptr[i] = (int8_t)output_tmp; + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/reshape_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/reshape_int8.h new file mode 100644 index 0000000000..129178086a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/reshape_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RESHAHPE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RESHAHPE_INT8_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Reshape(int8_t *input_ptr, int8_t *output_ptr, size_t data_size, int input_num, QuantArg in_quant_arg, + QuantArg out_quant_arg); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_RESHAHPE_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.cc new file mode 100644 index 0000000000..0428717522 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/softmax_int8.h" +#include + +int Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, + SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) { + int32_t axis = parameter->axis_; + int n_dim = parameter->n_dim_; + int *input_shape = parameter->input_shape_; + int axis_shape_size = input_shape[axis]; + + double output_scale = quant_param.out_quant_arg_.scale_; + int32_t output_zp = quant_param.out_quant_arg_.zp_; + + int inner_size = 1; + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int o = 0; o < count; o++) { + int outter_offset = o * axis_shape_size * inner_size; + for (int i = 0; i < inner_size; i++) { + float sum = 0; + for (int j = 0; j < axis_shape_size; j++) { + int axis_offset = outter_offset + i + j * inner_size; + sum += exp_data[axis_offset]; + } + sum_data[i] = sum; + } + for (int j = 0; j < axis_shape_size; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int i = 0; i < inner_size; i++) { + int inner_offset = axis_offset + i; + float real_output = exp_data[inner_offset] / sum_data[i]; + int32_t output_scaled = round(real_output / output_scale) + output_zp; + output_ptr[inner_offset] = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, output_scaled)); + } + } + } + return 0; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h new file mode 100644 index 0000000000..546f9e940a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/softmax_int8.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SOFTMAX_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SOFTMAX_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" + +int Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data, + SoftmaxQuantArg quant_param, SoftmaxParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/split_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/split_int8.cc new file mode 100644 index 0000000000..b4d6602124 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/split_int8.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/split_int8.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int DoSplit(int8_t *in_data, int8_t **out_data, const int *input_shape, int offset, int num_unit, + SplitParameter *param) { + if (in_data == nullptr || out_data == nullptr) { + return NNACL_ERR; + } + int num_split = param->num_split_; + int *split_sizes = param->split_sizes_; + int *strides = param->strides_; + int split_dim = param->split_dim_; + int in_stride = strides[split_dim]; + + int stride_per_split = in_stride * input_shape[split_dim]; + int split_which = offset % num_split; + int split_times = offset / num_split; + int8_t *src = in_data + split_times * stride_per_split; + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride; + } + + QuantArg in_quant_arg = param->quant_arg_.in_args_; + float in_scale = in_quant_arg.scale_; + int32_t in_zp = in_quant_arg.zp_; + QuantArg *out_quant_arg = param->quant_arg_.out_args_; + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int copy_size = split_sizes[split_which] * in_stride; + int8_t *dst = out_data[split_which] + split_times * copy_size; + float out_scale = out_quant_arg[split_which].scale_; + int32_t out_zp = out_quant_arg[split_which].zp_; + if (in_scale == out_scale && in_zp == out_zp) { + (void)memcpy(dst, src, copy_size * sizeof(int8_t)); + } else { + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + for (int j = 0; j < copy_size; j++) { + int32_t output_tmp = round(src[j] * scale + bias) + out_zp; + if (output_tmp > param->quant_arg_.output_activation_max_) { + dst[j] = param->quant_arg_.output_activation_max_; + } else if (output_tmp < param->quant_arg_.output_activation_min_) { + dst[j] = param->quant_arg_.output_activation_min_; + } else { + dst[j] = static_cast(output_tmp); + } + } + } + src += copy_size; + } + + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/split_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/split_int8.h new file mode 100644 index 0000000000..2c269af394 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/split_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SPLIT_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SPLIT_INT8_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" + +int DoSplit(int8_t *in_data, int8_t **out_data, const int *input_shape, int offset, int num_unit, + SplitParameter *split_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_SPLIT_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/topk_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/topk_int8.cc new file mode 100644 index 0000000000..3cb482177f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/topk_int8.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/int8/topk_int8.h" + +int DescendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)b)->element - ((const TopkNodeInt8 *)a)->element; +} + +int AscendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)a)->element - ((const TopkNodeInt8 *)b)->element; +} + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter) { + int last_dim_size = parameter->last_dim_size_; + int loop_num = parameter->loop_num_; + int k = parameter->k_; + TopkNodeInt8 *top_map = (TopkNodeInt8 *)parameter->topk_node_list_; + + int8_t *cur_input_data = input_data; + int8_t *cur_output_data = output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < loop_num; i++) { + for (int j = 0; j < last_dim_size; j++) { + top_map[j].element = *(cur_input_data + j); + top_map[j].index = j; + } + if (parameter->sorted_) { + qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmpInt8); + } else { + qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmpInt8); + } + for (int m = 0; m < k; m++) { + cur_output_data[m] = top_map[m].element; + cur_output_index[m] = top_map[m].index; + } + cur_input_data += last_dim_size; + cur_output_data += k; + cur_output_index += k; + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/topk_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/topk_int8.h new file mode 100644 index 0000000000..5b7bbc554f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/topk_int8.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_TOPK_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_TOPK_INT8_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/fp32/topk.h" + +struct TopkNodeInt8 { + int8_t element; + int32_t index; +}; + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_TOPK_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul.h new file mode 100644 index 0000000000..058f0371ab --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MATMUL_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +enum ActType { ActType_No, ActType_Relu, ActType_Relu6 }; + +struct MatMulParameter { + OpParameter op_parameter_; + int row_; + int col_; + int row_8_; + int col_8_; + int deep_; + bool has_bias_; + int batch; + bool a_transpose_; /* false : row-major */ + bool b_transpose_; /* true : col-major */ + ActType act_type_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matrix_table.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matrix_table.h new file mode 100644 index 0000000000..3a21e6c68c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matrix_table.h @@ -0,0 +1,512 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MATRIX_TABLE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MATRIX_TABLE_H_ + +inline void MatrixG4x2(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 0.5f; + matrix_data[4] = 1.0f; + matrix_data[5] = -0.5f; + matrix_data[6] = 0.0f; + matrix_data[7] = 1.0f; +} + +inline void MatrixGT2x4(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 0.5f; + matrix_data[6] = -0.5f; + matrix_data[7] = 1.0f; +} + +inline void MatrixG8x2(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 0.5f; + matrix_data[4] = 1.0f; + matrix_data[5] = -0.5f; + matrix_data[6] = 1.0f; + matrix_data[7] = 1.0f; + matrix_data[8] = 1.0f; + matrix_data[9] = -1.0f; + matrix_data[10] = 1.0f; + matrix_data[11] = 1.5f; + matrix_data[12] = 1.0f; + matrix_data[13] = -1.5f; + matrix_data[14] = 0.0f; + matrix_data[15] = 1.0f; +} + +inline void MatrixGT2x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.5f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 1.0f; +} + +inline void MatrixG8x3(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 0.5f; + matrix_data[5] = 0.25f; + matrix_data[6] = 1.0f; + matrix_data[7] = -0.5f; + matrix_data[8] = 0.25f; + matrix_data[9] = 1.0f; + matrix_data[10] = 1.0f; + matrix_data[11] = 1.0f; + matrix_data[12] = 1.0f; + matrix_data[13] = -1.0f; + matrix_data[14] = 1.0f; + matrix_data[15] = 1.0f; + matrix_data[16] = 1.5f; + matrix_data[17] = 2.25f; + matrix_data[18] = 1.0f; + matrix_data[19] = -1.5f; + matrix_data[20] = 2.25f; + matrix_data[21] = 0.0f; + matrix_data[22] = 0.0f; + matrix_data[23] = 1.0f; +} + +inline void MatrixGT3x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 1.0f; +} + +inline void MatrixG8x4(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 0.5f; + matrix_data[6] = 0.25f; + matrix_data[7] = 0.125f; + matrix_data[8] = 1.0f; + matrix_data[9] = -0.5f; + matrix_data[10] = 0.25f; + matrix_data[11] = -0.125f; + matrix_data[12] = 1.0f; + matrix_data[13] = 1.0f; + matrix_data[14] = 1.0f; + matrix_data[15] = 1.0f; + matrix_data[16] = 1.0f; + matrix_data[17] = -1.0f; + matrix_data[18] = 1.0f; + matrix_data[19] = -1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 1.5f; + matrix_data[22] = 2.25f; + matrix_data[23] = 3.375f; + matrix_data[24] = 1.0f; + matrix_data[25] = -1.5f; + matrix_data[26] = 2.25f; + matrix_data[27] = -3.375f; + matrix_data[28] = 0.0f; + matrix_data[29] = 0.0f; + matrix_data[30] = 0.0f; + matrix_data[31] = 1.0f; +} + +inline void MatrixGT4x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 1.0f; +} + +inline void MatrixG8x5(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 0.5f; + matrix_data[7] = 0.25f; + matrix_data[8] = 0.125f; + matrix_data[9] = 0.0625f; + matrix_data[10] = 1.0f; + matrix_data[11] = -0.5f; + matrix_data[12] = 0.25f; + matrix_data[13] = -0.125f; + matrix_data[14] = 0.0625f; + matrix_data[15] = 1.0f; + matrix_data[16] = 1.0f; + matrix_data[17] = 1.0f; + matrix_data[18] = 1.0f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = -1.0f; + matrix_data[22] = 1.0f; + matrix_data[23] = -1.0f; + matrix_data[24] = 1.0f; + matrix_data[25] = 1.0f; + matrix_data[26] = 1.5f; + matrix_data[27] = 2.25f; + matrix_data[28] = 3.375f; + matrix_data[29] = 5.0625f; + matrix_data[30] = 1.0f; + matrix_data[31] = -1.5f; + matrix_data[32] = 2.25f; + matrix_data[33] = -3.375f; + matrix_data[34] = 5.0625f; + matrix_data[35] = 0.0f; + matrix_data[36] = 0.0f; + matrix_data[37] = 0.0f; + matrix_data[38] = 0.0f; + matrix_data[39] = 1.0f; +} + +inline void MatrixGT5x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 0.0f; + matrix_data[32] = 0.0f; + matrix_data[33] = 0.0625f; + matrix_data[34] = 0.0625f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.0f; + matrix_data[37] = 5.0625f; + matrix_data[38] = 5.0625f; + matrix_data[39] = 1.0f; +} + +inline void MatrixG8x6(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 0.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.5f; + matrix_data[8] = 0.25f; + matrix_data[9] = 0.125f; + matrix_data[10] = 0.0625f; + matrix_data[11] = 0.03125f; + matrix_data[12] = 1.0f; + matrix_data[13] = -0.5f; + matrix_data[14] = 0.25f; + matrix_data[15] = -0.125f; + matrix_data[16] = 0.0625f; + matrix_data[17] = -0.03125f; + matrix_data[18] = 1.0f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 1.0f; + matrix_data[22] = 1.0f; + matrix_data[23] = 1.0f; + matrix_data[24] = 1.0f; + matrix_data[25] = -1.0f; + matrix_data[26] = 1.0f; + matrix_data[27] = -1.0f; + matrix_data[28] = 1.0f; + matrix_data[29] = -1.0f; + matrix_data[30] = 1.0f; + matrix_data[31] = 1.5f; + matrix_data[32] = 2.25f; + matrix_data[33] = 3.375f; + matrix_data[34] = 5.0625f; + matrix_data[35] = 7.59375f; + matrix_data[36] = 1.0f; + matrix_data[37] = -1.5f; + matrix_data[38] = 2.25f; + matrix_data[39] = -3.375f; + matrix_data[40] = 5.0625f; + matrix_data[41] = -7.59375f; + matrix_data[42] = 0.0f; + matrix_data[43] = 0.0f; + matrix_data[44] = 0.0f; + matrix_data[45] = 0.0f; + matrix_data[46] = 0.0f; + matrix_data[47] = 1.0f; +} + +inline void MatrixGT6x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 0.0f; + matrix_data[32] = 0.0f; + matrix_data[33] = 0.0625f; + matrix_data[34] = 0.0625f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.0f; + matrix_data[37] = 5.0625f; + matrix_data[38] = 5.0625f; + matrix_data[39] = 0.0f; + matrix_data[40] = 0.0; + matrix_data[41] = 0.03125f; + matrix_data[42] = -0.03125f; + matrix_data[43] = 1.0f; + matrix_data[44] = -1.0f; + matrix_data[45] = 7.59375f; + matrix_data[46] = -7.59375f; + matrix_data[47] = 0.0f; + matrix_data[48] = 1.0f; +} + +inline void MatrixG8x7(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 0.0f; + matrix_data[6] = 0.0f; + matrix_data[7] = 1.0f; + matrix_data[8] = 0.5f; + matrix_data[9] = 0.25f; + matrix_data[10] = 0.125f; + matrix_data[11] = 0.0625f; + matrix_data[12] = 0.03125f; + matrix_data[13] = 0.015625f; + matrix_data[14] = 1.0f; + matrix_data[15] = -0.5f; + matrix_data[16] = 0.25f; + matrix_data[17] = -0.125f; + matrix_data[18] = 0.0625f; + matrix_data[19] = -0.03125f; + matrix_data[20] = 0.015625f; + matrix_data[21] = 1.0f; + matrix_data[22] = 1.0f; + matrix_data[23] = 1.0f; + matrix_data[24] = 1.0f; + matrix_data[25] = 1.0f; + matrix_data[26] = 1.0f; + matrix_data[27] = 1.0f; + matrix_data[28] = 1.0f; + matrix_data[29] = -1.0f; + matrix_data[30] = 1.0f; + matrix_data[31] = -1.0f; + matrix_data[32] = 1.0f; + matrix_data[33] = -1.0f; + matrix_data[34] = 1.0f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.5f; + matrix_data[37] = 2.25f; + matrix_data[38] = 3.375f; + matrix_data[39] = 5.0625f; + matrix_data[40] = 7.59375f; + matrix_data[41] = 11.390625f; + matrix_data[42] = 1.0f; + matrix_data[43] = -1.5f; + matrix_data[44] = 2.25f; + matrix_data[45] = -3.375f; + matrix_data[46] = 5.0625f; + matrix_data[47] = -7.59375f; + matrix_data[48] = 11.390625f; + matrix_data[49] = 0.0f; + matrix_data[50] = 0.0f; + matrix_data[51] = 0.0f; + matrix_data[52] = 0.0f; + matrix_data[53] = 0.0f; + matrix_data[54] = 0.0f; + matrix_data[55] = 1.0f; +} + +inline void MatrixGT7x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 0.0f; + matrix_data[32] = 0.0f; + matrix_data[33] = 0.0625f; + matrix_data[34] = 0.0625f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.0f; + matrix_data[37] = 5.0625f; + matrix_data[38] = 5.0625f; + matrix_data[39] = 0.0f; + matrix_data[40] = 0.0; + matrix_data[41] = 0.03125f; + matrix_data[42] = -0.03125f; + matrix_data[43] = 1.0f; + matrix_data[44] = -1.0f; + matrix_data[45] = 7.59375f; + matrix_data[46] = -7.59375f; + matrix_data[47] = 0.0f; + matrix_data[48] = 0.0f; + matrix_data[49] = 0.015625f; + matrix_data[50] = 0.015625f; + matrix_data[51] = 1.0f; + matrix_data[52] = 1.0f; + matrix_data[53] = 11.390625f; + matrix_data[54] = 11.390625f; + matrix_data[55] = 1.0f; +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MATRIX_TABLE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/mul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/mul_parameter.h new file mode 100644 index 0000000000..8b70bd09ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/mul_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MUL_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MUL_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct MulParameter { + OpParameter op_parameter_; + int thread_count_; + MulQuantArg mul_quant_arg_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_MUL_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/nnacl_utils.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/nnacl_utils.cc new file mode 100644 index 0000000000..1607d5793d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/nnacl_utils.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/nnacl_utils.h" +#ifdef __ANDROID__ +#include +#endif + +#if defined(__ANDROID__) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/nnacl_utils.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/nnacl_utils.h new file mode 100644 index 0000000000..a8d39d669f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/nnacl_utils.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_NNACL_UTILS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_NNACL_UTILS_H_ + +#include + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_NNACL_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h new file mode 100644 index 0000000000..dcf1e956e7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_OP_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_OP_BASE_H_ + +#include +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +#define C4NUM 4 +#define C8NUM 8 +#define BLOCK 4 +#define TILE_NUM 8 + +#define MSMIN(x, y) ((x) < (y) ? (x) : (y)) +#define MSMAX(x, y) ((x) > (y) ? (x) : (y)) + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) +#define UP_ROUND_DIV(x, y) (x % y == 0 ? (x / y) : (x / y) + 1) +#define DOWN_DIV(x, y) (((x) - (y) + (1)) / (y)) + +#define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right)) + +#define DIMENSION_4D 4 + +#define kInputIndex 0 +#define kWeightIndex 1 +#define kBiasIndex 2 +#define kOutputIndex 0 +#define kNHWC_N 0 +#define kNHWC_H 1 +#define kNHWC_W 2 +#define kNHWC_C 3 +#define kInputSize1 2 +#define kInputSize2 3 + +enum LiteDataType { + kDataTypeFloat, + kDataTypeInt8, +}; + +struct OpParameter { + char name_[100]; + int type_; + int thread_num_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_OP_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c new file mode 100644 index 0000000000..82591d4aef --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + size_t ksize, size_t ic4, size_t output_channel, size_t offset, + const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, + size_t out_multiplier, size_t shift_before, size_t shift_after); + +void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + size_t ksize, size_t ic4, size_t output_channel, size_t offset, + const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, + size_t out_multiplier, size_t shift_before, size_t shift_after) { + return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, + act_max, out_zp, out_multiplier, shift_before, shift_after); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h new file mode 100644 index 0000000000..d5bf245376 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_OPTIMIZED_KERNEL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_OPTIMIZED_KERNEL_H_ + +#include +#ifdef __ANDROID__ +#include +#include "src/runtime/kernel/arm/nnacl/nnacl_utils.h" +#endif + +#define OPTIMIZE_SHARED_LIBRARY_PATH "liboptimize.so" + +class OptimizeModule { + public: + OptimizeModule() { + bool support_optimize_ops = false; + bool support_fp16 = false; +#ifdef __ANDROID__ + int hwcap_type = 16; + uint32_t hwcap = getHwCap(hwcap_type); +#ifdef ENABLE_ARM64 + if (hwcap & HWCAP_FPHP) { +#elif defined(__arm__) + if (hwcap & HWCAP_HALF) { +#endif + MS_LOG(INFO) << "Hw cap support FP16, hwcap: 0x" << hwcap; + support_fp16 = true; +#ifdef ENABLE_ARM64 + } +#elif defined(__arm__) + } +#endif + +#ifdef ENABLE_ARM64 + if (hwcap & HWCAP_ASIMDDP) { + printf("Hw cap support SMID Dot Product, hwcap: 0x%x \n", hwcap); + support_optimize_ops = true; + } else { + printf("Hw cap NOT support SIMD Dot Product, hwcap: 0x%x\n", hwcap); + } +#endif +#endif + if ((!support_optimize_ops) && (!support_fp16)) { + return; + } + optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); + if (optimized_op_handler_ == nullptr) { + printf("Open optimize shared library failed.\n"); + } + } + + ~OptimizeModule() = default; + + static OptimizeModule *GetInstance() { + static OptimizeModule opt_module; + return &opt_module; + } + void *optimized_op_handler_ = nullptr; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_OPTIMIZED_KERNEL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.cc new file mode 100644 index 0000000000..64dcbaaa4c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.cc @@ -0,0 +1,838 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include +#include + +void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block, + int oc_block_num) { + // original weight format : ohwi + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc_block * oc_block_num * ic4 * C4NUM * kernel_plane; + + int unit_size = oc_block * C4NUM; + int block_size = pack_weight_size / oc_block_num; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * ic4; + for (int i = 0; i < ic4; i++) { + int channel_block_stride = kernel_plane_stride + i * C4NUM; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * C4NUM; + int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h * oc_block; + for (int j = 0; j < oc_block_num; j++) { + int kernel_block_stride = block_stride + j * oc_block * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * oc_block; + int real_oc_num = oc_remainder < oc_block ? oc_remainder : oc_block; + for (int k = 0; k < real_oc_num; k++) { + float *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + float *packed_data_ptr = packed_weight + packed_kernel_block_size + k; + *packed_data_ptr = *origin_data_ptr; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) { + // original weight format : ohwi + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + int pack_weight_size = oc4 * C4NUM * ic4 * C4NUM * plane_c4 * C4NUM; + int block_size = pack_weight_size / oc4; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * C4NUM; + for (int i = 0; i < ic4; i++) { + int channel_block_stride = kernel_plane_stride + i * C4NUM; + int packed_channel_block_size = packed_kernel_plane_stride + i * C4NUM * C4NUM * C4NUM; + int ic_remainder = in_channel - i * C4NUM; + int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h; + for (int j = 0; j < oc4; j++) { + int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * C4NUM; + int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM; + for (int k = 0; k < real_oc_num; k++) { + int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM * C4NUM; + *packed_data_ptr = origin_data_ptr[0]; + // value of weight must between [-127, 127] + if (packed_data_ptr[0] == -128) { + packed_data_ptr[0] = -127; + } + weight_sum[j * C4NUM + k] += (int32_t)packed_data_ptr[0]; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) { + // original weight format : ohwi + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; + int unit_size = C4NUM * C4NUM; + int block_size = pack_weight_size / oc4; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * ic4; + for (int i = 0; i < ic4; i++) { + int channel_block_stride = kernel_plane_stride + i * C4NUM; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * C4NUM; + int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h; + for (int j = 0; j < oc4; j++) { + int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * C4NUM; + int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM; + for (int k = 0; k < real_oc_num; k++) { + int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM; + *packed_data_ptr = origin_data_ptr[0]; + if (packed_data_ptr[0] == -128) { + packed_data_ptr[0] = -127; + } + weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - filter_zp); + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param) { + /* support nhwc */ + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const float *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_; + float *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_, + conv_param->input_channel_ * sizeof(float)); + } + } + return; +} + +void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param) { + int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < conv_param->output_channel_; oc++) { + int oc4mod = oc % 4; + int oc4div = oc / 4; + int dst_index = oc4div * c4 * C4NUM + ic * C4NUM + oc4mod; + int src_index = oc * conv_param->input_channel_ + ic; + packed_weight[dst_index] = weight_data[src_index]; + } + } + return; +} + +void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int ic4 = UP_DIV(in_channel, C4NUM); + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * ic4 * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * ic4 * C4NUM; + int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM; + for (int m = 0; m < ic4; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * C8NUM * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32(packed_input + channel_block_offset, vld1q_f32(input_data + channel_block_stride)); +#else + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; +#endif + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + } // tile num loop +} + +void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param) { + // input format : nhwc + int tile_num = conv_param->tile_num_; + int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_w = conv_param->output_w_; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_cal_num_offset = i * C4NUM * C4NUM; + int32_t input_accumulator = 0; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * ic4 * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * ic4 * C4NUM; + int plane_c4_block = (j * kernel_w + n) / C4NUM; + int plane_c4_res = (j * kernel_w + n) % C4NUM; + int input_plane_offset = + plane_c4_block * tile_num * C4NUM * C4NUM * ic4 + plane_c4_res * C4NUM + input_cal_num_offset; + for (int m = 0; m < ic4; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * tile_num * C4NUM * C4NUM; + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; + input_accumulator += (packed_input + channel_block_offset)[0]; + input_accumulator += (packed_input + channel_block_offset)[1]; + input_accumulator += (packed_input + channel_block_offset)[2]; + input_accumulator += (packed_input + channel_block_offset)[3]; + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + input_sum[i] = input_accumulator * filter_zp; + } // tile num loop +} + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param) { + // input format : nhwc + int tile_num = conv_param->tile_num_; + int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_w = conv_param->output_w_; + int block_size = kernel_h * kernel_w; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * ic4 * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * ic4 * C4NUM; + int input_plane_offset = (j * kernel_w + n) * tile_num * C4NUM * ic4 + i * C4NUM; + for (int m = 0; m < ic4; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * tile_num * C4NUM; + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + int32_t input_accumulator = 0; + for (int j = 0; j < block_size; j++) { + int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM; + for (int c = 0; c < ic4; c++) { + int ic4_offset = block_offset + c * tile_num * C4NUM; + input_accumulator += (packed_input + ic4_offset)[0]; + input_accumulator += (packed_input + ic4_offset)[1]; + input_accumulator += (packed_input + ic4_offset)[2]; + input_accumulator += (packed_input + ic4_offset)[3]; + } + } + input_sum[i] = input_accumulator * filter_zp; + } // tile num loop +} + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) { + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic8 = UP_DIV(in_channel, C8NUM); + + for (int b = 0; b < in_batch; b++) { + int src_batch_offset = b * in_channel * in_h * in_w; + int dst_batch_offset = b * ic8 * C8NUM * in_h * in_w; + for (int c = 0; c < in_channel; c++) { + int ic8_block = c / C8NUM; + int ic8_res = c % C8NUM; + int src_c_offset = src_batch_offset + c; + int dst_c_offset = dst_batch_offset + ic8_block * C8NUM * in_h * in_w + ic8_res; + for (int k = 0; k < in_w * in_h; k++) { + int src_plane_offset = src_c_offset + k * in_channel; + int dst_plane_offset = dst_c_offset + k * C8NUM; + (packed_input + dst_plane_offset)[0] = (int16_t)(input_data + src_plane_offset)[0]; + } + } + } +} + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = UP_DIV(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + int filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; + for (int i = 0; i < input_channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp); + } + } + } +} + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + nhwc4_batch_offset + i * c4 * C4NUM, (float *)src + batch_offset + i * channel, + channel * sizeof(float)); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM, + channel * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * c4 * C4NUM; + ((float *)dst)[dst_plane_offset] = ((float *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset)); +#else + ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; + ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; + ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; + ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; +#endif + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = ((float *)src)[src_index]; + } + } + } + return; +} + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc4_batch_offset + i * c4 * C4NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy(reinterpret_cast(dst) + batch_offset + i * channel, + reinterpret_cast(src) + nhwc4_batch_offset + i * c4 * C4NUM, channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * c4 * C4NUM; + ((uint8_t *)dst)[dst_plane_offset] = ((uint8_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; + ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; + ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; + ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; + ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index]; + } + } + } + return; +} + +void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c8 * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C8NUM; + for (int i = 0; i < channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c8_block_num * plane * C8NUM + c8_block_rem; + ((int8_t *)dst + dst_ic_offset)[0] = ((int8_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; + } + } + } + return; +} + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((float *)dst)[nchw_index] = ((float *)src)[nhwc_index]; + } + } + } + return; +} + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((float *)dst)[nhwc_index] = ((float *)src)[nchw_index]; + } + } + } + return; +} + +void MatrixPackUnit(const float *src, float *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride) { + size_t copy_size = row * C4NUM * sizeof(float); + for (int c = 0; c < col; c++) { + memcpy(dst + c * dst_stride, src + c * src_stride, copy_size); + } +} + +void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) { + int row4mod = row % 4; + int row4div = row / 4; + + for (int i = 0; i < row4div; i++) { + MatrixPackUnit(src + i * 4 * 4, dst + i * 4 * ic4 * 4, 4, ic4, stride, 16); + } + + if (row4mod > 0) { + MatrixPackUnit(src + row4div * 4 * 4, dst + row4div * 4 * ic4 * 4, row4mod, ic4, stride, row4mod * 4); + } + return; +} + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { + auto input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int unit = conv_param->input_h_ * conv_param->input_w_; + + for (int b = 0; b < conv_param->input_batch_; b++) { + auto src_b = src + b * unit * conv_param->input_channel_; + auto dst_b = dst + b * unit * ic4 * C4NUM; + for (int k = 0; k < unit; k++) { + auto src_k = src_b + k * conv_param->input_channel_; + auto dst_k = dst_b + k * ic4 * C4NUM; + for (int c = 0; c < conv_param->input_channel_; c++) { + dst_k[c] = (int16_t)(src_k[c] - input_zp); + } + } + } +} + +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, const ConvParameter *conv_param) { + auto weight_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int unit = conv_param->kernel_h_ * conv_param->kernel_w_; + for (int c = 0; c < conv_param->output_channel_; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + auto src_c = origin_weight + c * unit; + auto dst_c = packed_weight_ + c4_block_num * unit * C4NUM; + for (int k = 0; k < unit; k++) { + auto src_kernel = src_c + k; + auto dst_kernel = dst_c + C4NUM * k + c4_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h new file mode 100644 index 0000000000..02ed4c6fba --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PACK_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index); + +void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param); + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param); + +void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param); + +void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); + +void MatrixPack(const float *src, float *dst, int row, int ic4, int stride); + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); + +void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block, + int oc_block_num); + +void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); + +void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); + +void PackDepthwiseInt8Weight(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PACK_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.cc new file mode 100644 index 0000000000..58a52963dd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/nnacl/pack_ext.h" + +static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } + +void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) { + const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; + // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; + const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; + // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_h = conv_param->output_h_; + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int /*channel,*/ kernel_row, kernel_col, output_rows, output_col; + + int row_stride_offset = 0; + + for (output_rows = output_h; output_rows; output_rows--) { + int col_stride_offset = 0; + for (output_col = output_w; output_col; output_col--) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + + if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(data_col, in_data + offset, sizeof(float) * channels); + data_col += channels; + } else { + memset(data_col, 0, sizeof(float) * channels); + data_col += channels; + } + } + } + col_stride_offset += stride_w; + } + row_stride_offset += stride_h; + } +} + +// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w) +void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) { + const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; + // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; + const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; + // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_h = conv_param->output_h_; + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int channel, kernel_row, kernel_col, output_rows, output_col; + + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + for (channel = 0; channel < channels; channel++) { + int input_row = -pad_up + kernel_row * dilation_h; + for (output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { + for (output_col = output_w; output_col; output_col--) { + *(data_row++) = 0; + } + } else { + int input_col = -pad_left + kernel_col * dilation_w; + for (output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + const int offset = (input_row * in_width + input_col) * tot_channels + channel; + *(data_row++) = in_data[offset]; + } else { + *(data_row++) = 0; + } + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} + +void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) { + const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; + // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; + const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; + // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_h = conv_param->output_h_; + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col, output_rows, output_col; + + int row_stride_offset = 0; + + for (output_rows = output_h; output_rows; output_rows--) { + int col_stride_offset = 0; + for (output_col = output_w; output_col; output_col--) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + + if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + col_stride_offset += stride_w; + } + row_stride_offset += stride_h; + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.h new file mode 100644 index 0000000000..d943467f16 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PACK_EXT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PACK_EXT_H_ + +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param); +void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param); +void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PACK_EXT_H diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pad_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/pad_parameter.h new file mode 100644 index 0000000000..230ca70aaf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pad_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PAD_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PAD_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +#define MAX_PAD_SIZE 8 +#define DEFAULT_PAD_NDIMS 4 + +struct PadParameter { + OpParameter op_parameter_; + PadQuantArg pad_quant_arg_; + int paddings_[MAX_PAD_SIZE] = {0}; + int pad_mode_; + float constant_value_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PAD_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/power.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/power.cc new file mode 100644 index 0000000000..b1165f12c1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/power.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/power.h" + +bool CheckInteger(float f) { return floorf(f) == f; } + +float OptimizedPowerImpl(float x, int exponent) { + int exp = abs(exponent); + float result = 1; + float iterator = x; + while (exp) { + if (exp % 2) { + result *= iterator; + } + iterator *= iterator; + exp = exp / 2; + } + return exponent >= 0 ? result : 1 / result; +} + +float StdPowerImpl(float x, float exponent) { return pow(x, exponent); } + +void Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, + bool broadcast) { + if (broadcast) { + if (CheckInteger(*exponent)) { + for (int i = 0; i < len; ++i) { + output[i] = OptimizedPowerImpl(scale * input[i] + shift, (int)(*exponent)); + } + } else { + for (int i = 0; i < len; ++i) { + output[i] = StdPowerImpl(scale * input[i] + shift, *exponent); + } + } + } else { + for (int i = 0; i < len; ++i) { + if (CheckInteger(*exponent)) { + output[i] = OptimizedPowerImpl(scale * input[i] + shift, (int)exponent[i]); + } else { + output[i] = StdPowerImpl(scale * input[i] + shift, exponent[i]); + } + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/power.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/power.h new file mode 100644 index 0000000000..61800b47dd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/power.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_H_ +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct PowerParameter { + OpParameter op_parameter_; + float power_; + float scale_; + float shift_; +}; + +void Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POWER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/prelu.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/prelu.cc new file mode 100644 index 0000000000..6b454e94f9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/prelu.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/prelu.h" + +void PRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id) { + for (int i = task_id; i < prelu_param_->input_num_; i += prelu_param_->op_parameter_.thread_num_) { + if (input[i] <= 0) { + output[i] = input[i] * prelu_param_->negtive_slope_[0]; + } else { + output[i] = input[i]; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/prelu.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/prelu.h new file mode 100644 index 0000000000..208274ee6b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/prelu.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PRELU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PRELU_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct PReluParameter { + OpParameter op_parameter_; + float *negtive_slope_; + int input_num_; + int thread_num_; +}; + +void PRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PRELU_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/prior_box.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/prior_box.cc new file mode 100644 index 0000000000..e149ed90ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/prior_box.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/prior_box.h" + +int PriorBox(const float *input_data, float *output_data, const size_t size, const int tid, const int thread_num) { + size_t unit_size = size / thread_num; + if (tid == thread_num - 1) { + size_t tail_size = size - unit_size * tid; + (void)memcpy(output_data + tid * unit_size, input_data + tid * unit_size, tail_size * sizeof(float)); + } else { + (void)memcpy(output_data + tid * unit_size, input_data + tid * unit_size, unit_size * sizeof(float)); + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/prior_box.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/prior_box.h new file mode 100644 index 0000000000..fd0ea64e28 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/prior_box.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PRIOR_BOX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PRIOR_BOX_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#define PRIOR_BOX_MAX_NUM 8 +#define PRIOR_BOX_VAR_NUM 4 +struct PriorBoxParameter { + OpParameter op_parameter_; + int32_t min_sizes_size; + int32_t min_sizes[PRIOR_BOX_MAX_NUM]; + int32_t max_sizes_size; + int32_t max_sizes[PRIOR_BOX_MAX_NUM]; + int32_t aspect_ratios_size; + float aspect_ratios[PRIOR_BOX_MAX_NUM]; + float variances[PRIOR_BOX_VAR_NUM]; + int32_t image_size_w; + int32_t image_size_h; + float step_w; + float step_h; + bool clip; + bool flip; + float offset; +}; + +int PriorBox(const float *input_data, float *output_data, const size_t size, const int tid, const int thread_num); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_PRIOR_BOX_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h new file mode 100644 index 0000000000..5dfc38012f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_FIXED_POINT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_FIXED_POINT_H_ + +#include +#include "include/infer_log.h" +#ifdef ENABLE_NEON +#include +#endif + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +inline int SaturatingRoundingDoublingHighMul(int a, int b) { + if (a == INT_MIN && b == INT_MIN) { + return INT_MAX; + } + int64_t ab = ((int64_t)a) * ((int64_t)b); + int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30)); + // do not apply right shift to potential negetive values + int ab_mantissa = (int)((ab + rounding) / (1ll << 31)); + return ab_mantissa; +} + +inline int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) { + if (a == SHRT_MIN && b == SHRT_MIN) { + return SHRT_MAX; + } + int32_t ab = ((int32_t)a) * ((int32_t)b); + int16_t rounding = ab >= 0 ? (1ll << 14) : (1ll - (1ll << 14)); + return (int16_t)((ab + rounding) / (1ll << 15)); +} + +// division by a 2^exponent with rounding +// or arithmetic right shift with rouding +inline int RoundingDivideByPOT(int x, int exponent) { + MS_ASSERT(exponent >= 0); + MS_ASSERT(exponent <= 31); + const int mask = (1ll << exponent) - 1; + const int remainder = x & mask; + const int threshold = (mask >> 1) + (x < 0 ? 1 : 0); + return (x >> exponent) + (remainder > threshold ? 1 : 0); +} + +inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +#ifdef ENABLE_NEON +inline int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +inline int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { return vqrdmulhq_s32(a, b); } +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_FIXED_POINT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.cc new file mode 100644 index 0000000000..aee93ed738 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" + +const uint64_t dSignMask = 1ull << 63; +const uint64_t dExponentMask = 0x7ffull << 52; +const uint64_t dFractionMask = (1ull << 52) - 1; +const int dExponentBias = 1022; +const int dMantissaBits = 52; +const int dInfiniteExponent = 0x7ff; +const double dNormalizer = 0x1p54; +const int dNormalizerBias = 54; +const int iMantissaBits = 31; + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift) { + if (quantized_multiplier == nullptr || shift == nullptr) { + return; + } + // we split a floating number into two parts: exponent and fraction + // since fraction is stored as int32, only 31 bits of mantissa is remained + union { + double d; + uint64_t ul; + } dul; + dul.d = double_multiplier; + if (!(dul.ul & (~dSignMask))) { + // multiplier is 0 + *quantized_multiplier = 0; + *shift = 0; + return; + } + int exponent = (int) ((dul.ul & dExponentMask) >> dMantissaBits); + if (exponent == dInfiniteExponent) { + // multiplier is inf or NaN + *shift = 0; + if (!(dul.ul & dFractionMask)) { + // inf + *quantized_multiplier = (dul.ul & dSignMask) ? INT_MIN : INT_MAX; + } else { + // NaN + *quantized_multiplier = 0; + } + return; + } + if (exponent == 0) { + // multiplier is a subnormal number + dul.d *= dNormalizer; + exponent = (int) ((dul.ul & dExponentMask) >> dMantissaBits); + *shift = exponent - dExponentBias - dNormalizerBias; + } else { + *shift = exponent - dExponentBias; + } + uint64_t fraction = dul.ul & dFractionMask; + fraction += (1ull << dMantissaBits); + uint64_t rounded = ((fraction >> (dMantissaBits - iMantissaBits)) + 1ull) >> 1; + // we get 31 rounded bits now + if (rounded == (1ull << iMantissaBits)) { + // rounding may cause a carry + rounded >>= 1; + ++*shift; + } + *quantized_multiplier = (dul.ul & dSignMask) ? (-(int32_t)(rounded)) : (int32_t)(rounded); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h new file mode 100644 index 0000000000..421d582c5e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_ + +#include +#include +#include +#include +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct QuantArg { + double scale_; + int32_t zp_; +}; + +struct ConvQuantArg { + QuantArg **quant_args_; + double *real_multiplier_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; + int32_t *out_act_min_; + int32_t *out_act_max_; +}; + +struct ConcatQuantArg { + int *input_sizes_; + int output_size_; + int **input_shapes_; + int *output_shape_; + size_t input_num_; + size_t output_dim_; + QuantArg *in_quant_args_; + QuantArg out_quant_args_; +}; + +struct MatmulQuantArg { + QuantArg input; + QuantArg weight; + QuantArg output; + int32_t out_act_min; + int32_t out_act_max; + int32_t left_shift; + int32_t right_shift; + int32_t quant_multiplier; +}; + +struct PadQuantArg { + QuantArg *in_quant_args_ = nullptr; + QuantArg *out_quanr_args_ = nullptr; + int8_t *constant_value_ = nullptr; +}; + +struct MulQuantArg { + QuantArg in_quant_args_[2]; + QuantArg out_quant_arg_; + int output_multiplier_; + int output_activation_min_; + int output_activation_max_; + int shift_left_; + int shift_right_; +}; + +struct CropQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +}; + +struct ArithSelfQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +}; + +struct SplitQuantArg { + QuantArg in_args_; + QuantArg out_args_[20]; + int output_activation_min_; + int output_activation_max_; +}; + +struct SoftmaxQuantArg { + QuantArg in_quant_args_; + QuantArg out_quant_arg_; +}; + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); + +inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, + int *right_shift) { + if (quantized_multiplier == nullptr || right_shift == nullptr) { + return; + } + int shift; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + *right_shift = -shift; +} + +inline void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, + int *right_shift) { + int shift; + QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +inline uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +inline int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, + int *maxi) { + int32_t min = std::numeric_limits::min(); + int32_t max = std::numeric_limits::max(); + int32_t quantized_zero = QuantizeToInt8(0, scale, zp); + int32_t quantized_six = QuantizeToInt8(6, scale, zp); + if (is_relu) { + min = min > quantized_zero ? min : quantized_zero; + } else if (is_relu6) { + min = min > quantized_zero ? min : quantized_zero; + max = max < quantized_six ? max : quantized_six; + } else { + // do nothing + } + *mini = min; + *maxi = max; +} + +// quantize from float to int8 +inline void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { + for (int i = 0; i < length; ++i) { + int r = (int)round(input_data[i] / scale + zero_point); + int8_t q = r > CHAR_MAX ? (int8_t)CHAR_MAX : (int8_t)r; + q = q < CHAR_MIN ? CHAR_MIN : q; + output_data[i] = q; + } +} + +// dequantize from int8 to float +inline void Dequantize(int8_t *input_data, int length, float scale, int zero_point, float *output_data) { + for (int i = 0; i < length; ++i) { + output_data[i] = scale * (input_data[i] - zero_point); + } +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape.cc new file mode 100644 index 0000000000..5505c8b0f9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/reshape.h" +#include + +void Reshape(void *input_ptr, void *output_ptr, size_t data_size) { memcpy(output_ptr, input_ptr, data_size); } + + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape.h new file mode 100644 index 0000000000..a992a46ba0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_H_ +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void Reshape(void *input_ptr, void *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape_parameter.h new file mode 100644 index 0000000000..0243a60f86 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/reshape_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ReshapeParameter { + OpParameter op_parameter_; + int thread_count_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/resize.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/resize.cc new file mode 100644 index 0000000000..48567ce01b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/resize.cc @@ -0,0 +1,135 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "src/runtime/kernel/arm/nnacl/resize.h" +#include "src/runtime/kernel/arm/nnacl/common_func.h" + +int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + bool align_corners, int tid, int thread_num) { + if (input_data == nullptr || output_data == nullptr || input_shape == nullptr || output_shape == nullptr) { + return NNACL_NULL_PTR; + } + // nhwc (memory layout is nc4hw4) + int n = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int channel = input_shape[3]; + int c4 = UP_DIV(channel, C4NUM); + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + float height_scale = (float)(in_h) / new_height; + float width_scale = (float)(in_w) / new_width; + if (align_corners && new_height > 1) { + height_scale = (float)(in_h - 1) / (new_height - 1); + } + if (align_corners && new_width > 1) { + width_scale = (float)(in_w - 1) / (new_width - 1); + } + + int o[5]; // n c4 h w 4 + for (o[0] = 0; o[0] < n; o[0]++) { + for (o[1] = tid; o[1] < c4; o[1] += thread_num) { + for (o[2] = 0; o[2] < new_height; o[2]++) { + float actual_y = (float)(o[2]) * height_scale; + int y_left = (int)(floor(actual_y)); + int y_right = y_left + 1 < in_h ? (y_left + 1) : (in_h - 1); + float y_right_weight = actual_y - (float)(y_left); + float y_left_weight = 1.0 - y_right_weight; + for (o[3] = 0; o[3] < new_width; o[3]++) { + float actual_x = (float)(o[3]) * width_scale; + int x_left = (int)(floor(actual_x)); + int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1); + float x_right_weight = actual_x - (float)(x_left); + float x_left_weight = 1.0 - x_right_weight; + + auto input_base_offset = (((o[0] * c4 + o[1]) * in_h + y_left) * in_w + x_left) * C4NUM; + auto output_base_offset = (((o[0] * c4 + o[1]) * new_height + o[2]) * new_width + o[3]) * C4NUM; + int in_offset_1_0 = (y_right - y_left) * in_w * C4NUM; + int in_offset_0_1 = (x_right - x_left) * C4NUM; +#ifdef ENABLE_NEON + float32x4_t x_l_weight = vdupq_n_f32(x_left_weight); + float32x4_t x_r_weight = vdupq_n_f32(x_right_weight); + float32x4_t y_l_weight = vdupq_n_f32(y_left_weight); + float32x4_t y_r_weight = vdupq_n_f32(y_right_weight); + + float32x4_t input_yl_xl = vld1q_f32(input_data + input_base_offset); + float32x4_t input_yr_xl = vld1q_f32(input_data + input_base_offset + in_offset_1_0); + float32x4_t input_yl_xr = vld1q_f32(input_data + input_base_offset + in_offset_0_1); + float32x4_t input_yr_xr = vld1q_f32(input_data + input_base_offset + in_offset_0_1 + in_offset_1_0); + + float32x4_t interp_value = vdupq_n_f32(0.0); + float32x4_t interp_value_tmp = vmulq_f32(input_yl_xl, y_l_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_l_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + + interp_value_tmp = vmulq_f32(input_yr_xl, y_r_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_l_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + + interp_value_tmp = vmulq_f32(input_yl_xr, y_l_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_r_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + + interp_value_tmp = vmulq_f32(input_yr_xr, y_r_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_r_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + vst1q_f32(output_base_offset + output_data, interp_value); +#else + // 4 continuous data in a group; + for (o[4] = 0; o[4] < C4NUM; o[4]++) { + auto in_offset = input_base_offset + o[4]; + auto output_offset = output_base_offset + o[4]; + float interp_value = + input_data[in_offset] * y_left_weight * x_left_weight + + input_data[in_offset + in_offset_1_0] * y_right_weight * x_left_weight + + input_data[in_offset + in_offset_0_1] * y_left_weight * x_right_weight + + input_data[in_offset + in_offset_0_1 + in_offset_1_0] * y_right_weight * x_right_weight; + output_data[output_offset] = interp_value; + } +#endif + } + } + } + } + return NNACL_OK; +} + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + int tid, int thread_num) { + int batch, y, x, c; + c = input_shape[3]; + + float height_scale = (float)(input_shape[1]) / (float)(output_shape[1]); + float width_scale = (float)(input_shape[2]) / (float)(output_shape[2]); + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int actual_y = (int)(floor((float)(y) * height_scale)); + int input_y = actual_y < input_shape[1] ? actual_y : input_shape[1] - 1; + for (x = 0; x < output_shape[2]; x++) { + int actual_x = (int)(floor((float)(x) * width_scale)); + int input_x = actual_x < input_shape[2] ? actual_x : input_shape[2] - 1; + int in_offset = offset(input_shape, batch, input_y, input_x, 0); + int out_offset = offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float)); + } + } + } + + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/resize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/resize.h new file mode 100644 index 0000000000..e7fd4a6876 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/resize.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "schema/ops_generated.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +using mindspore::schema::ResizeMethod; + +struct ResizeParameter { + OpParameter op_parameter_; + ResizeMethod method_; + int64_t new_height_; + int64_t new_width_; + bool align_corners_; + bool preserve_aspect_ratio_; +}; + +int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + bool align_corners, int tid, int thread_num); + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + int tid, int thread_num); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESIZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/reverse_sequence.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/reverse_sequence.cc new file mode 100644 index 0000000000..7308716fbb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/reverse_sequence.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/reverse_sequence.h" +#include +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) { + (void)memcpy(output, input0, para->total_data_size_); + ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_); + ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_); + for (int i = 0; i < para->outer_count_; ++i) { + auto in = input0 + i * para->outer_stride_; + auto out = output + i * para->outer_stride_; + for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) { + auto in_batch = in + batch * para->input_stride_[para->batch_axis_]; + auto out_batch = out + batch * para->output_stride_[para->batch_axis_]; + for (int n = 0; n < input1[batch]; ++n) { + auto in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_]; + auto out_seq = out_batch + n * para->output_stride_[para->seq_axis_]; + for (int j = 0; j < para->inner_count_; ++j) { + (void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_); + } + } + } + } +} + + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/reverse_sequence.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/reverse_sequence.h new file mode 100644 index 0000000000..582bcd35eb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/reverse_sequence.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_REVERSE_SEQUENCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_REVERSE_SEQUENCE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ReverseSequenceParameter { + OpParameter op_parameter_; + int ndim_; + int input_shape0_[5]; + int output_shape_[5]; + int input_stride_[5]; + int output_stride_[5]; + int seq_axis_; + int batch_axis_; + int outer_count_; + int outer_stride_; + int inner_count_; + int inner_stride_; + int copy_byte_size_; + int total_data_size_; +}; + +void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_REVERSE_SEQUENCE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/scale.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/scale.cc new file mode 100644 index 0000000000..37405b67c1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/scale.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/scale.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param) { + if (in_data == nullptr || out_data == nullptr || scale == nullptr || offset == nullptr || scale_param == nullptr) { + return NNACL_ERR; + } + + if (scale_param->has_offset_) { + for (int out = task_id; out < scale_param->outer_size_; out += scale_param->op_parameter_.thread_num_) { + int out_offset = out * scale_param->axis_size_ * scale_param->inner_size_; + for (int i = 0; i < scale_param->axis_size_; i++) { + int axis_offset = out_offset + i * scale_param->inner_size_; + for (int in = 0; in < scale_param->inner_size_; in++) { + int in_offset = axis_offset + in; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } + } else { + for (int out = task_id; out < scale_param->outer_size_; out += scale_param->op_parameter_.thread_num_) { + int out_offset = out * scale_param->axis_size_ * scale_param->inner_size_; + for (int i = 0; i < scale_param->axis_size_; i++) { + int axis_offset = out_offset + i * scale_param->inner_size_; + for (int in = 0; in < scale_param->inner_size_; in++) { + int in_offset = axis_offset + in; + out_data[in_offset] = in_data[in_offset] * scale[i]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/scale.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/scale.h new file mode 100644 index 0000000000..aaa77b7dda --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/scale.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SCALE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SCALE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ScaleParameter { + OpParameter op_parameter_; + int outer_size_; + int axis_size_; + int inner_size_; + int axis_; + bool has_offset_; + // todo yangruoqi: axis +}; + +int DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SCALE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/scatter_nd.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/scatter_nd.cc new file mode 100644 index 0000000000..0194e928af --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/scatter_nd.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/scatter_nd.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int DoScatterND(float *output_ptr, float *update, int *output_unit_offsets, int unit_size, int num_units) { + if (output_ptr == nullptr || update == nullptr || output_unit_offsets == nullptr || unit_size <= 0 || num_units < 0) { + return NNACL_ERR; + } + for (int i = 0; i < num_units; i++) { + (void)memcpy(output_ptr + output_unit_offsets[i], update + unit_size * i, unit_size * sizeof(float)); + } + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/scatter_nd.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/scatter_nd.h new file mode 100644 index 0000000000..adbe179bed --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/scatter_nd.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SCATTER_ND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SCATTER_ND_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ScatterNDParameter { + OpParameter op_parameter_; +}; + +int DoScatterND(float *output_ptr, float *update, int *output_unit_offsets, int unit_size, int num_units); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SCATTER_ND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/shape.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/shape.h new file mode 100644 index 0000000000..bfcbc4f8dc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/shape.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARM_NNACL_SHAPE_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ARM_NNACL_SHAPE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct ShapeParameter { + OpParameter op_parameter_; +}; + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_NNACL_SHAPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/softmax_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/softmax_parameter.h new file mode 100644 index 0000000000..b0fc83ce3e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/softmax_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SOFTMAX_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SOFTMAX_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct SoftmaxParameter { + OpParameter op_parameter_; + int32_t axis_; + int element_size_; + int n_dim_; + int input_shape_[4]; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SOFTMAX_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/sparse_to_dense.cc new file mode 100644 index 0000000000..17b80564b3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/sparse_to_dense.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/sparse_to_dense.h" + +void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, + SparseToDenseParameter *s2d_param_, int task_id) { + int m; + for (int i = task_id; i < output_shape_[0]; i += s2d_param_->op_parameter_.thread_num_) { + for (int j = 0; j < output_shape_[1]; j++) { + m = i * output_shape_[1] + j; + output[m] = dnum[0]; + } + } + + for (int j = 0; j < sp_num; j++) { + int temp = j * 2; + int temp1 = j * 2 + 1; + int tempout1 = input[temp] * output_shape_[1] + input[temp1]; + output[tempout1] = snum[j]; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/sparse_to_dense.h new file mode 100644 index 0000000000..3f048dd817 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/sparse_to_dense.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPARSETODENSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPARSETODENSE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct SparseToDenseParameter { + OpParameter op_parameter_; + int thread_num_; + int count_ = 0; +}; + +void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, + SparseToDenseParameter *s2d_param_, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPARSETODENCE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/split.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/split.cc new file mode 100644 index 0000000000..712abebe19 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/split.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/split.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset, int num_unit, + SplitParameter *split_param) { + if (in_data == nullptr || out_data == nullptr) { + return NNACL_ERR; + } + int num_split = split_param->num_split_; + int *split_sizes = split_param->split_sizes_; + int *strides = split_param->strides_; + int split_dim = split_param->split_dim_; + int in_stride = strides[split_dim]; + + float *src; + int size_float = (int)(sizeof(float)); + int in_stride_bytes = in_stride * size_float; + + int split_which; + int split_times; + int stride_per_split = in_stride * input_shape[split_dim]; + + split_which = offset % num_split; + split_times = offset / num_split; + src = in_data + split_times * stride_per_split; + + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride; + } + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int split_size = split_sizes[split_which]; + float *dst = out_data[split_which] + split_times * in_stride * split_size; + (void)memcpy(dst, src, split_size * in_stride_bytes); + src += split_size * in_stride; + } + + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/split.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/split.h new file mode 100644 index 0000000000..ff861ff4ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/split.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPLIT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPLIT_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" +#include "src/runtime/kernel/arm/nnacl/split_parameter.h" + +int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset, int num_unit, + SplitParameter *split_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPLIT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/split_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/split_parameter.h new file mode 100644 index 0000000000..30456df654 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/split_parameter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPLIT_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPLIT_PARAMETER_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct SplitParameter { + OpParameter op_parameter_; + SplitQuantArg quant_arg_; + int num_split_; + int split_sizes_[20] = {0}; + int strides_[20]; + int split_dim_; + int n_dims_; + int split_count_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SPLIT_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.cc new file mode 100644 index 0000000000..43e7ffad7c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/squeeze.h" +#include + +int DoSqueeze(float *in_data, float *out_data, size_t data_size) { + if (in_data == nullptr || out_data == nullptr) { + return -1; + } + (void)memcpy(out_data, in_data, data_size); + return 0; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h new file mode 100644 index 0000000000..5da0637e8a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SQUEEZE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct SqueezeParameter { + OpParameter op_parameter_; + int axes_[8]; +}; + +int DoSqueeze(float *input_ptr, float *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_SQUEEZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/strassen_matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/strassen_matmul.h new file mode 100644 index 0000000000..c385d3b45f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/strassen_matmul.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_STRASSEN_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_STRASSEN_MATMUL_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +/* hw*inc4 X inc4*oc4 */ +struct StrassenMatMulParameter { + OpParameter op_parameter; + int row_{}; /* h * w */ + int col_{}; /* oc4 / 4 */ + int deep_{}; /* inc4 / 4 */ + int a_stride_{}; /* h * w * 4 */ + int b_stride_{}; /* inc4 * 4 */ + int c_stride_{}; /* h * w * 4 */ +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_STRASSEN_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.cc new file mode 100644 index 0000000000..0c14eb2f71 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/strided_slice.h" +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +void PadStridedSliceParameterTo4D(StridedSliceParameter *param) { + int32_t begins[DIMENSION_4D]; + int32_t ends[DIMENSION_4D]; + int32_t strides[DIMENSION_4D]; + int32_t input_shape[DIMENSION_4D]; + for (int32_t i = 0; i < param->num_axes_; ++i) { + begins[i] = param->begins_[i]; + ends[i] = param->ends_[i]; + strides[i] = param->strides_[i]; + input_shape[i] = param->in_shape_[i]; + } + int32_t real_index = param->num_axes_ - 1; + for (int32_t i = DIMENSION_4D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begins_[i] = begins[real_index]; + param->ends_[i] = ends[real_index]; + param->strides_[i] = strides[real_index]; + param->in_shape_[i] = input_shape[real_index--]; + } else { + param->begins_[i] = 0; + param->ends_[i] = 1; + param->strides_[i] = 1; + param->in_shape_[i] = 1; + } + } + param->num_axes_ = DIMENSION_4D; +} + +int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { + if (in_data == nullptr || out_data == nullptr || param == nullptr) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_4D) { + return NNACL_PARAM_INVALID; + } + + int *begins = param->begins_; + int *ends = param->ends_; + int *strides = param->strides_; + int *in_shape = param->in_shape_; + + if (param->num_axes_ < DIMENSION_4D) { + PadStridedSliceParameterTo4D(param); + } + + size_t dim_offset[DIMENSION_4D - 1]; + dim_offset[2] = in_shape[3]; + dim_offset[1] = dim_offset[2] * in_shape[2]; + dim_offset[0] = dim_offset[1] * in_shape[1]; + size_t out_offset = 0; + for (int32_t dim0 = begins[0]; dim0 < ends[0]; dim0 += strides[0]) { + for (int32_t dim1 = begins[1]; dim1 < ends[1]; dim1 += strides[1]) { + for (int32_t dim2 = begins[2]; dim2 < ends[2]; dim2 += strides[2]) { + for (int32_t dim3 = begins[3]; dim3 < ends[3]; dim3 += strides[3]) { + int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + dim3; + if (param->data_type == kDataTypeFloat) { + *((float *)out_data + out_offset) = *((float *)in_data + in_offset); + } else { + *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); + } + out_offset++; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.h new file mode 100644 index 0000000000..0e9bcac593 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_STRIDED_SLICE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_STRIDED_SLICE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct StridedSliceParameter { + OpParameter op_parameter_; + int begins_[8] = {0}; + int ends_[8] = {0}; + int strides_[8] = {1}; + int isScale; + int num_axes_; + int in_shape_[8]; + LiteDataType data_type; +}; + +int DoStridedSlice(const void *inputs, void *output, StridedSliceParameter *param); +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_STRIDED_SLICE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/tile.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/tile.cc new file mode 100644 index 0000000000..747a4a51f7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/tile.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/tile.h" +#include + +void CopyData(float *input_data, float *output_data, size_t size, size_t multiple) { + float *out_data = output_data; + for (size_t i = 0; i < multiple; ++i) { + (void)memcpy(out_data, input_data, size * sizeof(float)); + out_data += size; + } +} + +int TileOneDimension(float *input_data, float *output_data, size_t dim, TileParameter *parameter) { + size_t src_dim_size = parameter->in_shape_[dim]; + if (dim == parameter->in_dim_ - 1) { + CopyData(input_data, output_data, src_dim_size, parameter->multiples_[dim]); + return 0; + } + for (size_t i = 0; i < src_dim_size; ++i) { + for (size_t j = 0; j < parameter->multiples_[dim]; ++j) { + size_t in_pos = parameter->in_strides_[dim] * i; + size_t out_pos = parameter->out_strides_[dim] * (i + j * src_dim_size); + TileOneDimension(input_data + in_pos, output_data + out_pos, dim + 1, parameter); + } + } + return 0; +} + +void Tile(float *input_data, float *output_data, TileParameter *parameter) { + TileOneDimension(input_data, output_data, 0, parameter); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/tile.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/tile.h new file mode 100644 index 0000000000..da6d09602e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/tile.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TILE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TILE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct TileParameter { + OpParameter op_parameter_; + int in_dim_; + int in_shape_[5]; + int out_shape_[5]; + int multiples_[5]; + int in_strides_[5]; + int out_strides_[5]; +}; + +void Tile(float *input_data, float *output_data, TileParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TILE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc new file mode 100644 index 0000000000..0e1b75db9a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/transpose.h" +#include +#include "src/runtime/kernel/arm/nnacl/errorcode.h" + +void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; j++) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; j++) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; k++) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; j++) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; k++) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; m++) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, + TransposeParameter *transpose_param) { + if (in_data == nullptr || out_data == nullptr) { + return NNACL_ERR; + } + int *perm = transpose_param->perm_; + int *strides = transpose_param->strides_; + int *out_strides = transpose_param->out_strides_; + int data_size = transpose_param->data_size_; + int num_axes = transpose_param->num_axes_; + + if (num_axes < 2 || num_axes > 4) { + return NNACL_ERR; + } + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; i++) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + if (num_axes == 2) { + TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); + } + return NNACL_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h new file mode 100644 index 0000000000..80e19b8ae8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TRANSPOSE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct TransposeParameter { + OpParameter op_parameter_; + int perm_[8]; + bool conjugate_; + int num_axes_; + int strides_[8]; + int out_strides_[8]; + int data_size_; +}; + +int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, + TransposeParameter *transpose_param); +void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); +void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); +void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TRANSPOSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/unique.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/unique.cc new file mode 100644 index 0000000000..9db29931ae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/unique.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/unique.h" + +int Find(float *array, int len, float target) { + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void Unique(float *input, int input_len, float *output0, int *output0_len, int *output1) { + output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = Find(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/unique.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/unique.h new file mode 100644 index 0000000000..dce4b59319 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/unique.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNIQUE_H +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNIQUE_H + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct UniqueParameter { + OpParameter op_parameter_; +}; + +void Unique(float *input, int input_len, float *output0, int *output0_len, int *output1); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNIQUE_H + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/unstack.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/unstack.cc new file mode 100644 index 0000000000..5a7fd07b91 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/unstack.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/unstack.h" +#include + +void Unistack(float *input, float **output, UnstackParameter *para) { + for (int j = 0; j < para->num_; j++) { + float *out_addr = output[j]; + int out_offset = 0; + for (int i = 0; i < para->pre_dims_; i++) { + int in_offset = i * para->axis_dim_ * para->after_dims_ + j * para->after_dims_; + (void)memcpy(out_addr + out_offset, input + in_offset, para->after_dims_ * sizeof(float)); + out_offset += para->after_dims_; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/unstack.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/unstack.h new file mode 100644 index 0000000000..ee57a810af --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/unstack.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNSTACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNSTACK_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct UnstackParameter { + OpParameter op_parameter_; + int num_; + int axis_; + int pre_dims_; + int axis_dim_; + int after_dims_; +}; + +void Unistack(float *input, float **output, UnstackParameter *para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_UNSTACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/where.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/where.cc new file mode 100644 index 0000000000..40a4802ab6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/where.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/where.h" + +void Where(bool *input, float *input1, float *input2, float *output, WhereParameter *where_param_, int task_id) { + for (int i = task_id; i < where_param_->number_; i += where_param_->op_parameter_.thread_num_) { + if (input[where_param_->num_ > 1 ? i : 0] == true) { + output[i] = input1[where_param_->num1_ > 1 ? i : 0]; + } else { + output[i] = input2[where_param_->num2_ > 1 ? i : 0]; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/where.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/where.h new file mode 100644 index 0000000000..73145fd9ec --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/where.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WHERE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WHERE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +struct WhereParameter { + OpParameter op_parameter_; + int num_; + int num1_; + int num2_; + int number_; + int thread_num_; +}; + +void Where(bool *input, float *input1, float *input2, float *output, WhereParameter *where_param_, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WHERE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.cc new file mode 100644 index 0000000000..23b40f6639 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.cc @@ -0,0 +1,1370 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/winograd_transform.h" + +// fp32 conv winograd +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * ic4 * C4NUM; + int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * ic4 * C4NUM; + int dst_x_offset = dst_y_offset + j * C4NUM; + float *src_addr = (float *)(input_data) + src_x_offset; + float *dst_addr = tmp_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f32(dst_addr, vld1q_f32(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } + } + // input transform + int dst_ic4_offset = dst_plane_offset + ic * TILE_NUM * C4NUM; + size_t dst_step = ic4 * C4NUM * TILE_NUM; + float *trans_input_ptr = trans_input + dst_ic4_offset; + input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, ConvParameter *conv_param, + OutputTransformUnitFunc output_trans_func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_unit_block = UP_DIV(output_w, output_unit); + int output_channel = conv_param->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int input_unit = conv_param->input_unit_; + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = + dst_tile_offset + j * C4NUM * output_unit_block * output_unit_block * output_unit * output_unit; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = tmp_out_data + dst_oc4_offset; + output_trans_func(src_ptr, dst_ptr, bias_ptr, C4NUM, output_unit_block * output_unit); + } + out_tile_index++; + } +} + +// fp32 conv3x3 +void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step) { +#ifdef ENABLE_ARM + float32x4_t d00 = vld1q_f32(tmp_data); + float32x4_t d01 = vld1q_f32(tmp_data + 4); + float32x4_t d02 = vld1q_f32(tmp_data + 2 * 4); + float32x4_t d03 = vld1q_f32(tmp_data + 3 * 4); + + float32x4_t d10 = vld1q_f32(tmp_data + 4 * 4); + float32x4_t d11 = vld1q_f32(tmp_data + 5 * 4); + float32x4_t d12 = vld1q_f32(tmp_data + 6 * 4); + float32x4_t d13 = vld1q_f32(tmp_data + 7 * 4); + + float32x4_t d20 = vld1q_f32(tmp_data + 8 * 4); + float32x4_t d21 = vld1q_f32(tmp_data + 9 * 4); + float32x4_t d22 = vld1q_f32(tmp_data + 10 * 4); + float32x4_t d23 = vld1q_f32(tmp_data + 11 * 4); + + float32x4_t d30 = vld1q_f32(tmp_data + 12 * 4); + float32x4_t d31 = vld1q_f32(tmp_data + 13 * 4); + float32x4_t d32 = vld1q_f32(tmp_data + 14 * 4); + float32x4_t d33 = vld1q_f32(tmp_data + 15 * 4); + + float32x4_t t00 = vsubq_f32(d00, d20); + float32x4_t t01 = vsubq_f32(d01, d21); + float32x4_t t02 = vsubq_f32(d02, d22); + float32x4_t t03 = vsubq_f32(d03, d23); + + float32x4_t t10 = vaddq_f32(d10, d20); + float32x4_t t11 = vaddq_f32(d11, d21); + float32x4_t t12 = vaddq_f32(d12, d22); + float32x4_t t13 = vaddq_f32(d13, d23); + + float32x4_t t20 = vsubq_f32(d20, d10); + float32x4_t t21 = vsubq_f32(d21, d11); + float32x4_t t22 = vsubq_f32(d22, d12); + float32x4_t t23 = vsubq_f32(d23, d13); + + float32x4_t t30 = vsubq_f32(d10, d30); + float32x4_t t31 = vsubq_f32(d11, d31); + float32x4_t t32 = vsubq_f32(d12, d32); + float32x4_t t33 = vsubq_f32(d13, d33); + + float32x4_t m00 = vsubq_f32(t00, t02); + float32x4_t m01 = vaddq_f32(t01, t02); + float32x4_t m02 = vsubq_f32(t02, t01); + float32x4_t m03 = vsubq_f32(t01, t03); + + float32x4_t m10 = vsubq_f32(t10, t12); + float32x4_t m11 = vaddq_f32(t11, t12); + float32x4_t m12 = vsubq_f32(t12, t11); + float32x4_t m13 = vsubq_f32(t11, t13); + + float32x4_t m20 = vsubq_f32(t20, t22); + float32x4_t m21 = vaddq_f32(t21, t22); + float32x4_t m22 = vsubq_f32(t22, t21); + float32x4_t m23 = vsubq_f32(t21, t23); + + float32x4_t m30 = vsubq_f32(t30, t32); + float32x4_t m31 = vaddq_f32(t31, t32); + float32x4_t m32 = vsubq_f32(t32, t31); + float32x4_t m33 = vsubq_f32(t31, t33); + + vst1q_f32(trans_input_data, m00); + vst1q_f32(trans_input_data + step, m01); + vst1q_f32(trans_input_data + 2 * step, m02); + vst1q_f32(trans_input_data + 3 * step, m03); + + vst1q_f32(trans_input_data + 4 * step, m10); + vst1q_f32(trans_input_data + 5 * step, m11); + vst1q_f32(trans_input_data + 6 * step, m12); + vst1q_f32(trans_input_data + 7 * step, m13); + + vst1q_f32(trans_input_data + 8 * step, m20); + vst1q_f32(trans_input_data + 9 * step, m21); + vst1q_f32(trans_input_data + 10 * step, m22); + vst1q_f32(trans_input_data + 11 * step, m23); + + vst1q_f32(trans_input_data + 12 * step, m30); + vst1q_f32(trans_input_data + 13 * step, m31); + vst1q_f32(trans_input_data + 14 * step, m32); + vst1q_f32(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C4NUM; i++) { + const float *local_ptr = tmp_data + i; + float d00 = local_ptr[0]; + float d01 = (local_ptr + C4NUM)[0]; + float d02 = (local_ptr + 2 * C4NUM)[0]; + float d03 = (local_ptr + 3 * C4NUM)[0]; + + float d10 = (local_ptr + 4 * C4NUM)[0]; + float d11 = (local_ptr + 5 * C4NUM)[0]; + float d12 = (local_ptr + 6 * C4NUM)[0]; + float d13 = (local_ptr + 7 * C4NUM)[0]; + + float d20 = (local_ptr + 8 * C4NUM)[0]; + float d21 = (local_ptr + 9 * C4NUM)[0]; + float d22 = (local_ptr + 10 * C4NUM)[0]; + float d23 = (local_ptr + 11 * C4NUM)[0]; + + float d30 = (local_ptr + 12 * C4NUM)[0]; + float d31 = (local_ptr + 13 * C4NUM)[0]; + float d32 = (local_ptr + 14 * C4NUM)[0]; + float d33 = (local_ptr + 15 * C4NUM)[0]; + + float t00 = d00 - d20; + float t01 = d01 - d21; + float t02 = d02 - d22; + float t03 = d03 - d23; + + float t10 = d10 + d20; + float t11 = d11 + d21; + float t12 = d12 + d22; + float t13 = d13 + d23; + + float t20 = d20 - d10; + float t21 = d21 - d11; + float t22 = d22 - d12; + float t23 = d23 - d13; + + float t30 = d10 - d30; + float t31 = d11 - d31; + float t32 = d12 - d32; + float t33 = d13 - d33; + + float m00 = t00 - t02; + float m01 = t01 + t02; + float m02 = t02 - t01; + float m03 = t01 - t03; + + float m10 = t10 - t12; + float m11 = t11 + t12; + float m12 = t12 - t11; + float m13 = t11 - t13; + + float m20 = t20 - t22; + float m21 = t21 + t22; + float m22 = t22 - t21; + float m23 = t21 - t23; + + float m30 = t30 - t32; + float m31 = t31 + t32; + float m32 = t32 - t31; + float m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + int ic4 = UP_DIV(input_channel, C4NUM); + int input_unit = 4; + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = real_y_start; interval < real_y_end; interval++) { + int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM; + int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM; + for (int j = 0; j < (real_x_end - real_x_start); j++) { + int src_x_offset = src_y_offset + j * ic4 * C4NUM; + int dst_x_offset = dst_y_offset + j * C4NUM; + float *src_addr = (float *)(input_data) + src_x_offset; + float *dst_addr = tmp_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f32(dst_addr, vld1q_f32(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + (dst_addr + k)[0] = (src_addr + k)[0]; + } +#endif + } + } + + // input transform + int dst_ic4_offset = dst_plane_offset + ic * TILE_NUM * C4NUM; + size_t dst_step = ic4 * C4NUM * TILE_NUM; + float *trans_input_ptr = trans_input + dst_ic4_offset; + Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step); + } + } +} + +void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane, + int oc_block) { + int input_unit = 4; + int dst_step = iC4 * C4NUM * oc_block; + for (int o = 0; o < output_channel; o++) { + int oc_block_num = o / oc_block; + int oc_block_rem = o % oc_block; + int src_oc_offset = o * iC4 * C4NUM * kernel_plane; + int dst_oc_offset = oc_block_num * oc_block * iC4 * C4NUM * input_unit * input_unit + oc_block_rem; + for (int i = 0; i < iC4; i++) { + float *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; + float *dst_ic4_ptr = trans_weight + dst_oc_offset + i * oc_block * C4NUM; +#ifdef ENABLE_ARM + float32x4_t g00 = vld1q_f32(src_ic4_ptr); + float32x4_t g01 = vld1q_f32(src_ic4_ptr + 4); + float32x4_t g02 = vld1q_f32(src_ic4_ptr + 2 * 4); + float32x4_t g10 = vld1q_f32(src_ic4_ptr + 3 * 4); + float32x4_t g11 = vld1q_f32(src_ic4_ptr + 4 * 4); + float32x4_t g12 = vld1q_f32(src_ic4_ptr + 5 * 4); + float32x4_t g20 = vld1q_f32(src_ic4_ptr + 6 * 4); + float32x4_t g21 = vld1q_f32(src_ic4_ptr + 7 * 4); + float32x4_t g22 = vld1q_f32(src_ic4_ptr + 8 * 4); + + float32x4_t dst00 = g00; + float32x4_t dst01 = g01; + float32x4_t dst02 = g02; + + float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5)); + float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5)); + float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5)); + float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5)); + float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst30 = g20; + float32x4_t dst31 = g21; + float32x4_t dst32 = g22; + + float32x4_t m00 = dst00; + float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5)); + float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5)); + float32x4_t m03 = dst02; + + float32x4_t m10 = dst10; + float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5)); + float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5)); + float32x4_t m13 = dst12; + + float32x4_t m20 = dst20; + float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5)); + float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5)); + float32x4_t m23 = dst22; + + float32x4_t m30 = dst30; + float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5)); + float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5)); + float32x4_t m33 = dst32; + + dst_ic4_ptr[0] = m00[0]; + dst_ic4_ptr[8] = m00[1]; + dst_ic4_ptr[16] = m00[2]; + dst_ic4_ptr[24] = m00[3]; + + dst_ic4_ptr[0 + dst_step] = m01[0]; + dst_ic4_ptr[8 + dst_step] = m01[1]; + dst_ic4_ptr[16 + dst_step] = m01[2]; + dst_ic4_ptr[24 + dst_step] = m01[3]; + + dst_ic4_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic4_ptr[8 + 2 * dst_step] = m02[1]; + dst_ic4_ptr[16 + 2 * dst_step] = m02[2]; + dst_ic4_ptr[24 + 2 * dst_step] = m02[3]; + + dst_ic4_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic4_ptr[8 + 3 * dst_step] = m03[1]; + dst_ic4_ptr[16 + 3 * dst_step] = m03[2]; + dst_ic4_ptr[24 + 3 * dst_step] = m03[3]; + + dst_ic4_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic4_ptr[8 + 4 * dst_step] = m10[1]; + dst_ic4_ptr[16 + 4 * dst_step] = m10[2]; + dst_ic4_ptr[24 + 4 * dst_step] = m10[3]; + + dst_ic4_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic4_ptr[8 + 5 * dst_step] = m11[1]; + dst_ic4_ptr[16 + 5 * dst_step] = m11[2]; + dst_ic4_ptr[24 + 5 * dst_step] = m11[3]; + + dst_ic4_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic4_ptr[8 + 6 * dst_step] = m12[1]; + dst_ic4_ptr[16 + 6 * dst_step] = m12[2]; + dst_ic4_ptr[24 + 6 * dst_step] = m12[3]; + + dst_ic4_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic4_ptr[8 + 7 * dst_step] = m13[1]; + dst_ic4_ptr[16 + 7 * dst_step] = m13[2]; + dst_ic4_ptr[24 + 7 * dst_step] = m13[3]; + + dst_ic4_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic4_ptr[8 + 8 * dst_step] = m20[1]; + dst_ic4_ptr[16 + 8 * dst_step] = m20[2]; + dst_ic4_ptr[24 + 8 * dst_step] = m20[3]; + + dst_ic4_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic4_ptr[8 + 9 * dst_step] = m21[1]; + dst_ic4_ptr[16 + 9 * dst_step] = m21[2]; + dst_ic4_ptr[24 + 9 * dst_step] = m21[3]; + + dst_ic4_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic4_ptr[8 + 10 * dst_step] = m22[1]; + dst_ic4_ptr[16 + 10 * dst_step] = m22[2]; + dst_ic4_ptr[24 + 10 * dst_step] = m22[3]; + + dst_ic4_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic4_ptr[8 + 11 * dst_step] = m23[1]; + dst_ic4_ptr[16 + 11 * dst_step] = m23[2]; + dst_ic4_ptr[24 + 11 * dst_step] = m23[3]; + + dst_ic4_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic4_ptr[8 + 12 * dst_step] = m30[1]; + dst_ic4_ptr[16 + 12 * dst_step] = m30[2]; + dst_ic4_ptr[24 + 12 * dst_step] = m30[3]; + + dst_ic4_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic4_ptr[8 + 13 * dst_step] = m31[1]; + dst_ic4_ptr[16 + 13 * dst_step] = m31[2]; + dst_ic4_ptr[24 + 13 * dst_step] = m31[3]; + + dst_ic4_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic4_ptr[8 + 14 * dst_step] = m32[1]; + dst_ic4_ptr[16 + 14 * dst_step] = m32[2]; + dst_ic4_ptr[24 + 14 * dst_step] = m32[3]; + + dst_ic4_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic4_ptr[8 + 15 * dst_step] = m33[1]; + dst_ic4_ptr[16 + 15 * dst_step] = m33[2]; + dst_ic4_ptr[24 + 15 * dst_step] = m33[3]; +#else + for (int j = 0; j < C4NUM; j++) { + float *local_ptr = src_ic4_ptr + j; + float dst00 = local_ptr[0]; + float dst01 = (local_ptr + 4)[0]; + float dst02 = (local_ptr + 8)[0]; + + float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst30 = (local_ptr + 24)[0]; + float dst31 = (local_ptr + 28)[0]; + float dst32 = (local_ptr + 32)[0]; + + float m00 = dst00; + float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02; + float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02; + float m03 = dst02; + + float m10 = dst10; + float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12; + float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12; + float m13 = dst12; + + float m20 = dst20; + float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22; + float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22; + float m23 = dst22; + + float m30 = dst30; + float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32; + float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32; + float m33 = dst32; + + *(dst_ic4_ptr + j * 8) = m00; + *(dst_ic4_ptr + j * 8 + dst_step) = m01; + *(dst_ic4_ptr + j * 8 + 2 * dst_step) = m02; + *(dst_ic4_ptr + j * 8 + 3 * dst_step) = m03; + + *(dst_ic4_ptr + j * 8 + 4 * dst_step) = m10; + *(dst_ic4_ptr + j * 8 + 5 * dst_step) = m11; + *(dst_ic4_ptr + j * 8 + 6 * dst_step) = m12; + *(dst_ic4_ptr + j * 8 + 7 * dst_step) = m13; + + *(dst_ic4_ptr + j * 8 + 8 * dst_step) = m20; + *(dst_ic4_ptr + j * 8 + 9 * dst_step) = m21; + *(dst_ic4_ptr + j * 8 + 10 * dst_step) = m22; + *(dst_ic4_ptr + j * 8 + 11 * dst_step) = m23; + + *(dst_ic4_ptr + j * 8 + 12 * dst_step) = m30; + *(dst_ic4_ptr + j * 8 + 13 * dst_step) = m31; + *(dst_ic4_ptr + j * 8 + 14 * dst_step) = m32; + *(dst_ic4_ptr + j * 8 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound, + bool w_not_bound, int output_w) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias_data); + + float32x4_t s00 = vld1q_f32(gemm_out); + float32x4_t s01 = vld1q_f32(gemm_out + 4); + float32x4_t s02 = vld1q_f32(gemm_out + 8); + float32x4_t s03 = vld1q_f32(gemm_out + 12); + + float32x4_t s10 = vld1q_f32(gemm_out + 16); + float32x4_t s11 = vld1q_f32(gemm_out + 20); + float32x4_t s12 = vld1q_f32(gemm_out + 24); + float32x4_t s13 = vld1q_f32(gemm_out + 28); + + float32x4_t s20 = vld1q_f32(gemm_out + 32); + float32x4_t s21 = vld1q_f32(gemm_out + 36); + float32x4_t s22 = vld1q_f32(gemm_out + 40); + float32x4_t s23 = vld1q_f32(gemm_out + 44); + + float32x4_t s30 = vld1q_f32(gemm_out + 48); + float32x4_t s31 = vld1q_f32(gemm_out + 52); + float32x4_t s32 = vld1q_f32(gemm_out + 56); + float32x4_t s33 = vld1q_f32(gemm_out + 60); + + float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20); + float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21); + float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22); + float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23); + + float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30); + float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31); + float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32); + float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33); + + float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr); + float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr); + float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr); + float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr); + + vst1q_f32(output_data, d00); + if (w_not_bound) { + vst1q_f32(output_data + 4, d01); + } + if (h_not_bound) { + vst1q_f32(output_data + output_w * 4, d10); + if (w_not_bound) { + vst1q_f32(output_data + output_w * 4 + 4, d11); + } + } +#else + for (int i = 0; i < C4NUM; i++) { + const float *local_ptr = gemm_out + i; + const float *bias_ptr = bias_data + i; + + float s00 = local_ptr[0]; + float s01 = (local_ptr + 4)[0]; + float s02 = (local_ptr + 8)[0]; + float s03 = (local_ptr + 12)[0]; + + float s10 = (local_ptr + 16)[0]; + float s11 = (local_ptr + 20)[0]; + float s12 = (local_ptr + 24)[0]; + float s13 = (local_ptr + 28)[0]; + + float s20 = (local_ptr + 32)[0]; + float s21 = (local_ptr + 36)[0]; + float s22 = (local_ptr + 40)[0]; + float s23 = (local_ptr + 44)[0]; + + float s30 = (local_ptr + 48)[0]; + float s31 = (local_ptr + 52)[0]; + float s32 = (local_ptr + 56)[0]; + float s33 = (local_ptr + 60)[0]; + + float t00 = s00 + s10 + s20; + float t01 = s01 + s11 + s21; + float t02 = s02 + s12 + s22; + float t03 = s03 + s13 + s23; + + float t10 = s10 - s20 - s30; + float t11 = s11 - s21 - s31; + float t12 = s12 - s22 - s32; + float t13 = s13 - s23 - s33; + + float d00 = t00 + t01 + t02 + bias_ptr[0]; + float d01 = t01 - t02 - t03 + bias_ptr[0]; + float d10 = t10 + t11 + t12 + bias_ptr[0]; + float d11 = t11 - t12 - t13 + bias_ptr[0]; + + (output_data + i)[0] = d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = d11; + } + } + } +#endif +} + +void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int input_unit = 4; + + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + + // output transform + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Fp32OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w); + } + } +} + +// int8 conv3x3 +void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { +#ifdef ENABLE_ARM + int16x8_t zp = vdupq_n_s16(input_zp); + + int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); + int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); + int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); + int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); + + int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); + int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); + int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); + int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); + + int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); + int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); + int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); + int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); + + int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); + int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); + int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); + int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); + + int16x8_t t00 = vsubq_s16(d00, d20); + int16x8_t t01 = vsubq_s16(d01, d21); + int16x8_t t02 = vsubq_s16(d02, d22); + int16x8_t t03 = vsubq_s16(d03, d23); + + int16x8_t t10 = vaddq_s16(d10, d20); + int16x8_t t11 = vaddq_s16(d11, d21); + int16x8_t t12 = vaddq_s16(d12, d22); + int16x8_t t13 = vaddq_s16(d13, d23); + + int16x8_t t20 = vsubq_s16(d20, d10); + int16x8_t t21 = vsubq_s16(d21, d11); + int16x8_t t22 = vsubq_s16(d22, d12); + int16x8_t t23 = vsubq_s16(d23, d13); + + int16x8_t t30 = vsubq_s16(d10, d30); + int16x8_t t31 = vsubq_s16(d11, d31); + int16x8_t t32 = vsubq_s16(d12, d32); + int16x8_t t33 = vsubq_s16(d13, d33); + + int16x8_t m00 = vsubq_s16(t00, t02); + int16x8_t m01 = vaddq_s16(t01, t02); + int16x8_t m02 = vsubq_s16(t02, t01); + int16x8_t m03 = vsubq_s16(t01, t03); + + int16x8_t m10 = vsubq_s16(t10, t12); + int16x8_t m11 = vaddq_s16(t11, t12); + int16x8_t m12 = vsubq_s16(t12, t11); + int16x8_t m13 = vsubq_s16(t11, t13); + + int16x8_t m20 = vsubq_s16(t20, t22); + int16x8_t m21 = vaddq_s16(t21, t22); + int16x8_t m22 = vsubq_s16(t22, t21); + int16x8_t m23 = vsubq_s16(t21, t23); + + int16x8_t m30 = vsubq_s16(t30, t32); + int16x8_t m31 = vaddq_s16(t31, t32); + int16x8_t m32 = vsubq_s16(t32, t31); + int16x8_t m33 = vsubq_s16(t31, t33); + + vst1q_s16(trans_input_data, m00); + vst1q_s16(trans_input_data + step, m01); + vst1q_s16(trans_input_data + 2 * step, m02); + vst1q_s16(trans_input_data + 3 * step, m03); + + vst1q_s16(trans_input_data + 4 * step, m10); + vst1q_s16(trans_input_data + 5 * step, m11); + vst1q_s16(trans_input_data + 6 * step, m12); + vst1q_s16(trans_input_data + 7 * step, m13); + + vst1q_s16(trans_input_data + 8 * step, m20); + vst1q_s16(trans_input_data + 9 * step, m21); + vst1q_s16(trans_input_data + 10 * step, m22); + vst1q_s16(trans_input_data + 11 * step, m23); + + vst1q_s16(trans_input_data + 12 * step, m30); + vst1q_s16(trans_input_data + 13 * step, m31); + vst1q_s16(trans_input_data + 14 * step, m32); + vst1q_s16(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C8NUM; i++) { + int16_t *local_ptr = tmp_data + i; + int16_t d00 = local_ptr[0] - input_zp; + int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; + int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; + int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; + + int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; + int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; + int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; + int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; + + int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; + int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; + int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; + int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; + + int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; + int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; + int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; + int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; + + int16_t t00 = d00 - d20; + int16_t t01 = d01 - d21; + int16_t t02 = d02 - d22; + int16_t t03 = d03 - d23; + + int16_t t10 = d10 + d20; + int16_t t11 = d11 + d21; + int16_t t12 = d12 + d22; + int16_t t13 = d13 + d23; + + int16_t t20 = d20 - d10; + int16_t t21 = d21 - d11; + int16_t t22 = d22 - d12; + int16_t t23 = d23 - d13; + + int16_t t30 = d10 - d30; + int16_t t31 = d11 - d31; + int16_t t32 = d12 - d32; + int16_t t33 = d13 - d33; + + int16_t m00 = t00 - t02; + int16_t m01 = t01 + t02; + int16_t m02 = t02 - t01; + int16_t m03 = t01 - t03; + + int16_t m10 = t10 - t12; + int16_t m11 = t11 + t12; + int16_t m12 = t12 - t11; + int16_t m13 = t11 - t13; + + int16_t m20 = t20 - t22; + int16_t m21 = t21 + t22; + int16_t m22 = t22 - t21; + int16_t m23 = t21 - t23; + + int16_t m30 = t30 - t32; + int16_t m31 = t31 + t32; + int16_t m32 = t32 - t31; + int16_t m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + ConvQuantArg quant_arg = conv_param->conv_quant_arg_; + int input_zp = quant_arg.quant_args_[0][0].zp_; + int ic8 = UP_DIV(input_channel, C8NUM); + int input_unit = 4; + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + // copy data from origin input to tmp buffer + for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; + + int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; + for (int j = real_y_start; j < real_y_end; j++) { + const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); + int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); + memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); + } + // input transform + int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; + size_t dst_step = ic8 * C8NUM * TILE_NUM; + int16_t *trans_input_ptr = trans_input + dst_ic8_offset; + Conv3x3Uint8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); + } + } +} + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane) { + int input_unit = 4; + int dst_step = iC8 * C8NUM * C4NUM; + for (int o = 0; o < output_channel; o++) { + int oc4_block_num = o / C4NUM; + int oc4_block_rem = o % C4NUM; + int src_oc_offset = o * iC8 * C8NUM * kernel_plane; + int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; + for (int i = 0; i < iC8; i++) { + auto src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; + auto dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; +#ifdef ENABLE_ARM + int16x8_t g00 = vld1q_s16(src_ic8_ptr); + int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); + int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); + int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); + int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); + int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); + int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); + int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); + int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); + + int16x8_t dst00 = vmulq_n_s16(g00, 2); + int16x8_t dst01 = vmulq_n_s16(g01, 2); + int16x8_t dst02 = vmulq_n_s16(g02, 2); + + int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); + int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); + int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); + + int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); + int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); + int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); + + int16x8_t dst30 = vmulq_n_s16(g20, 2); + int16x8_t dst31 = vmulq_n_s16(g21, 2); + int16x8_t dst32 = vmulq_n_s16(g22, 2); + + int16x8_t m00 = vmulq_n_s16(dst00, 2); + int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); + int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); + int16x8_t m03 = vmulq_n_s16(dst02, 2); + + int16x8_t m10 = vmulq_n_s16(dst10, 2); + int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); + int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); + int16x8_t m13 = vmulq_n_s16(dst12, 2); + + int16x8_t m20 = vmulq_n_s16(dst20, 2); + int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); + int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); + int16x8_t m23 = vmulq_n_s16(dst22, 2); + + int16x8_t m30 = vmulq_n_s16(dst30, 2); + int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); + int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); + int16x8_t m33 = vmulq_n_s16(dst32, 2); + + dst_ic8_ptr[0] = m00[0]; + dst_ic8_ptr[4] = m00[1]; + dst_ic8_ptr[8] = m00[2]; + dst_ic8_ptr[12] = m00[3]; + dst_ic8_ptr[16] = m00[4]; + dst_ic8_ptr[20] = m00[5]; + dst_ic8_ptr[24] = m00[6]; + dst_ic8_ptr[28] = m00[7]; + + dst_ic8_ptr[0 + dst_step] = m01[0]; + dst_ic8_ptr[4 + dst_step] = m01[1]; + dst_ic8_ptr[8 + dst_step] = m01[2]; + dst_ic8_ptr[12 + dst_step] = m01[3]; + dst_ic8_ptr[16 + dst_step] = m01[4]; + dst_ic8_ptr[20 + dst_step] = m01[5]; + dst_ic8_ptr[24 + dst_step] = m01[6]; + dst_ic8_ptr[28 + dst_step] = m01[7]; + + dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; + dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; + dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; + dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; + dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; + dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; + dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; + + dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; + dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; + dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; + dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; + dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; + dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; + dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; + + dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; + dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; + dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; + dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; + dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; + dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; + dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; + + dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; + dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; + dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; + dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; + dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; + dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; + dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; + + dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; + dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; + dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; + dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; + dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; + dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; + dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; + + dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; + dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; + dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; + dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; + dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; + dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; + dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; + + dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; + dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; + dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; + dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; + dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; + dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; + dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; + + dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; + dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; + dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; + dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; + dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; + dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; + dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; + + dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; + dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; + dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; + dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; + dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; + dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; + dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; + + dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; + dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; + dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; + dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; + dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; + dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; + dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; + + dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; + dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; + dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; + dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; + dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; + dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; + dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; + + dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; + dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; + dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; + dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; + dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; + dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; + dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; + + dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; + dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; + dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; + dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; + dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; + dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; + dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; + + dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; + dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; + dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; + dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; + dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; + dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; + dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; +#else + for (int j = 0; j < C8NUM; j++) { + auto local_ptr = src_ic8_ptr + j; + int16_t dst00 = local_ptr[0] * 2; + int16_t dst01 = (local_ptr + 8)[0] * 2; + int16_t dst02 = (local_ptr + 16)[0] * 2; + + int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst30 = (local_ptr + 48)[0] * 2; + int16_t dst31 = (local_ptr + 56)[0] * 2; + int16_t dst32 = (local_ptr + 64)[0] * 2; + + int16_t m00 = dst00 * 2; + int16_t m01 = dst00 + dst01 + dst02; + int16_t m02 = dst00 - dst01 + dst02; + int16_t m03 = dst02 * 2; + + int16_t m10 = dst10 * 2; + int16_t m11 = dst10 + dst11 + dst12; + int16_t m12 = dst10 - dst11 + dst12; + int16_t m13 = dst12 * 2; + + int16_t m20 = dst20 * 2; + int16_t m21 = dst20 + dst21 + dst22; + int16_t m22 = dst20 - dst21 + dst22; + int16_t m23 = dst22 * 2; + + int16_t m30 = dst30 * 2; + int16_t m31 = dst30 + dst31 + dst32; + int16_t m32 = dst30 - dst31 + dst32; + int16_t m33 = dst32 * 2; + + *(dst_ic8_ptr + j * 4) = m00; + *(dst_ic8_ptr + j * 4 + dst_step) = m01; + *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; + *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; + + *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; + *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; + *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; + *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; + + *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; + *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; + *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; + *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; + + *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; + *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; + *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; + *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param) { + int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; + int right_shift = conv_param->conv_quant_arg_.right_shift_[0]; + int quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; + int output_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; + +#ifdef ENABLE_ARM + int32x4_t bias_ptr = vld1q_s32(bias_data); + + int32x4_t s00 = vld1q_s32(gemm_out); + int32x4_t s01 = vld1q_s32(gemm_out + 4); + int32x4_t s02 = vld1q_s32(gemm_out + 8); + int32x4_t s03 = vld1q_s32(gemm_out + 12); + + int32x4_t s10 = vld1q_s32(gemm_out + 16); + int32x4_t s11 = vld1q_s32(gemm_out + 20); + int32x4_t s12 = vld1q_s32(gemm_out + 24); + int32x4_t s13 = vld1q_s32(gemm_out + 28); + + int32x4_t s20 = vld1q_s32(gemm_out + 32); + int32x4_t s21 = vld1q_s32(gemm_out + 36); + int32x4_t s22 = vld1q_s32(gemm_out + 40); + int32x4_t s23 = vld1q_s32(gemm_out + 44); + + int32x4_t s30 = vld1q_s32(gemm_out + 48); + int32x4_t s31 = vld1q_s32(gemm_out + 52); + int32x4_t s32 = vld1q_s32(gemm_out + 56); + int32x4_t s33 = vld1q_s32(gemm_out + 60); + + int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); + int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); + int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); + int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); + + int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); + int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); + int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); + int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); + + int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); + int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); + + int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); + int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); + + int32x4_t out_multiplier = vdupq_n_s32(quant_multiplier); + int32x4_t out_zp = vdupq_n_s32(output_zp); + int32x4_t output_min = vdupq_n_s32(out_min); + int32x4_t output_max = vdupq_n_s32(out_max); + int32x4_t ls = vdupq_n_s32(left_shift); + int32x4_t rs = vdupq_n_s32(right_shift); + + d00 = vqshlq_s32(d00, ls); + d00 = vqrdmulhq_s32(d00, out_multiplier); + d00 = vqrshlq_s32(d00, rs); + d00 = vaddq_s32(d00, out_zp); + d00 = vmaxq_s32(d00, output_min); + d00 = vminq_s32(d00, output_max); + + d01 = vqshlq_s32(d01, ls); + d01 = vqrdmulhq_s32(d01, out_multiplier); + d01 = vqrshlq_s32(d01, rs); + d01 = vaddq_s32(d01, out_zp); + d01 = vmaxq_s32(d01, output_min); + d01 = vminq_s32(d01, output_max); + + d10 = vqshlq_s32(d10, ls); + d10 = vqrdmulhq_s32(d10, out_multiplier); + d10 = vqrshlq_s32(d10, rs); + d10 = vaddq_s32(d10, out_zp); + d10 = vmaxq_s32(d10, output_min); + d10 = vminq_s32(d10, output_max); + + d11 = vqshlq_s32(d11, ls); + d11 = vqrdmulhq_s32(d11, out_multiplier); + d11 = vqrshlq_s32(d11, rs); + d11 = vaddq_s32(d11, out_zp); + d11 = vmaxq_s32(d11, output_min); + d11 = vminq_s32(d11, output_max); + + (output_data)[0] = (uint8_t)d00[0]; + (output_data + 1)[0] = (uint8_t)d00[1]; + (output_data + 2)[0] = (uint8_t)d00[2]; + (output_data + 3)[0] = (uint8_t)d00[3]; + + if (w_not_bound) { + *(output_data + 4) = (uint8_t)d01[0]; + *(output_data + 5) = (uint8_t)d01[1]; + *(output_data + 6) = (uint8_t)d01[2]; + *(output_data + 7) = (uint8_t)d01[3]; + } + if (h_not_bound) { + *(output_data + output_w * 4) = (uint8_t)d10[0]; + *(output_data + output_w * 4 + 1) = (uint8_t)d10[1]; + *(output_data + output_w * 4 + 2) = (uint8_t)d10[2]; + *(output_data + output_w * 4 + 3) = (uint8_t)d10[3]; + if (w_not_bound) { + *(output_data + output_w * 4 + 4) = (uint8_t)d11[0]; + *(output_data + output_w * 4 + 5) = (uint8_t)d11[1]; + *(output_data + output_w * 4 + 6) = (uint8_t)d11[2]; + *(output_data + output_w * 4 + 7) = (uint8_t)d11[3]; + } + } +#else + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } +#endif +} + +void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int input_unit = 4; + + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const int32_t *src_ptr = gemm_out + src_oc4_offset; + const int32_t *bias_ptr = bias_data + j * C4NUM; + int8_t *dst_ptr = out_data + dst_oc4_offset; + + // output transform + int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, conv_param); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h new file mode 100644 index 0000000000..e31e2e1930 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_transform.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_TRANSFORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_TRANSFORM_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv.h" +#include "src/runtime/kernel/arm/nnacl/winograd_utils.h" +#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" + +#define OUPUT_UNIT 2 + +// for fp32 winograd input/output transform +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func); + +void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, ConvParameter *conv_param, + OutputTransformUnitFunc output_trans_func); + +// for fp32 convolution 3x3 filter/input/output transform +void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step); + +void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane, + int oc_block); + +void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound, + bool w_not_bound, int output_w); + +void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +// for int8 convolution 3x3 filter/input/output transform +void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp); + +void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param); + +void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_TRANSFORM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.cc new file mode 100644 index 0000000000..0c26437018 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.cc @@ -0,0 +1,4710 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/nnacl/winograd_utils.h" +#include + +#define MIN_UNIT 2 +#define MAX_UNIT 8 + +static OutputTransformUnitFunc outputTransformUnit[] = { + nullptr, // 0 + nullptr, // 1 + OutputTransform8x2Unit, + OutputTransform8x3Unit, + OutputTransform8x4Unit, + OutputTransform8x5Unit, + OutputTransform8x6Unit, + OutputTransform8x7Unit, +}; + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 15 * src_step); + + float32x4_t t00 = vsubq_f32(src_data_00, vmulq_n_f32(src_data_20, 4)); + float32x4_t t01 = vsubq_f32(src_data_01, vmulq_n_f32(src_data_21, 4)); + float32x4_t t02 = vsubq_f32(src_data_02, vmulq_n_f32(src_data_22, 4)); + float32x4_t t03 = vsubq_f32(src_data_03, vmulq_n_f32(src_data_23, 4)); + + float32x4_t t10 = vaddq_f32(src_data_10, vmulq_n_f32(src_data_20, 2)); + float32x4_t t11 = vaddq_f32(src_data_11, vmulq_n_f32(src_data_21, 2)); + float32x4_t t12 = vaddq_f32(src_data_12, vmulq_n_f32(src_data_22, 2)); + float32x4_t t13 = vaddq_f32(src_data_13, vmulq_n_f32(src_data_23, 2)); + + float32x4_t t20 = vsubq_f32(vmulq_n_f32(src_data_20, 2), src_data_10); + float32x4_t t21 = vsubq_f32(vmulq_n_f32(src_data_21, 2), src_data_11); + float32x4_t t22 = vsubq_f32(vmulq_n_f32(src_data_22, 2), src_data_12); + float32x4_t t23 = vsubq_f32(vmulq_n_f32(src_data_23, 2), src_data_13); + + float32x4_t t30 = vsubq_f32(src_data_30, vmulq_n_f32(src_data_10, 0.25)); + float32x4_t t31 = vsubq_f32(src_data_31, vmulq_n_f32(src_data_11, 0.25)); + float32x4_t t32 = vsubq_f32(src_data_32, vmulq_n_f32(src_data_12, 0.25)); + float32x4_t t33 = vsubq_f32(src_data_33, vmulq_n_f32(src_data_13, 0.25)); + + float32x4_t m00 = vsubq_f32(t00, vmulq_n_f32(t02, 4)); + float32x4_t m01 = vaddq_f32(t01, vmulq_n_f32(t02, 2)); + float32x4_t m02 = vsubq_f32(vmulq_n_f32(t02, 2), t01); + float32x4_t m03 = vsubq_f32(t03, vmulq_n_f32(t01, 0.25)); + + float32x4_t m10 = vsubq_f32(t10, vmulq_n_f32(t12, 4)); + float32x4_t m11 = vaddq_f32(t11, vmulq_n_f32(t12, 2)); + float32x4_t m12 = vsubq_f32(vmulq_n_f32(t12, 2), t11); + float32x4_t m13 = vsubq_f32(t13, vmulq_n_f32(t11, 0.25)); + + float32x4_t m20 = vsubq_f32(t20, vmulq_n_f32(t22, 4)); + float32x4_t m21 = vaddq_f32(t21, vmulq_n_f32(t22, 2)); + float32x4_t m22 = vsubq_f32(vmulq_n_f32(t22, 2), t21); + float32x4_t m23 = vsubq_f32(t23, vmulq_n_f32(t21, 0.25)); + + float32x4_t m30 = vsubq_f32(t30, vmulq_n_f32(t32, 4)); + float32x4_t m31 = vaddq_f32(t31, vmulq_n_f32(t32, 2)); + float32x4_t m32 = vsubq_f32(vmulq_n_f32(t32, 2), t31); + float32x4_t m33 = vsubq_f32(t33, vmulq_n_f32(t31, 0.25)); + + vst1q_f32(dst_data + 0 * dst_step, m00); + vst1q_f32(dst_data + 1 * dst_step, m01); + vst1q_f32(dst_data + 2 * dst_step, m02); + vst1q_f32(dst_data + 3 * dst_step, m03); + vst1q_f32(dst_data + 4 * dst_step, m10); + vst1q_f32(dst_data + 5 * dst_step, m11); + vst1q_f32(dst_data + 6 * dst_step, m12); + vst1q_f32(dst_data + 7 * dst_step, m13); + vst1q_f32(dst_data + 8 * dst_step, m20); + vst1q_f32(dst_data + 9 * dst_step, m21); + vst1q_f32(dst_data + 10 * dst_step, m22); + vst1q_f32(dst_data + 11 * dst_step, m23); + vst1q_f32(dst_data + 12 * dst_step, m30); + vst1q_f32(dst_data + 13 * dst_step, m31); + vst1q_f32(dst_data + 14 * dst_step, m32); + vst1q_f32(dst_data + 15 * dst_step, m33); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_10 = src_data[i + 4 * src_step]; + float src_data_11 = src_data[i + 5 * src_step]; + float src_data_12 = src_data[i + 6 * src_step]; + float src_data_13 = src_data[i + 7 * src_step]; + float src_data_20 = src_data[i + 8 * src_step]; + float src_data_21 = src_data[i + 9 * src_step]; + float src_data_22 = src_data[i + 10 * src_step]; + float src_data_23 = src_data[i + 11 * src_step]; + float src_data_30 = src_data[i + 12 * src_step]; + float src_data_31 = src_data[i + 13 * src_step]; + float src_data_32 = src_data[i + 14 * src_step]; + float src_data_33 = src_data[i + 15 * src_step]; + + float t00 = src_data_00 - 4 * src_data_20; + float t01 = src_data_01 - 4 * src_data_21; + float t02 = src_data_02 - 4 * src_data_22; + float t03 = src_data_03 - 4 * src_data_23; + + float t10 = src_data_10 + 2 * src_data_20; + float t11 = src_data_11 + 2 * src_data_21; + float t12 = src_data_12 + 2 * src_data_22; + float t13 = src_data_13 + 2 * src_data_23; + + float t20 = 2 * src_data_20 - src_data_10; + float t21 = 2 * src_data_21 - src_data_11; + float t22 = 2 * src_data_22 - src_data_12; + float t23 = 2 * src_data_23 - src_data_13; + + float t30 = src_data_30 - 0.25f * src_data_10; + float t31 = src_data_31 - 0.25f * src_data_11; + float t32 = src_data_32 - 0.25f * src_data_12; + float t33 = src_data_33 - 0.25f * src_data_13; + + float m00 = t00 - 4 * t02; + float m01 = t01 + 2 * t02; + float m02 = 2 * t02 - t01; + float m03 = t03 - 0.25f * t01; + + float m10 = t10 - 4 * t12; + float m11 = t11 + 2 * t12; + float m12 = 2 * t12 - t11; + float m13 = t13 - 0.25f * t11; + + float m20 = t20 - 4 * t22; + float m21 = t21 + 2 * t22; + float m22 = 2 * t22 - t21; + float m23 = t23 - 0.25f * t21; + + float m30 = t30 - 4 * t32; + float m31 = t31 + 2 * t32; + float m32 = 2 * t32 - t31; + float m33 = t33 - 0.25f * t31; + + (dst_data + i)[0] = m00; + (dst_data + i + dst_step)[0] = m01; + (dst_data + i + 2 * dst_step)[0] = m02; + (dst_data + i + 3 * dst_step)[0] = m03; + + (dst_data + i + 4 * dst_step)[0] = m10; + (dst_data + i + 5 * dst_step)[0] = m11; + (dst_data + i + 6 * dst_step)[0] = m12; + (dst_data + i + 7 * dst_step)[0] = m13; + + (dst_data + i + 8 * dst_step)[0] = m20; + (dst_data + i + 9 * dst_step)[0] = m21; + (dst_data + i + 10 * dst_step)[0] = m22; + (dst_data + i + 11 * dst_step)[0] = m23; + + (dst_data + i + 12 * dst_step)[0] = m30; + (dst_data + i + 13 * dst_step)[0] = m31; + (dst_data + i + 14 * dst_step)[0] = m32; + (dst_data + i + 15 * dst_step)[0] = m33; + } +#endif +} + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t t00 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_00, vmulq_n_f32(src_data_20, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_40, 6.222222222222)), + vmulq_n_f32(src_data_60, 1.7777777777777)); + float32x4_t t01 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_01, vmulq_n_f32(src_data_21, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_41, 6.222222222222)), + vmulq_n_f32(src_data_61, 1.7777777777777)); + float32x4_t t02 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_02, vmulq_n_f32(src_data_22, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_42, 6.222222222222)), + vmulq_n_f32(src_data_62, 1.7777777777777)); + float32x4_t t03 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_03, vmulq_n_f32(src_data_23, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_43, 6.222222222222)), + vmulq_n_f32(src_data_63, 1.7777777777777)); + float32x4_t t04 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_04, vmulq_n_f32(src_data_24, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_44, 6.222222222222)), + vmulq_n_f32(src_data_64, 1.7777777777777)); + float32x4_t t05 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_05, vmulq_n_f32(src_data_25, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_45, 6.222222222222)), + vmulq_n_f32(src_data_65, 1.7777777777777)); + float32x4_t t06 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_06, vmulq_n_f32(src_data_26, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_46, 6.222222222222)), + vmulq_n_f32(src_data_66, 1.7777777777777)); + float32x4_t t07 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_07, vmulq_n_f32(src_data_27, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_47, 6.222222222222)), + vmulq_n_f32(src_data_67, 1.7777777777777)); + + float32x4_t t10 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_10, 1.5), vmulq_n_f32(src_data_20, 3)), + vmulq_n_f32(src_data_30, 2.166666666666666667)), + vmulq_n_f32(src_data_40, 4.333333333333)), + vmulq_n_f32(src_data_50, 0.66666666666)), + vmulq_n_f32(src_data_60, 1.333333333333)); + float32x4_t t11 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_11, 1.5), vmulq_n_f32(src_data_21, 3)), + vmulq_n_f32(src_data_31, 2.166666666666666667)), + vmulq_n_f32(src_data_41, 4.333333333333)), + vmulq_n_f32(src_data_51, 0.66666666666)), + vmulq_n_f32(src_data_61, 1.333333333333)); + float32x4_t t12 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_12, 1.5), vmulq_n_f32(src_data_22, 3)), + vmulq_n_f32(src_data_32, 2.166666666666666667)), + vmulq_n_f32(src_data_42, 4.333333333333)), + vmulq_n_f32(src_data_52, 0.66666666666)), + vmulq_n_f32(src_data_62, 1.333333333333)); + float32x4_t t13 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_13, 1.5), vmulq_n_f32(src_data_23, 3)), + vmulq_n_f32(src_data_33, 2.166666666666666667)), + vmulq_n_f32(src_data_43, 4.333333333333)), + vmulq_n_f32(src_data_53, 0.66666666666)), + vmulq_n_f32(src_data_63, 1.333333333333)); + float32x4_t t14 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_14, 1.5), vmulq_n_f32(src_data_24, 3)), + vmulq_n_f32(src_data_34, 2.166666666666666667)), + vmulq_n_f32(src_data_44, 4.333333333333)), + vmulq_n_f32(src_data_54, 0.66666666666)), + vmulq_n_f32(src_data_64, 1.333333333333)); + float32x4_t t15 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_15, 1.5), vmulq_n_f32(src_data_25, 3)), + vmulq_n_f32(src_data_35, 2.166666666666666667)), + vmulq_n_f32(src_data_45, 4.333333333333)), + vmulq_n_f32(src_data_55, 0.66666666666)), + vmulq_n_f32(src_data_65, 1.333333333333)); + float32x4_t t16 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_16, 1.5), vmulq_n_f32(src_data_26, 3)), + vmulq_n_f32(src_data_36, 2.166666666666666667)), + vmulq_n_f32(src_data_46, 4.333333333333)), + vmulq_n_f32(src_data_56, 0.66666666666)), + vmulq_n_f32(src_data_66, 1.333333333333)); + float32x4_t t17 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_17, 1.5), vmulq_n_f32(src_data_27, 3)), + vmulq_n_f32(src_data_37, 2.166666666666666667)), + vmulq_n_f32(src_data_47, 4.333333333333)), + vmulq_n_f32(src_data_57, 0.66666666666)), + vmulq_n_f32(src_data_67, 1.333333333333)); + + float32x4_t t20 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_10, -1.5), vmulq_n_f32(src_data_20, 3)), + vmulq_n_f32(src_data_30, 2.166666666666666667)), + vmulq_n_f32(src_data_40, 4.333333333333)), + vmulq_n_f32(src_data_50, 0.66666666666)), + vmulq_n_f32(src_data_60, 1.333333333333)); + float32x4_t t21 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_11, -1.5), vmulq_n_f32(src_data_21, 3)), + vmulq_n_f32(src_data_31, 2.166666666666666667)), + vmulq_n_f32(src_data_41, 4.333333333333)), + vmulq_n_f32(src_data_51, 0.66666666666)), + vmulq_n_f32(src_data_61, 1.333333333333)); + float32x4_t t22 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_12, -1.5), vmulq_n_f32(src_data_22, 3)), + vmulq_n_f32(src_data_32, 2.166666666666666667)), + vmulq_n_f32(src_data_42, 4.333333333333)), + vmulq_n_f32(src_data_52, 0.66666666666)), + vmulq_n_f32(src_data_62, 1.333333333333)); + float32x4_t t23 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_13, -1.5), vmulq_n_f32(src_data_23, 3)), + vmulq_n_f32(src_data_33, 2.166666666666666667)), + vmulq_n_f32(src_data_43, 4.333333333333)), + vmulq_n_f32(src_data_53, 0.66666666666)), + vmulq_n_f32(src_data_63, 1.333333333333)); + float32x4_t t24 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_14, -1.5), vmulq_n_f32(src_data_24, 3)), + vmulq_n_f32(src_data_34, 2.166666666666666667)), + vmulq_n_f32(src_data_44, 4.333333333333)), + vmulq_n_f32(src_data_54, 0.66666666666)), + vmulq_n_f32(src_data_64, 1.333333333333)); + float32x4_t t25 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_15, -1.5), vmulq_n_f32(src_data_25, 3)), + vmulq_n_f32(src_data_35, 2.166666666666666667)), + vmulq_n_f32(src_data_45, 4.333333333333)), + vmulq_n_f32(src_data_55, 0.66666666666)), + vmulq_n_f32(src_data_65, 1.333333333333)); + float32x4_t t26 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_16, -1.5), vmulq_n_f32(src_data_26, 3)), + vmulq_n_f32(src_data_36, 2.166666666666666667)), + vmulq_n_f32(src_data_46, 4.333333333333)), + vmulq_n_f32(src_data_56, 0.66666666666)), + vmulq_n_f32(src_data_66, 1.333333333333)); + float32x4_t t27 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_17, -1.5), vmulq_n_f32(src_data_27, 3)), + vmulq_n_f32(src_data_37, 2.166666666666666667)), + vmulq_n_f32(src_data_47, 4.333333333333)), + vmulq_n_f32(src_data_57, 0.66666666666)), + vmulq_n_f32(src_data_67, 1.333333333333)); + + float32x4_t t30 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_30, src_data_40), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_10, src_data_20), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_50, src_data_60), 0.53333333333)); + float32x4_t t31 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_31, src_data_41), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_11, src_data_21), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_51, src_data_61), 0.53333333333)); + float32x4_t t32 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_32, src_data_42), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_12, src_data_22), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_52, src_data_62), 0.53333333333)); + float32x4_t t33 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_33, src_data_43), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_13, src_data_23), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_53, src_data_63), 0.53333333333)); + float32x4_t t34 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_34, src_data_44), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_14, src_data_24), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_54, src_data_64), 0.53333333333)); + float32x4_t t35 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_35, src_data_45), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_15, src_data_25), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_55, src_data_65), 0.53333333333)); + float32x4_t t36 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_36, src_data_46), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_16, src_data_26), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_56, src_data_66), 0.53333333333)); + float32x4_t t37 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_37, src_data_47), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_17, src_data_27), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_57, src_data_67), 0.53333333333)); + + float32x4_t t40 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_40, src_data_30), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_10, src_data_20), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_50, src_data_60), 0.53333333333)); + float32x4_t t41 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_41, src_data_31), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_11, src_data_21), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_51, src_data_61), 0.53333333333)); + float32x4_t t42 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_42, src_data_32), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_12, src_data_22), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_52, src_data_62), 0.53333333333)); + float32x4_t t43 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_43, src_data_33), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_13, src_data_23), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_53, src_data_63), 0.53333333333)); + float32x4_t t44 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_44, src_data_34), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_14, src_data_24), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_54, src_data_64), 0.53333333333)); + float32x4_t t45 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_45, src_data_35), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_15, src_data_25), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_55, src_data_65), 0.53333333333)); + float32x4_t t46 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_46, src_data_36), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_16, src_data_26), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_56, src_data_66), 0.53333333333)); + float32x4_t t47 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_47, src_data_37), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_17, src_data_27), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_57, src_data_67), 0.53333333333)); + + float32x4_t t50 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_10, 0.03333333), vmulq_n_f32(src_data_20, 0.022222222)), + vmulq_n_f32(src_data_30, 0.1666666666)), + vmulq_n_f32(src_data_40, 0.11111111111)), + vmulq_n_f32(src_data_50, 0.133333333)), + vmulq_n_f32(src_data_60, 0.088888888)); + float32x4_t t51 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_11, 0.03333333), vmulq_n_f32(src_data_21, 0.022222222)), + vmulq_n_f32(src_data_31, 0.1666666666)), + vmulq_n_f32(src_data_41, 0.11111111111)), + vmulq_n_f32(src_data_51, 0.133333333)), + vmulq_n_f32(src_data_61, 0.088888888)); + float32x4_t t52 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_12, 0.03333333), vmulq_n_f32(src_data_22, 0.022222222)), + vmulq_n_f32(src_data_32, 0.1666666666)), + vmulq_n_f32(src_data_42, 0.11111111111)), + vmulq_n_f32(src_data_52, 0.133333333)), + vmulq_n_f32(src_data_62, 0.088888888)); + float32x4_t t53 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_13, 0.03333333), vmulq_n_f32(src_data_23, 0.022222222)), + vmulq_n_f32(src_data_33, 0.1666666666)), + vmulq_n_f32(src_data_43, 0.11111111111)), + vmulq_n_f32(src_data_53, 0.133333333)), + vmulq_n_f32(src_data_63, 0.088888888)); + float32x4_t t54 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_14, 0.03333333), vmulq_n_f32(src_data_24, 0.022222222)), + vmulq_n_f32(src_data_34, 0.1666666666)), + vmulq_n_f32(src_data_44, 0.11111111111)), + vmulq_n_f32(src_data_54, 0.133333333)), + vmulq_n_f32(src_data_64, 0.088888888)); + float32x4_t t55 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_15, 0.03333333), vmulq_n_f32(src_data_25, 0.022222222)), + vmulq_n_f32(src_data_35, 0.1666666666)), + vmulq_n_f32(src_data_45, 0.11111111111)), + vmulq_n_f32(src_data_55, 0.133333333)), + vmulq_n_f32(src_data_65, 0.088888888)); + float32x4_t t56 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_16, 0.03333333), vmulq_n_f32(src_data_26, 0.022222222)), + vmulq_n_f32(src_data_36, 0.1666666666)), + vmulq_n_f32(src_data_46, 0.11111111111)), + vmulq_n_f32(src_data_56, 0.133333333)), + vmulq_n_f32(src_data_66, 0.088888888)); + float32x4_t t57 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_17, 0.03333333), vmulq_n_f32(src_data_27, 0.022222222)), + vmulq_n_f32(src_data_37, 0.1666666666)), + vmulq_n_f32(src_data_47, 0.11111111111)), + vmulq_n_f32(src_data_57, 0.133333333)), + vmulq_n_f32(src_data_67, 0.088888888)); + + float32x4_t t60 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_10, -0.03333333), vmulq_n_f32(src_data_20, 0.022222222)), + vmulq_n_f32(src_data_30, 0.1666666666)), + vmulq_n_f32(src_data_40, 0.11111111111)), + vmulq_n_f32(src_data_50, -0.133333333)), + vmulq_n_f32(src_data_60, 0.088888888)); + float32x4_t t61 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_11, -0.03333333), vmulq_n_f32(src_data_21, 0.022222222)), + vmulq_n_f32(src_data_31, 0.1666666666)), + vmulq_n_f32(src_data_41, 0.11111111111)), + vmulq_n_f32(src_data_51, -0.133333333)), + vmulq_n_f32(src_data_61, 0.088888888)); + float32x4_t t62 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_12, -0.03333333), vmulq_n_f32(src_data_22, 0.022222222)), + vmulq_n_f32(src_data_32, 0.1666666666)), + vmulq_n_f32(src_data_42, 0.11111111111)), + vmulq_n_f32(src_data_52, -0.133333333)), + vmulq_n_f32(src_data_62, 0.088888888)); + float32x4_t t63 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_13, -0.03333333), vmulq_n_f32(src_data_23, 0.022222222)), + vmulq_n_f32(src_data_33, 0.1666666666)), + vmulq_n_f32(src_data_43, 0.11111111111)), + vmulq_n_f32(src_data_53, -0.133333333)), + vmulq_n_f32(src_data_63, 0.088888888)); + float32x4_t t64 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_14, -0.03333333), vmulq_n_f32(src_data_24, 0.022222222)), + vmulq_n_f32(src_data_34, 0.1666666666)), + vmulq_n_f32(src_data_44, 0.11111111111)), + vmulq_n_f32(src_data_54, -0.133333333)), + vmulq_n_f32(src_data_64, 0.088888888)); + float32x4_t t65 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_15, -0.03333333), vmulq_n_f32(src_data_25, 0.022222222)), + vmulq_n_f32(src_data_35, 0.1666666666)), + vmulq_n_f32(src_data_45, 0.11111111111)), + vmulq_n_f32(src_data_55, -0.133333333)), + vmulq_n_f32(src_data_65, 0.088888888)); + float32x4_t t66 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_16, -0.03333333), vmulq_n_f32(src_data_26, 0.022222222)), + vmulq_n_f32(src_data_36, 0.1666666666)), + vmulq_n_f32(src_data_46, 0.11111111111)), + vmulq_n_f32(src_data_56, -0.133333333)), + vmulq_n_f32(src_data_66, 0.088888888)); + float32x4_t t67 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_17, -0.03333333), vmulq_n_f32(src_data_27, 0.022222222)), + vmulq_n_f32(src_data_37, 0.1666666666)), + vmulq_n_f32(src_data_47, 0.11111111111)), + vmulq_n_f32(src_data_57, -0.133333333)), + vmulq_n_f32(src_data_67, 0.088888888)); + + float32x4_t t70 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_30, 3.0625), vmulq_n_f32(src_data_10, -0.5625)), + vmulq_n_f32(src_data_50, 3.5)), + src_data_70); + float32x4_t t71 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_31, 3.0625), vmulq_n_f32(src_data_11, -0.5625)), + vmulq_n_f32(src_data_51, 3.5)), + src_data_71); + float32x4_t t72 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_32, 3.0625), vmulq_n_f32(src_data_12, -0.5625)), + vmulq_n_f32(src_data_52, 3.5)), + src_data_72); + float32x4_t t73 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_33, 3.0625), vmulq_n_f32(src_data_13, -0.5625)), + vmulq_n_f32(src_data_53, 3.5)), + src_data_73); + float32x4_t t74 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_34, 3.0625), vmulq_n_f32(src_data_14, -0.5625)), + vmulq_n_f32(src_data_54, 3.5)), + src_data_74); + float32x4_t t75 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_35, 3.0625), vmulq_n_f32(src_data_15, -0.5625)), + vmulq_n_f32(src_data_55, 3.5)), + src_data_75); + float32x4_t t76 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_36, 3.0625), vmulq_n_f32(src_data_16, -0.5625)), + vmulq_n_f32(src_data_56, 3.5)), + src_data_76); + float32x4_t t77 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_37, 3.0625), vmulq_n_f32(src_data_17, -0.5625)), + vmulq_n_f32(src_data_57, 3.5)), + src_data_77); + + float32x4_t m00 = + vsubq_f32(vaddq_f32(vsubq_f32(t00, vmulq_n_f32(t02, 5.444444444444444)), vmulq_n_f32(t04, 6.22222222222)), + vmulq_n_f32(t06, 1.77777777777777777778)); + float32x4_t m01 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 1.5), vmulq_n_f32(t02, 3)), + vmulq_n_f32(t03, 2.16666666666666667)), + vmulq_n_f32(t04, 4.3333333333)), + vmulq_n_f32(t05, 0.66666666667)), + vmulq_n_f32(t06, 1.333333333333)); + float32x4_t m02 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t01, -1.5), vmulq_n_f32(t02, 3)), + vmulq_n_f32(t03, 2.16666666666666667)), + vmulq_n_f32(t04, 4.3333333333)), + vmulq_n_f32(t05, 0.66666666667)), + vmulq_n_f32(t06, 1.333333333333)); + float32x4_t m03 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t01, t02), -0.3), vmulq_n_f32(vaddq_f32(t03, t04), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t05, t06), -0.533333333333)); + float32x4_t m04 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t04, t03), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t05, t06), 0.533333333333)); + float32x4_t m05 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 0.03333333), vmulq_n_f32(t02, 0.0222222)), + vmulq_n_f32(t03, 0.16666666666666667)), + vmulq_n_f32(t04, 0.11111111111)), + vmulq_n_f32(t05, 0.1333333333)), + vmulq_n_f32(t06, 0.08888888888)); + float32x4_t m06 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t01, -0.03333333), vmulq_n_f32(t02, 0.0222222)), + vmulq_n_f32(t03, 0.16666666666666667)), + vmulq_n_f32(t04, 0.11111111111)), + vmulq_n_f32(t05, 0.1333333333)), + vmulq_n_f32(t06, 0.08888888888)); + float32x4_t m07 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, -0.5625), vmulq_n_f32(t03, 3.0625)), vmulq_n_f32(t05, 3.5)), t07); + + float32x4_t m10 = + vsubq_f32(vaddq_f32(vsubq_f32(t10, vmulq_n_f32(t12, 5.444444444444444)), vmulq_n_f32(t14, 6.22222222222)), + vmulq_n_f32(t16, 1.77777777777777777778)); + float32x4_t m11 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 1.5), vmulq_n_f32(t12, 3)), + vmulq_n_f32(t13, 2.16666666666666667)), + vmulq_n_f32(t14, 4.3333333333)), + vmulq_n_f32(t15, 0.66666666667)), + vmulq_n_f32(t16, 1.333333333333)); + float32x4_t m12 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t11, -1.5), vmulq_n_f32(t12, 3)), + vmulq_n_f32(t13, 2.16666666666666667)), + vmulq_n_f32(t14, 4.3333333333)), + vmulq_n_f32(t15, 0.66666666667)), + vmulq_n_f32(t16, 1.333333333333)); + float32x4_t m13 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t11, t12), -0.3), vmulq_n_f32(vaddq_f32(t13, t14), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t15, t16), -0.533333333333)); + float32x4_t m14 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t14, t13), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t15, t16), 0.533333333333)); + float32x4_t m15 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 0.03333333), vmulq_n_f32(t12, 0.0222222)), + vmulq_n_f32(t13, 0.16666666666666667)), + vmulq_n_f32(t14, 0.11111111111)), + vmulq_n_f32(t15, 0.1333333333)), + vmulq_n_f32(t16, 0.08888888888)); + float32x4_t m16 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t11, -0.03333333), vmulq_n_f32(t12, 0.0222222)), + vmulq_n_f32(t13, 0.16666666666666667)), + vmulq_n_f32(t14, 0.11111111111)), + vmulq_n_f32(t15, 0.1333333333)), + vmulq_n_f32(t16, 0.08888888888)); + float32x4_t m17 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, -0.5625), vmulq_n_f32(t13, 3.0625)), vmulq_n_f32(t15, 3.5)), t17); + + float32x4_t m20 = + vsubq_f32(vaddq_f32(vsubq_f32(t20, vmulq_n_f32(t22, 5.444444444444444)), vmulq_n_f32(t24, 6.22222222222)), + vmulq_n_f32(t26, 1.77777777777777777778)); + float32x4_t m21 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 1.5), vmulq_n_f32(t22, 3)), + vmulq_n_f32(t23, 2.16666666666666667)), + vmulq_n_f32(t24, 4.3333333333)), + vmulq_n_f32(t25, 0.66666666667)), + vmulq_n_f32(t26, 1.333333333333)); + float32x4_t m22 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t21, -1.5), vmulq_n_f32(t22, 3)), + vmulq_n_f32(t23, 2.16666666666666667)), + vmulq_n_f32(t24, 4.3333333333)), + vmulq_n_f32(t25, 0.66666666667)), + vmulq_n_f32(t26, 1.333333333333)); + float32x4_t m23 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t21, t22), -0.3), vmulq_n_f32(vaddq_f32(t23, t24), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t25, t26), -0.533333333333)); + float32x4_t m24 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t24, t23), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t25, t26), 0.533333333333)); + float32x4_t m25 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 0.03333333), vmulq_n_f32(t22, 0.0222222)), + vmulq_n_f32(t23, 0.16666666666666667)), + vmulq_n_f32(t24, 0.11111111111)), + vmulq_n_f32(t25, 0.1333333333)), + vmulq_n_f32(t26, 0.08888888888)); + float32x4_t m26 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t21, -0.03333333), vmulq_n_f32(t22, 0.0222222)), + vmulq_n_f32(t23, 0.16666666666666667)), + vmulq_n_f32(t24, 0.11111111111)), + vmulq_n_f32(t25, 0.1333333333)), + vmulq_n_f32(t26, 0.08888888888)); + float32x4_t m27 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, -0.5625), vmulq_n_f32(t23, 3.0625)), vmulq_n_f32(t25, 3.5)), t27); + + float32x4_t m30 = + vsubq_f32(vaddq_f32(vsubq_f32(t30, vmulq_n_f32(t32, 5.444444444444444)), vmulq_n_f32(t34, 6.22222222222)), + vmulq_n_f32(t36, 1.77777777777777777778)); + float32x4_t m31 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 1.5), vmulq_n_f32(t32, 3)), + vmulq_n_f32(t33, 2.16666666666666667)), + vmulq_n_f32(t34, 4.3333333333)), + vmulq_n_f32(t35, 0.66666666667)), + vmulq_n_f32(t36, 1.333333333333)); + float32x4_t m32 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t31, -1.5), vmulq_n_f32(t32, 3)), + vmulq_n_f32(t33, 2.16666666666666667)), + vmulq_n_f32(t34, 4.3333333333)), + vmulq_n_f32(t35, 0.66666666667)), + vmulq_n_f32(t36, 1.333333333333)); + float32x4_t m33 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t31, t32), -0.3), vmulq_n_f32(vaddq_f32(t33, t34), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t35, t36), -0.533333333333)); + float32x4_t m34 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t34, t33), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t35, t36), 0.533333333333)); + float32x4_t m35 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 0.03333333), vmulq_n_f32(t32, 0.0222222)), + vmulq_n_f32(t33, 0.16666666666666667)), + vmulq_n_f32(t34, 0.11111111111)), + vmulq_n_f32(t35, 0.1333333333)), + vmulq_n_f32(t36, 0.08888888888)); + float32x4_t m36 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t31, -0.03333333), vmulq_n_f32(t32, 0.0222222)), + vmulq_n_f32(t33, 0.16666666666666667)), + vmulq_n_f32(t34, 0.11111111111)), + vmulq_n_f32(t35, 0.1333333333)), + vmulq_n_f32(t36, 0.08888888888)); + float32x4_t m37 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, -0.5625), vmulq_n_f32(t33, 3.0625)), vmulq_n_f32(t35, 3.5)), t37); + + float32x4_t m40 = + vsubq_f32(vaddq_f32(vsubq_f32(t40, vmulq_n_f32(t42, 5.444444444444444)), vmulq_n_f32(t44, 6.22222222222)), + vmulq_n_f32(t46, 1.77777777777777777778)); + float32x4_t m41 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 1.5), vmulq_n_f32(t42, 3)), + vmulq_n_f32(t43, 2.16666666666666667)), + vmulq_n_f32(t44, 4.3333333333)), + vmulq_n_f32(t45, 0.66666666667)), + vmulq_n_f32(t46, 1.333333333333)); + float32x4_t m42 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t41, -1.5), vmulq_n_f32(t42, 3)), + vmulq_n_f32(t43, 2.16666666666666667)), + vmulq_n_f32(t44, 4.3333333333)), + vmulq_n_f32(t45, 0.66666666667)), + vmulq_n_f32(t46, 1.333333333333)); + float32x4_t m43 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t41, t42), -0.3), vmulq_n_f32(vaddq_f32(t43, t44), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t45, t46), -0.533333333333)); + float32x4_t m44 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t44, t43), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t45, t46), 0.533333333333)); + float32x4_t m45 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 0.03333333), vmulq_n_f32(t42, 0.0222222)), + vmulq_n_f32(t43, 0.16666666666666667)), + vmulq_n_f32(t44, 0.11111111111)), + vmulq_n_f32(t45, 0.1333333333)), + vmulq_n_f32(t46, 0.08888888888)); + float32x4_t m46 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t41, -0.03333333), vmulq_n_f32(t42, 0.0222222)), + vmulq_n_f32(t43, 0.16666666666666667)), + vmulq_n_f32(t44, 0.11111111111)), + vmulq_n_f32(t45, 0.1333333333)), + vmulq_n_f32(t46, 0.08888888888)); + float32x4_t m47 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, -0.5625), vmulq_n_f32(t43, 3.0625)), vmulq_n_f32(t45, 3.5)), t47); + + float32x4_t m50 = + vsubq_f32(vaddq_f32(vsubq_f32(t50, vmulq_n_f32(t52, 5.444444444444444)), vmulq_n_f32(t54, 6.22222222222)), + vmulq_n_f32(t56, 1.77777777777777777778)); + float32x4_t m51 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 1.5), vmulq_n_f32(t52, 3)), + vmulq_n_f32(t53, 2.16666666666666667)), + vmulq_n_f32(t54, 4.3333333333)), + vmulq_n_f32(t55, 0.66666666667)), + vmulq_n_f32(t56, 1.333333333333)); + float32x4_t m52 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t51, -1.5), vmulq_n_f32(t52, 3)), + vmulq_n_f32(t53, 2.16666666666666667)), + vmulq_n_f32(t54, 4.3333333333)), + vmulq_n_f32(t55, 0.66666666667)), + vmulq_n_f32(t56, 1.333333333333)); + float32x4_t m53 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t51, t52), -0.3), vmulq_n_f32(vaddq_f32(t53, t54), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t55, t56), -0.533333333333)); + float32x4_t m54 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t54, t53), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t55, t56), 0.533333333333)); + float32x4_t m55 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 0.03333333), vmulq_n_f32(t52, 0.0222222)), + vmulq_n_f32(t53, 0.16666666666666667)), + vmulq_n_f32(t54, 0.11111111111)), + vmulq_n_f32(t55, 0.1333333333)), + vmulq_n_f32(t56, 0.08888888888)); + float32x4_t m56 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t51, -0.03333333), vmulq_n_f32(t52, 0.0222222)), + vmulq_n_f32(t53, 0.16666666666666667)), + vmulq_n_f32(t54, 0.11111111111)), + vmulq_n_f32(t55, 0.1333333333)), + vmulq_n_f32(t56, 0.08888888888)); + float32x4_t m57 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, -0.5625), vmulq_n_f32(t53, 3.0625)), vmulq_n_f32(t55, 3.5)), t57); + + float32x4_t m60 = + vsubq_f32(vaddq_f32(vsubq_f32(t60, vmulq_n_f32(t62, 5.444444444444444)), vmulq_n_f32(t64, 6.22222222222)), + vmulq_n_f32(t66, 1.77777777777777777778)); + float32x4_t m61 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 1.5), vmulq_n_f32(t62, 3)), + vmulq_n_f32(t63, 2.16666666666666667)), + vmulq_n_f32(t64, 4.3333333333)), + vmulq_n_f32(t65, 0.66666666667)), + vmulq_n_f32(t66, 1.333333333333)); + float32x4_t m62 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t61, -1.5), vmulq_n_f32(t62, 3)), + vmulq_n_f32(t63, 2.16666666666666667)), + vmulq_n_f32(t64, 4.3333333333)), + vmulq_n_f32(t65, 0.66666666667)), + vmulq_n_f32(t66, 1.333333333333)); + float32x4_t m63 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t61, t62), -0.3), vmulq_n_f32(vaddq_f32(t63, t64), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t65, t66), -0.533333333333)); + float32x4_t m64 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t64, t63), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t65, t66), 0.533333333333)); + float32x4_t m65 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 0.03333333), vmulq_n_f32(t62, 0.0222222)), + vmulq_n_f32(t63, 0.16666666666666667)), + vmulq_n_f32(t64, 0.11111111111)), + vmulq_n_f32(t65, 0.1333333333)), + vmulq_n_f32(t66, 0.08888888888)); + float32x4_t m66 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t61, -0.03333333), vmulq_n_f32(t62, 0.0222222)), + vmulq_n_f32(t63, 0.16666666666666667)), + vmulq_n_f32(t64, 0.11111111111)), + vmulq_n_f32(t65, 0.1333333333)), + vmulq_n_f32(t66, 0.08888888888)); + float32x4_t m67 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, -0.5625), vmulq_n_f32(t63, 3.0625)), vmulq_n_f32(t65, 3.5)), t67); + + float32x4_t m70 = + vsubq_f32(vaddq_f32(vsubq_f32(t70, vmulq_n_f32(t72, 5.444444444444444)), vmulq_n_f32(t74, 6.22222222222)), + vmulq_n_f32(t76, 1.77777777777777777778)); + float32x4_t m71 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 1.5), vmulq_n_f32(t72, 3)), + vmulq_n_f32(t73, 2.16666666666666667)), + vmulq_n_f32(t74, 4.3333333333)), + vmulq_n_f32(t75, 0.66666666667)), + vmulq_n_f32(t76, 1.333333333333)); + float32x4_t m72 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t71, -1.5), vmulq_n_f32(t72, 3)), + vmulq_n_f32(t73, 2.16666666666666667)), + vmulq_n_f32(t74, 4.3333333333)), + vmulq_n_f32(t75, 0.66666666667)), + vmulq_n_f32(t76, 1.333333333333)); + float32x4_t m73 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t71, t72), -0.3), vmulq_n_f32(vaddq_f32(t73, t74), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t75, t76), -0.533333333333)); + float32x4_t m74 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t74, t73), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t75, t76), 0.533333333333)); + float32x4_t m75 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 0.03333333), vmulq_n_f32(t72, 0.0222222)), + vmulq_n_f32(t73, 0.16666666666666667)), + vmulq_n_f32(t74, 0.11111111111)), + vmulq_n_f32(t75, 0.1333333333)), + vmulq_n_f32(t76, 0.08888888888)); + float32x4_t m76 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t71, -0.03333333), vmulq_n_f32(t72, 0.0222222)), + vmulq_n_f32(t73, 0.16666666666666667)), + vmulq_n_f32(t74, 0.11111111111)), + vmulq_n_f32(t75, 0.1333333333)), + vmulq_n_f32(t76, 0.08888888888)); + float32x4_t m77 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, -0.5625), vmulq_n_f32(t73, 3.0625)), vmulq_n_f32(t75, 3.5)), t77); + + vst1q_f32(dst_data + 0 * dst_step, m00); + vst1q_f32(dst_data + 1 * dst_step, m01); + vst1q_f32(dst_data + 2 * dst_step, m02); + vst1q_f32(dst_data + 3 * dst_step, m03); + vst1q_f32(dst_data + 4 * dst_step, m04); + vst1q_f32(dst_data + 5 * dst_step, m05); + vst1q_f32(dst_data + 6 * dst_step, m06); + vst1q_f32(dst_data + 7 * dst_step, m07); + vst1q_f32(dst_data + 8 * dst_step, m10); + vst1q_f32(dst_data + 9 * dst_step, m11); + vst1q_f32(dst_data + 10 * dst_step, m12); + vst1q_f32(dst_data + 11 * dst_step, m13); + vst1q_f32(dst_data + 12 * dst_step, m14); + vst1q_f32(dst_data + 13 * dst_step, m15); + vst1q_f32(dst_data + 14 * dst_step, m16); + vst1q_f32(dst_data + 15 * dst_step, m17); + vst1q_f32(dst_data + 16 * dst_step, m20); + vst1q_f32(dst_data + 17 * dst_step, m21); + vst1q_f32(dst_data + 18 * dst_step, m22); + vst1q_f32(dst_data + 19 * dst_step, m23); + vst1q_f32(dst_data + 20 * dst_step, m24); + vst1q_f32(dst_data + 21 * dst_step, m25); + vst1q_f32(dst_data + 22 * dst_step, m26); + vst1q_f32(dst_data + 23 * dst_step, m27); + vst1q_f32(dst_data + 24 * dst_step, m30); + vst1q_f32(dst_data + 25 * dst_step, m31); + vst1q_f32(dst_data + 26 * dst_step, m32); + vst1q_f32(dst_data + 27 * dst_step, m33); + vst1q_f32(dst_data + 28 * dst_step, m34); + vst1q_f32(dst_data + 29 * dst_step, m35); + vst1q_f32(dst_data + 30 * dst_step, m36); + vst1q_f32(dst_data + 31 * dst_step, m37); + vst1q_f32(dst_data + 32 * dst_step, m40); + vst1q_f32(dst_data + 33 * dst_step, m41); + vst1q_f32(dst_data + 34 * dst_step, m42); + vst1q_f32(dst_data + 35 * dst_step, m43); + vst1q_f32(dst_data + 36 * dst_step, m44); + vst1q_f32(dst_data + 37 * dst_step, m45); + vst1q_f32(dst_data + 38 * dst_step, m46); + vst1q_f32(dst_data + 39 * dst_step, m47); + vst1q_f32(dst_data + 40 * dst_step, m50); + vst1q_f32(dst_data + 41 * dst_step, m51); + vst1q_f32(dst_data + 42 * dst_step, m52); + vst1q_f32(dst_data + 43 * dst_step, m53); + vst1q_f32(dst_data + 44 * dst_step, m54); + vst1q_f32(dst_data + 45 * dst_step, m55); + vst1q_f32(dst_data + 46 * dst_step, m56); + vst1q_f32(dst_data + 47 * dst_step, m57); + vst1q_f32(dst_data + 48 * dst_step, m60); + vst1q_f32(dst_data + 49 * dst_step, m61); + vst1q_f32(dst_data + 50 * dst_step, m62); + vst1q_f32(dst_data + 51 * dst_step, m63); + vst1q_f32(dst_data + 52 * dst_step, m64); + vst1q_f32(dst_data + 53 * dst_step, m65); + vst1q_f32(dst_data + 54 * dst_step, m66); + vst1q_f32(dst_data + 55 * dst_step, m67); + vst1q_f32(dst_data + 56 * dst_step, m70); + vst1q_f32(dst_data + 57 * dst_step, m71); + vst1q_f32(dst_data + 58 * dst_step, m72); + vst1q_f32(dst_data + 59 * dst_step, m73); + vst1q_f32(dst_data + 60 * dst_step, m74); + vst1q_f32(dst_data + 61 * dst_step, m75); + vst1q_f32(dst_data + 62 * dst_step, m76); + vst1q_f32(dst_data + 63 * dst_step, m77); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float t00 = src_data_00 - 5.444444444444444445125f * src_data_20 + 6.222222222222222222223f * src_data_40 - + 1.77777777777777778f * src_data_60; + float t01 = src_data_01 - 5.444444444444444445125f * src_data_21 + 6.222222222222222222223f * src_data_41 - + 1.77777777777777778f * src_data_61; + float t02 = src_data_02 - 5.444444444444444445125f * src_data_22 + 6.222222222222222222223f * src_data_42 - + 1.77777777777777778f * src_data_62; + float t03 = src_data_03 - 5.444444444444444445125f * src_data_23 + 6.222222222222222222223f * src_data_43 - + 1.77777777777777778f * src_data_63; + float t04 = src_data_04 - 5.444444444444444445125f * src_data_24 + 6.222222222222222222223f * src_data_44 - + 1.77777777777777778f * src_data_64; + float t05 = src_data_05 - 5.444444444444444445125f * src_data_25 + 6.222222222222222222223f * src_data_45 - + 1.77777777777777778f * src_data_65; + float t06 = src_data_06 - 5.444444444444444445125f * src_data_26 + 6.222222222222222222223f * src_data_46 - + 1.77777777777777778f * src_data_66; + float t07 = src_data_07 - 5.444444444444444445125f * src_data_27 + 6.222222222222222222223f * src_data_47 - + 1.77777777777777778f * src_data_67; + + float t10 = 1.5f * src_data_10 + 3.0f * src_data_20 - 2.1666666666666667f * src_data_30 - + 4.333333333333333333f * src_data_40 + 0.66666666666666667f * src_data_50 + + 1.333333333333333f * src_data_60; + float t11 = 1.5f * src_data_11 + 3.0f * src_data_21 - 2.1666666666666667f * src_data_31 - + 4.333333333333333333f * src_data_41 + 0.66666666666666667f * src_data_51 + + 1.333333333333333f * src_data_61; + float t12 = 1.5f * src_data_12 + 3.0f * src_data_22 - 2.1666666666666667f * src_data_32 - + 4.333333333333333333f * src_data_42 + 0.66666666666666667f * src_data_52 + + 1.333333333333333f * src_data_62; + float t13 = 1.5f * src_data_13 + 3.0f * src_data_23 - 2.1666666666666667f * src_data_33 - + 4.333333333333333333f * src_data_43 + 0.66666666666666667f * src_data_53 + + 1.333333333333333f * src_data_63; + float t14 = 1.5f * src_data_14 + 3.0f * src_data_24 - 2.1666666666666667f * src_data_34 - + 4.333333333333333333f * src_data_44 + 0.66666666666666667f * src_data_54 + + 1.333333333333333f * src_data_64; + float t15 = 1.5f * src_data_15 + 3.0f * src_data_25 - 2.1666666666666667f * src_data_35 - + 4.333333333333333333f * src_data_45 + 0.66666666666666667f * src_data_55 + + 1.333333333333333f * src_data_65; + float t16 = 1.5f * src_data_16 + 3.0f * src_data_26 - 2.1666666666666667f * src_data_36 - + 4.333333333333333333f * src_data_46 + 0.66666666666666667f * src_data_56 + + 1.333333333333333f * src_data_66; + float t17 = 1.5f * src_data_17 + 3.0f * src_data_27 - 2.1666666666666667f * src_data_37 - + 4.333333333333333333f * src_data_47 + 0.66666666666666667f * src_data_57 + + 1.333333333333333f * src_data_67; + + float t20 = -1.5f * src_data_10 + 3.0f * src_data_20 + 2.1666666666666667f * src_data_30 - + 4.333333333333333333f * src_data_40 - 0.66666666666666667f * src_data_50 + + 1.333333333333333f * src_data_60; + float t21 = -1.5f * src_data_11 + 3.0f * src_data_21 + 2.1666666666666667f * src_data_31 - + 4.333333333333333333f * src_data_41 - 0.66666666666666667f * src_data_51 + + 1.333333333333333f * src_data_61; + float t22 = -1.5f * src_data_12 + 3.0f * src_data_22 + 2.1666666666666667f * src_data_32 - + 4.333333333333333333f * src_data_42 - 0.66666666666666667f * src_data_52 + + 1.333333333333333f * src_data_62; + float t23 = -1.5f * src_data_13 + 3.0f * src_data_23 + 2.1666666666666667f * src_data_33 - + 4.333333333333333333f * src_data_43 - 0.66666666666666667f * src_data_53 + + 1.333333333333333f * src_data_63; + float t24 = -1.5f * src_data_14 + 3.0f * src_data_24 + 2.1666666666666667f * src_data_34 - + 4.333333333333333333f * src_data_44 - 0.66666666666666667f * src_data_54 + + 1.333333333333333f * src_data_64; + float t25 = -1.5f * src_data_15 + 3.0f * src_data_25 + 2.1666666666666667f * src_data_35 - + 4.333333333333333333f * src_data_45 - 0.66666666666666667f * src_data_55 + + 1.333333333333333f * src_data_65; + float t26 = -1.5f * src_data_16 + 3.0f * src_data_26 + 2.1666666666666667f * src_data_36 - + 4.333333333333333333f * src_data_46 - 0.66666666666666667f * src_data_56 + + 1.333333333333333f * src_data_66; + float t27 = -1.5f * src_data_17 + 3.0f * src_data_27 + 2.1666666666666667f * src_data_37 - + 4.333333333333333333f * src_data_47 - 0.66666666666666667f * src_data_57 + + 1.333333333333333f * src_data_67; + + float t30 = -0.3f * (src_data_10 + src_data_20) + 1.33333333333333f * (src_data_30 + src_data_40) - + 0.53333333333f * (src_data_50 + src_data_60); + float t31 = -0.3f * (src_data_11 + src_data_21) + 1.33333333333333f * (src_data_31 + src_data_41) - + 0.53333333333f * (src_data_51 + src_data_61); + float t32 = -0.3f * (src_data_12 + src_data_22) + 1.33333333333333f * (src_data_32 + src_data_42) - + 0.53333333333f * (src_data_52 + src_data_62); + float t33 = -0.3f * (src_data_13 + src_data_23) + 1.33333333333333f * (src_data_33 + src_data_43) - + 0.53333333333f * (src_data_53 + src_data_63); + float t34 = -0.3f * (src_data_14 + src_data_24) + 1.33333333333333f * (src_data_34 + src_data_44) - + 0.53333333333f * (src_data_54 + src_data_64); + float t35 = -0.3f * (src_data_15 + src_data_25) + 1.33333333333333f * (src_data_35 + src_data_45) - + 0.53333333333f * (src_data_55 + src_data_65); + float t36 = -0.3f * (src_data_16 + src_data_26) + 1.33333333333333f * (src_data_36 + src_data_46) - + 0.53333333333f * (src_data_56 + src_data_66); + float t37 = -0.3f * (src_data_17 + src_data_27) + 1.33333333333333f * (src_data_37 + src_data_47) - + 0.53333333333f * (src_data_57 + src_data_67); + + float t40 = 0.3f * (src_data_10 - src_data_20) + 1.33333333333333f * (src_data_40 - src_data_30) + + 0.53333333333f * (src_data_50 - src_data_60); + float t41 = 0.3f * (src_data_11 - src_data_21) + 1.33333333333333f * (src_data_41 - src_data_31) + + 0.53333333333f * (src_data_51 - src_data_61); + float t42 = 0.3f * (src_data_12 - src_data_22) + 1.33333333333333f * (src_data_42 - src_data_32) + + 0.53333333333f * (src_data_52 - src_data_62); + float t43 = 0.3f * (src_data_13 - src_data_23) + 1.33333333333333f * (src_data_43 - src_data_33) + + 0.53333333333f * (src_data_53 - src_data_63); + float t44 = 0.3f * (src_data_14 - src_data_24) + 1.33333333333333f * (src_data_44 - src_data_34) + + 0.53333333333f * (src_data_54 - src_data_64); + float t45 = 0.3f * (src_data_15 - src_data_25) + 1.33333333333333f * (src_data_45 - src_data_35) + + 0.53333333333f * (src_data_55 - src_data_65); + float t46 = 0.3f * (src_data_16 - src_data_26) + 1.33333333333333f * (src_data_46 - src_data_36) + + 0.53333333333f * (src_data_56 - src_data_66); + float t47 = 0.3f * (src_data_17 - src_data_27) + 1.33333333333333f * (src_data_47 - src_data_37) + + 0.53333333333f * (src_data_57 - src_data_67); + + float t50 = 0.0333333333f * src_data_10 + 0.02222222f * src_data_20 - 0.1666666666f * src_data_30 - + 0.1111111111f * src_data_40 + 0.1333333f * src_data_50 + 0.0888888f * src_data_60; + float t51 = 0.0333333333f * src_data_11 + 0.02222222f * src_data_21 - 0.1666666666f * src_data_31 - + 0.1111111111f * src_data_41 + 0.1333333f * src_data_51 + 0.0888888f * src_data_61; + float t52 = 0.0333333333f * src_data_12 + 0.02222222f * src_data_22 - 0.1666666666f * src_data_32 - + 0.1111111111f * src_data_42 + 0.1333333f * src_data_52 + 0.0888888f * src_data_62; + float t53 = 0.0333333333f * src_data_13 + 0.02222222f * src_data_23 - 0.1666666666f * src_data_33 - + 0.1111111111f * src_data_43 + 0.1333333f * src_data_53 + 0.0888888f * src_data_63; + float t54 = 0.0333333333f * src_data_14 + 0.02222222f * src_data_24 - 0.1666666666f * src_data_34 - + 0.1111111111f * src_data_44 + 0.1333333f * src_data_54 + 0.0888888f * src_data_64; + float t55 = 0.0333333333f * src_data_15 + 0.02222222f * src_data_25 - 0.1666666666f * src_data_35 - + 0.1111111111f * src_data_45 + 0.1333333f * src_data_55 + 0.0888888f * src_data_65; + float t56 = 0.0333333333f * src_data_16 + 0.02222222f * src_data_26 - 0.1666666666f * src_data_36 - + 0.1111111111f * src_data_46 + 0.1333333f * src_data_56 + 0.0888888f * src_data_66; + float t57 = 0.0333333333f * src_data_17 + 0.02222222f * src_data_27 - 0.1666666666f * src_data_37 - + 0.1111111111f * src_data_47 + 0.1333333f * src_data_57 + 0.0888888f * src_data_67; + + float t60 = -0.0333333333f * src_data_10 + 0.02222222f * src_data_20 + 0.1666666666f * src_data_30 - + 0.1111111111f * src_data_40 - 0.1333333f * src_data_50 + 0.0888888f * src_data_60; + float t61 = -0.0333333333f * src_data_11 + 0.02222222f * src_data_21 + 0.1666666666f * src_data_31 - + 0.1111111111f * src_data_41 - 0.1333333f * src_data_51 + 0.0888888f * src_data_61; + float t62 = -0.0333333333f * src_data_12 + 0.02222222f * src_data_22 + 0.1666666666f * src_data_32 - + 0.1111111111f * src_data_42 - 0.1333333f * src_data_52 + 0.0888888f * src_data_62; + float t63 = -0.0333333333f * src_data_13 + 0.02222222f * src_data_23 + 0.1666666666f * src_data_33 - + 0.1111111111f * src_data_43 - 0.1333333f * src_data_53 + 0.0888888f * src_data_63; + float t64 = -0.0333333333f * src_data_14 + 0.02222222f * src_data_24 + 0.1666666666f * src_data_34 - + 0.1111111111f * src_data_44 - 0.1333333f * src_data_54 + 0.0888888f * src_data_64; + float t65 = -0.0333333333f * src_data_15 + 0.02222222f * src_data_25 + 0.1666666666f * src_data_35 - + 0.1111111111f * src_data_45 - 0.1333333f * src_data_55 + 0.0888888f * src_data_65; + float t66 = -0.0333333333f * src_data_16 + 0.02222222f * src_data_26 + 0.1666666666f * src_data_36 - + 0.1111111111f * src_data_46 - 0.1333333f * src_data_56 + 0.0888888f * src_data_66; + float t67 = -0.0333333333f * src_data_17 + 0.02222222f * src_data_27 + 0.1666666666f * src_data_37 - + 0.1111111111f * src_data_47 - 0.1333333f * src_data_57 + 0.0888888f * src_data_67; + + float t70 = -0.5625f * src_data_10 + 3.0625f * src_data_30 - 3.5f * src_data_50 + src_data_70; + float t71 = -0.5625f * src_data_11 + 3.0625f * src_data_31 - 3.5f * src_data_51 + src_data_71; + float t72 = -0.5625f * src_data_12 + 3.0625f * src_data_32 - 3.5f * src_data_52 + src_data_72; + float t73 = -0.5625f * src_data_13 + 3.0625f * src_data_33 - 3.5f * src_data_53 + src_data_73; + float t74 = -0.5625f * src_data_14 + 3.0625f * src_data_34 - 3.5f * src_data_54 + src_data_74; + float t75 = -0.5625f * src_data_15 + 3.0625f * src_data_35 - 3.5f * src_data_55 + src_data_75; + float t76 = -0.5625f * src_data_16 + 3.0625f * src_data_36 - 3.5f * src_data_56 + src_data_76; + float t77 = -0.5625f * src_data_17 + 3.0625f * src_data_37 - 3.5f * src_data_57 + src_data_77; + + float m00 = t00 - 5.444444444444444445125f * t02 + 6.222222222222222222223f * t04 - 1.77777777777777778f * t06; + float m01 = 1.5f * t01 + 3.0f * t02 - 2.1666666666666667f * t03 - 4.333333333333333333f * t04 + + 0.66666666666666667f * t05 + 1.333333333333333f * t06; + float m02 = -1.5f * t01 + 3.0f * t02 + 2.1666666666666667f * t03 - 4.333333333333333333f * t04 - + 0.66666666666666667f * t05 + 1.333333333333333f * t06; + float m03 = -0.3f * (t01 + t02) + 1.33333333333333f * (t03 + t04) - 0.53333333333f * (t05 + t06); + float m04 = 0.3f * (t01 - t02) + 1.33333333333333f * (t04 - t03) + 0.53333333333f * (t05 - t06); + float m05 = 0.0333333333f * t01 + 0.02222222f * t02 - 0.1666666666f * t03 - 0.1111111111f * t04 + 0.1333333f * t05 + + 0.0888888f * t06; + float m06 = -0.0333333333f * t01 + 0.02222222f * t02 + 0.1666666666f * t03 - 0.1111111111f * t04 - + 0.1333333f * t05 + 0.0888888f * t06; + float m07 = -0.5625f * t01 + 3.0625f * t03 - 3.5f * t05 + t07; + + float m10 = t10 - 5.444444444444444445125f * t12 + 6.222222222222222222223f * t14 - 1.77777777777777778f * t16; + float m11 = 1.5f * t11 + 3.0f * t12 - 2.1666666666666667f * t13 - 4.333333333333333333f * t14 + + 0.66666666666666667f * t15 + 1.333333333333333f * t16; + float m12 = -1.5f * t11 + 3.0f * t12 + 2.1666666666666667f * t13 - 4.333333333333333333f * t14 - + 0.66666666666666667f * t15 + 1.333333333333333f * t16; + float m13 = -0.3f * (t11 + t12) + 1.33333333333333f * (t13 + t14) - 0.53333333333f * (t15 + t16); + float m14 = 0.3f * (t11 - t12) + 1.33333333333333f * (t14 - t13) + 0.53333333333f * (t15 - t16); + float m15 = 0.0333333333f * t11 + 0.02222222f * t12 - 0.1666666666f * t13 - 0.1111111111f * t14 + 0.1333333f * t15 + + 0.0888888f * t16; + float m16 = -0.0333333333f * t11 + 0.02222222f * t12 + 0.1666666666f * t13 - 0.1111111111f * t14 - + 0.1333333f * t15 + 0.0888888f * t16; + float m17 = -0.5625f * t11 + 3.0625f * t13 - 3.5f * t15 + t17; + + float m20 = t20 - 5.444444444444444445125f * t22 + 6.222222222222222222223f * t24 - 1.77777777777777778f * t26; + float m21 = 1.5f * t21 + 3.0f * t22 - 2.1666666666666667f * t23 - 4.333333333333333333f * t24 + + 0.66666666666666667f * t25 + 1.333333333333333f * t26; + float m22 = -1.5f * t21 + 3.0f * t22 + 2.1666666666666667f * t23 - 4.333333333333333333f * t24 - + 0.66666666666666667f * t25 + 1.333333333333333f * t26; + float m23 = -0.3f * (t21 + t22) + 1.33333333333333f * (t23 + t24) - 0.53333333333f * (t25 + t26); + float m24 = 0.3f * (t21 - t22) + 1.33333333333333f * (t24 - t23) + 0.53333333333f * (t25 - t26); + float m25 = 0.0333333333f * t21 + 0.02222222f * t22 - 0.1666666666f * t23 - 0.1111111111f * t24 + 0.1333333f * t25 + + 0.0888888f * t26; + float m26 = -0.0333333333f * t21 + 0.02222222f * t22 + 0.1666666666f * t23 - 0.1111111111f * t24 - + 0.1333333f * t25 + 0.0888888f * t26; + float m27 = -0.5625f * t21 + 3.0625f * t23 - 3.5f * t25 + t27; + + float m30 = t30 - 5.444444444444444445125f * t32 + 6.222222222222222222223f * t34 - 1.77777777777777778f * t36; + float m31 = 1.5f * t31 + 3.0f * t32 - 2.1666666666666667f * t33 - 4.333333333333333333f * t34 + + 0.66666666666666667f * t35 + 1.333333333333333f * t36; + float m32 = -1.5f * t31 + 3.0f * t32 + 2.1666666666666667f * t33 - 4.333333333333333333f * t34 - + 0.66666666666666667f * t35 + 1.333333333333333f * t36; + float m33 = -0.3f * (t31 + t32) + 1.33333333333333f * (t33 + t34) - 0.53333333333f * (t35 + t36); + float m34 = 0.3f * (t31 - t32) + 1.33333333333333f * (t34 - t33) + 0.53333333333f * (t35 - t36); + float m35 = 0.0333333333f * t31 + 0.02222222f * t32 - 0.1666666666f * t33 - 0.1111111111f * t34 + 0.1333333f * t35 + + 0.0888888f * t36; + float m36 = -0.0333333333f * t31 + 0.02222222f * t32 + 0.1666666666f * t33 - 0.1111111111f * t34 - + 0.1333333f * t35 + 0.0888888f * t36; + float m37 = -0.5625f * t31 + 3.0625f * t33 - 3.5f * t35 + t37; + + float m40 = t40 - 5.444444444444444445125f * t42 + 6.222222222222222222223f * t44 - 1.77777777777777778f * t46; + float m41 = 1.5f * t41 + 3.0f * t42 - 2.1666666666666667f * t43 - 4.333333333333333333f * t44 + + 0.66666666666666667f * t45 + 1.333333333333333f * t46; + float m42 = -1.5f * t41 + 3.0f * t42 + 2.1666666666666667f * t43 - 4.333333333333333333f * t44 - + 0.66666666666666667f * t45 + 1.333333333333333f * t46; + float m43 = -0.3f * (t41 + t42) + 1.33333333333333f * (t43 + t44) - 0.53333333333f * (t45 + t46); + float m44 = 0.3f * (t41 - t42) + 1.33333333333333f * (t44 - t43) + 0.53333333333f * (t45 - t46); + float m45 = 0.0333333333f * t41 + 0.02222222f * t42 - 0.1666666666f * t43 - 0.1111111111f * t44 + 0.1333333f * t45 + + 0.0888888f * t46; + float m46 = -0.0333333333f * t41 + 0.02222222f * t42 + 0.1666666666f * t43 - 0.1111111111f * t44 - + 0.1333333f * t45 + 0.0888888f * t46; + float m47 = -0.5625f * t41 + 3.0625f * t43 - 3.5f * t45 + t47; + + float m50 = t50 - 5.444444444444444445125f * t52 + 6.222222222222222222223f * t54 - 1.77777777777777778f * t56; + float m51 = 1.5f * t51 + 3.0f * t52 - 2.1666666666666667f * t53 - 4.333333333333333333f * t54 + + 0.66666666666666667f * t55 + 1.333333333333333f * t56; + float m52 = -1.5f * t51 + 3.0f * t52 + 2.1666666666666667f * t53 - 4.333333333333333333f * t54 - + 0.66666666666666667f * t55 + 1.333333333333333f * t56; + float m53 = -0.3f * (t51 + t52) + 1.33333333333333f * (t53 + t54) - 0.53333333333f * (t55 + t56); + float m54 = 0.3f * (t51 - t52) + 1.33333333333333f * (t54 - t53) + 0.53333333333f * (t55 - t56); + float m55 = 0.0333333333f * t51 + 0.02222222f * t52 - 0.1666666666f * t53 - 0.1111111111f * t54 + 0.1333333f * t55 + + 0.0888888f * t56; + float m56 = -0.0333333333f * t51 + 0.02222222f * t52 + 0.1666666666f * t53 - 0.1111111111f * t54 - + 0.1333333f * t55 + 0.0888888f * t56; + float m57 = -0.5625f * t51 + 3.0625f * t53 - 3.5f * t55 + t57; + + float m60 = t60 - 5.444444444444444445125f * t62 + 6.222222222222222222223f * t64 - 1.77777777777777778f * t66; + float m61 = 1.5f * t61 + 3.0f * t62 - 2.1666666666666667f * t63 - 4.333333333333333333f * t64 + + 0.66666666666666667f * t65 + 1.333333333333333f * t66; + float m62 = -1.5f * t61 + 3.0f * t62 + 2.1666666666666667f * t63 - 4.333333333333333333f * t64 - + 0.66666666666666667f * t65 + 1.333333333333333f * t66; + float m63 = -0.3f * (t61 + t62) + 1.33333333333333f * (t63 + t64) - 0.53333333333f * (t65 + t66); + float m64 = 0.3f * (t61 - t62) + 1.33333333333333f * (t64 - t63) + 0.53333333333f * (t65 - t66); + float m65 = 0.0333333333f * t61 + 0.02222222f * t62 - 0.1666666666f * t63 - 0.1111111111f * t64 + 0.1333333f * t65 + + 0.0888888f * t66; + float m66 = -0.0333333333f * t61 + 0.02222222f * t62 + 0.1666666666f * t63 - 0.1111111111f * t64 - + 0.1333333f * t65 + 0.0888888f * t66; + float m67 = -0.5625f * t61 + 3.0625f * t63 - 3.5f * t65 + t67; + + float m70 = t70 - 5.444444444444444445125f * t72 + 6.222222222222222222223f * t74 - 1.77777777777777778f * t76; + float m71 = 1.5f * t71 + 3.0f * t72 - 2.1666666666666667f * t73 - 4.333333333333333333f * t74 + + 0.66666666666666667f * t75 + 1.333333333333333f * t76; + float m72 = -1.5f * t71 + 3.0f * t72 + 2.1666666666666667f * t73 - 4.333333333333333333f * t74 - + 0.66666666666666667f * t75 + 1.333333333333333f * t76; + float m73 = -0.3f * (t71 + t72) + 1.33333333333333f * (t73 + t74) - 0.53333333333f * (t75 + t76); + float m74 = 0.3f * (t71 - t72) + 1.33333333333333f * (t74 - t73) + 0.53333333333f * (t75 - t76); + float m75 = 0.0333333333f * t71 + 0.02222222f * t72 - 0.1666666666f * t73 - 0.1111111111f * t74 + 0.1333333f * t75 + + 0.0888888f * t76; + float m76 = -0.0333333333f * t71 + 0.02222222f * t72 + 0.1666666666f * t73 - 0.1111111111f * t74 - + 0.1333333f * t75 + 0.0888888f * t76; + float m77 = -0.5625f * t71 + 3.0625f * t73 - 3.5f * t75 + t77; + + (dst_data + i)[0] = m00; + (dst_data + i + dst_step)[0] = m01; + (dst_data + i + 2 * dst_step)[0] = m02; + (dst_data + i + 3 * dst_step)[0] = m03; + (dst_data + i + 4 * dst_step)[0] = m04; + (dst_data + i + 5 * dst_step)[0] = m05; + (dst_data + i + 6 * dst_step)[0] = m06; + (dst_data + i + 7 * dst_step)[0] = m07; + + (dst_data + i + 8 * dst_step)[0] = m10; + (dst_data + i + 9 * dst_step)[0] = m11; + (dst_data + i + 10 * dst_step)[0] = m12; + (dst_data + i + 11 * dst_step)[0] = m13; + (dst_data + i + 12 * dst_step)[0] = m14; + (dst_data + i + 13 * dst_step)[0] = m15; + (dst_data + i + 14 * dst_step)[0] = m16; + (dst_data + i + 15 * dst_step)[0] = m17; + + (dst_data + i + 16 * dst_step)[0] = m20; + (dst_data + i + 17 * dst_step)[0] = m21; + (dst_data + i + 18 * dst_step)[0] = m22; + (dst_data + i + 19 * dst_step)[0] = m23; + (dst_data + i + 20 * dst_step)[0] = m24; + (dst_data + i + 21 * dst_step)[0] = m25; + (dst_data + i + 22 * dst_step)[0] = m26; + (dst_data + i + 23 * dst_step)[0] = m27; + + (dst_data + i + 24 * dst_step)[0] = m30; + (dst_data + i + 25 * dst_step)[0] = m31; + (dst_data + i + 26 * dst_step)[0] = m32; + (dst_data + i + 27 * dst_step)[0] = m33; + (dst_data + i + 28 * dst_step)[0] = m34; + (dst_data + i + 29 * dst_step)[0] = m35; + (dst_data + i + 30 * dst_step)[0] = m36; + (dst_data + i + 31 * dst_step)[0] = m37; + + (dst_data + i + 32 * dst_step)[0] = m40; + (dst_data + i + 33 * dst_step)[0] = m41; + (dst_data + i + 34 * dst_step)[0] = m42; + (dst_data + i + 35 * dst_step)[0] = m43; + (dst_data + i + 36 * dst_step)[0] = m44; + (dst_data + i + 37 * dst_step)[0] = m45; + (dst_data + i + 38 * dst_step)[0] = m46; + (dst_data + i + 39 * dst_step)[0] = m47; + + (dst_data + i + 40 * dst_step)[0] = m50; + (dst_data + i + 41 * dst_step)[0] = m51; + (dst_data + i + 42 * dst_step)[0] = m52; + (dst_data + i + 43 * dst_step)[0] = m53; + (dst_data + i + 44 * dst_step)[0] = m54; + (dst_data + i + 45 * dst_step)[0] = m55; + (dst_data + i + 46 * dst_step)[0] = m56; + (dst_data + i + 47 * dst_step)[0] = m57; + + (dst_data + i + 48 * dst_step)[0] = m60; + (dst_data + i + 49 * dst_step)[0] = m61; + (dst_data + i + 50 * dst_step)[0] = m62; + (dst_data + i + 51 * dst_step)[0] = m63; + (dst_data + i + 52 * dst_step)[0] = m64; + (dst_data + i + 53 * dst_step)[0] = m65; + (dst_data + i + 54 * dst_step)[0] = m66; + (dst_data + i + 55 * dst_step)[0] = m67; + + (dst_data + i + 56 * dst_step)[0] = m70; + (dst_data + i + 57 * dst_step)[0] = m71; + (dst_data + i + 58 * dst_step)[0] = m72; + (dst_data + i + 59 * dst_step)[0] = m73; + (dst_data + i + 60 * dst_step)[0] = m74; + (dst_data + i + 61 * dst_step)[0] = m75; + (dst_data + i + 62 * dst_step)[0] = m76; + (dst_data + i + 63 * dst_step)[0] = m77; + } +#endif +} + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias_data); + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 15 * src_step); + + float32x4_t t00 = vaddq_f32(src_data_00, vaddq_f32(src_data_10, src_data_20)); + float32x4_t t01 = vaddq_f32(src_data_01, vaddq_f32(src_data_11, src_data_21)); + float32x4_t t02 = vaddq_f32(src_data_02, vaddq_f32(src_data_12, src_data_22)); + float32x4_t t03 = vaddq_f32(src_data_03, vaddq_f32(src_data_13, src_data_23)); + + float32x4_t t10 = vsubq_f32(src_data_30, vmulq_n_f32(vsubq_f32(src_data_10, src_data_20), 0.5)); + float32x4_t t11 = vsubq_f32(src_data_31, vmulq_n_f32(vsubq_f32(src_data_11, src_data_21), 0.5)); + float32x4_t t12 = vsubq_f32(src_data_32, vmulq_n_f32(vsubq_f32(src_data_12, src_data_22), 0.5)); + float32x4_t t13 = vsubq_f32(src_data_33, vmulq_n_f32(vsubq_f32(src_data_13, src_data_23), 0.5)); + + float32x4_t m00 = vaddq_f32(vaddq_f32(t00, vaddq_f32(t01, t02)), bias_ptr); + float32x4_t m01 = vaddq_f32(vaddq_f32(t03, vmulq_n_f32(vsubq_f32(t01, t02), 0.5)), bias_ptr); + float32x4_t m10 = vaddq_f32(vaddq_f32(t10, vaddq_f32(t11, t12)), bias_ptr); + float32x4_t m11 = vaddq_f32(vaddq_f32(t13, vmulq_n_f32(vsubq_f32(t11, t12), 0.5)), bias_ptr); + + vst1q_f32(dst_data, m00); + vst1q_f32(dst_data + C4NUM, m01); + vst1q_f32(dst_data + dst_step * C4NUM, m10); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m11); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_10 = src_data[i + 4 * src_step]; + float src_data_11 = src_data[i + 5 * src_step]; + float src_data_12 = src_data[i + 6 * src_step]; + float src_data_13 = src_data[i + 7 * src_step]; + float src_data_20 = src_data[i + 8 * src_step]; + float src_data_21 = src_data[i + 9 * src_step]; + float src_data_22 = src_data[i + 10 * src_step]; + float src_data_23 = src_data[i + 11 * src_step]; + float src_data_30 = src_data[i + 12 * src_step]; + float src_data_31 = src_data[i + 13 * src_step]; + float src_data_32 = src_data[i + 14 * src_step]; + float src_data_33 = src_data[i + 15 * src_step]; + + float t00 = src_data_00 + src_data_10 + src_data_20; + float t01 = src_data_01 + src_data_11 + src_data_21; + float t02 = src_data_02 + src_data_12 + src_data_22; + float t03 = src_data_03 + src_data_13 + src_data_23; + + float t10 = 0.5f * (src_data_10 - src_data_20) + src_data_30; + float t11 = 0.5f * (src_data_11 - src_data_21) + src_data_31; + float t12 = 0.5f * (src_data_12 - src_data_22) + src_data_32; + float t13 = 0.5f * (src_data_13 - src_data_23) + src_data_33; + + float m00 = t00 + t01 + t02 + bias_data[i]; + float m01 = 0.5f * (t01 - t02) + t03 + bias_data[i]; + float m10 = t10 + t11 + t12 + bias_data[i]; + float m11 = 0.5f * (t11 - t12) + t13 + bias_data[i]; + + (dst_data + i)[0] = m00; + (dst_data + i + C4NUM)[0] = m01; + (dst_data + i + dst_step * C4NUM)[0] = m10; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11; + } +#endif +} + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias_data); + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 15 * src_step); + + float32x4_t t00 = vaddq_f32(src_data_00, vaddq_f32(src_data_10, src_data_20)); + float32x4_t t01 = vaddq_f32(src_data_01, vaddq_f32(src_data_11, src_data_21)); + float32x4_t t02 = vaddq_f32(src_data_02, vaddq_f32(src_data_12, src_data_22)); + float32x4_t t03 = vaddq_f32(src_data_03, vaddq_f32(src_data_13, src_data_23)); + + float32x4_t t10 = vmulq_n_f32(vsubq_f32(src_data_10, src_data_20), 0.5); + float32x4_t t11 = vmulq_n_f32(vsubq_f32(src_data_11, src_data_21), 0.5); + float32x4_t t12 = vmulq_n_f32(vsubq_f32(src_data_12, src_data_22), 0.5); + float32x4_t t13 = vmulq_n_f32(vsubq_f32(src_data_13, src_data_23), 0.5); + + float32x4_t t20 = vaddq_f32(src_data_30, vmulq_n_f32(vaddq_f32(src_data_10, src_data_20), 0.25)); + float32x4_t t21 = vaddq_f32(src_data_31, vmulq_n_f32(vaddq_f32(src_data_11, src_data_21), 0.25)); + float32x4_t t22 = vaddq_f32(src_data_32, vmulq_n_f32(vaddq_f32(src_data_12, src_data_22), 0.25)); + float32x4_t t23 = vaddq_f32(src_data_33, vmulq_n_f32(vaddq_f32(src_data_13, src_data_23), 0.25)); + + float32x4_t m00 = vaddq_f32(vaddq_f32(t00, vaddq_f32(t01, t02)), bias_ptr); + float32x4_t m01 = vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.5), bias_ptr); + float32x4_t m02 = vaddq_f32(vaddq_f32(t03, vmulq_n_f32(vaddq_f32(t01, t02), 0.25)), bias_ptr); + float32x4_t m10 = vaddq_f32(vaddq_f32(t10, vaddq_f32(t11, t12)), bias_ptr); + float32x4_t m11 = vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.5), bias_ptr); + float32x4_t m12 = vaddq_f32(vaddq_f32(t13, vmulq_n_f32(vaddq_f32(t11, t12), 0.25)), bias_ptr); + float32x4_t m20 = vaddq_f32(vaddq_f32(t20, vaddq_f32(t21, t22)), bias_ptr); + float32x4_t m21 = vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.5), bias_ptr); + float32x4_t m22 = vaddq_f32(vaddq_f32(t23, vmulq_n_f32(vaddq_f32(t21, t22), 0.25)), bias_ptr); + + vst1q_f32(dst_data, m00); + vst1q_f32(dst_data + C4NUM, m01); + vst1q_f32(dst_data + 2 * C4NUM, m02); + vst1q_f32(dst_data + dst_step * C4NUM, m10); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m11); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m12); + vst1q_f32(dst_data + 2 * dst_step * C4NUM, m20); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m21); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m22); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_10 = src_data[i + 4 * src_step]; + float src_data_11 = src_data[i + 5 * src_step]; + float src_data_12 = src_data[i + 6 * src_step]; + float src_data_13 = src_data[i + 7 * src_step]; + float src_data_20 = src_data[i + 8 * src_step]; + float src_data_21 = src_data[i + 9 * src_step]; + float src_data_22 = src_data[i + 10 * src_step]; + float src_data_23 = src_data[i + 11 * src_step]; + float src_data_30 = src_data[i + 12 * src_step]; + float src_data_31 = src_data[i + 13 * src_step]; + float src_data_32 = src_data[i + 14 * src_step]; + float src_data_33 = src_data[i + 15 * src_step]; + + float t00 = src_data_00 + src_data_10 + src_data_20; + float t01 = src_data_01 + src_data_11 + src_data_21; + float t02 = src_data_02 + src_data_12 + src_data_22; + float t03 = src_data_03 + src_data_13 + src_data_23; + + float t10 = 0.5f * (src_data_10 - src_data_20); + float t11 = 0.5f * (src_data_11 - src_data_21); + float t12 = 0.5f * (src_data_12 - src_data_22); + float t13 = 0.5f * (src_data_13 - src_data_23); + + float t20 = 0.25f * (src_data_10 + src_data_20) + src_data_30; + float t21 = 0.25f * (src_data_11 + src_data_21) + src_data_31; + float t22 = 0.25f * (src_data_12 + src_data_22) + src_data_32; + float t23 = 0.25f * (src_data_13 + src_data_23) + src_data_33; + + float m00 = t00 + t01 + t02 + bias_data[i]; + float m01 = 0.5f * (t01 - t02) + bias_data[i]; + float m02 = 0.25f * (t01 + t02) + t03 + bias_data[i]; + + float m10 = t10 + t11 + t12 + bias_data[i]; + float m11 = 0.5f * (t11 - t12) + bias_data[i]; + float m12 = 0.25f * (t11 + t12) + t13 + bias_data[i]; + + float m20 = t20 + t21 + t22 + bias_data[i]; + float m21 = 0.5f * (t21 - t22) + bias_data[i]; + float m22 = 0.25f * (t21 + t22) + t23 + bias_data[i]; + + (dst_data + i)[0] = m00; + (dst_data + i + C4NUM)[0] = m01; + (dst_data + i + 2 * C4NUM)[0] = m02; + + (dst_data + i + dst_step * C4NUM)[0] = m10; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22; + } +#endif +} + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)), src_data_70); + float32x4_t t11 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)), src_data_71); + float32x4_t t12 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)), src_data_72); + float32x4_t t13 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)), src_data_73); + float32x4_t t14 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)), src_data_74); + float32x4_t t15 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)), src_data_75); + float32x4_t t16 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)), src_data_76); + float32x4_t t17 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)), t17); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21 + src_data_70; + float t11 = 0.5f * d02 + d12 + 1.5f * d22 + src_data_71; + float t12 = 0.5f * d03 + d13 + 1.5f * d23 + src_data_72; + float t13 = 0.5f * d04 + d14 + 1.5f * d24 + src_data_73; + float t14 = 0.5f * d05 + d15 + 1.5f * d25 + src_data_74; + float t15 = 0.5f * d06 + d16 + 1.5f * d26 + src_data_75; + float t16 = 0.5f * d07 + d17 + 1.5f * d27 + src_data_76; + float t17 = 0.5f * d08 + d18 + 1.5f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s21 = t03 - t04; + float s22 = t13 - t14; + float s31 = t05 - t06; + float s32 = t15 - t16; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31 + t07; + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32 + t17; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + } +#endif +} + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)), src_data_70); + float32x4_t t21 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)), src_data_71); + float32x4_t t22 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)), src_data_72); + float32x4_t t23 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)), src_data_73); + float32x4_t t24 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)), src_data_74); + float32x4_t t25 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)), src_data_75); + float32x4_t t26 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)), src_data_76); + float32x4_t t27 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)), t27); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51 + src_data_70; + float t21 = 0.25f * d32 + d42 + 2.25f * d52 + src_data_71; + float t22 = 0.25f * d33 + d43 + 2.25f * d53 + src_data_72; + float t23 = 0.25f * d34 + d44 + 2.25f * d54 + src_data_73; + float t24 = 0.25f * d35 + d45 + 2.25f * d55 + src_data_74; + float t25 = 0.25f * d36 + d46 + 2.25f * d56 + src_data_75; + float t26 = 0.25f * d37 + d47 + 2.25f * d57 + src_data_76; + float t27 = 0.25f * d38 + d48 + 2.25f * d58 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63 + t27; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + } +#endif +} + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)), src_data_70); + float32x4_t t31 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)), src_data_71); + float32x4_t t32 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)), src_data_72); + float32x4_t t33 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)), src_data_73); + float32x4_t t34 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)), src_data_74); + float32x4_t t35 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)), src_data_75); + float32x4_t t36 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)), src_data_76); + float32x4_t t37 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)), t37); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21 + src_data_70; + float t31 = 0.125f * d02 + d12 + 3.375f * d22 + src_data_71; + float t32 = 0.125f * d03 + d13 + 3.375f * d23 + src_data_72; + float t33 = 0.125f * d04 + d14 + 3.375f * d24 + src_data_73; + float t34 = 0.125f * d05 + d15 + 3.375f * d25 + src_data_74; + float t35 = 0.125f * d06 + d16 + 3.375f * d26 + src_data_75; + float t36 = 0.125f * d07 + d17 + 3.375f * d27 + src_data_76; + float t37 = 0.125f * d08 + d18 + 3.375f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34 + t37; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + } +#endif +} + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)); + float32x4_t t31 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)); + float32x4_t t32 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)); + float32x4_t t33 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)); + float32x4_t t34 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)); + float32x4_t t35 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)); + float32x4_t t36 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)); + float32x4_t t37 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)); + + float32x4_t t40 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.0625), d41), vmulq_n_f32(d51, 5.0625)), src_data_70); + float32x4_t t41 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.0625), d42), vmulq_n_f32(d52, 5.0625)), src_data_71); + float32x4_t t42 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.0625), d43), vmulq_n_f32(d53, 5.0625)), src_data_72); + float32x4_t t43 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.0625), d44), vmulq_n_f32(d54, 5.0625)), src_data_73); + float32x4_t t44 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.0625), d45), vmulq_n_f32(d55, 5.0625)), src_data_74); + float32x4_t t45 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.0625), d46), vmulq_n_f32(d56, 5.0625)), src_data_75); + float32x4_t t46 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.0625), d47), vmulq_n_f32(d57, 5.0625)), src_data_76); + float32x4_t t47 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.0625), d48), vmulq_n_f32(d58, 5.0625)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + float32x4_t s15 = vsubq_f32(t41, t42); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + float32x4_t s25 = vsubq_f32(t43, t44); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + float32x4_t s35 = vsubq_f32(t45, t46); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + float32x4_t s45 = vaddq_f32(t41, t42); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + float32x4_t s55 = vaddq_f32(t43, t44); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + float32x4_t s65 = vaddq_f32(t45, t46); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)); + float32x4_t m04 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.0625), s51), vmulq_n_f32(s61, 5.0625)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)); + float32x4_t m14 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.0625), s52), vmulq_n_f32(s62, 5.0625)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)); + float32x4_t m24 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.0625), s53), vmulq_n_f32(s63, 5.0625)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)); + float32x4_t m34 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.0625), s54), vmulq_n_f32(s64, 5.0625)), t37); + + float32x4_t m40 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t40, t41), t42), t43), t44), t45), t46); + float32x4_t m41 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.5), s25), vmulq_n_f32(s35, 1.5)); + float32x4_t m42 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.25), s55), vmulq_n_f32(s65, 2.25)); + float32x4_t m43 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.125), s25), vmulq_n_f32(s35, 3.375)); + float32x4_t m44 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.0625), s55), vmulq_n_f32(s65, 5.0625)), t47); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + vst1q_f32(dst_data + 4 * C4NUM, vaddq_f32(m04, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m14, bias_ptr)); + + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m24, bias_ptr)); + + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m34, bias_ptr)); + + vst1q_f32(dst_data + 4 * dst_step * C4NUM, vaddq_f32(m40, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, vaddq_f32(m41, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m42, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m43, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m44, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21; + float t31 = 0.125f * d02 + d12 + 3.375f * d22; + float t32 = 0.125f * d03 + d13 + 3.375f * d23; + float t33 = 0.125f * d04 + d14 + 3.375f * d24; + float t34 = 0.125f * d05 + d15 + 3.375f * d25; + float t35 = 0.125f * d06 + d16 + 3.375f * d26; + float t36 = 0.125f * d07 + d17 + 3.375f * d27; + float t37 = 0.125f * d08 + d18 + 3.375f * d28; + + float t40 = 0.0625f * d31 + d41 + 5.0625f * d51 + src_data_70; + float t41 = 0.0625f * d32 + d42 + 5.0625f * d52 + src_data_71; + float t42 = 0.0625f * d33 + d43 + 5.0625f * d53 + src_data_72; + float t43 = 0.0625f * d34 + d44 + 5.0625f * d54 + src_data_73; + float t44 = 0.0625f * d35 + d45 + 5.0625f * d55 + src_data_74; + float t45 = 0.0625f * d36 + d46 + 5.0625f * d56 + src_data_75; + float t46 = 0.0625f * d37 + d47 + 5.0625f * d57 + src_data_76; + float t47 = 0.0625f * d38 + d48 + 5.0625f * d58 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + float s15 = t41 - t42; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + float s25 = t43 - t44; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + float s35 = t45 - t46; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + float s45 = t41 + t42; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + float s55 = t43 + t44; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + float s65 = t45 + t46; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31; + float m04 = 0.0625f * s41 + s51 + 5.0625f * s61 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32; + float m14 = 0.0625f * s42 + s52 + 5.0625f * s62 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33; + float m24 = 0.0625f * s43 + s53 + 5.0625f * s63 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34; + float m34 = 0.0625f * s44 + s54 + 5.0625f * s64 + t37; + + float m40 = t40 + t41 + t42 + t43 + t44 + t45 + t46; + float m41 = 0.5f * s15 + s25 + 1.5f * s35; + float m42 = 0.25f * s45 + s55 + 2.25f * s65; + float m43 = 0.125f * s15 + s25 + 3.375f * s35; + float m44 = 0.0625f * s45 + s55 + 5.0625f * s65 + t47; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + (dst_data + i + 4 * C4NUM)[0] = m04 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 4 * C4NUM)[0] = m14 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 4 * C4NUM)[0] = m24 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 4 * C4NUM)[0] = m34 + bias_data[i]; + + (dst_data + i + 4 * dst_step * C4NUM)[0] = m40 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + C4NUM)[0] = m41 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 2 * C4NUM)[0] = m42 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 3 * C4NUM)[0] = m43 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 4 * C4NUM)[0] = m44 + bias_data[i]; + } +#endif +} + +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)); + float32x4_t t31 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)); + float32x4_t t32 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)); + float32x4_t t33 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)); + float32x4_t t34 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)); + float32x4_t t35 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)); + float32x4_t t36 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)); + float32x4_t t37 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)); + + float32x4_t t40 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.0625), d41), vmulq_n_f32(d51, 5.0625)); + float32x4_t t41 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.0625), d42), vmulq_n_f32(d52, 5.0625)); + float32x4_t t42 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.0625), d43), vmulq_n_f32(d53, 5.0625)); + float32x4_t t43 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.0625), d44), vmulq_n_f32(d54, 5.0625)); + float32x4_t t44 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.0625), d45), vmulq_n_f32(d55, 5.0625)); + float32x4_t t45 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.0625), d46), vmulq_n_f32(d56, 5.0625)); + float32x4_t t46 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.0625), d47), vmulq_n_f32(d57, 5.0625)); + float32x4_t t47 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.0625), d48), vmulq_n_f32(d58, 5.0625)); + + float32x4_t t50 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.03125), d11), vmulq_n_f32(d21, 7.59375)), src_data_70); + float32x4_t t51 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.03125), d12), vmulq_n_f32(d22, 7.59375)), src_data_71); + float32x4_t t52 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.03125), d13), vmulq_n_f32(d23, 7.59375)), src_data_72); + float32x4_t t53 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.03125), d14), vmulq_n_f32(d24, 7.59375)), src_data_73); + float32x4_t t54 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.03125), d15), vmulq_n_f32(d25, 7.59375)), src_data_74); + float32x4_t t55 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.03125), d16), vmulq_n_f32(d26, 7.59375)), src_data_75); + float32x4_t t56 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.03125), d17), vmulq_n_f32(d27, 7.59375)), src_data_76); + float32x4_t t57 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.03125), d18), vmulq_n_f32(d28, 7.59375)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + float32x4_t s15 = vsubq_f32(t41, t42); + float32x4_t s16 = vsubq_f32(t51, t52); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + float32x4_t s25 = vsubq_f32(t43, t44); + float32x4_t s26 = vsubq_f32(t53, t54); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + float32x4_t s35 = vsubq_f32(t45, t46); + float32x4_t s36 = vsubq_f32(t55, t56); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + float32x4_t s45 = vaddq_f32(t41, t42); + float32x4_t s46 = vaddq_f32(t51, t52); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + float32x4_t s55 = vaddq_f32(t43, t44); + float32x4_t s56 = vaddq_f32(t53, t54); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + float32x4_t s65 = vaddq_f32(t45, t46); + float32x4_t s66 = vaddq_f32(t55, t56); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)); + float32x4_t m04 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.0625), s51), vmulq_n_f32(s61, 5.0625)); + float32x4_t m05 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.03125), s21), vmulq_n_f32(s31, 7.59375)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)); + float32x4_t m14 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.0625), s52), vmulq_n_f32(s62, 5.0625)); + float32x4_t m15 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.03125), s22), vmulq_n_f32(s32, 7.59375)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)); + float32x4_t m24 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.0625), s53), vmulq_n_f32(s63, 5.0625)); + float32x4_t m25 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.03125), s23), vmulq_n_f32(s33, 7.59375)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)); + float32x4_t m34 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.0625), s54), vmulq_n_f32(s64, 5.0625)); + float32x4_t m35 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.03125), s24), vmulq_n_f32(s34, 7.59375)), t37); + + float32x4_t m40 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t40, t41), t42), t43), t44), t45), t46); + float32x4_t m41 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.5), s25), vmulq_n_f32(s35, 1.5)); + float32x4_t m42 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.25), s55), vmulq_n_f32(s65, 2.25)); + float32x4_t m43 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.125), s25), vmulq_n_f32(s35, 3.375)); + float32x4_t m44 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.0625), s55), vmulq_n_f32(s65, 5.0625)); + float32x4_t m45 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.03125), s25), vmulq_n_f32(s35, 7.59375)), t47); + + float32x4_t m50 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t50, t51), t52), t53), t54), t55), t56); + float32x4_t m51 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.5), s26), vmulq_n_f32(s36, 1.5)); + float32x4_t m52 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.25), s56), vmulq_n_f32(s66, 2.25)); + float32x4_t m53 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.125), s26), vmulq_n_f32(s36, 3.375)); + float32x4_t m54 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.0625), s56), vmulq_n_f32(s66, 5.0625)); + float32x4_t m55 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.03125), s26), vmulq_n_f32(s36, 7.59375)), t57); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + vst1q_f32(dst_data + 4 * C4NUM, vaddq_f32(m04, bias_ptr)); + vst1q_f32(dst_data + 5 * C4NUM, vaddq_f32(m05, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m14, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m15, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m24, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m25, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m34, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m35, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM, vaddq_f32(m40, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, vaddq_f32(m41, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m42, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m43, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m44, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m45, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM, vaddq_f32(m50, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + C4NUM, vaddq_f32(m51, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m52, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m53, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m54, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m55, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21; + float t31 = 0.125f * d02 + d12 + 3.375f * d22; + float t32 = 0.125f * d03 + d13 + 3.375f * d23; + float t33 = 0.125f * d04 + d14 + 3.375f * d24; + float t34 = 0.125f * d05 + d15 + 3.375f * d25; + float t35 = 0.125f * d06 + d16 + 3.375f * d26; + float t36 = 0.125f * d07 + d17 + 3.375f * d27; + float t37 = 0.125f * d08 + d18 + 3.375f * d28; + + float t40 = 0.0625f * d31 + d41 + 5.0625f * d51; + float t41 = 0.0625f * d32 + d42 + 5.0625f * d52; + float t42 = 0.0625f * d33 + d43 + 5.0625f * d53; + float t43 = 0.0625f * d34 + d44 + 5.0625f * d54; + float t44 = 0.0625f * d35 + d45 + 5.0625f * d55; + float t45 = 0.0625f * d36 + d46 + 5.0625f * d56; + float t46 = 0.0625f * d37 + d47 + 5.0625f * d57; + float t47 = 0.0625f * d38 + d48 + 5.0625f * d58; + + float t50 = 0.03125f * d01 + d11 + 7.59375f * d21 + src_data_70; + float t51 = 0.03125f * d02 + d12 + 7.59375f * d22 + src_data_71; + float t52 = 0.03125f * d03 + d13 + 7.59375f * d23 + src_data_72; + float t53 = 0.03125f * d04 + d14 + 7.59375f * d24 + src_data_73; + float t54 = 0.03125f * d05 + d15 + 7.59375f * d25 + src_data_74; + float t55 = 0.03125f * d06 + d16 + 7.59375f * d26 + src_data_75; + float t56 = 0.03125f * d07 + d17 + 7.59375f * d27 + src_data_76; + float t57 = 0.03125f * d08 + d18 + 7.59375f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + float s15 = t41 - t42; + float s16 = t51 - t52; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + float s25 = t43 - t44; + float s26 = t53 - t54; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + float s35 = t45 - t46; + float s36 = t55 - t56; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + float s45 = t41 + t42; + float s46 = t51 + t52; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + float s55 = t43 + t44; + float s56 = t53 + t54; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + float s65 = t45 + t46; + float s66 = t55 + t56; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31; + float m04 = 0.0625f * s41 + s51 + 5.0625f * s61; + float m05 = 0.03125f * s11 + s21 + 7.59375f * s31 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32; + float m14 = 0.0625f * s42 + s52 + 5.0625f * s62; + float m15 = 0.03125f * s12 + s22 + 7.59375f * s32 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33; + float m24 = 0.0625f * s43 + s53 + 5.0625f * s63; + float m25 = 0.03125f * s13 + s23 + 7.59375f * s33 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34; + float m34 = 0.0625f * s44 + s54 + 5.0625f * s64; + float m35 = 0.03125f * s14 + s24 + 7.59375f * s34 + t37; + + float m40 = t40 + t41 + t42 + t43 + t44 + t45 + t46; + float m41 = 0.5f * s15 + s25 + 1.5f * s35; + float m42 = 0.25f * s45 + s55 + 2.25f * s65; + float m43 = 0.125f * s15 + s25 + 3.375f * s35; + float m44 = 0.0625f * s45 + s55 + 5.0625f * s65; + float m45 = 0.03125f * s15 + s25 + 7.59375f * s35 + t47; + + float m50 = t50 + t51 + t52 + t53 + t54 + t55 + t56; + float m51 = 0.5f * s16 + s26 + 1.5f * s36; + float m52 = 0.25f * s46 + s56 + 2.25f * s66; + float m53 = 0.125f * s16 + s26 + 3.375f * s36; + float m54 = 0.0625f * s46 + s56 + 5.0625f * s66; + float m55 = 0.03125f * s16 + s26 + 7.59375f * s36 + t57; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + (dst_data + i + 4 * C4NUM)[0] = m04 + bias_data[i]; + (dst_data + i + 5 * C4NUM)[0] = m05 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 4 * C4NUM)[0] = m14 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 5 * C4NUM)[0] = m15 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 4 * C4NUM)[0] = m24 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 5 * C4NUM)[0] = m25 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 4 * C4NUM)[0] = m34 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 5 * C4NUM)[0] = m35 + bias_data[i]; + + (dst_data + i + 4 * dst_step * C4NUM)[0] = m40 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + C4NUM)[0] = m41 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 2 * C4NUM)[0] = m42 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 3 * C4NUM)[0] = m43 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 4 * C4NUM)[0] = m44 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 5 * C4NUM)[0] = m45 + bias_data[i]; + + (dst_data + i + 5 * dst_step * C4NUM)[0] = m50 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + C4NUM)[0] = m51 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 2 * C4NUM)[0] = m52 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 3 * C4NUM)[0] = m53 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 4 * C4NUM)[0] = m54 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 5 * C4NUM)[0] = m55 + bias_data[i]; + } +#endif +} + +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)); + float32x4_t t31 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)); + float32x4_t t32 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)); + float32x4_t t33 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)); + float32x4_t t34 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)); + float32x4_t t35 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)); + float32x4_t t36 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)); + float32x4_t t37 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)); + + float32x4_t t40 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.0625), d41), vmulq_n_f32(d51, 5.0625)); + float32x4_t t41 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.0625), d42), vmulq_n_f32(d52, 5.0625)); + float32x4_t t42 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.0625), d43), vmulq_n_f32(d53, 5.0625)); + float32x4_t t43 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.0625), d44), vmulq_n_f32(d54, 5.0625)); + float32x4_t t44 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.0625), d45), vmulq_n_f32(d55, 5.0625)); + float32x4_t t45 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.0625), d46), vmulq_n_f32(d56, 5.0625)); + float32x4_t t46 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.0625), d47), vmulq_n_f32(d57, 5.0625)); + float32x4_t t47 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.0625), d48), vmulq_n_f32(d58, 5.0625)); + + float32x4_t t50 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.03125), d11), vmulq_n_f32(d21, 7.59375)); + float32x4_t t51 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.03125), d12), vmulq_n_f32(d22, 7.59375)); + float32x4_t t52 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.03125), d13), vmulq_n_f32(d23, 7.59375)); + float32x4_t t53 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.03125), d14), vmulq_n_f32(d24, 7.59375)); + float32x4_t t54 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.03125), d15), vmulq_n_f32(d25, 7.59375)); + float32x4_t t55 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.03125), d16), vmulq_n_f32(d26, 7.59375)); + float32x4_t t56 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.03125), d17), vmulq_n_f32(d27, 7.59375)); + float32x4_t t57 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.03125), d18), vmulq_n_f32(d28, 7.59375)); + + float32x4_t t60 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.015625), d41), vmulq_n_f32(d51, 11.390625)), src_data_70); + float32x4_t t61 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.015625), d42), vmulq_n_f32(d52, 11.390625)), src_data_71); + float32x4_t t62 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.015625), d43), vmulq_n_f32(d53, 11.390625)), src_data_72); + float32x4_t t63 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.015625), d44), vmulq_n_f32(d54, 11.390625)), src_data_73); + float32x4_t t64 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.015625), d45), vmulq_n_f32(d55, 11.390625)), src_data_74); + float32x4_t t65 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.015625), d46), vmulq_n_f32(d56, 11.390625)), src_data_75); + float32x4_t t66 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.015625), d47), vmulq_n_f32(d57, 11.390625)), src_data_76); + float32x4_t t67 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.015625), d48), vmulq_n_f32(d58, 11.390625)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + float32x4_t s15 = vsubq_f32(t41, t42); + float32x4_t s16 = vsubq_f32(t51, t52); + float32x4_t s17 = vsubq_f32(t61, t62); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + float32x4_t s25 = vsubq_f32(t43, t44); + float32x4_t s26 = vsubq_f32(t53, t54); + float32x4_t s27 = vsubq_f32(t63, t64); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + float32x4_t s35 = vsubq_f32(t45, t46); + float32x4_t s36 = vsubq_f32(t55, t56); + float32x4_t s37 = vsubq_f32(t65, t66); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + float32x4_t s45 = vaddq_f32(t41, t42); + float32x4_t s46 = vaddq_f32(t51, t52); + float32x4_t s47 = vaddq_f32(t61, t62); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + float32x4_t s55 = vaddq_f32(t43, t44); + float32x4_t s56 = vaddq_f32(t53, t54); + float32x4_t s57 = vaddq_f32(t63, t64); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + float32x4_t s65 = vaddq_f32(t45, t46); + float32x4_t s66 = vaddq_f32(t55, t56); + float32x4_t s67 = vaddq_f32(t65, t66); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)); + float32x4_t m04 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.0625), s51), vmulq_n_f32(s61, 5.0625)); + float32x4_t m05 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.03125), s21), vmulq_n_f32(s31, 7.59375)); + float32x4_t m06 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.015625), s51), vmulq_n_f32(s61, 11.390625)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)); + float32x4_t m14 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.0625), s52), vmulq_n_f32(s62, 5.0625)); + float32x4_t m15 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.03125), s22), vmulq_n_f32(s32, 7.59375)); + float32x4_t m16 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.015625), s52), vmulq_n_f32(s62, 11.390625)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)); + float32x4_t m24 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.0625), s53), vmulq_n_f32(s63, 5.0625)); + float32x4_t m25 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.03125), s23), vmulq_n_f32(s33, 7.59375)); + float32x4_t m26 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.015625), s53), vmulq_n_f32(s63, 11.390625)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)); + float32x4_t m34 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.0625), s54), vmulq_n_f32(s64, 5.0625)); + float32x4_t m35 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.03125), s24), vmulq_n_f32(s34, 7.59375)); + float32x4_t m36 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.015625), s54), vmulq_n_f32(s64, 11.390625)), t37); + + float32x4_t m40 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t40, t41), t42), t43), t44), t45), t46); + float32x4_t m41 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.5), s25), vmulq_n_f32(s35, 1.5)); + float32x4_t m42 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.25), s55), vmulq_n_f32(s65, 2.25)); + float32x4_t m43 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.125), s25), vmulq_n_f32(s35, 3.375)); + float32x4_t m44 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.0625), s55), vmulq_n_f32(s65, 5.0625)); + float32x4_t m45 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.03125), s25), vmulq_n_f32(s35, 7.59375)); + float32x4_t m46 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.015625), s55), vmulq_n_f32(s65, 11.390625)), t47); + + float32x4_t m50 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t50, t51), t52), t53), t54), t55), t56); + float32x4_t m51 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.5), s26), vmulq_n_f32(s36, 1.5)); + float32x4_t m52 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.25), s56), vmulq_n_f32(s66, 2.25)); + float32x4_t m53 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.125), s26), vmulq_n_f32(s36, 3.375)); + float32x4_t m54 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.0625), s56), vmulq_n_f32(s66, 5.0625)); + float32x4_t m55 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.03125), s26), vmulq_n_f32(s36, 7.59375)); + float32x4_t m56 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.015625), s56), vmulq_n_f32(s66, 11.390625)), t57); + + float32x4_t m60 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t60, t61), t62), t63), t64), t65), t66); + float32x4_t m61 = vaddq_f32(vaddq_f32(vmulq_n_f32(s17, 0.5), s27), vmulq_n_f32(s37, 1.5)); + float32x4_t m62 = vaddq_f32(vaddq_f32(vmulq_n_f32(s47, 0.25), s57), vmulq_n_f32(s67, 2.25)); + float32x4_t m63 = vaddq_f32(vaddq_f32(vmulq_n_f32(s17, 0.125), s27), vmulq_n_f32(s37, 3.375)); + float32x4_t m64 = vaddq_f32(vaddq_f32(vmulq_n_f32(s47, 0.0625), s57), vmulq_n_f32(s67, 5.0625)); + float32x4_t m65 = vaddq_f32(vaddq_f32(vmulq_n_f32(s17, 0.03125), s27), vmulq_n_f32(s37, 7.59375)); + float32x4_t m66 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s47, 0.015625), s57), vmulq_n_f32(s67, 11.390625)), t67); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + vst1q_f32(dst_data + 4 * C4NUM, vaddq_f32(m04, bias_ptr)); + vst1q_f32(dst_data + 5 * C4NUM, vaddq_f32(m05, bias_ptr)); + vst1q_f32(dst_data + 6 * C4NUM, vaddq_f32(m06, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m14, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m15, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m16, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m24, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m25, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m26, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m34, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m35, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m36, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM, vaddq_f32(m40, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, vaddq_f32(m41, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m42, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m43, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m44, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m45, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m46, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM, vaddq_f32(m50, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + C4NUM, vaddq_f32(m51, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m52, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m53, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m54, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m55, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m56, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM, vaddq_f32(m60, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + C4NUM, vaddq_f32(m61, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m62, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m63, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m64, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m65, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m66, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21; + float t31 = 0.125f * d02 + d12 + 3.375f * d22; + float t32 = 0.125f * d03 + d13 + 3.375f * d23; + float t33 = 0.125f * d04 + d14 + 3.375f * d24; + float t34 = 0.125f * d05 + d15 + 3.375f * d25; + float t35 = 0.125f * d06 + d16 + 3.375f * d26; + float t36 = 0.125f * d07 + d17 + 3.375f * d27; + float t37 = 0.125f * d08 + d18 + 3.375f * d28; + + float t40 = 0.0625f * d31 + d41 + 5.0625f * d51; + float t41 = 0.0625f * d32 + d42 + 5.0625f * d52; + float t42 = 0.0625f * d33 + d43 + 5.0625f * d53; + float t43 = 0.0625f * d34 + d44 + 5.0625f * d54; + float t44 = 0.0625f * d35 + d45 + 5.0625f * d55; + float t45 = 0.0625f * d36 + d46 + 5.0625f * d56; + float t46 = 0.0625f * d37 + d47 + 5.0625f * d57; + float t47 = 0.0625f * d38 + d48 + 5.0625f * d58; + + float t50 = 0.03125f * d01 + d11 + 7.59375f * d21; + float t51 = 0.03125f * d02 + d12 + 7.59375f * d22; + float t52 = 0.03125f * d03 + d13 + 7.59375f * d23; + float t53 = 0.03125f * d04 + d14 + 7.59375f * d24; + float t54 = 0.03125f * d05 + d15 + 7.59375f * d25; + float t55 = 0.03125f * d06 + d16 + 7.59375f * d26; + float t56 = 0.03125f * d07 + d17 + 7.59375f * d27; + float t57 = 0.03125f * d08 + d18 + 7.59375f * d28; + + float t60 = 0.015625f * d31 + d41 + 11.390625f * d51 + src_data_70; + float t61 = 0.015625f * d32 + d42 + 11.390625f * d52 + src_data_71; + float t62 = 0.015625f * d33 + d43 + 11.390625f * d53 + src_data_72; + float t63 = 0.015625f * d34 + d44 + 11.390625f * d54 + src_data_73; + float t64 = 0.015625f * d35 + d45 + 11.390625f * d55 + src_data_74; + float t65 = 0.015625f * d36 + d46 + 11.390625f * d56 + src_data_75; + float t66 = 0.015625f * d37 + d47 + 11.390625f * d57 + src_data_76; + float t67 = 0.015625f * d38 + d48 + 11.390625f * d58 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + float s15 = t41 - t42; + float s16 = t51 - t52; + float s17 = t61 - t62; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + float s25 = t43 - t44; + float s26 = t53 - t54; + float s27 = t63 - t64; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + float s35 = t45 - t46; + float s36 = t55 - t56; + float s37 = t56 - t66; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + float s45 = t41 + t42; + float s46 = t51 + t52; + float s47 = t61 + t62; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + float s55 = t43 + t44; + float s56 = t53 + t54; + float s57 = t63 + t64; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + float s65 = t45 + t46; + float s66 = t55 + t56; + float s67 = t65 + t66; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31; + float m04 = 0.0625f * s41 + s51 + 5.0625f * s61; + float m05 = 0.03125f * s11 + s21 + 7.59375f * s31; + float m06 = 0.015625f * s41 + s51 + 11.390625f * s61 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32; + float m14 = 0.0625f * s42 + s52 + 5.0625f * s62; + float m15 = 0.03125f * s12 + s22 + 7.59375f * s32; + float m16 = 0.015625f * s42 + s52 + 11.390625f * s62 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33; + float m24 = 0.0625f * s43 + s53 + 5.0625f * s63; + float m25 = 0.03125f * s13 + s23 + 7.59375f * s33; + float m26 = 0.015625f * s43 + s53 + 11.390625f * s63 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34; + float m34 = 0.0625f * s44 + s54 + 5.0625f * s64; + float m35 = 0.03125f * s14 + s24 + 7.59375f * s34; + float m36 = 0.015625f * s44 + s54 + 11.390625f * s64 + t37; + + float m40 = t40 + t41 + t42 + t43 + t44 + t45 + t46; + float m41 = 0.5f * s15 + s25 + 1.5f * s35; + float m42 = 0.25f * s45 + s55 + 2.25f * s65; + float m43 = 0.125f * s15 + s25 + 3.375f * s35; + float m44 = 0.0625f * s45 + s55 + 5.0625f * s65; + float m45 = 0.03125f * s15 + s25 + 7.59375f * s35; + float m46 = 0.015625f * s45 + s55 + 11.390625f * s65 + t47; + + float m50 = t50 + t51 + t52 + t53 + t54 + t55 + t56; + float m51 = 0.5f * s16 + s26 + 1.5f * s36; + float m52 = 0.25f * s46 + s56 + 2.25f * s66; + float m53 = 0.125f * s16 + s26 + 3.375f * s36; + float m54 = 0.0625f * s46 + s56 + 5.0625f * s66; + float m55 = 0.03125f * s16 + s26 + 7.59375f * s36; + float m56 = 0.015625f * s46 + s56 + 11.390625f * s66 + t57; + + float m60 = t60 + t61 + t62 + t63 + t64 + t65 + t66; + float m61 = 0.5f * s17 + s27 + 1.5f * s37; + float m62 = 0.25f * s47 + s57 + 2.25f * s67; + float m63 = 0.125f * s17 + s27 + 3.375f * s37; + float m64 = 0.0625f * s47 + s57 + 5.0625f * s67; + float m65 = 0.03125f * s17 + s27 + 7.59375f * s37; + float m66 = 0.015625f * s47 + s57 + 11.390625f * s67 + t67; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + (dst_data + i + 4 * C4NUM)[0] = m04 + bias_data[i]; + (dst_data + i + 5 * C4NUM)[0] = m05 + bias_data[i]; + (dst_data + i + 6 * C4NUM)[0] = m06 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 4 * C4NUM)[0] = m14 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 5 * C4NUM)[0] = m15 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 6 * C4NUM)[0] = m16 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 4 * C4NUM)[0] = m24 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 5 * C4NUM)[0] = m25 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 6 * C4NUM)[0] = m26 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 4 * C4NUM)[0] = m34 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 5 * C4NUM)[0] = m35 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 6 * C4NUM)[0] = m36 + bias_data[i]; + + (dst_data + i + 4 * dst_step * C4NUM)[0] = m40 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + C4NUM)[0] = m41 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 2 * C4NUM)[0] = m42 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 3 * C4NUM)[0] = m43 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 4 * C4NUM)[0] = m44 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 5 * C4NUM)[0] = m45 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 6 * C4NUM)[0] = m46 + bias_data[i]; + + (dst_data + i + 5 * dst_step * C4NUM)[0] = m50 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + C4NUM)[0] = m51 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 2 * C4NUM)[0] = m52 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 3 * C4NUM)[0] = m53 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 4 * C4NUM)[0] = m54 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 5 * C4NUM)[0] = m55 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 6 * C4NUM)[0] = m56 + bias_data[i]; + + (dst_data + i + 6 * dst_step * C4NUM)[0] = m60 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + C4NUM)[0] = m61 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 2 * C4NUM)[0] = m62 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 3 * C4NUM)[0] = m63 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 4 * C4NUM)[0] = m64 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 5 * C4NUM)[0] = m65 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 6 * C4NUM)[0] = m66 + bias_data[i]; + } +#endif +} + +// Reference to the paper "Fast Algorithms for Convolutional Neural Networks" +// Utilize cost model to compute performance gain. +// If the gain is greater than got from Im2col, winograd algorithm will be chosen. +int SelectOutputUnit(ConvParameter *conv_param) { + auto input_batch = conv_param->input_batch_; + auto kernel_h = conv_param->kernel_h_; + auto kernel_w = conv_param->kernel_w_; + auto in_channel = conv_param->input_channel_; + auto out_h = conv_param->output_h_; + auto out_w = conv_param->output_w_; + auto out_channel = conv_param->output_channel_; + int out_plane = out_h * out_w; + + int max_unit = ::sqrt((float)(out_plane)); + max_unit = max_unit > MIN_UNIT ? max_unit : MIN_UNIT; + max_unit = max_unit < MAX_UNIT ? max_unit : MAX_UNIT; + int output_unit = 1; + float ratio = 0.0f; + // cost of conventional convolution multiplications + float ori_cost = out_plane * out_channel * in_channel * kernel_h * kernel_w; + + for (int u = MIN_UNIT; u < max_unit; u++) { + auto input_unit = u + kernel_h - 1; + if (input_unit != 4 && input_unit != 8) { + continue; + } + // don't count filter transform cost, because it can be processed once offline. + float input_trans_unit_cost = 2 * input_unit * input_unit * input_unit * in_channel; + float gemm_unit_cost = input_unit * input_unit * in_channel * out_channel; + float output_trans_unit_cost = input_unit * u * (u + input_unit) * out_channel; + // equation (23) in papar + float winograd_cost = (input_trans_unit_cost + gemm_unit_cost + output_trans_unit_cost) * + (UP_DIV(out_w, u) * (UP_DIV(out_h, u))) * input_batch; + float reduce_rate = ori_cost / winograd_cost; + if (reduce_rate > ratio && reduce_rate > 1) { + ratio = reduce_rate; + output_unit = u; + } + } + // If output_unit is 1, then it is conventional convolution + return output_unit; +} + +InputTransformUnitFunc GetInputTransFunc(int input_unit) { + if (input_unit == 4) { + return InputTransform4x4Unit; + } else if (input_unit == 8) { + return InputTransform8x8Unit; + } else { + printf("Only support 4 or 8 for input unit."); + return nullptr; + } +} + +OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit) { + if (input_unit == 4 && output_unit == 2) { + return OutputTransform4x2Unit; + } else if (input_unit == 4 && output_unit == 3) { + return OutputTransform4x3Unit; + } else if (input_unit == 8) { + return outputTransformUnit[output_unit]; + } else { + printf("."); + return nullptr; + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h new file mode 100644 index 0000000000..d7a7b7a69c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include "src/runtime/kernel/arm/nnacl/matrix_table.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +using InputTransformUnitFunc = void (*)(const float *src_data, float *dst_data, int src_step, int dst_step); +using OutputTransformUnitFunc = void (*)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step); + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step); + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step); + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +int SelectOutputUnit(ConvParameter *conv_param); + +InputTransformUnitFunc GetInputTransFunc(int input_unit); + +OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/zeroslike.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/zeroslike.cc new file mode 100644 index 0000000000..697d77e6ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/zeroslike.cc @@ -0,0 +1,21 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/nnacl/zeroslike.h" +#include +#include + +void ApproximateZerosLike(float *input, float *output, int number) { memset(output, 0.0, number * sizeof(float)); } + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/zeroslike.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/zeroslike.h new file mode 100644 index 0000000000..8948ad39fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/zeroslike.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ZEROSLIKE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ZEROSLIKE_H_ + +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +void ApproximateZerosLike(float *input, float *output, int number); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ZEROSLIKE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt new file mode 100644 index 0000000000..b090065ca1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt @@ -0,0 +1,13 @@ +set(OPENCL_KERNEL_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_opencl_kernel.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/arithmetic.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/convolution.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/depthwise_conv2d.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/pooling2d.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/matmul.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/softmax.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/concat.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/conv2d_transpose.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/transpose.cc + ) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl new file mode 100644 index 0000000000..e166e699c5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl @@ -0,0 +1,61 @@ +#define FLT half +#define FLT4 half4 +#define FLT16 half16 +#define READ_IMAGE read_imageh +#define WRITE_IMAGE write_imageh +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, + __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, + int4 src_size, int4 dst_size) { + int h = get_global_id(0); + int kh = h % 2; + int src_h = h / 2; + src_h = src_h * 2; + int w = get_global_id(1); + int kw = w % 2; + int src_w = w / 2; + src_w = src_w * 2; + int co = get_global_id(2); + if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; + FLT4 r0 = (FLT4)(0.f); + FLT4 r1 = (FLT4)(0.f); + FLT4 r2 = (FLT4)(0.f); + FLT4 r3 = (FLT4)(0.f); + int base_w = (co * 4 + kh + kw * 2) * src_size.z; + for (int ci = 0; ci < src_size.z; ++ci) { + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); + FLT16 weight_cache = weight[base_w++]; + r0 += x0.x * weight_cache.s0123; + r0 += x0.y * weight_cache.s4567; + r0 += x0.z * weight_cache.s89ab; + r0 += x0.w * weight_cache.scdef; + + r1 += x1.x * weight_cache.s0123; + r1 += x1.y * weight_cache.s4567; + r1 += x1.z * weight_cache.s89ab; + r1 += x1.w * weight_cache.scdef; + + r2 += x2.x * weight_cache.s0123; + r2 += x2.y * weight_cache.s4567; + r2 += x2.z * weight_cache.s89ab; + r2 += x2.w * weight_cache.scdef; + + r3 += x3.x * weight_cache.s0123; + r3 += x3.y * weight_cache.s4567; + r3 += x3.z * weight_cache.s89ab; + r3 += x3.w * weight_cache.scdef; + } + FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(co, 0)); + r0 += bias_val; + r1 += bias_val; + r2 += bias_val; + r3 += bias_val; + + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl new file mode 100644 index 0000000000..2725ca9261 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl @@ -0,0 +1,76 @@ +#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define ACCUM_FLT4 half4 +#define FLT half +#define FLT2 half2 +#define FLT3 half3 +#define FLT4 half4 +#define TO_FLT4 convert_half4 +#define TO_ACCUM_TYPE convert_half4 +#define TO_ACCUM_FLT convert_half +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef +__constant sampler_t smp_edge = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void DepthwiseConv2d_NC4HW4(__global FLT4 *src_data, __global FLT4 *filters, __global FLT4 *biases, + float relu_clip1, __global FLT4 *dst_data, int2 kernel_size, int2 stride, + int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filters[fx_c]; + FLT4 src_final = src_data[(((Z)*src_size.y + (y_c)) * src_size.x + (x_c))]; + r += TO_ACCUM_TYPE(src_final * f); + } + fx_c++; + } + } + FLT4 bias_val = biases[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res0; +} + +__kernel void DepthwiseConv2d_NHWC4(__global FLT4 *src_data, __global FLT4 *filters, __global FLT4 *biases, + float relu_clip1, __global FLT4 *dst_data, int2 kernel_size, int2 stride, + int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filters[fx_c]; + FLT4 src_final = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_ACCUM_TYPE(src_final * f); + } + fx_c++; + } + } + FLT4 bias_val = biases[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl new file mode 100644 index 0000000000..c121f824bd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl @@ -0,0 +1,32 @@ +#define FLT4 half4 +#define FLT16 half16 +#define READ_IMAGE read_imageh +#define WRITE_IMAGE write_imageh +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, + __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { + int2 gid = (int2)(get_global_id(0), get_global_id(1)); + int2 lid = (int2)(get_local_id(0), get_local_id(1)); + FLT4 result = (FLT4)(0.0f); + bool inside = gid.x < offset_co.y; + for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { + FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, 0)); + FLT16 w = weight[gid.x + i * offset_co.y]; + result.x += dot(v, w.s0123); + result.y += dot(v, w.s4567); + result.z += dot(v, w.s89ab); + result.w += dot(v, w.scdef); + } + __local FLT4 temp[64][4]; + temp[lid.x][lid.y] = result; + barrier(CLK_LOCAL_MEM_FENCE); + if (lid.y == 0 && inside) { + result += temp[lid.x][1]; + result += temp[lid.x][2]; + result += temp[lid.x][3]; + if (has_bias != 0) { + result += READ_IMAGE(bias, smp_zero, (int2)(gid.x, 0)); + } + WRITE_IMAGE(output, (int2)(gid.x, 0), result); + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl new file mode 100644 index 0000000000..ebc3db633f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl @@ -0,0 +1,44 @@ +#define FLT half +#define FLT4 half4 +#define READ_IMAGE read_imageh +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void transpose(__read_only image2d_t src_data, __global float4 *dst_data, int2 HW, int2 C) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= HW.y || Y >= C.y) { + return; + } + FLT4 result[4]; + result[0] = (FLT4)(0.0f); + result[1] = (FLT4)(0.0f); + result[2] = (FLT4)(0.0f); + result[3] = (FLT4)(0.0f); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); + result[0].x = x0.x; + result[0].y = x1.x; + result[0].z = x2.x; + result[0].w = x3.x; + + result[1].x = x0.y; + result[1].y = x1.y; + result[1].z = x2.y; + result[1].w = x3.y; + + result[2].x = x0.z; + result[2].y = x1.z; + result[2].z = x2.z; + result[2].w = x3.z; + + result[3].x = x0.w; + result[3].y = x1.w; + result[3].z = x2.w; + result[3].w = x3.w; + + dst_data[4 * Y * HW.y + X] = result[0]; + dst_data[(4 * Y + 1) * HW.y + X] = result[1]; + dst_data[(4 * Y + 2) * HW.y + X] = result[2]; + dst_data[(4 * Y + 3) * HW.y + X] = result[3]; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl new file mode 100644 index 0000000000..4e9a8422be --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl @@ -0,0 +1,51 @@ +__kernel void ElementAdd(__global float *input_a, __global float *input_b, __global float *output, + const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] + input_b[idx]; +} + +__kernel void ElementSub(__global float *input_a, __global float *input_b, __global float *output, + const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] - input_b[idx]; +} + +__kernel void ElementMul(__global float *input_a, __global float *input_b, __global float *output, + const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] * input_b[idx]; +} + +__kernel void ElementDiv(__global float *input_a, __global float *input_b, __global float *output, + const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] * input_b[idx]; +} + +__kernel void BoardcastAdd(__global float *input_a, float input_b, __global float *output, const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] + input_b; +} + +__kernel void BoardcastSub(__global float *input_a, float input_b, __global float *output, const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] - input_b; +} + +__kernel void BoardcastMul(__global float *input_a, float input_b, __global float *output, const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] * input_b; +} + +__kernel void BoardcastDiv(__global float *input_a, float input_b, __global float *output, const unsigned int n) { + int idx = get_global_id(0); + if (idx >= n) return; + output[idx] = input_a[idx] * input_b; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl new file mode 100644 index 0000000000..12f75438f9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl @@ -0,0 +1,15 @@ +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; + +__kernel void ElementAdd(__read_only image2d_t *input_a, __read_only image2d_t *input_b, __write_only image2d_t *output, + const int4 output_shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) return; + + if (idx >= n) return; + float4 a = read_imagef(input_a, smp_none, (int2)(X, Y * output_shape.w + Z)); + float4 b = read_imagef(input_b, smp_none, (int2)(X, Y * output_shape.w + Z)); + src = a + b; + write_imagef(output, (int2)(0, 0), src); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl new file mode 100644 index 0000000000..0e60a4ca1e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl @@ -0,0 +1,65 @@ +__kernel void AvgPooling2d(__global float4 *input, __global float4 *output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 r = (float4)(0.0f); + float window_size = 0.0f; + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + bool outside_x = x_c < 0 || x_c >= input_shape.x; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + bool outside = outside_x || y_c < 0 || y_c >= input_shape.y; + r += !outside ? input[(input_shape.y * x_c + y_c) * output_shape.w + Z] : (float4)(0.0f); + window_size += !outside ? 1.0f : 0.0f; + } + } + float4 result = convert_float4(r / window_size); + output[(output_shape.y * X + Y) * output_shape.w + Z] = result; +} + +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void AvgPooling2dImage2d(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, + const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 r = (float4)(0.0f); + float window_size = 0.0f; + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + bool outside_x = x_c < 0 || x_c >= input_shape.x; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + bool outside = outside_x || y_c < 0 || y_c >= input_shape.y; + + r += read_imagef(input, smp_zero, (int2)(x_c, y_c * input_shape.w + Z)); + window_size += !outside ? 1.0f : 0.0f; + } + } + float4 result = convert_float4(r / window_size); + write_imagef(output, (int2)(X, Y * output_shape.w + Z), result); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl new file mode 100644 index 0000000000..d457a3e4fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl @@ -0,0 +1,54 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void Concat(__global float *input0, __global float *input1, __global float *output, const int4 input_shape0, + const int4 input_shape1, const int4 output_shape, const int axis) { + uint oh = get_global_id(0); + uint ow = get_global_id(1); + uint oc = get_global_id(2); + uint index_output; + uint input_idx; + if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) { + return; + } + if (axis == 3) { + index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc; + if (oc < input_shape0.w) { + input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc; + output[index_output] = input0[input_idx]; + } else if ((input_shape0.w <= oc) && oc < (input_shape0.w + input_shape1.w)) { + input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w); + output[index_output] = input1[input_idx]; + } else { + output[index_output] = 0; + } + } +} + +__kernel void Concat3input(__global float *input0, __global float *input1, __global float *input2, + __global float *output, const int4 input_shape0, const int4 input_shape1, + const int4 input_shape2, const int4 output_shape, const int axis) { + uint oh = get_global_id(0); + uint ow = get_global_id(1); + uint oc = get_global_id(2); + uint index_output; + uint input_idx; + if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) { + return; + } + index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc; + if (oc < (input_shape0.w + input_shape1.w)) { + if (oc < input_shape0.w) { + input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc; + output[index_output] = input0[input_idx]; + } else { + input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w); + output[index_output] = input1[input_idx]; + } + } else { + if ((input_shape0.w + input_shape1.w + input_shape2.w) <= oc) { + output[index_output] = 0; + } else { + input_idx = (input_shape2.z * oh + ow) * input_shape2.w + (oc - input_shape0.w - input_shape1.w); + output[index_output] = input2[input_idx]; + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl new file mode 100644 index 0000000000..c4b9057980 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl @@ -0,0 +1,61 @@ +#define FLT float +#define FLT4 float4 +#define FLT16 float16 +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, + __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, + int4 src_size, int4 dst_size) { + int h = get_global_id(0); + int kh = h % 2; + int src_h = h / 2; + src_h = src_h * 2; + int w = get_global_id(1); + int kw = w % 2; + int src_w = w / 2; + src_w = src_w * 2; + int co = get_global_id(2); + if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; + FLT4 r0 = (FLT4)(0.f); + FLT4 r1 = (FLT4)(0.f); + FLT4 r2 = (FLT4)(0.f); + FLT4 r3 = (FLT4)(0.f); + int base_w = (co * 4 + kh + kw * 2) * src_size.z; + for (int ci = 0; ci < src_size.z; ++ci) { + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); + FLT16 weight_cache = weight[base_w++]; + r0 += x0.x * weight_cache.s0123; + r0 += x0.y * weight_cache.s4567; + r0 += x0.z * weight_cache.s89ab; + r0 += x0.w * weight_cache.scdef; + + r1 += x1.x * weight_cache.s0123; + r1 += x1.y * weight_cache.s4567; + r1 += x1.z * weight_cache.s89ab; + r1 += x1.w * weight_cache.scdef; + + r2 += x2.x * weight_cache.s0123; + r2 += x2.y * weight_cache.s4567; + r2 += x2.z * weight_cache.s89ab; + r2 += x2.w * weight_cache.scdef; + + r3 += x3.x * weight_cache.s0123; + r3 += x3.y * weight_cache.s4567; + r3 += x3.z * weight_cache.s89ab; + r3 += x3.w * weight_cache.scdef; + } + FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(co, 0)); + r0 += bias_val; + r1 += bias_val; + r2 += bias_val; + r3 += bias_val; + + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl new file mode 100644 index 0000000000..43a5a0306b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl @@ -0,0 +1,150 @@ +#define CI_TILE 4 +#define CO_TILE 4 +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +// #define __global +// #pragma OPENCL EXTENSION cl_arm_printf : enable +__kernel void convolution_NHWC_OHWI(__global float *input, __global float *weight, __global float *bias, + __global float *output, + const int4 input_shape, // NHWC + const int4 output_shape, // NHWC + const int4 kernel_stride, // kernelHW_strideHW + const int4 pad) { + int ow = get_global_id(0); + int oh = get_global_id(1); + int co_slice = get_global_id(2); + + int CI = input_shape.w, IH = input_shape.y, IW = input_shape.z; + int CO = output_shape.w, OH = output_shape.y, OW = output_shape.z; + int KH = kernel_stride.x, KW = kernel_stride.y; + int strideH = kernel_stride.z, strideW = kernel_stride.w; + int padTop = pad.x, padLeft = pad.z; + int CI_SLICES = UP_DIV(CI, CI_TILE); + int CO_SLICES = UP_DIV(CO, CO_TILE); + + if (oh >= OH || ow >= OW || co_slice >= CO_SLICES) return; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + for (int kh = 0; kh < KH; ++kh) { + int ih = kh + oh * strideH - padTop; + for (int kw = 0; kw < KW; ++kw) { + int iw = kw + ow * strideW - padLeft; + for (int ci_slice = 0; ci_slice < CI_SLICES; ++ci_slice) { + for (int ci_inner = 0; ci_inner < CI_TILE; ++ci_inner) { + int ci = ci_slice * CI_TILE + ci_inner; + if (ci >= CI) break; + + int input_idx = ih * IW * CI + iw * CI + ci; + float value = 0; + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + value = 0; + else + value = input[input_idx]; + + int CO_OFFSET = KH * KW * CI; + int weight_idx = (co_slice * CO_TILE) * CO_OFFSET + kh * KW * CI + kw * CI + ci; + acc.x += weight[weight_idx + 0 * CO_OFFSET] * value; + acc.y += weight[weight_idx + 1 * CO_OFFSET] * value; + acc.z += weight[weight_idx + 2 * CO_OFFSET] * value; + acc.w += weight[weight_idx + 3 * CO_OFFSET] * value; + } + } + } + } + int output_idx = oh * OW * CO + ow * CO + (co_slice * CO_TILE); + if (co_slice < CO_SLICES - 1 || CO % CO_TILE == 0) { + output[output_idx + 0] = acc.x + bias[co_slice * CO_TILE + 0]; + output[output_idx + 1] = acc.y + bias[co_slice * CO_TILE + 1]; + output[output_idx + 2] = acc.z + bias[co_slice * CO_TILE + 2]; + output[output_idx + 3] = acc.w + bias[co_slice * CO_TILE + 3]; + } else if (CO % CO_TILE == 1) { + output[output_idx + 0] = acc.x + bias[co_slice * CO_TILE + 0]; + } else if (CO % CO_TILE == 2) { + output[output_idx + 0] = acc.x + bias[co_slice * CO_TILE + 0]; + output[output_idx + 1] = acc.y + bias[co_slice * CO_TILE + 1]; + } else if (CO % CO_TILE == 3) { + output[output_idx + 0] = acc.x + bias[co_slice * CO_TILE + 0]; + output[output_idx + 1] = acc.y + bias[co_slice * CO_TILE + 1]; + output[output_idx + 2] = acc.z + bias[co_slice * CO_TILE + 2]; + } +} + +// #pragma OPENCL EXTENSION cl_khr_fp16 : enable +// #define FLT4 half4 +#define FLT4 float4 +__kernel void convolution_NHWC4_OHWIIO_float8(__global FLT4 *input, __global FLT4 *weight, __global FLT4 *bias, + __global FLT4 *output, + const int4 input_shape, // NHWC + const int4 output_shape, // NHWC + const int4 kernel_stride, // kernelHW_strideHW + const int4 pad) { + int oh = get_global_id(0); // [0, OH) + int ow = get_global_id(1); // [0, OW) + int co_slice = get_global_id(2); // [0, UP_DIV(CO, CO_TILE) ) + + int CI = input_shape.w, IH = input_shape.y, IW = input_shape.z; + int CO = output_shape.w, OH = output_shape.y, OW = output_shape.z; + int CI_SLICES = UP_DIV(CI, CI_TILE); + int CO_SLICES = UP_DIV(CO, CO_TILE); + int KH = kernel_stride.x, KW = kernel_stride.y; + int strideH = kernel_stride.z, strideW = kernel_stride.w; + int padTop = pad.x, padLeft = pad.z; + + if (oh >= OH || ow >= OW || 2 * co_slice >= CO_SLICES) return; + if (2 * co_slice + 1 >= CO_SLICES) { + FLT4 out0_c4 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + __global FLT4 *w0_ic1_oc4 = weight + (2 * co_slice + 0) * KH * KW * CI_SLICES * CI_TILE; + for (int kh = 0; kh < KH; ++kh) { + int ih = kh + oh * strideH - padTop; + for (int kw = 0; kw < KW; ++kw) { + int iw = kw + ow * strideW - padLeft; + if (ih >= 0 && ih < IH && iw >= 0 && iw < IW) { + for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) { + FLT4 in_c4 = input[ih * IW * CI_SLICES + iw * CI_SLICES + ci_slice]; + out0_c4 += w0_ic1_oc4[0] * in_c4.x; + out0_c4 += w0_ic1_oc4[1] * in_c4.y; + out0_c4 += w0_ic1_oc4[2] * in_c4.z; + out0_c4 += w0_ic1_oc4[3] * in_c4.w; + w0_ic1_oc4 += 4; + } + } else { + w0_ic1_oc4 += 4 * CI_SLICES; + } + } + } + output[oh * OW * CO_SLICES + ow * CO_SLICES + 2 * co_slice + 0] = out0_c4 + bias[2 * co_slice + 0]; + } else { + FLT4 out0_c4 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + FLT4 out1_c4 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + __global FLT4 *w0_ic1_oc4 = weight + (2 * co_slice + 0) * KH * KW * CI_SLICES * CI_TILE; + __global FLT4 *w1_ic1_oc4 = weight + (2 * co_slice + 1) * KH * KW * CI_SLICES * CI_TILE; + for (int kh = 0; kh < KH; ++kh) { + int ih = kh + oh * strideH - padTop; + for (int kw = 0; kw < KW; ++kw) { + int iw = kw + ow * strideW - padLeft; + if (ih >= 0 && ih < IH && iw >= 0 && iw < IW) { + int idx = ih * IW * CI_SLICES + iw * CI_SLICES; + for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) { + FLT4 in_c4 = input[idx + ci_slice]; + + out0_c4 += w0_ic1_oc4[0] * in_c4.x; + out0_c4 += w0_ic1_oc4[1] * in_c4.y; + out0_c4 += w0_ic1_oc4[2] * in_c4.z; + out0_c4 += w0_ic1_oc4[3] * in_c4.w; + w0_ic1_oc4 += 4; + + out1_c4 += w1_ic1_oc4[0] * in_c4.x; + out1_c4 += w1_ic1_oc4[1] * in_c4.y; + out1_c4 += w1_ic1_oc4[2] * in_c4.z; + out1_c4 += w1_ic1_oc4[3] * in_c4.w; + w1_ic1_oc4 += 4; + } + } else { + w0_ic1_oc4 += 4 * CI_SLICES; + w1_ic1_oc4 += 4 * CI_SLICES; + } + } + } + output[oh * OW * CO_SLICES + ow * CO_SLICES + 2 * co_slice + 0] = out0_c4 + bias[2 * co_slice + 0]; + output[oh * OW * CO_SLICES + ow * CO_SLICES + 2 * co_slice + 1] = out1_c4 + bias[2 * co_slice + 1]; + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl new file mode 100644 index 0000000000..f7944e9a2b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl @@ -0,0 +1,198 @@ +#ifdef ENABLE_FP16 +#define FLT half +#define FLT4 half4 +#define TO_FLT4 convert_half4 +#else +#define FLT float +#define FLT4 float4 +#define TO_FLT4 convert_float4 +#endif +__constant sampler_t sampler_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void DepthwiseConv2d_IMG_NC4HW4(__read_only image2d_t src_data, __global FLT4 *filter, __global FLT4 *bias, + float relu_clip1, __write_only image2d_t dst_data, int2 kernel_size, + int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + // FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))]; + FLT4 src_final = read_imagef(src_data, sampler_zero, (int2)(x_c, (Z * src_size.y + y_c))); + r += TO_FLT4(src_final * f); + } + fx_c++; + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + // dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0; + write_imagef(dst_data, (int2)(X, (Z * dst_size.y + Y)), res0); +} + +__kernel void DepthwiseConv2d_IMG_NHWC4(__read_only image2d_t src_data, __global FLT4 *filter, __global FLT4 *bias, + float relu_clip1, __write_only image2d_t dst_data, int2 kernel_size, + int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + // FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + FLT4 src_final = read_imagef(src_data, sampler_zero, (int2)(Z + x_c * src_size.z, y_c)); + r += TO_FLT4(src_final * f); + } + fx_c++; + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + // dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; + write_imagef(dst_data, (int2)(X * dst_size.z + Z, Y), res0); +} + +__kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__read_only image2d_t src_data, __global FLT4 *filter, __global FLT4 *bias, + float relu_clip1, __write_only image2d_t dst_data, int2 kernel_size, + int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z; + { + int y_c = y_offseted; + bool outside_y = y_c < 0 || y_c >= src_size.y; + { + int x_c = x_offseted; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + // FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + FLT4 src_final = read_imagef(src_data, sampler_zero, (int2)(Z, (y_c * src_size.x + x_c) * src_size.z)); + r += TO_FLT4(src_final * f); + } + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + // dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; + write_imagef(dst_data, (int2)(Z, (Y * dst_size.x + X) * dst_size.z), res0); +} +__kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *src_data, __global FLT4 *filter, __global FLT4 *bias, + float relu_clip1, __global FLT4 *dst_data, int2 kernel_size, int2 stride, + int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + FLT4 src_final = src_data[(((Z)*src_size.y + (y_c)) * src_size.x + (x_c))]; + r += TO_FLT4(src_final * f); + } + fx_c++; + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res0; +} + +__kernel void DepthwiseConv2d_BUF_NHWC4(__global FLT4 *src_data, __global FLT4 *filter, __global FLT4 *bias, + float relu_clip1, __global FLT4 *dst_data, int2 kernel_size, int2 stride, + int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + FLT4 src_final = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_FLT4(src_final * f); + } + fx_c++; + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; +} + +__kernel void DepthwiseConv2d_BUF_NHWC4_1x1(__global FLT4 *src_data, __global FLT4 *filter, __global FLT4 *bias, + float relu_clip1, __global FLT4 *dst_data, int2 kernel_size, int2 stride, + int2 padding, int2 dilation, int4 src_size, int4 dst_size) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z; + { + int y_c = y_offseted; + bool outside_y = y_c < 0 || y_c >= src_size.y; + { + int x_c = x_offseted; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filter[fx_c]; + FLT4 src_final = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_FLT4(src_final * f); + } + } + } + FLT4 bias_val = bias[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl new file mode 100644 index 0000000000..1dcc884e0e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl @@ -0,0 +1,32 @@ +#define FLT4 float4 +#define FLT16 float16 +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, + __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { + int2 gid = (int2)(get_global_id(0), get_global_id(1)); + int2 lid = (int2)(get_local_id(0), get_local_id(1)); + FLT4 result = (FLT4)(0.0f); + bool inside = gid.x < offset_co.y; + for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { + FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, 0)); + FLT16 w = weight[gid.x + i * offset_co.y]; + result.x += dot(v, w.s0123); + result.y += dot(v, w.s4567); + result.z += dot(v, w.s89ab); + result.w += dot(v, w.scdef); + } + __local FLT4 temp[64][4]; + temp[lid.x][lid.y] = result; + barrier(CLK_LOCAL_MEM_FENCE); + if (lid.y == 0 && inside) { + result += temp[lid.x][1]; + result += temp[lid.x][2]; + result += temp[lid.x][3]; + if (has_bias != 0) { + result += READ_IMAGE(bias, smp_zero, (int2)(gid.x, 0)); + } + WRITE_IMAGE(output, (int2)(gid.x, 0), result); + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl new file mode 100644 index 0000000000..f65e3e06d6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl @@ -0,0 +1,67 @@ +__kernel void MaxPooling2d_BUF(__global float4 *input, __global float4 *output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 maximum = (float4)(-10000.0f); + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + if (x_c < 0 || x_c >= input_shape.x) { + continue; + } + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + if (y_c < 0 || y_c >= input_shape.y) { + continue; + } + float4 src = input[(input_shape.y * x_c + y_c) * input_shape.w + Z]; + maximum = max(src, maximum); + } + } + output[(output_shape.y * X + Y) * output_shape.w + Z] = maximum; +} + +__constant sampler_t sample_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; + +__kernel void MaxPooling2d_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 maximum = (float4)(-10000.0f); + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + if (x_c < 0 || x_c >= input_shape.x) { + continue; + } + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + if (y_c < 0 || y_c >= input_shape.y) { + continue; + } + float4 src = read_imagef(input, sample_none, (int2)(x_c, y_c * input_shape.w + Z)); + maximum = max(src, maximum); + } + } + write_imagef(output, (int2)(X, Y * output_shape.w + Z), maximum); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl new file mode 100644 index 0000000000..f1a5c69d94 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl @@ -0,0 +1,32 @@ +#define SLICES 4 + +int DivideRoundUp(int n, int div) { + int q = n / div; + return n % div == 0 ? q : q + 1; +} + +__kernel void SoftMax(__global float4 *input, __global float4 *output, const int4 input_shape) { + int X = get_global_id(0); // width + int Y = get_global_id(1); // height + int H = input_shape.y; + int W = input_shape.z; + int C = input_shape.w; + + if (X >= W || Y >= H) return; + + float sum = 0.0f; + for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { + float4 t = input[(Y * W + X * H) * C + d]; + sum += exp(t.x); + if (d * 4 + 1 < C) sum += exp(t.y); + if (d * 4 + 2 < C) sum += exp(t.z); + if (d * 4 + 3 < C) sum += exp(t.w); + } + + for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { + float4 t = input[(Y * W + X * H) * C + d]; + t = exp(t) / sum; + float4 result = convert_float4(t); + output[(Y * W + X * H) * C + d] = result; + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl new file mode 100644 index 0000000000..08069ee80f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl @@ -0,0 +1,44 @@ +#define FLT float +#define FLT4 float4 +#define READ_IMAGE read_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void transpose(__read_only image2d_t src_data, __global float4 *dst_data, int2 HW, int2 C) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= HW.y || Y >= C.y) { + return; + } + FLT4 result[4]; + result[0] = (FLT4)(0.0f); + result[1] = (FLT4)(0.0f); + result[2] = (FLT4)(0.0f); + result[3] = (FLT4)(0.0f); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); + result[0].x = x0.x; + result[0].y = x1.x; + result[0].z = x2.x; + result[0].w = x3.x; + + result[1].x = x0.y; + result[1].y = x1.y; + result[1].z = x2.y; + result[1].w = x3.y; + + result[2].x = x0.z; + result[2].y = x1.z; + result[2].z = x2.z; + result[2].w = x3.z; + + result[3].x = x0.w; + result[3].y = x1.w; + result[3].z = x2.w; + result[3].w = x3.w; + + dst_data[4 * Y * HW.y + X] = result[0]; + dst_data[(4 * Y + 1) * HW.y + X] = result[1]; + dst_data[(4 * Y + 2) * HW.y + X] = result[2]; + dst_data[(4 * Y + 3) * HW.y + X] = result[3]; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/image_format.h b/mindspore/lite/src/runtime/kernel/opencl/image_format.h new file mode 100644 index 0000000000..4987afe49d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/image_format.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_IMAGE_FORMAT_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_IMAGE_FORMAT_H_ + +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore { +namespace kernel { + +/** + * MindSpore to OpenCL channel order. + * @param num_channels + * @return opencl_channels + */ +cl_channel_order ToChannelOrder(int num_channels) { + switch (num_channels) { + case 1: + return CL_R; + case 2: + return CL_RG; + case 3: + return CL_RGB; + case 4: + return CL_RGBA; + default: + return -1; + } +} + +/** + * MindSpore image channel type to OpenCL channel data type. + * @param data_type + * @return opencl_data_type + */ +cl_channel_type ToImageChannelType(TypeId data_type) { + switch (data_type) { + case kNumberTypeFloat32: + return CL_FLOAT; + case kNumberTypeFloat16: + return CL_HALF_FLOAT; + default: + return -1; + } +} +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_IMAGE_FORMAT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc new file mode 100644 index 0000000000..8cd3113797 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/arithmetic.h" +#include +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; + +namespace mindspore::kernel { + +std::vector ArithmeticOpenCLKernel::InitGlobalSize() const { + const size_t global_x = outputs_[0]->Width(); + const size_t global_y = outputs_[0]->Height(); + const size_t global_z = UP_ROUND_DIV(outputs_[0]->Channel(), 4); + std::vector global = {global_x, global_y, global_z}; + return global; +} + +void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() { + global_size_ = InitGlobalSize(); + int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())()); + local_size_ = GetCommonLocalSize(global_size_, max_work_group_size); + global_size_ = GetCommonGlobalSize(local_size_, global_size_); +} + +void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() { + uint32_t element_num = outputs_[0]->ElementsC4Num(); + global_size_ = {element_num}; +} + +int ArithmeticOpenCLKernel::Init() { + runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); + std::string element_name; + std::string boardcast_name; + + if (inputs_[1]->TensorType() == schema::NodeType_ValueNode && inputs_[1]->Data() != nullptr) { + element_flag_ = false; + } else { + element_flag_ = true; + } + + switch (opParameter->type_) { + case PrimitiveType_Mul: + element_name = "ElementMul"; + boardcast_name = "BoardcastMul"; + break; + case PrimitiveType_Add: + element_name = "ElementAdd"; + boardcast_name = "BoardcastAdd"; + break; + case PrimitiveType_Sub: + element_name = "ElementSub"; + boardcast_name = "BoardcastSub"; + break; + case PrimitiveType_Div: + element_name = "ElementDiv"; + boardcast_name = "BoardcastDiv"; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << opParameter->type_; + break; + } + +#ifdef PROGRAM_WITH_IL + runtime_->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::string program_name = "Arithmetic"; + std::set build_options; + std::string source = arithmetic_buffer_source_fp32; + runtime_->LoadSource(program_name, source); + + if (element_flag_) { + runtime_->BuildKernel(kernel_, program_name, element_name, build_options); + MS_LOG(DEBUG) << element_name << " Init Done!"; + } else { + runtime_->BuildKernel(kernel_, program_name, boardcast_name, build_options); + MS_LOG(DEBUG) << boardcast_name << " Init Done!"; + } +#endif + outputs_[0]->SetFormat(schema::Format_NHWC4); + return 0; +} + +int ArithmeticOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + auto runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); + BufferGetWorkGroupSize(); + + int arg_idx = 0; + uint32_t element_num = outputs_[0]->ElementsC4Num(); + + runtime_->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data()); + if (element_flag_) { + runtime_->SetKernelArg(kernel_, arg_idx++, inputs_[1]->Data()); + } else { + runtime_->SetKernelArg(kernel_, arg_idx++, static_cast(inputs_[1]->Data())[0]); + } + runtime_->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data()); + runtime_->SetKernelArg(kernel_, arg_idx++, element_num); + + runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new ArithmeticOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; + return nullptr; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Arithmetic"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Mul, OpenCLArithmeticKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Add, OpenCLArithmeticKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Sub, OpenCLArithmeticKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Div, OpenCLArithmeticKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h new file mode 100644 index 0000000000..37143775b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_ARITHMETIC_H_ + +#include +#include "src/runtime/kernel/arm/fp32/arithmetic.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" + +namespace mindspore::kernel { + +class ArithmeticOpenCLKernel : public ArithmeticCPUKernel { + public: + explicit ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {} + ~ArithmeticOpenCLKernel() override{}; + + int Init() override; + int Run() override; + + private: + std::vector InitGlobalSize() const; + void Image2dGetWorkGroupSize(); + void BufferGetWorkGroupSize(); + + cl::Kernel kernel_; + lite::opencl::OpenCLRuntime *runtime_; + bool element_flag_{true}; + + std::vector local_size_; + std::vector global_size_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_ARITHMETIC_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc new file mode 100644 index 0000000000..4823f1eca0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -0,0 +1,217 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/kernel/concat.h" +#include "src/runtime/kernel/opencl/cl/fp32/concat.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Concat; + +namespace mindspore::kernel { + +int ConcatOpenCLKernel::Init() { + if (inputs_[0]->shape().size() != 4) { + MS_LOG(ERROR) << "only support dim=4"; + } + + auto param = reinterpret_cast(this->opParameter); + MS_LOG(INFO) << "concat at axis=: " << param->axis_; + if (param->axis_ != 0 && param->axis_ != 3) { + MS_LOG(ERROR) << "only support axis=0 or axis=3"; + } + + if (param->axis_ == 0) { + return 0; + } + if (inputs_.size() == 2) { + std::set build_options; + std::string source = concat_source_fp32; + std::string program_name = "Concat"; + std::string kernel_name = "Concat"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + } + + if (inputs_.size() == 3) { + std::set build_options; + std::string source = concat_source_fp32; + std::string program_name = "Concat3input"; + std::string kernel_name = "Concat3input"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + } + + return 0; +} + +int ConcatOpenCLKernel::ReSize() { return 0; } + +int ConcatOpenCLKernel::Run_axis0() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator_ = ocl_runtime->GetAllocator(); + cl::CommandQueue *command_queue = ocl_runtime->GetDefaultCommandQueue(); + + for (auto &tensor : inputs_) { + auto buffer = static_cast(allocator_->GetDeviceBuffer(tensor->Data())); + ocl_runtime->MapBuffer(*buffer, CL_MAP_READ, tensor->Size(), command_queue, true); + } + for (auto &tensor : outputs_) { + auto buffer = static_cast(allocator_->GetDeviceBuffer(tensor->Data())); + ocl_runtime->MapBuffer(*buffer, CL_MAP_WRITE, tensor->Size(), command_queue, true); + } + + memcpy(outputs_[0]->Data(), inputs_[0]->Data(), inputs_[0]->Size()); + memcpy(reinterpret_cast(outputs_[0]->Data()) + inputs_[0]->Size(), inputs_[1]->Data(), inputs_[1]->Size()); + + for (auto tensors : {&inputs_, &outputs_}) { + for (auto &tensor : *tensors) { + auto buffer = static_cast(allocator_->GetDeviceBuffer(tensor->Data())); + ocl_runtime->UnmapBuffer(*buffer, tensor->Data()); + } + } + return 0; +} +int DivideRoundUp(int n, int div) { + int q = n / div; + return n % div == 0 ? q : q + 1; +} + +int GetBiggestDividerWithPriority(int number, int max_divider) { + if (number % 8 == 0 && 8 <= max_divider) { + return number / 8; + } + if (number % 4 == 0 && 4 <= max_divider) { + return number / 4; + } + if (number % 2 == 0 && 2 <= max_divider) { + return number / 2; + } + for (int i = max_divider; i != 0; i--) { + if (number % i == 0) { + return i; + } + } + return 1; +} + +void ConcatGetWorkGroup(const std::vector &global, std::vector *local, int max_size) { + int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4); + int yz = max_size / x; + int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8); + int z = std::min(yz / y, DivideRoundUp(global[2], 2)); + + local->clear(); + local->push_back(x); + local->push_back(y); + local->push_back(z); +} +int ConcatOpenCLKernel::Run() { + auto param = reinterpret_cast(this->opParameter); + if (param->axis_ == 0) { + return Run_axis0(); + } + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::vector local; + std::vector global; + if (inputs_.size() == 2) { + auto input0_shape = inputs_[0]->shape(); + auto input1_shape = inputs_[1]->shape(); + auto output_shape = outputs_[0]->shape(); + + cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; + cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; + cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; + + uint32_t OH = output_shape[0] * output_shape[1]; // N*H + uint32_t OW = output_shape[2]; + uint32_t OC = output_shape[3]; + global = {OH, OW, OC}; // HWC + ConcatGetWorkGroup(global, &local, 384); + std::cout << "local size=:" << std::endl; + for (int i = 0; i < local.size(); i++) { + std::cout << local[i] << " "; + } + std::cout << std::endl; + int arg_cn = 0; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); + } + if (inputs_.size() == 3) { + auto input0_shape = inputs_[0]->shape(); + auto input1_shape = inputs_[1]->shape(); + auto input2_shape = inputs_[2]->shape(); + auto output_shape = outputs_[0]->shape(); + + cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; + cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; + cl_int4 input2_shape_ = {input2_shape[0], input2_shape[1], input2_shape[2], input2_shape[3]}; + cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; + + uint32_t OH = output_shape[0] * output_shape[1]; // N*H + uint32_t OW = output_shape[2]; + uint32_t OC = output_shape[3]; + global = {OH, OW, OC}; // HWC + ConcatGetWorkGroup(global, &local, 384); + std::cout << "local size=:" << std::endl; + for (int i = 0; i < local.size(); i++) { + std::cout << local[i] << " "; + } + std::cout << std::endl; + int arg_cn = 0; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[2]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input2_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); + } + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + + return 0; +} + +kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Concat, OpenCLConcatKernelCreator); +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h new file mode 100644 index 0000000000..1f2c115f87 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ + +#include +#include "ir/anf.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/base/concat_base.h" + +namespace mindspore::kernel { + +class ConcatOpenCLKernel : public LiteKernel { + public: + explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + + ~ConcatOpenCLKernel() override{}; + + int Init() override; + + int ReSize() override; + + int Run_axis0(); + + int Run() override; + + private: + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc new file mode 100644 index 0000000000..6aefa36a69 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { + +int Conv2dTransposeOpenCLKernel::Init() { + ConvParameter *param = reinterpret_cast(opParameter); + if (param->kernel_h_ != 2 || param->kernel_w_ != 2 || param->stride_h_ != 2 || param->stride_w_ != 2) { + MS_LOG(ERROR) << "only support kh=kw=2 and stride_h=stride_w=2."; + return 1; + } + if (param->pad_h_ >= 2 || param->pad_w_ >= 2) { + MS_LOG(ERROR) << "only support pad in {0,1}."; + return 1; + } + std::string kernel_name = "conv2d_transpose2x2"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else +#ifdef ENABLE_FP16 + std::string source = conv2d_transpose2x2_source_fp16; +#else + std::string source = conv2d_transpose2x2_source_fp32; +#endif + std::set build_options; + std::string program_name = "conv2d_transpose2x2"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + int ci = param->input_channel_; + int co = param->output_channel_; + int kh = param->kernel_h_; + int kw = param->kernel_w_; + int div_ci = UP_DIV(ci, 4); + int div_co = UP_DIV(co, 4); + auto allocator = ocl_runtime->GetAllocator(); + padWeight_ = reinterpret_cast(allocator->Malloc(div_ci * div_co * 16 * kh * kw * sizeof(FLOAT_T))); + padWeight_ = reinterpret_cast(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); + PadWeight(); + allocator->UnmapBuffer(padWeight_); + outputs_[0]->SetFormat(schema::Format_NHWC4); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return 0; +} + +int Conv2dTransposeOpenCLKernel::ReSize() { return 0; } + +void Conv2dTransposeOpenCLKernel::PadWeight() { + // OHWI to OHWI4(I)4(O) + ConvParameter *param = reinterpret_cast(opParameter); + int ci = param->input_channel_; + int co = param->output_channel_; + int kh = param->kernel_h_; + int kw = param->kernel_w_; + int div_ci = UP_DIV(ci, 4); + int div_co = UP_DIV(co, 4); + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + int index = 0; + for (int co_i = 0; co_i < div_co; co_i++) { + for (int kw_i = 0; kw_i < kw; kw_i++) { + for (int kh_i = 0; kh_i < kh; kh_i++) { + for (int ci_i = 0; ci_i < div_ci; ci_i++) { + for (int ci4_i = 0; ci4_i < 4; ci4_i++) { + for (int co4_i = 0; co4_i < 4; co4_i++) { + int co_offset = co_i * 4 + co4_i; + int ci_offset = ci_i * 4 + ci4_i; + if (co_offset < co && ci_offset < ci) { + int ori_index = ((co_offset * kh + kh_i) * kw + kw_i) * ci + ci_offset; + padWeight_[index++] = origin_weight[ori_index]; + } else { + padWeight_[index++] = 0.; + } + } + } + } + } + } + } +} + +int Conv2dTransposeOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + std::vector shapex = inputs_[0]->shape(); + int n = shapex[0]; + if (n > 1) { + MS_LOG(ERROR) << "Conv2dTranspose n > 1 not supported!"; + return 1; + } + ConvParameter *param = reinterpret_cast(opParameter); + int ci = param->input_channel_; + int co = param->output_channel_; + int kh = param->kernel_h_; + int kw = param->kernel_w_; + int pad = param->pad_h_; + int oh = outputs_[0]->shape()[1]; + int ow = outputs_[0]->shape()[2]; + int h = inputs_[0]->shape()[1]; + int w = inputs_[0]->shape()[2]; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + cl::ImageFormat image_format; + { + image_format.image_channel_order = CL_RGBA; +#ifdef ENABLE_FP16 + image_format.image_channel_data_type = CL_HALF_FLOAT; +#else + image_format.image_channel_data_type = CL_FLOAT; +#endif + } + cl_int in_error_code, in_error_code_weight, in_error_code_bias, out_error_code; + cl::Image2D img_x(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, w * ci / 4, h, 0, + inputs_[0]->Data(), &in_error_code); + cl::Image2D img_bias(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, co / 4, 1, 0, + inputs_[2]->Data(), &in_error_code_bias); + cl::Image2D out_mem(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, ow * co / 4, oh, 0, nullptr, + &out_error_code); + // local size should less than MAX_GROUP_SIZE + std::vector local = {16, 1, 16}; + std::vector global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]), + UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND((size_t)co / 4, local[2])}; + + cl_int2 kernel_size = {kh, kw}; + cl_int2 stride = {2, 2}; + cl_int2 padding = {pad, pad}; + cl_int4 src_size = {h, w, UP_DIV(ci, 4), 1}; + cl_int4 dst_size = {oh, ow, UP_DIV(co, 4), 1}; + ocl_runtime->SetKernelArg(kernel_, 0, img_x); + ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); + ocl_runtime->SetKernelArg(kernel_, 2, img_bias); + ocl_runtime->SetKernelArg(kernel_, 3, out_mem); + ocl_runtime->SetKernelArg(kernel_, 4, kernel_size); + ocl_runtime->SetKernelArg(kernel_, 5, stride); + ocl_runtime->SetKernelArg(kernel_, 6, padding); + ocl_runtime->SetKernelArg(kernel_, 7, src_size); + ocl_runtime->SetKernelArg(kernel_, 8, dst_size); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + auto origin = cl::array{0, 0, 0}; + auto region = cl::array{(size_t)(ow * co / 4), (size_t)(oh), 1}; + ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(out_mem, CL_TRUE, origin, region, 0, 0, outputs_[0]->Data()); + return 0; +} + +kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() + // << ", type: " << lite::EnumNameOpT(opDef.attr_type()); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, OpenCLConv2dTransposeKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h new file mode 100644 index 0000000000..b3299d2a53 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +#ifdef ENABLE_FP16 +using FLOAT_T = float16_t; +#else +using FLOAT_T = float; +#endif + +namespace mindspore::kernel { + +class Conv2dTransposeOpenCLKernel : public LiteKernel { + public: + explicit Conv2dTransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~Conv2dTransposeOpenCLKernel() override {}; + + int Init() override; + int ReSize() override; + int Run() override; + void PadWeight(); + + private: + ConvParameter *parameter_; + cl::Kernel kernel_; + FLOAT_T *padWeight_; + FLOAT_T *bias_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc new file mode 100644 index 0000000000..c2e5e9a447 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc @@ -0,0 +1,238 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/runtime/kernel/opencl/kernel/convolution.h" +#include "src/runtime/kernel/opencl/cl/fp32/convolution.cl.inc" +#include "src/kernel_registry.h" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { + +int ConvolutionOpenCLKernel::Init() { + MS_LOG(INFO) << "ConvolutionOpenCLKernel::Init()"; + + if (inputs_[0]->Batch() != 1 || outputs_[0]->Batch() != 1) { + MS_LOG(ERROR) << "ConvolutionOpenCLKernel only support Batch=1!"; + } + + auto io_NHWC = inputs_[0]->GetFormat() == schema::Format_NHWC && outputs_[0]->GetFormat() == schema::Format_NHWC; + auto io_NHWC4 = inputs_[0]->GetFormat() == schema::Format_NHWC4 && outputs_[0]->GetFormat() == schema::Format_NHWC4; + if (!io_NHWC && !io_NHWC4) { + MS_LOG(ERROR) << "input and output data_format is invalid!"; + } + io_dataformat_ = inputs_[0]->GetFormat(); + + if (inputs_[1]->GetFormat() != schema::Format_KHWC) { + MS_LOG(ERROR) << "weight data_format is invalid!"; + } + + std::set build_options; + std::string source = convolution_source_fp32; + std::string program_name = "convolution"; + std::string kernel_name = io_NHWC4 ? "convolution_NHWC4_OHWIIO_float8" : "convolution_NHWC_OHWI"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + this->InitBuffer(); + + return 0; +} +int ConvolutionOpenCLKernel::InitBuffer() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator = ocl_runtime->GetAllocator(); + + auto weight_tensor = inputs_[1]; + auto bias_tensor = inputs_[2]; + if (io_dataformat_ == schema::Format_NHWC) { + packed_weight_ = reinterpret_cast(allocator->Malloc(weight_tensor->Size())); + packed_weight_ = reinterpret_cast(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true)); + memcpy(packed_weight_, weight_tensor->Data(), weight_tensor->Size()); + allocator->UnmapBuffer(packed_weight_); + + packed_bias_ = reinterpret_cast(allocator->Malloc(bias_tensor->Size())); + packed_bias_ = reinterpret_cast(allocator->MapBuffer(packed_bias_, CL_MAP_WRITE, nullptr, true)); + memcpy(packed_bias_, bias_tensor->Data(), bias_tensor->Size()); + allocator->UnmapBuffer(packed_bias_); + } else if (io_dataformat_ == schema::Format_NHWC4) { + // OHWI -> OHWIIO + auto weight_shape = weight_tensor->shape(); + size_t CO = weight_shape[0]; + size_t KH = weight_shape[1]; + size_t KW = weight_shape[2]; + size_t CI = weight_shape[3]; + size_t CI_SLICES = UP_DIV(CI, C4NUM); + size_t CO_SLICES = UP_DIV(CO, C4NUM); + constexpr size_t CI_TILE = C4NUM; + constexpr size_t CO_TILE = C4NUM; + size_t packed_weight_size = CO_SLICES * KH * KW * CI_SLICES * CI_TILE * CO_TILE * sizeof(float); + + packed_weight_ = reinterpret_cast(allocator->Malloc(packed_weight_size)); + packed_weight_ = reinterpret_cast(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true)); + memset(packed_weight_, 0x00, packed_weight_size); + auto weight_data = reinterpret_cast(weight_tensor->Data()); + for (int co = 0; co < CO; ++co) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + for (int ci = 0; ci < CI; ++ci) { + auto co_outer = co / CO_TILE; + auto co_inner = co % CO_TILE; + auto ci_outer = ci / CI_TILE; + auto ci_inner = ci % CI_TILE; + packed_weight_[((((co_outer * KH + kh) * KW + kw) * CI_SLICES + ci_outer) * CI_TILE + ci_inner) * CO_TILE + + co_inner] = *(weight_data++); + } + } + } + } + allocator->UnmapBuffer(packed_weight_); + size_t packed_bias_size = CO_SLICES * CO_TILE * sizeof(float); + packed_bias_ = reinterpret_cast(allocator->Malloc(packed_bias_size)); + packed_bias_ = reinterpret_cast(allocator->MapBuffer(packed_bias_, CL_MAP_WRITE, nullptr, true)); + memset(packed_bias_, 0x00, packed_bias_size); + auto bias_data = reinterpret_cast(bias_tensor->Data()); + for (int co = 0; co < CO; ++co) { + packed_bias_[co] = bias_data[co]; + } + allocator->UnmapBuffer(packed_bias_); + } + + return 0; +} // namespace mindspore::kernel + +int ConvolutionOpenCLKernel::ReSize() { return 0; } + +static int GetBiggestDivider(int x, int y) { + for (int i = y; i != 0; i--) { + if (x % i == 0) { + return i; + } + } + return 1; +} + +static void GetLocalSize(const ConvParameter *param, std::vector *global, std::vector *local) { + constexpr size_t work_group_size[] = {4, 4, 1}; + constexpr size_t max_work_item_sizes[] = {512, 512, 512}; + constexpr size_t max_work_group_size = 512; + const size_t max_z_size = std::min(16, max_work_item_sizes[2]); + + // 先用OH OW CO_SLICES初始化global,并且441对齐 + size_t global_h = UP_DIV(param->output_h_, work_group_size[0]) * work_group_size[0]; + size_t global_w = UP_DIV(param->output_w_, work_group_size[1]) * work_group_size[1]; + size_t global_c = UP_DIV(UP_DIV(param->output_channel_, C4NUM), work_group_size[2]) * work_group_size[2]; + + // 使用策略计算local + size_t local_c = GetBiggestDivider(global_c, max_z_size); + size_t local_hw_size = std::min(256, max_work_group_size) / local_c; + size_t local_w = std::min(global_w, local_hw_size); + size_t local_h = std::min(local_hw_size / local_w, global_h); + if (local_h == global_h && global_h % 2 == 0) { + local_h = global_h / 2; + } + + global->clear(); + global->push_back(UP_DIV(param->output_h_, local_h) * local_h); + global->push_back(UP_DIV(param->output_w_, local_w) * local_w); + global->push_back(UP_DIV(UP_DIV(param->output_channel_, C4NUM), local_c) * local_c); + local->clear(); + local->push_back(local_h); + local->push_back(local_w); + local->push_back(local_c); +} + +int ConvolutionOpenCLKernel::Run() { + MS_LOG(INFO) << "ConvolutionOpenCLKernel::Run()"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + auto param = reinterpret_cast(opParameter); + auto input0_shape = inputs_[0]->shape(); // NHWC + auto input1_shape = inputs_[1]->shape(); // OHWI + auto outpu0_shape = outputs_[0]->shape(); // NHWC + cl_int N = input0_shape[0]; + cl_int CI = input0_shape[3]; + cl_int IH = input0_shape[1]; + cl_int IW = input0_shape[2]; + cl_int CO = outpu0_shape[3]; + cl_int OH = outpu0_shape[1]; + cl_int OW = outpu0_shape[2]; + cl_int KH = input1_shape[1]; + cl_int KW = input1_shape[2]; + cl_int CI_ALIGN = UP_DIV(CI, C4NUM) * C4NUM; + cl_int CO_ALIGN = UP_DIV(CO, C4NUM) * C4NUM; + + cl_int4 input_shape; + cl_int4 output_shape; + if (io_dataformat_ == schema::Format_NHWC) { + input_shape = {N, IH, IW, CI}; + output_shape = {N, OH, OW, CO}; + } else if (io_dataformat_ == schema::Format_NHWC4) { + input_shape = {N, IH, IW, CI_ALIGN}; + output_shape = {N, OH, OW, CO_ALIGN}; + } + cl_int4 kernel_stride = {KH, KW, param->stride_h_, param->stride_w_}; + cl_int4 pad = {param->pad_u_, param->pad_d_, param->pad_l_, param->pad_r_}; + + int arg_cn = 0; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, packed_weight_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, packed_bias_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, kernel_stride); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, pad); + + std::vector global; + std::vector local; + GetLocalSize(reinterpret_cast(this->opParameter), &global, &local); + // float8 per thread + if (io_dataformat_ == schema::Format_NHWC4) { + local[2] = UP_DIV(local[2], 2); + global[2] = UP_DIV(global[2], 2); + global[2] = UP_DIV(global[2], global[2]) * global[2]; + } + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + + return 0; +} + +kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new ConvolutionOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create OpenCL Convolution kernel failed!"; + return nullptr; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Conv2D, OpenCLConvolutionKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h new file mode 100644 index 0000000000..f757b1968a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CONVOLUTION_H_ + +#include +#include "src/ir/tensor.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "schema/model_generated.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +namespace mindspore::kernel { + +class ConvolutionOpenCLKernel : public LiteKernel { + public: + explicit ConvolutionOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ConvolutionOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + + private: + schema::Format io_dataformat_ = schema::Format_NHWC4; + float *packed_weight_ = nullptr; + float *packed_bias_ = nullptr; + cl::Kernel kernel_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc new file mode 100644 index 0000000000..e7a89821c0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -0,0 +1,206 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" + +#ifndef PROGRAM_WITH_IL + +#include "src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl.inc" + +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { + +int DepthwiseConv2dOpenCLKernel::Init() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::string kernel_name = "DepthwiseConv2d"; + auto in_format = inputs_[0]->GetFormat(); + outputs_[0]->SetFormat(in_format); + if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) { + MS_LOG(ERROR) << "input format(" << in_format << ") " + << "format not support!"; + } + if (mem_type_ == MEM_TYPE::BUF) { + kernel_name += "_BUF"; + } else { + kernel_name += "_IMG"; + } + if (in_format == schema::Format_NC4HW4) { + kernel_name += "_NC4HW4"; + } else if (in_format == schema::Format_NHWC4) { + kernel_name += "_NHWC4"; + } + auto parameter = reinterpret_cast(opParameter); + if (parameter->kernel_h_ == 1) { + kernel_name += "_1x1"; + } +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::string program_name = "DepthwiseConv2d"; + std::set build_options; +#ifdef ENABLE_FP16 + std::string source = depthwise_conv2d_source_fp16; +#else + std::string source = depthwise_conv2d_source_fp32; +#endif + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + this->InitBuffer(); + MS_LOG(DEBUG) << kernel_name << " Init Done! mem type=" << static_cast(mem_type_); + return RET_OK; +} + +int DepthwiseConv2dOpenCLKernel::InitBuffer() { + auto parameter = reinterpret_cast(opParameter); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator = ocl_runtime->GetAllocator(); + + // weight: o, h, w, i; o == group, i == 1 + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + int CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_; + + packed_weight_ = reinterpret_cast(allocator->Malloc(pack_weight_size * sizeof(FLOAT_t))); + packed_weight_ = reinterpret_cast(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true)); + int plane = parameter->kernel_h_ * parameter->kernel_w_; +#ifdef ENABLE_FP16 + PackNCHWToNC4HW4Fp16(origin_weight, packed_weight_, 1, plane, outputs_[0]->Channel()); +#else + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, plane, outputs_[0]->Channel()); +#endif + + allocator->UnmapBuffer(packed_weight_); + + // init bias + if (inputs_.size() == kInputSize2) { + bias_data_ = reinterpret_cast(allocator->Malloc(C4NUM * CO4 * sizeof(FLOAT_t))); + bias_data_ = reinterpret_cast(allocator->MapBuffer(bias_data_, CL_MAP_WRITE, nullptr, true)); + size_t up_co_size = C4NUM * CO4 * sizeof(FLOAT_t); + memset(bias_data_, 0, up_co_size); + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, outputs_[0]->Channel() * sizeof(FLOAT_t)); + allocator->UnmapBuffer(bias_data_); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int DepthwiseConv2dOpenCLKernel::ReSize() { return RET_OK; } + +int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + size_t im_dst_x, im_dst_y; + if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { + im_dst_x = outputs_[0]->Width() * CO4; + im_dst_y = outputs_[0]->Height(); + } else { + im_dst_y = outputs_[0]->Height() * CO4; + im_dst_x = outputs_[0]->Width(); + } +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + img_size->clear(); + std::vector vec{im_dst_x, im_dst_y, img_dtype}; + *img_size = vec; + return RET_OK; +} + +int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector *global_size) { + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + std::vector global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; + *global_size = std::move(global); + return RET_OK; +} + +int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector &global_size, + std::vector *local_size) { + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + std::vector local = {1, 1, CO4}; + *local_size = std::move(local); + return RET_OK; +} + +int DepthwiseConv2dOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + auto parameter = reinterpret_cast(opParameter); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM); + std::vector global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; + std::vector local; + GetLocalSize(0, global, &local); + + float relu_clip1 = 6.0; + cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_}; + cl_int2 stride = {parameter->stride_h_, parameter->stride_w_}; + cl_int2 padding = {-parameter->pad_h_, -parameter->pad_w_}; + cl_int2 dilation = {parameter->dilation_h_, parameter->dilation_w_}; + cl_int4 src_size = {inputs_[0]->Width(), inputs_[0]->Height(), (cl_int)CI4, inputs_[0]->Batch()}; + cl_int4 dst_size = {(cl_int)outputs_[0]->Width(), (cl_int)outputs_[0]->Height(), (cl_int)CO4, + (cl_int)outputs_[0]->Batch()}; + + ocl_runtime->SetKernelArg(kernel_, 1, packed_weight_); + ocl_runtime->SetKernelArg(kernel_, 2, bias_data_); + ocl_runtime->SetKernelArg(kernel_, 3, relu_clip1); + ocl_runtime->SetKernelArg(kernel_, 5, kernel_size); + ocl_runtime->SetKernelArg(kernel_, 6, stride); + ocl_runtime->SetKernelArg(kernel_, 7, padding); + ocl_runtime->SetKernelArg(kernel_, 8, dilation); + ocl_runtime->SetKernelArg(kernel_, 9, src_size); + ocl_runtime->SetKernelArg(kernel_, 10, dst_size); + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data()); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return RET_OK; +} + +kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init DepthwiseConv2dOpenCLKernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, OpenCLDepthwiseConv2dKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h new file mode 100644 index 0000000000..adad41f6c8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ + +#include +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" + +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { + public: + explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs), packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} + + ~DepthwiseConv2dOpenCLKernel() override{}; + + int Init() override; + + int ReSize() override; + + int Run() override; + + int InitBuffer(); + + int GetImageSize(size_t idx, std::vector *img_size) override; + int GetGlobalSize(size_t idx, std::vector *global_size) override; + int GetLocalSize(size_t idx, const std::vector &global_size, std::vector *local_size) override; + + private: + FLOAT_t *packed_weight_; + FLOAT_t *bias_data_; + cl::Kernel kernel_; + enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc new file mode 100644 index 0000000000..9a73ef48d7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" +#include "src/runtime/kernel/opencl/kernel/matmul.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/matmul.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/matmul.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_FullConnection; +using mindspore::schema::PrimitiveType_MatMul; + +namespace mindspore::kernel { + +int MatMulOpenCLKernel::Init() { + std::string kernel_name = "MatMul"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; +// build_options.emplace("-DPOOL_AVG"); +#ifdef ENABLE_FP16 + std::string source = matmul_source_fp16; +#else + std::string source = matmul_source_fp32; +#endif + std::string program_name = "MatMul"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + auto weight_format = inputs_[1]->GetFormat(); + if (weight_format != schema::Format_NHWC) { + MS_LOG(ERROR) << "weight format(" << weight_format << ") " + << "format not support!"; + return 1; + } + int ci = inputs_[1]->shape()[3]; + int co = inputs_[1]->shape()[0]; + sizeCI = {ci, UP_DIV(ci, 4)}; + sizeCO = {co, UP_DIV(co, 4)}; + auto allocator = ocl_runtime->GetAllocator(); + padWeight_ = reinterpret_cast(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * 16 * sizeof(FLOAT_T))); + padWeight_ = reinterpret_cast(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); + bias_ = reinterpret_cast(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T))); + bias_ = reinterpret_cast(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); + PadWeight(); + allocator->UnmapBuffer(padWeight_); + allocator->UnmapBuffer(bias_); + outputs_[0]->SetFormat(schema::Format_NHWC4); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return 0; +} + +int MatMulOpenCLKernel::ReSize() { return 0; } + +void MatMulOpenCLKernel::PadWeight() { + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + int divCI = sizeCI.s[1]; + int divCO = sizeCO.s[1]; + int index = 0; + for (int i = 0; i < divCI; ++i) { + for (int j = 0; j < divCO; ++j) { + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 4; ++l) { + int src_x = i * 4 + l; + int src_y = j * 4 + k; + if (src_x < sizeCI.s[0] && src_y < sizeCO.s[0]) { + padWeight_[index++] = origin_weight[src_y * sizeCI.s[0] + src_x]; + } else { + padWeight_[index++] = 0; + } + } + } + } + } + if (hasBias_) { + memcpy(bias_, inputs_[2]->Data(), sizeof(FLOAT_T) * sizeCO.s[0]); + for (int i = sizeCO.s[0]; i < sizeCO.s[1] * 4; i++) { + bias_[i] = 0; + } + } else { + for (int i = 0; i < sizeCO.s[1] * 4; i++) { + bias_[i] = 0; + } + } +} + +int MatMulOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + std::vector shapex = inputs_[0]->shape(); + int n = shapex[0]; + if (n > 1) { + MS_LOG(ERROR) << "MatMul n > 1 not supported!"; + return 1; + } + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + // local size should less than MAX_GROUP_SIZE + std::vector local = {64, 4}; + std::vector global = {UP_ROUND(sizeCO.s[1], local[0]), 4}; + + cl::ImageFormat image_format; + { + image_format.image_channel_order = CL_RGBA; +#ifdef ENABLE_FP16 + image_format.image_channel_data_type = CL_HALF_FLOAT; +#else + image_format.image_channel_data_type = CL_FLOAT; +#endif + } + cl_int in_error_code, in_error_code_weight, in_error_code_bias, out_error_code; + cl::Image2D img_input(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCI.s[1], 1, + 0, inputs_[0]->Data(), &in_error_code); + cl::Image2D img_bias(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCO.s[1], 1, + 0, bias_, &in_error_code_bias); + cl::Image2D img_out(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, sizeCO.s[1], 1, 0, nullptr, + &out_error_code); + + ocl_runtime->SetKernelArg(kernel_, 0, img_input); + ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); + ocl_runtime->SetKernelArg(kernel_, 2, img_bias); + ocl_runtime->SetKernelArg(kernel_, 3, img_out); + ocl_runtime->SetKernelArg(kernel_, 4, sizeCI); + ocl_runtime->SetKernelArg(kernel_, 5, sizeCO); + ocl_runtime->SetKernelArg(kernel_, 6, hasBias_ ? 1 : 0); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + auto origin = cl::array{0, 0, 0}; + auto region = cl::array{(size_t)(sizeCO.s[1]), 1, 1}; + ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(img_out, CL_TRUE, origin, region, 0, 0, outputs_[0]->Data()); + return 0; +} + +kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + bool hasBias = false; + if (opParameter->type_ == PrimitiveType_FullConnection) { + hasBias = (reinterpret_cast(opParameter))->has_bias_; + } + auto *kernel = new MatMulOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, hasBias); + auto ret = kernel->Init(); + if (0 != ret) { + // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() + // << ", type: " << lite::EnumNameOpT(opDef.attr_type()); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_MatMul, OpenCLMatMulKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_FullConnection, OpenCLMatMulKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h new file mode 100644 index 0000000000..beb874cc67 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_MATMUL_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_MATMUL_H_ + +#include + + +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +#ifdef ENABLE_FP16 +using FLOAT_T = float16_t; +#else +using FLOAT_T = float; +#endif + +namespace mindspore::kernel { + +class MatMulOpenCLKernel : public LiteKernel { + public: + explicit MatMulOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, bool hasBias) + : LiteKernel(parameter, inputs, outputs) { + hasBias_ = hasBias; + } + ~MatMulOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + void PadWeight(); + + private: + cl::Kernel kernel_; + FLOAT_T *padWeight_; + FLOAT_T *bias_; + bool hasBias_ = false; + cl_int2 sizeCI; + cl_int2 sizeCO; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc new file mode 100644 index 0000000000..476ff23dbc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/pooling2d.h" +#include +#include +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/opencl/opencl_wrapper.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/image_format.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_INVALID_OP_NAME; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pooling; + +namespace mindspore { +namespace kernel { +int PoolingOpenCLKernel::Init() { + std::string kernel_name; +#ifndef PROGRAM_WITH_IL + std::string source; + std::string program_name; +#endif + if (parameter_->max_pooling_) { + kernel_name = "MaxPooling2d"; +#ifndef PROGRAM_WITH_IL + source = max_pool2d_source_fp32; + program_name = "MaxPooling2d"; +#endif + } else if (parameter_->avg_pooling_) { + kernel_name = "AvgPooling2d"; +#ifndef PROGRAM_WITH_IL + source = avg_pool2d_source_fp32; + program_name = "AvgPooling2d"; +#endif + } else { + MS_LOG(ERROR) << "Init `Pooling2d` kernel failed!"; + return RET_INVALID_OP_NAME; + } + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + if (mem_type_ == MEM_TYPE::BUF) { + kernel_name += "_BUF"; + } else { + kernel_name += "_IMG"; + } + std::set build_options; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + outputs_[0]->SetFormat(schema::Format_NHWC4); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + + return RET_OK; +} + +std::vector PoolingOpenCLKernel::InitGlobalSize() const { + const size_t global_x = outputs_[0]->Height(); + const size_t global_y = outputs_[0]->Width(); + const size_t global_z = UP_ROUND_DIV(outputs_[0]->Channel(), 4); + std::vector global = {global_x, global_y, global_z}; + return global; +} + +int PoolingOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + size_t im_dst_x, im_dst_y; + if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { + im_dst_x = outputs_[0]->Height(); + im_dst_y = outputs_[0]->Width() * CO4; + } else { + im_dst_y = outputs_[0]->Width(); + im_dst_x = outputs_[0]->Height() * CO4; + } +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + img_size->clear(); + std::vector vec{im_dst_x, im_dst_y, img_dtype}; + *img_size = vec; + return RET_OK; +} + +int PoolingOpenCLKernel::InitBuffer() { return RET_OK; } + +int PoolingOpenCLKernel::ReSize() { return RET_OK; } + +int PoolingOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + // attribute + int slices = UP_ROUND_DIV(outputs_[0]->Channel(), 4); + cl_int4 input_shape = {inputs_[0]->Height(), inputs_[0]->Width(), inputs_[0]->Channel(), slices}; + cl_int4 output_shape = {outputs_[0]->Height(), outputs_[0]->Width(), outputs_[0]->Channel(), slices}; + cl_int2 stride = {parameter_->stride_h_, parameter_->stride_w_}; + cl_int2 kernel_size = {parameter_->window_h_, parameter_->window_w_}; + cl_int2 padding = {parameter_->pad_u_, parameter_->pad_l_}; + + // binding parameters + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, output_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, stride); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, kernel_size); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, padding); + + // set work group size + std::vector local_size; + std::vector global_size = InitGlobalSize(); + int max_work_group_size = ocl_runtime->GetKernelMaxWorkGroupSize(kernel_(), (*ocl_runtime->Device())()); + local_size = GetCommonLocalSize(global_size, max_work_group_size); + global_size = GetCommonGlobalSize(local_size, global_size); + + // run opengl kernel + ocl_runtime->RunKernel(kernel_, global_size, local_size, nullptr); + return RET_OK; +} + +kernel::LiteKernel *OpenCLPooling2dKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new PoolingOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create OpenCL Pooling kernel failed!"; + return nullptr; + } + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init OpenCL Pooling kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Pooling, OpenCLPooling2dKernelCreator) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h new file mode 100644 index 0000000000..b08b000308 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ + +#include + +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +class PoolingOpenCLKernel : public OpenCLKernel { + public: + explicit PoolingOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) { + parameter_ = reinterpret_cast(parameter); + } + ~PoolingOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + int GetImageSize(size_t idx, std::vector *img_size) override; + + private: + std::vector InitGlobalSize() const; + enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; + PoolingParameter *parameter_; + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc new file mode 100644 index 0000000000..4bdf5db2c4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/softmax.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore { +namespace kernel { +int SoftmaxOpenCLKernel::Init() { + std::string kernel_name = "SoftMax"; + if (parameter_->axis_ != -1 && parameter_->axis_ != 3) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; + return -1; + } + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; + std::string source = softmax_source_fp32; + std::string program_name = "SoftMax"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + outputs_[0]->SetFormat(schema::Format_NHWC4); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return 0; +} + +int SoftmaxOpenCLKernel::InitBuffer() { return 0; } +int SoftmaxOpenCLKernel::ReSize() { return 0; } + +int SoftmaxOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator = ocl_runtime->GetAllocator(); + + // global and local workers + const uint32_t grid_x = inputs_[0]->shape()[2]; // W + const uint32_t grid_y = inputs_[0]->shape()[1]; // H + const uint32_t grid_z = 1; + std::vector global = {grid_x, grid_y, grid_z}; + std::vector local = {1, 1, 1}; + + // input and output + cl::Buffer *input = reinterpret_cast(allocator->GetDeviceBuffer(inputs_[0]->Data())); + cl::Buffer *output = reinterpret_cast(allocator->GetDeviceBuffer(outputs_[0]->Data())); + cl_int4 input_size = {inputs_[0]->shape()[0], inputs_[0]->shape()[1], inputs_[0]->shape()[2], inputs_[0]->shape()[3]}; + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size); + + // run opengl kernel + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + + return 0; +} + +kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new SoftmaxOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (inputs[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported multi-batch."; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator) +} // namespace kernel +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h new file mode 100644 index 0000000000..e4e3ce70ac --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ + +#include + +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore { +namespace kernel { +class SoftmaxOpenCLKernel : public LiteKernel { + public: + explicit SoftmaxOpenCLKernel(OpParameter *parameter, + const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + parameter_ = reinterpret_cast(parameter); + } + ~SoftmaxOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + + private: + SoftmaxParameter *parameter_; + cl::Kernel kernel_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc new file mode 100644 index 0000000000..428c071800 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/kernel/transpose.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/transpose.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/transpose.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Transpose; + +namespace mindspore::kernel { + +int TransposeOpenCLKernel::Init() { + std::string kernel_name = "transpose"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; +#ifdef ENABLE_FP16 + std::string source = transpose_source_fp16; +#else + std::string source = transpose_source_fp32; +#endif + std::string program_name = "transpose"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + auto input_format = inputs_[0]->GetFormat(); + if (input_format != schema::Format_NHWC4) { + MS_LOG(ERROR) << "input format(" << input_format << ") " + << "format not support!"; + return RET_ERROR; + } + if ((inputs_[0]->Height() * inputs_[0]->Width()) % 4 != 0) { + MS_LOG(ERROR) << "input H * W % 4 != 0 not support!"; + return RET_ERROR; + } + outputs_[0]->SetFormat(schema::Format_NCHW); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; +} + +int TransposeOpenCLKernel::ReSize() { return 0; } + +int TransposeOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + std::vector shapex = inputs_[0]->shape(); + int h = shapex[1]; + int w = shapex[2]; + int c = shapex[3]; + int c4 = UP_DIV(c, 4); + int hw4 = UP_DIV(h * w, 4); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + // local size should less than MAX_GROUP_SIZE + std::vector local = {16, 16}; + std::vector global = {UP_ROUND(hw4, local[0]), UP_ROUND(c4, local[1])}; + + cl_int2 HW = {h * w, hw4}; + cl_int2 C = {c, c4}; + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 1, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 2, HW); + ocl_runtime->SetKernelArg(kernel_, 3, C); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new TransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Transpose, OpenCLTransposeKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h new file mode 100644 index 0000000000..c16557b5b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" + + +namespace mindspore::kernel { +class TransposeOpenCLKernel : public OpenCLKernel { + public: + explicit TransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + ~TransposeOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + cl::Kernel kernel_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h new file mode 100644 index 0000000000..cf76286e35 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_ +#define MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class OpenCLKernel : public LiteKernel { + public: + explicit OpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + + virtual int Init() { return -1; } + virtual int Prepare() { return -1; } + virtual int InferShape() { return -1; } + virtual int ReSize() { return -1; } + virtual int Run() { return -1; } + virtual int GetImageSize(size_t idx, std::vector* img_size) { return -1; } + virtual int GetGlobalSize(size_t idx, std::vector* global_size) { return -1; } + virtual int GetLocalSize(size_t idx, const std::vector& global_size, + std::vector* local_size) { return -1; } +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc new file mode 100644 index 0000000000..dcd2fe8943 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "src/runtime/opencl/opencl_executor.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +SubGraphOpenCLKernel::~SubGraphOpenCLKernel() { UnInit(); } + +int SubGraphOpenCLKernel::Init() { + allocator_ = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); + for (const auto tensor : inputs_) { + tensor->set_allocator(allocator_); + } + for (const auto tensor : outputs_) { + tensor->set_allocator(allocator_); + } + // Map buffer for write, it is not necessary for fine-grained + for (auto &tensor : inputs_) { + void *data = tensor->Data(); + // It is required with coarse-grained SVM + if (data != nullptr) { + data = allocator_->MapBuffer(data, CL_MAP_WRITE, nullptr, true); + tensor->SetData(data); + } else { + MS_LOG(ERROR) << "OpenCL kernel must use GPU buffer pointer, " + << "please make sure that this buffer allocate by OpenCLAllocator!"; + } + } + return 0; +} + +int SubGraphOpenCLKernel::UnInit() { + for (auto &tensor : outputs_) { + allocator_->UnmapBuffer(tensor->Data()); + } + for (const auto tensor : inputs_) { + if (tensor != nullptr) { + tensor->FreeData(); + } + } + for (const auto tensor : outputs_) { + if (tensor != nullptr) { + tensor->FreeData(); + } + } + return 0; +} + +int SubGraphOpenCLKernel::InferShape() { return 0; } + +int SubGraphOpenCLKernel::ReSize() { return 0; } + +int SubGraphOpenCLKernel::Run() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + for (auto &tensor : inputs_) { + allocator_->UnmapBuffer(tensor->Data()); + } + + lite::opencl::OpenCLExecutor executor; + executor.Run(inputs_, outputs_, nodes_, allocator_); + ocl_runtime->SyncCommandQueue(); + for (auto &tensor : outputs_) { + void *data = allocator_->MapBuffer(tensor->Data(), CL_MAP_READ, nullptr, true); + tensor->SetData(data); + } + return 0; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h new file mode 100644 index 0000000000..7f7d5a343e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KENEL_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KENEL_H_ + +#include +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/opencl/opencl_allocator.h" + +namespace mindspore::kernel { + +struct SubGraphOpenCLParameter { + OpParameter op_parameter; + int input_size; + int output_size; +}; + +class SubGraphOpenCLKernel : public SubGraphKernel { + public: + explicit SubGraphOpenCLKernel(const std::vector inputs, + const std::vector outputs, + const std::vector inKernels, + const std::vector outKernels, + const std::vector nodes) + : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes) {} + ~SubGraphOpenCLKernel() override; + + int Init() override; + int InferShape() override; + int ReSize() override; + int Run() override; + int UnInit(); + + private: + SubGraphOpenCLParameter *subgraph_ocl_parameter_; + lite::opencl::OpenCLAllocator *allocator_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KERNEL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc new file mode 100644 index 0000000000..1ce176ca22 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/utils.h" +#include +#include +#include + +namespace mindspore { +namespace kernel { + +std::vector GetCommonGlobalSize(const std::vector &local, const std::vector &global) { + std::vector result(3, 1); + for (int i = 0; i < 3; ++i) { + result[i] = AlignByN(global[i], local[i]); + } + return result; +} + +std::vector GetCommonLocalSize(const std::vector &global, int max_size) { + size_t wg_z = GetBiggestDividerWithPriority(global[2], 8); + size_t wg_xy_size = max_size / wg_z; + size_t wg_x = std::min(DivideRoundUp(global[0], 2), wg_xy_size); + size_t wg_y = std::min(wg_xy_size / wg_x, global[1]); + std::vector local = {wg_x, wg_y, wg_z}; + return local; +} + +std::string CLErrorCode(cl_int error_code) { + switch (error_code) { + case CL_SUCCESS: + return "Success"; + case CL_DEVICE_NOT_FOUND: + return "Device not found"; + case CL_DEVICE_NOT_AVAILABLE: + return "Device not available"; + case CL_COMPILER_NOT_AVAILABLE: + return "Compiler not available"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: + return "Memory object allocation failure"; + case CL_OUT_OF_RESOURCES: + return "Out of resources"; + case CL_OUT_OF_HOST_MEMORY: + return "Out of host memory"; + case CL_PROFILING_INFO_NOT_AVAILABLE: + return "Profiling information not available"; + case CL_MEM_COPY_OVERLAP: + return "Memory copy overlap"; + case CL_IMAGE_FORMAT_MISMATCH: + return "Image format mismatch"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: + return "Image format not supported"; + case CL_BUILD_PROGRAM_FAILURE: + return "Build program failure"; + case CL_MAP_FAILURE: + return "Mapping failure"; + case CL_MISALIGNED_SUB_BUFFER_OFFSET: + return "Misaligned sub-buffer offset"; + case CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST: + return "Execution status error for events in wait list"; + case CL_COMPILE_PROGRAM_FAILURE: + return "Compile program failure"; + case CL_LINKER_NOT_AVAILABLE: + return "Linker not available"; + case CL_LINK_PROGRAM_FAILURE: + return "Link program failure"; + case CL_DEVICE_PARTITION_FAILED: + return "Device partition failed"; + case CL_KERNEL_ARG_INFO_NOT_AVAILABLE: + return "Kernel argument information not available"; + case CL_INVALID_VALUE: + return "Invalid value"; + case CL_INVALID_DEVICE_TYPE: + return "Invalid device type"; + case CL_INVALID_PLATFORM: + return "Invalid platform"; + case CL_INVALID_DEVICE: + return "Invalid device"; + case CL_INVALID_CONTEXT: + return "Invalid context"; + case CL_INVALID_QUEUE_PROPERTIES: + return "Invalid queue properties"; + case CL_INVALID_COMMAND_QUEUE: + return "Invalid command queue"; + case CL_INVALID_HOST_PTR: + return "Invalid host pointer"; + case CL_INVALID_MEM_OBJECT: + return "Invalid memory object"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: + return "Invalid image format descriptor"; + case CL_INVALID_IMAGE_SIZE: + return "Invalid image size"; + case CL_INVALID_SAMPLER: + return "Invalid sampler"; + case CL_INVALID_BINARY: + return "Invalid binary"; + case CL_INVALID_BUILD_OPTIONS: + return "Invalid build options"; + case CL_INVALID_PROGRAM: + return "Invalid program"; + case CL_INVALID_PROGRAM_EXECUTABLE: + return "Invalid program executable"; + case CL_INVALID_KERNEL_NAME: + return "Invalid kernel name"; + case CL_INVALID_KERNEL_DEFINITION: + return "Invalid kernel definition"; + case CL_INVALID_KERNEL: + return "Invalid kernel"; + case CL_INVALID_ARG_INDEX: + return "Invalid argument index"; + case CL_INVALID_ARG_VALUE: + return "Invalid argument value"; + case CL_INVALID_ARG_SIZE: + return "Invalid argument size"; + case CL_INVALID_KERNEL_ARGS: + return "Invalid kernel arguments"; + case CL_INVALID_WORK_DIMENSION: + return "Invalid work dimension"; + case CL_INVALID_WORK_GROUP_SIZE: + return "Invalid work group size"; + case CL_INVALID_WORK_ITEM_SIZE: + return "Invalid work item size"; + case CL_INVALID_GLOBAL_OFFSET: + return "Invalid global offset"; + case CL_INVALID_EVENT_WAIT_LIST: + return "Invalid event wait list"; + case CL_INVALID_EVENT: + return "Invalid event"; + case CL_INVALID_OPERATION: + return "Invalid operation"; + case CL_INVALID_GL_OBJECT: + return "Invalid GL object"; + case CL_INVALID_BUFFER_SIZE: + return "Invalid buffer size"; + case CL_INVALID_MIP_LEVEL: + return "Invalid mip-level"; + case CL_INVALID_GLOBAL_WORK_SIZE: + return "Invalid global work size"; + case CL_INVALID_PROPERTY: + return "Invalid property"; + case CL_INVALID_IMAGE_DESCRIPTOR: + return "Invalid image descriptor"; + case CL_INVALID_COMPILER_OPTIONS: + return "Invalid compiler options"; + case CL_INVALID_LINKER_OPTIONS: + return "Invalid linker options"; + case CL_INVALID_DEVICE_PARTITION_COUNT: + return "Invalid device partition count"; + case CL_INVALID_PIPE_SIZE: + return "Invalid pipe size"; + case CL_INVALID_DEVICE_QUEUE: + return "Invalid device queue"; + case CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR: + return "Invalid GL share group reference KHR"; + default: + return "Unknown OpenCL error code"; + } +} +} // namespace kernel +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.h b/mindspore/lite/src/runtime/kernel/opencl/utils.h new file mode 100644 index 0000000000..1593a93ee1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ + +#include +#include +#include "CL/cl2.hpp" +#include "utils/log_adapter.h" +#include "src/runtime/kernel/arm/nnacl/op_base.h" + +namespace mindspore::kernel { + +/** + * GetLocalSize + * @param number + * @param max_divider + * @return + */ +template +T GetBiggestDividerWithPriority(T number, N max_divider) { + if (number % 8 == 0 && 8 <= max_divider) { + return (T)8; + } + if (number % 4 == 0 && 4 <= max_divider) { + return (T)4; + } + if (number % 2 == 0 && 2 <= max_divider) { + return (T)2; + } + for (int i = max_divider; i != 0; i--) { + if (number % i == 0) { + return (T)i; + } + } + return (T)1; +} + +/** + * GetLocalSize + * @param n must be non negative + * @param divisor must be greater than zero + * @return + */ +template +T DivideRoundUp(T n, N divisor) { + const T div = static_cast(divisor); + const T q = n / div; + return n % div == 0 ? q : q + 1; +} + +/** + * GetLocalSize + * @param number + * @param n + * @return + */ +template +T AlignByN(T number, N n) { + return DivideRoundUp(number, n) * n; +} + +// GetGlobalSize +std::vector GetCommonGlobalSize(const std::vector &local, const std::vector &global); + +// GetLocalSize +std::vector GetCommonLocalSize(const std::vector &global, int max_size); + +std::string CLErrorCode(cl_int error_code); + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/opencl/CMakeLists.txt new file mode 100644 index 0000000000..5f5e73f867 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/CMakeLists.txt @@ -0,0 +1,11 @@ +set(OPENCL_RUNTIME_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_allocator.h + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_kernel.h + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_runtime.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_runtime.h + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_wrapper.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_wrapper.h + + ) diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc new file mode 100644 index 0000000000..ed579540b4 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc @@ -0,0 +1,375 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/opencl/opencl_allocator.h" +#include +#include "utils/log_adapter.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "include/errorcode.h" + +namespace mindspore::lite::opencl { + +OpenCLAllocator::OpenCLAllocator() {} +OpenCLAllocator::~OpenCLAllocator() {} + +void OpenCLAllocator::SetContext(const AllocatorContext &ctx) { + lock_flag_ = ctx.lockFlag; + shift_factor_ = ctx.shiftFactor; +} + +void OpenCLAllocator::Lock() { + if (lock_flag_) { + lock.lock(); + } +} + +void OpenCLAllocator::UnLock() { + if (lock_flag_) { + lock.unlock(); + } +} + +void *OpenCLAllocator::Malloc(size_t size) { + if (size > MAX_MALLOC_SIZE) { + MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; + return nullptr; + } + Lock(); + auto iter = free_list_.lower_bound(size); + if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { + auto mem_buf = iter->second; + free_list_.erase(iter); + allocated_list_[mem_buf->host_ptr_] = mem_buf; + UnLock(); + MS_LOG(DEBUG) << "Malloc buffer from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + return mem_buf->host_ptr_; + } + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + void *host_ptr = nullptr; + void *device_ptr = nullptr; + if (svm_capabilities && svm_on_) { + cl_svm_mem_flags flags = (svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) ? CL_MEM_SVM_FINE_GRAIN_BUFFER : 0; + flags |= (svm_capabilities & CL_DEVICE_SVM_ATOMICS) ? CL_MEM_SVM_ATOMICS : 0; + flags = flags | CL_MEM_READ_WRITE; + host_ptr = clSVMAlloc((*ocl_runtime->Context())(), flags, size, 0); + } else { + cl_int ret = CL_SUCCESS; + cl::Buffer *buffer = + new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create OpenCL buffer failed! (ERROR CODE: " << ret << ")"; + UnLock(); + return nullptr; + } + device_ptr = static_cast(buffer); + host_ptr = ocl_runtime->MapBuffer(*buffer, CL_MAP_READ | CL_MAP_WRITE, size); + if (host_ptr == nullptr) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr; + UnLock(); + return nullptr; + } + cl::Memory *mem = buffer; + ocl_runtime->UnmapBuffer(*mem, host_ptr); + } + std::unique_ptr mem_buf = std::make_unique(); + mem_buf->size_ = size; + mem_buf->device_ptr_ = device_ptr; + mem_buf->host_ptr_ = host_ptr; + MS_LOG(DEBUG) << "Malloc a new buffer. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + allocated_list_[host_ptr] = mem_buf.release(); + UnLock(); + return host_ptr; +} + +void *OpenCLAllocator::Malloc(size_t size, const std::vector& img_size) { + if (size > MAX_MALLOC_SIZE) { + MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; + return nullptr; + } + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + Lock(); + auto iter = free_list_.lower_bound(size); + if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { + auto mem_buf = iter->second; + bool is_match{mem_buf->img_size.size() == img_size.size()}; + for (int i = 0; i < img_size.size() && is_match; ++i) { + is_match = img_size[i] == mem_buf->img_size[i]; + } + if (is_match) { + free_list_.erase(iter); + allocated_list_[mem_buf->host_ptr_] = mem_buf; + UnLock(); + MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ + << ", host addr: " << mem_buf->host_ptr_ << ", device addr: " << mem_buf->device_ptr_; + return mem_buf->host_ptr_; + } + } + void *host_ptr = nullptr; + void *device_ptr = nullptr; + cl_int ret = CL_SUCCESS; + // CL_HALF_FLOAT, CL_FLOAT + cl::ImageFormat image_format(CL_RGBA, img_size[2]); + cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, + image_format, img_size[0], img_size[1], 0, nullptr, &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; + UnLock(); + return nullptr; + } + device_ptr = static_cast(buffer); + std::vector region{img_size[0], img_size[1], 1}; + host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region); + if (host_ptr == nullptr) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr; + UnLock(); + return nullptr; + } + cl::Memory *mem = buffer; + ocl_runtime->UnmapBuffer(*mem, host_ptr); + std::unique_ptr mem_buf = std::make_unique(); + mem_buf->size_ = size; + mem_buf->device_ptr_ = device_ptr; + mem_buf->host_ptr_ = host_ptr; + mem_buf->img_size = img_size; + MS_LOG(DEBUG) << "Malloc a new Image2D. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + allocated_list_[host_ptr] = mem_buf.release(); + UnLock(); + return host_ptr; +} + +void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::vector& img_size) { + if (size > MAX_MALLOC_SIZE) { + MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; + return nullptr; + } + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + Lock(); + auto iter = free_list_.lower_bound(size); + if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { + auto mem_buf = iter->second; + free_list_.erase(iter); + allocated_list_[mem_buf->host_ptr_] = mem_buf; + UnLock(); + MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + return mem_buf->host_ptr_; + } + void *host_ptr = nullptr; + void *device_ptr = nullptr; + cl_int ret = CL_SUCCESS; + // CL_HALF_FLOAT, CL_FLOAT + cl::ImageFormat image_format(CL_RGBA, img_size[2]); + cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, + img_size[0], img_size[1], 0, data, &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; + UnLock(); + return nullptr; + } + device_ptr = static_cast(buffer); + std::vector region{img_size[0], img_size[1], 1}; + host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region); + if (host_ptr == nullptr) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr; + UnLock(); + return nullptr; + } + cl::Memory *mem = buffer; + ocl_runtime->UnmapBuffer(*mem, host_ptr); + std::unique_ptr mem_buf = std::make_unique(); + mem_buf->size_ = size; + mem_buf->device_ptr_ = device_ptr; + mem_buf->host_ptr_ = host_ptr; + mem_buf->img_size = img_size; + MS_LOG(DEBUG) << "Malloc a new Image2D. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + allocated_list_[host_ptr] = mem_buf.release(); + UnLock(); + return host_ptr; +} +void OpenCLAllocator::Free(void *buf) { + if (buf == nullptr) { + return; + } + Lock(); + auto iter = allocated_list_.find(buf); + if (iter != allocated_list_.end()) { + auto mem_buf = iter->second; + allocated_list_.erase(iter); + free_list_.insert(std::make_pair(mem_buf->size_, mem_buf)); + UnLock(); + return; + } + UnLock(); + free(buf); +} + +size_t OpenCLAllocator::GetTotalSize() { + Lock(); + size_t totalSize = 0; + + for (auto it = allocated_list_.begin(); it != allocated_list_.end(); it++) { + totalSize += it->second->size_; + } + + for (auto it = free_list_.begin(); it != free_list_.end(); it++) { + totalSize += it->second->size_; + } + UnLock(); + return totalSize; +} + +void *OpenCLAllocator::GetDeviceBuffer(void *buffer) { + auto it = allocated_list_.find(buffer); + if (it != allocated_list_.end()) { + return it->second->device_ptr_; + } + return nullptr; +} + +void OpenCLAllocator::Clear() { + Lock(); + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + for (auto it = allocated_list_.begin(); it != allocated_list_.end(); it++) { + if (svm_capabilities) { + clSVMFree((*ocl_runtime->Context())(), it->second->host_ptr_); + MS_LOG(DEBUG) << "OpenCL free svm buffer : " << it->second->host_ptr_; + } else { + cl::Buffer *buff = static_cast(it->second->device_ptr_); + MS_LOG(DEBUG) << "OpenCL free device buffer : " << buff; + delete buff; + } + } + allocated_list_.clear(); + + for (auto it = free_list_.begin(); it != free_list_.end(); it++) { + if (svm_capabilities) { + clSVMFree((*ocl_runtime->Context())(), it->second->host_ptr_); + MS_LOG(DEBUG) << "OpenCL free svm buffer : " << it->second->host_ptr_; + } else { + cl::Buffer *buff = static_cast(it->second->device_ptr_); + MS_LOG(DEBUG) << "OpenCL free device buffer : " << buff; + delete buff; + } + } + free_list_.clear(); + UnLock(); +} + +void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue, bool sync) { + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + if (svm_capabilities && svm_on_) { + if (!(svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER)) { + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; + return nullptr; + } + ocl_runtime->MapBuffer(host_ptr, flags, it->second->size_, static_cast(command_queue), sync); + } + return host_ptr; + } + Lock(); + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; + UnLock(); + return nullptr; + } + MemBuf *mem_buf = it->second; + void *new_host_ptr{nullptr}; + if (mem_buf->img_size.empty()) { + cl::Buffer *buffer = static_cast(mem_buf->device_ptr_); + new_host_ptr = ocl_runtime->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync); + } else { + cl::ImageFormat image_format(CL_RGBA, mem_buf->img_size[2]); + std::vector region{mem_buf->img_size[0], mem_buf->img_size[1], 1}; + cl::Image2D *buffer = static_cast(mem_buf->device_ptr_); + new_host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region); + } + if (new_host_ptr == nullptr) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << mem_buf->device_ptr_ << ", host_ptr=" << host_ptr; + UnLock(); + return nullptr; + } + mem_buf->host_ptr_ = new_host_ptr; + allocated_list_.erase(it); + allocated_list_[new_host_ptr] = mem_buf; + UnLock(); + return new_host_ptr; +} + +int OpenCLAllocator::UnmapBuffer(void *host_ptr, void *command_queue) { + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + if (svm_capabilities) { + if (!(svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER)) { + return ocl_runtime->UnmapBuffer(host_ptr); + } + return 0; + } + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; + return 1; + } + cl::Buffer *buffer = static_cast(it->second->device_ptr_); + return ocl_runtime->UnmapBuffer(*buffer, it->second->host_ptr_, static_cast(command_queue)); +} + +MEM_TYPE OpenCLAllocator::GetMemType(void *host_ptr) { + MEM_TYPE mem_type{MEM_TYPE::BUF}; + Lock(); + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Can not found buffer :" << host_ptr; + UnLock(); + return mem_type; + } + MemBuf *mem_buf = it->second; + if (mem_buf->img_size.empty()) { + mem_type = MEM_TYPE::BUF; + } else { + mem_type = MEM_TYPE::IMG; + } + UnLock(); + return mem_type; +} + +int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector* img_size) { + Lock(); + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Can not found buffer :" << host_ptr; + UnLock(); + return RET_OK; + } + MemBuf *mem_buf = it->second; + if (!mem_buf->img_size.empty()) { + *img_size = mem_buf->img_size; + } + UnLock(); + return RET_OK; +} + +} // namespace mindspore::lite::opencl + diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.h b/mindspore/lite/src/runtime/opencl/opencl_allocator.h new file mode 100644 index 0000000000..0664020096 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_ALLOCATOR_H_ +#define MINDSPORE_LITE_SRC_OPENCL_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "src/runtime/allocator.h" + +namespace mindspore::lite::opencl { + +#define MS_HOST_BUFFER 0 +#define MS_CL_BUFFER (1 << 1) +#define MS_CL_IMAGE2D (1 << 2) +typedef int32_t OpenCLMemoryType; + +struct OpenclMemory { + void *host_ptr{nullptr}; + void *device_ptr{nullptr}; + OpenCLMemoryType mem_type{MS_HOST_BUFFER | MS_CL_BUFFER}; +}; + +enum class MEM_TYPE : char { + BUF, IMG +}; + +class OpenCLAllocator : public Allocator { + public: + OpenCLAllocator(); + ~OpenCLAllocator() override; + void SetContext(const AllocatorContext &ctx) override; + void *Malloc(size_t size) override; + void *Malloc(size_t size, const std::vector& img_size); + void *CreateImageFromHost(void *host_ptr, size_t size, const std::vector& img_size); + void Free(void *ptr) override; + size_t GetTotalSize() override; + + void Clear() override; + void *GetDeviceBuffer(void *buffer); + void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true); + int UnmapBuffer(void *host_ptr, void *command_queue = nullptr); + MEM_TYPE GetMemType(void *host_ptr); + int GetImageSize(void *host_ptr, std::vector* img_size); + + private: + void Lock(); + void UnLock(); + struct MemBuf { + size_t size_; + void *device_ptr_; + void *host_ptr_; + std::vector img_size; + }; + + std::mutex lock; + // buf, membuf> + std::unordered_map allocated_list_; + std::multimap free_list_; + // 6 is empirical value + int shift_factor_ = 6; + bool lock_flag_ = false; + bool svm_on_{false}; +}; + +} // namespace mindspore::lite::opencl + +#endif // MINDSPORE_LITE_SRC_OPENCL_ALLOCATOR_H_ + diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc new file mode 100644 index 0000000000..be273ea523 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/opencl/opencl_executor.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "src/common/ms_tensor_utils.h" +#include "include/errorcode.h" + +namespace mindspore::lite::opencl { +int OpenCLExecutor::Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator, + const session::KernelCallBack &before, const session::KernelCallBack &after) { + MS_ASSERT(nullptr != allocator); + for (auto &inTensor : inputs) { + if (inTensor == nullptr) { + MS_LOG(ERROR) << "Graph input tensor is nullptr"; + return RET_ERROR; + } + if (inTensor->GetFormat() != schema::Format_NHWC4 && inTensor->GetFormat() != schema::Format_NC4HW4 && + inTensor->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "input should be NHWC/NHWC4/NC4HW4, actual is " << schema::EnumNameFormat(inTensor->GetFormat()); + return RET_ERROR; + } else { + TransformTensorLayout(inTensor, inTensor->GetFormat(), schema::Format_NHWC4, true); + // TransformTensorLayout(inTensor, inTensor->GetFormat(), schema::Format_NC4HW4, true); + } + } + kernel::LiteKernelUtil::InitTensorRefCount(kernels); + OpenCLAllocator* op_allocator = reinterpret_cast(allocator); + for (auto *kernel : kernels) { + MS_ASSERT(nullptr != kernel); + kernel::OpenCLKernel *op_kernel = reinterpret_cast(kernel); + auto &outputs = kernel->GetOutputs(); + for (auto i = 0; i < outputs.size(); ++i) { + auto *output = outputs.at(i); + MS_ASSERT(nullptr != output); + if (is_image2d_out_) { + std::vector img_size; + op_kernel->GetImageSize(i, &img_size); + auto data_ptr = op_allocator->Malloc(output->Size(), img_size); + output->SetData(data_ptr); + } else { + output->MallocData(allocator); + } + } + session::CallBackParam callbackParam; + callbackParam.name_callback_param = kernel->Name(); + + if (before != nullptr) { + if (!before(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()), callbackParam)) { + MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->Name(); + } + } + auto ret = kernel->Run(); + if (0 != ret) { + MS_LOG(ERROR) << "run kernel failed, name: " << kernel->Name(); + return ret; + } + + if (after != nullptr) { + if (!after(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()), callbackParam)) { + MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->Name(); + } + } + for (auto input_kernel : kernel->GetInKernels()) { + MS_EXCEPTION_IF_NULL(input_kernel); + ret = input_kernel->DecOutTensorRefCount(); + if (0 != ret) { + MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed"; + } + } + } + // output format transform + for (auto &outTensor : outputs) { + if (outTensor == nullptr) { + MS_LOG(ERROR) << "Graph output tensor is nullptr"; + return RET_ERROR; + } + if (outTensor->GetFormat() != schema::Format_NHWC) { + TransformTensorLayout(outTensor, outTensor->GetFormat(), schema::Format_NHWC, false); + } + } + return RET_OK; +} + +int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format src_format, + schema::Format dst_format, bool trans_dir) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(4 == tensor->shape().size()); + auto data_type = tensor->data_type(); + switch (data_type) { + case kNumberTypeInt8: + return TransformTensorLayoutUint8(tensor, src_format, dst_format, trans_dir); + case kNumberTypeFloat32: + return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir); + default: + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format); + return RET_ERROR; + } + return RET_OK; +} + +int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format src_format, + schema::Format dst_format, bool trans_dir) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator_); + MS_ASSERT(4 == tensor->shape().size()); + if (trans_dir) { + if (is_image2d_out_) { + return TransformTensorLayoutToImage(tensor, src_format, dst_format); + } else { + return TransformTensorLayoutToBuffer(tensor, src_format, dst_format); + } + } else { + if (is_image2d_out_) { + return TransformTensorLayoutFromImage(tensor, src_format, dst_format); + } else { + return TransformTensorLayoutToBuffer(tensor, src_format, dst_format); + } + } +} + +int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema::Format src_format, + schema::Format dst_format) { + if (dst_format == schema::Format_NHWC4) { + auto *src_data = tensor->Data(); + size_t C4 = UP_DIV(tensor->Channel(), C4NUM); + std::vector img_size{tensor->Width() * C4, (size_t) tensor->Height(), CL_FLOAT}; + if (src_format == schema::Format_NHWC) { + auto *dst_data = allocator_->Malloc(tensor->Size(), img_size); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + dst_data = reinterpret_cast(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true)); + PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); + tensor->SetData(dst_data); + allocator_->Free(src_data); + allocator_->UnmapBuffer(dst_data); + } + tensor->SetFormat(dst_format); + return RET_OK; + } else if (dst_format == schema::Format_NHWC) { + // TODO(wandongdong): add support !! + return RET_OK; + } else { + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in float32"; + return RET_ERROR; + } +} + +int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema::Format src_format, + schema::Format dst_format) { + if (dst_format == schema::Format_NHWC4) { + // convert to nhwc4 + auto *src_data = tensor->Data(); + auto *dst_data{src_data}; + if (src_format == schema::Format_NHWC) { + dst_data = allocator_->Malloc(tensor->Size()); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + dst_data = reinterpret_cast(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true)); + PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); + tensor->SetData(dst_data); + allocator_->Free(src_data); + allocator_->UnmapBuffer(dst_data); + } + // copy to image2d + src_data = dst_data; + size_t C4 = UP_DIV(tensor->Channel(), C4NUM); + std::vector img_size{tensor->Width() * C4, (size_t)tensor->Height(), CL_FLOAT}; + dst_data = allocator_->CreateImageFromHost(src_data, tensor->Size(), img_size); + tensor->SetData(dst_data); + allocator_->Free(src_data); + tensor->SetFormat(schema::Format_NHWC4); + return RET_OK; + } else { + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in float32"; + return RET_ERROR; + } +} + +int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schema::Format src_format, + schema::Format dst_format) { + if (dst_format == schema::Format_NHWC) { + auto src_data = tensor->Data(); + auto dst_data = allocator_->Malloc(tensor->Size()); + cl::Image2D *out_mem = reinterpret_cast(allocator_->GetDeviceBuffer(src_data)); + std::vector img_size; + allocator_->GetImageSize(src_data, &img_size); + auto origin = cl::array < cl::size_type, 3U > {0, 0, 0}; + auto region = cl::array < cl::size_type, 3U > {img_size[0], img_size[1], 1}; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(*out_mem, CL_TRUE, origin, region, 0, 0, dst_data); + tensor->SetData(dst_data); + allocator_->Free(src_data); + return RET_OK; + } else { + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in float32"; + return RET_ERROR; + } +} + +int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format src_format, + schema::Format dst_format, bool is_image) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(4 == tensor->shape().size()); + // auto src_format = tensor->GetFormat(); + // todo + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in uint8"; + return RET_ERROR; +} +} // namespace mindspore::lite::opencl + diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.h b/mindspore/lite/src/runtime/opencl/opencl_executor.h new file mode 100644 index 0000000000..d40a13574f --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_EXECUTOR_H_ +#define MINDSPORE_LITE_SRC_OPENCL_EXECUTOR_H_ + +#include +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/allocator.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/executor.h" +#include "include/lite_session.h" + +namespace mindspore::lite::opencl { +class OpenCLExecutor : Executor { + public: + OpenCLExecutor() : Executor() { + allocator_ = OpenCLRuntime::GetInstance()->GetAllocator(); + } + + int Prepare(const std::vector &kernels) { return 0; } + + int Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator = nullptr, + const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr); + + protected: + int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format, + bool trans_dir = false); + + int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format, + bool trans_dir = false); + + int TransformTensorLayout(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format, + bool trans_dir = false); + + int TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format); + + int TransformTensorLayoutToImage(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format); + + int TransformTensorLayoutFromImage(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format); + + protected: + Context *context = nullptr; + OpenCLAllocator *allocator_; + bool is_image2d_out_{true}; +}; + +} // namespace mindspore::lite::opencl +#endif + diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc new file mode 100644 index 0000000000..d503e2a32e --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -0,0 +1,626 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/opencl/opencl_runtime.h" +#include +#include +#ifdef SHARING_MEM_WITH_OPENGL +#include +#endif +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/opencl/opencl_allocator.h" +#ifdef PROGRAM_WITH_IL +#include "src/backend/opencl/cl/program.inc" +#endif + +#ifndef ROUND_UP +#define ROUND_UP(x, y) ((static_cast(x) + static_cast(y) - (1)) / static_cast(y) * static_cast(y)) +#endif + +using mindspore::kernel::CLErrorCode; + +namespace mindspore::lite::opencl { + +std::map g_opencl_program_map; + +static std::mutex g_mtx; +static std::mutex g_init_mtx; + +// magic number +static std::map AdrenoSubGroup{ + {640, 128}, {630, 128}, {616, 128}, {612, 64}, {610, 64}, {540, 32}, {530, 32}, + {512, 32}, {510, 32}, {509, 32}, {506, 32}, {505, 32}, {405, 32}, {330, 16}, +}; + +#ifdef USE_OPENCL_WRAPPER +std::shared_ptr OpenCLWrapper::opencl_wrapper_singleton_ = nullptr; +#endif +std::shared_ptr OpenCLRuntime::opencl_runtime_singleton_ = nullptr; +bool OpenCLRuntime::init_done_ = false; + +OpenCLRuntime *OpenCLRuntime::GetInstance() { + std::unique_lock lck(g_mtx); + if (opencl_runtime_singleton_.get() == nullptr) { + opencl_runtime_singleton_.reset(new OpenCLRuntime()); + opencl_runtime_singleton_->Init(); + } + return opencl_runtime_singleton_.get(); +} + +void OpenCLRuntime::DeleteInstance() { + std::unique_lock lck(g_mtx); + init_done_ = false; + if (opencl_runtime_singleton_ != nullptr) { + opencl_runtime_singleton_.reset(); + opencl_runtime_singleton_ = nullptr; + } +} + +OpenCLRuntime::OpenCLRuntime() { default_build_opts_ = " -cl-mad-enable -cl-fast-relaxed-math -Werror"; } + +// Init will get platforms info, get devices info, create opencl context. +int OpenCLRuntime::Init() { + std::unique_lock lck(g_init_mtx); + + if (init_done_) { + return 0; + } + MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION; + MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION; + MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION; + +#ifdef USE_OPENCL_WRAPPER + if (false == OpenCLWrapper::GetInstance()->LoadOpenCLLibrary()) { + MS_LOG(ERROR) << "Load OpenCL symbols failed!"; + return 1; + } +#endif // USE_OPENCL_WRAPPER + + std::vector platforms; + cl::Platform::get(&platforms); + if (platforms.size() == 0) { + MS_LOG(ERROR) << "OpenCL Platform not found!"; + return 1; + } + + // search GPU + std::vector devices; + for (auto it = platforms.begin(); it != platforms.end(); ++it) { + std::string platform_name; + it->getInfo(CL_PLATFORM_NAME, &platform_name); + it->getDevices(CL_DEVICE_TYPE_GPU, &devices); + MS_LOG(INFO) << "Platform (" << platform_name << ") has " << devices.size() << " GPUs"; + + if (devices.size() > 0) { + std::string device_name = devices[0].getInfo(); + MS_LOG(INFO) << "Find GPU: " << device_name.c_str(); + cl::Platform::setDefault(*it); + break; + } + } + + // not found, return error code. + if (devices.size() == 0) { + MS_LOG(ERROR) << "OpenCL Device not found!"; + return 1; + } + + device_ = std::make_shared(); + *device_ = devices[0]; + max_work_item_sizes_ = device_->getInfo(); + const std::string device_name = device_->getInfo(); + const std::string device_version = device_->getInfo(); + const std::string opencl_version = device_->getInfo(); + cl_uint align; + size_t ret; + clGetDeviceInfo((*device_)(), CL_DEVICE_IMAGE_PITCH_ALIGNMENT, sizeof(cl_uint), &align, &ret); + MS_LOG(INFO) << "Device name:\t" << device_name; + MS_LOG(INFO) << "Opencl version:\t" << device_version; + MS_LOG(INFO) << "Image alignment:\t" << align; + MS_LOG(INFO) << "Image ret:\t" << ret; + MS_LOG(INFO) << "Highest OpenCL c version:\t" << opencl_version; + MS_LOG(INFO) << "Max work item size:\t" + << max_work_item_sizes_[0] << " : " + << max_work_item_sizes_[1] << " : " + << max_work_item_sizes_[2]; + + gpu_info_ = ParseGpuInfo(device_name, device_version); + cl_int err; +#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120) + // create context from glcontext + MS_LOG(INFO) << "Create special opencl context to share with OpenGL"; + cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(), + CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0}; + context_ = std::make_shared(std::vector{*device_}, context_prop, nullptr, nullptr, &err); + + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Create special OpenCL context falied, Create common OpenCL context then."; + context_ = std::make_shared(std::vector{*device_}, nullptr, nullptr, nullptr, &err); + } +#else + MS_LOG(INFO) << "Create common opencl context"; + context_ = std::make_shared(std::vector{*device_}, nullptr, nullptr, nullptr, &err); +#endif + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(err); + return 1; + } + + // get cache size, compute units and frequency. + device_->getInfo(CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, &global_memery_cachesize_); + device_->getInfo(CL_DEVICE_MAX_COMPUTE_UNITS, &compute_units_); + device_->getInfo(CL_DEVICE_MAX_CLOCK_FREQUENCY, &max_freq_); + cl_device_fp_config fp_config; + auto success = device_->getInfo(CL_DEVICE_HALF_FP_CONFIG, &fp_config); + support_fp16_ = CL_SUCCESS == success && fp_config > 0; + + err = device_->getInfo(CL_DEVICE_SVM_CAPABILITIES, &svm_capabilities_); + svm_capabilities_ = 0; + if (err != CL_SUCCESS || svm_capabilities_ == 0) { + svm_capabilities_ = 0; + MS_LOG(INFO) << "SVM capalibilties: " + << "NONE"; + } else { + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_FINE_GRAIN_BUFFER"; + } + if (svm_capabilities_ & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_COARSE_GRAIN_BUFFER"; + } + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_COARSE_GRAIN_SYSTEM"; + } + if (svm_capabilities_ & CL_DEVICE_SVM_ATOMICS) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_ATOMICS"; + } + } + + MS_LOG(INFO) << "Global Mem Cache Size: " << global_memery_cachesize_; + MS_LOG(INFO) << "Compute Unit: " << compute_units_; + MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz"; + + cl_command_queue_properties properties = 0; +#if MS_OPENCL_PROFILE + properties |= CL_QUEUE_PROFILING_ENABLE; +#endif + + default_command_queue_ = std::make_shared(*context_, *device_, properties, &err); + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Command Queue create failed: " << CLErrorCode(err); + return 1; + } + + allocator_ = std::make_shared(); +#ifdef PROGRAM_WITH_IL + std::string flag = ""; + CreateProgramFromIL(g_program_binary, flag); +#endif + init_done_ = true; + MS_LOG(INFO) << "OpenCLRuntime init done!"; + + return 0; +} + +OpenCLRuntime::~OpenCLRuntime() { + program_map_.clear(); + // allocator_->Clear(); + allocator_.reset(); + default_command_queue_.reset(); + context_.reset(); + device_.reset(); +} + +cl::Context *OpenCLRuntime::Context() { return context_.get(); } + +cl::Device *OpenCLRuntime::Device() { return device_.get(); } + +uint64_t OpenCLRuntime::DeviceGlobalMemoryCacheSize() const { return global_memery_cachesize_; } + +int OpenCLRuntime::DeviceMaxWorkGroupSize() const { return max_work_group_size; } + +uint32_t OpenCLRuntime::DeviceComputeUnits() const { return compute_units_; } + +uint32_t OpenCLRuntime::DeviceMaxFreq() const { return max_freq_; } + +// get kernel enqueue max work group size +uint64_t OpenCLRuntime::GetMaxWorkGroupSize(const cl::Kernel &kernel) { + uint64_t max_workgroup_size = 0; + int ret = kernel.getWorkGroupInfo(*device_, CL_KERNEL_WORK_GROUP_SIZE, &max_workgroup_size); + if (ret != 0) max_workgroup_size = 0; + return max_workgroup_size; +} + +// opencl 2.0 can get SubGroupSize. +uint32_t OpenCLRuntime::GetSubGroupSize(const cl::Kernel &kernel, const cl::NDRange &range) { + uint32_t sub_group_size = 0; + + if (ADRENO == gpu_info_.type) { +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 && CL_TARGET_OPENCL_VERSION >= 210 && defined(CL_HPP_USE_CL_SUB_GROUPS_KHR) + cl_int cl_ret; + sub_group_size = kernel.getSubGroupInfo(*device_, range, &cl_ret); + if (cl_ret != CL_SUCCESS) { + CHECK_CL_SUCCESS(cl_ret) + sub_group_size = 0; + } +#else + if (AdrenoSubGroup.find(gpu_info_.model_num) != AdrenoSubGroup.end()) { + sub_group_size = AdrenoSubGroup[gpu_info_.model_num]; + } +#endif + } + + return sub_group_size; +} + +GpuInfo OpenCLRuntime::GetGpuInfo() { return gpu_info_; } + +bool OpenCLRuntime::GetFp16Enable() const { return fp16_enable_; } + +// if support fp16, set fp16 will success. +bool OpenCLRuntime::SetFp16Enable(bool enable) { + fp16_enable_ = enable && support_fp16_; + return fp16_enable_ == enable; +} + +int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, + const std::set &build_options) { + std::string build_options_str; + // set default macro + if (fp16_enable_) { + // fp16 enable, kernel will use half and read_imageh and write_imageh. + build_options_str = + "-DFLOAT=half -DFLOAT4=half4 -DRI_F=read_imageh " + "-DWI_F=write_imageh"; + } else { + // fp16 not enable, kernel will use float and read_imagef and write_imagef. + build_options_str = + "-DFLOAT=float -DFLOAT4=float4 -DRI_F=read_imagef " + "-DWI_F=write_imagef"; + } + + build_options_str = std::accumulate( + build_options.begin(), build_options.end(), build_options_str, + [](const std::string &options, const std::string &option) -> std::string { return options + " " + option; }); + build_options_str += default_build_opts_; + // program identifier = program_name + build_options + std::string build_program_key = program_name + build_options_str; + + auto build_program_it = program_map_.find(build_program_key); + cl::Program program; + // if search program identifier exist, then use it. + if (build_program_it != program_map_.end()) { + program = build_program_it->second; + } else { + // load program and build program + auto status = this->LoadProgram(program_name, &program); + if (!status) { + MS_LOG(ERROR) << "load program (" << program_name << ") failed!"; + return 1; + } + status = this->BuildProgram(build_options_str, &program); + if (!status) { + MS_LOG(ERROR) << program_name << " build failed!"; + return 1; + } + program_map_.emplace(build_program_key, program); + } + + cl_int err; + kernel = cl::Kernel(program, kernel_name.c_str(), &err); + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << kernel_name << " Kernel create failed:" << CLErrorCode(err); + return 1; + } + return 0; +} + +// Run Kernel with 1D, 2D, 3D group size, and local size can be empty. +int OpenCLRuntime::RunKernel(const cl_kernel &kernel, const std::vector &global, + const std::vector &local, cl::CommandQueue *command_queue) { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + MS_ASSERT(local.size() == 0 || local.size() == global.size()); + std::vector internal_global_ws = global; + for (size_t i = 0; i < local.size(); ++i) { + internal_global_ws[i] = ROUND_UP(global[i], local[i]); + } + + MS_LOG(INFO) << "global size: " << global.size() << ", local size: " << local.size(); + for (size_t i = 0; i < global.size(); i++) { + MS_LOG(DEBUG) << "global[" << i << "] = " << global[i]; + } + for (size_t i = 0; i < local.size(); i++) { + MS_LOG(DEBUG) << "local[" << i << "] = " << local[i]; + } + + cl::Event event; + cl_int error = CL_SUCCESS; + if (local.size() == 0) { + error = + clEnqueueNDRangeKernel((*command_queue)(), kernel, global.size(), 0, global.data(), nullptr, 0, nullptr, nullptr); + } else { + error = clEnqueueNDRangeKernel((*command_queue)(), kernel, global.size(), 0, global.data(), local.data(), 0, + nullptr, nullptr); + } + + if (error != CL_SUCCESS) { + MS_LOG(ERROR) << "Kernel execute failed:" << CLErrorCode(error); + return 1; + } + MS_LOG(INFO) << "RunKernel success!"; + return 0; +} + +// Run Kernel with 1D, 2D, 3D group size, and local size can be empty. +int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const std::vector &global, + const std::vector &local, cl::CommandQueue *command_queue) { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + MS_ASSERT(local.size() == 0 || local.size() == global.size()); + std::vector internal_global_ws = global; + for (size_t i = 0; i < local.size(); ++i) { + internal_global_ws[i] = ROUND_UP(global[i], local[i]); + } + + MS_LOG(INFO) << "global size: " << global.size() << ", local size: " << local.size(); + for (size_t i = 0; i < global.size(); i++) { + MS_LOG(DEBUG) << "global[" << i << "] = " << global[i]; + } + for (size_t i = 0; i < local.size(); i++) { + MS_LOG(DEBUG) << "local[" << i << "] = " << local[i]; + } + + cl::Event event; + cl_int err = CL_SUCCESS; + + cl::NDRange global_range = cl::NullRange; + cl::NDRange local_range = cl::NullRange; + if (global.size() == 1) { + global_range = cl::NDRange(internal_global_ws[0]); + if (!local.empty()) { + local_range = cl::NDRange(local[0]); + } + } else if (global.size() == 2) { + global_range = cl::NDRange(internal_global_ws[0], internal_global_ws[1]); + if (!local.empty()) { + local_range = cl::NDRange(local[0], local[1]); + } + } else if (global.size() == 3) { + global_range = cl::NDRange(internal_global_ws[0], internal_global_ws[1], internal_global_ws[2]); + if (!local.empty()) { + local_range = cl::NDRange(local[0], local[1], local[2]); + } + } else { + MS_LOG(INFO) << "Not supported NDRange!"; + return 1; + } + + err = command_queue->enqueueNDRangeKernel(kernel, cl::NullRange, global_range, local_range, nullptr, &event); + + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Kernel execute failed:" << CLErrorCode(err); + return 1; + } + MS_LOG(INFO) << "RunKernel success!"; +#if MS_OPENCL_PROFILE + event.wait(); + cl_ulong time_start; + cl_ulong time_end; + event.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start); + event.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end); + double nanoSeconds = time_end - time_start; + MS_LOG(INFO) << "OpenCl Execution time is: " << nanoSeconds / 1000000.0 << "ms"; +#endif + return 0; +} + +// get gpu divce type +GpuInfo OpenCLRuntime::ParseGpuInfo(std::string device_name, std::string device_version) { + GpuInfo info; + + if (device_name == "QUALCOMM Adreno(TM)") { + info.type = ADRENO; + sscanf(device_version.c_str(), "%*s%f%*s%d", &info.opencl_version, &info.model_num); + + } else if (device_name.find("Mali") != std::string::npos) { + info.type = MALI; + + // Mali type MALI-G or MALI_T + if (device_name.find("Mali-G") != std::string::npos) { + info.type = MALI_G; + sscanf(device_name.c_str(), "Mali-G%d", &info.model_num); + } else if (device_name.find("Mali-T") != std::string::npos) { + info.type = MALI_T; + sscanf(device_name.c_str(), "Mali-T%d", &info.model_num); + } + sscanf(device_version.c_str(), "%*s%f%*s", &info.opencl_version); + } + + return info; +} + +bool OpenCLRuntime::LoadSource(const std::string &program_name, const std::string &source) { + auto it_source = g_opencl_program_map.find(program_name); + if (it_source != g_opencl_program_map.end()) { + it_source->second = source; + } else { + g_opencl_program_map.emplace(program_name, source); + } + return true; +} + +// load program with program name. +bool OpenCLRuntime::LoadProgram(const std::string &program_name, cl::Program *program) { + auto it_source = g_opencl_program_map.find(program_name); + if (it_source != g_opencl_program_map.end()) { + cl::Program::Sources sources; + sources.push_back(it_source->second); + *program = cl::Program(*context_, sources); + return true; + } else { + MS_LOG(ERROR) << "Can't find kernel source !"; + return false; + } +} + +// build program with build options +bool OpenCLRuntime::BuildProgram(const std::string &build_options, cl::Program *program) { + cl_int ret = program->build({*device_}, build_options.c_str()); + if (ret != CL_SUCCESS) { + if (program->getBuildInfo(*device_) == CL_BUILD_ERROR) { + std::string build_log = program->getBuildInfo(*device_); + MS_LOG(ERROR) << "Program build log: " << build_log; + } + MS_LOG(ERROR) << "Build program failed: " << CLErrorCode(ret); + return false; + } + return true; +} + +bool OpenCLRuntime::CopyDeviceMemToHost(void *dst, const void *src, size_t size, cl::CommandQueue *command_queue, + bool sync) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl_int cl_ret = CL_SUCCESS; + const cl::Buffer *buffer = static_cast(src); + if (command_queue != nullptr) { + cl_ret = command_queue->enqueueReadBuffer(*buffer, sync, 0, size, dst); + } + return cl_ret == CL_SUCCESS; +} + +bool OpenCLRuntime::CopyHostMemToDevice(const void *dst, const void *src, size_t size, cl::CommandQueue *command_queue, + bool sync) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl_int cl_ret = CL_SUCCESS; + const cl::Buffer *buffer = static_cast(dst); + if (command_queue != nullptr) { + cl_ret = command_queue->enqueueWriteBuffer(*buffer, sync, 0, size, src); + } + return cl_ret == CL_SUCCESS; +} + +void *OpenCLRuntime::MapBuffer(const cl::Buffer buffer, int flags, size_t size, cl::CommandQueue *command_queue, + bool sync) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueMapBuffer(buffer, sync, flags, 0, size); +} + +int OpenCLRuntime::MapBuffer(void *host_ptr, int flags, size_t size, cl::CommandQueue *command_queue, bool sync) const { + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) { + return 0; + } + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueMapSVM(host_ptr, sync, flags, size); +} + +void *OpenCLRuntime::MapBuffer(const cl::Image2D buffer, bool sync, int flags, + const std::vector& region, cl::CommandQueue *command_queue) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl::size_type row_pitch; + cl::size_type slice_pitch; + cl::array origin_{0, 0, 0}; + cl::array region_{region[0], region[1], region[2]}; + return command_queue->enqueueMapImage(buffer, sync, flags, origin_, region_, &row_pitch, &slice_pitch); +} + +int OpenCLRuntime::UnmapBuffer(const cl::Memory buffer, void *host_ptr, cl::CommandQueue *command_queue) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueUnmapMemObject(buffer, host_ptr); +} + +int OpenCLRuntime::UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue) const { + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) { + return 0; + } + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueUnmapSVM(host_ptr); +} + +bool OpenCLRuntime::SyncCommandQueue(cl::CommandQueue *command_queue) { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl_int ret = command_queue->finish(); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Command queue sync failed: " << CLErrorCode(ret); + return 1; + } + return ret == CL_SUCCESS; +} + +int OpenCLRuntime::GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id) { + size_t max_work_group_size; + cl_int err = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE, sizeof(size_t), + &max_work_group_size, nullptr); + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Failed to get info CL_KERNEL_WORK_GROUP_SIZE " << CLErrorCode(err); + } + return static_cast(max_work_group_size); +} + +bool OpenCLRuntime::CreateKernelFromIL(cl_kernel &kernel, const std::string kernel_name) { + cl_int ret = CL_SUCCESS; + kernel = clCreateKernel(il_program_, kernel_name.c_str(), &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create kernel with IL failed: " << CLErrorCode(ret); + } + return ret == CL_SUCCESS; +} + +// build program with IL +bool OpenCLRuntime::CreateProgramFromIL(const std::vector program_binary, const std::string flag) { +#if CL_HPP_TARGET_OPENCL_VERSION >= 210 + size_t program_length = program_binary.size(); + cl_int ret = CL_SUCCESS; + il_program_ = clCreateProgramWithIL((*context_)(), program_binary.data(), program_length, &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create program with IL failed: " << CLErrorCode(ret); + return false; + } + + ret = clBuildProgram(il_program_, 1, &(*device_)(), flag.c_str(), NULL, NULL); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Build program with IL failed: " << CLErrorCode(ret); + } + return ret == CL_SUCCESS; +#else + MS_LOG(ERROR) << "Create program with IL failed! The compute capabitity of device should be 2.1 and higher."; + return false; +#endif +} + +} // namespace mindspore::lite::opencl + diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.h b/mindspore/lite/src/runtime/opencl/opencl_runtime.h new file mode 100644 index 0000000000..173d0416d6 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.h @@ -0,0 +1,164 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); +j* you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ +#define MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "src/runtime/opencl/opencl_wrapper.h" +#include "src/runtime/opencl/opencl_allocator.h" + +namespace mindspore::lite::opencl { + +enum GpuType { OTHER = 0, ADRENO = 1, MALI = 2, MALI_T = 3, MALI_G = 4 }; + +struct GpuInfo { + GpuType type = OTHER; + int model_num = 0; + float opencl_version = 0; +}; + +// Base GPU cache size used for computing local work group size. +const int32_t g_base_gpu_mem_cachesize = 16384; + +class OpenCLRuntime { + public: + static OpenCLRuntime *GetInstance(); + static void DeleteInstance(); + + ~OpenCLRuntime(); + OpenCLRuntime(const OpenCLRuntime &) = delete; + OpenCLRuntime &operator=(const OpenCLRuntime &) = delete; + + int Init(); + + cl::Context *Context(); + cl::Device *Device(); + OpenCLAllocator *GetAllocator() { return allocator_.get(); } + cl::CommandQueue *GetDefaultCommandQueue() { return default_command_queue_.get(); } + uint64_t DeviceGlobalMemoryCacheSize() const; + int DeviceMaxWorkGroupSize() const; + uint32_t DeviceComputeUnits() const; + uint32_t DeviceMaxFreq() const; + uint64_t GetMaxWorkGroupSize(const cl::Kernel &kernel); + uint32_t GetSubGroupSize(const cl::Kernel &kernel, const cl::NDRange &range = cl::NullRange); + GpuInfo GetGpuInfo(); + bool GetFp16Enable() const; + bool SetFp16Enable(bool enable); + const std::vector &GetWorkItemSize() { return max_work_item_sizes_; } + cl_device_svm_capabilities GetSVMCapabilities() const { return svm_capabilities_; } + + template + typename std::enable_if::value, cl_int>::type SetKernelArg(cl_kernel &kernel, uint32_t index, + const T value) { + if (svm_capabilities_) { + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value; + return clSetKernelArgSVMPointer(kernel, index, value); + } else { + MEM_TYPE mem_type = allocator_->GetMemType(value); + if (mem_type == MEM_TYPE::BUF) { + cl::Buffer *buffer = reinterpret_cast(allocator_->GetDeviceBuffer(value)); + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << value; + return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)()); + } else { + cl::Image2D *buffer = reinterpret_cast(allocator_->GetDeviceBuffer(value)); + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Image2D " << value; + return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)()); + } + } + } + + template + typename std::enable_if::value, cl_int>::type SetKernelArg(cl_kernel &kernel, uint32_t index, + const T value) { + return clSetKernelArg(kernel, index, sizeof(value), &value); + } + + template + int SetKernelArg(cl::Kernel &kernel, uint32_t index, const T &value) { + return SetKernelArg(kernel(), index, value); + } + + bool CreateProgramFromIL(const std::vector program_binary, const std::string flag); + bool CreateKernelFromIL(cl_kernel &kernel, const std::string kernel_name); + bool LoadSource(const std::string &program_name, const std::string &source); + int BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, + const std::set &build_options); + int RunKernel(const cl_kernel &kernel, const std::vector &global, const std::vector &local, + cl::CommandQueue *command_queue); + int RunKernel(const cl::Kernel &kernel, const std::vector &global, const std::vector &local, + cl::CommandQueue *command_queue); + bool CopyDeviceMemToHost(void *dst, const void *src, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + bool CopyHostMemToDevice(const void *dst, const void *src, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + void *MapBuffer(const cl::Buffer buffer, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + void *MapBuffer(const cl::Image2D buffer, bool sync, int flags, + const std::vector& region, cl::CommandQueue *command_queue = nullptr) const; + int MapBuffer(void *host_ptr, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + int UnmapBuffer(const cl::Memory buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; + int UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; + bool SyncCommandQueue(cl::CommandQueue *command_queue = nullptr); + + /** + * Get kernel max worker group size. + * @param kernel + * @param device_id + * @return max_work_group_size + */ + int GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id); + + private: + OpenCLRuntime(); + GpuInfo ParseGpuInfo(std::string device_name, std::string device_version); + + bool LoadProgram(const std::string &program_name, cl::Program *program); + bool BuildProgram(const std::string &build_options, cl::Program *program); + + private: + static std::shared_ptr opencl_runtime_singleton_; + static bool init_done_; + std::shared_ptr default_command_queue_{nullptr}; + std::shared_ptr context_{nullptr}; + std::shared_ptr device_{nullptr}; + std::shared_ptr allocator_{nullptr}; + std::map program_map_{}; + cl_program il_program_{0}; + uint64_t global_memery_cachesize_{0}; + int max_work_group_size; + uint32_t compute_units_{0}; + uint32_t max_freq_{0}; + std::string default_build_opts_{""}; + GpuInfo gpu_info_; + bool support_fp16_{false}; + bool fp16_enable_{false}; + cl_device_svm_capabilities svm_capabilities_{0}; + std::vector max_work_item_sizes_; +}; + +} // namespace mindspore::lite::opencl + +#endif // MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ + diff --git a/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc b/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc new file mode 100644 index 0000000000..084afc344a --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc @@ -0,0 +1,683 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef USE_OPENCL_WRAPPER + +#include "src/runtime/opencl/opencl_wrapper.h" +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::lite::opencl { + +// default opencl library path +static const std::vector g_opencl_library_paths = { +#if defined(__APPLE__) || defined(__MACOSX) + "libOpenCL.so", "/System/Library/Frameworks/OpenCL.framework/OpenCL" +#elif defined(__ANDROID__) +#if defined(__aarch64__) + // Mali + "/system/vendor/lib64/egl/libGLES_mali.so", + "/system/lib64/egl/libGLES_mali.so", + // Qualcomm Adreno + "/system/vendor/lib64/libOpenCL.so", + "/system/lib64/libOpenCL.so", +#else + // Qualcomm Adreno + "/system/vendor/lib/libOpenCL.so", "/system/lib/libOpenCL.so", + // Mali + "/system/vendor/lib/egl/libGLES_mali.so", "/system/lib/egl/libGLES_mali.so", + // other + "/system/vendor/lib/libPVROCL.so", "/data/data/org.pocl.libs/files/lib/libpocl.so" +#endif + "libOpenCL.so", + "libGLES_mali.so", + "libmali.so", +#elif defined(__linux__) + "/usr/lib/libOpenCL.so", + "/usr/local/lib/libOpenCL.so", + "/usr/local/lib/libpocl.so", + "/usr/lib64/libOpenCL.so", + "/usr/lib32/libOpenCL.so", + "libOpenCL.so", + // intel + "/opt/intel/system_studio_2020/opencl/SDK/lib64/libOpenCL.so", +#endif +}; + +OpenCLWrapper *OpenCLWrapper::GetInstance() { + static std::once_flag opencl_wrapper_once; + std::call_once(opencl_wrapper_once, + []() { opencl_wrapper_singleton_ = std::shared_ptr(new OpenCLWrapper()); }); + + return opencl_wrapper_singleton_.get(); +} + +OpenCLWrapper::OpenCLWrapper() {} + +OpenCLWrapper::~OpenCLWrapper() { + if (nullptr == opencl_wrapper_singleton_.get()) return; + opencl_wrapper_singleton_->UnLoadOpenCLLibrary(); +} + +// load default library path +bool OpenCLWrapper::LoadOpenCLLibrary() { + if (handle_ != nullptr) { + return true; + } + for (const auto &lib_path : g_opencl_library_paths) { + if (LoadLibraryFromPath(lib_path)) { + MS_LOG(DEBUG) << "Find a OpenCL dynamic library : " << lib_path; + return true; + } + } + return false; +} + +bool OpenCLWrapper::UnLoadOpenCLLibrary() { + if (handle_ != nullptr) { + if (dlclose(handle_) != 0) { + return false; + } + handle_ = nullptr; + return true; + } + return true; +} + +bool OpenCLWrapper::LoadLibraryFromPath(const std::string &library_path) { + handle_ = dlopen(library_path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (handle_ == nullptr) { + return false; + } + +// load function ptr use dlopen and dlsym. +#define LOAD_OPENCL_FUNCTION_PTR(func_name) \ + func_name = reinterpret_cast(dlsym(handle_, #func_name)); \ + if (func_name == nullptr) { \ + MS_LOG(ERROR) << "load func (" << #func_name << ") from (" << library_path << ") failed!"; \ + return false; \ + } + + LOAD_OPENCL_FUNCTION_PTR(clGetPlatformIDs); + LOAD_OPENCL_FUNCTION_PTR(clGetPlatformInfo); + LOAD_OPENCL_FUNCTION_PTR(clBuildProgram); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueNDRangeKernel); + LOAD_OPENCL_FUNCTION_PTR(clSetKernelArg); + LOAD_OPENCL_FUNCTION_PTR(clReleaseKernel); + LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithSource); + LOAD_OPENCL_FUNCTION_PTR(clCreateBuffer); + LOAD_OPENCL_FUNCTION_PTR(clCreateImage2D); + LOAD_OPENCL_FUNCTION_PTR(clCreateImage3D); + LOAD_OPENCL_FUNCTION_PTR(clRetainKernel); + LOAD_OPENCL_FUNCTION_PTR(clCreateKernel); + LOAD_OPENCL_FUNCTION_PTR(clGetProgramInfo); + LOAD_OPENCL_FUNCTION_PTR(clFlush); + LOAD_OPENCL_FUNCTION_PTR(clFinish); + LOAD_OPENCL_FUNCTION_PTR(clReleaseProgram); + LOAD_OPENCL_FUNCTION_PTR(clRetainContext); + LOAD_OPENCL_FUNCTION_PTR(clGetContextInfo); + LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithBinary); + LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueue); + LOAD_OPENCL_FUNCTION_PTR(clGetCommandQueueInfo); + LOAD_OPENCL_FUNCTION_PTR(clReleaseCommandQueue); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueMapBuffer); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueMapImage); + LOAD_OPENCL_FUNCTION_PTR(clRetainProgram); + LOAD_OPENCL_FUNCTION_PTR(clGetProgramBuildInfo); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueReadBuffer); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueWriteBuffer); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueReadImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueWriteImage); + LOAD_OPENCL_FUNCTION_PTR(clWaitForEvents); + LOAD_OPENCL_FUNCTION_PTR(clReleaseEvent); + LOAD_OPENCL_FUNCTION_PTR(clCreateContext); + LOAD_OPENCL_FUNCTION_PTR(clCreateContextFromType); + LOAD_OPENCL_FUNCTION_PTR(clReleaseContext); + LOAD_OPENCL_FUNCTION_PTR(clRetainCommandQueue); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueUnmapMemObject); + LOAD_OPENCL_FUNCTION_PTR(clRetainMemObject); + LOAD_OPENCL_FUNCTION_PTR(clReleaseMemObject); + LOAD_OPENCL_FUNCTION_PTR(clGetDeviceInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetDeviceIDs); + LOAD_OPENCL_FUNCTION_PTR(clRetainEvent); + LOAD_OPENCL_FUNCTION_PTR(clGetKernelWorkGroupInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetEventInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetEventProfilingInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetImageInfo); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueCopyImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueCopyBufferToImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueCopyImageToBuffer); +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + LOAD_OPENCL_FUNCTION_PTR(clRetainDevice); + LOAD_OPENCL_FUNCTION_PTR(clReleaseDevice); + LOAD_OPENCL_FUNCTION_PTR(clCreateImage); +#endif +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 + // LOAD_OPENCL_FUNCTION_PTR(clGetKernelSubGroupInfoKHR); + LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueueWithProperties); + LOAD_OPENCL_FUNCTION_PTR(clGetExtensionFunctionAddress); + LOAD_OPENCL_FUNCTION_PTR(clSVMAlloc); + LOAD_OPENCL_FUNCTION_PTR(clSVMFree); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueSVMMap); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueSVMUnmap); + LOAD_OPENCL_FUNCTION_PTR(clSetKernelArgSVMPointer); +#ifdef PROGRAM_WITH_IL + LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithIL); +#endif +#endif + +#undef LOAD_OPENCL_FUNCTION_PTR + + return true; +} + +} // namespace mindspore::lite::opencl + +// clGetPlatformIDs wrapper, use OpenCLWrapper function. use OpenCLWrapper function. +cl_int clGetPlatformIDs(cl_uint num_entries, cl_platform_id *platforms, cl_uint *num_platforms) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetPlatformIDs; + MS_ASSERT(func != nullptr); + return func(num_entries, platforms, num_platforms); +} + +// clGetPlatformInfo wrapper, use OpenCLWrapper function. use OpenCLWrapper function. +cl_int clGetPlatformInfo(cl_platform_id platform, cl_platform_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetPlatformInfo; + MS_ASSERT(func != nullptr); + return func(platform, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clGetDeviceIDs wrapper, use OpenCLWrapper function. +cl_int clGetDeviceIDs(cl_platform_id platform, cl_device_type device_type, cl_uint num_entries, cl_device_id *devices, + cl_uint *num_devices) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetDeviceIDs; + MS_ASSERT(func != nullptr); + return func(platform, device_type, num_entries, devices, num_devices); +} + +// clGetDeviceInfo wrapper, use OpenCLWrapper function. +cl_int clGetDeviceInfo(cl_device_id device, cl_device_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetDeviceInfo; + MS_ASSERT(func != nullptr); + return func(device, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clCreateContext wrapper, use OpenCLWrapper function. +cl_context clCreateContext(const cl_context_properties *properties, cl_uint num_devices, const cl_device_id *devices, + void(CL_CALLBACK *pfn_notify)(const char *, const void *, size_t, void *), void *user_data, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateContext; + MS_ASSERT(func != nullptr); + return func(properties, num_devices, devices, pfn_notify, user_data, errcode_ret); +} + +// clCreateContextFromType wrapper, use OpenCLWrapper function. +cl_context clCreateContextFromType(const cl_context_properties *properties, cl_device_type device_type, + void(CL_CALLBACK *pfn_notify)(const char *, const void *, size_t, void *), + void *user_data, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateContextFromType; + MS_ASSERT(func != nullptr); + return func(properties, device_type, pfn_notify, user_data, errcode_ret); +} + +// clRetainContext wrapper, use OpenCLWrapper function. +cl_int clRetainContext(cl_context context) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainContext; + MS_ASSERT(func != nullptr); + return func(context); +} + +// clReleaseContext wrapper, use OpenCLWrapper function. +cl_int clReleaseContext(cl_context context) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseContext; + MS_ASSERT(func != nullptr); + return func(context); +} + +// clGetContextInfo wrapper, use OpenCLWrapper function. +cl_int clGetContextInfo(cl_context context, cl_context_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetContextInfo; + MS_ASSERT(func != nullptr); + return func(context, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clCreateProgramWithSource wrapper, use OpenCLWrapper function. +cl_program clCreateProgramWithSource(cl_context context, cl_uint count, const char **strings, const size_t *lengths, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateProgramWithSource; + MS_ASSERT(func != nullptr); + return func(context, count, strings, lengths, errcode_ret); +} + +// clGetProgramInfo wrapper, use OpenCLWrapper function. +cl_int clGetProgramInfo(cl_program program, cl_program_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetProgramInfo; + MS_ASSERT(func != nullptr); + return func(program, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clGetProgramBuildInfo wrapper, use OpenCLWrapper function. +cl_int clGetProgramBuildInfo(cl_program program, cl_device_id device, cl_program_build_info param_name, + size_t param_value_size, void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetProgramBuildInfo; + MS_ASSERT(func != nullptr); + return func(program, device, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clRetainProgram wrapper, use OpenCLWrapper function. +cl_int clRetainProgram(cl_program program) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainProgram; + MS_ASSERT(func != nullptr); + return func(program); +} + +// clReleaseProgram wrapper, use OpenCLWrapper function. +cl_int clReleaseProgram(cl_program program) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseProgram; + MS_ASSERT(func != nullptr); + return func(program); +} + +// clBuildProgram wrapper, use OpenCLWrapper function. +cl_int clBuildProgram(cl_program program, cl_uint num_devices, const cl_device_id *device_list, const char *options, + void(CL_CALLBACK *pfn_notify)(cl_program program, void *user_data), void *user_data) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clBuildProgram; + MS_ASSERT(func != nullptr); + return func(program, num_devices, device_list, options, pfn_notify, user_data); +} + +// clCreateKernel wrapper, use OpenCLWrapper function. +cl_kernel clCreateKernel(cl_program program, const char *kernelName, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateKernel; + MS_ASSERT(func != nullptr); + return func(program, kernelName, errcode_ret); +} + +// clRetainKernel wrapper, use OpenCLWrapper function. +cl_int clRetainKernel(cl_kernel kernel) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainKernel; + MS_ASSERT(func != nullptr); + return func(kernel); +} + +// clReleaseKernel wrapper, use OpenCLWrapper function. +cl_int clReleaseKernel(cl_kernel kernel) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseKernel; + MS_ASSERT(func != nullptr); + return func(kernel); +} + +// clSetKernelArg wrapper, use OpenCLWrapper function. +cl_int clSetKernelArg(cl_kernel kernel, cl_uint arg_index, size_t arg_size, const void *arg_value) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSetKernelArg; + MS_ASSERT(func != nullptr); + return func(kernel, arg_index, arg_size, arg_value); +} + +// clCreateBuffer wrapper, use OpenCLWrapper function. +cl_mem clCreateBuffer(cl_context context, cl_mem_flags flags, size_t size, void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateBuffer; + MS_ASSERT(func != nullptr); + return func(context, flags, size, host_ptr, errcode_ret); +} + +// clRetainMemObject wrapper, use OpenCLWrapper function. +cl_int clRetainMemObject(cl_mem memobj) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainMemObject; + MS_ASSERT(func != nullptr); + return func(memobj); +} + +// clReleaseMemObject wrapper, use OpenCLWrapper function. +cl_int clReleaseMemObject(cl_mem memobj) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseMemObject; + MS_ASSERT(func != nullptr); + return func(memobj); +} + +// clGetImageInfo wrapper, use OpenCLWrapper function. +cl_int clGetImageInfo(cl_mem image, cl_image_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetImageInfo; + MS_ASSERT(func != nullptr); + return func(image, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clRetainCommandQueue wrapper, use OpenCLWrapper function. +cl_int clRetainCommandQueue(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainCommandQueue; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clReleaseCommandQueue wrapper, use OpenCLWrapper function. +cl_int clReleaseCommandQueue(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseCommandQueue; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clEnqueueReadBuffer wrapper, use OpenCLWrapper function. +cl_int clEnqueueReadBuffer(cl_command_queue command_queue, cl_mem buffer, cl_bool blocking_read, size_t offset, + size_t size, void *ptr, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, + cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueReadBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, buffer, blocking_read, offset, size, ptr, num_events_in_wait_list, event_wait_list, event); +} + +// clEnqueueWriteBuffer wrapper, use OpenCLWrapper function. +cl_int clEnqueueWriteBuffer(cl_command_queue command_queue, cl_mem buffer, cl_bool blocking_write, size_t offset, + size_t size, const void *ptr, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueWriteBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, buffer, blocking_write, offset, size, ptr, num_events_in_wait_list, event_wait_list, + event); +} + +// clEnqueueWriteImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueWriteImage(cl_command_queue command_queue, cl_mem image, cl_bool blocking_write, const size_t *origin, + const size_t *region, size_t input_row_pitch, size_t input_slice_pitch, const void *ptr, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueWriteImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, blocking_write, origin, region, input_row_pitch, input_slice_pitch, ptr, + num_events_in_wait_list, event_wait_list, event); +} + +// clEnqueueReadImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueReadImage(cl_command_queue command_queue, cl_mem image, cl_bool blocking_read, const size_t *origin, + const size_t *region, size_t row_pitch, size_t slice_pitch, void *ptr, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueReadImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, blocking_read, origin, region, row_pitch, slice_pitch, ptr, num_events_in_wait_list, + event_wait_list, event); +} + +// clEnqueueMapBuffer wrapper, use OpenCLWrapper function. +void *clEnqueueMapBuffer(cl_command_queue command_queue, cl_mem buffer, cl_bool blocking_map, cl_map_flags map_flags, + size_t offset, size_t size, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, + cl_event *event, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueMapBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, buffer, blocking_map, map_flags, offset, size, num_events_in_wait_list, event_wait_list, + event, errcode_ret); +} + +// clEnqueueMapImage wrapper, use OpenCLWrapper function. +void *clEnqueueMapImage(cl_command_queue command_queue, cl_mem image, cl_bool blocking_map, cl_map_flags map_flags, + const size_t *origin, const size_t *region, size_t *image_row_pitch, size_t *image_slice_pitch, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueMapImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, blocking_map, map_flags, origin, region, image_row_pitch, image_slice_pitch, + num_events_in_wait_list, event_wait_list, event, errcode_ret); +} + +// clEnqueueUnmapMemObject wrapper, use OpenCLWrapper function. +cl_int clEnqueueUnmapMemObject(cl_command_queue command_queue, cl_mem memobj, void *mapped_ptr, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueUnmapMemObject; + MS_ASSERT(func != nullptr); + return func(command_queue, memobj, mapped_ptr, num_events_in_wait_list, event_wait_list, event); +} + +// clGetKernelWorkGroupInfo wrapper, use OpenCLWrapper function. +cl_int clGetKernelWorkGroupInfo(cl_kernel kernel, cl_device_id device, cl_kernel_work_group_info param_name, + size_t param_value_size, void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetKernelWorkGroupInfo; + MS_ASSERT(func != nullptr); + return func(kernel, device, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clGetEventProfilingInfo wrapper, use OpenCLWrapper function. +cl_int clGetEventProfilingInfo(cl_event event, cl_profiling_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetEventProfilingInfo; + MS_ASSERT(func != nullptr); + return func(event, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clEnqueueNDRangeKernel wrapper, use OpenCLWrapper function. +cl_int clEnqueueNDRangeKernel(cl_command_queue command_queue, cl_kernel kernel, cl_uint work_dim, + const size_t *global_work_offset, const size_t *global_work_size, + const size_t *local_work_size, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueNDRangeKernel; + MS_ASSERT(func != nullptr); + return func(command_queue, kernel, work_dim, global_work_offset, global_work_size, local_work_size, + num_events_in_wait_list, event_wait_list, event); +} + +// clWaitForEvents wrapper, use OpenCLWrapper function. +cl_int clWaitForEvents(cl_uint num_events, const cl_event *event_list) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clWaitForEvents; + MS_ASSERT(func != nullptr); + return func(num_events, event_list); +} + +// clRetainEvent wrapper, use OpenCLWrapper function. +cl_int clRetainEvent(cl_event event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainEvent; + MS_ASSERT(func != nullptr); + return func(event); +} + +// clReleaseEvent wrapper, use OpenCLWrapper function. +cl_int clReleaseEvent(cl_event event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseEvent; + MS_ASSERT(func != nullptr); + return func(event); +} + +// clGetEventInfo wrapper, use OpenCLWrapper function. +cl_int clGetEventInfo(cl_event event, cl_event_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetEventInfo; + MS_ASSERT(func != nullptr); + return func(event, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clFlush wrapper, use OpenCLWrapper function. +cl_int clFlush(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clFlush; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clFinish wrapper, use OpenCLWrapper function. +cl_int clFinish(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clFinish; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clCreateImage2D wrapper, use OpenCLWrapper function. +cl_mem clCreateImage2D(cl_context context, cl_mem_flags flags, const cl_image_format *image_format, size_t imageWidth, + size_t imageHeight, size_t image_row_pitch, void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateImage2D; + MS_ASSERT(func != nullptr); + return func(context, flags, image_format, imageWidth, imageHeight, image_row_pitch, host_ptr, errcode_ret); +} + +// clCreateImage3D wrapper, use OpenCLWrapper function. +cl_mem clCreateImage3D(cl_context context, cl_mem_flags flags, const cl_image_format *image_format, size_t imageWidth, + size_t imageHeight, size_t imageDepth, size_t image_row_pitch, size_t image_slice_pitch, + void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateImage3D; + MS_ASSERT(func != nullptr); + return func(context, flags, image_format, imageWidth, imageHeight, imageDepth, image_row_pitch, image_slice_pitch, + host_ptr, errcode_ret); +} + +// clCreateCommandQueue wrapper, use OpenCLWrapper function. +cl_command_queue clCreateCommandQueue(cl_context context, cl_device_id device, cl_command_queue_properties properties, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateCommandQueue; + MS_ASSERT(func != nullptr); + return func(context, device, properties, errcode_ret); +} + +// clGetCommandQueueInfo wrapper, use OpenCLWrapper function. +cl_int clGetCommandQueueInfo(cl_command_queue command_queue, cl_command_queue_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetCommandQueueInfo; + MS_ASSERT(func != nullptr); + return func(command_queue, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clEnqueueCopyImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueCopyImage(cl_command_queue queue, cl_mem src_image, cl_mem dst_image, const size_t *src_origin, + const size_t *dst_origin, const size_t *region, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueCopyImage; + MS_ASSERT(func != nullptr); + return func(queue, src_image, dst_image, src_origin, dst_origin, region, num_events_in_wait_list, event_wait_list, + event); +} + +// clEnqueueCopyBufferToImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueCopyBufferToImage(cl_command_queue command_queue, cl_mem src_buffer, cl_mem dst_image, + size_t src_offset, const size_t *dst_origin, const size_t *region, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueCopyBufferToImage; + MS_ASSERT(func != nullptr); + return func(command_queue, src_buffer, dst_image, src_offset, dst_origin, region, num_events_in_wait_list, + event_wait_list, event); +} + +// clEnqueueCopyImageToBuffer wrapper, use OpenCLWrapper function. +cl_int clEnqueueCopyImageToBuffer(cl_command_queue command_queue, cl_mem src_image, cl_mem dst_buffer, + const size_t *src_origin, const size_t *region, size_t dst_offset, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueCopyImageToBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, src_image, dst_buffer, src_origin, region, dst_offset, num_events_in_wait_list, + event_wait_list, event); +} + +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + +// clRetainDevice wrapper, use OpenCLWrapper function. +cl_int clRetainDevice(cl_device_id device) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainDevice; + MS_ASSERT(func != nullptr); + return func(device); +} + +// clReleaseDevice wrapper, use OpenCLWrapper function. +cl_int clReleaseDevice(cl_device_id device) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseDevice; + MS_ASSERT(func != nullptr); + return func(device); +} + +// clCreateImage wrapper, use OpenCLWrapper function. +cl_mem clCreateImage(cl_context context, cl_mem_flags flags, const cl_image_format *image_format, + const cl_image_desc *image_desc, void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateImage; + MS_ASSERT(func != nullptr); + return func(context, flags, image_format, image_desc, host_ptr, errcode_ret); +} + +#endif + +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 +#if 0 +// clGetKernelSubGroupInfoKHR wrapper, use OpenCLWrapper function. +cl_int clGetKernelSubGroupInfoKHR(cl_kernel kernel, cl_device_id device, cl_kernel_sub_group_info param_name, + size_t input_value_size, const void *input_value, size_t param_value_size, + void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetKernelSubGroupInfoKHR; + MS_ASSERT(func != nullptr); + return func(kernel, device, param_name, input_value_size, input_value, param_value_size, param_value, + param_value_size_ret); +} +#endif + +// clCreateCommandQueueWithProperties wrapper, use OpenCLWrapper function. +cl_command_queue clCreateCommandQueueWithProperties(cl_context context, cl_device_id device, + const cl_queue_properties *properties, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateCommandQueueWithProperties; + MS_ASSERT(func != nullptr); + return func(context, device, properties, errcode_ret); +} + +// clGetExtensionFunctionAddress wrapper, use OpenCLWrapper function. +void *clGetExtensionFunctionAddress(const char *func_name) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetExtensionFunctionAddress; + MS_ASSERT(func != nullptr); + return func(func_name); +} +// clCreateProgramWithIL wrapper, use OpenCLWrapper function. +cl_program clCreateProgramWithIL(cl_context context, const void *il, size_t length, cl_int *ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateProgramWithIL; + MS_ASSERT(func != nullptr); + return func(context, il, length, ret); +} + +// clSVMAlloc wrapper, use OpenCLWrapper function. +void *clSVMAlloc(cl_context context, cl_mem_flags flags, size_t size, cl_uint align) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSVMAlloc; + MS_ASSERT(func != nullptr); + return func(context, flags, size, align); +} + +// clSVMFree wrapper, use OpenCLWrapper function. +void clSVMFree(cl_context context, void *buffer) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSVMFree; + MS_ASSERT(func != nullptr); + func(context, buffer); +} + +// clEnqueueSVMMap wrapper, use OpenCLWrapper function. +cl_int clEnqueueSVMMap(cl_command_queue command_queue, cl_bool blocking, cl_map_flags flags, void *host_ptr, + size_t size, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueSVMMap; + MS_ASSERT(func != nullptr); + return func(command_queue, blocking, flags, host_ptr, size, num_events_in_wait_list, event_wait_list, event); +} + +// clEnqueueSVMUnmap wrapper, use OpenCLWrapper function. +cl_int clEnqueueSVMUnmap(cl_command_queue command_queue, void *host_ptr, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueSVMUnmap; + MS_ASSERT(func != nullptr); + return func(command_queue, host_ptr, num_events_in_wait_list, event_wait_list, event); +} + +// clSetKernelArgSVMPointer wrapper, use OpenCLWrapper function. +cl_int clSetKernelArgSVMPointer(cl_kernel kernel, cl_uint index, const void *host_ptr) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSetKernelArgSVMPointer; + MS_ASSERT(func != nullptr); + return func(kernel, index, host_ptr); +} +#endif + +#endif // USE_OPENCL_WRAPPER + diff --git a/mindspore/lite/src/runtime/opencl/opencl_wrapper.h b/mindspore/lite/src/runtime/opencl/opencl_wrapper.h new file mode 100644 index 0000000000..d4f0d98f9a --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_wrapper.h @@ -0,0 +1,240 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_WRAPPER_H_ +#define MINDSPORE_LITE_SRC_OPENCL_WRAPPER_H_ + +#include +#include +#include + +// support opencl min version is 1.1 +#ifndef CL_TARGET_OPENCL_VERSION +#define CL_TARGET_OPENCL_VERSION 210 +#endif +#ifndef CL_HPP_TARGET_OPENCL_VERSION +#define CL_HPP_TARGET_OPENCL_VERSION 210 +#endif +#ifndef CL_HPP_MINIMUM_OPENCL_VERSION +#define CL_HPP_MINIMUM_OPENCL_VERSION 110 +#endif + +#include "CL/cl2.hpp" + +#ifdef USE_OPENCL_WRAPPER + +namespace mindspore::lite::opencl { + +// This is a opencl function wrapper. +class OpenCLWrapper { + public: + static OpenCLWrapper *GetInstance(); + + ~OpenCLWrapper(); + OpenCLWrapper(const OpenCLWrapper &) = delete; + OpenCLWrapper &operator=(const OpenCLWrapper &) = delete; + + bool LoadOpenCLLibrary(); + bool UnLoadOpenCLLibrary(); + // get platfrom id + using clGetPlatformIDsFunc = cl_int (*)(cl_uint, cl_platform_id *, cl_uint *); + // get platform info + using clGetPlatformInfoFunc = cl_int (*)(cl_platform_id, cl_platform_info, size_t, void *, size_t *); + // build program + using clBuildProgramFunc = cl_int (*)(cl_program, cl_uint, const cl_device_id *, const char *, + void (*pfn_notify)(cl_program, void *), void *); + // enqueue run kernel + using clEnqueueNDRangeKernelFunc = cl_int (*)(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, + const size_t *, cl_uint, const cl_event *, cl_event *); + // set kernel parameter + using clSetKernelArgFunc = cl_int (*)(cl_kernel, cl_uint, size_t, const void *); + using clRetainMemObjectFunc = cl_int (*)(cl_mem); + using clReleaseMemObjectFunc = cl_int (*)(cl_mem); + using clEnqueueUnmapMemObjectFunc = cl_int (*)(cl_command_queue, cl_mem, void *, cl_uint, const cl_event *, + cl_event *); + using clRetainCommandQueueFunc = cl_int (*)(cl_command_queue command_queue); + // create context + using clCreateContextFunc = cl_context (*)(const cl_context_properties *, cl_uint, const cl_device_id *, + void(CL_CALLBACK *)( // NOLINT(readability/casting) + const char *, const void *, size_t, void *), + void *, cl_int *); + using clEnqueueCopyImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_mem, const size_t *, const size_t *, + const size_t *, cl_uint, const cl_event *, cl_event *); + + using clCreateContextFromTypeFunc = cl_context (*)(const cl_context_properties *, cl_device_type, + void(CL_CALLBACK *)( // NOLINT(readability/casting) + const char *, const void *, size_t, void *), + void *, cl_int *); + using clReleaseContextFunc = cl_int (*)(cl_context); + using clWaitForEventsFunc = cl_int (*)(cl_uint, const cl_event *); + using clReleaseEventFunc = cl_int (*)(cl_event); + using clEnqueueWriteBufferFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, size_t, size_t, const void *, cl_uint, + const cl_event *, cl_event *); + using clEnqueueWriteImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, const size_t *, const size_t *, size_t, + size_t, const void *, cl_uint, const cl_event *, cl_event *); + using clEnqueueReadImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, const size_t *, const size_t *, size_t, + size_t, void *, cl_uint, const cl_event *, cl_event *); + using clEnqueueReadBufferFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, size_t, size_t, void *, cl_uint, + const cl_event *, cl_event *); + using clGetProgramBuildInfoFunc = cl_int (*)(cl_program, cl_device_id, cl_program_build_info, size_t, void *, + size_t *); + using clRetainProgramFunc = cl_int (*)(cl_program program); + using clEnqueueMapBufferFunc = void *(*)(cl_command_queue, cl_mem, cl_bool, cl_map_flags, size_t, size_t, cl_uint, + const cl_event *, cl_event *, cl_int *); + using clEnqueueMapImageFunc = void *(*)(cl_command_queue, cl_mem, cl_bool, cl_map_flags, const size_t *, + const size_t *, size_t *, size_t *, cl_uint, const cl_event *, cl_event *, + cl_int *); + using clCreateCommandQueueFunc = cl_command_queue(CL_API_CALL *)(cl_context, cl_device_id, + cl_command_queue_properties, cl_int *); + using clGetCommandQueueInfoFunc = cl_int (*)(cl_command_queue, cl_command_queue_info, size_t, void *, size_t *); + using clReleaseCommandQueueFunc = cl_int (*)(cl_command_queue); + using clCreateProgramWithBinaryFunc = cl_program (*)(cl_context, cl_uint, const cl_device_id *, const size_t *, + const unsigned char **, cl_int *, cl_int *); + using clRetainContextFunc = cl_int (*)(cl_context context); + using clGetContextInfoFunc = cl_int (*)(cl_context, cl_context_info, size_t, void *, size_t *); + using clReleaseProgramFunc = cl_int (*)(cl_program program); + using clFlushFunc = cl_int (*)(cl_command_queue command_queue); + using clFinishFunc = cl_int (*)(cl_command_queue command_queue); + using clGetProgramInfoFunc = cl_int (*)(cl_program, cl_program_info, size_t, void *, size_t *); + using clCreateKernelFunc = cl_kernel (*)(cl_program, const char *, cl_int *); + using clRetainKernelFunc = cl_int (*)(cl_kernel kernel); + using clCreateBufferFunc = cl_mem (*)(cl_context, cl_mem_flags, size_t, void *, cl_int *); + using clCreateImage2DFunc = cl_mem(CL_API_CALL *)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, + size_t, void *, cl_int *); + using clCreateImage3DFunc = cl_mem(CL_API_CALL *)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, + size_t, size_t, size_t, void *, cl_int *); + using clCreateProgramWithSourceFunc = cl_program (*)(cl_context, cl_uint, const char **, const size_t *, cl_int *); + using clReleaseKernelFunc = cl_int (*)(cl_kernel kernel); + using clGetDeviceInfoFunc = cl_int (*)(cl_device_id, cl_device_info, size_t, void *, size_t *); + using clGetDeviceIDsFunc = cl_int (*)(cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *); + using clRetainEventFunc = cl_int (*)(cl_event); + using clGetKernelWorkGroupInfoFunc = cl_int (*)(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, + size_t *); + using clGetEventInfoFunc = cl_int (*)(cl_event event, cl_event_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret); + using clGetEventProfilingInfoFunc = cl_int (*)(cl_event event, cl_profiling_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret); + using clGetImageInfoFunc = cl_int (*)(cl_mem, cl_image_info, size_t, void *, size_t *); + using clEnqueueCopyBufferToImageFunc = cl_int(CL_API_CALL *)(cl_command_queue, cl_mem, cl_mem, size_t, const size_t *, + const size_t *, cl_uint, const cl_event *, cl_event *); + using clEnqueueCopyImageToBufferFunc = cl_int(CL_API_CALL *)(cl_command_queue, cl_mem, cl_mem, const size_t *, + const size_t *, size_t, cl_uint, const cl_event *, + cl_event *); +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + using clRetainDeviceFunc = cl_int (*)(cl_device_id); + using clReleaseDeviceFunc = cl_int (*)(cl_device_id); + using clCreateImageFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *, + cl_int *); +#endif +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 + using clCreateProgramWithILFunc = cl_program (*)(cl_context, const void *, size_t, cl_int *); + using clSVMAllocFunc = void *(*)(cl_context, cl_mem_flags, size_t size, cl_uint); + using clSVMFreeFunc = void (*)(cl_context, void *); + using clEnqueueSVMMapFunc = cl_int (*)(cl_command_queue, cl_bool, cl_map_flags, void *, size_t, cl_uint, + const cl_event *, cl_event *); + using clEnqueueSVMUnmapFunc = cl_int (*)(cl_command_queue, void *, cl_uint, const cl_event *, cl_event *); + using clSetKernelArgSVMPointerFunc = cl_int (*)(cl_kernel, cl_uint, const void *); + // opencl 2.0 can get sub group info and wave size. + using clGetKernelSubGroupInfoKHRFunc = cl_int(CL_API_CALL *)(cl_kernel, cl_device_id, cl_kernel_sub_group_info, + size_t, const void *, size_t, void *, size_t *); + using clCreateCommandQueueWithPropertiesFunc = cl_command_queue(CL_API_CALL *)(cl_context, cl_device_id, + const cl_queue_properties *, cl_int *); + using clGetExtensionFunctionAddressFunc = void *(CL_API_CALL *)(const char *); +#endif + +#define CL_DEFINE_FUNC_PTR(func) func##Func func = nullptr + + CL_DEFINE_FUNC_PTR(clGetPlatformIDs); + CL_DEFINE_FUNC_PTR(clGetPlatformInfo); + CL_DEFINE_FUNC_PTR(clBuildProgram); + CL_DEFINE_FUNC_PTR(clEnqueueNDRangeKernel); + CL_DEFINE_FUNC_PTR(clSetKernelArg); + CL_DEFINE_FUNC_PTR(clReleaseKernel); + CL_DEFINE_FUNC_PTR(clCreateProgramWithSource); + CL_DEFINE_FUNC_PTR(clCreateBuffer); + CL_DEFINE_FUNC_PTR(clCreateImage2D); + CL_DEFINE_FUNC_PTR(clCreateImage3D); + CL_DEFINE_FUNC_PTR(clRetainKernel); + CL_DEFINE_FUNC_PTR(clCreateKernel); + CL_DEFINE_FUNC_PTR(clGetProgramInfo); + CL_DEFINE_FUNC_PTR(clFlush); + CL_DEFINE_FUNC_PTR(clFinish); + CL_DEFINE_FUNC_PTR(clReleaseProgram); + CL_DEFINE_FUNC_PTR(clRetainContext); + CL_DEFINE_FUNC_PTR(clGetContextInfo); + CL_DEFINE_FUNC_PTR(clCreateProgramWithBinary); + CL_DEFINE_FUNC_PTR(clCreateCommandQueue); + CL_DEFINE_FUNC_PTR(clGetCommandQueueInfo); + CL_DEFINE_FUNC_PTR(clReleaseCommandQueue); + CL_DEFINE_FUNC_PTR(clEnqueueMapBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueMapImage); + CL_DEFINE_FUNC_PTR(clEnqueueCopyImage); + CL_DEFINE_FUNC_PTR(clRetainProgram); + CL_DEFINE_FUNC_PTR(clGetProgramBuildInfo); + CL_DEFINE_FUNC_PTR(clEnqueueReadBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueWriteBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueWriteImage); + CL_DEFINE_FUNC_PTR(clEnqueueReadImage); + CL_DEFINE_FUNC_PTR(clWaitForEvents); + CL_DEFINE_FUNC_PTR(clReleaseEvent); + CL_DEFINE_FUNC_PTR(clCreateContext); + CL_DEFINE_FUNC_PTR(clCreateContextFromType); + CL_DEFINE_FUNC_PTR(clReleaseContext); + CL_DEFINE_FUNC_PTR(clRetainCommandQueue); + CL_DEFINE_FUNC_PTR(clEnqueueUnmapMemObject); + CL_DEFINE_FUNC_PTR(clRetainMemObject); + CL_DEFINE_FUNC_PTR(clReleaseMemObject); + CL_DEFINE_FUNC_PTR(clGetDeviceInfo); + CL_DEFINE_FUNC_PTR(clGetDeviceIDs); + CL_DEFINE_FUNC_PTR(clRetainEvent); + CL_DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); + CL_DEFINE_FUNC_PTR(clGetEventInfo); + CL_DEFINE_FUNC_PTR(clGetEventProfilingInfo); + CL_DEFINE_FUNC_PTR(clGetImageInfo); + CL_DEFINE_FUNC_PTR(clEnqueueCopyBufferToImage); + CL_DEFINE_FUNC_PTR(clEnqueueCopyImageToBuffer); +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + CL_DEFINE_FUNC_PTR(clRetainDevice); + CL_DEFINE_FUNC_PTR(clReleaseDevice); + CL_DEFINE_FUNC_PTR(clCreateImage); +#endif +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 + CL_DEFINE_FUNC_PTR(clGetKernelSubGroupInfoKHR); + CL_DEFINE_FUNC_PTR(clCreateCommandQueueWithProperties); + CL_DEFINE_FUNC_PTR(clGetExtensionFunctionAddress); + CL_DEFINE_FUNC_PTR(clCreateProgramWithIL); + CL_DEFINE_FUNC_PTR(clSVMAlloc); + CL_DEFINE_FUNC_PTR(clSVMFree); + CL_DEFINE_FUNC_PTR(clEnqueueSVMMap); + CL_DEFINE_FUNC_PTR(clEnqueueSVMUnmap); + CL_DEFINE_FUNC_PTR(clSetKernelArgSVMPointer); +#endif + +#undef TNN_CL_DEFINE_FUNC_PTR + + private: + OpenCLWrapper(); + bool LoadLibraryFromPath(const std::string &path); + + private: + static std::shared_ptr opencl_wrapper_singleton_; + void *handle_ = nullptr; +}; + +} // namespace mindspore::lite::opencl +#endif // USE_OPENCL_WRAPPER +#endif // MINDSPORE_LITE_SRC_OPENCL_WRAPPER_H_ + diff --git a/mindspore/lite/src/runtime/runtime_api.cc b/mindspore/lite/src/runtime/runtime_api.cc new file mode 100644 index 0000000000..460ae4b07a --- /dev/null +++ b/mindspore/lite/src/runtime/runtime_api.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/runtime_api.h" +#include "src/runtime/workspace_pool.h" +#include "src/runtime/thread_pool.h" +#include "utils/log_adapter.h" + +static std::mutex gWorkspaceMutex; +#ifdef __cplusplus +extern "C" { +#endif +void LiteAPISetLastError(const char *msg) { + MS_LOG(ERROR) << "The lite api set last error is " << msg; +} + +void *LiteBackendAllocWorkspace(int deviceType, + int deviceId, + uint64_t size, + int dtypeCode, + int dtypeBits) { + std::lock_guard lock(gWorkspaceMutex); + auto p = mindspore::predict::WorkspacePool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return nullptr; + } + return p->AllocWorkSpaceMem(size); +} + +int LiteBackendFreeWorkspace(int deviceType, int deviceId, void *ptr) { + std::lock_guard lock(gWorkspaceMutex); + auto p = mindspore::predict::WorkspacePool::GetInstance(); + if (p == nullptr) { + return -1; + } + p->FreeWorkSpaceMem(ptr); + return 0; +} + +void SetMaxWokerNum(int num) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return; + } + if (num < 0) { + LiteAPISetLastError("The number of work thread is less than 0"); + return; + } + p->ConfigMaxThreadNum(num); +} + +void ConfigThreadPool(int mode, int nthreads) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return; + } + p->ConfigThreadPool(mode, nthreads); +} + +int LiteBackendParallelLaunch(FTVMParallelLambda flambda, void *cdata, int num_task) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return -1; + } + if (!p->LaunchWork(flambda, cdata, num_task)) { + MS_LOG(ERROR) << "launch thread pool work failed"; + return -1; + } + return 0; +} + +void DoAllThreadBind(bool ifBind, int mode) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return; + } + if (!p->BindAllThreads(ifBind, mode)) { + MS_LOG(ERROR) << "do thread cpu bind failed"; + } +} + +#ifdef __cplusplus +} +#endif + diff --git a/mindspore/lite/src/runtime/runtime_api.h b/mindspore/lite/src/runtime/runtime_api.h new file mode 100644 index 0000000000..cd3942d79e --- /dev/null +++ b/mindspore/lite/src/runtime/runtime_api.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_ +#include + +#ifndef INTERNAL_API_DLL +#ifdef _WIN32 +#ifdef LITE_EXPORTS +#define INTERNAL_API_DLL __declspec(dllexport) +#else +#define INTERNAL_API_DLL __declspec(dllimport) +#endif +#else +#define INTERNAL_API_DLL __attribute__((visibility("default"))) +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + void *sync_handle; + int32_t num_task; +} LiteParallelGroupEnv; +typedef int (*FTVMParallelLambda)(int task_id, LiteParallelGroupEnv *penv, void *cdata); +INTERNAL_API_DLL void LiteAPISetLastError(const char *msg); +INTERNAL_API_DLL void *LiteBackendAllocWorkspace(int deviceType, int deviceId, uint64_t size, int dtypeCode, + int dtypeBits); +INTERNAL_API_DLL int LiteBackendFreeWorkspace(int deviceType, int deviceId, void *ptr); +INTERNAL_API_DLL void SetMaxWokerNum(int num); +INTERNAL_API_DLL void ConfigThreadPool(int mode, int nthreads); +INTERNAL_API_DLL inline void CfgThreadPool(int nthread) { ConfigThreadPool(-1, nthread); } +INTERNAL_API_DLL int LiteBackendParallelLaunch(FTVMParallelLambda flambda, void *cdata, int num_task); +INTERNAL_API_DLL int LiteBackendRegisterSystemLibSymbol(const char *name, void *ptr); +INTERNAL_API_DLL void DoAllThreadBind(bool ifBind, int mode); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_ + diff --git a/mindspore/lite/src/runtime/thread_pool.cc b/mindspore/lite/src/runtime/thread_pool.cc new file mode 100644 index 0000000000..933a49664b --- /dev/null +++ b/mindspore/lite/src/runtime/thread_pool.cc @@ -0,0 +1,456 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/thread_pool.h" +#include +#include "utils/log_adapter.h" +#ifdef MS_COMPILE_IOS +#include +#include +#include +#endif // MS_COMPILE_IOS + +namespace mindspore { +namespace predict { +constexpr int kDefaultBigCount = 2; +constexpr int kDefaultMidCount = 2; +constexpr int kSmallCpuNum = 4; +constexpr int kBigMidCpuNum = 4; +constexpr int kDefaultThreadNum = 1; +static unsigned int kDefaultMaxThreadNums = 8; +static unsigned int localMaxThreadNums = 1; + +bool LiteQueue::Enqueue(ThreadPoolTask *task) { + const int tailIndex = tail.load(std::memory_order_relaxed); + // queue full + auto next = (tailIndex + 1) % kSingleThreadMaxTask; + if (next == head.load(std::memory_order_acquire)) { + return false; + } + buffer[tailIndex] = task; + tail.store(next, std::memory_order_release); + ++taskSize; + return true; +} + +bool LiteQueue::Dequeue(ThreadPoolTask **out) { + if (taskSize == 0) { + return false; + } + // queue empty + const int headIndex = head.load(std::memory_order_relaxed); + if (headIndex == tail.load(std::memory_order_acquire)) { + return false; + } + *out = buffer[headIndex]; + head.store((headIndex + 1) % kSingleThreadMaxTask, std::memory_order_release); + return true; +} + +bool LiteThreadBind::Bind(bool ifBind, int numThreads, bool master) { + if (master) { + if (!BindMasterThread(ifBind, bindModel)) { + MS_LOG(ERROR) << "bind msater thread failed"; + return false; + } + MS_LOG(DEBUG) << "bind master thread successful"; + } + if (numThreads > static_cast(sortedCpuIds.size())) { + MS_LOG(ERROR) << "thread num " << numThreads << " is larger than cores " << static_cast(sortedCpuIds.size()) + << " in the system"; + return true; + } + + if (!BindThreads(ifBind)) { + MS_LOG(ERROR) << "action " << ifBind << " thread failed"; + return false; + } + MS_LOG(DEBUG) << "action " << ifBind << " thread successful"; + return true; +} + +void LiteThreadBind::InitSortedCpuId() { + // mate10(970)|p20(970): 4big, 4small + // mate20(980)|p30(980)|mate30(990): 2big, 2mid, 4small + // note: p30's core 7 not allowed to be bind + int numCores = 0; +#ifdef MS_COMPILE_IOS + size_t len = sizeof(numCores); + sysctlbyname("hw.ncpu", &numCores, &len, NULL, 0); + numCores = numCores > 1 ? numCores : 1; +#else + numCores = static_cast(std::thread::hardware_concurrency()); +#endif // MS_COMPILE_IOS + if (numCores < kBigMidCpuNum) { + bigCore = 0; + midCore = numCores; + } else { + bigCore = kDefaultBigCount; + midCore = kDefaultMidCount; + } + sortedCpuIds.clear(); + for (int i = numCores - 1; i >= 0; --i) { + sortedCpuIds.emplace_back(i); + } + if (sortedCpuIds.size() > kSmallCpuNum) { + sortedCpuIds.resize(bigCore + midCore); + } +} + +bool LiteThreadBind::BindMasterThread(bool bindFlag, int mode) { + std::vector cpu; + if (bindFlag) { + size_t cpuIndex; + if (mode == MID_CORE) { + cpuIndex = sortedCpuIds.size() - 1; + } else { + cpuIndex = 0; + } + cpu.emplace_back(sortedCpuIds[cpuIndex]); + } else { + // unbind master + cpu.assign(sortedCpuIds.begin(), sortedCpuIds.end()); + } + cpu_set_t cpuSet; +#ifndef CPU_SET + (void)memset(&cpuSet, 0, sizeof(cpu_set_t)); +#else + CPU_ZERO(&cpuSet); +#endif + for (auto coreId : cpu) { +#ifndef CPU_SET + CPU_SET_LOCAL(coreId, &cpuSet); +#else + CPU_SET(coreId, &cpuSet); +#endif + } + if (!SetCPUBind(pthread_self(), &cpuSet)) { + MS_LOG(ERROR) << "do master bind failed. mode: " << mode; + return false; + } + return true; +} + +bool LiteThreadBind::BindThreads(bool bindFlag) { + if (bindFlag && bindModel != NO_BIND) { + size_t bindNums = std::min(sortedCpuIds.size(), threadIdList.size()); + cpu_set_t cpuSet; + size_t coreIndex; + for (size_t i = 0; i < bindNums; ++i) { +#ifndef CPU_SET + (void)memset(&cpuSet, 0, sizeof(cpu_set_t)); +#else + CPU_ZERO(&cpuSet); +#endif + if (bindModel == MID_CORE) { + coreIndex = sortedCpuIds.size() - 2 - i; + } else { + coreIndex = i + 1; + } +#ifndef CPU_SET + CPU_SET_LOCAL(sortedCpuIds[coreIndex], &cpuSet); +#else + CPU_SET(sortedCpuIds[coreIndex], &cpuSet); +#endif + if (!SetCPUBind(threadIdList[i], &cpuSet)) { + MS_LOG(ERROR) << "do SetCPUBind failed"; + return false; + } + } + } else { + // unbind + size_t bindNums = std::min(sortedCpuIds.size(), threadIdList.size()); + cpu_set_t cpuSet; +#ifndef CPU_SET + (void)memset(&cpuSet, 0, sizeof(cpu_set_t)); +#else + CPU_ZERO(&cpuSet); +#endif + for (auto coreId : sortedCpuIds) { +#ifndef CPU_SET + CPU_SET_LOCAL(coreId, &cpuSet); +#else + CPU_SET(coreId, &cpuSet); +#endif + } + for (size_t i = 0; i < bindNums; ++i) { + if (!SetCPUBind(threadIdList[i], &cpuSet)) { + MS_LOG(ERROR) << "do SetCPUBind failed"; + return false; + } + } + } + return true; +} + +bool LiteThreadBind::SetCPUBind(pthread_t threadId, cpu_set_t *cpuSet) { +#if defined(__ANDROID__) +#if __ANDROID_API__ >= 21 + int ret = sched_setaffinity(pthread_gettid_np(threadId), sizeof(cpu_set_t), cpuSet); + if (ret != 0) { + MS_LOG(ERROR) << "bind thread " << threadId << "to cpu failed.ERROR " << ret; + } +#endif +#else +#ifdef __APPLE__ + MS_LOG(ERROR) << "not bind thread to apple's cpu."; + return false; +#else + int ret = pthread_setaffinity_np(threadId, sizeof(cpuSet), cpuSet); + if (ret != 0) { + MS_LOG(ERROR) << "bind thread " << threadId << " to cpu failed.ERROR " << ret; + return false; + } +#endif // __APPLE__ +#endif + return true; +} + +bool ThreadPool::SetThreadPool() { + std::lock_guard Lock(poolMutex); + if (configThreadNums <= 0) { + MS_LOG(WARNING) << "numThreads " << configThreadNums << ", must be greater than 0"; + configThreadNums = curThreadRunNums; + } + if (localMaxThreadNums == 0) { + localMaxThreadNums = 1; + } else if (localMaxThreadNums > kDefaultMaxThreadNums) { + localMaxThreadNums = kDefaultMaxThreadNums; + } + if (configThreadNums > kDefaultMaxThreadNums) { + configThreadNums = kDefaultMaxThreadNums; + } + int addNum = 0; + if (configThreadNums > kDefaultMaxThreadNums) { + addNum = configThreadNums - curThreadRunNums; + } else if (localMaxThreadNums > curThreadNums) { + addNum = localMaxThreadNums - curThreadNums; + } + AddNewThread(addNum); + if (curThreadRunNums > localMaxThreadNums) { + SubRunThread(localMaxThreadNums); + } else { + AddRunThread(localMaxThreadNums); + } + MS_LOG(DEBUG) << "configThreadNums=" << configThreadNums << ", curThreadNums=" << curThreadNums + << ", curThreadRunNums=" << curThreadRunNums << ", localMaxThreadNums=" << localMaxThreadNums; + return true; +} + +void ThreadPool::AddNewThread(int newNums) { + for (int i = curThreadNums - 1, j = 0; j < newNums; ++i, ++j) { + auto active = new std::atomic_bool{true}; + auto queue = std::make_shared(); + threadList.emplace_back([this, i, active, queue]() { + ThreadPoolTask *task = nullptr; + while (!exitRun) { + while (*active) { + if (queue->Dequeue(&task)) { + auto ret = task->first(i + 1, task->second.tvmParam, task->second.cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(i + 1, std::make_pair(false, ret))); + } + queue->taskSize--; + } + std::this_thread::yield(); + } + std::unique_lock queueLock(tMutex); + queueReady.wait(queueLock, [active, this] { return exitRun || *active; }); + } + }); + activateList.emplace_back(active); + queueList.emplace_back(queue); + } + curThreadNums += newNums; + curThreadRunNums += newNums; + MS_LOG(DEBUG) << "add " << newNums << " thread"; +} + +bool ThreadPool::SetThreadCpuBind(bool ifBind, int mode, bool master) { + if (curThreadRunNums <= 0) { + MS_LOG(ERROR) << "no threads need to be bind, totalThreadNum : " << curThreadRunNums; + return false; + } + if (threadBind == nullptr) { + threadBind = std::unique_ptr(new LiteThreadBind()); + if (threadBind == nullptr) { + MS_LOG(ERROR) << "create threadBind failed"; + return false; + } + threadBind->threadIdList.resize(kDefaultMaxThreadNums); + threadBind->InitSortedCpuId(); + } + threadBind->threadIdList.clear(); + for (auto &it : threadList) { + threadBind->threadIdList.emplace_back(it.native_handle()); + } + threadBind->bindModel = static_cast(mode); + if (!threadBind->Bind(ifBind, curThreadRunNums, master)) { + MS_LOG(ERROR) << "bind failed"; + return false; + } + return true; +} + +bool ThreadPool::AddTask(WorkFun &&worker, void *cdata, int numTask) { + if (numTask <= 0) { + numTask = curThreadRunNums; + } + TvmEnv env{}; + env.num_task = numTask; + errorInfo.clear(); + // single task, run master thread + if (curThreadRunNums <= 1) { + for (int i = 0; i < numTask; ++i) { + int ret = worker(i, &env, cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(0, std::make_pair(false, ret))); + } + } + return CheckResult(); + } + ThreadPoolTask task; + task.first = std::move(worker); + task.second.cdata = cdata; + task.second.tvmParam = &env; + return DistributeTask(&task, numTask); +} + +bool ThreadPool::DistributeTask(ThreadPoolTask *task, int numTask) { + MS_LOG(DEBUG) << "numTask = " << numTask << ", curThreadRunNums = " << curThreadRunNums; + auto taskOri = *task; + if (numTask > curThreadRunNums) { + task->first = [taskOri, numTask, this](int task_id, TvmEnv *penv, void *cdata) -> int { + for (int i = task_id; i < numTask; i += curThreadRunNums) { + int ret = taskOri.first(i, penv, cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(i + 1, std::make_pair(false, ret))); + } + } + return 0; + }; + } + bool kSuccFlag; + auto size = std::min(curThreadRunNums, numTask); + for (int i = 0; i < size - 1; ++i) { + do { + kSuccFlag = true; + if (!queueList[i]->Enqueue(task)) { + std::this_thread::yield(); + kSuccFlag = false; + } + } while (!kSuccFlag); + } + // master thread + int ret = task->first(0, task->second.tvmParam, task->second.cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(0, std::make_pair(false, ret))); + } + kSuccFlag = false; + while (!kSuccFlag) { + std::this_thread::yield(); + kSuccFlag = true; + for (int i = 0; i < curThreadRunNums - 1; ++i) { + if (queueList[i]->taskSize != 0) { + kSuccFlag = false; + break; + } + } + } + MS_LOG(DEBUG) << "finish " << numTask << " task successful"; + return CheckResult(); +} + +void ThreadPool::AddRunThread(int num) { + MS_LOG(DEBUG) << "num=" << num << ", curThreadRunNums=" << curThreadRunNums; + int activeNums = num - curThreadRunNums; + if (activeNums <= 0 || activateList.size() < activeNums) { + return; + } + for (int i = curThreadRunNums - 1, j = 0; j < activeNums; ++i, ++j) { + *activateList[i] = true; + } + std::lock_guard queueLock(tMutex); + queueReady.notify_all(); + curThreadRunNums = num; +} + +void ThreadPool::SubRunThread(int num) { + MS_LOG(DEBUG) << "num=" << num << ", curThreadRunNums=" << curThreadRunNums; + int deactiveNums = curThreadRunNums - num; + if (deactiveNums <= 0) { + return; + } + for (int i = num - 1, j = 0; j < deactiveNums; ++i, ++j) { + *activateList[i] = false; + } + curThreadRunNums = num; +} + +bool ThreadPool::CheckResult() { + bool kSuccFlag = true; + for (auto result : errorInfo) { + if (result.second.first) { + MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second; + kSuccFlag = false; + } + } + return kSuccFlag; +} + +bool ThreadPool::LaunchWork(WorkFun worker, void *cdata, int numTask) { + if (!SetThreadPool()) { + return false; + } + return AddTask(std::move(worker), cdata, numTask); +} + +bool ThreadPool::BindAllThreads(bool ifBind, int mode, bool master) { + if (!SetThreadPool()) { + return false; + } + return SetThreadCpuBind(ifBind, mode, master); +} + +void ThreadPool::ConfigThreadPool(int mode, int numThreads) { + configBindMode = mode; + configThreadNums = numThreads; +} + +void ThreadPool::ConfigMaxThreadNum(unsigned int num) { localMaxThreadNums = num; } + +ThreadPool *ThreadPool::GetInstance() { + static ThreadPool instance; + return &instance; +} + +ThreadPool::~ThreadPool() { + curThreadRunNums = static_cast(threadList.size() + 1); + exitRun = true; + SubRunThread(kDefaultThreadNum); + queueReady.notify_all(); + for (auto &it : threadList) { + if (it.joinable()) { + it.join(); + } + } + for (const auto &it : activateList) { + delete it; + } +} +} // namespace predict +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/thread_pool.h b/mindspore/lite/src/runtime/thread_pool.h new file mode 100644 index 0000000000..f9a26bff65 --- /dev/null +++ b/mindspore/lite/src/runtime/thread_pool.h @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_THREAD_POOL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/runtime/runtime_api.h" + +namespace mindspore { +namespace predict { +#ifndef CPU_SET +const int CPU_SETSIZE = 1024; +#define __NCPUBITS (8 * sizeof(uint64_t)) +typedef struct { + uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; +} cpu_set_t; + +#define CPU_SET_LOCAL(cpu, cpusetp) ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#endif + +constexpr int kSingleThreadMaxTask = 2; +using TvmEnv = LiteParallelGroupEnv; +using WorkFun = std::function; +using TaskParam = struct Param { + void *cdata; + TvmEnv *tvmParam; +}; +using ThreadPoolTask = std::pair; +enum AffinityMode : int { BIG_CORE = 1, MID_CORE = -1, NO_BIND = 0 }; + +class LiteQueue { + public: + LiteQueue() = default; + ~LiteQueue() = default; + bool Enqueue(ThreadPoolTask *task); + bool Dequeue(ThreadPoolTask **out); + std::atomic_int taskSize = {0}; + + private: + std::atomic_int head = {0}; + std::atomic_int tail = {0}; + ThreadPoolTask *buffer[kSingleThreadMaxTask]{}; +}; + +class LiteThreadBind { + public: + LiteThreadBind() = default; + ~LiteThreadBind() = default; + void InitSortedCpuId(); + bool Bind(bool ifBind, int numThreads, bool master); + AffinityMode bindModel = MID_CORE; + std::vector threadIdList; + + private: + bool BindMasterThread(bool bindFlag, int mode); + bool BindThreads(bool bindFlag); + bool SetCPUBind(pthread_t threadId, cpu_set_t *cpuSet); + int bigCore = 0; + int midCore = 0; + std::vector sortedCpuIds{}; +}; + +class ThreadPool { + public: + ThreadPool() = default; + ~ThreadPool(); + static ThreadPool *GetInstance(); + bool LaunchWork(WorkFun worker, void *cdata, int numTask); + void ConfigThreadPool(int mode, int numThreads); + void ConfigMaxThreadNum(unsigned int num); + bool BindAllThreads(bool ifBind, int mode, bool master = true); + ThreadPool(const ThreadPool &) = delete; + ThreadPool &operator=(const ThreadPool &) = delete; + + private: + bool SetThreadPool(); + void AddNewThread(int newNums); + bool SetThreadCpuBind(bool ifBind, int mode, bool master); + bool AddTask(WorkFun &&worker, void *cdata, int numTask); + bool DistributeTask(ThreadPoolTask *task, int numTask); + void AddRunThread(int num); + void SubRunThread(int num); + bool CheckResult(); + + std::mutex poolMutex; + std::mutex tMutex; + std::condition_variable queueReady; + std::atomic_bool exitRun = {false}; + std::vector activateList{}; + int curThreadNums = 1; + int curThreadRunNums = 1; + int configThreadNums = 1; + int configBindMode = -1; + std::vector threadList{}; + std::vector> queueList{}; + std::unique_ptr threadBind{nullptr}; + std::vector>> errorInfo{}; +}; +} // namespace predict +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_RUNTIME_THREAD_POOL_H_ + diff --git a/mindspore/lite/src/runtime/workspace_pool.cc b/mindspore/lite/src/runtime/workspace_pool.cc new file mode 100644 index 0000000000..b1cd76cb1d --- /dev/null +++ b/mindspore/lite/src/runtime/workspace_pool.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/workspace_pool.h" +#ifdef __APPLE__ +#include +#else +#include +#endif +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace predict { +static constexpr size_t kWorkspacePageSize = 4096; +static constexpr int kTempAllocaAlignment = 64; +WorkspacePool *WorkspacePool::GetInstance() { + static WorkspacePool instance; + return &instance; +} + +void *WorkspacePool::AllocWorkSpaceMem(size_t size) { + size_t nbytes = (size + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize; + if (nbytes == 0) { + nbytes = kWorkspacePageSize; + } + std::pair alloc; + // fist alloc + if (freeList.empty()) { + alloc.first = nbytes; +#ifdef __APPLE__ + int err = posix_memalign(&alloc.second, kTempAllocaAlignment, nbytes); + if (err != 0) { + MS_LOGE("posix_memalign failed, error code:%d", err); + return alloc.second; + } +#else + alloc.second = memalign(kTempAllocaAlignment, nbytes); +#endif + } else if (freeList.size() == 1) { // one element + alloc = *(freeList.begin()); + freeList.erase(freeList.begin()); + if (alloc.first < nbytes) { + free(alloc.second); + alloc.first = nbytes; +#ifdef __APPLE__ + int err = posix_memalign(&alloc.second, kTempAllocaAlignment, nbytes); + if (err != 0) { + MS_LOGE("posix_memalign failed, error code:%d", err); + return alloc.second; + } +#else + alloc.second = memalign(kTempAllocaAlignment, nbytes); +#endif + } + } else { + if ((*(freeList.begin())).first >= nbytes) { + auto iter = freeList.begin(); + for (; iter != freeList.end(); ++iter) { + if ((*iter).first < size) { + alloc = *(--iter); + freeList.erase(iter); + break; + } + } + if (iter == freeList.end()) { + alloc = *(freeList.rbegin()); + freeList.erase(--freeList.end()); + } + } else { + alloc = *(freeList.begin()); + freeList.erase(freeList.begin()); + free(alloc.second); + alloc.first = nbytes; +#ifdef __APPLE__ + int err = posix_memalign(&alloc.second, kTempAllocaAlignment, nbytes); + if (err != 0) { + MS_LOGE("posix_memalign failed, error code:%d", err); + return alloc.second; + } +#else + alloc.second = memalign(kTempAllocaAlignment, nbytes); +#endif + } + } + allocList.emplace_back(alloc); + return alloc.second; +} + +void WorkspacePool::FreeWorkSpaceMem(void *ptr) { + if (ptr == nullptr) { + return; + } + std::pair alloc; + if (allocList.empty()) { + MS_LOG(ERROR) << "no mem have been alloc"; + return; + } else if (allocList.back().second == ptr) { + alloc = allocList.back(); + allocList.pop_back(); + } else { + auto iter = allocList.begin(); + for (; iter != allocList.end(); ++iter) { + if ((*iter).second == ptr) { + alloc = *iter; + allocList.erase(iter); + break; + } + } + if (iter == allocList.end()) { + MS_LOG(ERROR) << "no value ptr have been alloc"; + return; + } + } + freeList.insert(alloc); +} + +WorkspacePool::~WorkspacePool() { + for (auto &a : allocList) { + free(a.second); + } + allocList.clear(); + for (auto &f : freeList) { + free(f.second); + } + freeList.clear(); +} +} // namespace predict +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/workspace_pool.h b/mindspore/lite/src/runtime/workspace_pool.h new file mode 100644 index 0000000000..9342200b28 --- /dev/null +++ b/mindspore/lite/src/runtime/workspace_pool.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_WORKSPACE_POOL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_WORKSPACE_POOL_H_ +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace predict { +class WorkspacePool { + public: + WorkspacePool() = default; + ~WorkspacePool(); + WorkspacePool(const WorkspacePool &) = delete; + WorkspacePool &operator=(const WorkspacePool &) = delete; + static WorkspacePool *GetInstance(); + void *AllocWorkSpaceMem(size_t size); + void FreeWorkSpaceMem(void *ptr); + + private: + std::vector> allocList{}; + std::set, std::greater>> freeList{}; +}; +} // namespace predict +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_RUNTIME_WORKSPACE_POOL_H_ + diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc new file mode 100644 index 0000000000..d3a7ca7d69 --- /dev/null +++ b/mindspore/lite/src/scheduler.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/scheduler.h" +#include +#include +#include "include/errorcode.h" +#include "src/kernel_factory.h" +#if SUPPORT_GPU +#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#endif + +namespace mindspore::lite { +int Scheduler::Schedule(const lite::Model *model, std::vector *tensors, + std::vector *kernels) { + // 1. op ---> kernel + // 2. sub graph + // 3. kernels (kernels --> subGraph) + int ret = InitOp2Kernel(model, tensors, kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "init op to kernel failed."; + return RET_ERROR; + } + + kernel::LiteKernelUtil::TopologicalSortKernels(*kernels); + + ConstructSubgraphs(kernels); + + MS_LOG(DEBUG) << "schedule kernels success."; + return RET_OK; +} + +int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector *tensors, + std::vector *kernels) { + MS_EXCEPTION_IF_NULL(model); + MS_EXCEPTION_IF_NULL(tensors); + MS_EXCEPTION_IF_NULL(kernels); + auto meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + uint32_t kernelCount = meta_graph->nodes()->size(); + for (uint32_t i = 0; i < kernelCount; i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + std::vector inputs; + std::vector outputs; + auto inIndexes = cNode->inputIndex(); + for (size_t j = 0; j < inIndexes->size(); j++) { + inputs.emplace_back(tensors->at(size_t(inIndexes->GetAs(j)))); + } + auto outIndexes = cNode->outputIndex(); + for (size_t j = 0; j < outIndexes->size(); j++) { + outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs(j)))); + } + auto *primitive = model->GetOp(cNode->name()->str()); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Op " << cNode->name()->str() << " should exist in model, type: " + << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return RET_ERROR; + } + auto ret = primitive->InferShape(inputs, outputs); + if (0 != ret) { + MS_LOG(ERROR) << "InferShape failed, name: " << cNode->name()->str() + << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return ret; + } + + auto *kernel = this->ScheduleNode(inputs, outputs, primitive); + if (nullptr == kernel) { + MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << cNode->name()->str() + << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return RET_ERROR; + } + kernel->set_name(cNode->name()->str()); + kernels->emplace_back(kernel); + } + return RET_OK; +} + +void Scheduler::ConstructSubgraphs(std::vector *kernels) { + uint32_t kernel_count = kernels->size(); + std::vector sub_kernels; + std::vector> sub_kernels_list; + + kernel::KERNEL_ARCH prev_arch = kernels->front()->Desc().arch; + for (uint32_t i = 0; i < kernel_count; ++i) { + auto curr_kernel = kernels->at(i); + auto curr_arch = curr_kernel->Desc().arch; + if (curr_arch == prev_arch) { + sub_kernels.emplace_back(curr_kernel); + } + if ((curr_arch != prev_arch) || (i == kernel_count - 1)) { + sub_kernels_list.emplace_back(sub_kernels); + sub_kernels.clear(); + sub_kernels.emplace_back(curr_kernel); + } + prev_arch = curr_arch; + } + + std::vector subgraph_kernels; + for (auto temp_kernels : sub_kernels_list) { + kernel::KERNEL_ARCH arch = temp_kernels.front()->Desc().arch; + if (arch == kernel::KERNEL_ARCH::kCPU) { + for (auto kernel : temp_kernels) { + for (auto tensor : kernel->GetOutputs()) { + tensor->set_allocator(context_->allocator.get()); + } + } + std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); + } else { + auto subgraph_kernel = CreateSubKernel(temp_kernels, arch); + subgraph_kernels.emplace_back(subgraph_kernel); + } + } + kernels->clear(); + kernels->insert(kernels->begin(), subgraph_kernels.begin(), subgraph_kernels.end()); +} + +kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector &kernels, + kernel::KERNEL_ARCH arch) { + kernel::LiteKernel *sub_kernel = nullptr; +#if SUPPORT_GPU + if (arch == kernel::KERNEL_ARCH::kGPU) { + std::vector input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels); + std::vector output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); + std::vector input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels); + std::vector output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); + sub_kernel = + new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels); + sub_kernel->Init(); + } else if (arch == kernel::KERNEL_ARCH::kNPU) { + MS_LOG(ERROR) << "NPU kernel is not supported"; + } else { + MS_LOG(ERROR) << "unsupported kernel arch: " << arch; + } +#endif + return sub_kernel; +} + +int Scheduler::MarkKernels(const std::vector &kernels) { return 0; } + +int Scheduler::MergeKernels(std::vector *kernels) { return 0; } + +kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector &inputs, + const std::vector &outputs, + const lite::Primitive *primitive) { + // todo: support NPU, APU + MS_ASSERT(nullptr != primitive); + auto data_type = inputs.front()->data_type(); + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()}; + if (context_->device_ctx_.type == DT_GPU) { + desc.arch = kernel::KERNEL_ARCH::kGPU; + auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc); + if (nullptr != kernel) { + kernel->set_desc(desc); + return kernel; + } + } + + desc.arch = kernel::KERNEL_ARCH::kCPU; + kernel::LiteKernel *kernel; + if (data_type == kNumberTypeFloat32) { + // check if support fp16 + kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; + kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, key); + if (kernel != nullptr) { + MS_LOG(DEBUG) << "Get fp16 op success."; + kernel->set_desc(desc); + return kernel; + } + MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; + kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc); + } else { + kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc); + } + if (kernel != nullptr) { + kernel->set_desc(desc); + return kernel; + } + return nullptr; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h new file mode 100644 index 0000000000..f86ca2b035 --- /dev/null +++ b/mindspore/lite/src/scheduler.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_ +#define MINDSPORE_LITE_SRC_SCHEDULER_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "include/model.h" + +namespace mindspore::lite { +class Scheduler { + public: + explicit Scheduler(const Context *ctx) : context_(ctx) {} + int Schedule(const lite::Model *model, std::vector *tensors, + std::vector *kernels); + + protected: + kernel::LiteKernel *ScheduleNode(const std::vector &inputs, + const std::vector &outputs, const lite::Primitive *primitive); + // find schedule able kernels and save in markedKernelGroup + int MarkKernels(const std::vector &kernels); + // use SubGraphKernel to replace group in kernels + int MergeKernels(std::vector *kernels); + + private: + int InitOp2Kernel(const lite::Model *model, std::vector *tensors, + std::vector *kernels); + + // construct SubGraphKernel for each kernel-group in markedKernelGroup + void ConstructSubgraphs(std::vector *kernels); + + kernel::LiteKernel *CreateSubKernel(const std::vector &kernels, kernel::KERNEL_ARCH arch); + + protected: + std::vector> markedKernelGroup; + const Context *context_ = nullptr; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_SCHEDULER_H_ diff --git a/mindspore/lite/src/train/base_ref_utils.cc b/mindspore/lite/src/train/base_ref_utils.cc new file mode 100644 index 0000000000..5df16a9552 --- /dev/null +++ b/mindspore/lite/src/train/base_ref_utils.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/train/base_ref_utils.h" +#include +#include +// #include "utils/base_ref_utils.h" +#include "include/ms_tensor.h" +#include "src/ir/tensor.h" + +namespace mindspore { +std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref) { + std::vector> msTensors; + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + for (size_t i = 0; i < ref_list.size(); ++i) { + if (utils::isa(ref_list[i])) { + auto tensor_ptr = utils::cast>(ref_list[i]); + MS_EXCEPTION_IF_NULL(tensor_ptr); + auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); + msTensors.emplace_back(std::shared_ptr(tensor)); + } else { + MS_LOG(EXCEPTION) << "The output is not a tensor!"; + } + } + } else if (utils::isa(base_ref)) { + auto tensor_ptr = utils::cast>(base_ref); + MS_EXCEPTION_IF_NULL(tensor_ptr); + auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); + msTensors.emplace_back(std::shared_ptr(tensor)); + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + return msTensors; +} + +std::vector>> TransformVectorRefToMultiTensor( + const VectorRef &vector_ref) { + std::vector>> multiTensor; + for (size_t i = 0; i < vector_ref.size(); ++i) { + auto tensors = TransformBaseRefToMSTensor(vector_ref[i]); + multiTensor.emplace_back(tensors); + } + return multiTensor; +} +} // namespace mindspore + diff --git a/mindspore/lite/src/train/base_ref_utils.h b/mindspore/lite/src/train/base_ref_utils.h new file mode 100644 index 0000000000..63370efeb9 --- /dev/null +++ b/mindspore/lite/src/train/base_ref_utils.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "base/base_ref.h" +#include "include/ms_tensor.h" + +#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H +#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H +namespace mindspore { +std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref); + +std::vector>> TransformVectorRefToMultiTensor( + const VectorRef &vector_ref); +} // namespace mindspore +#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H + diff --git a/mindspore/lite/src/train/import.hpp b/mindspore/lite/src/train/import.hpp new file mode 100644 index 0000000000..e8153beacb --- /dev/null +++ b/mindspore/lite/src/train/import.hpp @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/common/anf_importer/import_from_meta_graph.h" +namespace mindspore::lite::train { +std::shared_ptr Import(const char *model_buf, size_t size) { + MS_EXCEPTION_IF_NULL(model_buf); + flatbuffers::Verifier verify((const uint8_t *) model_buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return nullptr; + } + // todo hangangqiang remove when copy primitive done + auto *inner_buf = new char[size]; + memcpy(inner_buf, model_buf, size); + auto meta_graph = schema::GetMetaGraph(inner_buf); + auto model = std::make_shared(meta_graph); + auto ret = model->BuildOps(); + if (0 != ret) { + MS_LOG(ERROR) << "BuildOps failed"; + return nullptr; + } + MS_EXCEPTION_IF_NULL(meta_graph); + auto importer = new AnfImporterFromMetaGraph(model); + auto ret2 = importer->Import(); + if (0 != ret2) { + MS_LOG(ERROR) << "Import anf_graph from meta_graph failed, ret2: " << ret2; + return nullptr; + } + return model; +} +} // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/lite_kernel_runtime.cc b/mindspore/lite/src/train/lite_kernel_runtime.cc new file mode 100644 index 0000000000..e1eb95d442 --- /dev/null +++ b/mindspore/lite/src/train/lite_kernel_runtime.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/train/lite_kernel_runtime.h" +namespace mindspore::lite { +std::vector LiteInferKernelRuntime::GetGraphInputs(const std::vector &execution_order) { + std::vector graph_inputs; + for (const auto &cnode : execution_order) { + bool is_graph_inputs = true; + for (const auto &input : cnode->inputs()) { + if (input->isa()) { + is_graph_inputs = false; + break; + } + } + if (is_graph_inputs) { + graph_inputs.emplace_back(cnode); + } + } + return graph_inputs; +} + +void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, + const std::vector &inputs, VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(graph); + auto execution_order = graph->execution_order(); + auto graph_inputs = GetGraphInputs(execution_order); + int input_count = 0; + for (const auto &graph_input : graph_inputs) { + auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(graph_input)); + for (auto input_tensor : liteKernel->GetInputs()) { + if (schema::NodeType_ValueNode == input_tensor->TensorType() && input_tensor->Data() != nullptr) { + continue; + } + input_tensor->SetData(inputs[input_count]->Data()); + input_count++; + } + } + + auto return_node = graph->get_return(); + for (const auto &return_input : return_node->inputs()) { + if (return_input->isa()) { + auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(return_input)); + auto output_tensors = liteKernel->GetOutputs(); + for (auto output_tensor : output_tensors) { + tensor::TensorPtr output_tensor_ptr(output_tensor); + outputs->push_back(output_tensor_ptr); + } + } + } +} + +bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector kernels; + auto nodes = graph->execution_order(); + for (const auto &node : nodes) { + auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(node)); + if (liteKernel == nullptr) { + continue; + } + kernels.emplace_back(liteKernel); + } + kernel::LiteKernelUtil::TopologicalSortKernels(kernels); + Executor executor; + auto ret = executor.Run(kernels); + return 0 == ret; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/train/lite_kernel_runtime.h b/mindspore/lite/src/train/lite_kernel_runtime.h new file mode 100644 index 0000000000..27b4ec867b --- /dev/null +++ b/mindspore/lite/src/train/lite_kernel_runtime.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ +#define MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ + +#include +#include +#include +#include +#include "src/runtime/allocator.h" +#include "src/executor.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/device_address.h" +#include "src/lite_kernel.h" +#include "backend/session/kernel_graph.h" +namespace mindspore::lite { +class LiteInferKernelRuntime : public device::KernelRuntime { + public: + LiteInferKernelRuntime() = default; + ~LiteInferKernelRuntime() override = default; + + bool Init() override { return true; } + + void BindInputOutput(const session::KernelGraph *graph, const std::vector &inputs, + VectorRef *outputs); + + bool Run(session::KernelGraph *graph); + + void AssignKernelAddress(session::KernelGraph *graph) {} + + protected: + std::vector GetGraphInputs(const std::vector &execution_order); + bool SyncStream() override { return true; }; + device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override { + return nullptr; + }; +}; + +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ + diff --git a/mindspore/lite/src/train/model_impl.cc b/mindspore/lite/src/train/model_impl.cc new file mode 100644 index 0000000000..30d60f7709 --- /dev/null +++ b/mindspore/lite/src/train/model_impl.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/train/model_impl.h" +#include "schema/model_generated.h" +#include "ir/func_graph.h" + +namespace mindspore::lite::train { + +const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { + auto iter = ops.find(name); + if (iter == ops.end()) { + return nullptr; + } else { + return iter->second; + } +} + +void ModelImpl::FreeMetaGraph() { delete this->meta_graph; } + +const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; } + +lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { + MS_EXCEPTION_IF_NULL(srcPrim); + auto op_type = srcPrim->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(srcPrim)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(srcPrim)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(srcPrim)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(srcPrim)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(srcPrim)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(srcPrim)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_CaffeBatchNorm: + return new lite::CaffeBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(srcPrim)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(srcPrim)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(srcPrim)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(srcPrim)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(srcPrim)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(srcPrim)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(srcPrim)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(srcPrim)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(srcPrim)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(srcPrim)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(srcPrim)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(srcPrim)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(srcPrim)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(srcPrim)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(srcPrim)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(srcPrim)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(srcPrim)); + case schema::PrimitiveType_Slice: + return new lite::Slice(const_cast(srcPrim)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(srcPrim)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(srcPrim)); + default: + break; + } + return nullptr; +} + +int ModelImpl::BuildOps() { + if (this->meta_graph == nullptr) { + MS_LOG(ERROR) << "mete_graph is nullptr"; + return -1; + } + for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + auto name = cNode->name()->str(); + auto srcPrim = cNode->primitive(); + this->ops[name] = CopyPrimitive(srcPrim); + } +} +} // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/model_impl.h b/mindspore/lite/src/train/model_impl.h new file mode 100644 index 0000000000..496fed2ac3 --- /dev/null +++ b/mindspore/lite/src/train/model_impl.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ +#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H + +#include +#include +#include +#include "schema/model_generated.h" +#include "src/ops/ops.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +namespace train { +class ModelImpl : public FuncGraph { + public: + static std::shared_ptr Import(const char *model_buf, size_t size); + ModelImpl() = default; + explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} + ~ModelImpl() override = default; + const lite::Primitive *GetOp(const std::string &name) const; + const schema::MetaGraph *GetMetaGraph() const; + void FreeMetaGraph(); + int BuildOps(); + + protected: + lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); + + protected: + const schema::MetaGraph *meta_graph = nullptr; + std::map ops; +}; +} // namespace train +using ModelImpl = mindspore::lite::train::ModelImpl; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_INCLUDE_MODEL_H + diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc new file mode 100644 index 0000000000..c92473e177 --- /dev/null +++ b/mindspore/lite/src/train/train_session.cc @@ -0,0 +1,232 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/src/train/train_session.h" +#include "mindspore/lite/src/kernel_factory.h" +#include "mindspore/lite/src/param_value_lite.h" +#include "utils/ms_utils.h" +#include "mindspore/lite/src/ops/ops.h" +#include "ir/anf.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "abstract/abstract_value.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "src/ir/primitive_value.h" + +namespace mindspore { +namespace session { +static std::vector GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + auto shape = nodeAbstract->GetShapeTrack(); + if (!shape->isa()) { + MS_LOG(EXCEPTION) << "Not a Shape"; + return {}; + } + auto dims = dyn_cast(shape)->shape(); + return dims; + } else { + MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; + return {}; + } +} + +static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode + } else { + MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; + return schema::Format_NHWC; + } +} + +static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + return nodeAbstract->GetTypeTrack()->type_id(); + } else { + MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; + return TypeId::kTypeUnknown; + } +} + +int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { + auto return_node = kernel_graph->get_return(); + auto node_list = TopoSort(return_node); + for (auto &node : node_list) { + if (!node->isa()) { + continue; + } + KernelRelation kernel_relation; + auto cnode = node->cast(); + kernel_relation.node_full_name = cnode->fullname_with_scope(); + kernel_relation.cnode = cnode; + auto *out_tensor = + new tensor::Tensor(GetAnfNodeOutTypeId(cnode), GetAnfNodeOutDims(cnode), GetAnfNodeFormat(cnode), + schema::NodeType_Parameter); + kernel_relation.output_tensor.push_back(out_tensor); + tensor::Tensor *tensor_ptr = nullptr; + for (size_t index = 1; index < cnode->inputs().size(); ++index) { + if (cnode->input(index)->isa()) { + auto input_cnode = cnode->input(index)->cast(); + auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; + // todo not support multi-outputs kernel sudo as spilt + tensor_ptr = input_kernel_relation.output_tensor.front(); + } else if (cnode->input(index)->isa()) { + auto input_parameter = cnode->input(index)->cast(); + auto para = input_parameter->default_param(); + auto param_value = std::dynamic_pointer_cast(para); + auto dims = param_value->tensor_shape(); + tensor_ptr = new tensor::Tensor(param_value->tensor_type(), dims, schema::Format_NHWC, + schema::NodeType_ValueNode); // XXX TODO -- extract Format from AnfNode + if (param_value->tensor_size() != 0) { + tensor_ptr->SetData(param_value->tensor_addr()); + } + } else if (cnode->input(index)->isa()) { + auto input_valuenode = cnode->input(index)->cast(); + tensor_ptr = new tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), GetAnfNodeOutDims(input_valuenode), + schema::Format_NHWC, + schema::NodeType_Parameter); // XXX TODO -- extract Format from AnfNode + // todo(yankai) + } else { + MS_ASSERT(false); + } + kernel_relation.input_tensor.push_back(tensor_ptr); + } + kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; + } + return 0; +} + +GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + auto graph_id = graph_sum_; + auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); + MS_EXCEPTION_IF_NULL(graph); + + BuildKernel(graph.get()); + MS_LOG(INFO) << "Assign kernel address"; + runtime_.AssignKernelAddress(graph.get()); + return graph_id; +} + +GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } + +std::shared_ptr TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto graph = NewKernelGraph(); + graph->set_return(func_graph->get_return()); + auto node_list = TopoSort(func_graph->get_return()); + std::vector cnode_order; + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cn_node = node->cast(); + cnode_order.push_back(cn_node); + } + } + graph->set_execution_order(cnode_order); + return graph; +} + +GraphId TrainSession::CompileGraph(NotNull func_graph) { + auto graph = ConstructKernelGraph(func_graph); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Set kernel info"; + SetKernelInfo(graph.get()); + + (void) BuildKernelInputAndOutputFromFuncGraph(graph); + MS_LOG(INFO) << "Build kernel"; + auto ret = BuildKernel(graph.get()); + if (0 != ret) { + MS_LOG(EXCEPTION) << "BuildKernel failed"; + } + + // return the graph id to backend + auto graph_id = graph->graph_id(); + graphs_[graph_id] = graph; + MS_LOG(INFO) << "Compile graph " << graph_id << " success"; + return graph_id; +} + +void TrainSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector &outputs) { + auto &kernel_graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Bind input output address"; + runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); + // auto execution_order = kernel_graph->execution_order(); + // Todo : hangangqiang + // Reorder(&execution_order); + // kernel_graph->set_execution_order(execution_order); + MS_LOG(INFO) << "Run graph start"; + auto ret = runtime_.Run(kernel_graph.get(), (std::vector &) inputs, outputs); + if (!ret) { + MS_LOG(EXCEPTION) << "Run graph failed"; + } + MS_LOG(INFO) << "Run graph end"; +} + +void TrainSession::SetKernelInfo(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_info = std::make_shared(); + kernel_node->set_kernel_info(kernel_info); + } +} + +int TrainSession::BuildKernel(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { + std::string kernel_name = iter->first; + KernelRelation anf_register = iter->second; + MS_EXCEPTION_IF_NULL(anf_register.cnode); + if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { + continue; + } + lite::Context context; + context.deviceCtx.type = lite::DeviceType::DT_CPU; + auto value_node_prim = anf_register.cnode->input(0); + MS_EXCEPTION_IF_NULL(value_node_prim); + auto prim = GetValueNode>(value_node_prim); + MS_EXCEPTION_IF_NULL(prim); + auto node_primitive = (lite::Primitive *) (prim->GetPrimitive()); + MS_EXCEPTION_IF_NULL(node_primitive); + auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); + if (0 != ret) { + MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; + return ret; + } + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, node_primitive->Type()}; + + auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, + node_primitive, &context, desc); + if (nullptr == kernel) { + MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; + return -1; + } + kernel->train(); + auto *kernel_info = anf_register.cnode->kernel_info(); + std::shared_ptr kernel_mod(kernel); + kernel_info->set_kernel_mod(kernel_mod); + } + return 0; +} +} // namespace session +} // namespace mindspore + diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h new file mode 100644 index 0000000000..d9b026d55f --- /dev/null +++ b/mindspore/lite/src/train/train_session.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ +#include +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "mindspore/lite/src/train/lite_kernel_runtime.h" +#include "backend/session/session_factory.h" +namespace mindspore { +namespace lite::tensor { +class Tensor; +} +namespace session { +struct KernelRelation { + std::string node_full_name; + std::vector input_tensor; + std::vector output_tensor; + CNodePtr cnode; +}; + +class TrainSession : public SessionBasic { + public: + TrainSession() : SessionBasic() {} + ~TrainSession() override = default; + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kCPUDevice, device_id); + } + + GraphId CompileGraph(NotNull func_graph) override; + + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + + private: + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + GraphId CompileGraph(const char *model_buf, size_t size); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); + int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); + void SetKernelInfo(const KernelGraph *kernel_graph); + int BuildKernel(const KernelGraph *kernel_graph); + lite::LiteInferKernelRuntime runtime_; + std::map kernel_relation_infos_; +}; +MS_REG_SESSION(kCPUDevice, TrainSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ + diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt new file mode 100644 index 0000000000..8a9c47ce4a --- /dev/null +++ b/mindspore/lite/test/CMakeLists.txt @@ -0,0 +1,337 @@ +set(TEST_DIR ${TOP_DIR}/mindspore/lite/test) +set(LITE_DIR ${TOP_DIR}/mindspore/lite) +include_directories(${TOP_DIR}) +include_directories(${TEST_DIR}) +include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/dependency_gtest.cmake) + +### anf src +set(ANF_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/meta_tensor.cc + ${CCSRC_DIR}/gvar/logging_level.cc + ${CCSRC_DIR}/gvar/typeid_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/base/base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/log_adapter.cc + ) +if(BUILD_CONVERTER) + set(ANF_SRC + ${ANF_SRC} + # core/base + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/base/base_ref.cc + # core/ir + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/anf.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/anf_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/meta_func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/graph_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/func_graph_cloner.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/func_graph_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/primitive.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/visitor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/meta_tensor_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/named.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/any.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/symbolic.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/misc.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/trace_base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/trace_info.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/label.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/info.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/profile.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/ms_context.cc + # core/abstract + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/abstract_function.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/analysis_context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/param_validator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/abstract_value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/dshape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/utils.cc + ## ccsrc + ${CCSRC_DIR}/debug/draw.cc + ${CCSRC_DIR}/pybind_api/export_flags.cc + ${CCSRC_DIR}/utils/context/context_extends.cc + ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc + ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc + ${CCSRC_DIR}/backend/optimizer/common/visit.cc + ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc + ) +else() + set(ANF_SRC + ${ANF_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/../src/ir/meta_tensor_extends.cc + ) +endif() +### cpu kernel +file(GLOB KERNEL_OP_SRC + ${LITE_DIR}/src/runtime/kernel/arm/base/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/fp32/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc + ) +if (PLATFORM_ARM64) + # assembly + file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.S) + + set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${TEST_ASSEMBLY_SRC} + ) +endif() +if (PLATFORM_ARM32) + # assembly + file(GLOB TEST_ASSEMBLY_SRC + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm32/*.S + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm32/*.s) + set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${TEST_ASSEMBLY_SRC} + ) +endif() +if (ENABLE_FP16) + file(GLOB KERNEL_OP_FP16_SRC + ${LITE_DIR}/src/runtime/kernel/arm/fp16/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/fp16/*.cc + ) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${KERNEL_OP_FP16_SRC} + ) +endif () +### gpu kernel +if (SUPPORT_GPU) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${LITE_DIR}/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc + ${LITE_DIR}/src/runtime/kernel/opencl/utils.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/arithmetic.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/convolution.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/pooling2d.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc + ) +endif() +### minddata lite +if (BUILD_MINDDATA) + include_directories(${CCSRC_DIR}/minddata) + set(DATASET_TEST_DIR ${TEST_DIR}/ut/src/dataset) + set(TEST_MINDDATA_SRC + ${DATASET_TEST_DIR}/de_tensor_test.cc + ${DATASET_TEST_DIR}/eager_test.cc + ) +endif() +### runtime framework +file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc) +set(TEST_LITE_SRC + ${ANF_SRC} + ${OPS_SRC} + ${KERNEL_OP_SRC} + ${LITE_DIR}/src/runtime/allocator.cc + ${LITE_DIR}/src/runtime/runtime_api.cc + ${LITE_DIR}/src/runtime/thread_pool.cc + ${LITE_DIR}/src/runtime/workspace_pool.cc + ${LITE_DIR}/src/ir/tensor.cc + ${LITE_DIR}/src/context.cc + ${LITE_DIR}/src/executor.cc + ${LITE_DIR}/src/kernel_factory.cc + ${LITE_DIR}/src/kernel_registry.cc + ${LITE_DIR}/src/lite_kernel.cc + ${LITE_DIR}/src/lite_session.cc + ${LITE_DIR}/src/model.cc + ${LITE_DIR}/src/model_impl.cc + ${LITE_DIR}/src/populate_parameter.cc + ${LITE_DIR}/src/scheduler.cc + ${LITE_DIR}/src/common/graph_util.cc + ${LITE_DIR}/src/common/file_utils.cc + ${LITE_DIR}/src/common/file_utils_ext.cc + ${LITE_DIR}/src/common/utils.cc + ${LITE_DIR}/src/common/ms_tensor_utils.cc + ${LITE_DIR}/tools/common/graph_util.cc + ${LITE_DIR}/tools/common/tensor_util.cc + ${LITE_DIR}/tools/common/node_util.cc + ${LITE_DIR}/tools/common/flag_parser.cc + ${LITE_DIR}/tools/common/storage.cc + ${LITE_DIR}/tools/benchmark/benchmark.cc + ${LITE_DIR}/test/st/benchmark_test.cc + ) +### gpu runtime +if (SUPPORT_GPU) + include_directories(${TOP_DIR}/third_party/OpenCL-Headers) + include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include) + set(OPENCL_RUNTIME_SRC + ${LITE_DIR}/src/runtime/opencl/opencl_allocator.cc + ${LITE_DIR}/src/runtime/opencl/opencl_executor.cc + ${LITE_DIR}/src/runtime/opencl/opencl_runtime.cc + ${LITE_DIR}/src/runtime/opencl/opencl_wrapper.cc + ) + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + ${OPENCL_RUNTIME_SRC} + ) +endif() +### converter +if(BUILD_CONVERTER) + file(GLOB_RECURSE TEST_CASE_TFLITE_PARSERS_SRC + ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc + ) + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + ${TEST_CASE_TFLITE_PARSERS_SRC} + ${TOP_DIR}/mindspore/core/utils/flags.cc + ${LITE_DIR}/tools/converter/optimizer.cc + ${LITE_DIR}/src/common/anf_importer/anf_importer.cc + ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc + ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc + ${LITE_DIR}/tools/converter/anf_transform.cc + ${LITE_DIR}/tools/converter/graphdef_transform.cc + ${LITE_DIR}/tools/converter/converter_flags.cc + ${LITE_DIR}/tools/converter/converter.cc + ${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc + ${LITE_DIR}/test/st/converter_test.cc + ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc + ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc + ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc + ${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc + ${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc + ${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc + ${LITE_DIR}/tools/optimizer/common/gllo_utils.cc + ${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/conv_activation_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/conv_transform_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc + ) +endif() +### train +if (SUPPORT_TRAIN) + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + # ${SRC_DIR}/common/trans.cc + # ${SRC_DIR}/common/lite/trans_extends.cc + # ${SRC_DIR}/kernel/kernel_build_info.cc + # ${SRC_DIR}/utils/lite/base_ref_utils.cc + # ${SRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc + # ${SRC_DIR}/session/lite/session_basic_extends.cc + # ${SRC_DIR}/session/anf_runtime_algorithm.cc + # ${SRC_DIR}/session/anf_runtime_algorithm.cc + # ${SRC_DIR}/session/session_basic.cc + # ${SRC_DIR}/session/kernel_graph.cc + # ${SRC_DIR}/session/session_factory.cc + # ${SRC_DIR}/device/kernel_info.cc + # ${SRC_DIR}/device/kernel_runtime.cc + # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc + ${LITE_DIR}/src/common/anf_importer/anf_importer.cc + ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc + ${LITE_DIR}/src/ir/primitive_value.cc + ${LITE_DIR}/src/train/lite_kernel_runtime.cc + ${LITE_DIR}/src/train/train_session.cc + ${LITE_DIR}/src/train/model_impl.cc + ) +else() + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + ${LITE_DIR}/src/lite_session.cc + ) +endif() +### test src +file(GLOB_RECURSE TEST_CASE_KERNEL_SRC + ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc +) + +set(TEST_SRC + ${TEST_LITE_SRC} + ${TEST_MINDDATA_SRC} + ${TEST_CASE_KERNEL_SRC} + ${TEST_DIR}/common/common_test.cc + ${TEST_DIR}/main.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/common/pack_tests.cc + ${TEST_DIR}/ut/src/infer_test.cc +) + +if (SUPPORT_TRAIN) + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/ut/src/train_test.cc + ) +else() + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/ut/src/infer_test.cc + ) +endif() + +if (SUPPORT_GPU) + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc + ) +endif() + +if (ENABLE_FP16) + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc) +endif () + + +add_executable(lite-test ${TEST_SRC}) + +target_link_libraries(lite-test dl ${GTEST_LIBRARY}) +if (BUILD_MINDDATA) + target_link_libraries(lite-test + minddata-lite + minddata-eager + ) + if (PLATFORM_ARM32 OR PLATFORM_ARM64) + target_link_libraries(lite-test log) + endif() +endif() +if (BUILD_CONVERTER) + target_link_libraries(lite-test + anf_exporter_mid + tflite_parser_mid + caffe_parser_mid + node_mid + graph_pass_mid + fusion_mid + quantizer_mid + pthread + protobuf + mindspore::eigen + mindspore::json + ${SECUREC_LIBRARY} + ) +endif() diff --git a/mindspore/lite/test/common/common_test.cc b/mindspore/lite/test/common/common_test.cc new file mode 100644 index 0000000000..1c4b699619 --- /dev/null +++ b/mindspore/lite/test/common/common_test.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common_test.h" +#include "mindspore/core/utils/log_adapter.h" + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif +#endif + +namespace mindspore { + +void Common::SetUpTestCase() {} + +void Common::TearDownTestCase() {} + +void Common::SetUp() {} + +void Common::TearDown() {} + +} // namespace mindspore + +#ifdef __cplusplus +#if __cplusplus +} +#endif +#endif diff --git a/mindspore/lite/test/common/common_test.h b/mindspore/lite/test/common/common_test.h new file mode 100644 index 0000000000..7aafb9780c --- /dev/null +++ b/mindspore/lite/test/common/common_test.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TESTS_UT_COMMON_UT_COMMON_H_ +#define TESTS_UT_COMMON_UT_COMMON_H_ + +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +namespace mindspore { +class Common : public testing::Test { + public: + // TestCase only enter once + static void SetUpTestCase(); + static void TearDownTestCase(); + + // every TEST_F macro will enter one + virtual void SetUp(); + virtual void TearDown(); + + template + void PrintData(std::string name, T *output_data, int size) { + std::cout << "The " << name << " is as follows:" << std::endl; + if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << static_cast(output_data[i]) << " "; + } + } else { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << output_data[i] << " "; + } + } + std::cout << std::endl; + } + + template + static void CompareOutputData(T *output_data, T *correct_data, int size, float err_bound) { + for (size_t i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + ASSERT_LE(abs, err_bound); + } + } + + void ReadFile(const char *file, size_t *size, char **buf) { + ASSERT_NE(nullptr, file); + ASSERT_NE(nullptr, size); + ASSERT_NE(nullptr, buf); + std::string path = std::string(file); + std::ifstream ifs(path); + ASSERT_EQ(true, ifs.good()); + ASSERT_EQ(true, ifs.is_open()); + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + *buf = new char[*size]; + + ifs.seekg(0, std::ios::beg); + ifs.read(*buf, *size); + ifs.close(); + } +}; +} // namespace mindspore +#endif // TESTS_UT_COMMON_UT_COMMON_H_ diff --git a/mindspore/lite/test/main.cc b/mindspore/lite/test/main.cc new file mode 100644 index 0000000000..40ede795c7 --- /dev/null +++ b/mindspore/lite/test/main.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "gtest/gtest.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +extern void InitSubModulesLogLevel(); +} + +GTEST_API_ int main(int argc, char** argv) { + mindspore::InitSubModulesLogLevel(); + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/mindspore/lite/test/models_tflite.cfg b/mindspore/lite/test/models_tflite.cfg new file mode 100644 index 0000000000..708784b158 --- /dev/null +++ b/mindspore/lite/test/models_tflite.cfg @@ -0,0 +1,27 @@ +hiai_model_0909_kd_rot_ps_softmax.tflite +hiai_chinese_english_recognize_model_float32.tflite +hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite +hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite +hiai_cn_recognize_modify_padv2.tflite +hiai_model_normalize_object_scene_ps_20200519.tflite +hiai_detectmodel_06_23_960_480_1180700.tflite +hiai_detect_curve_model_float32.tflite +hiai_detectmodel_desnet_256_128_64_32.tflite +mtk_AADB_HADB_MBV2_model_fp32.tflite +mobilenet_v1_0.25_128.tflite +mobilenet_v1_0.25_160.tflite +mobilenet_v1_0.25_192.tflite +mobilenet_v1_0.25_224.tflite +mobilenet_v1_0.5_128.tflite +mobilenet_v1_0.5_160.tflite +mobilenet_v1_0.5_192.tflite +mobilenet_v1_0.5_224.tflite +mobilenet_v1_0.75_128.tflite +mobilenet_v1_0.75_160.tflite +mobilenet_v1_0.75_192.tflite +mobilenet_v1_0.75_224.tflite +mobilenet_v1_1.0_128.tflite +mobilenet_v1_1.0_160.tflite +mobilenet_v1_1.0_192.tflite +mobilenet_v1_1.0_224.tflite +mobilenet_v2_1.0_224.tflite diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh new file mode 100644 index 0000000000..4fc0891c3a --- /dev/null +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -0,0 +1,97 @@ +#!/bin/bash +basepath=$(pwd) +echo $basepath +set -e +#example:sh run_benchmark_nets.sh -a /home/temp_test -c /home/temp_test -m /home/temp_test/models -d "8KE5T19620002408" +while getopts "a:c:m:d:" opt +do + case $opt in + a) + arm_path=$OPTARG + echo "arm_path is $OPTARG" + ;; + c) + convertor_path=$OPTARG + echo "convertor_path is $OPTARG" + ;; + m) + models_path=$OPTARG + echo "models_path is $OPTARG" + ;; + d) + device_id=$OPTARG + echo "device_id is $OPTARG" + ;; + ?) + echo "unknown para" + exit 1;; + esac +done + + + +#unzip arm +cd $arm_path +tar -zxf MSLite-*-linux_arm64.tar.gz + + +#unzip convertor +cd $convertor_path +tar -zxf MSLite-*-linux_x86_64.tar.gz +cd $convertor_path/MSLite-*-linux_x86_64 +cp converter/converter_lite ./ +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib/:./third_party/protobuf/lib + +#the original model's path: $models_path/ + + +#convert the models +cd $convertor_path/MSLite-*-linux_x86_64 + +#models_config_filename=/home/workspace/mindspore_dataset/mslite/models/models_config.txt +models_tflite_config=${basepath}/models_tflite.cfg +rm -rf ${basepath}/ms_models +mkdir -p ${basepath}/ms_models +ms_models_path=${basepath}/ms_models + +while read line;do + model_name=$line + echo $model_name + echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' + ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name} +done < ${models_tflite_config} + +#push to the arm and run benchmark: + +#first:copy to the server which connected to the phone +rm -rf ${basepath}/benchmark_test +mkdir -p ${basepath}/benchmark_test +benchmark_test_path=${basepath}/benchmark_test +cd ${benchmark_test_path} +cp $arm_path/MSLite-0.6.0-linux_arm64/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so +cp $arm_path/MSLite-0.6.0-linux_arm64/benchmark/benchmark ${benchmark_test_path}/benchmark + +#copy the models: +cp ${ms_models_path}/*.ms ${benchmark_test_path} + +#second:adb push to the phone +adb -s $device_id push ${benchmark_test_path} /data/local/tmp/ + +#third:run adb ,run session ,check the result: +echo 'cd /data/local/tmp/benchmark_test' > adb_cmd.txt +echo 'cp /data/local/tmp/libc++_shared.so ./' >> adb_cmd.txt +echo 'chmod 777 benchmark' >> adb_cmd.txt + +adb -s $device_id shell < adb_cmd.txt + +#run models: +while read line;do + model_name=$line + echo $model_name + echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelPath='${model_name}'.ms --inDataPath=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --calibDataPath=/data/local/tmp/input_output/output/'${model_name}'.ms.out --warmUpLoopCount=1 --loopCount=1' + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelPath='${model_name}'.ms --inDataPath=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --calibDataPath=/data/local/tmp/input_output/output/'${model_name}'.ms.out --warmUpLoopCount=1 --loopCount=1' >> adb_run_cmd.txt + adb -s $device_id shell < adb_run_cmd.txt +done < ${models_tflite_config} + + diff --git a/mindspore/lite/test/run_test.sh b/mindspore/lite/test/run_test.sh new file mode 100755 index 0000000000..adfb48130f --- /dev/null +++ b/mindspore/lite/test/run_test.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -e +CUR_DIR=$(cd "$(dirname $0)"; pwd) +BUILD_DIR=${CUR_DIR}/../build +mkdir -pv ${CUR_DIR}/do_test +cd ${CUR_DIR}/do_test +cp ${BUILD_DIR}/test/lite-test ./ +cp -r ${CUR_DIR}/ut/src/runtime/kernel/arm/test_data/* ./ +cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./ +## prepare data for dataset +TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ +cp -fr $TEST_DATA_DIR/testPK ./data + +./lite-test --gtest_filter="*MindDataTestTensorDE*" +./lite-test --gtest_filter="*MindDataTestEager*" + +./lite-test --gtest_filter="TestTfliteParser*" + +./lite-test --gtest_filter="*TestHebing*" + +./lite-test --gtest_filter=TestFcFp32* +./lite-test --gtest_filter=TestConv1x1Fp32* +./lite-test --gtest_filter=TestStrassenFp32* +./lite-test --gtest_filter=TestDeConvolutionFp32* + +./lite-test --gtest_filter=TestPadInt8.* +./lite-test --gtest_filter=TestDeconvInt8.* diff --git a/mindspore/lite/test/st/benchmark_test.cc b/mindspore/lite/test/st/benchmark_test.cc new file mode 100644 index 0000000000..86b468677b --- /dev/null +++ b/mindspore/lite/test/st/benchmark_test.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "tools/benchmark/benchmark.h" + +namespace mindspore { +namespace lite { +class BenchmarkTest : public mindspore::Common { + public: + BenchmarkTest() {} +}; + +TEST_F(BenchmarkTest, TestVideo) { + const char *argv[] = {"./benchmark", "--modelPath=./hiai/hiai_label_and_video.ms" + "--inDataPath=./hiai/hiai_label_and_video.bin" + "--calibDataPath=./hiai/hiai_label_and_video.txt"}; + auto status = RunBenchmark(2, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(BenchmarkTest, TestOCR_02) { + const char *argv[] = {"./benchmark", "--modelPath=./hiai/hiai_cv_focusShootOCRMOdel_02.ms" + "--inDataPath=./hiai/hiai_cv_focusShootOCRMOdel_02.bin" + "--calibDataPath=./hiai/hiai_cv_focusShootOCRMOdel_02.txt"}; + auto status = RunBenchmark(2, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(BenchmarkTest, TestOCR_02_GPU) { +const char *argv[] = {"./benchmark", "--modelPath=./hiai/model_02.ms", + "--inDataPath=./hiai/model_02_in.bin", + "--calibDataPath=./hiai/model_02_out.bin", + "--device=GPU"}; +auto status = RunBenchmark(5, argv); +ASSERT_EQ(status, RET_OK); +} + +TEST_F(BenchmarkTest, TestHebing) { + const char *argv[] = {"./benchmark", "--modelPath=./hiai/model_hebing_3branch.ms" + "--inDataPath=./hiai/model_hebing_3branch.bin" + "--calibDataPath=./hiai/model_hebing_3branch.txt"}; + auto status = RunBenchmark(2, argv); + ASSERT_EQ(status, RET_OK); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/test/st/converter_test.cc b/mindspore/lite/test/st/converter_test.cc new file mode 100644 index 0000000000..25754f6da4 --- /dev/null +++ b/mindspore/lite/test/st/converter_test.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "tools/converter/converter.h" +#include "common/common_test.h" + +namespace mindspore { +namespace lite { +class ConverterTest : public mindspore::Common { + public: + ConverterTest() {} +}; + +TEST_F(ConverterTest, TestLenet) { + const char *argv[] = {"./converter", "--fmk=MS", "--modelFile=./common/lenet_bin.pb", + "--outputFile=./models/lenet_bin"}; + auto status = RunConverter(4, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(ConverterTest, TestVideo) { + const char *argv[] = {"./converter", "--fmk=TFLITE", "--modelFile=./hiai/hiai_label_and_video.tflite", + "--outputFile=./models/hiai_label_and_video"}; + auto status = RunConverter(4, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(ConverterTest, TestOCR_02) { + const char *argv[] = {"./converter", "--fmk=TFLITE", "--modelFile=./hiai/hiai_cv_focusShootOCRMOdel_02.tflite", + "--outputFile=./models/hiai_cv_focusShootOCRMOdel_02"}; + auto status = RunConverter(4, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(ConverterTest, TestHebing) { + const char *argv[] = {"./converter", "--fmk=CAFFE", "--modelFile=./hiai/model_hebing_3branch.prototxt", + "--weightFile=./models/model_hebing_3branch.caffemodel", + "--outputFile=./models/model_hebing_3branch"}; + auto status = RunConverter(5, argv); + ASSERT_EQ(status, RET_OK); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc b/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc new file mode 100644 index 0000000000..a2b6bd5f89 --- /dev/null +++ b/mindspore/lite/test/ut/src/dataset/de_tensor_test.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "gtest/gtest.h" +#include "./securec.h" +#include "dataset/core/tensor.h" +#include "dataset/core/cv_tensor.h" +#include "dataset/core/data_type.h" +#include "mindspore/lite/src/ir/tensor.h" + +using MSTensor = mindspore::tensor::MSTensor; +using DETensor = mindspore::tensor::DETensor; +using LiteTensor = mindspore::lite::tensor::LiteTensor; +using Tensor = mindspore::dataset::Tensor; +using DataType = mindspore::dataset::DataType; +using TensorShape = mindspore::dataset::TensorShape; + +class MindDataTestTensorDE : public mindspore::Common { + public: + MindDataTestTensorDE() {} +}; + +TEST_F(MindDataTestTensorDE, MSTensorBasic) { + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT32)); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + ASSERT_EQ(t == std::dynamic_pointer_cast(ms_tensor)->tensor(), true); +} + +TEST_F(MindDataTestTensorDE, MSTensorConvertToLiteTensor) { + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT32)); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + std::shared_ptr lite_ms_tensor = std::shared_ptr( + std::dynamic_pointer_cast(ms_tensor)->ConvertToLiteTensor()); + // check if the lite_ms_tensor is the derived LiteTensor + LiteTensor * lite_tensor = static_cast(lite_ms_tensor.get()); + ASSERT_EQ(lite_tensor != nullptr, true); +} + +TEST_F(MindDataTestTensorDE, MSTensorShape) { + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT32)); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + ASSERT_EQ(ms_tensor->DimensionSize(0) == 2, true); + ASSERT_EQ(ms_tensor->DimensionSize(1) == 3, true); + ms_tensor->set_shape(std::vector{3, 2}); + ASSERT_EQ(ms_tensor->DimensionSize(0) == 3, true); + ASSERT_EQ(ms_tensor->DimensionSize(1) == 2, true); + ms_tensor->set_shape(std::vector{6}); + ASSERT_EQ(ms_tensor->DimensionSize(0) == 6, true); +} + +TEST_F(MindDataTestTensorDE, MSTensorSize) { + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT32)); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + ASSERT_EQ(ms_tensor->ElementsNum() == 6, true); + ASSERT_EQ(ms_tensor->Size() == 24, true); +} + +TEST_F(MindDataTestTensorDE, MSTensorDataType) { + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT32)); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + ASSERT_EQ(ms_tensor->data_type() == mindspore::TypeId::kNumberTypeFloat32, true); + ms_tensor->set_data_type(mindspore::TypeId::kNumberTypeInt32); + ASSERT_EQ(ms_tensor->data_type() == mindspore::TypeId::kNumberTypeInt32, true); + ASSERT_EQ(std::dynamic_pointer_cast(ms_tensor)->tensor()->type() == DataType::DE_INT32, true); +} + +TEST_F(MindDataTestTensorDE, MSTensorMutableData) { + std::vector x = {2.5, 2.5, 2.5, 2.5}; + std::shared_ptr t; + Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + float *data = static_cast(ms_tensor->MutableData()); + std::vector tensor_vec(data, data + ms_tensor->ElementsNum()); + ASSERT_EQ(x == tensor_vec, true); +} + +TEST_F(MindDataTestTensorDE, MSTensorHash) { + std::vector x = {2.5, 2.5, 2.5, 2.5}; + std::shared_ptr t; + Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); + auto ms_tensor = std::shared_ptr(new DETensor(t)); + ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); +} diff --git a/mindspore/lite/test/ut/src/dataset/eager_test.cc b/mindspore/lite/test/ut/src/dataset/eager_test.cc new file mode 100644 index 0000000000..ffc271a981 --- /dev/null +++ b/mindspore/lite/test/ut/src/dataset/eager_test.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/common_test.h" +#include "gtest/gtest.h" +#include "./securec.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/execute.h" +#include "minddata/dataset/util/path.h" + +using MSTensor = mindspore::tensor::MSTensor; +using DETensor = mindspore::tensor::DETensor; +using mindspore::dataset::api::vision::Decode; +using mindspore::dataset::api::vision::Normalize; +using mindspore::dataset::api::vision::Resize; +using Execute = mindspore::dataset::api::Execute; +using Path = mindspore::dataset::Path; + +class MindDataTestEager : public mindspore::Common { + public: + MindDataTestEager() {} +}; + +TEST_F(MindDataTestEager, Test1) { +#if defined(ENABLE_ARM64) || defined(ENABLE_ARM32) + std::string in_dir = "/sdcard/data/testPK/data/class1"; +#else + std::string in_dir = "data/testPK/data/class1"; +#endif + Path base_dir = Path(in_dir); + MS_LOG(WARNING) << base_dir.toString() << "."; + if (!base_dir.IsDirectory() || !base_dir.Exists()) { + MS_LOG(INFO) << "Input dir is not a directory or doesn't exist" << "."; + } + auto t_start = std::chrono::high_resolution_clock::now(); + // check if output_dir exists and create it if it does not exist + + // iterate over in dir and create json for all images + auto dir_it = Path::DirIterator::OpenDirectory(&base_dir); + while (dir_it->hasNext()) { + Path v = dir_it->next(); + MS_LOG(WARNING) << v.toString() << "."; + std::shared_ptr image = std::shared_ptr(DETensor::CreateTensor(v.toString())); + + image = Execute(Decode())(image); + EXPECT_TRUE(image != nullptr); + image = Execute(Normalize({121.0, 115.0, 100.0}, {70.0, 68.0, 71.0}))(image); + EXPECT_TRUE(image != nullptr); + image = Execute(Resize({224, 224}))(image); + EXPECT_TRUE(image != nullptr); + EXPECT_EQ(image->DimensionSize(0), 224); + EXPECT_EQ(image->DimensionSize(1), 224); + } + auto t_end = std::chrono::high_resolution_clock::now(); + double elapsed_time_ms = std::chrono::duration(t_end-t_start).count(); + MS_LOG(INFO) << "duration: " << elapsed_time_ms << " ms\n"; +} diff --git a/mindspore/lite/test/ut/src/graph_test.cc b/mindspore/lite/test/ut/src/graph_test.cc new file mode 100644 index 0000000000..c1fa4cbd8e --- /dev/null +++ b/mindspore/lite/test/ut/src/graph_test.cc @@ -0,0 +1,246 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/common_test.h" +#include "mindspore/core/utils/log_adapter.h" +#include "mindspore/lite/include/lite_session.h" +#include "mindspore/lite/src/executor.h" +#include "mindspore/lite/schema/inner/anf_ir_generated.h" + +namespace mindspore { +class TestLiteInference : public mindspore::Common { + public: + TestLiteInference() {} +}; + +std::string RealPath(const char *path) { + if (path == nullptr) { + return ""; + } + if ((strlen(path)) >= PATH_MAX) { + return ""; + } + + std::shared_ptr resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); + if (resolvedPath == nullptr) { + return ""; + } + + auto ret = realpath(path, resolvedPath.get()); + if (ret == nullptr) { + return ""; + } + return resolvedPath.get(); +} + +char *ReadModelFile(const char *file, size_t *size) { + if (file == nullptr) { + return nullptr; + } + MS_ASSERT(size != nullptr); + std::ifstream ifs(RealPath(file)); + if (!ifs.good()) { + return nullptr; + } + + if (!ifs.is_open()) { + return nullptr; + } + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + std::unique_ptr buf(new (std::nothrow) char[*size]); + if (buf == nullptr) { + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), *size); + ifs.close(); + + return buf.release(); +} + +// TEST_F(TestLiteInference, Net) { +// auto msGraph = std::make_shared(); +// msGraph->name = "graph"; +// auto msSubgraph = std::make_unique(); +// msSubgraph->name = "subGraph"; +// +// auto node = std::make_unique(); +// node->inputIndex = {0, 1}; +// node->outputIndex = {2}; +// node->attr.type = lite::OpT_Add; +// node->attr.value = new lite::AddT; +// node->name = "Add"; +// node->fmkType = lite::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(node)); +// +// msSubgraph->inputIndex = {0}; +// msSubgraph->outputIndex = {2}; +// +// auto input0 = std::make_unique(); +// input0->refCount = lite::MSCONST_WEIGHT_REFCOUNT; +// input0->format = lite::Format_NCHW; +// input0->dataType = TypeId::kNumberTypeFloat; +// input0->dims = {1, 1, 2, 2}; +// input0->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(input0)); +// +// auto input1 = std::make_unique(); +// input1->refCount = lite::MSCONST_WEIGHT_REFCOUNT; +// input1->format = lite::Format_NCHW; +// input1->dataType = TypeId::kNumberTypeFloat; +// input1->dims = {1, 1, 2, 2}; +// input1->offset = -1; +// input1->data.resize(16); +// msSubgraph->allTensors.emplace_back(std::move(input1)); +// +// auto output = std::make_unique(); +// output->refCount = 0; +// output->format = lite::Format_NCHW; +// output->dims = {1, 1, 2, 2}; +// output->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(output)); +// msGraph->subgraphs.emplace_back(std::move(msSubgraph)); +// +// flatbuffers::FlatBufferBuilder builder(1024); +// auto offset = lite::GraphDef::Pack(builder, msGraph.get()); +// builder.Finish(offset); +// int size = builder.GetSize(); +// auto *content = builder.GetBufferPointer(); +// mindspore::lite::Context context; +// context.allocator = nullptr; +// context.deviceCtx.type = mindspore::lite::DeviceType::DT_CPU; +// #if 0 +// auto graph = mindspore::lite::inference::LoadModel((char *)content, size); +// +// auto session = mindspore::lite::inference::Session::CreateSession(&context); +// +// std::vector z1 = {1.1, 2.1, 3.1, 4.1}; +// std::vector inputs; +// auto t1 = inference::MSTensor::CreateTensor(TypeId::kNumberTypeFloat32, std::vector({1, 1, 2, 2})); +// memcpy_s(t1->MutableData(), z1.size() * sizeof(float), z1.data(), z1.size() * sizeof(float)); +// +// auto t2 = inference::MSTensor::CreateTensor(TypeId::kNumberTypeFloat32, std::vector({1, 1, 2, 2})); +// memcpy_s(t2->MutableData(), z1.size() * sizeof(float), z1.data(), z1.size() * sizeof(float)); +// +// inputs.push_back(t1); +// inputs.push_back(t1); +// // VectorRef *outputs = new VectorRef(); +// auto outputs = session->RunGraph(inputs); +// #else +// auto file = "./efficientnet_b0.ms"; +// size_t model_size; +// +// char *modelbuf = ReadModelFile(file, &model_size); +// auto graph = mindspore::lite::inference::LoadModel(modelbuf, model_size); +// auto session = mindspore::lite::inference::Session::CreateSession(&context); +// session->CompileGraph(graph); +// std::vector inputs; +// auto t1 = inference::MSTensor::CreateTensor(TypeId::kNumberTypeFloat32, std::vector({1, 244, 244, 3})); +// +// inputs.push_back(t1); +// auto outputs = session->RunGraph(inputs); +// #endif +// } + +// TEST_F(TestLiteInference, Conv) { +// auto msGraph = std::make_shared(); +// msGraph->name = "graph"; +// auto msSubgraph = std::make_unique(); +// msSubgraph->name = "subGraph"; +// +// auto node = std::make_unique(); +// node->inputIndex = {0, 1}; +// node->outputIndex = {2}; +// node->attr.type = lite::OpT_Conv2D; +// auto attr = new lite::Conv2DT; +// attr->padMode = lite::PadMode_SAME; +// attr->channelIn = 1; +// attr->channelOut = 1; +// attr->format = lite::Format_NHWC; +// attr->strideH = 1; +// attr->strideW = 1; +// attr->kernelH = 2; +// attr->kernelW = 2; +// +// node->attr.value = attr; +// node->name = "Conv2D"; +// node->fmkType = lite::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(node)); +// +// msSubgraph->inputIndex = {0}; +// msSubgraph->outputIndex = {2}; +// // MS_LOG(ERROR) << "OutData"; +// +// auto input0 = std::make_unique(); +// input0->refCount = lite::MSCONST_WEIGHT_REFCOUNT; +// input0->format = lite::Format_NCHW; +// input0->dataType = TypeId::kNumberTypeFloat; +// input0->dims = {1, 1, 5, 5}; +// // input0->data.resize(sizeof(float) * 25); +// // std::vector input_data = {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; +// // memcpy(input0->data.data(), input_data.data(), sizeof(int) * 25); +// input0->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(input0)); +// +// auto weight = std::make_unique(); +// weight->refCount = lite::MSCONST_WEIGHT_REFCOUNT; +// weight->format = lite::Format_KHWC; +// weight->dataType = TypeId::kNumberTypeFloat; +// weight->dims = {1, 2, 2, 1}; +// weight->data.resize(sizeof(float) * 4); +// std::vector weight_data = {1, 2, 3, 4}; +// memcpy(weight->data.data(), weight_data.data(), sizeof(int) * 4); +// weight->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(weight)); +// +// auto output = std::make_unique(); +// output->refCount = 0; +// output->format = lite::Format_NCHW; +// output->dims = {1, 1, 5, 5}; +// output->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(output)); +// msGraph->subgraphs.emplace_back(std::move(msSubgraph)); +// +// flatbuffers::FlatBufferBuilder builder(1024); +// auto offset = lite::GraphDef::Pack(builder, msGraph.get()); +// builder.Finish(offset); +// int size = builder.GetSize(); +// auto *content = builder.GetBufferPointer(); +// mindspore::lite::Context context; +// context.allocator = nullptr; +// context.deviceCtx.type = mindspore::lite::DeviceType::DT_CPU; +// auto graph = mindspore::lite::inference::LoadModel((char *)content, size); +// auto session = mindspore::lite::inference::Session::CreateSession(&context); +// session->CompileGraph(graph); +// std::vector inputs; +// auto t1 = inference::MSTensor::CreateTensor(TypeId::kNumberTypeFloat32, std::vector({1, 3, 244, 244})); +// +// inputs.push_back(t1); +// auto outputs = session->RunGraph(inputs); +// } + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/infer_test.cc b/mindspore/lite/test/ut/src/infer_test.cc new file mode 100644 index 0000000000..5931d89034 --- /dev/null +++ b/mindspore/lite/test/ut/src/infer_test.cc @@ -0,0 +1,409 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/schema/inner/model_generated.h" +#include "mindspore/lite/include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +class InferTest : public mindspore::Common { + public: + InferTest() {} +}; + +TEST_F(InferTest, TestConvNode) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + + auto node = std::make_unique(); + node->inputIndex = {0, 1}; + node->outputIndex = {2}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Conv2D; + auto primitive = new schema::Conv2DT; + primitive->padMode = schema::PadMode_SAME; + primitive->channelIn = 3; + primitive->channelOut = 32; + primitive->format = schema::Format_NHWC; + primitive->strideH = 1; + primitive->strideW = 1; + primitive->kernelH = 3; + primitive->kernelW = 3; + primitive->dilateH = 1; + primitive->dilateW = 1; + node->primitive->value.value = primitive; + node->name = "Conv2D"; + meta_graph->nodes.emplace_back(std::move(node)); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {2}; + + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 28, 28, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + auto weight = std::make_unique(); + weight->nodeType = schema::NodeType::NodeType_ValueNode; + weight->format = schema::Format_KHWC; + weight->dataType = TypeId::kNumberTypeFloat32; + weight->dims = {32, 3, 3, 3}; + + auto buf = new char *[1]; + //================================================================ + size_t weight_size; + std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin"; + ReadFile(weight_path.c_str(), &weight_size, buf); + ASSERT_NE(nullptr, buf[0]); + auto weight_data_temp = reinterpret_cast(buf[0]); + ASSERT_NE(nullptr, weight_data_temp); + weight->data.resize(sizeof(float) * 32 * 3 * 3 * 3); + + //================================================================ + memcpy(weight->data.data(), weight_data_temp, weight_size); + weight->offset = -1; + meta_graph->allTensors.emplace_back(std::move(weight)); + + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 28, 28, 32}; + output->offset = -1; + meta_graph->allTensors.emplace_back(std::move(output)); + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); + builder.Finish(offset); + size_t size = builder.GetSize(); + const char *content = reinterpret_cast(builder.GetBufferPointer()); + + auto model = lite::Model::Import(content, size); + ASSERT_NE(nullptr, model); + meta_graph.reset(); + content = nullptr; + auto context = new lite::Context; + context->cpu_bind_mode_ = lite::NO_BIND; + context->device_ctx_.type = lite::DT_CPU; + context->thread_num_ = 4; + auto session = session::LiteSession::CreateSession(context); + ASSERT_NE(nullptr, session); + auto ret = session->CompileGraph(model.get()); + ASSERT_EQ(lite::RET_OK, ret); + auto inputs = session->GetInputs(); + ASSERT_EQ(inputs.size(), 1); + auto inTensor = inputs.front(); + ASSERT_NE(nullptr, inTensor); + auto data = inTensor->MutableData(); + //=================================================== + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; + ReadFile(input_path.c_str(), &input_size, buf); + ASSERT_NE(nullptr, buf[0]); + auto input_data = reinterpret_cast(buf[0]); + ASSERT_NE(nullptr, input_data); + //=================================================== + ASSERT_EQ(input_size, inTensor->Size()); + memcpy(data, input_data, input_size); + ret = session->RunGraph(); + ASSERT_EQ(lite::RET_OK, ret); + auto outputs = session->GetOutputs(); + ASSERT_EQ(outputs.size(), 1); + auto outTensor = outputs.front(); + ASSERT_NE(nullptr, outTensor); + ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum()); + ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); + auto *outData = reinterpret_cast(outTensor->MutableData()); + ASSERT_NE(nullptr, outData); + //=================================================== + size_t output_size; + std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin"; + ReadFile(output_path.c_str(), &output_size, buf); + ASSERT_NE(nullptr, buf[0]); + auto output_data = reinterpret_cast(buf[0]); + ASSERT_NE(nullptr, output_data); + //=================================================== + ASSERT_EQ(output_size, outTensor->Size()); + for (size_t i = 0; i < outTensor->ElementsNum(); i++) { + ASSERT_LE((output_data[i]- outData[i]), 0.001); + } + MS_LOG(INFO) << "Passed"; +} +TEST_F(InferTest, TestAddNode) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + + auto node = std::make_unique(); + node->inputIndex = {0, 1}; + node->outputIndex = {2}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Add; + auto primitive = new schema::AddT; + node->primitive->value.value = primitive; + node->name = "Add"; + meta_graph->nodes.emplace_back(std::move(node)); + meta_graph->inputIndex = {0, 1}; + meta_graph->outputIndex = {2}; + + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 28, 28, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + auto weight = std::make_unique(); + weight->nodeType = schema::NodeType::NodeType_ValueNode; + weight->format = schema::Format_KHWC; + weight->dataType = TypeId::kNumberTypeFloat32; + weight->dims = {1, 28, 28, 3}; + + weight->offset = -1; + meta_graph->allTensors.emplace_back(std::move(weight)); + + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->offset = -1; + meta_graph->allTensors.emplace_back(std::move(output)); + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); + builder.Finish(offset); + size_t size = builder.GetSize(); + const char *content = reinterpret_cast(builder.GetBufferPointer()); + + auto model = lite::Model::Import(content, size); + ASSERT_NE(nullptr, model); + meta_graph.reset(); + content = nullptr; + auto context = new lite::Context; + context->cpu_bind_mode_ = lite::NO_BIND; + context->device_ctx_.type = lite::DT_GPU; + context->thread_num_ = 4; + auto session = session::LiteSession::CreateSession(context); + ASSERT_NE(nullptr, session); + auto ret = session->CompileGraph(model.get()); + ASSERT_EQ(lite::RET_OK, ret); + auto inputs = session->GetInputs(); + ASSERT_EQ(inputs.size(), 2); + auto inTensor = inputs.front(); + ASSERT_NE(nullptr, inTensor); + (void)inTensor->MutableData(); + auto inTensor1 = inputs.back(); + ASSERT_NE(nullptr, inTensor1); + (void)inTensor1->MutableData(); + ret = session->RunGraph(); + ASSERT_EQ(lite::RET_OK, ret); + auto outputs = session->GetOutputs(); + ASSERT_EQ(outputs.size(), 1); + auto outTensor = outputs.front(); + ASSERT_NE(nullptr, outTensor); + ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum()); + ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); + auto *outData = reinterpret_cast(outTensor->MutableData()); + ASSERT_NE(nullptr, outData); + // //=================================================== + // size_t output_size; + // std::string output_path = "./convfp32_out_1_28_28_32.bin"; + // ReadFile(output_path.c_str(), &output_size, buf); + // ASSERT_NE(nullptr, buf[0]); + // auto output_data = reinterpret_cast(buf[0]); + // ASSERT_NE(nullptr, output_data); + // //=================================================== + // ASSERT_EQ(output_size, outTensor->Size()); + // for (size_t i = 0; i < outTensor->ElementsNum(); i++) { + // ASSERT_EQ(output_data[i], outData[i]); + // } + MS_LOG(INFO) << "Passed"; +} + +TEST_F(InferTest, TestModel) { + auto buf = new char *[1]; + size_t model_size; + std::string model_path = "./model.ms"; + ReadFile(model_path.c_str(), &model_size, buf); + ASSERT_NE(nullptr, buf[0]); + + auto model = lite::Model::Import(buf[0], model_size); + ASSERT_NE(nullptr, model); + delete[] buf[0]; + auto context = new lite::Context; + context->cpu_bind_mode_ = lite::NO_BIND; + context->device_ctx_.type = lite::DT_CPU; + context->thread_num_ = 4; + auto session = session::LiteSession::CreateSession(context); + ASSERT_NE(nullptr, session); + auto ret = session->CompileGraph(model.get()); + ASSERT_EQ(lite::RET_OK, ret); + auto inputs = session->GetInputs(); + ASSERT_EQ(inputs.size(), 1); + auto inTensor = inputs.front(); + ASSERT_NE(nullptr, inTensor); + (void)inTensor->MutableData(); + ret = session->RunGraph(); + ASSERT_EQ(lite::RET_OK, ret); + auto outputs = session->GetOutputs(); + MS_LOG(INFO) << "Passed"; +} + +// TEST_F(TrainTest, TestMultiNode) { +// auto msGraph = std::make_shared(); +// msGraph->name = "graph"; +// auto msSubgraph = std::make_unique(); +// msSubgraph->name = "subGraph"; +// +// auto conv = std::make_unique(); +// conv->inputIndex = {0, 1}; +// conv->outputIndex = {2}; +// conv->attr.type = schema::OpT_Conv2D; +// auto conv_attr = new schema::Conv2DT; +// conv_attr->padMode = schema::PadMode_SAME; +// conv_attr->format = schema::Format_NHWC; +// conv_attr->strideH = 1; +// conv_attr->strideW = 1; +// conv_attr->kernelH = 3; +// conv_attr->kernelW = 3; +// conv_attr->dilateH = 1; +// conv_attr->dilateW = 1; +// +// conv->attr.value = conv_attr; +// conv->name = "Conv2D"; +// conv->fmkType = schema::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(conv)); +// +// auto matMul1 = std::make_unique(); +// matMul1->inputIndex = {2, 3}; +// matMul1->outputIndex = {4}; +// matMul1->attr.type = schema::OpT_MatMul; +// auto matMul_attr1 = new schema::MatMulT; +// matMul_attr1->transposeA = false; +// matMul_attr1->transposeB = true; +// matMul1->attr.value = matMul_attr1; +// matMul1->name = "matmul1"; +// matMul1->fmkType = schema::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(matMul1)); +// +// auto matMul2 = std::make_unique(); +// matMul2->inputIndex = {4, 5}; +// matMul2->outputIndex = {6}; +// matMul2->attr.type = schema::OpT_MatMul; +// auto matMul_attr2 = new schema::MatMulT; +// matMul_attr2->transposeA = false; +// matMul_attr2->transposeB = true; +// matMul2->attr.value = matMul_attr2; +// matMul2->name = "matmul2"; +// matMul2->fmkType = schema::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(matMul2)); +// +// msSubgraph->inputIndex = {0}; +// msSubgraph->outputIndex = {6}; +// +// auto input0 = std::make_unique(); +// input0->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// input0->format = schema::Format_NHWC; +// input0->dataType = TypeId::kNumberTypeFloat32; +// input0->dims = {1, 5, 5, 3}; +// input0->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(input0)); +// +// auto conv_weight = std::make_unique(); +// conv_weight->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// conv_weight->format = schema::Format_KHWC; +// conv_weight->dataType = TypeId::kNumberTypeFloat32; +// conv_weight->dims = {8, 3, 3, 3}; +// conv_weight->data.resize(8*3*3*3*sizeof(float)); +// msSubgraph->allTensors.emplace_back(std::move(conv_weight)); +// +// auto conv_output = std::make_unique(); +// conv_output->refCount = 0; +// conv_output->format = schema::Format_NHWC; +// conv_output->dataType = TypeId::kNumberTypeFloat32; +// conv_output->dims = {1, 5, 5, 8}; +// msSubgraph->allTensors.emplace_back(std::move(conv_output)); +// +// auto add_weight = std::make_unique(); +// add_weight->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// add_weight->format = schema::Format_NHWC; +// add_weight->dataType = TypeId::kNumberTypeFloat32; +// add_weight->dims = {1, 5, 5, 8}; +// add_weight->data.resize(5*5*8*sizeof(float)); +// msSubgraph->allTensors.emplace_back(std::move(add_weight)); +// +// auto add_output = std::make_unique(); +// add_output->refCount = 0; +// add_output->format = schema::Format_NHWC; +// add_output->dataType = TypeId::kNumberTypeFloat32; +// add_output->dims = {1, 5, 5, 8}; +// msSubgraph->allTensors.emplace_back(std::move(add_output)); +// +// auto mul_weight = std::make_unique(); +// mul_weight->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// mul_weight->format = schema::Format_NHWC; +// mul_weight->dataType = TypeId::kNumberTypeFloat32; +// mul_weight->dims = {1, 5, 5, 8}; +// mul_weight->data.resize(5*5*8*sizeof(float)); +// msSubgraph->allTensors.emplace_back(std::move(mul_weight)); +// +// auto mul_output = std::make_unique(); +// mul_output->refCount = 0; +// mul_output->format = schema::Format_NHWC; +// mul_output->dataType = TypeId::kNumberTypeFloat32; +// mul_output->dims = {1, 5, 5, 8}; +// msSubgraph->allTensors.emplace_back(std::move(mul_output)); +// msGraph->subgraphs.emplace_back(std::move(msSubgraph)); +// +// flatbuffers::FlatBufferBuilder builder(1024); +// auto offset = schema::GraphDef::Pack(builder, msGraph.get()); +// builder.Finish(offset); +// size_t size = builder.GetSize(); +// const char *content = (char *)builder.GetBufferPointer(); +// const std::string strstub = ""; +// +// auto func_graph = inference::LoadModel(content, size, strstub); +// ASSERT_NE(nullptr, func_graph); +// auto session = inference::MSSession::CreateSession(kCPUDevice, 0); +// ASSERT_NE(nullptr, session); +// auto graphId = session->CompileGraph(func_graph); +// +// auto inTensor = +// std::shared_ptr(inference::MSTensor::CreateTensor(TypeId::kNumberTypeFloat32, {1, 5, 5, 3})); +// ASSERT_NE(nullptr, inTensor); +// ASSERT_EQ(sizeof(float) * (5 * 5 * 3), inTensor->Size()); +// (void)inTensor->MutableData(); +// +// std::vector> inputs; +// inputs.emplace_back(inTensor); +// auto outputs = session->RunGraph(graphId, inputs); +// ASSERT_EQ(1, outputs.size()); +// ASSERT_EQ(1, outputs.front().size()); +// auto runOutput = outputs.front().front(); +// ASSERT_NE(nullptr, runOutput); +// ASSERT_EQ(5 * 5 * 8, runOutput->ElementsNum()); +// ASSERT_EQ(TypeId::kNumberTypeFloat32, runOutput->data_type()); +// MS_LOG(INFO) << "Passed"; +//} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc new file mode 100644 index 0000000000..c3283f335f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc @@ -0,0 +1,302 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h" + +namespace mindspore { +class TestPack : public mindspore::Common { + public: + TestPack() {} +}; + +void InitConvParamPack(ConvParameter *conv_param) { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 28; + conv_param->input_w_ = 28; + conv_param->input_channel_ = 3; + + conv_param->output_batch_ = 1; + conv_param->output_h_ = 28; + conv_param->output_w_ = 28; + conv_param->output_channel_ = 32; + + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; +} + +TEST_F(TestPack, PackInputFp32) { + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + auto conv_param = new ConvParameter; + InitConvParamPack(conv_param); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int thread_count = 1; + int tile_n = 8; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + + int inchannel_block = 4; + int channel_block = UP_DIV(in_channel, inchannel_block); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * channel_block * inchannel_block; + int packed_input_size = output_tile_count * tile_n * unit_size; + + auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float))); + memset(packed_input, 0, in_batch * packed_input_size * sizeof(float)); + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = 0; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - tile_n) : tile_n; + float *gemm_input = + reinterpret_cast(packed_input) + thread_id * unit_size * tile_n + gemm_in_batch_offset; + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + } + } + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << packed_input[i] << " ,"; + } + std::cout << std::endl; + + std::string file_path = "./test_data/conv/convfp32_packinput.txt"; + // mindspore::lite::WriteToTxt(file_path, packed_data, in_batch * packed_input_size); + + delete input_data; + delete conv_param; + free(packed_input); + MS_LOG(INFO) << "TestPackInputFp32 passed"; +} + +TEST_F(TestPack, PackWeightFp32) { + auto conv_param = new ConvParameter; + InitConvParamPack(conv_param); + + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int oc8 = UP_DIV(out_channel, C8NUM); + + size_t weight_size; + std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float))); + PackWeightFp32(weight_data, conv_param, packed_weight, C8NUM, oc8); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << packed_weight[i] << " ,"; + } + std::cout << std::endl; + + free(packed_weight); + delete conv_param; + + MS_LOG(INFO) << "TestPackWeightFp32 passed"; +} + +#ifdef ENABLE_FP16 +TEST_F(TestPack, PackInputFp16) { + // todo + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + int input_ele_size = input_size / sizeof(float); + auto fp16_input_data = new float16_t[input_ele_size]; + for (int i = 0; i < input_ele_size; i++) { + fp16_input_data[i] = (float16_t)input_data[i]; + } + + auto conv_param = new ConvParameter; + InitConvParamPack(conv_param); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int thread_count = 1; + int tile_n = 16; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + + int inchannel_block = 8; + int channel_block = UP_DIV(in_channel, inchannel_block); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * channel_block * inchannel_block; + int packed_input_size = output_tile_count * tile_n * unit_size; + + auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); + memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t)); + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = 0; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - tile_n) : tile_n; + float16_t *gemm_input = + reinterpret_cast(packed_input) + thread_id * unit_size * tile_n + gemm_in_batch_offset; + Im2ColPackUnitFp16(fp16_input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + } + } + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << packed_input[i] << " ,"; + } + std::cout << std::endl; + + delete input_data; + delete[] fp16_input_data; + delete conv_param; + delete packed_input; + MS_LOG(INFO) << "TestPackInputFp16 passed"; +} +#endif + +TEST_F(TestPack, PackInputUint8) { + auto conv_param = new ConvParameter; + InitConvParamPack(conv_param); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int thread_count = 1; + int tile_n = 8; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + + int inchannel_block = 4; + int channel_block = UP_DIV(in_channel, inchannel_block); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * channel_block * inchannel_block; + int packed_input_size = output_tile_count * tile_n * unit_size; + + // input + size_t input_size; + std::string input_path = "./test_data/conv/convuint8_input_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + auto int8_input = reinterpret_cast(malloc(input_size)); + for (int i = 0; i < input_size; i++) { + int8_input[i] = (int8_t)(input_data[i] - 128); + } + auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size)); + memset(packed_input, 0, in_batch * packed_input_size); + int32_t *input_sum = reinterpret_cast(malloc(tile_n * thread_count * sizeof(int32_t))); + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = 0; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - tile_n) : tile_n; + int8_t *gemm_input = + reinterpret_cast(packed_input) + thread_id * unit_size * tile_n + gemm_in_batch_offset; + memset(input_sum, 0, tile_n * thread_count * sizeof(int32_t)); + Im2ColPackUnitInt8(int8_input + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); + } + } + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << static_cast(packed_input[i]) << " ,"; + } + std::cout << std::endl; + + delete input_data; + delete conv_param; + free(int8_input); + free(packed_input); + free(input_sum); + MS_LOG(INFO) << "TestPackInputUint8 passed"; +} + +TEST_F(TestPack, PackWeightUint8) { + auto conv_param = new ConvParameter; + InitConvParamPack(conv_param); + + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int oc4 = UP_DIV(out_channel, C4NUM); + + size_t weight_size; + std::string weight_path = "./test_data/conv/convuint8_weight_32_3_3_3.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + auto int8_weight = reinterpret_cast(malloc(weight_size)); + for (int i = 0; i < weight_size; i++) { + int8_weight[i] = (int8_t)(weight_data[i] - 128); + } + int32_t filter_zp = 20; + + int32_t *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * k_h * k_w; + auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc4 * C4NUM)); + PackWeightInt8(int8_weight, conv_param, packed_weight, weight_sum); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << static_cast(packed_weight[i]) << " ,"; + } + std::cout << std::endl; + + free(weight_sum); + free(int8_weight); + free(packed_weight); + delete conv_param; + + MS_LOG(INFO) << "TestPackWeightUint8 passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc new file mode 100644 index 0000000000..16ea9e1f75 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/strided_slice.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestStridedSlice : public mindspore::Common { + public: + TestStridedSlice() {} +}; + +void InitStridedSliceParam(StridedSliceParameter *strided_slice_param) { + strided_slice_param->begins_[0] = 0; + strided_slice_param->begins_[1] = 0; + strided_slice_param->begins_[2] = 0; + + strided_slice_param->ends_[0] = 1; + strided_slice_param->ends_[1] = 2; + strided_slice_param->ends_[2] = 4; + + strided_slice_param->strides_[0] = 1; + strided_slice_param->strides_[1] = 2; + strided_slice_param->strides_[2] = 2; + + strided_slice_param->in_shape_[0] = 1; + strided_slice_param->in_shape_[1] = 2; + strided_slice_param->in_shape_[2] = 4; + strided_slice_param->num_axes_ = 3; +} + +TEST_F(TestStridedSlice, StridedSlice) { + lite::tensor::Tensor in_tensor(kNumberTypeFloat32, {1, 2, 4}); + lite::tensor::Tensor out_tensor(kNumberTypeFloat32, {1, 1, 2}); + float input_data[] = {0.2390374, 0.92039955, 0.05051243, 0.49574447, 0.8355223, 0.02647042, 0.08811307, 0.4566604}; + float output_data[2] = {0}; + in_tensor.SetData(input_data); + out_tensor.SetData(output_data); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor}; + + StridedSliceParameter parameter = {0}; + InitStridedSliceParam(¶meter); + parameter.op_parameter_.type_ = schema::PrimitiveType_StridedSlice; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_StridedSlice}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + float expect[2] = {0.2390374, 0.05051243}; + CompareOutputData(output_data, expect, 2, 0.000001); + + in_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); +} + +TEST_F(TestStridedSlice, StridedSliceInt8) { + lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 3, 4}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {2, 3, 4}); + int8_t input_data[] = {-12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int8_t output_data[4] = {0}; + in_tensor.SetData(input_data); + out_tensor.SetData(output_data); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor}; + + StridedSliceParameter parameter = {0}; + parameter.begins_[0] = 0; + parameter.begins_[1] = 1; + parameter.begins_[2] = 2; + parameter.ends_[0] = 2; + parameter.ends_[1] = 3; + parameter.ends_[2] = 4; + parameter.strides_[0] = 1; + parameter.strides_[1] = 2; + parameter.strides_[2] = 1; + parameter.in_shape_[0] = 2; + parameter.in_shape_[1] = 3; + parameter.in_shape_[2] = 4; + parameter.num_axes_ = 3; + + parameter.op_parameter_.type_ = schema::PrimitiveType_StridedSlice; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_StridedSlice}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect[4] = {-6, -5, 7, 8}; + for (int i = 0; i < sizeof(expect); ++i) { + EXPECT_EQ(output_data[i], expect[i]); + } + + in_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc new file mode 100644 index 0000000000..552240f730 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc @@ -0,0 +1,593 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/utils.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" + +namespace mindspore { +class TestConvolutionFp16 : public mindspore::Common { + public: + TestConvolutionFp16() {} +}; + +void InitConvParamGroup1Fp16(ConvParameter *conv_param) { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 28; + conv_param->input_w_ = 28; + conv_param->input_channel_ = 3; + + conv_param->output_batch_ = 1; + conv_param->output_h_ = 28; + conv_param->output_w_ = 28; + conv_param->output_channel_ = 32; + + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + conv_param->thread_num_ = 1; +} + +void InitConvParamGroup2Fp16(ConvParameter *conv_param) { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 128; + conv_param->input_w_ = 128; + conv_param->input_channel_ = 32; + + conv_param->output_batch_ = 1; + conv_param->output_h_ = 128; + conv_param->output_w_ = 128; + conv_param->output_channel_ = 32; + + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + conv_param->thread_num_ = 1; +} + +TEST_F(TestConvolutionFp16, ConvTest1) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup1Fp16(conv_param); + + int tile_num = 16; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int kernel_plane = k_h * k_w; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int i_h = conv_param->input_h_; + int i_w = conv_param->input_w_; + int out_channel = conv_param->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int oc8 = UP_DIV(out_channel, C8NUM); + + size_t weight_size; + std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + std::cout << "==============fp32 weight data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << weight_data[i] << ", "; + } + std::cout << std::endl; + + std::cout << "weight data size: " << weight_size / sizeof(float) << std::endl; + + int weight_ele_size = weight_size / sizeof(float); + auto fp16_weight_data = new float16_t[weight_ele_size]; + for (int i = 0; i < weight_ele_size; i++) { + fp16_weight_data[i] = static_cast(weight_data[i]); + } + + std::cout << "==============fp16 weight data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << fp16_weight_data[i] << ", "; + } + std::cout << std::endl; + + auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float16_t))); + PackWeightFp16(fp16_weight_data, conv_param, packed_weight); + + std::cout << "==============fp16 packed weight data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::cout << "==============fp32 input data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << input_data[i] << ", "; + } + std::cout << std::endl; + + int input_ele_size = input_size / sizeof(float); + auto fp16_input_data = new float16_t[input_ele_size]; + for (int i = 0; i < input_ele_size; i++) { + fp16_input_data[i] = static_cast(input_data[i]); + } + + auto nhwc4_input_data = reinterpret_cast(malloc(i_h * i_w * ic4 * C4NUM* sizeof(float16_t))); + PackNHWCToNHWC4Fp32(fp16_input_data, nhwc4_input_data, 1, i_h * i_w, in_channel); + + std::cout << "==============fp16 input data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << fp16_input_data[i] << ", "; + } + std::cout << std::endl; + + int output_count = conv_param->output_h_ * conv_param->output_w_; + int output_tile_count = UP_DIV(output_count, tile_num); + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_num * unit_size; + auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); + memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t)); + + auto bias_data = reinterpret_cast(malloc(conv_param->output_channel_ * sizeof(float16_t))); + memset(bias_data, 0, conv_param->output_channel_ * sizeof(float16_t)); + + size_t output_data_size = + conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + auto output_data = new float16_t[output_data_size]; + auto tmp_output_block = reinterpret_cast(malloc(tile_num * out_channel * sizeof(float16_t))); + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + // warmup + for (int i = 0; i < 3; i++) { + ConvFp16(nhwc4_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + ConvFp16(nhwc4_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::cout << "==============fp16 output data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << ", "; + } + std::cout << std::endl; + + auto fp32_output_data = new float[output_data_size]; + for (int i = 0; i < output_data_size; i++) { + fp32_output_data[i] = static_cast(output_data[i]); + } + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << fp32_output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin"; + lite::CompareOutput(fp32_output_data, output_path); + + free(nhwc4_input_data); + free(packed_input); + free(bias_data); + free(packed_weight); + free(tmp_output_block); + delete conv_param; + delete input_data; + delete weight_data; + delete[] fp16_weight_data; + delete[] fp16_input_data; + delete[] fp32_output_data; + delete[] output_data; + MS_LOG(INFO) << "TestConvolutionFp16 passed"; +} + +TEST_F(TestConvolutionFp16, ConvTest2) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup2Fp16(conv_param); + + // parameter + int tile_num = 16; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int kernel_plane = k_h * k_w; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int oc8 = UP_DIV(out_channel, C8NUM); + + // weight + size_t weight_size; + std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_32.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + int weight_ele_size = weight_size / sizeof(float); + auto fp16_weight_data = new float16_t[weight_ele_size]; + for (int i = 0; i < weight_ele_size; i++) { + fp16_weight_data[i] = static_cast(weight_data[i]); + } + auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float16_t))); + PackWeightFp16(fp16_weight_data, conv_param, packed_weight); + + // input + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_128_128_32.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + int input_ele_size = input_size / sizeof(float); + auto fp16_input_data = new float16_t[input_ele_size]; + for (int i = 0; i < input_ele_size; i++) { + fp16_input_data[i] = static_cast(input_data[i]); + } + int output_count = conv_param->output_h_ * conv_param->output_w_; + int output_tile_count = UP_DIV(output_count, tile_num); + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_num * unit_size; + auto packed_input = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); + memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t)); + + // bias + auto bias_data = reinterpret_cast(malloc(conv_param->output_channel_ * sizeof(float16_t))); + memset(bias_data, 0, conv_param->output_channel_ * sizeof(float16_t)); + + // output + auto tmp_output_block = reinterpret_cast(malloc(tile_num * out_channel * sizeof(float16_t))); + size_t output_data_size = + conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + auto output_data = new float16_t[output_data_size]; + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + // warmup + for (int i = 0; i < 3; i++) { + ConvFp16(fp16_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + ConvFp16(fp16_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::cout << "==============fp16 output data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << ", "; + } + std::cout << std::endl; + + auto fp32_output_data = new float[output_data_size]; + for (int i = 0; i < output_data_size; i++) { + fp32_output_data[i] = static_cast(output_data[i]); + } + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << fp32_output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/conv/convfp32_out_1_128_128_32.bin"; + lite::CompareOutput(fp32_output_data, output_path); + + free(packed_input); + free(bias_data); + free(packed_weight); + free(tmp_output_block); + delete conv_param; + delete input_data; + delete weight_data; + delete[] fp16_weight_data; + delete[] fp16_input_data; + delete[] fp32_output_data; + delete[] output_data; + MS_LOG(INFO) << "TestConvolutionFp16 passed"; +} + +TEST_F(TestConvolutionFp16, Conv3x3Test1) { + auto conv_param = new ConvParameter(); + InitConvParamGroup1Fp16(conv_param); + // todo + int thread_count = 1; + int tile_num = 16; + int output_batch = conv_param->output_batch_; + int output_h = conv_param->output_h_; + int output_w = conv_param->output_w_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + + // tmp buffer + int k_plane = 36; + size_t tile_buffer_size = thread_count * tile_num * k_plane * ic4 * C4NUM * sizeof(float16_t); + float16_t *tile_buffer = reinterpret_cast(malloc(tile_buffer_size)); + memset(tile_buffer, 0, tile_buffer_size); + + size_t block_unit_buffer_size = thread_count * k_plane * C4NUM * sizeof(float16_t); + float16_t *block_unit_buffer = reinterpret_cast(malloc(block_unit_buffer_size)); + memset(block_unit_buffer, 0, block_unit_buffer_size); + + size_t tmp_dst_buffer_size = thread_count * tile_num * k_plane * oc8 * C8NUM * sizeof(float16_t); + float16_t *tmp_dst_buffer = reinterpret_cast(malloc(tmp_dst_buffer_size)); + memset(tmp_dst_buffer, 0, tmp_dst_buffer_size); + + size_t tmp_out_size = oc8 * C8NUM * output_batch * output_h * output_w * tile_num * sizeof(float16_t); + float16_t *tmp_out = reinterpret_cast(malloc(tmp_out_size)); + memset(tmp_out, 0, tmp_out_size); + + // weight + size_t weight_size; + std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + std::cout << "==============fp32 weight data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << weight_data[i] << ", "; + } + std::cout << std::endl; + + std::cout << "weight data size: " << weight_size / sizeof(float) << std::endl; + + int weight_ele_size = weight_size / sizeof(float); + auto fp16_weight_data = new float16_t[weight_ele_size]; + for (int i = 0; i < weight_ele_size; i++) { + fp16_weight_data[i] = (float16_t)weight_data[i]; + } + + std::cout << "==============fp16 weight data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << fp16_weight_data[i] << ", "; + } + std::cout << std::endl; + + size_t transformed_size = ic4 * C4NUM * oc8 * C8NUM * 36; + auto transformed_weight_data = new float16_t[transformed_size]; + memset(transformed_weight_data, 0, transformed_size * sizeof(float16_t)); + kernel::ProcessFilterFp16(fp16_weight_data, transformed_weight_data, conv_param); + + // bias + auto bias_data = + reinterpret_cast(malloc(UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t))); + memset(bias_data, 0, UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t)); + + // input + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::cout << "==============fp32 input data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << input_data[i] << ", "; + } + std::cout << std::endl; + + int input_ele_size = input_size / sizeof(float); + auto fp16_input_data = new float16_t[input_ele_size]; + for (int i = 0; i < input_ele_size; i++) { + fp16_input_data[i] = static_cast(input_data[i]); + } + + std::cout << "==============fp16 input data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << fp16_input_data[i] << ", "; + } + std::cout << std::endl; + + // output + size_t output_data_size = + conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + auto output_data = new float16_t[output_data_size]; + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + // warmup + for (int i = 0; i < 3; i++) { + Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, + tmp_dst_buffer, tmp_out, 0, conv_param); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, + tmp_dst_buffer, tmp_out, 0, conv_param); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::cout << "==============fp16 output data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << ", "; + } + std::cout << std::endl; + + auto fp32_output_data = new float[output_data_size]; + for (int i = 0; i < output_data_size; i++) { + fp32_output_data[i] = static_cast(output_data[i]); + } + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << fp32_output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin"; + lite::CompareOutput(fp32_output_data, output_path); + + free(bias_data); + free(tile_buffer); + free(block_unit_buffer); + free(tmp_dst_buffer); + free(tmp_out); + delete input_data; + delete weight_data; + delete conv_param; + delete[] fp16_weight_data; + delete[] fp16_input_data; + delete[] fp32_output_data; + delete[] output_data; + delete[] transformed_weight_data; + MS_LOG(INFO) << "TestConvolutionFp16 Conv3x3 passed"; +} + +TEST_F(TestConvolutionFp16, Conv3x3Test2) { + auto conv_param = new ConvParameter(); + InitConvParamGroup2Fp16(conv_param); + // todo + int thread_count = 1; + int tile_num = 16; + int output_batch = conv_param->output_batch_; + int output_h = conv_param->output_h_; + int output_w = conv_param->output_w_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + + // tmp buffer + int k_plane = 36; + size_t tile_buffer_size = thread_count * tile_num * k_plane * ic4 * C4NUM * sizeof(float16_t); + float16_t *tile_buffer = reinterpret_cast(malloc(tile_buffer_size)); + memset(tile_buffer, 0, tile_buffer_size); + + size_t block_unit_buffer_size = thread_count * k_plane * C4NUM * sizeof(float16_t); + float16_t *block_unit_buffer = reinterpret_cast(malloc(block_unit_buffer_size)); + memset(block_unit_buffer, 0, block_unit_buffer_size); + + size_t tmp_dst_buffer_size = thread_count * tile_num * k_plane * oc8 * C8NUM * sizeof(float16_t); + float16_t *tmp_dst_buffer = reinterpret_cast(malloc(tmp_dst_buffer_size)); + memset(tmp_dst_buffer, 0, tmp_dst_buffer_size); + + size_t tmp_out_size = oc8 * C8NUM * output_batch * output_h * output_w * tile_num * sizeof(float16_t); + float16_t *tmp_out = reinterpret_cast(malloc(tmp_out_size)); + memset(tmp_out, 0, tmp_out_size); + + // weight + size_t weight_size; + std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_32.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + int weight_ele_size = weight_size / sizeof(float); + auto fp16_weight_data = new float16_t[weight_ele_size]; + for (int i = 0; i < weight_ele_size; i++) { + fp16_weight_data[i] = static_cast(weight_data[i]); + } + size_t transformed_size = ic4 * C4NUM * oc8 * C8NUM * 36; + auto transformed_weight_data = new float16_t[transformed_size]; + memset(transformed_weight_data, 0, transformed_size * sizeof(float16_t)); + kernel::ProcessFilterFp16(fp16_weight_data, transformed_weight_data, conv_param); + + // bias + auto bias_data = + reinterpret_cast(malloc(UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t))); + memset(bias_data, 0, UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t)); + + // input + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_input_1_128_128_32.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + int input_ele_size = input_size / sizeof(float); + auto fp16_input_data = new float16_t[input_ele_size]; + for (int i = 0; i < input_ele_size; i++) { + fp16_input_data[i] = static_cast(input_data[i]); + } + + // output + size_t output_data_size = + conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + auto output_data = new float16_t[output_data_size]; + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + // warmup + for (int i = 0; i < 3; i++) { + Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, + tmp_dst_buffer, tmp_out, 0, conv_param); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer, + tmp_dst_buffer, tmp_out, 0, conv_param); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::cout << "==============fp16 output data===========" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << ", "; + } + std::cout << std::endl; + + auto fp32_output_data = new float[output_data_size]; + for (int i = 0; i < output_data_size; i++) { + fp32_output_data[i] = static_cast(output_data[i]); + } + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << fp32_output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/conv/convfp32_out_1_128_128_32.bin"; + lite::CompareOutput(fp32_output_data, output_path); + + free(bias_data); + free(tile_buffer); + free(block_unit_buffer); + free(tmp_dst_buffer); + free(tmp_out); + delete input_data; + delete weight_data; + delete conv_param; + delete[] fp16_weight_data; + delete[] fp16_input_data; + delete[] fp32_output_data; + delete[] output_data; + delete[] transformed_weight_data; + MS_LOG(INFO) << "TestConvolutionFp16 Conv3x3 passed"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc new file mode 100644 index 0000000000..0fdfe5f29e --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { + +class TestActivationFp32 : public mindspore::Common { + public: + TestActivationFp32() {} +}; + +TEST_F(TestActivationFp32, ReluFp32) { + float input[8] = {-3, -2, -1, 0, 1, 5, 6, 7}; + float output[8] = {0}; + Relu(input, 8, output); + float expect[8] = {0, 0, 0, 0, 1, 5, 6, 7}; + for (int i = 0; i < 8; ++i) { + ASSERT_EQ(output[i], expect[i]); + } +} + +TEST_F(TestActivationFp32, Relu6Fp32) { + float input[8] = {-3, -2, -1, 0, 1, 5, 6, 7}; + float output[8] = {0}; + Relu6(input, 8, output); + float expect[8] = {0, 0, 0, 0, 1, 5, 6, 6}; + for (int i = 0; i < 8; ++i) { + ASSERT_EQ(output[i], expect[i]); + } + MS_LOG(INFO) << "TestActivationFp32 passed"; +} + +TEST_F(TestActivationFp32, LReluFp32) { + float input[8] = {-3, -2, -1, 0, 1, 5, 6, 7}; + float output[8] = {0}; + LRelu(input, 8, output, 0.01); + float expect[8] = {-0.03, -0.02, -0.01, 0, 1, 5, 6, 7}; + for (int i = 0; i < 8; ++i) { + ASSERT_EQ(output[i], expect[i]); + } + MS_LOG(INFO) << "TestActivationFp32 passed"; +} + +TEST_F(TestActivationFp32, SigmoidFp32) { + float input[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + float output[8] = {0}; + Sigmoid(input, 8, output); + + // expect output {0.5, 0.731059, 0.880797, 0.952574, 0.982014, 0.993307, 0.997527, 0.999089}; + printf("==================output data=================\n"); + for (int i = 0; i < 8; ++i) { + std::cout << output[i] << " "; + } + std::cout << std::endl; + MS_LOG(INFO) << "TestSigmoidFp32 passed"; +} + +TEST_F(TestActivationFp32, TanhFp32) { + float input[7] = {-3, -2, -1, 0, 1, 2, 3}; + float output[7] = {0}; + Tanh(input, 7, output); + float expect[8] = {-0.995055, -0.964028, -0.761594, 0.000000, 0.761594, 0.964028, 0.995055}; + for (int i = 0; i < 8; ++i) { + EXPECT_NEAR(output[i], expect[i], 0.00001); + } + MS_LOG(INFO) << "TanhFp32 passed"; +} + +TEST_F(TestActivationFp32, HSwishFp32) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + ActivationParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Activation; + op_param.type_ = schema::ActivationType_HSWISH; + op_param.alpha_ = 0.01; + + std::vector input = {-3.0, -2.0, -1.0, 0.0, 1.0, 5.0, 6.0, 7.0}; + std::vector in_shape = {8}; + + lite::tensor::Tensor input0_tensor; + inputs_tensor.push_back(&input0_tensor); + input0_tensor.SetData(input.data()); + input0_tensor.set_shape(in_shape); + + std::vector output(8); + std::vector output_shape = {8}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Activation}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 7; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector expect_output = {-0, -0.33333334, -0.33333334, 0, 0.6666667, 5, 6, 7}; + CompareOutputData(output.data(), expect_output.data(), 8, 0.00001); + + input0_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_grad_fp32_tests.cc new file mode 100644 index 0000000000..1badd29a26 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_grad_fp32_tests.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h" + +namespace mindspore { +class TestActGradFp32 : public mindspore::Common { + public: + TestActGradFp32() {} +}; + +TEST_F(TestActGradFp32, ReluGradFp32) { + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = 50; + + size_t input_size; + std::string input_path = "./test_data/activationGrad/relu_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + ReluGrad(yt_data, input_data, 50, output_data); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + ReluGrad(yt_data, input_data, 50, output_data); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/activationGrad/relu_out_50.bin"; + + int res = lite::CompareRelativeOutput(output_data, output_path); + + EXPECT_EQ(res, 0); + + delete input_data; + delete[] output_data; + delete yt_data; + + MS_LOG(INFO) << "ReluGradFp32 passed"; +} + +TEST_F(TestActGradFp32, Relu6GradFp32) { + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = 50; + + size_t input_size; + std::string input_path = "./test_data/activationGrad/relu6_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::string yt_path = "./test_data/activationGrad/relu6_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + Relu6Grad(yt_data, input_data, 50, output_data); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + Relu6Grad(yt_data, input_data, 50, output_data); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/activationGrad/relu6_out_50.bin"; + int res = lite::CompareRelativeOutput(output_data, output_path); + + EXPECT_EQ(res, 0); + + delete input_data; + delete[] output_data; + delete yt_data; + + MS_LOG(INFO) << "Relu6GradFp32 passed"; +} + +TEST_F(TestActGradFp32, LReluGradFp32) { + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = 50; + + size_t input_size; + std::string input_path = "./test_data/activationGrad/lrelu_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::string yt_path = "./test_data/activationGrad/lrelu_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + LReluGrad(yt_data, input_data, 50, output_data, 0.1); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + LReluGrad(yt_data, input_data, 50, output_data, 0.1); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/activationGrad/lrelu_out_50.bin"; + int res = lite::CompareRelativeOutput(output_data, output_path); + + EXPECT_EQ(res, 0); + + delete input_data; + delete[] output_data; + delete yt_data; + + MS_LOG(INFO) << "LReluGradFp32 passed"; +} + +TEST_F(TestActGradFp32, SigmoidGradFp32) { + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = 50; + + size_t input_size; + std::string input_path = "./test_data/activationGrad/sigmoid_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::string yt_path = "./test_data/activationGrad/sigmoid_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + SigmoidGrad(yt_data, input_data, 50, output_data); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + SigmoidGrad(yt_data, input_data, 50, output_data); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/activationGrad/sigmoid_out_50.bin"; + int res = lite::CompareRelativeOutput(output_data, output_path); + + EXPECT_EQ(res, 0); + // lite::CompareOutput(output_data, output_path); + + delete input_data; + delete[] output_data; + delete yt_data; + + MS_LOG(INFO) << "SigmoidGradFp32 passed"; +} + +TEST_F(TestActGradFp32, tanhGradFp32) { + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = 50; + + size_t input_size; + std::string input_path = "./test_data/activationGrad/tanh_y_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::string yt_path = "./test_data/activationGrad/tanh_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + TanhGrad(yt_data, input_data, 50, output_data); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + TanhGrad(yt_data, input_data, 50, output_data); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/activationGrad/tanh_out_50.bin"; + int res = lite::CompareRelativeOutput(output_data, output_path); + + EXPECT_EQ(res, 0); + + delete input_data; + delete[] output_data; + delete yt_data; + MS_LOG(INFO) << "TanhGradFp32 passed"; +} + +TEST_F(TestActGradFp32, hswishGradFp32) { + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = 50; + + size_t input_size; + std::string input_path = "./test_data/activationGrad/hswish_x_50.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::string yt_path = "./test_data/activationGrad/hswish_yt_50.bin"; + auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + HSwishGrad(yt_data, input_data, 50, output_data); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + HSwishGrad(yt_data, input_data, 50, output_data); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/activationGrad/hswish_out_50.bin"; + int res = lite::CompareRelativeOutput(output_data, output_path); + + EXPECT_EQ(res, 0); + + delete input_data; + delete[] output_data; + delete yt_data; + MS_LOG(INFO) << "hswishGradFp32 passed"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc new file mode 100644 index 0000000000..4a56abd804 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc @@ -0,0 +1,328 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arg_min_max.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +namespace mindspore { + +class TestArgMinMaxTestFp32 : public mindspore::Common { + public: + TestArgMinMaxTestFp32() = default; +}; + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {2, 2, 0, 2, 0}; + std::vector shape = {3, 5}; + float out[5]; + ArgMinMaxParameter param; + param.topk_ = 1; + param.out_value_ = false; + param.axis_ = 0; + param.data_type_ = 43; + param.dims_size_ = 2; + param.get_max_ = true; + ArgMinMax(in.data(), out, shape.data(), ¶m); + for (size_t i = 0; i < except_out.size(); ++i) { + std::cout << out[i] << " "; + } + std::cout << "\n"; + CompareOutputData(out, except_out.data(), except_out.size(), 0.000001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest2) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {30, 45, 30, 50, 90}; + std::vector shape = {3, 5}; + float out[5]; + ArgMinMaxParameter param; + param.topk_ = 1; + param.out_value_ = true; + param.axis_ = 0; + param.data_type_ = 43; + param.dims_size_ = 2; + param.get_max_ = true; + ArgMinMax(in.data(), out, shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.000001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMinTest2) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {10, 11, 15, 1, 30}; + std::vector shape = {3, 5}; + float out[5]; + ArgMinMaxParameter param; + param.topk_ = 1; + param.out_value_ = true; + param.axis_ = 0; + param.data_type_ = 43; + param.dims_size_ = 2; + param.get_max_ = false; + ArgMinMax(in.data(), out, shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.000001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest3_axis2_out_data) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {30, 45, 30, 50, 90, 20, 20, 25, 40, 50}; + ArgMinMaxParameter param; + param.axis_ = 2; + std::vector in_shape = {1, 1, 3, 5}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = true; + param.topk_ = 2; + std::vector out_shape = {1, 1, 2, 5}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[10]; + ArgMaxDim2(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest3_axis2_out_index) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {2, 2, 0, 2, 0, 1, 0, 2, 0, 1}; + ArgMinMaxParameter param; + param.axis_ = 2; + std::vector in_shape = {1, 1, 3, 5}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = false; + param.topk_ = 2; + std::vector out_shape = {1, 1, 2, 5}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[10]; + ArgMaxDim2(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest4_axis3_out_data) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {90, 40, + 50, 20, + 50, 45}; + ArgMinMaxParameter param; + param.axis_ = 3; + std::vector in_shape = {1, 1, 3, 5}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = true; + param.topk_ = 2; + std::vector out_shape = {1, 1, 3, 2}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[6]; + ArgMaxDim3(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest4_axis3_out_index) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {4, 3, + 4, 0, + 3, 1}; + ArgMinMaxParameter param; + param.axis_ = 3; + std::vector in_shape = {1, 1, 3, 5}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = false; + param.topk_ = 2; + std::vector out_shape = {1, 1, 3, 2}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[6]; + ArgMaxDim3(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest5_axis1_out_index) { + std::vector in = {100, 2, 300, + 4, 50, 6, + 11, 12, 13, + 34, 35, 36, + 9, 6, 17, + 10, 20, 30, + 10, 20, 30, + 40, 5, 60, + 7, 80, 90, + 10, 11, 120, + 18, 5, 16, + 9, 22, 23}; + std::vector except_out = {0, 1, 0, + 1, 0, 1, + 1, 2, 2, + 2, 1, 2, + 2, 1, 1, + 0, 2, 1, + 0, 0, 0, + 1, 1, 0}; + ArgMinMaxParameter param; + param.axis_ = 1; + std::vector in_shape = {2, 3, 2, 3}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = false; + param.topk_ = 2; + std::vector out_shape = {2, 2, 2, 3}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[24]; + ArgMaxDim1(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest5_axis1_out_data) { + std::vector in = {100, 2, 300, + 4, 50, 6, + 11, 12, 13, + 34, 35, 36, + 9, 6, 17, + 10, 20, 30, + 10, 20, 30, + 40, 5, 60, + 7, 80, 90, + 10, 11, 120, + 18, 5, 16, + 9, 22, 23}; + std::vector except_out = {100, 12, 300, + 34, 50, 36, + 11, 6, 17, + 10, 35, 30, + 18, 80, 90, + 40, 22, 120, + 10, 20, 30, + 10, 11, 60}; + ArgMinMaxParameter param; + param.axis_ = 1; + std::vector in_shape = {2, 3, 2, 3}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = true; + param.topk_ = 2; + std::vector out_shape = {2, 2, 2, 3}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[24]; + ArgMaxDim1(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest6_axis0_out_index) { + std::vector in = {100, 2, + 4, 50, + 11, 12, + 34, 35, + 10, 20, + 40, 5, + 7, 80, + 10, 11, + 55, 25, + 5, 15, + 18, 8, + 15, 16}; + std::vector except_out = {0, 2, + 1, 0, + 2, 1, + 0, 0, + 2, 1, + 2, 2, + 0, 0, + 2, 2}; + ArgMinMaxParameter param; + param.axis_ = 1; + std::vector in_shape = {3, 2, 2, 2}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = false; + param.topk_ = 2; + std::vector out_shape = {2, 2, 2, 2}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[16]; + ArgMaxDim0(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMaxTest6_axis0_out_data) { + std::vector in = {100, 2, + 4, 50, + 11, 12, + 34, 35, + 10, 20, + 40, 5, + 7, 80, + 10, 11, + 55, 25, + 5, 15, + 18, 8, + 15, 16}; + std::vector except_out = {100, 25, + 40, 50, + 18, 80, + 34, 35, + 55, 20, + 5, 15, + 11, 12, + 15, 16}; + ArgMinMaxParameter param; + param.axis_ = 1; + std::vector in_shape = {3, 2, 2, 2}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = true; + param.topk_ = 2; + std::vector out_shape = {2, 2, 2, 2}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[16]; + ArgMaxDim0(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +TEST_F(TestArgMinMaxTestFp32, ArgMinTest1_axis3_out_data) { + std::vector in = {10, 20, 30, 40, 90, + 20, 11, 15, 1, 50, + 30, 45, 25, 50, 30}; + std::vector except_out = {10, 20, + 1, 11, + 25, 30}; + ArgMinMaxParameter param; + param.axis_ = 3; + std::vector in_shape = {1, 1, 3, 5}; + param.arg_elements_ = reinterpret_cast(malloc(in_shape[param.axis_] * sizeof(ArgElement))); + param.out_value_ = true; + param.topk_ = 2; + std::vector out_shape = {1, 1, 3, 2}; + ComputeStrides(in_shape.data(), param.in_strides_, in_shape.size()); + ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); + float out[6]; + ArgMinDim3(in.data(), out, in_shape.data(), ¶m); + CompareOutputData(out, except_out.data(), except_out.size(), 0.00001); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc new file mode 100644 index 0000000000..fe4a3b7cf2 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc @@ -0,0 +1,497 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { + +class TestArithmeticGradFp32 : public mindspore::Common { + public: + TestArithmeticGradFp32() {} +}; + +std::vector GenerateTensorsForTest(const char *test, int test_id) { + size_t input_size; + std::vector large_dim({4, 6}); + std::vector small_dim({6}); + int large_size = (4 * 6); + int small_size = (1 * 6); + char *dx1_file = const_cast("./test_data/operators/arithmetic_fp32_1_x1_4_6.bin"); + char *dx2_file = const_cast("./test_data/operators/arithmetic_fp32_1_x2_1_6.bin"); + + if (test_id == 7) { + large_dim = std::vector({4, 5, 6}); + small_dim = std::vector({6}); + large_size = (4 * 5 * 6); + small_size = (6); + dx1_file = const_cast("./test_data/operators/arithmetic_fp32_7_x1_4_5_6.bin"); + dx2_file = const_cast("./test_data/operators/arithmetic_fp32_7_x2_1_1_6.bin"); + } + if (test_id >= 8) { + large_dim = std::vector({5, 4, 6}); + small_dim = std::vector({5, 1, 6}); + large_size = (4 * 5 * 6); + small_size = (5 * 6); + dx1_file = const_cast("./test_data/operators/arithmetic_fp32_8_x1_5_4_6.bin"); + dx2_file = const_cast("./test_data/operators/arithmetic_fp32_8_x2_5_1_6.bin"); + } + + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(test, &input_size)); + lite::tensor::Tensor *dy_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, large_dim); + dy_tensor->SetData(dy_data); + + auto x1_data = reinterpret_cast(mindspore::lite::ReadFile(dx1_file, &input_size)); + lite::tensor::Tensor *x1_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, large_dim); + x1_tensor->SetData(x1_data); + + auto x2_data = reinterpret_cast(mindspore::lite::ReadFile(dx2_file, &input_size)); + lite::tensor::Tensor *x2_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, small_dim); + x2_tensor->SetData(x2_data); + + auto dx1_data = new float[large_size]; + lite::tensor::Tensor *dx1_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, large_dim); + dx1_tensor->SetData(dx1_data); + + auto dx2_data = new float[small_size]; + lite::tensor::Tensor *dx2_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, small_dim); + dx2_tensor->SetData(dx2_data); + + std::vector ret_vector = {dy_tensor, x1_tensor, x2_tensor, dx1_tensor, dx2_tensor}; + return ret_vector; +} + +TEST_F(TestArithmeticGradFp32, TestAddGradFp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_AddGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_1_dy_4_6.bin", 1); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_1_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestAddGradFp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestAddGrad2Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_AddGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_1_dy_4_6.bin", 1); + + std::vector inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; + std::vector outputs = {all_tensors[4], all_tensors[3]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[0]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_1_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[1]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestAddGrad2Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestAddGrad3Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_AddGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_8_dy_5_4_6.bin", 8); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[0]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_8_dx2_5_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[1]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestAddGrad3Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestSubGradFp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_SubGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_2_dy_4_6.bin", 2); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_2_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_2_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestSubGradFp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestSubGrad2Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_SubGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_3_dy_4_6.bin", 3); + + std::vector inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; + std::vector outputs = {all_tensors[4], all_tensors[3]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[0]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_3_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[1]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_3_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestSubGrad2Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestMulGradFp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_MulGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_4_dy_4_6.bin", 4); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + + int loop_count = 1000; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel_obj->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + printf("total cost (for %d loops): %lu us\n", loop_count, cost); + // auto time_avg = cost / loop_count; + // printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_4_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestMulGradFp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestMulGrad2Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_MulGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_4_dy_4_6.bin", 4); + + std::vector inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; + std::vector outputs = {all_tensors[4], all_tensors[3]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[0]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_4_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[1]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestMulGrad2Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestMulGrad3Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_MulGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin", 9); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestMulGrad3Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestMulGrad4Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_MulGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin", 9); + + std::vector inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; + std::vector outputs = {all_tensors[4], all_tensors[3]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[0]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[1]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestMulGrad4Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestDivGradFp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_DivGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_5_dy_4_6.bin", 5); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string output_path = "./test_data/operators/arithmetic_fp32_5_dx1_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), output_path)); + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_5_dx2_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestDivGradFp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestDivGrad2Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_DivGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_6_dy_4_6.bin", 6); + + std::vector inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; + std::vector outputs = {all_tensors[4], all_tensors[3]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[0]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string dx2_path = "./test_data/operators/arithmetic_fp32_6_dx2_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[1]->Data()), dx2_path)); + + std::string output_path = "./test_data/operators/arithmetic_fp32_6_dx1_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestDivGrad2Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, TestDivGrad3Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_DivGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_10_dy_5_4_6.bin", 10); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string dx1_path = "./test_data/operators/arithmetic_fp32_10_dx1_5_4_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), dx1_path)); + + std::string output_path = "./test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestDivGrad3Fp32 passed"; +} + +TEST_F(TestArithmeticGradFp32, Test3DDivGrad2Fp32) { + auto param = new ArithmeticParameter(); + param->op_parameter_.type_ = PrimitiveType_DivGrad; + std::vector all_tensors = + GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_7_dy_4_5_6.bin", 7); + + std::vector inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; + std::vector outputs = {all_tensors[3], all_tensors[4]}; + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + kernel_obj->Run(); + + float *output_ptr = reinterpret_cast(outputs[1]->Data()); + printf("==================output data=================\n"); + for (int i = 0; i < 6; i++) { + std::cout << output_ptr[i] << " ,"; + } + std::cout << std::endl; + + std::string dx1_path = "./test_data/operators/arithmetic_fp32_7_dx1_4_5_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast(outputs[0]->Data()), dx1_path)); + + std::string output_path = "./test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin"; + EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); + + for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete param; + MS_LOG(INFO) << "TestDivGrad2Fp32 passed"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc new file mode 100644 index 0000000000..e5b75c5ed4 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/batch_to_space.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +namespace mindspore { + +class BatchToSpaceTestFp32 : public mindspore::Common { + public: + BatchToSpaceTestFp32() = default; +}; + + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest1) { + float input[12] = {10, 30, 90, 2, 20, 120, 5, 50, 150, 6, 16, 160}; + constexpr int kOutSize = 12; + float expect_out[kOutSize] = {10, 30, 90, 2, 20, 120, 5, 50, 150, 6, 16, 160}; + + float output[kOutSize]; + int in_shape[4] = {4, 1, 1, 3}; + int out_n = 1; + int block[2] = {2, 2}; + BatchToSpaceNoCropForNHWC(input, output, in_shape, out_n, block, sizeof(float)); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest_crop_1) { + float input[12] = {10, 30, 90, 2, 20, 120, 5, 50, 150, 6, 16, 160}; + constexpr int kOutSize = 3; + float expect_out[kOutSize] = {5, 50, 150}; + + float output[kOutSize]; + int in_shape[4] = {4, 1, 1, 3}; + int out_n = 1; + int block[2] = {2, 2}; + int crops[4] = {1, 0, 0, 1}; + BatchToSpaceForNHWC(input, output, in_shape, out_n, block, crops, sizeof(float)); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest2) { + float input[32] = {1, 10, 3, 30, 9, 90, 11, 110, 2, 20, 4, 40, 10, 100, 12, 120, + 5, 50, 7, 70, 13, 130, 15, 150, 6, 60, 8, 80, 14, 140, 16, 160}; + constexpr int kOutSize = 32; + float expect_out[kOutSize] = {1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 6, 60, 7, 70, 8, 80, + 9, 90, 10, 100, 11, 110, 12, 120, 13, 130, 14, 140, 15, 150, 16, 160}; + + float output[kOutSize]; + int in_shape[4] = {4, 2, 2, 2}; + int out_n = 1; + int block[2] = {2, 2}; + BatchToSpaceNoCropForNHWC(input, output, in_shape, out_n, block, sizeof(float)); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest_crop_2) { + float input[32] = {1, 10, 3, 30, 9, 90, 11, 110, 2, 20, 4, 40, 10, 100, 12, 120, + 5, 50, 7, 70, 13, 130, 15, 150, 6, 60, 8, 80, 14, 140, 16, 160}; + constexpr int kOutSize = 12; + float expect_out[kOutSize] = {6, 60, 7, 70, 8, 80, + 10, 100, 11, 110, 12, 120}; + + float output[kOutSize]; + int in_shape[4] = {4, 2, 2, 2}; + int out_n = 1; + int block[2] = {2, 2}; + int crops[4] = {1, 1, 1, 0}; + BatchToSpaceForNHWC(input, output, in_shape, out_n, block, crops, sizeof(float)); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest3) { + float input[64] = {1, 10, 3, 30, 9, 90, 11, 110, 2, 20, 4, 40, 10, 100, 12, 120, + 5, 50, 7, 70, 13, 130, 15, 150, 6, 60, 8, 80, 14, 140, 16, 160, + 21, 10, 23, 30, 29, 90, 211, 110, 22, 20, 24, 40, 210, 100, 212, 120, + 25, 50, 27, 70, 213, 130, 215, 150, 26, 60, 28, 80, 214, 140, 216, 160}; + constexpr int kOutSize = 64; + float expect_out[kOutSize] = {1, 10, 5, 50, 3, 30, 7, 70, 21, 10, 25, 50, 23, 30, 27, 70, + 9, 90, 13, 130, 11, 110, 15, 150, 29, 90, 213, 130, 211, 110, 215, 150, + 2, 20, 6, 60, 4, 40, 8, 80, 22, 20, 26, 60, 24, 40, 28, 80, + 10, 100, 14, 140, 12, 120, 16, 160, 210, 100, 214, 140, 212, 120, 216, 160}; + + float output[kOutSize]; + int in_shape[4] = {8, 2, 2, 2}; + int out_n = 2; + int block[2] = {2, 2}; + BatchToSpaceNoCropForNHWC(input, output, in_shape, out_n, block, sizeof(float)); + for (int i = 0; i < kOutSize && i < 32; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest_crop_3) { + float input[64] = {1, 10, 3, 30, 9, 90, 11, 110, 2, 20, 4, 40, 10, 100, 12, 120, + 5, 50, 7, 70, 13, 130, 15, 150, 6, 60, 8, 80, 14, 140, 16, 160, + 21, 10, 23, 30, 29, 90, 211, 110, 22, 20, 24, 40, 210, 100, 212, 120, + 25, 50, 27, 70, 213, 130, 215, 150, 26, 60, 28, 80, 214, 140, 216, 160}; + constexpr int kOutSize = 16; + float expect_out[kOutSize] = {9, 90, 13, 130, 29, 90, 213, 130, + 10, 100, 14, 140, 210, 100, 214, 140}; + + float output[kOutSize]; + int in_shape[4] = {8, 2, 2, 2}; + int out_n = 2; + int block[2] = {2, 2}; + int crops[4] = {2, 0, 0, 2}; + BatchToSpaceForNHWC(input, output, in_shape, out_n, block, crops, sizeof(float)); + for (int i = 0; i < kOutSize && i < 32; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest4) { + float input[96] = {1, 10, 3, 30, 9, 90, 11, 110, 2, 20, 4, 40, 10, 100, 12, 120, 5, 50, 7, 70, + 13, 130, 15, 150, 6, 60, 8, 80, 14, 140, 16, 160, 21, 10, 23, 30, 29, 90, 211, 110, + 22, 20, 24, 40, 210, 100, 212, 120, 25, 50, 27, 70, 213, 130, 215, 150, 26, 60, 28, 80, + 214, 140, 216, 160, 31, 10, 33, 30, 39, 90, 311, 110, 32, 20, 34, 40, 310, 100, 312, 120, + 35, 50, 37, 70, 313, 130, 315, 150, 36, 60, 38, 80, 314, 140, 316, 160}; + constexpr int kOutSize = 96; + float expect_out[kOutSize] = { + 1, 10, 5, 50, 3, 30, 7, 70, 21, 10, 25, 50, 23, 30, 27, 70, 31, 10, 35, 50, 33, 30, 37, 70, + 9, 90, 13, 130, 11, 110, 15, 150, 29, 90, 213, 130, 211, 110, 215, 150, 39, 90, 313, 130, 311, 110, 315, 150, + 2, 20, 6, 60, 4, 40, 8, 80, 22, 20, 26, 60, 24, 40, 28, 80, 32, 20, 36, 60, 34, 40, 38, 80, + 10, 100, 14, 140, 12, 120, 16, 160, 210, 100, 214, 140, 212, 120, 216, 160, 310, 100, 314, 140, 312, 120, 316, 160}; + + float output[kOutSize]; + int in_shape[4] = {12, 2, 2, 2}; + int out_n = 2; + int block[2] = {3, 2}; + BatchToSpaceNoCropForNHWC(input, output, in_shape, out_n, block, sizeof(float)); + for (int i = 0; i < kOutSize && i < 32; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(BatchToSpaceTestFp32, BatchToSpaceTest_crop_4) { + float input[96] = {1, 10, 3, 30, 9, 90, 11, 110, 2, 20, 4, 40, 10, 100, 12, 120, 5, 50, 7, 70, + 13, 130, 15, 150, 6, 60, 8, 80, 14, 140, 16, 160, 21, 10, 23, 30, 29, 90, 211, 110, + 22, 20, 24, 40, 210, 100, 212, 120, 25, 50, 27, 70, 213, 130, 215, 150, 26, 60, 28, 80, + 214, 140, 216, 160, 31, 10, 33, 30, 39, 90, 311, 110, 32, 20, 34, 40, 310, 100, 312, 120, + 35, 50, 37, 70, 313, 130, 315, 150, 36, 60, 38, 80, 314, 140, 316, 160}; + constexpr int kOutSize = 24; + float expect_out[kOutSize] = { + 25, 50, 23, 30, 35, 50, 33, 30, + 13, 130, 11, 110, 26, 60, 24, 40, 36, 60, 34, 40, 14, 140, 12, 120}; + + float output[kOutSize]; + int in_shape[4] = {12, 2, 2, 2}; + int out_n = 2; + int block[2] = {3, 2}; + int crops[4] = {1, 2, 1, 1}; + BatchToSpaceForNHWC(input, output, in_shape, out_n, block, crops, sizeof(float)); + for (int i = 0; i < kOutSize && i < 32; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc new file mode 100644 index 0000000000..8e8c191866 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/batchnorm.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fused_batchnorm.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/common/file_utils.h" + +namespace mindspore { + +class TestBatchnormFp32 : public mindspore::Common { + public: + TestBatchnormFp32() {} +}; + +TEST_F(TestBatchnormFp32, BNTest) { + std::vector in_data = {0.0669681, 0.959215, 0.252686, 0.613594, 0.811776, 0.139469, 0.322848, 0.118354, + 0.082978, 0.399467, 0.961267, 0.0247456, 0.0714259, 0.0791484, 0.0648625, 0.561612, + 0.412069, 0.311492, 0.46109, 0.377125, 0.369283, 0.0332446, 0.696142, 0.715973, + 0.525524, 0.477265, 0.0336351, 0.751577, 0.377548, 0.964603, 0.0196834, 0.174865}; + std::vector in_data1 = {0.855446, 0.821765, 0.281008, 0.0798653, 0.22294, 0.793782, 0.963222, 0.17851, + 0.667549, 0.274381, 0.592842, 0.216552, 0.190274, 0.237873, 0.610063, 0.307559, + 0.830007, 0.760957, 0.583265, 0.763793, 0.456372, 0.391378, 0.547915, 0.862198, + 0.510794, 0.826776, 0.515894, 0.30071, 0.404987, 0.184773}; + std::vector in_data2 = {0.712438, 0.4927, 0.078419, 0.310429, 0.546871, 0.0667141, 0.874321, 0.0265647, + 0.685165, 0.732586, 0.952889, 0.506402, 0.540784, 0.131119, 0.357713, 0.678992, + 0.960839, 0.340706, 0.697678, 0.398146, 0.313321, 0.6485, 0.739153, 0.00190134, + 0.536842, 0.996873, 0.445276, 0.371212, 0.420397, 0.0930115}; + std::vector in_data3(32, 1); + std::vector in_data4(32, 0); + std::vector inputs_tensor; + std::vector outputs_tensor; + + BatchNormParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_BatchNorm; + op_param.epsilon_ = 0.001f; + + std::vector in_shape = {1, 2, 4, 4}; + + lite::tensor::Tensor input0_tensor; + lite::tensor::Tensor input1_tensor; + lite::tensor::Tensor input2_tensor; + inputs_tensor.push_back(&input0_tensor); + inputs_tensor.push_back(&input1_tensor); + inputs_tensor.push_back(&input2_tensor); + input0_tensor.SetData(in_data.data()); + input1_tensor.SetData(in_data1.data()); + input2_tensor.SetData(in_data2.data()); + input0_tensor.set_shape(in_shape); + + std::vector output(32); + std::vector corr_out(32); + std::vector output_shape = {1, 2, 4, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_BatchNorm}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + lite::Context ctx; + ctx.thread_num_ = 7; + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + FusedBatchNorm(in_data.data(), in_data3.data(), in_data4.data(), in_data1.data(), in_data2.data(), in_shape.data(), + 0.001f, corr_out.data()); + + printf("==================output data=================\n"); + for (int i = 0; i < 1 * 28; i++) { + std::cout << output[i] << " ,"; + } + std::cout << std::endl; + CompareOutputData(output.data(), corr_out.data(), 32, 0.00001); + + input0_tensor.SetData(nullptr); + input1_tensor.SetData(nullptr); + input2_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc new file mode 100644 index 0000000000..7c26e95022 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { + +class TestBiasGradFp32 : public mindspore::Common { + public: + TestBiasGradFp32() {} +}; + +TEST_F(TestBiasGradFp32, BiasGradFp32) { + // prepare stage + auto bias_param = new ArithmeticParameter(); + + size_t input_size; + std::string input_path = "./test_data/operators/biasgradfp32_1_dy_10_28_28_7.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_dy({10, 28, 28, 7}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(input_data); + + std::vector inputs = {&dy_tensor}; + + auto output_data = new float[7]; + std::vector dim_dw({7}); + lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); + dw_tensor.SetData(output_data); + std::vector outputs = {&dw_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BiasGrad}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(bias_param), NULL, desc); + + kernel_obj->Run(); + + printf("==================output data=================\n"); + for (int i = 0; i < 7; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin"; + lite::CompareOutput(output_data, output_path); + + // delete input_data; + // delete[] output_data; + delete bias_param; + MS_LOG(INFO) << "BiasGradFp32 passed"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc new file mode 100644 index 0000000000..da3c661697 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -0,0 +1,395 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/runtime/kernel/arm/fp32/convolution_1x1.h" +#include "src/runtime/kernel/arm/nnacl/matmul.h" +#include "src/runtime/kernel/arm/nnacl/strassen_matmul.h" + +namespace mindspore { +using mindspore::lite::tensor::Tensor; + +class TestConv1x1Fp32 : public mindspore::Common { + public: + TestConv1x1Fp32() {} +}; + +TEST_F(TestConv1x1Fp32, Input1x1PrePack1) { + auto conv_param = new ConvParameter(); + float in[] = {-0.59, -0.63, -7.26, -0.64, -6.403, 4.87, 9.612, 9.36, 12.84, -0.838, 6.588, 2.02, 13.756, + 15.92, 16.0, -7.82, 9.53, 1.77, 10.521, 13.45, 17.991, 17.063, 4.6859, 13.57, -6.31, 5.27, + 7.54, -7.418, 15.12, 0.6195, 1.5475, -5.925, -7.59, 18.13, 15.8, 19.86, -7.766, 13.25, 7.141, + -0.34, 16.254, -5.78, 16.13, -7.1, 6.259, 10.771, -5.54, 10.477, 9.2366, 12.258, -9.86, -8.29, + -4.9, 18.14, -5.400, 0.829, 7.4575, 12.075, 13.734, 16.51, -9.82, -4.9, 18.44, -0.808, 8.066, + 6.914, 2.5098, 10.985, 16.96, 1.721, -1.0, 2.096, 9.2553, 8.635, 9.2136, 13.558, 7.7505, -0.55, + 15.68, -7.3, 0.429, -0.560, 17.98, 19.068, 9.2764, 17.939, -6.51, -2.04, 7.29, -0.87, 10.311, + -6.74, -6.424, 18.708, -0.368, 9.725, 9.129, 6.99, 3.11, -1.573, -8.25, 10.427, 17.427, -9.739, + 17.32, 6.076, -3.5, 7.43, -2.659, -0.89, -9.157, 1.9951, -3.463, 15.22, 13.99, 4.39, 18.12}; + float correct[] = {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 15.12, -7.59, -7.766, 0.000, + 0.000, 0.429, 9.2764, 7.29, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000}; + + conv_param->input_h_ = 9; + conv_param->input_w_ = 13; + conv_param->input_channel_ = 1; + conv_param->output_h_ = 4; + conv_param->output_w_ = 5; + conv_param->stride_h_ = conv_param->stride_w_ = 4; + conv_param->pad_h_ = conv_param->pad_w_ = 2; + + float out[20] = {0}; + Conv1x1InputPackFp32(in, out, conv_param); + EXPECT_EQ(0, lite::CompareOutputData(out, correct, 20)); + delete conv_param; +} + +TEST_F(TestConv1x1Fp32, Input1x1PrePack2) { + auto conv_param = new ConvParameter(); + float in[] = { + 12.755477, 7.647509, 14.670943, -8.03628, -1.815172, 7.7517915, 5.6838546, 0.9693578, 10.86119, 10.960915, + 17.758, -4.800611, -8.743361, 1.6797531, -0.234721, 7.7575417, 10.19116, 11.744166, -2.674233, 8.977257, + 1.5364298, 14.600166, 16.625568, -4.820712, 10.050005, 4.114301, 10.436717, -7.443196, -2.669484, 5.3399734, + 7.5060234, 12.705402, -2.203446, 19.582493, 8.716431, 11.463841, 2.1704009, -7.740846, 0.6420606, 15.4524, + 1.9975507, -4.6742086, -0.425350, 7.120687, -9.663703, 18.799034, -4.425679, 10.846515, -1.993019, 0.2714671, + -8.511215, 16.797249, 18.438688, 8.391737, 15.632475, 16.98368, -5.901906, -2.718238, -3.131561, -3.707477, + -8.04332, 13.010143, 3.187699, 7.6656003, 9.344805, 2.100789, -7.123898, 10.088698, 7.8578715, -8.320831, + 6.821173, -2.263130, -2.886815, 2.285673, 10.664816, -4.747543, -4.9607406, 1.0546302, 15.628643, 1.7381196, + 18.267065, 11.504781, -0.193673, 16.431538, 8.011203, -3.3506372, 16.546675, -3.983052, 4.8116174, -9.49816, + 11.714877, 12.401133, -3.799531, 5.109032, 11.657709, 1.9226302, 0.9720376, 14.517606, 7.712793, 17.820406, + 17.644344, 15.314725, 17.884249, -3.6718662, -2.053803, 10.629432, 16.67133, -3.929358, 3.3747706, 8.818307, + -0.371532, 18.14205, 5.9272094, 12.691162, 6.816437, 8.310599, 17.566565, 16.581955, -7.433713, 2.5550082, + 9.1433325, -2.9258926, 5.7442937, -2.9434314, -9.864248, -0.122141, 11.5717945, -4.174809, -6.192147, 8.390994, + -7.4617224, 17.419308, 7.0560303, 11.58972, 17.671894, 6.2352304, 13.778206, 3.4766717, -6.687946, -7.887233, + -1.150991, -3.1441534, 17.288366, 13.669407, -4.997481, -6.147624, -5.6006193, -8.15764, 9.595266, 8.296087, + -0.9590447, -3.6464965, -8.155689, 4.8459644, 19.75259, 5.5307946, -6.934994, -9.928046, 4.02548, -9.45412, + 13.605555, 10.22008, -3.067481, 8.114803, 2.4563003, 0.4125615, 6.076172, -1.875376, 19.553644, -9.809106, + 17.235031, -4.222316, -9.534478, 18.639902, 1.7095382, 18.821035, -8.177748, -2.9353676, 2.064462, 12.190292, + -1.475221, -1.842325, -3.664825, 10.538533, -4.255415, 3.4860964, 11.418711, -2.348281, -4.527373, 19.534836}; + float correct[] = {12.755477, -8.03628, 5.6838546, 10.960915, 7.5060234, 19.582493, 2.1704009, + 15.4524, -8.04332, 7.6656003, -7.123898, -8.320831, 11.714877, 5.109032, + 0.9720376, 17.820406, 9.1433325, -2.9434314, 11.5717945, 8.390994, -0.9590447, + 4.8459644, -6.934994, -9.45412, -1.4752215, 10.538533, 11.418711, 19.534836}; + + conv_param->input_h_ = 19; + conv_param->input_w_ = 10; + conv_param->input_channel_ = 1; + conv_param->output_h_ = 7; + conv_param->output_w_ = 4; + conv_param->stride_h_ = conv_param->stride_w_ = 3; + conv_param->pad_h_ = conv_param->pad_w_ = 0; + + float out[28] = {0}; + Conv1x1InputPackFp32(in, out, conv_param); + CompareOutputData(out, correct, 28, 0.0001); + delete conv_param; +} + +TEST_F(TestConv1x1Fp32, Input1x1PrePack3) { + auto conv_param = new ConvParameter(); + conv_param->input_channel_ = 2; + conv_param->input_h_ = conv_param->input_w_ = 3; + conv_param->output_h_ = conv_param->output_w_ = 3; + conv_param->stride_h_ = conv_param->stride_w_ = 2; + conv_param->pad_h_ = conv_param->pad_w_ = 1; + + float in[] = {1.6767339, 12.25904, 19.018835, 3.0790641, -9.252135, -8.685675, 3.6115494, 3.2282279, 17.025112, + -5.052577, 12.750252, 12.701241, -8.9477215, -9.080522, 19.03931, -6.501229, -4.122992, 9.540845}; + float out[18] = {0}; + float correct[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.025112, + -5.052577, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + + Conv1x1InputPackFp32(in, out, conv_param); + EXPECT_EQ(0, lite::CompareOutputData(out, correct, 18)); + delete conv_param; +} + +TEST_F(TestConv1x1Fp32, Input1x1PrePack4) { + auto conv_param = new ConvParameter(); + conv_param->input_channel_ = 6; + conv_param->input_h_ = conv_param->input_w_ = 3; + conv_param->output_h_ = conv_param->output_w_ = 3; + conv_param->stride_h_ = conv_param->stride_w_ = 2; + conv_param->pad_h_ = conv_param->pad_w_ = 1; + float in[] = {4.1795, 13.142, -3.593, 16.505, 19.899, 8.5562, 19.969, -6.235, -2.380, -9.027, 9.5542, + 18.974, 23.622, 8.3608, 47.325, -14.36, 15.370, 4.3049, -0.784, 37.925, -0.081, 6.1298, + 0.6721, -1.517, 37.998, 13.719, 11.029, 1.7127, -1.770, 41.903, 9.0560, 14.988, 3.1866, + 0.0562, 8.1381, 9.1391, 14.530, -14.10, -8.115, -8.071, -8.158, 7.7566, 19.250, 17.923, + 13.584, 3.3293, 9.7341, 18.834, -1.514, -0.293, 18.686, 0.0873, 4.2010, -2.253}; + float correct[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127, + -1.770, 41.903, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + float out[54] = {0}; + Conv1x1InputPackFp32(in, out, conv_param); + EXPECT_EQ(0, lite::CompareOutputData(out, correct, 54)); + delete conv_param; +} + +TEST_F(TestConv1x1Fp32, Conv1x1WeightTest1) { + ConvParameter *conv_param = new ConvParameter(); + float in[] = {0.214637, 0.3815, 0.811557, 0.982146, 0.09123, 0.687198, 0.02742, 0.3360, 0.853275, + 0.674123, 0.81337, 0.57188, 0.706416, 0.2740942, 0.9045, 0.07155, 0.130864, 0.037712, + 0.5369175, 0.97283, 0.92133, 0.3588165, 0.7432479, 0.7886823, 0.870324, 0.230946, 0.343969, + 0.095415, 0.50036, 0.396918, 0.09029, 0.934583, 0.91616, 0.206713, 0.9756054, 0.614025, + 0.432057, 0.1493, 0.6787, 0.10642, 0.736823, 0.377668, 0.2464896, 0.93152, 0.315917, + 0.35745, 0.52233, 0.0263, 0.339392, 0.99447, 0.49129, 0.675686, 0.75703, 0.6665356, + 0.0491, 0.1070, 0.18899, 0.929156, 0.4633427, 0.08585, 0.040709, 0.2478724, 0.5238441, + 0.0579918, 0.531636, 0.085524, 0.640923, 0.336395, 0.218651, 0.630491}; + float co[] = {0.214637, 0.81337, 0.92133, 0.09029, 0.3815, 0.57188, 0.3588165, 0.934583, 0.811557, + 0.706416, 0.7432479, 0.91616, 0.982146, 0.2740942, 0.7886823, 0.206713, 0.09123, 0.9045, + 0.870324, 0.9756054, 0.687198, 0.07155, 0.230946, 0.614025, 0.02742, 0.130864, 0.343969, + 0.432057, 0.3360, 0.037712, 0.095415, 0.1493, 0.853275, 0.5369175, 0.50036, 0.6787, + 0.674123, 0.97283, 0.396918, 0.10642, 0, 0, 0, 0, 0, + 0, 0, 0, 0.736823, 0.49129, 0.040709, 0, 0.377668, 0.675686, + 0.2478724, 0, 0.2464896, 0.75703, 0.5238441, 0, 0.93152, 0.6665356, 0.0579918, + 0, 0.315917, 0.0491, 0.531636, 0, 0.35745, 0.1070, 0.085524, 0, + 0.52233, 0.18899, 0.640923, 0, 0.0263, 0.929156, 0.336395, 0, 0.339392, + 0.4633427, 0.218651, 0, 0.99447, 0.08585, 0.630491, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + + conv_param->input_channel_ = 10; + conv_param->output_channel_ = 7; + float out[96] = {0}; + Pack1x1WeightFp32(in, out, conv_param); + EXPECT_EQ(0, lite::CompareOutputData(out, co, 96)); + delete conv_param; +} + +TEST_F(TestConv1x1Fp32, PostConvFuncC4Test1) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, + -2.6300175, 0, 0, 0, -7.2690716, 0, 0, 0, + 11.1863365, 0, 0, 0, -3.4595785, 0, 0, 0, + -8.344107, 0, 0, 0, -3.792715, 0, 0, 0, + -7.0394287, 0, 0, 0, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float out[40] = {0}; + + float no[] = {-8.646674, -5.3524485, 8.56133, -1.2702886, -2.6201365, -4.7133026, 1.2270198, 17.954533, + 11.086085, -7.2591906, -0.11849791, -3.9182835, 11.90631, 0.3088621, 11.196218, -4.530405, + -0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027, + -8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942, + -4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402}; + PostConvFuncFp32C4(in, out, bias, 5, 8, 5, false, false); + CompareOutputData(out, no, 40, 0.0001); + + float relu[] = {0, 0, 8.56133, 0, 0, 0, 1.2270198, 17.954533, 11.086085, 0, + 0, 0, 11.90631, 0.3088621, 11.196218, 0, 0, 0, 0, 0, + 0, 0, 0, 9.464027, 0, 14.387108, 8.693133, 8.080041, 0, 0, + 2.8319538, 7.177942, 0, 12.194644, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C4(in, out, bias, 5, 8, 5, true, false); + CompareOutputData(out, relu, 40, 0.0001); + + float corr_relu6[] = {0, 0, 6, 0, 0, 0, 1.2270198, 6, 6, 0, 0, 0, 6, 0.3088621, 6, 0, 0, 0, 0, 0, + 0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C4(in, out, bias, 5, 8, 5, false, true); + CompareOutputData(out, corr_relu6, 40, 0.0001); + + float nob_relu[] = {0, 0, 7.5724425, 0, 0, 0, 0.7406984, 16.965645, + 10.888806, 0, 0, 0, 10.917422, 0.11158327, 11.1863365, 0, + 0, 0, 0, 0, 0, 0, 0, 9.266748, + 0, 13.644127, 8.206812, 7.091153, 0, 0, 2.0889723, 6.6916203, + 0, 11.997365, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C4(in, out, nullptr, 5, 8, 5, true, false); + CompareOutputData(out, nob_relu, 40, 0.0001); +} + +TEST_F(TestConv1x1Fp32, PostConvFuncC4Test2) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806, + -0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815, + -6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964, + -2.6300175, 0, 0, 0, -7.2690716, 0, 0, 0, + 11.1863365, 0, 0, 0, -3.4595785, 0, 0, 0, + -8.344107, 0, 0, 0, -3.792715, 0, 0, 0, + -7.0394287, 0, 0, 0, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float corr[] = {-8.646674, -5.3524485, 8.56133, -1.2702886, -2.6201365, -4.7133026, 1.2270198, 17.954533, + 11.086085, -7.2591906, -0.11849791, -3.9182835, 11.90631, 0.3088621, 11.196218, -4.530405, + -0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027, + -8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942, + -4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402}; + float out[40] = {0}; + + int thread_count_ = 2; + int thread_oc4_stride_ = 1; + int output_channel = 5; + int plane_size = 8; + + for (int i = 0; i < thread_count_; i++) { + int cur_oc = MSMIN(thread_oc4_stride_ * 4, output_channel - i * thread_oc4_stride_ * 4); + if (cur_oc <= 0) break; + PostConvFuncFp32C4(in + thread_oc4_stride_ * i * 8 * 4, out + i * i * thread_oc4_stride_ * 4, + bias + i * thread_oc4_stride_ * 4, cur_oc, plane_size, output_channel, false, false); + } + + CompareOutputData(out, corr, 40, 0.0001); +} + +int Conv1x1TestInit1(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, float **correct) { + lite::tensor::Tensor *in_t = + new lite::tensor::Tensor(kNumberTypeFloat, {1, 2, 3, 4}, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in[] = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, 14.727955, -8.080715, + 13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352, + 6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505}; + memcpy(in_t->Data(), in, sizeof(float) * 24); + inputs_->push_back(in_t); + + lite::tensor::Tensor *weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, {3, 1, 1, 4}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight[] = {-0.7308652, 0.5257509, -0.87825793, -1.123181, -1.2206168, 0.562695, + 1.5382664, -0.5020635, 0.8591602, -0.26410004, 1.1262615, 0.073132955}; /* nhwc */ + memcpy(weight_t->Data(), weight, sizeof(float) * 12); + inputs_->push_back(weight_t); + + lite::tensor::Tensor *bias_t = + new lite::tensor::Tensor(kNumberTypeFloat, {3}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + float bias[] = {2, 2, 2}; + memcpy(bias_t->Data(), bias, sizeof(float) * 3); + inputs_->push_back(bias_t); + + lite::tensor::Tensor *out_t = + new lite::tensor::Tensor(kNumberTypeFloat, {1, 2, 3, 3}, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float co[] = {2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.3731456, 1.6877825, 12.427691, 2., 2., 2.}; + memcpy(*correct, co, out_t->ElementsNum() * sizeof(float)); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 1; + conv_param->stride_h_ = conv_param->stride_w_ = 2; + conv_param->dilation_h_ = conv_param->dilation_w_ = 1; + conv_param->pad_h_ = conv_param->pad_w_ = 1; + conv_param->is_relu_ = conv_param->is_relu6_ = false; + return out_t->ElementsNum(); +} + +TEST_F(TestConv1x1Fp32, Conv1x1Test1) { + std::vector inputs_; + std::vector outputs_; + auto conv_param = new ConvParameter(); + lite::Context *ctx = new lite::Context(); + ctx->thread_num_ = 1; + float *correct; + int total_size = Conv1x1TestInit1(&inputs_, &outputs_, conv_param, &correct); + kernel::Convolution1x1CPUKernel *conv1x1 = + new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx); + + conv1x1->Init(); + conv1x1->Run(); + + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete conv_param; + delete conv1x1; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +int Conv1x1TestInit2(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, float **correct) { + size_t buffer_size; + lite::tensor::Tensor *in_t = new lite::tensor::Tensor(kNumberTypeFloat, {1, 300, 300, 24}, schema::Format_NHWC, + static_cast(1)); + in_t->MallocData(); + std::string input_path = "./conv/conv1x1fp32_input1_nhwc.bin"; + auto in = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &buffer_size)); + memcpy(in_t->Data(), in, buffer_size); + inputs_->push_back(in_t); + + lite::tensor::Tensor *weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, {40, 1, 1, 24}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + std::string weight_path = "./conv/conv1x1fp32_weight1_nhwc.bin"; + auto weight = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size)); + memcpy(weight_t->Data(), weight, buffer_size); + inputs_->push_back(weight_t); + + lite::tensor::Tensor *bias_t = + new lite::tensor::Tensor(kNumberTypeFloat, {40}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + std::string bias_path = "./conv/conv1x1fp32_bias1_nhwc.bin"; + auto bias = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size); + memcpy(bias_t->Data(), bias, buffer_size); + inputs_->push_back(bias_t); + + lite::tensor::Tensor *out_t = new lite::tensor::Tensor(kNumberTypeFloat, {1, 300, 300, 40}, schema::Format_NHWC, + static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + std::string out_path = "./conv/conv1x1fp32_output1_nhwc.bin"; + auto out_nhwc = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size); + *correct = reinterpret_cast(malloc(buffer_size)); + memcpy(*correct, out_nhwc, buffer_size); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 1; + conv_param->stride_h_ = conv_param->stride_w_ = 1; + conv_param->dilation_h_ = conv_param->dilation_w_ = 1; + conv_param->pad_h_ = conv_param->pad_w_ = 0; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; + return out_t->ElementsNum(); +} + +TEST_F(TestConv1x1Fp32, Conv1x1Test2) { + std::vector inputs_; + std::vector outputs_; + auto conv_param = new ConvParameter(); + lite::Context *ctx = new lite::Context(); + ctx->thread_num_ = 2; + float *correct; + int total_size = Conv1x1TestInit2(&inputs_, &outputs_, conv_param, &correct); + kernel::Convolution1x1CPUKernel *conv1x1 = + new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx); + + conv1x1->Init(); + conv1x1->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + + /* running warm up */ + for (int i = 0; i < 0; i++) { + conv1x1->Run(); + } + + /* running time cost */ + int loop_count = 1; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + conv1x1->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + uint64_t time_avg = cost / loop_count; + printf("1x1 average time : %f ms\n", time_avg / 1000.0f); + + delete conv_param; + delete conv1x1; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc new file mode 100644 index 0000000000..3394ecb5af --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/ops/ops.h" + +namespace mindspore { +class TestConvolutionDwFp32 : public mindspore::Common { + public: + TestConvolutionDwFp32() {} +}; + +void InitConvDwParam(ConvParameter *conv_param) { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 288; + conv_param->input_w_ = 288; + conv_param->input_channel_ = 25; + + conv_param->output_batch_ = 1; + conv_param->output_h_ = 288; + conv_param->output_w_ = 288; + conv_param->output_channel_ = 25; + + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; +} + +void InitConvDwCreator(std::vector *inputs, std::vector *outputs, + const ConvParameter *conv_param) { + // prepare input, format NHWC + size_t input_size; + std::string input_path = "./test_data/convDw/convDwfp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + auto *input = new lite::tensor::Tensor; + input->set_data_type(kNumberTypeFloat32); + input->SetFormat(schema::Format_NHWC); + input->set_shape({conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, conv_param->input_channel_}); + input->MallocData(); + memcpy(input->Data(), input_data, input_size); + + // prepare weight, format co kh kw ci, ci = 1 + size_t weight_size; + std::string weight_path = "./test_data/convDw/convDwfp32_weight.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + + auto *weight = new lite::tensor::Tensor; + weight->set_data_type(kNumberTypeFloat32); + weight->set_shape({conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_, 1}); + weight->MallocData(); + memcpy(weight->Data(), weight_data, weight_size); + + // prepare bias + auto *bias = new lite::tensor::Tensor; + bias->set_data_type(kNumberTypeFloat32); + bias->set_shape({conv_param->output_channel_}); + bias->MallocData(); + memset(bias->Data(), 0, bias->ElementsNum() * sizeof(float)); + + inputs->push_back(input); + inputs->push_back(weight); + inputs->push_back(bias); + + auto *output = new lite::tensor::Tensor; + output->set_data_type(kNumberTypeFloat32); + output->set_shape( + {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, conv_param->output_channel_}); + output->SetFormat(schema::Format_NHWC); + output->MallocData(); + memset(output->Data(), 0, output->ElementsNum() * sizeof(float)); + outputs->push_back(output); +} + +TEST_F(TestConvolutionDwFp32, ConvDwFp32Accuracy) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvDwParam(conv_param); + + // init ctx + auto ctx = new Context(); + ctx->thread_num_ = 4; + + // init tensor + std::vector inputs; + std::vector outputs; + InitConvDwCreator(&inputs, &outputs, conv_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + // op run + kernel->Run(); + + std::cout << "==================output data=================" << std::endl; + auto output_ptr = reinterpret_cast(outputs[0]->Data()); + for (int i = 0; i < 20; i++) { + std::cout << output_ptr[i] << ", "; + } + std::cout << std::endl; + + // read output data, format NHWC + size_t output_size; + std::string output_path = "./test_data/convDw/convDwfp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + + // compare + CompareOutputData(output_ptr, correct_data, outputs[0]->ElementsNum(), 0.0001); + + delete conv_param; + for (int i = 0; i < inputs.size(); i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + delete correct_data; + MS_LOG(INFO) << "TestConvolutionDwFp32 accuracy passed"; +} + +TEST_F(TestConvolutionDwFp32, ConvDwFp32Performance) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvDwParam(conv_param); + + // init ctx + auto ctx = new Context(); + ctx->thread_num_ = 1; + + // init tensor + std::vector inputs; + std::vector outputs; + InitConvDwCreator(&inputs, &outputs, conv_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + + /* running warm up */ + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + /* running time cost */ + int loop_count = 10; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + uint64_t time_avg = cost / loop_count; + printf("Convolution_depthwise fp32 average time : %f ms\n", time_avg / 1000.0f); + + delete conv_param; + for (int i = 0; i < inputs.size(); i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + MS_LOG(INFO) << "TestConvolutionDwFp32 performance passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc new file mode 100644 index 0000000000..7847e013a3 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc @@ -0,0 +1,521 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestConvolutionGradFp32 : public mindspore::Common { + public: + TestConvolutionGradFp32() {} +}; + +void InitConvParamGroup1FP32(ConvParameter *conv_param) { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 28; + conv_param->input_w_ = 28; + conv_param->input_channel_ = 3; + + conv_param->output_batch_ = 1; + conv_param->output_h_ = 28; + conv_param->output_w_ = 28; + conv_param->output_channel_ = 32; + + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + + conv_param->group_ = 1; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; + conv_param->thread_num_ = 1; +} + +void InitConvParamGroup3FP32(ConvParameter *conv_param) { + InitConvParamGroup1FP32(conv_param); + conv_param->group_ = 3; + conv_param->output_channel_ = 18; +} + +void InitConvParamGroup3Dilation2FP32(ConvParameter *conv_param) { + InitConvParamGroup3FP32(conv_param); + conv_param->dilation_h_ = 2; + conv_param->dilation_w_ = 2; + conv_param->output_h_ = 26; + conv_param->output_w_ = 26; +} + +TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup1FP32(conv_param); + + size_t dy_size; + std::string dy_path = "./test_data/conv/convfp32_dy_1_28_28_32.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); + std::vector dim_dy({1, 28, 28, 32}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = + conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_x_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_x({1, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(input_data); + + auto dw_data = new float[output_data_size]; + std::vector dim_dw({32, 3, 3, 3}); + lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); + dw_tensor.SetData(dw_data); + std::vector inputs = {&dy_tensor, &x_tensor}; + std::vector outputs = {&dw_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_dw_32_3_3_3.bin"; + auto res = lite::CompareRelativeOutput(dw_data, output_path); + + EXPECT_EQ(res, 0); + + // delete input_data; + // delete dy_data; + // delete [] dw_data; + delete kernel; + delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; +} + +TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup1FP32(conv_param); + + size_t dy_size; + std::string dy_path = "./test_data/conv/convfp32_dy_1_28_28_32.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); + std::vector dim_dy({1, 28, 28, 32}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + + size_t w_size; + std::string w_path = "./test_data/conv/convfp32_w_32_3_3_3.bin"; + auto w_data = reinterpret_cast(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); + std::vector dim_dw({32, 3, 3, 3}); + lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_dw); + w_tensor.SetData(w_data); + + size_t output_data_size = + conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + auto dx_data = new float[output_data_size]; + std::vector dim_dx({1, 28, 28, 3}); + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); + dx_tensor.SetData(dx_data); + + std::vector inputs = {&dy_tensor, &w_tensor}; + std::vector outputs = {&dx_tensor}; + // runtime part + + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_dx_1_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(dx_data, output_path); + EXPECT_EQ(res, 0); + + delete kernel; + delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; +} + +TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup3FP32(conv_param); + + size_t dy_size; + std::string dy_path = "./test_data/conv/convfp32_dy_g3_1_28_28_18.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); + std::vector dim_dy({1, 28, 28, 18}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * + conv_param->input_channel_ / conv_param->group_; + + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_x_g3_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_x({1, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(input_data); + + auto dw_data = new float[output_data_size]; + std::vector dim_dw({18, 3, 3, 1}); + lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); + dw_tensor.SetData(dw_data); + std::vector inputs = {&dy_tensor, &x_tensor}; + std::vector outputs = {&dw_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_dw_g3_18_3_3_3.bin"; + auto res = lite::CompareRelativeOutput(dw_data, output_path); + EXPECT_EQ(res, 0); + + // delete input_data; + // delete dy_data; + // delete [] dw_data; + delete kernel; + delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; +} + +TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup3FP32(conv_param); + + size_t dy_size; + std::string dy_path = "./test_data/conv/convfp32_dy_g3_1_28_28_18.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); + std::vector dim_dy({1, 28, 28, 18}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + + size_t w_size; + std::string w_path = "./test_data/conv/convfp32_w_g3_18_3_3_3.bin"; + auto w_data = reinterpret_cast(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); + std::vector dim_dw({18, 3, 3, 1}); + lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_dw); + w_tensor.SetData(w_data); + + size_t output_data_size = + conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + auto dx_data = new float[output_data_size]; + std::vector dim_dx({1, 28, 28, 3}); + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); + dx_tensor.SetData(dx_data); + + std::vector inputs = {&dy_tensor, &w_tensor}; + std::vector outputs = {&dx_tensor}; + // runtime part + + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_dx_g3_1_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(dx_data, output_path); + EXPECT_EQ(res, 0); + + delete kernel; + delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; +} + +TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { + // prepare stage + auto conv_param = new ConvParameter(); + + InitConvParamGroup3Dilation2FP32(conv_param); + + size_t dy_size; + std::string dy_path = "./test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); + std::vector dim_dy({1, 26, 26, 18}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * + conv_param->input_channel_ / conv_param->group_; + + size_t input_size; + std::string input_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_x({1, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(input_data); + + auto dw_data = new float[output_data_size]; + std::vector dim_dw({18, 3, 3, 1}); + lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); + dw_tensor.SetData(dw_data); + std::vector inputs = {&dy_tensor, &x_tensor}; + std::vector outputs = {&dw_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin"; + auto res = lite::CompareRelativeOutput(dw_data, output_path); + EXPECT_EQ(res, 0); + // delete input_data; + // delete dy_data; + // delete [] dw_data; + delete kernel; + delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; +} + +TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup3Dilation2FP32(conv_param); + + size_t dy_size; + std::string dy_path = "./test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); + std::vector dim_dy({1, 26, 26, 18}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + + size_t w_size; + std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin"; + auto w_data = reinterpret_cast(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); + std::vector dim_w({18, 3, 3, 1}); + lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); + w_tensor.SetData(w_data); + + size_t output_data_size = + conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + auto dx_data = new float[output_data_size]; + std::vector dim_dx({1, 28, 28, 3}); + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); + dx_tensor.SetData(dx_data); + + std::vector inputs = {&dy_tensor, &w_tensor}; + std::vector outputs = {&dx_tensor}; + // runtime part + + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(dx_data, output_path); + EXPECT_EQ(res, 0); + + delete kernel; + delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; +} + +// TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { +// // prepare stage +// auto conv_param = new ConvParameter(); +// InitConvParamGroup3Dilation2FP32(conv_param); + +// size_t x_size; +// std::string x_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin"; +// auto x_data = reinterpret_cast(mindspore::lite::ReadFile(x_path.c_str(), &x_size)); +// std::vector dim_x({1, 28, 28, 3}); +// tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); +// x_tensor.SetData(x_data); + +// size_t w_size; +// std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin"; +// auto w_data = reinterpret_cast(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); +// std::vector dim_w({18, 3, 3, 1}); +// tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); +// w_tensor.SetData(w_data); + +// size_t output_data_size = +// conv_param->output_batch_ * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; +// auto y_data = new float[output_data_size]; +// std::vector dim_y({1, 26, 26, 18}); +// tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); +// y_tensor.SetData(y_data); + +// std::vector inputs = {&x_tensor, &w_tensor}; +// std::vector outputs = {&y_tensor}; +// // runtime part + +// printf("Calculating runtime cost...\n"); +// uint64_t time_avg = 0; + +// lite::Context context; +// ; +// context.deviceCtx.type = lite::DT_CPU; +// context.threadNum = 1; + +// kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D}; +// auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); +// auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc); + +// kernel->train(); +// EXPECT_EQ(kernel->is_train(), 1); + +// // warm up loop +// for (int i = 0; i < 3; i++) { +// kernel->Run(); +// } + +// int loop_count = 100; +// auto time_start = mindspore::lite::GetTimeUs(); +// for (int i = 0; i < loop_count; i++) { +// kernel->Run(); +// } +// auto time_end = mindspore::lite::GetTimeUs(); +// auto cost = time_end - time_start; +// time_avg = cost / loop_count; +// printf("single thread running time : %f ms\n", time_avg / 1000.0f); + +// std::string output_path = "./test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin"; +// auto res = lite::CompareRelativeOutput(y_data, output_path); +// EXPECT_EQ(res, 0); + +// delete kernel; +// delete conv_param; + +// MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed"; +// } + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc new file mode 100644 index 0000000000..89704cfa9f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc @@ -0,0 +1,234 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/crop.h" + +namespace mindspore { +class CropTestFp32 : public mindspore::Common { + public: + CropTestFp32() = default; +}; + +TEST_F(CropTestFp32, CropTest1) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 2; + float expect_out[kOutSize] = {8, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {2, 1, 1, 1}; + CropParameter crop_param; + crop_param.axis_ = 1; + crop_param.offset_[0] = 1; + crop_param.offset_[1] = 1; + crop_param.offset_[2] = 1; + crop_param.op_parameter_.thread_num_ = 1; + crop_param.thread_id_ = 0; + Crop4D(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest2) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 4; + float expect_out[kOutSize] = {13, 14, 15, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {1, 1, 2, 2}; + CropParameter crop_param; + crop_param.axis_ = 0; + crop_param.offset_[0] = 1; + crop_param.offset_[1] = 1; + crop_param.offset_[2] = 0; + crop_param.offset_[3] = 0; + crop_param.op_parameter_.thread_num_ = 1; + crop_param.thread_id_ = 0; + Crop4D(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest3) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 8; + float expect_out[kOutSize] = {2, 4, 6, 8, 10, 12, 14, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {2, 2, 2, 1}; + CropParameter crop_param; + crop_param.axis_ = 3; + crop_param.offset_[0] = 1; + crop_param.op_parameter_.thread_num_ = 1; + crop_param.thread_id_ = 0; + Crop4D(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest4) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 8; + float expect_out[kOutSize] = {2, 4, 6, 8, 10, 12, 14, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {2, 2, 2, 1}; + CropParameter crop_param; + crop_param.axis_ = 3; + crop_param.offset_[0] = 1; + crop_param.op_parameter_.thread_num_ = 2; + crop_param.thread_id_ = 0; + Crop4D(input, output, in_shape, out_shape, &crop_param); + crop_param.thread_id_ = 1; + Crop4D(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest5) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 2; + float expect_out[kOutSize] = {8, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {2, 1, 1, 1}; + CropParameter crop_param; + crop_param.axis_ = 1; + crop_param.offset_[0] = 1; + crop_param.offset_[1] = 1; + crop_param.offset_[2] = 1; + Crop4DNoParallel(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest6) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 4; + float expect_out[kOutSize] = {13, 14, 15, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {1, 1, 2, 2}; + CropParameter crop_param; + crop_param.axis_ = 0; + crop_param.offset_[0] = 1; + crop_param.offset_[1] = 1; + crop_param.offset_[2] = 0; + crop_param.offset_[3] = 0; + Crop4DNoParallel(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest7) { + float input[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const int kOutSize = 8; + float expect_out[kOutSize] = {2, 4, 6, 8, 10, 12, 14, 16}; + + float output[kOutSize]; + int in_shape[4] = {2, 2, 2, 2}; + int out_shape[4] = {2, 2, 2, 1}; + CropParameter crop_param; + crop_param.axis_ = 3; + crop_param.offset_[0] = 1; + Crop4DNoParallel(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest8) { + float input[27] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 11, 12, 13, 14, 15, 16, 17, 18, 19, + 21, 22, 23, 24, 25, 26, 27, 28, 29}; + const int kOutSize = 4; + float expect_out[kOutSize] = {15, 16, 18, 19}; + + float output[kOutSize]; + int in_shape[4] = {1, 3, 3, 3}; + int out_shape[4] = {1, 1, 2, 2}; + CropParameter crop_param; + crop_param.axis_ = 1; + crop_param.offset_[0] = 1; + crop_param.offset_[1] = 1; + crop_param.offset_[2] = 1; + crop_param.op_parameter_.thread_num_ = 2; + crop_param.thread_id_ = 0; + Crop4D(input, output, in_shape, out_shape, &crop_param); + crop_param.thread_id_ = 1; + Crop4D(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(CropTestFp32, CropTest9) { + float input[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 110, 111, 112, 113, 114, 115, 116, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 210, 211, 212, 213, 214, 215, 216, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 310, 311, 312, 313, 314, 315, 316}; + const int kOutSize = 8; + float expect_out[kOutSize] = {16, 17, 110, 111, 26, 27, 210, 211}; + + float output[kOutSize]; + int in_shape[4] = {1, 4, 4, 4}; + int out_shape[4] = {1, 2, 2, 2}; + CropParameter crop_param; + crop_param.axis_ = 1; + crop_param.offset_[0] = 1; + crop_param.offset_[1] = 1; + crop_param.offset_[2] = 1; + crop_param.op_parameter_.thread_num_ = 2; + crop_param.thread_id_ = 0; + Crop4D(input, output, in_shape, out_shape, &crop_param); + crop_param.thread_id_ = 1; + Crop4D(input, output, in_shape, out_shape, &crop_param); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +} // namespace mindspore + diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc new file mode 100644 index 0000000000..727f86f3b9 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -0,0 +1,548 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/ops/ops.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/deconv.h" + +namespace mindspore { +class TestDeConvolutionFp32 : public mindspore::Common { + public: + TestDeConvolutionFp32() {} +}; + +TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack1) { + float in[] = {0.43005997, -0.01335099, -0.43214464, -0.2569654, -0.14664753, -0.09249142, 0.42330834, 0.17678244, + -0.26901904, 0.29920393, -0.25139654, 0.04580693, 0.08898365, -0.29335496, 0.1332809, 0.06561925, + 0.50099367, -0.45963442, -0.17191549, -0.1517635, -0.54385597, 0.20007996, 0.3174582, -0.13803318, + -0.10295965, 0.03531377, -0.05687982, 0.09801699, -0.1504936, 0.27094424, -0.15454058, 0.25500196, + 0.03428256, 0.1711275, -0.28639716, 0.05972834, 0.1301975, 0.09662235, -0.26297596, 0.25723842, + 0.37723106, -0.49640322, 0.21951586, -0.25885767, -0.44244745, 0.04153876, 0.41899854, 0.07920247, + 0.31681255, 0.3300002, 0.23956111, 0.13012694, 0.26047292, 0.0851135, -0.185474, 0.306445, + 0.20750166, -0.13887969, -0.15064844, -0.08100204, 0.08206631, 0.3151005, 0.26807567, -0.6340778, + 0.1019667, 0.14200483, -0.56623703, 0.47877932, 0.13249867, 0.3862773, 0.7469436, 0.14524518, + 0.42495733, 0.08011179, 0.19647601, -0.03030056, 0.12770538, -0.32460797, -0.2103409, 0.33223677, + -0.47110182, -0.5424416, 0.18340437, 0.3781465, 0.04931778, 0.17888185, 0.04547426, -0.01483545, + 0.29989168, 0.12018301, 0.00213889, 0.21470474, -0.4031554, -0.10013647, -0.12780161, -0.28953925, + 0.05002394, 0.5460746, -0.7209624, 0.32692385, -0.09215609, -0.07226299, 0.47478926, -0.6297518, + 0.22869332, -0.33726704, -0.24732, 0.07623845, 0.38042688, -0.18950662, -0.16825019, 0.49407697, + -0.10242693, 0.59533256, -0.11732046, 0.7062394, 0.35063574, -0.17253993, -0.14738934, 0.26435736}; + float co[] = { + 0.43005997, -0.01335099, -0.43214464, -0.2569654, -0.10295965, 0.03531377, -0.05687982, 0.09801699, 0.31681255, + 0.3300002, 0.23956111, 0.13012694, 0.42495733, 0.08011179, 0.19647601, -0.03030056, 0.05002394, 0.5460746, + -0.7209624, 0.32692385, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.42330834, 0.17678244, -0.26901904, 0.29920393, + -0.15454058, 0.25500196, 0.03428256, 0.1711275, -0.185474, 0.306445, 0.20750166, -0.13887969, -0.2103409, + 0.33223677, -0.47110182, -0.5424416, 0.47478926, -0.6297518, 0.22869332, -0.33726704, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.08898365, -0.29335496, 0.1332809, 0.06561925, 0.1301975, 0.09662235, -0.26297596, 0.25723842, + 0.08206631, 0.3151005, 0.26807567, -0.6340778, 0.04931778, 0.17888185, 0.04547426, -0.01483545, 0.38042688, + -0.18950662, -0.16825019, 0.49407697, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, -0.17191549, -0.1517635, -0.54385597, + 0.20007996, 0.21951586, -0.25885767, -0.44244745, 0.04153876, -0.56623703, 0.47877932, 0.13249867, 0.3862773, + 0.00213889, 0.21470474, -0.4031554, -0.10013647, -0.11732046, 0.7062394, 0.35063574, -0.17253993, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, -0.14664753, -0.09249142, 0.000, 0.000, -0.1504936, 0.27094424, 0.000, + 0.000, 0.26047292, 0.0851135, 0.000, 0.000, 0.12770538, -0.32460797, 0.000, 0.000, + -0.09215609, -0.07226299, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, -0.25139654, 0.04580693, + 0.000, 0.000, -0.28639716, 0.05972834, 0.000, 0.000, -0.15064844, -0.08100204, 0, + 0, 0.18340437, 0.3781465, 0.000, 0.000, -0.24732, 0.07623845, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.50099367, -0.45963442, 0.000, 0.000, 0.37723106, -0.49640322, + 0.000, 0.000, 0.1019667, 0.14200483, 0.000, 0.000, 0.29989168, 0.12018301, 0.000, + 0.000, -0.10242693, 0.59533256, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.3174582, + -0.13803318, 0.000, 0.000, 0.41899854, 0.07920247, 0.000, 0.000, 0.7469436, 0.14524518, + 0.000, 0.000, -0.12780161, -0.28953925, 0.000, 0.000, -0.14738934, 0.26435736, 0.000, + 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, + 0.000, 0.000, 0.000, 0.00}; + float dst[256] = {0}; + PackDeConvWeightFp32(in, dst, 5, 6, 2 * 2); + EXPECT_EQ(0, lite::CompareOutputData(dst, co, 256)); +} + +TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { + float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, + -0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562, + 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873}; + + float co[] = {4.1795, 13.142, -3.593, 0, -2.380, -9.027, 23.622, 0, -0.784, 37.925, -0.081, 0, 11.029, + 1.7127, 9.0560, 0, 14.530, -14.10, -8.115, 0, 13.584, 3.3293, -1.514, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 16.505, 19.969, -6.235, 0, 8.3608, 47.325, -14.36, + 0, 6.1298, 37.998, 13.719, 0, 14.988, 3.1866, 0.0562, 0, -8.071, 19.250, 17.923, 0, + -0.293, 18.686, 0.0873, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float dst[64] = {0}; + PackDeConvWeightFp32(in, dst, 6, 3, 2 * 1); + EXPECT_EQ(0, lite::CompareOutputData(dst, co, 64)); +} + +TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { + float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0, + -5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0, + -0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0, + -5.2733865, -0.96367484, -4.731118, -7.576815, -3.4595785, 0, 0, 0, + -6.1621623, -0.6315082, -9.140878, 9.266748, -8.344107, 0, 0, 0, + 13.644127, 8.206812, 7.091153, -0.50162584, -3.792715, 0, 0, 0, + 2.0889723, 6.6916203, -5.3981733, 11.997365, -7.0394287, 0, 0, 0, + -9.254076, -5.5964484, -5.981469, -0.51114964, -2.7693212, 0, 0, 0}; + float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0}; + float out[40] = {0}; + + float no[] = {-8.646674, -5.3524485, 8.56133, -1.2702886, -2.6201365, -4.7133026, 1.2270198, 17.954533, + 11.086085, -7.2591906, -0.11849791, -3.9182835, 11.90631, 0.3088621, 11.196218, -4.530405, + -0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027, + -8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942, + -4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402}; + PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, false); + CompareOutputData(out, no, 40, 0.0001); + + float relu[] = {0, 0, 8.56133, 0, 0, 0, 1.2270198, 17.954533, 11.086085, 0, + 0, 0, 11.90631, 0.3088621, 11.196218, 0, 0, 0, 0, 0, + 0, 0, 0, 9.464027, 0, 14.387108, 8.693133, 8.080041, 0, 0, + 2.8319538, 7.177942, 0, 12.194644, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 5, 8, 5, true, false); + CompareOutputData(out, relu, 40, 0.0001); + + float corr_relu6[] = {0, 0, 6, 0, 0, 0, 1.2270198, 6, 6, 0, 0, 0, 6, 0.3088621, 6, 0, 0, 0, 0, 0, + 0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, true); + CompareOutputData(out, corr_relu6, 40, 0.0001); + + float nob_relu[] = {0, 0, 7.5724425, 0, 0, 0, 0.7406984, 16.965645, + 10.888806, 0, 0, 0, 10.917422, 0.11158327, 11.1863365, 0, + 0, 0, 0, 0, 0, 0, 0, 9.266748, + 0, 13.644127, 8.206812, 7.091153, 0, 0, 2.0889723, 6.6916203, + 0, 11.997365, 0, 0, 0, 0, 0, 0}; + PostConvFuncFp32C8(in, out, nullptr, 5, 8, 5, true, false); + CompareOutputData(out, nob_relu, 40, 0.0001); +} + +int DeConvTestInit1(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, float **correct) { + std::vector in_dims_nhwc = {1, 5, 7, 2}; + lite::tensor::Tensor *in_t = + new lite::tensor::Tensor(kNumberTypeFloat, in_dims_nhwc, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in_nchw[] = { + 0.39451003, 0.15045597, 0.5367726, 0.62690735, 0.113554195, 0.5402554, 0.5522764, 0.044319753, 0.25721782, + 0.41789535, 0.6717553, 0.72254324, 0.15164013, 0.93585724, 0.33732107, 0.14599903, 0.20070823, 0.640386, + 0.74077445, 0.088589266, 0.08755991, 0.4489046, 0.7409207, 0.7373529, 0.8887349, 0.045393247, 0.6483991, + 0.7542141, 0.8730748, 0.5480396, 0.19493233, 0.41220096, 0.77443165, 0.9909433, 0.8081086, 0.91432786, + 0.97605807, 0.48640794, 0.7690306, 0.9381521, 0.44073114, 0.27656683, 0.0725352, 0.53911537, 0.994353, + 0.2642501, 0.29840338, 0.38820496, 0.37829784, 0.105839334, 0.07713295, 0.45629853, 0.9290373, 0.56323594, + 0.59976774, 0.48325357, 0.102543674, 0.35449505, 0.3158472, 0.02927611, 0.44739273, 0.0516185, 0.12340133, + 0.13908496, 0.54970616, 0.74672216, 0.673308, 0.6400629, 0.26790652, 0.98673576}; /* nhwc */ + PackNCHWToNHWCFp32(in_nchw, in_t->Data(), in_t->Batch(), in_t->Width() * in_t->Height(), in_t->Channel()); + inputs_->push_back(in_t); + + std::vector weight_dims_nhwc = {2, 3, 3, 6}; + lite::tensor::Tensor *weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, weight_dims_nhwc, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight_nchw[] = { + 0.061163727, -0.06261389, 0.07708351, -0.019354159, -0.3859104, -0.082844816, -0.21268463, -0.15746808, + -0.096376516, 0.016681675, 0.1364329, -0.007941234, -0.10095563, 0.32489842, -0.042597733, 0.2701167, + -0.1415933, 0.007270595, -0.34188282, -0.3374504, -0.26375315, -0.075536035, 0.11136466, -0.2239981, + -0.07840504, -0.23905717, -0.10171707, -0.11058277, 0.363706, -0.09807812, -0.05729029, 0.0018888254, + -0.29443327, 0.13365538, 0.0453783, -0.31048688, 0.07062391, 0.16674924, 0.2268152, -0.18341774, + 0.10190555, 0.08567296, 0.13261533, -0.40412605, 0.13981377, -0.08217087, -0.050615843, -0.05403921, + -0.028555218, 0.2651543, 0.10668221, -0.013095176, 0.09588115, 0.044287443, -0.009692867, 0.06717065, + -0.29928264, -0.09110823, -0.07987715, -0.15888898, 0.041994736, 0.086504236, -0.19046812, 0.20323305, + 0.08014105, 0.009099235, 0.2525443, -0.010155359, 0.039532702, 0.20266832, 0.0045211455, -0.14146733, + -0.07135475, -0.011584315, 0.1640728, 0.13032198, 0.18829331, -0.27231383, -0.15681058, -0.14862166, + -0.084803745, -0.020582702, -0.0681792, 0.06789135, 0.13603394, 0.090862036, -0.08380498, -0.16875166, + -0.2570391, -0.013280135, 0.24033138, -0.08921211, 0.2722501, 0.24916205, -0.20001566, -0.11610521, + 0.06060236, 0.10848369, -0.4512424, 0.023834296, 0.1643943, -0.25290534, 0.066953085, -0.11685201, + -0.4159784, 0.37839416, -0.11141268, -0.15986018}; /* nhwc */ + PackNCHWToNHWCFp32(weight_nchw, weight_t->Data(), weight_t->Batch(), weight_t->Width() * weight_t->Height(), + weight_t->Channel()); + inputs_->push_back(weight_t); + + lite::tensor::Tensor *bias_t = + new lite::tensor::Tensor(kNumberTypeFloat, {6}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + float bias[] = {-0.19064677, -0.0034778118, 0.63741624, -1.0311537, -1.0288948, 0.71384084}; + memcpy(bias_t->Data(), bias, sizeof(float) * 6); + inputs_->push_back(bias_t); + + std::vector output_nhwc_dims = {1, 9, 13, 6}; + lite::tensor::Tensor *out_t = + new lite::tensor::Tensor(kNumberTypeFloat, output_nhwc_dims, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float nchw_co[] = { + -0.4159262, -0.46044537, -0.32667404, -0.4129007, -0.43664578, -0.39459872, -0.49400482, -0.4524444, + -0.30940545, -0.3997266, -0.4343413, -0.3413178, -0.42586732, -0.17157906, -0.4016143, -0.1097983, + -0.61039054, -0.19246969, -0.6629166, -0.24715163, -0.36829865, -0.1525711, -0.50477314, -0.22101344, + -0.4834266, -0.2868756, -0.21354413, -0.25993955, -0.33297282, -0.3962972, -0.43134302, -0.4203356, + -0.47099167, -0.32945585, -0.4933193, -0.3362223, -0.28017497, -0.31746963, -0.5820211, -0.2053628, + -0.23829184, -0.1884751, -0.36922038, -0.15235345, -0.6430171, -0.25126106, -0.63569427, -0.28716096, + -0.44492853, -0.14620401, -0.63435787, -0.27831206, -0.32927662, -0.24526191, -0.25315046, -0.2604547, + -0.30455, -0.37681228, -0.5119872, -0.4569657, -0.521509, -0.39786643, -0.27274203, -0.33900544, + -0.26303798, -0.25582826, -0.22533318, -0.2295449, -0.2498781, -0.20773302, -0.3777015, -0.2648021, + -0.50503045, -0.23136339, -0.45421264, -0.18984585, -0.23228307, -0.20156652, -0.3720746, -0.29076657, + -0.5048918, -0.35140067, -0.5004279, -0.32178527, -0.5359573, -0.3105652, -0.24390095, -0.28274524, + -0.44499388, -0.27840495, -0.49156278, -0.29778862, -0.34227157, -0.27404356, -0.5907216, -0.24148186, + -0.69942933, -0.3086446, -0.40131485, -0.16459012, -0.48982328, -0.33233505, -0.38212818, -0.2830558, + -0.5386851, -0.34576517, -0.4460499, -0.39519656, -0.3255192, -0.39476353, -0.40350133, -0.4050802, + -0.5406344, -0.40009072, -0.5944617, -0.42084867, -0.58132195, 0.11541255, 0.24717134, 0.035492875, + 0.09734866, 0.16597912, 0.12381038, 0.1923936, 0.22568025, 0.023888497, 0.085535035, 0.16757454, + 0.0050217994, 0.17314728, -0.043344263, 0.22266465, 0.057929777, 0.315026, 0.059421062, 0.3274499, + 0.02406001, 0.18286264, 0.107178226, 0.17828721, -0.026181899, 0.23815396, 0.07757285, 0.010184985, + 0.10768472, 0.07461695, 0.21580729, 0.12219772, 0.016947635, 0.21209088, -0.019231271, 0.22824496, + 0.060270205, 0.041847467, 0.006466368, 0.29673898, 0.04507852, 0.18171927, -0.0113601275, 0.332155, + 0.005798064, 0.29595143, 0.0644246, 0.349865, 0.04176835, 0.20181134, 0.036958598, 0.37659368, + -0.0836041, 0.105042435, -0.008922746, 0.04317373, 0.08832521, 0.057098098, 0.1759837, 0.19514789, + 0.07342724, 0.23147877, 0.12975746, 0.019213844, 0.1296622, 0.020062651, 0.01870161, 0.1208442, + 0.105693996, 0.20719647, 0.096077755, 0.3124894, 0.033647023, 0.26888633, -0.06377239, 0.09272936, + 0.07928991, 0.06689171, 0.09909828, 0.14132921, -0.0038207127, 0.23364612, -0.015699724, 0.23287944, + -0.10473035, 0.28497344, 0.06822525, 0.0067269485, -0.0401484, 0.20666184, -0.074035384, 0.24031198, + 0.06368647, 0.37245232, 0.012040168, 0.3706034, -0.020015769, 0.35215783, -0.018986963, 0.24762997, + 0.14907081, 0.18981782, 0.061614163, 0.43125582, 0.07961907, 0.27877036, 0.048327066, 0.16899693, + 0.16380924, 0.052272163, 0.14616457, 0.12360795, 0.08904207, 0.24163374, -0.043546468, 0.31575742, + 0.1325127, 0.24905476, 0.8535125, 0.4158996, 0.8379569, 0.36076424, 0.7887811, 0.4375921, + 0.85203487, 0.40125692, 0.8267099, 0.37313673, 0.78056836, 0.39070883, 0.750996, 0.39142087, + 0.22870147, 0.36334175, 0.22776845, 0.2842683, 0.17623127, 0.14350969, -0.049721956, 0.22356126, + 0.21368039, 0.38709402, 0.13516903, 0.14409906, 0.6560098, 0.65856576, 0.76757306, 0.5310113, + 0.87118506, 0.25672826, 0.76198256, 0.39929584, 0.77406937, 0.43344593, 0.7274, 0.47634512, + 0.8128686, 0.50098574, 0.39502823, 0.44564128, 0.24981359, 0.31671798, 0.15317863, 0.21069425, + 0.13331234, 0.16383857, 0.28979823, 0.50662756, 0.46699578, 0.32232434, 0.6949107, 0.5320594, + 0.668199, 0.6280134, 0.745686, 0.54090333, 0.88366413, 0.25842816, 0.8259659, 0.38957846, + 0.7602142, 0.510612, 0.7381607, 0.38837627, 0.1904087, 0.33691993, 0.11685282, 0.26914072, + -0.06617683, 0.046009183, 0.0700444, 0.356119, 0.24937916, 0.30769932, 0.06569201, 0.28872308, + 0.70671666, 0.4991707, 0.78667766, 0.36038262, 0.7790032, 0.32292485, 0.7419024, 0.48524532, + 0.7267125, 0.46316653, 0.7193444, 0.4372312, 0.7446447, 0.2186315, 0.03533274, 0.216304, + 0.25036755, 0.33977476, 0.3434924, 0.27370954, 0.16213486, 0.29132545, 0.078781545, 0.13724238, + -0.07549429, 0.1546486, 0.7608347, 0.43421644, 0.8019545, 0.44755372, 0.7997276, 0.44701982, + 0.81010026, 0.3866497, 0.8441801, 0.24970922, 0.7982173, 0.4100442, 0.9132067, -0.94733083, + -1.0997784, -0.9421829, -1.1218354, -0.9859438, -1.1612623, -0.96009386, -1.1590697, -0.9456968, + -1.1142067, -0.9900875, -1.2211759, -1.004981, -1.2370956, -1.349351, -1.2184161, -1.1564747, + -1.0476248, -1.3034617, -0.9740715, -1.5131376, -1.0246942, -1.1564014, -1.091238, -1.2773981, + -0.76259595, -1.0244793, -0.9916798, -0.9816827, -1.0407434, -0.94001544, -1.2400658, -1.0058745, + -1.251888, -1.0026754, -1.2247806, -0.99559414, -1.1104892, -0.9950131, -0.93231726, -1.1461066, + -1.1102134, -1.2707901, -1.2258892, -1.2075629, -0.899022, -1.2902625, -0.8440441, -1.3612556, + -1.1327276, -1.0097463, -1.0870252, -1.0208998, -1.1372137, -1.0238695, -1.0300313, -0.9893144, + -1.0387962, -0.9455299, -1.2633826, -0.97857773, -1.2199508, -0.97649026, -1.0467783, -0.9870789, + -0.8867735, -1.2570912, -0.7990466, -1.2643247, -0.89268696, -1.3204725, -0.9196508, -1.3377675, + -1.1563053, -1.4048479, -0.9489901, -1.2825038, -0.8854966, -1.0209885, -1.166144, -0.99754405, + -1.278291, -1.0010624, -1.3216578, -1.0268149, -1.2370203, -0.99041694, -1.1121378, -1.0252388, + -1.2528121, -1.0185167, -0.72908103, -1.2807931, -0.9268043, -1.2740122, -1.0588918, -1.1783062, + -0.89433515, -1.4704434, -0.90606475, -1.1208334, -0.67285204, -1.341852, -0.80200857, -1.016867, + -1.2564906, -0.9801711, -1.1481711, -0.96293676, -1.0831497, -0.969197, -1.1662431, -0.9715335, + -1.3331397, -1.0049394, -1.2574395, -0.9399705, -1.171572, -0.88565385, -1.2087893, -1.1065894, + -1.0714839, -0.9627551, -1.1188276, -0.8515502, -1.2049681, -1.1173695, -1.0619929, -1.066168, + -1.0279324, -1.0882176, -1.129684, -0.9890163, -0.8740333, -1.2120758, -0.56714463, -1.1103767, + -0.86929953, -0.8791485, -0.98886544, -1.2087606, -0.76514137, -1.0997763, -1.0388865, -0.9463707, + -1.1105144, -0.89834666, -1.1851951, -1.1659127, -1.0132934, -1.0602008, -1.014949, -0.9327261, + -1.0910889, -1.1383713, -1.0091913, -0.99213076, -0.8544737, -1.056894, -0.94257253, -1.0971456, + -0.8758079, -1.2477993, -0.35445136, -1.2152452, -0.5471301, -1.086797, -0.73012817, -1.3945714, + -1.0156894, -1.0198442, -1.0294445, -0.9484633, -1.0997083, -0.95065546, -1.1494579, -1.0774312, + -1.0660617, -0.89763457, -1.13983, -0.9865928, -1.1166302, -1.0880268, -0.7381968, -0.9876064, + -0.5964719, -0.9657296, -0.74247324, -1.041322, -0.9059322, -1.2995027, -0.94108796, -0.8961159, + -1.0022087, -0.89709914, -1.0036592, -1.0499129, -1.0242954, -1.0631231, -1.0169288, -1.1581104, + -0.94418347, -0.853006, -1.1137545, -1.183017, -0.9731438, -1.086927, -0.97671837, -1.066008, + -0.48595423, -1.2475185, -0.50115275, -1.326726, -0.5102552, -1.3762127, -0.39939296, -0.9266701, + -0.6510342, -1.1439915, -0.2621194, -1.2735826, -0.9677428, -0.9337987, -1.0829964, -0.8954656, + -1.1583862, -1.0067348, -1.1215614, -1.05432, -1.0779985, -1.151866, -0.98149765, -0.8774674, + -1.1439066, 0.71160585, 0.43664122, 0.63968056, 0.3411116, 0.79933065, 0.6023572, 0.79020524, + 0.5203902, 0.63432527, 0.34978527, 0.8055916, 0.5908885, 0.8279619, 0.6594803, 0.9234866, + 0.6951297, 0.580612, 0.8534291, 0.61968267, 0.69770944, 0.8167807, 0.6326902, 0.6108708, + 0.7726814, 0.5904738, 0.7508015, 0.71711653, 0.7171464, 0.71904653, 0.57166296, 0.70845544, + 0.3433037, 0.8610815, 0.6749295, 0.87055725, 0.6884554, 0.70868635, 0.56713784, 0.91778255, + 0.71033454, 0.8496836, 0.68372923, 0.9768204, 0.70797944, 0.5078603, 0.86912346, 0.48779017, + 0.80497104, 0.66758573, 0.7792437, 0.63723993, 0.8364369, 0.7909154, 0.7067954, 0.74354, + 0.72215, 0.7137401, 0.5893581, 0.77508205, 0.4122566, 0.8444451, 0.59620094, 0.6672466, + 0.5036563, 0.6805886, 0.72852767, 0.63650995, 0.74002045, 0.6952553, 0.6968493, 0.8008863, + 0.631564, 0.7486131, 0.79336673, 0.71474713, 0.6311797, 0.69647217, 0.6505069, 0.8208874, + 0.7216524, 0.8688757, 0.6455133, 0.87244576, 0.6376998, 0.94607174, 0.8251329, 0.6735983, + 0.51751864, 0.87973493, 0.74826664, 0.8994043, 0.72413105, 0.72747874, 0.808015, 0.6329842, + 0.8622399, 0.47823763, 0.8856161, 0.6762785, 0.73437214, 0.3766058, 0.764144, 0.60693324, + 0.89371794, 0.92908806, 0.7702812, 0.79492164, 0.58807003, 0.678272, 0.4573259, 0.7444603, + 0.49847388, 0.84439206, 0.51984715, 0.9452883, 0.7511028, 0.81281227}; + PackNCHWToNHWCFp32(nchw_co, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel()); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 3; + conv_param->stride_h_ = conv_param->stride_w_ = 2; + conv_param->dilation_h_ = conv_param->dilation_w_ = 1; + conv_param->pad_h_ = conv_param->pad_w_ = 1; + return out_t->ElementsNum(); +} + +TEST_F(TestDeConvolutionFp32, DeConvTest1) { + std::vector inputs_; + std::vector outputs_; + ConvParameter *deconv_param = new ConvParameter(); + lite::Context *ctx = new lite::Context(); + ctx->thread_num_ = 1; + float *correct; + int total_size = DeConvTestInit1(&inputs_, &outputs_, deconv_param, &correct); + kernel::DeConvolutionCPUKernel *deconv = + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + + deconv->Init(); + deconv->Run(); + + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete deconv_param; + delete deconv; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +int DeConvTestInit2(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, float **correct) { + auto *in_t = + new lite::tensor::Tensor(kNumberTypeFloat, {1, 4, 2, 3}, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in[] = {7.7566547, 19.250782, 17.923292, 13.584222, 3.3293908, 9.734102, 18.83455, -1.5142503, + -0.29382008, 18.686155, 0.087307654, 4.2010098, -2.2539594, 4.1795673, 13.142356, -3.5939367, + 16.505789, 19.899279, 8.556229, 19.969376, -6.2355065, -2.3804698, -9.027744, 9.5542}; /* nhwc */ + memcpy(in_t->Data(), in, sizeof(float) * in_t->ElementsNum()); + inputs_->push_back(in_t); + + auto *weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, {3, 3, 3, 2}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight[] = {-0.39557076, 0.15087655, 0.35216075, -0.20893791, 0.28683448, 0.08006268, 0.9830812, + 0.27212173, 0.5171944, -0.0014505, 0.78694165, 0.25425306, 0.16605458, -0.06127124, + 0.07637237, -0.5596424, -0.26599348, 0.223331, -0.45220536, -0.17021523, 0.20895825, + -0.07697097, 0.17581257, 0.09553282, 0.5369023, -0.6631143, 0.51170826, -0.5332868, + -0.19414032, -0.7109704, -0.05779554, -0.05178713, 0.3592201, -0.05532698, 0.06928781, + -0.5730523, -0.21037689, -0.01435696, 0.33056936, 0.51348346, -0.28136733, -0.36971128, + -0.10048455, 0.09297352, -0.27097073, -0.08646037, -0.06631696, -0.1684566, 0.31797925, + -0.06270258, 0.00119315, -0.2821196, -0.5166795, -0.09961014}; /* nhwc */ + memcpy(weight_t->Data(), weight, sizeof(float) * weight_t->ElementsNum()); + inputs_->push_back(weight_t); + + std::vector out_nhwc_dims = {1, 7, 3, 2}; + auto *out_t = + new lite::tensor::Tensor(kNumberTypeFloat, out_nhwc_dims, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); /* nc4hw4 */ + float nchw_co[] = {9.005795, 15.341887, 6.091704, 13.748293, -7.92756, 10.232557, 9.045886, + 33.1299, 8.5707, 5.318199, -14.367487, 10.22495, -2.5882099, -0.12742424, + 1.195263, 6.469591, 9.609164, 6.112072, 16.333368, -4.87735, -8.439645, + -11.827093, -12.340071, -2.6368382, -14.432123, -8.483799, -12.28651, 0.80561405, + 11.332421, -0.43688506, -3.476327, -4.587028, -1.9491882, -3.3619316, -15.831648, + -10.517606, -9.204161, -0.15148449, 1.5822954, -10.122691, -4.7448387, 3.99177}; + PackNCHWToNHWCFp32(nchw_co, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel()); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 3; + conv_param->stride_h_ = conv_param->stride_w_ = 2; + conv_param->dilation_h_ = conv_param->dilation_w_ = 1; + conv_param->pad_h_ = conv_param->pad_w_ = 1; + return out_t->ElementsNum(); +} + +TEST_F(TestDeConvolutionFp32, DeConvTest2) { + std::vector inputs_; + std::vector outputs_; + auto deconv_param = new ConvParameter(); + float *correct; + int total_size = DeConvTestInit2(&inputs_, &outputs_, deconv_param, &correct); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 4; + kernel::DeConvolutionCPUKernel *deconv = + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + + deconv->Init(); + deconv->Run(); + EXPECT_EQ(0, lite::CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size)); + delete deconv_param; + delete deconv; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +int DeConvTestInit3(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, float **correct) { + std::vector in_dims_nhwc = {1, 3, 3, 2}; + auto *in_t = + new lite::tensor::Tensor(kNumberTypeFloat, in_dims_nhwc, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in_nchw[] = {0.10411751, 0.24034509, 0.71456534, 0.75286126, 0.9778457, 0.21043599, + 0.26498786, 0.6701024, 0.9744634, 0.49075702, 0.03877404, 0.48646277, + 0.5473929, 0.32438126, 0.87553847, 0.75820315, 0.86666644, 0.4852329}; + PackNCHWToNHWCFp32(in_nchw, reinterpret_cast(in_t->Data()), in_t->Batch(), in_t->Width() * in_t->Height(), + in_t->Channel()); + inputs_->push_back(in_t); + + std::vector w_dims_nhwc = {2, 2, 2, 2}; + auto *weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, w_dims_nhwc, schema::Format_NHWC, schema::NodeType_Parameter); + weight_t->MallocData(); + float w_nchw[] = {-0.108016446, -0.44254777, 0.29249913, 0.18764605, 1.1250675, 0.29441583, + -0.34362152, 0.7557833, 0.16503833, 0.2418737, -0.26612744, 0.5072577, + -0.4284475, 0.2215941, 0.9273913, 0.34634787}; + PackNCHWToNHWCFp32(w_nchw, weight_t->Data(), weight_t->Batch(), weight_t->Width() * weight_t->Height(), + weight_t->Channel()); + inputs_->push_back(weight_t); + + std::vector out_dims_nhwc = {1, 9, 9, 2}; + auto *out_t = + new lite::tensor::Tensor(kNumberTypeFloat, out_dims_nhwc, schema::Format_NC4HW4, schema::NodeType_Parameter); + out_t->MallocData(); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float nchw_co[] = {0.069747314, 0.0, 0.072624244, -0.019562019, 0.0, -0.096985765, 0.0031001933, 0.0, -0.19856673, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -0.100149624, 0.0, 0.26847753, 0.059981894, 0.0, 0.06476824, 0.07954865, 0.0, 0.38084733, + 0.009019416, 0.0, -0.20077711, -0.05208808, 0.0, -0.35428414, 0.12176686, 0.0, 0.11864175, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.074535, 0.0, 0.4189407, 0.19969228, 0.0, 0.3480338, -0.17145246, 0.0, 0.4836111, + 0.09650954, 0.0, 0.06611961, 0.0706511, 0.0, -0.08692852, -0.02517605, 0.0, -0.31388155, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + -0.12426994, 0.0, 0.43432832, -0.034639344, 0.0, 0.5653653, 0.15589589, 0.0, 0.42899233, + -0.0931244, 0.0, 0.1394027, 0.2537918, 0.0, 0.0793535, 0.5955104, 0.0, 0.31817663, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.41934675, 0.0, 0.24866292, -0.04662904, 0.0, 0.1950781, 0.2056013, 0.0, 0.7085419, + 0.6124906, 0.0, 0.34295332, 0.96116215, 0.0, 0.35977423, -0.1383676, 0.0, 0.25596985, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.24894807, 0.0, 0.7585884, -0.03518048, 0.0, 0.8513882, 0.73965645, 0.0, 0.46228492, + -0.026721025, 0.0, 0.24602996, 0.38258934, 0.0, 0.38933694, 0.88844025, 0.0, 0.3944222, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.6120955, 0.0, 0.46287543, 0.57347727, 0.0, 0.80662024, 0.11515418, 0.0, 0.90454257}; + PackNCHWToNHWCFp32(nchw_co, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel()); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 2; + conv_param->stride_h_ = conv_param->stride_w_ = 3; + conv_param->dilation_h_ = conv_param->dilation_w_ = 2; + conv_param->pad_h_ = conv_param->pad_w_ = 0; + return out_t->ElementsNum(); +} + +TEST_F(TestDeConvolutionFp32, DeConvTest3) { + std::vector inputs_; + std::vector outputs_; + auto deconv_param = new ConvParameter(); + float *correct; + int total_size = DeConvTestInit3(&inputs_, &outputs_, deconv_param, &correct); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::DeConvolutionCPUKernel *deconv = + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + + deconv->Init(); + deconv->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + + delete deconv_param; + delete deconv; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +int DeConvTestInit4(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, float **correct) { + size_t buffer_size; + std::vector in_nhwc_dims = {1, 300, 300, 30}; + auto *in_t = + new lite::tensor::Tensor(kNumberTypeFloat, in_nhwc_dims, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + std::string in_nhwc_path = "./deconv/deconv_fp32_nhwc_input1.bin"; + auto in_nhwc = reinterpret_cast(mindspore::lite::ReadFile(in_nhwc_path.c_str(), &buffer_size)); + memcpy(in_t->Data(), in_nhwc, buffer_size); + inputs_->push_back(in_t); + + std::vector w_nhwc_dims = {30, 3, 3, 40}; + auto *weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, w_nhwc_dims, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + std::string weight_path = "./deconv/deconv_fp32_nchw_weight1.bin"; + auto weight_nchw = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size)); + PackNCHWToNHWCFp32(weight_nchw, weight_t->Data(), weight_t->Batch(), weight_t->Width() * weight_t->Height(), + weight_t->Channel()); + inputs_->push_back(weight_t); + + auto *bias_t = + new lite::tensor::Tensor(kNumberTypeFloat, {40}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + std::string bias_path = "./deconv/deconv_fp32_nchw_bias1.bin"; + auto bias = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size); + memcpy(bias_t->Data(), bias, buffer_size); + inputs_->push_back(bias_t); + + std::vector out_nhwc_dims = {1, 302, 302, 40}; + auto *out_t = + new lite::tensor::Tensor(kNumberTypeFloat, out_nhwc_dims, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + std::string out_path = "./deconv/deconv_fp32_nchw_output1.bin"; + auto out_nchw = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size); + *correct = reinterpret_cast(malloc(buffer_size)); + PackNCHWToNHWCFp32(out_nchw, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel()); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 3; + conv_param->stride_h_ = conv_param->stride_w_ = 1; + conv_param->dilation_h_ = conv_param->dilation_w_ = 1; + conv_param->pad_h_ = conv_param->pad_w_ = 0; + conv_param->is_relu_ = conv_param->is_relu6_ = false; + + return out_t->ElementsNum(); +} + +TEST_F(TestDeConvolutionFp32, DeConvTest4) { + std::vector inputs_; + std::vector outputs_; + auto deconv_param = new ConvParameter(); + float *correct; + int total_size = DeConvTestInit4(&inputs_, &outputs_, deconv_param, &correct); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::DeConvolutionCPUKernel *deconv = + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + + deconv->Init(); + deconv->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + + /* running warm up */ + for (int i = 0; i < 0; i++) { + deconv->Run(); + } + + /* running time cost */ + int loop_count = 1; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + deconv->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + uint64_t time_avg = cost / loop_count; + printf("deconv fp32 average time : %f ms\n", time_avg / 1000.0f); + + delete deconv_param; + delete deconv; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc new file mode 100644 index 0000000000..2c22d10249 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/depth_to_space.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h" + +namespace mindspore { + +class DepthToSpaceTestFp32 : public mindspore::Common { + public: + DepthToSpaceTestFp32() = default; +}; + +TEST_F(DepthToSpaceTestFp32, DepthToSpaceTest2) { + float input[16] = {1, 2, 10, 20, 5, 6, 3, 8, 18, 10, 11, 55, 3, 4, 15, 25}; + constexpr int kOutSize = 16; + float expect_out[kOutSize] = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; + + float output[kOutSize]; + int in_shape[4] = {1, 2, 2, 4}; + int out_shape[4] = {1, 4, 4, 1}; + DepthToSpaceParameter param; + param.block_size_ = 2; + int in_strides[4]; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4]; + ComputeStrides(out_shape, out_strides, 4); + param.in_stride_dim0_ = in_strides[0]; + param.in_stride_dim1_ = in_strides[1]; + param.in_stride_dim2_ = in_strides[2]; + param.out_stride_dim0_ = out_strides[0]; + param.out_stride_dim1_ = out_strides[1]; + param.out_stride_dim2_ = out_strides[2]; + param.data_type_size_ = sizeof(float); + DepthToSpaceForNHWC((const void *)input, output, in_shape, ¶m); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + +TEST_F(DepthToSpaceTestFp32, DepthToSpaceTest3) { + float input[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + constexpr int kOutSize = 8; + float expect_out[kOutSize] = {1, 2, 3, 4, 5, 6, 7, 8}; + + float output[kOutSize]; + int in_shape[4] = {1, 1, 1, 8}; + int out_shape[4] = {1, 2, 2, 2}; + DepthToSpaceParameter param; + param.block_size_ = 2; + int in_strides[4]; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4]; + ComputeStrides(out_shape, out_strides, 4); + param.in_stride_dim0_ = in_strides[0]; + param.in_stride_dim1_ = in_strides[1]; + param.in_stride_dim2_ = in_strides[2]; + param.out_stride_dim0_ = out_strides[0]; + param.out_stride_dim1_ = out_strides[1]; + param.out_stride_dim2_ = out_strides[2]; + param.data_type_size_ = sizeof(float); + DepthToSpaceForNHWC((const void *)input, output, in_shape, ¶m); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc new file mode 100644 index 0000000000..1f6d2997c4 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32/embedding_lookup.h" +#include "src/runtime/kernel/arm/nnacl/fp32/embedding_lookup.h" +#include "src/common/file_utils.h" +#include "common/common_test.h" +#include "utils/log_adapter.h" + +namespace mindspore { +using mindspore::lite::tensor::Tensor; + +class TestEmbeddingLookupFp32 : public mindspore::Common { + public: + TestEmbeddingLookupFp32() {} +}; + +void ElTestInit(std::vector *inputs_, std::vector *outputs_, + EmbeddingLookupParameter *embedding_lookup_param) { + Tensor *in_t_first = new Tensor(kNumberTypeFloat32, {6, 2}, schema::Format_NHWC, static_cast(1)); + in_t_first->MallocData(); + float in_first[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + memcpy(in_t_first->Data(), in_first, sizeof(float) * in_t_first->ElementsNum()); + inputs_->push_back(in_t_first); + + Tensor *in_t_second = new Tensor(kNumberTypeFloat32, {4, 2}, schema::Format_NHWC, static_cast(1)); + in_t_second->MallocData(); + float in_second[] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}; + memcpy(in_t_second->Data(), in_second, sizeof(float) * in_t_second->ElementsNum()); + inputs_->push_back(in_t_second); + + Tensor *ids_t = new Tensor(kNumberTypeFloat32, {2, 3}, schema::Format_NHWC, static_cast(1)); + ids_t->MallocData(); + int ids[] = {1, 9, 2, 4, 6, 7}; + memcpy(ids_t->Data(), ids, sizeof(int) * ids_t->ElementsNum()); + inputs_->push_back(ids_t); + + Tensor *outputs_t = new Tensor(kNumberTypeInt32, {2, 3, 2}, schema::Format_NHWC, static_cast(1)); + outputs_t->MallocData(); + outputs_->push_back(outputs_t); + + embedding_lookup_param->max_norm_ = 1; +} + +TEST_F(TestEmbeddingLookupFp32, ElTest) { + std::vector inputs_; + std::vector outputs_; + auto embedding_lookup_param_ = new EmbeddingLookupParameter(); + ElTestInit(&inputs_, &outputs_, embedding_lookup_param_); + + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::EmbeddingLookupCPUKernel *el = new kernel::EmbeddingLookupCPUKernel( + reinterpret_cast(embedding_lookup_param_), inputs_, outputs_, ctx); + + el->Init(); + el->Run(); + + std::cout << "output shape:" << std::endl; + for (int i = 0; i < outputs_.front()->shape().size(); ++i) { + std::cout << outputs_.front()->shape()[i] << ' '; + } + std::cout << std::endl; + float *out = reinterpret_cast(outputs_.front()->Data()); + for (int i = 0; i < outputs_.front()->ElementsNum(); ++i) { + std::cout << out[i] << ' '; + } + std::cout << std::endl; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc new file mode 100644 index 0000000000..9a226fef65 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/runtime/kernel/arm/fp32/fullconnection.h" +#include "src/runtime/kernel/arm/nnacl/fp32/matmul.h" + +namespace mindspore { +using mindspore::lite::tensor::Tensor; + +class TestFcFp32 : public mindspore::Common { + public: + TestFcFp32() {} +}; + +int FcTestInit1(std::vector *inputs_, std::vector *outputs_, + MatMulParameter *matmal_param, float **correct) { + Tensor *in_t = new Tensor(kNumberTypeFloat, {2, 2, 2, 2}, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383, + 17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792}; + memcpy(in_t->Data(), in, sizeof(float) * in_t->ElementsNum()); + inputs_->push_back(in_t); + + Tensor *weight_t = new Tensor(kNumberTypeFloat, {3, 8}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight[] = {-0.0024438887, 0.0006738146, -0.008169129, 0.0021510671, -0.012470592, -0.0053063435, + 0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552, + -0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459, + 0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343}; + memcpy(weight_t->Data(), weight, sizeof(float) * weight_t->ElementsNum()); + inputs_->push_back(weight_t); + + Tensor *bias_t = new Tensor(kNumberTypeFloat, {3}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + float bias[] = {1.6103756, -0.9872417, 0.546849}; + memcpy(bias_t->Data(), bias, sizeof(float) * bias_t->ElementsNum()); + inputs_->push_back(bias_t); + + Tensor *out_t = new Tensor(kNumberTypeFloat, {2, 3}, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float nchw_co[] = {1.6157111, -0.98469573, 0.6098231, 1.1649342, -1.2334653, 0.404779}; + memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float)); + + matmal_param->b_transpose_ = true; + matmal_param->a_transpose_ = false; + matmal_param->has_bias_ = true; + matmal_param->act_type_ = ActType_No; + return out_t->ElementsNum(); +} + +TEST_F(TestFcFp32, FcTest1) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + float *correct; + int total_size = FcTestInit1(&inputs_, &outputs_, matmul_param, &correct); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::FullconnectionCPUKernel *fc = + new kernel::FullconnectionCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + + fc->Init(); + fc->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); +} + +int FcTestInit2(std::vector *inputs_, std::vector *outputs_, + MatMulParameter *matmal_param, float **correct) { + size_t buffer_size; + + Tensor *in_t = new Tensor(kNumberTypeFloat, {20, 4, 2, 10}, schema::Format_NCHW, static_cast(1)); + in_t->MallocData(); + std::string in_path = "./matmul/FcFp32_input1.bin"; + auto in_data = mindspore::lite::ReadFile(in_path.c_str(), &buffer_size); + memcpy(in_t->Data(), in_data, buffer_size); + inputs_->push_back(in_t); + + Tensor *weight_t = new Tensor(kNumberTypeFloat, {30, 80}, schema::Format_NCHW, static_cast(1)); + weight_t->MallocData(); + std::string weight_path = "./matmul/FcFp32_weight1.bin"; + auto w_data = mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size); + memcpy(weight_t->Data(), w_data, buffer_size); + inputs_->push_back(weight_t); + + Tensor *bias_t = new Tensor(kNumberTypeFloat, {30}, schema::Format_NCHW, static_cast(1)); + bias_t->MallocData(); + std::string bias_path = "./matmul/FcFp32_bias1.bin"; + auto bias_data = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size); + memcpy(bias_t->Data(), bias_data, buffer_size); + inputs_->push_back(bias_t); + + Tensor *out_t = new Tensor(kNumberTypeFloat, {20, 30}, schema::Format_NCHW, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + std::string out_path = "./matmul/FcFp32_output1.bin"; + auto out_data = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size); + memcpy(*correct, out_data, out_t->ElementsNum() * sizeof(float)); + + matmal_param->b_transpose_ = true; + matmal_param->a_transpose_ = false; + matmal_param->has_bias_ = true; + matmal_param->act_type_ = ActType_No; + return out_t->ElementsNum(); +} + +TEST_F(TestFcFp32, FcTest2) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + float *correct; + int total_size = FcTestInit2(&inputs_, &outputs_, matmul_param, &correct); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + kernel::FullconnectionCPUKernel *fc = + new kernel::FullconnectionCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + + fc->Init(); + fc->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc new file mode 100644 index 0000000000..f591548c00 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc @@ -0,0 +1,330 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/lstm.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/ops/ops.h" + +namespace mindspore { +class LstmFp32 : public mindspore::Common { + public: + LstmFp32() {} +}; + +void InitLstmParam(LstmParameter *lstm_param) { + lstm_param->seq_len_ = 4; + lstm_param->batch_ = 1; + lstm_param->input_size_ = 2; + lstm_param->hidden_size_ = 3; + lstm_param->bidirectional_ = false; +} + +void InitLstmForwardCreator(std::vector *inputs, std::vector *outputs, + const LstmParameter *lstm_param) { + // prepare input + std::vector input_data = {1.3889, -0.3006, -0.1787, 2.1504, -0.3181, 0.4945, -0.4758, -0.8187}; + auto *input = new lite::tensor::Tensor; + input->set_data_type(kNumberTypeFloat32); + input->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->input_size_}); + input->MallocData(); + memcpy(input->Data(), input_data.data(), input_data.size() * sizeof(float)); + + // prepare weight_i + std::vector weight_i_data = {0.21368974, -0.3778776, 0.05025542, 0.09011161, 0.18355745, 0.5491228, + -0.14186832, -0.4655916, 0.49541366, -0.44039622, 0.5625571, 0.23325664, + 0.3449825, -0.42750397, 0.01911497, -0.4125802, -0.56690466, 0.50593233, + -0.29129684, -0.27841482, 0.01964372, -0.42543447, 0.41720617, -0.30054367}; + auto *weight_i = new lite::tensor::Tensor; + weight_i->set_data_type(kNumberTypeFloat32); + weight_i->SetFormat(schema::Format_NHWC); + weight_i->set_shape({1, lstm_param->hidden_size_ * 4, lstm_param->input_size_}); + weight_i->MallocData(); + memcpy(weight_i->Data(), weight_i_data.data(), weight_i_data.size() * sizeof(float)); + + // prepare weight_r + std::vector weight_h_data = { + -0.03424168, 0.00643545, 0.36867607, -0.08598137, 0.19804275, -0.11319417, -0.0244593, -0.16440144, -0.07268238, + 0.09828371, 0.33358777, 0.53381383, -0.39431244, -0.06005383, -0.3520246, 0.42687547, 0.5772828, 0.5380008, + -0.16130409, -0.24737108, 0.42409766, -0.50648475, 0.48223662, -0.5221103, -0.49216837, -0.29084128, 0.3408438, + 0.34080023, 0.49467337, 0.23473483, 0.01759732, 0.04691631, 0.45574808, -0.29481018, 0.29442167, -0.36718}; + auto *weight_h = new lite::tensor::Tensor; + weight_h->set_data_type(kNumberTypeFloat32); + weight_h->SetFormat(schema::Format_NHWC); + weight_h->set_shape({1, lstm_param->hidden_size_ * 4, lstm_param->hidden_size_}); + weight_h->MallocData(); + memcpy(weight_h->Data(), weight_h_data.data(), weight_h_data.size() * sizeof(float)); + + // prepare bias + std::vector bias_data = {-0.00207639, 0.16391152, -0.00069344, -0.32945693, -0.367423, 0.28301108, + -0.17930457, 0.5278388, 0.12598747, -0.53130764, 0.1479364, 0.16695255, + -0.00708795, -0.46417096, -0.23966661, -0.17496741, -0.19166365, -0.50466555, + -0.23593256, -0.3911457, 0.51128435, 0.5128727, 0.253451, -0.51891875}; + auto *bias = new lite::tensor::Tensor; + bias->set_data_type(kNumberTypeFloat32); + bias->SetFormat(schema::Format_NHWC); + bias->set_shape({1, lstm_param->hidden_size_ * 4 * 2}); + bias->MallocData(); + memcpy(bias->Data(), bias_data.data(), bias_data.size() * sizeof(float)); + + // prepare state + std::vector state_data = {0, 0, 0}; + auto *state = new lite::tensor::Tensor; + state->set_data_type(kNumberTypeFloat32); + state->SetFormat(schema::Format_NHWC); + state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_}); + state->MallocData(); + memcpy(state->Data(), state_data.data(), state_data.size() * sizeof(float)); + + inputs->push_back(input); + inputs->push_back(weight_i); + inputs->push_back(weight_h); + inputs->push_back(bias); + inputs->push_back(state); + inputs->push_back(state); + + // malloc output buffer, for arm cpu, format: N C4 H W 4 + auto *output = new lite::tensor::Tensor; + output->set_data_type(kNumberTypeFloat32); + output->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->hidden_size_}); + output->SetFormat(schema::Format_NHWC); + output->MallocData(); + memset(output->Data(), 0, output->ElementsNum() * sizeof(float)); + + auto *cell_state = new lite::tensor::Tensor; + cell_state->set_data_type(kNumberTypeFloat32); + cell_state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_}); + cell_state->SetFormat(schema::Format_NHWC); + cell_state->MallocData(); + memset(cell_state->Data(), 0, cell_state->ElementsNum() * sizeof(float)); + + auto *hidden_state = new lite::tensor::Tensor; + hidden_state->set_data_type(kNumberTypeFloat32); + hidden_state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_}); + hidden_state->SetFormat(schema::Format_NHWC); + hidden_state->MallocData(); + memset(hidden_state->Data(), 0, hidden_state->ElementsNum() * sizeof(float)); + + outputs->push_back(output); + outputs->push_back(cell_state); + outputs->push_back(hidden_state); +} + +void CompareOutput(lite::tensor::Tensor *output, std::vector data) { + for (int i = 0; i < output->ElementsNum(); i++) { + std::cout << reinterpret_cast(output->Data())[i] << ", "; + } + std::cout << std::endl; + + Common::CompareOutputData(reinterpret_cast(output->Data()), data.data(), output->ElementsNum(), 0.0001); +} + +TEST_F(LstmFp32, LstmForwardFp32Accuracy) { + // prepare stage + auto lstm_param = new LstmParameter(); + InitLstmParam(lstm_param); + + // init ctx + auto ctx = new lite::Context(); + ctx->thread_num_ = 1; + + // init tensor + std::vector inputs; + std::vector outputs; + InitLstmForwardCreator(&inputs, &outputs, lstm_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + // op run + kernel->Run(); + + std::cout << "==================output data=================" << std::endl; + std::vector output0_data = {-0.0702, 0.1225, 0.0876, -0.0357, -0.0227, -0.2294, + -0.0345, -0.0108, -0.2002, 0.0451, 0.0853, -0.1205}; + CompareOutput(outputs[0], output0_data); + + std::vector output1_data = {0.0451, 0.0853, -0.1205}; + CompareOutput(outputs[1], output1_data); + + std::vector output2_data = {0.0989, 0.2094, -0.4132}; + CompareOutput(outputs[2], output2_data); + + delete lstm_param; + for (int i = 0; i < inputs.size() - 1; i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + MS_LOG(INFO) << "LstmFp32 forward accuracy passed"; +} + +void InitLstmBackwardCreator(std::vector *inputs, std::vector *outputs, + const LstmParameter *lstm_param) { + // prepare input + std::vector input_data = {1.4305, 0.5342, -0.9221, 0.0527, 2.3770, -0.3697, -0.2833, -2.1285}; + auto *input = new lite::tensor::Tensor; + input->set_data_type(kNumberTypeFloat32); + input->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->input_size_}); + input->MallocData(); + memcpy(input->Data(), input_data.data(), input_data.size() * sizeof(float)); + + // prepare weight_i + std::vector weight_i_data = { + -0.19253477, -0.007966279, -0.06039094, 0.27697134, -0.5071223, 0.18996351, 0.20472168, -0.1007814, + 0.04282999, 0.20836472, -0.4654655, 0.050321221, -0.3431457, 0.22256428, 0.29294532, 0.45042896, + 0.20468240, 0.13078391, -0.20987969, -0.3173505, -0.3813517, 0.10205835, 0.21858131, -0.0386473, + 0.5512280, -0.2763766, -0.3593936, -0.5181975, 0.3469863, -0.38533931, 0.010202527, -0.46598294, + -0.5740513, 0.06127524, -0.03960543, 0.2478809, -0.17296993, 0.19159525, -0.4976995, 0.05985528, + 0.3653409, 0.386924, 0.3170289, -0.08830952, -0.31105759, 0.3110240, 0.15174299, 0.287579894}; + auto *weight_i = new lite::tensor::Tensor; + weight_i->set_data_type(kNumberTypeFloat32); + weight_i->SetFormat(schema::Format_NHWC); + weight_i->set_shape({2, lstm_param->hidden_size_ * 4, lstm_param->input_size_}); + weight_i->MallocData(); + memcpy(weight_i->Data(), weight_i_data.data(), weight_i_data.size() * sizeof(float)); + + // prepare weight_r + std::vector weight_h_data = { + 0.106934666, -0.50430017, 0.33296257, -0.288117021, -0.38019785, -0.147071093, 0.422707557, 0.41497004, + -0.5329730, -0.430150926, -0.032713949, 0.35401260, 0.179495036, -0.14158579, 0.380428612, -0.175597071, + 0.54088723, -0.403292059, -0.287720531, -0.51250511, -0.15405902, -0.440592586, 0.16726928, -0.0163397789, + 0.51673841, 0.5094323, -0.137105107, -0.181070089, -0.47221425, -0.38046866, -0.206725060, 0.248537719, + -0.23961094, -0.117781728, 0.426800847, 0.0266208052, -0.197408229, 0.54831492, -0.280048757, -0.125062286, + -0.29929456, 0.42354834, -0.401066303, 0.356340110, 0.54629492, -0.15852552, 0.131406366, -0.101815432, + 0.0121276974, -0.53553336, 0.121099889, 0.060554087, 0.46259057, -0.49666053, 0.090806663, 0.20542401, + -0.38674920, -0.23874849, -0.5222138, 0.57537007, 0.113343358, -0.35233467, -0.25532332, 0.159506142, + 0.35996592, -0.201961308, -0.16323345, 0.119177639, -0.12677872, -0.175229549, -0.160024613, -0.21058899}; + auto *weight_h = new lite::tensor::Tensor; + weight_h->set_data_type(kNumberTypeFloat32); + weight_h->SetFormat(schema::Format_NHWC); + weight_h->set_shape({2, lstm_param->hidden_size_ * 4, lstm_param->hidden_size_}); + weight_h->MallocData(); + memcpy(weight_h->Data(), weight_h_data.data(), weight_h_data.size() * sizeof(float)); + + // prepare bias + std::vector bias_data = { + 0.57061123, -0.25357073, -0.146834075, 0.412972748, -0.27809411, -0.0542128682, -0.45384609, -0.53261917, + 0.222133636, -0.18093895, -0.045559883, 0.09109061, 0.080319643, 0.455167174, 0.36235427, -0.00164419412, + -0.135566502, 0.41905909, -0.450117409, 0.50565385, -0.077815443, -0.47051778, -0.141349375, -0.338519752, + 0.48683023, 0.282384872, 0.13399660, -0.382526844, -0.23370727, -0.184681564, 0.45679104, -0.339453905, + 0.452010273, 0.0552094578, 0.328843057, 0.127738714, -0.127084732, -0.334061294, -0.46742400, -0.401568055, + 0.23712641, -0.052937567, 0.272351622, 0.42767739, 0.303884744, -0.46025499, -0.43985402, 0.256422877}; + auto *bias = new lite::tensor::Tensor; + bias->set_data_type(kNumberTypeFloat32); + bias->SetFormat(schema::Format_NHWC); + bias->set_shape({2, lstm_param->hidden_size_ * 4 * 2}); + bias->MallocData(); + memcpy(bias->Data(), bias_data.data(), bias_data.size() * sizeof(float)); + + // prepare state + std::vector state_data = {0, 0, 0, 0, 0, 0}; + auto *state = new lite::tensor::Tensor; + state->set_data_type(kNumberTypeFloat32); + state->SetFormat(schema::Format_NHWC); + state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_}); + state->MallocData(); + memcpy(state->Data(), state_data.data(), state_data.size() * sizeof(float)); + + inputs->push_back(input); + inputs->push_back(weight_i); + inputs->push_back(weight_h); + inputs->push_back(bias); + inputs->push_back(state); + inputs->push_back(state); + + // malloc output buffer, for arm cpu, format: N C4 H W 4 + auto *output = new lite::tensor::Tensor; + output->set_data_type(kNumberTypeFloat32); + output->set_shape({lstm_param->seq_len_, 2, lstm_param->batch_, lstm_param->hidden_size_}); + output->SetFormat(schema::Format_NHWC); + output->MallocData(); + memset(output->Data(), 0, output->ElementsNum() * sizeof(float)); + + auto *cell_state = new lite::tensor::Tensor; + cell_state->set_data_type(kNumberTypeFloat32); + cell_state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_}); + cell_state->SetFormat(schema::Format_NHWC); + cell_state->MallocData(); + memset(cell_state->Data(), 0, cell_state->ElementsNum() * sizeof(float)); + + auto *hidden_state = new lite::tensor::Tensor; + hidden_state->set_data_type(kNumberTypeFloat32); + hidden_state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_}); + hidden_state->SetFormat(schema::Format_NHWC); + hidden_state->MallocData(); + memset(hidden_state->Data(), 0, hidden_state->ElementsNum() * sizeof(float)); + + outputs->push_back(output); + outputs->push_back(cell_state); + outputs->push_back(hidden_state); +} + +TEST_F(LstmFp32, LstmBackwardFp32Accuracy) { + // prepare stage + auto lstm_param = new LstmParameter(); + InitLstmParam(lstm_param); + lstm_param->bidirectional_ = true; + + // init ctx + auto ctx = new lite::Context(); + ctx->thread_num_ = 1; + + // init tensor + std::vector inputs; + std::vector outputs; + InitLstmBackwardCreator(&inputs, &outputs, lstm_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + // op run + kernel->Run(); + + std::cout << "==================output data=================" << std::endl; + std::vector output0_data = {-0.2922, -0.1416, 0.0077, -0.0422, -0.0585, 0.2061, -0.2385, -0.0146, + -0.1796, -0.0554, -0.0973, 0.1013, -0.3062, -0.1516, -0.0310, 0.0459, + -0.0784, 0.0949, 0.0249, -0.0653, -0.0869, -0.1113, -0.2155, -0.0500}; + CompareOutput(outputs[0], output0_data); + + std::vector output1_data = {0.0249, -0.0653, -0.0869, -0.0422, -0.0585, 0.2061}; + CompareOutput(outputs[1], output1_data); + + std::vector output2_data = {0.0373, -0.2322, -0.1477, -0.1621, -0.1808, 0.5146}; + CompareOutput(outputs[2], output2_data); + + delete lstm_param; + for (int i = 0; i < inputs.size() - 1; i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + MS_LOG(INFO) << "LstmFp32 backward accuracy passed"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc new file mode 100644 index 0000000000..8ce0da32d0 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc @@ -0,0 +1,308 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h" +#include "src/kernel_registry.h" +#include "src/lite_kernel.h" + +namespace mindspore { +class TestMatMulFp32 : public mindspore::Common { + public: + TestMatMulFp32() {} +}; + +TEST_F(TestMatMulFp32, Row2Col8Test1) { + float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, + 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, + 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, + 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, + 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, + 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; + float co[] = {0.21, 0.67, 0.53, 0.09, 0.43, 0.35, 0.04, 0.43, 0.38, 0.81, 0.97, 0.50, 0.14, 0.52, 0.10, 0.14, + 0.81, 0.57, 0.92, 0.39, 0.67, 0.02, 0.18, 0.67, 0.98, 0.70, 0.35, 0.09, 0.10, 0.33, 0.92, 0.10, + 0.09, 0.27, 0.74, 0.93, 0.73, 0.99, 0.46, 0.73, 0.68, 0.90, 0.78, 0.91, 0.37, 0.49, 0.08, 0.37, + 0.02, 0.07, 0.87, 0.20, 0.24, 0.67, 0.04, 0.24, 0.33, 0.13, 0.23, 0.97, 0.93, 0.75, 0.24, 0.93, + 0.85, 0.03, 0.34, 0.61, 0.31, 0.66, 0.52, 0.31, 0.35, 0.04, 0, 0, 0, 0, 0, 0, + 0.52, 0.10, 0, 0, 0, 0, 0, 0, 0.02, 0.18, 0, 0, 0, 0, 0, 0, + 0.33, 0.92, 0, 0, 0, 0, 0, 0, 0.99, 0.46, 0, 0, 0, 0, 0, 0, + 0.49, 0.08, 0, 0, 0, 0, 0, 0, 0.67, 0.04, 0, 0, 0, 0, 0, 0, + 0.75, 0.24, 0, 0, 0, 0, 0, 0, 0.66, 0.52, 0, 0, 0, 0, 0, 0}; + float out[144] = {0}; + RowMajor2Col8Major(in, out, 10, 9); + CompareOutputData(out, co, 144, 0.0001); +} + +TEST_F(TestMatMulFp32, Row2Col8Test2) { + float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, + 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, + 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, + 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, + 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, + 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; + float co[] = {0.21, 0.68, 0.81, 0.07, 0.92, 0.23, 0.09, 0.61, 0.38, 0.02, 0.57, 0.13, 0.35, 0.34, 0.93, + 0.43, 0.81, 0.33, 0.70, 0.03, 0.74, 0.09, 0.91, 0.14, 0.98, 0.85, 0.27, 0.53, 0.78, 0.50, + 0.20, 0.67, 0.09, 0.67, 0.90, 0.97, 0.87, 0.39, 0.97, 0.10, 0.73, 0.35, 0.49, 0.10, 0.04, + 0.67, 0.93, 0.33, 0.37, 0.52, 0.67, 0.18, 0.24, 0.10, 0.31, 0.99, 0.24, 0.02, 0.75, 0.92, + 0.52, 0.73, 0.35, 0.49, 0.93, 0.33, 0.66, 0.46, 0.43, 0.37, 0.52, 0.67, 0.31, 0.99, 0.04, + 0.08, 0.14, 0.24, 0.02, 0.75, 0.66, 0.46, 0, 0, 0, 0, 0, 0, 0.04, 0.08, + 0, 0, 0, 0, 0, 0, 0.10, 0.04, 0, 0, 0, 0, 0, 0, 0.18, + 0.24, 0, 0, 0, 0, 0, 0, 0.92, 0.52, 0, 0, 0, 0, 0, 0}; + float out[120] = {0}; + RowMajor2Col8Major(in, out, 18, 5); + CompareOutputData(out, co, 120, 0.0001); +} + +TEST_F(TestMatMulFp32, Row8x82RowTest1) { + float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0, + 0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0, + 0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0, + 0.09, 0.93, 0.91, 0.20, 0.97, 0, 0, 0, 0.61, 0.43, 0.14, 0.67, 0.10, 0, 0, 0, + 0.73, 0.37, 0.24, 0.93, 0.31, 0, 0, 0, 0.35, 0.52, 0.02, 0.33, 0.99, 0, 0, 0, + 0.49, 0.67, 0.75, 0.66, 0.04, 0, 0, 0, 0.10, 0.18, 0.92, 0.46, 0.08, 0, 0, 0, + 0.04, 0.24, 0.52, 0.43, 0.14, 0, 0, 0, 0.67, 0.10, 0.73, 0.37, 0.24, 0, 0, 0, + 0.93, 0.31, 0.35, 0.52, 0.02, 0, 0, 0, 0.33, 0.99, 0.49, 0.67, 0.75, 0, 0, 0, + 0.66, 0.04, 0.10, 0.18, 0.92, 0, 0, 0, 0.46, 0.08, 0.04, 0.24, 0.52, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, + 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, + 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, + 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, + 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, + 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; + float out[90] = {0}; + Row8x8Major2RowMajor(in, out, 18, 5); + CompareOutputData(out, co, 90, 0.0001); +} + +TEST_F(TestMatMulFp32, Row8x82RowTest2) { + float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0, + 0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0, + 0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, + 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39}; + float out[30] = {0}; + Row8x8Major2RowMajor(in, out, 6, 5); + CompareOutputData(out, co, 30, 0.0001); +} + +TEST_F(TestMatMulFp32, Row8x82RowTest3) { + float in[] = { + 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73, + 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, + 0.10, 0.18, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.93, + 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.75, 0.66, 0.04, 0.10, + 0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, + 0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, + 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50, + 0.39, 0.09, 0.93, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, + 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81, + 0.98, 0.09, 0.68, 0.02, 0.33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52, + 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04, + 0.24, 0.52, 0.21, 0.38, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, + 0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.04, 0.24, + 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85, 0.67, 0.81, 0.57, 0.70, + 0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, + 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97, + 0.61, 0.43, 0.14, 0.67, 0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, + 0.99, 0.49, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85, + 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0, + 0, 0.04, 0.10, 0.18, 0, 0, 0, 0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.81, 0.98, + 0.09, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73, 0.37, 0.24, 0, 0, + 0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0, 0, 0, 0, 0, + 0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0, 0, 0.13, 0.03, 0.53, + 0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0, 0, 0.04, 0.10, 0.18, 0, 0, 0, + 0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73, + 0.37, 0.24, 0, 0, 0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0, + 0, 0, 0, 0, 0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0, + 0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + float co[] = { + 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, + 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, + 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, + 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, + 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.21, 0.38, 0.81, 0.98, 0.09, + 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, + 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, + 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, + 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67, + 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, + 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, + 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, + 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, + 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, + 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, + 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, + 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, + 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, + 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67, + 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, + 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, + 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53}; + float out[418] = {0}; + Row8x8Major2RowMajor(in, out, 22, 19); + CompareOutputData(out, co, 418, 0.0001); +} + +int MMTestInit(std::vector *inputs_, std::vector *outputs_, + float *a_ptr, float *b_ptr, std::vector a_shape, std::vector b_shape, + std::vector c_shape) { + auto in_t = + new lite::tensor::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum()); + inputs_->push_back(in_t); + + auto weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, b_shape, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + memcpy(weight_t->Data(), b_ptr, sizeof(float) * weight_t->ElementsNum()); + inputs_->push_back(weight_t); + + auto out_t = + new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + return out_t->ElementsNum(); +} + +TEST_F(TestMatMulFp32, simple) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = false; + matmul_param->has_bias_ = false; + float a[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383, + 17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792}; + float b[] = {-0.0024438887, 0.0006738146, -0.008169129, 0.0021510671, -0.012470592, -0.0053063435, + 0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552, + -0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459, + 0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343}; + std::vector a_shape = {1, 2, 8}; + std::vector b_shape = {1, 8, 3}; + std::vector c_shape = {1, 2, 3}; + int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + auto ctx = new lite::Context; + ctx->thread_num_ = 2; + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + mm->Init(); + mm->Run(); + float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779, + -0.3049793541431427, -0.027687929570674896, -0.18109679222106934}; + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete matmul_param; + delete mm; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; +} + +TEST_F(TestMatMulFp32, simple_transb) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = true; + matmul_param->has_bias_ = false; + float a[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383, + 17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792}; + float b[] = {-0.0024438887, 0.0006738146, -0.008169129, 0.0021510671, -0.012470592, -0.0053063435, + 0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552, + -0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459, + 0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343}; + std::vector a_shape = {1, 2, 8}; + std::vector b_shape = {1, 3, 8}; + std::vector c_shape = {1, 2, 3}; + int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + auto ctx = new lite::Context; + ctx->thread_num_ = 2; + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + mm->Init(); + mm->Run(); + float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031}; + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete matmul_param; + delete mm; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; +} + +TEST_F(TestMatMulFp32, batch) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = true; + matmul_param->has_bias_ = false; + float a[] = {-4.946672525326248, 11.154420027909701, -7.831129637356922, 17.309845099949953, -10.46177877610444, + 2.5412751480833897, 2.700113860276929, -12.616715572097341, -15.513316568881574, -9.513294738065516, + 17.931148376418896, -10.83801964632579, -14.023733862948017, -14.50805001403956, 0.7952221556310306, + 6.619720423569035, -19.277904230909357, -13.450479287024839, 19.914652156692625, 16.542571697048878, + -2.9715041389268926, 4.949555349889412, -1.9408110276290103, -15.062828261031868, 0.20012569643335, + 8.260383531209776, 3.1092344458607357, 16.742272486091487, 17.31277252415167, -16.60303202099434, + -8.980314693173042, -11.735087989358268, -14.918976184088514, -11.347592686892733, 11.808756029220604, + -18.76179414554809, 7.579758962360987, 3.13240880962163, 6.528181981442103, -16.802624652419794, + -14.323146919914901, -16.197579076296144, 9.738053920125779, -12.245780062949866, 8.817905278096319, + 0.5261391331275007, -18.26152522535471, -2.400461208771226}; + float b[] = { + -0.895183867395529, -0.8146900207660068, -0.27931593219652817, 0.783554361201179, -0.05080215007779798, + -0.9879631271568501, 0.07710949009001333, -0.9562579726211344, 0.29505553318356825, -0.26651960351085124, + -0.12755456259718279, -0.8221417897250098, -0.5094334041431876, -0.9117373380256013, 0.991501784215064, + 0.20131976450979394, 0.07889260559412059, -0.8138407752750305, -0.047622075866657454, -0.2778043115153188, + -0.6269973420163957, -0.44345812666611617, -0.8571568605933642, 0.020192166011526735, 0.4860054298402434, + 0.41525925469513614, -0.40270506445219967, -0.8716538067535347, 0.5276448387223114, 0.6064500154192936, + -0.9553204135772526, 0.3253219646257437, -0.7237956595774822, 0.3271284879679077, -0.534543967339336, + -0.4076498484281894, 0.01574797075171963, -0.37322004720586244, 0.16425071396119928, -0.5328652244800547, + 0.7389336170615435, -0.6552069958923377, -0.042305872596973604, -0.6714941466767734, -0.9281411415119043, + -0.7748558258281224, -0.6209799945964443, 0.02526428593887675, -0.44984776800225856, 0.6281401952319337, + 0.9907258228680276, 0.6288646615999687, -0.82076880150175, 0.3065944740797497, -0.29201038744043584, + -0.025685501802048982, -0.07273175145419652, 0.9370449239208709, -0.8233807408078093, -0.4195634619023012, + 0.9799555630257346, -0.23461882935715228, -0.8884793313829993, -0.4760267734754635, -0.2874539543614072, + -0.8795685985480997, -0.08099698251915255, -0.1626521023321741, -0.9337167240793414, 0.40924842916829207, + -0.7375713045221615, -0.0065659291539015285}; + std::vector a_shape = {3, 2, 8}; + std::vector b_shape = {3, 3, 8}; + std::vector c_shape = {3, 2, 3}; + int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + auto ctx = new lite::Context; + ctx->thread_num_ = 1; + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + mm->Init(); + mm->Run(); + float correct[] = {21.38518524169922, -14.514888763427734, -11.040614128112793, 16.91403579711914, + 27.07421112060547, 23.35394287109375, -39.006141662597656, -2.021998405456543, + -17.63555145263672, -8.490625381469727, 5.317771911621094, -14.561882019042969, + -7.251564025878906, -2.508212089538574, 5.86458683013916, -3.466249465942383, + 8.869029998779297, 25.034008026123047}; + + float *output = reinterpret_cast(outputs_[0]->Data()); + for (int i = 0; i < 18; ++i) printf("%f ", output[i]); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete matmul_param; + delete mm; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc new file mode 100644 index 0000000000..fb9e262e28 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc @@ -0,0 +1,332 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/include/context.h" +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "src/common/utils.h" +#include "src/common/file_utils.h" +#include "src/runtime/kernel/arm/fp32/pooling_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" + +namespace mindspore { +class TestPoolingGradFp32 : public mindspore::Common { + public: + TestPoolingGradFp32() {} +}; + +void InitPoolingParamFP32(PoolingParameter *pooling_param) { + pooling_param->input_batch_ = 1; + pooling_param->input_h_ = 28; + pooling_param->input_w_ = 28; + pooling_param->input_channel_ = 3; + + pooling_param->output_batch_ = 1; + pooling_param->output_h_ = 28; + pooling_param->output_w_ = 28; + pooling_param->output_channel_ = 32; + + pooling_param->window_h_ = 3; + pooling_param->window_w_ = 3; + + pooling_param->stride_h_ = 1; + pooling_param->stride_w_ = 1; + + pooling_param->pad_u_ = 1; + pooling_param->pad_d_ = 1; + pooling_param->pad_l_ = 1; + pooling_param->pad_r_ = 1; + pooling_param->thread_num_ = 1; +} + +TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { + // prepare stage + auto pooling_param = new PoolingParameter(); + InitPoolingParamFP32(pooling_param); + pooling_param->output_channel_ = 3; + + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = + pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; + + size_t input_size; + std::string input_path = "./test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + AvgPoolingGrad(input_data, output_data, pooling_param); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + AvgPoolingGrad(input_data, output_data, pooling_param); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; + lite::CompareOutput(output_data, output_path); + + delete input_data; + delete[] output_data; + delete pooling_param; + MS_LOG(INFO) << "TestAvgPoolingGradFp32 passed"; +} + +TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { + // prepare stage + auto pooling_param = new PoolingParameter(); + InitPoolingParamFP32(pooling_param); + + pooling_param->output_channel_ = 3; + + // runtime part + printf("Calculating runtime cost...\n"); + // uint64_t time_avg = 0; + size_t output_data_size = + pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; + + size_t input_size; + std::string input_path = "./test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_dy({1, 28, 28, 3}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(input_data); + + std::string input1_path = "./test_data/pooling/avgpoolgradfp32_1_x_1_28_28_3.bin"; + input_data = reinterpret_cast(mindspore::lite::ReadFile(input1_path.c_str(), &input_size)); + std::vector dim_x({1, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(input_data); + + std::vector inputs = {&dy_tensor, &x_tensor}; + + auto output_data = new float[output_data_size]; + std::vector dim_dx({1, 28, 28, 3}); + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); + dx_tensor.SetData(output_data); + std::vector outputs = {&dx_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(pooling_param), NULL, desc); + + kernel_obj->Run(); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; + lite::CompareOutput(output_data, output_path); + + // delete input_data; + // delete[] output_data; + delete pooling_param; + MS_LOG(INFO) << "TestAvgPoolingGradFp32 passed"; +} + +TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { + // prepare stage + auto pooling_param = new PoolingParameter(); + InitPoolingParamFP32(pooling_param); + pooling_param->output_channel_ = 3; + pooling_param->avg_pooling_ = false; + pooling_param->max_pooling_ = true; + // runtime part + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + size_t output_data_size = + pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; + + size_t input_size; + std::string i_path = "./test_data/pooling/maxpoolgradfp32_1_i_1_28_28_3.bin"; + auto ill_data = reinterpret_cast(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); + auto i_data = new int[output_data_size]; + for (uint32_t i = 0; i < output_data_size; i++) { + i_data[i] = static_cast(ill_data[i]); + } + + std::string dy_path = "./test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin"; + auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &input_size)); + + auto output_data = new float[output_data_size]; + // warm up loop + for (int i = 0; i < 3; i++) { + MaxPoolingGrad(dy_data, i_data, output_data, pooling_param); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + MaxPoolingGrad(dy_data, i_data, output_data, pooling_param); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; + lite::CompareOutput(output_data, output_path); + + // delete input_data; + delete pooling_param; + delete[] output_data; + MS_LOG(INFO) << "TestMaxPoolingGradFp32 passed"; +} + +#if 0 +TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) { + // prepare stage + auto maxpool = new PoolingParameter(); + InitPoolingParamFP32(maxpool); + maxpool->avg_pooling_ = false; + maxpool->max_pooling_ = true; + maxpool->input_h_ = 30; + maxpool->input_w_ = 30; + maxpool->input_channel_ = 3; + + maxpool->output_batch_ = 1; + maxpool->output_h_ = 10; + maxpool->output_w_ = 10; + maxpool->output_channel_ = 3; + maxpool->stride_h_ = 3; + maxpool->stride_w_ = 3; + + maxpool->pad_u_ = 0; + maxpool->pad_d_ = 0; + maxpool->pad_l_ = 0; + maxpool->pad_r_ = 0; + + size_t input_size; + size_t y_data_size = maxpool->output_batch_ * maxpool->output_channel_ * maxpool->output_h_ * maxpool->output_w_; + + auto x_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_2_x_1_30_30_3.bin", &input_size)); + std::vector dim_x({1, 30, 30, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + std::vector maxpool_inputs = {&x_tensor}; + + auto y_data = new float[y_data_size]; + std::vector dim_y({1, 10, 10, 3}); + lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); + y_tensor.SetData(y_data); + + auto ind_data = new int[y_data_size]; + lite::tensor::Tensor ind_tensor(TypeId::kNumberTypeInt32, dim_y); + ind_tensor.SetData(ind_data); + + std::vector maxpool_outputs = {&y_tensor, &ind_tensor}; + + kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Pooling}; + auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); + auto maxpoolobj = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast(maxpool), + NULL, maxpool_desc); + maxpoolobj->Run(); + + printf("==================indices data=================\n"); + for (int i = 0; i < 10; i++) { + std::cout << ind_data[i] << " ,"; + } + std::cout << std::endl; + + auto pooling_param = new PoolingParameter(); + InitPoolingParamFP32(pooling_param); + pooling_param->avg_pooling_ = false; + pooling_param->max_pooling_ = true; + pooling_param->input_h_ = 10; + pooling_param->input_w_ = 10; + pooling_param->input_channel_ = 3; + + pooling_param->output_batch_ = 1; + pooling_param->output_h_ = 30; + pooling_param->output_w_ = 30; + pooling_param->output_channel_ = 3; + + // runtime part + printf("Calculating runtime cost...\n"); + // uint64_t time_avg = 0; + size_t output_data_size = + pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; + + auto dy_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_2_dy_1_10_10_3.bin", &input_size)); + std::vector dim_dy({1, 3, 10, 10}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(dy_data); + +#if 0 + std::string i_path = "./test_data/pooling/maxpoolgradfp32_2_i_1_3_10_10.bin"; + auto ill_data = reinterpret_cast(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); + auto i_data = new int[output_data_size]; + for (int i=0; i < output_data_size; i++) + i_data[i] = static_cast(ill_data[i]); + std::vector dim_ind({1, 3, 10, 10}); + lite::tensor::Tensor ind_tensor(TypeId::kNumberTypeInt32, dim_ind); + ind_tensor.SetData(i_data); +#endif + + std::vector inputs = {&dy_tensor, &ind_tensor}; + + auto output_data = new float[output_data_size]; + std::vector dim_dx({1, 3, 30, 30}); + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); + dx_tensor.SetData(output_data); + std::vector outputs = {&dx_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(pooling_param), NULL, desc); + kernel_obj->Run(); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + std::string output_path = "./test_data/pooling/maxpoolgradfp32_2_dx_1_30_30_3.bin"; + lite::CompareOutput(output_data, output_path); + + // delete input_data; + // delete[] output_data; + delete pooling_param; + MS_LOG(INFO) << "TestMaxPoolingKernelGradFp32 passed"; +} +#endif // if 0 before MaxPoolingKernelGradFp32 +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc new file mode 100644 index 0000000000..46423d89a4 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/power.h" +#include "src/kernel_registry.h" +#include "src/lite_kernel.h" + +namespace mindspore { +class TestPowerFp32 : public mindspore::Common { + public: + TestPowerFp32() {} +}; + +int PowerTestInit(std::vector *inputs_, std::vector *outputs_, + float *a_ptr, float *b_ptr, std::vector a_shape, std::vector b_shape, + std::vector c_shape) { + auto in_t = + new lite::tensor::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum()); + inputs_->push_back(in_t); + + auto weight_t = + new lite::tensor::Tensor(kNumberTypeFloat, b_shape, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + memcpy(weight_t->Data(), b_ptr, sizeof(float) * weight_t->ElementsNum()); + inputs_->push_back(weight_t); + + auto out_t = + new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + outputs_->push_back(out_t); + + return out_t->ElementsNum(); +} + +TEST_F(TestPowerFp32, Simple) { + std::vector inputs_; + std::vector outputs_; + auto param = new PowerParameter(); + param->scale_ = 1; + param->shift_ = 0; + float a[] = {1, 2, 3, 4}; + float b[] = {5, 6, 7, 8}; + std::vector a_shape = {2, 2}; + std::vector b_shape = {2, 2}; + std::vector c_shape = {2, 2}; + int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + auto ctx = new lite::Context; + ctx->thread_num_ = 1; + auto op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, ctx); + op->Init(); + op->Run(); + float correct[] = {1, 64, 2187, 65536}; + float *output = reinterpret_cast(outputs_[0]->Data()); + for (int i = 0; i < 4; ++i) printf("%f ", output[i]); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete op; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; +} + +TEST_F(TestPowerFp32, Broadcast) { + std::vector inputs_; + std::vector outputs_; + auto param = new PowerParameter(); + param->scale_ = 1; + param->shift_ = 0; + float a[] = {1, 2, 3, 4}; + float b[] = {2}; + std::vector a_shape = {2, 2}; + std::vector b_shape = {1}; + std::vector c_shape = {2, 2}; + int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); + auto ctx = new lite::Context; + ctx->thread_num_ = 2; + auto op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, ctx); + op->Init(); + op->Run(); + float correct[] = {1, 4, 9, 16}; + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); + delete op; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc new file mode 100644 index 0000000000..c6b1e5c7e3 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc @@ -0,0 +1,165 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { + +class SpaceToBatchTestFp32 : public mindspore::Common { + public: + SpaceToBatchTestFp32() {} +}; + +void InitSpaceToBatchParameter(SpaceToBatchParameter *param) { + param->n_dims_ = 4; + param->n_space_dims_ = 2; + + param->block_sizes_[0] = 2; + param->block_sizes_[1] = 2; + + param->paddings_[0] = 2; + param->paddings_[1] = 0; + param->paddings_[2] = 2; + param->paddings_[3] = 2; + + param->in_shape_[0] = 1; + param->in_shape_[1] = 4; + param->in_shape_[2] = 4; + param->in_shape_[3] = 1; + + param->padded_in_shape_[0] = 1; + param->padded_in_shape_[1] = 6; + param->padded_in_shape_[2] = 8; + param->padded_in_shape_[3] = 1; + + param->num_elements_ = 16; + param->num_elements_padded_ = 48; + + param->need_paddings_ = true; +} + +void InitSpaceToBatchParameter2(SpaceToBatchParameter *param) { + param->block_sizes_[0] = 2; + param->block_sizes_[1] = 2; + + param->paddings_[0] = 2; + param->paddings_[1] = 0; + param->paddings_[2] = 2; + param->paddings_[3] = 2; + + param->in_shape_[0] = 1; + param->in_shape_[1] = 4; + param->in_shape_[2] = 4; + param->in_shape_[3] = 1; + + param->padded_in_shape_[0] = 1; + param->padded_in_shape_[1] = 6; + param->padded_in_shape_[2] = 8; + param->padded_in_shape_[3] = 1; +} + +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest1) { + float input[16] = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; + const int out_size = 16; + float expect_out[16] = {1, 5, 18, 3, 2, 6, 10, 4, 10, 3, 11, 15, 20, 8, 55, 25}; + + float output[16]; + int in_shape[4] = {1, 4, 4, 1}; + int out_shape[4] = {4, 2, 2, 1}; + int block_sizes[2] = {2, 2}; + SpaceToBatchForNHWC((const float *)input, output, in_shape, 4, block_sizes); + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, out_size, 0.000001); +} + +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest2) { + SpaceToBatchParameter param; + InitSpaceToBatchParameter(¶m); + float input[16] = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; + const int out_size = 48; + float expect_out[48] = {0, 0, 0, 0, 0, 1, 5, 0, 0, 18, 3, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 10, 4, 0, + 0, 0, 0, 0, 0, 10, 3, 0, 0, 11, 15, 0, 0, 0, 0, 0, 0, 20, 8, 0, 0, 55, 25, 0}; + float output[48]; + int in_shape[4] = {1, 4, 4, 1}; + int out_shape[4] = {4, 3, 4, 1}; + int block_sizes[2] = {2, 2}; + + float padded_input[48]{}, tmp[48]{}, tmp_zero[48]{}; + float *tmp_space[3] = {padded_input, tmp, tmp_zero}; + auto ret = SpaceToBatch((const float *)input, output, param, tmp_space); + std::cout << "return " << ret << std::endl; + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, out_size, 0.000001); +} + +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest3) { + SpaceToBatchParameter param; + InitSpaceToBatchParameter2(¶m); + param.op_parameter_.type_ = schema::PrimitiveType_SpaceToBatch; + + std::vector input = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; + std::vector in_shape = {1, 4, 4, 1}; + lite::tensor::Tensor input_tensor; + input_tensor.SetData(input.data()); + input_tensor.set_shape(in_shape); + input_tensor.SetFormat(schema::Format_NHWC); + input_tensor.set_data_type(kNumberTypeFloat32); + std::vector inputs_tensor; + inputs_tensor.emplace_back(&input_tensor); + + const int out_size = 48; + float expect_out[48] = {0, 0, 0, 0, 0, 1, 5, 0, 0, 18, 3, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 10, 4, 0, + 0, 0, 0, 0, 0, 10, 3, 0, 0, 11, 15, 0, 0, 0, 0, 0, 0, 20, 8, 0, 0, 55, 25, 0}; + std::vector output(48); + std::vector out_shape = {4, 3, 4, 1}; + lite::tensor::Tensor output_tensor; + output_tensor.SetData(output.data()); + output_tensor.set_shape(out_shape); + output_tensor.SetFormat(schema::Format_NHWC); + output_tensor.set_data_type(kNumberTypeFloat32); + std::vector outputs_tensor; + outputs_tensor.emplace_back(&output_tensor); + + lite::Context ctx; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToBatch}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc); + ASSERT_NE(kernel, nullptr); + kernel->Run(); + + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output.data(), expect_out, out_size, 0.000001); + input_tensor.SetData(nullptr); + output_tensor.SetData(nullptr); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc new file mode 100644 index 0000000000..a758336be2 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_depth.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { + +class SpaceToDepthTestFp32 : public mindspore::Common { + public: + SpaceToDepthTestFp32() {} +}; + +TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest1) { + float input[16] = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; + const int out_size = 16; + float expect_out[16] = {1, 2, 10, 20, 5, 6, 3, 8, 18, 10, 11, 55, 3, 4, 15, 25}; + + float output[16]; + int in_shape[4] = {1, 4, 4, 1}; + int out_shape[4] = {1, 2, 2, 4}; + int h_start = 0; + int h_end = 2; + SpaceToDepthForNHWC((const float *)input, output, in_shape, out_shape, 4, 2, h_start, h_end); + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, out_size, 0.000001); +} + +TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest2) { + std::vector input = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; + std::vector in_shape = {1, 4, 4, 1}; + lite::tensor::Tensor input_tensor; + input_tensor.SetData(input.data()); + input_tensor.set_shape(in_shape); + input_tensor.SetFormat(schema::Format_NHWC); + input_tensor.set_data_type(kNumberTypeFloat32); + std::vector inputs_tensor; + inputs_tensor.push_back(&input_tensor); + + const int out_size = 16; + float expect_out[16] = {1, 2, 10, 20, 5, 6, 3, 8, 18, 10, 11, 55, 3, 4, 15, 25}; + std::vector output(16); + std::vector out_shape = {1, 2, 2, 4}; + lite::tensor::Tensor output_tensor; + output_tensor.SetData(output.data()); + output_tensor.set_shape(out_shape); + output_tensor.SetFormat(schema::Format_NHWC); + output_tensor.set_data_type(kNumberTypeFloat32); + std::vector outputs_tensor; + outputs_tensor.push_back(&output_tensor); + + SpaceToDepthParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; + op_param.block_size_ = 2; + + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToDepth}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + ASSERT_NE(kernel, nullptr); + kernel->Run(); + + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output.data(), expect_out, out_size, 0.000001); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/strassen_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/strassen_fp32_tests.cc new file mode 100644 index 0000000000..442f490c0a --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/strassen_fp32_tests.cc @@ -0,0 +1,369 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/strassen_matmul.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h" + +namespace mindspore { +class TestStrassenFp32 : public mindspore::Common { + public: + TestStrassenFp32() {} +}; + +TEST_F(TestStrassenFp32, MatrixAdd1) { + float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137, + 0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607, + 0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928, + 0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327, + 0.37747085, 0, 0, 0, 0.590335, 0, 0, 0, + 0.7578682, 0, 0, 0, 0.81001425, 0, 0, 0, + 0.9487712, 0, 0, 0, 0.11742989, 0, 0, 0, + 0.60004807, 0, 0, 0, 0.05973052, 0, 0, 0}; + float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781, + 0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596, + 0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268, + 0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878, + 0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0, + 0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0, + 0.2873559, 0, 0, 0, 0.612321, 0, 0, 0, + 0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0}; + float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918, + 1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347, + 1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554, + 0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211, + 1.2371662, 0, 0, 0, 1.4244826, 0, 0, 0, + 1.4808793, 0, 0, 0, 1.2173516, 0, 0, 0, + 1.2361271, 0, 0, 0, 0.72975093, 0, 0, 0, + 1.1009188, 0, 0, 0, 0.31835714, 0, 0, 0}; + float out[64] = {0}; + MatrixAdd(a, b, out, 32, 32, 32, 8, 2); + EXPECT_EQ(0, lite::CompareOutputData(out, add, 64)); +} + +TEST_F(TestStrassenFp32, MatrixAdd2) { + float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137, + 0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607, + 0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928, + 0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0.37747085, 0, 0, 0, + 0.590335, 0, 0, 0, 0.7578682, 0, 0, 0, + 0.81001425, 0, 0, 0, 0.9487712, 0, 0, 0, + 0.11742989, 0, 0, 0, 0.60004807, 0, 0, 0, + 0.05973052, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781, + 0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596, + 0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268, + 0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0, + 0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0, + 0.2873559, 0, 0, 0, 0.612321, 0, 0, 0, + 0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918, + 1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347, + 1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554, + 0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211, + 0, 0, 0, 0, 1.2371662, 0, 0, 0, + 1.4244826, 0, 0, 0, 1.4808793, 0, 0, 0, + 1.2173516, 0, 0, 0, 1.2361271, 0, 0, 0, + 0.72975093, 0, 0, 0, 1.1009188, 0, 0, 0, + 0.31835714, 0, 0, 0, 0, 0, 0, 0}; + float out[72] = {0}; + MatrixAdd(a, b, out, 44, 56, 36, 8, 2); + EXPECT_EQ(0, lite::CompareOutputData(out, add, 72)); +} + +TEST_F(TestStrassenFp32, MatrixSub1) { + float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548, + 0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536, + 0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007, + 0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054, + 0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0, + 0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0, + 0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0, + 0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0}; + float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214, + 0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907, + 0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234, + 0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138, + 0.85054415, 0, 0, 0, 0.18541782, 0, 0, 0, + 0.72714496, 0, 0, 0, 0.43221787, 0, 0, 0, + 0.7200413, 0, 0, 0, 0.15780604, 0, 0, 0, + 0.30473796, 0, 0, 0, 0.37719592, 0, 0, 0}; + float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459, + 0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548, + 0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226, + 0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675, + -0.6046069, 0., 0., 0., 0.31704146, 0., 0., 0., + -0.3044341, 0., 0., 0., 0.05446747, 0., 0., 0., + -0.2826118, 0., 0., 0., 0.07041438, 0., 0., 0., + 0.57706296, 0., 0., 0., 0.3733264, 0., 0., 0.}; + float out[64] = {0}; + MatrixSub(a, b, out, 32, 32, 32, 8, 2); + EXPECT_EQ(0, lite::CompareOutputData(out, s, 64)); +} + +TEST_F(TestStrassenFp32, MatrixSub2) { + float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548, + 0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536, + 0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007, + 0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054, + 0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0, + 0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0, + 0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0, + 0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0}; + float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214, + 0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907, + 0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234, + 0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0.85054415, 0, 0, 0, + 0.18541782, 0, 0, 0, 0.72714496, 0, 0, 0, + 0.43221787, 0, 0, 0, 0.7200413, 0, 0, 0, + 0.15780604, 0, 0, 0, 0.30473796, 0, 0, 0, + 0.37719592, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459, + 0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548, + 0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226, + 0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675, + 0, 0, 0, 0, -0.6046069, 0., 0., 0., + 0.31704146, 0., 0., 0., -0.3044341, 0., 0., 0., + 0.05446747, 0., 0., 0., -0.2826118, 0., 0., 0., + 0.07041438, 0., 0., 0., 0.57706296, 0., 0., 0., + 0.3733264, 0., 0., 0, 0, 0, 0, 0.}; + float out[72] = {0}; + MatrixSub(a, b, out, 32, 44, 36, 8, 2); + EXPECT_EQ(0, lite::CompareOutputData(out, s, 72)); +} + +TEST_F(TestStrassenFp32, MatrixPack1) { + float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, + -0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562, + 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873, + 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0, + 0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, + -8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0}; + float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, + -0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, + 15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127, + 9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, -1.770, 41.903, 0.0, 0.0, + 8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0}; + float out[56] = {0}; + + MatrixPack(in, out, 7, 2, 36); + EXPECT_EQ(0, lite::CompareOutputData(out, correct, 56)); +} + +TEST_F(TestStrassenFp32, MatrixPack2) { + float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, + -0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562, + 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873, + 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0, + 0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, + -8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0}; + float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, + -0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, + 15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127, + 9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, + -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0, + 9.7341, 18.834, 0.0, 0.0, -1.514, -0.293, 18.686, 0.0873, 4.2010, -2.253, 0.0, 0.0}; + float out[72] = {0}; + MatrixPack(in, out, 9, 2, 36); + EXPECT_EQ(0, lite::CompareOutputData(out, correct, 72)); +} + +TEST_F(TestStrassenFp32, CommonMatmul1) { + float a_ptr[] = {7.756654, 19.250782, 17.923292, 0, 13.584222, 3.3293908, 9.734102, 0, + 18.83455, -1.51425, -0.29382, 0, 18.686155, 0.0873076, 4.2010098, 0, + -2.2539594, 4.1795673, 13.14235, 0, -3.59393, 16.50578, 19.899279, 0, + 8.556229, 19.969376, -6.2355065, 0, -2.380469, -9.027744, 9.5542, 0}; + float b_ptr[] = {0.2674241, 0.089372, -0.081915, 2.0580146, -0.295045, 1.377944, 0.703658, 1.055378, + 1.204049, -0.256505, -0.309640, 0.560465, 0, 0, 0, 0, + 0.646906, 0, 0, 0, -0.168206, 0, 0, 0, + -0.95630, 0, 0, 0, 0, 0, 0, 0}; + float correct[] = {17.97499, 22.622334, 7.360805, 46.325558, 14.37076, 3.304931, -1.784072, 36.925926, + 5.129812, -0.3278886, -2.517368, 36.99899, 10.029593, 0.7127603, -2.77004, 40.90305, + 13.988123, 2.186689, -0.943787, 7.138184, 18.128653, 17.31859, 5.7472067, 21.176342, + -11.11159, 29.880829, 15.281498, 35.1893, 13.530734, -15.10318, -9.11581, -9.071925, + -15.36046, 0, 0, 0, -1.081104, 0, 0, 0, + 12.719885, 0, 0, 0, 8.056052, 0, 0, 0, + -14.72927, 0, 0, 0, -24.1311, 0, 0, 0, + 8.139168, 0, 0, 0, -9.158176, 0, 0, 0}; + StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); + matmul_param->row_ = 8; + matmul_param->deep_ = 1; + matmul_param->col_ = 2; + matmul_param->a_stride_ = 32; + matmul_param->b_stride_ = 16; + matmul_param->c_stride_ = 32; + + float c_ptr[64] = {0}; + float tmp_ptr[32]; + CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_ptr); + + EXPECT_EQ(0, lite::CompareOutputData(c_ptr, correct, 64)); + delete matmul_param; +} + +TEST_F(TestStrassenFp32, CommonMatmul2) { + StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); + float a[] = {4.864725, 6.830073, 0.76780415, 8.922394, 5.096872, 2.4946148, 4.2148714, 1.7762588, 0.89195687, + 9.703938, 2.0654619, 9.048538, 2.358036, 5.643526, 2.5152204, 3.512572, 3.7913973, 3.7136157, + 8.820186, 1.5324963, 3.135459, 7.5792265, 7.1820426, 0.267987, 8.737802, 4.064117, 2.7232447, + 0.27355433, 0, 0, 0, 0, 0, 0, 0, 0, + 6.320409, 9.479354, 0, 0, 1.6220464, 0.57753897, 0, 0, 9.786372, + 6.0404425, 0, 0, 2.1067812, 4.8034563, 0, 0, 2.1140356, 8.204062, + 0, 0, 3.29985, 1.2034118, 0, 0, 7.6059656, 4.162436, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + float b[] = { + 4.4558744, 0.6383263, 0.05037839, 9.730914, 8.1542015, 4.3625517, 8.654026, 3.805875, 9.845131, 4.08051, + 9.667656, 7.73955, 9.283867, 8.465257, 2.292051, 9.853942, 0.13320169, 3.8789113, 9.460265, 4.2616735, + 0.23831692, 4.420147, 0.5355651, 7.829217, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1.9866786, 0, 0, 0, 6.0188327, 0, + 0, 0, 6.6249146, 0, 0, 0, 3.5639563, 0, 0, 0, + 0.14810833, 0, 0, 0, 7.4168983, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + float c[] = {170.86482, 177.98166, 152.0957, 268.3473, 101.39282, 55.216248, 82.31873, 120.65008, 190.18558, + 192.58974, 220.54767, 239.75931, 115.32386, 95.52758, 103.82857, 145.08948, 150.4757, 112.04814, + 145.50496, 207.63342, 149.6962, 84.76027, 167.65851, 141.06763, 103.42963, 84.63687, 136.74927, + 189.26935, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 158.90288, 0, 0, 0, 63.917973, + 0, 0, 0, 152.3613, 0, 0, 0, 103.77265, 0, + 0, 0, 154.94044, 0, 0, 0, 109.79707, 0, 0, + 0, 92.83551, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + + matmul_param->row_ = 7; + matmul_param->deep_ = 2; + matmul_param->col_ = 2; + matmul_param->a_stride_ = 36; + matmul_param->b_stride_ = 64; + matmul_param->c_stride_ = 40; + float out[80] = {0}; + float tmp_ptr[1000]; + CommonMatMul(a, b, out, matmul_param, tmp_ptr); + EXPECT_EQ(0, lite::CompareOutputData(out, c, 80)); + delete (matmul_param); +} + +TEST_F(TestStrassenFp32, RecMatmul1) { + StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); + matmul_param->row_ = 4; + matmul_param->deep_ = 2; + matmul_param->col_ = 2; + matmul_param->a_stride_ = 16; + matmul_param->b_stride_ = 32; + matmul_param->c_stride_ = 16; + + float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592, + 8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887, + 4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0, + 6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0}; + float b[] = {1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393, + 0.26377398, 3.9010656, 4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786, + 0.97356945, 1.0714527, 6.447588, 6.161091, 3.332229, 2.8775468, 6.558747, 2.6986659, + 0, 0, 0, 0, 0, 0, 0, 0, + 1.9830805, 0, 0, 0, 8.44718, 0, 0, 0, + 9.360418, 0, 0, 0, 6.220693, 0, 0, 0, + 1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806, + 65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036, + 136.34613, 0, 0, 0, 230.64507, 0, 0, 0, + 204.15103, 0, 0, 0, 104.86488, 0, 0, 0}; + float out[32] = {0}; + + float tmp_ptr[1000]; + RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr); + EXPECT_EQ(0, lite::CompareOutputData(out, c, 32)); + delete (matmul_param); +} + +TEST_F(TestStrassenFp32, RecMatmul2) { + StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); + matmul_param->row_ = 4; + matmul_param->deep_ = 2; + matmul_param->col_ = 2; + matmul_param->a_stride_ = 32; + matmul_param->b_stride_ = 64; + matmul_param->c_stride_ = 32; + + float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592, + 8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887, + 1, 2, 3, 4, 1, 2, 3, 4, + 3, 2, 3, 4, 4, 2, 3, 4, + 4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0, + 6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0, + 1, 2, 3, 4, 1, 2, 3, 4, + 3, 2, 3, 4, 4, 2, 3, 4}; + float b[] = { + 1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393, 0.26377398, 3.9010656, + 4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786, 0.97356945, 1.0714527, 6.447588, 6.161091, + 3.332229, 2.8775468, 6.558747, 2.6986659, 0, 0, 0, 0, 0, 0, + 0, 0, 11, 2, 3, 4, 22, 2, 3, 4, + 33, 3, 3, 4, 44, 2, 3, 4, 11, 2, + 3, 4, 22, 2, 3, 4, 33, 3, 3, 4, + 44, 2, 3, 4, 1.9830805, 0, 0, 0, 8.44718, 0, + 0, 0, 9.360418, 0, 0, 0, 6.220693, 0, 0, 0, + 1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 11, 2, 3, 4, + 22, 2, 3, 4, 33, 3, 3, 4, 44, 2, + 3, 4, 11, 2, 3, 4, 22, 2, 3, 4, + 33, 3, 3, 4, 44, 2, 3, 4}; + float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806, + 65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 136.34613, 0, 0, 0, 230.64507, 0, 0, 0, + 204.15103, 0, 0, 0, 104.86488, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + float out[64] = {0}; + + float tmp_ptr[1000]; + RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr); + EXPECT_EQ(0, lite::CompareOutputData(out, c, 64)); + delete (matmul_param); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc new file mode 100644 index 0000000000..b45037d36c --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestTopKFp32 : public mindspore::Common { + public: + TestTopKFp32() {} +}; + +TEST_F(TestTopKFp32, TopK) { + lite::tensor::Tensor in_tensor(kNumberTypeFloat32, {2, 2, 3}); + lite::tensor::Tensor out_tensor0(kNumberTypeFloat32, {2, 2, 2}); + lite::tensor::Tensor out_tensor1(kNumberTypeInt32, {2, 2, 2}); + float input_data[] = {1, 2, 3, 6, 5, 4, 9, 8, 7, 10, 12, 11}; + float output_data0[8] = {0}; + int32_t output_data1[8] = {0}; + in_tensor.SetData(input_data); + out_tensor0.SetData(output_data0); + out_tensor1.SetData(output_data1); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor0, &out_tensor1}; + + TopkParameter parameter = {{}, 3, 4, 2, true}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_TopK}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + float expect0[] = {3, 2, 6, 5, 9, 8, 12, 11}; + int32_t expect1[] = {2, 1, 0, 1, 0, 1, 1, 2}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(output_data0[i], expect0[i]); + EXPECT_EQ(output_data1[i], expect1[i]); + } + + in_tensor.SetData(nullptr); + out_tensor0.SetData(nullptr); + out_tensor1.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc new file mode 100644 index 0000000000..a8d738ddaa --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" + +namespace mindspore { +class TestQuantizedAdd : public mindspore::Common { + public: + TestQuantizedAdd() {} +}; + +TEST_F(TestQuantizedAdd, Add) { + lite::tensor::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5}); + lite::tensor::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 2, 5}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5}); + + int8_t input_data0[] = {-102, 25, -51, 89, -102, 25, -51, 89, -102, 25}; // -0.8 0.2 -0.4 0.7 + int8_t input_data1[] = {38, 51, 64, -102, 38, 51, 64, -102, 38, 51}; // 0.3 0.4 0.5 -0.8 + int8_t output_data[10] = {0}; + in_tensor0.SetData(input_data0); + in_tensor1.SetData(input_data1); + out_tensor.SetData(output_data); + + const lite::tensor::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255 + const lite::tensor::QuantArg quant_in1 = {0.00784314f, 0}; + const lite::tensor::QuantArg quant_out = {0.00784314f, 0}; + in_tensor0.AddQuantParam(quant_in0); + in_tensor1.AddQuantParam(quant_in1); + out_tensor.AddQuantParam(quant_out); + + std::vector inputs = {&in_tensor0, &in_tensor1}; + std::vector outputs = {&out_tensor}; + + OpParameter parameter = {}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Add}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect0[10] = {-64, 76, 13, -13, -64, 76, 13, -13, -64, 76}; // -0.5 0.6 0.1 -0.1 + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(output_data[i], expect0[i]); + } + + in_tensor0.SetData(nullptr); + in_tensor1.SetData(nullptr); + out_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc new file mode 100644 index 0000000000..3ab4e5f26b --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc @@ -0,0 +1,976 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/ir/tensor.h" + +namespace mindspore { + +class TestArithmeticSelfInt8 : public mindspore::Common { + public: + TestArithmeticSelfInt8() {} +}; + +TEST_F(TestArithmeticSelfInt8, floor_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Floor; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Floor}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, floor_quant1_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.8; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.5; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Floor; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Floor}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {0, 1, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, round_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Round; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Floor}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, round_quant1_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.8; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.5; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Round; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Floor}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, ceil_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Ceil; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Floor}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, ceil_quant1_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.8; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.5; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Ceil; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Floor}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 7}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, abs_quant0_thread0) { + std::vector input1 = {-1, -2, -3, -4, -5, -6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Abs; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Abs}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, abs_quant1_thread2) { + std::vector input1 = {-1, -2, -3, -4, -5, -6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.8; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.5; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Abs; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Abs}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, sin_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4}; + std::vector shape1 = {2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 4; + int8_t output[4]; + std::vector output_shape = {2, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Sin; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Sin}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 1, 0, -1}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, cos_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4}; + std::vector shape1 = {2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 4; + int8_t output[4]; + std::vector output_shape = {2, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Cos; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Cos}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 0, -1, -1}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, log_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Log; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Log}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, sqrt_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Sqrt; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Sqrt}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, rsqrt_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Rsqrt; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Rsqrt}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, square_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Square; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Square}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 127}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, square_quant1_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.8; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.5; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Square; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Square}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {1, 2, 4, 7, 11, 16, 21, 28, 35, 43, 52, 62}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestArithmeticSelfInt8, logical_not_quant0_thread2) { + std::vector input1 = {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 12; + int8_t output[12]; + std::vector output_shape = {2, 3, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(1); + outputs_tensor[0] = output0_tensor; + + ArithmeticSelfParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_LogicalNot; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_LogicalNot}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc new file mode 100644 index 0000000000..d94ae6eefe --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc @@ -0,0 +1,672 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/crop_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/ir/tensor.h" + +namespace mindspore { + +class TestCropInt8 : public mindspore::Common { + public: + TestCropInt8() {} +}; + +TEST_F(TestCropInt8, crop_1d_axis0_offset0_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector shape1 = {8}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 7; + int8_t output[7]; + std::vector output_shape = {7}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.axis_ = 0; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {2, 3, 4, 5, 6, 7, 8}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_2d_axis1_offset0_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector shape1 = {2, 8}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 14; + int8_t output[14]; + std::vector output_shape = {2, 7}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.axis_ = 1; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_3d_axis1_offset0_quant0_thread0) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector shape1 = {2, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 2; + int8_t output[2]; + std::vector output_shape = {2, 1, 1}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + op_param.axis_ = 1; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {4, 8}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_3d_axis1_offset0_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}; + std::vector shape1 = {2, 8, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 14; + int8_t output[14]; + std::vector output_shape = {2, 7, 1}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.axis_ = 1; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {4, 6, 8, 10, 12, 14, 16, 20, 22, 24, 26, 28, 30, 32}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_4d_axis0_offset0_quant0_thread0) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector shape1 = {2, 2, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 1; + int8_t output[1]; + std::vector output_shape = {1, 1, 1, 1}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + op_param.axis_ = 0; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {16}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_4d_axis1_offset0_quant0_thread0) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector shape1 = {2, 2, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 2; + int8_t output[2]; + std::vector output_shape = {2, 1, 1, 1}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + op_param.axis_ = 1; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {8, 16}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_4d_axis1_offset1_quant0_thread0) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector shape1 = {2, 2, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 4; + int8_t output[4]; + std::vector output_shape = {1, 1, 2, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + op_param.axis_ = 0; + op_param.offset_[0] = 1; + op_param.offset_[1] = 1; + op_param.offset_[2] = 0; + op_param.offset_[3] = 0; + op_param.offset_size_ = 4; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {13, 14, 15, 16}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_4d_axis1_offset1_quant1_thread0) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector shape1 = {2, 2, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 4; + int8_t output[4]; + std::vector output_shape = {1, 1, 2, 2}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 2.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 1; + op_param.axis_ = 0; + op_param.offset_[0] = 1; + op_param.offset_[1] = 1; + op_param.offset_[2] = 0; + op_param.offset_[3] = 0; + op_param.offset_size_ = 4; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {7, 7, 8, 8}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_4d_axis0_offset0_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}; + std::vector shape1 = {2, 8, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 7; + int8_t output[7]; + std::vector output_shape = {1, 7, 1, 1}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + op_param.axis_ = 0; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {40, 44, 48, 52, 56, 60, 64}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} + +TEST_F(TestCropInt8, crop_4d_axis0_offset0_quant0_thread3) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}; + std::vector shape1 = {2, 8, 2, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output_size = 7; + int8_t output[7]; + std::vector output_shape = {1, 7, 1, 1}; + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + TypeId tid_int8 = kNumberTypeInt8; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + std::vector outputs_tensor(1); + lite::tensor::Tensor *output0_tensor = new lite::tensor::Tensor; + output0_tensor->SetData(output); + output0_tensor->set_shape(output_shape); + output0_tensor->AddQuantParam(output_quant_arg); + output0_tensor->set_data_type(tid_int8); + outputs_tensor[0] = output0_tensor; + + CropParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Crop; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 3; + op_param.axis_ = 0; + op_param.offset_[0] = 1; + op_param.offset_size_ = 1; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Crop}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor->shape(); + ASSERT_EQ(output_tensor_shape, output_shape); + kernel->Run(); + + std::vector except_result = {40, 44, 48, 52, 56, 60, 64}; + PrintData("output data", output, output_size); + PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); + CompareOutputData(output, except_result.data(), output_size, 0.000001); + + input_tensor1->SetData(nullptr); + output0_tensor->SetData(nullptr); + delete input_tensor1; + delete output0_tensor; + delete ctx; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc new file mode 100644 index 0000000000..632abad686 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -0,0 +1,266 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/pack.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h" + +using mindspore::lite::DeviceType; + +namespace mindspore { +using mindspore::lite::tensor::QuantArg; +using mindspore::lite::tensor::Tensor; +using mindspore::schema::Format_NHWC; +using mindspore::schema::NodeType_Parameter; +class TestDeconvInt8 : public mindspore::Common { + public: + TestDeconvInt8() {} +}; + +void FloatToInt8(float *fptr, int8_t *iptr, size_t size, int32_t zp, double scale) { + for (int i = 0; i < size; i++) { + int32_t value = round(fptr[i] / scale + zp); + value = MSMIN(value, INT8_MAX); + value = MSMAX(value, INT8_MIN); + iptr[i] = (int8_t)value; + } +} + +TEST_F(TestDeconvInt8, PackWeight1) { + int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100, + 72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90, + 4, -106, 105, 28, -61, -79, 55, -54, 47, -38, 114, 125, -65, 100, 6, -72, -33, 60, 109, -68}; + int8_t co[] = {-8, 11, 99, -80, 8, -12, 0, 0, 112, 124, -109, 85, -24, 28, 0, 0, -110, 37, -72, 65, + -124, 91, 0, 0, -14, -81, 67, 90, 4, -106, 0, 0, 47, -38, 114, 125, -65, 100, 0, 0, + 37, -45, 31, -69, -66, 26, 0, 0, -46, 100, 72, -36, -82, 64, 0, 0, -43, 99, 3, 100, + 19, 51, 0, 0, 105, 28, -61, -79, 55, -54, 0, 0, 6, -72, -33, 60, 109, -68, 0, 0}; + int8_t dst[80] = {0}; + /*5*1*2*6 nhwc*/ + PackNHWCToC8HWN8Int8(in, dst, 5, 2, 6); + CompareOutputData(dst, co, 80, 1); +} + +TEST_F(TestDeconvInt8, PackWeight2) { + int8_t in[] = { + 40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103, + -22, 32, 26, 112, -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 121, -62, 119, -93, + 21, -92, -1, -82, -71, -54, 63, -93, 92, -93, 99, 122, -104, -16, -8, -32, 90, -126, 51, 91, + 4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, -97, 99, 88, -15, 54, 26, 77, -25, + 113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, 80, -115, -98, -8, -17, 31, 88, 65, + -1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, 92, -34, -17, -52, + 104, -52, -91, 76, 79, 105, 102, -65, 43, 32, 13, 15, -38, 95, -18, -82, -7, 118, -79, -85, + 120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54, 63, 111, -16, 92, 82, -23, 111, 53, + 1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31, 94, 101, -10, 18, 0, -49, 108, 28, + -36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66, 99, -121, -107, 31, -38, 56, -30, 109, + -7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71, -54, 20, -45, 109, -42, 78, -79, 98, + -10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92, 24, 55, 4, -110, -37, 112, -18, 10, + -42, 16, -9, 31, 39, -70, 108, -3, -90, -60, -121, 11, 50, -88, -104, -29, -89, 94, 64, -91, + -101, -7, 23, -57, 93, 16, 17, 35, -48, -25, 13, -121, 73, -68, -54, -122, -20, 12, 64, 20, + -11, -6, -71, -52, -97, 109, 116, -107, 117, -124, 56, 80, -108, 30, 123, 56, -80, 39, -18, -97, + -103, 122, 114, -10, -31, 97, -92, 105, -61, -25, 10, -119, -106, 41, 77, -117, 55, -83, -29, 14, + 27, -106, -86, 41, 43, 23, 11, -76, -34, 121, 94, 18, 69, 73, 100, 54, 43, 32, 13, 15, + -38, 95, -18, -82, -7, 118, -79, -85, 120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54, + 63, 111, -16, 92, 82, -23, 111, 53, 1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31, + 94, 101, -10, 18, 0, -49, 108, 28, -36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66, + 99, -121, -107, 31, -38, 56, -30, 109, -7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71, + -54, 20, -45, 109, -42, 78, -79, 98, -10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92}; + int8_t co[] = { + 40, 24, 94, 122, 67, 34, -89, 31, -22, 32, 26, 112, -92, -23, 43, 9, 21, -92, -1, -82, + -71, -54, 63, -93, 4, 70, -7, 116, 99, 81, -79, 124, 113, 119, 119, -75, -17, 7, 7, 1, + -1, -15, -98, 77, 56, 119, -20, -32, 104, -52, -91, 76, 79, 105, 102, -65, 120, -15, 2, 32, + -94, 111, 115, 102, 1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84, + -7, 28, -22, -17, -3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -42, 16, -9, 31, + 39, -70, 108, -3, -101, -7, 23, -57, 93, 16, 17, 35, -11, -6, -71, -52, -97, 109, 116, -107, + -103, 122, 114, -10, -31, 97, -92, 105, 27, -106, -86, 41, 43, 23, 11, -76, -38, 95, -18, -82, + -7, 118, -79, -85, 63, 111, -16, 92, 82, -23, 111, 53, 94, 101, -10, 18, 0, -49, 108, 28, + 99, -121, -107, 31, -38, 56, -30, 109, -54, 20, -45, 109, -42, 78, -79, 98, -43, 121, 48, -54, + 44, -91, 35, 89, 81, 118, -73, -54, 65, -99, 51, -90, 92, -93, 99, 122, -104, -16, -8, -32, + -14, 28, 97, 9, -97, 99, 88, -15, 69, 66, 40, -13, 80, -115, -98, -8, -54, -58, -16, 52, + 121, 126, -33, 43, 43, 32, 13, 15, -38, 95, -18, -82, -18, 121, -106, 54, 63, 111, -16, 92, + 80, -51, 116, 31, 94, 101, -10, 18, 74, -114, -107, 66, 99, -121, -107, 31, 108, -84, -23, -71, + -54, 20, -45, 109, 46, 121, 66, 92, 24, 55, 4, -110, -90, -60, -121, 11, 50, -88, -104, -29, + -48, -25, 13, -121, 73, -68, -54, -122, 117, -124, 56, 80, -108, 30, 123, 56, -61, -25, 10, -119, + -106, 41, 77, -117, -34, 121, 94, 18, 69, 73, 100, 54, 120, -15, 2, 32, -94, 111, 115, 102, + 1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84, -7, 28, -22, -17, + -3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -37, 114, -8, 103, 0, 0, 0, 0, + 121, -62, 119, -93, 0, 0, 0, 0, 90, -126, 51, 91, 0, 0, 0, 0, 54, 26, 77, -25, + 0, 0, 0, 0, -17, 31, 88, 65, 0, 0, 0, 0, 92, -34, -17, -52, 0, 0, 0, 0, + -7, 118, -79, -85, 0, 0, 0, 0, 82, -23, 111, 53, 0, 0, 0, 0, 0, -49, 108, 28, + 0, 0, 0, 0, -38, 56, -30, 109, 0, 0, 0, 0, -42, 78, -79, 98, 0, 0, 0, 0, + -37, 112, -18, 10, 0, 0, 0, 0, -89, 94, 64, -91, 0, 0, 0, 0, -20, 12, 64, 20, + 0, 0, 0, 0, -80, 39, -18, -97, 0, 0, 0, 0, 55, -83, -29, 14, 0, 0, 0, 0, + 43, 32, 13, 15, 0, 0, 0, 0, -18, 121, -106, 54, 0, 0, 0, 0, 80, -51, 116, 31, + 0, 0, 0, 0, 74, -114, -107, 66, 0, 0, 0, 0, 108, -84, -23, -71, 0, 0, 0, 0, + 46, 121, 66, 92, 0, 0, 0, 0}; + int8_t dst[528] = {0}; + PackNHWCToC8HWN8Int8(in, dst, 22, 1, 20); + CompareOutputData(dst, co, 528, 1); +} + +TEST_F(TestDeconvInt8, MatMulTest1) { + int8_t a_row_major_10_12[] = { + -6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105, + 68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, 57, -41, -51, 77, + 1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68, 57, 61, 55, -20, 3, 2, + 17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15, -7, -16, -47, 112, 114, -26, -98, 53, + 15, -49, 26, 19, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 124, 113, -5, 15, + -8, 107, -65, -88, 50, -47, -80, -84, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10}; + int32_t zp_a = 15; + int8_t a_col8_major[16 * 12] = {0}; + int8_t b_col_major_12_18[] = { + 92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99, + 77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103, + -3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80, + -75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33, + -77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46, + -53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42, + 43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46, + 35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7, + -8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81, + 40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51, + 40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49}; + int32_t zp_b = -20; + int8_t b_row8_major[12 * 24] = {0}; + int32_t co_row_major_10_18[] = { + 32005, 3597, 16595, -3458, 6627, -6663, 818, -3910, 10228, 15079, -19205, -10203, -3178, -10046, + 10374, -6199, 5330, 12163, 1819, 20533, 17382, 18283, 9778, 9185, -12623, -26234, -11987, 7904, + 8144, -1603, 27611, -10190, -20053, 4999, -28389, 21852, 24680, 25858, 23506, 17944, 11768, 24378, + -6102, -4675, -23460, 10434, -47579, 1986, 12018, -19418, -7248, 4938, -32613, -941, 8171, -4788, + 3325, -11310, -8351, -14786, 6909, 16401, 2017, -6456, 11242, 7393, -9119, 17312, 2646, -14402, + 7201, -9949, 23986, 17607, 27461, -1547, 2783, 7558, 19487, 11158, -2686, 6328, -8225, -11668, + 21858, -2079, -8671, -639, -1544, 1235, 1156, 6582, 2829, -10311, -2692, 5154, 1527, 10870, + 106, -8189, -24174, -1846, -15399, -3598, 14874, -5591, -619, -13667, -6053, -31103, -24499, 13008, + 9143, -17982, 28437, 2176, -2114, -11631, 10779, -1032, -24690, -3112, 2125, 432, 20270, -33859, + 8907, 10063, 1603, 3761, 4805, 4904, -15594, 10786, 4287, -13591, -18777, -1679, 2109, -2243, + 12051, -8504, -6558, 4209, 13606, -25803, 27922, 12092, 7140, 27142, -12267, 2339, -26224, 23674, + -26579, -11398, -1823, -18976, 3641, 4415, -24878, -2045, 15937, 41465, 12601, -14513, -17619, -5728, + 334, -424, 8147, -1369, 5984, 11000, 19016, 4456, -25920, 4506, 5930, 15458}; + int32_t c_row8x8_major[16 * 24] = {0}; + + int32_t out_row_major[180] = {0}; + RowMajor2Col8MajorInt8(a_row_major_10_12, a_col8_major, 10, 12); + RowMajor2Col8MajorInt8(b_col_major_12_18, b_row8_major, 18, 12); + MatMulInt8(a_col8_major, b_row8_major, c_row8x8_major, 16, 24, 12, zp_a, zp_b); + Row8x8Major2RowMajor(reinterpret_cast(c_row8x8_major), reinterpret_cast(out_row_major), 10, 18); + CompareOutputData(out_row_major, co_row_major_10_18, 180, 1); +} + +TEST_F(TestDeconvInt8, PostAddTest1) { + int32_t in[] = { + -4956, -3923, 868, -8880, -4089, -5179, -4526, -4527, -10464, 99, -5826, -2995, -4519, -4519, -10509, -2505, + -11272, 434, -4522, -4523, -5287, -8936, -878, 373, -4528, -4529, -1960, -6589, 1688, 2287, -8059, 926, + -2506, -6972, -2834, -8281, -8118, -3110, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519, + -4520, -4521, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519, + 1578, 2231, -4522, -4523, -4524, -4525, -4526, -4527, -8449, -990, -4519, -4519, -4519, -4519, -4519, -4519, + -4303, -10293, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519, + -7025, 924, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519, + -4520, -4521, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519}; + int8_t co[] = {-8, 11, 99, -80, 8, -12, 0, 0, 112, 124, -109, 85, -24, 28, 0, 0, -110, + 37, -72, 65, -124, 91, 0, 0, -14, -81, 67, 90, 4, -106, 0, 0, 47, -38, + 114, 125, -65, 100, 0, 0, 37, -45, 31, -69, -66, 26, 0, 0, -46, 100}; + int32_t bias[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + int8_t out[50] = {0}; + double multiplier = 0.0183649725490196; + int32_t quant_multiplier; + int32_t left_shift; + int32_t right_shift; + QuantizeRoundParameter(multiplier, &quant_multiplier, &left_shift, &right_shift); + int32_t zp = 83; + PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, -128, 127); + CompareOutputData(out, co, 50, 1); + + int8_t co_relu[] = {0, 11, 99, 0, 8, 0, 0, 0, 112, 124, 0, 85, 0, 28, 0, 0, 0, 37, 0, 65, 0, 91, 0, 0, 0, + 0, 67, 90, 4, 0, 0, 0, 47, 0, 114, 125, 0, 100, 0, 0, 37, 0, 31, 0, 0, 26, 0, 0, 0, 100}; + PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 127); + CompareOutputData(out, co_relu, 50, 1); + + int8_t co_relu6[] = {0, 6, 6, 0, 6, 0, 0, 0, 6, 6, 0, 6, 0, 6, 0, 0, 0, 6, 0, 6, 0, 6, 0, 0, 0, + 0, 6, 6, 4, 0, 0, 0, 6, 0, 6, 6, 0, 6, 0, 0, 6, 0, 6, 0, 0, 6, 0, 0, 0, 6}; + PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 6); + CompareOutputData(out, co_relu6, 50, 1); +} + +int DeConvInt8TestInit1(std::vector *inputs_, std::vector *outputs_, + ConvParameter *conv_param, int8_t **correct) { + /* float data from deconv fp32 testcase : DeConvTestInit2 */ + /* vq = (vi - zp) * s vi = vq / s + zp */ + Tensor *in_t = new Tensor(kNumberTypeInt8, {1, 4, 2, 3}, Format_NHWC, NodeType_Parameter); + in_t->MallocData(); + int8_t in[] = {6, 43, 38, 24, -8, 12, 41, -24, -20, 41, -19, -6, -26, -6, 23, -31, 34, 45, 8, 45, -39, -27, -48, 12}; + memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum()); + QuantArg *in_quant_arg = new QuantArg(); + in_quant_arg->zeroPoint = -19, in_quant_arg->scale = 0.31228156; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 3, 3, 2}, Format_NHWC, NodeType_Parameter); + weight_t->MallocData(); + int8_t weight[] = {66, 89, 98, 74, 95, 86, 125, 95, 105, 83, 116, 94, 90, 80, 86, 59, 72, 92, + 64, 76, 92, 80, 90, 87, 106, 55, 105, 60, 75, 53, 81, 81, 98, 81, 86, 59, + 74, 82, 97, 105, 71, 67, 79, 87, 72, 79, 80, 76, 96, 80, 83, 71, 61, 79}; + memcpy(weight_t->Data(), weight, sizeof(int8_t) * weight_t->ElementsNum()); + QuantArg *w_quant_arg = new QuantArg(); + w_quant_arg->zeroPoint = 83, w_quant_arg->scale = 0.023649725490196; + weight_t->AddQuantParam(*w_quant_arg); + inputs_->push_back(weight_t); + + Tensor *out_t = new Tensor(kNumberTypeInt8, {1, 7, 3, 2}, Format_NHWC, NodeType_Parameter); + out_t->MallocData(); + QuantArg *out_quant_arg = new QuantArg(); + out_quant_arg->zeroPoint = 31, out_quant_arg->scale = 0.3439215686275; + out_t->AddQuantParam(*out_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); + int8_t co_nchw[] = {57, 76, 49, 71, 8, 61, 57, 127, 56, 46, -11, 61, 23, 31, 34, 50, 59, 49, 78, 17, 6, + -3, -5, 23, -11, 6, -5, 33, 64, 30, 21, 18, 25, 21, -15, 0, 4, 31, 36, 2, 17, 43}; + PackNCHWToNHWCInt8(co_nchw, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel()); + + conv_param->kernel_h_ = conv_param->kernel_w_ = 3; + conv_param->pad_h_ = conv_param->pad_w_ = 1; + conv_param->stride_h_ = conv_param->stride_w_ = 2; + conv_param->dilation_h_ = conv_param->dilation_w_ = 1; + return out_t->ElementsNum(); +} + +TEST_F(TestDeconvInt8, DeConvInt8Test1) { + std::vector inputs_; + std::vector outputs_; + auto deconv_param = new ConvParameter(); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + int8_t *correct; + int total_size = DeConvInt8TestInit1(&inputs_, &outputs_, deconv_param, &correct); + mindspore::kernel::DeConvInt8CPUKernel *deconv = + new mindspore::kernel::DeConvInt8CPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + + deconv->Init(); + deconv->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 3); + + delete deconv_param; + delete deconv; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc new file mode 100644 index 0000000000..d5ca51733f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { +using lite::tensor::Tensor; +class TestFcInt8 : public mindspore::Common { + public: + TestFcInt8() {} +}; + +int FcInt8TestInit(std::vector *inputs_, std::vector *outputs_, + MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) { + float input_max = 20; + float input_min = -20; + float weight_max = 1; + float weight_min = -1; + float output_max = 20; + float output_min = -20; + + double input_scale = + (input_max - input_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int input_zp = std::numeric_limits::max() - input_max / input_scale; + double weight_scale = + (weight_max - weight_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int weight_zp = std::numeric_limits::max() - weight_max / weight_scale; + double output_scale = + (output_max - output_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int output_zp = std::numeric_limits::max() - output_max / output_scale; + *scale = output_scale; + *zeropoint = output_zp; + + Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383, + 17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792}; + Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast(in_t->Data())); + auto in_quant_arg = new mindspore::lite::tensor::QuantArg(); + in_quant_arg->zeroPoint = input_zp; + in_quant_arg->scale = input_scale; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435, + 0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552, + -0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459, + 0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343}; + Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast(weight_t->Data())); + auto weight_quant_arg = new mindspore::lite::tensor::QuantArg(); + weight_quant_arg->zeroPoint = weight_zp; + weight_quant_arg->scale = weight_scale; + weight_t->AddQuantParam(*weight_quant_arg); + inputs_->push_back(weight_t); + + Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum()); + inputs_->push_back(bias_t); + + Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + auto output_quant_arg = new mindspore::lite::tensor::QuantArg(); + output_quant_arg->zeroPoint = output_zp; + output_quant_arg->scale = output_scale; + out_t->AddQuantParam(*output_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108}; + memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float)); + + matmal_param->b_transpose_ = true; + matmal_param->a_transpose_ = false; + matmal_param->has_bias_ = true; + matmal_param->act_type_ = ActType_No; + return out_t->ElementsNum(); +} + +TEST_F(TestFcInt8, fcint8) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + float *correct; + double output_scale; + int output_zp; + int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp); + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::FullconnectionInt8CPUKernel *fc = + new kernel::FullconnectionInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + + fc->Init(); + fc->Run(); + float fout[6] = {0}; + Dequantize(reinterpret_cast(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, + fout); + CompareOutputData(fout, correct, 6, 0.2); + delete matmul_param; + delete fc; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc new file mode 100644 index 0000000000..ef2f1b0e77 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32/activation.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" + +namespace mindspore { +class TestHSwishInt8 : public mindspore::Common { + public: + TestHSwishInt8() {} +}; + +TEST_F(TestHSwishInt8, HSwish) { + lite::tensor::Tensor in_tensor(kNumberTypeInt8, {4, 4}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {4, 4}); + + int8_t input_data[] = {-116, -105, -93, -35, 23, 35, 46, 104}; // -3.5f, -3.0f, -2.5f, 0.f, 2.5f, 3.0f, 3.5f, 6.0f + int8_t output_data[8] = {0}; + in_tensor.SetData(input_data); + out_tensor.SetData(output_data); + + const lite::tensor::QuantArg quant_in = {0.0431373f, -35}; // -4.0 -- 7.0 + const lite::tensor::QuantArg quant_out = {0.0392157f, -52}; // -3.0 -- 7.0 + in_tensor.AddQuantParam(quant_in); + out_tensor.AddQuantParam(quant_out); + + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor}; + + ActivationParameter parameter = {0}; + parameter.op_parameter_.type_ = schema::PrimitiveType_Activation; + parameter.type_ = schema::ActivationType_HSWISH; + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Activation}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect[8] = {-52, -52, -57, -52, 7, 25, 37, 101}; // 0, 0, -0.208333, 0, 2.29167, 3, 3.5, 6 + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(output_data[i], expect[i]); + } + + in_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc new file mode 100644 index 0000000000..36d7ecd23e --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { +class TestMatmulInt8 : public mindspore::Common { + public: + TestMatmulInt8() {} +}; + +int MMInt8TestInit(std::vector *inputs_, std::vector *outputs_, + MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) { + float input_max = 20; + float input_min = -20; + float weight_max = 1; + float weight_min = -1; + float output_max = 30; + float output_min = -30; + + double input_scale = + (input_max - input_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int input_zp = std::numeric_limits::max() - input_max / input_scale; + double weight_scale = + (weight_max - weight_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int weight_zp = std::numeric_limits::max() - weight_max / weight_scale; + double output_scale = + (output_max - output_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int output_zp = std::numeric_limits::max() - output_max / output_scale; + *scale = output_scale; + *zeropoint = output_zp; + + auto in_t = + new lite::tensor::Tensor(kNumberTypeInt8, {1, 2, 8}, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in[] = {6.583835634764597, 11.337275140963907, -4.125256949459629, 10.994337291530833, + 19.086065139532636, 3.620842999158455, 13.167624585590346, -18.326739299407755, + 14.877693740734841, -17.092677920571653, 19.24147072807235, -15.14805323833401, + -18.075654829688737, -0.9164404591894204, -3.836646280336332, -10.870298671273918}; + Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast(in_t->Data())); + auto in_quant_arg = new mindspore::lite::tensor::QuantArg(); + in_quant_arg->zeroPoint = input_zp; + in_quant_arg->scale = input_scale; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + auto weight_t = + new lite::tensor::Tensor(kNumberTypeInt8, {1, 3, 8}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight[] = {0.3651070698591563, -0.5856943921727129, -0.7472032663840145, 0.9489992871641959, + -0.8179490270358738, -0.873058811259344, 0.39876672713807215, -0.1816769383004213, + -0.13584645926733696, -0.7614673836659709, -0.2535825872616164, -0.05265760030895916, + 0.28558728305658754, 0.15404213943520118, -0.1634824450738006, -0.5068199082730189, + -0.026961256849111326, -0.1508441942453307, 0.9375335677537737, 0.3304690744194263, + -0.5091563780251127, 0.029887336278646925, -0.39540496207319276, 0.46094065001445084}; + Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast(weight_t->Data())); + auto weight_quant_arg = new mindspore::lite::tensor::QuantArg(); + weight_quant_arg->zeroPoint = weight_zp; + weight_quant_arg->scale = weight_scale; + weight_t->AddQuantParam(*weight_quant_arg); + inputs_->push_back(weight_t); + + auto out_t = + new lite::tensor::Tensor(kNumberTypeInt8, {1, 2, 3}, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + auto output_quant_arg = new mindspore::lite::tensor::QuantArg(); + output_quant_arg->zeroPoint = output_zp; + output_quant_arg->scale = output_scale; + out_t->AddQuantParam(*output_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float nchw_co[] = {-0.912632942, 4.08398056, -25.385608673, 2.720281124, 7.745952606, 20.893184662}; + memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float)); + + matmal_param->b_transpose_ = true; + matmal_param->a_transpose_ = false; + matmal_param->has_bias_ = false; + return out_t->ElementsNum(); +} + +TEST_F(TestMatmulInt8, mmint8) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + float *correct; + double output_scale; + int output_zp; + int total_size = MMInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp); + auto ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::MatmulInt8CPUKernel *mm = + new kernel::MatmulInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + + mm->Init(); + mm->Run(); + float fout[6] = {0}; + Dequantize(reinterpret_cast(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, + fout); + CompareOutputData(fout, correct, 6, 0.3); + delete matmul_param; + delete mm; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc new file mode 100644 index 0000000000..4230ec4ae4 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "include/context.h" +#include "src/ir/tensor.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/runtime/kernel/arm/nnacl/pad_parameter.h" +#include "src/runtime/kernel/arm/int8/pad_int8.h" + +namespace mindspore { +using mindspore::lite::tensor::QuantArg; +using mindspore::lite::tensor::Tensor; + +class TestPadInt8 : public mindspore::Common { + public: + TestPadInt8() {} +}; + +int PadInt8TestInit1(std::vector *inputs_, std::vector *outputs_, PadParameter *pad_param, + int8_t **correct) { + Tensor *in_t = new Tensor(kNumberTypeInt8, {3}, schema::Format_NHWC, schema::NodeType_Parameter); + in_t->MallocData(); + int8_t in[] = {1, 1, 1}; + memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum()); + QuantArg *in_quant_arg = new QuantArg(); + in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + Tensor *out_t = new Tensor(kNumberTypeInt8, {7}, schema::Format_NHWC, schema::NodeType_Parameter); + out_t->MallocData(); + QuantArg *out_quant_arg = new QuantArg(); + out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156; + out_t->AddQuantParam(*out_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); + int8_t co[] = {10, 10, 1, 1, 1, 10, 10}; + memcpy(*correct, co, out_t->ElementsNum() * sizeof(int8_t)); + + int padding[] = {0, 0, 0, 0, 0, 0, 2, 2}; + memcpy(pad_param->paddings_, padding, MAX_PAD_SIZE * sizeof(int)); + pad_param->constant_value_ = 0; + + return out_t->ElementsNum(); +} + +TEST_F(TestPadInt8, PadInt8Test1) { + std::vector inputs_; + std::vector outputs_; + auto pad_param = new PadParameter(); + lite::Context *ctx = new lite::Context; + int8_t *correct; + int total_size = PadInt8TestInit1(&inputs_, &outputs_, pad_param, &correct); + kernel::PadInt8CPUKernel *pad = + new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx); + + pad->Init(); + pad->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0); + + delete pad_param; + delete pad; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +int PadInt8TestInit2(std::vector *inputs_, std::vector *outputs_, PadParameter *pad_param, + int8_t **correct) { + Tensor *in_t = new Tensor(kNumberTypeInt8, {6, 2}, schema::Format_NHWC, schema::NodeType_Parameter); + in_t->MallocData(); + int8_t in[] = {18, 71, 99, -6, 5, -119, 86, 13, 15, -85, -41, -77}; + memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum()); + QuantArg *in_quant_arg = new QuantArg(); + in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + Tensor *out_t = new Tensor(kNumberTypeInt8, {10, 5}, schema::Format_NHWC, schema::NodeType_Parameter); + out_t->MallocData(); + QuantArg *out_quant_arg = new QuantArg(); + out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156; + out_t->AddQuantParam(*out_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); + int8_t co[] = {10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 18, + 71, 10, 10, 10, 99, -6, 10, 10, 10, 5, -119, 10, 10, 10, 86, 13, 10, + 10, 10, 15, -85, 10, 10, 10, -41, -77, 10, 10, 10, 10, 10, 10, 10}; + memcpy(*correct, co, out_t->ElementsNum() * sizeof(int8_t)); + + int padding[] = {0, 0, 0, 0, 3, 1, 1, 2}; + memcpy(pad_param->paddings_, padding, MAX_PAD_SIZE * sizeof(int)); + pad_param->constant_value_ = 0; + + return out_t->ElementsNum(); +} + +TEST_F(TestPadInt8, PadInt8Test2) { + std::vector inputs_; + std::vector outputs_; + auto pad_param = new PadParameter(); + lite::Context *ctx = new lite::Context; + int8_t *correct; + int total_size = PadInt8TestInit2(&inputs_, &outputs_, pad_param, &correct); + kernel::PadInt8CPUKernel *pad = + new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx); + + pad->Init(); + pad->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0); + + delete pad_param; + delete pad; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +int PadInt8TestInit4(std::vector *inputs_, std::vector *outputs_, PadParameter *pad_param, + int8_t **correct) { + Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 3, 2, 1}, schema::Format_NHWC, schema::NodeType_Parameter); + in_t->MallocData(); + int8_t in[] = {73, 24, 7, -31, -109, -2, 69, -64, 51, -45, 38, 53}; + memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum()); + QuantArg *in_quant_arg = new QuantArg(); + in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + Tensor *out_t = new Tensor(kNumberTypeInt8, {6, 6, 4, 3}, schema::Format_NHWC, schema::NodeType_Parameter); + out_t->MallocData(); + QuantArg *out_quant_arg = new QuantArg(); + out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156; + out_t->AddQuantParam(*out_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); + int8_t co[] = { + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 73, 10, 10, 24, 10, 10, 10, 10, + 10, 10, 10, 10, 7, 10, 10, -31, 10, 10, 10, 10, 10, 10, 10, 10, -109, 10, 10, -2, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 69, 10, 10, -64, 10, 10, 10, 10, 10, 10, 10, 10, 51, 10, 10, -45, 10, + 10, 10, 10, 10, 10, 10, 10, 38, 10, 10, 53, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}; + memcpy(*correct, co, out_t->ElementsNum() * sizeof(int8_t)); + + int padding[] = {3, 1, 1, 2, 2, 0, 1, 1}; + memcpy(pad_param->paddings_, padding, MAX_PAD_SIZE * sizeof(int)); + pad_param->constant_value_ = 0; + + return out_t->ElementsNum(); +} + +TEST_F(TestPadInt8, PadInt8TestInit4) { + std::vector inputs_; + std::vector outputs_; + auto pad_param = new PadParameter(); + lite::Context *ctx = new lite::Context; + int8_t *correct; + int total_size = PadInt8TestInit2(&inputs_, &outputs_, pad_param, &correct); + kernel::PadInt8CPUKernel *pad = + new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx); + + pad->Init(); + pad->Run(); + CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0); + + delete pad_param; + delete pad; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc new file mode 100644 index 0000000000..89b2385b3f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/int8/quant_dtype_cast.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { + +class QuantDTypeCastTestFp32 : public mindspore::Common { + public: + QuantDTypeCastTestFp32() {} +}; + +TEST_F(QuantDTypeCastTestFp32, QuantDTypeCastTest1) { + const lite::tensor::QuantArg quant_arg{0.21176, 5}; + QuantDTypeCastParameter param; + param.srcT = kNumberTypeInt8; + param.dstT = kNumberTypeFloat32; + param.op_parameter_.type_ = schema::PrimitiveType_QuantDTypeCast; + + std::vector input = {10, 14, 29, 33, 52, 99, 19, 43, 90, 52, 19, 24, 57, 127, 76, 123}; + std::vector in_shape = {1, 4, 4, 1}; + lite::tensor::Tensor input_tensor; + input_tensor.SetData(input.data()); + input_tensor.set_shape(in_shape); + input_tensor.set_data_type(kNumberTypeInt8); + input_tensor.SetFormat(schema::Format_NHWC); + + input_tensor.AddQuantParam(quant_arg); + std::vector inputs_tensor; + inputs_tensor.emplace_back(&input_tensor); + + const int out_size = 16; + float expect_out[16] = {3.1764, 4.02344, 7.19984, 8.04688, 12.07032, 22.02304, 5.08224, 10.16448, + 20.1172, 12.07032, 5.082240, 6.14104, 13.12912, 27.95232, 17.15256, 27.10528}; + std::vector output(16); + std::vector out_shape = {1, 4, 4, 1}; + lite::tensor::Tensor output_tensor; + output_tensor.SetData(output.data()); + output_tensor.set_shape(out_shape); + output_tensor.set_data_type(kNumberTypeFloat32); + // output_tensor.SetFormat(schema::Format_NHWC); + std::vector outputs_tensor; + outputs_tensor.emplace_back(&output_tensor); + + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_QuantDTypeCast}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc); + ASSERT_NE(kernel, nullptr); + kernel->Run(); + + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output.data(), expect_out, out_size, 0.000001); +} + +TEST_F(QuantDTypeCastTestFp32, QuantDTypeCastTest2) { + const lite::tensor::QuantArg quant_arg = {0.3515625, -57}; + QuantDTypeCastParameter param; + param.op_parameter_.type_ = schema::PrimitiveType_QuantDTypeCast; + param.dstT = kNumberTypeInt8; + param.srcT = kNumberTypeFloat32; + std::vector input = {1, 2, 5, 6, 10, -20, 3, 8, 18, 10, 3, 4, 11, 16, 15, 25}; + std::vector in_shape = {1, 4, 4, 1}; + lite::tensor::Tensor input_tensor; + input_tensor.SetData(input.data()); + input_tensor.set_shape(in_shape); + // input_tensor.SetFormat(schema::Format_NHWC); + input_tensor.set_data_type(kNumberTypeFloat32); + input_tensor.AddQuantParam(quant_arg); + std::vector inputs_tensor; + inputs_tensor.emplace_back(&input_tensor); + + const int out_size = 16; + int8_t expect_out[16] = {-54, -51, -43, -40, -29, -114, -48, -34, -6, -29, -48, -46, -26, -11, -14, 14}; + std::vector output(16); + std::vector out_shape = {1, 4, 4, 1}; + lite::tensor::Tensor output_tensor; + output_tensor.SetData(output.data()); + output_tensor.set_shape(out_shape); + output_tensor.SetFormat(schema::Format_NHWC); + output_tensor.set_data_type(kNumberTypeInt8); + std::vector outputs_tensor; + outputs_tensor.emplace_back(&output_tensor); + + lite::Context ctx; + ctx.thread_num_ = 3; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_QuantDTypeCast}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc); + ASSERT_NE(kernel, nullptr); + kernel->Run(); + + for (int i = 0; i < out_size; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output.data(), expect_out, out_size, 0.000001); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relu_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relu_int8_tests.cc new file mode 100644 index 0000000000..99ca0098a8 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relu_int8_tests.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/relu_int8.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" + +namespace mindspore { +class TestReluInt8 : public mindspore::Common { + public: + TestReluInt8() {} +}; + +TEST_F(TestReluInt8, Relu) { + lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 2}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {2, 2}); + + int8_t input_data[] = {-102, 25, -51, 89}; // -0.8 0.2 -0.4 0.7 + int8_t output_data[4] = {0}; + in_tensor.SetData(input_data); + out_tensor.SetData(output_data); + + const lite::tensor::QuantArg quant_in = {0.00784314f, 0}; // -1.0--1.0 -> + const lite::tensor::QuantArg quant_out = {0.00784314f, 0}; + in_tensor.AddQuantParam(quant_in); + out_tensor.AddQuantParam(quant_out); + + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor}; + + ActivationParameter parameter = {0}; + parameter.op_parameter_.type_ = schema::PrimitiveType_Activation; + parameter.type_ = schema::ActivationType_RELU; + + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Activation}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect0[4] = {0, 26, 0, 90}; // + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(output_data[i], expect0[i]); + } + + in_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc new file mode 100644 index 0000000000..1c22a52dee --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/softmax_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { + +class TestSoftmaxInt8 : public mindspore::Common { + public: + TestSoftmaxInt8() {} +}; + +TEST_F(TestSoftmaxInt8, SoftmaxInt8) { + std::vector inputs_tensor; + std::vector outputs_tensor; + + SoftmaxParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_SoftMax; + op_param.axis_ = 2; + op_param.element_size_ = 24; + op_param.input_shape_[0] = 1; + op_param.input_shape_[1] = 2; + op_param.input_shape_[2] = 3; + op_param.input_shape_[3] = 4; + + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 0.0352941; + input_quant_arg.zeroPoint = -128; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 0.00392157; + output_quant_arg.zeroPoint = -128; + + std::vector input = {-71, -43, -15, 14, -43, -15, 14, 42, 70, 99, 99, 127, + -100, -71, -43, -15, 14, 42, 70, 99, 42, 70, 99, 127}; + std::vector in_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor input0_tensor; + TypeId tid_int8 = kNumberTypeInt8; + inputs_tensor.push_back(&input0_tensor); + input0_tensor.SetData(input.data()); + input0_tensor.set_shape(in_shape); + input0_tensor.AddQuantParam(input_quant_arg); + input0_tensor.set_data_type(tid_int8); + + std::vector output(24); + std::vector output_shape = {1, 2, 3, 4}; + + lite::tensor::Tensor output0_tensor; + outputs_tensor.push_back(&output0_tensor); + output0_tensor.SetData(output.data()); + output0_tensor.AddQuantParam(output_quant_arg); + output0_tensor.set_data_type(tid_int8); + + auto ctx = std::make_shared(); + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_SoftMax}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx.get(), desc); + ASSERT_NE(kernel, nullptr); + auto output_tensor_shape = output0_tensor.shape(); + kernel->Run(); + + std::vector except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 121, 121, 111, 111, + -127, -127, -127, -127, -59, -59, -61, -59, 57, 57, 59, 57}; + + CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001); + + input0_tensor.SetData(nullptr); + output0_tensor.SetData(nullptr); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc new file mode 100644 index 0000000000..8437edaa76 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/split_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/ir/tensor.h" + +namespace mindspore { + +class TestSplitInt8 : public mindspore::Common { + public: + TestSplitInt8() {} +}; + +TEST_F(TestSplitInt8, Split_quant0_thread2) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output1_size = 4; + int8_t output1[4]; + const int output2_size = 8; + int8_t output2[8]; + std::vector output1_shape = {2, 1, 2}; + std::vector output2_shape = {2, 2, 2}; + + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output1_tensor = new lite::tensor::Tensor; + output1_tensor->SetData(output1); + output1_tensor->set_shape(output1_shape); + output1_tensor->AddQuantParam(output_quant_arg); + output1_tensor->set_data_type(tid_int8); + lite::tensor::Tensor *output2_tensor = new lite::tensor::Tensor; + output2_tensor->SetData(output2); + output2_tensor->set_shape(output2_shape); + output2_tensor->AddQuantParam(output_quant_arg); + output2_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(2); + outputs_tensor[0] = output1_tensor; + outputs_tensor[1] = output2_tensor; + + SplitParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Split; + op_param.num_split_ = 2; + op_param.split_dim_ = 1; + op_param.split_sizes_[0] = 1; + op_param.split_sizes_[1] = 2; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Split}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output1_tensor_shape = output1_tensor->shape(); + auto output2_tensor_shape = output2_tensor->shape(); + ASSERT_EQ(output1_tensor_shape, output1_shape); + ASSERT_EQ(output2_tensor_shape, output2_shape); + kernel->Run(); + + std::vector except_result1 = {1, 2, 7, 8}; + std::vector except_result2 = {3, 4, 5, 6, 9, 10, 11, 12}; + PrintData("output data", output1, output1_size); + PrintData("output data shape", output1_tensor_shape.data(), output1_tensor_shape.size()); + PrintData("output data", output2, output2_size); + PrintData("output data shape", output2_tensor_shape.data(), output2_tensor_shape.size()); + CompareOutputData(output1, except_result1.data(), output1_size, 0.000001); + CompareOutputData(output2, except_result2.data(), output2_size, 0.000001); + + input_tensor1->SetData(nullptr); + output1_tensor->SetData(nullptr); + output2_tensor->SetData(nullptr); + delete input_tensor1; + delete output1_tensor; + delete output2_tensor; + delete ctx; +} + +TEST_F(TestSplitInt8, Split_quant0_thread2_num) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output1_size = 4; + int8_t output1[4]; + const int output2_size = 4; + int8_t output2[4]; + const int output3_size = 4; + int8_t output3[4]; + std::vector output1_shape = {2, 1, 2}; + std::vector output2_shape = {2, 1, 2}; + std::vector output3_shape = {2, 1, 2}; + + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 1.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output1_tensor = new lite::tensor::Tensor; + output1_tensor->SetData(output1); + output1_tensor->set_shape(output1_shape); + output1_tensor->AddQuantParam(output_quant_arg); + output1_tensor->set_data_type(tid_int8); + lite::tensor::Tensor *output2_tensor = new lite::tensor::Tensor; + output2_tensor->SetData(output2); + output2_tensor->set_shape(output2_shape); + output2_tensor->AddQuantParam(output_quant_arg); + output2_tensor->set_data_type(tid_int8); + lite::tensor::Tensor *output3_tensor = new lite::tensor::Tensor; + output3_tensor->SetData(output3); + output3_tensor->set_shape(output3_shape); + output3_tensor->AddQuantParam(output_quant_arg); + output3_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(3); + outputs_tensor[0] = output1_tensor; + outputs_tensor[1] = output2_tensor; + outputs_tensor[2] = output3_tensor; + + SplitParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Split; + op_param.num_split_ = 3; + op_param.split_dim_ = 1; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Split}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output1_tensor_shape = output1_tensor->shape(); + auto output2_tensor_shape = output2_tensor->shape(); + auto output3_tensor_shape = output3_tensor->shape(); + ASSERT_EQ(output1_tensor_shape, output1_shape); + ASSERT_EQ(output2_tensor_shape, output2_shape); + ASSERT_EQ(output3_tensor_shape, output3_shape); + kernel->Run(); + + std::vector except_result1 = {1, 2, 7, 8}; + std::vector except_result2 = {3, 4, 9, 10}; + std::vector except_result3 = {5, 6, 11, 12}; + PrintData("output data", output1, output1_size); + PrintData("output data shape", output1_tensor_shape.data(), output1_tensor_shape.size()); + PrintData("output data", output2, output2_size); + PrintData("output data shape", output2_tensor_shape.data(), output2_tensor_shape.size()); + PrintData("output data", output3, output3_size); + PrintData("output data shape", output3_tensor_shape.data(), output3_tensor_shape.size()); + CompareOutputData(output1, except_result1.data(), output1_size, 0.000001); + CompareOutputData(output2, except_result2.data(), output2_size, 0.000001); + CompareOutputData(output3, except_result3.data(), output3_size, 0.000001); + + input_tensor1->SetData(nullptr); + output1_tensor->SetData(nullptr); + output2_tensor->SetData(nullptr); + output3_tensor->SetData(nullptr); + delete input_tensor1; + delete output1_tensor; + delete output2_tensor; + delete output3_tensor; + delete ctx; +} + +TEST_F(TestSplitInt8, Split_quant1_thread2_num) { + std::vector input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::vector shape1 = {2, 3, 2}; + std::vector input(1, nullptr); + input[0] = input1.data(); + + const int output1_size = 4; + int8_t output1[4]; + const int output2_size = 4; + int8_t output2[4]; + const int output3_size = 4; + int8_t output3[4]; + std::vector output1_shape = {2, 1, 2}; + std::vector output2_shape = {2, 1, 2}; + std::vector output3_shape = {2, 1, 2}; + + lite::tensor::QuantArg input_quant_arg; + input_quant_arg.scale = 1.0; + input_quant_arg.zeroPoint = 0; + lite::tensor::QuantArg output_quant_arg; + output_quant_arg.scale = 2.0; + output_quant_arg.zeroPoint = 0; + + TypeId tid_int8 = kNumberTypeInt8; + lite::tensor::Tensor *input_tensor1 = new lite::tensor::Tensor; + input_tensor1->SetData(input1.data()); + input_tensor1->set_shape(shape1); + input_tensor1->AddQuantParam(input_quant_arg); + input_tensor1->set_data_type(tid_int8); + std::vector inputs_tensor(1); + inputs_tensor[0] = input_tensor1; + + lite::tensor::Tensor *output1_tensor = new lite::tensor::Tensor; + output1_tensor->SetData(output1); + output1_tensor->set_shape(output1_shape); + output1_tensor->AddQuantParam(output_quant_arg); + output1_tensor->set_data_type(tid_int8); + lite::tensor::Tensor *output2_tensor = new lite::tensor::Tensor; + output2_tensor->SetData(output2); + output2_tensor->set_shape(output2_shape); + output2_tensor->AddQuantParam(output_quant_arg); + output2_tensor->set_data_type(tid_int8); + lite::tensor::Tensor *output3_tensor = new lite::tensor::Tensor; + output3_tensor->SetData(output3); + output3_tensor->set_shape(output3_shape); + output3_tensor->AddQuantParam(output_quant_arg); + output3_tensor->set_data_type(tid_int8); + std::vector outputs_tensor(3); + outputs_tensor[0] = output1_tensor; + outputs_tensor[1] = output2_tensor; + outputs_tensor[2] = output3_tensor; + + SplitParameter op_param; + op_param.op_parameter_.type_ = schema::PrimitiveType_Split; + op_param.num_split_ = 3; + op_param.split_dim_ = 1; + lite::Context *ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Split}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + auto output1_tensor_shape = output1_tensor->shape(); + auto output2_tensor_shape = output2_tensor->shape(); + auto output3_tensor_shape = output3_tensor->shape(); + ASSERT_EQ(output1_tensor_shape, output1_shape); + ASSERT_EQ(output2_tensor_shape, output2_shape); + ASSERT_EQ(output3_tensor_shape, output3_shape); + kernel->Run(); + + std::vector except_result1 = {1, 1, 4, 4}; + std::vector except_result2 = {2, 2, 5, 5}; + std::vector except_result3 = {3, 3, 6, 6}; + PrintData("output data", output1, output1_size); + PrintData("output data shape", output1_tensor_shape.data(), output1_tensor_shape.size()); + PrintData("output data", output2, output2_size); + PrintData("output data shape", output2_tensor_shape.data(), output2_tensor_shape.size()); + PrintData("output data", output3, output3_size); + PrintData("output data shape", output3_tensor_shape.data(), output3_tensor_shape.size()); + CompareOutputData(output1, except_result1.data(), output1_size, 0.000001); + CompareOutputData(output2, except_result2.data(), output2_size, 0.000001); + CompareOutputData(output3, except_result3.data(), output3_size, 0.000001); + + input_tensor1->SetData(nullptr); + output1_tensor->SetData(nullptr); + output2_tensor->SetData(nullptr); + output3_tensor->SetData(nullptr); + delete input_tensor1; + delete output1_tensor; + delete output2_tensor; + delete output3_tensor; + delete ctx; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc new file mode 100644 index 0000000000..1334f72b35 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestTopKInt8 : public mindspore::Common { + public: + TestTopKInt8() {} +}; + +TEST_F(TestTopKInt8, TopK) { + lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 2, 3}); + lite::tensor::Tensor out_tensor0(kNumberTypeInt8, {2, 2, 2}); + lite::tensor::Tensor out_tensor1(kNumberTypeInt32, {2, 2, 2}); + int8_t input_data[] = {1, 2, 3, 6, 5, 4, 9, 8, 7, 10, 12, 11}; + int8_t output_data0[8] = {0}; + int32_t output_data1[8] = {0}; + in_tensor.SetData(input_data); + out_tensor0.SetData(output_data0); + out_tensor1.SetData(output_data1); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor0, &out_tensor1}; + + TopkParameter parameter = {{}, 3, 4, 2, true}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_TopK}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect0[] = {3, 2, 6, 5, 9, 8, 12, 11}; + int32_t expect1[] = {2, 1, 0, 1, 0, 1, 1, 2}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(output_data0[i], expect0[i]); + EXPECT_EQ(output_data1[i], expect1[i]); + } + + in_tensor.SetData(nullptr); + out_tensor0.SetData(nullptr); + out_tensor1.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_out_50.bin new file mode 100644 index 0000000000..9418b5866b --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_out_50.bin @@ -0,0 +1 @@ +"x>#>K9>pR >)J >4>K>Z>>>L>=Q>*^>M>&>6>S>*>N>-=+L>vK>+A}>w^>$Q>s>/W>=M'>9[*>#%<#>C>>>$=Gj>>7*>2>6> >1p>s#>Y)>k>9==lQ0>w> \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_x_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_x_50.bin new file mode 100644 index 0000000000..d216d74d9d --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_x_50.bin @@ -0,0 +1 @@ +M?Ƿ? H2|7>0?dyX?C.\fT@?ͳg?Lw񾫘žE9&7A?T?XF4??ҹ?(k?0??VH?-Tz@&"-1w?F?羢D>Y> _p?] ?%R5 Ks=? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_yt_50.bin new file mode 100644 index 0000000000..23f04a2015 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hsigmoid_yt_50.bin @@ -0,0 +1 @@ +?6V?U?S?=M?;?3P?;?E?Ln?u?!?V??sW?9_?e?}H?h??X=? ?%??Y1?[s??c??t{?Ո?7=DK?eW???>?kcY?S???_fQ?u%?-u?}??k9??=?? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hswish_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hswish_out_50.bin new file mode 100644 index 0000000000..6ff3dd84c9 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hswish_out_50.bin @@ -0,0 +1 @@ +v=qٽBs>Q=@"\=`ο;?廿? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hswish_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hswish_yt_50.bin new file mode 100644 index 0000000000..b25a5c7787 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/hswish_yt_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_out_50.bin new file mode 100644 index 0000000000..b13b02ddbc Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_out_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_y_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_y_50.bin new file mode 100644 index 0000000000..6031302671 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_y_50.bin @@ -0,0 +1 @@ +&5Sa?t?@W,2ս&8?;V?橡?$?5pNF7:?5V:΄?m ,!@`|>Vؚ ?_B?0Խ"?q!>%=,? >Ѓ?;?qGh?7<U>=?-ap?g?>r@X> \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_yt_50.bin new file mode 100644 index 0000000000..7cd15ef7f9 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/lrelu_yt_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_out_50.bin new file mode 100644 index 0000000000..0e9de34e83 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_out_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_y_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_y_50.bin new file mode 100644 index 0000000000..095be3ca3a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_y_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_yt_50.bin new file mode 100644 index 0000000000..0e9de34e83 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu6_yt_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_out_50.bin new file mode 100644 index 0000000000..dc7098eca0 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_out_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_y_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_y_50.bin new file mode 100644 index 0000000000..00d6139a77 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_y_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_yt_50.bin new file mode 100644 index 0000000000..dc7098eca0 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/relu_yt_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_out_50.bin new file mode 100644 index 0000000000..8fcc885497 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_out_50.bin @@ -0,0 +1 @@ +]>E >Jn>bK>8=

&>g>];>Q>I>=\>S> ŀ=C*>K=n>Iy>>l>/=>rp>>>( >[>->{j=4>C>e>D>B==x>/m>vj>P>v=Pʕ>=3>vN= >ӂ> \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_y_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_y_50.bin new file mode 100644 index 0000000000..9e3724e2c0 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_y_50.bin @@ -0,0 +1 @@ +w>>XO>?>h>=%?o:?9>"q=7> >?? >9?{>t?D2\?J>n>1>>OF?/?7y?J0?eT?A?F$>'>Ab>#?"m@>$i>8?*C?)>r3?ᆒ>X?9y>>^2>S??w!'? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_yt_50.bin new file mode 100644 index 0000000000..ca66395f8f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/sigmoid_yt_50.bin @@ -0,0 +1 @@ +wa?"?XOo??Q?>E>?o?9i?"qv>> u??? W?9?>t?D2?>n8?>o?O?/Ǚ?7y?J?e?A?F$&?'?Ab_??"m>˼?? ??>$>?*?)l?r??X?9>??^>SԿ?w!? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_out_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_out_50.bin new file mode 100644 index 0000000000..9e3e18c43c --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_out_50.bin @@ -0,0 +1,2 @@ +@?^*Su>?(1?Y?O]>8>yͽh>Y:ן<e +?@?C?C?6GUp>_=I 0`>0>ݎ9?;?Gs*>e3>?ʑ>;?(,?&3*?C?Cw<2?=K>%HC%8?M>~>'u>JI>^4YuZ? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_y_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_y_50.bin new file mode 100644 index 0000000000..0a9a479053 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_y_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_yt_50.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_yt_50.bin new file mode 100644 index 0000000000..7323533d2f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/activationGrad/tanh_yt_50.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_0.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_0.bin new file mode 100644 index 0000000000..b22edaef0e Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_0.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_1.bin new file mode 100644 index 0000000000..437a6958ad --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_1.bin @@ -0,0 +1 @@ +L[?-"R>q>{B>?yx?_>JSD>G0? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_2.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_2.bin new file mode 100644 index 0000000000..4708330c95 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_2.bin @@ -0,0 +1 @@ +J[q?P?>g?A?>oo? 7G?x<"? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_3.bin new file mode 100644 index 0000000000..ca38daf512 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_3.bin @@ -0,0 +1 @@ +WU>X8?* ?!v>F>0 ?.<C?d? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_4.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_4.bin new file mode 100644 index 0000000000..dd1fa36149 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_input_4.bin @@ -0,0 +1 @@ +R?]?>c~?um?z1->??'?U? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_out.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_out.bin new file mode 100644 index 0000000000..9bc4e21395 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/batchnorm/fusedBatchnorm_out.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/2conv1x1conv1_input_nc4hwc4.txt b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/2conv1x1conv1_input_nc4hwc4.txt new file mode 100644 index 0000000000..297107068f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/2conv1x1conv1_input_nc4hwc4.txt differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_bias1_nhwc.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_bias1_nhwc.bin new file mode 100644 index 0000000000..5b135816a9 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_bias1_nhwc.bin @@ -0,0 +1,2 @@ +:eQݿc?p @E(o=ű*΢=͕^C?-?=@$ ?(W!=+>@@? -JP ?k ?M + wq>3=Rj @E%@!H￸l==\j/m2>b @B \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_input1_nhwc.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_input1_nhwc.bin new file mode 100644 index 0000000000..44d1e21ade Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_input1_nhwc.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin new file mode 100644 index 0000000000..66813ca607 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_weight1_nhwc.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_weight1_nhwc.bin new file mode 100644 index 0000000000..5e31050ba9 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_weight1_nhwc.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32.tflite b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32.tflite new file mode 100644 index 0000000000..8763b51b0f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32.tflite differ diff --git a/mindspore/_akg/utils/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_32_3_3_3.bin similarity index 100% rename from mindspore/_akg/utils/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_32_3_3_3.bin diff --git a/model_zoo/Transformer/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_18_3_3_3.bin similarity index 100% rename from model_zoo/Transformer/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_18_3_3_3.bin diff --git a/model_zoo/alexnet/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin similarity index 100% rename from model_zoo/alexnet/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin diff --git a/model_zoo/gat/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_1_28_28_3.bin similarity index 100% rename from model_zoo/gat/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_1_28_28_3.bin diff --git a/model_zoo/lenet/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_1_28_28_3.bin similarity index 100% rename from model_zoo/lenet/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_1_28_28_3.bin diff --git a/model_zoo/mass/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin similarity index 100% rename from model_zoo/mass/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin diff --git a/model_zoo/mass/scripts/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_1_28_28_32.bin similarity index 100% rename from model_zoo/mass/scripts/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_1_28_28_32.bin diff --git a/model_zoo/resnext50/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_1_28_28_18.bin similarity index 100% rename from model_zoo/resnext50/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_1_28_28_18.bin diff --git a/model_zoo/resnext50/src/utils/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin similarity index 100% rename from model_zoo/resnext50/src/utils/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_128_128_24.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_128_128_24.bin new file mode 100644 index 0000000000..b13d74559f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_128_128_24.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_128_128_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_128_128_32.bin new file mode 100644 index 0000000000..1c37f7cc16 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_128_128_32.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_28_28_3.bin new file mode 100644 index 0000000000..a67b9fb715 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_3_28_28.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_3_28_28.bin new file mode 100644 index 0000000000..d937821330 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_input_1_3_28_28.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_127_127_24.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_127_127_24.bin new file mode 100644 index 0000000000..4d197e365e Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_127_127_24.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_128_128_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_128_128_32.bin new file mode 100644 index 0000000000..8bef61d977 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_128_128_32.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_28_28_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_28_28_32.bin new file mode 100644 index 0000000000..1979eef911 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_out_1_28_28_32.bin differ diff --git a/model_zoo/ssd/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_32_3_3_3.bin similarity index 100% rename from model_zoo/ssd/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_32_3_3_3.bin diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_18_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_18_3_3_3.bin new file mode 100644 index 0000000000..1dc4bf74d8 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_18_3_3_3.bin @@ -0,0 +1,4 @@ +F .N2?󻾩`?ͽؿSį2xR=}%T?9>R?Eÿ?>?@<*Fs?h>۾i) W>+ ;=y@\?V=~?)оϬ?HF}?տի?F꓿E +?Gÿ#μ>PD>>J?gNY, <ֈu?Y_"4?fx Y7;¡̾?)???]@-/zb?Ye?M /6?"?t?T?; -1?,6?.>n>8D?Ǿ +F+j?~B? +P??t ek?I?WJ>&? ?;;隿j =sg?[k?r?ݖc>.ljy?D>S¿?l?rSq?#m(@_?>l %6?h%k>=4?ŋoJs>fW?8c;??k:?bQ1Y>yp>nW= z|S:?P?r?K?kw>->V?~> \ No newline at end of file diff --git a/model_zoo/wide_and_deep/src/__init__.py b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin similarity index 100% rename from model_zoo/wide_and_deep/src/__init__.py rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_24_3_3_24.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_24_3_3_24.bin new file mode 100644 index 0000000000..d37e077886 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_24_3_3_24.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_32_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_32_3_3_3.bin new file mode 100644 index 0000000000..a9cf75b584 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_32_3_3_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_32_3_3_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_32_3_3_32.bin new file mode 100644 index 0000000000..b129a056ba Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_weight_32_3_3_32.bin differ diff --git a/predict/benchmark/README.md b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_1_28_28_3.bin similarity index 100% rename from predict/benchmark/README.md rename to mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_1_28_28_3.bin diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_1_28_28_3.bin new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/inception_v1_quant.tflite b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/inception_v1_quant.tflite new file mode 100644 index 0000000000..8e979ef111 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/inception_v1_quant.tflite differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/mv1_quant.tflite b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/mv1_quant.tflite new file mode 100644 index 0000000000..98ccc2c34e Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/mv1_quant.tflite differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_1_224_224_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_1_224_224_3.bin new file mode 100644 index 0000000000..986cce3525 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_1_224_224_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_1_28_28_16.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_1_28_28_16.bin new file mode 100644 index 0000000000..80294e0994 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_1_28_28_16.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_out_1_112_112_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_out_1_112_112_32.bin new file mode 100644 index 0000000000..e0c9a81244 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_out_1_112_112_32.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_out_1_28_28_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_out_1_28_28_32.bin new file mode 100644 index 0000000000..b903346cc5 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/uint8_out_1_28_28_32.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_input.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_input.bin new file mode 100644 index 0000000000..6f76847a88 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_input.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_output.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_output.bin new file mode 100644 index 0000000000..1142ffa404 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_output.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_weight.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_weight.bin new file mode 100644 index 0000000000..850ca15ce9 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/convDw/convDwfp32_weight.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_bias1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_bias1.bin new file mode 100644 index 0000000000..12fdbb0315 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_bias1.bin @@ -0,0 +1 @@ +ꜿX>+ ?6@?ܕe?;ܿ?ƄR??ýP`ڄ=??Cm?H?<<?ڈ3?:v ?-? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_output1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_output1.bin new file mode 100644 index 0000000000..f68fc47547 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_output1.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_weight1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_weight1.bin new file mode 100644 index 0000000000..6c06207ffa Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nchw_weight1.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nhwc_input1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nhwc_input1.bin new file mode 100644 index 0000000000..75e3206931 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconv/deconv_fp32_nhwc_input1.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_input.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_input.bin new file mode 100644 index 0000000000..bf8a032ce6 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_input.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_output.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_output.bin new file mode 100644 index 0000000000..c7b022f23b Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_output.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_weight.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_weight.bin new file mode 100644 index 0000000000..78b6aa2c1c Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/deconvDw/deconvDw_weight.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_bias1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_bias1.bin new file mode 100644 index 0000000000..d4ad33ce06 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_bias1.bin @@ -0,0 +1 @@ +3: ?i߾te??6\X`^j6@ h'>ݼ>6?%Fm?|C)@l>F@vC(*n6@{?@¹,~ @d6>@(R@g?8@ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_input1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_input1.bin new file mode 100644 index 0000000000..73aac58120 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_input1.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_output1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_output1.bin new file mode 100644 index 0000000000..620de8d1ae Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_output1.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_weight1.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_weight1.bin new file mode 100644 index 0000000000..71340f0703 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/FcFp32_weight1.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_a_10x4.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_a_10x4.bin new file mode 100644 index 0000000000..9d152d6f6c Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_a_10x4.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_a_4x10.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_a_4x10.bin new file mode 100644 index 0000000000..2352a3989a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_a_4x10.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_b_10x5.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_b_10x5.bin new file mode 100644 index 0000000000..a6df1ed831 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_b_10x5.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_b_5x10.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_b_5x10.bin new file mode 100644 index 0000000000..eb28b692f6 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_b_5x10.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_c_4x5.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_c_4x5.bin new file mode 100644 index 0000000000..183a9f0ba5 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/matmul/matmulfp32_c_4x5.bin @@ -0,0 +1 @@ +.@8|A-=,fAQ>2@dui}?t4@@2zN@ԣx@&(ӂe@g \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/inception_v1_quant/inception_v1_224_quant.tflite b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/inception_v1_quant/inception_v1_224_quant.tflite new file mode 100644 index 0000000000..54efd7031d Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/inception_v1_quant/inception_v1_224_quant.tflite differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.data-00000-of-00001 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.data-00000-of-00001 new file mode 100644 index 0000000000..4744fdb5c2 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.data-00000-of-00001 differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.index b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.index new file mode 100644 index 0000000000..cd90c18ea6 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.index differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.meta b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.meta new file mode 100644 index 0000000000..322fdc426a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.ckpt.meta differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.tflite b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.tflite new file mode 100644 index 0000000000..437640b069 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant.tflite differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant_eval.pbtxt b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant_eval.pbtxt new file mode 100644 index 0000000000..796b3f14c2 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant_eval.pbtxt @@ -0,0 +1,50681 @@ +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\003\000\000\000 \000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_0/weights" + input: "MobilenetV1/Conv2d_0/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_0/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_0/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D" + op: "Conv2D" + input: "input" + input: "MobilenetV1/Conv2d_0/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_0/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_0/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000 \000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 32 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000 \000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 32 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 32 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000 \000\000\000@\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 32 + } + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_pointwise/weights" + input: "MobilenetV1/Conv2d_1_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_1_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000@\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 64 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000@\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 64 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 64 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000@\000\000\000\200\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 64 + } + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_pointwise/weights" + input: "MobilenetV1/Conv2d_2_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_2_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\200\000\000\000\200\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 128 + } + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_pointwise/weights" + input: "MobilenetV1/Conv2d_3_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_3_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 128 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\200\000\000\000\000\001\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 128 + } + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_pointwise/weights" + input: "MobilenetV1/Conv2d_4_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_4_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 256 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\001\000\000\000\001\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 256 + } + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_pointwise/weights" + input: "MobilenetV1/Conv2d_5_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_5_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 256 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 256 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 256 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\001\000\000\000\002\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 256 + } + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_pointwise/weights" + input: "MobilenetV1/Conv2d_6_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_6_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 512 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\002\000\000\000\002\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 512 + } + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_pointwise/weights" + input: "MobilenetV1/Conv2d_7_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_7_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 512 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\002\000\000\000\002\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 512 + } + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_pointwise/weights" + input: "MobilenetV1/Conv2d_8_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_8_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 512 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\002\000\000\000\002\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 512 + } + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_pointwise/weights" + input: "MobilenetV1/Conv2d_9_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_9_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 512 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\002\000\000\000\002\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 512 + } + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_pointwise/weights" + input: "MobilenetV1/Conv2d_10_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_10_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 512 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\002\000\000\000\002\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 512 + } + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_pointwise/weights" + input: "MobilenetV1/Conv2d_11_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_11_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 512 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 512 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 512 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\002\000\000\000\004\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 512 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_pointwise/weights" + input: "MobilenetV1/Conv2d_12_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_12_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\004\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 1024 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/depthwise_weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\004\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\004\000\000\000\004\000\000" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 1024 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_pointwise/weights" + input: "MobilenetV1/Conv2d_13_pointwise/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/weights/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_pointwise/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/weights" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Conv2D" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/Conv2d_13_pointwise/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean" + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1024 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Initializer/ones" + op: "Fill" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Initializer/ones/shape_as_tensor" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Initializer/ones/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Assign" + op: "Assign" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/Initializer/ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/read" + op: "Identity" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm/FusedBatchNorm" + op: "FusedBatchNorm" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Conv2D" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.0010000000475 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm/Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.000300000014249 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6" + op: "Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/add_fold" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Logits/AvgPool_1a/AvgPool" + op: "AvgPool" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "ksize" + value { + list { + i: 1 + i: 7 + i: 7 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "VALID" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/Logits/Dropout_1b/Identity" + op: "Identity" + input: "MobilenetV1/Logits/AvgPool_1a/AvgPool" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\000\004\000\000\351\003\000\000" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/mean" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/stddev" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0900000035763 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/mul" + op: "Mul" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/TruncatedNormal" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal" + op: "Add" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/mul" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 1024 + } + dim { + size: 1001 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Assign" + op: "Assign" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/Initializer/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/read" + op: "Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Initializer/zeros/shape_as_tensor" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/biases" + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1001 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Initializer/zeros/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/biases" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Initializer/zeros" + op: "Fill" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Initializer/zeros/shape_as_tensor" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Initializer/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/biases" + } + } + } + attr { + key: "index_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/biases" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/biases" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1001 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Assign" + op: "Assign" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/biases" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/read" + op: "Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/biases" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/Conv2D" + op: "Conv2D" + input: "MobilenetV1/Logits/Dropout_1b/Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd" + op: "BiasAdd" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/Conv2D" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } +} +node { + name: "MobilenetV1/Logits/SpatialSqueeze" + op: "Squeeze" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "squeeze_dims" + value { + list { + i: 1 + i: 2 + } + } + } +} +node { + name: "MobilenetV1/Predictions/Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\377\377\377\377\351\003\000\000" + } + } + } +} +node { + name: "MobilenetV1/Predictions/Reshape" + op: "Reshape" + input: "MobilenetV1/Logits/SpatialSqueeze" + input: "MobilenetV1/Predictions/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Predictions/Softmax" + op: "Softmax" + input: "MobilenetV1/Predictions/Reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/Predictions/Shape" + op: "Shape" + input: "MobilenetV1/Logits/SpatialSqueeze" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/Predictions/Reshape_1" + op: "Reshape" + input: "MobilenetV1/Predictions/Softmax" + input: "MobilenetV1/Predictions/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_0/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_0/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D_Fold" + op: "Conv2D" + input: "input" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: " \000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_1_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000 \000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_1_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "@\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_2_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000@\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_2_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_3_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_3_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_4_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_4_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_5_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_5_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_6_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\001\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_6_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_7_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_7_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_8_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_8_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_9_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_9_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_10_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_10_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_11_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_11_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_12_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\002\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 2 + i: 2 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_12_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/scale_reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\004\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/scale_reshape" + op: "Reshape" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/mul" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/scale_reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_13_depthwise/depthwise_weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/scale_reshape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise_Fold/Shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\003\000\000\000\003\000\000\000\000\004\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise_Fold/dilation_rate" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise_Fold" + op: "DepthwiseConv2dNative" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/depthwise_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000475 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/add" + op: "Add" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_variance/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/Rsqrt" + op: "Rsqrt" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/mul" + op: "Mul" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/mul_1" + op: "Mul" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/moving_mean/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/bias" + op: "Sub" + input: "MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/mul_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/mul_fold" + op: "Mul" + input: "MobilenetV1/Conv2d_13_pointwise/weights/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Conv2D_Fold" + op: "Conv2D" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/FakeQuantWithMinMaxVars" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "dilations" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "padding" + value { + s: "SAME" + } + } + attr { + key: "strides" + value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } + attr { + key: "use_cudnn_on_gpu" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/add_fold" + op: "Add" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Conv2D_Fold" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/BatchNorm_Fold/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_0/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_0/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_0/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_1_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_2_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_3_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_4_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_5_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_6_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_7_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_8_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_9_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_10_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_11_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_12_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_depthwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/mul_fold" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/min/read" + input: "MobilenetV1/MobilenetV1/Conv2d_13_pointwise/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -6.0 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min/read" + op: "Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max/read" + op: "Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights/read" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/min/read" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/max/read" + attr { + key: "narrow_range" + value { + b: true + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min/Assign" + op: "Assign" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min/read" + op: "Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max/Initializer/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 6.0 + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max/Assign" + op: "Assign" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max/Initializer/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max/read" + op: "Identity" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max" + } + } + } +} +node { + name: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/min/read" + input: "MobilenetV1/Logits/Conv2d_1c_1x1/act_quant/max/read" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +versions { + producer: 26 +} diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant_info.txt b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant_info.txt new file mode 100644 index 0000000000..96e6eddb30 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/models/mobilenet_quant/mobilenet_v1_1.0_224_quant_info.txt @@ -0,0 +1,3 @@ +Model: mobilenet_v1_1.0_224_quant +Input: input +Output: MobilenetV1/Predictions/Reshape_1 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dx1_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dx1_5_4_6.bin new file mode 100644 index 0000000000..fcb5df3926 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dx1_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin new file mode 100644 index 0000000000..aa7bc323f7 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dy_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dy_5_4_6.bin new file mode 100644 index 0000000000..c68b9c80b0 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_10_dy_5_4_6.bin @@ -0,0 +1 @@ +&JvBAoL?̓I?̿A6ԽS?d>5?iӿ`@u@G@`M>Av>B)>c@$/AwA˿^ 0kܾȁAfr>0x˿cR?vu=`,>pŔ?aK@y?׾db3?@ڤeK?Ч9)?uu@?"=P>b>v@Nl@ÐU>ot ?*@y ; @Av_ 俶\q?w@0ݻj?Aq;bo,3@I`?3sfl@@I?? AC_>=L@? .@xy`?῿A3 3ˑ?n?.=\B@/A>B_KGF?-^.;?V]?K@z:} QPqH ſ?,@~?LP>7>Pq@P>@ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_dx2_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_dx2_1_6.bin new file mode 100644 index 0000000000..1676dc671f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_dx2_1_6.bin @@ -0,0 +1 @@ +L&h>)A[7?.O2@ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_dy_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_dy_4_6.bin new file mode 100644 index 0000000000..37947edcb8 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_dy_4_6.bin @@ -0,0 +1 @@ +wӿ?8>GF?-^.;?V]?K@z:} QPqH ſ?,@~?LP>7>Pq@P>@ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_x1_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_x1_4_6.bin new file mode 100644 index 0000000000..a093c882f6 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_1_x1_4_6.bin @@ -0,0 +1 @@ +.%?.s?d-f=<Vnߔ?`Dz?t"?e|>~ ޾I? ??ƽY?3B?Xf \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dx1_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dx1_4_6.bin new file mode 100644 index 0000000000..6737f96700 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dx1_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dx2_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dx2_1_6.bin new file mode 100644 index 0000000000..94bd26e133 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dx2_1_6.bin @@ -0,0 +1 @@ +2)A%dȑ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dy_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dy_4_6.bin new file mode 100644 index 0000000000..6737f96700 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_2_dy_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dx1_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dx1_4_6.bin new file mode 100644 index 0000000000..6737f96700 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dx1_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dx2_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dx2_1_6.bin new file mode 100644 index 0000000000..94bd26e133 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dx2_1_6.bin @@ -0,0 +1 @@ +2)A%dȑ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dy_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dy_4_6.bin new file mode 100644 index 0000000000..a4130257de Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_3_dy_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dx1_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dx1_4_6.bin new file mode 100644 index 0000000000..b8cc5229db --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dx1_4_6.bin @@ -0,0 +1,2 @@ +;p +?y=xJɷi<%.ve!?N[@I(*8Iኽ}> $!ÿi>u?JWy?R8]?z{?5> \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dx2_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dx2_1_6.bin new file mode 100644 index 0000000000..40d97286c8 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dx2_1_6.bin @@ -0,0 +1 @@ +b殿Q.* (@0˸ok@ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dy_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dy_4_6.bin new file mode 100644 index 0000000000..3cac9bfc20 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_4_dy_4_6.bin @@ -0,0 +1 @@ +ao=XD<ԟо?$_`=US>]%@?,?SG+b>&e+ѿ#Nv1>b?\i׿׺ H|@N?& ?ǖb ?9A{d@uh?wN?3(cvi5Ϙ>t +6?ngZ?> /TA \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dx1_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dx1_1_6.bin new file mode 100644 index 0000000000..4fe396e1bc --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dx1_1_6.bin @@ -0,0 +1 @@ +YA:ۢA93*B%B \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dx2_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dx2_4_6.bin new file mode 100644 index 0000000000..35a86f48b5 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dx2_4_6.bin @@ -0,0 +1 @@ +gE|: J"@۽>'BtpB=?C?[^9Ano*Bļ-hVCDBYD\ۿ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dy_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dy_4_6.bin new file mode 100644 index 0000000000..4bfe040d8c Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_6_dy_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dx1_4_5_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dx1_4_5_6.bin new file mode 100644 index 0000000000..32e73ebd1a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dx1_4_5_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin new file mode 100644 index 0000000000..06a8bb8181 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin @@ -0,0 +1 @@ +MA=@TY$¦&+f \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dy_4_5_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dy_4_5_6.bin new file mode 100644 index 0000000000..3a75ce88ed Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_dy_4_5_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_x1_4_5_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_x1_4_5_6.bin new file mode 100644 index 0000000000..e5104b8e39 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_x1_4_5_6.bin @@ -0,0 +1,2 @@ +5^x|>?NWւpټR4=p?v֥?RݾQ>(v?pj?H?G)>I޿ʟ>u{`Խ>Ԁ!O?9;>c>?\7K?cb_?.kuS=j?1=y? p)P?=V ?.>ј?H?*wMH ?\q>(z࿣?-=? y?ƿ׻+ _?dGp X?Qt>F?N=N +Yn`%G?)$?Jdi?7"8){oϾ w>(L?Pԝ'2?u?2?">!פ???<߾f?9F&QEE}<:G>.}.S?\̿K?*V%?Q꿹uyC7:?T4>`>o>!?>;y_ \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_x2_1_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_x2_1_1_6.bin new file mode 100644 index 0000000000..b0921a2d80 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_7_x2_1_1_6.bin @@ -0,0 +1 @@ +<>o8?#?d@ٌ< \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin new file mode 100644 index 0000000000..199d582173 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dx2_5_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dx2_5_1_6.bin new file mode 100644 index 0000000000..ea5dbc6a93 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dx2_5_1_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dy_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dy_5_4_6.bin new file mode 100644 index 0000000000..199d582173 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_dy_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_x1_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_x1_5_4_6.bin new file mode 100644 index 0000000000..d14440793e Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_x1_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_x2_5_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_x2_5_1_6.bin new file mode 100644 index 0000000000..9b6169666c Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_8_x2_5_1_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin new file mode 100644 index 0000000000..5824f414cf Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin new file mode 100644 index 0000000000..bee4fa98bc --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin @@ -0,0 +1,2 @@ + @UpQ@(A,@N5@(q?O)X@!g@}@+??B>+?M@,@ @eAOF@ۉ@%AN* ++| \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin new file mode 100644 index 0000000000..dde1dd8cbc Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_x1_5_4_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_x1_5_4_6.bin new file mode 100644 index 0000000000..d14440793e Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_x1_5_4_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_x2_5_1_6.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_x2_5_1_6.bin new file mode 100644 index 0000000000..9b6169666c Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/arithmetic_fp32_9_x2_5_1_6.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/biasgradfp32_1_db_7.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/biasgradfp32_1_db_7.bin new file mode 100644 index 0000000000..46853507f8 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/biasgradfp32_1_db_7.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/biasgradfp32_1_dy_10_28_28_7.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/biasgradfp32_1_dy_10_28_28_7.bin new file mode 100644 index 0000000000..c079b77a36 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/biasgradfp32_1_dy_10_28_28_7.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/sce_fp32_1_dy_6_4.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/sce_fp32_1_dy_6_4.bin new file mode 100644 index 0000000000..d135477618 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/operators/sce_fp32_1_dy_6_4.bin @@ -0,0 +1,2 @@ +,Q= dt=*<ʮ<'=䜶=_<ὥ +=i=ӭ=a:<=AL#wDh?>?>>a>S?>>*D^? [?Q=?56>ȭ? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avg_pool_1_128_128_24.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avg_pool_1_128_128_24.bin new file mode 100644 index 0000000000..12ec891f83 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avg_pool_1_128_128_24.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin new file mode 100644 index 0000000000..e5d086ef06 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin new file mode 100644 index 0000000000..ad87262c97 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_x_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_x_1_28_28_3.bin new file mode 100644 index 0000000000..d32ffbe755 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_x_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolingfp32_out_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolingfp32_out_1_28_28_3.bin new file mode 100644 index 0000000000..0c2971cb5a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolingfp32_out_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/max_pool_1_128_128_24.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/max_pool_1_128_128_24.bin new file mode 100644 index 0000000000..b0d69e70a1 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/max_pool_1_128_128_24.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin new file mode 100644 index 0000000000..cca67a85df Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin new file mode 100644 index 0000000000..15c810365e Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_i_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_i_1_28_28_3.bin new file mode 100644 index 0000000000..c50b6145ab Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_i_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dx_1_30_30_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dx_1_30_30_3.bin new file mode 100644 index 0000000000..e2ee3307f9 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dx_1_30_30_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dy_1_10_10_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dy_1_10_10_3.bin new file mode 100644 index 0000000000..0985d76b03 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dy_1_10_10_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dy_1_3_10_10.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dy_1_3_10_10.bin new file mode 100644 index 0000000000..c95d4ce0fa Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_dy_1_3_10_10.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_i_1_10_10_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_i_1_10_10_3.bin new file mode 100644 index 0000000000..d5647f32dc Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_i_1_10_10_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_i_1_3_10_10.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_i_1_3_10_10.bin new file mode 100644 index 0000000000..89157fd9dd Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_i_1_3_10_10.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_x_1_30_30_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_x_1_30_30_3.bin new file mode 100644 index 0000000000..61cfb07fdb Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_2_x_1_30_30_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolingfp32_out_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolingfp32_out_1_28_28_3.bin new file mode 100644 index 0000000000..7105a06b1c Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolingfp32_out_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_dx_scale5_shift2_power3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_dx_scale5_shift2_power3.bin new file mode 100644 index 0000000000..0bb3ced765 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_dx_scale5_shift2_power3.bin @@ -0,0 +1,2 @@ +NG^aFN.cJ5G +WExeء?wDRGTEAC_>YHÜ*IIM.IEB:g$IHpñ BFCBtHjC{@FIW>tFӦJE1F x3E~#)JIE4'nIB \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_dy_scale5_shift2_power3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_dy_scale5_shift2_power3.bin new file mode 100644 index 0000000000..b1e404b742 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_dy_scale5_shift2_power3.bin @@ -0,0 +1 @@ +N#LSRDi] Bo;Eg9WNA{CB@I4~D>PD DSpBH  VNLDCm B*l A֓\@iC AB^tB+3ECOAs`D3DeB~D`v@: \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_x_scale5_shift2_power3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_x_scale5_shift2_power3.bin new file mode 100644 index 0000000000..d902be1588 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/power/powerfp32_x_scale5_shift2_power3.bin @@ -0,0 +1 @@ +h)jr_O%?"=|;9$>=Dw9u??PҴ?~q<= BͿ9Ċ?,?6>8$du5?gb>")H`>q?I>.=,%?$!Y>H?| \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/softmax/softmaxfp32_out_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/softmax/softmaxfp32_out_1_28_28_3.bin new file mode 100644 index 0000000000..2682cdf83f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/softmax/softmaxfp32_out_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_bias_10.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_bias_10.bin new file mode 100644 index 0000000000..527f9f7399 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_bias_10.bin @@ -0,0 +1 @@ +>SY@+[Kc߾&O?-?Qjt? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_input_32_1000.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_input_32_1000.bin new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_weight_10_1000.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_weight_10_1000.bin new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/common_utils_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/common_utils_test.cc new file mode 100644 index 0000000000..7b47ee3fe5 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/common_utils_test.cc @@ -0,0 +1,134 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/common_test.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +class CommonUtilTest : public mindspore::Common { + public: + CommonUtilTest() = default; +}; + +TEST_F(CommonUtilTest, BucketReduceSparseGradient1) { + // The indices is a vector and the grad is a tensor with shape (6, 2) + /* 0 + * 0 + * 1 + * 1 + * 0 + * 3 + */ + std::vector indices{0, 0, 1, 1, 0, 3}; + /* 0 1 + * 2 3 + * 4 5 + * 6 7 + * 8 9 + * 10 11 + */ + std::vector grad; + for (int i = 0; i < 6 * 2; i++) { + grad.push_back(i); + } + std::vector unique_indices(6); + std::vector summed_grad(12); + std::vector tmp_indices(6); + std::vector tmp_grad(12); + + SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6}); + SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6}); + SparseGradient input_grad({grad.data(), indices.data(), 6}); + + ReduceSparseGradientParam param; + param.input_grad_ = &input_grad; + param.workspace_grad_ = &workspace_grad; + param.output_grad_ = &unique_grad; + param.max_index_ = 6; + param.value_stride_ = 2; + BucketReduceSparseGradient(param); + + EXPECT_EQ(unique_grad.indices_size_, 3); + std::vector expect_indices({0, 1, 3}); + for (size_t i = 0; i < unique_grad.indices_size_; ++i) { + EXPECT_EQ(unique_grad.indices_[i], expect_indices[i]); + } + /* 10 13 + * 10 12 + * 10 11 + */ + std::vector expect_value({10, 13, 10, 12, 10, 11}); + for (size_t i = 0; i < unique_grad.indices_size_ * 2; ++i) { + EXPECT_EQ(unique_grad.value_[i], expect_value[i]); + } +} + +TEST_F(CommonUtilTest, BucketReduceSparseGradient2) { + // The indices is a vector and the grad is a tensor with shape (6, 2) + /* 0 + * 0 + * 1 + * 1 + * 0 + * 6 + */ + std::vector indices{0, 0, 1, 1, 0, 6}; + /* 0 1 + * 2 3 + * 4 5 + * 6 7 + * 8 9 + * 10 11 + */ + std::vector grad; + for (int i = 0; i < 6 * 2; i++) { + grad.push_back(i); + } + std::vector unique_indices(6); + std::vector summed_grad(12); + std::vector tmp_indices(6); + std::vector tmp_grad(12); + SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6}); + SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6}); + SparseGradient input_grad({grad.data(), indices.data(), 6}); + + ReduceSparseGradientParam param; + param.input_grad_ = &input_grad; + param.workspace_grad_ = &workspace_grad; + param.output_grad_ = &unique_grad; + param.max_index_ = 6; + param.value_stride_ = 2; + BucketReduceSparseGradient(param); + + EXPECT_EQ(unique_grad.indices_size_, 2); + + std::vector expect_indices({0, 1}); + for (size_t i = 0; i < unique_grad.indices_size_; ++i) { + EXPECT_EQ(unique_grad.indices_[i], expect_indices[i]); + } + + /* 10 13 + * 10 12 + */ + std::vector expect_value({10, 13, 10, 12}); + for (size_t i = 0; i < unique_grad.indices_size_ * 2; ++i) { + EXPECT_EQ(unique_grad.value_[i], expect_value[i]); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc new file mode 100644 index 0000000000..e7fccfe1a3 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h" + +namespace mindspore { + +void BoardcaseAdd(const float *a, const float b, float *c, const int size) { + for (int i = 0; i < size; i++) { + c[i] = a[i] + b; + } +} + +void ElementAdd(const float *a, const float *b, float *c, const int size) { + for (int i = 0; i < size; i++) { + c[i] = a[i] + b[i]; + } +} + +bool DataCompare(const float *a, const float *b, const int size, const float accuracy = 1e-4) { + for (int i = 0; i < size; i++) { + auto diff = fabs(a[i] - b[i]); + if (diff > accuracy) { + MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i]; + return false; + } + } + return true; +} + +void InitData(void *data, const int size) { + float *data_float = reinterpret_cast(data); + static unsigned int seed = 123; + for (int i = 0; i < size; i++) { + data_float[i] = static_cast(rand_r(&seed)) % 100; + } +} + +void LogData(void *data, const int size, const std::string prefix) { + std::cout << prefix; + float *data_float = reinterpret_cast(data); + for (int i = 0; i < size; i++) { + std::cout << data_float[i] << ","; + } + std::cout << std::endl; +} + +void TestCase(const std::vector &shape_a, const std::vector &shape_b) { + std::cout << "TestCase" << std::endl; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + bool is_bias_add = shape_b.empty(); + auto tensorType = schema::NodeType_ValueNode; + + std::cout << "TestCase tensor" << std::endl; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(kNumberTypeFloat32, shape_a, schema::Format_NHWC4, tensorType); + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(kNumberTypeFloat32, shape_b, schema::Format_NHWC4, tensorType); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(kNumberTypeFloat32, shape_a, schema::Format_NHWC4, tensorType); + int64_t element_num = tensor_a->ElementsC4Num(); + int64_t element_num_b = is_bias_add ? 1 : tensor_b->ElementsC4Num(); + + std::cout << "TestCase new data" << std::endl; + float *data_a = new float[element_num]; + float *data_b = new float[element_num_b]; + float *data_c_cpu = new float[element_num]; + float *data_c_ocl = new float[element_num]; + + InitData(data_a, element_num); + InitData(data_b, element_num_b); + memset(data_c_ocl, 0, sizeof(float) * element_num); + + std::cout << "TestCase run cpu" << std::endl; + if (is_bias_add) { + BoardcaseAdd(data_a, static_cast(data_b)[0], data_c_cpu, element_num); + } else { + ElementAdd(data_a, data_b, data_c_cpu, element_num); + } + + std::cout << "TestCase set data" << std::endl; + std::vector inputs = {tensor_a}; + if (!is_bias_add) { + inputs.push_back(tensor_b); + } else { + tensor_b->MallocData(); + memcpy(tensor_b->Data(), data_b, sizeof(float)); + } + std::vector outputs = {tensor_c}; + + ArithmeticParameter *param = new ArithmeticParameter(); + param->ndim_ = 4; + param->op_parameter_.type_ = PrimitiveType_Add; + + std::vector arithmetic_inputs = {tensor_a, tensor_b}; + lite::Context ctx; + auto *arith_kernel = + new kernel::ArithmeticOpenCLKernel(reinterpret_cast(param), arithmetic_inputs, outputs, &ctx); + arith_kernel->Init(); + + std::vector kernels{arith_kernel}; + auto *kernel = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + std::cout << "TestCase Init" << std::endl; + kernel->Init(); + + memcpy(inputs[0]->Data(), data_a, sizeof(float) * element_num); + if (!is_bias_add) { + memcpy(inputs[1]->Data(), data_b, sizeof(float) * element_num_b); + } + + std::cout << "TestCase Run" << std::endl; + kernel->Run(); + + memcpy(data_c_ocl, outputs[0]->Data(), sizeof(float) * element_num); + + // ocl_runtime->SyncCommandQueue(); + LogData(data_a, 10, "Data A : "); + LogData(data_b, tensor_b->shape().empty() ? 1 : 10, "Data B : "); + LogData(data_c_cpu, 10, "Expect compute : "); + LogData(outputs[0]->Data(), 10, "OpenCL compute : "); + bool cmp = DataCompare(data_c_cpu, data_c_ocl, element_num); + MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!"); + std::cout << "TestCase End" << std::endl; + + // free + delete[] data_a; + delete[] data_b; + delete[] data_c_cpu; + delete[] data_c_ocl; + + delete kernel; + delete arith_kernel; + + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +class TestArithmeticOpenCL : public mindspore::Common { + public: + TestArithmeticOpenCL() {} +}; + +TEST_F(TestArithmeticOpenCL, AddElementwiseTest) { + const std::vector &shape_a = {1, 32, 32, 4}; + const std::vector &shape_b = {1, 32, 32, 4}; + TestCase(shape_a, shape_b); +} + +// TEST_F(TestOpenCLKernel, AddBoardcaseTest) { +// const std::vector &shape_a = {1, 4, 128, 128}; +// const std::vector &shape_b = {}; +// TestCase(shape_a, shape_b); +//} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc new file mode 100644 index 0000000000..075cce119e --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h" + +namespace mindspore { + +class TestAvgPoolingOpenCL : public mindspore::Common {}; + +void InitAvgPoolingParam(PoolingParameter *param) { + param->input_batch_ = 1; + param->input_h_ = 2; + param->input_w_ = 2; + param->input_channel_ = 4; + + param->output_batch_ = 1; + param->output_h_ = 1; + param->output_w_ = 1; + param->output_channel_ = 4; + + param->window_h_ = 2; + param->window_w_ = 2; + + param->stride_h_ = 2; + param->stride_w_ = 2; + + param->pad_u_ = 0; + param->pad_d_ = 0; + param->pad_l_ = 0; + param->pad_r_ = 0; + + param->max_pooling_ = false; + param->avg_pooling_ = true; +} + +TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) { + MS_LOG(INFO) << "start TEST_F TestPoolingOpenCL"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + + MS_LOG(INFO) << "create PoolingParameter"; + auto param = new PoolingParameter(); + InitAvgPoolingParam(param); + + MS_LOG(INFO) << "create Tensors"; + std::vector shape_in = { + param->input_batch_, + param->input_h_, + param->input_w_, + param->input_channel_, + }; + std::vector shape_out = { + param->output_batch_, + param->output_h_, + param->output_w_, + param->output_channel_, + }; + auto data_type = kNumberTypeFloat32; + auto tensorType = schema::NodeType_ValueNode; + lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NHWC, tensorType); + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NHWC, tensorType); + std::vector inputs{tensor_in}; + std::vector outputs{tensor_out}; + + MS_LOG(INFO) << "create OpenCL Kernel"; + auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast(param), inputs, outputs); + pooling_kernel->Init(); + std::vector kernels{pooling_kernel}; + + MS_LOG(INFO) << "create SubGraphOpenCLKernel"; + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + pGraph->Init(); + + MS_LOG(INFO) << "initialize data"; + std::vector tensor_map = {tensor_in}; + for (auto &tensor_file : tensor_map) { + auto tensor = tensor_file; + size_t size = tensor->Size(); + const float data[16] = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + memcpy(tensor->Data(), data, size); + } + + MS_LOG(INFO) << "pGraph->Run()"; + pGraph->Run(); + + MS_LOG(INFO) << "==================output data================="; + float *output_data = reinterpret_cast(tensor_out->Data()); + printf("output:"); + for (int i = 0; i < 4; i++) { + printf("%.3f ", output_data[i]); + } + printf("\n"); + size_t output_size = tensor_out->Size(); + float expect[4] = {2.0f, 3.0f, 4.0f, 5.0f}; + + for (int i = 0; i < tensor_out->ElementsNum(); ++i) + if (std::fabs(output_data[i] - expect[i]) > 1e-5) { + printf("idx[%d] except=%.3f output=%.3f, ", i, expect[i], output_data[i]); + } + printf("test all close OK!\n"); + lite::CompareOutputData(output_data, expect, 4); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc new file mode 100644 index 0000000000..7c039d47ca --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" + + +int DivideRoundUp(int n, int div) { + int q = n / div; + return n % div == 0 ? q : q + 1; +} +void printfNode(float *result, const std::vector &tempNode) { + for (int i = 0; i < tempNode[0]; i++) { + for (int j = 0; j < tempNode[1]; j++) { + for (int k = 0; k < tempNode[2]; k++) { + for (int w = 0; w < tempNode[3]; w++) { + std::cout + << result[i * tempNode[2] * tempNode[1] * tempNode[3] + j * tempNode[2] * tempNode[3] + k * tempNode[3] + w] + << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +void ConcatComputeByCPU_2input_dim4_axis3(float *input0, float *input1, float *output, std::vector input_shape0, + std::vector input_shape1, std::vector output_shape, + const int axis) { + int postion, index0 = 0, index1 = 0; + for (int i = 0; i < output_shape[0]; i++) { + for (int j = 0; j < output_shape[1]; j++) { + for (int k = 0; k < output_shape[2]; k++) { + postion = i * output_shape[1] * output_shape[2] * output_shape[3] + j * output_shape[2] * output_shape[3] + + k * output_shape[3]; + for (int w = 0; w < output_shape[3]; w++) { + if (w < input_shape0[3] + input_shape1[3]) { + output[postion++] = (w < input_shape0[3]) ? input0[index0++] : input1[index1++]; + } else { + for (int ind = input_shape0[3] + input_shape1[3]; ind < output_shape[3]; ind++) { + output[postion++] = 0; + } + } + } + } + } + } +} +void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *input2, float *output, + std::vector input_shape0, std::vector input_shape1, + std::vector input_shape2, std::vector output_shape, + const int axis) { + int postion, index0 = 0, index1 = 0, index2 = 0; + for (int i = 0; i < output_shape[0]; i++) { + for (int j = 0; j < output_shape[1]; j++) { + for (int k = 0; k < output_shape[2]; k++) { + postion = i * output_shape[1] * output_shape[2] * output_shape[3] + j * output_shape[2] * output_shape[3] + + k * output_shape[3]; + for (int w = 0; w < output_shape[3]; w++) { + if (w < input_shape0[3] + input_shape1[3]) { + output[postion++] = (w < input_shape0[3]) ? input0[index0++] : input1[index1++]; + } else if ((input_shape0[3] + input_shape1[3]) <= w && + w < (input_shape0[3] + input_shape1[3] + input_shape2[3])) { + output[postion++] = input2[index2++]; + } else { + for (int ind = input_shape0[3] + input_shape1[3]; ind < output_shape[3]; ind++) { + output[postion++] = 0; + } + } + } + } + } + } +} + +namespace mindspore { +class TestConcatOpenCL : public mindspore::Common { + public: + TestConcatOpenCL(){} +}; +TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { + MS_LOG(INFO) << "begin test"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + + MS_LOG(INFO) << "init tensors"; + constexpr int INPUT_NUM = 3; + std::array, INPUT_NUM> input_shapes = { + std::vector{1, 240, 240, 16}, std::vector{1, 240, 240, 16}, std::vector{1, 240, 240, 64}}; + std::vector output_shape = {1, 240, 240, 96}; + output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; + auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; + std::vector inputs; + for (auto &shape : input_shapes) { + inputs.push_back(new lite::tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); + } + auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); + std::vector outputs{output_tensor}; + std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; + MS_LOG(INFO) << "initialize tensors"; + auto param = new ConcatParameter(); + param->axis_ = 3; + auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); + concat_kernel->Init(); + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{concat_kernel}; + auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + sub_graph->Init(); + + MS_LOG(INFO) << "initialize input data"; + srand(time(NULL)); + for (auto &input_tensor : inputs) { + auto input_data = reinterpret_cast(input_tensor->Data()); + static unsigned int seed = 123; + for (int i = 0; i < input_tensor->ElementsNum(); ++i) { + input_data[i] = static_cast(rand_r(&seed) % 10 + 1); + } + printf("\n"); + } + + MS_LOG(INFO) << "==================output data================"; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); + printf("\n"); + auto *input_data0 = reinterpret_cast(inputs[0]->Data()); + auto *input_data1 = reinterpret_cast(inputs[1]->Data()); + std::vector output_data_cpu(output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]); + if (inputs.size() == 2) { + ConcatComputeByCPU_2input_dim4_axis3(input_data0, input_data1, output_data_cpu.data(), input_shapes[0], + input_shapes[1], output_shape, param->axis_); + } + if (inputs.size() == 3) { + auto *input_data2 = reinterpret_cast(inputs[2]->Data()); + ConcatComputeByCPU_3input_dim4_axis3(input_data0, input_data1, input_data2, output_data_cpu.data(), input_shapes[0], + input_shapes[1], input_shapes[2], output_shape, param->axis_); + } + printf("\n"); + CompareOutputData(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); + MS_LOG(INFO) << "Testconcat passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc new file mode 100644 index 0000000000..a16485d2f1 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +class TestConv2dTransposeOpenCL : public mindspore::Common { + public: + TestConv2dTransposeOpenCL() {} +}; + +TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { + // setbuf(stdout, NULL); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + int pad = 0; + int n = 1; + int h = 240; + int w = 240; + int kh = 2; + int kw = 2; + int ci = 128; + int co = 128; + int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1; + int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; + + size_t input_size; + std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + size_t weight_size; + std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + + size_t bias_size; + std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin"; + auto bias_data = reinterpret_cast(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size)); + + lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, ci}); + tensor_x->SetData(input_data); + + lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, kh, kw, ci}); + tensor_w->SetData(weight_data); + + lite::tensor::Tensor *tensor_bias = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co}); + tensor_bias->SetData(bias_data); + + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, oh, ow, co}); + std::vector inputs{tensor_x, tensor_w, tensor_bias}; + std::vector outputs{tensor_out}; + ConvParameter *opParameter = new ConvParameter(); + opParameter->kernel_h_ = kh; + opParameter->kernel_w_ = kw; + opParameter->stride_h_ = 2; + opParameter->stride_w_ = 2; + opParameter->pad_h_ = pad; + opParameter->pad_w_ = pad; + opParameter->input_channel_ = ci; + opParameter->output_channel_ = co; + auto *arith_kernel = + new kernel::Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + arith_kernel->Init(); + + std::vector kernels{arith_kernel}; + auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); + pGraph->Init(); + pGraph->Run(); + + printf("==================output data=================\n"); + float *output_data = reinterpret_cast(tensor_out->Data()); + std::cout << std::endl; + size_t output_size; + std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + int size_n = oh * ow * co; + size_n = size_n > 100 ? 100 : size_n; + for (int i = 0; i < size_n; i++) { + std::cout << output_data[i] << ", "; + if ((i + 1) % co == 0) { + std::cout << std::endl; + } + } + std::cout << std::endl; + + // compare + CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); + + MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc new file mode 100755 index 0000000000..59e4d93288 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -0,0 +1,864 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" + + +#define SAFE_DELETE_ARRAY(a) \ + if (a != nullptr) { \ + delete[] a; \ + a = nullptr; \ + } +#define SAFE_DELETE_PTR(a) \ + if (a != nullptr) { \ + delete a; \ + a = nullptr; \ + } + +bool IMAGE2D_OPEN = true; + +namespace mindspore { +class TestConvolutionDwOpenCL : public mindspore::Common { + public: + TestConvolutionDwOpenCL(){} +}; + +void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, + schema::Format format, bool is_compare = true) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + // pack input + int IC4 = UP_DIV(conv_param->input_channel_, C4NUM); + int pack_input_size = C4NUM * IC4 * conv_param->input_h_ * conv_param->input_w_; + float *packed_input = new float[pack_input_size]; + memset(packed_input, 0, pack_input_size * sizeof(float)); + int plane = conv_param->input_w_ * conv_param->input_h_; + if (format == schema::Format_NHWC4) { + PackNHWCToNHWC4Fp32(input_data, packed_input, 1, plane, conv_param->input_channel_); + } else { + PackNHWCToNC4HW4Fp32(input_data, packed_input, 1, plane, conv_param->input_channel_); + } + + // pack weight + int OC4 = UP_DIV(conv_param->output_channel_, C4NUM); + int pack_weight_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_; + float *packed_weight = weight_data; + + // float bias_data[] = {0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.0, 0.0, 0.0}; + float bias_data[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + size_t packed_output_size = conv_param->output_batch_ * C4NUM * UP_DIV(conv_param->output_channel_, C4NUM) * + conv_param->output_h_ * conv_param->output_w_; + + std::vector shape_in = {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, + conv_param->input_channel_}; // Note!!!actual is NHWC4 + std::vector shape_filter = {1, conv_param->kernel_h_, conv_param->kernel_w_, conv_param->output_channel_}; + std::vector shape_bias = {conv_param->output_channel_}; + std::vector shape_out = {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, + conv_param->output_channel_}; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_in, format); // Note!!!actual is NHWC4 + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_filter, schema::Format_NHWC); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_bias, schema::Format_NHWC); + lite::tensor::Tensor *tensor_d = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_out, format); + std::vector inputs{tensor_a, tensor_b, tensor_c}; + std::vector outputs{tensor_d}; + + // freamework to do!!! + inputs[1]->SetData(packed_weight); + inputs[2]->SetData(bias_data); + + OpParameter * parameter = reinterpret_cast(conv_param); + auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + pKernel->Init(); + + std::vector kernels{pKernel}; + std::vector inputs_{tensor_a}; + size_t C4 = UP_DIV(inputs[0]->Channel(), C4NUM); + // if (IMAGE2D_OPEN && format == schema::Format_NHWC4) { + // std::vector img_size{inputs[0]->Width() * C4, (size_t)inputs[0]->Height(), CL_FLOAT}; + // auto in_data = allocator->Malloc(inputs[0]->Size(), img_size); + // inputs[0]->SetData(in_data); + // } else if (IMAGE2D_OPEN && format == schema::Format_NC4HW4) { + // std::vector img_size{(size_t)inputs[0]->Width(), inputs[0]->Height() * C4, CL_FLOAT}; + // auto in_data = allocator->Malloc(inputs[0]->Size(), img_size); + // inputs[0]->SetData(in_data); + // } else { + inputs[0]->MallocData(allocator); + // } + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + pGraph->Init(); + + // freamework to do!!! + memcpy(inputs[0]->Data(), packed_input, sizeof(float) * pack_input_size); + + pGraph->Run(); + if (is_compare) { + float_t* packed_output = reinterpret_cast(outputs[0]->Data()); + float_t *packed_correct_data = new float_t[packed_output_size]; + memset(packed_correct_data, 0, packed_output_size * sizeof(float_t)); + if (format == schema::Format_NC4HW4) { + PackNHWCToNC4HW4Fp32(gnd_data, packed_correct_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + } else { + PackNHWCToNHWC4Fp32(gnd_data, packed_correct_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + } + + printf("==================input_data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_input_size; i++) { + std::cout << packed_input[i] << ", "; + } + std::cout << std::endl; + printf("==================weight data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_weight_size; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + printf("==================output data=================\n"); + std::cout << std::endl; + for (int i = 0; i < 80/*packed_output_size*/; i++) { + std::cout << packed_output[i] << ", "; + } + std::cout << std::endl; + printf("==================expected output data=================\n"); + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_correct_data[i] << ", "; + } + std::cout << std::endl; + // compare + Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + SAFE_DELETE_ARRAY(packed_correct_data) + } + + inputs[1]->SetData(nullptr); + inputs[2]->SetData(nullptr); + SAFE_DELETE_ARRAY(packed_input); + for (auto tensor : inputs) { + SAFE_DELETE_PTR(tensor) + } + for (auto tensor : outputs) { + SAFE_DELETE_PTR(tensor) + } + SAFE_DELETE_PTR(pKernel) + SAFE_DELETE_PTR(pGraph) + return; +} + +TEST_F(TestConvolutionDwOpenCL, NoPadNC4HW4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 4; + conv_param->input_w_ = 4; + conv_param->input_channel_ = 4; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 2; + conv_param->output_w_ = 2; + conv_param->output_channel_ = 4; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + } + + // nhwc + float input_data[] = {0.5488135, 0.0202184, 0.45615032, 0.31542835, 0.71518934, 0.83261985, 0.56843394, 0.36371076, + 0.60276335, 0.77815676, 0.0187898, 0.57019675, 0.5448832, 0.87001216, 0.6176355, 0.43860152, + 0.4236548, 0.9786183, 0.6120957, 0.9883738, 0.6458941, 0.7991586, 0.616934, 0.10204481, + 0.4375872, 0.46147937, 0.94374806, 0.20887676, 0.891773, 0.7805292, 0.6818203, 0.16130951, + 0.96366274, 0.11827443, 0.3595079, 0.6531083, 0.3834415, 0.639921, 0.43703195, 0.2532916, + 0.79172504, 0.14335328, 0.6976312, 0.46631077, 0.5288949, 0.9446689, 0.06022547, 0.2444256, + 0.56804454, 0.5218483, 0.6667667, 0.15896958, 0.92559665, 0.41466194, 0.67063785, 0.11037514, + 0.07103606, 0.2645556, 0.21038257, 0.6563296, 0.0871293, 0.7742337, 0.12892629, 0.13818295}; + + // co h w ci + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + + // pack correct data, nhwc + float gnd_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, + 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 3; + conv_param->input_w_ = 3; + conv_param->input_channel_ = 5; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 3; + conv_param->output_w_ = 3; + conv_param->output_channel_ = 5; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + } + + // nhwc + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + // float input_data[]={ + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 }; + // co h w ci + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + // float weight_data[]={ + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 }; + // pack correct data, nhwc + float gnd_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 4; + conv_param->input_w_ = 4; + conv_param->input_channel_ = 4; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 2; + conv_param->output_w_ = 2; + conv_param->output_channel_ = 4; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + } + + // nhwc + float input_data[] = {0.5488135, 0.0202184, 0.45615032, 0.31542835, 0.71518934, 0.83261985, 0.56843394, 0.36371076, + 0.60276335, 0.77815676, 0.0187898, 0.57019675, 0.5448832, 0.87001216, 0.6176355, 0.43860152, + 0.4236548, 0.9786183, 0.6120957, 0.9883738, 0.6458941, 0.7991586, 0.616934, 0.10204481, + 0.4375872, 0.46147937, 0.94374806, 0.20887676, 0.891773, 0.7805292, 0.6818203, 0.16130951, + 0.96366274, 0.11827443, 0.3595079, 0.6531083, 0.3834415, 0.639921, 0.43703195, 0.2532916, + 0.79172504, 0.14335328, 0.6976312, 0.46631077, 0.5288949, 0.9446689, 0.06022547, 0.2444256, + 0.56804454, 0.5218483, 0.6667667, 0.15896958, 0.92559665, 0.41466194, 0.67063785, 0.11037514, + 0.07103606, 0.2645556, 0.21038257, 0.6563296, 0.0871293, 0.7742337, 0.12892629, 0.13818295}; + + // co h w ci + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + + // pack correct data, nhwc + float gnd_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, + 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 3; + conv_param->input_w_ = 3; + conv_param->input_channel_ = 5; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 3; + conv_param->output_w_ = 3; + conv_param->output_channel_ = 5; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + } + + // nhwc + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + // float input_data[]={ + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 }; + // co h w ci + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + // float weight_data[]={ + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 }; + // pack correct data, nhwc + float gnd_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + + +TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 4; + conv_param->input_w_ = 4; + conv_param->input_channel_ = 4; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 2; + conv_param->output_w_ = 2; + conv_param->output_channel_ = 4; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 0; + conv_param->pad_w_ = 0; + } + + // nhwc + float input_data[] = {0.5488135, 0.0202184, 0.45615032, 0.31542835, 0.71518934, 0.83261985, 0.56843394, 0.36371076, + 0.60276335, 0.77815676, 0.0187898, 0.57019675, 0.5448832, 0.87001216, 0.6176355, 0.43860152, + 0.4236548, 0.9786183, 0.6120957, 0.9883738, 0.6458941, 0.7991586, 0.616934, 0.10204481, + 0.4375872, 0.46147937, 0.94374806, 0.20887676, 0.891773, 0.7805292, 0.6818203, 0.16130951, + 0.96366274, 0.11827443, 0.3595079, 0.6531083, 0.3834415, 0.639921, 0.43703195, 0.2532916, + 0.79172504, 0.14335328, 0.6976312, 0.46631077, 0.5288949, 0.9446689, 0.06022547, 0.2444256, + 0.56804454, 0.5218483, 0.6667667, 0.15896958, 0.92559665, 0.41466194, 0.67063785, 0.11037514, + 0.07103606, 0.2645556, 0.21038257, 0.6563296, 0.0871293, 0.7742337, 0.12892629, 0.13818295}; + + // pack input + int IC4 = UP_DIV(conv_param->input_channel_, C4NUM); + int pack_input_size = C4NUM * IC4 * conv_param->input_h_ * conv_param->input_w_; + float *packed_input = input_data; + + // co h w ci + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + + // pack weight + int OC4 = UP_DIV(conv_param->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param->kernel_h_ * conv_param->kernel_w_; + float *packed_weight = weight_data; + + // float bias_data[] = {0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.0, 0.0, 0.0}; + float bias_data[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + size_t packed_output_size = conv_param->output_batch_ * C4NUM * UP_DIV(conv_param->output_channel_, C4NUM) * + conv_param->output_h_ * conv_param->output_w_; + + std::vector shape_in = {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, + IC4 * C4NUM}; // Note!!!actual is NHWC4 + std::vector shape_filter = {1, conv_param->kernel_h_, conv_param->kernel_w_, conv_param->output_channel_}; + std::vector shape_bias = {conv_param->output_channel_}; + std::vector shape_out = {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, + conv_param->output_channel_}; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_in, schema::Format_NC4HW4); // Note!!!actual is NHWC4 + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_filter, schema::Format_NHWC); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_bias, schema::Format_NHWC); + lite::tensor::Tensor *tensor_d = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_out, schema::Format_NC4HW4); + std::vector inputs{tensor_a, tensor_b, tensor_c}; + std::vector outputs{tensor_d}; + + // freamework to do!!! + inputs[1]->SetData(packed_weight); + inputs[2]->SetData(bias_data); + + OpParameter * parameter = reinterpret_cast(conv_param); + auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + pKernel->Init(); + + std::vector kernels{pKernel}; + std::vector inputs_{tensor_a}; + inputs[0]->MallocData(); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + pGraph->Init(); + + // freamework to do!!! + memcpy(inputs[0]->Data(), packed_input, sizeof(float) * pack_input_size); + + pGraph->Run(); + float *packed_output = reinterpret_cast(outputs[0]->Data()); + + // pack correct data, nhwc + float packed_correct_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, + 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; + + printf("==================input_data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_input_size; i++) { + std::cout << packed_input[i] << ", "; + } + std::cout << std::endl; + printf("==================packed_weight data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_weight_size; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + printf("==================output data=================\n"); + std::cout << std::endl; + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_output[i] << ", "; + } + std::cout << std::endl; + printf("==================expected output data=================\n"); + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_correct_data[i] << ", "; + } + std::cout << std::endl; + // compare + Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + + inputs[1]->SetData(nullptr); + inputs[2]->SetData(nullptr); + for (auto tensor : inputs) { + SAFE_DELETE_PTR(tensor) + } + for (auto tensor : outputs) { + SAFE_DELETE_PTR(tensor) + } + SAFE_DELETE_PTR(pKernel) + SAFE_DELETE_PTR(pGraph) + MS_LOG(INFO) << "TestConvolutionDwNoPadFp32 passed"; + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = 3; + conv_param->input_w_ = 3; + conv_param->input_channel_ = 5; + conv_param->output_batch_ = 1; + conv_param->output_h_ = 3; + conv_param->output_w_ = 3; + conv_param->output_channel_ = 5; + conv_param->kernel_h_ = 3; + conv_param->kernel_w_ = 3; + conv_param->stride_h_ = 1; + conv_param->stride_w_ = 1; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + conv_param->pad_h_ = 1; + conv_param->pad_w_ = 1; + } + + // nhwc + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + // float input_data[]={ + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 , + // 1 , 1 , 1 , 1 , 1 }; + + // pack input + int IC4 = UP_DIV(conv_param->input_channel_, C4NUM); + int pack_input_size = C4NUM * IC4 * conv_param->input_h_ * conv_param->input_w_; + float *packed_input = new float[pack_input_size]; + memset(packed_input, 0, pack_input_size * sizeof(float)); + int plane = conv_param->input_w_ * conv_param->input_h_; + PackNHWCToNC4HW4Fp32(input_data, packed_input, 1, plane, conv_param->input_channel_); + + // co h w ci + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + // float weight_data[]={ + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 , + // 1 , 1 , 1 }; + + // pack weight + int OC4 = UP_DIV(conv_param->output_channel_, C4NUM); + int pack_weight_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_; + float *packed_weight = weight_data; + + // float bias_data[] = {0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.0, 0.0, 0.0}; + float bias_data[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + size_t packed_output_size = conv_param->output_batch_ * C4NUM * UP_DIV(conv_param->output_channel_, C4NUM) * + conv_param->output_h_ * conv_param->output_w_; + + std::vector shape_in = {conv_param->input_batch_, conv_param->input_h_, conv_param->input_w_, + IC4 * C4NUM}; // Note!!!actual is NHWC4 + std::vector shape_filter = {1, conv_param->kernel_h_, conv_param->kernel_w_, conv_param->output_channel_}; + std::vector shape_bias = {conv_param->output_channel_}; + std::vector shape_out = {conv_param->output_batch_, conv_param->output_h_, conv_param->output_w_, + conv_param->output_channel_}; + lite::tensor::Tensor *tensor_a = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_in, schema::Format_NC4HW4); // Note!!!actual is NHWC4 + lite::tensor::Tensor *tensor_b = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_filter, schema::Format_NHWC); + lite::tensor::Tensor *tensor_c = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_bias, schema::Format_NHWC); + lite::tensor::Tensor *tensor_d = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), shape_out, schema::Format_NC4HW4); + std::vector inputs{tensor_a, tensor_b, tensor_c}; + std::vector outputs{tensor_d}; + + // freamework to do!!! + inputs[1]->SetData(packed_weight); + inputs[2]->SetData(bias_data); + + OpParameter * parameter = reinterpret_cast(conv_param); + auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + pKernel->Init(); + + std::vector kernels{pKernel}; + std::vector inputs_{tensor_a}; + inputs[0]->MallocData(); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + pGraph->Init(); + + // freamework to do!!! + memcpy(inputs[0]->Data(), packed_input, sizeof(float) * pack_input_size); + + pGraph->Run(); + float *packed_output = reinterpret_cast(outputs[0]->Data()); + + // pack correct data, nhwc + float correct_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + float *packed_correct_data = new float[packed_output_size]; + memset(packed_correct_data, 0, packed_output_size * sizeof(float)); + PackNHWCToNC4HW4Fp32(correct_data, packed_correct_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + + printf("==================input_data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_input_size; i++) { + std::cout << packed_input[i] << ", "; + } + std::cout << std::endl; + printf("==================weight data=================\n"); + std::cout << std::endl; + for (int i = 0; i < pack_weight_size; i++) { + std::cout << packed_weight[i] << ", "; + } + std::cout << std::endl; + printf("==================output data=================\n"); + std::cout << std::endl; + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_output[i] << ", "; + } + std::cout << std::endl; + printf("==================expected output data=================\n"); + for (int i = 0; i < packed_output_size; i++) { + std::cout << packed_correct_data[i] << ", "; + } + std::cout << std::endl; + // compare + Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + + inputs[1]->SetData(nullptr); + inputs[2]->SetData(nullptr); + SAFE_DELETE_ARRAY(packed_input); + SAFE_DELETE_ARRAY(packed_correct_data) + for (auto tensor : inputs) { + SAFE_DELETE_PTR(tensor) + } + for (auto tensor : outputs) { + SAFE_DELETE_PTR(tensor) + } + SAFE_DELETE_PTR(pKernel) + SAFE_DELETE_PTR(pGraph) + MS_LOG(INFO) << "TestConvolutionDwPadFp32 passed"; + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { + std::vector> src_shape{ + {1, 32, 112, 112}, + {1, 96, 112, 112}, + {1, 144, 56, 56}, + {1, 144, 56, 56}, + {1, 192, 28, 28}, + {1, 192, 28, 28}, + {1, 384, 14, 14}, + {1, 576, 14, 14}, + {1, 576, 14, 14}, + {1, 960, 7, 7}, + }; + std::vector> dst_shape{ + {1, 32, 112, 112}, + {1, 96, 56, 56}, + {1, 144, 56, 56}, + {1, 144, 28, 28}, + {1, 192, 28, 28}, + {1, 192, 14, 14}, + {1, 384, 14, 14}, + {1, 576, 14, 14}, + {1, 576, 7, 7}, + {1, 960, 7, 7}, + }; + std::vector> filter_shape{ + {32, 1, 1, 1}, + {96, 3, 3, 1}, + {144, 1, 1, 1}, + {144, 3, 3, 1}, + {192, 1, 1, 1}, + {192, 3, 3, 1}, + {384, 1, 1, 1}, + {576, 1, 1, 1}, + {576, 3, 3, 1}, + {960, 1, 1, 1}, + }; + + // nhwc + size_t in_size = 96*112*112; + float_t *input_data = new float_t[in_size]; + memset(input_data, 0, in_size); + for (auto i = 0; i < in_size; ++i) { + input_data[i] = 1; + } + // co h w ci + size_t wt_size = 576*3*3; + float_t *weight_data = new float_t[wt_size]; + memset(weight_data, 0, wt_size); + for (auto i = 0; i < wt_size; ++i) { + weight_data[i] = 1; + } + size_t out_size = 96*112*112; + float_t *gnd_data = new float_t[out_size]; + memset(gnd_data, 0, out_size); +// for (auto i = 0; i < in_size; ++i) { +// gnd_data[i] = 1; +// } + for (size_t i = 0; i < src_shape.size(); ++i) { + const int MAX_RUN_TIMES = 1; + for (int j = 0; j < MAX_RUN_TIMES; ++j) { + printf("========profiling depthwise, in shape(%d,%d,%d,%d), out shape(%d,%d,%d,%d), iter%d========\n", + src_shape[i][0], src_shape[i][1], src_shape[i][2], src_shape[i][3], + dst_shape[i][0], dst_shape[i][1], dst_shape[i][2], dst_shape[i][3], j); + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = src_shape[i][2]; + conv_param->input_w_ = src_shape[i][3]; + conv_param->input_channel_ = src_shape[i][1]; + conv_param->output_batch_ = 1; + conv_param->output_h_ = dst_shape[i][2]; + conv_param->output_w_ = dst_shape[i][3]; + conv_param->output_channel_ = dst_shape[i][1]; + conv_param->kernel_h_ = filter_shape[i][1]; + conv_param->kernel_w_ = filter_shape[i][2]; + conv_param->stride_h_ = conv_param->output_h_/conv_param->input_h_; + conv_param->stride_w_ = conv_param->output_w_/conv_param->input_w_; + conv_param->pad_h_ = (conv_param->kernel_h_-1)/2; + conv_param->pad_w_ = (conv_param->kernel_w_-1)/2; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + } +// DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4, false); + DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NHWC4, false); + } + } + SAFE_DELETE_ARRAY(input_data); + SAFE_DELETE_ARRAY(weight_data); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestConvolutionDwOpenCL, Buffer2Image) { + std::vector src_shape{1, 96, 64, 64}; + std::vector dst_shape{1, 96, 32, 32}; + std::vector filter_shape{96, 3, 3, 1}; + + // nhwc + size_t in_size = 96*112*112; + float_t *input_data = new float_t[in_size]; + memset(input_data, 0, in_size); + for (auto i = 0; i < in_size; ++i) { + input_data[i] = 1; + } + // co h w ci + size_t wt_size = 576*3*3; + float_t *weight_data = new float_t[wt_size]; + memset(weight_data, 0, wt_size); + for (auto i = 0; i < wt_size; ++i) { + weight_data[i] = 1; + } + size_t out_size = 96*112*112; + float_t *gnd_data = new float_t[out_size]; + memset(gnd_data, 0, out_size); +// for (auto i = 0; i < in_size; ++i) { +// gnd_data[i] = 1; +// } + ConvParameter *conv_param = new ConvParameter(); + { + conv_param->input_batch_ = 1; + conv_param->input_h_ = src_shape[2]; + conv_param->input_w_ = src_shape[3]; + conv_param->input_channel_ = src_shape[1]; + conv_param->output_batch_ = 1; + conv_param->output_h_ = dst_shape[2]; + conv_param->output_w_ = dst_shape[3]; + conv_param->output_channel_ = dst_shape[1]; + conv_param->kernel_h_ = filter_shape[1]; + conv_param->kernel_w_ = filter_shape[2]; + conv_param->stride_h_ = conv_param->output_h_/conv_param->input_h_; + conv_param->stride_w_ = conv_param->output_w_/conv_param->input_w_; + conv_param->pad_h_ = (conv_param->kernel_h_-1)/2; + conv_param->pad_w_ = (conv_param->kernel_w_-1)/2; + conv_param->dilation_h_ = 1; + conv_param->dilation_w_ = 1; + } +// DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4, true); + DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4, true); + SAFE_DELETE_ARRAY(input_data); + SAFE_DELETE_ARRAY(weight_data); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc new file mode 100644 index 0000000000..50d46b2a68 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h" + +namespace mindspore { +class TestMatMulOpenCL : public mindspore::Common { + public: + TestMatMulOpenCL() {} +}; + +TEST_F(TestMatMulOpenCL, MatMulFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + size_t input_size; + int ci = 1280; + int co = 1001; + std::string input_path = "./test_data/matmul/matmul_fp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + size_t weight_size; + std::string weight_path = "./test_data/matmul/matmul_fp32_weight.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + + lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, ci}); + tensor_x->SetData(input_data); + + lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, 1, 1, ci}); + tensor_w->SetData(weight_data); + + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, co}); + std::vector inputs{tensor_x, tensor_w}; + std::vector outputs{tensor_out}; + auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); + arith_kernel->Init(); + + std::vector kernels{arith_kernel}; + auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); + pGraph->Init(); + pGraph->Run(); + + size_t output_size; + std::string output_path = "./test_data/matmul/matmul_fp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + printf("==================output data=================\n"); + float *output_data = reinterpret_cast(tensor_out->Data()); + std::cout << std::endl; + int size_n = co; + size_n = size_n > 100 ? 100 : size_n; + for (int i = 0; i < size_n; i++) { + std::cout << output_data[i] << " "; + } + std::cout << std::endl; + + + // compare + CompareOutputData(output_data, correct_data, co, 0.00001); + + delete input_data; + delete weight_data; + delete tensor_x; + delete tensor_w; + delete tensor_out; + delete correct_data; + MS_LOG(INFO) << "TestMatMulFp32 passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc new file mode 100644 index 0000000000..d90be66a4c --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" + +namespace mindspore { + +class TestMaxPoolingOpenCL : public mindspore::Common {}; + +void InitParameter(PoolingParameter *param) { + param->window_h_ = 2; + param->window_w_ = 2; + param->stride_h_ = 2; + param->stride_w_ = 2; + param->pad_u_ = 0; + param->pad_d_ = 0; + param->pad_l_ = 0; + param->pad_r_ = 0; + param->avg_pooling_ = false; + param->max_pooling_ = true; +} + +TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { + MS_LOG(INFO) << "ocl runtime"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "PoolingParameter"; + auto param = new PoolingParameter; + InitParameter(param); + + // define tensor + MS_LOG(INFO) << "define tensor1"; + std::vector input_shape = {1, 16, 256, 192}; + std::vector output_shape = {1, 8, 128, 192}; + auto data_type = kNumberTypeFloat32; + auto tensorType = schema::NodeType_ValueNode; + MS_LOG(INFO) << "define tensor2"; + auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensorType); + auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensorType); + MS_LOG(INFO) << "define input"; + std::vector inputs{input_tensor}; + std::vector outputs{output_tensor}; + + // run + MS_LOG(INFO) << "pooling_kernel"; + auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast(param), inputs, outputs); + MS_LOG(INFO) << "pooling_kernel init"; + pooling_kernel->Init(); + + std::vector kernels{pooling_kernel}; + inputs[0]->MallocData(allocator); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + MS_LOG(INFO) << "pGraph init"; + pGraph->Init(); + + // load data + MS_LOG(INFO) << "load data1"; + std::string input_file = "maxpool_in.bin"; + std::string expect_file = "maxpool_out.bin"; + MS_LOG(INFO) << "load data2"; + LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); + auto *input_data = reinterpret_cast(input_tensor->Data()); + printf("input[0:10]:"); + for (int i = 0; i < 10; i++) { + printf("[%d]:%.3f ", i, input_data[i]); + } + printf("\n"); + + pGraph->Run(); + + MS_LOG(INFO) << "compare result"; + std::cout << "compare result" << std::endl; + CompareOutput(output_tensor, expect_file); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/opencl_kernel_tests.h b/mindspore/lite/test/ut/src/runtime/kernel/opencl/opencl_kernel_tests.h new file mode 100644 index 0000000000..bcde788efe --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/opencl_kernel_tests.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/common_test.h" +#include "mindspore/core/utils/log_adapter.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" + +#ifndef TESTS_UT_OPENCL_KERNLE_TESTS_H +#define TESTS_UT_OPENCL_KERNLE_TESTS_H + +namespace mindspore { + +class TestOpenCLKernel : public mindspore::Common { + public: + TestOpenCLKernel() {} +}; + +} // namespace mindspore +#endif // TESTS_UT_OPENCL_KERNLE_TESTS_H diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc new file mode 100644 index 0000000000..684190bde2 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" + +namespace mindspore { + +class TestSoftmaxOpenCL : public mindspore::Common {}; + +void InitSoftaxParam(SoftmaxParameter *param) { param->axis_ = -1; } + +TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { + std::cout << "======" << std::endl; + MS_LOG(INFO) << "start TEST_F TestSoftmaxOpenCL"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + + MS_LOG(INFO) << "create SoftmaxParameter"; + auto param = new SoftmaxParameter(); + InitSoftaxParam(param); + + MS_LOG(INFO) << "create Tensors"; + std::vector shape_in = {1, 2, 2, 1}; + std::vector shape_out = {1, 2, 2, 1}; + auto data_type = kNumberTypeFloat32; + auto tensorType = schema::NodeType_ValueNode; + lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NCHW, tensorType); + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NCHW, tensorType); + std::vector inputs{tensor_in}; + std::vector outputs{tensor_out}; + + MS_LOG(INFO) << "create OpenCL Kernel"; + auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); + Softmax_kernel->Init(); + std::vector kernels{Softmax_kernel}; + + MS_LOG(INFO) << "create SubGraphOpenCLKernel"; + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + pGraph->Init(); + + MS_LOG(INFO) << "initialize data"; + std::vector tensor_map = {tensor_in}; + for (auto &tensor_file : tensor_map) { + auto tensor = tensor_file; + size_t size = tensor->Size(); + const float data[4] = {std::log(1.0f), std::log(2.0f), std::log(3.0f), std::log(4.0f)}; + memcpy(tensor->Data(), data, size); + } + + MS_LOG(INFO) << "pGraph->Run()"; + pGraph->Run(); + + MS_LOG(INFO) << "==================output data================="; + float *output_data = reinterpret_cast(tensor_out->Data()); + size_t output_size = tensor_out->Size(); + + printf("output:"); + for (int i = 0; i < 4; i++) { + printf("%.3f ", output_data[i]); + } + printf("\n"); + float expect[4] = {1.0f, 2.0f, 3.0f, 4.0f}; + + for (int i = 0; i < tensor_out->ElementsNum(); ++i) { + if (std::fabs(output_data[i] - expect[i]) > 1e-5) { + printf("idx[%d] except=%.3f output=%.3f .", i, expect[i], output_data[i]); + } + } + printf("\nTest all close OK for %zu!\n", output_size); + lite::CompareOutputData(output_data, expect, 4); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc new file mode 100644 index 0000000000..20324e0cdb --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h" + +namespace mindspore { +class TestTransposeOpenCL : public mindspore::Common { + public: + TestTransposeOpenCL() {} +}; + +TEST_F(TestTransposeOpenCL, TransposeFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + int h = 64; + int w = 1; + int c = 7360; + size_t input_size; + std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + lite::tensor::Tensor *tensor_x = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4); + + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w}); + std::vector inputs{tensor_x}; + std::vector outputs{tensor_out}; + auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs); + arith_kernel->Init(); + + inputs[0]->MallocData(allocator); + + std::vector kernels{arith_kernel}; + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + pGraph->Init(); + memcpy(inputs[0]->Data(), input_data, input_size); + pGraph->Run(); + + size_t output_size; + std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + printf("==================output data=================\n"); + float *output_data = reinterpret_cast(tensor_out->Data()); + std::cout << std::endl; + int size_n = h * w * c; + size_n = size_n > 100 ? 100 : size_n; + for (int i = 0; i < size_n; i++) { + std::cout << output_data[i] << " "; + } + std::cout << std::endl; + + // compare + CompareOutputData(output_data, correct_data, h * w * c, 0.00001); + MS_LOG(INFO) << "TestMatMulFp32 passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc new file mode 100644 index 0000000000..e834e29b07 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "utils/log_adapter.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" + +namespace mindspore { + +void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) { + if (file_path.empty()) { + memset(dst, 0x00, dst_size); + } else { + auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); + if (src_data != nullptr) { + memcpy(dst, src_data, dst_size); + } else { + MS_LOG(ERROR) << "read file empty."; + } + } +} + +void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path) { + float *output_data = reinterpret_cast(output_tensor->Data()); + size_t output_size = output_tensor->Size(); + float *expect_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); + + printf("output[0:10]:"); + for (int i = 0; i < 10; i++) { + printf("[%d]:%.3f ", i, output_data[i]); + } + printf("\n"); + printf("expect[0:10]:"); + for (int i = 0; i < 10; i++) { + printf("[%d]:%.3f ", i, expect_data[i]); + } + printf("\n"); + + constexpr float atol = 1e-5; + for (int i = 0; i < output_tensor->ElementsNum(); ++i) { + if (std::fabs(output_data[i] - expect_data[i]) > atol) { + printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); + return; + } + } + printf("compare success!\n"); + printf("compare success!\n"); + printf("compare success!\n\n\n"); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h b/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h new file mode 100644 index 0000000000..90038c2ab3 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tests/ut/cpp/common/common_test.h" +#include "utils/log_adapter.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" + +#ifndef TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ +#define TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ + +namespace mindspore { + +void LoadTestData(void *dst, size_t dst_size, const std::string &file_path); + +void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path); + +} // namespace mindspore + +#endif // TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ diff --git a/mindspore/lite/test/ut/src/train_test.cc b/mindspore/lite/test/ut/src/train_test.cc new file mode 100644 index 0000000000..e64d5a9dfd --- /dev/null +++ b/mindspore/lite/test/ut/src/train_test.cc @@ -0,0 +1,287 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "utils/base_ref_utils.h" +#include "mindspore/lite/schema/inner/model_generated.h" +#include "mindspore/lite/src/train/model_impl.h" +#include "mindspore/lite/include/model.h" +#include "mindspore/lite/src/train/train_session.h" +#include "common/common_test.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +class TrainTest : public mindspore::Common { + public: + TrainTest() {} +}; + +TEST_F(TrainTest, TestConvNode) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + + auto node = std::make_unique(); + node->inputIndex = {0, 1}; + node->outputIndex = {2}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Conv2D; + auto primitive = new schema::Conv2DT; + primitive->padMode = schema::PadMode_SAME; + primitive->channelIn = 3; + primitive->channelOut = 32; + primitive->format = schema::Format_NHWC; + primitive->strideH = 1; + primitive->strideW = 1; + primitive->kernelH = 3; + primitive->kernelW = 3; + primitive->dilateH = 1; + primitive->dilateW = 1; + node->primitive->value.value = primitive; + node->name = "Conv2D"; + meta_graph->nodes.emplace_back(std::move(node)); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {2}; + + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_Parameter; // todo use ValueNode? + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 28, 28, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + auto weight = std::make_unique(); + weight->nodeType = schema::NodeType::NodeType_ValueNode; + weight->format = schema::Format_KHWC; + weight->dataType = TypeId::kNumberTypeFloat32; + weight->dims = {32, 3, 3, 3}; + + auto buf = new char *[1]; + //================================================================ + size_t weight_size; + std::string weight_path = "./convfp32_weight_32_3_3_3.bin"; + ReadFile(weight_path.c_str(), &weight_size, buf); + ASSERT_NE(nullptr, buf[0]); + auto weight_data_temp = reinterpret_cast(buf[0]); + ASSERT_NE(nullptr, weight_data_temp); + weight->data.resize(sizeof(float) * 32 * 3 * 3 * 3); + + //================================================================ + memcpy(weight->data.data(), weight_data_temp, weight_size); + weight->offset = -1; + meta_graph->allTensors.emplace_back(std::move(weight)); + + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 28, 28, 32}; + output->offset = -1; + meta_graph->allTensors.emplace_back(std::move(output)); + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); + builder.Finish(offset); + size_t size = builder.GetSize(); + const char *content = reinterpret_cast(builder.GetBufferPointer()); + + auto model = lite::Model::Import(content, size); + ASSERT_NE(nullptr, model); + auto session = new session::TrainSession(); // inference::MSSession::CreateSession(kCPUDevice, 0); + ASSERT_NE(nullptr, session); + auto graphId = session->CompileGraph(NOT_NULL(model->GetModelImpl())); + + auto inTensor = new tensor::Tensor(TypeId::kNumberTypeFloat32, {1, 28, 28, 3}); + ASSERT_NE(nullptr, inTensor); + ASSERT_EQ(sizeof(float) * (28 * 28 * 3), inTensor->Size()); + auto ret = inTensor->MallocData(); + ASSERT_EQ(0, ret); + auto data = inTensor->Data(); + //=================================================== + size_t input_size; + std::string input_path = "./convfp32_input_1_28_28_3.bin"; + ReadFile(input_path.c_str(), &input_size, buf); + ASSERT_NE(nullptr, buf[0]); + auto input_data = reinterpret_cast(buf[0]); + ASSERT_NE(nullptr, input_data); + //=================================================== + memcpy(data, input_data, input_size); + std::vector> inputs; + inputs.emplace_back(inTensor); + VectorRef outputsRef; + session->RunGraph(graphId, inputs, &outputsRef); + auto outputs = TransformVectorRefToMultiTensor(outputsRef); + ASSERT_EQ(1, outputs.size()); + ASSERT_EQ(1, outputs.front().size()); + auto runOutput = outputs.front().front(); + ASSERT_NE(nullptr, runOutput); + ASSERT_EQ(28 * 28 * 32, runOutput->ElementsNum()); + ASSERT_EQ(TypeId::kNumberTypeFloat32, runOutput->data_type()); + auto *outData = reinterpret_cast(runOutput->MutableData()); + //=================================================== + size_t output_size; + std::string output_path = "./convfp32_out_1_28_28_32.bin"; + ReadFile(output_path.c_str(), &output_size, buf); + ASSERT_NE(nullptr, buf[0]); + auto output_data = reinterpret_cast(buf[0]); + ASSERT_NE(nullptr, output_data); + //=================================================== + ASSERT_EQ(output_size, runOutput->Size()); + for (size_t i = 0; i < runOutput->ElementsNum(); i++) { + ASSERT_EQ(output_data[i], outData[i]); + } + MS_LOG(INFO) << "Passed"; +} + +// TEST_F(TrainTest, TestMultiNode) { +// auto msGraph = std::make_shared(); +// msGraph->name = "graph"; +// auto msSubgraph = std::make_unique(); +// msSubgraph->name = "subGraph"; +// +// auto conv = std::make_unique(); +// conv->inputIndex = {0, 1}; +// conv->outputIndex = {2}; +// conv->attr.type = schema::OpT_Conv2D; +// auto conv_attr = new schema::Conv2DT; +// conv_attr->padMode = schema::PadMode_SAME; +// conv_attr->format = schema::Format_NHWC; +// conv_attr->strideH = 1; +// conv_attr->strideW = 1; +// conv_attr->kernelH = 3; +// conv_attr->kernelW = 3; +// conv_attr->dilateH = 1; +// conv_attr->dilateW = 1; +// +// conv->attr.value = conv_attr; +// conv->name = "Conv2D"; +// conv->fmkType = schema::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(conv)); +// +// auto matMul1 = std::make_unique(); +// matMul1->inputIndex = {2, 3}; +// matMul1->outputIndex = {4}; +// matMul1->attr.type = schema::OpT_MatMul; +// auto matMul_attr1 = new schema::MatMulT; +// matMul_attr1->transposeA = false; +// matMul_attr1->transposeB = true; +// matMul1->attr.value = matMul_attr1; +// matMul1->name = "matmul1"; +// matMul1->fmkType = schema::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(matMul1)); +// +// auto matMul2 = std::make_unique(); +// matMul2->inputIndex = {4, 5}; +// matMul2->outputIndex = {6}; +// matMul2->attr.type = schema::OpT_MatMul; +// auto matMul_attr2 = new schema::MatMulT; +// matMul_attr2->transposeA = false; +// matMul_attr2->transposeB = true; +// matMul2->attr.value = matMul_attr2; +// matMul2->name = "matmul2"; +// matMul2->fmkType = schema::FmkType_CAFFE; +// msSubgraph->nodes.emplace_back(std::move(matMul2)); +// +// msSubgraph->inputIndex = {0}; +// msSubgraph->outputIndex = {6}; +// +// auto input0 = std::make_unique(); +// input0->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// input0->format = schema::Format_NHWC; +// input0->dataType = TypeId::kNumberTypeFloat32; +// input0->dims = {1, 5, 5, 3}; +// input0->offset = -1; +// msSubgraph->allTensors.emplace_back(std::move(input0)); +// +// auto conv_weight = std::make_unique(); +// conv_weight->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// conv_weight->format = schema::Format_KHWC; +// conv_weight->dataType = TypeId::kNumberTypeFloat32; +// conv_weight->dims = {8, 3, 3, 3}; +// conv_weight->data.resize(8*3*3*3*sizeof(float)); +// msSubgraph->allTensors.emplace_back(std::move(conv_weight)); +// +// auto conv_output = std::make_unique(); +// conv_output->refCount = 0; +// conv_output->format = schema::Format_NHWC; +// conv_output->dataType = TypeId::kNumberTypeFloat32; +// conv_output->dims = {1, 5, 5, 8}; +// msSubgraph->allTensors.emplace_back(std::move(conv_output)); +// +// auto add_weight = std::make_unique(); +// add_weight->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// add_weight->format = schema::Format_NHWC; +// add_weight->dataType = TypeId::kNumberTypeFloat32; +// add_weight->dims = {1, 5, 5, 8}; +// add_weight->data.resize(5*5*8*sizeof(float)); +// msSubgraph->allTensors.emplace_back(std::move(add_weight)); +// +// auto add_output = std::make_unique(); +// add_output->refCount = 0; +// add_output->format = schema::Format_NHWC; +// add_output->dataType = TypeId::kNumberTypeFloat32; +// add_output->dims = {1, 5, 5, 8}; +// msSubgraph->allTensors.emplace_back(std::move(add_output)); +// +// auto mul_weight = std::make_unique(); +// mul_weight->refCount = schema::MSCONST_WEIGHT_REFCOUNT; +// mul_weight->format = schema::Format_NHWC; +// mul_weight->dataType = TypeId::kNumberTypeFloat32; +// mul_weight->dims = {1, 5, 5, 8}; +// mul_weight->data.resize(5*5*8*sizeof(float)); +// msSubgraph->allTensors.emplace_back(std::move(mul_weight)); +// +// auto mul_output = std::make_unique(); +// mul_output->refCount = 0; +// mul_output->format = schema::Format_NHWC; +// mul_output->dataType = TypeId::kNumberTypeFloat32; +// mul_output->dims = {1, 5, 5, 8}; +// msSubgraph->allTensors.emplace_back(std::move(mul_output)); +// msGraph->subgraphs.emplace_back(std::move(msSubgraph)); +// +// flatbuffers::FlatBufferBuilder builder(1024); +// auto offset = schema::GraphDef::Pack(builder, msGraph.get()); +// builder.Finish(offset); +// size_t size = builder.GetSize(); +// const char *content = (char *)builder.GetBufferPointer(); +// const std::string strstub = ""; +// +// auto func_graph = inference::LoadModel(content, size, strstub); +// ASSERT_NE(nullptr, func_graph); +// auto session = inference::MSSession::CreateSession(kCPUDevice, 0); +// ASSERT_NE(nullptr, session); +// auto graphId = session->CompileGraph(func_graph); +// +// auto inTensor = +// std::shared_ptr(inference::MSTensor::CreateTensor(TypeId::kNumberTypeFloat32, {1, 5, 5, 3})); +// ASSERT_NE(nullptr, inTensor); +// ASSERT_EQ(sizeof(float) * (5 * 5 * 3), inTensor->Size()); +// (void)inTensor->MutableData(); +// +// std::vector> inputs; +// inputs.emplace_back(inTensor); +// auto outputs = session->RunGraph(graphId, inputs); +// ASSERT_EQ(1, outputs.size()); +// ASSERT_EQ(1, outputs.front().size()); +// auto runOutput = outputs.front().front(); +// ASSERT_NE(nullptr, runOutput); +// ASSERT_EQ(5 * 5 * 8, runOutput->ElementsNum()); +// ASSERT_EQ(TypeId::kNumberTypeFloat32, runOutput->data_type()); +// MS_LOG(INFO) << "Passed"; +// } +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/abs.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/abs.tflite new file mode 100644 index 0000000000..29768a273a Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/abs.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/addn.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/addn.tflite new file mode 100644 index 0000000000..d704d5d80d Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/addn.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/batch_to_space_nd.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/batch_to_space_nd.tflite new file mode 100644 index 0000000000..87f8e477c3 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/batch_to_space_nd.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/cast.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/cast.tflite new file mode 100644 index 0000000000..0369c36111 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/cast.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/cos.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/cos.tflite new file mode 100644 index 0000000000..2f90af4092 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/cos.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depth_to_space.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depth_to_space.tflite new file mode 100644 index 0000000000..2f9dbe01c4 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depth_to_space.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/equal.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/equal.tflite new file mode 100644 index 0000000000..b0325f43ec Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/equal.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/greater.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/greater.tflite new file mode 100644 index 0000000000..6bf638e157 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/greater.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/greater_equal.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/greater_equal.tflite new file mode 100644 index 0000000000..eabe745325 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/greater_equal.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/less.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/less.tflite new file mode 100644 index 0000000000..9d76ebcefd Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/less.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/less_equal.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/less_equal.tflite new file mode 100644 index 0000000000..7d9898d627 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/less_equal.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/log.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/log.tflite new file mode 100644 index 0000000000..734fa6e775 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/log.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_and.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_and.tflite new file mode 100644 index 0000000000..a26b1bebf0 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_and.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_not.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_not.tflite new file mode 100644 index 0000000000..9cbbe05d44 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_not.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_or.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_or.tflite new file mode 100644 index 0000000000..ac539be1b7 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logical_or.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/maximum.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/maximum.tflite new file mode 100644 index 0000000000..3d3b6a3eef Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/maximum.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/minimum.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/minimum.tflite new file mode 100644 index 0000000000..691b3c25d6 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/minimum.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/not_equal.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/not_equal.tflite new file mode 100644 index 0000000000..580f650e23 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/not_equal.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/one_hot.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/one_hot.tflite new file mode 100644 index 0000000000..3e81d98f45 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/one_hot.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/prelu.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/prelu.tflite new file mode 100644 index 0000000000..dd432b62ac Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/prelu.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_max.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_max.tflite new file mode 100644 index 0000000000..96fde48b3a Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_max.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_min.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_min.tflite new file mode 100644 index 0000000000..e60aaf4756 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_min.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_prod.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_prod.tflite new file mode 100644 index 0000000000..7eaa9f4f24 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reduce_prod.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reverse_sequence.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reverse_sequence.tflite new file mode 100644 index 0000000000..171aaa8cae Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/reverse_sequence.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/round.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/round.tflite new file mode 100644 index 0000000000..1c77775737 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/round.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/rsqrt.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/rsqrt.tflite new file mode 100644 index 0000000000..6275509186 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/rsqrt.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sin.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sin.tflite new file mode 100644 index 0000000000..a06748d47c Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sin.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/space_to_batch_nd.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/space_to_batch_nd.tflite new file mode 100644 index 0000000000..f05f6fef85 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/space_to_batch_nd.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/space_to_depth.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/space_to_depth.tflite new file mode 100644 index 0000000000..3ed48958f0 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/space_to_depth.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sparse_to_dense.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sparse_to_dense.tflite new file mode 100644 index 0000000000..b3ef9112d3 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sparse_to_dense.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/split.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/split.tflite new file mode 100644 index 0000000000..def7d53523 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/split.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/split_v.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/split_v.tflite new file mode 100644 index 0000000000..61405534dc Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/split_v.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sqrt.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sqrt.tflite new file mode 100644 index 0000000000..984f61353d Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sqrt.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/square.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/square.tflite new file mode 100644 index 0000000000..4c907b4a1f Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/square.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/squared_difference.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/squared_difference.tflite new file mode 100644 index 0000000000..e0f454a5e6 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/squared_difference.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/strided_slice.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/strided_slice.tflite new file mode 100644 index 0000000000..d586c1faab Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/strided_slice.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sum.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sum.tflite new file mode 100644 index 0000000000..d84b2fedbc Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sum.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/tile.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/tile.tflite new file mode 100644 index 0000000000..b959b4bfdc Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/tile.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/topk_v2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/topk_v2.tflite new file mode 100644 index 0000000000..4f8e658707 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/topk_v2.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/unique.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/unique.tflite new file mode 100644 index 0000000000..9bd16c5ddf Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/unique.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/unstack.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/unstack.tflite new file mode 100644 index 0000000000..93703929f9 Binary files /dev/null and b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/unstack.tflite differ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_abs_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_abs_parser_test.cc new file mode 100644 index 0000000000..c90ff924a1 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_abs_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserAbs : public TestTfliteParser { + public: + TestTfliteParserAbs() {} + void SetUp() override { meta_graph = LoadAndConvert("./abs.tflite", ""); } +}; + +TEST_F(TestTfliteParserAbs, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Abs) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc new file mode 100644 index 0000000000..89519cea10 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserAddN : public TestTfliteParser { + public: + TestTfliteParserAddN() {} + void SetUp() override { + meta_graph = LoadAndConvert("./addn.tflite"); + } +}; + +TEST_F(TestTfliteParserAddN, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_AddN) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserAddN, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc new file mode 100644 index 0000000000..4d37fa506a --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserBatchToSpaceNd : public TestTfliteParser { + public: + TestTfliteParserBatchToSpaceNd() {} + void SetUp() override { meta_graph = LoadAndConvert("./batch_to_space_nd.tflite"); } +}; + +TEST_F(TestTfliteParserBatchToSpaceNd, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_BatchToSpace) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { + const std::vector blockShape{2, 2}; + const std::vector crops{0, 0, 2, 0}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc new file mode 100644 index 0000000000..c986d595d7 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserCast : public TestTfliteParser { + public: + TestTfliteParserCast() {} + void SetUp() override { + meta_graph = LoadAndConvert("./cast.tflite"); + } +}; + +TEST_F(TestTfliteParserCast, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Cast) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserCast, AttrValue) { + // float32 --> int32 + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cos_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cos_parser_test.cc new file mode 100644 index 0000000000..785cd35e44 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cos_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserCos : public TestTfliteParser { + public: + TestTfliteParserCos() {} + void SetUp() override { meta_graph = LoadAndConvert("./cos.tflite", ""); } +}; + +TEST_F(TestTfliteParserCos, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Cos) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc new file mode 100644 index 0000000000..d1983e0b84 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserDepthToSpace : public TestTfliteParser { + public: + TestTfliteParserDepthToSpace() {} + void SetUp() override { + meta_graph = LoadAndConvert("./depth_to_space.tflite"); + } +}; + +TEST_F(TestTfliteParserDepthToSpace, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DepthToSpace) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDepthToSpace, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_equal_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_equal_parser_test.cc new file mode 100644 index 0000000000..7e816a9d39 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_equal_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserEqual : public TestTfliteParser { + public: + TestTfliteParserEqual() {} + void SetUp() override { + meta_graph = LoadAndConvert("./equal.tflite"); + } +}; + +TEST_F(TestTfliteParserEqual, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Equal) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_greater_equal_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_greater_equal_parser_test.cc new file mode 100644 index 0000000000..3b5dba47c2 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_greater_equal_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserGreaterEqual : public TestTfliteParser { + public: + TestTfliteParserGreaterEqual() {} + void SetUp() override { + meta_graph = LoadAndConvert("./greater_equal.tflite"); + } +}; + +TEST_F(TestTfliteParserGreaterEqual, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GreaterEqual) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_greater_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_greater_parser_test.cc new file mode 100644 index 0000000000..055dbab7da --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_greater_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserGreater : public TestTfliteParser { + public: + TestTfliteParserGreater() {} + void SetUp() override { + meta_graph = LoadAndConvert("./greater.tflite"); + } +}; + +TEST_F(TestTfliteParserGreater, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Greater) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_less_equal_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_less_equal_parser_test.cc new file mode 100644 index 0000000000..6c1d29877f --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_less_equal_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserLessEqual : public TestTfliteParser { + public: + TestTfliteParserLessEqual() {} + void SetUp() override { + meta_graph = LoadAndConvert("./less_equal.tflite"); + } +}; + +TEST_F(TestTfliteParserLessEqual, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LessEqual) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_less_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_less_parser_test.cc new file mode 100644 index 0000000000..795acfce7d --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_less_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserLess : public TestTfliteParser { + public: + TestTfliteParserLess() {} + void SetUp() override { + meta_graph = LoadAndConvert("./less.tflite"); + } +}; + +TEST_F(TestTfliteParserLess, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Less) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_log_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_log_parser_test.cc new file mode 100644 index 0000000000..fe869019ce --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_log_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserLog : public TestTfliteParser { + public: + TestTfliteParserLog() {} + void SetUp() override { meta_graph = LoadAndConvert("./log.tflite", ""); } +}; + +TEST_F(TestTfliteParserLog, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Log) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_and_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_and_parser_test.cc new file mode 100644 index 0000000000..415dc874ee --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_and_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteLogicalParserAnd : public TestTfliteParser { + public: + TestTfliteLogicalParserAnd() {} + void SetUp() override { meta_graph = LoadAndConvert("./logical_and.tflite", ""); } +}; + +TEST_F(TestTfliteLogicalParserAnd, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LogicalAnd) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_not_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_not_parser_test.cc new file mode 100644 index 0000000000..a34ae07afe --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_not_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserLogicalNot : public TestTfliteParser { + public: + TestTfliteParserLogicalNot() {} + void SetUp() override { meta_graph = LoadAndConvert("./logical_not.tflite", ""); } +}; + +TEST_F(TestTfliteParserLogicalNot, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LogicalNot) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_or_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_or_parser_test.cc new file mode 100644 index 0000000000..19141d6125 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_logical_or_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserLogicalOr : public TestTfliteParser { + public: + TestTfliteParserLogicalOr() {} + void SetUp() override { meta_graph = LoadAndConvert("./logical_or.tflite", ""); } +}; + +TEST_F(TestTfliteParserLogicalOr, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LogicalOr) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_maximum_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_maximum_parser_test.cc new file mode 100644 index 0000000000..a21dba4d8b --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_maximum_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserMaximum : public TestTfliteParser { + public: + TestTfliteParserMaximum() {} + void SetUp() override { meta_graph = LoadAndConvert("./maximum.tflite"); } +}; + +TEST_F(TestTfliteParserMaximum, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Maximum) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_minimum_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_minimum_parser_test.cc new file mode 100644 index 0000000000..2999d33b16 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_minimum_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserMinimum : public TestTfliteParser { + public: + TestTfliteParserMinimum() {} + void SetUp() override { meta_graph = LoadAndConvert("./minimum.tflite"); } +}; + +TEST_F(TestTfliteParserMinimum, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Minimum) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_not_equal_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_not_equal_parser_test.cc new file mode 100644 index 0000000000..e7260a1262 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_not_equal_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserNotEqual : public TestTfliteParser { + public: + TestTfliteParserNotEqual() {} + void SetUp() override { + meta_graph = LoadAndConvert("./not_equal.tflite"); + } +}; + +TEST_F(TestTfliteParserNotEqual, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_NotEqual) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc new file mode 100644 index 0000000000..f1f007e849 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserOneHot : public TestTfliteParser { + public: + TestTfliteParserOneHot() {} + void SetUp() override { meta_graph = LoadAndConvert("./one_hot.tflite"); } +}; + +TEST_F(TestTfliteParserOneHot, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_OneHot) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserOneHot, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); + // in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size() + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_p_relu_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_p_relu_parser_test.cc new file mode 100644 index 0000000000..11e5fe7c14 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_p_relu_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserPrelu : public TestTfliteParser { + public: + TestTfliteParserPrelu() {} + void SetUp() override { + meta_graph = LoadAndConvert("./prelu.tflite"); + } +}; + +TEST_F(TestTfliteParserPrelu, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Prelu) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserPrelu, AttrValue) { + std::vector slope(20, 0); + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc new file mode 100644 index 0000000000..881d95b89f --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/parser/tflite/tflite_model_parser.h" + +namespace mindspore { + +schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) { + schema::MetaGraphT *meta_graph = nullptr; + lite::TfliteModelParser parser; + meta_graph = parser.Parse(model_path, weight_path); + return meta_graph; +} + +void TestTfliteParser::TearDown() { free(meta_graph); } + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h new file mode 100644 index 0000000000..cf592570c3 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TEST_UT_TOOLS_CONVERTER_PARSER_TFLITE_TFLITE_PARSERS_TEST_H_ +#define MINDSPORE_LITE_TEST_UT_TOOLS_CONVERTER_PARSER_TFLITE_TFLITE_PARSERS_TEST_H_ + +#include +#include "common/common_test.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +class TestTfliteParser : public Common { + public: + TestTfliteParser() {} + void TearDown() override; + schema::MetaGraphT *LoadAndConvert(const std::string &model_path, const std::string &weight_path = ""); + schema::MetaGraphT *meta_graph; +}; + +} // namespace mindspore + +#endif // MINDSPORE_LITE_TEST_UT_TOOLS_CONVERTER_PARSER_TFLITE_TFLITE_PARSERS_TEST_H_ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_max_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_max_parser_test.cc new file mode 100644 index 0000000000..6c2ff57114 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_max_parser_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserReduceMax : public TestTfliteParser { + public: + TestTfliteParserReduceMax() {} + void SetUp() override { meta_graph = LoadAndConvert("./reduce_max.tflite"); } +}; + +TEST_F(TestTfliteParserReduceMax, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserReduceMax, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceMax) + << "wrong reduce mode"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_min_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_min_parser_test.cc new file mode 100644 index 0000000000..dce6a95fb0 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_min_parser_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserReduceMin : public TestTfliteParser { + public: + TestTfliteParserReduceMin() {} + void SetUp() override { meta_graph = LoadAndConvert("./reduce_min.tflite"); } +}; + +TEST_F(TestTfliteParserReduceMin, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserReduceMin, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceMin) + << "wrong reduce mode"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_prod_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_prod_parser_test.cc new file mode 100644 index 0000000000..dc19ad983f --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_prod_parser_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserReduceProd : public TestTfliteParser { + public: + TestTfliteParserReduceProd() {} + void SetUp() override { meta_graph = LoadAndConvert("./reduce_prod.tflite"); } +}; + +TEST_F(TestTfliteParserReduceProd, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserReduceProd, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceProd) + << "wrong reduce mode"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc new file mode 100644 index 0000000000..f1a972f6e4 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserReverseSequence : public TestTfliteParser { + public: + TestTfliteParserReverseSequence() {} + void SetUp() override { + meta_graph = LoadAndConvert("./reverse_sequence.tflite"); + } +}; + +TEST_F(TestTfliteParserReverseSequence, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ReverseSequence) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserReverseSequence, AttrValue) { + std::vector seq_length{7, 2, 3, 5}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_round_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_round_parser_test.cc new file mode 100644 index 0000000000..ffb6a26ffe --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_round_parser_test.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserRound : public TestTfliteParser { + public: + TestTfliteParserRound() {} + void SetUp() override { + meta_graph = LoadAndConvert("./round.tflite"); + } +}; + +TEST_F(TestTfliteParserRound, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Round) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_rsqrt_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_rsqrt_parser_test.cc new file mode 100644 index 0000000000..c1db7380cb --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_rsqrt_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserRsqrt : public TestTfliteParser { + public: + TestTfliteParserRsqrt() {} + void SetUp() override { meta_graph = LoadAndConvert("./rsqrt.tflite", ""); } +}; + +TEST_F(TestTfliteParserRsqrt, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Rsqrt) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sin_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sin_parser_test.cc new file mode 100644 index 0000000000..a859cabc87 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sin_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSin : public TestTfliteParser { + public: + TestTfliteParserSin() {} + void SetUp() override { meta_graph = LoadAndConvert("./sin.tflite", ""); } +}; + +TEST_F(TestTfliteParserSin, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sin) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc new file mode 100644 index 0000000000..a9d66ceaf4 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSpaceToBatchND : public TestTfliteParser { + public: + TestTfliteParserSpaceToBatchND() {} + void SetUp() override { + meta_graph = LoadAndConvert("./space_to_batch_nd.tflite"); + } +}; + +TEST_F(TestTfliteParserSpaceToBatchND, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SpaceToBatchND) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { + std::vector blockshape{2, 2}; + std::vector padding{0, 0, 2, 0}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc new file mode 100644 index 0000000000..76245204a7 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSpaceToDepth : public TestTfliteParser { + public: + TestTfliteParserSpaceToDepth() {} + void SetUp() override { + meta_graph = LoadAndConvert("./space_to_depth.tflite"); + } +}; + +TEST_F(TestTfliteParserSpaceToDepth, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SpaceToDepth) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc new file mode 100644 index 0000000000..043fdbfe63 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSparseToDense : public TestTfliteParser { + public: + TestTfliteParserSparseToDense() {} + void SetUp() override { + meta_graph = LoadAndConvert("./sparse_to_dense.tflite"); + } +}; + +TEST_F(TestTfliteParserSparseToDense, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SparseToDense) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSparseToDense, AttrValue) { + std::vector outputShape{5, 5}; + std::vector sparseValue{1}; + std::vector defaultValue{0}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc new file mode 100644 index 0000000000..aae2b629d1 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSplit : public TestTfliteParser { + public: + TestTfliteParserSplit() {} + + void SetUp() override { meta_graph = LoadAndConvert("./split.tflite"); } +}; + +TEST_F(TestTfliteParserSplit, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Split) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSplit, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); + const std::vector sizeSplits{2, 2}; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc new file mode 100644 index 0000000000..30a7049a42 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSplitV : public TestTfliteParser { + public: + TestTfliteParserSplitV() {} + + void SetUp() override { meta_graph = LoadAndConvert("./split_v.tflite"); } +}; + +TEST_F(TestTfliteParserSplitV, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Split) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSplitV, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); + const std::vector sizeSplits{1, 3}; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sqrt_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sqrt_parser_test.cc new file mode 100644 index 0000000000..427907dc70 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sqrt_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSqrt : public TestTfliteParser { + public: + TestTfliteParserSqrt() {} + void SetUp() override { meta_graph = LoadAndConvert("./sqrt.tflite", ""); } +}; + +TEST_F(TestTfliteParserSqrt, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sqrt) << "wrong Op Type"; +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_square_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_square_parser_test.cc new file mode 100644 index 0000000000..96e0e71f0c --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_square_parser_test.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSquare : public TestTfliteParser { + public: + TestTfliteParserSquare() {} + void SetUp() override { meta_graph = LoadAndConvert("./square.tflite", ""); } +}; + +TEST_F(TestTfliteParserSquare, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Square) << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_squared_difference_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_squared_difference_parser_test.cc new file mode 100644 index 0000000000..f6be7a992d --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_squared_difference_parser_test.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSquaredDifference : public TestTfliteParser { + public: + TestTfliteParserSquaredDifference() {} + void SetUp() override { + meta_graph = LoadAndConvert("./squared_difference.tflite"); + } +}; + +TEST_F(TestTfliteParserSquaredDifference, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SquaredDifference) + << "wrong Op Type"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc new file mode 100644 index 0000000000..bbb327b998 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserStridedSlice : public TestTfliteParser { + public: + TestTfliteParserStridedSlice() {} + void SetUp() override { + meta_graph = LoadAndConvert("./strided_slice.tflite"); + } +}; + +TEST_F(TestTfliteParserStridedSlice, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_StridedSlice) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserStridedSlice, AttrValue) { + std::vector begin{1, -1, 0}; + std::vector end{2, -3, 3}; + std::vector stride{1, -1, 1}; + std::vector isscale{3, 2, 3}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sum_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sum_parser_test.cc new file mode 100644 index 0000000000..f028516da2 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sum_parser_test.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSum : public TestTfliteParser { + public: + TestTfliteParserSum() {} + + void SetUp() override { meta_graph = LoadAndConvert("./sum.tflite"); } +}; + +TEST_F(TestTfliteParserSum, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSum, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceSum) + << "wrong reduce mode"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc new file mode 100644 index 0000000000..19401218f7 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserTile : public TestTfliteParser { + public: + TestTfliteParserTile() {} + void SetUp() override { + meta_graph = LoadAndConvert("./tile.tflite"); + } +}; + +TEST_F(TestTfliteParserTile, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Tile) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserTile, AttrValue) { + std::vector multiply{2, 3, 4}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc new file mode 100644 index 0000000000..d358fdb2d5 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserTopKV2 : public TestTfliteParser { + public: + TestTfliteParserTopKV2() {} + void SetUp() override { + meta_graph = LoadAndConvert("./topk_v2.tflite"); + } +}; + +TEST_F(TestTfliteParserTopKV2, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopKV2) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserTopKV2, AttrValue) { + // attr->sorted default is true + std::vector k{3}; + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc new file mode 100644 index 0000000000..244abf8ee2 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserUnique : public TestTfliteParser { + public: + TestTfliteParserUnique() {} + void SetUp() override { + meta_graph = LoadAndConvert("./unique.tflite"); + } +}; + +TEST_F(TestTfliteParserUnique, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Unique) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserUnique, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32 +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc new file mode 100644 index 0000000000..e2c03a5bec --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserUnstack : public TestTfliteParser { + public: + TestTfliteParserUnstack() {} + void SetUp() override { + meta_graph = LoadAndConvert("./unstack.tflite"); + } +}; + +TEST_F(TestTfliteParserUnstack, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Unstack) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserUnstack, AttrValue) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc new file mode 100644 index 0000000000..6fac6dedc4 --- /dev/null +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvActivationFusionTest : public mindspore::Common { + public: + ConvActivationFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +CNodeTptr BuildConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +CNodeTptr BuildDepthwiseConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} + +MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, + schema::ActivationType activation_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + meta_graph->nodes.emplace_back(std::move(convNode)); + + // relu node + auto next_node = std::make_unique(); + next_node->inputIndex = {2}; + next_node->outputIndex = {3}; + next_node->primitive = std::make_unique(); + next_node->primitive->value.type = schema::PrimitiveType_Activation; + auto prim2 = new schema::ActivationT; + prim2->type = activation_type; + next_node->primitive->value.value = prim2; + next_node->name = "activation"; + meta_graph->nodes.emplace_back(std::move(next_node)); + + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {3}; + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // final output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + return meta_graph; +} +} // namespace +TEST_F(ConvActivationFusionTest, TestConvReluNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU); + } +} + +TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU6); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU6); + } +} + +TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::ActivationType_LEAKY_RELU); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 2); + for (auto &cnode : new_meta_graph->nodes) { + if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->activationType, schema::ActivationType_NO_ACTIVATION); + } + } +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc new file mode 100644 index 0000000000..ef9fd87115 --- /dev/null +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvBiasAddFusionTest : public mindspore::Common { + public: + ConvBiasAddFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +CNodeTptr BuildConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +CNodeTptr BuildDepthwiseConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} + +MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, + schema::PrimitiveType add_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // biasadd node + auto biasadd_node = std::make_unique(); + biasadd_node->inputIndex = {2, 3}; + biasadd_node->outputIndex = {4}; + biasadd_node->primitive = std::make_unique(); + biasadd_node->primitive->value.type = add_type; + auto prim2 = new schema::BiasAddT; + biasadd_node->primitive->value.value = prim2; + biasadd_node->name = "BiasAdd"; + meta_graph->nodes.emplace_back(std::move(biasadd_node)); + + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {4}; + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // input2: bias + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // final output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + return meta_graph; +} +} // namespace +TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::PrimitiveType_BiasAdd); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); + } + MS_LOG(INFO) << "Passed"; +} + +TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); + } +} + +TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_MatMul); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 2); + for (auto &cnode : new_meta_graph->nodes) { + if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, false); + } + } +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc new file mode 100644 index 0000000000..e2ce0d9e90 --- /dev/null +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc @@ -0,0 +1,296 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "mindspore/core/utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvBNFusionTest : public mindspore::Common { + public: + ConvBNFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +CNodeTptr BuildConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +CNodeTptr BuildDepthwiseConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1, 2}; + convNode->outputIndex = {3}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +// caffe bn op has 3 inputs +MetaGraphTptr BuildCaffeGraph(schema::PrimitiveType conv_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // bn_node + auto bn_node = std::make_unique(); + bn_node->inputIndex = {2, 3, 4}; + bn_node->outputIndex = {5}; + bn_node->primitive = std::make_unique(); + bn_node->primitive->value.type = schema::PrimitiveType_BatchNorm; + auto prim2 = new schema::BatchNormT; + bn_node->primitive->value.value = prim2; + bn_node->name = "bn"; + meta_graph->nodes.emplace_back(std::move(bn_node)); + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // caffe bn : mean + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // caffe bn : var + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_ValueNode; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {1, 5, 5, 8}; + input3->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input3)); + + + // final bn output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {5}; + return meta_graph; +} + +// tf bn op has 4 inputs +MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // bn_node + auto bn_node = std::make_unique(); + bn_node->inputIndex = {3, 4, 5, 6, 7}; + bn_node->outputIndex = {8}; + bn_node->primitive = std::make_unique(); + bn_node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + auto prim2 = new schema::FusedBatchNormT; + bn_node->primitive->value.value = prim2; + bn_node->name = "bn"; + meta_graph->nodes.emplace_back(std::move(bn_node)); + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + + // input 1: conv_bias + auto input11 = std::make_unique(); + input11->nodeType = schema::NodeType::NodeType_ValueNode; + input11->format = schema::Format_KHWC; + input11->dataType = TypeId::kNumberTypeFloat32; + input11->dims = {8, 3, 3, 3}; + input11->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input11)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // tflite bn : scale + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // tflite bn : bias + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_ValueNode; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {1, 5, 5, 8}; + input3->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input3)); + + // tflite bn : mean + auto input4 = std::make_unique(); + input4->nodeType = schema::NodeType::NodeType_ValueNode; + input4->format = schema::Format_NHWC; + input4->dataType = TypeId::kNumberTypeFloat32; + input4->dims = {1, 5, 5, 8}; + input4->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input4)); + + // tflite bn : var + auto input5 = std::make_unique(); + input5->nodeType = schema::NodeType::NodeType_ValueNode; + input5->format = schema::Format_NHWC; + input5->dataType = TypeId::kNumberTypeFloat32; + input5->dims = {1, 5, 5, 8}; + input5->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input5)); + + // final output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {8}; + return meta_graph; +} +} // namespace +TEST_F(ConvBNFusionTest, TestConvAddNode) { + auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2D); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); + } +} + +TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) { + auto meta_graph = BuildTFGraph(schema::PrimitiveType_DepthwiseConv2D); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); + } +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc new file mode 100644 index 0000000000..1eca3a469c --- /dev/null +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc @@ -0,0 +1,221 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvScaleFusionTest : public mindspore::Common { + public: + ConvScaleFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +// conv has 2 inputs +CNodeTptr BuildConv2D(int with_bias_flag) { + auto convNode = std::make_unique(); + if (with_bias_flag) { + convNode->inputIndex = {0, 1, 2}; + convNode->outputIndex = {3}; + } else { + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + } + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +// conv2d has 3 inputs +CNodeTptr BuildDepthwiseConv2D(int with_bias_flag) { + auto convNode = std::make_unique(); + if (with_bias_flag) { + convNode->inputIndex = {0, 1, 2}; + convNode->outputIndex = {3}; + } else { + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + } + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} + +MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(conv_with_bias); + } else { + convNode = BuildDepthwiseConv2D(conv_with_bias); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // scale_node weight bias + auto scale_node = std::make_unique(); + if (conv_with_bias) { + scale_node->inputIndex = {3, 4, 5}; + scale_node->outputIndex = {6}; + } else { + scale_node->inputIndex = {2, 3, 4}; + scale_node->outputIndex = {5}; + } + + scale_node->primitive = std::make_unique(); + scale_node->primitive->value.type = schema::PrimitiveType_Scale; + auto prim2 = new schema::ScaleT; + scale_node->primitive->value.value = prim2; + scale_node->name = "scale"; + meta_graph->nodes.emplace_back(std::move(scale_node)); + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + if (conv_with_bias) { + // input 00: bias + auto input00 = std::make_unique(); + input00->nodeType = schema::NodeType::NodeType_ValueNode; + input00->format = schema::Format_NHWC; + input00->dataType = TypeId::kNumberTypeFloat32; + input00->dims = {1, 5, 5, 3}; + input00->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input00)); + } + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // scale weight input + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // scale bias input + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_ValueNode; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {1, 5, 5, 8}; + input3->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input3)); + + // final scale output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + if (conv_with_bias) { + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {6}; + } else { + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {5}; + } + return meta_graph; +} +} // namespace +TEST_F(ConvScaleFusionTest, TestConvScaleNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, true); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); + } +} + +TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, false); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); + ASSERT_EQ(cnode->inputIndex.size(), 3); + } +} +} // namespace mindspore diff --git a/mindspore/lite/tools/benchmark/CMakeLists.txt b/mindspore/lite/tools/benchmark/CMakeLists.txt new file mode 100644 index 0000000000..18d34bab56 --- /dev/null +++ b/mindspore/lite/tools/benchmark/CMakeLists.txt @@ -0,0 +1,17 @@ +# add shared link library +set(COMMON_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc + ) + +add_executable(benchmark + ${CMAKE_CURRENT_SOURCE_DIR}/main.cc + ${CMAKE_CURRENT_SOURCE_DIR}/benchmark.cc + ${COMMON_SRC}) + +if (PLATFORM_ARM32 OR PLATFORM_ARM64) + target_link_libraries(benchmark mindspore-lite) +else() + target_link_libraries(benchmark mindspore-lite pthread) +endif() diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc new file mode 100644 index 0000000000..c36fe113fd --- /dev/null +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -0,0 +1,561 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/benchmark/benchmark.h" +#define __STDC_FORMAT_MACROS +#include +#undef __STDC_FORMAT_MACROS +#include +#include +#include +#include +#include "src/common/common.h" +#include "include/ms_tensor.h" +#include "include/context.h" + +namespace mindspore { +namespace lite { +int Benchmark::GenerateRandomData(size_t size, void *data) { + MS_ASSERT(data != nullptr); + char *castedData = static_cast(data); + for (size_t i = 0; i < size; i++) { + castedData[i] = static_cast(i); + } + return RET_OK; +} + +int Benchmark::GenerateInputData() { + for (auto tensor : msInputs) { + MS_ASSERT(tensor != nullptr); + auto inputData = tensor->MutableData(); + if (inputData == nullptr) { + MS_LOG(ERROR) << "MallocData for inTensor failed"; + return RET_ERROR; + } + MS_ASSERT(tensor->GetData() != nullptr); + auto tensorByteSize = tensor->Size(); + auto status = GenerateRandomData(tensorByteSize, inputData); + if (status != 0) { + MS_LOG(ERROR) << "GenerateRandomData for inTensor failed %d" << status; + return status; + } + } + return RET_OK; +} + +int Benchmark::LoadInput() { + if (_flags->inDataPath.empty()) { + auto status = GenerateInputData(); + if (status != 0) { + MS_LOG(ERROR) << "Generate input data error " << status; + return status; + } + } else { + auto status = ReadInputFile(); + if (status != 0) { + MS_LOG(ERROR) << "ReadInputFile error, " << status; + return status; + } + } + return RET_OK; +} + +int Benchmark::ReadInputFile() { + if (msInputs.empty()) { + return RET_OK; + } + + if (this->_flags->inDataType == kImage) { + // int cvFlags; + // if (inTensor->Channel() == 3) { + // cvFlags = 0; // cv::IMREAD_COLOR; + // } else if (inTensor->Channel() == 1) { + // cvFlags = 1; // cv::IMREAD_GRAYSCALE; + // } else { + // MS_LOG(ERROR) << "Image mode only support imgChannel == 1 or 3, imgChannel : %lld", (long + // long)inTensor->Channel(); return RET_PARAM_INVALID; + // } + // todo fill inTensor->GetData() + } else { + for (auto i = 0; i < _flags->input_data_list.size(); i++) { + auto cur_tensor = msInputs.at(i); + MS_ASSERT(cur_tensor != nullptr); + size_t size; + char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size); + auto tensorDataSize = cur_tensor->Size(); + if (size != tensorDataSize) { + MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size; + return RET_ERROR; + } + auto inputData = cur_tensor->MutableData(); + memcpy(inputData, binBuf, tensorDataSize); + } + } + return RET_OK; +} + +// calibData is FP32 +int Benchmark::ReadCalibData() { + const char *calibDataPath = _flags->calibDataPath.c_str(); + // read calib data + std::ifstream inFile(calibDataPath); + if (!inFile.good()) { + MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist"; + return RET_ERROR; + } + + if (!inFile.is_open()) { + MS_LOG(ERROR) << "file: " << calibDataPath << " open failed"; + inFile.close(); + return RET_ERROR; + } + + std::string line; + + MS_LOG(INFO) << "Start reading calibData file"; + std::string tensorName; + while (!inFile.eof()) { + getline(inFile, line); + std::stringstream stringLine1(line); + size_t dim = 0; + stringLine1 >> tensorName >> dim; + std::vector dims; + size_t shapeSize = 1; + for (size_t i = 0; i < dim; i++) { + size_t tmpDim; + stringLine1 >> tmpDim; + dims.push_back(tmpDim); + shapeSize *= tmpDim; + } + + getline(inFile, line); + std::stringstream stringLine2(line); + std::vector tensorData; + for (size_t i = 0; i < shapeSize; i++) { + float tmpData; + stringLine2 >> tmpData; + tensorData.push_back(tmpData); + } + + auto *checkTensor = new CheckTensor(dims, tensorData); + this->calibData.insert(std::make_pair(tensorName, checkTensor)); + } + inFile.close(); + MS_LOG(INFO) << "Finish reading calibData file"; + return RET_OK; +} + +// tensorData need to be converter first +float Benchmark::CompareData(const std::string &nodeName, std::vector msShape, float *msTensorData) { + auto iter = this->calibData.find(nodeName); + if (iter != this->calibData.end()) { + std::vector castedMSShape; + size_t shapeSize = 1; + for (int64_t dim : msShape) { + castedMSShape.push_back(size_t(dim)); + shapeSize *= dim; + } + + CheckTensor *calibTensor = iter->second; + if (calibTensor->shape != castedMSShape) { + std::ostringstream oss; + oss << "Shape of mslite output("; + for (auto dim : castedMSShape) { + oss << dim << ","; + } + oss << ") and shape source model output("; + for (auto dim : calibTensor->shape) { + oss << dim << ","; + } + oss << ") are different"; + MS_LOG(ERROR) << "%s", oss.str().c_str(); + return RET_ERROR; + } + size_t errorCount = 0; + float meanError = 0; + std::cout << "Data of node " << nodeName << " : "; + for (size_t j = 0; j < shapeSize; j++) { + if (j < 50) { + std::cout << msTensorData[j] << " "; + } + + if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { + MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; + return RET_ERROR; + } + + auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j)); + auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j)); + if (absoluteError > tolerance) { + // just assume that atol = rtol + meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN); + errorCount++; + } + } + std::cout << std::endl; + if (meanError > 0.0f) { + meanError /= errorCount; + } + + if (meanError <= 0.0000001) { + std::cout << "Mean bias of node " << nodeName << " : 0%" << std::endl; + } else { + std::cout << "Mean bias of node " << nodeName << " : " << meanError * 100 << "%" << std::endl; + } + return meanError; + } else { + MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str(); + return RET_ERROR; + } +} + +int Benchmark::CompareOutput() { + std::cout << "================ Comparing Output data ================" << std::endl; + float totalBias = 0; + int totalSize = 0; + bool hasError = false; + for (const auto &calibTensor : calibData) { + std::string nodeName = calibTensor.first; + auto tensors = session->GetOutputsByName(nodeName); + if (tensors.empty()) { + MS_LOG(ERROR) << "Cannot find output node: " << nodeName.c_str() << " , compare output data fail."; + return RET_ERROR; + } + // make sure tensor size is 1 + if (tensors.size() != 1) { + MS_LOG(ERROR) << "Only support 1 tensor with a name now."; + return RET_ERROR; + } + auto &tensor = tensors.front(); + MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); + MS_ASSERT(tensor->GetData() != nullptr); + float bias = CompareData(nodeName, tensor->shape(), static_cast(tensor->MutableData())); + if (bias >= 0) { + totalBias += bias; + totalSize++; + } else { + hasError = true; + break; + } + } + + if (!hasError) { + float meanBias; + if (totalSize != 0) { + meanBias = totalBias / totalSize * 100; + } else { + meanBias = 0; + } + + std::cout << "Mean bias of all nodes: " << meanBias << "%" << std::endl; + std::cout << "=======================================================" << std::endl << std::endl; + + if (meanBias > this->_flags->accuracyThreshold) { + MS_LOG(ERROR) << "Mean bias of all nodes is too big: " << meanBias << "%%"; + return RET_ERROR; + } else { + return RET_OK; + } + } else { + MS_LOG(ERROR) << "Error in CompareData"; + std::cout << "=======================================================" << std::endl << std::endl; + return RET_ERROR; + } +} + +int Benchmark::MarkPerformance() { + MS_LOG(INFO) << "Running warm up loops..."; + for (int i = 0; i < _flags->warmUpLoopCount; i++) { + auto status = session->RunGraph(); + if (status != 0) { + MS_LOG(ERROR) << "Inference error %d" << status; + return status; + } + } + + MS_LOG(INFO) << "Running benchmark loops..."; + uint64_t timeMin = 1000000; + uint64_t timeMax = 0; + uint64_t timeAvg = 0; + + for (int i = 0; i < _flags->loopCount; i++) { + session->BindThread(true); + auto start = GetTimeUs(); + auto status = session->RunGraph(); + if (status != 0) { + MS_LOG(ERROR) << "Inference error %d" << status; + return status; + } + + auto end = GetTimeUs(); + auto time = end - start; + timeMin = std::min(timeMin, time); + timeMax = std::max(timeMax, time); + timeAvg += time; + + session->BindThread(false); + } + if (_flags->loopCount > 0) { + timeAvg /= _flags->loopCount; + MS_LOG(INFO) << "Model = " << _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1).c_str() + << ", NumThreads = " << _flags->numThreads << ", MinRunTime = " << timeMin / 1000.0f + << ", MaxRuntime = " << timeMax / 1000.0f << ", AvgRunTime = " << timeAvg / 1000.0f; + printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n", + _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1).c_str(), _flags->numThreads, + timeMin / 1000.0f, timeMax / 1000.0f, timeAvg / 1000.0f); + } + return RET_OK; +} + +int Benchmark::MarkAccuracy() { + MS_LOG(INFO) << "MarkAccuracy"; + for (size_t i = 0; i < msInputs.size(); i++) { + MS_ASSERT(msInputs.at(i) != nullptr); + MS_ASSERT(msInputs.at(i)->data_type() == TypeId::kNumberTypeFloat32); + auto inData = reinterpret_cast(msInputs.at(i)->MutableData()); + std::cout << "InData" << i << ": "; + for (size_t j = 0; j < 20; j++) { + std::cout << inData[j] << " "; + } + std::cout << std::endl; + } + auto status = session->RunGraph(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Inference error " << status; + return status; + } + + status = ReadCalibData(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Read calib data error " << status; + return status; + } + + status = CompareOutput(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Compare output error " << status; + return status; + } + return RET_OK; +} + +int Benchmark::RunBenchmark(const std::string &deviceType) { + auto startPrepareTime = GetTimeUs(); + // Load graph + std::string modelName = _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1); + + MS_LOG(INFO) << "start reading model file"; + size_t size = 0; + char *graphBuf = ReadFile(_flags->modelPath.c_str(), &size); + if (graphBuf == nullptr) { + MS_LOG(ERROR) << "Read model file failed while running %s", modelName.c_str(); + return RET_ERROR; + } + auto model = lite::Model::Import(graphBuf, size); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model file failed while running %s", modelName.c_str(); + delete[](graphBuf); + return RET_ERROR; + } + delete[](graphBuf); + auto context = new(std::nothrow) lite::Context; + if (context == nullptr) { + MS_LOG(ERROR) << "New context failed while running %s", modelName.c_str(); + return RET_ERROR; + } + if (_flags->device == "CPU") { + context->device_ctx_.type = lite::DT_CPU; + } else if (_flags->device == "GPU") { + context->device_ctx_.type = lite::DT_GPU; + } else { + context->device_ctx_.type = lite::DT_NPU; + } + + if (_flags->cpuBindMode == -1) { + context->cpu_bind_mode_ = MID_CPU; + } else if (_flags->cpuBindMode == 0) { + context->cpu_bind_mode_ = HIGHER_CPU; + } else { + context->cpu_bind_mode_ = NO_BIND; + } + context->thread_num_ = _flags->numThreads; + session = session::LiteSession::CreateSession(context); + delete(context); + if (session == nullptr) { + MS_LOG(ERROR) << "CreateSession failed while running %s", modelName.c_str(); + return RET_ERROR; + } + auto ret = session->CompileGraph(model.get()); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CompileGraph failed while running %s", modelName.c_str(); + delete(session); + return ret; + } + msInputs = session->GetInputs(); + auto endPrepareTime = GetTimeUs(); +#if defined(__arm__) + MS_LOG(INFO) << "PrepareTime = " << (endPrepareTime - startPrepareTime) / 1000 << " ms"; + printf("PrepareTime = %lld ms, ", (endPrepareTime - startPrepareTime) / 1000); +#else + MS_LOG(INFO) << "PrepareTime = " << (endPrepareTime - startPrepareTime) / 1000 << " ms "; + printf("PrepareTime = %ld ms, ", (endPrepareTime - startPrepareTime) / 1000); +#endif + + // Load input + MS_LOG(INFO) << "start generate input data"; + auto status = LoadInput(); + if (status != 0) { + MS_LOG(ERROR) << "Generate input data error"; + delete(session); + return status; + } + if (!_flags->calibDataPath.empty()) { + status = MarkAccuracy(); + if (status != 0) { + MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status; + delete(session); + return status; + } + } else { + status = MarkPerformance(); + if (status != 0) { + MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status; + delete(session); + return status; + } + } + + if (cleanData) { + for (auto &data : calibData) { + data.second->shape.clear(); + data.second->data.clear(); + delete data.second; + } + calibData.clear(); + } + + delete(session); + return RET_OK; +} + +void BenchmarkFlags::InitInputDataList() { + char *input_list = new char[this->inDataPath.length() + 1]; + snprintf(input_list, this->inDataPath.length() + 1, "%s", this->inDataPath.c_str()); + char *cur_input; + const char *split_c = ","; + cur_input = strtok(input_list, split_c); + while (cur_input) { + input_data_list.emplace_back(cur_input); + cur_input = strtok(nullptr, split_c); + } + delete[] input_list; +} + +void BenchmarkFlags::InitResizeDimsList() { + std::string content; + content = this->resizeDimsIn; + std::vector shape; + auto shapeStrs = StringSplit(content, std::string(DELIM_COLON)); + for (const auto &shapeStr : shapeStrs) { + shape.clear(); + auto dimStrs = StringSplit(shapeStr, std::string(DELIM_COMMA)); + std::cout << "Resize Dims: "; + for (const auto &dimStr : dimStrs) { + std::cout << dimStr << " "; + shape.emplace_back(static_cast(std::stoi(dimStr))); + } + std::cout << std::endl; + this->resizeDims.emplace_back(shape); + } +} + +int Benchmark::Init() { + if (this->_flags == nullptr) { + return 1; + } + MS_LOG(INFO) << "ModelPath = " << this->_flags->modelPath; + MS_LOG(INFO) << "InDataPath = " << this->_flags->inDataPath; + MS_LOG(INFO) << "InDataType = " << this->_flags->inDataTypeIn; + MS_LOG(INFO) << "LoopCount = " << this->_flags->loopCount; + MS_LOG(INFO) << "DeviceType = " << this->_flags->device; + MS_LOG(INFO) << "AccuracyThreshold = " << this->_flags->accuracyThreshold; + MS_LOG(INFO) << "WarmUpLoopCount = " << this->_flags->warmUpLoopCount; + MS_LOG(INFO) << "NumThreads = " << this->_flags->numThreads; + MS_LOG(INFO) << "calibDataPath = " << this->_flags->calibDataPath; + if (this->_flags->cpuBindMode == -1) { + MS_LOG(INFO) << "cpuBindMode = MID_CPU"; + } else if (this->_flags->cpuBindMode == 1) { + MS_LOG(INFO) << "cpuBindMode = HIGHER_CPU"; + } else { + MS_LOG(INFO) << "cpuBindMode = NO_BIND"; + } + + this->_flags->inDataType = this->_flags->inDataTypeIn == "img" ? kImage : kBinary; + + if (_flags->modelPath.empty()) { + MS_LOG(ERROR) << "modelPath is required"; + return 1; + } + _flags->InitInputDataList(); + _flags->InitResizeDimsList(); + if (!_flags->resizeDims.empty() && _flags->resizeDims.size() != _flags->input_data_list.size()) { + MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; + return RET_ERROR; + } + + return RET_OK; +} + +int RunBenchmark(int argc, const char **argv) { + BenchmarkFlags flags; + Option err = flags.ParseFlags(argc, argv); + + if (err.IsSome()) { + std::cerr << err.Get() << std::endl; + std::cerr << flags.Usage() << std::endl; + return RET_ERROR; + } + + if (flags.help) { + std::cerr << flags.Usage() << std::endl; + return RET_OK; + } + + Benchmark mBenchmark(&flags); + auto status = mBenchmark.Init(); + if (status != 0) { + MS_LOG(ERROR) << "Benchmark init Error : " << status; + return RET_ERROR; + } + + if (flags.device == "NPU") { + status = mBenchmark.RunBenchmark("NPU"); + } else { + status = mBenchmark.RunBenchmark("CPU"); + } + + if (status != 0) { + MS_LOG(ERROR) << "Run Benchmark " << flags.modelPath.substr(flags.modelPath.find_last_of(DELIM_SLASH) + 1).c_str() + << " Failed : " << status; + return RET_ERROR; + } + + MS_LOG(INFO) << "Run Benchmark " << flags.modelPath.substr(flags.modelPath.find_last_of(DELIM_SLASH) + 1).c_str() + << " Success."; + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h new file mode 100644 index 0000000000..74ac170928 --- /dev/null +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINNIE_BENCHMARK_BENCHMARK_H_ +#define MINNIE_BENCHMARK_BENCHMARK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tools/common/flag_parser.h" +#include "src/common/file_utils.h" +#include "src/common/utils.h" +#include "schema/model_generated.h" +#include "include/model.h" +#include "include/lite_session.h" +#include "include/inference.h" + +namespace mindspore::lite { +enum MS_API InDataType { kImage = 0, kBinary = 1 }; + +constexpr float relativeTolerance = 1e-5; +constexpr float absoluteTolerance = 1e-8; + +struct MS_API CheckTensor { + CheckTensor(const std::vector &shape, const std::vector &data) { + this->shape = shape; + this->data = data; + } + std::vector shape; + std::vector data; +}; + +class MS_API BenchmarkFlags : public virtual FlagParser { + public: + BenchmarkFlags() { + // common + AddFlag(&BenchmarkFlags::modelPath, "modelPath", "Input model path", ""); + AddFlag(&BenchmarkFlags::inDataPath, "inDataPath", "Input data path, if not set, use random input", ""); + AddFlag(&BenchmarkFlags::inDataTypeIn, "inDataType", "Input data type. img | bin", "bin"); + AddFlag(&BenchmarkFlags::omModelPath, "omModelPath", "OM model path, only required when device is NPU", ""); + AddFlag(&BenchmarkFlags::device, "device", "CPU | NPU | GPU", "CPU"); + AddFlag(&BenchmarkFlags::cpuBindMode, "cpuBindMode", + "Input -1 for MID_CPU, 1 for HIGHER_CPU, 0 for NO_BIND, defalut value: 1", 1); + // MarkPerformance + AddFlag(&BenchmarkFlags::loopCount, "loopCount", "Run loop count", 10); + AddFlag(&BenchmarkFlags::numThreads, "numThreads", "Run threads number", 2); + AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); + // MarkAccuracy + AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); + AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); + // Resize + AddFlag(&BenchmarkFlags::resizeDimsIn, "resizeDims", "Dims to resize to", ""); + } + + ~BenchmarkFlags() override = default; + + void InitInputDataList(); + + void InitResizeDimsList(); + + public: + // common + std::string modelPath; + std::string inDataPath; + std::vector input_data_list; + InDataType inDataType; + std::string inDataTypeIn; + int cpuBindMode = 1; + // MarkPerformance + int loopCount; + int numThreads; + int warmUpLoopCount; + // MarkAccuracy + std::string calibDataPath; + float accuracyThreshold; + // Resize + std::string resizeDimsIn; + std::vector> resizeDims; + + std::string omModelPath; + std::string device; +}; + +class MS_API Benchmark { + public: + explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {} + + virtual ~Benchmark() = default; + + int Init(); + int RunBenchmark(const std::string &deviceType = "NPU"); + // int RunNPUBenchmark(); + + private: + // call GenerateInputData or ReadInputFile to init inputTensors + int LoadInput(); + + // call GenerateRandomData to fill inputTensors + int GenerateInputData(); + + int GenerateRandomData(size_t size, void *data); + + int ReadInputFile(); + + int ReadCalibData(); + + int CompareOutput(); + + float CompareData(const std::string &nodeName, std::vector msShape, float *msTensorData); + + int MarkPerformance(); + + int MarkAccuracy(); + + private: + BenchmarkFlags *_flags; + session::LiteSession *session; + std::vector msInputs; + std::unordered_map> msOutputs; + std::unordered_map calibData; + bool cleanData = true; +}; + +int MS_API RunBenchmark(int argc, const char **argv); +} // namespace mindspore::lite +#endif // MINNIE_BENCHMARK_BENCHMARK_H_ diff --git a/mindspore/lite/tools/benchmark/main.cc b/mindspore/lite/tools/benchmark/main.cc new file mode 100644 index 0000000000..10d2204783 --- /dev/null +++ b/mindspore/lite/tools/benchmark/main.cc @@ -0,0 +1,20 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/benchmark/benchmark.h" + +int main(int argc, const char **argv) { return mindspore::lite::RunBenchmark(argc, argv); } + diff --git a/mindspore/lite/tools/common/CMakeLists.txt b/mindspore/lite/tools/common/CMakeLists.txt new file mode 100755 index 0000000000..0250fa6378 --- /dev/null +++ b/mindspore/lite/tools/common/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_library(converter_common_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/storage.cc + ) +set_target_properties(converter_common_mid PROPERTIES COMPILE_FLAGS "-Wno-unused-function") diff --git a/mindspore/lite/tools/common/converter_op_utils.h b/mindspore/lite/tools/common/converter_op_utils.h new file mode 100644 index 0000000000..20356b5ade --- /dev/null +++ b/mindspore/lite/tools/common/converter_op_utils.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_CONVERTER_COMMON_OP_UTILS_H_ +#define PREDICT_CONVERTER_COMMON_OP_UTILS_H_ + +#include +#include +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; } +inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) { + return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT)); +} +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_CONVERTER_COMMON_OP_UTILS_H_ + diff --git a/mindspore/lite/tools/common/flag_parser.cc b/mindspore/lite/tools/common/flag_parser.cc new file mode 100755 index 0000000000..3ea4baac9b --- /dev/null +++ b/mindspore/lite/tools/common/flag_parser.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/flag_parser.h" + +namespace mindspore { +namespace lite { +// parse flags read from command line +Option FlagParser::ParseFlags(int argc, const char *const *argv, bool supportUnknown, + bool supportDuplicate) { + MS_ASSERT(argv != nullptr); + const int FLAG_PREFIX_LEN = 2; + // Get binary name + binName = GetFileName(argv[0]); + + std::multimap> keyValues; + for (int i = 1; i < argc; i++) { + std::string tmp = argv[i]; + Trim(&tmp); + const std::string flagItem(tmp); + + if (flagItem == "--") { + break; + } + + if (flagItem.find("--") == std::string::npos) { + continue; + } + + std::string key; + Option value = Option(None()); + + size_t pos = flagItem.find_first_of("="); + if (pos == std::string::npos && flagItem.find("--no-") != std::string::npos) { + key = flagItem.substr(FLAG_PREFIX_LEN); + } else if (pos == std::string::npos) { + key = flagItem.substr(FLAG_PREFIX_LEN); + } else { + key = flagItem.substr(FLAG_PREFIX_LEN, pos - FLAG_PREFIX_LEN); + value = Option(flagItem.substr(pos + 1)); + } + + keyValues.insert(std::pair>(key, value)); + } + + Option ret = Option(InnerParseFlags(&keyValues)); + if (ret.IsSome()) { + return Option(ret.Get()); + } + + return Option(None()); +} + +bool FlagParser::GetRealFlagName(std::string *flagName, const std::string &oriFlagName) { + MS_ASSERT(flagName != nullptr); + const int BOOL_TYPE_FLAG_PREFIX_LEN = 3; + bool opaque = false; + if (StartsWithPrefix(oriFlagName, "no-")) { + *flagName = oriFlagName.substr(BOOL_TYPE_FLAG_PREFIX_LEN); + opaque = true; + } else { + *flagName = oriFlagName; + } + return opaque; +} + +// Inner parse function +Option FlagParser::InnerParseFlags(std::multimap> *keyValues) { + MS_ASSERT(keyValues != nullptr); + for (auto it = keyValues->begin(); it != keyValues->end(); ++it) { + std::string flagName; + bool opaque = GetRealFlagName(&flagName, (*it).first); + Option flagValue = (*it).second; + + auto item = flags.find(flagName); + if (item == flags.end()) { + return Option(std::string(flagName + " is not a valid flag")); + } + FlagInfo *flag = &(item->second); + if (flag == nullptr) { + return Option("Failed: flag is nullptr"); + } + if (flag->isParsed) { + return Option("Failed: already parsed flag: " + flagName); + } + std::string tmpValue; + if (!flag->isBoolean) { + if (opaque) { + return Option(flagName + " is not a boolean type"); + } + if (flagValue.IsNone()) { + return Option("No value provided for non-boolean type: " + flagName); + } + tmpValue = flagValue.Get(); + } else { + if (flagValue.IsNone() || flagValue.Get().empty()) { + tmpValue = !opaque ? "true" : "false"; + } else if (!opaque) { + tmpValue = flagValue.Get(); + } else { + return Option(std::string("Boolean flag can not have non-empty value")); + } + } + // begin to parse value + Option ret = flag->parse(this, tmpValue); + if (ret.IsNone()) { + return Option("Failed to parse value for: " + flag->flagName); + } + flag->isParsed = true; + } + + // to check flags not given in command line but added as in constructor + for (auto &flag : flags) { + if (flag.second.isRequired && !flag.second.isParsed) { + return Option("Error, value of '" + flag.first + "' not provided"); + } + } + + return Option(None()); +} + +void Replaceall(std::string *str, const std::string &oldValue, const std::string &newValue) { + if (str == nullptr) { + // MS_LOG(ERROR)("Input str is nullptr"); + return; + } + while (true) { + std::string::size_type pos(0); + if ((pos = str->find(oldValue)) != std::string::npos) { + str->replace(pos, oldValue.length(), newValue); + } else { + break; + } + } +} + +std::string FlagParser::Usage(const Option &usgMsg) const { + // first line, brief of the usage + std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; + // usage of bin name + usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; + // help line of help message, usageLine:message of parametors + std::string helpLine = ""; + std::string usageLine = ""; + uint32_t i = 0; + for (auto flag = flags.begin(); flag != flags.end(); flag++) { + std::string flagName = flag->second.flagName; + std::string helpInfo = flag->second.helpInfo; + // parameter line + std::string thisLine = flag->second.isBoolean ? " --[no-]" + flagName : " --" + flagName + "=VALUE"; + if (++i <= flags.size()) { + // add parameter help message of each line + thisLine += " " + helpInfo; + Replaceall(&helpInfo, "\n\r", "\n"); + usageLine += thisLine + "\n"; + } else { + // breif help message + helpLine = thisLine + " " + helpInfo + "\n"; + } + } + // total usage is brief of usage+ brief of bin + help message + brief of + // parameters + return usageString + helpLine + usageLine; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/common/flag_parser.h b/mindspore/lite/tools/common/flag_parser.h new file mode 100755 index 0000000000..e8e87ac699 --- /dev/null +++ b/mindspore/lite/tools/common/flag_parser.h @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_COMMON_FLAG_PARSER_H_ +#define PREDICT_COMMON_FLAG_PARSER_H_ + +#include +#include +#include +#include + +#include "src/common/utils.h" +#include "tools/common/option.h" + +namespace mindspore { +namespace lite { +struct FlagInfo; + +struct Nothing {}; + +class FlagParser { + public: + FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", false); } + + virtual ~FlagParser() {} + + // only support read flags from command line + virtual Option ParseFlags(int argc, const char *const *argv, bool supportUnknown = false, + bool supportDuplicate = false); + std::string Usage(const Option &usgMsg = Option(None())) const; + + template + void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); + template + void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); + + // non-Option type fields in class + template + void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); + + template + void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); + + template + void AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo); + + // Option-type fields + template + void AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo); + bool help; + + protected: + template + void AddFlag(std::string Flags::*t1, const std::string &flagName, const std::string &helpInfo, const char *t2) { + AddFlag(t1, flagName, helpInfo, std::string(t2)); + } + + std::string binName; + Option usageMsg; + + private: + struct FlagInfo { + std::string flagName; + bool isRequired; + bool isBoolean; + std::string helpInfo; + bool isParsed; + std::function(FlagParser *, const std::string &)> parse; + }; + + inline void AddFlag(const FlagInfo &flag); + + // construct a temporary flag + template + void ConstructFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); + + // construct a temporary flag + template + void ConstructFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); + + Option InnerParseFlags(std::multimap> *values); + + bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName); + + std::map flags; +}; + +// convert to std::string +template +Option ConvertToString(T Flags::*t, const FlagParser &baseFlag) { + const Flags *flag = dynamic_cast(&baseFlag); + if (flag != nullptr) { + return std::to_string(flag->*t); + } + + return Option(None()); +} + +// construct for a Option-type flag +template +void FlagParser::ConstructFlag(Option Flags::*t1, const std::string &flagName, const std::string &helpInfo, + FlagInfo *flag) { + if (flag == nullptr) { + // MS_LOGE("FlagInfo is nullptr"); + return; + } + flag->flagName = flagName; + flag->helpInfo = helpInfo; + + flag->isBoolean = typeid(T) == typeid(bool); + flag->isParsed = false; +} + +// construct a temporary flag +template +void FlagParser::ConstructFlag(T Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) { + if (flag == nullptr) { + // MS_LOGE("FlagInfo is nullptr"); + return; + } + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + flag->flagName = flagName; + flag->helpInfo = helpInfo; + flag->isBoolean = typeid(T) == typeid(bool); + flag->isParsed = false; +} + +inline void FlagParser::AddFlag(const FlagInfo &flagItem) { flags[flagItem.flagName] = flagItem; } + +template +void FlagParser::AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo) { + if (t == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + AddFlag(t, flagName, helpInfo, static_cast(nullptr)); +} + +template +void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + AddFlag(t1, flagName, helpInfo, &t2); +} + +// just for test +template +void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + AddFlag(t1, flagName, helpInfo, &t2); +} + +template +void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + + FlagInfo flagItem; + + // flagItem is as a output parameter + ConstructFlag(t1, flagName, helpInfo, flagItem); + flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { + if (base != nullptr) { + Option ret = Option(GenericParseValue(value)); + if (ret.IsNone()) { + return Option(None()); + } else { + *t1 = ret.Get(); + } + } + + return Option(Nothing()); + }; + + if (t2 != nullptr) { + flagItem.isRequired = false; + *t1 = *t2; + } + + flagItem.helpInfo += + !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; + if (t2 != nullptr) { + flagItem.helpInfo += ToString(*t2).Get(); + } + flagItem.helpInfo += ")"; + + // add this flag to a std::map + AddFlag(flagItem); +} + +template +void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + + Flags *flag = dynamic_cast(this); + if (flag == nullptr) { + return; + } + + FlagInfo flagItem; + + // flagItem is as a output parameter + ConstructFlag(t1, flagName, helpInfo, &flagItem); + flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { + Flags *flag = dynamic_cast(base); + if (base != nullptr) { + Option ret = Option(GenericParseValue(value)); + if (ret.IsNone()) { + return Option(None()); + } else { + flag->*t1 = ret.Get(); + } + } + + return Option(Nothing()); + }; + + if (t2 != nullptr) { + flagItem.isRequired = false; + flag->*t1 = *t2; + } else { + flagItem.isRequired = true; + } + + flagItem.helpInfo += + !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; + if (t2 != nullptr) { + flagItem.helpInfo += ToString(*t2).Get(); + } + flagItem.helpInfo += ")"; + + // add this flag to a std::map + AddFlag(flagItem); +} + +// option-type add flag +template +void FlagParser::AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo) { + if (t == nullptr) { + // MS_LOGE("t is nullptr"); + return; + } + + Flags *flag = dynamic_cast(this); + if (flag == nullptr) { + // MS_LOGE("dynamic_cast failed"); + return; + } + + FlagInfo flagItem; + // flagItem is as a output parameter + ConstructFlag(t, flagName, helpInfo, &flagItem); + flagItem.isRequired = false; + flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option { + Flags *flag = dynamic_cast(base); + if (base != nullptr) { + Option ret = Option(GenericParseValue(value)); + if (ret.IsNone()) { + return Option(None()); + } else { + flag->*t = Option(Some(ret.Get())); + } + } + + return Option(Nothing()); + }; + + // add this flag to a std::map + AddFlag(flagItem); +} +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_COMMON_FLAG_PARSER_H_ + diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc new file mode 100755 index 0000000000..1b84029174 --- /dev/null +++ b/mindspore/lite/tools/common/graph_util.cc @@ -0,0 +1,671 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/graph_util.h" +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "tools/common/tensor_util.h" +#include "tools/common/node_util.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +OpDefCopyer GetSimpleOpCopyer() { + return [](std::unique_ptr &inCNode) -> std::unique_ptr { + std::unique_ptr newCNode(new CNodeT); + + newCNode->name = inCNode->name; + newCNode->quantType = inCNode->quantType; + newCNode->primitive = std::make_unique(); + newCNode->primitive->value.type = inCNode->primitive->value.type; + // newCNode->quantParam.clear(); + // for (size_t i = 0; i < inCNode->quantParam.size(); i++) { + // auto &quantParam = inCNode->quantParam.at(i); + // auto quantParamCopy = CopyQuantParamArrayT(quantParam); + // if (quantParamCopy == nullptr) { + // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", inOpDef->name.c_str()); + // return nullptr; + // } + // newCNode->quantParam.emplace_back(std::move(quantParamCopy)); + // } + return std::move(newCNode); + }; +} + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) { + return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx); +} + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) { + std::vector inputIndexes; + if (inputIndexIdx == -1) { + inputIndexes = node.inputIndex; + } else { + MS_ASSERT(node.inputIndex.size() > inputIndexIdx); + inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx)); + } + std::set inputNodeIdx; + for (uint32_t inputIdx : inputIndexes) { + auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx); + inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end()); + } + std::vector ret; + ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end()); + return ret; +} + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, + const int outputIndexIdx) { + return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx); +} + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) { + std::vector outputIndexes; + if (outputIndexIdx == -1) { + outputIndexes = node.outputIndex; + } else { + MS_ASSERT(node.outputIndex.size() > outputIndexIdx); + outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx)); + } + std::set outputNodeIdx; + for (uint32_t outputIdx : outputIndexes) { + auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx); + outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end()); + } + std::vector ret; + ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end()); + return ret; +} + +std::vector GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { + std::vector preNodeIdx; + for (size_t i = 0; i < graphT.nodes.size(); i++) { + auto &oldNode = graphT.nodes.at(i); + if (oldNode == nullptr) { + continue; + } + auto outputIndexes = oldNode->outputIndex; + if (IsContain(outputIndexes, tensorIdx)) { + preNodeIdx.emplace_back(i); + } + } + return std::move(preNodeIdx); +} + +std::vector GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { + std::vector postNodeIdx; + for (size_t i = 0; i < graphT.nodes.size(); i++) { + auto &oldNode = graphT.nodes.at(i); + if (oldNode == nullptr) { + continue; + } + auto inputIndexes = oldNode->inputIndex; + if (IsContain(inputIndexes, tensorIdx)) { + postNodeIdx.emplace_back(i); + } + } + return std::move(postNodeIdx); +} + +STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(node != nullptr); + size_t nodeIdx = 0; + for (size_t i = 0; i < graphT->nodes.size(); i++) { + auto &inNode = graphT->nodes.at(i); + MS_ASSERT(inNode != nullptr); + if (inNode->name == node->name) { + nodeIdx = i; + break; + } + } + auto inputTensorIdxes = node->inputIndex; + auto outputTensorIdxes = node->outputIndex; + if (inputTensorIdxes.empty()) { + // MS_LOG(ERROR)("Node %s should has no inputs", node->name.c_str()); + return RET_ERROR; + } + if (outputTensorIdxes.size() != 1) { + // MS_LOG(ERROR)("FakeQuantNode %s should has 1 output, in fact: %zu", node->name.c_str(), + // outputTensorIdxes.size()); + return RET_ERROR; + } + auto inDataTensorIdx = inputTensorIdxes.front(); + auto outDataTensorIdx = outputTensorIdxes.front(); + + MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); + const auto &inDataTensor = graphT->allTensors.at(inDataTensorIdx); + MS_ASSERT(inDataTensor != nullptr); + auto &gOutTensorIdx = graphT->outputIndex; + for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + + // find poseNode + auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + } + + // todo whether need to remove weightInputTensores + // remove all node's outputTensors + RemoveTensor(graphT, outputTensorIdxes); + node->inputIndex.clear(); + node->outputIndex.clear(); + + return RET_OK; +} + +STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) { + MS_ASSERT(graph != nullptr); + /* + if (graph->subgraphs.size() <= subGraphIdx) { + //MS_LOG(ERROR)("subGraphIdx out of range: %zu", subGraphIdx); + return RET_PARAM_INVALID; + } + */ + // return IsolateOneWayNode(graph->subgraphs.at(subGraphIdx).get(), nodeIdx, removeTensor); + return IsolateOneWayNode(graph, nodeIdx, removeTensor); +} + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) { + MS_ASSERT(graphT != nullptr); + if (graphT->nodes.size() <= nodeIdx) { + // MS_LOG(ERROR)("nodeIdx out of range: %zu", nodeIdx); + return RET_PARAM_INVALID; + } + + CNodeT *node = graphT->nodes.at(nodeIdx).get(); + auto inputTensorIdxes = node->inputIndex; + auto outputTensorIdxes = node->outputIndex; + auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); + if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) { + // MS_LOG(ERROR)("Only support node who has no more than one input and one output"); + return RET_ERROR; + } + if (inputTensorIdxes.empty()) { + // MS_LOG(ERROR)("Error, %zuth node has no input tensor", nodeIdx); + return RET_ERROR; + } + auto inDataTensorIdx = inputTensorIdxes.front(); + if (!outputTensorIdxes.empty()) { + auto outDataTensorIdx = outputTensorIdxes.front(); + MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); + MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr); + auto &gOutTensorIdx = graphT->outputIndex; + for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + // find poseNode + auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + } + } + + if (removeTensor) { + // now all node's outputTensors are useless + // remove all node's outputTensors + auto status = RemoveTensor(graphT, outputTensorIdxes); + if (status != RET_OK) { + // MS_LOG(ERROR)("RemoveOutputTensors of node %s failed", node->name.c_str()); + return RET_ERROR; + } + } + node->inputIndex.clear(); + node->outputIndex.clear(); + return RET_OK; +} + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(node != nullptr); + bool isSubNode = false; + size_t nodeIdx = 0; + for (size_t i = 0; i < graphT->nodes.size(); i++) { + auto &inNode = graphT->nodes.at(i); + if (inNode->name == node->name) { + isSubNode = true; + nodeIdx = i; + break; + } + } + if (!isSubNode) { + // MS_LOG(ERROR)("Node %s is not in graphT %s", node->name.c_str(), graphT->name.c_str()); + return RET_PARAM_INVALID; + } else { + return IsolateOneWayNode(graphT, nodeIdx, removeTensor); + } +} + +STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector toDeleteTensorIdxes, bool forceDelete) { + for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { + uint32_t deleteIdx = *iter; + if (!forceDelete) { + if (GetRefCount(graphT, deleteIdx) > 1) { + iter++; + continue; + } + } + // update graph input indexes + for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { + if (*gInIdx > deleteIdx) { + (*gInIdx)--; + } + } + // update graph output indexes + for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { + if (*gOutIdx > deleteIdx) { + (*gOutIdx)--; + } + } + // update nodes indexes + for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) { + // update nodes input indexes + UpdateNodeIndex((*nodeIter).get(), deleteIdx); + } + // update deleteTensorIdx + for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { + if (*selfIt > deleteIdx) { + (*selfIt)--; + } + } + graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx); + iter = toDeleteTensorIdxes.erase(iter); + } + return RET_OK; +} + +STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) { + for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { + if (*inIdxIt == deleteIdx) { + inIdxIt = node->inputIndex.erase(inIdxIt); + } else { + if (*inIdxIt > deleteIdx) { + (*inIdxIt)--; + } + inIdxIt++; + } + } + // update nodes output indexes + for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) { + if (*outIdxIt == deleteIdx) { + outIdxIt = node->outputIndex.erase(outIdxIt); + } else { + if (*outIdxIt > deleteIdx) { + (*outIdxIt)--; + } + outIdxIt++; + } + } + return RET_OK; +} + +STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr tensor, + InsertPlace place) { + if (nodeIdx >= graphT->nodes.size()) { + // MS_LOG(ERROR)("nodeIdx out of range: %du", nodeIdx); + return RET_PARAM_INVALID; + } + graphT->allTensors.emplace_back(std::move(tensor)); + uint32_t newTensorIdx = graphT->allTensors.size() - 1; + auto node = graphT->nodes.at(nodeIdx).get(); + if (place == kBefore) { + node->inputIndex.emplace_back(newTensorIdx); + } else { + node->outputIndex.emplace_back(newTensorIdx); + } + return RET_OK; +} + +STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, + std::unique_ptr tensor) { + if (nodeIdx >= graphT->nodes.size()) { + // MS_LOG(ERROR)("nodeIdx out of range: %du", nodeIdx); + return RET_PARAM_INVALID; + } + auto node = graphT->nodes.at(nodeIdx).get(); + if (inTensorIdx >= graphT->allTensors.size()) { + // MS_LOG(ERROR)("inTensorIdx out of range: %du", nodeIdx); + return RET_PARAM_INVALID; + } + if (!IsContain(node->inputIndex, inTensorIdx)) { + // MS_LOG(ERROR)("inTensorIdx(%du) is not a inputIdx of node(%du)", inTensorIdx, nodeIdx); + return RET_PARAM_INVALID; + } + graphT->allTensors.at(inTensorIdx).swap(tensor); + return RET_OK; +} + +NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { + if (existNodeIdx >= graphT->nodes.size()) { + // MS_LOG(ERROR)("nodeIdx out of range: %du", existNodeIdx); + return graphT->nodes.end(); + } + auto nodeIter = graphT->nodes.begin() + existNodeIdx; + MS_ASSERT(nodeIter != graphT->nodes.begin()); + MS_ASSERT((*nodeIter) != nullptr); + return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode); +} + +NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { + if (place == kBefore) { + return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); + } else if (place == kAfter) { + return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); + } else { + // MS_LOG(ERROR)("Invalid InsertPlace : %d", place); + return graphT->nodes.end(); + } +} + +NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, + std::unique_ptr toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { + auto &existNode = *existNodeIter; + MS_ASSERT(existNode != nullptr); + MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx); + MS_ASSERT(toAddNodeIn != nullptr); + auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx); + MS_ASSERT(graphT->allTensors.size() > preTensorIdx); + + auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx); + if (preNodeIdxes.empty()) { + auto &preTensor = graphT->allTensors.at(preTensorIdx); + MS_ASSERT(preTensor != nullptr); + auto toAddTensor = CopyTensorDefT(preTensor); + if (toAddTensor == nullptr) { + MS_LOG(ERROR) << "Copy TensorT failed"; + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + preTensor->refCount = 0; + preTensor->data.clear(); + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + MS_LOG(ERROR) << "copy toAddNodeIn failed"; + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(toAddTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(preTensorIdx); + for (auto iter = graphT->inputIndex.begin(); iter != graphT->inputIndex.end(); iter++) { + if (*iter == preTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } else { + std::vector> toAddNodes; + int i = 0; + for (size_t preNodeIdx : preNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > preNodeIdx); + auto &preNode = graphT->nodes.at(preNodeIdx); + MS_ASSERT(preNode != nullptr); + auto &preTensor = graphT->allTensors.at(preTensorIdx); + MS_ASSERT(preTensor != nullptr); + auto toAddTensor = CopyTensorDefT(preTensor); + if (toAddTensor == nullptr) { + *errorCode = RET_NULL_PTR; + // MS_LOG(ERROR)("Copy TensorT failed"); + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + // MS_LOG(ERROR)("copy toAddNodeIn failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++); + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(preTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) { + if (*iter == preTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + toAddNodes.emplace_back(std::move(toAddNode)); + } + for (auto &toAddNode : toAddNodes) { + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } + } + *errorCode = RET_OK; + return existNodeIter; +} + +NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, + std::unique_ptr toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { + auto &existNode = *existNodeIter; + MS_ASSERT(existNode != nullptr); + MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx); + MS_ASSERT(toAddNodeIn != nullptr); + auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx); + MS_ASSERT(graphT->allTensors.size() > postTensorIdx); + + auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx); + if (postNodeIdxes.empty()) { + auto &postTensor = graphT->allTensors.at(postTensorIdx); + MS_ASSERT(postTensor != nullptr); + auto toAddTensor = CopyTensorDefT(postTensor); + if (toAddTensor == nullptr) { + // MS_LOG(ERROR)("Copy TensorT failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + // MS_LOG(ERROR)("copy toAddNodeIn failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(postTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) { + if (*iter == postTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } else { + std::vector> toAddNodes; + int i = 0; + for (size_t postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + auto &postTensor = graphT->allTensors.at(postTensorIdx); + MS_ASSERT(postTensor != nullptr); + auto toAddTensor = CopyTensorDefT(postTensor); + if (toAddTensor == nullptr) { + // MS_LOG(ERROR)("Copy TensorT failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + // MS_LOG(ERROR)("copy toAddNodeIn failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++); + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(postTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + MS_ASSERT(IsContain(postNode->inputIndex, postTensorIdx)); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == postTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + toAddNodes.emplace_back(std::move(toAddNode)); + } + for (auto &toAddNode : toAddNodes) { + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } + } + *errorCode = RET_OK; + return existNodeIter; +} + +STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) { + if (modelFile.size() > fileType.size()) { + if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) { + return RET_OK; + } else { + return RET_ERROR; + } + } else { + return RET_ERROR; + } +} + +std::string GetModelName(const std::string &modelFile) { + std::string modelName = modelFile; + modelName = modelName.substr(modelName.find_last_of('/') + 1); + modelName = modelName.substr(0, modelName.find_last_of('.')); + + srand((unsigned)time(NULL)); + modelName = modelName + std::to_string(rand()); + + return modelName; +} + +OpGraphT *OpGraphT::Build(const schema::MetaGraphT *subGraphDef) { + if (subGraphDef == nullptr) { + // MS_LOG(ERROR)("subGraphDef is nullptr"); + return nullptr; + } + auto graph = std::unique_ptr(new OpGraphT()); + if (graph == nullptr) { + // MS_LOG(ERROR)("malloc opgraph failed"); + return nullptr; + } + + auto &opDefs = subGraphDef->nodes; + + for (auto &opDef : opDefs) { + auto ret = graph->AddEdge(opDef.get(), &opDefs); + if (ret != RET_OK) { + // MS_LOG(ERROR)("%s add edge failed. ret:%d", opDef->name.c_str(), ret); + return nullptr; + } + } + + return graph.release(); +} + +int OpGraphT::AddEdge(const schema::CNodeT *srcNodeDef, const std::vector> *nodeDefs) { + MS_ASSERT(srcNodeDef != nullptr); + MS_ASSERT(nodeDefs != nullptr); + NODE_ID srcId = std::string(srcNodeDef->name); + // for single op condition + AddNode(srcId); + for (auto index : srcNodeDef->outputIndex) { + for (auto &dstNodeDef : *nodeDefs) { + bool find = false; + auto inputIndex = dstNodeDef->inputIndex; + if (std::any_of(inputIndex.begin(), inputIndex.end(), [&index](int i) { return i == index; })) { + find = true; + } + + if (!find) { + continue; + } + NODE_ID dstId = std::string(dstNodeDef->name.c_str()); + auto ret = AddEdge(srcId, dstId); + if (ret != RET_OK) { + return ret; + } + } + } + return RET_OK; +} + +int OpGraphT::AddEdge(NODE_ID srcId, NODE_ID dstId) { + auto srcNode = AddNode(srcId); + if (srcNode == nullptr) { + // MS_LOG(ERROR)("add srcNode failed"); + return RET_ERROR; + } + srcNode->AddOutEdge(dstId); + auto dstNode = AddNode(dstId); + if (dstNode == nullptr) { + // MS_LOG(ERROR)("add dstNode failed"); + return RET_ERROR; + } + dstNode->AddInEdge(srcId); + return RET_OK; +} + +OpGraphT::~OpGraphT() { + for (auto iter : nodes) { + delete iter.second; + } + nodes.clear(); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h new file mode 100644 index 0000000000..818c53502b --- /dev/null +++ b/mindspore/lite/tools/common/graph_util.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_GRAPH_UTIL_H +#define MINDSPORE_PREDICT_GRAPH_UTIL_H + +#include +#include +#include +#include +#include +#include + +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" +#include "src/common/graph_util.h" + +namespace mindspore { +namespace lite { +using STATUS = int; +enum InsertPlace { kBefore, kAfter }; + +using NodeIter = std::vector>::iterator; + +using OpDefCopyer = std::function(std::unique_ptr &)>; + +OpDefCopyer GetSimpleOpCopyer(); + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1); + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, + int inputIndexIdx = -1); + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1); + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, + int outputIndexIdx = -1); + +std::vector GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); + +std::vector GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); + +STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node); + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true); + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true); + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true); + +STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx); + +STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector toDeleteTensorIdxes, bool forceDelete = false); + +STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr tensor, + InsertPlace place = kBefore); + +STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, + std::unique_ptr tensor); + +NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, + std::unique_ptr toAddNode, STATUS *errorCode, + OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); + +NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, + OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); + +NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); + +NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); + +STATUS ValidateFileStr(const std::string &modelFile, std::string fileType); +std::string GetModelName(const std::string &modelFile); + +class OpGraphT : public OpGraph { + public: + OpGraphT() {} + ~OpGraphT(); + static OpGraphT *Build(const schema::MetaGraphT *subGraphDef); + + private: + int AddEdge(NODE_ID srcId, NODE_ID dstId); + int AddEdge(const schema::CNodeT *srcNodeDef, const std::vector> *nodeDefs); +}; + +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_GRAPH_UTIL_H + diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc new file mode 100644 index 0000000000..2f3312a8cb --- /dev/null +++ b/mindspore/lite/tools/common/node_util.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/node_util.h" +#include +#include +#include "src/common/common.h" +#include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(node != nullptr); + // set quantParam to preNode + for (size_t i = 0; i < node->inputIndex.size(); i++) { + auto preNodeIdexes = GetInputNodeIdx(*graphT, *(node.get()), i); + for (auto preNodeIdx : preNodeIdexes) { + MS_ASSERT(graphT->nodes.size() > preNodeIdx); + auto &preNode = graphT->nodes.at(preNodeIdx); + MS_ASSERT(preNode != nullptr); + // if preNode is not init, it maybe not a quantNode, so skip + // if (preNode->inputIndex.size() + preNode->outputIndex.size() != preNode->quantParam.size()) { + // continue; + // } + auto preNodeOutputIndexes = preNode->outputIndex; + int32_t currentNodeIndexInPre = -1; + for (auto index : preNodeOutputIndexes) { + currentNodeIndexInPre++; + if (index == node->inputIndex.at(i)) { + break; + } + } + MS_ASSERT(currentNodeIndexInPre != -1); + MS_ASSERT(node->quantParam.size() > i); + MS_ASSERT(node->quantParam.at(i) != nullptr); + // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(i)); + // if (quantParamArrayCopy == nullptr) { + // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); + // return RET_ERROR; + // } + // preNode->quantParam.at(preNode->inputIndex.size() + currentNodeIndexInPre) = + // std::move(CopyQuantParamArrayT(quantParamArrayCopy)); + } + } + + // set quantParam to postNode + for (size_t i = 0; i < node->outputIndex.size(); i++) { + auto postNodeIdexes = GetOutputNodeIdx(*graphT, *(node.get()), i); + for (auto postNodeIdx : postNodeIdexes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + // if postNode is not init, it maybe not a quantNode, so skip + // if (postNode->inputIndex.size() + postNode->outputIndex.size() != postNode->quantParam.size()) { + // continue; + // } + auto postNodeInputIndexes = postNode->inputIndex; + int32_t currentNodeIndexInPost = -1; + for (auto index : postNodeInputIndexes) { + currentNodeIndexInPost++; + if (index == node->outputIndex.at(i)) { + break; + } + } + MS_ASSERT(currentNodeIndexInPost != -1); + MS_ASSERT(node->quantParam.size() > node->inputIndex.size() + i); + MS_ASSERT(node->quantParam.at(node->inputIndex.size() + i) != nullptr); + // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(node->inputIndex.size() + i)); + // if (quantParamArrayCopy == nullptr) { + // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); + // return RET_ERROR; + // } + // postNode->quantParam.at(currentNodeIndexInPost) = std::move(CopyQuantParamArrayT(quantParamArrayCopy)); + } + } + return RET_OK; +} + +static const std::vector nhwcOpList = { + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, + schema::PrimitiveType_FusedBatchNorm}; + +static const std::vector fp32FullOpList = { + schema::PrimitiveType_Concat, schema::PrimitiveType_Add, + schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32 + +static const std::vector uint8NeedNhwcOpList = {}; + +static const std::vector uint8OpList = { + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Activation}; + +std::vector Getfp32FullOpList() { return fp32FullOpList; } + +std::vector GetNhwcOpList() { return nhwcOpList; } + +std::vector GetUint8NhwcOpList() { return uint8NeedNhwcOpList; } + +std::vector GetUint8OpList() { return uint8OpList; } + +STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vector &src_dims, + mindspore::lite::Format dst_format, std::vector *dst_dims) { + if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { + // MS_LOG(ERROR)("Convert format , src size %lu <3 or src format is equal to dst format,not need convert", + // src_dims.size()); + *dst_dims = src_dims; + return RET_PARAM_INVALID; + } + + std::vector nchw_dim; + switch (src_format) { + case Format_NCHW: + nchw_dim = src_dims; + break; + case Format_NHWC: + if (src_dims.size() == DIM_DEFAULT_SIZE) { + nchw_dim.push_back(src_dims[NHWC_N]); + nchw_dim.push_back(src_dims[NHWC_C]); + nchw_dim.push_back(src_dims[NHWC_H]); + nchw_dim.push_back(src_dims[NHWC_W]); + } else { + nchw_dim.push_back(src_dims[HWC_C]); + nchw_dim.push_back(src_dims[HWC_H]); + nchw_dim.push_back(src_dims[HWC_W]); + } + break; + default: + // MS_LOG(ERROR)("Not support src format: %d", src_format); + return RET_ERROR; + } + + if (nchw_dim.size() == 0) { + // MS_LOG(ERROR)("Param nchw_dim is empty!"); + return RET_ERROR; + } + + switch (dst_format) { + case Format_NCHW: + *dst_dims = nchw_dim; + break; + case Format_NHWC: + if (src_dims.size() == DIM_DEFAULT_SIZE) { + dst_dims->push_back(nchw_dim[NCHW_N]); + dst_dims->push_back(nchw_dim[NCHW_H]); + dst_dims->push_back(nchw_dim[NCHW_W]); + dst_dims->push_back(nchw_dim[NCHW_C]); + } + break; + default: + // MS_LOG(ERROR)("Not support dst format: %d", dst_format); + return RET_ERROR; + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h new file mode 100644 index 0000000000..43e9aef558 --- /dev/null +++ b/mindspore/lite/tools/common/node_util.h @@ -0,0 +1,373 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_NODE_UTIL_H +#define MINDSPORE_PREDICT_NODE_UTIL_H + +#include +#include +#include "schema/inner/model_generated.h" +#include "src/common/common.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +using STATUS = int; +STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node); + +std::vector GetNhwcOpList(); + +std::vector Getfp32FullOpList(); + +std::vector GetUint8NhwcOpList(); + +std::vector GetUint8OpList(); + +class NodeUtils { + public: + static STATUS ConvertDims(schema::Format src_format, const std::vector &src_dims, schema::Format dst_format, + std::vector *dst_dims); + + static void SliceData(std::vector &input, int64_t chunk_size, std::vector &output, int64_t begin, + int64_t out_dim, int64_t stride); + + static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, + std::vector &begin, std::vector &output_dims, + schema::TensorT *output, std::vector &stride); +}; + +// todo check this +enum kTransFilterType { + kKCHW2HWCK, + kKCHW2KHWC, + kCKHW2KHWC, + kCKHW2HWCK, + kKCHW2HWKC, + kCKHW2HWKC, + kHWCK2KCHW, + kHWCK2CKHW, + kHWKC2KCHW, + kHWKC2CKHW, + kNHWC2KCHW, + kNHWC2CKHW, + kNHWC2HWCK, + kKHWC2HWCK, + kCHWK2HWCK, + kKHWC2CHWK, + kCHWK2KHWC +}; + +static STATUS GetFilterDim(std::vector &oriDims, kTransFilterType type, int32_t &filterK, int32_t &filterC, + int32_t &filterH, int32_t &filterW) { + MS_ASSERT(oriDims.size() == 4); + if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC) { + filterK = oriDims.at(KCHW_K); + filterC = oriDims.at(KCHW_C); + filterH = oriDims.at(KCHW_H); + filterW = oriDims.at(KCHW_W); + } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { + filterC = oriDims.at(CKHW_C); + filterK = oriDims.at(CKHW_K); + filterH = oriDims.at(CKHW_H); + filterW = oriDims.at(CKHW_W); + } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { + filterH = oriDims.at(HWCK_H); + filterW = oriDims.at(HWCK_W); + filterC = oriDims.at(HWCK_C); + filterK = oriDims.at(HWCK_K); + } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { + filterH = oriDims.at(HWKC_H); + filterW = oriDims.at(HWKC_W); + filterK = oriDims.at(HWKC_K); + filterC = oriDims.at(HWKC_C); + } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { + filterK = oriDims.at(NHWC_N); + filterH = oriDims.at(NHWC_H); + filterW = oriDims.at(NHWC_W); + filterC = oriDims.at(NHWC_C); + } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { + filterC = oriDims.at(CHWK_C); + filterH = oriDims.at(CHWK_H); + filterW = oriDims.at(CHWK_W); + filterK = oriDims.at(CHWK_K); + } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { + filterK = oriDims.at(KHWC_K); + filterH = oriDims.at(KHWC_H); + filterW = oriDims.at(KHWC_W); + filterC = oriDims.at(KHWC_C); + } else { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + return RET_OK; +} + +static STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW) { + MS_ASSERT(tensor != nullptr); + if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { + tensor->dims = {filterH, filterW, filterC, filterK}; + } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { + tensor->dims = {filterH, filterW, filterK, filterC}; + } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { + tensor->dims = {filterK, filterC, filterH, filterW}; + } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW) { + tensor->dims = {filterC, filterK, filterH, filterW}; + } else if (type == kKHWC2CHWK) { + tensor->dims = {filterC, filterH, filterW, filterK}; + } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { + tensor->dims = {filterK, filterH, filterW, filterC}; + } else { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + return RET_OK; +} + +template +static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW) { + MS_ASSERT(tensor != nullptr); + int count = filterH * filterW * filterC * filterK; + if (count <= 0) { + MS_LOG(ERROR) << "Dim size invalid"; + return RET_ERROR; + } + std::unique_ptr buf(new (std::nothrow) T[count]); + if (buf == nullptr) { + MS_LOG(ERROR) << "new buf failed"; + return RET_ERROR; + } + + void *originWeightDate = tensor->data.data(); + T *weightData = static_cast(originWeightDate); + + if (weightData == nullptr) { + MS_LOG(ERROR) << "weightData is nullptr"; + return RET_ERROR; + } + T *p1Buff = nullptr; + T *p2Buff = nullptr; + switch (type) { + case kCHWK2HWCK: + case kCHWK2KHWC: { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); + if (type == kCHWK2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCHWK2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kKHWC2HWCK: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kKCHW2HWCK: + case kKCHW2KHWC: + case kKCHW2HWKC: { + for (int k = 0; k < filterK; ++k) { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + if (type == kKCHW2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else { + p2Buff = + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kCKHW2HWCK: + case kCKHW2KHWC: + case kCKHW2HWKC: { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + if (type == kCKHW2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c)); + } else { + p2Buff = + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kHWCK2KCHW: + case kHWCK2CKHW: { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + if (type == kHWCK2KCHW) { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kHWKC2KCHW: + case kHWKC2CKHW: { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kHWKC2KCHW) { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kNHWC2HWCK: + case kNHWC2KCHW: + case kNHWC2CKHW: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kNHWC2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kNHWC2CKHW) { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kKHWC2CHWK: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + } + + auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +template +static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) { + MS_ASSERT(tensor != nullptr); + std::vector oriDims = tensor->dims; + if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) { + MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); + return RET_ERROR; + } + + int32_t filterH; + int32_t filterW; + int32_t filterC; + int32_t filterK; + auto status = GetFilterDim(oriDims, type, filterK, filterC, filterH, filterW); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetFilterDim failed: " << status; + return status; + } + status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetFilterDim failed: " << status; + return status; + } + status = TransFilterData(tensor, type, filterK, filterC, filterH, filterW); + if (status != RET_OK) { + MS_LOG(ERROR) << "TransFilterData failed: " << status; + return status; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_PREDICT_NODE_UTIL_H + diff --git a/mindspore/lite/tools/common/option.h b/mindspore/lite/tools/common/option.h new file mode 100644 index 0000000000..8b323b7336 --- /dev/null +++ b/mindspore/lite/tools/common/option.h @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_COMMON_OPTION_H_ +#define PREDICT_COMMON_OPTION_H_ + +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +template +struct InnerSome { + explicit InnerSome(const T &t) : _t(std::move(t)) {} + + T _t; +}; + +template +InnerSome::type> Some(T &&t) { + return InnerSome::type>(std::forward(t)); +} + +struct None {}; + +template +class Option { + public: + Option() : state(NONE) {} + + explicit Option(const T &t) : data(t), state(SOME) {} + + explicit Option(T &&t) : data(std::move(t)), state(SOME) {} + + explicit Option(const InnerSome &some) : data(some._t), state(SOME) {} + + explicit Option(const None &none) : state(NONE) {} + + Option(const Option &that) : state(that.state) { + if (that.IsSome()) { + new (&data) T(that.data); + } + } + + virtual ~Option() {} + + bool IsNone() const { return state == NONE; } + + bool IsSome() const { return state == SOME; } + + const T &Get() const & { + MS_ASSERT(IsSome()); + return data; + } + + T &Get() & { + MS_ASSERT(IsSome()); + return data; + } + + T &&Get() && { + MS_ASSERT(IsSome()); + return std::move(data); + } + + const T &&Get() const && { + MS_ASSERT(IsSome()); + return std::move(data); + } + + // oprerator override + Option &operator=(const Option &that) { + if (&that != this) { + if (IsSome()) { + data.~T(); + } + state = that.state; + if (that.IsSome()) { + new (&data) T(that.data); + } + } + + return *this; + } + + bool operator==(const Option &that) const { + return (IsNone() && that.IsNone()) || (IsSome() && that.IsSome() && data == that.data); + } + + bool operator!=(const Option &that) const { return !(*this == that); } + + bool operator==(const T &that) const { return IsSome() && data == that; } + + bool operator!=(const T &that) const { return !(*this == that); } + + private: + enum State { NONE = 0, SOME = 1 }; + + T data; + State state; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_COMMON_OPTION_H_ + diff --git a/mindspore/lite/tools/common/storage.cc b/mindspore/lite/tools/common/storage.cc new file mode 100755 index 0000000000..dac5071502 --- /dev/null +++ b/mindspore/lite/tools/common/storage.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/storage.h" +#include "flatbuffers/flatbuffers.h" +#include "utils/log_adapter.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath) { + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, &graph); + builder.Finish(offset); + int size = builder.GetSize(); + auto content = builder.GetBufferPointer(); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + + std::ofstream output(outputPath + ".ms", std::ofstream::binary); + if (!output.is_open()) { + MS_LOG(ERROR) << "ofstream open failed"; + return RET_ERROR; + } + + output.write((const char *)content, size); + output.close(); + return RET_OK; +} + +schema::MetaGraphT *Storage::Load(const std::string &inputPath) { + size_t size; + auto buf = ReadFile(inputPath.c_str(), &size); + if (buf == nullptr) { + // MS_LOG(ERROR)("the file buffer is nullptr"); + return nullptr; + } + + flatbuffers::Verifier verify((const uint8_t *)buf, size); + // if (false == VerifyGraphDefBuffer(verify)) { + // //MS_LOG(ERROR)("the buffer is invalid and fail to create graph"); + // return nullptr; + // } + + auto graphDefT = schema::UnPackMetaGraph(buf); + return graphDefT.release(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/common/storage.h b/mindspore/lite/tools/common/storage.h new file mode 100644 index 0000000000..c1cdfa27ad --- /dev/null +++ b/mindspore/lite/tools/common/storage.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_COMMON_STORAGE_H_ +#define PREDICT_COMMON_STORAGE_H_ + +#include +#include +#include "include/errorcode.h" +#include "flatbuffers/flatbuffers.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +class Storage { + public: + int Save(const schema::MetaGraphT &graph, const std::string &outputPath); + + schema::MetaGraphT *Load(const std::string &inputPath); +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_COMMON_STORAGE_H_ + diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc new file mode 100644 index 0000000000..e41d5efd2c --- /dev/null +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/common/utils.h" +#include "tools/common/tensor_util.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +std::unique_ptr CopyQuantParamArrayT(const std::unique_ptr &srcQuantParamArray) { + MS_ASSERT(srcQuantParamArray != nullptr); + auto dstQuantParamArrayT = std::unique_ptr(new (std::nothrow) QuantParamT()); + if (dstQuantParamArrayT == nullptr) { + // MS_LOG(ERROR)("new dstQuantParamArrayT failed"); + return nullptr; + } + /* + for (size_t i = 0; i < srcQuantParamArray->param.size(); i++) { + auto &srcQuantParam = srcQuantParamArray->param.at(i); + MS_ASSERT(srcQuantParam != nullptr); + std::unique_ptr dstQuantParam(new (std::nothrow) QuantParamT()); + if (dstQuantParam == nullptr) { + //MS_LOG(ERROR)("new dstQuantParam failed"); + dstQuantParamArrayT.release(); + return nullptr; + } + dstQuantParam->scale = srcQuantParam->scale; + dstQuantParam->zeroPoint = srcQuantParam->zeroPoint; + dstQuantParam->min = srcQuantParam->min; + dstQuantParam->max = srcQuantParam->max; + dstQuantParam->narrowRange = srcQuantParam->narrowRange; + dstQuantParam->numBits = srcQuantParam->numBits; + dstQuantParamArrayT->param.emplace_back(std::move(dstQuantParam)); + } + */ + return std::move(dstQuantParamArrayT); +} + +std::unique_ptr GetInTensorQuantParamArray(const MetaGraphT &graphT, size_t tensorIdx) { + auto preNodeIdxes = GetLinkedPreIdx(graphT, tensorIdx); + MS_ASSERT(preNodeIdxes.size() <= 1); + if (preNodeIdxes.empty()) { + // MS_LOGD("the %zuth tensor has no preNode", tensorIdx); + return nullptr; + } + auto preNodeIdx = preNodeIdxes.front(); + MS_ASSERT(preNodeIdx < graphT.nodes.size()); + auto &preNode = graphT.nodes.at(preNodeIdx); + MS_ASSERT(preNode != nullptr); + MS_ASSERT(preNode->inputIndex.size() + preNode->outputIndex.size() == preNode->quantParam.size()); + /* + for (size_t i = 0; i < preNode->outputIndex.size(); i++) { + if (preNode->outputIndex.at(i) == tensorIdx) { + auto &quantPArray = preNode->quantParam.at(preNode->inputIndex.size() + i); + MS_ASSERT(quantPArray->param.size() == 1); // only support prelayer + MS_ASSERT(quantPArray->param.front() != nullptr); + if (quantPArray->param.front()->min == FLT_MAX) { + //MS_LOGD("the %zuth tensor's preNode's relative quantParam has not be inited", tensorIdx); + return nullptr; + } else { + return std::move(CopyQuantParamArrayT(quantPArray)); + } + } + } + */ + MS_ASSERT(false); + return nullptr; +} + +std::unique_ptr GetOutTensorQuantParamArray(const MetaGraphT &graphT, size_t tensorIdx) { + auto postNodeIdxes = GetLinkedPostIdx(graphT, tensorIdx); + if (postNodeIdxes.empty()) { + // MS_LOGD("the %zuth tensor has no postNode", tensorIdx); + return nullptr; + } + // find one postNode which can give valid quantParamArray + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(postNodeIdx < graphT.nodes.size()); + auto &postNode = graphT.nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + MS_ASSERT(postNode->inputIndex.size() + postNode->outputIndex.size() == postNode->quantParam.size()); + /* + for (size_t i = 0; i < postNode->inputIndex.size(); i++) { + if (postNode->inputIndex.at(i) == tensorIdx) { + auto &quantPArray = postNode->quantParam.at(i); + MS_ASSERT(quantPArray->param.size() == 1); // only support prelayer + MS_ASSERT(quantPArray->param.front() != nullptr); + // check if postNode has valid quantParam + if (quantPArray->param.front()->min == FLT_MAX) { + continue; + } + MS_ASSERT(graphT.allTensors.size() > postNode->inputIndex.at(i)); + auto &tensor = graphT.allTensors.at(postNode->inputIndex.at(i)); + MS_ASSERT(tensor != nullptr); + if (tensor->refCount == schema::NodeType_ValueNode) { + continue; + } + // find valid quantParam return + auto paramArray = CopyQuantParamArrayT(quantPArray); + if (paramArray == nullptr) { + //MS_LOG(ERROR)("CopyQuantParamArrayT return nullptr"); + return nullptr; + } + return std::move(paramArray); + } + }*/ + } + return nullptr; +} + +size_t GetElementSize(const TensorT &tensor) { return GetElementSize(TypeId(tensor.dataType)); } + +size_t GetElementSize(const TypeId &dataType) { + switch (dataType) { + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeFloat: + return sizeof(float); + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + default: + return sizeof(float); + } +} + +size_t GetShapeSize(const TensorT &tensor) { + auto shape = tensor.dims; + size_t shapeSize = 1; + for (auto dim : shape) { + shapeSize *= dim; + } + return shapeSize; +} + +std::unique_ptr CopyTensorDefT(const std::unique_ptr &oldTensor) { + auto newTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (newTensor == nullptr) { + // MS_LOG(ERROR)("new TensorT failed"); + return nullptr; + } + newTensor->dims = oldTensor->dims; + newTensor->format = oldTensor->format; + newTensor->dataType = oldTensor->dataType; + newTensor->refCount = oldTensor->refCount; + newTensor->nodeType = oldTensor->nodeType; + newTensor->data = oldTensor->data; + return std::move(newTensor); +} + +size_t GetRefCount(MetaGraphT *graphT, uint32_t tensorIdx) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(graphT->allTensors.size() > tensorIdx); + size_t refCount = 0; + for (auto &node : graphT->nodes) { + MS_ASSERT(node != nullptr); + if (IsContain(node->inputIndex, tensorIdx)) { + refCount++; + } + } + return refCount; +} +size_t GetShapeSize(const std::vector &shape) { + size_t shapeSize = 1; + for (auto dim : shape) { + shapeSize *= dim; + } + return shapeSize; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h new file mode 100644 index 0000000000..93cb3520a1 --- /dev/null +++ b/mindspore/lite/tools/common/tensor_util.h @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TENSOR_UTIL_H +#define MINDSPORE_PREDICT_TENSOR_UTIL_H + +#include +#include +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace lite { +using schema::TensorT; +using schema::MetaGraphT; +using schema::CNodeT; +using schema::QuantParamT; +using schema::Format; +using schema::FusedBatchNormT; +using schema::Format_NCHW; +using schema::Format_NHWC; +using STATUS = int; +size_t GetElementSize(const TensorT &tensor); + +size_t GetElementSize(const TypeId &dataType); + +size_t GetShapeSize(const TensorT &tensor); + +size_t GetShapeSize(const std::vector &shape); + +std::unique_ptr CopyTensorDefT(const std::unique_ptr &); + +size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx); + +std::unique_ptr \ + CopyQuantParamArrayT(const std::unique_ptr &srcQuantParamArray); + +std::unique_ptr GetInTensorQuantParamArray(const schema::MetaGraphT &graphT, size_t tensorIdx); + +std::unique_ptr GetOutTensorQuantParamArray(const schema::MetaGraphT &graphT, size_t tensorIdx); + +using MSGraphDefTPtr = std::shared_ptr; + +enum TensorType { CONST = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 }; + +class TensorCache { + public: + TensorCache() {} + + ~TensorCache() { tensors.clear(); } + + int AddTensor(const std::string &name, TensorT *tensor, int TensorType) { + index++; + if (TensorType == CONST || TensorType == TF_CONST || TensorType == GRAPH_INPUT) { + tensor->refCount = 1; + tensor->nodeType = schema::NodeType_ValueNode; + } else { + tensor->nodeType = schema::NodeType_Parameter; + } + tensors.push_back(tensor); + + if (TensorType == GRAPH_INPUT) { + graphInputs.push_back(index); + } + + if (TensorType == GRAPH_INPUT || TensorType == OP_OUTPUT || TensorType == TF_CONST) { + UpdateTensorIndex(name, index); + } + return index; + } + + // find the name index + int FindTensor(const std::string &name) { + auto iter = tensorIndex.find(name); + if (iter != tensorIndex.end()) { + return iter->second; + } + return -1; + } + + void UpdateTensorIndex(const std::string &name, int index) { + auto iter = tensorIndex.find(name); + if (iter != tensorIndex.end()) { + tensorIndex[name] = index; + } else { + tensorIndex.insert(make_pair(name, index)); + } + } + + // return allTensors + const std::vector &GetCachedTensor() const { return tensors; } + + const std::vector &GetGraphInputs() const { return graphInputs; } + + private: + std::vector tensors; + std::unordered_map tensorIndex; + std::vector graphInputs; + int index = -1; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TENSOR_UTIL_H + diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt new file mode 100644 index 0000000000..f8cc8bcfa5 --- /dev/null +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -0,0 +1,114 @@ +set(ANF_SRC + ${ANF_SRC} +#core / abstract + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/abstract_function.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/analysis_context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/param_validator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/abstract_value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/dshape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/utils.cc +#core / base + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/base/base_ref.cc +#core / ir + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/anf.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/anf_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/meta_func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/graph_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/func_graph_cloner.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/func_graph_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/primitive.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/visitor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/meta_tensor_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/named.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/type.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/type_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/any.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/symbolic.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/misc.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/flags.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/trace_base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/trace_info.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/label.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/info.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/profile.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/ms_context.cc + ## ccsrc + ${CCSRC_DIR}/debug/draw.cc + ${CCSRC_DIR}/pybind_api/export_flags.cc + ${CCSRC_DIR}/utils/context/context_extends.cc + ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc + ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc + ${CCSRC_DIR}/backend/optimizer/common/visit.cc + ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc + ) + +file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc) + +file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/converter_flags.cc + ${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc + ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc + ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/anf_importer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_meta_graphT.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc + ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc + + ../optimizer/common/node_pass_extends.cc + ../optimizer/common/pass_manager_extends.cc + ../optimizer/common/gllo_utils.cc + ../optimizer/fusion/conv_biasadd_fusion.cc + ../optimizer/fusion/conv_activation_fusion.cc + ../optimizer/fusion/conv_transform_fusion.cc + ../optimizer/fusion/conv_scale_fusion.cc + ../optimizer/fusion/conv_bn_fusion.cc + ) + +add_subdirectory(parser/caffe) +add_subdirectory(parser/tflite) +add_subdirectory(legacy_optimizer) +add_subdirectory(quantizer) + +add_executable(converter_lite + main.cc + ${ANF_SRC} + ${CONVERTER_SRC} + ${OPS_SRC} + ) +target_link_libraries(converter_lite PRIVATE + tflite_parser_mid + caffe_parser_mid + anf_exporter_mid + node_mid + graph_pass_mid + fusion_mid + protobuf + quantizer_mid + pthread + mindspore-lite + ${SECUREC_LIBRARY} + mindspore::json + mindspore::eigen + ) + diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc new file mode 100644 index 0000000000..5c8fcf2c7f --- /dev/null +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/anf_transform.h" +#include +#include +#include "utils/log_adapter.h" +#include "mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h" +#include "mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h" +#include "mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h" +#include "mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h" + +using std::string; +namespace mindspore { +namespace lite { +AnfTransform::AnfTransform() = default; + +AnfTransform::~AnfTransform() = default; + +void AnfTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } + +FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) { + // return old_graph; + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared(true, "conv_relu", schema::PrimitiveType_Activation, + schema::ActivationType_RELU)); + pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, + schema::ActivationType_RELU6)); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(old_graph); + return new_graph; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h new file mode 100644 index 0000000000..3b393a15bc --- /dev/null +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ANF_TRANSFORM_H +#define MS_ANF_TRANSFORM_H + +#include "schema/inner/model_generated.h" +#include "tools/common/storage.h" +#include "tools/converter/converter_flags.h" +#include "ir/anf.h" + + +namespace mindspore { +namespace lite { +class AnfTransform { + public: + AnfTransform(); + virtual ~AnfTransform(); + FuncGraphPtr Transform(const FuncGraphPtr &old_graph); + void SetGraphDef(schema::MetaGraphT *dstDef); + inline schema::MetaGraphT *GetOutput() { return graphDefT; } + + protected: + schema::MetaGraphT *graphDefT = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc new file mode 100644 index 0000000000..3f3f5a03a3 --- /dev/null +++ b/mindspore/lite/tools/converter/converter.cc @@ -0,0 +1,203 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/converter.h" +#include +#include +#include +#include "tools/converter/converter_flags.h" +#include "src/common/common.h" +#include "src/common/file_utils.h" +#include "ir/func_graph.h" + +#include "utils/log_adapter.h" +#include "tools/common/storage.h" +#include "parser/caffe/caffe_converter.h" +#include "parser/tflite/tflite_converter.h" +#include "src/common/anf_exporter/anf_exporter.h" +#include "src/common/anf_importer/import_from_protobuf.h" +#include "tools/converter/parser/onnx/onnx.pb.h" +#include "tools/converter/quantizer/weight_quantizer.h" +#include "tools/converter/quantizer/post_training.h" +#include "tools/converter/quantizer/quant_cast.h" + +namespace mindspore { +namespace lite { +using FmkType = converter::FmkType; +Converter::Converter() { + this->transform = new GraphDefTransform; + this->anfTransform = new AnfTransform; +} + +Converter::~Converter() { + if (nullptr != modelParser) { + delete modelParser; + } + if (nullptr != modelImporter) { + delete modelImporter; + } + if (nullptr != transform) { + delete transform; + } + if (nullptr != anfTransform) { + delete anfTransform; + } +} + +class MindsporeImporter : public Converter { + public: + MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) { + modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph)); + } + + ~MindsporeImporter() override = default; +}; + +MetaGraphT *Converter::Convert(const converter::Flags *flag) { + // parse the model and weight file to generate inference data structure + FuncGraphPtr graph = nullptr; + if (flag->fmk == converter::FmkType_MS) { + MS_ASSERT(nullptr != modelImporter); + modelImporter->Import(); + graph = modelImporter->GetResult(); + } else { + MS_ASSERT(nullptr != modelParser); + const std::string modelFile = flag->modelFile; + const std::string weightFile = flag->weightFile; + auto meta_graph = modelParser->Parse(modelFile, weightFile); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; + return nullptr; + } + // todo hangangqiang + graph = ModelParser::Fb2Anf(meta_graph); + } + if (graph == nullptr) { + MS_LOG(ERROR) << "Parser/Import model return nullptr"; + return nullptr; + } + + // graph = anfTransform->Transform(graph); + + CreateQuantizer(graph, flag); + if (mQuantizer != nullptr) { + auto status = mQuantizer->DoQuantize(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Quant failed " << status; + return nullptr; + } + quant::QuantCast quant_cast; + quant_cast.SetInputDataDType(kNumberTypeFloat32); + status = quant_cast.Run(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "add QuantCast error"; + return nullptr; + } + } + + // anf -- fb + auto meta_graph = Export(graph); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Export to meta_graph return nullptr"; + return nullptr; + } + + // transform + transform->SetGraphDef(meta_graph); + auto status = transform->Transform(*flag); + if (status != 0) { + MS_LOG(ERROR) << "FBTransform model failed " << status; + return nullptr; + } + return meta_graph; +} +void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { + auto type = flags->quantType; + switch (type) { + case mindspore::schema::QuantType_AwareTrainning: { + // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); + break; + } + case mindspore::schema::QuantType_WeightQuant: { + MS_LOG(INFO) << "create WeightQuantizer!"; + mQuantizer.reset( + new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum)); + break; + } + case mindspore::schema::QuantType_PostTraining: { + MS_LOG(INFO) << "create PostTrainningQuantizer!"; + mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); + break; + } + case mindspore::schema::QuantType_QUANT_NONE: + MS_LOG(INFO) << "Not do quantization for model!"; + break; + default: + MS_LOG(INFO) << "will support quntizer type " << flags->quantTypeIn.c_str() << " in the future!"; + break; + } +} +int RunConverter(int argc, const char **argv) { + auto flags = new converter::Flags; + auto status = flags->Init(argc, argv); + if (status != 0) { + MS_LOG(ERROR) << "converter::Flags Init failed: " << status; + return 1; + } + // Load graph + std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1); + MS_LOG(INFO) << "start reading model file"; + + MetaGraphT *fb_graph = nullptr; + switch (flags->fmk) { + case FmkType::FmkType_MS: { + auto graph = std::make_shared(); + auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile); + MindsporeImporter mindsporeImporter(onnx_graph, graph); + fb_graph = mindsporeImporter.Convert(flags); + break; + } + case FmkType::FmkType_CAFFE: { + CaffeConverter caffeConverter; + fb_graph = caffeConverter.Convert(flags); + } break; + case FmkType::FmkType_TFLITE: { + TfliteConverter tfLiteConverter; + fb_graph = tfLiteConverter.Convert(flags); + } break; + default: { + MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; + return 1; + } + } + if (fb_graph == nullptr) { + MS_LOG(ERROR) << "Convert model return nullptr"; + return 1; + } + + // save graph to file + Storage storage; + status = storage.Save(*fb_graph, flags->outputFile); + if (status != 0) { + MS_LOG(ERROR) << "Save graph failed"; + return 1; + } + MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!"; + + return 0; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h new file mode 100644 index 0000000000..54e3560a87 --- /dev/null +++ b/mindspore/lite/tools/converter/converter.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CONVERTER_H +#define MS_CONVERTER_H + +#include +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/graphdef_transform.h" +#include "tools/converter/model_parser.h" +#include "src/common/anf_importer/anf_importer.h" +#include "tools/converter/converter_flags.h" +#include "tools/converter/anf_transform.h" +#include "tools/converter/quantizer/quantizer.h" + +namespace mindspore { +namespace lite { +class Converter { + public: + Converter(); + virtual ~Converter(); + virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags); + void CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags); + + protected: + ModelParser *modelParser = nullptr; + AnfImporter *modelImporter = nullptr; + GraphDefTransform *transform = nullptr; + AnfTransform *anfTransform = nullptr; + std::unique_ptr mQuantizer = nullptr; +}; + +int RunConverter(int argc, const char **argv); +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc new file mode 100644 index 0000000000..3fd7a9fc45 --- /dev/null +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/converter_flags.h" + + +namespace mindspore { +namespace lite { +namespace converter { +Flags::Flags() { + AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | CAFFE | ONNX | MS | TFLITE", ""); + AddFlag(&Flags::modelFile, "modelFile", + "Input model file path. TF: *.pb | CAFFE: *.prototxt | ONNX: *.onnx | MS: *.ms", ""); + AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); + AddFlag(&Flags::weightFile, "weightFile", + "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); + AddFlag(&Flags::inferenceType, "inferenceType", + "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); + AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", ""); + AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); + AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); + AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); + AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); + AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); + AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); +} + +int Flags::Init(int argc, const char **argv) { + Option err = this->ParseFlags(argc, argv); + + if (err.IsSome()) { + MS_LOG(ERROR) << err.Get(); + std::cerr << this->Usage() << std::endl; + return 1; + } + + if (this->help) { + std::cerr << this->Usage() << std::endl; + return 0; + } + if (this->modelFile.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: model file path is necessary"; + return 1; + } + if (this->outputFile.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: output file path is necessary"; + return 1; + } + + if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { + MS_LOG(ERROR) << "INPUT ILLEGAL: outputFile must be a valid file path"; + return 1; + } + + if (this->fmkIn.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: fmk is necessary"; + return 1; + } + if (this->inputInferenceTypeIn == "FLOAT") { + this->inputInferenceType = 0; + } else if (this->inputInferenceTypeIn == "UINT8") { + this->inputInferenceType = 1; + } else { + MS_LOG(ERROR) << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); + return 1; + } + if (this->fmkIn == "TF") { + this->fmk = FmkType_TF; + } else if (this->fmkIn == "CAFFE") { + this->fmk = FmkType_CAFFE; + } else if (this->fmkIn == "ONNX") { + this->fmk = FmkType_ONNX; + } else if (this->fmkIn == "MS") { + this->fmk = FmkType_MS; + } else if (this->fmkIn == "TFLITE") { + this->fmk = FmkType_TFLITE; + } else { + MS_LOG(ERROR) << "INPUT ILLEGAL: fmk must be TF|CAFFE|ONNX|MS"; + return 1; + } + + if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { + MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile is not a valid flag"; + return 1; + } + if (this->quantTypeIn == "AwareTrainning") { + this->quantType = QuantType_AwareTrainning; + } else if (this->quantTypeIn == "WeightQuant") { + this->quantType = QuantType_WeightQuant; + } else if (this->quantTypeIn == "PostTraining") { + this->quantType = QuantType_PostTraining; + } else if (this->quantTypeIn.empty()) { + this->quantType = QuantType_QUANT_NONE; + } else { + MS_LOG(ERROR) << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; + return 1; + } + + // auto status = ValidateAwareQuantizerCLI(); + // if (status != RET_OK) { + // MS_PRINT_ERROR("Parse aware quantization command line failed: %d", status); + // return status; + // } + // status = ValidateWeighQuantCLI(); + // if (status != RET_OK) { + // MS_PRINT_ERROR("ValidateWeighQuantCLI failed: %d", status); + // return status; + // } + return 0; +} + +// bool Flags::ValidateString(const string pattern, const string input) { +// std::regex repPattern(pattern, std::regex_constants::extended); +// std::match_results regResult; +// return regex_match(input, regResult, repPattern); +//} + +// int Flags::ValidateAwareQuantizerCLI() { +// // check input inference type +// if (this->inputInferenceType == DataType_DT_FLOAT) { +// if (this->mean.empty()) { +// MS_PRINT_ERROR("mean value shound not be null!") +// return RET_PARAM_INVALID; +// } +// if (this->stdDev.empty()) { +// MS_PRINT_ERROR("standard deviation value shound not be null!") +// return RET_PARAM_INVALID; +// } +// const std::string pattern = "^[+-]?([0-9]*\.?[0-9]+|[0-9]+\.?[0-9]*)([eE][+-]?[0-9]+)?$"; +// if (!ValidateString(pattern, this->mean)) { +// MS_PRINT_ERROR("invalid input mean values: %s", this->mean.c_str()); +// return RET_PARAM_INVALID; +// } +// if (!ValidateString(pattern, this->stdDev)) { +// MS_PRINT_ERROR("invalid input standard deviation value: %s", this->stdDev.c_str()); +// return RET_PARAM_INVALID; +// } +// } else { +// if (!this->mean.empty()) { +// MS_PRINT_INFO("useless mean value: %s", this->mean.c_str()); +// } +// if (!this->stdDev.empty()) { +// MS_PRINT_INFO("useless stdDev value: %s", this->stdDev.c_str()); +// } +// } +// return RET_OK; +//} + +// int Flags::ValidateWeighQuantCLI() { +// if (!this->quantSize.empty()) { +// if (!ValidateString("^[0-9]*$", this->quantSize)) { +// MS_PRINT_ERROR("invalid input quantSize: %s, only support positive integer type!", this->quantSize.c_str()); +// return RET_PARAM_INVALID; +// } +// } +// return RET_OK; +//} +} // namespace converter +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h new file mode 100644 index 0000000000..b97d777ae1 --- /dev/null +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CONVERTER_FLAGS_H +#define CONVERTER_FLAGS_H + +#include +#include "tools/common/flag_parser.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +using mindspore::schema::QuantType; +using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_AwareTrainning; +using mindspore::schema::QuantType_WeightQuant; +using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_PostTraining; +namespace converter { +enum FmkType { + FmkType_TF = 0, + FmkType_CAFFE = 1, + FmkType_ONNX = 2, + FmkType_MS = 3, + FmkType_TFLITE = 4 +}; + +class Flags : public virtual mindspore::lite::FlagParser { + public: + Flags(); + + ~Flags() override = default; + + int Init(int argc, const char **argv); + + private: + bool ValidateString(std::string pattern, std::string input); + + // int ValidateAwareQuantizerCLI(); + // + // int ValidateWeighQuantCLI(); + + public: + std::string modelFile; + std::string outputFile; + std::string fmkIn; + FmkType fmk; + std::string weightFile; + std::string inputArrays; + std::string outputArrays; + std::string inputShapes; + // used for quantization + std::string quantTypeIn; + QuantType quantType; + std::string inferenceType; + // used for parse aware trainning + std::string inputInferenceTypeIn; + // mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT; + int inputInferenceType = 0; + std::string stdDev; + std::string mean; + // used for post-trainning-weight + std::string quantSize; + std::string bitNum; + std::string configFile; + bool formatTrans = true; + std::string convWeightQuantChannelThreshold; +}; +} // namespace converter +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc new file mode 100644 index 0000000000..ec92083597 --- /dev/null +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -0,0 +1,183 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/graphdef_transform.h" +#include +#include +#include "schema/model_generated.h" +#include "utils/log_adapter.h" +#include "src/common/op_utils.h" +#include "tools/converter/converter_flags.h" +#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" +// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" +// #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" +// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" +// +// #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h" +// #include "tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h" +// +#include "tools/converter/legacy_optimizer/node/weight_format_pass.h" +#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" +#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" +#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" +#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" + +#include "tools/converter/converter.h" + +using std::string; +namespace mindspore { +namespace lite { +GraphDefTransform::GraphDefTransform() = default; + +GraphDefTransform::~GraphDefTransform() = default; + +void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } + +int GraphDefTransform::Transform(const converter::Flags &ctx) { + STATUS status; + // // constant folding + // { + // Optimizer topologicalSortOptimizer; + // topologicalSortOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + // status = topologicalSortOptimizer.Run(graphDefT); + // if (status != RET_OK) { + // MS_LOG(ERROR)<<"Run topologicalSortOptimizer graphPasses Failed"; + // return status; + // } + // Optimizer constFoldOptimizer; + // constFoldOptimizer.AddPass(new (std::nothrow) AddConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) CastConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ConcatV2ConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ExpandDimsConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) MulConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) RangeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ReshapeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) RsqrtConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ShapeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) SliceConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) StackConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) StridedSliceConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) SubConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) TileConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) TransposeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + // status = constFoldOptimizer.Run(graphDefT); + // if (status != RET_OK && status != RET_NO_CHANGE) { + // MS_LOG(ERROR) << "Run constFoldOptimizer graphPasses Failed"; + // return status; + // } + // } + + // fusion + { + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + return status; + } + } + + // weight format trans + if (ctx.formatTrans) { + Optimizer weightFormatOptimizer; + auto weightFormatPass = new (std::nothrow) WeightFormatPass(); + if (weightFormatPass == nullptr) { + MS_LOG(ERROR) << "new weightFormatPass failed"; + return RET_ERROR; + } + weightFormatPass->SetQuantType(ctx.quantType); + weightFormatPass->SetFmkType(ctx.fmk); + weightFormatOptimizer.AddPass(weightFormatPass); + status = weightFormatOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run weightFormatOptimizer graphPasses Failed"; + return status; + } + } + + // format transform + if (ctx.formatTrans) { + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_ERROR; + } + formatTransPass->SetQuantType(ctx.quantType); + formatTransPass->SetFmk(ctx.fmk); + formatTransOptimizer.AddPass(formatTransPass); + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + // if (ctx.quantType == QuantType_AwareTrainning) { + // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); + // } + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; + } + } + + { + Optimizer unusedOpRemoveOptimizer; + unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); + unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); + status = unusedOpRemoveOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; + return status; + } + } + // topological sorting + { + Optimizer topologicalOptimizer; + topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + status = topologicalOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/graphdef_transform.h b/mindspore/lite/tools/converter/graphdef_transform.h new file mode 100644 index 0000000000..b50579ac99 --- /dev/null +++ b/mindspore/lite/tools/converter/graphdef_transform.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_GRAPHDEF_TRANSFORM_H +#define MS_GRAPHDEF_TRANSFORM_H + +#include "tools/converter/optimizer.h" +// #include "quantizer/quantizer.h" +#include "schema/inner/model_generated.h" +#include "tools/common/storage.h" +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +/* + * transform GraphDef by fusion legacy_optimizer and quantizer + * */ + +class GraphDefTransform { + public: + GraphDefTransform(); + virtual ~GraphDefTransform(); + virtual int Transform(const converter::Flags &ctx); + void SetGraphDef(schema::MetaGraphT *dstDef); + inline schema::MetaGraphT *GetOutput() { return graphDefT; } + void CreateQuantizer(const converter::Flags *flags); + + protected: + schema::MetaGraphT *graphDefT = nullptr; + Optimizer *optimizer = nullptr; + + // std::unique_ptr mQuantizer; +}; +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt new file mode 100755 index 0000000000..898d060738 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt @@ -0,0 +1,6 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(fusion) +#add_subdirectory(const_fold) +add_subdirectory(node) +add_subdirectory(graph) \ No newline at end of file diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/CMakeLists.txt new file mode 100644 index 0000000000..fdd03aca27 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/CMakeLists.txt @@ -0,0 +1,50 @@ +set(OP_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/runtime/allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op_common.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op_factory.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/common/op_func_comm.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/common/op_nc4hw4_comm.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/add.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/cast.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/concat.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/fp32/add_fp32.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/fp32/concat_fp32.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/uint8/add_uint8.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/uint8/concat_uint8.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/expand_dim.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/mul.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/range.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/reshape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/uint8/reshape_uint8.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/rsqrt.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/shape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/slice.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/stack.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/strided_slice.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/sub.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/tile.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/transpose.cc + ) + +add_library(const_fold_mid OBJECT + ${OP_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/add_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cast_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/concat_v2_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expand_dims_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/mul_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/range_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/reshape_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rsqrt_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/shape_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/slice_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/stack_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/strided_slice_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/sub_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tile_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/transpose_const_fold_pass.cc) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.cc new file mode 100644 index 0000000000..b4df8f946f --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/add.h" + +namespace mindspore { +namespace lite { + +STATUS AddConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS AddConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Add; + desc.arch = kCPU; + MS_ASSERT(inputs.size() == kArithOpInputNum); + auto inTensor0 = inputs.at(kArithOpInputTensorIndex0); + auto inTensor1 = inputs.at(kArithOpInputTensorIndex1); + MS_ASSERT(inTensor0 != nullptr); + MS_ASSERT(inTensor1 != nullptr); + DataType dataType; + if (inTensor0->GetNDim() > 1) { + dataType = inTensor0->GetDataType(); + } else { + dataType = inTensor1->GetDataType(); + } + switch (dataType) { + case DataType_DT_UINT8: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT32: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_FLOAT: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT8: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_UINT32: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + default: { + MS_LOGE("Unsupported dataType: %d", dataType); + return RET_ERROR; + } + } + if (op == nullptr) { + MS_LOGE("new OpAdd return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS AddConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kArithOpOutputNum) { + MS_LOGE("The number of output for add must be %u, nodeName: %s", kArithOpOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h new file mode 100644 index 0000000000..28f758a755 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ADD_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_ADD_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class AddConstFoldPass : public ConstFoldPass { + public: + AddConstFoldPass() : ConstFoldPass(OpT_Add) {} + + ~AddConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_ADD_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.cc new file mode 100644 index 0000000000..1d3da9fda8 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/cast.h" + +#define CAST_OUTPUT_NUM 1 + +namespace mindspore { +namespace lite { +STATUS CastConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS CastConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Cast; + desc.arch = kCPU; + op = new (std::nothrow) OpCast(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpCast return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpCast InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpCast Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS CastConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpCast Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != CAST_OUTPUT_NUM) { + MS_LOGE("The number of output for cast must be %u, nodeName: %s", CAST_OUTPUT_NUM, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h new file mode 100644 index 0000000000..65c07dca05 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CAST_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_CAST_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class CastConstFoldPass : public ConstFoldPass { + public: + CastConstFoldPass() : ConstFoldPass(OpT_Cast) {} + + ~CastConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CAST_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.cc new file mode 100644 index 0000000000..17cc102a50 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h" +#include "src/operator/cpu/creator/concat.h" + +namespace mindspore { +namespace lite { + +STATUS ConcatV2ConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ConcatV2ConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Concat; + desc.arch = kCPU; + op = new (std::nothrow) OpConcat(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpConcat return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpConcat InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpConcat Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ConcatV2ConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpConcat Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kConcatOutputNum) { + MS_LOGE("The number of output for concat must be %u, nodeName: %s", kConcatOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h new file mode 100644 index 0000000000..5833892e80 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONCAT_V2_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_CONCAT_V2_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +class ConcatV2ConstFoldPass : public ConstFoldPass { + public: + ConcatV2ConstFoldPass() : ConstFoldPass(OpT_Concat) {} + + ~ConcatV2ConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; + + private: + template + STATUS DoConcat(SubGraphDefT *subGraph, const std::vector &inTensorIdxes, int axis) { + MS_ASSERT(this->outputTensor != nullptr); + std::vector inTensors; + std::vector inDatas; + for (size_t i = 0; i < inTensorIdxes.size(); i++) { + auto &inTensor = subGraph->allTensors.at(inTensorIdxes.at(i)); + MS_ASSERT(inTensor != nullptr); + inTensors.emplace_back(inTensor.get()); + void *inData = inTensor->data.data(); + MS_ASSERT(inData != nullptr); + T *castedInData = static_cast(inData); + MS_ASSERT(castedInData != nullptr); + inDatas.emplace_back(castedInData); + } + auto &inShape = subGraph->allTensors.at(inTensorIdxes.at(0))->dims; + std::vector outputDims; + for (size_t i = 0; i < inShape.size(); i++) { + if (i == axis) { + int32_t axisDim = 0; + for (size_t j = 0; j < inTensors.size(); j++) { + axisDim += inTensors.at(j)->dims.at(i); + } + outputDims.push_back(axisDim); + continue; + } + outputDims.push_back(inShape.at(i)); + } + + size_t outShapeSize = 1; + for (auto dim : outputDims) { + outShapeSize *= dim; + } + size_t elementSize = GetElementSize(subGraph->allTensors.at(inTensorIdxes.at(0))->dataType); + + this->outputTensor->dims = outputDims; + this->outputTensor->data.clear(); + this->outputTensor->data.resize(outShapeSize * elementSize); + + void *outData = this->outputTensor->data.data(); + MS_ASSERT(outData != nullptr); + T *castedOutData = static_cast(outData); + + size_t copyBlockTile = 1; + for (int i = axis + 1; i < inShape.size(); i++) { + copyBlockTile *= inShape[i]; + } + std::vector inCopyBlocks; + size_t outCopyBlock = 0; + for (size_t i = 0; i < inTensors.size(); i++) { + inCopyBlocks.emplace_back(copyBlockTile * (inTensors.at(i)->dims.at(axis))); + outCopyBlock += inCopyBlocks.back(); + } + + size_t outIndex = 0; + while (outIndex < outShapeSize) { + for (size_t i = 0; i < inDatas.size(); i++) { + ::memcpy_s(castedOutData + outIndex, inCopyBlocks.at(i), inDatas.at(i), inCopyBlocks.at(i)); + outIndex += inCopyBlocks.at(i); + inDatas.at(i) += inCopyBlocks.at(i); + } + } + + return RET_OK; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONCAT_V2_CONST_FOLD_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/const_fold_pass.cc new file mode 100644 index 0000000000..8d12c1d4c9 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/const_fold_pass.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/const_fold_pass.h" +#include +#include "utils/log_adapter.h" +#include "converter/common/graph_util.h" + +namespace mindspore { +namespace lite { +STATUS ConstFoldPass::Run(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto subGraph = graphNode->subGraph; + auto node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + if (GetOpType(*node) != opType) { + return RET_OK; + } + if (!IsFoldable(subGraph, node)) { + MS_LOGD("All input should be ConstTensor, node : %s"); + return RET_OK; + } + + for (uint32_t i : node->inputIndex) { + TensorDefT *tensorDefT = subGraph->allTensors.at(i).get(); + MS_ASSERT(tensorDefT != nullptr); + auto tensor = CopyTensorDefT2Tensor(tensorDefT); + if (tensor == nullptr) { + MS_LOGE("Pack TensorDefT return nullptr"); + FreeTensors(); + return RET_ERROR; + } + inputs.emplace_back(tensor); + } + for (uint32_t i : node->outputIndex) { + TensorDefT *tensorDefT = subGraph->allTensors.at(i).get(); + MS_ASSERT(tensorDefT != nullptr); + auto tensor = CopyTensorDefT2Tensor(tensorDefT, false); + if (tensor == nullptr) { + MS_LOGE("Pack TensorDefT return nullptr"); + FreeTensors(); + return RET_ERROR; + } + outputs.emplace_back(tensor); + } + + auto status = CreateOp(subGraph, node); + if (status != RET_OK) { + MS_LOGE("CreateOp error: %d, node: %s", status, node->name.c_str()); + FreeTensors(); + return status; + } + for (auto &outputTensor : outputs) { + auto statusTmp = outputTensor->MallocData(); + if (statusTmp != RET_OK) { + MS_LOGE("OutTensor MallocData error: %d, nodeName: %s", statusTmp, node->name.c_str()); + FreeTensors(); + return RET_ERROR; + } + } + status = DoFold(subGraph, node); + if (status != RET_OK) { + MS_LOGE("DoFold error: %d, node: %s", status, node->name.c_str()); + FreeTensors(); + return status; + } + + if (this->outputTensor->data.empty()) { + MS_LOGI("outputTensor's data has not been set, node : %s", node->name.c_str()); + FreeTensors(); + return RET_OK; + } + this->outputTensor->refCount = schema::NodeType_ValueNode; + bool isSubNode = false; + for (auto &inNode : subGraph->nodes) { + if (inNode->name == node->name) { + isSubNode = true; + break; + } + } + if (!isSubNode) { + MS_LOGE("Node %s is not in subGraph %s", node->name.c_str(), subGraph->name.c_str()); + return RET_PARAM_INVALID; + } else { + status = RemoveTensor(subGraph, node->inputIndex); + if (status != RET_OK) { + MS_LOGE("RemoveTensor failed, node : %s", node->name.c_str()); + FreeTensors(); + return status; + } + // we can not erase nodes in iter loop, so just isolate the node + node->inputIndex.clear(); + node->outputIndex.clear(); + } + + FreeTensors(); + return RET_OK; +} + +OpDef *ConstFoldPass::PackOpDefT(const OpDefT *opDefT) { + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = OpDef::Pack(builder, opDefT); + builder.Finish(offset); + auto buf = builder.GetBufferPointer(); + auto opDef = flatbuffers::GetRoot(buf); + return const_cast(opDef); +} + +Tensor *ConstFoldPass::CopyTensorDefT2Tensor(const TensorDefT *tensorDefT, bool needCopyData) { + if (tensorDefT == nullptr) { + MS_LOGE("tensorDefT is null"); + return nullptr; + } + std::vector dims; + for (size_t i = 0; i < tensorDefT->dims.size(); i++) { + dims.emplace_back(tensorDefT->dims.at(i)); + } + + auto tensor = new (std::nothrow) Tensor(tensorDefT->dataType, dims, tensorDefT->format, nullptr); + if (tensor == nullptr) { + MS_LOGE("new tensor error"); + return nullptr; + } + if (needCopyData) { + auto status = tensor->MallocData(); + if (status != RET_OK) { + MS_LOGE("malloc tensor data error: %d", status); + delete (tensor); + return nullptr; + } + size_t dataLength = tensor->GetDataSize(); + status = ::memcpy_s(tensor->GetData(), dataLength, tensorDefT->data.data(), dataLength); + if (status != 0) { + MS_LOGE("memcpy_s error: %d", status); + delete (tensor); + return nullptr; + } + } + return tensor; +} + +STATUS ConstFoldPass::CopyTensor2TensorDefT(const Tensor *tensor, TensorDefT *tensorDefT) { + MS_ASSERT(tensorDefT != nullptr); + if (tensor == nullptr) { + MS_LOGE("tensor is null"); + return RET_ERROR; + } + + tensorDefT->dims.clear(); + for (size_t i = 0; i < tensor->GetNDim(); i++) { + tensorDefT->dims.emplace_back(tensor->GetDims().at(i)); + } + tensorDefT->dataType = tensor->GetDataType(); + tensorDefT->format = tensor->GetFormat(); + size_t dataLength = tensor->GetDataSize(); + tensorDefT->data.resize(dataLength); + auto ret = ::memcpy_s(tensorDefT->data.data(), dataLength, tensor->GetData(), dataLength); + if (ret != 0) { + MS_LOGE("memcpy_s error: %d", ret); + return RET_ERROR; + } + return RET_OK; +} + +bool ConstFoldPass::IsFoldable(SubGraphDefT *subGraph, OpDefT *node) { + bool isFoldable = true; + for (auto tensorIdx : node->inputIndex) { + auto &tensor = subGraph->allTensors.at(tensorIdx); + if (tensor->refCount != schema::NodeType_ValueNode || tensor->data.empty()) { + isFoldable = false; + break; + } + } + return isFoldable; +} + +void ConstFoldPass::FreeTensors() { + for (auto tensor : inputs) { + if (tensor != nullptr) { + delete (tensor); + } + } + inputs.clear(); + for (auto tensor : outputs) { + if (tensor != nullptr) { + delete (tensor); + } + } + outputs.clear(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/const_fold_pass.h new file mode 100644 index 0000000000..cbc0103598 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/const_fold_pass.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_CONST_FOLD_PASS_H + +#include +#include "mindspore/lite/tools/converter/optimizer.h" +#include "include/tensor.h" +#include "utils/log_adapter.h" +#include "converter/common/converter_op_utils.h" +#include "securec/include/securec.h" +#include "src/op.h" + +namespace mindspore { +namespace lite { +class ConstFoldPass : public NodePass { + public: + explicit ConstFoldPass(schema::PrimitiveType opType) : opType(opType) {} + + ~ConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + protected: + bool IsFoldable(SubGraphDefT *subGraph, OpDefT *node); + + virtual STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) = 0; + + virtual STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) = 0; + + protected: + OpDef *PackOpDefT(const OpDefT *opDefT); + + Tensor *CopyTensorDefT2Tensor(const TensorDefT *tensorDefT, bool needCopyData = true); + + STATUS CopyTensor2TensorDefT(const Tensor *tensor, TensorDefT *tensorDefT); + + void FreeTensors(); + + protected: + schema::PrimitiveType opType; + TensorDefT *outputTensor = nullptr; + std::vector inputs; + std::vector outputs; + OpBase *op = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONST_FOLD_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.cc new file mode 100644 index 0000000000..4e7e80cacb --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/expand_dim.h" + +namespace mindspore { +namespace lite { +STATUS ExpandDimsConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ExpandDimsConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_ExpandDims; + desc.arch = kCPU; + op = new (std::nothrow) OpExpandDim(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpExpandDim return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpExpandDim InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpExpandDim Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ExpandDimsConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpExpandDim Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kExpandDimsOutputNum) { + MS_LOGE("The number of output for expandDim must be %u, nodeName: %s", kExpandDimsOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h new file mode 100644 index 0000000000..12e9d979a1 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_EXPANDDIMS_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_EXPANDDIMS_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class ExpandDimsConstFoldPass : public ConstFoldPass { + public: + ExpandDimsConstFoldPass() : ConstFoldPass(OpT_ExpandDims) {} + + ~ExpandDimsConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_EXPANDDIMS_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.cc new file mode 100644 index 0000000000..0c5be36892 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "converter/common/tensor_util.h" +#include "converter/common/converter_op_utils.h" +#include "src/operator/cpu/creator/mul.h" + +namespace mindspore { +namespace lite { +STATUS MulConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS MulConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Mul; + desc.arch = kCPU; + MS_ASSERT(inputs.size() == kArithOpInputNum); + auto inTensor0 = inputs.at(kArithOpInputTensorIndex0); + auto inTensor1 = inputs.at(kArithOpInputTensorIndex1); + MS_ASSERT(inTensor0 != nullptr); + MS_ASSERT(inTensor1 != nullptr); + DataType dataType; + if (inTensor0->GetNDim() > 1) { + dataType = inTensor0->GetDataType(); + } else { + dataType = inTensor1->GetDataType(); + } + op = nullptr; + switch (dataType) { + case DataType_DT_UINT8: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT32: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_FLOAT: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT8: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_UINT32: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + default: { + MS_LOGE("Unsupported dataType: %d", dataType); + return RET_ERROR; + } + } + if (op == nullptr) { + MS_LOGE("new OpMul return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpMul InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpMul Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS MulConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpMul Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kArithOpOutputNum) { + MS_LOGE("The number of output for mul must be %u, nodeName: %s", kArithOpOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h new file mode 100644 index 0000000000..6c47b60c63 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_MUL_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_MUL_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class MulConstFoldPass : public ConstFoldPass { + public: + MulConstFoldPass() : ConstFoldPass(OpT_Mul) {} + + ~MulConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_MUL_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.cc new file mode 100644 index 0000000000..d79b60007c --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/range.h" + +namespace mindspore { +namespace lite { +#define kRangeOutputNum 1 + +STATUS RangeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS RangeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Range; + desc.arch = kCPU; + op = new (std::nothrow) OpRange(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpAdd return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS RangeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kRangeOutputNum) { + MS_LOGE("The number of range for range must be %u, nodeName: %s", kRangeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h new file mode 100644 index 0000000000..e8b48e4004 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_RANGE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_RANGE_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class RangeConstFoldPass : public ConstFoldPass { + public: + RangeConstFoldPass() : ConstFoldPass(OpT_Range) {} + + ~RangeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_RANGE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.cc new file mode 100644 index 0000000000..e40c09adea --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/reshape.h" + +namespace mindspore { +namespace lite { +STATUS ReshapeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ReshapeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Reshape; + desc.arch = kCPU; + op = new (std::nothrow) OpReshape(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpReshape return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpReshape InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpReshape Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ReshapeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpReshape Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kReshapeOutputNum) { + MS_LOGE("The number of output for Reshape must be %u, nodeName: %s", kReshapeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h new file mode 100644 index 0000000000..13a7950395 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_RESHAPE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_RESHAPE_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class ReshapeConstFoldPass : public ConstFoldPass { + public: + ReshapeConstFoldPass() : ConstFoldPass(OpT_Reshape) {} + + ~ReshapeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; + + private: + STATUS CalNewShape(const TensorDefT &inTensor, std::vector &outShape); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_RESHAPE_CONST_FOLD_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.cc new file mode 100644 index 0000000000..849eca570e --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/fp32/rsqrt_fp32.h" + +namespace mindspore { +namespace lite { + +STATUS RsqrtConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS RsqrtConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Rsqrt; + desc.arch = kCPU; + op = new (std::nothrow) RsqrtFp32(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpRsqrt return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpRsqrt InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpRsqrt Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS RsqrtConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpRsqrt Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kRsqrtOutputNum) { + MS_LOGE("The number of output for Rsqrt must be %u, nodeName: %s", kRsqrtOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h new file mode 100644 index 0000000000..7ce1fc1611 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_RSQRT_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_RSQRT_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class RsqrtConstFoldPass : public ConstFoldPass { + public: + RsqrtConstFoldPass() : ConstFoldPass(OpT_Rsqrt) {} + + ~RsqrtConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_RSQRT_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.cc new file mode 100644 index 0000000000..7b69eb0af6 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h" +#include "src/operator/cpu/creator/shape.h" + +namespace mindspore { +namespace lite { +STATUS ShapeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ShapeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Shape; + desc.arch = kCPU; + op = new (std::nothrow) OpShape(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpShape return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpShape InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpShape Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ShapeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpShape Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kShapeOutputNum) { + MS_LOGE("The number of output for shape must be %u, nodeName: %s", kShapeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h new file mode 100644 index 0000000000..7f05a9b9e2 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_SHAPE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_SHAPE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class ShapeConstFoldPass : public ConstFoldPass { + public: + ShapeConstFoldPass() : ConstFoldPass(OpT_Shape) {} + + ~ShapeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_SHAPE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.cc new file mode 100644 index 0000000000..8c3c4ba345 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h" +#include "src/operator/cpu/creator/slice.h" + +namespace mindspore { +namespace lite { +// todo if slice op has placeholder tensor +STATUS SliceConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS SliceConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Slice; + desc.arch = kCPU; + op = new (std::nothrow) OpSlice(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpSlice return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSlice InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSlice Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS SliceConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSlice Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kSliceOutputNum) { + MS_LOGE("The number of output for slice must be %u, nodeName: %s", kSliceOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h new file mode 100644 index 0000000000..c5d7ca3470 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_SLICE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_SLICE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +// This Op only supports 1-4D cases +class SliceConstFoldPass : public ConstFoldPass { + public: + SliceConstFoldPass() : ConstFoldPass(OpT_Slice) {} + + ~SliceConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_SLICE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.cc new file mode 100644 index 0000000000..6cf4d9d060 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h" +#include "src/operator/cpu/creator/stack.h" + +namespace mindspore { +namespace lite { +STATUS StackConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS StackConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Stack; + desc.arch = kCPU; + op = new (std::nothrow) OpStack(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpStack return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStack InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStack Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS StackConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStack Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kStackOutputNum) { + MS_LOGE("The number of output for stack must be %u, nodeName: %s", kStackOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h new file mode 100644 index 0000000000..2f36669616 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_STACK_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_STACK_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "securec/include/securec.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +class StackConstFoldPass : public ConstFoldPass { + public: + StackConstFoldPass() : ConstFoldPass(OpT_Stack) {} + + ~StackConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_STACK_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.cc new file mode 100644 index 0000000000..2a42b1e5b0 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h" +#include "src/operator/cpu/creator/strided_slice.h" + +namespace mindspore { +namespace lite { +STATUS StridedSliceConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS StridedSliceConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Slice; + desc.arch = kCPU; + op = new (std::nothrow) OpStridedSlice(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpStridedSlice return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStridedSlice InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStridedSlice Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS StridedSliceConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStridedSlice Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kStridedSliceOutputNum) { + MS_LOGE("The number of output for slice must be %u, nodeName: %s", kStridedSliceOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h new file mode 100644 index 0000000000..bbb3387141 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_STRIDED_SLICE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_STRIDED_SLICE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +// This Op only supports 1-4D cases +class StridedSliceConstFoldPass : public ConstFoldPass { + public: + StridedSliceConstFoldPass() : ConstFoldPass(OpT_StridedSlice) {} + + ~StridedSliceConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_STRIDED_SLICE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.cc new file mode 100644 index 0000000000..8575c8d4ff --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h" + +#include "utils/log_adapter.h" +#include "converter/common/tensor_util.h" +#include "converter/common/converter_op_utils.h" +#include "src/operator/cpu/creator/sub.h" + +namespace mindspore { +namespace lite { + +STATUS SubConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS SubConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Sub; + desc.arch = kCPU; + MS_ASSERT(inputs.size() == kArithOpInputNum); + auto inTensor0 = inputs.at(kArithOpInputTensorIndex0); + auto inTensor1 = inputs.at(kArithOpInputTensorIndex1); + MS_ASSERT(inTensor0 != nullptr); + MS_ASSERT(inTensor1 != nullptr); + DataType dataType; + if (inTensor0->GetNDim() > 1) { + dataType = inTensor0->GetDataType(); + } else { + dataType = inTensor1->GetDataType(); + } + switch (dataType) { + case DataType_DT_UINT8: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT32: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_FLOAT: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT8: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_UINT32: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + default: { + MS_LOGE("Unsupported dataType: %d", dataType); + return RET_ERROR; + } + } + if (op == nullptr) { + MS_LOGE("new OpSub return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSub InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSub Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS SubConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSub Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kArithOpOutputNum) { + MS_LOGE("The number of output for sub must be %u, nodeName: %s", kArithOpOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h new file mode 100644 index 0000000000..2feecb2954 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_SUB_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_SUB_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class SubConstFoldPass : public ConstFoldPass { + public: + SubConstFoldPass() : ConstFoldPass(OpT_Sub) {} + + ~SubConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_SUB_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.cc new file mode 100644 index 0000000000..efa50c7b14 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/tile.h" + +namespace mindspore { +namespace lite { +STATUS TileConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS TileConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Tile; + desc.arch = kCPU; + op = new (std::nothrow) OpTile(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpTile return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTile InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTile Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS TileConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTile Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kTileOutputNum) { + MS_LOGE("The number of output for tile must be %u, nodeName: %s", kTileOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h new file mode 100644 index 0000000000..df7404b8cd --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TILE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_TILE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +class TileConstFoldPass : public ConstFoldPass { + public: + TileConstFoldPass() : ConstFoldPass(OpT_Tile) {} + + ~TileConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TILE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.cc new file mode 100644 index 0000000000..d76a2ca581 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/transpose.h" + +namespace mindspore { +namespace lite { + +STATUS TransposeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS TransposeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Transpose; + desc.arch = kCPU; + op = new (std::nothrow) OpTranspose(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpTranspose return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTranspose InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTranspose Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS TransposeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTranspose Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kTransposeOutputNum) { + MS_LOGE("The number of output for transpose must be %u, nodeName: %s", kTransposeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h new file mode 100644 index 0000000000..902b564c89 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TRANSPOSE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_TRANSPOSE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class TransposeConstFoldPass : public ConstFoldPass { + public: + TransposeConstFoldPass() : ConstFoldPass(OpT_Transpose) {} + + ~TransposeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TRANSPOSE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt new file mode 100755 index 0000000000..32aa9d4dac --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(fusion_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc + ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_bias_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_bn_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_activation_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu6_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_biasadd_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc + ) + +target_link_libraries(fusion_mid securec) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc new file mode 100644 index 0000000000..ec1caa520f --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc @@ -0,0 +1,500 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define kBatchNormFoldFusionPathLen6 6 +#define kBatchNormFoldFusionPathLen7 7 + +STATUS BatchNormFoldFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS BatchNormFoldFusionPass::DefinePattern() { + // with preNode + { + auto inputOp = std::make_shared(); + inputOp->id = inputOpName; + inputOp->types = {schema::PrimitiveType_NONE}; + inputOp->isPlaceHold = true; + + auto convOp1 = std::make_shared(); + convOp1->id = convPatternOpName1; + convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + convOp1->left = inputOp; + + auto bnFoldOp = std::make_shared(); + bnFoldOp->id = bnFoldOpName; + bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; + bnFoldOp->left = convOp1; + + auto mulFoldOp = std::make_shared(); + mulFoldOp->id = mulFoldOpName; + mulFoldOp->types = {schema::PrimitiveType_MulFold}; + mulFoldOp->left = bnFoldOp; + + auto fakeQuantOp = std::make_shared(); + fakeQuantOp->id = fakeQuantOpName; + fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; + fakeQuantOp->left = mulFoldOp; + + auto convOp2 = std::make_shared(); + convOp2->id = convPatternOpName2; + convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + convOp2->left = fakeQuantOp; + convOp2->right = inputOp; + + auto addFoldOp = std::make_shared(); + addFoldOp->id = addFoldOpName; + addFoldOp->types = {schema::PrimitiveType_AddFold}; + addFoldOp->left = convOp2; + addFoldOp->right = bnFoldOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(withPrePatternName)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(inputOp); + fusionPattern->AddPatternOp(convOp1); + fusionPattern->AddPatternOp(bnFoldOp); + fusionPattern->AddPatternOp(mulFoldOp); + fusionPattern->AddPatternOp(fakeQuantOp); + fusionPattern->AddPatternOp(convOp2); + fusionPattern->AddPatternOp(addFoldOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + // no preNode + { + auto convOp1 = std::make_shared(); + convOp1->id = convPatternOpName1; + convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + + auto bnFoldOp = std::make_shared(); + bnFoldOp->id = bnFoldOpName; + bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; + bnFoldOp->left = convOp1; + + auto mulFoldOp = std::make_shared(); + mulFoldOp->id = mulFoldOpName; + mulFoldOp->types = {schema::PrimitiveType_MulFold}; + mulFoldOp->left = bnFoldOp; + + auto fakeQuantOp = std::make_shared(); + fakeQuantOp->id = fakeQuantOpName; + fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; + fakeQuantOp->left = mulFoldOp; + + auto convOp2 = std::make_shared(); + convOp2->id = convPatternOpName2; + convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + convOp2->left = fakeQuantOp; + + auto addFoldOp = std::make_shared(); + addFoldOp->id = addFoldOpName; + addFoldOp->types = {schema::PrimitiveType_AddFold}; + addFoldOp->left = convOp2; + addFoldOp->right = bnFoldOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(noPrePatternName)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp1); + fusionPattern->AddPatternOp(bnFoldOp); + fusionPattern->AddPatternOp(mulFoldOp); + fusionPattern->AddPatternOp(fakeQuantOp); + fusionPattern->AddPatternOp(convOp2); + fusionPattern->AddPatternOp(addFoldOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (patternName == withPrePatternName) { + if (matchedPath.size() != kBatchNormFoldFusionPathLen7) { + MS_LOG(ERROR) << "BatchNormFold-Fusion should have seven NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + } else if (patternName == noPrePatternName) { + if (matchedPath.size() != kBatchNormFoldFusionPathLen6) { + MS_LOG(ERROR) << "BatchNormFold-Fusion should have six NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + } + + auto status = FindNodes(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "FindNodes failed: " << status; + return status; + } + status = CheckPath(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "CheckPath failed: " << status; + return status; + } + status = FindTensors(); + if (status != RET_OK) { + MS_LOG(ERROR) << "FindTensors failed: " << status; + return status; + } + status = GenNewWeightTensor(); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenNewWeightTensor failed: " << status; + return status; + } + status = GenNewBiasTensor(); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenNewBiasTensor failed: " << status; + return status; + } + status = IsolateNodes(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateNodes failed: " << status; + return status; + } + UpdateConvWeights(); + status = DeleteConstTensors(); + if (status != RET_OK) { + MS_LOG(ERROR) << "DeleteConstTensors failed: " << status; + return status; + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::FindNodes(MetaGraphT *graph, + const std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + auto preConvPath = matchedPath.at(convPatternOpName1); + auto bnFoldPath = matchedPath.at(bnFoldOpName); + auto mulFoldPath = matchedPath.at(mulFoldOpName); + auto fakeQuantPath = matchedPath.at(fakeQuantOpName); + auto convPath = matchedPath.at(convPatternOpName2); + auto addFoldPath = matchedPath.at(addFoldOpName); + MS_ASSERT(preConvPath != nullptr); + MS_ASSERT(bnFoldPath != nullptr); + MS_ASSERT(mulFoldPath != nullptr); + MS_ASSERT(fakeQuantPath != nullptr); + MS_ASSERT(convPath != nullptr); + MS_ASSERT(addFoldPath != nullptr); + if (preConvPath->subGraphIdx != bnFoldPath->subGraphIdx || preConvPath->subGraphIdx != mulFoldPath->subGraphIdx || + preConvPath->subGraphIdx != fakeQuantPath->subGraphIdx || preConvPath->subGraphIdx != convPath->subGraphIdx || + preConvPath->subGraphIdx != addFoldPath->subGraphIdx) { + MS_LOG(ERROR) << "matched nodes should from same subGraph"; + return RET_ERROR; + } + MS_ASSERT(graph->nodes.size() > preConvPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > bnFoldPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > mulFoldPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > fakeQuantPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > convPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > addFoldPath->nodeIdx); + preConv = graph->nodes.at(preConvPath->nodeIdx).get(); + bnFold = graph->nodes.at(bnFoldPath->nodeIdx).get(); + mulFold = graph->nodes.at(mulFoldPath->nodeIdx).get(); + fakeNode = graph->nodes.at(fakeQuantPath->nodeIdx).get(); + convNode = graph->nodes.at(convPath->nodeIdx).get(); + addFold = graph->nodes.at(addFoldPath->nodeIdx).get(); + MS_ASSERT(preConv != nullptr); + MS_ASSERT(bnFold != nullptr); + MS_ASSERT(mulFold != nullptr); + MS_ASSERT(fakeNode != nullptr); + MS_ASSERT(convNode != nullptr); + MS_ASSERT(addFold != nullptr); + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::FindTensors() { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnFold != nullptr); + MS_ASSERT(addFold != nullptr); + if (bnFold->inputIndex.size() != 4) { + MS_LOG(ERROR) << "BatchNormFold node should have 4 inputTensor, got " << bnFold->inputIndex.size() + << " input tensors"; + return RET_ERROR; + } + if (addFold->inputIndex.size() != 5) { + MS_LOG(ERROR) << "AddFold node should have 5 inputTensor, got " << addFold->inputIndex.size() << " input tensors"; + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(1)); + muTensor = graph->allTensors.at(bnFold->inputIndex.at(1)).get(); + MS_ASSERT(muTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(2)); + sigmaTensor = graph->allTensors.at(bnFold->inputIndex.at(2)).get(); + MS_ASSERT(sigmaTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(1)); + betaTensor = graph->allTensors.at(addFold->inputIndex.at(1)).get(); + MS_ASSERT(betaTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(2)); + gammaTensor = graph->allTensors.at(addFold->inputIndex.at(2)).get(); + MS_ASSERT(gammaTensor != nullptr); + + if (betaTensor->dims.size() != 1) { + MS_LOG(ERROR) << "ConstTensor should have only one dim, got " << betaTensor->dims.size(); + return RET_ERROR; + } + if (betaTensor->dims != gammaTensor->dims || betaTensor->dims != sigmaTensor->dims || + betaTensor->dims != muTensor->dims) { + MS_LOG(ERROR) << "All ConstTensor should have same dims"; + return RET_ERROR; + } + channelOut = betaTensor->dims.front(); + + MS_ASSERT(mulFold != nullptr); + if (mulFold->inputIndex.size() != 3) { + MS_LOG(ERROR) << "MulFold node should have 3 outputTensor, got " << addFold->inputIndex.size() << " output tensors"; + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > mulFold->inputIndex.front()); + oldWeightTensor = graph->allTensors.at(mulFold->inputIndex.front()).get(); + MS_ASSERT(oldWeightTensor != nullptr); + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::CheckPath(MetaGraphT *graph, + const std::unordered_map> &matchedPath) { + MS_ASSERT(preConv != nullptr); + MS_ASSERT(convNode != nullptr); + MS_ASSERT(mulFold != nullptr); + MS_ASSERT(preConv->inputIndex.size() == 2); + MS_ASSERT(convNode->inputIndex.size() == 2); + MS_ASSERT(mulFold->inputIndex.size() == 3); + MS_ASSERT(preConv->inputIndex.front() == convNode->inputIndex.front()); + MS_ASSERT(preConv->inputIndex.at(1) == mulFold->inputIndex.front()); + // todo + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::GenNewWeightTensor() { + MS_ASSERT(oldWeightTensor != nullptr); + MS_ASSERT(oldWeightTensor->dataType == DataType_DT_FLOAT); + MS_ASSERT(oldWeightTensor->refCount == schema::NodeType_ValueNode); + auto weightShape = oldWeightTensor->dims; + if (weightShape.size() != 4) { + MS_LOG(ERROR) << "shape of weight should be 4 dims, got " << weightShape.size() << " dims"; + return RET_ERROR; + } + if (weightShape.front() != channelOut) { + MS_LOG(ERROR) << "weight should be in KCHW format, and outputChannel should be " << channelOut; + return RET_ERROR; + } + auto weightShapeSize = GetShapeSize(*oldWeightTensor); + newWeightTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (newWeightTensor == nullptr) { + MS_LOG(ERROR) << "new weightTensor failed"; + return RET_ERROR; + } + newWeightTensor->dataType = oldWeightTensor->dataType; + newWeightTensor->format = oldWeightTensor->format; + newWeightTensor->refCount = schema::NodeType_ValueNode; + newWeightTensor->dims = weightShape; + newWeightTensor->data.resize(weightShapeSize * sizeof(float)); + void *oldWeightData = oldWeightTensor->data.data(); + auto castedOldWeightData = static_cast(oldWeightData); + void *newWeightData = newWeightTensor->data.data(); + auto castedNewWeightData = static_cast(newWeightData); + MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); + void *gammaData = gammaTensor->data.data(); + auto *castedGammaData = static_cast(gammaData); + MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); + void *miData = muTensor->data.data(); + auto *castedMiData = static_cast(miData); + size_t stride = weightShapeSize / channelOut; + for (size_t i = 0; i < channelOut; i++) { + for (size_t j = 0; j < stride; j++) { + castedNewWeightData[i * stride + j] = castedOldWeightData[i * stride + j] * castedGammaData[i] / castedMiData[i]; + } + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::GenNewBiasTensor() { // bias has no quant + std::vector biasShape = {channelOut}; + newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (newBiasTensor == nullptr) { + MS_LOG(ERROR) << "new BiasTensor failed"; + return RET_ERROR; + } + newBiasTensor->dataType = 0; // todo is float + newBiasTensor->format = Format_NUM_OF_FORMAT; + newBiasTensor->refCount = schema::NodeType_ValueNode; + newBiasTensor->dims = biasShape; + newBiasTensor->data.resize(channelOut * sizeof(float)); + void *newBiasData = newBiasTensor->data.data(); + auto castedNewBiasData = static_cast(newBiasData); + MS_ASSERT(betaTensor->dataType == DataType_DT_FLOAT); + void *betaData = betaTensor->data.data(); + auto *castedBetaData = static_cast(betaData); + MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); + void *gammaData = gammaTensor->data.data(); + auto *castedGammaData = static_cast(gammaData); + MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); + void *miData = muTensor->data.data(); + auto *castedMiData = static_cast(miData); + MS_ASSERT(sigmaTensor->dataType == DataType_DT_FLOAT); + void *sigmaData = sigmaTensor->data.data(); + auto *castedSigmaData = static_cast(sigmaData); + for (size_t i = 0; i < channelOut; i++) { + castedNewBiasData[i] = castedBetaData[i] - castedGammaData[i] * castedMiData[i] / castedSigmaData[i]; + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::IsolateNodes( + MetaGraphT *graph, const std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + auto preConvPath = matchedPath.at(convPatternOpName1); + auto bnFoldPath = matchedPath.at(bnFoldOpName); + auto mulFoldPath = matchedPath.at(mulFoldOpName); + auto fakeQuantPath = matchedPath.at(fakeQuantOpName); + auto convPath = matchedPath.at(convPatternOpName2); + auto addFoldPath = matchedPath.at(addFoldOpName); + MS_ASSERT(preConvPath != nullptr); + MS_ASSERT(bnFoldPath != nullptr); + MS_ASSERT(mulFoldPath != nullptr); + MS_ASSERT(fakeQuantPath != nullptr); + MS_ASSERT(convPath != nullptr); + MS_ASSERT(addFoldPath != nullptr); + auto status = IsolateOneWayNode(graph, preConvPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << preConv->name.c_str() << " failed, error: " << status; + return status; + } + std::vector toDeleteTensorIdxes; + toDeleteTensorIdxes.emplace_back(bnFold->inputIndex.at(3)); + toDeleteTensorIdxes.insert(toDeleteTensorIdxes.end(), bnFold->outputIndex.begin(), bnFold->outputIndex.end()); + status = RemoveTensor(graph, toDeleteTensorIdxes, true); + if (status != RET_OK) { + MS_LOG(ERROR) << "Remove Tensors of BnFold " << bnFold->name.c_str() << " failed, error: " << status; + return RET_ERROR; + } + status = IsolateOneWayNode(graph, bnFoldPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << bnFold->name.c_str() << " failed, error: " << status; + return status; + } + status = IsolateOneWayNode(graph, mulFoldPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << mulFold->name.c_str() << " failed, error: " << status; + return status; + } + status = IsolateOneWayNode(graph, addFoldPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << addFold->name.c_str() << " failed, error: " << status; + return status; + } + return RET_OK; +} + +void BatchNormFoldFusionPass::UpdateConvWeights() { + MS_ASSERT(graph != nullptr); + MS_ASSERT(convNode != nullptr); + MS_ASSERT(newWeightTensor != nullptr); + MS_ASSERT(newBiasTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > fakeNode->inputIndex.at(0)); + graph->allTensors.at(fakeNode->inputIndex.at(0)).reset(); + graph->allTensors.at(fakeNode->inputIndex.at(0)) = std::move(this->newWeightTensor); + graph->allTensors.emplace_back(std::move(this->newBiasTensor)); + convNode->inputIndex.emplace_back(graph->allTensors.size() - 1); + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; + } else { + MS_ASSERT(false); + } + + this->oldWeightTensor = nullptr; + this->newWeightTensor = nullptr; + this->newBiasTensor = nullptr; +} + +STATUS BatchNormFoldFusionPass::DeleteConstTensors() { + MS_ASSERT(graph != nullptr); + bool muFind = false; + bool sigmaFind = false; + bool betaFind = false; + bool gammaFind = false; + std::vector toDeleteTensorIdxes; + for (size_t i = 0; i < graph->allTensors.size(); i++) { + auto &tensor = graph->allTensors.at(i); + if (tensor.get() == muTensor) { + toDeleteTensorIdxes.emplace_back(i); + muFind = true; + this->muTensor = nullptr; + } + if (tensor.get() == sigmaTensor) { + toDeleteTensorIdxes.emplace_back(i); + sigmaFind = true; + this->sigmaTensor = nullptr; + } + if (tensor.get() == gammaTensor) { + toDeleteTensorIdxes.emplace_back(i); + gammaFind = true; + this->gammaTensor = nullptr; + } + if (tensor.get() == betaTensor) { + toDeleteTensorIdxes.emplace_back(i); + betaFind = true; + this->betaTensor = nullptr; + } + } + if (!muFind || !sigmaFind || !betaFind || !gammaFind) { + MS_LOG(ERROR) << "Can not find muTensor or sigmaTensor or betaTensor or gammaTensor in graph"; + return RET_ERROR; + } + auto status = RemoveTensor(graph, toDeleteTensorIdxes); + if (status != RET_OK) { + MS_LOG(ERROR) << "Remove ConstTensors failed" << bnFold->name.c_str(); + return RET_ERROR; + } + return RET_OK; +} + +BatchNormFoldFusionPass::~BatchNormFoldFusionPass() { + if (newWeightTensor == nullptr) { + newWeightTensor.reset(); + newWeightTensor = nullptr; + } + if (newBiasTensor == nullptr) { + newBiasTensor.reset(); + newBiasTensor = nullptr; + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h new file mode 100644 index 0000000000..3fb2ab8ce0 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H +#define MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +// input = input +// weight = SimQuantPerChannel(weight * gamma / sigma) +// bias = beta - gamma * mi / sigma +// MulFold: gamma sigma +// BatchNormFold: mi sigma +// AddFold: gamma beta mi sigma +class BatchNormFoldFusionPass : public FusionPass { + public: + BatchNormFoldFusionPass() = default; + + ~BatchNormFoldFusionPass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(MetaGraphT *graph) override; + + protected: + STATUS FindNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); + STATUS CheckPath(MetaGraphT *graph, const std::unordered_map> &matchedPath); + STATUS FindTensors(); + STATUS GenNewWeightTensor(); + STATUS GenNewBiasTensor(); + STATUS IsolateNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); + void UpdateConvWeights(); + STATUS DeleteConstTensors(); + + protected: + MetaGraphT *graph = nullptr; + CNodeT *preConv = nullptr; + CNodeT *bnFold = nullptr; + CNodeT *mulFold = nullptr; + CNodeT *fakeNode = nullptr; + CNodeT *convNode = nullptr; + CNodeT *addFold = nullptr; + TensorT *muTensor = nullptr; + TensorT *sigmaTensor = nullptr; + TensorT *gammaTensor = nullptr; + TensorT *betaTensor = nullptr; + TensorT *oldWeightTensor = nullptr; + int32_t channelOut = 0; + + std::unique_ptr newWeightTensor = nullptr; + std::unique_ptr newBiasTensor = nullptr; + + std::string inputOpName = "Input"; + std::string convPatternOpName1 = "Convolution1"; + std::string bnFoldOpName = "BatchNormFold"; + std::string mulFoldOpName = "MulFold"; + std::string fakeQuantOpName = "FakeQuant"; + std::string convPatternOpName2 = "Convolution2"; + std::string addFoldOpName = "AddFold"; + std::string withPrePatternName = "BNFoldFusionWithPre"; + std::string noPrePatternName = "BNFoldFusionNoPre"; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc new file mode 100644 index 0000000000..d132aa8981 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "tools/common/graph_util.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define CONV_ACTIVATION_MATCH_PATH_LEN 2 + +STATUS ConvActivationFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + auto actOp = std::make_shared(); + actOp->id = ACTIVATION_NAME; + actOp->types = {schema::PrimitiveType_Activation}; + actOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvActivationFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(actOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +// 1. change attr of conv +// 2. delete Activation node +STATUS ConvActivationFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != CONV_ACTIVATION_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "Conv-Activation-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto convPath = matchedPath[kConvName]; + auto actPath = matchedPath[ACTIVATION_NAME]; + auto &convNode = graph->nodes.at(convPath->nodeIdx); + auto &actNode = graph->nodes.at(actPath->nodeIdx); + + // todo if combine conv_relu_fusion and conv_relu6_fusion to conv_activation_fusion + if (actNode->primitive->value.AsActivation()->type != this->activationType) { + return RET_NO_CHANGE; + } + + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->activationType = this->activationType; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->activationType = this->activationType; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + + // remove activation node + MergeNodeAttrFromPost(convNode, actNode); + auto status = IsolateOneWayNode(graph, actPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << actPath->subGraphIdx << ", node: " << actPath->nodeIdx + << ", error: " << status; + return status; + } + + return RET_OK; +} + +STATUS ConvActivationFusionPass::Run(schema::MetaGraphT *graph) { + SetActivationType(); + return FusionPass::Run(graph); +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h new file mode 100644 index 0000000000..b760b503d5 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +class ConvActivationFusionPass : public FusionPass { + public: + ConvActivationFusionPass() = default; + + ~ConvActivationFusionPass() override = default; + + STATUS DefinePattern() override; + + virtual STATUS SetActivationType() = 0; + + // 1. change attr of conv + // 2. delete Activation node + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + schema::ActivationType activationType = schema::ActivationType_RELU; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc new file mode 100644 index 0000000000..093af0ed54 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc @@ -0,0 +1,295 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +// #include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define CONV_BIASADD_MATCH_PATH_LEN 2 +#define BIASADD_OP_BIAS_INDEX_IN_WEIGHT 0 +#define BIASADD_OP_INPUT_NUM 2 +#define BIASADD_OP_CONST_TENSOR_INDEX 1 + +STATUS ConvBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS ConvBiasAddFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeConv2D}; + auto baOp = std::make_shared(); + baOp->id = BIASADD_NAME; + baOp->types = {schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Add}; + baOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBiasAddFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(baOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != CONV_BIASADD_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "Conv-BiasAdd-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto convPath = matchedPath[kConvName]; + auto baPath = matchedPath[BIASADD_NAME]; + auto &convNode = graph->nodes.at(convPath->nodeIdx); + auto &baNode = graph->nodes.at(baPath->nodeIdx); + // add/biasadd node the second tensor is not constant tensor, don't fusion + auto baNodeInputIndex = baNode->inputIndex; + if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; + return RET_ERROR; + } + auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); + MS_ASSERT(baNodeBiasTensor != nullptr); + if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) { + // dont fusion, return + return RET_OK; + } + + // 1. generate newBiasTensor for conv + auto status = GenConvBiasTensor(convPath, baPath, graph); + if (RET_OK != status) { + MS_LOG(ERROR) << "GenConvBiasTensor failed, " << status; + return status; + } + if (this->newBiasTensor != nullptr) { + status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); + this->newBiasTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "AddTensor2Node failed, node: " << convPath->nodeIdx << ", error: " << status; + return status; + } + // add bias quantParam + // todo add quantParam for tensors + + // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { + // std::unique_ptr quantParamArray(new QuantParamArrayT()); + // if (quantParamArray == nullptr) { + // MS_LOG(ERROR) << "new QuantParamArrayT failed"); + // return RET_ERROR; + // } + // std::unique_ptr quantParam(new QuantParamT()); + // if (quantParam == nullptr) { + // MS_LOG(ERROR) << "new QuantParamT failed"); + // return RET_ERROR; + // } + // quantParam->numBits = -1; + // quantParam->scale = FLT_MAX; + // quantParam->zeroPoint = 0; + // quantParam->narrowRange = true; + // quantParam->min = FLT_MAX; + // quantParam->max = FLT_MAX; + // quantParamArray->param.emplace_back(quantParam.release()); + // convNode->quantParam.emplace_back(quantParamArray.release()); + // } + } + + // 2. change attr of conv + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { + convNode->primitive->value.AsDeConv2D()->hasBias = true; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + + // 5. delete BiasAdd node + MergeNodeAttrFromPost(convNode, baNode); + status = IsolateOneWayNode(graph, baPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, graph: %zu, node: %zu, error: %d"; + //, baPath->subGraphIdx, baPath->nodeIdx, status); + return status; + } + + return RET_OK; +} + +#define BIASADD_WEIGHT_SHAPE_SIZE 1 +#define BIASADD_BIAS_DIM_INDEX 0 + +STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, std::shared_ptr baPath, + MetaGraphT *graph) { + MS_ASSERT(convPath != nullptr); + MS_ASSERT(baPath != nullptr); + MS_ASSERT(graph != nullptr); + + auto convNode = graph->nodes.at(convPath->nodeIdx).get(); + MS_ASSERT(convNode != nullptr); + auto baNode = graph->nodes.at(baPath->nodeIdx).get(); + MS_ASSERT(baNode != nullptr); + int32_t kernelNum = 0; + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + kernelNum = convNode->primitive->value.AsConv2D()->channelOut; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelIn * + convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { + kernelNum = convNode->primitive->value.AsDeConv2D()->channelOut; + } + auto convWeightTensorIdxes = convNode->inputIndex; + if (convWeightTensorIdxes.size() < CONV_OP_NO_BIAS_INPUT_NUM) { + MS_LOG(ERROR) << convNode->name.c_str() << " node tensors number is invalid! "; + return RET_ERROR; + } + convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); + auto baWeightTensorIdxes = baNode->inputIndex; + if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; + return RET_ERROR; + } + baWeightTensorIdxes.erase(baWeightTensorIdxes.begin()); + + if (convWeightTensorIdxes.empty()) { + MS_LOG(ERROR) << "Conv2D should has one weight tensors at least, current number of weight tensors " + << convWeightTensorIdxes.size(); + return RET_ERROR; + } + + if (baWeightTensorIdxes.empty()) { + MS_LOG(ERROR) << "BiasAdd should has one weight tensors at least, current number of weight tensors " + << baWeightTensorIdxes.size(); + return RET_ERROR; + } + + TensorT *oldBiasTensor = nullptr; + TensorT *biasTensor = nullptr; + + if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { + oldBiasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); + MS_ASSERT(oldBiasTensor != nullptr); + } + biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX_IN_WEIGHT)).get(); + MS_ASSERT(biasTensor != nullptr); + auto biasDims = biasTensor->dims; + // if biasTensor is a scaler + if (biasDims.empty() && biasTensor->data.data() == nullptr) { + MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid" << baNode->name.c_str(); + return RET_ERROR; + } + if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { + MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size() + << ". or bias tensor is a scaler"; + return RET_ERROR; + } + + bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1; + if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { + MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" + << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; + return RET_ERROR; + } + + // cal new biasData + this->newBiasData = new (std::nothrow) float[kernelNum]; + if (newBiasData == nullptr) { + MS_LOG(ERROR) << "new newBiasData failed"; + return RET_ERROR; + } + + if (biasDims.empty() && biasTensor->data.data() != nullptr) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + if (0 != memset_s(newBiasData, kernelNum * sizeof(float), *biasData, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset_s newBiasData failed"; + return RET_ERROR; + } + } else if (bias_const) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + for (size_t i = 0; i < kernelNum; i++) { + newBiasData[i] = *biasData; + } + } else { + if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s newBiasData failed"; + return RET_ERROR; + } + } + if (oldBiasTensor != nullptr) { + auto oldBiasDims = oldBiasTensor->dims; + if (oldBiasDims.size() != 1) { + MS_LOG(ERROR) + << "Conv bias tensor should has one dimension, current number of dimension %zu"; // oldBiasDims.size()); + return RET_ERROR; + } + if (oldBiasDims.at(0) != kernelNum) { + MS_LOG(ERROR) + << "Size(%zu) of Conv bias tensor should be equal to kernelNum(%d), current number of dimension %zu"; + // oldBiasDims.size(), kernelNum); + return RET_ERROR; + } + auto *oldBiasData = reinterpret_cast(oldBiasTensor->data.data()); + for (size_t i = 0; i < kernelNum; i++) { + oldBiasData[i] += newBiasData[i]; + } + } else { + auto *newCharBiasData = reinterpret_cast(newBiasData); + std::vector tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); + + auto weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); + this->newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); + // todo biasShape + this->newBiasTensor->dims = {kernelNum}; + this->newBiasTensor->dataType = weightTensor->dataType; + this->newBiasTensor->format = weightTensor->format; + this->newBiasTensor->refCount = weightTensor->refCount; + this->newBiasTensor->data.swap(tmpBiasVec); + newCharBiasData = nullptr; + } + + delete (this->newBiasData); + newBiasData = nullptr; + + return RET_OK; +} + +ConvBiasAddFusionPass::~ConvBiasAddFusionPass() { + if (this->newBiasData != nullptr) { + delete (this->newBiasData); + } +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h new file mode 100644 index 0000000000..1f104af2c3 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +class ConvBiasAddFusionPass : public FusionPass { + public: + ConvBiasAddFusionPass() = default; + + ~ConvBiasAddFusionPass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + // gen this->newBiasTensor if conv has no bias before + STATUS GenConvBiasTensor(std::shared_ptr convPath, std::shared_ptr dstPath, schema::MetaGraphT *graph); + + protected: + float *newBiasData = nullptr; + std::unique_ptr newBiasTensor = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc new file mode 100644 index 0000000000..ae63da7a79 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" +#include "securec/include/securec.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +#define CAFFE_BATCHNORM_OP_WEIGHT_NUM 2 +#define TF_BATCHNORM_OP_WEIGHT_NUM 4 +#define CAFFE_BATCHNORM_MEAN_INDEX 0 +#define CAFFE_BATCHNORM_VARIANCE_INDEX 1 +#define TF_BATCHNORM_SCALE_INDEX 0 +#define TF_BATCHNORM_BIAS_INDEX 1 +#define TF_BATCHNORM_MEAN_INDEX 2 +#define TF_BATCHNORM_VARIANCE_INDEX 3 + +constexpr const float EPS = 1e-8; +constexpr const float EPS_DEFAULT_FLOAT = 1e-5; +constexpr const float POW_NUM = 0.5; + +STATUS ConvBNFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvBNFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + auto bnOp = std::make_shared(); + bnOp->id = DST_NAME; + bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm}; + bnOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBatchNormFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(bnOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS ConvBNFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } + +STATUS ConvBNFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnPath != nullptr); + + BNWeightTensors bnWeightTensors; + + auto status = GetBnWeightTensors(graph, bnPath, kernelNum, bnWeightTensors); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetBnWeightTensors error " << status; + return status; + } + schema::TensorT *meanTensor = bnWeightTensors.meanTensor; + schema::TensorT *varianceTensor = bnWeightTensors.varianceTensor; + schema::TensorT *scaleTensor = bnWeightTensors.scaleTensor; + schema::TensorT *biasTensor = bnWeightTensors.biasTensor; + + auto *meanData = reinterpret_cast(meanTensor->data.data()); + auto *varianceData = reinterpret_cast(varianceTensor->data.data()); + + float eps = EPS_DEFAULT_FLOAT; + status = GetBnEpsilon(graph, bnPath, eps); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetBnEpsilon failed " << status; + return status; + } + + // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) + if (memcpy_s(transScale, kernelNum * sizeof(float), varianceData, kernelNum * sizeof(float)) != 0) { + MS_LOG(ERROR) << "memcpy_s transScale error"; + return RET_ERROR; + } + // 1/sqrt(variance + eps) + for (int32_t i = 0; i < kernelNum; i++) { + float tmp = transScale[i] + eps; + tmp = pow(tmp, POW_NUM); + transScale[i] = 1 / tmp; + } + + if (scaleTensor != nullptr) { + auto *scaleData = reinterpret_cast(scaleTensor->data.data()); + // scale/sqrt(variance + eps) + for (int32_t i = 0; i < kernelNum; i++) { + transScale[i] *= scaleData[i]; + } + } + + // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) + // -mean/sqrt(variance + eps) + for (int32_t i = 0; i < kernelNum; i++) { + transBias[i] = -meanData[i] * transScale[i]; + } + + if (biasTensor != nullptr) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + // -scale*mean/sqrt(variance + eps) + bias + for (int32_t i = 0; i < kernelNum; i++) { + transBias[i] += biasData[i]; + } + } + + return RET_OK; +} + +// BatchNorm weight Tensor definition: +// caffe +// estimated_mean --0 +// estimated_variance --1 +// tensorflow +// scale -- 0 +// bias --1 +// estimated_mean --2 +// estimated_variance --3 +STATUS ConvBNFusionPass::GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum, + BNWeightTensors &bnWeightTensors) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnPath != nullptr); + auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); + auto bnWeightTensorIdxes = bnNode->inputIndex; + bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin()); + if (bnWeightTensorIdxes.size() == CAFFE_BATCHNORM_OP_WEIGHT_NUM) { + bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get(); + bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get(); + } else if (bnWeightTensorIdxes.size() == TF_BATCHNORM_OP_WEIGHT_NUM) { + bnWeightTensors.scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_SCALE_INDEX]).get(); + bnWeightTensors.biasTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_BIAS_INDEX]).get(); + bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_MEAN_INDEX]).get(); + bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_VARIANCE_INDEX]).get(); + } else { + MS_LOG(ERROR) << "BatchNorm should has " << CAFFE_BATCHNORM_OP_WEIGHT_NUM << " or " << TF_BATCHNORM_OP_WEIGHT_NUM + << " weight tensors, current number of weight tensors " << bnWeightTensorIdxes.size(); + return RET_ERROR; + } + + if (bnWeightTensors.meanTensor == nullptr) { + MS_LOG(ERROR) << "BatchNorm's mean tensor is nullptr"; + return RET_ERROR; + } + + if (bnWeightTensors.varianceTensor == nullptr) { + MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr"; + return RET_ERROR; + } + + if (kernelNum != bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + + if (kernelNum != bnWeightTensors.varianceTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + + if (bnWeightTensors.scaleTensor != nullptr) { + if (kernelNum != bnWeightTensors.scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + } + + if (bnWeightTensors.biasTensor != nullptr) { + if (kernelNum != bnWeightTensors.biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS ConvBNFusionPass::GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr bnPath, float &eps) { + MS_ASSERT(graph != nullptr); + auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); + MS_ASSERT(bnNode != nullptr); + if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { + eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; + } else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) { + eps = bnNode->primitive->value.AsBatchNorm()->epsilon; + } else { + MS_LOG(ERROR) << "match pattern has error, " << bnNode->name.c_str() << " not BatchNorm node"; + return RET_ERROR; + } + + if (eps < EPS) { + eps = EPS_DEFAULT_FLOAT; + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h new file mode 100644 index 0000000000..b7eb7c1d26 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#ifndef MINDSPORE_CONV_BN_FUSION_PASS_H +#define MINDSPORE_CONV_BN_FUSION_PASS_H + +#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" + +namespace mindspore { +namespace lite { +class ConvBNFusionPass : public ConvScaleBiasFusionPass { + public: + ConvBNFusionPass() = default; + + ~ConvBNFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum) override; + + // Get and check BNNode weight tensor + STATUS GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum, + BNWeightTensors &bnWeightTensors); + + STATUS GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr bnPath, float &eps); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CONV_BN_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc new file mode 100644 index 0000000000..8c22b52772 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS ConvRelu6FusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } + +STATUS ConvRelu6FusionPass::SetActivationType() { + this->activationType = ActivationType_RELU6; + return RET_OK; +} + +STATUS ConvRelu6FusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvRelu6FusionPass::Run(MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h new file mode 100644 index 0000000000..a500633351 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H + +#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +class ConvRelu6FusionPass : public ConvActivationFusionPass { + public: + ConvRelu6FusionPass() = default; + + ~ConvRelu6FusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS SetActivationType() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc new file mode 100644 index 0000000000..05a7880bd4 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS ConvReluFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvReluFusionPass::Run(schema::MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } + +STATUS ConvReluFusionPass::SetActivationType() { + this->activationType = schema::ActivationType_RELU; + return RET_OK; +} + +STATUS ConvReluFusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h new file mode 100644 index 0000000000..e7c87cd197 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H + +#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +class ConvReluFusionPass : public ConvActivationFusionPass { + public: + ConvReluFusionPass() = default; + + ~ConvReluFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS SetActivationType() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc new file mode 100644 index 0000000000..618fd9c5ac --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc @@ -0,0 +1,361 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. + * Description: mslite + * Author: mslite + * Create: 2019-12-13 + */ + +#include +#include +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" +#include "securec/include/securec.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { + +#define CONV_SCALE_BIAS_MATCH_PATH_LEN 2 + +// 1. generate biasTensor according to BN weightTensor +// 2. change attr of conv +// 3. delete BN node +STATUS ConvScaleBiasFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != CONV_SCALE_BIAS_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "Conv-Scale-Bias-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto convPath = matchedPath[kConvName]; + MS_ASSERT(convPath != nullptr); + auto dstPath = matchedPath[DST_NAME]; + MS_ASSERT(dstPath != nullptr); + MS_ASSERT(subGraph != nullptr); + auto &convNode = graph->nodes.at(convPath->nodeIdx); + MS_ASSERT(convNode != nullptr); + auto &dstNode = graph->nodes.at(dstPath->nodeIdx); + MS_ASSERT(dstNode != nullptr); + + // 1. generate new weightTensor and biasTensor for conv + auto status = GenConvWeightTensors(graph, convPath, dstPath); + if (RET_OK != status) { + MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; + return status; + } + if (convNode->inputIndex.size() == CONV_OP_HAS_BIAS_INPUT_NUM) { + status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), + std::move(this->newWeightTensor)); + this->newWeightTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_BIAS_INDEX_IN_INPUT), + std::move(this->newBiasTensor)); + this->newBiasTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + } else if (convNode->inputIndex.size() == CONV_OP_NO_BIAS_INPUT_NUM) { + status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), + std::move(this->newWeightTensor)); + this->newWeightTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); + this->newBiasTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + // if (convNode->name == "Conv_461") { + // } + // add bias quantParam + // todo use tensor quant param + // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { + // std::unique_ptr quantParamArray(new QuantParamArrayT()); + // if (quantParamArray == nullptr) { + // MS_LOG(ERROR) << "new QuantParamArrayT failed"; + // return RET_ERROR; + // } + // std::unique_ptr quantParam(new QuantParamT()); + // if (quantParam == nullptr) { + // MS_LOG(ERROR) << "new QuantParamT failed"; + // return RET_ERROR; + // } + // quantParam->numBits = -1; + // quantParam->scale = FLT_MAX; + // quantParam->zeroPoint = 0; + // quantParam->narrowRange = true; + // quantParam->min = FLT_MAX; + // quantParam->max = FLT_MAX; + // quantParamArray->param.emplace_back(quantParam.release()); + // convNode->quantParam.emplace_back(quantParamArray.release()); + // } + } else { + MS_LOG(ERROR) << "Conv node should has 2 or 3 weight tensors rather than " << convNode->inputIndex.size(); + return RET_ERROR; + } + + // 2. change attr of conv + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + + // 3. delete DST node + MergeNodeAttrFromPost(convNode, dstNode); + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstPath->nodeIdx << ", error: " << status; + return status; + } + + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, + std::shared_ptr dstPath) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(convPath != nullptr); + MS_ASSERT(dstPath != nullptr); + MS_ASSERT(subGraph != nullptr); + auto &convNode = graph->nodes.at(convPath->nodeIdx); + MS_ASSERT(convNode != nullptr); + int32_t kernelNum = -1; + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + kernelNum = convNode->primitive->value.AsConv2D()->channelOut; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier * + convNode->primitive->value.AsDepthwiseConv2D()->channelIn; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + if (kernelNum <= 0) { + MS_LOG(ERROR) << "KernelNum should be positive, " << kernelNum; + return RET_ERROR; + } + + this->transScale = new (std::nothrow) float[kernelNum]; + this->transBias = new (std::nothrow) float[kernelNum]; + + if (transScale == nullptr) { + MS_LOG(ERROR) << "new transScale failed"; + return RET_ERROR; + } + + if (transBias == nullptr) { + MS_LOG(ERROR) << "new transBias failed"; + return RET_ERROR; + } + + if (0 != memset_s(transScale, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset transScale failed"; + return RET_ERROR; + } + + if (0 != memset_s(transBias, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset transBias failed"; + return RET_ERROR; + } + + auto status = GetTransParam(graph, dstPath, kernelNum); + if (RET_OK != status) { + MS_LOG(ERROR) << "GetTransParam failed, " << status; + return status; + } + + status = CalConvWeightTensors(graph, convPath, kernelNum); + if (RET_OK != status) { + MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; + return status; + } + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::CalNewWeightTensor(TensorT *oldWeightTensor, const int32_t kernelNum, + const size_t kernelSize) { + MS_ASSERT(oldWeightTensor != nullptr); + auto weightData = reinterpret_cast(oldWeightTensor->data.data()); + size_t kernelDataCount = kernelNum * kernelSize; + if (kernelDataCount == 0) { + MS_LOG(ERROR) << "KernelDataCount should be positive, " << kernelDataCount; + return RET_ERROR; + } + this->newWeightData = new (std::nothrow) float[kernelDataCount]; + if (newWeightData == nullptr) { + MS_LOG(ERROR) << "new newWeightData failed"; + return RET_ERROR; + } + + if (0 != memset_s(newWeightData, kernelDataCount * sizeof(float), 0, kernelDataCount * sizeof(float))) { + MS_LOG(ERROR) << "memset newWeightData failed"; + return RET_ERROR; + } + + for (size_t i = 0; i < kernelNum; i++) { + for (size_t j = 0; j < kernelSize; j++) { + newWeightData[i * kernelSize + j] = weightData[i * kernelSize + j] * transScale[i]; + } + } + auto newCharWeightData = reinterpret_cast(newWeightData); + std::vector tmpWeightVec(newCharWeightData, + newCharWeightData + kernelDataCount * sizeof(float) / sizeof(uint8_t)); + + this->newWeightTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (this->newWeightTensor == nullptr) { + MS_LOG(ERROR) << "new newWeightTensor failed"; + return RET_ERROR; + } + this->newWeightTensor->dims.insert(this->newWeightTensor->dims.begin(), oldWeightTensor->dims.begin(), + oldWeightTensor->dims.end()); + this->newWeightTensor->dataType = oldWeightTensor->dataType; + this->newWeightTensor->format = oldWeightTensor->format; + this->newWeightTensor->refCount = oldWeightTensor->refCount; + this->newWeightTensor->data.swap(tmpWeightVec); + delete (this->newWeightData); + newWeightData = nullptr; + + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::CalNewBiasTensor(TensorT *oldWeightTensor, TensorT *oldBiasTensor, + const int32_t kernelNum) { + MS_ASSERT(oldWeightTensor != nullptr); + this->newBiasData = new (std::nothrow) float[kernelNum]; + if (newBiasData == nullptr) { + MS_LOG(ERROR) << "new newBiasData failed"; + return RET_ERROR; + } + if (0 != memset_s(newBiasData, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset newBiasData failed"; + return RET_ERROR; + } + + if (oldBiasTensor != nullptr) { + auto *biasData = reinterpret_cast(oldBiasTensor->data.data()); + + for (size_t i = 0; i < kernelNum; i++) { + this->newBiasData[i] = biasData[i] * transScale[i] + transBias[i]; + } + } else { + if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), transBias, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s newBiasData failed"; + return RET_ERROR; + } + } + auto *newCharBiasData = reinterpret_cast(newBiasData); + std::vector tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); + + this->newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (this->newBiasTensor == nullptr) { + MS_LOG(ERROR) << "new newBiasTensor failed"; + return RET_ERROR; + } + // todo biasShape + this->newBiasTensor->dims = {kernelNum}; + this->newBiasTensor->dataType = oldWeightTensor->dataType; + this->newBiasTensor->format = oldWeightTensor->format; + this->newBiasTensor->refCount = oldWeightTensor->refCount; + this->newBiasTensor->data.swap(tmpBiasVec); + delete (this->newBiasData); + newCharBiasData = nullptr; + newBiasData = nullptr; + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, + int32_t kernelNum) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(convPath != nullptr); + + auto convNode = graph->nodes.at(convPath->nodeIdx).get(); + MS_ASSERT(convNode != nullptr); + auto convWeightTensorIdxes = convNode->inputIndex; + convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); + + TensorT *weightTensor = nullptr; + TensorT *biasTensor = nullptr; + if (convWeightTensorIdxes.size() == CONV_OP_NO_BIAS_WEIGHT_NUM) { + weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); + } else if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { + weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); + biasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); + } else { + MS_LOG(ERROR) << "Conv2D should has " << CONV_OP_NO_BIAS_WEIGHT_NUM << " or " << CONV_OP_HAS_BIAS_WEIGHT_NUM + << " weight tensors, current number of weight tensors " << convWeightTensorIdxes.size(); + return RET_ERROR; + } + if (weightTensor == nullptr) { + MS_LOG(ERROR) << "Conv2D's weight tensor is nullptr"; + return RET_ERROR; + } + + auto weightShape = weightTensor->dims; + if (weightShape.size() != CONV_FILTER_SHAPE_SIZE) { + MS_LOG(ERROR) << "Size of dims of weight tensor should be " << CONV_FILTER_SHAPE_SIZE << " rather than " + << weightShape.size(); + return RET_ERROR; + } + size_t kernelSize = GetShapeSize(*weightTensor) / kernelNum; + + // cal new weightData + auto status = CalNewWeightTensor(weightTensor, kernelNum, kernelSize); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalNewWeightTensor error " << status; + return status; + } + // cal new biasData + status = CalNewBiasTensor(weightTensor, biasTensor, kernelNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalNewBiasTensor error " << status; + return status; + } + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } + +ConvScaleBiasFusionPass::~ConvScaleBiasFusionPass() { + if (this->transScale != nullptr) { + delete (this->transScale); + } + if (this->transBias != nullptr) { + delete (this->transBias); + } + if (this->newWeightData != nullptr) { + delete (this->newWeightData); + } + if (this->newBiasData != nullptr) { + delete (this->newBiasData); + } +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h new file mode 100644 index 0000000000..1a9ce07b06 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. + * Description: mslite + * Author: mslite + * Create: 2019-12-13 + */ + +#ifndef MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +struct BNWeightTensors { + schema::TensorT *meanTensor = nullptr; + schema::TensorT *varianceTensor = nullptr; + schema::TensorT *scaleTensor = nullptr; + schema::TensorT *biasTensor = nullptr; +}; + +class ConvScaleBiasFusionPass : public FusionPass { + public: + ConvScaleBiasFusionPass() = default; + + ~ConvScaleBiasFusionPass() override; + + STATUS DefinePattern() override = 0; + + // 1. generate biasTensor according to BN weightTensor + // 2. change attr of conv + // 3. delete BN node + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + // call GetTransParam() and CalConvWeightTensors() + STATUS GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, + std::shared_ptr dstPath); + + // fill this->transScale and this->transBias + virtual STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr dstPath, int32_t kernelNum) = 0; + + // fill this->newWeightTensor and this->newBiasTensor according to this->transScale and this->transBias + STATUS CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, int32_t kernelNum); + + STATUS CalNewWeightTensor(schema::TensorT *oldWeightTensor, int32_t kernelNum, size_t kernelSize); + + STATUS CalNewBiasTensor(schema::TensorT *oldWeightTensor, schema::TensorT *oldBiasTensor, int32_t kernelNum); + + protected: + float *transScale = nullptr; + float *transBias = nullptr; + float *newWeightData = nullptr; + float *newBiasData = nullptr; + std::unique_ptr newWeightTensor = nullptr; + std::unique_ptr newBiasTensor = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc new file mode 100644 index 0000000000..56d5b3d262 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" +#include "securec/include/securec.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define SCALE_OP_NO_BIAS_WEIGHT_NUM 1 +#define SCALE_OP_HAS_BIAS_WEIGHT_NUM 2 + +#define SCALE_OP_SCALE_INDEX_IN_WEIGHT 0 +#define SCALE_OP_BIAS_INDEX_IN_WEIGHT 1 + +STATUS ConvScaleFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + auto scaleOp = std::make_shared(); + scaleOp->id = DST_NAME; + scaleOp->types = {schema::PrimitiveType_Scale}; + scaleOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvScaleFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(scaleOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS ConvScaleFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvScaleFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } + +STATUS ConvScaleFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr scalePath, + int32_t kernelNum) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(scalePath != nullptr); + + auto scaleNode = graph->nodes.at(scalePath->nodeIdx).get(); + MS_ASSERT(scaleNode != nullptr); + auto scaleWeightTensorIdxes = scaleNode->inputIndex; + scaleWeightTensorIdxes.erase(scaleWeightTensorIdxes.begin()); + + schema::TensorT *scaleTensor = nullptr; + schema::TensorT *biasTensor = nullptr; + + if (scaleWeightTensorIdxes.size() == SCALE_OP_NO_BIAS_WEIGHT_NUM) { + scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); + } else if (scaleWeightTensorIdxes.size() == SCALE_OP_HAS_BIAS_WEIGHT_NUM) { + scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); + biasTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_BIAS_INDEX_IN_WEIGHT]).get(); + } else { + MS_LOG(ERROR) << "Scale should has %d or %d weight tensors, current number of weight tensors %zu"; + // SCALE_OP_NO_BIAS_WEIGHT_NUM, SCALE_OP_HAS_BIAS_WEIGHT_NUM, scaleWeightTensorIdxes.size()); + return RET_ERROR; + } + + if (scaleTensor == nullptr) { + MS_LOG(ERROR) << "Scale's scale tensor is nullptr"; + return RET_ERROR; + } + + if (kernelNum != scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to scale size(%lu)"; + //, kernelNum, scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)); + return RET_ERROR; + } + + const float *scaleData = reinterpret_cast(scaleTensor->data.data()); + + if (0 != memcpy_s(transScale, kernelNum * sizeof(float), scaleData, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s transScale failed"; + return RET_ERROR; + } + + if (biasTensor != nullptr) { + if (kernelNum != biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to bias size(%lu)"; + //, kernelNum, biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)); + return RET_ERROR; + } + + const float *biasData = reinterpret_cast(biasTensor->data.data()); + + if (0 != memcpy_s(transBias, kernelNum * sizeof(float), biasData, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s transBias failed"; + return RET_ERROR; + } + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h new file mode 100644 index 0000000000..8c2ed2808c --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H + +#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +class ConvScaleFusionPass : public ConvScaleBiasFusionPass { + public: + ConvScaleFusionPass() = default; + + ~ConvScaleFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr scalePath, int32_t kernelNum) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc new file mode 100644 index 0000000000..156c79103a --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +// #include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define kFormatTransMatchPathLen2 2 +#define kFormatTransMatchPathLen3 3 + +STATUS FormatTransFusionPass::DefinePattern() { + // nchw2nhwc + nhwc2nchw + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + nh2ncOp->left = nc2nhOp; + std::unique_ptr nc2NhAndNh2NcFusionPattern(new (std::nothrow) + FusionPattern(kNc2NhAndNh2NcFusionPattern)); + if (nc2NhAndNh2NcFusionPattern == nullptr) { + // MS_LOG(ERROR) << "new %s failed", kNc2NhAndNh2NcFusionPattern); + return RET_ERROR; + } + nc2NhAndNh2NcFusionPattern->AddPatternOp(nc2nhOp); + nc2NhAndNh2NcFusionPattern->AddPatternOp(nh2ncOp); + nc2NhAndNh2NcFusionPattern->Finish(); + this->patterns.emplace_back(nc2NhAndNh2NcFusionPattern.release()); + } + // nchw2nhwc + QuantDtypeCast + nhwc2nchw + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto passOp = std::make_shared(); + passOp->id = kFormatTransPassOp; + passOp->types = {PrimitiveType_QuantDTypeCast}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + passOp->left = nc2nhOp; + nh2ncOp->left = passOp; + std::unique_ptr nc2NhAndNh2NcPassFusionPattern(new FusionPattern(kNc2NhAndNh2NcPassFusionPattern)); + if (nc2NhAndNh2NcPassFusionPattern == nullptr) { + // MS_LOG(ERROR) << "new %s failed", kNc2NhAndNh2NcPassFusionPattern); + return RET_ERROR; + } + nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nc2nhOp); + nc2NhAndNh2NcPassFusionPattern->AddPatternOp(passOp); + nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nh2ncOp); + nc2NhAndNh2NcPassFusionPattern->Finish(); + this->patterns.emplace_back(nc2NhAndNh2NcPassFusionPattern.release()); + } + // nhwc2nchw + nchw2nhwc + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + nc2nhOp->left = nh2ncOp; + std::unique_ptr nh2NcAndNc2NhFusionPattern(new (std::nothrow) + FusionPattern(kNh2NcAndNc2NhFusionPattern)); + if (nh2NcAndNc2NhFusionPattern == nullptr) { + // MS_LOG(ERROR) << "new %s failed", kNh2NcAndNc2NhFusionPattern); + return RET_ERROR; + } + nh2NcAndNc2NhFusionPattern->AddPatternOp(nh2ncOp); + nh2NcAndNc2NhFusionPattern->AddPatternOp(nc2nhOp); + nh2NcAndNc2NhFusionPattern->Finish(); + this->patterns.emplace_back(nh2NcAndNc2NhFusionPattern.release()); + } + // nhwc2nchw + QuantDtypeCast + nchw2nhwc + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto passOp = std::make_shared(); + passOp->id = kFormatTransPassOp; + passOp->types = {PrimitiveType_QuantDTypeCast}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + passOp->left = nh2ncOp; + nc2nhOp->left = passOp; + std::unique_ptr nh2NcAndNc2NhPassFusionPattern(new (std::nothrow) + FusionPattern(kNh2NcAndNc2NhPassFusionPattern)); + if (nh2NcAndNc2NhPassFusionPattern == nullptr) { + MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; + return RET_ERROR; + } + nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nh2ncOp); + nh2NcAndNc2NhPassFusionPattern->AddPatternOp(passOp); + nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nc2nhOp); + nh2NcAndNc2NhPassFusionPattern->Finish(); + this->patterns.emplace_back(nh2NcAndNc2NhPassFusionPattern.release()); + } + return RET_OK; +} + +STATUS FormatTransFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != kFormatTransMatchPathLen2 && matchedPath.size() != kFormatTransMatchPathLen3) { + MS_LOG(ERROR) << "Format-Transform-Fusion should have " << kFormatTransMatchPathLen2 << " or " + << kFormatTransMatchPathLen3 << " NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + std::shared_ptr srcPath; + std::shared_ptr dstPath; + if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { + srcPath = matchedPath[kFormatTransNc2NhOp]; + dstPath = matchedPath[kFormatTransNh2NcOp]; + } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { + srcPath = matchedPath[kFormatTransNh2NcOp]; + dstPath = matchedPath[kFormatTransNc2NhOp]; + } else { + MS_ASSERT(false); + } + MS_ASSERT(srcPath != nullptr); + MS_ASSERT(dstPath != nullptr); + auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); + auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); + MS_ASSERT(srcNode != nullptr); + MS_ASSERT(dstNode != nullptr); + if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { + MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nchw2Nhwc); + MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nhwc2Nchw); + } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { + MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nhwc2Nchw); + MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nchw2Nhwc); + } else { + MS_ASSERT(false); + } + + auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; + return status; + } + + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; + return status; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h new file mode 100644 index 0000000000..1b77c74d6b --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H +#define MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +constexpr const char *kFormatTransNc2NhOp = "FormatTransNc2NhOp"; +constexpr const char *kFormatTransNh2NcOp = "FormatTransNh2NcOp"; +constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; +constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; +constexpr const char *kNc2NhAndNh2NcPassFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; +constexpr const char *kNh2NcAndNc2NhFusionPattern = "Nh2NcAndNc2NhFusionPattern"; +constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; + +class FormatTransFusionPass : public FusionPass { + public: + FormatTransFusionPass() = default; + + ~FormatTransFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc new file mode 100644 index 0000000000..5f081137ff --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc @@ -0,0 +1,349 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS FusionPass::Run(schema::MetaGraphT *graph) { + auto ret = DefinePattern(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DefinePattern Error " << ret; + return ret; + } + for (auto pattern : patterns) { + if (pattern == nullptr) { + MS_LOG(ERROR) << "FusionPattern has not been set"; + return RET_PARAM_INVALID; + } + + if (!pattern->Check()) { + MS_LOG(ERROR) << "FusionPattern is invaild"; + return RET_PARAM_INVALID; + } + } + + ret = MatchPatterns(graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MatchPattern Error " << ret; + return ret; + } + + if (this->matchedPaths.empty()) { + return RET_NO_CHANGE; + } else { + ret = Fuse(graph); + if (ret != RET_OK && ret != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Fuse Error " << ret; + } + return ret; + } +} + +STATUS FusionPass::MatchPatterns(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + this->matchedPaths.clear(); + STATUS status; + for (auto pattern : patterns) { + status = MatchOnePattern(graph, pattern); + if (status != RET_OK) { + MS_LOG(ERROR) << "MatchOnePatternInSubGraph failed: " << status; + return status; + } + } + this->mapedMatchedPaths.clear(); + for (auto iter = matchedPaths.begin(); iter != matchedPaths.end(); iter++) { + auto patternName = iter->first; + auto patternOps = iter->second; + std::vector>> mapedPaths; + for (const auto &patternOp : patternOps) { + std::queue> opQueue; + std::unordered_map> mapedPath; + opQueue.push(patternOp); + while (!opQueue.empty()) { + auto curPatternOp = opQueue.front(); + opQueue.pop(); + MS_ASSERT(curPatternOp != nullptr); + mapedPath.insert(std::make_pair(curPatternOp->id, curPatternOp->path)); + if (curPatternOp->left != nullptr) { + opQueue.push(curPatternOp->left); + } + if (curPatternOp->right != nullptr) { + opQueue.push(curPatternOp->right); + } + } + mapedPaths.emplace_back(mapedPath); + } + this->mapedMatchedPaths.insert(std::make_pair(patternName, mapedPaths)); + } + return RET_OK; +} + +// assume that all nodes have only one output. if node has multi-outputs, +// some errors may happen +STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pattern) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(pattern != nullptr); + // std::vector> patternMatchPaths; + auto outputOp = pattern->GetPatternOp(pattern->GetOutput()); + if (outputOp == nullptr) { + MS_LOG(ERROR) << "Can not find the output of the pattern"; + return RET_NULL_PTR; + } + MS_ASSERT(outputOp->isTail); + if (graph->nodes.empty()) { + return RET_OK; + } + // find all matched entries + std::vector entries; + std::queue nodeQueue; + std::vector sinkIdes; + for (auto index : graph->outputIndex) { + auto subGraphOutputNodeIdxes = GetLinkedPreIdx(*graph, index); + for (auto subGraphOutputNodeIdx : subGraphOutputNodeIdxes) { + MS_ASSERT((subGraph->nodes.size() > subGraphOutputNodeIdx)); + nodeQueue.push(subGraphOutputNodeIdx); + } + } + while (!nodeQueue.empty()) { + auto nodeIdx = nodeQueue.front(); + nodeQueue.pop(); + if (IsContain(sinkIdes, nodeIdx)) { + continue; + } + MS_ASSERT(subGraph->nodes.size() > nodeIdx); + auto &node = graph->nodes.at(nodeIdx); + sinkIdes.emplace_back(nodeIdx); + + MS_ASSERT(nullptr != node->primitive); + if (IsContain(outputOp->types, node->primitive->value.type)) { + entries.emplace_back(nodeIdx); + } + auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx); + for (auto preNodeIdx : preNodeIdxes) { + MS_ASSERT((subGraph->nodes.size() > preNodeIdx)); + nodeQueue.push(preNodeIdx); + } + } + + // check each entry + std::vector> paths; + sinkIdes.clear(); + std::vector pathSinkIdes; + for (auto nodeIdx : entries) { + if (IsContain(sinkIdes, nodeIdx)) { + continue; + } + pathSinkIdes.clear(); + auto path = PatternOp::Copy(outputOp); + auto ret = MatchTree(graph, nodeIdx, path, sinkIdes, pathSinkIdes); + if (ret && CheckMatch(graph, path)) { + paths.emplace_back(path); + } + } + auto patternName = pattern->GetName(); + this->matchedPaths.insert(std::make_pair(patternName, paths)); + return RET_OK; +} + +bool FusionPass::CheckMatch(schema::MetaGraphT *graph, const std::shared_ptr& patternOp) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(patternOp != nullptr); + // find included nodes + std::queue> opQueue; + std::vector matchedNodeIdxes; + std::vector> inputNodes; + std::shared_ptr outputNode = nullptr; + opQueue.push(patternOp); + while (!opQueue.empty()) { + auto curPatternOp = opQueue.front(); + opQueue.pop(); + matchedNodeIdxes.push_back(curPatternOp->path->nodeIdx); + if (curPatternOp->isHead) { + inputNodes.emplace_back(curPatternOp); + } + if (curPatternOp->isTail) { + if (outputNode != nullptr && outputNode != curPatternOp) { + return false; + } + outputNode = curPatternOp; + } + if (curPatternOp->left != nullptr) { + opQueue.push(curPatternOp->left); + } + if (curPatternOp->right != nullptr) { + opQueue.push(curPatternOp->right); + } + } + // all post node of input node should be in path except input node is placeHold + for (const auto& inputNode : inputNodes) { + if (inputNode->isPlaceHold) { + continue; + } + auto inputNodePostNodeIdxes = GetOutputNodeIdx(*graph, inputNode->path->nodeIdx); + for (auto inputNodePostNodeIdx : inputNodePostNodeIdxes) { + if (!IsContain(matchedNodeIdxes, inputNodePostNodeIdx)) { + return false; + } + } + } + // all pre node of output node should be in path + auto outputNodePreNodeIdxes = GetInputNodeIdx(*graph, outputNode->path->nodeIdx); + for (auto outputNodePreNodeIdx : outputNodePreNodeIdxes) { + if (!IsContain(matchedNodeIdxes, outputNodePreNodeIdx)) { + return false; + } + } + return true; +} + +bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std::shared_ptr &target, + std::vector &sinkIdes, std::vector &pathSinkIdes) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(nodeIdx < subGraph->nodes.size()); + auto &scope = graph->nodes.at(nodeIdx); + MS_ASSERT(scope != nullptr); + // if target(except target is marked head) is nullptr, it means the preNode + // has no left or right, but scope is not nullptr + if (target == nullptr) { + return false; + } + // if node is sinked and not in the pathSinkId, then return false + if (IsContain(sinkIdes, nodeIdx) && !IsContain(pathSinkIdes, nodeIdx)) { + return false; + } + // type not match + if (!target->isPlaceHold && !IsContain(target->types, scope->primitive->value.type)) { + return false; + } + // path is setted and not pointer to this node + if (target->pathSetted) { + MS_ASSERT(target->path != nullptr); + if (target->path->nodeIdx != nodeIdx) { + return false; + } + } + target->SetPath(-1, nodeIdx); + sinkIdes.push_back(nodeIdx); + pathSinkIdes.push_back(nodeIdx); + // target is marked head, no need to check left and right. head-target's left + // and right is always nullptr + if (target->isHead) { + return true; + } + auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx); + if (preNodeIdxes.empty() && target->left == nullptr && target->right == nullptr) { + return true; + } + for (auto preNodeIdx : preNodeIdxes) { + MS_ASSERT(subGraph->nodes.size() > preNodeIdx); + // match left + if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) { + // match right + if (preNodeIdxes.size() == 1 && target->right == nullptr) { + return true; + } + for (auto preNodeIdxInner : preNodeIdxes) { + if (preNodeIdxInner == preNodeIdx) { + continue; + } + MS_ASSERT(subGraph->nodes.size() > preNodeIdxInner); + if (MatchTree(graph, preNodeIdxInner, target->right, sinkIdes, pathSinkIdes)) { + return true; // ignore follow match, pick the first match + } + } + } + } + sinkIdes.erase((sinkIdes.end() - 1)); + pathSinkIdes.erase((pathSinkIdes.end() - 1)); + target->UnSetPath(); + return false; +} + +STATUS FusionPass::Fuse(schema::MetaGraphT *graph) { + STATUS ret; + bool isChange = false; + for (auto iter = mapedMatchedPaths.begin(); iter != mapedMatchedPaths.end(); iter++) { + for (auto &matchedPath : iter->second) { + ret = DoFusion(graph, iter->first, matchedPath); + if (ret != RET_OK && ret != RET_NO_CHANGE) { + MS_LOG(ERROR) << "DoFusion Error " << ret; + return ret; + } else { + if (ret == RET_OK) { + isChange = true; + } + } + } + } + return isChange ? RET_OK : RET_NO_CHANGE; +} + +FusionPass::~FusionPass() { + for (auto pattern : patterns) { + if (pattern != nullptr) { + delete (pattern); + } + } +} + +void FusionPass::MergeNodeAttrFromPost(std::unique_ptr &dstOp, std::unique_ptr &postOp, + size_t dstOpOutIdx) { + // // merge quantParam + // if (dstOp->quantParam.empty()) { // not awareing quant + // return; + // } + // MS_ASSERT(postOp->outputIndex.size() == 1); + // if (dstOp->quantParam.size() != dstOp->inputIndex.size() + dstOp->outputIndex.size()) { + // int a = 1; + // } + // MS_ASSERT(dstOp->quantParam.size() == dstOp->inputIndex.size() + dstOp->outputIndex.size()); + // auto &dstQuantParamArray = dstOp->quantParam.at(dstOp->inputIndex.size() + dstOpOutIdx); + // auto &postQuantParamArray = postOp->quantParam.back(); + // if (!(postQuantParamArray != nullptr && postQuantParamArray->param.size() == 1 && + // postQuantParamArray->param.front() != nullptr && postQuantParamArray->param.front()->min != FLT_MAX)) { + // return; // postNode has no quantParam, no need merge + // } + // + // if ((dstQuantParamArray != nullptr && dstQuantParamArray->param.size() != 1) || + // (dstQuantParamArray->param.front() != nullptr && dstQuantParamArray->param.front()->min != FLT_MAX)) { + // return; // dstNode has quantParam, no need merge + // } + // + // dstQuantParamArray->param.front()->min = postQuantParamArray->param.front()->min; + // dstQuantParamArray->param.front()->max = postQuantParamArray->param.front()->max; + // dstQuantParamArray->param.front()->scale = postQuantParamArray->param.front()->scale; + // dstQuantParamArray->param.front()->zeroPoint = postQuantParamArray->param.front()->zeroPoint; + // MS_LOGD("merge quantParam from %s to %s", postOp->name.c_str(), dstOp->name.c_str()); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.h new file mode 100644 index 0000000000..058b933a65 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FUSION_PASS_H +#define MINDSPORE_PREDICT_FUSION_PASS_H + +#include +#include +#include +#include +#include +#include "tools/common/converter_op_utils.h" +#include "tools/converter/optimizer.h" +#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h" + +namespace mindspore { +namespace lite { +#define CONV_OP_NO_BIAS_WEIGHT_NUM 1 +#define CONV_OP_HAS_BIAS_WEIGHT_NUM 2 +#define CONV_OP_NO_BIAS_INPUT_NUM 2 +#define CONV_OP_HAS_BIAS_INPUT_NUM 3 + +#define CONV_OP_FILTER_INDEX_IN_WEIGHT 0 +#define CONV_OP_BIAS_INDEX_IN_WEIGHT 1 +#define CONV_OP_FILTER_INDEX_IN_INPUT 1 +#define CONV_OP_BIAS_INDEX_IN_INPUT 2 + +#define CONV_FILTER_SHAPE_SIZE 4 + +// PatternOp Ids +constexpr const char *kConvName = "CONVOLUTION"; +constexpr const char *DST_NAME = "DESTINATION"; +constexpr const char *ACTIVATION_NAME = "ACTIVATION"; +constexpr const char *BIASADD_NAME = "BIASADD"; + +class FusionPass : public GraphPass { + public: + FusionPass() = default; + + ~FusionPass() override; + + virtual STATUS DefinePattern() = 0; + + STATUS Run(schema::MetaGraphT *graph) override; + + virtual STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) = 0; + + protected: + STATUS MatchPatterns(schema::MetaGraphT *graph); + + STATUS MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pattern); + + bool MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std::shared_ptr &target, + std::vector &sinkIdes, std::vector &pathSinkIdes); + + static bool CheckMatch(schema::MetaGraphT *graph, const std::shared_ptr& patternOp); + + void MergeNodeAttrFromPost(std::unique_ptr &dstOp, std::unique_ptr &postOp, + size_t dstOpOutIdx = 0); + + STATUS Fuse(schema::MetaGraphT *graph); + + protected: + std::vector patterns; + std::map>> matchedPaths; + // {name of pattern, vector<{name of pattern node, path}>} + std::map>>> mapedMatchedPaths; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.cc new file mode 100644 index 0000000000..6779d84ed2 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +// using namespace std; + +FusionPattern::FusionPattern(std::string name) { this->name = std::move(name); } + +FusionPattern::~FusionPattern() = default; + +FusionPattern &FusionPattern::SetName(const std::string &name) { + this->name = name; + return *this; +} + +FusionPattern &FusionPattern::AddPatternOp(const std::string &id, + const std::initializer_list &types) { + return AddPatternOp(id, std::vector(types)); +} + +FusionPattern &FusionPattern::AddPatternOp(const std::string &id, const std::vector &types) { + if (id.empty()) { + // MS_LOG(ERROR) << "Id cannot be empty"); + hasError = true; + } + + if (GetPatternOp(id) != nullptr) { + // MS_LOG(ERROR) << "Id repeated. (id:%s)", id.c_str()); + hasError = true; + } + + std::shared_ptr op(new PatternOp()); + if (op == nullptr) { + // MS_LOG(ERROR) << "new an object failed"); + hasError = true; + } else { + op->id = id; + op->types = types; + ops.push_back(op); + opMap[id] = op; + } + + return *this; +} + +FusionPattern &FusionPattern::RemovePatternOp(const std::string &id) { + for (uint32_t loop = 0; loop < ops.size(); loop++) { + std::shared_ptr op = ops.at(loop); + if (op->id == id) { + ops.erase(ops.begin() + loop); + opMap.erase(id); + break; + } + } + return *this; +} + +bool FusionPattern::Check() { + if (hasError) { + // MS_LOG(ERROR) << "Has Error in previous Func"); + return false; + } + + if (GetPatternOp(this->outputOpId) == nullptr) { + // MS_LOG(ERROR) << "Can not find the output of the pattern"); + return false; + } + + return true; +} + +void FusionPattern::Dump() const { + std::ostringstream oss; + oss << std::endl << "Pattern " << name << std::endl; + for (const auto op : ops) { + oss << " " << op->id << ": {"; + for (auto &type : op->types) { + oss << schema::EnumNamePrimitiveType(type) << ", "; + } + oss << "} {"; + if (op->left != nullptr) { + oss << "leftPreNode: " << op->left->id << ", "; + } + if (op->right != nullptr) { + oss << "rightPreNode: " << op->right->id << ", "; + } + oss << "}"; + + oss << std::endl; + } +} + +std::shared_ptr FusionPattern::GetPatternOp(const std::string &id) const { + auto it = opMap.find(id); + if (it != opMap.end()) return it->second; + + return nullptr; +} + +std::string FusionPattern::GetOutput() const { return this->outputOpId; } + +FusionPattern &FusionPattern::AddPatternOp(const std::shared_ptr &patternOp) { + ops.push_back(patternOp); + opMap[patternOp->id] = patternOp; + return *this; +} + +FusionPattern &FusionPattern::Finish() { + std::vector ids; + std::set nodeInputIds; + std::vector inputNodeIds; + for (auto patternOp : ops) { + if (IsContain(ids, patternOp->id)) { + // MS_LOG(ERROR) << "Duplicate id find: %s", patternOp->id.c_str()); + hasError = true; + return *this; + } + ids.emplace_back(patternOp->id); + if (patternOp->left != nullptr) { + nodeInputIds.insert(patternOp->left->id); + } + if (patternOp->right != nullptr) { + nodeInputIds.insert(patternOp->right->id); + } + if (patternOp->left == nullptr && patternOp->right == nullptr) { + inputNodeIds.emplace_back(patternOp->id); + } + } + for (auto iter = ids.begin(); iter != ids.end();) { + if (nodeInputIds.find(*iter) != nodeInputIds.end()) { + iter = ids.erase(iter); + } else { + iter++; + } + } + if (ids.size() > 1) { + // MS_LOG(ERROR) << "Multi-output node find, only support pattern with one output"); + hasError = true; + return *this; + } + if (ids.empty()) { + // MS_LOG(ERROR) << "No output node find, only support pattern with one output"); + hasError = true; + return *this; + } + this->outputOpId = ids.front(); + auto outputNode = GetPatternOp(this->outputOpId); + MS_ASSERT(outputNode != nullptr); + outputNode->isTail = true; + + for (auto inputNodeId : inputNodeIds) { + auto inputNode = GetPatternOp(inputNodeId); + MS_ASSERT(inputNode != nullptr); + inputNode->isHead = true; + } + return *this; +} + +std::string FusionPattern::GetName() { return this->name; } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h new file mode 100644 index 0000000000..334185e736 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h @@ -0,0 +1,141 @@ +#include + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FUSION_PATTERN_H +#define MINDSPORE_PREDICT_FUSION_PATTERN_H + +#include +#include +#include +#include +// #include +#include "utils/log_adapter.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +struct Path { + public: + Path(int32_t subGraphIdx, int32_t nodeIdx) : subGraphIdx(subGraphIdx), nodeIdx(nodeIdx) {} + int32_t subGraphIdx = -1; + int32_t nodeIdx = -1; +}; + +// Op description in pattern +struct PatternOp { + std::string id; // id of op in pattern + std::vector types; // type of matchable op + // TODO(...): only support node with no more than two preNode now + // avoid loop reference + std::shared_ptr left; // left input patternOp of this patternOp + std::shared_ptr right; // right input patternOp of this patternOp + std::shared_ptr path = std::make_shared(-1, -1); + bool pathSetted = false; + bool isHead = false; + bool isTail = false; + bool isPlaceHold = false; + + PatternOp() = default; + explicit PatternOp(const std::string &inId) : id(inId) {} + ~PatternOp() = default; + void SetPath(size_t subGraphIdx, size_t nodeIdx) { + MS_ASSERT(this->path != nullptr); + this->path->subGraphIdx = subGraphIdx; + this->path->nodeIdx = nodeIdx; + this->pathSetted = true; + } + void UnSetPath() { + MS_ASSERT(this->path != nullptr); + this->path->subGraphIdx = -1; + this->path->nodeIdx = -1; + this->pathSetted = false; + } + static std::shared_ptr Copy(const std::shared_ptr& src) { + if (src == nullptr) { + return nullptr; + } + auto dst = std::make_shared(); + dst->id = src->id; + dst->types = src->types; + if (src->path != nullptr) { + dst->path = std::make_shared(src->path->subGraphIdx, src->path->nodeIdx); + } + dst->pathSetted = src->pathSetted; + dst->isTail = src->isTail; + dst->isHead = src->isHead; + dst->isPlaceHold = src->isPlaceHold; + dst->left = PatternOp::Copy(src->left); + dst->right = PatternOp::Copy(src->right); + return dst; + } +}; + +class FusionPattern { + public: + explicit FusionPattern(std::string name = ""); + + ~FusionPattern(); + + std::string GetName(); + + FusionPattern &SetName(const std::string &name); + + FusionPattern &AddPatternOp(const std::string &id, const std::initializer_list &types = {}); + + FusionPattern &AddPatternOp(const std::string &id, const std::vector &types); + + FusionPattern &AddPatternOp(const std::shared_ptr& patternOp); + + FusionPattern &RemovePatternOp(const std::string &id); + + // set id of patternOp + // set isTail and isHead for patternOps + FusionPattern &Finish(); + + bool Check(); + // get the id of the output Op of th pattern + std::string GetOutput() const; + + void Dump() const; + + // return nullptr if not find + std::shared_ptr GetPatternOp(const std::string &id) const; + + private: + FusionPattern(const FusionPattern &) = default; + + FusionPattern &operator=(const FusionPattern &) = default; + + private: + std::string name; + + std::vector> ops; + + // same with ops, just for search + std::map> opMap; + + // output PatternOp id of pattern + std::string outputOpId; + + bool hasError = false; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FUSION_PATTERN_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc new file mode 100644 index 0000000000..aebd15db8d --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +// #include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define MATMUL_BIASADD_MATCH_PATH_LEN 2 +#define BIASADD_OP_BIAS_INDEX 1 +#define BIASADD_OP_INPUT_NUM 2 + +STATUS MatMulBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS MatMulBiasAddFusionPass::DefinePattern() { + auto matMulOp = std::make_shared(); + matMulOp->id = MATMUL_NAME; + matMulOp->types = {schema::PrimitiveType_MatMul}; + auto baOp = std::make_shared(); + baOp->id = BIASADD_NAME; + baOp->types = {schema::PrimitiveType_BiasAdd}; + baOp->left = matMulOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("MatMulBiasAddFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(matMulOp); + fusionPattern->AddPatternOp(baOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != MATMUL_BIASADD_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "MatMul-BiasAdd-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto matMulPath = matchedPath[MATMUL_NAME]; + auto baPath = matchedPath[BIASADD_NAME]; + auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); + auto &baNode = graph->nodes.at(baPath->nodeIdx); + // can not check shape because there is now shape infer in converter + MS_ASSERT(matMulNode != nullptr); + MS_ASSERT(matMulNode->inputIndex.size() == 2); + // biasadd node the second tensor is not constant tensor, don't fusion + auto baNodeInputIndex = baNode->inputIndex; + if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << "%s node tensors number is invalid! "; // baNode->name.c_str()); + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX)); + const auto &baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX)); + MS_ASSERT(baNodeBiasTensor != nullptr); + if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { + // dont fusion, return + return RET_OK; + } + + // 1. add biasTensor for matMul + auto status = AddFullConnectionBiasTensor(matMulPath, baPath, graph); + if (RET_OK != status) { + MS_LOG(ERROR) << "AddFullConnectionBiasTensor failed, %d"; // status); + return status; + } + + // 2. change matmul to full connection op + matMulNode->name += "-fc"; + std::unique_ptr fcAttr(new FullConnectionT()); + if (fcAttr == nullptr) { + MS_LOG(ERROR) << "new FullConnectionT node failed"; + return RET_ERROR; + } + fcAttr->hasBias = true; + fcAttr->axis = 1; + MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr); + transA = matMulNode->primitive->value.AsMatMul()->transposeA; + transB = matMulNode->primitive->value.AsMatMul()->transposeB; + MS_ASSERT(matMulNode->primitive->value.value != nullptr); + delete (matMulNode->primitive->value.value); + matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection; + matMulNode->primitive->value.value = fcAttr.release(); + + // 3. delete BiasAdd node + MergeNodeAttrFromPost(matMulNode, baNode); + status = IsolateOneWayNode(graph, baPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: %zu, node: %zu, error: %d"; + // baPath->subGraphIdx, baPath->nodeIdx, status); + return status; + } + + // 4. addTranspose node + status = InsertTransposeNode(graph, matMulPath); + if (status != RET_OK) { + MS_LOG(ERROR) + << "InsertTransposeNode failed, subGraph: %zu, node: %zu, error: %d"; // matMulPath->subGraphIdx, + // matMulPath->nodeIdx, status); + return status; + } + return RET_OK; +} + +STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std::shared_ptr &matMulPath) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(matMulPath != nullptr); + auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); + MS_ASSERT(graph->allTensors.size() > matMulNode->inputIndex.at(0)); + MS_ASSERT(graph->allTensors.size() > matMulNode->inputIndex.at(2)); + const auto &tensorA = graph->allTensors.at(matMulNode->inputIndex.at(0)); + const auto &tensorB = graph->allTensors.at(matMulNode->inputIndex.at(1)); + + std::vector insertNodeIdxList; + if (transA) { + insertNodeIdxList.emplace_back(0); + } + if (!transB) { + insertNodeIdxList.emplace_back(1); + } + + auto matmulOpIter = graph->nodes.begin() + matMulPath->nodeIdx; + STATUS errorCode = RET_OK; + for (auto needInsertIdx : insertNodeIdxList) { + auto transNode = std::unique_ptr(new (std::nothrow) CNodeT); + if (transNode == nullptr) { + MS_LOG(ERROR) << "new TransNode failed"; + return RET_ERROR; + } + transNode->name = "transpose" + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Transpose; + std::unique_ptr transposeParam(new TransposeT()); + if (transposeParam == nullptr) { + MS_LOG(ERROR) << "new transposeParam failed"; + return RET_ERROR; + } + transposeParam->conjugate = false; + transposeParam->perm = {1, 0}; + transNode->primitive->value.value = transposeParam.release(); + matmulOpIter = + InsertNode(graph, matmulOpIter, kBefore, needInsertIdx, std::move(transNode), &errorCode, TransposeOpCopyer); + if (errorCode != RET_OK) { + MS_LOG(ERROR) << "InsertNode failed: %d"; // errorCode); + return errorCode; + } + } + return RET_OK; +} + +#define BIASADD_WEIGHT_SHAPE_SIZE 1 +#define BIASADD_BIAS_DIM_INDEX 0 + +STATUS MatMulBiasAddFusionPass::AddFullConnectionBiasTensor(const std::shared_ptr &matMulPath, + const std::shared_ptr &baPath, MetaGraphT *graph) { + MS_ASSERT(matMulPath != nullptr); + MS_ASSERT(baPath != nullptr); + MS_ASSERT(graph != nullptr); + + MS_ASSERT(graph->nodes.size() > matMulPath->nodeIdx); + auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); + MS_ASSERT(matMulNode != nullptr); + auto baNode = graph->nodes.at(baPath->nodeIdx).get(); + MS_ASSERT(baNode != nullptr); + + // check biasTensor + auto baWeightTensorIdxes = baNode->inputIndex; + if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << "%s node tensors number is invalid! "; // baNode->name.c_str()); + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); + auto &biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); + MS_ASSERT(biasTensor != nullptr); + auto biasDims = biasTensor->dims; + // if biasTensor is a scaler + if (biasDims.empty() && biasTensor->data.data() == nullptr) { + MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid"; // baNode->name.c_str()); + return RET_ERROR; + } + if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { + MS_LOG(ERROR) + << "BiasAdd bias tensor should has one dimension, current number of dimension %zu. or bias tensor is a scaler"; + // biasDims.size()); + return RET_ERROR; + } + // add biasTensor to matmul + matMulNode->inputIndex.emplace_back(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); + baNode->inputIndex.erase(baNode->inputIndex.begin() + BIASADD_OP_BIAS_INDEX); + + return RET_OK; +} + +MatMulBiasAddFusionPass::~MatMulBiasAddFusionPass() = default; +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h new file mode 100644 index 0000000000..cc8ad536d0 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H +#define MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H + +#include +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +constexpr const char *MATMUL_NAME = "MATMUL"; + +class MatMulBiasAddFusionPass : public FusionPass { + public: + MatMulBiasAddFusionPass() = default; + + ~MatMulBiasAddFusionPass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(MetaGraphT *graph) override; + + protected: + static STATUS AddFullConnectionBiasTensor(const std::shared_ptr& matMulPath, + const std::shared_ptr& dstPath, + MetaGraphT *subGraph); + STATUS InsertTransposeNode(MetaGraphT *subGraph, const std::shared_ptr& matMulPath); + + protected: + bool transA = false; + bool transB = false; + size_t id = 0; + + OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr &inOpDef) -> std::unique_ptr { + std::unique_ptr newOpDef(new (std::nothrow) CNodeT); + if (newOpDef == nullptr) { + MS_LOG(ERROR) << "new OpDefT failed"; + return nullptr; + } + newOpDef->name = inOpDef->name; + newOpDef->quantType = inOpDef->quantType; + newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; + auto transposeParam = new (std::nothrow) TransposeT; + if (transposeParam == nullptr) { + MS_LOG(ERROR) << "new transposeParam failed"; + return nullptr; + } + auto inParam = inOpDef->primitive->value.AsTranspose(); + MS_ASSERT(inParam != nullptr); + transposeParam->conjugate = inParam->conjugate; + transposeParam->perm.resize(inParam->perm.size()); + std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(), + [](const int32_t ele) { return ele; }); + newOpDef->primitive->value.value = transposeParam; + return std::move(newOpDef); + }; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc new file mode 100644 index 0000000000..3bd0712259 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define kQuantCastMatchPathLen2 2 +#define kQuantCastMatchPathLen3 3 + +STATUS QuantCastFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != kQuantCastMatchPathLen2 && matchedPath.size() != kQuantCastMatchPathLen3) { + MS_LOG(ERROR) << "QuantDtypeCastFusion should have " << kQuantCastMatchPathLen2 << " or " << + kQuantCastMatchPathLen3 << " NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto srcPath = matchedPath[kQuantCastSrcOp]; + MS_ASSERT(srcPath != nullptr); + auto dstPath = matchedPath[kQuantCastDstOp]; + MS_ASSERT(dstPath != nullptr); + auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); + MS_ASSERT(srcNode != nullptr); + auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); + MS_ASSERT(dstNode != nullptr); + + // todo check + if (srcNode->inputIndex.empty() && srcNode->outputIndex.empty()) { + MS_LOG(DEBUG) << "srcNode " << srcNode->name.c_str() << " has been removed"; + return RET_NO_CHANGE; + } + if (dstNode->inputIndex.empty() && dstNode->outputIndex.empty()) { + MS_LOG(DEBUG) << "dstNode " << dstNode->name.c_str() << " has been removed"; + return RET_NO_CHANGE; + } + + auto srcAttr = srcNode->primitive->value.AsQuantDTypeCast(); + auto dstAttr = dstNode->primitive->value.AsQuantDTypeCast(); + MS_ASSERT(srcAttr != nullptr); + MS_ASSERT(dstAttr != nullptr); + if (srcAttr->dstT != dstAttr->srcT || srcAttr->srcT != dstAttr->dstT) { + MS_LOG(ERROR) << "srcNode and dstNode can not been fused"; + return RET_ERROR; + } + + auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name.c_str() << ", error: " << status; + return status; + } + + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name.c_str() << ", error: " << status; + return status; + } + + return RET_OK; +} + +STATUS QuantCastFusionPass::DefinePattern() { + // quantCast + quantCast + { + auto srcOp = std::make_shared(); + srcOp->id = kQuantCastSrcOp; + srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; + auto dstOp = std::make_shared(); + dstOp->id = kQuantCastDstOp; + dstOp->types = {schema::PrimitiveType_QuantDTypeCast}; + dstOp->left = srcOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(kQuantCastFusionPattern)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failde"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(srcOp); + fusionPattern->AddPatternOp(dstOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + // quantCast + formatTrans + quantCast + { + auto srcOp = std::make_shared(); + srcOp->id = kQuantCastSrcOp; + srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; + auto formatOp = std::make_shared(); + formatOp->id = kFormatTransOp; + formatOp->types = {schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc}; + formatOp->left = srcOp; + auto dstOp = std::make_shared(); + dstOp->id = kQuantCastDstOp; + dstOp->types = {schema::PrimitiveType_QuantDTypeCast}; + dstOp->left = formatOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(kQuantCastPassFusionPattern)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failde"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(srcOp); + fusionPattern->AddPatternOp(formatOp); + fusionPattern->AddPatternOp(dstOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h new file mode 100644 index 0000000000..dad09cfd02 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_QUANT_CAST_FUSION_PASS_H +#define MINDSPORE_PREDICT_QUANT_CAST_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +constexpr const char *kQuantCastSrcOp = "QuantCastSrcOp"; +constexpr const char *kFormatTransOp = "FormatTransOp"; +constexpr const char *kQuantCastDstOp = "QuantCastDstOp"; + +constexpr const char *kQuantCastFusionPattern = "QuantCastFusionPattern"; +constexpr const char *kQuantCastPassFusionPattern = "QuantCastPassFusionPattern"; + +class QuantCastFusionPass : public FusionPass { + public: + QuantCastFusionPass() = default; + + ~QuantCastFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_QUANT_CAST_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt new file mode 100755 index 0000000000..e5d2ceac19 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(graph_pass_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc + ) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc new file mode 100644 index 0000000000..5caaf046b4 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" +#include "utils/log_adapter.h" +#include "src/common/common.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +#define kMinInputNum 1 +#define kOutputNum 1 + +STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { + if (fmkType == converter::FmkType_TF) { + return RET_OK; + } + MS_ASSERT(graph != nullptr); + auto status = DoModelInputFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; + return status; + } + status = DoNodeInoutFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; + return status; + } + return RET_OK; +} + +STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { + if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) { + return RET_OK; + } + MS_ASSERT(graph != nullptr); + // insert trans node in model input tensor + if (graph->nodes.empty()) { + return RET_OK; + } + auto graphInputIdxes = graph->inputIndex; + for (size_t i = 0; i < graphInputIdxes.size(); i++) { + auto inputIdx = graphInputIdxes.at(i); + MS_ASSERT(inputIdx < subGraph->allTensors.size()); + auto &tensor = graph->allTensors.at(inputIdx); + if (tensor->dims.size() != kNCHWDimNumber) { + continue; + } + + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { + if (node->inputIndex.at(inputIndexIdx) == inputIdx) { + STATUS status = RET_OK; + iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; + return status; + } + // set first tensor format to nhwc + auto &transNode = *(iter - 1); + MS_ASSERT(transNode != nullptr); + MS_ASSERT(transNode->inputIndex.size() == 1); + MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); + auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); + graphInTensor->format = schema::Format_NHWC; + // assume parser not reformat shape + auto oldDims = graphInTensor->dims; + graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; + break; + } + } + } + } + return RET_OK; +} + +// inference needed inputFormat: +// conv deconv depth dedepth +// fp32 NCHW NCHW NCHW NCHW +// uint8 NCHW ? NCHW ? +STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + // insert before and after the op cal by nchw/nc4hw4 + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + FormatTransNodeType beforeNodeType, afterNodeType; + if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc + // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use + // nhwc + // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only + // support nhwc + // continue; + // } + // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + // continue; + // } + // } else { + // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + // } + // } + // beforeNodeType = kNCHW2NHWC; + // afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw + // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc + // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc + // continue; + // } + // } else { + // continue; + // } + if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + } + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_MS) { + if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + } + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else { + MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; + return RET_ERROR; + } + auto &node = *iter; + auto nodeName = node->name; + if (node->inputIndex.size() < kMinInputNum) { + MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; + return RET_ERROR; + } + if (node->outputIndex.size() != kOutputNum) { + MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; + return RET_ERROR; + } + STATUS status; + iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; + return RET_ERROR; + } + + iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, + size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) { + MS_ASSERT((*existNodeIter) != nullptr); + auto existNodeName = (*existNodeIter)->name; + std::string tileName; + if (place == kBefore) { + tileName = existNodeName + "_pre"; + } else { + tileName = existNodeName + "_post"; + } + auto transNode = std::make_unique(); + transNode->primitive = std::make_unique(); + + if (nodeType == kNCHW2NHWC) { + transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; + } else { + transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; + } + return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); +} + +void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h new file mode 100644 index 0000000000..2fc754a36d --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H +#define MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H + +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW }; + +class FormatTransPass : public GraphPass { + public: + FormatTransPass() : id(0) {} + + ~FormatTransPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + void SetQuantType(QuantType quantType); + + void SetFmk(converter::FmkType fmkType); + + private: + STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); + + STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); + + NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, + FormatTransNodeType nodeType, STATUS *errorCode); + + private: + size_t id; + QuantType quantType = QuantType_QUANT_NONE; + converter::FmkType fmkType = converter::FmkType_TF; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc new file mode 100644 index 0000000000..b6f0113bfe --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS IsolatedNodeRemovePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + bool ifChanged = false; + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end();) { + if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) { + ifChanged = true; + iter = graph->nodes.erase(iter); + } else { + iter++; + } + } + return ifChanged ? RET_OK : RET_NO_CHANGE; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h new file mode 100644 index 0000000000..293ccd8920 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H +#define MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H + +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class IsolatedNodeRemovePass : public GraphPass { + public: + IsolatedNodeRemovePass() = default; + + ~IsolatedNodeRemovePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.cc new file mode 100644 index 0000000000..b6979b7dbc --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +STATUS ModelInputFormatPreProcessPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto inputIndex : graph->inputIndex) { + if (graph->allTensors[inputIndex]->dims.size() == 4) { + std::vector tmpDims(graph->allTensors[inputIndex]->dims); + auto status = + NodeUtils::ConvertDims(schema::Format_NCHW, tmpDims, schema::Format_NHWC, &graph->allTensors[inputIndex]->dims); + if (status == RET_OK) { + graph->allTensors[inputIndex]->format = schema::Format_NHWC; + } else { + MS_LOG(ERROR) << "ConvertDims from NHWC to NCHW error: " << status; + return RET_ERROR; + } + } else { + graph->allTensors[inputIndex]->format = schema::Format_NHWC; + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.h new file mode 100644 index 0000000000..187c93079e --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/model_input_format_preprocess_pass.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_MODEL_FORMAT_PREPROCESS_PASS_H +#define MINDSPORE_PREDICT_MODEL_FORMAT_PREPROCESS_PASS_H + +#include +#include "tools/converter/optimizer.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class ModelInputFormatPreProcessPass : public GraphPass { + public: + ModelInputFormatPreProcessPass() = default; + + ~ModelInputFormatPreProcessPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_MODEL_FORMAT_PREPROCESS_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc new file mode 100644 index 0000000000..0478584f22 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" +#include "tools/common/converter_op_utils.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + std::vector> newNodes; + std::vector sinkedTensorIdxes; + // put all const tensor index into sinkedTensorIdxes + for (size_t i = 0; i < graph->allTensors.size(); i++) { + if (graph->allTensors.at(i)->nodeType == schema::NodeType_ValueNode) { + sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i); + } + } + auto &oldNodes = graph->nodes; + std::queue> opQueue; + // put all non depend node into queue + for (auto &node : graph->nodes) { + if (IsNodeNonDepend(node, sinkedTensorIdxes)) { + sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end()); + opQueue.push(std::move(node)); + } + } + // bfs + while (!opQueue.empty()) { + auto &node = opQueue.front(); + auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get())); + for (auto postNodeIdx : postNodeIdxes) { + auto &postNode = oldNodes.at(postNodeIdx); + // check if postNode is non-depended + if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) { + sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end()); + opQueue.push(std::move(postNode)); + } + } + newNodes.emplace_back(std::move(node)); + opQueue.pop(); + } + if (newNodes.size() != oldNodes.size()) { + MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size() + << ", newNodesSize: " << newNodes.size(); + return RET_ERROR; + } + graph->nodes.swap(newNodes); + return RET_OK; +} + +bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr &node, + const std::vector &sinkedTensorIdxes) { + for (auto inputIdx : node->inputIndex) { + if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) { + return false; + } + } + return true; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.h new file mode 100644 index 0000000000..994648ab57 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H +#define MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H + +#include +#include +#include "mindspore/lite/tools/converter/optimizer.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +class TopologicalSortPass : public GraphPass { + public: + TopologicalSortPass() = default; + + ~TopologicalSortPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + bool IsNodeNonDepend(const std::unique_ptr &node, const std::vector &sinkedTensorIdxes); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc new file mode 100644 index 0000000000..01968311a6 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + bool ifChanged = false; + for (size_t i = 0; i < graph->nodes.size(); i++) { + auto &node = graph->nodes.at(i); + if (node->primitive->value.type == schema::PrimitiveType_TupleGetItem) { + ifChanged = true; + auto status = IsolateOneWayNode(graph, i); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << graph->name << ", node: " << node->name + << ", error: " << status; + return status; + } + } + } + return ifChanged ? RET_OK : RET_NO_CHANGE; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h new file mode 100644 index 0000000000..7716592a24 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H +#define MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H + +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class UnusedNodeRemovePass : public GraphPass { + public: + UnusedNodeRemovePass() = default; + + ~UnusedNodeRemovePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H + diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt new file mode 100755 index 0000000000..6288071c81 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(node_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_pass.cc + ) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc new file mode 100644 index 0000000000..5432f58ae3 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -0,0 +1,403 @@ +/** + * Copyright 201+ Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/node/weight_format_pass.h" +#include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +int WeightFormatPass::Run(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto status = ShapeFormatTrans(graphNode); + if (status != 0) { + MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; + return status; + } + if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) { + status = QuantDataFormatTrans(graphNode); + if (status != 0) { + MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; + return status; + } + } else { + status = NonQuantDataFormatTrans(graphNode); + if (status != 0) { + MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; + return status; + } + } + return 0; +} + +void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } + +// pre set tensor format +// non quant, filterFormat: +// conv deconv depth dedepth +// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp +// tf HWCK HWKC HWCK HWKC +// onnx K(C/g)HW C(K/g)HW / / + +// awareing quant, filterFormat: +// conv deconv depth dedepth +// onnx KHWC ? CHWK ? +// tf HWCK ? HWCK ? +int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto &subGraph = graphNode->subGraph; + auto &node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto opType = node->primitive->value.type; + if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && + opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { + return 0; + } + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = subGraph->allTensors[weightIndex]; + auto &shape = weightTensor->dims; + MS_ASSERT(shape.size() == 4); + if (fmkType == converter::FmkType_CAFFE) { + switch (node->quantType) { + case QuantType_QUANT_NONE: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_KCHW; + } else { + MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) + << ", node: " << node->name.c_str(); + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid quantType: " << schema::EnumNameQuantType(node->quantType) + << ", node: " << node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_MS) { + switch (node->quantType) { + case QuantType_AwareTrainning: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_HWCK; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + case QuantType_QUANT_NONE: { + // conv [filter_height, filter_width, in_channels, out_channels] + // depthwise [filter_height, filter_width, in_channels, channel_multiplier] + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KCHW; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_KCHW; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_TF) { + switch (node->quantType) { + case QuantType_AwareTrainning: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_HWCK; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + case QuantType_QUANT_NONE: { + // conv [filter_height, filter_width, in_channels, out_channels] + // depthwise [filter_height, filter_width, in_channels, channel_multiplier] + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_HWCK; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_TFLITE) { + switch (node->quantType) { + case QuantType_QUANT_NONE: + case QuantType_AwareTrainning: + case QuantType_PostTraining: { + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KHWC; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CHWK; + } else if (opType == schema::PrimitiveType_DeConv2D) { + weightTensor->format = schema::Format_CHWK; + } else { + MS_LOG(ERROR) << "unsupport format"; + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_ONNX) { + switch (node->quantType) { + case QuantType_AwareTrainning: { + // sum up from current onnx quant models + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KHWC; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CHWK; + } else { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } break; + case QuantType_QUANT_NONE: { + // conv (K x C/group x kH x kW) group = 1 + // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) + // deconv (C x K/group x kH x kW) group = 1 + // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KCHW; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CKHW; + } else if (opType == schema::PrimitiveType_DeConv2D) { + weightTensor->format = schema::Format_CKHW; + } else { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: %d, node: " << node->quantType, node->name.c_str(); + return -1; + } + } + } else { + MS_LOG(ERROR) << "Invalid fmkType: %d, node: " << fmkType, node->name.c_str(); + return -1; + } + return 0; +} + +int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto &subGraph = graphNode->subGraph; + auto &node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto opType = node->primitive->value.type; + if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && + opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { + return 0; + } + + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = subGraph->allTensors[weightIndex]; + MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT + STATUS status; + if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK + if (weightTensor->format == schema::Format_KCHW) { // from caffe + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format + << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + } else { + MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format + << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + } + } else if (weightTensor->format == schema::Format_KHWC) { // from onnx + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); + } else { + status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); + } + } else if (weightTensor->format == schema::Format_HWCK) { // from tf + return 0; + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_HWCK; + } else { + MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : " + << (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"), + node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK + if (weightTensor->format == schema::Format_CKHW) { // from caffe + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, + weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); + } else { + MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, + weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); + } + + } else if (weightTensor->format == schema::Format_HWCK) { // from tf + return 0; + } else if (weightTensor->format == schema::Format_CHWK) { // from onnx + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); + } else { + status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); + } + } else if (weightTensor->format == schema::Format_KCHW) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + } else { + status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + } + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_HWCK; + } else { + MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : " + << (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"), + node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK + node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_CKHW; + } + return 0; +} + +// inference needed filterFormat: +// conv deconv depth dedepth +// fp32 KCHW CKHW CKHW CKHW +int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto &subGraph = graphNode->subGraph; + auto &node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto opType = node->primitive->value.type; + if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && + opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { + return 0; + } + + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = subGraph->allTensors[weightIndex]; + if (weightTensor->dataType != TypeId::kNumberTypeFloat32) { + MS_LOG(ERROR) << "weight tensor data should be float"; + // return -1; + } + STATUS status = RET_OK; + if (opType == schema::PrimitiveType_Conv2D) { // weight should be KCHW + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_KHWC) { + status = RET_OK; + } else if (weightTensor->format == schema::Format_CHWK) { + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_KHWC; + } else { + MS_LOG(WARNING) << "TransFilter " << ((weightTensor->format == schema::Format_HWCK) ? "HWCK" : "NHWC") + << "ToKCHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW + if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); + } else if (weightTensor->format == schema::Format_KCHW) { + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_CHWK) { + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_CKHW; + } else { + MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_CHWK) { // from tf + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_KHWC; + } else { + MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be CKHW + if (weightTensor->format == schema::Format_CKHW) { // from caffe + return 0; + } else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx + status = TransFilterFormat(weightTensor.get(), kHWKC2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_CKHW; + } else { + MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } + return 0; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h new file mode 100644 index 0000000000..cf7d3d462d --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H +#define MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H + +#include "tools/converter/optimizer.h" +#include "tools/converter/converter_flags.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +class WeightFormatPass : public NodePass { + public: + WeightFormatPass() = default; + + ~WeightFormatPass() override = default; + + void SetQuantType(QuantType quantType); + + void SetFmkType(converter::FmkType fmkType); + + int Run(GraphNode *graphNode) override; + + private: + // correct weightTensor->Format + int ShapeFormatTrans(GraphNode *graphNode); + + // transform weightTensor data and format + // if quant : conv transform dataFormat to NHWC, weight format to HWCK + // if quant : depth transform dataFormat to NCHW, weight format to CKHW + int QuantDataFormatTrans(GraphNode *graphNode); + + // if no quant : transform dataFormat to NCHW, weight format to KCHW/CKHW + int NonQuantDataFormatTrans(GraphNode *graphNode); + + private: + QuantType quantType = QuantType_QUANT_NONE; + converter::FmkType fmkType = converter::FmkType_TF; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H + diff --git a/mindspore/lite/tools/converter/main.cc b/mindspore/lite/tools/converter/main.cc new file mode 100644 index 0000000000..6923ed75c1 --- /dev/null +++ b/mindspore/lite/tools/converter/main.cc @@ -0,0 +1,20 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/converter.h" + +int main(int argc, const char **argv) { return mindspore::lite::RunConverter(argc, argv); } + diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h new file mode 100644 index 0000000000..f9014fbc4c --- /dev/null +++ b/mindspore/lite/tools/converter/model_parser.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_MODEL_PARSER_H +#define MS_MODEL_PARSER_H +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "src/common/anf_importer/import_from_meta_graphT.h" +#include "ir/anf.h" +#include "include/errorcode.h" + +namespace mindspore::lite { +using namespace schema; +class ModelParser { + public: + ModelParser() {} + + virtual ~ModelParser() {} + + virtual FuncGraphPtr ParseToAnf(const std::string &modelFile, const std::string &weightFile) { + auto *meta_graph = Parse(modelFile, weightFile); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; + return nullptr; + } + return Fb2Anf(Parse(modelFile, weightFile)); + } + virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0; + + public: + static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { + MS_EXCEPTION_IF_NULL(meta_graph); + auto func_graph = std::make_shared(); + auto importer = new AnfImporterFromMetaGraphT(meta_graph, func_graph); + auto ret = importer->Import(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << ret; + return nullptr; + } + return func_graph; + } +}; +} // namespace mindspore::lite + +#endif + + diff --git a/mindspore/lite/tools/converter/optimizer.cc b/mindspore/lite/tools/converter/optimizer.cc new file mode 100644 index 0000000000..d043138931 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +Optimizer::~Optimizer() { + for (auto pass : graphPasses) { + if (pass != nullptr) { + delete (pass); + } + } + + for (auto pass : nodePasses) { + if (pass != nullptr) { + delete (pass); + } + } +} + +void Optimizer::AddPass(GraphPass *graphPass) { + if (graphPass != nullptr) { + this->graphPasses.emplace_back(graphPass); + } +} + +void Optimizer::AddPass(NodePass *nodePass) { + if (nodePass != nullptr) { + this->nodePasses.emplace_back(nodePass); + } +} + +STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) { + STATUS status; + bool ifNotChanged = true; + // each node should go through all node pass not each node pass go through all node + for (auto &opDef : graphDefT->nodes) { + for (auto pass : this->nodePasses) { + status = pass->Run(new GraphNode(graphDefT, opDef.get())); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run NodePass failed"; + return status; + } else { + if (status == RET_OK) { + ifNotChanged = false; + } + } + } + } + + for (auto pass : this->graphPasses) { + status = pass->Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run GraphPass failed"; + return status; + } else { + if (status == RET_OK) { + ifNotChanged = false; + } + } + } + return ifNotChanged ? RET_NO_CHANGE : RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer.h b/mindspore/lite/tools/converter/optimizer.h new file mode 100644 index 0000000000..346e8c6016 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_OPTIMIZER_H +#define MS_OPTIMIZER_H +#include +#include "schema/inner/model_generated.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +using namespace schema; +template +class Pass { + public: + Pass() = default; + virtual ~Pass() = default; + virtual STATUS Run(T *t) = 0; +}; + +class GraphPass : public Pass { + public: + GraphPass() = default; + + ~GraphPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override = 0; + + // protected: + // GraphDefT *graphDefT = nullptr; +}; + +struct GraphNode { + GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : subGraph(subGraph), opDef(opDefT) {} + ~GraphNode() = default; + schema::MetaGraphT *subGraph = nullptr; + schema::CNodeT *opDef = nullptr; +}; + +class NodePass : public Pass { + public: + NodePass() = default; + + ~NodePass() override = default; + + STATUS Run(GraphNode *graphNode) override = 0; + + // protected: + // GraphNode *graphNode = nullptr; +}; + +class Optimizer { + public: + Optimizer() = default; + + virtual ~Optimizer(); + + void AddPass(GraphPass *graphPass); + + void AddPass(NodePass *nodePass); + + STATUS Run(schema::MetaGraphT *graphDefT); + + private: + std::vector graphPasses; + std::vector nodePasses; +}; +} // namespace lite +} // namespace mindspore + +#endif + + diff --git a/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt new file mode 100644 index 0000000000..97406e5bc7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt @@ -0,0 +1,52 @@ +add_library(caffe_parser_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/caffe.pb.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_argmax_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_argmax_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_batchnorm_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_batchnorm_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_concat_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_concat_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_conv_base_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_conv_base_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_converter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_converter.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_convolution_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_convolution_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_crop_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_crop_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_deconvolution_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_deconvolution_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_eltwise_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_eltwise_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_innerproduct_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_innerproduct_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_relu_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_relu_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_reshape_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_reshape_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_scale_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_scale_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_sigmoid_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_sigmoid_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_softmax_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_softmax_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_interp_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_interp_parser.h) diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe.proto b/mindspore/lite/tools/converter/parser/caffe/caffe.proto new file mode 100755 index 0000000000..75ae1aa357 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe.proto @@ -0,0 +1,1675 @@ +syntax = "proto2"; + +package caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional ProposalParameter proposal_param = 900; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ROIPoolingParameter roi_pooling_param = 147; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + +// Message that stores parameters used by ROIPoolingLayer +message ROIPoolingParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; +} + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by ProposalLayer +message ProposalParameter { + optional float feat_stride = 1; + optional float base_size = 2; + optional float min_size = 3; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6; + optional int32 post_nms_topn = 7; + optional float nms_thresh = 8; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; //分类的种类 + optional uint32 coords = 2 [default = 4]; //box的坐标数 + optional uint32 boxes = 3 [default = 1]; //每个grid预测的boxes数 + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + optional int32 axis = 1 [default = 1]; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional int32 scale = 1[default = 1]; +} diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc new file mode 100644 index 0000000000..a1035eec52 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + op->name = proto.name(); + std::unique_ptr attr(new schema::ArgMaxT()); + const caffe::ArgMaxParameter argmaxParam = proto.argmax_param(); + + int32_t axisType = 0; + int32_t axis = 0; + if (!argmaxParam.has_axis()) { + axisType = 2; + } else { + axisType = 1; + axis = (int64_t)argmaxParam.axis(); + if (axis == -1) { + // MS_LOGE("axis with -1 may lead to calculation errors when input less than 4 dims."); + return RET_ERROR; + } + } + + attr->axis = axis; + attr->axisType = axisType; + attr->outMaxValue = argmaxParam.out_max_val(); + attr->topK = argmaxParam.top_k(); + attr->keepDims = true; + + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_ArgMax; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeArgMaxParser("ArgMax", new CaffeArgMaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h new file mode 100644 index 0000000000..b539c49687 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ARGMAX_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ARGMAX_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeArgMaxParser : public CaffeNodeParser { + public: + CaffeArgMaxParser() : CaffeNodeParser("argmax") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ARGMAX_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc new file mode 100644 index 0000000000..7b22fd2d95 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h" +#include "tools/common/tensor_util.h" + +#define CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT 0.00001 +#define CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT 0.000000001 + +static const int CAFFE_BATCHNORMAL_BOTTOM_SIZE = 1; +static const int CAFFE_BATCHNORMAL_TOP_SIZE = 1; + +namespace mindspore { +namespace lite { +using STATUS = int; +STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + op->name = proto.name(); + // caffe batch norm attr + std::unique_ptr attr(new schema::BatchNormT()); + const caffe::BatchNormParameter batchNormParam = proto.batch_norm_param(); + + // check bottom size + if (proto.bottom_size() != CAFFE_BATCHNORMAL_BOTTOM_SIZE) { + // MS_LOGE("Layer %s bottom numbers is error, it must be %d, but is %d", proto.name().c_str(), + // CAFFE_BATCHNORMAL_BOTTOM_SIZE, proto.bottom_size()); + return RET_ERROR; + } + + // check top size + if (proto.top_size() != CAFFE_BATCHNORMAL_TOP_SIZE) { + // MS_LOGE("Layer %s top numbers is error, it must be %d, but is %d", \ + proto.name().c_str(), CAFFE_BATCHNORMAL_TOP_SIZE, + // proto.top_size()); + return RET_ERROR; + } + + if (batchNormParam.has_eps()) { + if (fabs(CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT - batchNormParam.eps()) < CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT) { + attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; + } else { + auto tmpAuto = batchNormParam.eps(); + attr->epsilon = tmpAuto; + } + } else { + attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; + } + + const float blob2Data = + (weight.blobs(2).double_data_size() > 0) ? weight.blobs(2).double_data(0) : weight.blobs(2).data(0); + const float scaleFactor = blob2Data == 0 ? 0 : 1 / blob2Data; + + // parse weight gamma + auto gamma = ConvertWeight(weight.blobs(0)); + if (gamma == nullptr) { + // MS_LOGE("Convert blobs(0) for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + + auto estimatedMean = reinterpret_cast(gamma->data.data()); + auto estimatedMeanShapeSize = GetShapeSize(*gamma); + for (size_t i = 0; i < estimatedMeanShapeSize; i++) { + estimatedMean[i] = estimatedMean[i] * scaleFactor; + } + estimatedMean = nullptr; + weightVec->push_back(gamma); + + // parse weight beta + auto beta = ConvertWeight(weight.blobs(1)); + if (beta == nullptr) { + // MS_LOGE("Convert blobs(1) for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + + auto estimatedVariance = reinterpret_cast(beta->data.data()); + size_t estimatedVarianceShapeSize = GetShapeSize(*beta); + for (size_t i = 0; i < estimatedVarianceShapeSize; i++) { + estimatedVariance[i] = estimatedVariance[i] * scaleFactor; + } + estimatedVariance = nullptr; + weightVec->push_back(beta); + + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_BatchNorm; + op->primitive->value.value = attr.release(); + + return RET_OK; +} + +CaffeNodeRegistrar g_caffeBatchNormParser("BatchNorm", new CaffeBatchNormParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h new file mode 100644 index 0000000000..aca0e41dbd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_BATCHNORM_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_BATCHNORM_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeBatchNormParser : public CaffeNodeParser { + public: + CaffeBatchNormParser() : CaffeNodeParser("batchnorm") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_BATCHNORM_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc new file mode 100644 index 0000000000..cd4ce95c7a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h" + +const int32_t CONCAT_DEFAULT_AXIS = 1; + +namespace mindspore { +namespace lite { +STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + op->name = proto.name(); + std::unique_ptr attr(new schema::ConcatT()); + const caffe::ConcatParameter concatParam = proto.concat_param(); + + if (concatParam.has_axis() && concatParam.has_concat_dim()) { + // MS_LOGE("Concat param in caffe have concat_dim and axis simultaneously,return fail"); + return RET_ERROR; + } + + if (concatParam.has_concat_dim()) { + // MS_LOGD("Concat dim , set axis:%d", concatParam.concat_dim()); + int32_t concat_dim_value = (int32_t)concatParam.concat_dim(); + + if (concat_dim_value < 0) { + // MS_LOGE("concat_dim value in model is smaller than 0:%d", concat_dim_value); + return RET_ERROR; + } + attr->axis = concat_dim_value; + } else if (concatParam.has_axis()) { + // MS_LOGD("axis , set axis:%d", concatParam.axis()); + int32_t tmpInt = (int32_t)concatParam.axis(); + attr->axis = tmpInt; + } else { + // MS_LOGD("default , set axis:%d", CONCAT_DEFAULT_AXIS); + attr->axis = CONCAT_DEFAULT_AXIS; + } + + attr->n = proto.bottom_size(); + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeConcatParser("Concat", new CaffeConcatParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h new file mode 100644 index 0000000000..10ae7013d2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONCAT_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONCAT_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeConcatParser : public CaffeNodeParser { + public: + CaffeConcatParser() : CaffeNodeParser("concat") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONCAT_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc new file mode 100644 index 0000000000..b3bc5adee2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h" + +const uint32_t PAD_DEFAULT_VALUE = 0; +const uint32_t STRIDE_DEFAULT_VALUE = 1; +const uint32_t DILATION_DEFAULT_VALUE = 1; +const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; +const uint32_t DEFAULT_CONV_GROUP = 1; +static const int CAFFE_CONV_BIAS_DIM_NUM = 1; + +namespace mindspore { +namespace lite { +STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convParam, std::vector *pad) { + /** + * padUp = padH; + * padDown = padH; + * padLeft = padW; + * padRight = padW; + */ + if (convParam.has_pad_h() || convParam.has_pad_w()) { + if (convParam.pad_size() != 0) { + // MS_LOGE("Either pad or pad_h/w should be specified; not both"); + return RET_ERROR; + } + + if (!convParam.has_pad_h()) { + (*pad)[0] = PAD_DEFAULT_VALUE; + (*pad)[1] = PAD_DEFAULT_VALUE; + (*pad)[2] = convParam.pad_w(); + (*pad)[3] = convParam.pad_w(); + } else if (!convParam.has_pad_w()) { + (*pad)[0] = convParam.pad_h(); + (*pad)[1] = convParam.pad_h(); + (*pad)[2] = PAD_DEFAULT_VALUE; + (*pad)[3] = PAD_DEFAULT_VALUE; + } else { + (*pad)[0] = convParam.pad_h(); + (*pad)[1] = convParam.pad_h(); + (*pad)[2] = convParam.pad_w(); + (*pad)[3] = convParam.pad_w(); + } + } else { + // default 2D + const int num_pad_dims = convParam.pad_size(); + int num_spatial_dims = std::max(num_pad_dims, SPATIAL_DIM_DEFAULT_SIZE); + + std::vector vec; + for (int i = 0; i < num_spatial_dims; ++i) { + vec.push_back((num_pad_dims == 0) ? PAD_DEFAULT_VALUE : convParam.pad((num_pad_dims == 1) ? 0 : i)); + } + // default 2D + (*pad)[0] = vec[0]; + (*pad)[1] = vec[0]; + (*pad)[2] = vec[1]; + (*pad)[3] = vec[1]; + } + return RET_OK; +} + +STATUS CaffeConvBaseParser::ParseStrides(const caffe::ConvolutionParameter &convParam, std::vector *stride) { + if (convParam.has_stride_h() || convParam.has_stride_w()) { + if (convParam.stride_size() != 0) { + // MS_LOGE("Either stride or stride_h/w should be specified; not both"); + return RET_ERROR; + } + if (!convParam.has_stride_h() || !convParam.has_stride_w()) { + // MS_LOGE("stride_h/w must appear at the same time!"); + return RET_ERROR; + } + (*stride)[0] = convParam.stride_h(); + (*stride)[1] = convParam.stride_w(); + } else { + const int num_stride_dims = convParam.stride_size(); + int num_spatial_dims = std::max(num_stride_dims, SPATIAL_DIM_DEFAULT_SIZE); + + std::vector vec; + for (int i = 0; i < num_spatial_dims; ++i) { + vec.push_back((num_stride_dims == 0) ? STRIDE_DEFAULT_VALUE : convParam.stride((num_stride_dims == 1) ? 0 : i)); + } + // default 2D + (*stride)[0] = vec[0]; + (*stride)[1] = vec[1]; + } + return RET_OK; +} + +STATUS CaffeConvBaseParser::ParseDilations(const caffe::ConvolutionParameter &convParam, + std::vector *dilation) { + const int num_dilation_dims = convParam.dilation_size(); + int num_spatial_dims = std::max(num_dilation_dims, SPATIAL_DIM_DEFAULT_SIZE); + + std::vector vec; + for (int i = 0; i < num_spatial_dims; ++i) { + vec.push_back((num_dilation_dims == 0) ? DILATION_DEFAULT_VALUE + : convParam.dilation((num_dilation_dims == 1) ? 0 : i)); + } + // default 2D + (*dilation)[0] = vec[0]; + (*dilation)[1] = vec[1]; + + return RET_OK; +} + +STATUS CaffeConvBaseParser::ParseKernels(const caffe::ConvolutionParameter &convParam, std::vector *kernel) { + if (convParam.has_kernel_h() || convParam.has_kernel_w()) { + if (convParam.kernel_size_size() != 0) { + // MS_LOGE("Either kernel_size or kernel_h/w should be specified; not both.") + return RET_ERROR; + } + if (convParam.has_kernel_h() && convParam.has_kernel_w()) { + (*kernel)[0] = convParam.kernel_h(); + (*kernel)[1] = convParam.kernel_w(); + } else { + // MS_LOGE("kernel_h/w must appear at the same time!"); + return RET_ERROR; + } + } else if (convParam.kernel_size_size() != 0) { + int kernel_size = convParam.kernel_size_size(); + int num_spatial_dims = std::max(kernel_size, SPATIAL_DIM_DEFAULT_SIZE); + std::vector vec; + for (int i = 0; i < num_spatial_dims; i++) { + vec.push_back(convParam.kernel_size((kernel_size == 1) ? 0 : i)); + } + // default 2D + (*kernel)[0] = vec[0]; + (*kernel)[1] = vec[1]; + } else { + return RET_ERROR; + } + return RET_OK; +} + +int CaffeConvBaseParser::ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType) { + // group default 1 + int group = 0; + if (convParam.has_group()) { + group = convParam.group(); + } else { + layerType == "ConvolutionDepthwise" ? (group = convParam.num_output()) : (group = DEFAULT_CONV_GROUP); + } + return group; +} + +int CaffeConvBaseParser::ParseChannelIn(const caffe::LayerParameter &proto, const int &group) { + int res = 0; + auto &weightBlob = proto.blobs(0); + if (weightBlob.has_shape()) { + res = weightBlob.shape().dim(1) * group; + } else { + // get shape information from Blob parameters(caffe proto v1) + if (proto.type() == "Deconvolution") { + res = weightBlob.num() * group; + } else { + res = weightBlob.channels() * group; + } + } + return res; +} + +int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &convParam) { + if (!convParam.has_num_output()) { + // MS_LOGE("Parse num_output for failed."); + } + return convParam.num_output(); +} + +STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, + std::vector *weightVec) { + // Layer must have Filter + if (weight.blobs_size() == 0) { + // MS_LOGE("No filter data in layer %s", weight.name().c_str()); + return RET_ERROR; + } + + auto filter = ConvertWeight(weight.blobs(0)); + if (filter == nullptr) { + // MS_LOGE("Convert weight for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(filter); + + // parse bias + const caffe::ConvolutionParameter convParam = weight.convolution_param(); + if (convParam.bias_term() && weight.blobs_size() > 1) { + auto bias = ConvertWeight(weight.blobs(1)); + if (bias == nullptr) { + // MS_LOGE("Convert bias for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + + std::vector shape = bias->dims; + if (shape.size() != CAFFE_CONV_BIAS_DIM_NUM) { + // MS_LOGE("Bias dim-num of layer %s is not supported"); + return RET_ERROR; + } + weightVec->push_back(bias); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h new file mode 100644 index 0000000000..d1e2886879 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONV_BASE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONV_BASE_PARSER_H_ + +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeConvBaseParser { + public: + CaffeConvBaseParser() {} + + virtual ~CaffeConvBaseParser() {} + + STATUS ParsePads(const caffe::ConvolutionParameter &conv_param, std::vector *pad); + + STATUS ParseStrides(const caffe::ConvolutionParameter &conv_param, std::vector *stride); + + STATUS ParseDilations(const caffe::ConvolutionParameter &conv_param, std::vector *dilation); + + STATUS ParseKernels(const caffe::ConvolutionParameter &conv_param, std::vector *kernel); + + int ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType); + + int ParseChannelOut(const caffe::ConvolutionParameter &convParam); + + int ParseChannelIn(const caffe::LayerParameter &proto, const int &group); + + STATUS ParseWeight(const caffe::LayerParameter &weight, std::vector *weightVec); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONV_BASE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc new file mode 100644 index 0000000000..16056fa39d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" + +namespace mindspore { +namespace lite { +CaffeConverter::CaffeConverter() { + modelParser = new CaffeModelParser(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h new file mode 100644 index 0000000000..889c5afefd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ + +#include +#include +#include "mindspore/lite/tools/converter/converter.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" +#include "mindspore/lite/tools/converter/graphdef_transform.h" + +namespace mindspore::lite { +class CaffeConverter : public Converter { + public: + CaffeConverter(); + + ~CaffeConverter() override = default; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc new file mode 100644 index 0000000000..5c622f13d0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h" + +namespace mindspore { +namespace lite { +void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { + if (attr == nullptr || attr->group == 1 || attr->group != attr->channelOut) { + return; + } + std::unique_ptr depthwiseConv2DParam(new schema::DepthwiseConv2DT()); + if (depthwiseConv2DParam == nullptr) { + // MS_LOGW("new DepthwiseConv2DT failed"); + return; + } + depthwiseConv2DParam->format = attr->format; + depthwiseConv2DParam->channelIn = attr->channelIn; + depthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn; + depthwiseConv2DParam->kernelW = attr->kernelW; + depthwiseConv2DParam->kernelH = attr->kernelH; + depthwiseConv2DParam->strideW = attr->strideW; + depthwiseConv2DParam->strideH = attr->strideH; + depthwiseConv2DParam->padMode = attr->padMode; + depthwiseConv2DParam->padUp = attr->padUp; + depthwiseConv2DParam->padDown = attr->padDown; + depthwiseConv2DParam->padLeft = attr->padLeft; + depthwiseConv2DParam->padRight = attr->padRight; + depthwiseConv2DParam->dilateW = attr->dilateW; + depthwiseConv2DParam->dilateH = attr->dilateH; + depthwiseConv2DParam->hasBias = attr->hasBias; + depthwiseConv2DParam->activationType = attr->activationType; + delete attr; + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + op->primitive->value.value = depthwiseConv2DParam.release(); +} + +STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + op->name = proto.name(); + schema::Conv2DT *attr = new schema::Conv2DT(); + + attr->format = schema::Format_NCHW; + const caffe::ConvolutionParameter convParam = proto.convolution_param(); + + CaffeConvBaseParser convParser; + // parse pad + std::vector pad(4, 0); + auto status = convParser.ParsePads(convParam, &pad); + if (status != RET_OK) { + // MS_LOGE("ParsePads for %s failed", proto.name().c_str()); + } + attr->padUp = pad[0]; + attr->padDown = pad[1]; + attr->padLeft = pad[2]; + attr->padRight = pad[3]; + + // parse stride + std::vector stride(2, 0); + status = convParser.ParseStrides(convParam, &stride); + if (status != RET_OK) { + // MS_LOGE("ParseStrides for %s failed", proto.name().c_str()); + } + attr->strideH = stride[0]; + attr->strideW = stride[1]; + + // parse dilation + std::vector dilation(2, 0); + status = convParser.ParseDilations(convParam, &dilation); + if (status != RET_OK) { + // MS_LOGE("ParseDilations for %s failed", proto.name().c_str()); + } + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + // parse kernel + std::vector kernel(2, 0); + status = convParser.ParseKernels(convParam, &kernel); + if (status != RET_OK) { + // MS_LOGE("ParseKernels for %s failed", proto.name().c_str()); + } + attr->kernelH = kernel[0]; + attr->kernelW = kernel[1]; + + attr->hasBias = convParam.bias_term(); + attr->group = convParser.ParseGroup(convParam, proto.type()); + attr->channelOut = convParser.ParseChannelOut(convParam); + attr->channelIn = convParser.ParseChannelIn(weight, attr->group); + attr->padMode = schema::PadMode_CAFFE; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = attr; + + ParseGroupConvolution(op, attr); + status = convParser.ParseWeight(weight, weightVec); + if (status != RET_OK) { + // MS_LOGE("ParseWeight for %s failed", proto.name().c_str()); + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffeConvolutionParser("Convolution", new CaffeConvolutionParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h new file mode 100644 index 0000000000..297c6242ba --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVOLUTION_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVOLUTION_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h" + +namespace mindspore { +namespace lite { +class CaffeConvolutionParser : public CaffeNodeParser { + public: + CaffeConvolutionParser() : CaffeNodeParser("convolution") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + private: + void ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVOLUTION_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc new file mode 100644 index 0000000000..106fb3ad71 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h" + +const int32_t CROP_AXIS = 2; + +namespace mindspore { +namespace lite { +STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::CropT()); + if (!proto.has_crop_param()) { + attr->axis = CROP_AXIS; + std::vector offsets(2, 0); + attr->offsets = offsets; + } else { + const caffe::CropParameter cropParam = proto.crop_param(); + + if (cropParam.has_axis()) { + if (cropParam.axis() == -1) { + // MS_LOGW("axis with -1 may lead to calculation errors when input less than 4 dims."); + } + attr->axis = cropParam.axis(); + } else { + attr->axis = CROP_AXIS; + } + + if (cropParam.offset_size() != 0) { + std::vector offsets; + for (int i = 0; i < cropParam.offset_size(); i++) { + offsets.push_back(cropParam.offset(i)); + } + attr->offsets = offsets; + } + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Crop; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeCropParser("Crop", new CaffeCropParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h new file mode 100644 index 0000000000..7de30b5cc0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CROP_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CROP_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeCropParser : public CaffeNodeParser { + public: + CaffeCropParser() : CaffeNodeParser("crop") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CROP_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc new file mode 100644 index 0000000000..be9682f2fd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h" + +namespace mindspore { +namespace lite { +void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr) { + if (attr == nullptr || attr->group == 1 || attr->group != attr->channelIn) { + return; + } + + std::unique_ptr deDepthwiseConv2DParam(new schema::DeDepthwiseConv2DT()); + if (deDepthwiseConv2DParam == nullptr) { + // MS_LOGW("new DeDepthwiseConv2DT failed"); + return; + } + deDepthwiseConv2DParam->format = attr->format; + deDepthwiseConv2DParam->channelIn = attr->channelOut; + deDepthwiseConv2DParam->channelMultiplier = attr->channelIn / attr->channelOut; + deDepthwiseConv2DParam->kernelW = attr->kernelW; + deDepthwiseConv2DParam->kernelH = attr->kernelH; + deDepthwiseConv2DParam->strideW = attr->strideW; + deDepthwiseConv2DParam->strideH = attr->strideH; + deDepthwiseConv2DParam->padMode = attr->padMode; + deDepthwiseConv2DParam->padUp = attr->padUp; + deDepthwiseConv2DParam->padDown = attr->padDown; + deDepthwiseConv2DParam->padLeft = attr->padLeft; + deDepthwiseConv2DParam->padRight = attr->padRight; + deDepthwiseConv2DParam->dilateW = attr->dilateW; + deDepthwiseConv2DParam->dilateH = attr->dilateH; + deDepthwiseConv2DParam->hasBias = attr->hasBias; + deDepthwiseConv2DParam->activationType = attr->activationType; + delete attr; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; + op->primitive->value.value = deDepthwiseConv2DParam.release(); +} +STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + op->name = proto.name(); + schema::DeConv2DT *attr = new schema::DeConv2DT(); + attr->format = schema::Format_NCHW; + const caffe::ConvolutionParameter convParam = proto.convolution_param(); + + CaffeConvBaseParser convParser; + // parse pad + std::vector pad(4, 0); + auto status = convParser.ParsePads(convParam, &pad); + if (status != RET_OK) { + // MS_LOGE("ParsePads for %s failed", proto.name().c_str()); + } + attr->padUp = pad[0]; + attr->padDown = pad[1]; + attr->padLeft = pad[2]; + attr->padRight = pad[3]; + + // parse stride + std::vector stride(2, 0); + status = convParser.ParseStrides(convParam, &stride); + if (status != RET_OK) { + // MS_LOGE("ParseStrides for %s failed", proto.name().c_str()); + } + attr->strideH = stride[0]; + attr->strideW = stride[1]; + + // parse dilation + std::vector dilation(2, 0); + status = convParser.ParseDilations(convParam, &dilation); + if (status != RET_OK) { + // MS_LOGE("ParseDilations for %s failed", proto.name().c_str()); + } + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + // parse kernel + std::vector kernel(2, 0); + status = convParser.ParseKernels(convParam, &kernel); + if (status != RET_OK) { + // MS_LOGE("ParseKernels for %s failed", proto.name().c_str()); + } + attr->kernelH = kernel[0]; + attr->kernelW = kernel[1]; + + attr->hasBias = convParam.bias_term(); + attr->group = convParser.ParseGroup(convParam, proto.type()); + attr->channelOut = convParser.ParseChannelOut(convParam); + attr->channelIn = convParser.ParseChannelIn(weight, attr->group); + attr->padMode = schema::PadMode_CAFFE; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeConv2D; + op->primitive->value.value = attr; + ParseGroupDeconvolution(op, attr); + status = convParser.ParseWeight(weight, weightVec); + if (status != RET_OK) { + // MS_LOGE("ParseWeight for %s failed", proto.name().c_str()); + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffeDeconvolutionParser("Deconvolution", new CaffeDeconvolutionParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h new file mode 100644 index 0000000000..834dd0a9c6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_DECONVOLUTION_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_DECONVOLUTION_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h" + +namespace mindspore { +namespace lite { +class CaffeDeconvolutionParser : public CaffeNodeParser { + public: + CaffeDeconvolutionParser() : CaffeNodeParser("deconvolution") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + private: + void ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_DECONVOLUTION_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc new file mode 100644 index 0000000000..c750a2c5d8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h" + +const int ELTWISE_MIN_INPUT_SIZE = 2; + +namespace mindspore { +namespace lite { +STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + std::unique_ptr attr(new schema::EltwiseT()); + if (proto.bottom_size() < ELTWISE_MIN_INPUT_SIZE) { + // MS_LOGE("Eltwise Op '%s' need at least 2 inputs,but input size is %d", proto.name().c_str(), + // proto.bottom_size()); + return RET_ERROR; + } + + const caffe::EltwiseParameter eltwiseParam = proto.eltwise_param(); + + if (eltwiseParam.coeff_size() != 0 && eltwiseParam.coeff_size() != proto.bottom_size()) { + // MS_LOGE("Coeff size(%d) check fail, Eltwise Layer takes one coefficient per bottom blob.", + // eltwiseParam.coeff_size()); + return RET_PARAM_INVALID; + } + + if (eltwiseParam.operation() == caffe::EltwiseParameter::PROD && eltwiseParam.coeff_size() != 0) { + // MS_LOGE("Eltwise layer only takes coefficients for summation."); + return RET_ERROR; + } + + if (proto.has_eltwise_param() && eltwiseParam.has_operation()) { + switch (eltwiseParam.operation()) { + case caffe::EltwiseParameter::PROD: + attr->mode = schema::EltwiseMode_PROD; + break; + case caffe::EltwiseParameter::SUM: + attr->mode = schema::EltwiseMode_SUM; + break; + case caffe::EltwiseParameter::MAX: + attr->mode = schema::EltwiseMode_MAXIMUM; + break; + default: + // MS_LOGE("Eltwise parse params fail, unsupported opration %d.", eltwiseParam.operation()); + return RET_PARAM_INVALID; + } + } else { + attr->mode = schema::EltwiseMode_SUM; + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Eltwise; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeEltwiseParser("Eltwise", new CaffeEltwiseParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h new file mode 100644 index 0000000000..efae240ecd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ELTWISE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ELTWISE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeEltwiseParser : public CaffeNodeParser { + public: + CaffeEltwiseParser() : CaffeNodeParser("eltwise") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ELTWISE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc new file mode 100644 index 0000000000..81c0a71c0d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + const caffe::InnerProductParameter innerProductParam = proto.inner_product_param(); + std::unique_ptr attr(new schema::FullConnectionT()); + + if (!innerProductParam.has_num_output()) { + // MS_LOGE("InnerProduct Parse num_output for %s failed.", proto.name().c_str()); + return RET_ERROR; + } + + if (innerProductParam.axis() == 1) { + attr->axis = 1; + } else { + // MS_LOGE("InnerProduct Parse axis only support default 1, but actually %d.", innerProductParam.axis()); + return RET_ERROR; + } + + if (innerProductParam.bias_term()) { + attr->hasBias = true; + } + + // parse weight + if (weight.blobs_size() == 0) { + // MS_LOGE("InnerProduct No filter data in layer %s", weight.name().c_str()); + return RET_ERROR; + } + + // parse filter + auto filter = ConvertWeight(weight.blobs(0)); + if (filter == nullptr) { + // MS_LOGE("InnerProduct parse weight for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(filter); + + // parse bias + if (innerProductParam.bias_term() && weight.blobs_size() > 1) { + auto bias = ConvertWeight(weight.blobs(1)); + if (bias == nullptr) { + // MS_LOGE("InnerProduct parse bias for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(bias); + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeInnerProductParser("InnerProduct", new CaffeInnerProductParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h new file mode 100644 index 0000000000..548c4535f7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INNERPRODUCT_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INNERPRODUCT_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeInnerProductParser : public CaffeNodeParser { + public: + CaffeInnerProductParser() : CaffeNodeParser("innerproduct") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INNERPRODUCT_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc new file mode 100644 index 0000000000..18c4337b8f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +STATUS CaffeInspector::InspectModel(const caffe::NetParameter &proto) { + net = proto; + + if (proto.layer_size() == 0) { + // MS_LOGE("net layer num is zero, prototxt file may be invalid."); + return RET_ERROR; + } + + ParseInput(); + + SetTopsAndBottoms(); + + FindInputAndOutput(); +} + +STATUS CaffeInspector::ParseInput() { + if (net.input_size() > 0) { + // MS_LOGI("This net exist input."); + for (int i = 0; i < net.input_size(); i++) { + graphInput.insert(net.input(i)); + } + } + return RET_OK; +} + +STATUS CaffeInspector::FindInputAndOutput() { + for (auto iter : layerBottoms) { + if (layerTops.find(iter) == layerTops.end()) { + graphInput.insert(iter); + } + } + for (auto iter : layerTops) { + if (layerBottoms.find(iter) == layerBottoms.end()) { + graphOutput.insert(iter); + } + } +} + +STATUS CaffeInspector::SetTopsAndBottoms() { + for (int32_t i = 0; i < net.layer_size(); i++) { + caffe::LayerParameter &layer = const_cast(net.layer(i)); + if (layer.top_size() == 1 && layer.bottom_size() == 1 && layer.top(0) == layer.bottom(0)) { + continue; + } + if (layer.top_size() == 1 && layer.bottom_size() == 0) { + graphInput.insert(layer.top(0)); + } + for (int j = 0; j < layer.top_size(); j++) { + layerTops.insert(layer.top(j)); + } + for (int j = 0; j < layer.bottom_size(); j++) { + layerBottoms.insert(layer.bottom(j)); + } + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h new file mode 100644 index 0000000000..94bda8ddec --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INSPECTOR_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INSPECTOR_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class CaffeInspector { + public: + CaffeInspector() = default; + + STATUS InspectModel(const caffe::NetParameter &proto); + STATUS ParseInput(); + STATUS FindInputAndOutput(); + STATUS SetTopsAndBottoms(); + + std::set GetGraphInput() { return graphInput; } + std::set GetGraphOutput() { return graphOutput; } + + private: + caffe::NetParameter net; + + std::set layerTops; + std::set layerBottoms; + + std::set graphInput; + std::set graphOutput; +}; + +using CaffeInspectorPtr = std::shared_ptr; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INSPECTOR_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc new file mode 100644 index 0000000000..2ade9ec54d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + std::unique_ptr attr(new schema::ResizeT()); + const caffe::InterpParameter interpParam = proto.interp_param(); + + if (interpParam.has_height()) { + int64_t height = interpParam.height(); + if (height < 0) { + // MS_LOGE("Interp height must be > 0"); + return RET_ERROR; + } + attr->newHeight = height; + } + + if (interpParam.has_width()) { + int64_t width = interpParam.width(); + if (width < 0) { + // MS_LOGE("Interp width must be > 0"); + return RET_ERROR; + } + attr->newWidth = width; + } + + attr->alignCorners = true; + attr->method = schema::ResizeMethod_BILINEAR; + + op->name = proto.name(); + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Resize; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeInterpParser("Interp", new CaffeInterpParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h new file mode 100644 index 0000000000..675cc9ff88 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INTERP_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INTERP_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeInterpParser : public CaffeNodeParser { + public: + CaffeInterpParser() : CaffeNodeParser("Interp") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INTERP_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc new file mode 100644 index 0000000000..7378bb1f99 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -0,0 +1,304 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +CaffeModelParser::CaffeModelParser() {} + +CaffeModelParser::~CaffeModelParser() {} + +const std::set CaffeModelParser::skipedLayerType = {"Dropout"}; + +schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { + std::unique_ptr graph(new schema::MetaGraphT()); + + if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; + return nullptr; + } + + if (weightFile.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary"; + return nullptr; + } + + if (ValidateFileStr(weightFile, ".caffemodel") != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel"; + return nullptr; + } + + std::unique_ptr subGraphDef(new schema::MetaGraphT()); + TensorCache tensorCache; + + caffe::NetParameter proto; + if (ReadProtoFromText((const char *)modelFile.c_str(), &proto) != RET_OK) { + MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile; + return nullptr; + } + subGraphDef->name = proto.name(); + + caffe::NetParameter weight; + if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) { + MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; + return nullptr; + } + + auto status = GetModelInput(proto, &tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetModelInput failed " << status; + return nullptr; + } + + status = ParseLayer(proto, weight, &tensorCache, subGraphDef.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "ParseLayer failed " << status; + return nullptr; + } + + // set inputTensor index and outputTensor index for the whole graph + status = SetGraphTensorIndex(proto, &tensorCache, subGraphDef.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set inputTensor index and outputTensor index for graph failed!"; + return nullptr; + } + subGraphDef->name = GetModelName(modelFile); + // set all tensors to graph + SetAllTensors(tensorCache, subGraphDef.get()); + graph = move(subGraphDef); + + // ConvertCaffeBatchNorm(graph.get()); + + return graph.release(); + // return Fb2Anf(graph.release()); +} + +STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, + TensorCache *tensorCache) { + for (int i = 0; i < layer.bottom_size(); i++) { + int index = tensorCache->FindTensor(layer.bottom(i)); + if (index >= 0) { + op->inputIndex.emplace_back(index); + } else { + // MS_LOGE("Can't find input layer for %s.", layer.name().c_str()); + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, + TensorCache *tensorCache) { + for (int i = 0; i < layer.top_size(); i++) { + std::unique_ptr msTensor(new schema::TensorT()); + op->outputIndex.emplace_back(tensorCache->AddTensor(layer.top(i), msTensor.release(), OP_OUTPUT)); + } + return RET_OK; +} + +STATUS CaffeModelParser::SetWeightTensor(const std::vector &weightVec, schema::CNodeT *op, + TensorCache *tensorCache) { + for (auto iter : weightVec) { + op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST)); + } + return RET_OK; +} + +STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef) { + std::vector tensors = tensorCache.GetCachedTensor(); + for (auto iter : tensors) { + std::unique_ptr temp(iter); + subGraphDef->allTensors.emplace_back(move(temp)); + } + return RET_OK; +} + +STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, TensorCache *tensorCache, + schema::MetaGraphT *subGraphDef) { + CaffeInspector caffeInspector; + caffeInspector.InspectModel(proto); + for (auto iter : caffeInspector.GetGraphInput()) { + int index = tensorCache->FindTensor(iter); + if (index >= 0) { + subGraphDef->inputIndex.emplace_back(index); + } else { + // MS_LOGE("Can't find input tensor layer for graph."); + return RET_ERROR; + } + } + + for (auto iter : caffeInspector.GetGraphOutput()) { + int index = tensorCache->FindTensor(iter); + if (index >= 0) { + subGraphDef->outputIndex.emplace_back(index); + } else { + // MS_LOGE("Can't find output tensor layer for graph."); + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, + TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) { + for (int i = 0; i < proto.layer_size(); i++) { + auto layer = proto.layer(i); + + caffe::LayerParameter layerP; + for (int j = 0; j < weight.layer_size(); j++) { + auto tempLayer = weight.layer(j); + if (tempLayer.name() == layer.name()) { + layerP = tempLayer; + break; + } + } + // todo y00520784 : layer.input_param().shape(0) + if (layer.type() == "Input") { + std::unique_ptr msTensor(new schema::TensorT()); + for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) { + msTensor->dims.push_back(layer.input_param().shape(0).dim(j)); + } + msTensor->nodeType = schema::NodeType_ValueNode; + msTensor->refCount = 1; + msTensor->dataType = kNumberTypeFloat32; + tensorCache->AddTensor(layer.top(0), msTensor.release(), GRAPH_INPUT); + } else { + if (skipedLayerType.find(layer.type()) != skipedLayerType.end()) { + MS_LOG(INFO) << "Skip layer " << layer.name(); + continue; + } + + std::unique_ptr op(new schema::CNodeT()); + op->name = layer.name(); + + // set op input index + auto status = SetOpInputIdx(layer, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; + return status; + } + + auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); + if (nodeParser == nullptr) { + MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name(); + return RET_ERROR; + } + + std::vector weightVec; + status = nodeParser->Parse(layer, layerP, op.get(), &weightVec); + if (status != RET_OK) { + MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; + return status; + } + // set op weight tensor to tensorcache + SetWeightTensor(weightVec, op.get(), tensorCache); + + // set op output index + status = SetOpOutputIdx(layer, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; + return status; + } + + // op->fmkType = FmkType_CAFFE; + subGraphDef->nodes.emplace_back(move(op)); + } + } + return RET_OK; +} + +STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { + for (int i = 0; i < proto.input_size(); i++) { + if (proto.input_dim_size() <= 0) { + continue; + } + std::unique_ptr msTensor(new schema::TensorT()); + for (int j = 0; j < proto.input_dim_size(); j++) { + msTensor->dims.push_back(proto.input_dim(j)); + } + msTensor->refCount = schema::NodeType_ValueNode; + msTensor->dataType = kNumberTypeFloat32; + tensorCache->AddTensor(proto.input(i), msTensor.release(), GRAPH_INPUT); + } + + for (int i = 0; i < proto.input_shape_size(); i++) { + auto shape = proto.input_shape(i); + std::unique_ptr msTensor(new schema::TensorT()); + for (int j = 0; j < shape.dim_size(); j++) { + msTensor->dims.push_back(shape.dim(j)); + } + msTensor->refCount = schema::NodeType_ValueNode; + msTensor->dataType = kNumberTypeFloat32; + tensorCache->AddTensor(proto.input(i), msTensor.release(), GRAPH_INPUT); + } + return RET_OK; +} + +void CaffeModelParser::ConvertCaffeBatchNorm(schema::MetaGraphT *meta_graph) { + MS_ASSERT(meta_graph != nullptr); + auto &nodes = meta_graph->nodes; + for (auto &node : nodes) { + if (node->primitive->value.type != schema::PrimitiveType_FusedBatchNorm) { + continue; + } + MS_ASSERT(node->inputIndex.size() == 2); + MS_ASSERT(node->inputIndex.back() < meta_graph->allTensors.size()); + auto &meanTensor = meta_graph->allTensors.at(node->inputIndex.back()); + MS_ASSERT(nullptr != meanTensor); + auto shape = meanTensor->dims; + auto shapeSize = GetShapeSize(shape); + + auto scaleTensor = std::make_unique(); + scaleTensor->dims = shape; + scaleTensor->nodeType = NodeType_ValueNode; + scaleTensor->refCount = 1; + scaleTensor->format = schema::Format_NUM_OF_FORMAT; + scaleTensor->dataType = TypeId::kNumberTypeFloat32; + scaleTensor->data.resize(shapeSize * sizeof(float)); + auto scaleData = reinterpret_cast(scaleTensor->data.data()); + for (size_t i = 0; i < shapeSize; i++) { + scaleData[i] = 1; + } + + auto biasTensor = std::make_unique(); + biasTensor->dims = shape; + biasTensor->nodeType = NodeType_ValueNode; + biasTensor->refCount = 1; + biasTensor->format = schema::Format_NUM_OF_FORMAT; + biasTensor->dataType = TypeId::kNumberTypeInt32; + biasTensor->data.resize(shapeSize * sizeof(int32_t)); + auto biasData = reinterpret_cast(biasTensor->data.data()); + for (size_t i = 0; i < shapeSize; i++) { + biasData[i] = 0; + } + + node->inputIndex.insert(node->inputIndex.begin() + 1, meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(biasTensor)); + + node->inputIndex.insert(node->inputIndex.begin() + 1, meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(scaleTensor)); + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h new file mode 100644 index 0000000000..52297d3018 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_MODEL_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_MODEL_PARSER_H_ + +#include +#include +#include +#include +#include "mindspore/lite/tools/converter/model_parser.h" +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class CaffeModelParser : public ModelParser { + public: + CaffeModelParser(); + + virtual ~CaffeModelParser(); + + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; + + private: + void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT); + + STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); + + STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); + + STATUS SetWeightTensor(const std::vector &weightVec, schema::CNodeT *op, TensorCache *tensorCache); + + STATUS SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); + + STATUS SetGraphTensorIndex(const caffe::NetParameter &proto, + TensorCache *tensorCache, + schema::MetaGraphT *subGraphDef); + + STATUS ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, TensorCache *tensorCache, + schema::MetaGraphT *subGraphDef); + + STATUS GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache); + + static const std::set skipedLayerType; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_MODEL_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc new file mode 100644 index 0000000000..fbee603b14 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "securec/include/securec.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace lite { +schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { + std::unique_ptr weight(new schema::TensorT()); + weight->format = schema::Format_NCHW; + std::vector shapeVec; + ConvertShape(proto, &shapeVec); + weight->dims = shapeVec; + weight->dataType = kNumberTypeFloat32; + weight->nodeType = schema::NodeType_ValueNode; + + // cal Weight num + int count = 1; + for (size_t i = 0; i < shapeVec.size(); ++i) { + int dim = shapeVec[i]; + if (dim <= 0) { + // MS_LOGE("Convert weight fail, Blob size invalid"); + return nullptr; + } + if (dim >= INT_MAX / count) { + // MS_LOGE("Convert weight fail, Blob size exceeds INT_MAX, dim:%d, count:%d", dim, count); + return nullptr; + } + count *= dim; + } + + // get weight + std::unique_ptr buf(new (std::nothrow) float[count]()); + if (buf == nullptr) { + return nullptr; + } + if (proto.double_data_size() > 0) { + // datatype double + if (count != proto.double_data_size()) { + // MS_LOGE("Convert weight fail, Blob size does not match shape size, shape size:%d, blob size:%d", count, + // proto.double_data_size()); + return nullptr; + } + + for (int i = 0; i < count; ++i) { + buf[i] = proto.double_data(i); + } + weight->data.resize(count * sizeof(float)); + ::memcpy_s(weight->data.data(), count * sizeof(float), + reinterpret_cast(buf.get()), + count * sizeof(float)); + } else { + // datatype float + if (count != proto.data_size()) { + // MS_LOGE("Convert weight fail, Blob size does not match shape size, shape size:%d, blob.data_size:%d", count, + // proto.data_size()); + return nullptr; + } + weight->data.resize(count * sizeof(float)); + const float *data_ptr = proto.data().data(); + ::memcpy_s(weight->data.data(), count * sizeof(float), (uint8_t *)data_ptr, count * sizeof(float)); + } + weight->refCount = 1; + + return weight.release(); +} + +STATUS ConvertShape(const caffe::BlobProto &proto, std::vector *shape) { + shape->clear(); + + if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { + // num, channels, height, width + shape->push_back(proto.num()); + shape->push_back(proto.channels()); + shape->push_back(proto.height()); + shape->push_back(proto.width()); + } else { + for (int i = 0; i < proto.shape().dim_size(); ++i) { + shape->push_back(proto.shape().dim(i)); + } + } +} +} // namespace lite +} // namespace mindspore +// + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h new file mode 100644 index 0000000000..e7b4f3d82b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_H_ + +#include +#include +#include "google/protobuf/message.h" +#include "mindspore/lite/schema/inner/model_generated.h" +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { + +class CaffeNodeParser { + public: + explicit CaffeNodeParser(const std::string &nodeName) : name(nodeName) {} + + virtual ~CaffeNodeParser() {} + + virtual int Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) = 0; + + protected: + const std::string &name; +}; + +schema::TensorT *ConvertWeight(const caffe::BlobProto &proto); + +STATUS ConvertShape(const caffe::BlobProto &proto, std::vector *shape); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc new file mode 100644 index 0000000000..33085e104b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +CaffeNodeParserRegistry::CaffeNodeParserRegistry() {} + +CaffeNodeParserRegistry::~CaffeNodeParserRegistry() {} + +CaffeNodeParserRegistry *CaffeNodeParserRegistry::GetInstance() { + static CaffeNodeParserRegistry instance; + return &instance; +} + +CaffeNodeParser *CaffeNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + return nullptr; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h new file mode 100644 index 0000000000..75e45bcdfa --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_REGISTRY_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_REGISTRY_H_ + +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "tools/converter/parser/caffe/caffe.pb.h" + +namespace mindspore::lite { +class CaffeNodeParserRegistry { + public: + CaffeNodeParserRegistry(); + + virtual ~CaffeNodeParserRegistry(); + + static CaffeNodeParserRegistry *GetInstance(); + + CaffeNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class CaffeNodeRegistrar { + public: + CaffeNodeRegistrar(const std::string &name, CaffeNodeParser *parser) { + CaffeNodeParserRegistry::GetInstance()->parsers[name] = parser; + } +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_REGISTRY_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc new file mode 100644 index 0000000000..d543578a79 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/coded_stream.h" +#include "securec/include/securec.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +static const int PROTO_READ_BYTES_LIMIT = INT_MAX; // Max size of 2 GB minus 1 byte. +static const int WARNING_THRESHOLD = 536870912 * 2; + +bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, + google::protobuf::Message *proto) { + if (proto == nullptr) { + // MS_LOGE("incorrect parameter. nullptr == proto"); + return false; + } + coded_stream->SetTotalBytesLimit(PROTO_READ_BYTES_LIMIT, WARNING_THRESHOLD); + return proto->ParseFromCodedStream(coded_stream); +} + +STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message) { + if (file == nullptr || message == nullptr) { + return RET_ERROR; + } + + std::string realPath = RealPath(file); + if (realPath.empty()) { + // MS_LOGE("Proto file path is '%s' not valid", file); + return RET_ERROR; + } + + std::ifstream fs(realPath.c_str(), std::ifstream::in); + + if (!fs.is_open()) { + // MS_LOGE("Open proto file '%s' failed.", file); + return RET_ERROR; + } + + google::protobuf::io::IstreamInputStream input(&fs); + bool status = google::protobuf::TextFormat::Parse(&input, message); + if (status != true) { + // MS_LOGE("call [google::protobuf::TextFormat::Parse] func status fail, please check your text file."); + return RET_ERROR; + } + + fs.close(); + + return RET_OK; +} + +STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message) { + if (file == nullptr || message == nullptr) { + return RET_ERROR; + } + + std::string realPath = RealPath(file); + if (realPath.empty()) { + // MS_LOGE("Weight file path is '%s' not valid", file); + return RET_ERROR; + } + + std::ifstream fs(realPath, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) { + // MS_LOGE("Open weight file '%s' failed.", file); + return RET_ERROR; + } + + google::protobuf::io::IstreamInputStream istream(&fs); + google::protobuf::io::CodedInputStream coded_stream(&istream); + + bool success = ReadProtoFromCodedInputStream(&coded_stream, message); + fs.close(); + + if (!success) { + // MS_LOGE("Parse %s failed.", file); + return RET_ERROR; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h new file mode 100644 index 0000000000..3ee51440df --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ + +#include +#include +#include "google/protobuf/message.h" + +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, + google::protobuf::Message *proto); + +STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message); + +STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc new file mode 100644 index 0000000000..c80dd6b21d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h" +#include "utils/log_adapter.h" + +const uint32_t INNERPRODUCT_WINDOW_DEFAULT_VALUE = 0; +const uint32_t INNERPRODUCT_PAD_DEFAULT_VALUE = 0; + +namespace mindspore { +namespace lite { +STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::PoolingT()); + attr->format = schema::Format_NCHW; + + const caffe::PoolingParameter poolingParam = proto.pooling_param(); + + auto status = ParsePads(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParsePads for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + status = ParseStrides(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParseStrides for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + status = ParseWindows(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParseWindows for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + status = ParsePoolingMode(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParsePoolingMode for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + if (poolingParam.has_round_mode()) { + if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_FLOOR) { + attr->roundMode = schema::RoundMode_FLOOR; + } else if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_CEIL) { + attr->roundMode = schema::RoundMode_CEIL; + } else { + MS_ASSERT(false); + } + } + + attr->padMode = schema::PadMode_CAFFE; + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + return RET_OK; +} + +STATUS CaffePoolingParser::ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.has_pad_h() && poolingParam.has_pad_w()) { + if (poolingParam.has_pad()) { + // MS_LOGE("Either pad or pad_h/w should be specified; not both"); + return RET_ERROR; + } + attr->padLeft = poolingParam.pad_w(); + attr->padRight = poolingParam.pad_w(); + attr->padUp = poolingParam.pad_h(); + attr->padDown = poolingParam.pad_h(); + } else { + attr->padLeft = poolingParam.pad(); + attr->padRight = poolingParam.pad(); + attr->padUp = poolingParam.pad(); + attr->padDown = poolingParam.pad(); + } + return RET_OK; +} + +STATUS CaffePoolingParser::ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.has_stride_h() && poolingParam.has_stride_w()) { + if (poolingParam.has_stride()) { + // MS_LOGE("Either stride or stride_h/w should be specified; not both"); + return RET_ERROR; + } + attr->strideH = poolingParam.stride_h(); + attr->strideW = poolingParam.stride_w(); + } else { + attr->strideH = poolingParam.stride(); + attr->strideW = poolingParam.stride(); + } + return RET_OK; +} + +STATUS CaffePoolingParser::ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.has_global_pooling() && poolingParam.global_pooling()) { + if (poolingParam.has_kernel_size() || poolingParam.has_kernel_h() || poolingParam.has_kernel_w()) { + // MS_LOGE("With Global_pooling: true Filter size cannot specified"); + return RET_ERROR; + } + attr->windowH = INNERPRODUCT_WINDOW_DEFAULT_VALUE; + attr->windowW = INNERPRODUCT_WINDOW_DEFAULT_VALUE; + attr->global = true; + } else { + if (poolingParam.has_kernel_size() == (poolingParam.has_kernel_h() || poolingParam.has_kernel_w())) { + // MS_LOGE("Filter size is kernel_size OR kernel_h and kernel_w; not both"); + return RET_ERROR; + } + if (!poolingParam.has_kernel_size() && !(poolingParam.has_kernel_h() && poolingParam.has_kernel_w())) { + // MS_LOGE("For non-square filters both kernel_h and kernel_w are required."); + return RET_ERROR; + } + + if (poolingParam.has_kernel_h() && poolingParam.has_kernel_w()) { + attr->windowH = poolingParam.kernel_h(); + attr->windowW = poolingParam.kernel_w(); + } else { + attr->windowH = poolingParam.kernel_size(); + attr->windowW = poolingParam.kernel_size(); + } + } + return RET_OK; +} + +STATUS CaffePoolingParser::ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.pool() == caffe::PoolingParameter::MAX) { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } else if (poolingParam.pool() == caffe::PoolingParameter::AVE) { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } else { + // MS_LOGE("Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."); + return RET_ERROR; + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffePoolingParser("Pooling", new CaffePoolingParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h new file mode 100644 index 0000000000..97d042f9ee --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POOLING_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POOLING_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffePoolingParser : public CaffeNodeParser { + public: + CaffePoolingParser() : CaffeNodeParser("pooling") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + STATUS ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); + + STATUS ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); + + STATUS ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); + + STATUS ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POOLING_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc new file mode 100644 index 0000000000..0336bd1c4f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h" + +static const float CAFFE_POWER_DEFAULT_POWER = 1.0; +static const float CAFFE_POWER_DEFAULT_SCALE = 1.0; +static const float CAFFE_POWER_DEFAULT_SHIFT = 0.0; + +namespace mindspore { +namespace lite { +STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::PowerT()); + const caffe::PowerParameter powerParam = proto.power_param(); + if (proto.has_power_param()) { + attr->power = powerParam.has_power() ? powerParam.power() : CAFFE_POWER_DEFAULT_POWER; + attr->scale = powerParam.has_scale() ? powerParam.scale() : CAFFE_POWER_DEFAULT_SCALE; + attr->shift = powerParam.has_shift() ? powerParam.shift() : CAFFE_POWER_DEFAULT_SHIFT; + } else { + attr->power = CAFFE_POWER_DEFAULT_POWER; + attr->scale = CAFFE_POWER_DEFAULT_SCALE; + attr->shift = CAFFE_POWER_DEFAULT_SHIFT; + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Power; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffePowerParser("Power", new CaffePowerParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h new file mode 100644 index 0000000000..c68b9dd9af --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POWER_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POWER_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffePowerParser : public CaffeNodeParser { + public: + CaffePowerParser() : CaffeNodeParser("power") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POWER_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc new file mode 100644 index 0000000000..1458d9415d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0f + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::CaffePReLUT()); + const caffe::PReLUParameter pReluParam = proto.prelu_param(); + + if (pReluParam.has_channel_shared()) { + attr->channelShared = pReluParam.channel_shared(); + } else { + attr->channelShared = false; + } + + if (weight.blobs_size() == 0) { + // MS_LOGE("PRelu No blobs data in layer %s", proto.name().c_str()); + return RET_ERROR; + } + + auto slope = ConvertWeight(weight.blobs(0)); + if (slope == nullptr) { + // MS_LOGE("CaffePRelu convert slope for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(slope); + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_CaffePReLU; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffePReluParser("PReLU", new CaffePReluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h new file mode 100644 index 0000000000..cfb53972fb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PRELU_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PRELU_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffePReluParser : public CaffeNodeParser { + public: + CaffePReluParser() : CaffeNodeParser("pRelu") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PRELU_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc new file mode 100644 index 0000000000..49ea560a5d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeReluParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_RELU; + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Activation; + // relu: negative_slope = 0, no parameter; + // leakyrelu: negative_slope != 0; + if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) { + float negative_slope = proto.relu_param().negative_slope(); + + if (0 != negative_slope) { + std::unique_ptr attrLeakyReLu(new schema::LeakyReLUT()); + attrLeakyReLu->negativeSlope = negative_slope; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LeakyReLU; + op->primitive->value.value = attrLeakyReLu.release(); + } + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffeReluParser("ReLU", new CaffeReluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h new file mode 100644 index 0000000000..618a53d694 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RELU_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RELU_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeReluParser : public CaffeNodeParser { + public: + CaffeReluParser() : CaffeNodeParser("relu") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RELU_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc new file mode 100644 index 0000000000..ee0e461e98 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ReshapeT()); + attr->format = schema::Format_NCHW; + + const caffe::ReshapeParameter reshapeParam = proto.reshape_param(); + + if (!reshapeParam.has_shape()) { + // MS_LOGE("Reshape has no shape info, ret fail"); + return RET_ERROR; + } + + const caffe::BlobShape &blob_shape = reshapeParam.shape(); + for (int i = 0; i < blob_shape.dim_size(); i++) { + attr->shape.push_back(blob_shape.dim(i)); + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeReshapeParser("Reshape", new CaffeReshapeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h new file mode 100644 index 0000000000..142751e457 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeReshapeParser : public CaffeNodeParser { + public: + CaffeReshapeParser() : CaffeNodeParser("reshape") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc new file mode 100644 index 0000000000..83198c2702 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h" + +const int32_t NCHW_DIM_C = 1; +const int32_t DIM_DEFAULT_SIZE = 4; + +namespace mindspore { +namespace lite { +STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + std::unique_ptr attr(new schema::ScaleT()); + + if (weight.blobs_size() + weight.bottom_size() < 2) { + // MS_LOGE("Scale bottom size:%d, blobs size:%d invalid in layer %s", weight.bottom_size(), weight.blobs_size(), + // weight.name().c_str()); + return RET_ERROR; + } + + const caffe::ScaleParameter scaleParam = weight.scale_param(); + int axis = NCHW_DIM_C; + if (scaleParam.has_axis()) { + uint32_t axis_index = NCHW_DIM_C; + if (GetAxisIndex(scaleParam.axis(), &axis_index)) { + // MS_LOGE("scale get axis failed for layer %s.", weight.name().c_str()); + } + } + attr->axis = axis; + + // parse scale + // todo expect only weight as scale not bias + if (weight.blobs().size() == 1) { + auto scale = ConvertWeight(weight.blobs(0)); + if (scale == nullptr) { + // MS_LOGE("Scale Convert blobs(0) for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(scale); + } else if (weight.blobs().size() >= 2) { + auto scale = ConvertWeight(weight.blobs(0)); + if (scale == nullptr) { + // MS_LOGE("Scale Convert blobs(0) for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(scale); + + // parse bias + bool scaleBias = scaleParam.bias_term(); + if (scaleBias) { + auto bias = ConvertWeight(weight.blobs_size() > 1 ? weight.blobs(1) : weight.blobs(0)); + if (bias == nullptr) { + // MS_LOGE("Scale Convert blobs(1) for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(bias); + } + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Scale; + return RET_OK; +} + +STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) { + if (axis < -DIM_DEFAULT_SIZE || axis >= DIM_DEFAULT_SIZE) { + // MS_LOGE("Scale axis value(%d) is not correct, ", axis); + return RET_PARAM_INVALID; + } + + if (axis == -1) { + // MS_LOGW("axis with -1 may lead to calculation errors when input less than 4 dims."); + } + + *axis_index = (axis + DIM_DEFAULT_SIZE) % DIM_DEFAULT_SIZE; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeScaleParser("Scale", new CaffeScaleParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h new file mode 100644 index 0000000000..cdd5c70726 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SCALE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SCALE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeScaleParser : public CaffeNodeParser { + public: + CaffeScaleParser() : CaffeNodeParser("scale") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SCALE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc new file mode 100644 index 0000000000..20c2590ff8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeSigmoidParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Activation; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeSigmoidParser("Sigmoid", new CaffeSigmoidParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h new file mode 100644 index 0000000000..5f795b11d3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SIGMOID_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SIGMOID_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeSigmoidParser : public CaffeNodeParser { + public: + CaffeSigmoidParser() : CaffeNodeParser("sigmoid") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SIGMOID_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc new file mode 100644 index 0000000000..399be822c1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h" +#include "utils/log_adapter.h" + +static const int32_t CAFFE_SOFTMAX_DEFAULT_AXIS = 1; + +namespace mindspore { +namespace lite { +STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::SoftMaxT()); + if (proto.has_softmax_param() && proto.softmax_param().has_axis()) { + if (proto.softmax_param().axis() == -1) { + MS_LOG(ERROR) << "axis with -1 may lead to calculation errors when input less than 4 dims."; + } + attr->axis = proto.softmax_param().axis(); + } else { + attr->axis = CAFFE_SOFTMAX_DEFAULT_AXIS; + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_SoftMax; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeSoftmaxParser("Softmax", new CaffeSoftmaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h new file mode 100644 index 0000000000..f8675d4fd5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SOFTMAX_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SOFTMAX_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeSoftmaxParser : public CaffeNodeParser { + public: + CaffeSoftmaxParser() : CaffeNodeParser("softmax") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SOFTMAX_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt b/mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt new file mode 100644 index 0000000000..d0b6cd3d52 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB_RECURSE ONNX_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) + +add_library(onnx_parser_mid OBJECT + ${ONNX_SRC_LIST} + ) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx.proto b/mindspore/lite/tools/converter/parser/onnx/onnx.proto new file mode 100644 index 0000000000..093fcf99c0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx.proto @@ -0,0 +1,569 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + optional SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +message TensorAnnotation { + optional string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + optional TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + optional TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional int32 key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +} \ No newline at end of file diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc new file mode 100644 index 0000000000..c6990ad085 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ArgMaxT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "keepdims") { + attr->keepDims = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ArgMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h new file mode 100644 index 0000000000..609aa53956 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ARGMAX_PARSER_H +#define MS_ONNX_ARGMAX_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxArgMaxParser : public OnnxNodeParser { + public: + OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc new file mode 100644 index 0000000000..44ee34f8fe --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -0,0 +1,270 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_RealDiv; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxMeanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mean; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Power; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Equal; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Less; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Greater; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Min; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + std::unique_ptr attr(new schema::EltwiseT()); + if (onnx_node.op_type() == "Prod") { + attr->mode = schema::EltwiseMode_PROD; + } else if (onnx_node.op_type() == "Prod") { + attr->mode = schema::EltwiseMode_SUM; + } else if (onnx_node.op_type() == "Sum") { + attr->mode = schema::EltwiseMode_MAXIMUM; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Eltwise; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Floor; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Abs; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Neg; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Exp; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Cos; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sin; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sqrt; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Ceil; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Log; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Tan; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Atan; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Asin; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); +OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); +OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); +OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); +OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); +OnnxNodeRegistrar g_onnxMeanParser("Mean", new OnnxMeanParser()); +OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser()); +OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); +OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); +OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); +OnnxNodeRegistrar g_onnxMinParser("Min", new OnnxMinParser()); +OnnxNodeRegistrar g_onnxProdParser("Prod", new OnnxEltwiseParser()); +OnnxNodeRegistrar g_onnxSumParser("Sum", new OnnxEltwiseParser()); +OnnxNodeRegistrar g_onnxMaxParser("Max", new OnnxEltwiseParser()); +OnnxNodeRegistrar g_onnxFloorParser("Floor", new OnnxFloorParser()); +OnnxNodeRegistrar g_onnxAbsParser("Abs", new OnnxAbsParser()); +OnnxNodeRegistrar g_onnxNegParser("Neg", new OnnxNegParser()); +OnnxNodeRegistrar g_onnxExpParser("Exp", new OnnxExpParser()); +OnnxNodeRegistrar g_onnxCosParser("Cos", new OnnxCosParser()); +OnnxNodeRegistrar g_onnxSinParser("Sin", new OnnxSinParser()); +OnnxNodeRegistrar g_onnxSqrtParser("Sqrt", new OnnxSqrtParser()); +OnnxNodeRegistrar g_onnxCeilParser("Ceil", new OnnxCeilParser()); +OnnxNodeRegistrar g_onnxLogParser("Log", new OnnxLogParser()); +OnnxNodeRegistrar g_onnxTanParser("Tan", new OnnxTanParser()); +OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser()); +OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser()); +OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h new file mode 100644 index 0000000000..bbb62e908f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -0,0 +1,171 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H +#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxAddParser : public OnnxNodeParser { + public: + OnnxAddParser() : OnnxNodeParser("Add") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxSubParser : public OnnxNodeParser { + public: + OnnxSubParser() : OnnxNodeParser("Sub") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxMulParser : public OnnxNodeParser { + public: + OnnxMulParser() : OnnxNodeParser("Mul") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxDivParser : public OnnxNodeParser { + public: + OnnxDivParser() : OnnxNodeParser("Div") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxMeanParser : public OnnxNodeParser { + public: + OnnxMeanParser() : OnnxNodeParser("Mean") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxPowParser : public OnnxNodeParser { + public: + OnnxPowParser() : OnnxNodeParser("Power") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxEqualParser : public OnnxNodeParser { + public: + OnnxEqualParser() : OnnxNodeParser("Equal") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxLessParser : public OnnxNodeParser { + public: + OnnxLessParser() : OnnxNodeParser("Less") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxGreaterParser : public OnnxNodeParser { + public: + OnnxGreaterParser() : OnnxNodeParser("Greater") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxMinParser : public OnnxNodeParser { + public: + OnnxMinParser() : OnnxNodeParser("Min") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxEltwiseParser : public OnnxNodeParser { + public: + OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxFloorParser : public OnnxNodeParser { + public: + OnnxFloorParser() : OnnxNodeParser("Floor") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxAbsParser : public OnnxNodeParser { + public: + OnnxAbsParser() : OnnxNodeParser("Abs") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxNegParser : public OnnxNodeParser { + public: + OnnxNegParser() : OnnxNodeParser("Neg") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxExpParser : public OnnxNodeParser { + public: + OnnxExpParser() : OnnxNodeParser("Exp") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxCosParser : public OnnxNodeParser { + public: + OnnxCosParser() : OnnxNodeParser("Cos") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxSinParser : public OnnxNodeParser { + public: + OnnxSinParser() : OnnxNodeParser("Sin") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxSqrtParser : public OnnxNodeParser { + public: + OnnxSqrtParser() : OnnxNodeParser("Sqrt") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxCeilParser : public OnnxNodeParser { + public: + OnnxCeilParser() : OnnxNodeParser("Ceil") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxLogParser : public OnnxNodeParser { + public: + OnnxLogParser() : OnnxNodeParser("Log") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxTanParser : public OnnxNodeParser { + public: + OnnxTanParser() : OnnxNodeParser("Tan") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxAtanParser : public OnnxNodeParser { + public: + OnnxAtanParser() : OnnxNodeParser("Atan") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxAsinParser : public OnnxNodeParser { + public: + OnnxAsinParser() : OnnxNodeParser("Asin") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxTanhParser : public OnnxNodeParser { + public: + OnnxTanhParser() : OnnxNodeParser("Tanh") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc new file mode 100644 index 0000000000..d4ea5cdde5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::FusedBatchNormT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "epsilon") { + attr->epsilon = onnx_node_attr.f(); + } else if (onnx_node_attr.name() == "momentum") { + attr->momentum = onnx_node_attr.f(); + } else if (onnx_node_attr.name() == "spatial") { + attr->spatial = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h new file mode 100644 index 0000000000..c6b6fdb70c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ADD_PARSER_H +#define MS_ONNX_ADD_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxBatchNormParser : public OnnxNodeParser { + public: + OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc new file mode 100644 index 0000000000..be5229ed07 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h" + +// using namespace mindspore::predict; +// using namespace onnx; +// using namespace std; +namespace mindspore { +namespace lite { +STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::BiasAddT()); + // use channel dim as axis + attr->axis = {1}; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_BiasAdd; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h new file mode 100644 index 0000000000..06497be3f0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_BIASADD_PARSER_H +#define MS_ONNX_BIASADD_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxBiasAddParser : public OnnxNodeParser { + public: + OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_BIASADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc new file mode 100644 index 0000000000..66ef7542eb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::CastT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "to") { + attr->dstT = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Cast; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h new file mode 100644 index 0000000000..8d028379aa --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CAST_PARSER_H +#define MS_ONNX_CAST_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxCastParser : public OnnxNodeParser { + public: + OnnxCastParser() : OnnxNodeParser("Cast") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CAST_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc new file mode 100644 index 0000000000..d27994b365 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::ClipT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "max") { + attr->max = onnx_node_attr.f(); + } else if (attribute_name == "min") { + attr->min = onnx_node_attr.f(); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Clip; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h new file mode 100644 index 0000000000..00532e73eb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CLIP_PARSER_H +#define MS_ONNX_CLIP_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxClipParser : public OnnxNodeParser { + public: + OnnxClipParser() : OnnxNodeParser("Clip") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc new file mode 100644 index 0000000000..2054999192 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ConcatT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h new file mode 100644 index 0000000000..b38039cd7b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONCAT_PARSER_H +#define MS_ONNX_CONCAT_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConcatParser : public OnnxNodeParser { + public: + OnnxConcatParser() : OnnxNodeParser("Concat") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CONCAT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc new file mode 100644 index 0000000000..ff618aba0a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Constant; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h new file mode 100644 index 0000000000..0356057b28 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONSTANT_PARSER_H +#define MS_ONNX_CONSTANT_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConstantParser : public OnnxNodeParser { + public: + OnnxConstantParser() : OnnxNodeParser("Constant") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CONSTANT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc new file mode 100644 index 0000000000..89ed7efb86 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h" + +namespace mindspore { +namespace lite { +bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { + if (attr == nullptr || attr->group != attr->channelIn) { + return false; + } + std::unique_ptr depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT()); + if (depthwiseConv2DParam == nullptr) { + // MS_LOGW("new DepthwiseConv2DT failed"); + return false; + } + depthwiseConv2DParam->format = attr->format; + depthwiseConv2DParam->channelIn = attr->channelIn; + depthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn; + depthwiseConv2DParam->kernelW = attr->kernelW; + depthwiseConv2DParam->kernelH = attr->kernelH; + depthwiseConv2DParam->strideW = attr->strideW; + depthwiseConv2DParam->strideH = attr->strideH; + depthwiseConv2DParam->padMode = attr->padMode; + depthwiseConv2DParam->padUp = attr->padUp; + depthwiseConv2DParam->padDown = attr->padDown; + depthwiseConv2DParam->padLeft = attr->padLeft; + depthwiseConv2DParam->padRight = attr->padRight; + depthwiseConv2DParam->dilateW = attr->dilateW; + depthwiseConv2DParam->dilateH = attr->dilateH; + depthwiseConv2DParam->hasBias = attr->hasBias; + depthwiseConv2DParam->activationType = attr->activationType; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + delete (op->primitive->value.value); + op->primitive->value.value = depthwiseConv2DParam.release(); + return true; +} + +STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + auto attr = new schema::Conv2DT(); + // set opdef each attr params + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "group") { + attr->group = static_cast(onnx_node_attr.i()); + } else if (onnx_node_attr.name() == "dilations") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->dilateW = static_cast(onnx_node_attr.ints(0)); + attr->dilateH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernels") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelH = static_cast(onnx_node_attr.ints(0)); + attr->kernelW = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernel_shape") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelW = static_cast(onnx_node_attr.ints(0)); + attr->kernelH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "auto_pad") { + attr->padMode = GetOnnxPadMode(onnx_node_attr); + } else if (onnx_node_attr.name() == "pads") { + if (onnx_node_attr.ints().size() != 4) { + // MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->padUp = static_cast(onnx_node_attr.ints(0)); + attr->padLeft = static_cast(onnx_node_attr.ints(1)); + attr->padDown = static_cast(onnx_node_attr.ints(2)); + attr->padRight = static_cast(onnx_node_attr.ints(3)); + } else if (onnx_node_attr.name() == "strides") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->strideW = static_cast(onnx_node_attr.ints(0)); + attr->strideH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "order") { + if (onnx_node_attr.s() == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + return RET_ERROR; + } + } + } + + const auto &onnx_conv_weight = onnx_node.input(1); + if (onnx_node.op_type() == "Conv") { + auto nodeIter = + std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); + if (nodeIter == onnx_graph.initializer().end()) { + // MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) + return RET_ERROR; + } + std::vector weight_shape; + auto size = (*nodeIter).dims_size(); + for (int i = 0; i < size; ++i) { + weight_shape.emplace_back((*nodeIter).dims(i)); + } + attr->channelOut = weight_shape[0]; + attr->channelIn = weight_shape[1] * attr->group; + } else { + auto nodeIter = + std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), + [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); + if (nodeIter == onnx_graph.node().end()) { + // MS_LOGE("can not find node: %s", onnx_conv_weight.c_str()) + return RET_ERROR; + } + std::vector dims; + auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); + if (iter != (*nodeIter).attribute().end()) { + dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); + } + attr->channelOut = dims[0]; + attr->channelIn = dims[3] * attr->group; + } + attr->format = schema::Format_NCHW; + attr->hasBias = onnx_node.input().size() == 3; + if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { + attr->activationType = schema::ActivationType_RELU; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = attr; + + if (attr->group != 1) { + if (!ParseGroupConvolution(op, attr)) { + delete attr; + // MS_LOGE("Convert Convolution to Depthwise failed"); + return RET_ERROR; + } + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); +OnnxNodeRegistrar g_onnxInt8ConvParser("Int8Conv", new OnnxConvParser()); +OnnxNodeRegistrar g_onnxConvReluParser("ConvRelu", new OnnxConvParser()); +OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h new file mode 100644 index 0000000000..73fa7e531c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONV_PARSER_H +#define MS_ONNX_CONV_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConvParser : public OnnxNodeParser { + public: + OnnxConvParser() : OnnxNodeParser("Conv") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + private: + bool ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc new file mode 100755 index 0000000000..2e7ecb90d4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_converter.h" + +namespace mindspore { +namespace lite { +OnnxConverter::OnnxConverter() { + modelParser = new OnnxModelParser(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h new file mode 100755 index 0000000000..a6fbc75172 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONVERTER_H +#define MS_ONNX_CONVERTER_H +#include +#include +#include "mindspore/lite/tools/converter/converter.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h" +#include "mindspore/lite/tools/converter/graphdef_transform.h" + +namespace mindspore { +namespace lite { +class OnnxConverter : public Converter { + public: + OnnxConverter(); + + ~OnnxConverter() override = default; +}; +} // namespace lite +} // namespace mindspore + +#endif // MS_ONNX_CONVERTER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc new file mode 100644 index 0000000000..7e6e021e98 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h" + +namespace mindspore { +namespace lite { +bool OnnxDeConvParser::ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr) { + if (attr == nullptr || attr->group != attr->channelOut) { + return false; + } + auto deDepthwiseConv2DParam(new (std::nothrow) schema::DeDepthwiseConv2DT()); + if (deDepthwiseConv2DParam == nullptr) { + // MS_LOGW("new DeDepthwiseConv2DT failed"); + return false; + } + deDepthwiseConv2DParam->format = attr->format; + deDepthwiseConv2DParam->channelIn = attr->channelIn; + deDepthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn; + deDepthwiseConv2DParam->kernelW = attr->kernelW; + deDepthwiseConv2DParam->kernelH = attr->kernelH; + deDepthwiseConv2DParam->strideW = attr->strideW; + deDepthwiseConv2DParam->strideH = attr->strideH; + deDepthwiseConv2DParam->padMode = attr->padMode; + deDepthwiseConv2DParam->padUp = attr->padUp; + deDepthwiseConv2DParam->padDown = attr->padDown; + deDepthwiseConv2DParam->padLeft = attr->padLeft; + deDepthwiseConv2DParam->padRight = attr->padRight; + deDepthwiseConv2DParam->dilateW = attr->dilateW; + deDepthwiseConv2DParam->dilateH = attr->dilateH; + deDepthwiseConv2DParam->hasBias = attr->hasBias; + deDepthwiseConv2DParam->activationType = attr->activationType; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + delete (op->primitive->value.value); + op->primitive->value.value = deDepthwiseConv2DParam; + } + return true; +} + +STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + auto attr = new schema::DeConv2DT(); + // set opdef each attr params + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "group") { + attr->group = static_cast(onnx_node_attr.i()); + } else if (onnx_node_attr.name() == "dilations") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->dilateW = static_cast(onnx_node_attr.ints(0)); + attr->dilateH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernels") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelH = static_cast(onnx_node_attr.ints(0)); + attr->kernelW = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernel_shape") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelW = static_cast(onnx_node_attr.ints(0)); + attr->kernelH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "auto_pad") { + attr->padMode = GetOnnxPadMode(onnx_node_attr); + } else if (onnx_node_attr.name() == "pads") { + if (onnx_node_attr.ints().size() != 4) { + // MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->padUp = static_cast(onnx_node_attr.ints(0)); + attr->padLeft = static_cast(onnx_node_attr.ints(1)); + attr->padDown = static_cast(onnx_node_attr.ints(2)); + attr->padRight = static_cast(onnx_node_attr.ints(3)); + } else if (onnx_node_attr.name() == "strides") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->strideW = static_cast(onnx_node_attr.ints(0)); + attr->strideH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "order") { + if (onnx_node_attr.s() == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + return RET_ERROR; + } + } + } + + const auto &onnx_conv_weight = onnx_node.input(1); + auto nodeIter = + std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); + if (nodeIter == onnx_graph.initializer().end()) { + // MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) + return RET_ERROR; + } + std::vector weight_shape; + auto size = (*nodeIter).dims_size(); + for (int i = 0; i < size; ++i) { + weight_shape.emplace_back((*nodeIter).dims(i)); + } + MS_ASSERT(weight_shape.size() == 4); + attr->channelIn = weight_shape[0]; + attr->channelOut = weight_shape[1] * attr->group; + + attr->format = schema::Format_NCHW; + attr->hasBias = onnx_node.input().size() == 3; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeConv2D; + op->primitive->value.value = attr; + } + + if (attr->group != 1) { + if (!ParseGroupDeConvolution(op, attr)) { + delete attr; + // MS_LOGE("Convert DeConvolution to DeDepthwise failed"); + return RET_ERROR; + } + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h new file mode 100644 index 0000000000..b4fba8bf4a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_DECONV_PARSER_H +#define MS_ONNX_DECONV_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxDeConvParser : public OnnxNodeParser { + public: + OnnxDeConvParser() : OnnxNodeParser("DeConv") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + private: + bool ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr); +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_DECONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc new file mode 100644 index 0000000000..ee819cb02d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::DepthToSpaceT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "blocksize") { + attr->blockSize = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthToSpace; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h new file mode 100644 index 0000000000..834d71ccc9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H +#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxDepthToSpaceParser : public OnnxNodeParser { + public: + OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc new file mode 100644 index 0000000000..451a4c4a69 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::DropoutT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "ratio") { + attr->ratio = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Dropout; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h new file mode 100644 index 0000000000..14898f4616 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ARGMAX_PARSER_H +#define MS_ONNX_ARGMAX_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxDropoutParser : public OnnxNodeParser { + public: + OnnxDropoutParser() : OnnxNodeParser("Dropout") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc new file mode 100644 index 0000000000..b1497c68a3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::EluT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "alpha") { + attr->alpha = onnx_node_attr.f(); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Elu; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h new file mode 100644 index 0000000000..4267609791 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ELU_PARSER_H +#define MS_ONNX_ELU_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxEluParser : public OnnxNodeParser { + public: + OnnxEluParser() : OnnxNodeParser("Elu") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ELU_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc new file mode 100644 index 0000000000..ad155fbf4e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Broadcast; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h new file mode 100644 index 0000000000..604281dbfb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_EXPAND_PARSER_H +#define MS_ONNX_EXPAND_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxExpandParser : public OnnxNodeParser { + public: + OnnxExpandParser() : OnnxNodeParser("Expand") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_EXPAND_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc new file mode 100644 index 0000000000..d06c620731 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ReshapeT()); + int axis = 1; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + axis = static_cast(onnx_node_attr.i()); + } + } + for (int i = 0; i < axis; ++i) { + attr->shape.emplace_back(0); + } + attr->shape.emplace_back(-1); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h new file mode 100644 index 0000000000..cacecc3758 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_FLATTEN_PARSER_H +#define MS_ONNX_FLATTEN_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxFlattenParser : public OnnxNodeParser { + public: + OnnxFlattenParser() : OnnxNodeParser("Fatten") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_FLATTEN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc new file mode 100644 index 0000000000..b9afef1a25 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::GatherT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Gather; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h new file mode 100644 index 0000000000..ef2d306f59 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_GATHER_PARSER_H +#define MS_ONNX_GATHER_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxGatherParser : public OnnxNodeParser { + public: + OnnxGatherParser() : OnnxNodeParser("Gather") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_GATHER_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc new file mode 100644 index 0000000000..0590a291e6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::LrnT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "size") { + attr->size = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "alpha") { + attr->alpha = onnx_node_attr.f(); + } else if (attribute_name == "beta") { + attr->beta = onnx_node_attr.f(); + } else if (attribute_name == "bias") { + attr->bias = onnx_node_attr.f(); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Lrn; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h new file mode 100644 index 0000000000..e3b15045a2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_LRN_PARSER_H +#define MS_ONNX_LRN_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxLrnParser : public OnnxNodeParser { + public: + OnnxLrnParser() : OnnxNodeParser("Lrn") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_LRN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc new file mode 100644 index 0000000000..857d38e207 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::MatMulT()); + float alpha = 1.0f; + float beta = 1.0f; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "transA") { + attr->transposeA = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "transB") { + attr->transposeB = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "alpha") { + alpha = onnx_node_attr.f(); + } else if (attribute_name == "beta") { + beta = onnx_node_attr.f(); + } + } + if (alpha != 1 || beta != 1) { + // MS_LOGE("not support alpha * A * B + beta * C"); + return RET_PARAM_INVALID; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_MatMul; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h new file mode 100644 index 0000000000..9c7565ded0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_MATMUL_PARSER_H +#define MS_ONNX_MATMUL_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxMatmulParser : public OnnxNodeParser { + public: + OnnxMatmulParser() : OnnxNodeParser("MatMul") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_MATMUL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc new file mode 100755 index 0000000000..2f06beb0a5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -0,0 +1,512 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h" +#include "tools/common/graph_util.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +OnnxModelParser::OnnxModelParser() = default; +OnnxModelParser::~OnnxModelParser() = default; + +static const std::unordered_map TYPE_MAP = { + {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8}, + {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8}, + {onnx::TensorProto_DataType_INT16, mindspore::kNumberTypeInt16}, + {onnx::TensorProto_DataType_INT32, mindspore::kNumberTypeInt32}, + {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32}, + {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64}, + {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat}}; + +TypeId OnnxModelParser::GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { + auto iter = TYPE_MAP.find(onnx_type); + if (iter == TYPE_MAP.end()) { + return kTypeUnknown; + } + return iter->second; +} + +std::vector OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value) { + std::vector dims; + const auto shape_info = onnx_value.type().tensor_type().shape(); + for (const auto &it : onnx_value.type().tensor_type().shape().dim()) { + dims.emplace_back(it.dim_value()); + } + return dims; +} + +STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { + std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); + if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) { + // MS_LOGE("get realpath %s fail", modelFile.c_str()); + return RET_ERROR; + } + int fd = open(onnx_file.get(), O_RDONLY); + google::protobuf::io::FileInputStream input(fd); + google::protobuf::io::CodedInputStream code_input(&input); + code_input.SetTotalBytesLimit(INT_MAX, 536870912); + bool ret = onnx_model->ParseFromCodedStream(&code_input); + if (!ret) { + // MS_LOGE("load onnx file failed"); + return RET_ERROR; + } + (void)close(fd); + return RET_OK; +} + +STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { + // MS_LOGD("set onnx constant tensors"); + for (const auto &onnx_const_value : onnx_graph.initializer()) { + std::vector dims; + std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(dims)); + auto data_type = GetDateTypeFromOnnx(static_cast(onnx_const_value.data_type())); + if (data_type == kTypeUnknown) { + // MS_LOGE("not support onnx type %d", static_cast(onnx_const_value.data_type())); + return RET_ERROR; + } + std::unique_ptr tensor(new (std::nothrow) schema::TensorT); + if (tensor == nullptr) { + // MS_LOGE("new tensor failed"); + return RET_ERROR; + } + tensor->dataType = data_type; + tensor->format = schema::Format_NCHW; + for (const auto &it : dims) { + tensor->dims.emplace_back(it); + } + tensor->nodeType = schema::NodeType_ValueNode; + if (CopyOnnxTensorData(onnx_const_value, tensor.get())) { + return RET_ERROR; + } + const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT); + // MS_LOGD("add const tensor: %s, index %d", onnx_const_value.name().c_str(), index) + } + return RET_OK; +} + +STATUS OnnxModelParser::AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor) { + auto data_type = GetDateTypeFromOnnx(static_cast(proto.type().tensor_type().elem_type())); + if (data_type == kTypeUnknown) { + // MS_LOGE("not support onnx type %d", + // static_cast(proto.type().tensor_type().elem_type())); + return RET_ERROR; + } + tensor->dataType = data_type; + tensor->dims = GetDimsFromOnnxValue(proto); + tensor->format = schema::Format_NCHW; + tensor->nodeType = schema::NodeType_ValueNode; + return RET_OK; +} + +STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, + TensorCache *tensor_cache) { + for (const auto &input_value : onnx_graph.input()) { + auto ret = tensor_cache->FindTensor(input_value.name()); + if (ret < 0) { + std::unique_ptr tensor(new schema::TensorT); + if (AddTensorCache(input_value, tensor.get())) { + return RET_ERROR; + } + auto tensor_index = tensor_cache->AddTensor(input_value.name(), tensor.release(), GRAPH_INPUT); + graph->inputIndex.emplace_back(static_cast(tensor_index)); + // MS_LOGD("input_value name: %s, graph input index: %d", input_value.name().c_str(), tensor_index); + } + } + return RET_OK; +} + +STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, + TensorCache *tensor_cache) { + for (const auto &output_value : onnx_graph.output()) { + std::unique_ptr tensor(new schema::TensorT); + if (AddTensorCache(output_value, tensor.get())) { + return RET_ERROR; + } + auto tensor_index = tensor_cache->AddTensor(output_value.name(), tensor.release(), OP_OUTPUT); + graph->outputIndex.emplace_back(tensor_index); + // MS_LOGD("output_value name: %s, graph output index: %d", output_value.name().c_str(), tensor_index); + } + return RET_OK; +} + +void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, TensorCache *tensor_cache) { + std::unique_ptr dst_op_1(new schema::CNodeT); + dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); + // dst_op_1->fmkType = FmkType_ONNX; + ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); + auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); + std::vector matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; + std::vector matmul_outputs{matmul_output_id}; + SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache); + SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache); + graph->nodes.emplace_back(std::move(dst_op_1)); + + std::unique_ptr dst_op_2(new schema::CNodeT); + dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); + // dst_op_2->fmkType = FmkType_ONNX; + ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); + std::vector biasadd_inputs{matmul_output_id, onnx_node.input(2)}; + std::vector biasadd_outputs{onnx_node.output(0)}; + SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache); + SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache); + graph->nodes.emplace_back(std::move(dst_op_2)); +} + +STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { + // convert GivenTensorFill node to a weight/bias tensor + auto ret = tensor_cache->FindTensor(onnx_node.output(0)); + if (ret < 0) { + std::unique_ptr tensor(new schema::TensorT); + std::vector shape; + auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); + if (iter != onnx_node.attribute().end()) { + (void)shape.insert(shape.begin(), iter->ints().begin(), iter->ints().end()); + std::for_each(shape.begin(), shape.end(), [](int sh) { /*MS_LOGD("shape: %d", sh);*/ }); + } + tensor->dims = shape; + tensor->format = schema::Format_NUM_OF_FORMAT; + tensor->nodeType = schema::NodeType_ValueNode; + iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); + // copy GivenIntTensorFill node value to tensor + if (iter != onnx_node.attribute().end()) { + size_t data_count = 1; + std::for_each(shape.begin(), shape.end(), [&data_count](int dim) { data_count *= dim; }); + size_t data_size = 0; + if (onnx_node.op_type() == "Int8GivenIntTensorFill") { + // todo how to read onnx-ori-dataType + tensor->dataType = kNumberTypeInt32; + data_size = data_count * sizeof(int32_t) / sizeof(uint8_t); + tensor->data.resize(data_size); + void *tensorData = tensor->data.data(); + auto castedTensorData = static_cast(tensorData); + MS_ASSERT(castedTensorData != nullptr); + for (size_t i = 0; i < data_count; i++) { + castedTensorData[i] = int32_t(iter->ints().data()[i]); + } + } else if (onnx_node.op_type() == "Int8GivenTensorFill") { + // todo how to read onnx-ori-dataType + tensor->dataType = kNumberTypeUInt8; + // todo: add * sizof(string) + data_size = data_count; + tensor->data.resize(data_size); + // MS_LOGD("tensor data size %lu, s: %lu", data_size, sizeof(iter->s().data())); + if (memcpy_s(tensor->data.data(), data_size, iter->s().data(), data_size) != 0) { + // MS_LOGE("memcpy_s failed") + return RET_ERROR; + } + } else { + // MS_LOGE("unsupported data type %d", tensor->dataType); + return RET_ERROR; + } + } + auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); + // MS_LOGD("add given tensor: %d", index); + } + return RET_OK; +} + +STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, + TensorCache *tensor_cache) { + // change op_type() to name(), that is unique + dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); + // dst_op->fmkType = FmkType_ONNX; + // MS_LOGD("onnx op name %s, dst op name: %s, input size %d", onnx_node.op_type().c_str(), dst_op->name.c_str(), + // onnx_node.input_size()); + // get the real op type + SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); + auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op); + if (status != RET_OK) { + // MS_LOGE("parser onnx node attr failed"); + return status; + } + // set op input index + std::vector node_inputs; + (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); + if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { + // MS_LOGE("SetOpInputIndex failed"); + return RET_ERROR; + } + // set op output index + std::vector node_outputs; + (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); + if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { + // MS_LOGE("SetOpOutputIndex failed"); + return RET_ERROR; + } + return RET_OK; +} + +void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { + MS_ASSERT(dst_op != nullptr); + MS_ASSERT(tensor_cache != nullptr); + std::vector quant_node_name; + quant_node_name.insert(quant_node_name.begin(), onnx_node.input().begin(), onnx_node.input().end()); + quant_node_name.insert(quant_node_name.end(), onnx_node.output().begin(), onnx_node.output().end()); + std::vector quant_node; + for (const auto &str : quant_node_name) { + for (auto &node : onnx_graph.node()) { + if (node.output(0) == str) { + quant_node.emplace_back(node); + break; + } + } + } + auto needQuantParams = size_t(onnx_node.input().size() + onnx_node.output().size()); + for (auto iter = onnx_node.input().begin(); iter != onnx_node.input().end(); iter++) { + if (IsContain(this->graphInputNames, *iter)) { + needQuantParams--; + } + } + size_t findQuantParams = 0; + for (const auto &node : quant_node) { + std::unique_ptr quant_param(new (std::nothrow) schema::QuantParamT()); + if (quant_param == nullptr) { + // MS_LOGE("new QuantParamT failed, node: %s", dst_op->name.c_str()); + return; + } + // std::unique_ptr quant_param_array(new (std::nothrow) QuantParamArrayT()); + if (quant_param == nullptr) { + // MS_LOGE("new QuantParamArrayT failed, node: %s", dst_op->name.c_str()); + return; + } + int argNum = 0; + for (const auto &onnx_node_attr : node.attribute()) { + if (onnx_node_attr.name() == "Y_scale") { + quant_param->scale = onnx_node_attr.f(); + argNum++; + } else if (onnx_node_attr.name() == "Y_zero_point") { + quant_param->zeroPoint = static_cast(onnx_node_attr.i()); + argNum++; + } + } + if (argNum != 2) { + quant_param->scale = FLT_MAX; + quant_param->zeroPoint = 0; + quant_param->min = FLT_MAX; + quant_param->max = FLT_MAX; + } + // quant_param_array->param.emplace_back(std::move(quant_param)); + dst_tensor->quantParams.emplace_back(std::move(quant_param)); + if (argNum == 2) { + findQuantParams++; + } + } + if (findQuantParams == needQuantParams) { + dst_op->quantType = schema::QuantType_AwareTrainning; + } else { + dst_op->quantType = schema::QuantType_QUANT_NONE; + } +} + +STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + const string &onnx_op_type, schema::CNodeT *dst_op) { + auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); + if (node_parser == nullptr) { + // MS_LOGE("not find %s, node parser is nullptr", onnx_op_type.c_str()); + return RET_NULL_PTR; + } + return node_parser->Parse(onnx_graph, onnx_node, dst_op); +} + +STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { + schema::Format format = schema::Format_MAX; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "order") { + if (onnx_node_attr.s() == "NHWC") { + format = schema::Format_NHWC; + } else { + // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + return RET_ERROR; + } + } + } + for (const auto &onnx_node_input : node_inputs) { + auto index = tensor_cache->FindTensor(onnx_node_input); + if (index < 0) { + std::unique_ptr tensor(new schema::TensorT); + index = tensor_cache->AddTensor(onnx_node_input, tensor.release(), OP_OUTPUT); + } + if (format != schema::Format_MAX) { + auto inTensor = tensor_cache->GetCachedTensor().at(index); + inTensor->format = format; + } + // MS_LOGD("node: %s, input index: %d", onnx_node_input.c_str(), index); + dst_op->inputIndex.emplace_back(index); + } + return RET_OK; +} + +STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, + TensorCache *tensor_cache) { + for (const auto &onnx_node_output : node_outputs) { + auto index = tensor_cache->FindTensor(onnx_node_output); + if (index < 0) { + std::unique_ptr tensor(new schema::TensorT); + index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); + } + // MS_LOGD("node: %s, input index: %d", onnx_node_output.c_str(), index); + dst_op->outputIndex.emplace_back(index); + } + return RET_OK; +} + +STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, + schema::TensorT *tensor) { + size_t data_count = 1; + std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); + size_t data_size = 0; + const void *tensor_data = nullptr; + switch (tensor->dataType) { + case kNumberTypeFloat: + data_size = data_count * sizeof(float); + if (onnx_const_value.float_data_size() == 0) { + tensor_data = onnx_const_value.raw_data().data(); + } else { + tensor_data = onnx_const_value.float_data().data(); + } + break; + case kNumberTypeInt32: + data_size = data_count * sizeof(int); + if (onnx_const_value.int32_data_size() == 0) { + tensor_data = onnx_const_value.raw_data().data(); + } else { + tensor_data = onnx_const_value.int32_data().data(); + } + break; + case kNumberTypeInt64: + data_size = data_count * sizeof(int64_t); + if (onnx_const_value.int64_data_size() == 0) { + tensor_data = onnx_const_value.raw_data().data(); + } else { + tensor_data = onnx_const_value.int64_data().data(); + } + break; + case kNumberTypeUInt8: + case kNumberTypeInt8: + data_size = data_count * sizeof(uint8_t); + tensor_data = onnx_const_value.raw_data().data(); + break; + default: + // MS_LOGE("unsupported data type %d", tensor->dataType); + return RET_ERROR; + } + tensor->data.resize(data_size); + if (memcpy_s(static_cast(tensor->data.data()), data_size, tensor_data, data_size) != 0) { + // MS_LOGE("memcpy_s failed") + return RET_ERROR; + } + return RET_OK; +} + +STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { + std::vector tensors = tensor_cache.GetCachedTensor(); + for (auto iter : tensors) { + std::unique_ptr temp(iter); + graphDef->allTensors.emplace_back(move(temp)); + } + return RET_OK; +} + +void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) { + this->graphInputNames.clear(); + this->graphConstNames.clear(); + for (auto &onnx_const : onnx_graph.initializer()) { + this->graphConstNames.emplace_back(onnx_const.name()); + } + for (auto &onnx_input : onnx_graph.input()) { + if (!IsContain(this->graphConstNames, onnx_input.name())) { + this->graphInputNames.emplace_back(onnx_input.name()); + } + } +} + +MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { + if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { + // MS_LOGE("Input illegal: modelFile must be *.onnx"); + return nullptr; + } + std::unique_ptr dst_graph(new schema::MetaGraphT()); + onnx::ModelProto onnx_model; + if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { + // MS_LOGE("read onnx model fail"); + return nullptr; + } + const onnx::GraphProto &onnx_graph = onnx_model.graph(); + // MS_LOGI("model producer name: %s, graph name: %s", onnx_model.producer_name().c_str(), onnx_graph.name().c_str()); + TensorCache tensor_cache; + dst_graph->name = onnx_graph.name(); + // find out input names and const names + FindGraphInputAndConst(onnx_graph); + // set const tensor + if (SetGraphConstTensor(onnx_graph, &tensor_cache)) { + // MS_LOGE("SetGraphConstTensor failed"); + return nullptr; + } + // init onnx model graph input tensor + if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { + // MS_LOGE("SetGraphInputTensor failed"); + return nullptr; + } + // init onnx model graph output tensor + if (SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { + // MS_LOGE("SetGraphOutputTensor failed"); + return nullptr; + } + // init op node input/output tensor, and dst_op attr + for (const auto &onnx_node : onnx_graph.node()) { + if (onnx_node.op_type() == "Gemm") { + ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); + continue; + } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { + auto status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); + if (status != RET_OK) { + // MS_LOGE("ParseOnnxGivenFillNode failed: %d", status); + return nullptr; + } + continue; + } + + std::unique_ptr dst_op(new schema::CNodeT); + std::unique_ptr dst_tensor(new schema::TensorT); + if (ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache)) { + // MS_LOGE("parse node %s failed", onnx_node.op_type().c_str()) + return nullptr; + } + dst_graph->nodes.emplace_back(std::move(dst_op)); + } + SetAllTensors(tensor_cache, dst_graph.get()); + dst_graph->mempoolSize = 0; + dst_graph->name = GetModelName(modelFile); + return dst_graph.release(); +// return Fb2Anf(dst_graph.release()); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h new file mode 100644 index 0000000000..f179082a70 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_MODEL_PARSER_H +#define MS_ONNX_MODEL_PARSER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "securec/include/securec.h" +#include "mindspore/lite/tools/converter/model_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class OnnxModelParser : public ModelParser { + public: + OnnxModelParser(); + virtual ~OnnxModelParser(); + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; + + private: + TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type); + std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); + STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto); + STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); + STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); + STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); + STATUS AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor); + STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, TensorCache *tensor_cache); + void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, + TensorCache *tensor_cache); + STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); + STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + const string &onnx_op_type, schema::CNodeT *dst_op); + void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, TensorCache *tensor_cache); + STATUS SetOpInputIndex(const std::vector &node_inputs, + schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, + TensorCache *tensor_cache); + STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); + STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); + STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); + void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); + + private: + std::vector graphInputNames; + std::vector graphConstNames; +}; +} // namespace lite +} // namespace mindspore + +#endif // MS_ONNX_MODEL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc new file mode 100644 index 0000000000..cd225232cf --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" + +namespace mindspore { +namespace lite { +schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) { + if (onnx_node_attr.s() == "NOTSET") { + return schema::PadMode_NOTSET; + } else if (onnx_node_attr.s() == "SAME_UPPER" || onnx_node_attr.s() == "SAME_LOWER") { + return schema::PadMode_SAME; + } else if (onnx_node_attr.s() == "VALID") { + return schema::PadMode_VALID; + } else { + // MS_LOGE("unsupported padMode"); + return schema::PadMode_NOTSET; + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h new file mode 100644 index 0000000000..a479d9b033 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_NODE_PARSER_H +#define MS_ONNX_NODE_PARSER_H + +#include +#include "google/protobuf/message.h" +#include "mindspore/lite/tools/converter/proto/onnx.pb.h" +#include "tools/common/node_util.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +// using namespace std; + +namespace mindspore { +namespace lite { +class OnnxNodeParser { + public: + explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {} + virtual ~OnnxNodeParser() = default; + virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; + + protected: + schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); + const std::string &name; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_NODE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc new file mode 100644 index 0000000000..daefc9964b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" +#include + +namespace mindspore { +namespace lite { +OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default; + +OnnxNodeParserRegistry::~OnnxNodeParserRegistry() = default; + +OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() { + static OnnxNodeParserRegistry instance; + return &instance; +} + +OnnxNodeParser *OnnxNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + for (auto const &i : parsers) { + if (name.find(i.first) != std::string::npos) { + return i.second; + } + } + return nullptr; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h new file mode 100644 index 0000000000..f4781467df --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_OP_REGISTRY_H +#define MS_ONNX_OP_REGISTRY_H + +#include +#include +#include "mindspore/lite/tools/converter/proto/onnx.pb.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" + +namespace mindspore { +namespace lite { +class OnnxNodeParserRegistry { + public: + OnnxNodeParserRegistry(); + + virtual ~OnnxNodeParserRegistry(); + + static OnnxNodeParserRegistry *GetInstance(); + OnnxNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class OnnxNodeRegistrar { + public: + OnnxNodeRegistrar(const std::string &name, OnnxNodeParser *parser) { + OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MS_ONNX_OP_REGISTRY_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc new file mode 100644 index 0000000000..c200d14f3a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::PadT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "pads") { + const int size = onnx_node_attr.ints_size(); + attr->paddings.resize(size); + for (int i = 0; i < size / 2; ++i) { + attr->paddings[i * 2] = static_cast(onnx_node_attr.ints(i)); + attr->paddings[i * 2 + 1] = static_cast(onnx_node_attr.ints(i + size / 2)); + } + } else if (attribute_name == "mode") { + const auto &mode = onnx_node_attr.s(); + if (mode == "constant") { + attr->paddingmode = schema::PaddingMode_CONSTANT; + } else if (mode == "reflect") { + attr->paddingmode = schema::PaddingMode_REFLECT; + } else if (mode == "edge") { + attr->paddingmode = schema::PaddingMode_SYMMETRIC; + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pad; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h new file mode 100644 index 0000000000..ba2e54bc59 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_LRN_PARSER_H +#define MS_ONNX_LRN_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxPadParser : public OnnxNodeParser { + public: + OnnxPadParser() : OnnxNodeParser("Pad") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_LRN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc new file mode 100644 index 0000000000..70e6136690 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::PoolingT()); + + const auto &pool_type = onnx_node.op_type(); + if (pool_type == "MaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + attr->global = false; + } else if (pool_type == "AveragePool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + attr->global = false; + } else if (pool_type == "GlobalMaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + attr->global = true; + } else if (pool_type == "GlobalAveragePool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + attr->global = true; + } else { + // MS_LOGE("Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."); + return RET_ERROR; + } + + attr->roundMode = schema::RoundMode_FLOOR; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "kernel_shape") { + if (onnx_node_attr.ints_size() == 2) { + attr->windowW = static_cast(onnx_node_attr.ints(0)); + attr->windowH = static_cast(onnx_node_attr.ints(1)); + } + } + if (attribute_name == "strides") { + if (onnx_node_attr.ints_size() == 2) { + attr->strideW = static_cast(onnx_node_attr.ints(0)); + attr->strideH = static_cast(onnx_node_attr.ints(1)); + } + } + if (attribute_name == "auto_pad") { + MS_ASSERT(false); + } + if (attribute_name == "pads") { + if (onnx_node_attr.ints_size() == 4) { + attr->padMode = schema::PadMode_CAFFE; + attr->padUp = static_cast(onnx_node_attr.ints(0)); + attr->padDown = static_cast(onnx_node_attr.ints(1)); + attr->padLeft = static_cast(onnx_node_attr.ints(0)); + attr->padRight = static_cast(onnx_node_attr.ints(1)); + } + } + if (attribute_name == "ceil_mode") { + MS_ASSERT(false); // todo (h00500767) + attr->roundMode = schema::RoundMode_CEIL; + } + if (attribute_name == "dilations") { + MS_ASSERT(false); // todo pooling op not support dilations now + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser()); +OnnxNodeRegistrar g_onnxAveragePoolParser("AveragePool", new OnnxPoolParser()); +OnnxNodeRegistrar g_onnxGlobalAveragePoolParser("GlobalAveragePool", new OnnxPoolParser()); +OnnxNodeRegistrar g_onnxGlobalMaxPoolParser("GlobalMaxPool", new OnnxPoolParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h new file mode 100644 index 0000000000..ce439cf3f1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_POOL_PARSER_H +#define MS_ONNX_POOL_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxPoolParser : public OnnxNodeParser { + public: + OnnxPoolParser() : OnnxNodeParser("Pool") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_POOL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc new file mode 100644 index 0000000000..c445373903 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ReduceT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + const int &size = onnx_node_attr.ints_size(); + for (int i = 0; i < size; ++i) { + attr->axes.push_back(onnx_node_attr.ints(i)); + } + } else if (attribute_name == "keepdims") { + attr->keepDims = static_cast(onnx_node_attr.i()); + } + } + const auto &type = onnx_node.op_type(); + if (type == "ReduceMean") { + attr->mode = schema::ReduceMode_ReduceMean; + } else if (type == "ReduceMax") { + attr->mode = schema::ReduceMode_ReduceMax; + } else if (type == "ReduceMin") { + attr->mode = schema::ReduceMode_ReduceMin; + } else if (type == "ReduceSum") { + attr->mode = schema::ReduceMode_ReduceSum; + } else { + // MS_LOGE("unsupoort type"); + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceMaxParser("ReduceMax", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceMinParser("ReduceMin", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceProdParser("ReduceProd", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceSumParser("ReduceSum", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceSumSquareParser("ReduceSumSquare", new OnnxReduceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h new file mode 100644 index 0000000000..9b19d37062 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_REDUCE_PARSER_H +#define MS_ONNX_REDUCE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxReduceParser : public OnnxNodeParser { + public: + OnnxReduceParser() : OnnxNodeParser("Reduce") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_REDUCE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc new file mode 100644 index 0000000000..9947b5fa8c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h" +#include "securec/include/securec.h" +namespace mindspore { +namespace lite { +STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::ActivationT()); + const auto &relu_type = onnx_node.op_type(); + if (relu_type == "Relu") { + attr->type = schema::ActivationType_RELU; + } else if (relu_type == "LeakyRelu") { + attr->type = schema::ActivationType_LEAKY_RELU; + } + + if (op != nullptr) { + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (onnx_node.input_size() != 2) { + // MS_LOGE("input num is not 2") + return RET_PARAM_INVALID; + } + unique_ptr attr(new schema::PreluT()); + std::vector params; + for (int i = 0; i < onnx_node.input_size(); ++i) { + const auto &input_name = onnx_node.input(i); + for ( const auto &it : onnx_graph.initializer() ) { + if (it.name() == "input_name") { + params.push_back(it); + break; + } + } + } + const onnx::TensorProto *slope = ¶ms[0]; + if (slope == nullptr) { + // MS_LOGE("input error") + return RET_PARAM_INVALID; + } + const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); + const int64_t slope_size = slope->raw_data().size() / sizeof(float); + if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) { + // MS_LOGE("memcpy_s failed") + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Prelu; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); +OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxLeakeyReluParser()); +OnnxNodeRegistrar g_onnxPReluParser("Prelu", new OnnxPReluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h new file mode 100644 index 0000000000..a3750c21be --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_RELU_PARSER_H +#define MS_ONNX_RELU_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxReluParser : public OnnxNodeParser { + public: + OnnxReluParser() : OnnxNodeParser("Relu") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxLeakeyReluParser : public OnnxReluParser { + public: + OnnxLeakeyReluParser() : OnnxReluParser() {} +}; + +class OnnxPReluParser : public OnnxNodeParser { + public: + OnnxPReluParser() : OnnxNodeParser("Prelu") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_RELU_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc new file mode 100644 index 0000000000..c02c428494 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ReshapeT()); + attr->format = schema::Format_NHWC; + + std::vector params; + for (int i = 0; i < onnx_node.input_size(); ++i) { + const auto &input_name = onnx_node.input(i); + for (const auto &it : onnx_graph.initializer()) { + if (it.name() == input_name) { + params.emplace_back(it); + break; + } + } + } + if (params.empty()) { + return RET_OK; + } + if (params.size() != 1) { + // MS_LOGE("input num is ,not equal 1", params.size()) + return RET_PARAM_INVALID; + } + + auto pre_shape = params[0]; + for (int i = 0; i < pre_shape.dims_size(); ++i) { + attr->shape.emplace_back(params[0].dims(i)); + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h new file mode 100644 index 0000000000..5c0d673dfc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_RESHAPE_PARSER_H +#define MS_ONNX_RESHAPE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxReshapeParser : public OnnxNodeParser { + public: + OnnxReshapeParser() : OnnxNodeParser("Reshape") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_RESHAPE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc new file mode 100644 index 0000000000..d85740b3b5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Shape; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h new file mode 100644 index 0000000000..27073aa66d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SHAPE_PARSER_H +#define MS_ONNX_SHAPE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxShapeParser : public OnnxNodeParser { + public: + OnnxShapeParser() : OnnxNodeParser("Shape") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SHAPE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc new file mode 100644 index 0000000000..e275e092b9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h new file mode 100644 index 0000000000..55f8664965 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SIGMOID_PARSER_H +#define MS_ONNX_SIGMOID_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSigmoidParser : public OnnxNodeParser { + public: + OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SIGMOID_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc new file mode 100644 index 0000000000..83f9c49f9f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SliceT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "starts") { + const int size = onnx_node_attr.ints_size(); + for (int i = 0; i < size; ++i) { + attr->begin.emplace_back(static_cast(onnx_node_attr.ints(i))); + } + } else if (attribute_name == "ends") { + const int size = onnx_node_attr.ints_size(); + for (int i = 0; i < size; ++i) { + attr->size.emplace_back(static_cast(onnx_node_attr.ints(i))); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Slice; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h new file mode 100644 index 0000000000..6a45db1f31 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SLICE_PARSER_H +#define MS_ONNX_SLICE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSliceParser : public OnnxNodeParser { + public: + OnnxSliceParser() : OnnxNodeParser("Slice") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SLICE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc new file mode 100644 index 0000000000..229cc7848a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SoftMaxT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SoftMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h new file mode 100644 index 0000000000..822944ea5e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SOFTMAX_PARSER_H +#define MS_ONNX_SOFTMAX_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSoftMaxParser : public OnnxNodeParser { + public: + OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SOFTMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc new file mode 100644 index 0000000000..549e20329f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSPaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SpaceToDepthT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "blocksize") { + attr->blockSize = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSPaceToDepthParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h new file mode 100644 index 0000000000..2a47a96758 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H +#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSPaceToDepthParser : public OnnxNodeParser { + public: + OnnxSPaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SPACE_TO_DEPTH_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc new file mode 100644 index 0000000000..f462d6091e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SqueezeT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + for (int i = 0; i < onnx_node_attr.ints().size(); ++i) { + attr->axis.emplace_back(onnx_node_attr.ints(i)); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Squeeze; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h new file mode 100644 index 0000000000..f8e3050809 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SQUEEZE_PARSER_H +#define MS_ONNX_SQUEEZE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSqueezeParser : public OnnxNodeParser { + public: + OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SQUEEZE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc new file mode 100644 index 0000000000..b6839b958a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Tile; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h new file mode 100644 index 0000000000..f09811e099 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_TILE_PARSER_H +#define MS_ONNX_TILE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxTileParser : public OnnxNodeParser { + public: + OnnxTileParser() : OnnxNodeParser("Tile") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc new file mode 100644 index 0000000000..f16a2d0278 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::TransposeT()); + attr->conjugate = false; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + attr->perm.resize(onnx_node_attr.ints_size()); + for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { + attr->perm[i] = onnx_node_attr.ints(i); + } + } + if (attribute_name == "perm") { + attr->perm.resize(onnx_node_attr.ints_size()); + for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { + attr->perm[i] = onnx_node_attr.ints(i); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Transpose; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h new file mode 100644 index 0000000000..e87279b43b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_TRANSPOSE_PARSER_H +#define MS_ONNX_TRANSPOSE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxTransposeParser : public OnnxNodeParser { + public: + OnnxTransposeParser() : OnnxNodeParser("Transpose") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_TRANSPOSE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc new file mode 100644 index 0000000000..90b572d565 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::UpsampleT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "mode") { + attr->mode = onnx_node_attr.s(); + } else if (attribute_name == "scales") { + for (int i = 0; i < onnx_node_attr.floats_size(); ++i) { + attr->scales[i] = onnx_node_attr.floats(i); + } + } + } + // to do + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Upsample; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxUpsampleParser("Upsample", new OnnxUpsampleParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h new file mode 100644 index 0000000000..2b2a6a3a6f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_UPSAMPLE_PARSER_H +#define MS_ONNX_UPSAMPLE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxUpsampleParser : public OnnxNodeParser { + public: + OnnxUpsampleParser() : OnnxNodeParser("Upsample") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_UPSAMPLE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc new file mode 100644 index 0000000000..8ed288e97b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::UnsqueezeT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + for (int i = 0; i < onnx_node_attr.ints().size(); ++i) { + attr->axis.emplace_back(onnx_node_attr.ints(i)); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Unsqueeze; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxUnsqueezeParser("Unsqueeze", new OnnxUnSqueezeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h new file mode 100644 index 0000000000..231d3ef2a9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_UNSQUEEZE_PARSER_H +#define MS_ONNX_UNSQUEEZE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxUnSqueezeParser : public OnnxNodeParser { + public: + OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_UNSQUEEZE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc new file mode 100644 index 0000000000..bb92abe567 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + if (onnx_node.op_type() == "Int8Quantize") { + op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; + op->primitive->value.value = new (std::nothrow) schema::OnnxInt8QuantizeT; + } else if (onnx_node.op_type() == "Int8Dequantize") { + op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; + op->primitive->value.value = new (std::nothrow) schema::OnnxInt8DequantizeT; + } else { + // MS_LOGE("Unsupported nodeType: %s", onnx_node.op_type().c_str()); + return RET_ERROR; + } + if (op->primitive->value.value == nullptr) { + // MS_LOGE("new %s attr value failed", onnx_node.op_type().c_str()); + return RET_ERROR; + } + } else { + // MS_LOGE("Input opDef is nullptr"); + return RET_PARAM_INVALID; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxUnusefulNodeParser()); +OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxUnusefulNodeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h new file mode 100644 index 0000000000..6e002254f0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_UNUSEFUL_PARSER_H +#define MS_ONNX_UNUSEFUL_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxUnusefulNodeParser : public OnnxNodeParser { + public: + OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_UNUSEFUL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt new file mode 100644 index 0000000000..03f9b3670b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE TFLITE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +add_library(tflite_parser_mid OBJECT + ${TFLITE_SRC_LIST} + ) diff --git a/mindspore/lite/tools/converter/parser/tflite/schema.fbs b/mindspore/lite/tools/converter/parser/tflite/schema.fbs new file mode 100644 index 0000000000..b7f41c756e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/schema.fbs @@ -0,0 +1,1094 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. + +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126 +} + + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; +} + +root_type Model; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_abs_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_abs_parser.cc new file mode 100644 index 0000000000..2ec1f0257d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_abs_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_abs_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteAbsParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteAbsParser"; + std::unique_ptr attr(new schema::AbsT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Abs; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_abs_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_abs_parser.h new file mode 100644 index 0000000000..7d4493f954 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_abs_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ABS_PARSER_H +#define PREDICT_TFLITE_ABS_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteAbsParser : public TfliteNodeParser { + public: + TfliteAbsParser() : TfliteNodeParser("Abs") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ABS_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc new file mode 100644 index 0000000000..3981b4b6ae --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_add_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteAddParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteAddParser"; + std::unique_ptr attr(new schema::AddT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + return RET_ERROR; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h new file mode 100644 index 0000000000..cb5c04d1f8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ADD_PARSER_H +#define PREDICT_TFLITE_ADD_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteAddParser : public TfliteNodeParser { + public: + TfliteAddParser() : TfliteNodeParser("Add") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc new file mode 100644 index 0000000000..5dc9ee5bfd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -0,0 +1,44 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_addn_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteAddNParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteAddNParser"; + std::unique_ptr attr(new schema::AddNT()); + attr->N = tflite_tensors.size() - 1; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_AddN; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h new file mode 100644 index 0000000000..bdc51bfc48 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_ADDN_PARSER_H +#define LITE_TFLITE_ADDN_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteAddNParser : public TfliteNodeParser { + public: + TfliteAddNParser() : TfliteNodeParser("AddN") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_ADDN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc new file mode 100644 index 0000000000..30706902b8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_argmax_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; + std::unique_ptr attr(new schema::ArgMaxT()); + // These are caffe attributes, set to default value. + attr->axisType = 1; + attr->outMaxValue = false; + attr->topK = 1; + attr->keepDims = false; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ArgMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h new file mode 100644 index 0000000000..0665a6b028 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H +#define PREDICT_TFLITE_ARGMAX_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteArgmaxParser : public TfliteNodeParser { + public: + TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ARGMAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc new file mode 100644 index 0000000000..2ff86ae045 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_argmin_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteArgminParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteArgminParser"; + std::unique_ptr attr(new schema::ArgMinT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsArgMinOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + // get axis + auto axis_idx = tfliteOp->inputs[1]; + std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); + auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; + auto data_ptr = buf_data->data.data(); + attr->axis = *(static_cast(static_cast(data_ptr))); + + // the following use default values + attr->outMaxValue = false; + attr->topK = 1; + attr->keepDims = false; + attr->axisType = 0; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ArgMin; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteArgminParser("Argmin", new TfliteArgminParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h new file mode 100644 index 0000000000..a02d4fe5e2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H +#define PREDICT_TFLITE_ARGMIN_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteArgminParser : public TfliteNodeParser { + public: + TfliteArgminParser() : TfliteNodeParser("Argmin") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ARGMIN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.cc new file mode 100644 index 0000000000..6ed52bfb12 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteBatchToSpaceNDParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteBatchToSpaceNDParser"; + std::unique_ptr attr(new schema::BatchToSpaceT()); + + // in tflite + // blockShape should be a 1D tensor with dimension [spatial_dims_num] + // crops should be a 2D tensor with dimension [spatial_dims_num, 2] + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { + MS_LOG(ERROR) << "BatchToSpaceNd get blockShape attr failed"; + return RET_ERROR; + } + if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) { + MS_LOG(ERROR) << "BatchToSpaceNd get crops attr failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_BatchToSpace; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h new file mode 100644 index 0000000000..59269fe454 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H +#define PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteBatchToSpaceNDParser : public TfliteNodeParser { + public: + TfliteBatchToSpaceNDParser() : TfliteNodeParser("BatchToSpaceND") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc new file mode 100644 index 0000000000..6fe391f4b6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -0,0 +1,51 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_batch_to_space_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; + std::unique_ptr attr(new schema::BatchToSpaceT()); + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { + MS_LOG(ERROR) << "batchToSpace -> blockShape get failed"; + return RET_ERROR; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) { + MS_LOG(ERROR) << "batchToSpace -> crops get failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_BatchToSpace; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h new file mode 100644 index 0000000000..37f20766a9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H +#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteBatchToSpaceParser : public TfliteNodeParser { + public: + TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc new file mode 100644 index 0000000000..bc508ef7d6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -0,0 +1,47 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; + std::unique_ptr attr(new schema::BroadcastToT()); + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) { + MS_LOG(ERROR) << "broadCastTo -> dst_shape get failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_BroadcastTo; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h new file mode 100644 index 0000000000..0bbebd449b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H +#define LITE_TFLITE_BROADCAST_TO_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteBroadcastToParser : public TfliteNodeParser { + public: + TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc new file mode 100644 index 0000000000..a7cce134a6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -0,0 +1,46 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_cast_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteCastParser"; + std::unique_ptr attr(new schema::CastT()); + + attr->srcT = dtype_map[tflite_tensors[tflite_op->inputs[0]]->type]; + attr->dstT = dtype_map[tflite_tensors[tflite_op->outputs[0]]->type]; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Cast; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + + TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h new file mode 100644 index 0000000000..843a144929 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -0,0 +1,55 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_CAST_PARSER_H +#define LITE_TFLITE_CAST_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteCastParser : public TfliteNodeParser { + public: + TfliteCastParser() : TfliteNodeParser("Cast") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; + + private: + std::map dtype_map = { + {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, + {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, + {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, + }; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_CAST_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_ceil_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_ceil_parser.cc new file mode 100644 index 0000000000..c8902c5f41 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_ceil_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_ceil_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteCeilParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteCeilParser"; + std::unique_ptr attr(new schema::CeilT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Ceil; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_ceil_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_ceil_parser.h new file mode 100644 index 0000000000..289c7b40c7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_ceil_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CEIL_PARSER_H +#define PREDICT_TFLITE_CEIL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteCeilParser : public TfliteNodeParser { + public: + TfliteCeilParser() : TfliteNodeParser("Ceil") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CEIL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc new file mode 100644 index 0000000000..5dbe6b2fea --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_concat_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteConcatParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteConcatParser"; + std::unique_ptr attr(new schema::ConcatT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + attr->axis = tfliteAttr->axis; + attr->n = tfliteOp->inputs.size(); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h new file mode 100644 index 0000000000..d2a1acff77 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CONCAT_PARSER_H +#define PREDICT_TFLITE_CONCAT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteConcatParser : public TfliteNodeParser { + public: + TfliteConcatParser() : TfliteNodeParser("Concat") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONCAT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc new file mode 100644 index 0000000000..66c2c13492 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_conv_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteConvParser"; + std::unique_ptr attr(new schema::Conv2DT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsConv2DOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + attr->group = 1; + attr->strideW = tfliteAttr->stride_w; + attr->strideH = tfliteAttr->stride_h; + attr->dilateH = tfliteAttr->dilation_h_factor; + attr->dilateW = tfliteAttr->dilation_w_factor; + attr->padMode = GetPadMode(tfliteAttr->padding); + attr->format = schema::Format_NHWC; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + // get the conv op weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[KHWC_C]; + attr->channelOut = weight_shape[KHWC_K]; + attr->kernelW = weight_shape[KHWC_W]; + attr->kernelH = weight_shape[KHWC_H]; + if (tflite_op->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tflite_op->inputs[2]; + const auto &bias_tensor = tflite_tensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + MS_LOG(ERROR) << "parse bias failed"; + return RET_ERROR; + } + } + // calculate pad params + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h new file mode 100644 index 0000000000..d2f523a0c3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CONV_PARSER_H +#define PREDICT_TFLITE_CONV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteConvParser : public TfliteNodeParser { + public: + TfliteConvParser() : TfliteNodeParser("Conv2D") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc new file mode 100644 index 0000000000..825deec6f9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_converter.h" + +namespace mindspore { +namespace lite { +TfliteConverter::TfliteConverter() { + modelParser = new TfliteModelParser(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h new file mode 100644 index 0000000000..88f0710851 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ + +#include +#include +#include "tools/converter/converter.h" +#include "tools/converter/parser/tflite/tflite_model_parser.h" +#include "tools/converter/graphdef_transform.h" + +namespace mindspore { +namespace lite { +class TfliteConverter : public Converter { + public: + TfliteConverter(); + + ~TfliteConverter() override = default; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cos_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cos_parser.cc new file mode 100644 index 0000000000..a61d95aaf7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cos_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_cos_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteCosParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteCosParser"; + std::unique_ptr attr(new schema::CosT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Cos; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cos_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cos_parser.h new file mode 100644 index 0000000000..f2dedf2c22 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cos_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_COS_PARSER_H +#define PREDICT_TFLITE_COS_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteCosParser : public TfliteNodeParser { + public: + TfliteCosParser() : TfliteNodeParser("Cos") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_COS_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc new file mode 100644 index 0000000000..50ca889d87 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_deconv_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_op_set, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; + std::unique_ptr attr(new schema::DeConv2DT()); + const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + return RET_NULL_PTR; + } + attr->group = 1; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = 1; + attr->dilateW = 1; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format_NHWC; + // get the conv op weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tflite_model_buffer, tensor_cache, schema::Format_KHWC)) { + return RET_ERROR; + } + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[CHWK_K]; + attr->channelOut = weight_shape[CHWK_C]; + attr->kernelW = weight_shape[CHWK_W]; + attr->kernelH = weight_shape[CHWK_H]; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeConv2D; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h new file mode 100644 index 0000000000..46e7e1b8b6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_DECONV_PARSER_H +#define PREDICT_TFLITE_DECONV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDeConvParser : public TfliteNodeParser { + public: + TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_op_set, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_DECONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc new file mode 100644 index 0000000000..0c441396c8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -0,0 +1,50 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; + std::unique_ptr attr(new schema::DepthToSpaceT()); + const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + return RET_NULL_PTR; + } + attr->blockSize = tflite_attr->block_size; + attr->format = schema::Format_NHWC; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthToSpace; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h new file mode 100644 index 0000000000..3be9968d8d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H +#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDepthToSpaceParser : public TfliteNodeParser { + public: + TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc new file mode 100644 index 0000000000..0d231d73dd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" +#include "tools/common/node_util.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op, + const std::unique_ptr &attr, + const std::unique_ptr &weightTensor, + TensorCache *tensor_cache) { + std::unique_ptr convAttr(new schema::Conv2DT); + convAttr->format = attr->format; + convAttr->channelIn = attr->channelIn; + convAttr->channelOut = attr->channelIn * attr->channelMultiplier; + convAttr->kernelH = attr->kernelH; + convAttr->kernelW = attr->kernelW; + convAttr->strideH = attr->strideH; + convAttr->strideW = attr->strideW; + convAttr->padMode = attr->padMode; + convAttr->padUp = attr->padUp; + convAttr->padDown = attr->padDown; + convAttr->padLeft = attr->padLeft; + convAttr->padRight = attr->padRight; + convAttr->dilateH = attr->dilateH; + convAttr->dilateW = attr->dilateW; + convAttr->hasBias = attr->hasBias; + convAttr->activationType = attr->activationType; + + auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name); + if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) { + auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex]; + if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + + if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = convAttr.release(); + return RET_OK; +} + +STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; + std::unique_ptr attr(new schema::DepthwiseConv2DT()); + const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = tflite_attr->dilation_h_factor; + attr->dilateW = tflite_attr->dilation_w_factor; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format_NHWC; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + // get the conv op weight tensor + auto input_index = tflite_op->inputs[0]; + const auto &input_tenosr = tflite_tensors[input_index]; + auto input_shape = input_tenosr->shape; + + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; + auto weight_shape = weight_tensor->shape; + attr->channelIn = input_shape[KHWC_C]; + attr->channelMultiplier = tflite_attr->depth_multiplier; + attr->kernelH = weight_shape[KHWC_H]; + attr->kernelW = weight_shape[KHWC_W]; + + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + + if (tflite_op->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tflite_op->inputs[2]; + const auto &bias_tensor = tflite_tensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + MS_LOG(ERROR) << "parse bias failed"; + return RET_ERROR; + } + } + + if (attr->channelMultiplier > 1) { + if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { + // MS_LOGE("Parse Group DepthwiseConv failed"); + return RET_ERROR; + } + } else { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h new file mode 100644 index 0000000000..2e0b1a0d02 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H +#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDepthwiseConv2DParser : public TfliteNodeParser { + public: + TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override; + + private: + STATUS ParseGroupDepthwiseConv(schema::CNodeT *op, + const std::unique_ptr &attr, + const std::unique_ptr &weightTensor, + TensorCache *tensor_cache); +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_div_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_div_parser.cc new file mode 100644 index 0000000000..b7b6efe123 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_div_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_div_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDivParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteDivParser"; + std::unique_ptr attr(new schema::DivT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Div; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_div_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_div_parser.h new file mode 100644 index 0000000000..1a82a0f806 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_div_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_DIV_PARSER_H +#define PREDICT_TFLITE_DIV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDivParser : public TfliteNodeParser { + public: + TfliteDivParser() : TfliteNodeParser("Div") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_DIV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_equal_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_equal_parser.cc new file mode 100644 index 0000000000..3ef4ddd286 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_equal_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_equal_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteEqualParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteEqualParser"; + std::unique_ptr attr(new schema::EqualT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Equal; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_equal_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_equal_parser.h new file mode 100644 index 0000000000..3435285b2a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_equal_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_EQUAL_PARSER_H +#define LITE_TFLITE_EQUAL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteEqualParser : public TfliteNodeParser { + public: + TfliteEqualParser() : TfliteNodeParser("Equal") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_EQUAL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_exp_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_exp_parser.cc new file mode 100644 index 0000000000..2d89b47dca --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_exp_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_exp_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteExpParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteExpParser"; + std::unique_ptr attr(new schema::ExpT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Exp; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_exp_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_exp_parser.h new file mode 100644 index 0000000000..ec27390ace --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_exp_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_EXP_PARSER_H +#define PREDICT_TFLITE_EXP_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteExpParser : public TfliteNodeParser { + public: + TfliteExpParser() : TfliteNodeParser("Exp") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_EXP_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc new file mode 100644 index 0000000000..3c42bf4d3b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; + std::unique_ptr attr(new schema::ExpandDimsT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + // get axis + auto axis_idx = tfliteOp->inputs[1]; + std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); + auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; + auto data_ptr = buf_data->data.data(); + attr->dim = *(static_cast(static_cast(data_ptr))); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ExpandDims; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h new file mode 100644 index 0000000000..aa867bc315 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H +#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteExpandDimsParser : public TfliteNodeParser { + public: + TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc new file mode 100644 index 0000000000..9d2dfddcfb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tflite/tflite_fakequant_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; + std::unique_ptr attr(new schema::FullConnectionT()); + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + if (tfliteOp->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tfliteOp->inputs[2]; + const auto &bias_tensor = tfliteTensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + MS_LOG(ERROR) << "parse bias failed"; + return RET_ERROR; + } + } + attr->axis = 1; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h new file mode 100644 index 0000000000..101c6cfec1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H +#define LITE_TFLITE_FAKEQUANT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFakeQuantParser : public TfliteNodeParser { + public: + TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_FAKEQUANT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc new file mode 100644 index 0000000000..e41141b560 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_fill_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteFillParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteFillParser"; + std::unique_ptr attr(new schema::FillT()); + + if (tfliteOp->inputs.size() > 1) { + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) { + return RET_ERROR; + } + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Fill; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteFillParser("Fill", new TfliteFillParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h new file mode 100644 index 0000000000..5d8fdee06d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_FILL_PARSER_H +#define PREDICT_TFLITE_FILL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFillParser : public TfliteNodeParser { + public: + TfliteFillParser() : TfliteNodeParser("Fill") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_FILL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_floor_div_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_div_parser.cc new file mode 100644 index 0000000000..2f38e6a35e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_div_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_floor_div_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteFloorDivParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; + std::unique_ptr attr(new schema::FloorDivT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FloorDiv; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_floor_div_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_div_parser.h new file mode 100644 index 0000000000..3ee5f51305 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_div_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_FLOOR_DIV_PARSER_H +#define PREDICT_TFLITE_FLOOR_DIV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFloorDivParser : public TfliteNodeParser { + public: + TfliteFloorDivParser() : TfliteNodeParser("FloorDiv") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_FLOOR_DIV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_floor_mod_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_mod_parser.cc new file mode 100644 index 0000000000..ea99cef833 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_mod_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_floor_mod_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteFloorModParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteFloorModParser"; + std::unique_ptr attr(new schema::FloorModT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FloorMod; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_floor_mod_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_mod_parser.h new file mode 100644 index 0000000000..b0ed989508 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_mod_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_FLOOR_MOD_PARSER_H +#define PREDICT_TFLITE_FLOOR_MOD_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFloorModParser : public TfliteNodeParser { + public: + TfliteFloorModParser() : TfliteNodeParser("FloorMod") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_FLOOR_MOD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_floor_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_parser.cc new file mode 100644 index 0000000000..70abaef920 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_floor_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteFloorParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteFloorParser"; + std::unique_ptr attr(new schema::FloorT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Floor; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFloorParser("Floor", new TfliteFloorParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_floor_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_parser.h new file mode 100644 index 0000000000..7db0e83324 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_floor_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_FLOOR_PARSER_H +#define PREDICT_TFLITE_FLOOR_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFloorParser : public TfliteNodeParser { + public: + TfliteFloorParser() : TfliteNodeParser("Floor") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_FLOOR_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc new file mode 100644 index 0000000000..8ee363335f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_fullyconnected_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; + std::unique_ptr attr(new schema::FullConnectionT()); + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + if (tfliteOp->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tfliteOp->inputs[2]; + const auto &bias_tensor = tfliteTensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + MS_LOG(ERROR) << "parse bias failed"; + return RET_ERROR; + } + } + attr->axis = 1; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h new file mode 100644 index 0000000000..f41ab2e3c0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ADD_PARSER_H +#define PREDICT_TFLITE_ADD_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFullyConnectedParser : public TfliteNodeParser { + public: + TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc new file mode 100644 index 0000000000..41841f4a28 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_gather_nd_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; + std::unique_ptr attr(new schema::GatherNdT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsGatherNdOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->batchDims = 0; // default + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_GatherNd; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteGatherNdParser("GatherNd", new TfliteGatherNdParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h new file mode 100644 index 0000000000..18b8b5531d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_GATHER_ND_PARSER_H +#define PREDICT_TFLITE_GATHER_ND_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteGatherNdParser : public TfliteNodeParser { + public: + TfliteGatherNdParser() : TfliteNodeParser("GatherNd") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_GATHER_ND_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc new file mode 100644 index 0000000000..4c9efde2b4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_gather_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteGatherParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteGatherParser"; + std::unique_ptr attr(new schema::GatherT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsGatherOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->axis = tflite_attr->axis; + attr->batchDims = 0; // default + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Gather; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h new file mode 100644 index 0000000000..5dd842414a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_GATHER_PARSER_H +#define PREDICT_TFLITE_GATHER_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteGatherParser : public TfliteNodeParser { + public: + TfliteGatherParser() : TfliteNodeParser("Gather") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_GATHER_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_v2_parser.cc new file mode 100644 index 0000000000..514fe5280c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_v2_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_gather_v2_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteGatherV2Parser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteGatherV2Parser"; + std::unique_ptr attr(new schema::GatherT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsGatherOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->axis = tflite_attr->axis; + attr->batchDims = 0; // default + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Gather; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteGatherV2Parser("GatherV2", new TfliteGatherV2Parser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_v2_parser.h new file mode 100644 index 0000000000..d9acc3721d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_v2_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_GATHER_V2_PARSER_H +#define PREDICT_TFLITE_GATHER_V2_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteGatherV2Parser : public TfliteNodeParser { + public: + TfliteGatherV2Parser() : TfliteNodeParser("GatherV2") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_GATHER_V2_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_greater_equal_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_equal_parser.cc new file mode 100644 index 0000000000..0ca417a180 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_equal_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_greater_equal_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteGreaterEqualParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; + std::unique_ptr attr(new schema::GreaterEqualT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_GreaterEqual; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_greater_equal_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_equal_parser.h new file mode 100644 index 0000000000..aacd8e3ab0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_equal_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_GREATER_EQUAL_PARSER_H +#define LITE_TFLITE_GREATER_EQUAL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteGreaterEqualParser : public TfliteNodeParser { + public: + TfliteGreaterEqualParser() : TfliteNodeParser("GreaterEqual") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_GREATER_EQUAL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_greater_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_parser.cc new file mode 100644 index 0000000000..c706140dc6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_greater_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteGreaterParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteGreaterParser"; + std::unique_ptr attr(new schema::GreaterT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Greater; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteGreaterParser("Greater", new TfliteGreaterParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_greater_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_parser.h new file mode 100644 index 0000000000..99ab59c1f6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_greater_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_GREATER_PARSER_H +#define LITE_TFLITE_GREATER_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteGreaterParser : public TfliteNodeParser { + public: + TfliteGreaterParser() : TfliteNodeParser("Greater") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_GREATER_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc new file mode 100644 index 0000000000..4dd7fe9b89 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_hard_swish_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteHardSwishParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteHardSwishParser"; + std::unique_ptr attr(new schema::ActivationT()); + + attr->type = schema::ActivationType_HSWISH; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h new file mode 100644 index 0000000000..00de1d2458 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_HARD_SWISH_PARSER_H +#define PREDICT_TFLITE_HARD_SWISH_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteHardSwishParser : public TfliteNodeParser { + public: + TfliteHardSwishParser() : TfliteNodeParser("HardSwish") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_HARD_SWISH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_inner_product_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_inner_product_parser.cc new file mode 100644 index 0000000000..aaf9366e81 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_inner_product_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_inner_product_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteInnerProductParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteInnerProductParser"; + std::unique_ptr attr(new schema::FullConnectionT()); + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + if (tfliteOp->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tfliteOp->inputs[2]; + const auto &bias_tensor = tfliteTensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + MS_LOG(ERROR) << "parse bias failed"; + return RET_ERROR; + } + } + attr->axis = 1; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteInnerProductParser("InnerProduct", new TfliteInnerProductParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_inner_product_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_inner_product_parser.h new file mode 100644 index 0000000000..0505e8ce23 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_inner_product_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_INNER_PRODUCT_PARSER_H +#define PREDICT_TFLITE_INNER_PRODUCT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteInnerProductParser : public TfliteNodeParser { + public: + TfliteInnerProductParser() : TfliteNodeParser("InnerProduct") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_INNER_PRODUCT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.cc new file mode 100644 index 0000000000..f04c65c4d3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_leaky_relu_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; + std::unique_ptr attr(new schema::LeakyReLUT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->negativeSlope = tflite_attr->alpha; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.h new file mode 100644 index 0000000000..bce9c90a9f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LEAKY_RELU_PARSER_H +#define PREDICT_TFLITE_LEAKY_RELU_PARSER_H + +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include +#include + +namespace mindspore { +namespace lite { +class TfliteLeakyReluParser : public TfliteNodeParser { + public: + TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_LEAKY_RELU_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_less_equal_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_less_equal_parser.cc new file mode 100644 index 0000000000..72f26ebc6e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_less_equal_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_less_equal_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLessEqualParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; + std::unique_ptr attr(new schema::LessEqualT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LessEqual; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_less_equal_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_less_equal_parser.h new file mode 100644 index 0000000000..87fa8cedd8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_less_equal_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_LESS_EQUAL_PARSER_H +#define LITE_TFLITE_LESS_EQUAL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLessEqualParser : public TfliteNodeParser { + public: + TfliteLessEqualParser() : TfliteNodeParser("LessEqual") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_LESS_EQUAL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_less_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_less_parser.cc new file mode 100644 index 0000000000..250272aa1f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_less_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_less_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLessParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteLessParser"; + std::unique_ptr attr(new schema::LessT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Less; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_less_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_less_parser.h new file mode 100644 index 0000000000..7cbe1c38da --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_less_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_LESS_PARSER_H +#define LITE_TFLITE_LESS_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLessParser : public TfliteNodeParser { + public: + TfliteLessParser() : TfliteNodeParser("Less") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_LESS_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_log_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_log_parser.cc new file mode 100644 index 0000000000..032c57225b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_log_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_log_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLogParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteLogParser"; + std::unique_ptr attr(new schema::LogT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Log; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_log_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_log_parser.h new file mode 100644 index 0000000000..cb828b6706 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_log_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LOG_PARSER_H +#define PREDICT_TFLITE_LOG_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLogParser : public TfliteNodeParser { + public: + TfliteLogParser() : TfliteNodeParser("Log") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_LOG_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_and_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_and_parser.cc new file mode 100644 index 0000000000..c7df7acc8c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_and_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_logical_and_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLogicalAndParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteLogicalAndParser"; + std::unique_ptr attr(new schema::LogicalAndT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LogicalAnd; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_and_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_and_parser.h new file mode 100644 index 0000000000..6e28d75f39 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_and_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LOGICAL_AND_PARSER_H +#define PREDICT_TFLITE_LOGICAL_AND_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLogicalAndParser : public TfliteNodeParser { + public: + TfliteLogicalAndParser() : TfliteNodeParser("LogicalAnd") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_LOGICAL_AND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_not_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_not_parser.cc new file mode 100644 index 0000000000..396decce53 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_not_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_logical_not_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLogicalNotParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteLogicalNotParser"; + std::unique_ptr attr(new schema::LogicalNotT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LogicalNot; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteLogicalNotParser("LogicalNot", new TfliteLogicalNotParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_not_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_not_parser.h new file mode 100644 index 0000000000..41f8f1bd2f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_not_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LOGICAL_NOT_PARSER_H +#define PREDICT_TFLITE_LOGICAL_NOT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLogicalNotParser : public TfliteNodeParser { + public: + TfliteLogicalNotParser() : TfliteNodeParser("LogicalNot") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_LOGICAL_NOT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_or_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_or_parser.cc new file mode 100644 index 0000000000..24582be41d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_or_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_logical_or_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLogicalOrParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteLogicalOrParser"; + std::unique_ptr attr(new schema::LogicalOrT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LogicalOr; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteLogicalOrParser("LogicalOr", new TfliteLogicalOrParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_or_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_or_parser.h new file mode 100644 index 0000000000..55f74c174b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_or_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LOGICAL_OR_PARSER_H +#define PREDICT_TFLITE_LOGICAL_OR_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLogicalOrParser : public TfliteNodeParser { + public: + TfliteLogicalOrParser() : TfliteNodeParser("LogicalOr") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_LOGICAL_OR_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc new file mode 100644 index 0000000000..bb254f2c3e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_logistic_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLogisticParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteLogisticParser"; + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h new file mode 100644 index 0000000000..6c5402faa8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LOGISTIC_PARSER_H +#define PREDICT_TFLITE_LOGISTIC_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLogisticParser : public TfliteNodeParser { + public: + TfliteLogisticParser() : TfliteNodeParser("Logistic") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONCAT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc new file mode 100644 index 0000000000..3ac19a2920 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_lrn_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteLRNParser"; + std::unique_ptr attr(new schema::LocalResponseNormalizationT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsLocalResponseNormalizationOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->depth_radius = tflite_attr->radius; + attr->alpha = tflite_attr->alpha; + attr->beta = tflite_attr->beta; + attr->bias = tflite_attr->bias; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h new file mode 100644 index 0000000000..b7eae4f978 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LRN_PARSER_H +#define PREDICT_TFLITE_ADD_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLRNParser : public TfliteNodeParser { + public: + TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_LRN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc new file mode 100644 index 0000000000..387b0302e2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_max_pooling_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMaxPoolingParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; + std::unique_ptr attr(new schema::PoolingT()); + const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + attr->format = schema::Format_NHWC; + // attr->global + attr->poolingMode = schema::PoolMode_MAX_POOLING; + attr->windowW = tflite_attr->filter_width; + attr->windowH = tflite_attr->filter_height; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->padMode = GetPadMode(tflite_attr->padding); + // calculate pad params + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TfliteMaxPoolingParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h new file mode 100644 index 0000000000..0893b580ca --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MAX_POOLING_PARSER_H +#define PREDICT_TFLITE_MAX_POOLING_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMaxPoolingParser : public TfliteNodeParser { + public: + TfliteMaxPoolingParser() : TfliteNodeParser("MaxPooling") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_maximum_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_maximum_parser.cc new file mode 100644 index 0000000000..dc55c07713 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_maximum_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_maximum_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMaximumParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteMaximumParser"; + std::unique_ptr attr(new schema::MaximumT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Maximum; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_maximum_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_maximum_parser.h new file mode 100644 index 0000000000..5f3587d656 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_maximum_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MAXIMUM_PARSER_H +#define PREDICT_TFLITE_MAXIMUM_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMaximumParser : public TfliteNodeParser { + public: + TfliteMaximumParser() : TfliteNodeParser("Maximum") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_MAXIMUM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc new file mode 100644 index 0000000000..caaf43a09f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_mean_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMeanParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteMeanParser"; + std::unique_ptr attr(new schema::MeanT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_ERROR; + } + + attr->keepDims = tflite_attr->keep_dims; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) { + MS_LOG(ERROR) << "Mean get axis attr failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mean; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteMeanParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h new file mode 100644 index 0000000000..09e926fc62 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MEAN_PARSER_H +#define PREDICT_TFLITE_MEAN_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMeanParser : public TfliteNodeParser { + public: + TfliteMeanParser() : TfliteNodeParser("Mean") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_MEAN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc new file mode 100644 index 0000000000..2ec3e0221f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_mean_pooling_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMeanPoolingParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; + std::unique_ptr attr(new schema::PoolingT()); + const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + attr->format = schema::Format_NHWC; + // attr->global + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + attr->windowW = tflite_attr->filter_width; + attr->windowH = tflite_attr->filter_height; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->padMode = GetPadMode(tflite_attr->padding); + // calculate pad params + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TfliteMeanPoolingParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h new file mode 100644 index 0000000000..9f1dca30a2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MEAN_POOLING_PARSER_H +#define PREDICT_TFLITE_MEAN_POOLING_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMeanPoolingParser : public TfliteNodeParser { + public: + TfliteMeanPoolingParser() : TfliteNodeParser("MeanPooling") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_minimum_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_minimum_parser.cc new file mode 100644 index 0000000000..245b69fd24 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_minimum_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_minimum_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMinimumParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteMinimumParser"; + std::unique_ptr attr(new schema::MinimumT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Minimum; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_minimum_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_minimum_parser.h new file mode 100644 index 0000000000..0e0bda8a4b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_minimum_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MINIMUM_PARSER_H +#define PREDICT_TFLITE_MINIMUM_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMinimumParser : public TfliteNodeParser { + public: + TfliteMinimumParser() : TfliteNodeParser("Minimum") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_MINIMUM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc new file mode 100644 index 0000000000..eebe4941f5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_model_parser.h" +#include +#include +#include +#include "tools/common/graph_util.h" +#include "tools/common/storage.h" +#include "flatbuffers/flatbuffers.h" +#include "utils/log_adapter.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +TfliteModelParser::TfliteModelParser() {} + +TfliteModelParser::~TfliteModelParser() {} + +std::unique_ptr TfliteModelParser::ReadTfliteModelFromFlat(const char *model_path) { + size_t size; + auto buf = ReadFile(model_path, &size); + if (buf == nullptr) { + MS_LOG(ERROR) << "the file buffer is nullptr"; + return nullptr; + } + flatbuffers::Verifier verify((const uint8_t *)buf, size); + if (!tflite::VerifyModelBuffer(verify)) { + MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; + return nullptr; + } + return tflite::UnPackModel(buf); +} + +std::string TfliteModelParser::GetTfliteNodeType(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto msOpType = GetMSOpType(tflite_op_type); + return msOpType; +} + +STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graphDef) { + std::vector tensors = tensor_cache.GetCachedTensor(); + for (auto iter : tensors) { + std::unique_ptr temp(iter); + temp->format = schema::Format_NHWC; + sub_graphDef->allTensors.emplace_back(move(temp)); + } + return RET_OK; +} + +STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op) { + auto dst_op = tfliteOpMap.at(tflite_op.get()); + + std::vector quant_params_index; + quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end()); + quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end()); + for (const auto &index : quant_params_index) { + const auto &tflite_tensor = tflite_subgraph->tensors[index]; + if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && + tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { + continue; + } + std::unique_ptr quant_param(new schema::QuantParamT()); + if (!tflite_tensor->quantization->scale.empty()) { + quant_param->scale = tflite_tensor->quantization->scale[0]; + } + + if (!tflite_tensor->quantization->zero_point.empty()) { + quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; + } + + if (!tflite_tensor->quantization->min.empty()) { + quant_param->min = tflite_tensor->quantization->min[0]; + } + + if (!tflite_tensor->quantization->max.empty()) { + quant_param->max = tflite_tensor->quantization->max[0]; + } + } + dst_op->quantType = schema::QuantType_AwareTrainning; + return RET_OK; +} + +STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, schema::CNodeT *op, + TensorCache *tensorCache) { + for (const auto &index : tflite_op->outputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[index]; + std::unique_ptr tensor(new schema::TensorT()); + tensor->dataType = GetTfliteDataType(tflite_tensor->type); + tensor->dims = tflite_tensor->shape; + tensor->nodeType = schema::NodeType_Parameter; + auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); + op->outputIndex.emplace_back(opOutputIndex); + } + + return RET_OK; +} + +STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, TensorCache *tensorCache) { + auto op_type = GetTfliteNodeType(tflite_op, tflite_model); + std::vector op_inputs(tflite_op->inputs); + if (op_type == "DeConv2D") { + reverse(op_inputs.begin(), op_inputs.end()); + } + + for (const auto &tflite_index : op_inputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; + auto tensor_name = tflite_tensor->name; + auto op = tfliteOpMap[tflite_op.get()]; + unsigned int index = tensorCache->FindTensor(tensor_name); + if (index != -1) { + op->inputIndex.push_back(index); + } + } + + return RET_OK; +} + +STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::MetaGraphT *subGraph, + mindspore::lite::TensorCache *tensorCache) { + auto i = 0; + for (const auto &tflite_op : tflite_subgraph->operators) { + auto opType = GetTfliteNodeType(tflite_op, tflite_model); + + std::unique_ptr op(new schema::CNodeT); + op->name = opType + "-" + std::to_string(i++); + + MS_LOG(INFO) << "parse op: [%s]" << op->name.c_str(); + + // 1. init op attr params + auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); + if (node_parser == nullptr) { + MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str(); + continue; + // return RET_NULL_PTR; + } + + auto status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, + tflite_model->operator_codes, op.get(), tensorCache, false); + if (status != RET_OK) { + MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed"; + return RET_ERROR; + } + + status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set Op " << op->name.c_str() << " Output Index Failed!"; + return RET_ERROR; + } + + subGraph->nodes.emplace_back(std::move(op)); + opMap[subGraph->nodes.back()->name] = subGraph->nodes.back().get(); + tfliteOpMap[tflite_op.get()] = subGraph->nodes.back().get(); + } + return RET_OK; +} + +void TfliteModelParser::SetInputTensor(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + TensorCache *tensor_cache) { + for (const auto &index : tflite_subgraph->inputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[index]; + std::unique_ptr tensor(new schema::TensorT()); + tensor->format = schema::Format_NHWC; + tensor->dataType = GetTfliteDataType(tflite_tensor->type); + tensor->nodeType = schema::NodeType_ValueNode; + tensor->dims = tflite_tensor->shape; + tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); + } +} + +void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, + schema::MetaGraphT *subGraphDef) { + auto opGraph = OpGraphT::Build(subGraphDef); + auto graphInputs = tensorCache.GetGraphInputs(); + auto graphOutputs = opGraph->GetOutputNode(); + + subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); + + for (const auto &output : graphOutputs) { + auto op = opMap[output->ID()]; + for (auto outputIndex : op->outputIndex) { + subGraphDef->outputIndex.emplace_back(outputIndex); + } + } +} + +MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { + std::unique_ptr subGraph(new schema::MetaGraphT); + if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { + // MS_LOGE("INPUT ILLEGAL: modelFile must be *.tflite"); + return nullptr; + } + std::unique_ptr tflite_model(new tflite::ModelT()); + tflite_model = ReadTfliteModelFromFlat(modelFile.c_str()); + if (tflite_model == nullptr) { + // MS_LOGE("read tflite model failed"); + return nullptr; + } + TensorCache tensorCache; + if (tflite_model->subgraphs.size() != 1) { + MS_LOG(ERROR) << "read tflite model subgraphs failed"; + return nullptr; + } + + const auto &tflite_subgraph = tflite_model->subgraphs[0]; + subGraph->name = "MS_model converted by TF-Lite"; + + // set dst subGraph input/output tensor + SetInputTensor(tflite_model, tflite_subgraph, &tensorCache); + // set dst subGraph op attr etc. + auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "ParseOp failed."; + return nullptr; + } + + for (const auto &tflite_op : tflite_subgraph->operators) { + auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); + if (status_tmp != RET_OK) { + // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); + } + } + + for (const auto &tflite_op : tflite_subgraph->operators) { + auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op); + if (statusTmp != RET_OK) { + // MS_LOGE("ParseTfliteQuantParams %s Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); + } + } + + SetGraphTensorIndex(tensorCache, subGraph.get()); + SetAllTensors(tensorCache, subGraph.get()); + return subGraph.release(); +// return Fb2Anf(subGraph.release()); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h new file mode 100644 index 0000000000..0ebd9a7199 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "securec/include/securec.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include "tools/common/tensor_util.h" + +#include "mindspore/lite/schema/inner/model_generated.h" + +// using namespace tflite; + +namespace mindspore { +namespace lite { +class TfliteModelParser : public ModelParser { + public: + TfliteModelParser(); + + virtual ~TfliteModelParser(); + + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile); + + private: + std::unique_ptr ReadTfliteModelFromFlat(const char *buf); + + void SetInputTensor(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, TensorCache *tensor_cache); + + void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, + schema::MetaGraphT *subGraphDef); + + STATUS ParseOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph, + TensorCache *tensor_cache); + + STATUS ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op); + + std::string GetTfliteNodeType(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model); + + STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); + + STATUS SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, + schema::CNodeT *op, + TensorCache *tensorCache); + + STATUS SetOpInputIdx(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, TensorCache *tensorCache); + + std::map opMap; + std::map tfliteOpMap; +}; +} // namespace lite +} // namespace mindspore +#endif // PREDICT_CONV +// ERTER_PARSER_TFLITE_MODEL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc new file mode 100644 index 0000000000..efdefbc743 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_mul_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMulParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteMulParser"; + std::unique_ptr attr(new schema::MulT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h new file mode 100644 index 0000000000..5514c1af4e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MUL_PARSER_H +#define PREDICT_TFLITE_MUL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMulParser : public TfliteNodeParser { + public: + TfliteMulParser() : TfliteNodeParser("Mul") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_MUL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc new file mode 100644 index 0000000000..54167499b6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "securec/include/securec.h" +#include "tools/converter/parser/tflite/tflite_node_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, + const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { + auto count = 1; + std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); + auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); + auto buffer_idx = tflite_tensor->buffer; + if (!tfliteModelBuffer[buffer_idx]->data.empty()) { + tensor->data.resize(data_size); + auto ret = memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size); + if (ret) { + MS_LOG(ERROR) << "memcpy tensor data failed, error code: %d" << ret; + return ret; + } + } else { + MS_LOG(ERROR) << "src tensor data is empty."; + return RET_ERROR; + } + return RET_OK; +} + +STATUS TfliteNodeParser::ParseWeight(const std::vector &weight_tenosrs, + const std::vector> &tfliteModelBuffer, + mindspore::lite::TensorCache *tensor_cache, schema::Format format) { + for (const auto &weight_tensor : weight_tenosrs) { + auto idx = tensor_cache->FindTensor(weight_tensor->name); + if (idx < 0) { + std::unique_ptr tensor(new schema::TensorT); + tensor->dataType = GetTfliteDataType(weight_tensor->type); + tensor->dims = weight_tensor->shape; + tensor->nodeType = schema::NodeType_ValueNode; + // memcpy tensor data + // buffer is 0 (which refers to an always existent empty buffer) + if (weight_tensor->buffer > 0) { + CopyTfliteTensorData(tfliteModelBuffer, weight_tensor, tensor.get()); + } + MS_LOG(DEBUG) << "add weight tensor name: %s", weight_tensor->name.c_str(); + tensor_cache->AddTensor(weight_tensor->name, tensor.release(), TF_CONST); + } + } + return RET_OK; +} + +STATUS TfliteNodeParser::ParseBias(const std::vector &bias_tensors, + const std::vector> &tfliteModelBuffer, + TensorCache *tensor_cache) { + for (const auto &bias_tensor : bias_tensors) { + auto idx = tensor_cache->FindTensor(bias_tensor->name); + if (idx < 0) { + std::unique_ptr tensor(new schema::TensorT); + tensor->dataType = GetTfliteDataType(bias_tensor->type); + tensor->dims = bias_tensor->shape; + tensor->nodeType = schema::NodeType_ValueNode; + // memcpy tensor data + // buffer is 0 (which refers to an always existent empty buffer) + if (bias_tensor->buffer > 0) { + CopyTfliteTensorData(tfliteModelBuffer, bias_tensor, tensor.get()); + } + // MS_LOGD("add weight tensor name: %s", bias_tensor->name.c_str()); + tensor_cache->AddTensor(bias_tensor->name, tensor.release(), TF_CONST); + } + } + return RET_OK; +} + +TypeId TfliteNodeParser::GetTfliteDataType(const tflite::TensorType &tflite_data_type) { + static std::unordered_map type_map = { + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + }; + auto iter = type_map.find(tflite_data_type); + if (iter == type_map.end()) { + return kTypeUnknown; + } + return iter->second; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h new file mode 100644 index 0000000000..3eeea81f83 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_NODE_PARSER_H +#define PREDICT_TFLITE_NODE_PARSER_H + +#include +#include +#include +#include "utils/log_adapter.h" +#include "schema/inner/model_generated.h" +#include "tools/converter/parser/tflite/tflite_util.h" +#include "tools/converter/parser/tflite/schema_generated.h" +#include "tools/common/tensor_util.h" +#include "ir/dtype/type_id.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class TfliteNodeParser { + public: + explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {} + + virtual ~TfliteNodeParser() {} + + virtual STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) = 0; + + STATUS ParseWeight(const std::vector &weight_tenosr, + const std::vector> &tfliteModelBuffer, TensorCache *tensor_cache, + schema::Format format); + + STATUS ParseBias(const std::vector &weight_tenosr, + const std::vector> &tfliteModelBuffer, TensorCache *tensor_cache); + + STATUS ParseAttr(const std::vector &attr_tenosrs, + const std::vector> &tfliteModelBuffer, + mindspore::lite::TensorCache *tensor_cache, schema::Format format); + + STATUS CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, + const tflite::TensorT *tflite_tensor, schema::TensorT *tensor); + + TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); + + template + STATUS GetTfliteData(const int32_t tensor_index, const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + std::vector &attr_data) { + int32_t count = 1; + std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(), + [&](int32_t sha) { count *= sha; }); + auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer]; + auto data_ptr = buf_data->data.data(); + switch (tfliteTensors[tensor_index]->type) { + case tflite::TensorType_UINT8: { + for (int i = 0; i < count; i++) { + uint8_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(uint8_t); + } + break; + } + case tflite::TensorType_INT8: { + for (int i = 0; i < count; i++) { + int8_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int8_t); + } + break; + } + case tflite::TensorType_INT16: { + for (int i = 0; i < count; i++) { + int16_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int16_t); + } + break; + } + case tflite::TensorType_INT32: { + for (int i = 0; i < count; i++) { + int32_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int32_t); + } + break; + } + case tflite::TensorType_INT64: { + for (int i = 0; i < count; i++) { + int64_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int64_t); + } + break; + } + case tflite::TensorType_FLOAT32: { + for (int i = 0; i < count; i++) { + float data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(float); + } + break; + } + } + return RET_OK; + } + + protected: + bool isQuantizedModel(); + + protected: + const std::string &name; + bool quantizedModel; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_NODE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc new file mode 100644 index 0000000000..93e3974a5b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +TfliteNodeParserRegistry::TfliteNodeParserRegistry() {} + +TfliteNodeParserRegistry::~TfliteNodeParserRegistry() {} + +TfliteNodeParserRegistry *TfliteNodeParserRegistry::GetInstance() { + static TfliteNodeParserRegistry instance; + return &instance; +} + +TfliteNodeParser *TfliteNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + return nullptr; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h new file mode 100644 index 0000000000..c2b533e241 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H + +#include +#include +#include "tools/common/node_util.h" +#include "tools/converter/parser/tflite/tflite_node_parser.h" + +namespace mindspore { +namespace lite { +class TfliteNodeParserRegistry { + public: + TfliteNodeParserRegistry(); + + virtual ~TfliteNodeParserRegistry(); + + static TfliteNodeParserRegistry *GetInstance(); + + TfliteNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class TfliteNodeRegister { + public: + TfliteNodeRegister(const std::string &name, TfliteNodeParser *parser) { + TfliteNodeParserRegistry::GetInstance()->parsers[name] = parser; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_not_equal_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_not_equal_parser.cc new file mode 100644 index 0000000000..c2dafab666 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_not_equal_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_not_equal_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteNotEqualParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; + std::unique_ptr attr(new schema::NotEqualT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_NotEqual; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_not_equal_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_not_equal_parser.h new file mode 100644 index 0000000000..bf69218ae4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_not_equal_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_NOT_EQUAL_PARSER_H +#define LITE_TFLITE_NOT_EQUAL_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteNotEqualParser : public TfliteNodeParser { + public: + TfliteNotEqualParser() : TfliteNodeParser("NotEqual") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_NOT_EQUAL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc new file mode 100644 index 0000000000..8694a71840 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_one_hot_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteOneHotParser"; + std::unique_ptr attr(new schema::OneHotT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsOneHotOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + auto axis = tflite_attr->axis; + const auto tensor_shape = tfliteTensors[tfliteOp->inputs[0]].get()->shape; + if (axis < 0) { + axis += tensor_shape.size(); + } + attr->axis = axis; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_OneHot; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteOneHotParser("OneHot", new TfliteOneHotParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h new file mode 100644 index 0000000000..f21659714a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ONE_HOT_PARSER_H +#define PREDICT_TFLITE_ONE_HOT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteOneHotParser : public TfliteNodeParser { + public: + TfliteOneHotParser() : TfliteNodeParser("OneHot") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ONE_HOT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_p_relu_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_p_relu_parser.cc new file mode 100644 index 0000000000..92ec67de8f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_p_relu_parser.cc @@ -0,0 +1,45 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_p_relu_parser.h" + +namespace mindspore { +namespace lite { +STATUS TflitePreluParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "paser TflitePreluParser"; + std::unique_ptr attr(new schema::PreluT()); + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { + MS_LOG(ERROR) << "pRelu -> slope get failed"; + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Prelu; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_p_relu_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_p_relu_parser.h new file mode 100644 index 0000000000..5e79ce6914 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_p_relu_parser.h @@ -0,0 +1,40 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_P_RELU_PARSER_H +#define LITE_TFLITE_P_RELU_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TflitePreluParser : public TfliteNodeParser { + public: + TflitePreluParser() : TfliteNodeParser("Prelu") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_P_RELU_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc new file mode 100644 index 0000000000..a49e032894 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_pad_parser.h" + +namespace mindspore { +namespace lite { +STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TflitePadParser"; + std::unique_ptr attr(new schema::PadT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsPadOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->paddingMode = schema::PaddingMode_CONSTANT; + if (tfliteOp->inputs.size() > 1) { + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->paddings)) { + return RET_ERROR; + } + } + // attr->constantValue = 0.0f; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pad; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h new file mode 100644 index 0000000000..e2f0c29c7b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_PAD_PARSER_H +#define PREDICT_TFLITE_PAD_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TflitePadParser : public TfliteNodeParser { + public: + TflitePadParser() : TfliteNodeParser("Pad") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_PAD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pow_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pow_parser.cc new file mode 100644 index 0000000000..aac929ff7a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pow_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_pow_parser.h" + +namespace mindspore { +namespace lite { +STATUS TflitePowParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TflitePowParser"; + std::unique_ptr attr(new schema::PowerT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsPowOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + // the following use default values. This op is doing... + attr->power = 0.0f; + attr->scale = 0.0f; + attr->shift = 0.0f; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Power; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pow_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pow_parser.h new file mode 100644 index 0000000000..e3a6b07bf9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pow_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_POW_PARSER_H +#define PREDICT_TFLITE_POW_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TflitePowParser : public TfliteNodeParser { + public: + TflitePowParser() : TfliteNodeParser("Pow") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_POW_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc new file mode 100644 index 0000000000..ee7e66d2db --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_range_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteRangeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteRangeParser"; + std::unique_ptr attr(new schema::RangeT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsRangeOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + auto start_idx = tfliteOp->inputs[1]; + std::for_each(tfliteTensors[start_idx]->shape.begin(), tfliteTensors[start_idx]->shape.end(), [&](int32_t sha){}); + auto &start_buf_data = tfliteModelBuffer[tfliteTensors[start_idx]->buffer]; + auto start_data_ptr = start_buf_data->data.data(); + attr->start = *(static_cast(static_cast(start_data_ptr))); + + auto limit_idx = tfliteOp->inputs[2]; + std::for_each(tfliteTensors[limit_idx]->shape.begin(), tfliteTensors[limit_idx]->shape.end(), [&](int32_t sha){}); + auto &limit_buf_data = tfliteModelBuffer[tfliteTensors[limit_idx]->buffer]; + auto limit_data_ptr = limit_buf_data->data.data(); + attr->limit = *(static_cast(static_cast(limit_data_ptr))); + + if (tfliteOp->inputs.size() > 2) { + auto delta_idx = tfliteOp->inputs[3]; + std::for_each(tfliteTensors[delta_idx]->shape.begin(), tfliteTensors[delta_idx]->shape.end(), [&](int32_t sha){}); + auto &delta_buf_data = tfliteModelBuffer[tfliteTensors[delta_idx]->buffer]; + auto delta_data_ptr = delta_buf_data->data.data(); + attr->delta = *(static_cast(static_cast(delta_data_ptr))); + } else { + attr->delta = 0; // default + } + + attr->dType = 0; // default + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Range; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h new file mode 100644 index 0000000000..2701590151 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RANGE_PARSER_H +#define PREDICT_TFLITE_RANGE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteRangeParser : public TfliteNodeParser { + public: + TfliteRangeParser() : TfliteNodeParser("Range") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RANGE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc new file mode 100644 index 0000000000..bb64242230 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_rank_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteRankParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteRankParser"; + std::unique_ptr attr(new schema::RankT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsRankOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Rank; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h new file mode 100644 index 0000000000..11257b6f2b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RANK_PARSER_H +#define PREDICT_TFLITE_RANK_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteRankParser : public TfliteNodeParser { + public: + TfliteRankParser() : TfliteNodeParser("Rank") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RANK_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_real_div_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_real_div_parser.cc new file mode 100644 index 0000000000..12dbd47878 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_real_div_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_real_div_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteRealDivParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteRealDivParser"; + std::unique_ptr attr(new schema::RealDivT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_RealDiv; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_real_div_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_real_div_parser.h new file mode 100644 index 0000000000..110e813d0e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_real_div_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_REAL_DIV_PARSER_H +#define LITE_TFLITE_REAL_DIV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteRealDivParser : public TfliteNodeParser { + public: + TfliteRealDivParser() : TfliteNodeParser("RealDiv") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_REAL_DIV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_any_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_any_parser.cc new file mode 100644 index 0000000000..10f6489575 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_any_parser.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_reduce_any_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReduceAnyParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteReduceAnyParser"; + std::unique_ptr attr(new schema::ReduceT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + // attr->mode = schema::; + MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; + return RET_NOT_FIND_OP; + + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { + MS_LOG(ERROR) << "REDUCE_ANY get axes attr failed"; + return RET_ERROR; + } + attr->keepDims = tflite_attr->keep_dims; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceAnyParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_any_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_any_parser.h new file mode 100644 index 0000000000..daa9b84e0b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_any_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_REDUCE_ANY_PARSER_H +#define PREDICT_TFLITE_REDUCE_ANY_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReduceAnyParser : public TfliteNodeParser { + public: + TfliteReduceAnyParser() : TfliteNodeParser("ReduceAny") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_REDUCE_ANY_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_max_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_max_parser.cc new file mode 100644 index 0000000000..85e9dc70d6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_max_parser.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_reduce_max_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReduceMaxParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteReduceMaxParser"; + std::unique_ptr attr(new schema::ReduceT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + attr->mode = schema::ReduceMode_ReduceMax; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { + MS_LOG(ERROR) << "REDUCE_MAX get axes attr failed"; + return RET_ERROR; + } + attr->keepDims = tflite_attr->keep_dims; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteReduceMaxParser("ReduceMax", new TfliteReduceMaxParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_max_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_max_parser.h new file mode 100644 index 0000000000..9372b73584 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_max_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_REDUCE_MAX_PARSER_H +#define PREDICT_TFLITE_REDUCE_MAX_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReduceMaxParser : public TfliteNodeParser { + public: + TfliteReduceMaxParser() : TfliteNodeParser("ReduceMax") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_REDUCE_MAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_min_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_min_parser.cc new file mode 100644 index 0000000000..5c385b43a4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_min_parser.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_reduce_min_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReduceMinParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteReduceMinParser"; + std::unique_ptr attr(new schema::ReduceT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + attr->mode = schema::ReduceMode_ReduceMin; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { + MS_LOG(ERROR) << "REDUCE_MIN get axes attr failed"; + return RET_ERROR; + } + attr->keepDims = tflite_attr->keep_dims; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteReduceMinParser("ReduceMin", new TfliteReduceMinParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_min_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_min_parser.h new file mode 100644 index 0000000000..38d6598c6a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_min_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_REDUCE_MIN_PARSER_H +#define PREDICT_TFLITE_REDUCE_MIN_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReduceMinParser : public TfliteNodeParser { + public: + TfliteReduceMinParser() : TfliteNodeParser("ReduceMin") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_REDUCE_MIN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_prod_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_prod_parser.cc new file mode 100644 index 0000000000..4d40540bf8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_prod_parser.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_reduce_prod_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReduceProdParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteReduceProdParser"; + std::unique_ptr attr(new schema::ReduceT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + attr->mode = schema::ReduceMode_ReduceProd; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { + MS_LOG(ERROR) << "REDUCE_PROD get axes attr failed"; + return RET_ERROR; + } + attr->keepDims = tflite_attr->keep_dims; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteReduceProdParser("ReduceProd", new TfliteReduceProdParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_prod_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_prod_parser.h new file mode 100644 index 0000000000..cadfe9b707 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_prod_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_REDUCE_PROD_PARSER_H +#define PREDICT_TFLITE_REDUCE_PROD_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReduceProdParser : public TfliteNodeParser { + public: + TfliteReduceProdParser() : TfliteNodeParser("ReduceProd") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_REDUCE_PROD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc new file mode 100644 index 0000000000..657c3eaf17 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_relu6_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteRelu6Parser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteRelu6Parser"; + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_RELU6; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h new file mode 100644 index 0000000000..3d1f84ff0c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RELU6_PARSER_H +#define PREDICT_TFLITE_RELU6_PARSER_H + +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include +#include + +namespace mindspore { +namespace lite { +class TfliteRelu6Parser : public TfliteNodeParser { + public: + TfliteRelu6Parser() : TfliteNodeParser("Relu6") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RELU6_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_relu_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_relu_parser.cc new file mode 100644 index 0000000000..31877c8d8f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_relu_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_relu_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReluParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteReluParser"; + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_RELU; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_relu_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_relu_parser.h new file mode 100644 index 0000000000..7b67e0c4ee --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_relu_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RELU_PARSER_H +#define PREDICT_TFLITE_RELU_PARSER_H + +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include +#include + +namespace mindspore { +namespace lite { +class TfliteReluParser : public TfliteNodeParser { + public: + TfliteReluParser() : TfliteNodeParser("Relu") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RELU_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc new file mode 100644 index 0000000000..968989f0c2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_reshape_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteReshapeParser"; + std::unique_ptr attr(new schema::ReshapeT()); + + const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions(); + if (tfliteAttr == nullptr) { + if (tfliteOp->inputs.size() < 2) { + MS_LOG(ERROR) << "expected two input tensors, but got: " << tfliteOp->inputs.size(); + return RET_ERROR; + } + auto shape_tensor_index = tfliteOp->inputs[1]; + const auto & shape_tensor = tfliteTensors[shape_tensor_index]; + std::vector shape_tensors{shape_tensor.get()}; + if (RET_OK != ParseWeight(shape_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse shape tensor error"; + return RET_ERROR; + } + } else { + attr->format = schema::Format_NHWC; + attr->shape.resize(tfliteAttr->new_shape.size()); + for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) { + attr->shape[i] = tfliteAttr->new_shape[i]; + } + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h new file mode 100644 index 0000000000..a122d9512f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RESHAPE_PARSER_H +#define PREDICT_TFLITE_RESHAPE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReshapeParser : public TfliteNodeParser { + public: + TfliteReshapeParser() : TfliteNodeParser("Reshape") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc new file mode 100644 index 0000000000..13c7eb48c1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_resize_bilinear_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteResizeBilinearParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser"; + std::unique_ptr attr(new schema::ResizeT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeBilinearOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + attr->format = schema::Format_NHWC; + attr->method = schema::ResizeMethod_BILINEAR; + attr->alignCorners = tfliteAttr->align_corners; + auto tfliteResizeTensorIndex = tfliteOp->inputs[1]; + auto resizeTensorBufferIndex = tfliteTensors.at(tfliteResizeTensorIndex)->buffer; + auto buffData = reinterpret_cast(tfliteModelBuffer.at(resizeTensorBufferIndex)->data.data()); + auto height = buffData[0]; + auto width = buffData[1]; + attr->newWidth = width; + attr->newHeight = height; + // attr->preserveAspectRatio + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Resize; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeBilinearParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h new file mode 100644 index 0000000000..bbe48edd51 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RESIZE_PARSER_H +#define PREDICT_TFLITE_RESIZE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteResizeBilinearParser : public TfliteNodeParser { + public: + TfliteResizeBilinearParser() : TfliteNodeParser("ResizeBilinear") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RESIZE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.cc new file mode 100644 index 0000000000..20438d63a4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteResizeNearestNeighborParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; + std::unique_ptr attr(new schema::ResizeT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeNearestNeighborOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + attr->format = schema::Format_NHWC; + attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR; + attr->alignCorners = tfliteAttr->align_corners; + auto tfliteResizeTensorIndex = tfliteOp->inputs[1]; + auto resizeTensorBufferIndex = tfliteTensors.at(tfliteResizeTensorIndex)->buffer; + auto buffData = reinterpret_cast(tfliteModelBuffer.at(resizeTensorBufferIndex)->data.data()); + auto height = buffData[0]; + auto width = buffData[1]; + attr->newWidth = width; + attr->newHeight = height; + // attr->preserveAspectRatio + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Resize; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", + new TfliteResizeNearestNeighborParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.h new file mode 100644 index 0000000000..4657a67d96 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_nearest_neighbor_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RESIZE_NN_PARSER_H +#define PREDICT_TFLITE_RESIZE_NN_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteResizeNearestNeighborParser : public TfliteNodeParser { + public: + TfliteResizeNearestNeighborParser() : TfliteNodeParser("NearestNeighbor") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RESIZE_NN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc new file mode 100644 index 0000000000..f516d7e2d1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_reverse_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteReverseParser"; + std::unique_ptr attr(new schema::ReverseT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsReverseV2Options(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) { + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reverse; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteReverseParser("Reverse", new TfliteReverseParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h new file mode 100644 index 0000000000..1db4301566 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_REVERSE_PARSER_H +#define PREDICT_TFLITE_REVERSE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReverseParser : public TfliteNodeParser { + public: + TfliteReverseParser() : TfliteNodeParser("Reverse") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_REVERSE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc new file mode 100644 index 0000000000..3658541d98 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -0,0 +1,49 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_reverse_sequence_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; + std::unique_ptr attr(new schema::ReverseSequenceT()); + const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions(); + + attr->seqAxis = tflite_attr->seq_dim; + attr->batchAxis = tflite_attr->batch_dim; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->seqLengths)) { + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ReverseSequence; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h new file mode 100644 index 0000000000..20cac753e1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H +#define LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReverseSequenceParser : public TfliteNodeParser { + public: + TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_round_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_round_parser.cc new file mode 100644 index 0000000000..3e385c8e44 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_round_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_round_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteRoundParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteRoundParser"; + std::unique_ptr attr(new schema::RoundT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Round; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_round_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_round_parser.h new file mode 100644 index 0000000000..060f4d0991 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_round_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_ROUND_PARSER_H +#define LITE_TFLITE_ROUND_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteRoundParser : public TfliteNodeParser { + public: + TfliteRoundParser() : TfliteNodeParser("Round") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_ROUND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc new file mode 100644 index 0000000000..15eac6e874 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_rsqrt_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteRsqrtParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "paser TfliteRsqrtParser"; + std::unique_ptr attr(new schema::RsqrtT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Rsqrt; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteRsqrtParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h new file mode 100644 index 0000000000..8a81b42f99 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RSQRT_PARSER_H +#define PREDICT_TFLITE_RSQRT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteRsqrtParser : public TfliteNodeParser { + public: + TfliteRsqrtParser() : TfliteNodeParser("Rsqrt") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RSQRT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc new file mode 100644 index 0000000000..a113266496 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_scatter_nd_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteScatterNdParser"; + std::unique_ptr attr(new schema::ScatterNDT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsScatterNdOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + /* + MS_LOG(DEBUG) << "op->inputIndex"; + for (auto &i : op->inputIndex) { + MS_LOG(DEBUG) << i; + } + */ + // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 + // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; + std::swap(op->inputIndex[0], op->inputIndex[2]); + std::swap(op->inputIndex[1], op->inputIndex[2]); + /* + MS_LOG(DEBUG) << "op->inputIndex after resort"; + for (auto &i : op->inputIndex) { + MS_LOG(DEBUG) << i; + } + */ + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ScatterND; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h new file mode 100644 index 0000000000..3823296885 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SCATTER_ND_PARSER_H +#define PREDICT_TFLITE_SCATTER_ND_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteScatterNdParser : public TfliteNodeParser { + public: + TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SCATTER_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc new file mode 100644 index 0000000000..eab6694a7a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_shape_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteShapeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteShapeParser"; + std::unique_ptr attr(new schema::ShapeT()); + + // tflite_attr->out_type; // this attr is dropped + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Shape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteShapeParser("Shape", new TfliteShapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h new file mode 100644 index 0000000000..b0f0fee85c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SHAPE_PARSER_H +#define PREDICT_TFLITE_SHAPE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteShapeParser : public TfliteNodeParser { + public: + TfliteShapeParser() : TfliteNodeParser("Shape") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sigmoid_parser.cc new file mode 100644 index 0000000000..fcfe4110d1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sigmoid_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_sigmoid_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSigmoidParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteSigmoidParser"; + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSigmoidParser("Sigmoid", new TfliteSigmoidParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sigmoid_parser.h new file mode 100644 index 0000000000..f291125964 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sigmoid_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SIGMOID_PARSER_H +#define PREDICT_TFLITE_SIGMOID_PARSER_H + +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include +#include + +namespace mindspore { +namespace lite { +class TfliteSigmoidParser : public TfliteNodeParser { + public: + TfliteSigmoidParser() : TfliteNodeParser("Sigmoid") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SIGMOID_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sin_parser.cc new file mode 100644 index 0000000000..d02173496e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sin_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_sin_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSinParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSinParser"; + std::unique_ptr attr(new schema::SinT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sin; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSinParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sin_parser.h new file mode 100644 index 0000000000..3b02203635 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sin_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SIN_PARSER_H +#define PREDICT_TFLITE_SIN_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSinParser : public TfliteNodeParser { + public: + TfliteSinParser() : TfliteNodeParser("Sin") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SIN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc new file mode 100644 index 0000000000..9f74bddd8f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_slice_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteSliceParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteSliceParser"; + std::unique_ptr attr(new schema::SliceT()); + + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->begin)) { + return RET_ERROR; + } + if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->size)) { + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Slice; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h new file mode 100644 index 0000000000..70c1b96da7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SLICE_PARSER_H +#define PREDICT_TFLITE_SLICE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSliceParser : public TfliteNodeParser { + public: + TfliteSliceParser() : TfliteNodeParser("Slice") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SLICE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc new file mode 100644 index 0000000000..07bb3a9507 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_softmax_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; + std::unique_ptr attr(new schema::SoftMaxT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsSoftmaxOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + // attr->axis + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SoftMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSoftmaxParser("Softmax", new TfliteSoftmaxParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h new file mode 100644 index 0000000000..685898c429 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CONV_PARSER_H +#define PREDICT_TFLITE_CONV_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSoftmaxParser : public TfliteNodeParser { + public: + TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc new file mode 100644 index 0000000000..ea3c2c8a3f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -0,0 +1,51 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; + std::unique_ptr attr(new schema::SpaceToBatchNDT()); + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { + MS_LOG(ERROR) << "spaceToBatchND -> blockShape get failed"; + return RET_ERROR; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->paddings)) { + MS_LOG(ERROR) << "spaceToBatchND -> paddings get failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SpaceToBatchND; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSpaceToBatchNDParser("SpaceToBatchND", new TfliteSpaceToBatchNDParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h new file mode 100644 index 0000000000..287f492bc6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#define LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSpaceToBatchNDParser : public TfliteNodeParser { + public: + TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc new file mode 100644 index 0000000000..e3435bc5f3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -0,0 +1,51 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_space_to_depth_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; + std::unique_ptr attr(new schema::SpaceToDepthT()); + const auto &tflite_attr = tflite_op->builtin_options.AsSpaceToDepthOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op:" << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + attr->blockSize = tflite_attr->block_size; + attr->format = schema::Format_NHWC; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSpaceToDepthParser("SpaceToDepth", new TfliteSpaceToDepthParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h new file mode 100644 index 0000000000..3adf534253 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H +#define LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSpaceToDepthParser : public TfliteNodeParser { + public: + TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc new file mode 100644 index 0000000000..86ed3d0a2a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -0,0 +1,57 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; + std::unique_ptr attr(new schema::SparseToDenseT()); + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->outputShape)) { + MS_LOG(ERROR) << "sparseToDense -> outputShape get failed"; + return RET_ERROR; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->sparseValue)) { + MS_LOG(ERROR) << "sparseToDense -> sparseValue get failed"; + return RET_ERROR; + } + if (GetTfliteData(tflite_op->inputs[3], tflite_tensors, tflite_model_buffer, attr->defaultValue)) { + MS_LOG(ERROR) << "sparseToDense -> defaultValue get failed"; + return RET_ERROR; + } + attr->validateIndices = false; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SparseToDense; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSparseToDenseParser("SparseToDense", new TfliteSparseToDenseParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h new file mode 100644 index 0000000000..ad4a6e02ce --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H +#define LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSparseToDenseParser : public TfliteNodeParser { + public: + TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc new file mode 100644 index 0000000000..eeb5d7c226 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_split_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSplitParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSplitParser"; + std::unique_ptr attr(new schema::SplitT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsSplitOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + + const auto tensor_shape = tfliteTensors[tfliteOp->inputs[1]].get()->shape; + auto axis = + *(reinterpret_cast(tfliteModelBuffer[tfliteTensors[tfliteOp->inputs[0]]->buffer]->data.data())); + if (axis < 0) { + axis += tensor_shape.size(); + } + if (axis >= tensor_shape.size()) { + MS_LOG(ERROR) << "axis value too large"; + return RET_ERROR; + } + attr->splitDim = axis; + + auto num_splits = tflite_attr->num_splits; + if (tensor_shape[axis] % num_splits != 0) { + MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; + return RET_ERROR; + } + attr->numberSplit = num_splits; + + for (int i = 0; i < num_splits; i++) { + attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Split; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSplitParser("Split", new TfliteSplitParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h new file mode 100644 index 0000000000..39210e5086 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SPLIT_PARSER_H +#define PREDICT_TFLITE_SPLIT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSplitParser : public TfliteNodeParser { + public: + TfliteSplitParser() : TfliteNodeParser("Split") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SPLIT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc new file mode 100644 index 0000000000..938ce612c2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_split_v_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSplitVParser"; + std::unique_ptr attr(new schema::SplitT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsSplitVOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + + attr->numberSplit = tflite_attr->num_splits; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->sizeSplits)) { + MS_LOG(ERROR) << "SPLIT_V get sizeSplits attr failed"; + return RET_ERROR; + } + + auto axis = + *(reinterpret_cast(tfliteModelBuffer[tfliteTensors[tfliteOp->inputs[2]]->buffer]->data.data())); + const auto tensor_shape = tfliteTensors[tfliteOp->inputs[0]].get()->shape; + if (axis < 0) { + axis += tensor_shape.size(); + } + if (axis >= tensor_shape.size()) { + MS_LOG(ERROR) << "axis value too large"; + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Split; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSplitVParser("SplitV", new TfliteSplitVParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h new file mode 100644 index 0000000000..c3eefcdec2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SPLIT_V_PARSER_H +#define PREDICT_TFLITE_SPLIT_V_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSplitVParser : public TfliteNodeParser { + public: + TfliteSplitVParser() : TfliteNodeParser("SplitV") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SPLIT_V_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sqrt_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sqrt_parser.cc new file mode 100644 index 0000000000..e27ebfd47a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sqrt_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_sqrt_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSqrtParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSqrtParser"; + std::unique_ptr attr(new schema::SqrtT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sqrt; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSqrtParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sqrt_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sqrt_parser.h new file mode 100644 index 0000000000..0b0ce97b78 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sqrt_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SQRT_PARSER_H +#define PREDICT_TFLITE_SQRT_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSqrtParser : public TfliteNodeParser { + public: + TfliteSqrtParser() : TfliteNodeParser("Sqrt") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SQRT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_square_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_square_parser.cc new file mode 100644 index 0000000000..e577ce5f9a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_square_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_square_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSquareParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSquareParser"; + std::unique_ptr attr(new schema::SquareT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Square; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSquareParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_square_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_square_parser.h new file mode 100644 index 0000000000..7e349a6b89 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_square_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_TFLITE_SQUARE_PARSER_H +#define LITE_TFLITE_SQUARE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSquareParser : public TfliteNodeParser { + public: + TfliteSquareParser() : TfliteNodeParser("Square") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_SQUARE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squared_difference_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squared_difference_parser.cc new file mode 100644 index 0000000000..18488fb734 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squared_difference_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_squared_difference_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSquaredDifferenceParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; + std::unique_ptr attr(new schema::SquaredDifferenceT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SquaredDifference; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteSquaredDifferenceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squared_difference_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squared_difference_parser.h new file mode 100644 index 0000000000..67dd05b109 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squared_difference_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_TFLITE_SQUARED_DIFFERENCE_PARSER_H +#define LITE_TFLITE_SQUARED_DIFFERENCE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSquaredDifferenceParser : public TfliteNodeParser { + public: + TfliteSquaredDifferenceParser() : TfliteNodeParser("SquaredDifference") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_SQUARED_DIFFERENCE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc new file mode 100644 index 0000000000..932acb90b0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_squeeze_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSqueezeParser"; + std::unique_ptr attr(new schema::SqueezeT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsSqueezeOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + + attr->axis = tflite_attr->squeeze_dims; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Squeeze; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSqueezeParser("Squeeze", new TfliteSqueezeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h new file mode 100644 index 0000000000..7738773856 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SQUEEZE_PARSER_H +#define PREDICT_TFLITE_SQUEEZE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSqueezeParser : public TfliteNodeParser { + public: + TfliteSqueezeParser() : TfliteNodeParser("Squeeze") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SQUEEZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc new file mode 100644 index 0000000000..43d76b3c56 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_stack_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteStackParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteStackParser"; + std::unique_ptr attr(new schema::StackT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsPackOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + attr->axis = tflite_attr->axis; + attr->n = tflite_attr->values_count; + attr->isScale.assign(tfliteTensors[tfliteOp->inputs[0]]->shape.begin(), + tfliteTensors[tfliteOp->inputs[0]]->shape.end()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Stack; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteStackParser("Stack", new TfliteStackParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h new file mode 100644 index 0000000000..db85b07828 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_STACK_PARSER_H +#define PREDICT_TFLITE_STACK_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteStackParser : public TfliteNodeParser { + public: + TfliteStackParser() : TfliteNodeParser("Stack") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_STACK_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc new file mode 100644 index 0000000000..44843a550f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_strided_slice_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteStridedSliceParser"; + std::unique_ptr attr(new schema::StridedSliceT()); + const auto &tflite_attr = tflite_op->builtin_options.AsStridedSliceOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + return RET_NULL_PTR; + } + + attr->beginMask = tflite_attr->begin_mask; + attr->endMask = tflite_attr->end_mask; + attr->ellipsisMask = tflite_attr->ellipsis_mask; + attr->newAxisMask = tflite_attr->new_axis_mask; + attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->begin)) { + MS_LOG(ERROR) << "stridedSlice -> begin get failed"; + return RET_ERROR; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->end)) { + MS_LOG(ERROR) << "stridedSlice -> end get failed"; + return RET_ERROR; + } + if (GetTfliteData(tflite_op->inputs[3], tflite_tensors, tflite_model_buffer, attr->stride)) { + MS_LOG(ERROR) << "stridedSlice -> stride get failed"; + return RET_ERROR; + } + attr->isScale.assign(tflite_tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_tensors[tflite_op->inputs[0]]->shape.end()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_StridedSlice; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteStridedSliceParser("StridedSlice", new TfliteStridedSliceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h new file mode 100644 index 0000000000..4a2b1814db --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef LITE_TFLITE_STRIDED_SLICE_PARSER_H +#define LITE_TFLITE_STRIDED_SLICE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteStridedSliceParser : public TfliteNodeParser { + public: + TfliteStridedSliceParser() : TfliteNodeParser("StridedSlice") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_STRIDED_SLICE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc new file mode 100644 index 0000000000..1a517f1490 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_sub_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSubParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteSubParser"; + std::unique_ptr attr(new schema::SubT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + } + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteSubParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h new file mode 100644 index 0000000000..f84c30c781 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SUB_PARSER_H +#define PREDICT_TFLITE_SUB_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSubParser : public TfliteNodeParser { + public: + TfliteSubParser() : TfliteNodeParser("Sub") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SUB_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.cc new file mode 100644 index 0000000000..9844d45467 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSumParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(INFO) << "parse TfliteSumParser"; + std::unique_ptr attr(new schema::ReduceT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; + return RET_NULL_PTR; + } + attr->mode = schema::ReduceMode_ReduceSum; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { + MS_LOG(ERROR) << "SUM get axes attr failed"; + return RET_ERROR; + } + attr->keepDims = tflite_attr->keep_dims; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteSumParser("Sum", new TfliteSumParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.h new file mode 100644 index 0000000000..6457be43f1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sum_parser.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SUM_PARSER_H +#define PREDICT_TFLITE_SUM_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSumParser : public TfliteNodeParser { + public: + TfliteSumParser() : TfliteNodeParser("Sum") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SUM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc new file mode 100644 index 0000000000..f78b8fa71e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_tanh_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteTanhParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteTanhParser"; + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_TANH; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h new file mode 100644 index 0000000000..38a003d87a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_TANH_PARSER_H +#define PREDICT_TFLITE_TANH_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteTanhParser : public TfliteNodeParser { + public: + TfliteTanhParser() : TfliteNodeParser("Tanh") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_TANH_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc new file mode 100644 index 0000000000..c3e259fa27 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -0,0 +1,48 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_tile_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteTileParser"; + std::unique_ptr attr(new schema::TileT()); + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->multiples)) { + MS_LOG(ERROR) << "tile -> multiples get failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Tile; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteTileParser("Tile", new TfliteTileParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h new file mode 100644 index 0000000000..48ba12053a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_TILE_PARSER_H +#define LITE_TFLITE_TILE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteTileParser : public TfliteNodeParser { + public: + TfliteTileParser() : TfliteNodeParser("Tile") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_TILE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc new file mode 100644 index 0000000000..c086eab7d4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -0,0 +1,47 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_topk_v2_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; + std::unique_ptr attr(new schema::TopKV2T()); + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) { + MS_LOG(ERROR) << "topKV2 -> k get failed"; + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_TopKV2; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteTopKV2Parser("TopKV2", new TfliteTopKV2Parser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h new file mode 100644 index 0000000000..3bf9c1dabf --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_TOPK_V2_PARSER_H +#define LITE_TFLITE_TOPK_V2_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteTopKV2Parser : public TfliteNodeParser { + public: + TfliteTopKV2Parser() : TfliteNodeParser("TopKV2") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_TOPK_V2_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc new file mode 100644 index 0000000000..829a813701 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/parser/tflite/tflite_transpose_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteTransposeParser"; + std::unique_ptr attr(new schema::TransposeT()); + + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->perm)) { + MS_LOG(ERROR) << "parse Transpose attr perm failed"; + return RET_ERROR; + } + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Transpose; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteTransposeParser("Transpose", new TfliteTransposeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h new file mode 100644 index 0000000000..f92eed0c54 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_TRANSPOSE_PARSER_H +#define PREDICT_TFLITE_TRANSPOSE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteTransposeParser : public TfliteNodeParser { + public: + TfliteTransposeParser() : TfliteNodeParser("Transpose") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_TRANSPOSE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc new file mode 100644 index 0000000000..12131d606f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -0,0 +1,49 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_unique_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteUniqueParser"; + std::unique_ptr attr(new schema::UniqueT()); + const auto &tflite_attr = tflite_op->builtin_options.AsUniqueOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + return RET_NULL_PTR; + } + + attr->outType = dtype_map[tflite_attr->idx_out_type]; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Unique; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteUniqueParser("Unique", new TfliteUniqueParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h new file mode 100644 index 0000000000..331aa5e48f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h @@ -0,0 +1,55 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_UNIQUE_PARSER_H +#define LITE_TFLITE_UNIQUE_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteUniqueParser : public TfliteNodeParser { + public: + TfliteUniqueParser() : TfliteNodeParser("Unique") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; + + private: + std::map dtype_map = { + {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, + {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, + {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, + }; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_UNIQUE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc new file mode 100644 index 0000000000..eca413ba7a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -0,0 +1,50 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_unstack_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "paser TfliteUnstackParser"; + std::unique_ptr attr(new schema::UnstackT()); + const auto &tflite_attr = tflite_op->builtin_options.AsUnpackOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + return RET_NULL_PTR; + } + + attr->num = tflite_attr->num; + attr->axis = tflite_attr->axis; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Unstack; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteUnstackParser("Unstack", new TfliteUnstackParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h new file mode 100644 index 0000000000..82729e7f38 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_UNSTACK_PARSER_H +#define LITE_TFLITE_UNSTACK_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteUnstackParser : public TfliteNodeParser { + public: + TfliteUnstackParser() : TfliteNodeParser("Unstack") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_UNSTACK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc new file mode 100644 index 0000000000..0ead307e90 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_util.h" +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +std::map tfMsActivationFunctionMap{ + {tflite::ActivationFunctionType_NONE, schema::ActivationType_NO_ACTIVATION}, + {tflite::ActivationFunctionType_RELU, schema::ActivationType_RELU}, + {tflite::ActivationFunctionType_RELU6, schema::ActivationType_RELU6}, +}; + +schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { + return tfMsActivationFunctionMap.at(tfliteAFType); +} + +std::map tfMsOpTypeMap{ + {tflite::BuiltinOperator_CONV_2D, "Conv2D"}, + {tflite::BuiltinOperator_DEPTHWISE_CONV_2D, "DepthwiseConv2D"}, + {tflite::BuiltinOperator_AVERAGE_POOL_2D, "MeanPooling"}, + {tflite::BuiltinOperator_MAX_POOL_2D, "MaxPooling"}, + {tflite::BuiltinOperator_ADD, "Add"}, + {tflite::BuiltinOperator_CONCATENATION, "Concat"}, + {tflite::BuiltinOperator_RESIZE_BILINEAR, "ResizeBilinear"}, + {tflite::BuiltinOperator_RESHAPE, "Reshape"}, + {tflite::BuiltinOperator_LOGISTIC, "Logistic"}, + {tflite::BuiltinOperator_MUL, "Mul"}, + {tflite::BuiltinOperator_SOFTMAX, "Softmax"}, + {tflite::BuiltinOperator_FULLY_CONNECTED, "FullyConnected"}, + {tflite::BuiltinOperator_SLICE, "Slice"}, + {tflite::BuiltinOperator_SUB, "Sub"}, + {tflite::BuiltinOperator_TRANSPOSE, "Transpose"}, + {tflite::BuiltinOperator_PACK, "Stack"}, + {tflite::BuiltinOperator_MEAN, "Mean"}, + {tflite::BuiltinOperator_RELU6, "Relu6"}, + {tflite::BuiltinOperator_TANH, "Tanh"}, + {tflite::BuiltinOperator_RSQRT, "Rsqrt"}, + {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, + {tflite::BuiltinOperator_SQUARED_DIFFERENCE, "SquaredDifference"}, + {tflite::BuiltinOperator_FAKE_QUANT, "FakeQuant"}, + {tflite::BuiltinOperator_TRANSPOSE_CONV, "DeConv2D"}, + {tflite::BuiltinOperator_PAD, "Pad"}, + {tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, "NearestNeighbor"}, + {tflite::BuiltinOperator_RELU, "Relu"}, + {tflite::BuiltinOperator_LEAKY_RELU, "LeakyRelu"}, + {tflite::BuiltinOperator_SQUEEZE, "Squeeze"}, + {tflite::BuiltinOperator_POW, "Pow"}, + {tflite::BuiltinOperator_ARG_MIN, "Argmin"}, + {tflite::BuiltinOperator_CEIL, "Ceil"}, + {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"}, + {tflite::BuiltinOperator_FILL, "Fill"}, + {tflite::BuiltinOperator_DIV, "Div"}, + {tflite::BuiltinOperator_FLOOR, "flOOR"}, + {tflite::BuiltinOperator_FLOOR_DIV, "FloorDiv"}, + {tflite::BuiltinOperator_FLOOR_MOD, "FloorMod"}, + {tflite::BuiltinOperator_GATHER, "Gather"}, + {tflite::BuiltinOperator_GATHER_ND, "GatherND"}, + {tflite::BuiltinOperator_REVERSE_V2, "reverse"}, + {tflite::BuiltinOperator_RANGE, "Range"}, + {tflite::BuiltinOperator_RANK, "Rank"}, + {tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, "LocalResponseNorm"}, + {tflite::BuiltinOperator_GATHER, "GatherV2"}, + {tflite::BuiltinOperator_EXP, "Exp"}, + {tflite::BuiltinOperator_SPLIT_V, "SplitV"}, + {tflite::BuiltinOperator_SPLIT, "Split"}, + {tflite::BuiltinOperator_BATCH_TO_SPACE_ND, "BatchToSpaceND"}, + {tflite::BuiltinOperator_STRIDED_SLICE, "StridedSlice"}, + {tflite::BuiltinOperator_ONE_HOT, "OneHot"}, + {tflite::BuiltinOperator_SHAPE, "Shape"}, + {tflite::BuiltinOperator_SQUEEZE, "Squeeze"}, + {tflite::BuiltinOperator_ABS, "Abs"}, + {tflite::BuiltinOperator_SIN, "Sin"}, + {tflite::BuiltinOperator_COS, "Cos"}, + {tflite::BuiltinOperator_LOG, "Log"}, + {tflite::BuiltinOperator_SQRT, "Sqrt"}, + {tflite::BuiltinOperator_SQUARE, "Square"}, + {tflite::BuiltinOperator_LOGICAL_NOT, "LogicalNot"}, + {tflite::BuiltinOperator_LOGICAL_AND, "LogicalAnd"}, + {tflite::BuiltinOperator_LOGICAL_OR, "LogicalOr"}, + {tflite::BuiltinOperator_HARD_SWISH, "HardSwish"}, + {tflite::BuiltinOperator_SUM, "Sum"}, + {tflite::BuiltinOperator_REDUCE_PROD, "ReduceProd"}, + {tflite::BuiltinOperator_REDUCE_MAX, "ReduceMax"}, + {tflite::BuiltinOperator_REDUCE_MIN, "ReduceMin"}, + // {tflite::BuiltinOperator_REDUCE_ANY, "ReduceAny"}, + {tflite::BuiltinOperator_SCATTER_ND, "ScatterNd"}, + {tflite::BuiltinOperator_MAXIMUM, "Maximum"}, + {tflite::BuiltinOperator_MINIMUM, "Minimum"}, + {tflite::BuiltinOperator_ADD_N, "AddN"}, + {tflite::BuiltinOperator_CAST, "Cast"}, + {tflite::BuiltinOperator_EQUAL, "Equal"}, + {tflite::BuiltinOperator_NOT_EQUAL, "NotEqual"}, + {tflite::BuiltinOperator_GREATER, "Greater"}, + {tflite::BuiltinOperator_GREATER_EQUAL, "GreaterEqual"}, + {tflite::BuiltinOperator_LESS, "Less"}, + {tflite::BuiltinOperator_LESS_EQUAL, "LessEqual"}, + {tflite::BuiltinOperator_DEPTH_TO_SPACE, "DepthToSpace"}, + {tflite::BuiltinOperator_SPACE_TO_BATCH_ND, "SpaceToBatchND"}, + {tflite::BuiltinOperator_SPACE_TO_DEPTH, "SpaceToDepth"}, + {tflite::BuiltinOperator_PRELU, "Prelu"}, + {tflite::BuiltinOperator_ROUND, "Round"}, + {tflite::BuiltinOperator_WHERE, "Where"}, + {tflite::BuiltinOperator_SPARSE_TO_DENSE, "SparseToDense"}, + {tflite::BuiltinOperator_ZEROS_LIKE, "ZerosLike"}, + {tflite::BuiltinOperator_TILE, "Tile"}, + {tflite::BuiltinOperator_TOPK_V2, "TopKV2"}, + {tflite::BuiltinOperator_REVERSE_SEQUENCE, "ReverseSequence"}, + {tflite::BuiltinOperator_UNIQUE, "Unique"}, + {tflite::BuiltinOperator_UNPACK, "Unstack"}, +}; + +std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { + auto iter = tfMsOpTypeMap.find(tfliteOpType); + if (iter == tfMsOpTypeMap.end()) { + // return "unsupported_op_type"; + return tflite::EnumNameBuiltinOperator(tfliteOpType); + } + return iter->second; +} + +std::map type_map = { + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, +}; + +TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type) { + auto iter = type_map.find(tflite_data_type); + if (iter == type_map.end()) { + return kTypeUnknown; + } + return iter->second; +} + +schema::PadMode GetPadMode(tflite::Padding tflite_padmode) { + if (tflite_padmode == tflite::Padding_SAME) { + return schema::PadMode_SAME; + } else if (tflite_padmode == tflite::Padding_VALID) { + return schema::PadMode_VALID; + } else { + return schema::PadMode_NOTSET; + } +} + +size_t GetDataTypeSize(const TypeId &data_type) { + switch (data_type) { + case TypeId::kNumberTypeFloat32: + return sizeof(float); + case TypeId::kNumberTypeFloat16: + return sizeof(float) >> 1; + case TypeId::kNumberTypeInt8: + return sizeof(int8_t); + case TypeId::kNumberTypeInt32: + return sizeof(int); + case TypeId::kNumberTypeUInt8: + return sizeof(uint8_t); + case TypeId::kNumberTypeUInt32: + return sizeof(uint32_t); + default: + MS_LOG(ERROR) << "unsupport datatype"; + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h new file mode 100644 index 0000000000..3598c2f3a2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_TFLITE_UTIL_H +#define MS_TFLITE_UTIL_H + + +#include +#include "utils/log_adapter.h" +#include "schema/inner/model_generated.h" +#include "tools/converter/parser/tflite/schema_generated.h" +#include "schema/inner/ops_generated.h" +#include "ir/dtype/type_id.h" + +// using namespace std; + +namespace mindspore { +namespace lite { +schema::PadMode GetPadMode(tflite::Padding tflite_padmode); + +size_t GetDataTypeSize(const TypeId &data_type); + +schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType); + +std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType); + +TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); +} // namespace lite +} // namespace mindspore + +#endif // MS_TFLITE_UTIL_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc new file mode 100644 index 0000000000..e05e7cb507 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -0,0 +1,47 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_where_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteWhereParser"; + std::unique_ptr attr(new schema::WhereT()); + + if (GetTfliteData(tflite_op->inputs[0], tflite_tensors, tflite_model_buffer, attr->condition)) { + MS_LOG(ERROR) << "where -> condition get failed"; + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Where; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteWhereParser("Where", new TfliteWhereParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h new file mode 100644 index 0000000000..3a707d0a8b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_WHERE_PARSER_H +#define LITE_TFLITE_WHERE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteWhereParser : public TfliteNodeParser { + public: + TfliteWhereParser() : TfliteNodeParser("Where") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_WHERE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc new file mode 100644 index 0000000000..1393ab5701 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -0,0 +1,43 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* distributed under the License is distributed on an AS +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include "tools/converter/parser/tflite/tflite_zeros_like_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantized_model) { + MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; + std::unique_ptr attr(new schema::ZerosLikeT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ZerosLike; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteZerosLikeParser("ZerosLike", new TfliteZerosLikeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h new file mode 100644 index 0000000000..0c656137d5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h @@ -0,0 +1,41 @@ +/** +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#ifndef LITE_TFLITE_ZEROS_LIKE_PARSER_H +#define LITE_TFLITE_ZEROS_LIKE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteZerosLikeParser : public TfliteNodeParser { + public: + TfliteZerosLikeParser() : TfliteNodeParser("ZerosLike") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + const std::vector> &tflite_opset, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantized_model) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_ZEROS_LIKE_PARSER_H diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt new file mode 100644 index 0000000000..f22f952164 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -0,0 +1,20 @@ +set(3RD_DIR ../../../third_party) +include_directories(${3RD_DIR}/protobuf/build/include) +include_directories(${3RD_DIR}/flatbuffers/include) +include_directories(${3RD_DIR}/opencv/build/include/opencv4) + +add_library(quantizer_mid OBJECT + #${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc + #${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc + ${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc + #${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc + ) + +if(ENABLE_ASAN) + target_link_libraries(quantizer_mid libasan libSecodefuzz) +endif() diff --git a/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc b/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc new file mode 100644 index 0000000000..3893b21c76 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/general_bitpacking.h" + +namespace mindspore { +namespace lite { +BitPack::BitPack(const uint8_t& bitnum) {this->bitnum = bitnum;} +void BitPack::UnPackFromUint8ToOrigin(uint8_t& n, std::queue& unpackBitData) { + int bitCount = 0; + while (bitCount < 8) { + bool a = n % 2; + n = n >> 1; + bitCount++; + unpackBitData.push(a); + } +} +void BitPack::UnPack(uint8_t bitnum, uint8_t& packedData, + std::vector &originData, std::queue& unpackBitData) { + UnPackFromUint8ToOrigin(packedData, unpackBitData); + // std::queue unpackBitTmpData; + + while (unpackBitData.size() > bitnum) { + uint32_t result = 0; + for (int k = 0; k < bitnum; k++) { + bool bitTmp = unpackBitData.front(); + result = (result << 1) + static_cast(bitTmp); + unpackBitData.pop(); + } + originData.push_back(result); + } +} +void BitPack::PackFromOriginToUint8(std::stack& ans, std::vector& packedDataVec) { + uint32_t result = 0; + for (size_t i = 0; i < 8; i++) { + bool bit_tmp = ans.top(); + result = (result << 1) + static_cast(bit_tmp); + ans.pop(); + } + packedDataVec.push_back(result); +} +void BitPack::DoBinary(uint8_t& n, std::stack& ans, std::vector& packedDataVec) { + int bitCount = 0; + while (bitCount < bitnum) { + bool a = n / (1 << (unsigned int)(bitnum - bitCount - 1)); + n = n - a * (1 << (unsigned int)(bitnum - bitCount - 1)); + bitCount++; + ans.push(a); + if (ans.size() == 8) { + PackFromOriginToUint8(ans, packedDataVec); + } + } +} + +void BitPack::BitPacking(const std::vector& originDataVec, std::vector& packedDataVec) { + std::stack bitDataVec; + for (size_t i = 0; i < originDataVec.size(); i++) { + uint8_t tmp = originDataVec[i]; + DoBinary(tmp, bitDataVec, packedDataVec); + } + + size_t remainBitData = bitDataVec.size(); + if ( 8 > remainBitData && remainBitData > 0 ) { + for ( int i = 0; i < 8 - remainBitData; i++ ) { + bitDataVec.push(0); + } + PackFromOriginToUint8(bitDataVec, packedDataVec); + } +} + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/general_bitpacking.h b/mindspore/lite/tools/converter/quantizer/general_bitpacking.h new file mode 100644 index 0000000000..284c6e028d --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/general_bitpacking.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GENERAL_BITPACKING_H +#define MINDSPORE_GENERAL_BITPACKING_H +#include +#include +#include +#include +#include + +namespace mindspore { +namespace lite { +class BitPack { + public: + explicit BitPack(const uint8_t &bitbum = 8); + ~BitPack() = default; + void BitPacking(const std::vector &originDataVec, std::vector &packedDataVec); + void UnPack(uint8_t bitnum, uint8_t &packedData, std::vector &originData, std::queue &unpackBitData); + + private: + void UnPackFromUint8ToOrigin(uint8_t &n, std::queue &unpackBitData); + void PackFromOriginToUint8(std::stack &ans, std::vector &packedDataVec); + void DoBinary(uint8_t &n, std::stack &ans, std::vector &packed_data_vec); + uint8_t bitnum; +}; +} // namespace lite +} // namespace mindspore + +#endif diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training.cc new file mode 100644 index 0000000000..10e0609add --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/post_training.cc @@ -0,0 +1,958 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "src/ir/tensor.h" +#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/converter/quantizer/post_training.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "src/common/common.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +#include "tools/common/tensor_util.h" +#include "src/common/file_utils.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { + +struct DivergInfo { + std::vector histogram; + CNodePtr cnode; + int bin_num; + float interval = 0; + float max; + float min; + float best_T = 0.0f; + size_t bit_num; + int quant_max = 255; + int quant_min = 0; + DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) { + this->cnode = cnode; + this->bin_num = bins; + this->bit_num = bits; + histogram.resize(bin_num); + max = FLT_MIN; + min = FLT_MAX; + this->quant_max = quant_max; + this->quant_min = quant_min; + std::fill(histogram.begin(), histogram.end(), 1.0e-7); + } + + STATUS RecordMaxValue(const std::vector &datas) { + for (float data : datas) { + max = std::max(data, max); + min = std::min(data, min); + } + return RET_OK; + } + + void UpdateInterval() { + auto max_value = std::max(fabs(this->max), fabs(this->min)); + this->interval = max_value / static_cast(bin_num); + } + + STATUS UpdateHistogram(const std::vector &data, const std::vector &shape) { + for (auto value : data) { + int bin_index = std::min(static_cast(std::fabs(value) / this->interval), bin_num - 1); + this->histogram[bin_index]++; + } + return RET_OK; + } + + void DumpHistogram() { + MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram"; + for (float item : this->histogram) { + std::cout << item << " "; + } + std::cout << std::endl; + } + + STATUS ComputeThreshold() { + constexpr int quant_bint_nums = 128; + int threshold = quant_bint_nums; + float min_kl = FLT_MAX; + float after_threshold_sum = std::accumulate(this->histogram.begin() + quant_bint_nums, this->histogram.end(), 0.0f); + + for (int i = quant_bint_nums; i < this->bin_num; ++i) { + std::vector quantized_histogram(quant_bint_nums, 0); + std::vector reference_histogram(this->histogram.begin(), this->histogram.begin() + i); + std::vector expanded_histogram(i, 0); + reference_histogram[i - 1] += after_threshold_sum; + after_threshold_sum -= this->histogram[i]; + + const float bin_interval = static_cast(i) / static_cast(quant_bint_nums); + + // merge i bins to target bins + for (int j = 0; j < quant_bint_nums; ++j) { + const float start = j * bin_interval; + const float end = start + bin_interval; + const int left_upper = static_cast(std::ceil(start)); + if (left_upper > start) { + const double left_scale = left_upper - start; + quantized_histogram[j] += left_scale * this->histogram[left_upper - 1]; + } + const int right_lower = static_cast(std::floor(end)); + if (right_lower < end) { + const double right_scale = end - right_lower; + quantized_histogram[j] += right_scale * this->histogram[right_lower]; + } + std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, + [&quantized_histogram, j](float item) { quantized_histogram[j] += item; }); + } + // expand target bins to i bins in order to calculate KL with reference_histogram + for (int j = 0; j < quant_bint_nums; ++j) { + const float start = j * bin_interval; + const float end = start + bin_interval; + float count = 0; + const int left_upper = static_cast(std::ceil(start)); + float left_scale = 0.0f; + if (left_upper > start) { + left_scale = left_upper - start; + if (this->histogram[left_upper - 1] != 0) { + count += left_scale; + } + } + const int right_lower = static_cast(std::floor(end)); + double right_scale = 0.0f; + if (right_lower < end) { + right_scale = end - right_lower; + if (this->histogram[right_lower] != 0) { + count += right_scale; + } + } + std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, + [&count](float item) { + if (item != 0) { + count += 1; + } + }); + if (count == 0) { + continue; + } + const float average_num = quantized_histogram[j] / count; + if (left_upper > start && this->histogram[left_upper - 1] != 0) { + expanded_histogram[left_upper - 1] += average_num * left_scale; + } + if (right_lower < end && this->histogram[right_lower] != 0) { + expanded_histogram[right_lower] += average_num * right_scale; + } + for (int k = left_upper; k < right_lower; ++k) { + if (this->histogram[k] != 0) { + expanded_histogram[k] += average_num; + } + } + } + auto KLDivergence = [](std::vector p, std::vector q) { + auto sum = 0.0f; + std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; }); + std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; }); + sum = 0.0f; + std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; }); + std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; }); + + float result = 0.0f; + const int size = p.size(); + for (int i = 0; i < size; ++i) { + if (p[i] != 0) { + if (q[i] == 0) { + result += 1.0f; + } else { + result += (p[i] * std::log((p[i]) / (q[i]))); + } + } + } + return result; + }; + const float kl = KLDivergence(reference_histogram, expanded_histogram); + if (kl < min_kl) { + min_kl = kl; + threshold = i; + } + } + MS_LOG(DEBUG) << "Best threshold bin index: " << threshold; + this->best_T = (static_cast(threshold) + 0.5f) * this->interval; + return RET_OK; + } + + std::pair GetScale() { + float max_value = this->best_T; + float min_value = -max_value; + MS_ASSERT(quant_max - quant_min != 0); + double scale = (max_value - min_value) / (quant_max - quant_min); + MS_ASSERT(scale != 0); + return std::make_pair(this->cnode, scale); + } + + std::pair GetZeropoint() { + float max_value = this->best_T; + float min_value = -max_value; + MS_ASSERT(quant_max - quant_min != 0); + float scale = (max_value - min_value) / (quant_max - quant_min); + + auto quant_min_float = static_cast(quant_min); + auto quant_max_float = static_cast(quant_max); + MS_ASSERT(scale != 0); + const float zero_point_from_min = quant_min_float - min_value / scale; + // const float zero_point_from_max = quant_max_float - max_value / scale; + int zero_point; + if (zero_point_from_min < quant_min_float) { + zero_point = quant_min; + } else if (zero_point_from_min > quant_max_float) { + zero_point = quant_max; + } else { + zero_point = static_cast(std::round(zero_point_from_min)); + } + MS_LOG(DEBUG) << "zero point:" << zero_point; + if (quant_min == 0 && quant_max == 255) { + zero_point = 128; + } else if (quant_min == -128 && quant_max == 127) { + zero_point = 0; + } + + return std::make_pair(this->cnode, zero_point); + } +}; +std::unordered_map Calibrator::GetResult( + std::unordered_map> *diverg_info) { + std::unordered_map result; + for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) { + DivergInfo *info = iter->second.get(); + auto item = info->GetScale(); + result.insert(item); + } + return result; +} +std::unordered_map Calibrator::GetZeropoint( + std::unordered_map> *mDivergInfo) { + std::unordered_map result; + for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) { + DivergInfo *info = iter->second.get(); + auto zeropoint = info->GetZeropoint(); + result.insert(zeropoint); + } + return result; +} + +std::map Calibrator::GetMinMax( + std::unordered_map> *mDivergInfo) { + std::map result; + for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) { + DivergInfo *info = iter->second.get(); + mindspore::lite::quant::MaxMin input_maxmin{}; + input_maxmin.min = info->min; + input_maxmin.max = info->max; + result[info->cnode] = input_maxmin; + } + return result; +} + +void Calibrator::Dump() { + for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + info->DumpHistogram(); + } +} + +std::unordered_map> *Calibrator::GetInputDivergInfo() { + return &this->input_diverg_info_; +} + +std::unordered_map> *Calibrator::GetOutputDivergInfo() { + return &this->output_diverg_info_; +} + +STATUS Calibrator::RecordMaxValue(std::string opName, vector data, + std::unordered_map> *mDivergInfo) { + auto got = (*mDivergInfo).find(opName); + if (got != (*mDivergInfo).end()) { + ((*got).second)->RecordMaxValue(data); + } + return RET_OK; +} + +STATUS Calibrator::ComputeThreshold() { + for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + info->ComputeThreshold(); + } + // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as + for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + auto cnode = info->cnode; + + bool already_computed = false; + auto input = cnode->input(1); + if (input->isa()) { + auto input_cnode = std::dynamic_pointer_cast(input); + for (const auto &output_diverg_info : output_diverg_info_) { + auto output_diverg_cnode = output_diverg_info.second->cnode; + if (output_diverg_cnode == input_cnode) { + *info = *(output_diverg_info.second); + info->cnode = cnode; + already_computed = true; + break; + } + } + } + if (!already_computed) { + info->ComputeThreshold(); + } + } + return RET_OK; +} + +STATUS Calibrator::UpdateDivergInverval(std::unordered_map> *diverg_info) { + for (auto iter = (*diverg_info).begin(); iter != (*diverg_info).end(); iter++) { + DivergInfo *info = iter->second.get(); + info->UpdateInterval(); + } + return RET_OK; +} + +STATUS Calibrator::UpdateDataFrequency(std::string op_name, vector data, vector shape, + std::unordered_map> *diverg_info) { + auto got = (*diverg_info).find(op_name); + if (got != (*diverg_info).end()) { + ((*got).second)->UpdateHistogram(data, shape); + } + return RET_OK; +} + +STATUS Calibrator::AddQuantizedOp(CNodePtr node) { + if (node == nullptr) { + MS_LOG(ERROR) << "To be quantized node is null"; + return RET_ERROR; + } + string node_name = node->fullname_with_scope(); + std::unique_ptr input_diverg = + std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); + std::unique_ptr output_diverg = + std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); + + input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); + output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); + return RET_OK; +} + +void Calibrator::AddImage(const string file) { + auto exist = [](const string file) { + struct stat buf; + return stat(file.c_str(), &buf) == 0; + }; + if (exist(file)) { + MS_LOG(INFO) << "load image: " << file; + this->images_.push_back(file); + } else { + MS_LOG(WARNING) << "Invaild image file path: " << file; + } +} + +STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTensor *tensor) const { + string path = images_[index]; + MS_LOG(INFO) << "read image: " << path; + size_t size; + char *binBuf = ReadFile(path.c_str(), &size); + + // auto *rawinputDatas = reinterpret_cast(binBuf); + // auto mobilenet_input = const_cast(rawinputDatas); + auto data = tensor->MutableData(); + memcpy(data, binBuf, size); + + // tensor->SetData(mobilenet_input); + return RET_OK; +} + +STATUS Calibrator::CollectImages() { + // check image file path + DIR *root = opendir(config_param_.image_path.c_str()); + if (root == nullptr) { + MS_LOG(ERROR) << "invalid image path: " << config_param_.image_path; + return RET_PARAM_INVALID; + } + struct dirent *image_dir = readdir(root); + int count = 0; + while (image_dir != nullptr) { + if (image_dir->d_name[0] != '.') { + const std::string file_name = config_param_.image_path + "/" + image_dir->d_name; + if (config_param_.batch_count == 0) { + this->AddImage(file_name); + count++; + } else if (count < config_param_.batch_count) { + this->AddImage(file_name); + count++; + } else { + break; + } + } + image_dir = readdir(root); + } + closedir(root); + return RET_OK; +} + +STATUS Calibrator::ReadConfig() { + if (config_path_.empty() || config_path_.length() > PATH_MAX) { + MS_LOG(ERROR) << "invalid config path!"; + return RET_PARAM_INVALID; + } + // check whether config file path is valid + char *resolved_path = new (std::nothrow) char[PATH_MAX]{0}; + if (resolved_path == nullptr) { + MS_LOG(ERROR) << "New an object failed."; + return RET_ERROR; + } + if (nullptr != realpath(config_path_.c_str(), resolved_path)) { + config_path_ = string(resolved_path); + } + std::ifstream fs(config_path_.c_str(), std::ifstream::in); + if (!fs.is_open()) { + MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_; + delete[] resolved_path; + return RET_PARAM_INVALID; + } + std::string line; + while (std::getline(fs, line)) { + auto index = line.find('='); + if (index == std::string::npos) { + MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; + delete[] resolved_path; + return RET_PARAM_INVALID; + } + auto key = line.substr(0, index); + auto value = line.substr(index + 1); + if (key == "image_path") { + config_param_.image_path = value; + } else if (key == "batch_count") { + config_param_.batch_count = std::stoul(value); + } else if (key == "thread_num") { + config_param_.thread_num = std::stoul(value); + } else { + MS_LOG(WARNING) << "unsupported parameter"; + } + } + MS_LOG(INFO) << "image_path: " << config_param_.image_path << " " + << "batch_count: " << config_param_.batch_count << " " + << "thread_num: " << config_param_.thread_num; + + delete[] resolved_path; + fs.close(); + return RET_OK; +} + +Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) + : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} + +PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) + : Quantizer(graph) { + this->bit_num = bit_num; + this->target_type_ = target_type; + if (target_type == kNumberTypeInt8) { + quant_max = (1 << (this->bit_num - 1)) - 1; // 127 + quant_min = -(1 << (this->bit_num - 1)); // -128 + } else if (target_type == kNumberTypeUInt8) { + quant_max = (1 << this->bit_num) - 1; // 255 + quant_min = 0; + } else { + MS_LOG(ERROR) << "unsupported quant value type: " << target_type; + } + calibrator_ = std::unique_ptr(new Calibrator(path, this->bit_num, quant_max, quant_min)); + if (calibrator_ == nullptr) { + MS_LOG(ERROR) << "creat calibrator failed!"; + return; + } +} + +STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, + std::shared_ptr lite_primitive) { + if (!lite_primitive->GetInputQuantParams().empty()) { + return RET_OK; + } + schema::QuantParamT quant_param; + quant_param.scale = scale; + quant_param.zeroPoint = zeropoint; + quant_param.max = max_min->max; + quant_param.min = max_min->min; + quant_param.numBits = bit_num; + quant_param.narrowRange = false; + lite_primitive->AddInputQuantParam(quant_param); + // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); + return RET_OK; +} + +STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, + std::shared_ptr lite_primitive) { + if (!lite_primitive->GetOutputQuantParams().empty()) { + return RET_OK; + } + schema::QuantParamT quant_param; + quant_param.scale = scale; + quant_param.zeroPoint = zeropoint; + quant_param.max = max_min->max; + quant_param.min = max_min->min; + quant_param.numBits = bit_num; + quant_param.narrowRange = false; + lite_primitive->AddOutputQuantParam(quant_param); + // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); + return RET_OK; +} + +STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { + // const vector dims = filter->dims; + // perlayer + if (!node->isa()) { + MS_LOG(ERROR) << "not a parameter"; + return RET_PARAM_INVALID; + } + auto parameter = std::dynamic_pointer_cast(node); + ParamValueLitePtr paramValue = std::dynamic_pointer_cast(parameter->default_param()); + auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed: " << status; + return status; + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr input, AnfNodePtr weight, AnfNodePtr bias) { + if (input == nullptr || weight == nullptr || bias == nullptr) { + MS_LOG(ERROR) << "null pointer!"; + return RET_NULL_PTR; + } + + ParameterPtr weightParameterPtr = std::dynamic_pointer_cast(weight); + auto default_param = weightParameterPtr->default_param(); + auto weight_param = std::dynamic_pointer_cast(default_param); + // std::vector> weight_quant_params = weight_param->get_quant_params(); + + ParameterPtr biasParameterPtr = std::dynamic_pointer_cast(bias); + auto bias_default_param = biasParameterPtr->default_param(); + auto bias_param = std::dynamic_pointer_cast(bias_default_param); + + vector input_scales; + vector filter_scales; + vector bias_scales; + auto quant_params = input->GetInputQuantParams(); + size_t sizeX = quant_params.size(); + for (size_t i = 0; i < sizeX; i++) { + input_scales.emplace_back(quant_params[i].scale); + } + size_t sizeY = weight_param->quant_param().size(); + if (sizeX != sizeY) { + if (sizeX > 1 && sizeY > 1) { + MS_LOG(ERROR) << "input and filter's scale count cannot match!"; + return RET_ERROR; + } + } + for (size_t i = 0; i < sizeY; i++) { + auto scale = weight_param->quant_param()[i]->scale; + filter_scales.push_back(scale); + } + size_t size = std::max(sizeX, sizeY); + for (size_t i = 0; i < size; i++) { + auto scaleX = sizeX > 1 ? input_scales[i] : input_scales[0]; + auto scaleY = sizeY > 1 ? filter_scales[i] : filter_scales[0]; + bias_scales.push_back(scaleX * scaleY); + } + MS_ASSERT(!bias_scales.empty()); + size_t shape_size = bias_param->tensor_shape_size(); + + // set bias quant param + bias_param->quant_param().clear(); + for (size_t i = 0; i < bias_scales.size(); i++) { + std::unique_ptr param(new (std::nothrow) AnfQuantParam()); + param->scale = bias_scales[i]; + param->zeroPoint = 0; + bias_param->quant_param().emplace_back(std::move(param)); + } + // quant bias data + int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; + if (quant_datas == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_NULL_PTR; + } + float *raw_datas = reinterpret_cast(bias_param->tensor_addr()); + double bias_scale_tmp; + for (size_t i = 0; i < shape_size; i++) { + if (bias_scales.size() == 1) { + bias_scale_tmp = bias_scales[0]; + } else { + bias_scale_tmp = bias_scales[i]; + } + auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); + quant_datas[i] = quant_data; + } + auto ret = + memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + delete[] quant_datas; + return RET_ERROR; + } + delete[] quant_datas; + bias_param->set_tensor_type(kNumberTypeInt32); + return RET_OK; +} + +// STATUS PostTrainingQuantizer::reformatConvWeight(GraphDefT *graph) { +// for (auto &subGraph : graphDefT->subgraphs) { +// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { +// OpDefT *node = (*iter).get(); +// bool isConv = false; +// kTransFilterType tansType; +// if ((*node).attr.type == OpT_Conv2D) { +// tansType = kKCHW2HWCK; +// isConv = true; +// } +// else if ((*node).attr.type == OpT_DepthwiseConv2D) { +// tansType = kCKHW2HWCK; +// isConv = true; +// } +// if (isConv) { +// auto status = TransFilterFormat(&(*subGraph.get()->allTensors.at(node->inputIndex[1])), +// tansType); +// if (status != RET_OK) { +// return status; +// } +// TensorDefT *weight = subGraph->allTensors.at(node->inputIndex[1]).get(); +// weight->format = Format_HWCK; +// PostBitPack(weight, bitNum); +// } +// } +// } +//} + +STATUS PostTrainingQuantizer::QuantNode() { + auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); + auto input_scale = this->calibrator_->GetResult(this->calibrator_->GetInputDivergInfo()); + auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo()); + + auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo()); + auto output_scale = this->calibrator_->GetResult(this->calibrator_->GetOutputDivergInfo()); + auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo()); + + auto cnodes = funcGraph->GetOrderedCnodes(); + for (auto &cnode : cnodes) { + auto cnode_name = cnode->fullname_with_scope(); + if (this->calibrator_->GetInputDivergInfo()->find(cnode_name) == this->calibrator_->GetInputDivergInfo()->end()) { + MS_LOG(INFO) << cnode_name << " can not do quant"; + continue; + } + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + continue; + } + + if (input_scale.find(cnode) == input_scale.end()) { + primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); + continue; + } + auto input_vec = cnode->inputs(); + auto op_name = cnode->fullname_with_scope(); + MS_LOG(INFO) << "OpName: " << op_name; + if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") { + MS_LOG(INFO) << "todo(x): "; + // int32_t qnodeOutputZeropoint = outputZeropoint[cnode]; + // p->AddAttr(kInputTensorDataType, MakeValue((int)targetType)); + } else { + // do input quant + double scale = input_scale[cnode]; + int32_t convInputzeropoint = input_zero_point[cnode]; + DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); + // do weight quant + auto weight = cnode->input(2); + DoWeightQuant(weight); + // do bias quant + if (cnode->inputs().size() == 4) { + auto bias = cnode->input(3); + DoBiasQuant(primitiveT_value, weight, bias); + } + } + // do output quant + double OutputScale = output_scale[cnode]; + int32_t OutputZeropoint = output_zeropoint[cnode]; + DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitiveT_value); + primitiveT_value->SetQuantType(schema::QuantType_PostTraining); + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::UpdateDivergInverval() { + this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo()); + this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo()); + return RET_OK; +} + +/** + * Pre Process + * 1. generate config param + * 1.1 read config file + * 1.2 parse txt + * 2. collect image files + * 2.1 parse image files to input tensor + * 3. save quantied node + **/ +STATUS PostTrainingQuantizer::PreProcess() { + if (this->calibrator_ == nullptr) { + MS_LOG(ERROR) << "calibrator is null!"; + return RET_ERROR; + } + // 1. generate config param + STATUS status = calibrator_->ReadConfig(); + if (status != RET_OK) { + MS_LOG(ERROR) << "read proto text failed!"; + return status; + } + // 2. collect image files + status = calibrator_->CollectImages(); + if (status != RET_OK) { + MS_LOG(ERROR) << "collect images failed!"; + return status; + } + // 3. collect to be quantized operators + // from user input + QuantStrategy strategy(10); + auto cnodes = funcGraph->GetOrderedCnodes(); + for (auto cnode : cnodes) { + AnfNodePtr anf = std::dynamic_pointer_cast(cnode); + if (strategy.CanOpPostQuantized(anf)) { + MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized"; + calibrator_->AddQuantizedOp(cnode); + } + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &nodeName, + const std::vector &tensorVec) const { + if (tensorVec.size() < 1) { + MS_LOG(ERROR) << "node: " << nodeName << " input tensors is 0"; + return RET_ERROR; + } + auto *tensor = tensorVec[0]; + if (tensor->data_type() != kNumberTypeFloat32) { + //&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT + MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize" + << " tensor data_type: " << tensor->data_type(); + return RET_ERROR; + } + return RET_OK; +} + +/** + * 1. create input tensor + * 2. insert callback to session + * 3. run session + **/ +STATUS PostTrainingQuantizer::DoInference() { + for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { + // TODO(x) when model has inputs count > 1 + // get input tensor + vector inputs = session_->GetInputs(); + if (inputs.size() > 1) { + MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " >1"; + return RET_ERROR; + } + STATUS status = calibrator_->GenerateInputData(i, inputs.front()); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + /** + * struct CallBackParam { + std::string nodeType; + NODE_ID nodeName; + std::unordered_set depends; + int opExecResult; + }; + */ + mindspore::session::KernelCallBack beforeCallBack = + [&](const std::vector &beforeInputs, + const std::vector &beforeOutputs, + const mindspore::session::CallBackParam &callParam) -> bool { + if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { + return false; + } + auto tensor = beforeInputs[0]; + const float *tData = static_cast(tensor->MutableData()); + size_t shapeSize = tensor->ElementsNum(); + vector data(tData, tData + shapeSize); + this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetInputDivergInfo()); + return true; + }; + // func + mindspore::session::KernelCallBack afterCallBack = [&]( + const std::vector &afterInputs, + const std::vector &afterOutputs, + const mindspore::session::CallBackParam &callParam) -> bool { + if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) { + return false; + } + auto tensor = afterOutputs[0]; + const float *tensor_data = static_cast(tensor->MutableData()); + size_t shape_size = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + shape_size); + this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetOutputDivergInfo()); + return true; + }; + status = session_->RunGraph(beforeCallBack, afterCallBack); + if (status != RET_OK) { + MS_LOG(ERROR) << "run model failed!"; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::CollectDataFrequency() { + for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { + // TODO(x) when model has inputs count > 1 + // get input tensor + vector inputs = session_->GetInputs(); + if (inputs.size() > 1) { + MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " > 1"; + return RET_ERROR; + } + STATUS status = calibrator_->GenerateInputData(i, inputs.front()); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + + mindspore::session::KernelCallBack beforeCallBack = + [&](const std::vector &beforeInputs, + const std::vector &beforeOutputs, + const mindspore::session::CallBackParam &callParam) { + if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { + return false; + } + auto tensor = beforeInputs[0]; + const float *tensor_data = static_cast(tensor->MutableData()); + size_t shape_size = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + shape_size); + this->calibrator_->UpdateDataFrequency(callParam.name_callback_param, data, tensor->shape(), + this->calibrator_->GetInputDivergInfo()); + return true; + }; + + mindspore::session::KernelCallBack afterCallBack = + [&](const std::vector &after_inputs, + const std::vector &after_outputs, + const mindspore::session::CallBackParam &call_param) { + if (PostTrainingQuantizer::CheckTensorVec(call_param.name_callback_param, after_outputs) != RET_OK) { + return false; + } + auto tensor = after_outputs[0]; + const float *tenosr_data = static_cast(tensor->MutableData()); + size_t shape_size = tensor->ElementsNum(); + vector data(tenosr_data, tenosr_data + shape_size); + this->calibrator_->UpdateDataFrequency(call_param.name_callback_param, data, tensor->shape(), + this->calibrator_->GetOutputDivergInfo()); + return true; + }; + status = session_->RunGraph(beforeCallBack, afterCallBack); + if (status != RET_OK) { + MS_LOG(ERROR) << "run model failed!"; + return RET_ERROR; + } + } + + return RET_OK; +} + +STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); } + +STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { + MS_LOG(INFO) << "start to parse config file"; + STATUS status = PreProcess(); + if (status != RET_OK) { + MS_LOG(ERROR) << "do pre process failed!"; + return status; + } + MS_LOG(INFO) << "start create session"; + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph)); + builder.Finish(offset); + size_t size = builder.GetSize(); + auto *content = reinterpret_cast(builder.GetBufferPointer()); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + auto model = lite::Model::Import(content, size); + + Context ctx; + ctx.device_ctx_.type = DT_CPU; + ctx.thread_num_ = calibrator_->GetThreadNum(); + ctx.cpu_bind_mode_ = MID_CPU; + + session_ = dynamic_cast(session::LiteSession::CreateSession(&ctx)); + if (session_ == nullptr) { + MS_LOG(ERROR) << "create session failed!"; + return RET_ERROR; + } + + auto ret = session_->CompileGraph(model.get()); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "compile graph error"; + return RET_ERROR; + } + + MS_LOG(INFO) << "start to update divergence's max value"; + status = DoInference(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "start to update divergence's interval"; + status = UpdateDivergInverval(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "start to collect data's distribution"; + status = CollectDataFrequency(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "compute the best threshold"; + status = ComputeThreshold(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "start to generate quant param and quantize tensor's data"; + status = QuantNode(); + if (status != RET_OK) { + return status; + } + return RET_OK; +} +} // namespace quant +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/post_training.h b/mindspore/lite/tools/converter/quantizer/post_training.h new file mode 100644 index 0000000000..d000df53ef --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/post_training.h @@ -0,0 +1,160 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POSTRAINING_QUANTIZER_H +#define POSTRAINING_QUANTIZER_H + +#include +#include +#include +#include +#include +#include +#include "src/lite_session.h" +#include "tools/converter/quantizer/quantizer.h" +#include "src/ir/primitive_t_value.h" +#include "tools/converter/converter.h" +#include "include/ms_tensor.h" + +namespace mindspore { +namespace lite { +namespace quant { +class Calibrator; + +struct MaxMin { + public: + float min; + float max; +}; + +enum ImageFormat { + RGB = 0, + GRAY = 1, + BGR = 2, +}; + +struct ConfigParam { + // ImageFormat imageFormat; + std::string image_path; + uint32_t batch_count; + uint32_t thread_num; +}; + +class PostTrainingQuantizer : public Quantizer { + public: + PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8); + + STATUS DoQuantize(FuncGraphPtr funcGraph) override; + + size_t bit_num; + int quant_max{255}; + int quant_min{0}; + + private: + TypeId target_type_{kNumberTypeInt8}; + + std::unique_ptr calibrator_; + + mindspore::lite::LiteSession *session_; + + STATUS PreProcess(); + + STATUS CheckTensorVec(const std::string &nodeName, const std::vector &tensorVec) const; + + STATUS DoInference(); + + STATUS UpdateDivergInverval(); + + STATUS CollectDataFrequency(); + + STATUS ComputeThreshold(); + + STATUS QuantNode(); + + // STATUS reformatConvWeight(GraphDefT *graph); + + STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + + STATUS DoWeightQuant(AnfNodePtr node); + + STATUS DoBiasQuant(std::shared_ptr input, AnfNodePtr weight, AnfNodePtr bias); +}; + +struct DivergInfo; + +class Calibrator { + public: + explicit Calibrator(std::string path, size_t quant_size, int quant_max, int quant_msin); + + ~Calibrator() = default; + + STATUS ReadConfig(); + + STATUS CollectImages(); + + STATUS GenerateInputData(int index, mindspore::tensor::MSTensor *tensor) const; + + size_t GetBatchNum() const { return images_.size(); } + + uint32_t GetThreadNum() const { return config_param_.thread_num; } + + STATUS AddQuantizedOp(CNodePtr node); + + STATUS RecordMaxValue(std::string opName, std::vector data, + std::unordered_map> *diverg_info); + + STATUS UpdateDivergInverval(std::unordered_map> *diverg_info); + + STATUS UpdateDataFrequency(std::string op_name, std::vector data, std::vector shape, + std::unordered_map> *diverg_info); + void Dump(); + + STATUS ComputeThreshold(); + + std::unordered_map GetResult( + std::unordered_map> *diverg_info); + + std::unordered_map GetZeropoint( + std::unordered_map> *diverg_info); + + std::map GetMinMax(std::unordered_map> *diverg_info); + + std::unordered_map> *GetInputDivergInfo(); + + std::unordered_map> *GetOutputDivergInfo(); + + private: + std::vector images_; + + std::string config_path_; + + ConfigParam config_param_; + + std::unordered_map> input_diverg_info_; + + std::unordered_map> output_diverg_info_; + + size_t bit_num_; + int quant_max_; + int quant_min_; + + void AddImage(std::string file); +}; +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif // POSTRAINING_QUANTIZER_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc new file mode 100644 index 0000000000..0d4857c182 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "mindspore/lite/tools/converter/quantizer/quant_cast.h" +#include +#include +#include "mindspore/lite/src/ir/primitive_t_value.h" + +namespace mindspore::lite::quant { + +ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector &quant_params) { + std::unique_ptr primitive = std::make_unique(); + schema::QuantDTypeCastT quant_dtype_cast; + quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; + quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; + primitive->value.Set(quant_dtype_cast); + auto primTValue = std::make_shared(primitive.release()); + for (auto &quant_param : quant_params) { + primTValue->AddInputQuantParam(quant_param); + } + return NewValueNode(primTValue); +} + +STATUS QuantCast::Run(FuncGraphPtr graph) { + MS_ASSERT(graph != nullptr); + + auto cnodes = graph->GetOrderedCnodes(); + bool first = true; + + for (auto &cnode : cnodes) { + auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto curnode_quant_type = schema::QuantType_QUANT_NONE; + if (primitiveT_value == nullptr) { + MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + } else { + curnode_quant_type = primitiveT_value->GetQuantType(); + } + if (first) { + if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { + auto value_node = + NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams()); + std::vector op_inputs = {value_node, cnode->input(1)}; + auto quant_cast_cnode = graph->NewCNode(op_inputs); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); + cnode->set_input(1, quant_cast_cnode); + MS_LOG(DEBUG) << "Add quant cast at front. " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type; + } + first = false; + continue; + } + + for (int i = 1; i < cnode->inputs().size(); i++) { + auto input_node = cnode->input(i); + if (!input_node->isa()) { + continue; + } + auto input_cnode = std::dynamic_pointer_cast(input_node); + auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitiveT_value == nullptr) { + MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " + << " PrimitiveTValue is null"; + continue; + } + auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType(); + + if (curnode_quant_type != input_cnode_quant_type) { + ValueNodePtr value_node = nullptr; + if (curnode_quant_type == schema::QuantType_PostTraining && + input_cnode_quant_type == schema::QuantType_QUANT_NONE) { + value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, + input_cnode_primitiveT_value->GetInputQuantParams()); + } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && + input_cnode_quant_type == schema::QuantType_PostTraining) { + value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32, + input_cnode_primitiveT_value->GetInputQuantParams()); + } + if (value_node == nullptr) { + MS_LOG(WARNING) << "value_node is null! " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " + << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " quant_type:" << input_cnode_quant_type; + continue; + } + std::vector op_inputs = {value_node, input_cnode}; + auto quant_cast_cnode = graph->NewCNode(op_inputs); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); + cnode->set_input(i, quant_cast_cnode); + MS_LOG(DEBUG) << "Add quant cast. " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type + << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " quant_type:" << input_cnode_quant_type; + } else { + MS_LOG(DEBUG) << "No need to add quant cast. " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type + << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " quant_type:" << input_cnode_quant_type; + } + } + } + return RET_OK; +} + +} // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.h b/mindspore/lite/tools/converter/quantizer/quant_cast.h new file mode 100644 index 0000000000..2349dc37bb --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_QUANT_CAST_H +#define LITE_QUANT_CAST_H + +#include "mindspore/core/ir/anf.h" +#include "mindspore/lite/include/errorcode.h" +#include "mindspore/core/ir/dtype/type_id.h" +#include "mindspore/core/ir/func_graph.h" + +namespace mindspore::lite::quant { + +class QuantCast { + public: + QuantCast() = default; + STATUS Run(FuncGraphPtr graph); + void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } + + private: + TypeId inputDataDType = kNumberTypeFloat32; +}; + +} // namespace mindspore::lite::quant + +#endif // LITE_QUANT_CAST_H diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc new file mode 100644 index 0000000000..0e0ba14f93 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -0,0 +1,333 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "src/ir/primitive_t_value.h" +#include "mindspore/lite/tools/converter/quantizer/quantize_util.h" +#include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" +#include "src/common/utils.h" +#include "abstract/abstract_value.h" +#include "securec/include/securec.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { +const std::array QuantStrategy::mConvTypes = { + {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; +const std::array QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}}; + +QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold) + : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} + +bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { + size_t i = 0; + for (i = 0; i < mConvTypes.size(); i++) { + if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { + break; + } + } + + if ((i == mConvTypes.size()) || (node->size() < 3)) { + return false; + } + + auto inputNode = node->input(2); + if (!inputNode->isa()) { + return false; + } + auto paramNode = inputNode->cast(); + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + return false; + } + + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < mWeightSize) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } + if (weight_shape[0] <= mConvWeightQuantChannelThreshold) { + MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; + return false; + } + + return true; +} + +bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { + if (!node->isa()) { + return false; + } + auto cnode = std::dynamic_pointer_cast(node); + + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + return false; + } + + auto type = primitiveT_value->GetPrimitiveT()->value.type; + MS_LOG(INFO) << "Primitive type: " << type; + static const std::vector uint8OpList = { + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Activation}; + return IsContain(uint8OpList, type); +} + +bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { + size_t i = 0; + for (i = 0; i < mMulTypes.size(); i++) { + if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { + break; + } + } + if (i == mMulTypes.size()) { + return false; + } + + if (node->size() < 3) { + MS_LOG(INFO) << "input size less!"; + return false; + } + + auto inputNode1 = node->input(1); + auto inputNode2 = node->input(2); + if (inputNode1 == nullptr || inputNode2 == nullptr) { + MS_LOG(INFO) << "mul input is nullptr!"; + return false; + } + + ParameterPtr paramNode = nullptr; + if (inputNode1->isa()) { + paramNode = inputNode1->cast(); + } else if (inputNode2->isa()) { + paramNode = inputNode2->cast(); + } + + if (paramNode == nullptr) { + MS_LOG(INFO) << "invalid paramNode!"; + return false; + } + + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + MS_LOG(INFO) << "abstract is nullptr"; + return false; + } + + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < mWeightSize) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } + + return true; +} + +void CalFakeNode(const AnfNodePtr &inTensor) { + // MS_ASSERT(inTensor != nullptr); + // MS_ASSERT(inTensor->dataType == DataType_DT_FLOAT); + // auto quantParam = GetTensorQuantParams(inTensor); + // if (quantParam == nullptr || !quantParam->inited) { + // MS_LOGW("tensor quantParam has not been inited"); + // return; + // } + + // float quantMin = quantParam->narrowRange ? 1 : 0; + // float quantMax = (1 << (unsigned int)(quantParam->numBits)) - 1; + // const float scale = quantParam->scale; + // const float nudgedMin = (quantMin - quantParam->zeroPoint) * scale; + // const float nudgedMax = (quantMax - quantParam->zeroPoint) * scale; + // // cal output + // float invNudgeScale = 1.0f / scale; + // void *inData = inTensor->data.data(); + // if(inData == nullptr) { + // MS_LOGE("null pointer dereferencing."); + // return; + // } + // auto *data = static_cast(inData); + // for (size_t i = 0; i < GetShapeSize(*inTensor); i++) { + // float clamped = std::min(nudgedMax, std::max(nudgedMin, data[i])); + // float clampedShifted = clamped - nudgedMin; + // data[i] = std::round(clampedShifted * invNudgeScale) * scale + nudgedMin; + // } +} + +STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, + double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) { + MS_ASSERT(quantParam != nullptr); + if (mMin > 0.0f) { + MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; + mMin = 0.0f; + } + if (mMax < 0.0f) { + MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; + mMax = 0.0f; + } + if (mMin > mMax) { + MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; + return RET_PARAM_INVALID; + } + if (mMin == mMax) { + if (mMin != 0.0f) { + MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; + return RET_ERROR; + } + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = 0.0f; + quantParam->zeroPoint = 0; + quantParam->narrowRange = narrowRange; + quantParam->numBits = num_bits; + return RET_OK; + } + + auto quantMinFloat = static_cast(quant_min); + auto quantMaxFloat = static_cast(quant_max); + double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); + const double zeroPointFromMin = quantMinFloat - mMin / scale; + // const double zeroPointFromMax = quantMaxFloat - mMax / scale; + int zeroPoint = static_cast(std::round(zeroPointFromMin)); + + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + MS_ASSERT(zeroPoint >= quantMin); + MS_ASSERT(zeroPoint <= quantMax); + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = scale; + quantParam->zeroPoint = zeroPoint; + quantParam->narrowRange = narrowRange; + quantParam->numBits = num_bits; + + return RET_OK; +} + +STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) { + auto dims = weightPtr->tensor_shape(); + if (dims.size() < 1) { + MS_LOG(ERROR) << "weight dims size error"; + return RET_ERROR; + } + uint32_t channels = dims[0]; + if (channels == 0) { + MS_LOG(ERROR) << "channels error 0"; + return RET_ERROR; + } + + size_t shapeSize = weightPtr->tensor_shape_size(); + size_t oneFilterSize = shapeSize / channels; + auto *rawDatas = reinterpret_cast(weightPtr->tensor_addr()); + if (rawDatas == nullptr) { + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; + } + + weightPtr->quant_param().clear(); + vector qDatas(shapeSize); + for (uint32_t i = 0; i < channels; i++) { + float min = 0; + float max = 0; + // find min and max + for (uint32_t j = 0; j < oneFilterSize; j++) { + min = std::min(min, rawDatas[j + i * oneFilterSize]); + max = std::max(max, rawDatas[j + i * oneFilterSize]); + } + + std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); + STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + // update data and datatype + for (uint32_t j = 0; j < oneFilterSize; j++) { + float rawData = rawDatas[j + i * oneFilterSize]; + auto qData = QuantizeData(rawData, quantParam.get()); + qDatas[j + i * oneFilterSize] = qData; + } + + weightPtr->set_quant_param(quantParam); + } + auto ret = memcpy_s(const_cast(rawDatas), weightPtr->tensor_size(), + qDatas.data(), shapeSize * sizeof(uint8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + if (quantType == QuantType_WeightQuant) { + PostBitPack(const_cast(rawDatas), shapeSize, bitNum); + } + + weightPtr->set_tensor_type(kNumberTypeInt8); + weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); + + return RET_OK; +} + +STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { + auto *rawDatas = reinterpret_cast(weight); + vector qDatas(rawDatas, rawDatas + shapeSize); + vector qDatas_packed; + if (bitNum < 8 && bitNum > 1) { + BitPack weight_bitpack(bitNum); + weight_bitpack.BitPacking(qDatas, qDatas_packed); + if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; + return RET_ERROR; + } + } else if (bitNum == 8) { + if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; + return RET_ERROR; + } + + return RET_OK; +} +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h new file mode 100644 index 0000000000..b310d586ae --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef QUANTIZER_UTIL_H +#define QUANTIZER_UTIL_H + +#include +#include +#include +#include +#include "include/errorcode.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "ir/primitive.h" +#include "abstract/dshape.h" +#include "mindspore/lite/tools/converter/quantizer/quantizer.h" + +namespace mindspore { +namespace lite { +namespace quant { + +static constexpr size_t UINT8_QUANTIZATION = 8; + +/** + * 1. when op's weight size > mWeightSize just skip + * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization + * 3. when conv/deconv/convdepthwise/deconvdepthwise ops' weight channel size > covWeightQuantChannelThreshold just skip + * */ +class QuantStrategy { + public: + explicit QuantStrategy(size_t weightSize, size_t covWeightQuantChannelThreshold = 16); + + ~QuantStrategy() = default; + + bool CanConvOpQuantized(const CNodePtr &node) const; + bool CanMulOpQuantized(const CNodePtr &node) const; + bool CanOpPostQuantized(AnfNodePtr &node) const; + + private: + size_t mWeightSize; + size_t mConvWeightQuantChannelThreshold; + + static const std::array mConvTypes; + static const std::array mMulTypes; +}; + +STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, double mMax, + bool narrowRange, int quant_max, int quant_min, int num_bits); + +template +T QuantizeData(const float originData, const AnfQuantParam *quantParam) { + MS_ASSERT(quantParam != nullptr); + MS_ASSERT(quantParam->inited); + const auto scale = quantParam->scale; + const auto zeroPoint = quantParam->zeroPoint; + const auto numBit = quantParam->numBits; + const auto narrowRange = quantParam->narrowRange; + const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; + double minLimit; + if (narrowRange) { + minLimit = static_cast(1 - zeroPoint) * scale; + } else { + minLimit = static_cast(0 - zeroPoint) * scale; + } + return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { + double tmp = 0.0f; + if (originData > maxLimit) { + tmp = maxLimit; + } else if (originData < minLimit) { + tmp = minLimit; + } else { + tmp = originData; + } + auto quantData = static_cast(std::round(tmp / scale + zeroPoint)); + if (quantData == 0 && narrowRange) { + quantData++; + } + return quantData; + }(); +} + +void CalFakeNode(const AnfNodePtr &inTensor); + +STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, + size_t bitNum = UINT8_QUANTIZATION); + +STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); + +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.cc b/mindspore/lite/tools/converter/quantizer/quantizer.cc new file mode 100644 index 0000000000..3480705c62 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantizer.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/quantizer/quantizer.h" + +namespace mindspore { +namespace lite { +namespace quant { +Quantizer::Quantizer(FuncGraphPtr graph) : funcGraph(graph) { + if (funcGraph == nullptr) { + return; + } +} + +STATUS Quantizer::GenerateQuantParam() { return RET_OK; } + +STATUS Quantizer::RemoveFakeQuant() { return RET_OK; } + +STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; } +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h new file mode 100644 index 0000000000..19284052f3 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_QUANTIZER_H +#define MS_QUANTIZER_H + +#include +#include "include/errorcode.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "src/param_value_lite.h" + +namespace mindspore { +namespace lite { +namespace quant { +using STATUS = int; +enum QuantType { + QuantType_QUANT_NONE = 0, + QuantType_AwareTraining = 1, + QuantType_WeightQuant = 2, + QuantType_PostTraining = 3, + QuantType_MIN = QuantType_QUANT_NONE, + QuantType_MAX = QuantType_PostTraining +}; + +class Quantizer { + public: + explicit Quantizer(FuncGraphPtr graph); + + ~Quantizer() = default; + + virtual STATUS RemoveFakeQuant(); + + virtual STATUS GenerateQuantParam(); + + virtual STATUS DetermineNodeQuantType(); + + virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; + + protected: + FuncGraphPtr funcGraph = nullptr; +}; +} // namespace quant +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc new file mode 100644 index 0000000000..362ff693f5 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/weight_quantizer.h" +#include +#include +#include "src/common/common.h" +#include "ir/dtype/type_id.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { + +WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, + const std::string &convWeightChannelThreshold, const std::string &bitNum) + : Quantizer(graph) { + auto quantSize = static_cast(std::stoull(weightSize)); + this->bitNum = static_cast(std::stoull(bitNum)); + auto convQuantWeightChannelThreshold = static_cast(std::stoull(convWeightChannelThreshold)); + // TODO(...): update stractory + mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); +} + +// uint32_t GetConvChannel(TensorDefT *weight) { +// uint32_t channel = 0; +// const vector dims = weight->dims; + +// switch (weight->format) { +// case Format_NCHW: +// case Format_KCHW: +// case Format_NC4HW4: +// channel = static_cast(dims[NCHW_N]); +// break; +// case Format_NHWC: +// case Format_HWKC: +// channel = static_cast(dims[NHWC_N]); +// break; +// case Format_HWCK: +// channel = static_cast(dims[HWCK_K]); +// break; +// case Format_CKHW: +// channel = static_cast(dims[CKHW_K]); +// break; +// default: +// MS_LOGE("Unsupported format: %d", weight->format); +// return 0; +// } +// return channel; +// } + +STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { + for (auto &cnode : nodes) { + if (!mStrategy->CanConvOpQuantized(cnode)) { + continue; + } + + auto inputNode = cnode->input(2); + if (!inputNode->isa()) { + return RET_ERROR; + } + + auto paramNode = inputNode->cast(); + if (!paramNode->has_default()) { + return RET_ERROR; + } + + ParamValueLitePtr paramValue = std::static_pointer_cast(paramNode->default_param()); + auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { + for (auto &node : nodes) { + if (!mStrategy->CanMulOpQuantized(node)) { + continue; + } + + ParamValueLitePtr paramValue = nullptr; + for (size_t i = 1; i < node->size(); i++) { + auto inputNode = node->input(i); + if (inputNode->isa() == true) { + auto paramNode = inputNode->cast(); + if ((paramNode != nullptr) && (paramNode->has_default() == true)) { + paramValue = std::static_pointer_cast(paramNode->default_param()); + if ((paramValue == nullptr) || (paramValue->tensor_size() == 0) + || (paramValue->tensor_shape().size() != 4) + || (paramValue->tensor_addr() == nullptr) + || (paramValue->tensor_type() != mindspore::kNumberTypeFloat32)) { + paramValue = nullptr; + continue; + } else { + break; + } + } + } + } + if (paramValue == nullptr) { + MS_LOG(ERROR) << "No valid input param node !"; + continue; + } + auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "QunatFilter failed" << status; + return RET_ERROR; + } + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { + auto ret = RET_OK; + auto cnodes = funcGraph->GetOrderedCnodes(); + ret = DoConvQuantize(cnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; + return ret; + } + ret = DoMulQuantize(cnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; + return ret; + } + return ret; +} +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h new file mode 100644 index 0000000000..0726dd3df1 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef WEIGHT_QUANTIZER_H +#define WEIGHT_QUANTIZER_H + +#include +#include +#include +#include "tools/converter/quantizer/quantizer.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace lite { +namespace quant { +class WeightQuantizer : public Quantizer { + public: + WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize, + const std::string& covWeightChannelThreshold, const std::string& bitNum); + + ~WeightQuantizer() = default; + + STATUS DoQuantize(FuncGraphPtr funcGraph) override; + STATUS DoConvQuantize(const std::list &nodes); + STATUS DoMulQuantize(const std::list &nodes); + + private: + std::unique_ptr mStrategy; + size_t bitNum; +}; +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif + diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc new file mode 100644 index 0000000000..fe091a4463 --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -0,0 +1,333 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/common/gllo_utils.h" +#include +#include "src/ir/primitive_t_value.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr auto kAnfPrimitiveIndex = 0; +bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); +} + +bool IsRealKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real kernel too + if (!node->isa()) { + return true; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); + } + auto input = cnode->inputs()[0]; + bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || + IsPrimitive(input, prim::kPrimTensorSummary) || + IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || + IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || + IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || + IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); + return !is_virtual_node; +} + +ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + return nullptr; +} + +CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + return nullptr; +} + +VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); + return std::make_shared(utils::cast(sexp), nullptr); + } + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); + return std::make_shared(utils::cast(sexp), utils::cast(graph)); + } + MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); + return nullptr; +} + +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { + MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + std::vector input_nodes; + const auto &tuple = utils::cast(sexp); + if (multigraph && utils::isa(graph)) { + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); + input_nodes.push_back(node); + } + VarPtr var_ptr = utils::cast(graph); + return std::make_shared(input_nodes, var_ptr); + } + + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); + input_nodes.push_back(node); + } + return CreateCNodeWithGraph(input_nodes, graph); +} +} // namespace + +bool AnfEqual(const BaseRef &a, const BaseRef &b) { + if (utils::isa(a) && utils::isa(b)) { + auto a_node = utils::cast(a); + auto b_node = utils::cast(b); + MS_EXCEPTION_IF_NULL(a_node); + MS_EXCEPTION_IF_NULL(b_node); + if (IsValueNode(a_node) && IsValueNode(b_node)) { + auto a_value_node = a_node->cast(); + MS_EXCEPTION_IF_NULL(a_value_node); + auto a_value = a_value_node->value(); + MS_EXCEPTION_IF_NULL(a_value); + auto a_prim = a_value->cast(); + MS_EXCEPTION_IF_NULL(a_prim); + + auto b_value_node = b_node->cast(); + MS_EXCEPTION_IF_NULL(b_value_node); + auto b_value = b_value_node->value(); + MS_EXCEPTION_IF_NULL(b_value); + auto b_prim = b_value->cast(); + MS_EXCEPTION_IF_NULL(b_prim); + + return a_prim->name() == b_prim->name(); + } else if (a_node->isa() && b_node->isa()) { + auto a_value_node_ptr = a_node->cast(); + if (a_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto a_value_ptr = a_value_node_ptr->value(); + if (a_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + auto b_value_node_ptr = b_node->cast(); + if (b_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto b_value_ptr = b_value_node_ptr->value(); + if (b_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { + auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get()); + auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get()); + return (*a_obj) == (*b_obj); + } else { + return (*a_value_ptr) == (*b_value_ptr); + } + } + } + if (a.m_ptr->isa() && b.m_ptr->isa()) { + auto a_value_node_ptr = a.m_ptr->cast(); + auto b_value_node_ptr = b.m_ptr->cast(); + return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; + } + + return a == b; +} + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { + // To matchCNode and Kernel's type + if (utils::isa(a) && utils::isa(b)) { + return true; + } + return a.type() == b.type(); +} + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { + MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_EXCEPTION_IF_NULL(primitive_vars); + if (utils::isa(sexp)) { + return HandleSexpVector(sexp, graph, primitive_vars, multigraph); + } + if (utils::isa(sexp)) { + auto var_ptr = utils::cast(sexp); + MS_EXCEPTION_IF_NULL(var_ptr); + if (var_ptr->primitive()) { + (*primitive_vars)[var_ptr->primitive()] = var_ptr; + return NewValueNode(var_ptr->primitive()); + } + return CreateVarNodeWithSexp(sexp, graph); + } + if (utils::isa(sexp)) { + return utils::cast(sexp); + } + auto value_node = CreateValueNodeWithSexp(sexp); + if (value_node == nullptr) { + MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); + } + return value_node; +} + +bool IsRealCNodeKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real cnode kernel + if (!node->isa()) { + return false; + } + // return considered as a real node + if (CheckPrimitiveType(node, prim::kPrimReturn)) { + return true; + } + return IsRealKernel(node); +} +bool IsGraphKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // graph kernel should be a real cnode kernel. + if (!IsRealCNodeKernel(node)) { + return false; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input = cnode->input(kAnfPrimitiveIndex); + // graph kernel should has func_graph as first input. + if (!IsValueNode(input)) { + return false; + } + + auto func_graph = GetValueNode(input); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); +} + +void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { + if (graph == nullptr) { + MS_LOG(EXCEPTION) << "The graph is null."; + } +} + +void CheckIfAnfNodeIsNull(const AnfNodePtr &node) { + if (node == nullptr) { + MS_LOG(EXCEPTION) << "The AnfNode is null."; + } +} + +void CheckIfCNodeIsNull(const CNodePtr &node) { + if (node == nullptr) { + MS_LOG(EXCEPTION) << "The CNode is null."; + } +} + +void CheckIfVarIsNull(const VarPtr &var) { + if (var == nullptr) { + MS_LOG(EXCEPTION) << "The Var is null."; + } +} + +void CheckIfNodeIsParam(const AnfNodePtr &node) { + if (node != nullptr && !utils::isa(node)) { + MS_LOG(EXCEPTION) << "The Node is not param."; + } +} + +void CheckInputSize(const CNodePtr &node, const int size) { + if (node->inputs().size() != size) { + MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); + } +} + +void CheckLeastInputSize(const CNodePtr &node, const int size) { + if (node->inputs().size() < size) { + MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); + } +} + +AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, + const ParamValueLitePtr &weight_tensor) { + auto bias_parameter = func_graph->add_parameter(); + MS_ASSERT(bias_parameter != nullptr); + std::vector shape = {kernel_num}; + auto abstract_tensor = std::make_shared(TypeIdToType(weight_tensor->tensor_type()), shape); + bias_parameter->set_abstract(abstract_tensor); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_addr(bias_data); + param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t)); + bias_parameter->set_default_param(param_value); + return bias_parameter; +} + +schema::PrimitiveType GetCNodeType(const BaseRef &n) { + ValueNodePtr value_node; + if (utils::isa(n)) { + auto in = utils::cast(n); + value_node = in->input(0)->cast(); + } else if (utils::isa(n)) { + value_node = utils::cast(n); + } else { + MS_LOG(EXCEPTION) << "only value node or cnode has type"; + return schema::PrimitiveType_NONE; + } + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + if (utils::isa(value)) { + auto primitive = value->cast(); + MS_ASSERT(primitive != nullptr); + return primitive->GetPrimitiveT()->value.type; + } + return schema::PrimitiveType_NONE; +} + +bool IsParamNode(const BaseRef &n) { + return utils::isa(n); +} + +bool IsConvNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D; + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h new file mode 100644 index 0000000000..d96190e750 --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ + +#include +#include "src/ir/primitive_t_value.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "src/common/utils.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "schema/inner/model_generated.h" +#include "src/param_value_lite.h" + +using PrimitiveTValuePtr = std::shared_ptr; +namespace mindspore { +namespace opt { +bool AnfEqual(const BaseRef &a, const BaseRef &b); + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph = false); + +bool IsRealCNodeKernel(const AnfNodePtr &node); + +bool IsGraphKernel(const AnfNodePtr &node); + +void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph); + +void CheckIfAnfNodeIsNull(const AnfNodePtr &node); + +void CheckIfCNodeIsNull(const CNodePtr &node); + +void CheckIfVarIsNull(const VarPtr &var); + +void CheckInputSize(const CNodePtr &node, int size); + +void CheckIfNodeIsParam(const AnfNodePtr &node); + +void CheckLeastInputSize(const CNodePtr &node, int size); + +AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, + const ParamValueLitePtr &weight_tensor); + +schema::PrimitiveType GetCNodeType(const BaseRef &node); + +bool IsParamNode(const BaseRef &n); + +bool IsConvNode(const BaseRef &n); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/common/node_pass_extends.cc b/mindspore/lite/tools/optimizer/common/node_pass_extends.cc new file mode 100644 index 0000000000..e5ee11d62f --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/node_pass_extends.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/node_pass.h" + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { +bool NodePass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + std::unordered_set seen_node; + std::deque todo{func_graph->output()}; + bool changes = false; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + (void)seen_node.erase(node); + } else if (new_node == nullptr) { + new_node = node; + } + if (new_node && IsValueNode(new_node)) { + auto const_func_graph = GetValueNode(new_node); + MS_EXCEPTION_IF_NULL(const_func_graph); + todo.push_back(const_func_graph->output()); + } else if (new_node && new_node->isa()) { + if (IsGraphKernel(new_node)) { + todo.push_back(new_node); + } + auto cnode = new_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + changes = changes || change; + } + return changes; +} +} // namespace opt +} // namespace mindspore + diff --git a/mindspore/lite/tools/optimizer/common/optimizer.h b/mindspore/lite/tools/optimizer/common/optimizer.h new file mode 100644 index 0000000000..62191a2053 --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/optimizer.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/graph_utils.h" +#include "src/common/utils.h" + +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { +using PatternListType = std::initializer_list; + +class PatternProcessPass : public NodePass { + public: + explicit PatternProcessPass(const std::string &name = "", bool multigraph = true); + ~PatternProcessPass() override = default; + virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; + virtual const BaseRef DefinePattern() const; + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + void Build(); + + AnfNodePtr pattern_ = nullptr; + bool multigraph_ = true; + PatternEngine pattern_engine_; + PrimitiveVarMapPtr primitive_vars_; +}; + +class MultipleOutputPatternProcessPass : public PatternProcessPass { + public: + explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) + : PatternProcessPass(name, multigraph), + child_pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + child_primitive_vars_(std::make_shared()) {} + ~MultipleOutputPatternProcessPass() override = default; + virtual BaseRef DefineAnotherPattern() const = 0; + // check two patterns whether share the same nodes or not + virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; + + protected: + bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; + PatternEngine child_pattern_engine_; + PrimitiveVarMapPtr child_primitive_vars_; +}; + +class GraphOptimizer { + public: + explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} + virtual ~GraphOptimizer() = default; + + void AddPassManager(const PassManagerPtr &pass_manager); + FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true); + + private: + const std::string name_ = "graph_optimizer"; + std::vector pass_managers_{}; + bool run_only_once_ = true; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_ + diff --git a/mindspore/lite/tools/optimizer/common/pass.h b/mindspore/lite/tools/optimizer/common/pass.h new file mode 100644 index 0000000000..3a3b692744 --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_ +#include +#include + +#include "ir/anf.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +// @brief ANF Graph level optimization base pass +class Pass { + public: + explicit Pass(const std::string &name = "pass") : name_(name) {} + virtual ~Pass() = default; + virtual bool Run(const FuncGraphPtr &func_graph) = 0; + virtual std::string name() const { return name_; } + + private: + const std::string name_; +}; +using PassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/common/pass_manager_extends.cc b/mindspore/lite/tools/optimizer/common/pass_manager_extends.cc new file mode 100644 index 0000000000..5907700b6d --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/pass_manager_extends.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/pass_manager.h" + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +const std::vector &PassManager::Passes() const { return passes_; } + +void PassManager::AddPass(const PassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + bool changed = false; + size_t num = 0; + for (const auto &pass : passes) { + if (pass != nullptr) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time {}; + struct timeval end_time {}; + (void)gettimeofday(&start_time, nullptr); +#endif + if (pass->Run(func_graph)) { + MS_LOG(DEBUG) << "Run pass and find change"; + changed = true; + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; +#endif + num++; + } + } + return changed; +} + +bool PassManager::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc new file mode 100644 index 0000000000..dcd2fc90d6 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/conv_activation_fusion.h" +#include +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kActivationInputsLength = 2; +} +const BaseRef ConvActivationFusion::DefinePattern() const { + auto conv_var = std::make_shared(IsConvNode); + auto prim = new schema::PrimitiveT(); + prim->value.type = primitive_type; + auto prim_value = std::make_shared(prim); + + return VectorRef({prim_value, conv_var}); +} + +const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto act_node = node->cast(); + CheckIfCNodeIsNull(act_node); + CheckInputSize(act_node, kActivationInputsLength); + + auto act_primitive = GetValueNode>(act_node->input(0)); + if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) { + return node; + } + AnfNodePtr pre_node = act_node->input(1); + CheckIfAnfNodeIsNull(pre_node); + if (pre_node != nullptr && pre_node->isa()) { + auto conv_node = pre_node->cast(); + auto node_type = GetCNodeType(conv_node); + auto primitiveT_value = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitiveT_value); + if (node_type == schema::PrimitiveType_Conv2D) { + primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type; + return pre_node; + } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { + primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->activationType = activation_type; + return pre_node; + } else { + MS_LOG(EXCEPTION) << "conv activation pass match only conv2d or depthwise_conv2d "; + } + } + return node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h new file mode 100644 index 0000000000..27150fcdda --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ + +#include +#include "tools/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvActivationFusion : public PatternProcessPass { + public: + explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion", + schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, + schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) : primitive_type( + primitive), activation_type(activation), PatternProcessPass(name, multigraph) {} + ~ConvActivationFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + schema::PrimitiveType primitive_type; + schema::ActivationType activation_type; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc new file mode 100644 index 0000000000..2d8d6a8678 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -0,0 +1,159 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/conv_biasadd_fusion.h" +#include +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kAddInputsLength = 3; +constexpr size_t kAddWEIGHTINDEX = 2; +constexpr size_t kConvWeightIndex = 2; +constexpr size_t kConvBiasIndex = 3; +constexpr size_t kConvNoBiasLen = 3; +constexpr size_t kConvWithBiasLen = 4; +bool IsConvExtendNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D + || type == schema::PrimitiveType_DeConv2D; + } + return false; +} +bool IsAddNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Add || type == schema::PrimitiveType_BiasAdd; + } + return false; +} + +int Get_Kenrnel_nums(const CNodePtr &conv_node) { + MS_ASSERT(conv_node != nullptr); + auto value_primitive = conv_node->input(0); + auto value_node = value_primitive->cast(); + MS_ASSERT(value_node != nullptr); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + auto primitive = value->cast(); + MS_ASSERT(primitive != nullptr); + auto type = primitive->GetPrimitiveT()->value.type; + if (type == schema::PrimitiveType_Conv2D) { + return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; + } else if (type == schema::PrimitiveType_DepthwiseConv2D) { + return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier + * primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + } else if (type == schema::PrimitiveType_DeConv2D) { + return primitive->GetPrimitiveT()->value.AsDeConv2D()->channelOut; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << type; + return 0; + } +} +void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) { + AnfNodePtr conv_bias_node = nullptr; + AnfNodePtr conv_weight_node = nullptr; + if (conv_node->inputs().size() == kConvNoBiasLen) { + conv_weight_node = conv_node->input(kConvWeightIndex); + } else if (conv_node->inputs().size() == kConvWithBiasLen) { + conv_weight_node = conv_node->input(kConvWeightIndex); + conv_bias_node = conv_node->input(kConvBiasIndex); + } else { + MS_LOG(EXCEPTION) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4"; + } + auto kernel_nums = Get_Kenrnel_nums(conv_node); + if (kernel_nums <= 0) { + MS_LOG(EXCEPTION) << "kernel num less than 0"; + } + auto add_bias_data = new(std::nothrow) float[kernel_nums]; + auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); + CheckIfNodeIsParam(bias_add_weight); + auto add_weight_param = bias_add_weight->cast()->default_param(); + auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param); + auto add_weight_data = reinterpret_cast(add_weight_tensor->tensor_addr()); + + if (add_weight_tensor->tensor_shape().empty()) { + if (EOK != memset_s(add_bias_data, kernel_nums * sizeof(float), *add_weight_data, kernel_nums * sizeof(float))) { + MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed"; + } + } else { + if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { + MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed"; + } + } + if (conv_bias_node != nullptr) { + CheckIfNodeIsParam(conv_bias_node); + auto conv_bias_param = conv_bias_node->cast()->default_param(); + auto conv_bias_tensor = std::dynamic_pointer_cast(conv_bias_param); + if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) { + MS_LOG(EXCEPTION) << "conv_bias_node shape error"; + } + auto conv_bias_data = reinterpret_cast(conv_bias_tensor->tensor_addr()); + for (size_t i = 0; i < kernel_nums; i++) { + conv_bias_data[i] += add_bias_data[i]; + } + delete[] add_bias_data; + } else { + auto conv_weight_param = conv_weight_node->cast()->default_param(); + auto conv_weight_tensor = std::dynamic_pointer_cast(conv_weight_param); + auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); + conv_node->add_input(conv_new_bias); + } +} +} // namespace +const BaseRef ConvBiasaddFusion::DefinePattern() const { + auto conv_var = std::make_shared(IsConvExtendNode); + auto add_var = std::make_shared(IsAddNode); + auto weight_var = std::make_shared(IsParamNode); + return VectorRef({add_var, conv_var, weight_var}); +} + +const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "Enter pass process"; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto add_node = node->cast(); + CheckIfCNodeIsNull(add_node); + CheckInputSize(add_node, kAddInputsLength); + + AnfNodePtr conv_node_anf = add_node->input(1); + CheckIfAnfNodeIsNull(conv_node_anf); + auto conv_node = conv_node_anf->cast(); + CheckIfCNodeIsNull(conv_node); + GenConvNewBias(func_graph, conv_node, add_node); + auto primitiveT_value = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitiveT_value != nullptr); + auto type = primitiveT_value->GetPrimitiveT()->value.type; + if (type == schema::PrimitiveType_Conv2D) { + primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; + } else if (type == schema::PrimitiveType_DepthwiseConv2D) { + primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true; + } else if (type == schema::PrimitiveType_DeConv2D) { + primitiveT_value->GetPrimitiveT()->value.AsDeConv2D()->hasBias = true; + } else { + MS_LOG(EXCEPTION) << "Unsupported opType, " << type; + } + return conv_node; +} +} // namespace mindspore::opt + diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h new file mode 100644 index 0000000000..b630c492dc --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ + +#include "tools/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvBiasaddFusion : public PatternProcessPass { + public: + explicit ConvBiasaddFusion(bool multigraph = true) : PatternProcessPass("conv_biasadd_fusion", multigraph) {} + ~ConvBiasaddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ + diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc new file mode 100644 index 0000000000..5a39798d58 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/conv_bn_fusion.h" +#include +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kCaffeBNMeanIndex = 2; +constexpr size_t kCaffeBNVarIndex = 3; +constexpr size_t kTFBNScaleIndex = 2; +constexpr size_t kTFBNBiasIndex = 3; +constexpr size_t kTFBNMeanIndex = 4; +constexpr size_t kTFBNVarIndex = 5; +constexpr const float EPS = 1e-8; +constexpr const float EPS_DEFAULT_FLOAT = 1e-5; +constexpr const float POW_NUM = 0.5; +bool IsBatchNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_BatchNorm || type == schema::PrimitiveType_FusedBatchNorm; + } + return false; +} +void CalTransale(const AnfNodePtr &bn_scale_node, const AnfNodePtr &bn_var_node, float *trans_scale, float eps, + int kernel_num) { + auto bn_var_param = bn_var_node->cast()->default_param(); + auto bn_var_tensor = std::dynamic_pointer_cast(bn_var_param); + auto bn_var_data = reinterpret_cast(bn_var_tensor->tensor_addr()); + // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) + if (memcpy_s(trans_scale, kernel_num * sizeof(float), bn_var_data, kernel_num * sizeof(float)) != EOK) { + MS_LOG(EXCEPTION) << "memcpy_s transScale error"; + return; + } + // 1/sqrt(variance + eps) + for (int32_t i = 0; i < kernel_num; i++) { + float tmp = trans_scale[i] + eps; + tmp = pow(tmp, POW_NUM); + trans_scale[i] = 1 / tmp; + } + if (bn_scale_node != nullptr) { + auto bn_scale_param = bn_scale_node->cast()->default_param(); + auto bn_scale_tensor = std::dynamic_pointer_cast(bn_scale_param); + auto bn_scale_data = reinterpret_cast(bn_scale_tensor->tensor_addr()); + // scale/sqrt(variance + eps) + for (int32_t i = 0; i < kernel_num; i++) { + trans_scale[i] *= bn_scale_data[i]; + } + } +} +void CalTransBias(const AnfNodePtr &bn_mean_node, const AnfNodePtr &bn_bias_node, const float *trans_scale, + float *trans_bias, int kernel_num) { + auto bn_mean_param = bn_mean_node->cast()->default_param(); + auto bn_mean_tensor = std::dynamic_pointer_cast(bn_mean_param); + auto bn_mean_data = reinterpret_cast(bn_mean_tensor->tensor_addr()); + // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) + // -mean/sqrt(variance + eps) + for (int32_t i = 0; i < kernel_num; i++) { + trans_bias[i] = -bn_mean_data[i] * trans_scale[i]; + } + + if (bn_bias_node != nullptr) { + auto bn_bias_param = bn_bias_node->cast()->default_param(); + auto bn_bias_tensor = std::dynamic_pointer_cast(bn_bias_param); + auto bn_bias_data = reinterpret_cast(bn_bias_tensor->tensor_addr()); + // -scale*mean/sqrt(variance + eps) + bias + for (int32_t i = 0; i < kernel_num; i++) { + trans_bias[i] += bn_bias_data[i]; + } + } +} +} // namespace +const BaseRef ConvBatchNormFusion::DefinePattern() const { + auto conv_var = std::make_shared(IsConvNode); + auto bn_var = std::make_shared(IsBatchNode); + auto bn_mean_var = std::make_shared(IsParamNode); + auto bn_variable_var = std::make_shared(IsParamNode); + auto bn_other_var = std::make_shared(); + return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});; +} +// BatchNorm weight Tensor definition: +// caffe +// estimated_mean --0 +// estimated_variance --1 +// tensorflow +// scale -- 0 +// bias --1 +// estimated_mean --2 +// estimated_variance --3 +const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, + float *trans_bias) const { + MS_ASSERT(bn_node != nullptr); + AnfNodePtr bn_mean_node = nullptr; + AnfNodePtr bn_variance_node = nullptr; + AnfNodePtr bn_scale_node = nullptr; + AnfNodePtr bn_bias_node = nullptr; + float eps = 0; + auto primitiveT_value = GetValueNode>(bn_node->input(0)); + if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) { + bn_mean_node = bn_node->input(kCaffeBNMeanIndex); + bn_variance_node = bn_node->input(kCaffeBNVarIndex); + CheckIfNodeIsParam(bn_mean_node); + CheckIfNodeIsParam(bn_variance_node); + eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon; + } else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) { + bn_scale_node = bn_node->input(kTFBNScaleIndex); + bn_bias_node = bn_node->input(kTFBNBiasIndex); + bn_mean_node = bn_node->input(kTFBNMeanIndex); + bn_variance_node = bn_node->input(kTFBNVarIndex); + eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon; + } else { + MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op."; + } + CheckIfNodeIsParam(bn_mean_node); + CheckIfNodeIsParam(bn_variance_node); + if (eps < EPS) { + eps = EPS_DEFAULT_FLOAT; + } + + CalTransale(bn_scale_node, bn_variance_node, trans_scale, eps, kernel_num); + CalTransBias(bn_mean_node, bn_bias_node, trans_scale, trans_bias, kernel_num); +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h new file mode 100644 index 0000000000..201e582fcd --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ + +#include "tools/optimizer/fusion/conv_transform_fusion.h" + +namespace mindspore::opt { +class ConvBatchNormFusion : public ConvTransformFusion { + public: + explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {} + ~ConvBatchNormFusion() override = default; + const BaseRef DefinePattern() const override; + const void InitTransParam(const CNodePtr &, int, float *, float *) const override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc new file mode 100644 index 0000000000..10aadeda4c --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/conv_scale_fusion.h" +#include +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "include/errorcode.h" +#include "securec/include/securec.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kScaleWeightIndex = 2; +constexpr size_t kScaleBiasIndex = 3; +constexpr size_t kScaleNoBiasLen = 3; +constexpr size_t kScaleWithBiasLen = 4; +bool IsScaleNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Scale; + } + return false; +} +} // namespace + +const BaseRef ConvScaleFusion::DefinePattern() const { + auto conv_var = std::make_shared(IsConvNode); + auto bn_var = std::make_shared(IsScaleNode); + auto weight_var = std::make_shared(IsParamNode); + auto bias_var = std::make_shared(); + + return VectorRef({bn_var, conv_var, weight_var, bias_var}); +} +const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, + float *trans_bias) const { + MS_ASSERT(scale_node != nullptr); + AnfNodePtr scale_weight_node; + AnfNodePtr scale_bias_node; + if (scale_node->inputs().size() == kScaleNoBiasLen) { + scale_weight_node = scale_node->input(kScaleWeightIndex); + } else if (scale_node->inputs().size() == kScaleWithBiasLen) { + scale_weight_node = scale_node->input(kScaleWeightIndex); + scale_bias_node = scale_node->input(kScaleBiasIndex); + } else { + MS_LOG(EXCEPTION) << "Scale should has 2 or 3 input tensors, current inputs is" << scale_node->inputs().size(); + } + if (!scale_weight_node->isa()) { + MS_LOG(EXCEPTION) << "scale weight node not paramter node"; + } + if (scale_bias_node != nullptr && !scale_bias_node->isa()) { + MS_LOG(EXCEPTION) << "scale bias node not paramter node"; + } + auto scale_weight_param = scale_weight_node->cast()->default_param(); + auto weight_value = std::dynamic_pointer_cast(scale_weight_param); + auto weight_data = reinterpret_cast(weight_value->tensor_addr()); + + if (EOK != memcpy_s(trans_scale, kernel_num * sizeof(float), weight_data, kernel_num * sizeof(float))) { + MS_LOG(EXCEPTION) << "memcpy_s transScale failed"; + } + + if (scale_bias_node != nullptr) { + auto scale_bias_param = scale_bias_node->cast()->default_param(); + auto bias_value = std::dynamic_pointer_cast(scale_bias_param); + auto bias_data = reinterpret_cast(bias_value->tensor_addr()); + if (EOK != memcpy_s(trans_bias, kernel_num * sizeof(float), bias_data, kernel_num * sizeof(float))) { + MS_LOG(EXCEPTION) << "memcpy_s transScale failed"; + } + } +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h new file mode 100644 index 0000000000..969490dd59 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ + +#include "tools/optimizer/fusion/conv_transform_fusion.h" + +namespace mindspore::opt { +class ConvScaleFusion : public ConvTransformFusion { + public: + explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {} + ~ConvScaleFusion() override = default; + const BaseRef DefinePattern() const override; + const void InitTransParam(const CNodePtr &, int, float *, float *) const override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc new file mode 100644 index 0000000000..ee7e56de6c --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/conv_transform_fusion.h" +#include +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "include/errorcode.h" +#include "securec/include/securec.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kConvWeightIndex = 2; +constexpr size_t kConvBiasIndex = 3; +constexpr size_t kConvNoBiasLen = 3; +constexpr size_t kConvWithBiasLen = 4; + +int Get_Kenrnel_nums(const CNodePtr &conv_node) { + MS_ASSERT(conv_node != nullptr); + auto value_primitive = conv_node->input(0); + auto value_node = value_primitive->cast(); + MS_ASSERT(value_node != nullptr); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + auto primitive = value->cast(); + MS_ASSERT(primitive != nullptr); + auto type = primitive->GetPrimitiveT()->value.type; + if (type == schema::PrimitiveType_Conv2D) { + return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; + } else if (type == schema::PrimitiveType_DepthwiseConv2D) { + return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier + * primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << type; + return 0; + } +} +} // namespace + +const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "conv activation pass process"; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + // transform node means scale,bn + auto transform_node = node->cast(); + CheckIfCNodeIsNull(transform_node); + CheckLeastInputSize(transform_node, 2); + + auto pre_node = transform_node->input(1); + auto conv_node = pre_node->cast(); + + int kernel_nums = Get_Kenrnel_nums(conv_node); + if (kernel_nums <= 0) { + MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString(); + return node; + } + auto trans_scale = new(std::nothrow) float[kernel_nums]; + auto trans_bias = new(std::nothrow) float[kernel_nums]; + GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias); + GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); + delete[] trans_bias; + delete[] trans_scale; + auto primitiveT_value = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitiveT_value != nullptr); + auto type = primitiveT_value->GetPrimitiveT()->value.type; + if (type == schema::PrimitiveType_Conv2D) { + primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; + } else if (type == schema::PrimitiveType_DepthwiseConv2D) { + primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true; + } else { + MS_LOG(EXCEPTION) << "Unsupported opType, " << type; + } + return pre_node; +} + +const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, + float *trans_scale, float *trans_bias) const { + if (trans_scale == nullptr) { + MS_LOG(EXCEPTION) << "new transScale failed"; + } + if (trans_bias == nullptr) { + MS_LOG(EXCEPTION) << "new transBias failed"; + } + if (0 != memset_s(trans_scale, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float))) { + MS_LOG(EXCEPTION) << "memset transScale failed"; + } + if (0 != memset_s(trans_bias, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float))) { + MS_LOG(EXCEPTION) << "memset transBias failed"; + } + + InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias); +} + +const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, + int kernel_num, const float *trans_scale, const float *trans_bias) +const { + MS_ASSERT(conv_node != nullptr); + AnfNodePtr conv_weight_node = nullptr; + AnfNodePtr conv_bias_node = nullptr; + if (conv_node->inputs().size() == kConvNoBiasLen) { + conv_weight_node = conv_node->input(kConvWeightIndex); + } else if (conv_node->inputs().size() == kConvWithBiasLen) { + conv_weight_node = conv_node->input(kConvWeightIndex); + conv_bias_node = conv_node->input(kConvBiasIndex); + } else { + MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4"; + return; + } + if (!conv_weight_node->isa()) { + MS_LOG(EXCEPTION) << "scale weight node not paramter node"; + } + if (conv_bias_node != nullptr && !conv_bias_node->isa()) { + MS_LOG(EXCEPTION) << "scale bias node not paramter node"; + } + + auto conv_weight_param = conv_weight_node->cast()->default_param(); + auto weight_tensor = std::dynamic_pointer_cast(conv_weight_param); + auto weight_data = reinterpret_cast(weight_tensor->tensor_addr()); + if (kernel_num <= 0) { + MS_LOG(EXCEPTION) << "kernel num less than 0"; + } + auto kernel_size = weight_tensor->tensor_shape_size() / kernel_num; + + CalNewWeightTensor(weight_data, kernel_num, kernel_size, trans_scale); + + float *bias_data = nullptr; + // conv has bias,bias_flag true + bool bias_flag = false; + if (conv_bias_node != nullptr) { + auto bias_weight_param = conv_weight_node->cast()->default_param(); + auto bias_tensor = std::dynamic_pointer_cast(bias_weight_param); + bias_data = reinterpret_cast(bias_tensor->tensor_addr()); + bias_flag = true; + } else { + bias_data = new(std::nothrow) float[kernel_num]; + } + CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias); + if (!bias_flag) { + auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor); + conv_node->add_input(bias_node); + } +} +const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, + const float *trans_scale) const { + MS_ASSERT(weight_data != nullptr); + auto tmp_weight_data = new(std::nothrow) float[kernel_num * kernel_size]; + MS_ASSERT(new_weight_data != nullptr); + auto data_size = kernel_num * kernel_size * sizeof(float); + if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) { + MS_LOG(EXCEPTION) << "memset newWeightData failed"; + return; + } + + for (size_t i = 0; i < kernel_num; i++) { + for (size_t j = 0; j < kernel_size; j++) { + tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i]; + } + } + + auto ret = memcpy_s(weight_data, data_size, tmp_weight_data, data_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy error: " << ret; + } + + delete[] tmp_weight_data; +} +const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, + const float *trans_scale, const float *trans_bias) const { + MS_ASSERT(bias_data != nullptr); + if (bias_flag) { + auto tmp_bias_data = new(std::nothrow) float[kernel_num]; + if (EOK != memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) { + MS_LOG(EXCEPTION) << "memset bias data failed"; + } + for (size_t i = 0; i < kernel_num; i++) { + tmp_bias_data[i] = bias_data[i] * trans_scale[i] + trans_bias[i]; + } + + auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), tmp_bias_data, kernel_num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy error: " << ret; + } + delete[] tmp_bias_data; + } else { + if (EOK != memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) { + MS_LOG(EXCEPTION) << "memset bias data failed"; + } + auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), trans_bias, kernel_num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy error: " << ret; + } + } +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h new file mode 100644 index 0000000000..38ff34a554 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + *conv_activation_fusion.h + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ + +#include +#include "tools/optimizer/common/optimizer.h" + +namespace mindspore::opt { +class ConvTransformFusion : public PatternProcessPass { + public: + explicit ConvTransformFusion(bool multigraph = true, const std::string &name = "conv_transform_fusion") + : PatternProcessPass(name, multigraph) {} + ~ConvTransformFusion() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + const void GenTransParam(const CNodePtr &, int, float *, float *) const; + virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; + const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; + const void CalNewWeightTensor(float *, int, int, const float *) const; + const void CalNewBiasTensor(float *, int, bool, const float *, const float *) const; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ diff --git a/mindspore/lite/tools/time_profile/CMakeLists.txt b/mindspore/lite/tools/time_profile/CMakeLists.txt new file mode 100644 index 0000000000..410ec8f2b1 --- /dev/null +++ b/mindspore/lite/tools/time_profile/CMakeLists.txt @@ -0,0 +1,18 @@ +# add shared link library + +set(COMMON_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc + ) + +add_executable(timeprofile + ${CMAKE_CURRENT_SOURCE_DIR}/main.cc + ${CMAKE_CURRENT_SOURCE_DIR}/time_profile.cc + ${COMMON_SRC}) + +if (PLATFORM_ARM32 OR PLATFORM_ARM64) + target_link_libraries(timeprofile mindspore-lite) +else() + target_link_libraries(timeprofile mindspore-lite pthread) +endif() diff --git a/mindspore/lite/tools/time_profile/main.cc b/mindspore/lite/tools/time_profile/main.cc new file mode 100644 index 0000000000..73b537a1b7 --- /dev/null +++ b/mindspore/lite/tools/time_profile/main.cc @@ -0,0 +1,19 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/time_profile/time_profile.h" + +int main(int argc, const char **argv) { return mindspore::lite::RunTimeProfile(argc, argv); } diff --git a/mindspore/lite/tools/time_profile/time_profile.cc b/mindspore/lite/tools/time_profile/time_profile.cc new file mode 100644 index 0000000000..65b375c284 --- /dev/null +++ b/mindspore/lite/tools/time_profile/time_profile.cc @@ -0,0 +1,372 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/time_profile/time_profile.h" +#define __STDC_FORMAT_MACROS +#include +#undef __STDC_FORMAT_MACROS +#include +#include +#include +#include "include/ms_tensor.h" +#include "utils/log_adapter.h" +#include "include/context.h" + +namespace mindspore { +namespace lite { +int TimeProfile::GenerateRandomData(size_t size, void *data) { + MS_ASSERT(data != nullptr); + char *castedData = static_cast(data); + for (size_t i = 0; i < size; i++) { + castedData[i] = static_cast(i); + } + return RET_OK; +} + +int TimeProfile::GenerateInputData() { + for (auto tensor : ms_inputs_) { + MS_ASSERT(tensor != nullptr); + auto input_data = tensor->MutableData(); + if (input_data == nullptr) { + MS_LOG(ERROR) << "MallocData for inTensor failed"; + } + MS_ASSERT(tensor->GetData() != nullptr); + auto tensor_byte_size = tensor->Size(); + auto status = GenerateRandomData(tensor_byte_size, input_data); + if (status != RET_OK) { + MS_LOG(ERROR) << "Generate RandomData for inTensor failed %d" << status; + } + } + return RET_OK; +} + +int TimeProfile::ReadInputFile() { + if (ms_inputs_.empty()) { + return RET_OK; + } + + auto inTensor = ms_inputs_.at(0); + MS_ASSERT(inTensor != nullptr); + + size_t size; + char *bin_buf = ReadFile(_flags->in_data_path_.c_str(), &size); + + auto tensor_data_size = inTensor->Size(); + if (size != tensor_data_size) { + MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensor_data_size << size; + } + auto input_data = inTensor->MutableData(); + memcpy(input_data, bin_buf, tensor_data_size); + return RET_OK; +} + +int TimeProfile::LoadInput() { + ms_inputs_ = session_->GetInputs(); + if (_flags->in_data_path_.empty()) { + auto status = GenerateInputData(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Generate input data error " << status; + } + } else { + auto status = ReadInputFile(); + if (status != RET_OK) { + MS_LOG(ERROR) << "ReadInputFile error, " << status; + } + } + return RET_OK; +} + +int TimeProfile::InitSession() { + size_t size = 0; + char *graph_buf = ReadFile(_flags->model_path_.c_str(), &size); + if (graph_buf == nullptr) { + MS_LOG(ERROR) << "Load graph failed, path %s" << _flags->model_path_; + } + + auto ctx = new lite::Context; + ctx->cpu_bind_mode_ = static_cast(_flags->cpu_bind_mode_); + ctx->device_ctx_.type = lite::DT_CPU; + ctx->thread_num_ = _flags->num_threads_; + + session_ = session::LiteSession::CreateSession(ctx); + if (session_ == nullptr) { + MS_LOG(ERROR) << "New session failed while running."; + } + + return RET_OK; +} + +int TimeProfile::InitCallbackParameter() { + // before callback + before_call_back_ = [&](const std::vector &before_inputs, + const std::vector &before_outputs, + const session::CallBackParam &callParam) { + if (before_inputs.empty()) { + MS_LOG(INFO) << "The num of beforeInputs is empty"; + } + if (before_outputs.empty()) { + MS_LOG(INFO) << "The num of beforeOutputs is empty"; + } + if (op_times_by_type_.find(callParam.type_callback_param) == op_times_by_type_.end()) { + op_times_by_type_.insert(std::make_pair(callParam.type_callback_param, std::make_pair(0, 0.0f))); + } + if (op_times_by_name_.find(callParam.name_callback_param) == op_times_by_name_.end()) { + op_times_by_name_.insert(std::make_pair(callParam.name_callback_param, std::make_pair(0, 0.0f))); + } + + op_call_times_total_++; + op_begin_ = GetTimeUs(); + return true; + }; + + // after callback + after_call_back_ = [&](const std::vector &after_inputs, + const std::vector &after_outputs, + const session::CallBackParam &call_param) { + uint64_t opEnd = GetTimeUs(); + + if (after_inputs.empty()) { + MS_LOG(INFO) << "The num of beforeInputs is empty"; + } + if (after_outputs.empty()) { + MS_LOG(INFO) << "The num of beforeOutputs is empty"; + } + + float cost = static_cast(opEnd - op_begin_) / 1000.0f; + op_cost_total_ += cost; + op_times_by_type_[call_param.type_callback_param].first++; + op_times_by_type_[call_param.type_callback_param].second += cost; + op_times_by_name_[call_param.name_callback_param].first++; + op_times_by_name_[call_param.name_callback_param].second += cost; + return true; + }; + + return RET_OK; +} + +int TimeProfile::Init() { + if (this->_flags == nullptr) { + return 1; + } + MS_LOG(INFO) << "ModelPath = " << _flags->model_path_; + MS_LOG(INFO) << "InDataPath = " << _flags->in_data_path_; + MS_LOG(INFO) << "LoopCount = " << _flags->loop_count_; + MS_LOG(INFO) << "NumThreads = " << _flags->num_threads_; + if (_flags->cpu_bind_mode_ == -1) { + MS_LOG(INFO) << "cpuBindMode = MID_CPU"; + } else if (_flags->cpu_bind_mode_ == 1) { + MS_LOG(INFO) << "cpuBindMode = HIGHER_CPU"; + } else { + MS_LOG(INFO) << "cpuBindMode = NO_BIND"; + } + + if (_flags->model_path_.empty()) { + MS_LOG(ERROR) << "modelPath is required"; + return 1; + } + + auto status = InitSession(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Init session failed."; + return RET_ERROR; + } + + status = this->LoadInput(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Load input failed."; + return RET_ERROR; + } + + status = InitCallbackParameter(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Init callback Parameter failed."; + return RET_ERROR; + } + + return RET_OK; +} + +int TimeProfile::PrintResult(const std::vector &title, + const std::map> &result) { + std::vector columnLenMax(5); + std::vector> rows; + + for (auto &iter : result) { + char stringBuf[5][100] = {}; + std::vector columns; + int len; + + len = iter.first.size(); + if (len > columnLenMax.at(0)) { + columnLenMax.at(0) = len + 4; + } + columns.push_back(iter.first); + + len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / _flags->loop_count_); + if (len > columnLenMax.at(1)) { + columnLenMax.at(1) = len + 4; + } + columns.emplace_back(stringBuf[1]); + + len = snprintf(stringBuf[2], sizeof(stringBuf[2]), "%f", iter.second.second / op_cost_total_); + if (len > columnLenMax.at(2)) { + columnLenMax.at(2) = len + 4; + } + columns.emplace_back(stringBuf[2]); + + len = snprintf(stringBuf[3], sizeof(stringBuf[3]), "%d", iter.second.first); + if (len > columnLenMax.at(3)) { + columnLenMax.at(3) = len + 4; + } + columns.emplace_back(stringBuf[3]); + + len = snprintf(stringBuf[4], sizeof(stringBuf[4]), "%f", iter.second.second); + if (len > columnLenMax.at(4)) { + columnLenMax.at(4) = len + 4; + } + columns.emplace_back(stringBuf[4]); + + rows.push_back(columns); + } + + printf("-------------------------------------------------------------------------\n"); + for (int i = 0; i < 5; i++) { + auto printBuf = title[i]; + if (printBuf.size() > columnLenMax.at(i)) { + columnLenMax.at(i) = printBuf.size(); + } + printBuf.resize(columnLenMax.at(i), ' '); + printf("%s", printBuf.c_str()); + } + printf("\n"); + for (int i = 0; i < rows.size(); i++) { + for (int j = 0; j < 5; j++) { + auto printBuf = rows[i][j]; + printBuf.resize(columnLenMax.at(j), ' '); + printf("%s\t", printBuf.c_str()); + } + printf("\n"); + } + return RET_OK; +} + +int TimeProfile::RunTimeProfile() { + uint64_t time_avg = 0; + + // Load graph + std::string modelName = _flags->model_path_.substr(_flags->model_path_.find_last_of("/") + 1); + + MS_LOG(INFO) << "start reading model file"; + size_t size = 0; + char *graphBuf = ReadFile(_flags->model_path_.c_str(), &size); + if (graphBuf == nullptr) { + MS_LOG(ERROR) << "Load graph failed while running %s", modelName.c_str(); + return 1; + } + auto model = lite::Model::Import(graphBuf, size); + + auto ret = session_->CompileGraph(model.get()); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Compile graph failed."; + return RET_ERROR; + } + + // load input + MS_LOG(INFO) << "start generate input data"; + auto status = LoadInput(); + if (status != 0) { + MS_LOG(ERROR) << "Generate input data error"; + return status; + } + + // run graph and test + for (int i = 0; i < _flags->loop_count_; i++) { + session_->BindThread(true); + uint64_t run_begin = GetTimeUs(); + + ret = session_->RunGraph(before_call_back_, after_call_back_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run graph failed."; + } + auto outputs = session_->GetOutputs(); + + uint64_t run_end = GetTimeUs(); + uint64_t time = run_end - run_begin; + time_avg += time; + session_->BindThread(false); + /* + for(auto &output : outputs) { + for (auto &outputTensor : output.second) { + delete outputTensor; + } + }*/ + outputs.clear(); + } + + time_avg /= _flags->loop_count_; + float runCost = static_cast(time_avg) / 1000.0f; + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run session failed."; + } + + const std::vector per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; + const std::vector per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; + PrintResult(per_op_name, op_times_by_name_); + PrintResult(per_op_type, op_times_by_type_); + + printf("\n total time: %5.5f ms, kernel cost: %5.5f ms \n\n", runCost, op_cost_total_ / _flags->loop_count_); + printf("-------------------------------------------------------------------------\n"); + + for (auto &msInput : ms_inputs_) { + delete msInput; + } + ms_inputs_.clear(); + delete graphBuf; + return ret; +} + +int RunTimeProfile(int argc, const char **argv) { + TimeProfileFlags flags; + Option err = flags.ParseFlags(argc, argv); + + if (err.IsSome()) { + std::cerr << err.Get() << std::endl; + std::cerr << flags.Usage() << std::endl; + return -1; + } + + if (flags.help) { + std::cerr << flags.Usage() << std::endl; + return 0; + } + + TimeProfile time_profile(&flags); + auto ret = time_profile.Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init TimeProfile failed."; + } + + ret = time_profile.RunTimeProfile(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run TimeProfile failed."; + } + + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/time_profile/time_profile.h b/mindspore/lite/tools/time_profile/time_profile.h new file mode 100644 index 0000000000..eaad720d34 --- /dev/null +++ b/mindspore/lite/tools/time_profile/time_profile.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINNIE_TIMEPROFILE_TIMEPROFILE_H_ +#define MINNIE_TIMEPROFILE_TIMEPROFILE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tools/common/flag_parser.h" +#include "src/common/file_utils.h" +#include "src/common/utils.h" +#include "schema/model_generated.h" +#include "include/model.h" +#include "include/lite_session.h" + + +namespace mindspore { +namespace lite { + +class MS_API TimeProfileFlags : public virtual FlagParser { + public: + TimeProfileFlags() { + AddFlag(&TimeProfileFlags::model_path_, "modelPath", "Input model path", ""); + AddFlag(&TimeProfileFlags::in_data_path_, "inDataPath", "Input data path, if not set, use random input", ""); + AddFlag(&TimeProfileFlags::cpu_bind_mode_, "cpuBindMode", + "Input -1 for MID_CPU, 1 for HIGHER_CPU, 0 for NO_BIND, defalut value: 1", 1); + AddFlag(&TimeProfileFlags::loop_count_, "loopCount", "Run loop count", 10); + AddFlag(&TimeProfileFlags::num_threads_, "numThreads", "Run threads number", 2); + } + + ~TimeProfileFlags() override = default; + + public: + std::string model_path_; + std::string in_data_path_; + int cpu_bind_mode_ = 1; + int loop_count_; + int num_threads_; +}; + +class MS_API TimeProfile { + public: + explicit TimeProfile(TimeProfileFlags *flags) : _flags(flags) {} + ~TimeProfile() = default; + + int Init(); + int RunTimeProfile(); + + private: + int GenerateRandomData(size_t size, void *data); + int GenerateInputData(); + int LoadInput(); + int ReadInputFile(); + int InitCallbackParameter(); + int InitSession(); + int PrintResult(const std::vector& title, const std::map>& result); + + private: + TimeProfileFlags *_flags; + std::vector ms_inputs_; + session::LiteSession *session_; + + // callback parameters + uint64_t op_begin_ = 0; + int op_call_times_total_ = 0; + float op_cost_total_ = 0.0f; + std::map> op_times_by_type_; + std::map> op_times_by_name_; + + session::KernelCallBack before_call_back_; + session::KernelCallBack after_call_back_; +}; + +int MS_API RunTimeProfile(int argc, const char **argv); +} // namespace lite +} // namespace mindspore +#endif // MINNIE_TIMEPROFILE_TIMEPROFILE_H_ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0913201861..49391d13ce 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -29,6 +29,7 @@ class ShardWriter: The class would write MindRecord File series. """ + def __init__(self): self._writer = ms.ShardWriter() self._header = None @@ -161,7 +162,7 @@ class ShardWriter: if row_blob: blob_data.append(list(row_blob)) # filter raw data according to schema - row_raw = {field: item[field] + row_raw = {field: self._convert_np_types(item[field]) for field in self._header.schema.keys() - self._header.blob_fields if field in item} if row_raw: raw_data.append(row_raw) @@ -172,6 +173,12 @@ class ShardWriter: raise MRMWriteDatasetError return ret + def _convert_np_types(self, val): + """convert numpy type to python primitive type""" + if isinstance(val, (np.int32, np.int64, np.float32, np.float64)): + return val.item() + return val + def _merge_blob(self, blob_data): """ Merge multiple blob data whose type is bytes or ndarray diff --git a/mindspore/mindrecord/tools/imagenet_to_mr.py b/mindspore/mindrecord/tools/imagenet_to_mr.py index 0aa870384e..59695c8734 100644 --- a/mindspore/mindrecord/tools/imagenet_to_mr.py +++ b/mindspore/mindrecord/tools/imagenet_to_mr.py @@ -35,10 +35,10 @@ class ImageNetToMR: .. code-block:: - n02119789 1 pen - n02100735 2 notebook - n02110185 3 mouse - n02096294 4 orange + n02119789 0 + n02100735 1 + n02110185 2 + n02096294 3 image_dir (str): image directory contains n02119789, n02100735, n02110185, n02096294 dir. destination (str): the MindRecord file path to transform into. diff --git a/mindspore/nn/__init__.py b/mindspore/nn/__init__.py index e5c133a9a6..7f69bc86d1 100644 --- a/mindspore/nn/__init__.py +++ b/mindspore/nn/__init__.py @@ -17,14 +17,14 @@ Neural Networks Cells. Pre-defined building blocks or computing units to construct Neural Networks. """ -from . import layer, loss, optim, metrics, wrap, distribution +from . import layer, loss, optim, metrics, wrap, probability, sparse from .cell import Cell, GraphKernel from .layer import * from .loss import * from .optim import * from .metrics import * from .wrap import * -from .distribution import * +from .sparse import * __all__ = ["Cell", "GraphKernel"] @@ -33,7 +33,6 @@ __all__.extend(loss.__all__) __all__.extend(optim.__all__) __all__.extend(metrics.__all__) __all__.extend(wrap.__all__) -__all__.extend(distribution.__all__) - +__all__.extend(sparse.__all__) __all__.sort() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 3eec96f0b5..93375d15dd 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple from .._c_expression import init_backend from ..ops.primitive import Primitive from ..ops.operations import HookBackward +from ..ops.functional import cast from ..parallel._tensor import _load_tensor_by_layout from ..common.tensor import Tensor @@ -57,9 +58,11 @@ class Cell: >>> def construct(self, x): >>> return self.relu(x) """ + def __init__(self, auto_prefix=True, flags=None): self._params = OrderedDict() self._cells = OrderedDict() + self._params_list = OrderedDict() self.training = False self.requires_grad = False self.pynative = False @@ -81,16 +84,16 @@ class Cell: self._backward_hook = None self.enable_hook = False self._bprop_debug = False - self._is_run = False + self._already_run = False self.cell_type = None @property - def is_run(self): - return self._is_run + def already_run(self): + return self._already_run - @is_run.setter - def is_run(self, value): - self._is_run = value + @already_run.setter + def already_run(self, value): + self._already_run = value @property def create_time(self): @@ -188,11 +191,22 @@ class Cell: if '_params' in self.__dict__: params = self.__dict__['_params'] if name in params: + if context.get_context("mode") == context.PYNATIVE_MODE: + return self.cast_param(params[name]) return params[name] if '_cells' in self.__dict__: cells = self.__dict__['_cells'] if name in cells: return cells[name] + if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__: + params_list = self.__dict__['_params_list'] + if name in params_list: + para_list = params_list[name] + cast_list = list() + for para in para_list: + cast_list.append(self.cast_param(para)) + para_list = ParameterTuple(cast_list) + return para_list raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) def __del__(self): @@ -215,7 +229,6 @@ class Cell: for item in inputs: if isinstance(item, numpy.ndarray): raise TypeError("cell inputs should not be numpy array.") - self.init_parameters_data() orign_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) @@ -225,22 +238,35 @@ class Cell: cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) + cast_inputs = list() + if hasattr(self, "_mindspore_flags"): + if self._mindspore_flags.get('fp16'): + for item in inputs: + cast_inputs.append(cast(item, mstype.float16)) + if self._mindspore_flags.get('fp32'): + for item in inputs: + cast_inputs.append(cast(item, mstype.float32)) + if cast_inputs: + cast_inputs = tuple(cast_inputs) + else: + cast_inputs = inputs if self.enable_hook: - output = self._hook_construct(*inputs) + output = self._hook_construct(*cast_inputs) else: - output = self.construct(*inputs) + output = self.construct(*cast_inputs) if isinstance(output, Parameter): output = output.data if self.requires_grad is True: _pynative_exec.end_graph(self, output, *inputs) for i, cell in enumerate(self.cells()): cell.set_grad(orign_grad[i]) - self._is_run = True + self._already_run = True return output def __setattr__(self, name, value): cells = self.__dict__.get('_cells') params = self.__dict__.get('_params') + params_list = self.__dict__.get('_params_list') if isinstance(value, Parameter): if params is None: raise AttributeError("Can not assign params before Cell.__init__() call.") @@ -256,7 +282,14 @@ class Cell: raise AttributeError("Can not assign params before Cell.__init__() call.") for item in value: self.insert_param_to_cell(item.name, item, check_name=False) - object.__setattr__(self, name, value) + if context.get_context("mode") == context.PYNATIVE_MODE: + if name in self.__dict__: + del self.__dict__[name] + if name in params: + del params[name] + params_list[name] = value + else: + object.__setattr__(self, name, value) elif isinstance(value, Cell): if cells is None: raise AttributeError("Can not assign cells before Cell.__init__() call.") @@ -316,19 +349,8 @@ class Cell: params (dict): The parameters dictionary used for init data graph. """ if params is None: - for key in self.parameters_dict(): - tensor = self.parameters_dict()[key].data - if key not in self.parameter_layout_dict: - logger.info("layout dict does not contain the key %s", key) - continue - if self.parameters_dict()[key].sliced: - logger.debug("Param %s is already sliced.", key) - continue - layout = self.parameter_layout_dict[key] - new_tensor = _load_tensor_by_layout(tensor, layout) - self.parameters_dict()[key].set_parameter_data(new_tensor) - self.parameters_dict()[key].sliced = True - elif isinstance(params, OrderedDict): + params = self.parameters_dict() + if isinstance(params, OrderedDict): for key in params: tensor = params[key].data if key not in self.parameter_layout_dict: @@ -339,8 +361,7 @@ class Cell: continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) - params[key].set_parameter_data(new_tensor) - params[key].sliced = True + params[key].set_parameter_data(new_tensor, True) else: raise TypeError('Parameters need OrderedDict type, but got {}'. format(type(params))) @@ -458,6 +479,22 @@ class Cell: raise TypeError("The type of parameter should be 'Parameter' if not None.") self._params[param_name] = param + def cast_param(self, param): + """ + Cast parameter according to auto mix precison level in pynative mode. + + Args: + param (Parameter): The parameter to cast. + """ + if hasattr(self, "_mindspore_flags"): + if self._mindspore_flags.get('fp16'): + param.cast_type = mstype.float16 + elif self._mindspore_flags.get('fp32'): + param.cast_type = mstype.float32 + else: + param.cast_type = None + return param + def insert_child_to_cell(self, child_name, child): """ Adds a child cell to the current cell. @@ -495,17 +532,50 @@ class Cell: """ raise NotImplementedError - def init_parameters_data(self, recurse=True, auto_parallel_mode=False): - """Init parameters' data.""" - for param in self.get_parameters(expand=recurse): - if not auto_parallel_mode: - param.init_data() - elif param.name not in self.parameter_layout_dict: - logger.debug("Layout dict does not contain the key %s.", param.name) - param.init_data(set_sliced=True) - else: - layout = self.parameter_layout_dict[param.name] - param.init_data(layout, set_sliced=True) + def init_parameters_data(self, auto_parallel_mode=False): + """ + Init all parameters' data and replace the original saved parameters in cell. + + Notes: + trainable_params() and other similar interfaces may return different parameter instance after + `init_parameters_data`, do not save these result. + + Args: + auto_parallel_mode (bool): If running in auto_parallel_mode. + + Returns: + Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter. + """ + replace = dict() + def _updata(param): + if param in replace: + return replace[param] + layout = None + set_sliced = False + if auto_parallel_mode: + set_sliced = True + if param.name not in self.parameter_layout_dict: + logger.debug("Layout dict does not contain the key %s.", param.name) + else: + layout = self.parameter_layout_dict[param.name] + new_p = param.init_data(layout, set_sliced=set_sliced) + replace[param] = new_p + return new_p + # replace all original usage. + cells = self.cells_and_names() + for _, cell in cells: + params = cell._params.items() + for param_name, param in params: + cell._params[param_name] = _updata(param) + cell_dict = cell.__dict__ + for key in cell_dict: + if isinstance(cell_dict[key], ParameterTuple): + param_tuple = cell_dict[key] + new_param_tuple = [] + for param in param_tuple: + new_param_tuple.append(_updata(param)) + cell.__dict__[key] = ParameterTuple(new_param_tuple) + return replace def parameters_dict(self, recurse=True): """ @@ -632,9 +702,10 @@ class Cell: for cell_name, cell in cells: params = cell._params.items() for par_name, par in params: - if par and par not in params_set: + if par.inited_param is not None: + par = par.inited_param + if par is not None and par not in params_set: params_set.add(par) - par_new_name = par_name if cell_name: par_new_name = cell_name + '.' + par_new_name @@ -816,7 +887,7 @@ class Cell: def register_backward_hook(self, fn): """ - Set the cell backward hook function. + Set the cell backward hook function. Note that this function is only supported in Pynative Mode. Note: fn should be defined as following code shows, `cell_name` is the name of registered cell, @@ -831,7 +902,7 @@ class Cell: self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self.enable_hook = True - def set_param_ps(self, recurse=True): + def set_param_ps(self, recurse=True, init_in_server=False): """ Set whether the trainable parameter is updated by parameter server. @@ -843,7 +914,8 @@ class Cell: """ params = self.trainable_params(recurse) for param in params: - param.set_param_ps() + param.set_param_ps(init_in_server) + class GraphKernel(Cell): """ @@ -861,6 +933,7 @@ class GraphKernel(Cell): >>> def construct(self, x): >>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) """ + def __init__(self, auto_prefix=True, pips=None): super(GraphKernel, self).__init__(auto_prefix, pips) class_name = self.__class__.__name__ diff --git a/mindspore/nn/distribution/__init__.py b/mindspore/nn/distribution/__init__.py deleted file mode 100644 index 55b4b03ef7..0000000000 --- a/mindspore/nn/distribution/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Distribution. - -The high-level components(Distributions) used to construct the probabilistic network. -""" - -from .distribution import Distribution -from .normal import Normal -from .bernoulli import Bernoulli - -__all__ = ['Distribution', - 'Normal', - 'Bernoulli',] diff --git a/mindspore/nn/distribution/_utils/__init__.py b/mindspore/nn/distribution/_utils/__init__.py deleted file mode 100644 index 816485643a..0000000000 --- a/mindspore/nn/distribution/_utils/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Distribution operation utility functions. -""" -from .utils import * - -__all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor', - 'calc_batch_size', 'check_greater', - 'check_greater_equal_zero', - 'calc_broadcast_shape_from_param', - 'check_scalar_from_param', 'check_prob'] diff --git a/mindspore/nn/distribution/_utils/utils.py b/mindspore/nn/distribution/_utils/utils.py deleted file mode 100644 index c790a66f25..0000000000 --- a/mindspore/nn/distribution/_utils/utils.py +++ /dev/null @@ -1,199 +0,0 @@ - -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Utitly functions to help distribution class.""" -import numpy as np -from mindspore.ops import _utils as utils -from ....common.tensor import Tensor -from ....common.parameter import Parameter -from ....common import dtype as mstype - - -def check_scalar(value): - """ - Check if input value is a scalar. - """ - return np.isscalar(value) - - -def cast_to_tensor(t, dtype=mstype.float32): - """ - Cast an user input value into a Tensor of dtype. - - Args: - t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. - dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. - - Raises: - RuntimeError: if t cannot be cast to Tensor. - - Returns: - Tensor. - """ - if isinstance(t, Parameter): - return t - if isinstance(t, Tensor): - #check if the Tensor in shape of Tensor(4) - if t.dim() == 0: - value = t.asnumpy() - return Tensor([t], dtype=dtype) - #convert the type of tensor to dtype - t.set_dtype(dtype) - return t - if isinstance(t, (list, np.ndarray)): - return Tensor(t, dtype=dtype) - if check_scalar(t): - return Tensor([t], dtype=dtype) - raise RuntimeError("Input type is not supported.") - -def calc_batch_size(batch_shape): - """ - Calculate the size of a given batch_shape. - - Args: - batch_shape (tuple): batch shape to be calculated. - - Returns: - int. - """ - return int(np.prod(batch_shape)) - -def convert_to_batch(t, batch_shape, dtype): - """ - Convert a Tensor to a given batch shape. - - Args: - t (Tensor, Parameter): Tensor to be converted. - batch_shape (tuple): desired batch shape. - dtype (mindspore.dtype): desired dtype. - - Raises: - RuntimeError: if the converison cannot be done. - - Returns: - Tensor, with shape of batch_shape. - """ - if isinstance(t, Parameter): - return t - t = cast_to_tensor(t, dtype) - if t.shape != batch_shape: - mul = calc_batch_size(batch_shape) // t.size() - if (calc_batch_size(batch_shape) % t.size()) != 0: - raise RuntimeError("Cannot cast the tensor to the given batch shape.") - temp = list(t.asnumpy()) * mul - temp = np.reshape(temp, batch_shape) - return Tensor(temp, dtype) - return t - -def check_scalar_from_param(params): - """ - Check if params are all scalars. - - Args: - params (dict): parameters used to initialize distribution. - - Notes: String parameters are excluded. - """ - for value in params.values(): - if isinstance(value, (str, type(params['dtype']))): - continue - elif check_scalar(value): - continue - else: - return False - return True - - -def calc_broadcast_shape_from_param(params): - """ - Calculate the broadcast shape from params. - - Args: - params (dict): parameters used to initialize distribution. - - Returns: - tuple. - """ - broadcast_shape = [] - for value in params.values(): - if isinstance(value, (str, type(params['dtype']))): - continue - if value is None: - return None - if isinstance(value, Parameter): - value_t = value.default_input - else: - value_t = cast_to_tensor(value, params['dtype']) - broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) - return tuple(broadcast_shape) - -def check_greater_equal_zero(value, name): - """ - Check if the given Tensor is greater zero. - - Args: - value (Tensor, Parameter): value to be checked. - name (str) : name of the value. - - Raises: - ValueError: if the input value is less than zero. - - """ - if isinstance(value, Parameter): - if not isinstance(value.default_input, Tensor): - return - value = value.default_input - comp = np.less(value.asnumpy(), np.zeros(value.shape)) - if comp.any(): - raise ValueError(f'{name} should be greater than zero.') - -def check_greater(a, b, name_a, name_b): - """ - Check if Tensor b is strictly greater than Tensor a. - - Args: - a (Tensor): input tensor a. - b (Tensor): input tensor b. - name_a (str): name of Tensor_a. - name_b (str): name of Tensor_b. - - Raises: - ValueError: if b is less than or equal to a - """ - comp = np.less(a.asnumpy(), b.asnumpy()) - if not comp.all(): - raise ValueError(f'{name_a} should be less than {name_b}') - - -def check_prob(p): - """ - Check if p is a proper probability, i.e. 0 <= p <=1. - - Args: - p (Tensor, Parameter): value to be checked. - - Raises: - ValueError: if p is not a proper probability. - """ - if isinstance(p, Parameter): - if not isinstance(p.default_input, Tensor): - return - p = p.default_input - comp = np.less(p.asnumpy(), np.zeros(p.shape)) - if comp.any(): - raise ValueError('Probabilities should be greater than or equal to zero') - comp = np.greater(p.asnumpy(), np.ones(p.shape)) - if comp.any(): - raise ValueError('Probabilities should be less than or equal to one') diff --git a/mindspore/nn/distribution/bernoulli.py b/mindspore/nn/distribution/bernoulli.py deleted file mode 100644 index 9aa20d668f..0000000000 --- a/mindspore/nn/distribution/bernoulli.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Bernoulli Distribution""" -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_prob -from ...common import dtype as mstype - -class Bernoulli(Distribution): - """ - Example class: Bernoulli Distribution. - - Args: - probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. - seed (int): seed to use in sampling. Default: 0. - dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. - name (str): name of the distribution. Default: Bernoulli. - - Note: - probs should be proper probabilities (0 <= p <= 1). - - Examples: - >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 - >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) - >>> # The following create two independent Bernoulli distributions - >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) - """ - - def __init__(self, - probs=None, - seed=0, - dtype=mstype.int32, - name="Bernoulli"): - """ - Constructor of Bernoulli distribution. - """ - param = dict(locals()) - super(Bernoulli, self).__init__(dtype, name, param) - if probs is not None: - self._probs = cast_to_tensor(probs) - check_prob(self._probs) - else: - self._probs = probs - self.seed = seed - - # ops needed for the class - self.log = P.Log() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.shape = P.Shape() - self.const = P.ScalarToArray() - self.less = P.Less() - self.cast = P.Cast() - self.erf = P.Erf() - self.sqrt = P.Sqrt() - - def extend_repr(self): - str_info = f'probs = {self._probs}' - return str_info - - def probs(self): - """ - Returns the probability for the outcome is 1. - """ - return self._probs - - def _mean(self, name='mean', probs1=None): - r""" - .. math:: - MEAN(B) = probs1 - """ - if name == 'mean': - return self._probs if probs1 is None else probs1 - return None - - def _var(self, name='var', probs1=None): - r""" - .. math:: - VAR(B) = probs1 * probs0 - """ - if name in ('sd', 'var'): - probs1 = self._probs if probs1 is None else probs1 - probs0 = self.add(1, -1 * probs1) - return self.mul(probs0, probs1) - return None - - def _prob(self, name, value, probs=None): - r""" - pmf of Bernoulli distribution. - - Args: - name (str): name of the function. Should be "prob" when passed in from construct. - value (Tensor): a Tensor composed of only zeros and ones. - probs (Tensor): probability of outcome is 1. Default: self._probs. - - .. math:: - pmf(k) = probs1 if k = 1; - pmf(k) = probs0 if k = 0; - """ - if name in ('prob', 'log_prob'): - probs1 = self._probs if probs is None else probs - probs0 = self.add(1, -1 * probs1) - return self.add(self.mul(probs1, value), - self.mul(probs0, self.add(1, -1 * value))) - return None - - def _kl_loss(self, name, dist, probs1_b, probs1_a=None): - r""" - Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). - - Args: - name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. - dist (str): type of the distributions. Should be "Bernoulli" in this case. - probs1_b (Tensor): probs1 of distribution b. - probs1_a (Tensor): probs1 of distribution a. Default: self._probs. - - .. math:: - KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + - probs0_a * \log(\fract{probs0_a}{probs0_b}) - """ - if name == 'kl_loss' and dist == 'Bernoulli': - probs1_a = self._probs if probs1_a is None else probs1_a - probs0_a = self.add(1, -1 * probs1_a) - probs0_b = self.add(1, -1 * probs1_b) - return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), - probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) - return None - - def _sample(self, name, shape=(), probs=None): - """ - Sampling. - - Args: - name (str): name of the function. Should always be 'sample' when passed in from construct. - shape (tuple): shape of the sample. Default: (). - probs (Tensor): probs1 of the samples. Default: self._probs. - - Returns: - Tensor, shape is shape + batch_shape. - """ - if name == 'sample': - probs1 = self._probs if probs is None else probs - batch_shape = self.shape(probs1) - sample_shape = shape + batch_shape - mean_zero = self.const(0.0) - sd_one = self.const(1.0) - sqrt_two = self.sqrt(self.const(2.0)) - sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) - sample = self.less(sample_uniform, probs1) - sample = self.cast(sample, self._dtype) - return sample - return None diff --git a/mindspore/nn/distribution/distribution.py b/mindspore/nn/distribution/distribution.py deleted file mode 100644 index 1ed7906a9e..0000000000 --- a/mindspore/nn/distribution/distribution.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""basic""" -from ..cell import Cell -from ._utils.utils import calc_broadcast_shape_from_param - - -class Distribution(Cell): - """ - Base class for all mathematical distributions. - - Args: - dtype (mindspore.dtype): type of the distribution. - name (str): name of the distribution. - param (dict): parameters used to initialize the distribution. - - Note: - Derived class should override operations such as ,_mean, _prob, - and _log_prob. Functions should be called through construct when - used inside a network in the form of function name followed by - arguments. - - Examples: - >>> class MyNormalDistribution(Distribution): - >>> def __init__(self): - >>> super(MyDistribution, self).__init__() - >>> self._mean_value = Tensor([2.0,3.0]) - >>> self._sd_value = Tensor([2.0,3.0]) - >>> - >>> def _mean(self): - >>> return self._mean_value - - """ - def __init__(self, - dtype, - name, - param): - - """ - Constructor of distribution class. - """ - super(Distribution, self).__init__() - self._name = name - self._dtype = dtype - self._parameters = {} - # parsing parameters - for k in param.keys(): - if not(k == 'self' or k.startswith('_')): - self._parameters[k] = param[k] - # some attributes - self._broadcast_shape = calc_broadcast_shape_from_param( - self._parameters) - - # set the function to call according to the derived class's attributes - self._set_prob() - self._set_log_prob() - self._set_sd() - - def _set_prob(self): - """ - Set probability funtion based on the availability of _prob and _log_likehood. - """ - if hasattr(self, '_prob'): - self._call_prob = self._prob - elif hasattr(self, '_log_likelihood'): - self._call_prob = self._calc_prob_from_log_likelihood - - def _set_sd(self): - """ - Set standard deviation based on the availability of _sd and _var. - """ - if hasattr(self, '_sd'): - self._call_sd = self._sd - elif hasattr(self, '_var'): - self._call_sd = self._calc_sd_from_var - - def _set_log_prob(self): - """ - Set log probability based on the availability of _prob and _log_likelihood. - """ - if hasattr(self, '_log_likelihood'): - self._call_log_prob = self._log_likelihood - if hasattr(self, '_prob'): - self._call_log_prob = self._calc_log_prob_from_prob - - def log_likelihood(self, *args): - """ - Evaluate the log probability at the given value. - - Note: - value is casted to Tensor for further calculation. - - Returns: - Tensor, shape is the broadcast_shape of the distribution. - """ - return self._call_log_prob(*args) - - def _calc_prob_from_log_likelihood(self, *args): - r""" - Evaluate prob from log probability. - - .. math:: - probability(x) = \exp(log_likehood(x)) - """ - return self.exp(self._log_likelihood(*args)) - - def prob(self, *args): - """ - Evaluate the prob (pdf or pmf) at given value. - - Note: - value is casted to Tensor for further calculation. - - Returns: - Tensor, shape is the broadcast_shape of the distribution. - """ - return self._call_prob(*args) - - def _calc_log_prob_from_prob(self, *args): - r""" - Evaluate log probability from probability. - - .. math:: - log_prob(x) = \log(prob(x)) - """ - return self.log(self._prob(*args)) - - def kl_loss(self, **kwargs): - """ - Evaluate the KL divergence. Parameters of the second distribution should be - passed in through **kwargs. - - Returns: - Tensor, shape is the broadcast_shape of the distribution and input distribution. - """ - return self._kl_loss(**kwargs) - - def mean(self, **kwargs): - """ - Evaluate the mean. - - Returns: - Tensor, shape is the broadcast_shape of the distribution. - """ - return self._mean(**kwargs) - - def sd(self, **kwargs): - """ - Evaluate the standard deviation. - - Returns: - Tensor, shape is the broadcast_shape of the distribution. - """ - return self._call_sd(**kwargs) - - def _calc_sd_from_var(self, *args): - r""" - Evaluate log probability from probability. - - .. math:: - STD(x) = \sqrt(VAR(x)) - """ - return self.sqrt(self._var(*args)) - - def construct(self, *inputs): - """ - Override construct in Cell. - - Args: - *inputs: inputs[0] is always the name of the function. - - Notes: - Always raise RuntimeError as Distribution should not be called directly. - """ - - if inputs[0] == 'log_prob': - return self._call_log_prob(*inputs) - if inputs[0] == 'prob': - return self._call_prob(*inputs) - if inputs[0] == 'kl_loss': - return self._kl_loss(*inputs) - if inputs[0] == 'mean': - return self._mean(*inputs) - if inputs[0] == 'sd': - return self._call_sd(*inputs) - if inputs[0] == 'sample': - return self._sample(*inputs) - return None diff --git a/mindspore/nn/distribution/normal.py b/mindspore/nn/distribution/normal.py deleted file mode 100644 index 61cec6d810..0000000000 --- a/mindspore/nn/distribution/normal.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Normal Distribution""" -import numpy as np -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from .distribution import Distribution -from ._utils.utils import convert_to_batch, check_greater_equal_zero -from ...common import dtype as mstype -from ...context import get_context - -class Normal(Distribution): - """ - Example class: Normal distribution. - - Args: - mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. - sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. - seed (int): seed to use in sampling. Default: 0. - dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. - name (str): name of the distribution. Default: Normal. - - - Note: - Standard deviation should be greater than zero. - - Examples: - >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 - >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) - >>> # The following create two independent normal distributions - >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) - """ - - def __init__(self, - mean=None, - sd=None, - seed=0, - dtype=mstype.float32, - name="Normal"): - """ - Constructor of normal distribution. - """ - param = dict(locals()) - super(Normal, self).__init__(dtype, name, param) - if mean is not None and sd is not None: - self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) - self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) - check_greater_equal_zero(self._sd_value, "Standard deviation") - else: - self._mean_value = mean - self._sd_value = sd - self.seed = seed - - #ops needed for the class - self.exp = P.Exp() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sq = P.Square() - self.log = P.Log() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step - self.shape = P.Shape() - self.zeroslike = P.ZerosLike() - self.const = P.ScalarToArray() - - def extend_repr(self): - str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' - return str_info - - def _expm1_by_step(self, x): - """ - Expm1 ops under GPU context. - """ - return self.add(self.exp(x), -1) - - def _mean(self, name='mean', mean=None, sd=None): - """ - Mean of the distribution. - """ - if name == 'mean': - mean = self._mean_value if mean is None or sd is None else mean - return mean - return None - - def _sd(self, name='sd', mean=None, sd=None): - """ - Standard deviation of the distribution. - """ - if name in ('sd', 'var'): - sd = self._sd_value if mean is None or sd is None else sd - return sd - return None - - def _log_likelihood(self, name, value, mean=None, sd=None): - r""" - Evaluate log probability. - - .. math:: - L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) - """ - if name in ('prob', 'log_prob'): - mean = self._mean_value if mean is None else mean - sd = self._sd_value if sd is None else sd - unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), - 2. * self.sq(sd)) - neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) - return self.add(unnormalized_log_prob, neg_normalization) - return None - - def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): - r""" - Evaluate Normal-Normal kl divergence, i.e. KL(a||b). - - Args: - name (str): name of the funtion passed in from construct. Should always be "kl_loss". - dist (str): type of the distributions. Should be "Normal" in this case. - mean_b (Tensor): mean of distribution b. - sd_b (Tensor): standard deviation distribution b. - mean_a (Tensor): mean of distribution a. Default: self._mean_value. - sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. - - .. math:: - KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + - 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) - """ - if name == 'kl_loss' and dist == 'Normal': - mean_a = self._mean_value if mean_a is None else mean_a - sd_a = self._sd_value if sd_a is None else sd_a - diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) - squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) - return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) - return None - - def _sample(self, name, shape=(), mean=None, sd=None): - """ - Sampling. - - Args: - name (str): name of the function. Should always be 'sample' when passed in from construct. - shape (tuple): shape of the sample. Default: (). - mean (Tensor): mean of the samples. Default: self._mean_value. - sd (Tensor): standard deviation of the samples. Default: self._sd_value. - - Returns: - Tensor, shape is shape + batch_shape. - """ - if name == 'sample': - mean = self._mean_value if mean is None else mean - sd = self._sd_value if sd is None else sd - batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) - sample_shape = shape + batch_shape - mean_zero = self.const(0.0) - sd_one = self.const(1.0) - sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample = self.add(mean, self.mul(sample_norm, sd)) - return sample - return None diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 6eeba415a7..aca241419c 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -231,8 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): >>> cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ - validator.check_float_positive('min_lr', min_lr, None) - validator.check_float_legal_value('min_lr', min_lr, None) + if not isinstance(min_lr, float): + raise TypeError("min_lr must be float.") + validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_float_positive('max_lr', max_lr, None) validator.check_float_legal_value('max_lr', max_lr, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) @@ -288,8 +289,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e """ validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_legal_value('learning_rate', learning_rate, None) - validator.check_float_positive('end_learning_rate', end_learning_rate, None) - validator.check_float_legal_value('end_learning_rate', end_learning_rate, None) + if not isinstance(end_learning_rate, float): + raise TypeError("end_learning_rate must be float.") + validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_float_positive('power', power, None) validator.check_float_legal_value('power', power, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) @@ -311,11 +313,58 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e return lr +def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch): + r""" + Get learning rate warming up. + + For the i-th step, the formula of computing warmup_learning_rate[i] is: + + .. math:: + warmup\_learning\_rate[i] = learning\_rate * tmp\_epoch / tmp\_warmup\_epoch + + Where :math:`tmp\_epoch=min(current\_epoch, warmup\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})` + + Args: + learning_rate (float): The initial value of learning rate. + warmup_steps (int): The warm up steps of learning rate. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> total_step = 6 + >>> step_per_epoch = 2 + >>> warmup_epoch = 2 + >>> warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch) + [0.0, 0.0, 0.05, 0.05, 0.1, 0.1] + """ + if not isinstance(learning_rate, float): + raise TypeError("learning_rate must be float.") + validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) + validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + + function = lambda x, y: (x, min(x, y)) + + lr = [] + for i in range(total_step): + current_epoch = math.floor(i / step_per_epoch) + warmup_epoch, tmp_epoch = function(warmup_epoch, current_epoch) + lr.append(learning_rate * tmp_epoch/ warmup_epoch) + return lr + + __all__ = [ 'piecewise_constant_lr', 'exponential_decay_lr', 'natural_exp_decay_lr', 'inverse_decay_lr', 'cosine_decay_lr', - 'polynomial_decay_lr' + 'polynomial_decay_lr', + 'warmup_lr' ] diff --git a/mindspore/nn/graph_kernels/graph_kernels.py b/mindspore/nn/graph_kernels/graph_kernels.py index 21cc4f8710..21a4c38ac5 100644 --- a/mindspore/nn/graph_kernels/graph_kernels.py +++ b/mindspore/nn/graph_kernels/graph_kernels.py @@ -1020,7 +1020,7 @@ class LayerNorm(Cell): Examples: >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) - >>> shape1 = x.shape()[1:] + >>> shape1 = x.shape[1:] >>> m = G.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) >>> m(x) """ diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 384f625133..8135e61a3b 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -454,7 +454,7 @@ class HSigmoid(Cell): Hard sigmoid is defined as: .. math:: - \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})), + \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{x_{i} + 3}{6})), where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index ecff453fab..f41268b6bf 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -33,7 +33,6 @@ from .activation import get_activation from ..._checkparam import Validator as validator from ..._checkparam import Rel - __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] @@ -251,6 +250,10 @@ def _is_equal_one(x): return False return bool(x.asnumpy().mean() == 1.0) +@constexpr +def _dtype_check(x_dtype): + if x_dtype not in [mstype.float32, mstype.float16]: + raise TypeError("The input type must be float32 or float16.") class ClipByNorm(Cell): r""" @@ -265,12 +268,11 @@ class ClipByNorm(Cell): where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. Inputs: - - **input** (Tensor) - Tensor of shape N-D. - - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)` and of - the same type as the input Tensor. + - **input** (Tensor) - Tensor of shape N-D. The type should be float32 or float16. + - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. Outputs: - Tensor, clipped tensor with the same shape as the input. + Tensor, clipped tensor with the same shape as the input, whose type is float32. Examples: >>> net = nn.ClipByNorm() @@ -286,7 +288,6 @@ class ClipByNorm(Cell): self.select_ = P.Select() self.greater_ = P.Greater() self.cast = P.Cast() - self.zero = Tensor(np.array([0.0]).astype(np.float32)) self.sqrt = P.Sqrt() self.max_op = P.Maximum() self.shape = P.Shape() @@ -300,12 +301,12 @@ class ClipByNorm(Cell): """add ms_function decorator for pynative mode""" mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) - cond = self.greater_(l2sum, self.zero) + cond = self.greater_(l2sum, 0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) - l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) + _dtype_check(self.dtype(x)) if _is_equal_one(clip_norm): intermediate = x else: @@ -407,11 +408,13 @@ class OneHot(Cell): super(OneHot, self).__init__() self.onehot = P.OneHot(axis) self.depth = depth - self.on_value = Tensor(on_value, dtype) - self.off_value = Tensor(off_value, dtype) + self.dtype = dtype + self.on_value = on_value + self.off_value = off_value def construct(self, indices): - return self.onehot(indices, self.depth, self.on_value, self.off_value) + return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype)) + class Pad(Cell): @@ -591,7 +594,7 @@ class MatrixDiagPart(Cell): Tensor, same type as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. Examples: - >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> matrix_diag_part = nn.MatrixDiagPart() >>> result = matrix_diag_part(x) [[-1., 1.], [-1., 1.], [-1., 1.]] @@ -622,11 +625,11 @@ class MatrixSetDiag(Cell): Tensor, same type as input `x`. The shape same as `x`. Examples: - >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) >>> matrix_set_diag = nn.MatrixSetDiag() >>> result = matrix_set_diag(x, diagonal) - [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] + [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]] """ def __init__(self): super(MatrixSetDiag, self).__init__() diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 77c6ace75d..2a6c21c0ab 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================ """conv""" +import numpy as np from mindspore import log as logger from mindspore.ops import operations as P +from mindspore.ops.primitive import constexpr from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer +from mindspore.common.tensor import Tensor from mindspore._checkparam import ParamValidator as validator, Rel from mindspore._checkparam import Validator from mindspore._checkparam import check_bool, twice, check_int_positive @@ -254,6 +257,11 @@ class Conv2d(_Conv): return s +@constexpr +def _check_input_3d(input_shape): + if len(input_shape) != 3: + raise ValueError(f"Input should be 3d, but got shape {input_shape}") + class Conv1d(_Conv): r""" 1D convolution layer. @@ -359,6 +367,15 @@ class Conv1d(_Conv): kernel_size = (1, kernel_size) stride = (1, stride) dilation = (1, dilation) + get_shape = P.Shape() + get_dtype = P.DType() + if isinstance(weight_init, Tensor): + weight_init_shape = get_shape(weight_init) + Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name) + weight_init_dtype = get_dtype(weight_init) + weight_init_value = weight_init.asnumpy() + weight_init_value = np.expand_dims(weight_init_value, 2) + weight_init = Tensor(weight_init_value, weight_init_dtype) super(Conv1d, self).__init__( in_channels, @@ -391,13 +408,13 @@ class Conv1d(_Conv): def construct(self, x): x_shape = self.shape(x) - if len(x_shape) == 3: - x = self.expand_dims(x, 2) + _check_input_3d(x_shape) + x = self.expand_dims(x, 2) output = self.conv2d(x, self.weight) if self.has_bias: output = self.bias_add(output, self.bias) - if len(x_shape) == 3: - output = self.squeeze(output) + + output = self.squeeze(output) return output def extend_repr(self): @@ -669,6 +686,15 @@ class Conv1dTranspose(_Conv): kernel_size = (1, kernel_size) stride = (1, stride) dilation = (1, dilation) + get_shape = P.Shape() + get_dtype = P.DType() + if isinstance(weight_init, Tensor): + weight_init_shape = get_shape(weight_init) + Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name) + weight_init_dtype = get_dtype(weight_init) + weight_init_value = weight_init.asnumpy() + weight_init_value = np.expand_dims(weight_init_value, 2) + weight_init = Tensor(weight_init_value, weight_init_dtype) # out_channels and in_channels swap. # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, # then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. @@ -733,8 +759,8 @@ class Conv1dTranspose(_Conv): def construct(self, x): x_shape = self.shape(x) - if len(x_shape) == 3: - x = self.expand_dims(x, 2) + _check_input_3d(x_shape) + x = self.expand_dims(x, 2) n, _, h, w = self.shape(x) @@ -746,8 +772,7 @@ class Conv1dTranspose(_Conv): if self.has_bias: output = self.bias_add(output, self.bias) - if len(x_shape) == 3: - output = self.squeeze(output) + output = self.squeeze(output) return output def extend_repr(self): @@ -880,6 +905,8 @@ class DepthwiseConv2d(Cell): self.dilation = dilation self.group = group self.has_bias = has_bias + self.weight_init = weight_init + self.bias_init = bias_init self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=self.pad_mode, diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 63ae7a94ac..88ab386c1a 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -99,6 +99,10 @@ def _check_input_filter_size(input_shape, param_name, filter_size, func_name): validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name) validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name) +@constexpr +def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): + validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) + def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0): return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, weight_init=weight, padding=padding, pad_mode="valid") @@ -211,6 +215,7 @@ class SSIM(Cell): self.concat = P.Concat(axis=1) def construct(self, img1, img2): + _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name) _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name) P.SameTypeShape()(img1, img2) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index 71c2920850..c640f89557 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -133,7 +133,8 @@ class LSTM(Cell): self.transpose2 = P.Transpose() num_directions = 2 if self.bidirectional else 1 self.cpu_target = False - if context.get_context("device_target") == "CPU": + enable_debug = context.get_context("enable_debug_runtime") + if context.get_context("device_target") == "CPU" and not enable_debug: self.cpu_target = True if not self.cpu_target: self.lstm = P.LSTM(input_size=self.input_size, diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 05e5e54b96..c09a02b3c2 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -44,7 +44,8 @@ class _BatchNorm(Cell): moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=None, - device_num_each_group=1): + device_num_each_group=1, + input_dims='2d'): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -55,6 +56,7 @@ class _BatchNorm(Cell): self.use_batch_statistics = use_batch_statistics self.num_features = num_features self.eps = eps + self.input_dims = input_dims self.moving_mean = Parameter(initializer( moving_mean_init, num_features), name="mean", requires_grad=False) self.moving_variance = Parameter(initializer( @@ -99,6 +101,9 @@ class _BatchNorm(Cell): epsilon=self.eps, momentum=self.momentum) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps) + self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) + self.enable_default_train = self.is_graph_mode and not self.is_global and \ + (self.is_ge_backend or self.is_ascend) data_parallel_strategy = ((1,), (1,)) data_parallel_strategy_one = ((1,), ()) @@ -145,45 +150,43 @@ class _BatchNorm(Cell): return y def construct(self, x): + _shape_check_bn(self.shape(x), self.input_dims) if self.use_batch_statistics is None: flag = self.training else: flag = self.use_batch_statistics + if flag: - if self.is_ge_backend and self.is_global: + if self.enable_global_sync: axes, re_shape = _shape_infer(F.shape(x), self.num_features) - y = self._global_sync(x, axes, re_shape) - elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): - if self.is_global: - axes, re_shape = _shape_infer(F.shape(x), self.num_features) - y = self._global_sync(x, axes, re_shape) - else: - y, batch_mean, batch_var, _, _ = \ - self.bn_train(x, - self.gamma, - self.beta, - None, - None) - - mean_sub = self.sub_mean(self.moving_mean, batch_mean) - temp_mean = self.mul_mean(mean_sub, self.momentum) - mean_sub2 = self.sub_var(self.moving_variance, batch_var) - temp_variance = self.mul_var(mean_sub2, self.momentum) - y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) - y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) - else: - y = self.bn_train(x, - self.gamma, - self.beta, - self.moving_mean, - self.moving_variance)[0] - else: - y = self.bn_infer(x, - self.gamma, - self.beta, - self.moving_mean, - self.moving_variance)[0] - return y + return self._global_sync(x, axes, re_shape) + + if self.enable_default_train: + y, batch_mean, batch_var, _, _ = self.bn_train(x, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, batch_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, batch_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) + y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + return y + + return self.bn_train(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + + return self.bn_infer(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] def extend_repr(self): return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( @@ -199,7 +202,18 @@ def _channel_check(channel, num_channel): @constexpr def _shape_check(in_shape): if len(in_shape) != 4: - raise ValueError("The input must has 4 dims") + raise ValueError("The input must has 4 dims.") + + +@constexpr +def _shape_check_bn(in_shape, in_dims): + dim = len(in_shape) + if in_dims == '1d' and dim != 2: + raise ValueError("The input must has 2 dims.") + if in_dims == '2d' and dim != 4: + raise ValueError("The input must has 4 dims.") + if in_dims == 'both' and dim != 2 and dim != 4: + raise ValueError("The input must has 2 dims or 4 dims.") @constexpr @@ -253,10 +267,10 @@ class BatchNorm1d(_BatchNorm): mean and variance. Default: None. Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`. Outputs: - Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`. Examples: >>> net = nn.BatchNorm1d(num_features=16) @@ -282,7 +296,8 @@ class BatchNorm1d(_BatchNorm): beta_init, moving_mean_init, moving_var_init, - use_batch_statistics) + use_batch_statistics, + input_dims='1d') def _check_data_dim(self, x): if x.dim() != 2: @@ -357,7 +372,8 @@ class BatchNorm2d(_BatchNorm): beta_init, moving_mean_init, moving_var_init, - use_batch_statistics) + use_batch_statistics, + input_dims='2d') def _check_data_dim(self, x): if x.dim() != 4: @@ -435,7 +451,8 @@ class GlobalBatchNorm(_BatchNorm): moving_mean_init, moving_var_init, use_batch_statistics, - device_num_each_group) + device_num_each_group, + input_dims='both') self.group = check_int_positive(device_num_each_group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.") diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 63cdedbfe9..0b94de155d 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Quantization aware.""" +"""Quantization aware training.""" from functools import partial import numpy as np @@ -28,8 +28,8 @@ from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import Rel import mindspore.context as context -from .normalization import BatchNorm2d -from .activation import get_activation +from .normalization import BatchNorm2d, BatchNorm1d +from .activation import get_activation, ReLU from ..cell import Cell from . import conv, basic from ..._checkparam import ParamValidator as validator @@ -39,10 +39,12 @@ __all__ = [ 'Conv2dBnAct', 'DenseBnAct', 'FakeQuantWithMinMax', - 'Conv2dBatchNormQuant', + 'Conv2dBnFoldQuant', + 'Conv2dBnWithoutFoldQuant', 'Conv2dQuant', 'DenseQuant', 'ActQuant', + 'LeakyReLUQuant', 'HSwishQuant', 'HSigmoidQuant', 'TensorAddQuant', @@ -83,7 +85,7 @@ class Conv2dBnAct(Cell): Initializer and string are the same as 'weight_init'. Refer to the values of Initializer for more details. Default: 'zeros'. has_bn (bool): Specifies to used batchnorm or not. Default: False. - activation (string): Specifies activation type. The optional values are as following: + activation (Cell): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. @@ -170,7 +172,7 @@ class DenseBnAct(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + activation (Cell): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. has_bn (bool): Specifies to used batchnorm or not. Default: False. activation (string): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', @@ -206,7 +208,7 @@ class DenseBnAct(Cell): self.has_bn = validator.check_bool("has_bn", has_bn) self.has_act = activation is not None if has_bn: - self.batchnorm = BatchNorm2d(out_channels) + self.batchnorm = BatchNorm1d(out_channels) self.activation = get_activation(activation) def construct(self, x): @@ -349,7 +351,7 @@ class FakeQuantWithMinMax(Cell): self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) # init fake quant relative op - if per_channel: + if self.per_channel: quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) else: @@ -369,7 +371,7 @@ class FakeQuantWithMinMax(Cell): num_bits=self.num_bits, symmetric=self.symmetric, narrow_range=self.narrow_range, - quant_delay=quant_delay) + quant_delay=self.quant_delay) self.fake_quant_train = quant_fun(training=True) self.fake_quant_infer = quant_fun(training=False) @@ -392,7 +394,7 @@ class FakeQuantWithMinMax(Cell): return out -class Conv2dBatchNormQuant(Cell): +class Conv2dBnFoldQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. @@ -403,8 +405,8 @@ class Conv2dBatchNormQuant(Cell): out_channels (int): The number of output channel :math:`C_{out}`. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. stride (int): Specifies stride for all spatial dimensions with the same value. - pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". - padding: (int): Implicit paddings on both sides of the input. Default: 0. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. eps (float): Parameters for BatchNormal. Default: 1e-5. momentum (float): Parameters for BatchNormal op. Default: 0.997. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the @@ -417,7 +419,7 @@ class Conv2dBatchNormQuant(Cell): mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the variance vector. Default: 'ones'. - fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. + fake (bool): Conv2dBnFoldQuant Cell add FakeQuantWithMinMax op or not. Default: True. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. @@ -432,7 +434,7 @@ class Conv2dBatchNormQuant(Cell): Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> batchnorm_quant = nn.Conv2dBatchNormQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid", + >>> batchnorm_quant = nn.Conv2dBnFoldQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid", >>> dilation=(1, 1)) >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) >>> result = batchnorm_quant(input_x) @@ -461,8 +463,8 @@ class Conv2dBatchNormQuant(Cell): narrow_range=False, quant_delay=0, freeze_bn=100000): - """init Conv2dBatchNormQuant layer""" - super(Conv2dBatchNormQuant, self).__init__() + """init Conv2dBnFoldQuant layer""" + super(Conv2dBnFoldQuant, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = twice(kernel_size) @@ -579,6 +581,132 @@ class Conv2dBatchNormQuant(Cell): return out +class Conv2dBnWithoutFoldQuant(Cell): + r""" + 2D convolution + batchnorm without fold with fake quant op layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. + stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + has_bn (bool): Specifies to used batchnorm or not. Default: False. + eps (float): Parameters for BatchNormal. Default: 1e-5. + momentum (float): Parameters for BatchNormal op. Default: 0.997. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid", + >>> dilation=(1, 1)) + >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mstype.float32) + >>> result = conv2d_quant(input_x) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + has_bn=True, + eps=1e-5, + momentum=0.997, + weight_init='normal', + bias_init='zeros', + per_channel=False, + num_bits=8, + symmetric=False, + narrow_range=False, + quant_delay=0): + super(Conv2dBnWithoutFoldQuant, self).__init__() + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + else: + self.kernel_size = kernel_size + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = has_bias + self.stride = twice(stride) + self.dilation = twice(dilation) + self.pad_mode = pad_mode + self.padding = padding + self.group = group + self.quant_delay = quant_delay + + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + + self.bias_add = P.BiasAdd() + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + self.bias = None + + self.conv = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + per_channel=per_channel, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + quant_delay=quant_delay) + self.has_bn = validator.check_bool("has_bn", has_bn) + if has_bn: + self.batchnorm = BatchNorm2d(out_channels) + + def construct(self, x): + weight = self.fake_quant_weight(self.weight) + out = self.conv(x, weight) + if self.has_bias: + out = self.bias_add(out, self.bias) + if self.has_bn: + out = self.batchnorm(out) + return out + + def extend_repr(self): + s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={}, ' \ + 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.quant_delay) + return s + + class Conv2dQuant(Cell): r""" 2D convolution with fake quant op layer. @@ -590,8 +718,8 @@ class Conv2dQuant(Cell): out_channels (int): The number of output channel :math:`C_{out}`. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1. - pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". - padding: (int): Implicit paddings on both sides of the input. Default: 0. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1. group (int): Split filter into groups, `in_ channels` and `out_channels` should be divisible by the number of groups. Default: 1. @@ -746,8 +874,8 @@ class DenseQuant(Cell): self.has_bias = check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ - weight_init.shape()[1] != in_channels: + if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") self.weight = Parameter(initializer( @@ -755,7 +883,7 @@ class DenseQuant(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: + if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( @@ -832,7 +960,7 @@ class ActQuant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> act_quant = nn.ActQuant(4, 1) + >>> act_quant = nn.ActQuant(nn.ReLU()) >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) >>> result = act_quant(input_x) """ @@ -855,7 +983,7 @@ class ActQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.act = activation() + self.act = activation def construct(self, x): x = self.act(x) @@ -865,6 +993,75 @@ class ActQuant(_QuantActivation): def get_origin(self): return self.act +class LeakyReLUQuant(_QuantActivation): + r""" + LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP. + + For a more Detailed overview of HSwish op. + + Args: + activation (Cell): Activation cell class. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + + Inputs: + - **x** (Tensor) - The input of LeakyReLUQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + Examples: + >>> activation = nn.LeakyReLUQuant(nn.LeakyReLU()) + >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> result = activation(input) + """ + + def __init__(self, + activation, + ema_decay=0.999, + per_channel=False, + num_bits=8, + symmetric=False, + narrow_range=False, + quant_delay=0): + super(LeakyReLUQuant, self).__init__() + self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + quant_delay=quant_delay) + self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + quant_delay=quant_delay) + if issubclass(activation.__class__, nn.LeakyReLU): + self.act = activation + else: + raise ValueError("Activation should be `nn.LeakyReLU`") + + def construct(self, x): + x = self.fake_quant_act_before(x) + x = self.act(x) + x = self.fake_quant_act_after(x) + return x + + def get_origin(self): + return self.act + + class HSwishQuant(_QuantActivation): r""" @@ -888,9 +1085,9 @@ class HSwishQuant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> hswish_quant = nn.HSwishQuant(4, 1) - >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) - >>> result = hswish_quant(input_x) + >>> activation = nn.HSwishQuant(nn.HSwish()) + >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> result = activation(input) """ def __init__(self, @@ -920,8 +1117,8 @@ class HSwishQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - if issubclass(activation, nn.HSwish): - self.act = activation() + if issubclass(activation.__class__, nn.HSwish): + self.act = activation else: raise ValueError("Activation should be `nn.HSwish`") @@ -957,9 +1154,9 @@ class HSigmoidQuant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> hsigmoid_quant = nn.HSigmoidQuant(4, 1) - >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) - >>> result = hsigmoid_quant(input_x) + >>> activation = nn.HSigmoidQuant(nn.HSigmoid()) + >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> result = activation(input) """ def __init__(self, @@ -989,8 +1186,8 @@ class HSigmoidQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - if issubclass(activation, nn.HSwish): - self.act = activation() + if issubclass(activation.__class__, nn.HSigmoid): + self.act = activation else: raise ValueError("Activation should be `nn.HSigmoid`") @@ -1107,7 +1304,7 @@ class QuantBlock(Cell): r""" A quant block of Conv/Dense, activation layer for Ascend deploy. - Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant. + Calculate Conv or Dense in Int8, with Quant and DeQuant. Notes: This block is only for deploy, and not trainable. @@ -1156,13 +1353,18 @@ class QuantBlock(Cell): self.has_bias = bias is not None self.activation = activation self.has_act = activation is not None + if isinstance(activation, ReLU): + self.activation = None + self.has_act = False + self.dequant.add_prim_attr("relu_flag", True) self.bias_add = P.BiasAdd() def construct(self, x): x = self.quant(x) - x = self.core_op(x, self.weight) if self.has_bias: - x = self.bias_add(x, self.bias) + x = self.core_op(x, self.weight, self.bias) + else: + x = self.core_op(x, self.weight) if self.has_act: x = self.activation(x) x = self.dequant(x, self.dequant_scale) diff --git a/mindspore/nn/learning_rate_schedule.py b/mindspore/nn/learning_rate_schedule.py new file mode 100644 index 0000000000..118dde4738 --- /dev/null +++ b/mindspore/nn/learning_rate_schedule.py @@ -0,0 +1,379 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate schedule.""" + +import math + +from ..common import dtype as mstype +from ..ops import operations as P +from .cell import Cell +from .._checkparam import Validator as validator +from .._checkparam import Rel + + +class LearningRateSchedule(Cell): + """Basic class of learning rate schedule.""" + def __init__(self): + super(LearningRateSchedule, self).__init__() + + def construct(self, global_step): + """ + Defines the computation to get the current learning rate. + + This method should be overridden by all subclasses. + + Note: + The output should be a Tensor of scalar. + + Inputs: + Tensor. The current step number. + """ + raise NotImplementedError + + +def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): + validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, cls_name) + validator.check_float_positive('learning_rate', learning_rate, cls_name) + validator.check_float_legal_value('learning_rate', learning_rate, cls_name) + validator.check_float_positive('decay_rate', decay_rate, cls_name) + validator.check_float_legal_value('decay_rate', decay_rate, cls_name) + validator.check_value_type('is_stair', is_stair, [bool], cls_name) + + +class ExponentialDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on exponential decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{p} + + Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula + is :math:`p = floor(\frac{current\_step}{decay\_steps})`. + + Args: + learning_rate (float): The initial value of learning rate. + decay_rate (float): The decay rate. + decay_steps (int): A value used to calculate decayed learning rate. + is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> decay_rate = 0.9 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> exponential_decay_lr = ExponentialDecayLR(learning_rate, decay_rate, decay_steps) + >>> exponential_decay_lr(global_step) + """ + def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False): + super(ExponentialDecayLR, self).__init__() + _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name) + self.learning_rate = learning_rate + self.decay_rate = decay_rate + self.decay_steps = decay_steps + self.is_stair = is_stair + self.pow = P.Pow() + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(global_step, mstype.float32) / self.decay_steps + if self.is_stair: + p = P.Floor()(p) + return self.learning_rate * self.pow(self.decay_rate, p) + + +class NaturalExpDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on natural exponential decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * p} + + Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula + is :math:`p = floor(\frac{current\_step}{decay\_steps})`. + + Args: + learning_rate (float): The initial value of learning rate. + decay_rate (float): The decay rate. + decay_steps (int): A value used to calculate decayed learning rate. + is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> decay_rate = 0.9 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> natural_exp_decay_lr = NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True) + >>> natural_exp_decay_lr(global_step) + """ + def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False): + super(NaturalExpDecayLR, self).__init__() + _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name) + self.learning_rate = learning_rate + self.decay_rate = decay_rate + self.decay_steps = decay_steps + self.is_stair = is_stair + self.math_e = math.e + self.pow = P.Pow() + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(global_step, mstype.float32) + if self.is_stair: + p = P.FloorDiv()(p, self.decay_steps) * self.decay_steps + return self.learning_rate * self.pow(self.math_e, -self.decay_rate * p) + + +class InverseDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on inverse-time decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * p) + + Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula + is :math:`p = floor(\frac{current\_step}{decay\_steps})`. + + Args: + learning_rate (float): The initial value of learning rate. + decay_rate (float): The decay rate. + decay_steps (int): A value used to calculate decayed learning rate. + is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> decay_rate = 0.9 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> inverse_decay_lr = InverseDecayLR(learning_rate, decay_rate, decay_steps, True) + >>> inverse_decay_lr(global_step) + """ + def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False): + super(InverseDecayLR, self).__init__() + _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name) + self.learning_rate = learning_rate + self.decay_rate = decay_rate + self.decay_steps = decay_steps + self.is_stair = is_stair + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(global_step, mstype.float32) / self.decay_steps + if self.is_stair: + p = P.Floor()(p) + return self.learning_rate / (1 + self.decay_rate * p) + + +class CosineDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on cosine decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) * + (1 + cos(\frac{current\_step}{decay\_steps}\pi)) + + + Args: + min_lr (float): The minimum value of learning rate. + max_lr (float): The maximum value of learning rate. + decay_steps (int): A value used to calculate decayed learning rate. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> min_lr = 0.01 + >>> max_lr = 0.1 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> cosine_decay_lr = CosineDecayLR(min_lr, max_lr, decay_steps) + >>> cosine_decay_lr(global_steps) + """ + def __init__(self, min_lr, max_lr, decay_steps): + super(CosineDecayLR, self).__init__() + if not isinstance(min_lr, float): + raise TypeError("min_lr must be float.") + validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_float_positive('max_lr', max_lr, self.cls_name) + validator.check_float_legal_value('max_lr', max_lr, self.cls_name) + validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) + if min_lr >= max_lr: + raise ValueError('`max_lr` should be greater than `min_lr`.') + self.min_lr = min_lr + self.max_lr = max_lr + self.decay_steps = decay_steps + self.math_pi = math.pi + self.delta = 0.5 * (max_lr - min_lr) + self.cos = P.Cos() + self.min = P.Minimum() + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(self.min(global_step, self.decay_steps), mstype.float32) + return self.min_lr + self.delta * (1.0 + self.cos(self.math_pi * p / self.decay_steps)) + + +class PolynomialDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on polynomial decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) * + (1 - tmp\_step / tmp\_decay\_steps)^{power} + end\_learning\_rate + + Where :math:`tmp\_step=min(current\_step, decay\_steps). + If `update_decay_steps` is true, update the value of `tmp_decay_step` every `decay_steps`. The formula + is :math:`tmp\_decay\_steps = decay\_steps * ceil(current\_step / decay\_steps)` + + Args: + learning_rate (float): The initial value of learning rate. + end_learning_rate (float): The end value of learning rate. + decay_steps (int): A value used to calculate decayed learning rate. + power (float): A value used to calculate decayed learning rate. This parameter should be greater than 0. + update_decay_steps (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> end_learning_rate = 0.01 + >>> decay_steps = 4 + >>> power = 0.5 + >>> global_step = Tenosr(2, mstype.int32) + >>> polynomial_decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + >>> polynomial_decay_lr(global_step) + """ + def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False): + super(PolynomialDecayLR, self).__init__() + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) + if not isinstance(end_learning_rate, float): + raise TypeError("end_learning_rate must be float.") + validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, + self.cls_name) + validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) + validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) + validator.check_float_positive('power', power, self.cls_name) + validator.check_float_legal_value('power', power, self.cls_name) + + self.decay_steps = decay_steps + self.start_learning_rate = learning_rate + self.end_learning_rate = end_learning_rate + self.diff_learning_rate = learning_rate - end_learning_rate + self.power = power + self.update_decay_steps = update_decay_steps + self.pow = P.Pow() + self.ceil = P.Ceil() + self.min = P.Minimum() + self.max = P.Maximum() + + def construct(self, global_step): + tmp_global_step = P.Cast()(global_step, mstype.float32) + tmp_decay_step = self.decay_steps + if self.update_decay_steps: + tmp_decay_step = tmp_decay_step * self.max(self.ceil(tmp_global_step / tmp_decay_step), 1) + else: + tmp_global_step = self.min(tmp_global_step, tmp_decay_step) + p = tmp_global_step / tmp_decay_step + lr = self.diff_learning_rate * self.pow(1.0 - p, self.power) + self.end_learning_rate + return lr + + +class WarmUpLR(LearningRateSchedule): + r""" + Get learning rate warming up. + + For the i-th step, the formula of computing warmup_learning_rate[i] is: + + .. math:: + warmup\_learning\_rate[i] = learning\_rate * tmp\_step / warmup\_steps + + Where :math:`tmp\_step=min(current\_step, warmup\_steps)`. + + Args: + learning_rate (float): The initial value of learning rate. + warmup_steps (int): The warm up steps of learning rate. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> warmup_steps = 2 + >>> global_step = Tenosr(2, mstype.int32) + >>> warmup_lr = WarmUpLR(learning_rate, warmup_steps) + >>> warmup_lr(global_step) + """ + def __init__(self, learning_rate, warmup_steps): + super(WarmUpLR, self).__init__() + if not isinstance(learning_rate, float): + raise TypeError("learning_rate must be float.") + validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GT, self.cls_name) + self.warmup_steps = warmup_steps + self.learning_rate = learning_rate + self.min = P.Minimum() + self.cast = P.Cast() + + def construct(self, global_step): + warmup_percent = self.cast(self.min(global_step, self.warmup_steps), mstype.float32)/ self.warmup_steps + return self.learning_rate * warmup_percent + + +__all__ = [ + 'ExponentialDecayLR', + 'NaturalExpDecayLR', + 'InverseDecayLR', + 'CosineDecayLR', + 'PolynomialDecayLR', + 'WarmUpLR' +] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 3f97fbf83c..5f17baf64a 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -218,7 +218,8 @@ class SoftmaxCrossEntropyWithLogits(_Loss): sparse (bool): Specifies whether labels use sparse format or not. Default: False. reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None, do not reduction. Default: None. - smooth_factor (float): Label smoothing factor. It is a optional input. Default: 0. + smooth_factor (float): Label smoothing factor. It is a optional input which should be in range [0, 1]. + Default: 0. num_classes (int): The number of classes in the task. It is a optional input Default: 2. Inputs: diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index 538c400067..e31e6345d4 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients. """ from .optimizer import Optimizer from .momentum import Momentum -from .adam import Adam, PSAdam, AdamWeightDecay, AdamWeightDecayDynamicLR +from .adam import Adam, AdamWeightDecay from .lamb import Lamb from .sgd import SGD from .lars import LARS -from .ftrl import FTRL, PSFTRL +from .ftrl import FTRL from .rmsprop import RMSProp from .proximal_ada_grad import ProximalAdagrad from .lazyadam import LazyAdam -__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam', - 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] +__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', + 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 794d2513e3..56de737176 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -27,12 +27,11 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer _adam_opt = C.MultitypeFuncGraph("adam_opt") -_adam_push_pull_opt = C.MultitypeFuncGraph("_adam_push_pull_opt") -@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") -def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter): +def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. @@ -41,7 +40,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0. + weight_decay (Number): Weight decay. Should be equal to or greater than 0. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. @@ -71,10 +70,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)) - update = next_m / (eps + op_sqrt(next_v)) if decay_flag: - update = op_mul(weight_decay_tensor, param_fp32) + update + update = op_mul(weight_decay, param_fp32) + update update_with_lr = op_mul(lr, update) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) @@ -86,75 +84,51 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad return gradient -def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): - """Check the type of inputs.""" - validator.check_value_type("beta1", beta1, [float], prim_name) - validator.check_value_type("beta2", beta2, [float], prim_name) - validator.check_value_type("eps", eps, [float], prim_name) - validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - - -def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): - """Check the type of inputs.""" - validator.check_value_type("learning_rate", learning_rate, [float], prim_name) - validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) - validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - validator.check_float_positive('power', power, prim_name) - validator.check_float_legal_value('power', power, prim_name) - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) - - -@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", - "Tensor", "Tensor", "Tensor") -def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2): +@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") +def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, + gradient, params, moment1, moment2, ps_parameter): """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" success = True - success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, - eps, gradient[1], gradient[0])) + indices = gradient.indices + values = gradient.values + if ps_parameter: + op_shape = P.Shape() + shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), + op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), + op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices), shapes), params)) + else: + success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices)) return success -@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor") -def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2): +@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, + params, moment1, moment2, ps_parameter): """Apply adam optimizer to the weight parameter using Tensor.""" success = True - success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, - eps, gradient)) - return success - -@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tuple", "Tensor", "Tensor", "Tensor") -def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2): - """Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse.""" - success = True - op_shape = P.Shape() - shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), - op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), - op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) - success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, - eps, gradient[1], gradient[0]), shapes), params)) + if ps_parameter: + op_shape = P.Shape() + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), + (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) + else: + success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, + eps, gradient)) return success +def _check_param_value(beta1, beta2, eps, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) -@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, - moment1, moment2): - """Apply adam optimizer by push and pull to the weight parameter using Tensor.""" - success = True - op_shape = P.Shape() - success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), - (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) - return success class Adam(Optimizer): r""" @@ -179,12 +153,9 @@ class Adam(Optimizer): :math:`\epsilon` represents `eps`. Note: - The Adam optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. @@ -209,14 +180,14 @@ class Adam(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a - Tensor but the dims of the Tensor is 0, use fixed learning - rate. Other cases are not supported. It should be equal to - or greater than 0. Default: 1e-3. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: 0.9. beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: @@ -249,9 +220,9 @@ class Adam(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] - >>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. - >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> optm = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() @@ -261,7 +232,7 @@ class Adam(Optimizer): def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, use_nesterov=False, weight_decay=0.0, loss_scale=1.0): super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale) - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + _check_param_value(beta1, beta2, eps, self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) @@ -269,7 +240,7 @@ class Adam(Optimizer): self.beta2 = Tensor(beta2, mstype.float32) self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") - self.eps = eps + self.eps = Tensor(eps, mstype.float32) self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') @@ -278,51 +249,9 @@ class Adam(Optimizer): self.opt = P.Adam(use_locking, use_nesterov) self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov) - def construct(self, gradients): - params = self.parameters - moment1 = self.moment1 - moment2 = self.moment2 - gradients = self.decay_weight(gradients) - gradients = self.scale_grad(gradients) - lr = self.get_lr() - - beta1_power = self.beta1_power * self.beta1 - self.beta1_power = beta1_power - beta2_power = self.beta2_power * self.beta2 - self.beta2_power = beta2_power - if self.is_group_lr: - success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps), - lr, gradients, params, moment1, moment2) - else: - success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps, lr), - gradients, params, moment1, moment2) - return success - -class PSAdam(Optimizer): - '''The same usage as Adam optimizer except the parameters are set PS mode.''' - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, - use_nesterov=False, weight_decay=0.0, loss_scale=1.0): - super(PSAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) - validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) - validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) - - self.beta1 = Tensor(beta1, mstype.float32) - self.beta2 = Tensor(beta2, mstype.float32) - self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") - self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") - self.eps = Tensor(eps, mstype.float32) - - self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') - self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') - - self.hyper_map = C.HyperMap() - self.push = P.Push("Adam", [0, 1, 2]) - self.push.add_prim_attr("primitive_target", "CPU") - self.pull = P.Pull() - self.pull.add_prim_attr("primitive_target", "CPU") + self._ps_pull = P.Pull() + self._ps_push = P.Push("Adam", [0, 1, 2]) + self._ps_push.add_prim_attr("use_nesterov", use_nesterov) def construct(self, gradients): params = self.parameters @@ -337,95 +266,52 @@ class PSAdam(Optimizer): beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power if self.is_group_lr: - success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps), - lr, gradients, params, moment1, moment2) + success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps), + lr, gradients, params, moment1, moment2, self.ps_parameters) else: - success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, - self.beta1, self.beta2, self.eps, lr), - gradients, params, moment1, moment2) + success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), + gradients, params, moment1, moment2, self.ps_parameters) return success + class AdamWeightDecay(Optimizer): """ Implements Adam algorithm weight decay fix. - Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. It should be equal to or - greater than 0. Default: 1e-3. - beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. - Should be in range (0.0, 1.0). - beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. - Should be in range (0.0, 1.0). - eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. - Should be greater than 0. - weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. - - Inputs: - - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - - Outputs: - tuple[bool], all elements are True. + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is posigive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. - Examples: - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - """ - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(AdamWeightDecay, self).__init__(learning_rate, params) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) - self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) - self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) - self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) + To improve parameter groups performance, the customized order of parameters can be supported. - self.params = self.parameters - self.moments1 = self.params.clone(prefix="adam_m", init='zeros') - self.moments2 = self.params.clone(prefix="adam_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. - self.hyper_map = C.HyperMap() + - params: Required. The value should be a list of `Parameter`. - def construct(self, gradients): - lr = self.get_lr() - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, - self.decay_flag, self.optim_filter) - if self.use_parallel: - optim_result = self.broadcast_params(optim_result) - return optim_result + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. -class AdamWeightDecayDynamicLR(Optimizer): - """ - Adam Weight Decay Dynamic Learning Rate (LR). + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. - Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - decay_steps (int): The steps of the decay. It must be int and positive. - warmup_steps (int): The steps of lr warm up. Default: 0. - learning_rate (float): A floating point value for the learning rate. It should be equal to or - greater than 0. Default: 0.001. - end_learning_rate (float): A floating point value for the end learning rate. It should be equal - to or greater than 0. Default: 0.0001. - power (float): The Power of the polynomial. It must be positive. Default: 10.0. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. Should be in range (0.0, 1.0). beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. @@ -433,8 +319,6 @@ class AdamWeightDecayDynamicLR(Optimizer): eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. Should be greater than 0. weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -444,71 +328,48 @@ class AdamWeightDecayDynamicLR(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - """ - def __init__(self, - params, - decay_steps, - warmup_steps=0, - learning_rate=0.001, - end_learning_rate=0.0001, - power=10.0, - beta1=0.9, - beta2=0.999, - eps=1e-6, - weight_decay=0.0, - decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): - super(AdamWeightDecayDynamicLR, self).__init__(0.0, params) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) - _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name) - validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, self.cls_name) - # turn them to scalar when me support scalar/tensor mix operations - self.global_step = Parameter(initializer(0, [1]), name="global_step") - self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) - self.warmup_flag = False - if warmup_steps > 0: - self.warmup_flag = True - self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32)) - self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32)) - self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32)) - self.power = power + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) - self.params = self.parameters - self.moments1 = self.params.clone(prefix="adam_m", init='zeros') - self.moments2 = self.params.clone(prefix="adam_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') self.hyper_map = C.HyperMap() - self.min = P.Minimum() - self.pow = P.Pow() - self.greater = P.Greater() - self.one = Tensor(np.array([1.0]).astype(np.float32)) - self.cast = P.Cast() - self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32)) def construct(self, gradients): - step = self.min(self.global_step, self.decay_steps) - p = step / self.decay_steps - lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate - if self.warmup_flag: - warmup_percent = self.global_step / self.warmup_steps - warmup_lr = self.start_learning_rate * warmup_percent - is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) - lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, - self.decay_flag, self.optim_filter) + lr = self.get_lr() + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) if self.use_parallel: - optim_result = self.broadcast_params(optim_result) - added_global_step = self.global_step + self.one - F.control_depend(lr, added_global_step) - self.global_step = added_global_step - + self.broadcast_params(optim_result) return optim_result diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index c0b11d6fa2..d00107dfb7 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -21,48 +21,41 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer, _apply_decay, _grad_scale _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") -_ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") -@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", - "Tensor") -def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): +@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", + "RowTensor", "Tensor", "Tensor", "Bool") +def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, + gradient, weight, moment, ps_parameter): """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" success = True - success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0])) + indices = gradient.indices + values = gradient.values + if ps_parameter: + op_shape = P.Shape() + shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) + success = F.depend(success, pull(push((values, indices), shapes), weight)) + else: + success = F.depend(success, spars_opt(weight, moment, linear, values, indices)) return success -@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", - "Tensor") -def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): +@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Bool") +def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, + gradient, weight, moment, ps_parameter): """Apply ftrl optimizer to the weight parameter.""" success = True - success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) + if ps_parameter: + op_shape = P.Shape() + success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), + (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) + else: + success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) return success -@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", - "Tensor", "Tensor") -def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, - weight, moment): - success = True - op_shape = P.Shape() - shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) - success = F.depend(success, pull(push((gradient[1], gradient[0]), shapes), weight)) - return success - - -@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", - "Tensor", "Tensor") -def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, - weight, moment): - success = True - op_shape = P.Shape() - success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), - (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) - return success -def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): +def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None): """Check param.""" validator.check_value_type("initial_accum", initial_accum, [float], prim_name) validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) @@ -78,9 +71,6 @@ def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, validator.check_value_type("use_locking", use_locking, [bool], prim_name) - validator.check_value_type("weight_decay", weight_decay, [float], prim_name) - validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) - class FTRL(Optimizer): """ @@ -92,22 +82,41 @@ class FTRL(Optimizer): `_ for engineering document. Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on all of the parameters. + + To improve parameter groups performance, the customized order of parameters can be supported. + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. - The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Using different learning rate by separating parameters is currently not supported. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. - learning_rate (float): The learning rate value, should be positive. Default: 0.001. + learning_rate (float): The learning rate value, should be zero or positive, dynamic learning rate is currently + not supported. Default: 0.001. lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0. use_locking (bool): If True use locks for update operation. Default: False. loss_scale (float): Value for the loss scale. It should be equal to or greater than 1.0. Default: 1.0. - wegith_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0. + weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0. Inputs: - **grads** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is as same as the `params` @@ -118,70 +127,55 @@ class FTRL(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.FTRL(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use weight decay of 0.01. + >>> # The no_conv_params's parameters will use default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> opt = nn.FTRL(net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, use_locking=False, loss_scale=1.0, weight_decay=0.0): - super(FTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) + super(FTRL, self).__init__(learning_rate, params, weight_decay, loss_scale=loss_scale) + if self.dynamic_lr or self.is_group_lr: + raise ValueError('Dynamic learning rate or group learning rate is currently not supported.') + _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.linear = self.parameters.clone(prefix="linear", init='zeros') self.l1 = l1 self.l2 = l2 self.lr_power = lr_power - self.weight_decay = weight_decay - self.decay_tf = tuple((lambda: True)() for x in self.parameters) + if not self.is_group: + self.decay_flags = tuple((lambda: True)() for x in self.parameters) self.hyper_map = C.HyperMap() self.opt = P.ApplyFtrl(use_locking=use_locking) self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) + self._ps_pull = P.Pull() + self._ps_push = P.Push("Ftrl", [0, 1, 2]) + self._ps_push.add_prim_attr("lr", learning_rate) + self._ps_push.add_prim_attr("l1", l1) + self._ps_push.add_prim_attr("l2", l2) + self._ps_push.add_prim_attr("lr_power", lr_power) def construct(self, grads): params = self.parameters moments = self.moments linear = self.linear - lr = self.learning_rate - if self.weight_decay > 0.0: - grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) - + grads = self.decay_weight(grads) grads = self.scale_grad(grads) - success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), - linear, grads, params, moments) - return success + lr = self.get_lr() -class PSFTRL(Optimizer): - def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, - use_locking=False, loss_scale=1.0, weight_decay=0.0): - super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) - self.moments = self.parameters.clone(prefix="moments", init=initial_accum) - self.linear = self.parameters.clone(prefix="linear", init='zeros') - self.l1 = l1 - self.l2 = l2 - self.lr_power = lr_power - self.weight_decay = weight_decay - self.decay_tf = tuple((lambda: True)() for x in self.parameters) - - self.hyper_map = C.HyperMap() - self.push = P.Push("Ftrl", [0, 1, 2]) - self.push.add_prim_attr("primitive_target", "CPU") - self.pull = P.Pull() - self.pull.add_prim_attr("primitive_target", "CPU") - - def construct(self, grads): - params = self.parameters - moments = self.moments - linear = self.linear - lr = self.learning_rate - if self.weight_decay > 0.0: - grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) - - grads = self.scale_grad(grads) - success = self.map_(F.partial(_ftrl_push_pull_opt, self.push, self.pull, lr, self.l1, self.l2, self.lr_power), - linear, grads, params, moments) + success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, + self.l1, self.l2, self.lr_power, lr), + linear, grads, params, moments, self.ps_parameters) return success diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 93c7edbce8..e17e590c21 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -32,10 +32,9 @@ num_one = Tensor(np.ones([1]), mstype.float32) _lamb_opt = C.MultitypeFuncGraph("lamb_opt") -@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", +@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") -def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, - gradient, decay_flag, optim_filter): +def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. @@ -44,7 +43,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0. + weight_decay (Number): Weight decay. Should be equal to or greater than 0. global_step (Tensor): Global step. param (Tensor): Parameters. m (Tensor): m value of parameters. @@ -87,7 +86,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32) - g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) + g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32) zeros = F.zeros_like(w_norm) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) trust_ratio = op_select( @@ -99,7 +98,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para update = next_mm / (op_sqrt(next_vv) + eps) if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) + update = update + op_mul(weight_decay, param_fp32) update_with_lr = op_mul(op_mul(trust_ratio, lr), update) @@ -116,10 +115,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") -@lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", +@lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, - global_step, param, m, v, gradient, decay_flag): +def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag): """ Update parameters. @@ -128,7 +126,7 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0. + weight_decay (Number): Weight decay. Should be equal to or greater than 0. global_step (Tensor): Global step. param (Tensor): Parameters. m (Tensor): m value of parameters. @@ -157,11 +155,10 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex) i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex) i1 = op_square(gradient_fp32) - add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, - i9, beta2, x1, weight_decay_tensor, eps) + add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9, beta2, x1, weight_decay, eps) if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) + update = update + op_mul(weight_decay, param_fp32) w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32) @@ -171,38 +168,18 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0) - next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, - param, zeros, ones, tens) + next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, param, zeros, ones, tens) next_v = F.control_depend(add3, next_param) return next_v -def _check_param_value(decay_steps, warmup_steps, start_learning_rate, - end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): - """Check the type of inputs.""" - validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) - validator.check_number_range("start_learning_rate rate", start_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, - prim_name) - validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) - validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, - prim_name) - validator.check_float_positive('power', power, prim_name) - validator.check_float_legal_value('power', power, prim_name) - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) - validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name) +def _check_param_value(beta1, beta2, eps, prim_name): validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name) - validator.check_value_type( - "weight_dacay", weight_decay, [float], prim_name) - validator.check_number_range( - "beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range( - "beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range( - "eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) - validator.check_number_range( - "weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) class Lamb(Optimizer): @@ -213,16 +190,37 @@ class Lamb(Optimizer): optimization technique. Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76 MINUTES `_. + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - decay_steps (int): The steps of the lr decay. Should be equal to or greater than 1. - warmup_steps (int): The steps of lr warm up. Should be equal to or greater than 0. Default: 0. - start_learning_rate (float): A floating point value for the learning rate. Should be equal to - or greater than 0. Default: 0.1. - end_learning_rate (float): A floating point value for the end learning rate. Should be equal to - or greater than 0. Default: 0.0001. - power (float): The power of the polynomial. It must be positive. Default: 1.0. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. Should be in range (0.0, 1.0). beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. @@ -230,8 +228,6 @@ class Lamb(Optimizer): eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. Should be greater than 0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -241,90 +237,84 @@ class Lamb(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.Lamb(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR() + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': poly_decay_lr}, + >>> {'order_params': net.trainable_params(0.01, 0.0001, 10, 0.5)}] + >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default + >>> # weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.Lamb(params=net.trainable_params(), decay_steps=10) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ - def __init__(self, - params, - decay_steps, - warmup_steps=0, - start_learning_rate=0.1, - end_learning_rate=0.0001, - power=1.0, - beta1=0.9, - beta2=0.999, - eps=1e-6, - weight_decay=0.0, - decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): - super(Lamb, self).__init__(0.0, params) - if self.is_group: - raise RuntimeError( - f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, - power, beta1, beta2, eps, weight_decay, self.cls_name) + def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(Lamb, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations - self.global_step = Parameter(initializer(0, [1]), name="global_step") - - self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) - self.warmup_flag = False - if warmup_steps > 0: - self.warmup_flag = True - self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32)) - self.start_learning_rate = Tensor( - np.array([start_learning_rate]).astype(np.float32)) - self.end_learning_rate = Tensor( - np.array([end_learning_rate]).astype(np.float32)) - self.diff_learning_rate = Tensor( - np.array([start_learning_rate - end_learning_rate]).astype(np.float32)) - self.power = power self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor( - np.array([weight_decay]).astype(np.float32)) self.params = self.parameters self.moments1 = self.params.clone(prefix="lamb_m", init='zeros') self.moments2 = self.params.clone(prefix="lamb_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) + if not self.dynamic_lr: + self.global_step = Parameter(initializer(0, [1]), name='global_step') + self.assignadd = P.AssignAdd() self.hyper_map = C.HyperMap() - self.min = P.Minimum() - self.pow = P.Pow() - self.greater = P.Greater() - self.one = Tensor(np.array([1.0]).astype(np.float32)) - self.cast = P.Cast() self.enable_graph_kernel = context.get_context("enable_graph_kernel") def construct(self, gradients): - step = self.min(self.global_step, self.decay_steps) - p = step / self.decay_steps - lr = self.diff_learning_rate * \ - self.pow(self.one - p, self.power) + self.end_learning_rate - if self.warmup_flag: - warmup_percent = self.global_step / self.warmup_steps - warmup_lr = self.start_learning_rate * warmup_percent - is_warmup = self.cast(self.greater( - self.warmup_steps, self.global_step), mstype.float32) - lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr + lr = self.get_lr() if self.enable_graph_kernel: - optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, + self.global_step), + lr, self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags) + else: + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, + self.global_step, lr), + self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags) + else: + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, + self.global_step, lr, self.weight_decay), + self.params, self.moments1, self.moments2, gradients, self.decay_flags) else: - optim_result = self.hyper_map(F.partial(_lamb_opt, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, - self.decay_flag, self.optim_filter) + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step), + lr, self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr), + self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr, self.weight_decay), + self.params, self.moments1, self.moments2, gradients, + self.decay_flags, self.optim_filter) + if self.use_parallel: - optim_result = self.broadcast_params(optim_result) + self.broadcast_params(optim_result) - added_global_step = self.global_step + self.one - F.control_depend(lr, added_global_step) - self.global_step = added_global_step + if not self.dynamic_lr: + F.control_depend(lr, self.assignadd(self.global_step, 1)) return optim_result diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 7b05b372eb..91ca9a4b22 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -38,14 +38,14 @@ def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_f return gradient + def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name): validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) - if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name: - raise TypeError("LARS can not be used with ", optimizer.cls_name) validator.check_value_type("epsilon", epsilon, [float], prim_name) validator.check_value_type("coefficient", coefficient, [float], prim_name) validator.check_value_type("use_clip", use_clip, [bool], prim_name) + class LARS(Optimizer): """ Implements the LARS algorithm with LARSUpdate Operator. @@ -81,45 +81,71 @@ class LARS(Optimizer): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) self.opt = optimizer + self.parameters = optimizer.parameters + self.use_clip = use_clip + self.lars_flag = tuple(lars_filter(x) for x in self.parameters) + self.is_group = optimizer.is_group + self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") + self.decay_flags = optimizer.decay_flags + self.reciprocal_scale = optimizer.reciprocal_scale + self.hyper_map = C.HyperMap() self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() - self.parameters = optimizer.parameters - if use_clip is True: - self.learning_rate = optimizer.learning_rate + + if use_clip: + self.is_group_lr = optimizer.is_group_lr self.dynamic_lr = optimizer.dynamic_lr - self.gather = optimizer.gather - self.assignadd = optimizer.assignadd + self.origin_learning_rate = optimizer.learning_rate self.global_step = optimizer.global_step - else: - self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") - self.reciprocal_scale = optimizer.reciprocal_scale - optimizer.reciprocal_scale = 1.0 - self.is_group = optimizer.is_group + if self.is_group_lr and self.dynamic_lr: + raise ValueError('Grouped dynamic learning rate is currently not supported for the inputs optimizer ' \ + 'of lars.') + if self.is_group: self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) + optimizer.weight_decay = tuple(map(lambda x: 0.0, optimizer.weight_decay)) else: self.weight_decay = optimizer.weight_decay / optimizer.loss_scale + optimizer.weight_decay = 0.0 + + optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags)) + optimizer.reciprocal_scale = 1.0 optimizer.exec_weight_decay = False - optimizer.weight_decay = 0.0 - self.decay_flags = optimizer.decay_flags - self.lars_flag = tuple(lars_filter(x) for x in self.parameters) - self.hyper_map = C.HyperMap() + + def _get_lr(self): + """Get the learning rate of current step.""" + lr = self.origin_learning_rate + if self.dynamic_lr: + if self.is_group_lr: + lr = () + for learning_rate in self.origin_learning_rate: + current_dynamic_lr = learning_rate(self.global_step) + lr += (current_dynamic_lr,) + else: + lr = self.origin_learning_rate(self.global_step) + + return lr def construct(self, gradients): params = self.parameters - if self.dynamic_lr: - lr = self.gather(self.learning_rate, self.global_step, 0) - F.control_depend(lr, self.assignadd(self.global_step, 1)) + if self.use_clip: + lr = self._get_lr() else: lr = self.learning_rate + if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) + if self.is_group: - grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, - gradients, params, self.decay_flags, self.lars_flag) + if self.is_group_lr: + gradients = self.hyper_map(F.partial(_lars_opt, self.lars), lr, self.weight_decay, + gradients, params, self.decay_flags, self.lars_flag) + else: + gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, + gradients, params, self.decay_flags, self.lars_flag) else: - grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), - gradients, params, self.decay_flags, self.lars_flag) - success = self.opt(grad_t) + gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), + gradients, params, self.decay_flags, self.lars_flag) + success = self.opt(gradients) return success diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 7df86bc277..80dfd9d38c 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -27,14 +27,14 @@ from .optimizer import Optimizer _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") -@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", - "Tensor", "Tensor", "Tensor") +@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", + "RowTensor", "Tensor", "Tensor", "Tensor") def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2): """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" success = True success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, - eps, gradient[1], gradient[0])) + eps, gradient.values, gradient.indices)) return success @@ -84,12 +84,11 @@ class LazyAdam(Optimizer): :math:`\epsilon` represents `eps`. Note: - The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse behavior, to be notice, is not equivalent to the @@ -113,13 +112,14 @@ class LazyAdam(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. Default: 1e-3. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: 0.9. beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: @@ -154,8 +154,8 @@ class LazyAdam(Optimizer): >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] >>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. - >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 1e8ce85570..014cc8f823 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """momentum""" -from mindspore.ops import functional as F, composite as C +from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import _selected_ops from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor @@ -25,11 +25,18 @@ from .optimizer import Optimizer _momentum_opt = C.MultitypeFuncGraph("momentum_opt") -@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): +@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter): """Apply momentum optimizer to the weight parameter using Tensor.""" success = True - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) + if ps_parameter: + op_shape = P.Shape() + _ps_pull = P.Pull() + _ps_push = P.Push("ApplyMomentum", []) + shapes = (op_shape(learning_rate), op_shape(gradient), op_shape(momentum)) + success = F.depend(success, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight)) + else: + success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) return success @@ -40,15 +47,25 @@ class Momentum(Optimizer): Refer to the paper on the importance of initialization and momentum in deep learning for more details. Note: - The Momentum optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. + .. math:: + v_{t} = v_{t-1} \ast u + gradients + + If use_nesterov is True: + .. math:: + p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr) + + If use_nesterov is Flase: + .. math:: + p_{t} = p_{t-1} - lr \ast v_{t} + + Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively. + Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", @@ -66,14 +83,13 @@ class Momentum(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a - Tensor but the dims of the Tensor is 0, use fixed learning - rate. Other cases are not supported. It should be equal to - or greater than 0.0. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. momentum (float): Hyperparameter of type float, means momentum for the moving average. It should be at least 0.0. weight_decay (int, float): Weight decay (L2 penalty). It should be equal to or greater than 0.0. Default: 0.0. @@ -100,7 +116,7 @@ class Momentum(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] - >>> opt = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) + >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. @@ -127,7 +143,9 @@ class Momentum(Optimizer): gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: - success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) + success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, + self.ps_parameters) else: - success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) + success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, + self.ps_parameters) return success diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 4d6471f424..e499060813 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -20,16 +20,18 @@ import numpy as np import mindspore from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell +from mindspore.nn.layer.container import CellList from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer -from mindspore.common.tensor import Tensor +from mindspore.common.tensor import Tensor, RowTensor import mindspore.common.dtype as mstype from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore import log as logger from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.parallel_utils import ParallelMode +from mindspore import context +from mindspore.nn.learning_rate_schedule import LearningRateSchedule __all__ = ['Optimizer'] @@ -44,25 +46,22 @@ class Optimizer(Cell): This class defines the API to add Ops to train a model. Never use this class directly, but instead instantiate one of its subclasses. - Some optimizers support separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. + Different parameter groups can set different `learning_rate` and `weight_decay`. When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will + be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. Args: - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. It should be equal to or greater - than 0. If the type of `learning_rate` input is int, it will be - converted to float. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning + rate. When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed. @@ -91,7 +90,7 @@ class Optimizer(Cell): def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): super(Optimizer, self).__init__(auto_prefix=False) - if parameters and not isinstance(parameters, list): + if parameters is not None and not isinstance(parameters, list): parameters = list(parameters) if not parameters: @@ -104,32 +103,17 @@ class Optimizer(Cell): loss_scale = float(loss_scale) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) + self.loss_scale = loss_scale - if isinstance(weight_decay, int): - weight_decay = float(weight_decay) - validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + weight_decay = self._preprocess_weight_decay(weight_decay) + self.dynamic_lr = False + self.assignadd = None + self.global_step = None self.is_group = False self.is_group_lr = False self.is_group_params_ordered = False - self.loss_scale = loss_scale - if isinstance(learning_rate, int): - learning_rate = float(learning_rate) - if isinstance(learning_rate, float): - self.dynamic_lr = False - self.gather = None - self.assignadd = None - self.global_step = None - self.scalar_lr = learning_rate - else: - self.dynamic_lr = True - self.gather = P.GatherV2() - self.assignadd = P.AssignAdd() - self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') - self.scalar_lr = None - - learning_rate = self._get_single_lr(learning_rate) + learning_rate = self._preprocess_single_lr(learning_rate) if isinstance(parameters[0], dict): self.is_group = True self.group_params = [] @@ -137,33 +121,42 @@ class Optimizer(Cell): self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) + # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params + if self.dynamic_lr: + self.assignadd = P.AssignAdd() + self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') + if self.is_group_lr: - self.learning_rate = ParameterTuple(self.group_lr) + if self.dynamic_lr: + self.learning_rate = CellList(self.group_lr) + else: + self.learning_rate = ParameterTuple(self.group_lr) else: - self.learning_rate = Parameter(learning_rate, name="learning_rate") - + self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') if self.is_group: self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) + self.exec_weight_decay = any(self.decay_flags) else: self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) + self.exec_weight_decay = self.weight_decay > 0 + ps_filter = lambda x: x.is_param_ps + self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale - self.exec_weight_decay = any(self.decay_flags) self.param_length = len(self.parameters) self.map_ = C.Map() - use_parallel = auto_parallel_context().get_enable_parallel_optimizer() + use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer") self.use_parallel = use_parallel if use_parallel: - if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: + if self.cls_name not in ["Lamb", "AdamWeightDecay"]: raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) - if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, - ParallelMode.AUTO_PARALLEL]: + if _get_parallel_mode() != ParallelMode.DATA_PARALLEL: raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format (_get_parallel_mode())) self.dev_num = _get_device_num() @@ -175,6 +168,7 @@ class Optimizer(Cell): self.param_names = [] for param in self.parameters: self.param_names.append(param.name) + else: self.optim_filter = (True,) * self.param_length @@ -191,13 +185,12 @@ class Optimizer(Cell): Returns: tuple[Tensor], The gradients after weight decay. """ - params = self.parameters - if self.is_group: - if self.exec_weight_decay: + if self.exec_weight_decay: + params = self.parameters + if self.is_group: gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, params, gradients) - else: - if self.weight_decay > 0: + else: gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, params, gradients) @@ -223,24 +216,53 @@ class Optimizer(Cell): return gradients - def _get_single_lr(self, learning_rate): - """Get learning rate in Tensor type.""" - if isinstance(learning_rate, float): + def _preprocess_weight_decay(self, weight_decay): + """Check weight decay, and convert int to float.""" + if isinstance(weight_decay, (float, int)): + weight_decay = float(weight_decay) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + return weight_decay + raise TypeError("Weight decay should be int or float.") + + def _preprocess_single_lr(self, learning_rate): + """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" + if isinstance(learning_rate, (float, int)): + learning_rate = float(learning_rate) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) - lr = Tensor(learning_rate, mstype.float32) - elif isinstance(learning_rate, Iterable): - lr = Tensor(np.array(list(learning_rate)).astype(np.float32)) - elif isinstance(learning_rate, Tensor): + return learning_rate + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: + return learning_rate + + self.dynamic_lr = True + if isinstance(learning_rate, Iterable): + return Tensor(np.array(list(learning_rate)).astype(np.float32)) + if isinstance(learning_rate, Tensor): if learning_rate.dim() > 1: - raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," + raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1," f"but got {learning_rate.dim()}.") if learning_rate.dim() == 1 and learning_rate.size() < 2: - logger.warning("If want to use the dynamic learning rate, please make sure that the number " - "of elements in the list, tuple or tensor passed is greater than 1.") - lr = learning_rate - else: - raise TypeError("Learning rate should be float, Tensor or Iterable.") - return lr + logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number" + "of elements in the tensor passed is greater than 1.") + return learning_rate + if isinstance(learning_rate, LearningRateSchedule): + return learning_rate + raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.") + + def _build_single_lr(self, learning_rate, name): + """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule.""" + if isinstance(learning_rate, float): + learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name) + if self.is_group_lr and self.dynamic_lr: + learning_rate = _ConvertToCell(learning_rate) + return learning_rate + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: + learning_rate = Parameter(learning_rate, name) + if self.is_group_lr and self.dynamic_lr: + learning_rate = _ConvertToCell(learning_rate) + return learning_rate + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: + return _IteratorLearningRate(learning_rate, name) + return learning_rate def _check_group_params(self, parameters): """Check group params.""" @@ -268,13 +290,12 @@ class Optimizer(Cell): def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" self._check_group_params(parameters) - if self.dynamic_lr: - dynamic_lr_length = learning_rate.size() + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: + tensor_lr_length = learning_rate.size() else: - dynamic_lr_length = 0 + tensor_lr_length = 0 for group_param in parameters: - lr_length = dynamic_lr_length if 'order_params' in group_param.keys(): if len(group_param.keys()) > 1: raise ValueError("The order params dict in group parameters should " @@ -286,53 +307,38 @@ class Optimizer(Cell): if 'lr' in group_param.keys(): self.is_group_lr = True - self._get_single_lr(group_param['lr']) - if isinstance(group_param['lr'], Iterable): - lr_length = len(group_param['lr']) - self.dynamic_lr = True - elif isinstance(group_param['lr'], Tensor): - lr_length = group_param['lr'].size() - self.dynamic_lr = True + group_lr = self._preprocess_single_lr(group_param['lr']) - if dynamic_lr_length not in (lr_length, 0): - raise ValueError("The dynamic learning rate in group should be the same size.") - - dynamic_lr_length = lr_length - self.dynamic_lr_length = dynamic_lr_length + if isinstance(group_lr, Tensor) and group_lr.dim() == 1: + group_lr_length = group_lr.size() + if tensor_lr_length == 0: + tensor_lr_length = group_lr_length + elif group_lr_length != tensor_lr_length: + raise ValueError("The Tensor type dynamic learning rate in group should be the same size.") def _init_group_params(self, parameters, learning_rate, weight_decay): """Init learning rate or weight decay in group params.""" - origin_dynamic_lr = self.dynamic_lr self._parse_group_params(parameters, learning_rate) - if self.dynamic_lr and not origin_dynamic_lr: - self.gather = P.GatherV2() - self.assignadd = P.AssignAdd() - self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') + default_lr = self._build_single_lr(learning_rate, 'learning_rate') params_store = [] - for group_param in parameters: + for group_num, group_param in enumerate(parameters): if 'order_params' in group_param.keys(): ordered_parameters = group_param['order_params'] continue self.group_params += group_param['params'] + if 'lr' in group_param.keys(): - params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor)) - if self.dynamic_lr and not params_dynamic_lr: - lr = Tensor(np.array([group_param['lr']] * self.dynamic_lr_length).astype(np.float32)) - else: - lr = self._get_single_lr(group_param['lr']) + lr_param_name = 'learning_rate_group_' + str(group_num) + lr = self._preprocess_single_lr(group_param['lr']) + lr = self._build_single_lr(lr, lr_param_name) else: - if self.dynamic_lr and not origin_dynamic_lr: - lr = Tensor(np.array([self.scalar_lr] * self.dynamic_lr_length).astype(np.float32)) - else: - lr = learning_rate + lr = default_lr if 'weight_decay' in group_param.keys(): - validator.check_float_legal_value('weight_decay', group_param['weight_decay'], None) - validator.check_number_range('weight_decay', group_param['weight_decay'], 0.0, float("inf"), - Rel.INC_LEFT, self.cls_name) - weight_decay_ = group_param['weight_decay'] * self.loss_scale + cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay']) + weight_decay_ = cur_weight_decay * self.loss_scale else: weight_decay_ = weight_decay * self.loss_scale @@ -346,7 +352,7 @@ class Optimizer(Cell): raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") params_store.append(param.name) - self.group_lr.append(Parameter(lr, name="lr_" + param.name)) + self.group_lr.append(lr) self.group_weight_decay.append(weight_decay_) if self.is_group_params_ordered: @@ -382,19 +388,17 @@ class Optimizer(Cell): Returns: float, the learning rate of current step. """ - if self.is_group_lr: - lr = self.learning_rate - if self.dynamic_lr: + lr = self.learning_rate + if self.dynamic_lr: + if self.is_group_lr: lr = () - for i in range(self.param_length): - current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) + for learning_rate in self.learning_rate: + current_dynamic_lr = learning_rate(self.global_step) lr += (current_dynamic_lr,) - F.control_depend(lr, self.assignadd(self.global_step, 1)) - else: - lr = self.learning_rate - if self.dynamic_lr: - lr = self.gather(self.learning_rate, self.global_step, 0) - F.control_depend(lr, self.assignadd(self.global_step, 1)) + else: + lr = self.learning_rate(self.global_step) + + F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr def get_lr_parameter(self, param): @@ -407,29 +411,32 @@ class Optimizer(Cell): Returns: Parameter, single `Parameter` or `list[Parameter]` according to the input type. """ - if not isinstance(param, (Parameter, list)): - raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") + def get_lr_value(learning_rate): + if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)): + return learning_rate.learning_rate - if isinstance(param, list): - lr = [] - for p in param: - validator.check_value_type("parameter", p, [Parameter], self.cls_name) - if p not in self.parameters: - raise ValueError(f"The parameter {p.name} is not in optimizer.") - if self.is_group_lr: - index = self.parameters.index(p) - lr.append(self.learning_rate[index]) - else: - lr.append(self.learning_rate) + return learning_rate + + if isinstance(param, Parameter): + param_list = [param] + elif isinstance(param, list): + param_list = param else: - if param not in self.parameters: - raise ValueError(f"The parameter {param.name} is not in optimizer.") + raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") + + lr = [] + ids = [id(p) for p in self.parameters] + for p in param_list: + validator.check_value_type("parameter", p, [Parameter], self.cls_name) + if id(p) not in ids: + raise ValueError(f"The parameter {p.name} is not in optimizer.") if self.is_group_lr: - index = self.parameters.index(param) - lr = self.learning_rate[index] + index = ids.index(id(p)) + lr.append(get_lr_value(self.learning_rate[index])) else: - lr = self.learning_rate - return lr + lr.append(get_lr_value(self.learning_rate)) + + return lr if isinstance(param, list) else lr[0] def _get_parameter_group_id(self): """ @@ -460,7 +467,7 @@ class Optimizer(Cell): param_group.append(F.make_tuple()) key_group.append(F.make_tuple()) for i in range(self.param_length): - param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) + param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],) key = P.MakeRefKey(self.param_names[i])() key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) new_param_group = [] @@ -470,9 +477,9 @@ class Optimizer(Cell): new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) - status = True + status = F.control_depend(optim_result, new_param_group[0][0]) for i in range(self.dev_num - 1): - status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) + status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status) return status @@ -486,12 +493,14 @@ op_gather = P.GatherV2() _apply_decay = C.MultitypeFuncGraph("apply_decay") -@_apply_decay.register("Number", "Bool", "Tensor", "Tuple") +@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor") def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: - weight = op_gather(weight, gradient[0], 0) - return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2] + indices = gradient.indices + values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values)) + shape = gradient.dense_shape + return RowTensor(indices, values, shape) return gradient @@ -514,9 +523,39 @@ def tensor_grad_scale(scale, grad): return grad * scale -@_grad_scale.register("Number", "Tuple") +@_grad_scale.register("Number", "RowTensor") def tensor_grad_scale_with_sparse(scale, grad): """Get grad with scale.""" if scale == 1.0: return grad - return grad[0], grad[1] * scale, grad[2] + return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) + + +class _ConvertToCell(LearningRateSchedule): + """Inner api, convert learning rate of scalar to LearningRateSchedule.""" + def __init__(self, learning_rate): + super(_ConvertToCell, self).__init__() + if not isinstance(learning_rate, Parameter): + raise TypeError('Learning rate must be Parameter.') + self.learning_rate = learning_rate + + def construct(self, global_step): + return self.learning_rate + 1.0 - 1.0 + + +class _IteratorLearningRate(LearningRateSchedule): + """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule.""" + def __init__(self, learning_rate, name): + super(_IteratorLearningRate, self).__init__() + if isinstance(learning_rate, Tensor): + if learning_rate.dim() != 1: + raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1," + f"but got {learning_rate.dim()}.") + else: + raise TypeError("Learning rate should be Tensor.") + + self.learning_rate = Parameter(learning_rate, name) + self.gather = P.GatherV2() + + def construct(self, global_step): + return self.gather(self.learning_rate, global_step, 0) diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index eba2b890df..c21874c8fd 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -22,16 +22,17 @@ from .optimizer import Optimizer _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") -@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tuple", "Tensor", "Tensor") -def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): +@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", + "Tensor") +def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" success = True - success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient[1], gradient[0])) + success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices)) return success @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): +def _tensor_run_opt(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): """Apply proximal_ada_grad optimizer to the weight parameter.""" success = True success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient)) @@ -58,20 +59,47 @@ class ProximalAdagrad(Optimizer): `_. Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. - learning_rate (float): The learning rate value, must be greater than or equal to zero. Default: 0.001. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 0.001. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0. use_locking (bool): If True use locks for update operation. Default: False. loss_scale (float): Value for the loss scale. It should be greater than 0.0. Default: 1.0. - wegith_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0. + weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0. Inputs: - **grads** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is as same as the `params` @@ -82,21 +110,31 @@ class ProximalAdagrad(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.ProximalAdagrad(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> opt = nn.ProximalAdagrad(net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0, use_locking=False, loss_scale=1.0, weight_decay=0.0): super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(accum, l1, l2, use_locking, self.cls_name) self.accum = self.parameters.clone(prefix="accum", init=accum) self.l1 = Tensor(l1, mstype.float32) self.l2 = Tensor(l2, mstype.float32) - self.weight_decay = weight_decay self.hyper_map = C.HyperMap() self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking) @@ -106,7 +144,11 @@ class ProximalAdagrad(Optimizer): accum = self.accum grads = self.decay_weight(grads) grads = self.scale_grad(grads) - lr = self.learning_rate - success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2), - grads, params, accum) + lr = self.get_lr() + if self.is_group_lr: + success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, + grads, params, accum) + else: + success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr), + grads, params, accum) return success diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index c4d3347038..fc5ebc8df9 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -44,12 +44,9 @@ class RMSProp(Optimizer): Implements Root Mean Squared Propagation (RMSProp) algorithm. Note: - The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. @@ -109,13 +106,14 @@ class RMSProp(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. Default: 0.1. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 0.1. decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or greater than 0. Default: 0.0. @@ -143,7 +141,7 @@ class RMSProp(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] - >>> opt = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) + >>> optim = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 382f095627..1d79cfad42 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -40,14 +40,26 @@ class SGD(Optimizer): momentum in deep learning `_. Note: - The SGD optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + .. math:: + v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening) + + If nesterov is True: + .. math:: + p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1}) - To improve parameter groups performance, the customized order of parameters can be supported. + If nesterov is Flase: + .. math:: + p_{t+1} = p_{t} - lr \ast v_{t+1} + + To be noticed, for the first step, v_{t+1} = gradient + + Here : where p, v and u denote the parameters, accum, and momentum respectively. Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, @@ -66,18 +78,19 @@ class SGD(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. It should be equal to or - greater than 0. Default: 0.1. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 0.1. momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0. dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0. weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0. - nesterov (bool): Enables the Nesterov momentum. Default: False. + nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive, + and dampening must equal to 0.0. Default: False. loss_scale (float): A floating point value for the loss scale, which should be larger than 0.0. Default: 1.0. @@ -101,7 +114,7 @@ class SGD(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] - >>> opt = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) + >>> optim = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. @@ -135,6 +148,10 @@ class SGD(Optimizer): weight_decay = float(weight_decay) validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) + + if nesterov and (momentum <= 0.0 or dampening != 0.0): + raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," + "but got momentum {}, dampening {}".format(momentum, dampening)) self.nesterov = nesterov self.opt = P.SGD(dampening, weight_decay, nesterov) diff --git a/mindspore/nn/probability/__init__.py b/mindspore/nn/probability/__init__.py new file mode 100644 index 0000000000..5bc8a54c40 --- /dev/null +++ b/mindspore/nn/probability/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Probability. + +The high-level components used to construct the probabilistic network. +""" + +from . import bijector +from . import distribution diff --git a/mindspore/nn/probability/bijector/__init__.py b/mindspore/nn/probability/bijector/__init__.py new file mode 100644 index 0000000000..3108742aea --- /dev/null +++ b/mindspore/nn/probability/bijector/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Bijector. + +The high-level components(Bijectors) used to construct the probabilistic network. +""" + +from .bijector import Bijector +from .power_transform import PowerTransform +from .exp import Exp + +__all__ = ['Bijector', + 'PowerTransform', + 'Exp'] diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py new file mode 100644 index 0000000000..22777231f6 --- /dev/null +++ b/mindspore/nn/probability/bijector/bijector.py @@ -0,0 +1,130 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bijector""" +from mindspore.nn.cell import Cell +from ..distribution import Distribution +from ..distribution import TransformedDistribution + +class Bijector(Cell): + """ + Bijecotr class. + + Args: + is_constant_jacobian (bool): if the bijector has constant derivative. Default: False. + is_injective (bool): if the bijector is an one-to-one mapping. Default: True. + name (str): name of the bijector. Default: None. + dtype (mstype): type of the distribution the bijector can operate on. Default: None. + param (dict): parameters used to initialize the bijector. Default: None. + """ + def __init__(self, + is_constant_jacobian=False, + is_injective=True, + name=None, + dtype=None, + param=None): + + """ + Constructor of bijector class. + """ + super(Bijector, self).__init__() + self._name = name + self._dtype = dtype + self._parameters = {} + # parsing parameters + for k in param.keys(): + if not(k == 'self' or k.startswith('_')): + self._parameters[k] = param[k] + self._is_constant_jacobian = is_constant_jacobian + self._is_injective = is_injective + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._dtype + + @property + def parameters(self): + return self._parameters + + @property + def is_constant_jacobian(self): + return self._is_constant_jacobian + + @property + def is_injective(self): + return self._is_injective + + def forward(self, *args): + """ + Forward transformation: transform the input value to another distribution. + """ + return self._forward(*args) + + def inverse(self, *args): + """ + Inverse transformation: transform the input value back to the original distribution. + """ + return self._inverse(*args) + + def forward_log_jacobian(self, *args): + """ + Logarithm of the derivative of forward transformation. + """ + return self._forward_log_jacobian(*args) + + def inverse_log_jacobian(self, *args): + """ + Logarithm of the derivative of forward transformation. + """ + return self._inverse_log_jacobian(*args) + + def __call__(self, *args): + """ + Call Bijector directly. + This __call__ may go into two directions: + If args[0] is a distribution instance, the call will generate a new distribution derived from + the input distribution. + Otherwise, input[0] should be the name of a bijector function, e.g. "forward", then this call will + go in the construct and invoke the correstpoding bijector function. + + Args: + *args: args[0] shall be either a distribution or the name of a bijector function. + """ + if isinstance(args[0], Distribution): + return TransformedDistribution(self, args[0]) + return super(Bijector, self).__call__(*args) + + def construct(self, name, *args): + """ + Override construct in Cell. + + Args: + *inputs: inputs[0] is always the name of a function. + + Notes: + Always raise RuntimeError as Distribution should not be called directly. + """ + if name == 'forward': + return self.forward(*args) + if name == 'inverse': + return self.inverse(*args) + if name == 'forward_log_jacobian': + return self.forward_log_jacobian(*args) + if name == 'inverse_log_jacobian': + return self.inverse_log_jacobian(*args) + return None diff --git a/mindspore/nn/probability/bijector/exp.py b/mindspore/nn/probability/bijector/exp.py new file mode 100644 index 0000000000..0f79a1abf2 --- /dev/null +++ b/mindspore/nn/probability/bijector/exp.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Power Bijector""" +from .power_transform import PowerTransform + +class Exp(PowerTransform): + r""" + Exponential Bijector. + This Bijector performs the operation: Y = exp(x). + + Examples: + >>> # To initialize a Exp bijector + >>> import mindspore.nn.probability.bijector as msb + >>> n = msb.Exp() + >>> + >>> # To use Exp distribution in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.e1 = msb.Exp() + >>> + >>> def construct(self, value): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'forward' with the name of the function + >>> ans1 = self.e1.forward(value) + >>> ans2 = self.e1.backward(value) + """ + def __init__(self, + name='Exp'): + param = dict(locals()) + super(Exp, self).__init__(name=name, param=param) diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py new file mode 100644 index 0000000000..456f635818 --- /dev/null +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Power Bijector""" +from mindspore.ops import operations as P +from mindspore._checkparam import Validator as validator +from .bijector import Bijector + +class PowerTransform(Bijector): + r""" + Power Bijector. + This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c is power. + + The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`. + + This bijector is equivalent to the `Exp` bijector when `c=0` + + Args: + power (int or float): scale factor. Default: 0. + + Examples: + >>> # To initialize a PowerTransform bijector of power 0.5 + >>> import mindspore.nn.probability.bijector as msb + >>> n = msb.PowerTransform(0.5) + >>> + >>> # To use PowerTransform distribution in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.p1 = msb.PowerTransform(0.5) + >>> + >>> def construct(self, value): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'forward' with the name of the function + >>> ans = self.p1.forward(, value) + """ + def __init__(self, + power=0, + name='PowerTransform', + param=None): + param = dict(locals()) if param is None else param + super(PowerTransform, self).__init__(name=name, param=param) + validator.check_value_type('power', power, [int, float], self.name) + self._power = power + self.pow = P.Pow() + self.exp = P.Exp() + self.log = P.Log() + self.log1p = self._log1p_by_step + self.expm1 = self._expm1_by_step + + def _log1p_by_step(self, x): + """ + Log1p ops on GPU device or when device_target == GPU. + """ + return self.log(x + 1.0) + + def _expm1_by_step(self, x): + """ + Expm1 ops on GPU device or when device_target == GPU. + """ + return self.exp(x) - 1.0 + + @property + def power(self): + return self._power + + def extend_repr(self): + str_info = f'power = {self.power}' + return str_info + + def shape_mapping(self, shape): + return shape + + def _forward(self, x): + if self.power == 0: + return self.exp(x) + return self.exp(self.log1p(x * self.power) / self.power) + + def _inverse(self, y): + if self.power == 0: + return self.log(y) + return self.expm1(self.log(y) * self.power) / self.power + + def _forward_log_jacobian(self, x): + r""" + .. math: + if c == 0: + f(x) = e^x + f'(x) = e^x + \log(f'(x)) = \log(e^x) = x + else: + f(x) = e^\frac{\log(xc + 1)}{c} + f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} + \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) + """ + if self.power == 0: + return x + return (1. / self.power - 1) * self.log1p(x * self.power) + + def _inverse_log_jacobian(self, y): + r""" + .. math: + if c == 0: + f(x) = \log(x) + f'(x) = \frac{1}{x} + \log(f'(x)) = \log(\frac{1}{x}) = -\log(x) + else: + f(x) = \frac{e^\log(y)*c + 1}{c} + f'(x) = \frac{e^c\log(y)}{y} + \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) + """ + return (self.power - 1) * self.log(y) diff --git a/mindspore/nn/probability/distribution/__init__.py b/mindspore/nn/probability/distribution/__init__.py new file mode 100644 index 0000000000..ea6b743e29 --- /dev/null +++ b/mindspore/nn/probability/distribution/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distribution. + +The high-level components(Distributions) used to construct the probabilistic network. +""" + +from .distribution import Distribution +from .transformed_distribution import TransformedDistribution +from .normal import Normal +from .bernoulli import Bernoulli +from .exponential import Exponential +from .uniform import Uniform +from .geometric import Geometric + +__all__ = ['Distribution', + 'TransformedDistribution', + 'Normal', + 'Bernoulli', + 'Exponential', + 'Uniform', + 'Geometric',] diff --git a/mindspore/nn/probability/distribution/_utils/__init__.py b/mindspore/nn/probability/distribution/_utils/__init__.py new file mode 100644 index 0000000000..f9cd3d3c2e --- /dev/null +++ b/mindspore/nn/probability/distribution/_utils/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distribution operation utility functions. +""" +from .utils import * + +__all__ = ['convert_to_batch', + 'cast_to_tensor', + 'check_greater', + 'check_greater_equal_zero', + 'check_greater_zero', + 'calc_broadcast_shape_from_param', + 'check_scalar_from_param', + 'check_prob'] diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py new file mode 100644 index 0000000000..aeccfc2b8f --- /dev/null +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -0,0 +1,253 @@ + +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utitly functions to help distribution class.""" +import numpy as np +from mindspore.ops import _utils as utils +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import composite as C +import mindspore.nn as nn + +def cast_to_tensor(t, dtype=mstype.float32): + """ + Cast an user input value into a Tensor of dtype. + If the input t is of type Parameter, t is directly returned as a Parameter. + + Args: + t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. + dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. + + Raises: + RuntimeError: if t cannot be cast to Tensor. + + Returns: + Tensor. + """ + if isinstance(t, Parameter): + return t + if isinstance(t, Tensor): + #check if the Tensor in shape of Tensor(4) + if t.dim() == 0: + value = t.asnumpy() + return Tensor([t], dtype=dtype) + #convert the type of tensor to dtype + t.set_dtype(dtype) + return t + if isinstance(t, (list, np.ndarray)): + return Tensor(t, dtype=dtype) + if np.isscalar(t): + return Tensor([t], dtype=dtype) + raise RuntimeError("Input type is not supported.") + +def convert_to_batch(t, batch_shape, dtype): + """ + Convert a Tensor to a given batch shape. + + Args: + t (Tensor, Parameter): Tensor to be converted. + batch_shape (tuple): desired batch shape. + dtype (mindspore.dtype): desired dtype. + + Raises: + RuntimeError: if the converison cannot be done. + + Returns: + Tensor, with shape of batch_shape. + """ + if isinstance(t, Parameter): + return t + if isinstance(t, Tensor): + return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=dtype) + return Tensor(np.broadcast_to(t, batch_shape), dtype=dtype) + +def check_scalar_from_param(params): + """ + Check if params are all scalars. + + Args: + params (dict): parameters used to initialize distribution. + + Notes: String parameters are excluded. + """ + for value in params.values(): + if isinstance(value, Parameter): + return False + if isinstance(value, (str, type(params['dtype']))): + continue + elif np.isscalar(value): + continue + else: + return False + return True + + +def calc_broadcast_shape_from_param(params): + """ + Calculate the broadcast shape from params. + + Args: + params (dict): parameters used to initialize distribution. + + Returns: + tuple. + """ + broadcast_shape = [] + for value in params.values(): + if isinstance(value, (str, type(params['dtype']))): + continue + if value is None: + return None + if isinstance(value, Parameter): + value_t = value.default_input + else: + value_t = cast_to_tensor(value, params['dtype']) + broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) + return tuple(broadcast_shape) + +def check_greater_equal_zero(value, name): + """ + Check if the given Tensor is greater zero. + + Args: + value (Tensor, Parameter): value to be checked. + name (str) : name of the value. + + Raises: + ValueError: if the input value is less than zero. + + """ + if isinstance(value, Parameter): + if not isinstance(value.default_input, Tensor): + return + value = value.default_input + comp = np.less(value.asnumpy(), np.zeros(value.shape)) + if comp.any(): + raise ValueError(f'{name} should be greater than ot equal to zero.') + +def check_greater_zero(value, name): + """ + Check if the given Tensor is strictly greater than zero. + + Args: + value (Tensor, Parameter): value to be checked. + name (str) : name of the value. + + Raises: + ValueError: if the input value is less than or equal to zero. + + """ + if isinstance(value, Parameter): + if isinstance(value.default_input, MetaTensor): + return + value = value.default_input + comp = np.less(np.zeros(value.shape), value.asnumpy()) + if not comp.all(): + raise ValueError(f'{name} should be greater than zero.') + +def check_greater(a, b, name_a, name_b): + """ + Check if Tensor b is strictly greater than Tensor a. + + Args: + a (Tensor, Parameter): input tensor a. + b (Tensor, Parameter): input tensor b. + name_a (str): name of Tensor_a. + name_b (str): name of Tensor_b. + + Raises: + ValueError: if b is less than or equal to a + """ + if isinstance(a, Parameter) or isinstance(b, Parameter): + return + comp = np.less(a.asnumpy(), b.asnumpy()) + if not comp.all(): + raise ValueError(f'{name_a} should be less than {name_b}') + + +def check_prob(p): + """ + Check if p is a proper probability, i.e. 0 <= p <=1. + + Args: + p (Tensor, Parameter): value to be checked. + + Raises: + ValueError: if p is not a proper probability. + """ + if isinstance(p, Parameter): + if not isinstance(p.default_input, Tensor): + return + p = p.default_input + comp = np.less(p.asnumpy(), np.zeros(p.shape)) + if comp.any(): + raise ValueError('Probabilities should be greater than or equal to zero') + comp = np.greater(p.asnumpy(), np.ones(p.shape)) + if comp.any(): + raise ValueError('Probabilities should be less than or equal to one') + + +def logits_to_probs(logits, is_binary=False): + """ + converts logits into probabilities. + Args: + logits (Tensor) + is_binary (bool) + """ + if is_binary: + return nn.sigmoid()(logits) + return nn.softmax(axis=-1)(logits) + + +def clamp_probs(probs): + """ + clamp probs boundary + Args: + probs (Tensor) + """ + eps = P.Eps()(probs) + return C.clip_by_value(probs, eps, 1-eps) + + +def probs_to_logits(probs, is_binary=False): + """ + converts probabilities into logits. + Args: + probs (Tensor) + is_binary (bool) + """ + ps_clamped = clamp_probs(probs) + if is_binary: + return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) + return P.Log()(ps_clamped) + +def check_tensor_type(name, inputs, valid_type): + """ + Check if inputs is proper. + + Args: + inputs: Tensor to be checked. + name: inputs name + + Raises: + ValueError: if inputs is not a proper Tensor. + """ + if not isinstance(inputs, Tensor): + raise TypeError(f"{name} should be a Tensor") + inputs = P.DType()(inputs) + if inputs not in valid_type: + raise TypeError(f"{name} dtype is invalid") diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py new file mode 100644 index 0000000000..0aaeabf9a2 --- /dev/null +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -0,0 +1,261 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bernoulli Distribution""" +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from .distribution import Distribution +from ._utils.utils import cast_to_tensor, check_prob + +class Bernoulli(Distribution): + """ + Bernoulli Distribution. + + Args: + probs (float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + name (str): name of the distribution. Default: Bernoulli. + + Note: + probs should be proper probabilities (0 <= p <= 1). + Dist_spec_args is probs. + + Examples: + >>> # To initialize a Bernoulli distribution of prob 0.5 + >>> import mindspore.nn.probability.distribution as msd + >>> b = msd.Bernoulli(0.5, dtype=mstype.int32) + >>> + >>> # The following creates two independent Bernoulli distributions + >>> b = msd.Bernoulli([0.5, 0.5], dtype=mstype.int32) + >>> + >>> # A Bernoulli distribution can be initilized without arguments + >>> # In this case, probs must be passed in through args during function calls. + >>> b = msd.Bernoulli(dtype=mstype.int32) + >>> + >>> # To use Bernoulli in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.b1 = msd.Bernoulli(0.5, dtype=mstype.int32) + >>> self.b2 = msd.Bernoulli(dtype=mstype.int32) + >>> + >>> # All the following calls in construct are valid + >>> def construct(self, value, probs_b, probs_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.b1.prob(value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.b1.prob(value, probs_b) + >>> + >>> # probs must be passed in during function calls + >>> ans = self.b2.prob(value, probs_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # Will return 0.5 + >>> ans = self.b1.mean() + >>> # Will return probs_b + >>> ans = self.b1.mean(probs_b) + >>> + >>> # probs must be passed in during function calls + >>> ans = self.b2.mean(probs_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.b1.kl_loss('Bernoulli', probs_b) + >>> ans = self.b1.kl_loss('Bernoulli', probs_b, probs_a) + >>> + >>> # Additional probs_a must be passed in through + >>> ans = self.b2.kl_loss('Bernoulli', probs_b, probs_a) + >>> + >>> # Sample + >>> ans = self.b1.sample() + >>> ans = self.b1.sample((2,3)) + >>> ans = self.b1.sample((2,3), probs_b) + >>> ans = self.b2.sample((2,3), probs_a) + """ + + def __init__(self, + probs=None, + seed=0, + dtype=mstype.int32, + name="Bernoulli"): + """ + Constructor of Bernoulli distribution. + """ + param = dict(locals()) + super(Bernoulli, self).__init__(dtype, name, param) + if probs is not None: + self._probs = cast_to_tensor(probs, dtype=mstype.float32) + check_prob(self.probs) + else: + self._probs = probs + self.seed = seed + + # ops needed for the class + self.cast = P.Cast() + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.erf = P.Erf() + self.fill = P.Fill() + self.log = P.Log() + self.less = P.Less() + self.shape = P.Shape() + self.select = P.Select() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'probs = {self.probs}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def probs(self): + """ + Returns the probability for the outcome is 1. + """ + return self._probs + + def _mean(self, probs1=None): + r""" + .. math:: + MEAN(B) = probs1 + """ + return self.probs if probs1 is None else probs1 + + def _mode(self, probs1=None): + r""" + .. math:: + MODE(B) = 1 if probs1 > 0.5 else = 0 + """ + probs1 = self.probs if probs1 is None else probs1 + prob_type = self.dtypeop(probs1) + zeros = self.fill(prob_type, self.shape(probs1), 0.0) + ones = self.fill(prob_type, self.shape(probs1), 1.0) + comp = self.less(0.5, probs1) + return self.select(comp, ones, zeros) + + def _var(self, probs1=None): + r""" + .. math:: + VAR(B) = probs1 * probs0 + """ + probs1 = self.probs if probs1 is None else probs1 + probs0 = 1.0 - probs1 + return probs0 * probs1 + + def _entropy(self, probs=None): + r""" + .. math:: + H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) + """ + probs1 = self.probs if probs is None else probs + probs0 = 1 - probs1 + return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) + + def _cross_entropy(self, dist, probs1_b, probs1_a=None): + """ + Evaluate cross_entropy between Bernoulli distributions. + + Args: + dist (str): type of the distributions. Should be "Bernoulli" in this case. + probs1_b (Tensor): probs1 of distribution b. + probs1_a (Tensor): probs1 of distribution a. Default: self.probs. + """ + if dist == 'Bernoulli': + return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) + return None + + def _prob(self, value, probs=None): + r""" + pmf of Bernoulli distribution. + + Args: + value (Tensor): a Tensor composed of only zeros and ones. + probs (Tensor): probability of outcome is 1. Default: self.probs. + + .. math:: + pmf(k) = probs1 if k = 1; + pmf(k) = probs0 if k = 0; + """ + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + return (probs1 * value) + (probs0 * (1.0 - value)) + + def _cdf(self, value, probs=None): + r""" + cdf of Bernoulli distribution. + + Args: + value (Tensor): value to be evaluated. + probs (Tensor): probability of outcome is 1. Default: self.probs. + + .. math:: + cdf(k) = 0 if k < 0; + cdf(k) = probs0 if 0 <= k <1; + cdf(k) = 1 if k >=1; + """ + probs1 = self.probs if probs is None else probs + prob_type = self.dtypeop(probs1) + value = value * self.fill(prob_type, self.shape(probs1), 1.0) + probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) + comp_zero = self.less(value, 0.0) + comp_one = self.less(value, 1.0) + zeros = self.fill(prob_type, self.shape(value), 0.0) + ones = self.fill(prob_type, self.shape(value), 1.0) + less_than_zero = self.select(comp_zero, zeros, probs0) + return self.select(comp_one, less_than_zero, ones) + + def _kl_loss(self, dist, probs1_b, probs1_a=None): + r""" + Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). + + Args: + dist (str): type of the distributions. Should be "Bernoulli" in this case. + probs1_b (Tensor): probs1 of distribution b. + probs1_a (Tensor): probs1 of distribution a. Default: self.probs. + + .. math:: + KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + + probs0_a * \log(\fract{probs0_a}{probs0_b}) + """ + if dist == 'Bernoulli': + probs1_a = self.probs if probs1_a is None else probs1_a + probs0_a = 1.0 - probs1_a + probs0_b = 1.0 - probs1_b + return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) + return None + + def _sample(self, shape=(), probs=None): + """ + Sampling. + + Args: + shape (tuple): shape of the sample. Default: (). + probs (Tensor): probs1 of the samples. Default: self.probs. + + Returns: + Tensor, shape is shape + batch_shape. + """ + probs1 = self.probs if probs is None else probs + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) + sample = self.less(sample_uniform, probs1) + sample = self.cast(sample, self.dtype) + return sample diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py new file mode 100644 index 0000000000..dd3d39f0d7 --- /dev/null +++ b/mindspore/nn/probability/distribution/distribution.py @@ -0,0 +1,463 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""basic""" +from mindspore.nn.cell import Cell +from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param + +class Distribution(Cell): + """ + Base class for all mathematical distributions. + + Args: + dtype (mindspore.dtype): type of the distribution. + name (str): name of the distribution. + param (dict): parameters used to initialize the distribution. + + Note: + Derived class should override operations such as ,_mean, _prob, + and _log_prob. Arguments should be passed in through *args. + + Dist_spec_args are unique for each type of distribution. For example, mean and sd + are the dist_spec_args for a Normal distribution. + + For all functions, passing in dist_spec_args, are optional. + Passing in the additional dist_spec_args will make the result to be evaluated with + new distribution specified by the dist_spec_args. But it won't change the + original distribuion. + """ + def __init__(self, + dtype, + name, + param): + + """ + Constructor of distribution class. + """ + super(Distribution, self).__init__() + self._name = name + self._dtype = dtype + self._parameters = {} + # parsing parameters + for k in param.keys(): + if not(k == 'self' or k.startswith('_')): + self._parameters[k] = param[k] + # some attributes + self._broadcast_shape = calc_broadcast_shape_from_param( + self.parameters) + self._is_scalar_batch = check_scalar_from_param(self.parameters) + + # set the function to call according to the derived class's attributes + self._set_prob() + self._set_log_prob() + self._set_sd() + self._set_var() + self._set_cdf() + self._set_survival() + self._set_log_cdf() + self._set_log_survival() + self._set_cross_entropy() + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._dtype + + @property + def parameters(self): + return self._parameters + + @property + def is_scalar_batch(self): + return self._is_scalar_batch + + def _set_prob(self): + """ + Set probability funtion based on the availability of _prob and _log_likehood. + """ + if hasattr(self, '_prob'): + self._call_prob = self._prob + elif hasattr(self, '_log_prob'): + self._call_prob = self._calc_prob_from_log_prob + + def _set_sd(self): + """ + Set standard deviation based on the availability of _sd and _var. + """ + if hasattr(self, '_sd'): + self._call_sd = self._sd + elif hasattr(self, '_var'): + self._call_sd = self._calc_sd_from_var + + def _set_var(self): + """ + Set variance based on the availability of _sd and _var. + """ + if hasattr(self, '_var'): + self._call_var = self._var + elif hasattr(self, '_sd'): + self._call_var = self._calc_var_from_sd + + def _set_log_prob(self): + """ + Set log probability based on the availability of _prob and _log_prob. + """ + if hasattr(self, '_log_prob'): + self._call_log_prob = self._log_prob + elif hasattr(self, '_prob'): + self._call_log_prob = self._calc_log_prob_from_prob + + def _set_cdf(self): + """ + Set cdf based on the availability of _cdf and _log_cdf and survival_functions. + """ + if hasattr(self, '_cdf'): + self._call_cdf = self._cdf + elif hasattr(self, '_log_cdf'): + self._call_cdf = self._calc_cdf_from_log_cdf + elif hasattr(self, '_survival_function'): + self._call_cdf = self._calc_cdf_from_survival + elif hasattr(self, '_log_survival'): + self._call_cdf = self._calc_cdf_from_log_survival + + def _set_survival(self): + """ + Set survival function based on the availability of _survival function and _log_survival + and _call_cdf. + """ + if hasattr(self, '_survival_function'): + self._call_survival = self._survival_function + elif hasattr(self, '_log_survival'): + self._call_survival = self._calc_survival_from_log_survival + elif hasattr(self, '_call_cdf'): + self._call_survival = self._calc_survival_from_call_cdf + + def _set_log_cdf(self): + """ + Set log cdf based on the availability of _log_cdf and _call_cdf. + """ + if hasattr(self, '_log_cdf'): + self._call_log_cdf = self._log_cdf + elif hasattr(self, '_call_cdf'): + self._call_log_cdf = self._calc_log_cdf_from_call_cdf + + def _set_log_survival(self): + """ + Set log survival based on the availability of _log_survival and _call_survival. + """ + if hasattr(self, '_log_survival'): + self._call_log_survival = self._log_survival + elif hasattr(self, '_call_survival'): + self._call_log_survival = self._calc_log_survival_from_call_survival + + def _set_cross_entropy(self): + """ + Set log survival based on the availability of _cross_entropy. + """ + if hasattr(self, '_cross_entropy'): + self._call_cross_entropy = self._cross_entropy + + def log_prob(self, *args): + """ + Evaluate the log probability(pdf or pmf) at the given value. + + Note: + Args must include value. + Dist_spec_args are optional. + """ + return self._call_log_prob(*args) + + def _calc_prob_from_log_prob(self, *args): + r""" + Evaluate prob from log probability. + + .. math:: + probability(x) = \exp(log_likehood(x)) + """ + return self.exp(self._log_prob(*args)) + + def prob(self, *args): + """ + Evaluate the probability (pdf or pmf) at given value. + + Note: + Args must include value. + Dist_spec_args are optional. + """ + return self._call_prob(*args) + + def _calc_log_prob_from_prob(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + log_prob(x) = \log(prob(x)) + """ + return self.log(self._prob(*args)) + + def cdf(self, *args): + """ + Evaluate the cdf at given value. + + Note: + Args must include value. + Dist_spec_args are optional. + """ + return self._call_cdf(*args) + + def _calc_cdf_from_log_cdf(self, *args): + r""" + Evaluate cdf from log_cdf. + + .. math:: + cdf(x) = \exp(log_cdf(x)) + """ + return self.exp(self._log_cdf(*args)) + + def _calc_cdf_from_survival(self, *args): + r""" + Evaluate cdf from survival function. + + .. math:: + cdf(x) = 1 - (survival_function(x)) + """ + return 1.0 - self._survival_function(*args) + + def _calc_cdf_from_log_survival(self, *args): + r""" + Evaluate cdf from log survival function. + + .. math:: + cdf(x) = 1 - (\exp(log_survival(x))) + """ + return 1.0 - self.exp(self._log_survival(*args)) + + def log_cdf(self, *args): + """ + Evaluate the log cdf at given value. + + Note: + Args must include value. + Dist_spec_args are optional. + """ + return self._call_log_cdf(*args) + + def _calc_log_cdf_from_call_cdf(self, *args): + r""" + Evaluate log cdf from cdf. + + .. math:: + log_cdf(x) = \log(cdf(x)) + """ + return self.log(self._call_cdf(*args)) + + def survival_function(self, *args): + """ + Evaluate the survival function at given value. + + Note: + Args must include value. + Dist_spec_args are optional. + """ + return self._call_survival(*args) + + def _calc_survival_from_call_cdf(self, *args): + r""" + Evaluate survival function from cdf. + + .. math:: + survival_function(x) = 1 - (cdf(x)) + """ + return 1.0 - self._call_cdf(*args) + + def _calc_survival_from_log_survival(self, *args): + r""" + Evaluate survival function from log survival function. + + .. math:: + survival(x) = \exp(survival_function(x)) + """ + return self.exp(self._log_survival(*args)) + + def log_survival(self, *args): + """ + Evaluate the log survival function at given value. + + Note: + Args must include value. + Dist_spec_args are optional. + """ + return self._call_log_survival(*args) + + def _calc_log_survival_from_call_survival(self, *args): + r""" + Evaluate log survival function from survival function. + + .. math:: + log_survival(x) = \log(survival_function(x)) + """ + return self.log(self._call_survival(*args)) + + def kl_loss(self, *args): + """ + Evaluate the KL divergence, i.e. KL(a||b). + + Note: + Args must include type of the distribution, parameters of distribution b. + Parameters for distribution a are optional. + """ + return self._kl_loss(*args) + + def mean(self, *args): + """ + Evaluate the mean. + + Note: + Dist_spec_args are optional. + """ + return self._mean(*args) + + def mode(self, *args): + """ + Evaluate the mode. + + Note: + Dist_spec_args are optional. + """ + return self._mode(*args) + + def sd(self, *args): + """ + Evaluate the standard deviation. + + Note: + Dist_spec_args are optional. + """ + return self._call_sd(*args) + + def var(self, *args): + """ + Evaluate the variance. + + Note: + Dist_spec_args are optional. + """ + return self._call_var(*args) + + def _calc_sd_from_var(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + STD(x) = \sqrt(VAR(x)) + """ + return self.sqrt(self._var(*args)) + + def _calc_var_from_sd(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + VAR(x) = STD(x) ^ 2 + """ + return self.sq(self._sd(*args)) + + def entropy(self, *args): + """ + Evaluate the entropy. + + Note: + Dist_spec_args are optional. + """ + return self._entropy(*args) + + def cross_entropy(self, *args): + """ + Evaluate the cross_entropy between distribution a and b. + + Note: + Args must include type of the distribution, parameters of distribution b. + Parameters for distribution a are optional. + """ + return self._call_cross_entropy(*args) + + def _calc_cross_entropy(self, *args): + r""" + Evaluate cross_entropy from entropy and kl divergence. + + .. math:: + H(X, Y) = H(X) + KL(X||Y) + """ + return self._entropy(*args) + self._kl_loss(*args) + + def sample(self, *args): + """ + Sampling function. + + Args: + *args (list): arguments passed in through construct. + + Note: + Shape of the sample is default to (). + Dist_spec_args are optional. + """ + return self._sample(*args) + + + def construct(self, name, *args): + """ + Override construct in Cell. + + Note: + Names of supported functions: + 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival' + 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. + + Args: + name (str): name of the function. + *args (list): list of arguments needed for the function. + """ + + if name == 'log_prob': + return self._call_log_prob(*args) + if name == 'prob': + return self._call_prob(*args) + if name == 'cdf': + return self._call_cdf(*args) + if name == 'log_cdf': + return self._call_log_cdf(*args) + if name == 'survival_function': + return self._call_survival(*args) + if name == 'log_survival': + return self._call_log_survival(*args) + if name == 'kl_loss': + return self._kl_loss(*args) + if name == 'mean': + return self._mean(*args) + if name == 'mode': + return self._mode(*args) + if name == 'sd': + return self._call_sd(*args) + if name == 'var': + return self._call_var(*args) + if name == 'entropy': + return self._entropy(*args) + if name == 'cross_entropy': + return self._call_cross_entropy(*args) + if name == 'sample': + return self._sample(*args) + return None diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py new file mode 100644 index 0000000000..74c6a40ab0 --- /dev/null +++ b/mindspore/nn/probability/distribution/exponential.py @@ -0,0 +1,253 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Exponential Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import cast_to_tensor, check_greater_zero + +class Exponential(Distribution): + """ + Example class: Exponential Distribution. + + Args: + rate (float, list, numpy.ndarray, Tensor, Parameter): inverse scale. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Exponential. + + Note: + rate should be strictly greater than 0. + Dist_spec_args is rate. + + Examples: + >>> # To initialize an Exponential distribution of rate 0.5 + >>> import mindspore.nn.probability.distribution as msd + >>> e = msd.Exponential(0.5, dtype=mstype.float32) + >>> + >>> # The following creates two independent Exponential distributions + >>> e = msd.Exponential([0.5, 0.5], dtype=mstype.float32) + >>> + >>> # An Exponential distribution can be initilized without arguments + >>> # In this case, rate must be passed in through args during function calls + >>> e = msd.Exponential(dtype=mstype.float32) + >>> + >>> # To use Exponential in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.e1 = msd.Exponential(0.5, dtype=mstype.float32) + >>> self.e2 = msd.Exponential(dtype=mstype.float32) + >>> + >>> # All the following calls in construct are valid + >>> def construct(self, value, rate_b, rate_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.e1.prob(value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.e1.prob(value, rate_b) + >>> + >>> # Rate must be passed in during function calls + >>> ans = self.e2.prob(value, rate_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage as'mean' + >>> # Will return 2 + >>> ans = self.e1.mean() + >>> # Will return 1 / rate_b + >>> ans = self.e1.mean(rate_b) + >>> + >>> # Rate must be passed in during function calls + >>> ans = self.e2.mean(rate_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.e1.kl_loss('Exponential', rate_b) + >>> ans = self.e1.kl_loss('Exponential', rate_b, rate_a) + >>> + >>> # Additional rate must be passed in + >>> ans = self.e2.kl_loss('Exponential', rate_b, rate_a) + >>> + >>> # Sample + >>> ans = self.e1.sample() + >>> ans = self.e1.sample((2,3)) + >>> ans = self.e1.sample((2,3), rate_b) + >>> ans = self.e2.sample((2,3), rate_a) + """ + + def __init__(self, + rate=None, + seed=0, + dtype=mstype.float32, + name="Exponential"): + """ + Constructor of Exponential distribution. + """ + param = dict(locals()) + super(Exponential, self).__init__(dtype, name, param) + if rate is not None: + self._rate = cast_to_tensor(rate, mstype.float32) + check_greater_zero(self._rate, "rate") + else: + self._rate = rate + + self.minval = np.finfo(np.float).tiny + + # ops needed for the class + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.exp = P.Exp() + self.fill = P.Fill() + self.less = P.Less() + self.log = P.Log() + self.select = P.Select() + self.shape = P.Shape() + self.sqrt = P.Sqrt() + self.sq = P.Square() + self.uniform = P.UniformReal(seed=seed) + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'rate = {self.rate}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def rate(self): + """ + Return rate of the distribution. + """ + return self._rate + + def _mean(self, rate=None): + r""" + .. math:: + MEAN(EXP) = \fract{1.0}{\lambda}. + """ + rate = self.rate if rate is None else rate + return 1.0 / rate + + + def _mode(self, rate=None): + r""" + .. math:: + MODE(EXP) = 0. + """ + rate = self.rate if rate is None else rate + return self.fill(self.dtype, self.shape(rate), 0.) + + def _sd(self, rate=None): + r""" + .. math:: + sd(EXP) = \fract{1.0}{\lambda}. + """ + rate = self.rate if rate is None else rate + return 1.0 / rate + + def _entropy(self, rate=None): + r""" + .. math:: + H(Exp) = 1 - \log(\lambda). + """ + rate = self.rate if rate is None else rate + return 1.0 - self.log(rate) + + + def _cross_entropy(self, dist, rate_b, rate_a=None): + """ + Evaluate cross_entropy between Exponential distributions. + + Args: + dist (str): type of the distributions. Should be "Exponential" in this case. + rate_b (Tensor): rate of distribution b. + rate_a (Tensor): rate of distribution a. Default: self.rate. + """ + if dist == 'Exponential': + return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a) + return None + + def _prob(self, value, rate=None): + r""" + pdf of Exponential distribution. + + Args: + Args: + value (Tensor): value to be evaluated. + rate (Tensor): rate of the distribution. Default: self.rate. + + Note: + Value should be greater or equal to zero. + + .. math:: + pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 + """ + rate = self.rate if rate is None else rate + prob = rate * self.exp(-1. * rate * value) + zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, prob) + + def _cdf(self, value, rate=None): + r""" + cdf of Exponential distribution. + + Args: + value (Tensor): value to be evaluated. + rate (Tensor): rate of the distribution. Default: self.rate. + + Note: + Value should be greater or equal to zero. + + .. math:: + cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 + """ + rate = self.rate if rate is None else rate + cdf = 1.0 - self.exp(-1. * rate * value) + zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, cdf) + + + def _kl_loss(self, dist, rate_b, rate_a=None): + """ + Evaluate exp-exp kl divergence, i.e. KL(a||b). + + Args: + dist (str): type of the distributions. Should be "Exponential" in this case. + rate_b (Tensor): rate of distribution b. + rate_a (Tensor): rate of distribution a. Default: self.rate. + """ + if dist == 'Exponential': + rate_a = self.rate if rate_a is None else rate_a + return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 + return None + + def _sample(self, shape=(), rate=None): + """ + Sampling. + + Args: + shape (tuple): shape of the sample. Default: (). + rate (Tensor): rate of the distribution. Default: self.rate. + + Returns: + Tensor, shape is shape + batch_shape. + """ + rate = self.rate if rate is None else rate + minval = self.const(self.minval) + maxval = self.const(1.0) + sample = self.uniform(shape + self.shape(rate), minval, maxval) + return -self.log(sample) / rate diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py new file mode 100644 index 0000000000..59bc8f0c99 --- /dev/null +++ b/mindspore/nn/probability/distribution/geometric.py @@ -0,0 +1,271 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Geometric Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import cast_to_tensor, check_prob + +class Geometric(Distribution): + """ + Geometric Distribution. + It represents k+1 Bernoulli trials needed to get one success, k is the number of failures. + + Args: + probs (float, list, numpy.ndarray, Tensor, Parameter): probability of success. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + name (str): name of the distribution. Default: Geometric. + + Note: + probs should be proper probabilities (0 <= p <= 1). + Dist_spec_args is probs. + + Examples: + >>> # To initialize a Geometric distribution of prob 0.5 + >>> import mindspore.nn.probability.distribution as msd + >>> n = msd.Geometric(0.5, dtype=mstype.int32) + >>> + >>> # The following creates two independent Geometric distributions + >>> n = msd.Geometric([0.5, 0.5], dtype=mstype.int32) + >>> + >>> # A Geometric distribution can be initilized without arguments + >>> # In this case, probs must be passed in through args during function calls. + >>> n = msd.Geometric(dtype=mstype.int32) + >>> + >>> # To use Geometric in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.g1 = msd.Geometric(0.5, dtype=mstype.int32) + >>> self.g2 = msd.Geometric(dtype=mstype.int32) + >>> + >>> # Tthe following calls are valid in construct + >>> def construct(self, value, probs_b, probs_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.g1.prob(value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.g1.prob(value, probs_b) + >>> + >>> # Probs must be passed in during function calls + >>> ans = self.g2.prob(value, probs_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # Will return 1.0 + >>> ans = self.g1.mean() + >>> # Another possible usage + >>> ans = self.g1.mean(probs_b) + >>> + >>> # Probs must be passed in during function calls + >>> ans = self.g2.mean(probs_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.g1.kl_loss('Geometric', probs_b) + >>> ans = self.g1.kl_loss('Geometric', probs_b, probs_a) + >>> + >>> # Additional probs must be passed in + >>> ans = self.g2.kl_loss('Geometric', probs_b, probs_a) + >>> + >>> # Sample + >>> ans = self.g1.sample() + >>> ans = self.g1.sample((2,3)) + >>> ans = self.g1.sample((2,3), probs_b) + >>> ans = self.g2.sample((2,3), probs_a) + """ + + def __init__(self, + probs=None, + seed=0, + dtype=mstype.int32, + name="Geometric"): + """ + Constructor of Geometric distribution. + """ + param = dict(locals()) + super(Geometric, self).__init__(dtype, name, param) + if probs is not None: + self._probs = cast_to_tensor(probs, dtype=mstype.float32) + check_prob(self._probs) + else: + self._probs = probs + + self.minval = np.finfo(np.float).tiny + + # ops needed for the class + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.fill = P.Fill() + self.floor = P.Floor() + self.issubclass = P.IsSubClass() + self.less = P.Less() + self.log = P.Log() + self.pow = P.Pow() + self.select = P.Select() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'probs = {self.probs}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def probs(self): + """ + Returns the probability for the outcome is 1. + """ + return self._probs + + def _mean(self, probs1=None): + r""" + .. math:: + MEAN(Geo) = \fratc{1 - probs1}{probs1} + """ + probs1 = self.probs if probs1 is None else probs1 + return (1. - probs1) / probs1 + + def _mode(self, probs1=None): + r""" + .. math:: + MODE(Geo) = 0 + """ + probs1 = self.probs if probs1 is None else probs1 + return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) + + def _var(self, probs1=None): + r""" + .. math:: + VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}} + """ + probs1 = self.probs if probs1 is None else probs1 + return (1.0 - probs1) / self.sq(probs1) + + def _entropy(self, probs=None): + r""" + .. math:: + H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} + """ + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 + + def _cross_entropy(self, dist, probs1_b, probs1_a=None): + r""" + Evaluate cross_entropy between Geometric distributions. + + Args: + dist (str): type of the distributions. Should be "Geometric" in this case. + probs1_b (Tensor): probability of success of distribution b. + probs1_a (Tensor): probability of success of distribution a. Default: self.probs. + """ + if dist == 'Geometric': + return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) + return None + + def _prob(self, value, probs=None): + r""" + pmf of Geometric distribution. + + Args: + value (Tensor): a Tensor composed of only natural numbers. + probs (Tensor): probability of success. Default: self.probs. + + .. math:: + pmf(k) = probs0 ^k * probs1 if k >= 0; + pmf(k) = 0 if k < 0. + """ + probs1 = self.probs if probs is None else probs + dtype = self.dtypeop(value) + if self.issubclass(dtype, mstype.int_): + pass + elif self.issubclass(dtype, mstype.float_): + value = self.floor(value) + else: + return None + pmf = self.pow((1.0 - probs1), value) * probs1 + zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, pmf) + + def _cdf(self, value, probs=None): + r""" + cdf of Geometric distribution. + + Args: + value (Tensor): a Tensor composed of only natural numbers. + probs (Tensor): probability of success. Default: self.probs. + + .. math:: + cdf(k) = 1 - probs0 ^ (k+1) if k >= 0; + cdf(k) = 0 if k < 0. + + """ + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + dtype = self.dtypeop(value) + if self.issubclass(dtype, mstype.int_): + pass + elif self.issubclass(dtype, mstype.float_): + value = self.floor(value) + else: + return None + cdf = 1.0 - self.pow(probs0, value + 1.0) + zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, cdf) + + + def _kl_loss(self, dist, probs1_b, probs1_a=None): + r""" + Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). + + Args: + dist (str): type of the distributions. Should be "Geometric" in this case. + probs1_b (Tensor): probability of success of distribution b. + probs1_a (Tensor): probability of success of distribution a. Default: self.probs. + + .. math:: + KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b}) + """ + if dist == 'Geometric': + probs1_a = self.probs if probs1_a is None else probs1_a + probs0_a = 1.0 - probs1_a + probs0_b = 1.0 - probs1_b + return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) + return None + + def _sample(self, shape=(), probs=None): + """ + Sampling. + + Args: + shape (tuple): shape of the sample. Default: (). + probs (Tensor): probability of success. Default: self.probs. + + Returns: + Tensor, shape is shape + batch_shape. + """ + probs = self.probs if probs is None else probs + minval = self.const(self.minval) + maxval = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) + return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py new file mode 100644 index 0000000000..f243a2bc31 --- /dev/null +++ b/mindspore/nn/probability/distribution/normal.py @@ -0,0 +1,263 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Normal Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import convert_to_batch, check_greater_equal_zero + + +class Normal(Distribution): + """ + Normal distribution. + + Args: + mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution. + sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Normal. + + Note: + Standard deviation should be greater than zero. + Dist_spec_args are mean and sd. + + Examples: + >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 + >>> import mindspore.nn.probability.distribution as msd + >>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Normal distributions + >>> n = msd.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> + >>> # A Normal distribution can be initilize without arguments + >>> # In this case, mean and sd must be passed in through args. + >>> n = msd.Normal(dtype=mstype.float32) + >>> + >>> # To use Normal in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.n1 = msd.Nomral(0.0, 1.0, dtype=mstype.float32) + >>> self.n2 = msd.Normal(dtype=mstype.float32) + >>> + >>> # The following calls are valid in construct + >>> def construct(self, value, mean_b, sd_b, mean_a, sd_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.n1.prob(value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.n1.prob(value, mean_b, sd_b) + >>> + >>> # mean and sd must be passed in during function calls + >>> ans = self.n2.prob(value, mean_a, sd_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # will return [0.0] + >>> ans = self.n1.mean() + >>> # will return mean_b + >>> ans = self.n1.mean(mean_b, sd_b) + >>> + >>> # mean and sd must be passed during function calls + >>> ans = self.n2.mean(mean_a, sd_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.n1.kl_loss('Normal', mean_b, sd_b) + >>> ans = self.n1.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a) + >>> + >>> # Additional mean and sd must be passed + >>> ans = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a) + >>> + >>> # Sample + >>> ans = self.n1.sample() + >>> ans = self.n1.sample((2,3)) + >>> ans = self.n1.sample((2,3), mean_b, sd_b) + >>> ans = self.n2.sample((2,3), mean_a, sd_a) + """ + + def __init__(self, + mean=None, + sd=None, + seed=0, + dtype=mstype.float32, + name="Normal"): + """ + Constructor of normal distribution. + """ + param = dict(locals()) + super(Normal, self).__init__(dtype, name, param) + if mean is not None and sd is not None: + self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) + self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) + check_greater_equal_zero(self._sd_value, "Standard deviation") + else: + self._mean_value = mean + self._sd_value = sd + self.seed = seed + + #ops needed for the class + self.const = P.ScalarToArray() + self.erf = P.Erf() + self.exp = P.Exp() + self.expm1 = self._expm1_by_step + self.fill = P.Fill() + self.log = P.Log() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.zeroslike = P.ZerosLike() + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + def _expm1_by_step(self, x): + """ + Expm1 ops under GPU context. + """ + return self.exp(x) - 1.0 + + def _mean(self, mean=None, sd=None): + """ + Mean of the distribution. + """ + mean = self._mean_value if mean is None or sd is None else mean + return mean + + def _mode(self, mean=None, sd=None): + """ + Mode of the distribution. + """ + mean = self._mean_value if mean is None or sd is None else mean + return mean + + def _sd(self, mean=None, sd=None): + """ + Standard deviation of the distribution. + """ + sd = self._sd_value if mean is None or sd is None else sd + return sd + + def _entropy(self, sd=None): + r""" + Evaluate entropy. + + .. math:: + H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) + """ + sd = self._sd_value if sd is None else sd + return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) + + def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): + r""" + Evaluate cross_entropy between normal distributions. + + Args: + dist (str): type of the distributions. Should be "Normal" in this case. + mean_b (Tensor): mean of distribution b. + sd_b (Tensor): standard deviation distribution b. + mean_a (Tensor): mean of distribution a. Default: self._mean_value. + sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. + """ + if dist == 'Normal': + return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a) + return None + + def _log_prob(self, value, mean=None, sd=None): + r""" + Evaluate log probability. + + Args: + value (Tensor): value to be evaluated. + mean (Tensor): mean of the distribution. Default: self._mean_value. + sd (Tensor): standard deviation the distribution. Default: self._sd_value. + + .. math:: + L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) + """ + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) + return unnormalized_log_prob + neg_normalization + + def _cdf(self, value, mean=None, sd=None): + r""" + Evaluate cdf of given value. + + Args: + value (Tensor): value to be evaluated. + mean (Tensor): mean of the distribution. Default: self._mean_value. + sd (Tensor): standard deviation the distribution. Default: self._sd_value. + + .. math:: + cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) + """ + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + sqrt2 = self.sqrt(self.const(2.0)) + adjusted = (value - mean) / (sd * sqrt2) + return 0.5 * (1.0 + self.erf(adjusted)) + + def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): + r""" + Evaluate Normal-Normal kl divergence, i.e. KL(a||b). + + Args: + dist (str): type of the distributions. Should be "Normal" in this case. + mean_b (Tensor): mean of distribution b. + sd_b (Tensor): standard deviation distribution b. + mean_a (Tensor): mean of distribution a. Default: self._mean_value. + sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. + + .. math:: + KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + + 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) + """ + if dist == 'Normal': + mean_a = self._mean_value if mean_a is None else mean_a + sd_a = self._sd_value if sd_a is None else sd_a + diff_log_scale = self.log(sd_a) - self.log(sd_b) + squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) + return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale + return None + + def _sample(self, shape=(), mean=None, sd=None): + """ + Sampling. + + Args: + shape (tuple): shape of the sample. Default: (). + mean (Tensor): mean of the samples. Default: self._mean_value. + sd (Tensor): standard deviation of the samples. Default: self._sd_value. + + Returns: + Tensor, shape is shape + batch_shape. + """ + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) + sample_shape = shape + batch_shape + mean_zero = self.const(0.0) + sd_one = self.const(1.0) + sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) + sample = mean + sample_norm * sd + return sample diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py new file mode 100644 index 0000000000..37f474e943 --- /dev/null +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformed Distribution""" +from mindspore.ops import operations as P +from .distribution import Distribution + +class TransformedDistribution(Distribution): + """ + Transformed Distribution. + This class contains a bijector and a distribution and transforms the original distribution + to a new distribution through the operation defined by the bijector. + + Args: + bijector (Bijector): transformation to perform. + distribution (Distribution): The original distribution. + name (str): name of the transformed distribution. Default: transformed_distribution. + + Note: + The arguments used to initialize the original distribution cannot be None. + For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a + TransformedDistribution since mean and sd are not specified. + """ + def __init__(self, + bijector, + distribution, + name="transformed_distribution"): + """ + Constructor of transformed_distribution class. + """ + param = dict(locals()) + super(TransformedDistribution, self).__init__(distribution.dtype, name, param) + self._bijector = bijector + self._distribution = distribution + self._is_linear_transformation = bijector.is_constant_jacobian + self.exp = P.Exp() + + @property + def bijector(self): + return self._bijector + + @property + def distribution(self): + return self._distribution + + @property + def is_linear_transformation(self): + return self._is_linear_transformation + + def _cdf(self, value): + r""" + .. math:: + Y = g(X) + P(Y <= a) = P(X <= g^{-1}(a)) + """ + inverse_value = self.bijector.inverse(value) + return self.distribution.cdf(inverse_value) + + def _log_prob(self, value): + r""" + .. math:: + Y = g(X) + Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a) + \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a)) + """ + inverse_value = self.bijector.inverse(value) + unadjust_prob = self.distribution.log_prob(inverse_value) + log_jacobian = self.bijector.inverse_log_jacobian(value) + return unadjust_prob + log_jacobian + + def _prob(self, value): + return self.exp(self._log_prob(value)) + + def _sample(self, shape): + org_sample = self.distribution.sample(shape) + return self.bijector.forward(org_sample) + + def _mean(self): + """ + Note: + This function maybe overridden by derived class. + """ + return self.bijector.forward(self.distribution.mean()) diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py new file mode 100644 index 0000000000..2fc459f56d --- /dev/null +++ b/mindspore/nn/probability/distribution/uniform.py @@ -0,0 +1,287 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Uniform Distribution""" +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import convert_to_batch, check_greater + +class Uniform(Distribution): + """ + Example class: Uniform Distribution. + + Args: + low (int, float, list, numpy.ndarray, Tensor, Parameter): lower bound of the distribution. + high (int, float, list, numpy.ndarray, Tensor, Parameter): upper bound of the distribution. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Uniform. + + Note: + low should be stricly less than high. + Dist_spec_args are high and low. + + Examples: + >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 + >>> import mindspore.nn.probability.distribution as msd + >>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Uniform distributions + >>> u = msd.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) + >>> + >>> # A Uniform distribution can be initilized without arguments + >>> # In this case, high and low must be passed in through args during function calls. + >>> u = msd.Uniform(dtype=mstype.float32) + >>> + >>> # To use Uniform in a network + >>> class net(Cell): + >>> def __init__(self) + >>> super(net, self).__init__(): + >>> self.u1 = msd.Uniform(0.0, 1.0, dtype=mstype.float32) + >>> self.u2 = msd.Uniform(dtype=mstype.float32) + >>> + >>> # All the following calls in construct are valid + >>> def construct(self, value, low_b, high_b, low_a, high_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.u1.prob(value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.u1.prob(value, low_b, high_b) + >>> + >>> # High and low must be passed in during function calls + >>> ans = self.u2.prob(value, low_a, high_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean' + >>> # Will return 0.5 + >>> ans = self.u1.mean() + >>> # Will return (low_b + high_b) / 2 + >>> ans = self.u1.mean(low_b, high_b) + >>> + >>> # High and low must be passed in during function calls + >>> ans = self.u2.mean(low_a, high_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.u1.kl_loss('Uniform', low_b, high_b) + >>> ans = self.u1.kl_loss('Uniform', low_b, high_b, low_a, high_a) + >>> + >>> # Additional high and low must be passed + >>> ans = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a) + >>> + >>> # Sample + >>> ans = self.u1.sample() + >>> ans = self.u1.sample((2,3)) + >>> ans = self.u1.sample((2,3), low_b, high_b) + >>> ans = self.u2.sample((2,3), low_a, high_a) + """ + + def __init__(self, + low=None, + high=None, + seed=0, + dtype=mstype.float32, + name="Uniform"): + """ + Constructor of Uniform distribution. + """ + param = dict(locals()) + super(Uniform, self).__init__(dtype, name, param) + if low is not None and high is not None: + self._low = convert_to_batch(low, self._broadcast_shape, dtype) + self._high = convert_to_batch(high, self._broadcast_shape, dtype) + check_greater(self.low, self.high, "low value", "high value") + else: + self._low = low + self._high = high + + # ops needed for the class + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.exp = P.Exp() + self.fill = P.Fill() + self.less = P.Less() + self.lessequal = P.LessEqual() + self.log = P.Log() + self.logicaland = P.LogicalAnd() + self.select = P.Select() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) + self.zeroslike = P.ZerosLike() + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'low = {self.low}, high = {self.high}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def low(self): + """ + Return lower bound of the distribution. + """ + return self._low + + @property + def high(self): + """ + Return upper bound of the distribution. + """ + return self._high + + def _range(self, low=None, high=None): + r""" + Return the range of the distribution. + .. math:: + range(U) = high -low + """ + low = self.low if low is None else low + high = self.high if high is None else high + return high - low + + def _mean(self, low=None, high=None): + r""" + .. math:: + MEAN(U) = \fract{low + high}{2}. + """ + low = self.low if low is None else low + high = self.high if high is None else high + return (low + high) / 2. + + + def _var(self, low=None, high=None): + r""" + .. math:: + VAR(U) = \fract{(high -low) ^ 2}{12}. + """ + low = self.low if low is None else low + high = self.high if high is None else high + return self.sq(high - low) / 12.0 + + def _entropy(self, low=None, high=None): + r""" + .. math:: + H(U) = \log(high - low). + """ + low = self.low if low is None else low + high = self.high if high is None else high + return self.log(high - low) + + def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None): + """ + Evaluate cross_entropy between Uniform distributoins. + + Args: + dist (str): type of the distributions. Should be "Uniform" in this case. + low_b (Tensor): lower bound of distribution b. + high_b (Tensor): upper bound of distribution b. + low_a (Tensor): lower bound of distribution a. Default: self.low. + high_a (Tensor): upper bound of distribution a. Default: self.high. + """ + if dist == 'Uniform': + return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a) + return None + + def _prob(self, value, low=None, high=None): + r""" + pdf of Uniform distribution. + + Args: + value (Tensor): value to be evaluated. + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + .. math:: + pdf(x) = 0 if x < low; + pdf(x) = \fract{1.0}{high -low} if low <= x <= high; + pdf(x) = 0 if x > high; + """ + low = self.low if low is None else low + high = self.high if high is None else high + ones = self.fill(self.dtype, self.shape(value), 1.0) + prob = ones / (high - low) + broadcast_shape = self.shape(prob) + zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) + comp_lo = self.less(value, low) + comp_hi = self.lessequal(value, high) + less_than_low = self.select(comp_lo, zeros, prob) + return self.select(comp_hi, less_than_low, zeros) + + def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None): + """ + Evaluate uniform-uniform kl divergence, i.e. KL(a||b). + + Args: + dist (str): type of the distributions. Should be "Uniform" in this case. + low_b (Tensor): lower bound of distribution b. + high_b (Tensor): upper bound of distribution b. + low_a (Tensor): lower bound of distribution a. Default: self.low. + high_a (Tensor): upper bound of distribution a. Default: self.high. + """ + if dist == 'Uniform': + low_a = self.low if low_a is None else low_a + high_a = self.high if high_a is None else high_a + kl = self.log(high_b - low_b) / self.log(high_a - low_a) + comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) + return self.select(comp, kl, self.log(self.zeroslike(kl))) + return None + + def _cdf(self, value, low=None, high=None): + r""" + cdf of Uniform distribution. + + Args: + value (Tensor): value to be evaluated. + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + .. math:: + cdf(x) = 0 if x < low; + cdf(x) = \fract{x - low}{high -low} if low <= x <= high; + cdf(x) = 1 if x > high; + """ + low = self.low if low is None else low + high = self.high if high is None else high + prob = (value - low) / (high - low) + broadcast_shape = self.shape(prob) + zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) + ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0) + comp_lo = self.less(value, low) + comp_hi = self.less(value, high) + less_than_low = self.select(comp_lo, zeros, prob) + return self.select(comp_hi, less_than_low, ones) + + def _sample(self, shape=(), low=None, high=None): + """ + Sampling. + + Args: + shape (tuple): shape of the sample. Default: (). + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + Returns: + Tensor, shape is shape + batch_shape. + """ + low = self.low if low is None else low + high = self.high if high is None else high + broadcast_shape = self.shape(low + high) + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) + sample = (high - low) * sample_uniform + low + return sample diff --git a/mindspore/nn/sparse/__init__.py b/mindspore/nn/sparse/__init__.py new file mode 100644 index 0000000000..1a4c350a0d --- /dev/null +++ b/mindspore/nn/sparse/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Sparse related transformation. +""" +from .sparse import SparseToDense + +__all__ = [ + "SparseToDense", + ] diff --git a/mindspore/nn/sparse/sparse.py b/mindspore/nn/sparse/sparse.py new file mode 100644 index 0000000000..2b9b5fa686 --- /dev/null +++ b/mindspore/nn/sparse/sparse.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Sparse related tools.""" +from mindspore.ops import operations as P +from ..cell import Cell + + +class SparseToDense(Cell): + """ + Convert a sparse tensor into dense. + + Not yet supported by any backend at the moment. + + Args: + sparse_tensor (SparseTensor): the sparse tensor to convert. + + Returns: + Tensor, the tensor converted. + + Examples: + >>> class SparseToDenseCell(nn.Cell): + >>> def __init__(self, dense_shape): + >>> super(SparseToDenseCell, self).__init__() + >>> self.dense_shape = dense_shape + >>> self.sparse_to_dense = nn.SparseToDense() + >>> def construct(self, indices, values): + >>> sparse = SparseTensor(indices, values, self.dense_shape) + >>> return self.sparse_to_dense(sparse) + >>> + >>> indices = Tensor([[0, 1], [1, 2]]) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> dense_shape = (3, 4) + >>> SparseToDenseCell(dense_shape)(indices, values) + """ + def __init__(self): + super(SparseToDense, self).__init__() + self.sparse_to_dense = P.SparseToDense() + + def construct(self, sparse_tensor): + return self.sparse_to_dense(sparse_tensor.indices, + sparse_tensor.values, + sparse_tensor.dense_shape) diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 9354b42e55..66543a1625 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -13,35 +13,48 @@ # limitations under the License. # ============================================================================ """grad reducer cell for distributed training""" +from mindspore import context from mindspore.nn.cell import Cell from mindspore.communication.management import GlobalComm, get_group_size +from mindspore.common.tensor import RowTensor from mindspore.ops import functional as F, composite as C, operations as P -from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather +from mindspore.ops.operations.comm_ops import AllReduce, AllGather +from mindspore.parallel._auto_parallel_context import auto_parallel_context import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") -_all_reduce = AllReduce() -_all_gather = None - -def _init_optimizer_communication(): - global _all_reduce - global _all_gather - - _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) - _all_reduce.add_prim_attr('fusion', 1) - _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP) - - -@reduce_opt.register("Function", "Number", "Bool", "Tensor") -def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): +def _init_allreduce_operators(length, split_indices): + """ initialize allreduce communication operators""" + group = 1 + fusion = () + for i in range(length): + fusion = fusion + (group,) + if split_indices[group - 1] <= i + 1: + if group >= len(split_indices): + continue + group = group + 1 + index = tuple(range(1, length + 1)) + op_list = () + for i in range(length): + op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) + op.add_prim_attr('fusion', fusion[i]) + op.add_prim_attr('index', index[i]) + op_list = op_list + (op,) + return op_list + + +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor") +def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad): """ - Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning. + Apply allreduce on gradient. Args: - mul (Primitive): Div operation. degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. allreduce_filter (bool): When it is true, allreduce would apply. grad (Tensor): The gradient tensor before operation. @@ -49,72 +62,106 @@ def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): Tensor, the gradient tensor after operation. """ if allreduce_filter: - degree = F.scalar_cast(degree, F.dtype(grad)) - grad = _all_reduce(grad) - cast_op = P.Cast() - return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad))) + grad = allreduce(grad) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad)) + cast_op = P.Cast() + mul_op = P.Mul() + grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) + return grad return grad -@reduce_opt.register("Function", "Number", "Bool", "Tuple") -def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") +def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): """ - Apply mean and allgather on gradient instead of allreduce for sparse feature. - Allgather is a communication operation used for distributed deep learning. + Apply allreduce on gradient. Args: - mul (Primitive): Div operation. degree (int): The mean coefficient. - allreduce_filter (bool): When it is true, allgather would apply. - grad (Tuple): The indices, gradient tensor and tensor_shape before operation. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. + allreduce_filter (bool): When it is true, allreduce would apply. + grad (Tensor): The gradient tensor before operation. + ps_parameter (bool): Use parameter server or not. Returns: - Tuple, include indices, the gradient tensor and tensor_shape after operation. + Tensor, the gradient tensor after operation. """ + if ps_parameter: + return grad + if allreduce_filter: - indices = _all_gather(grad[0]) - degree = F.scalar_cast(degree, F.dtype(grad[1])) - dout = _all_gather(grad[1]) - cast_op = P.Cast() - dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) - grad = (indices, dout, grad[2]) + grad = allreduce(grad) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad)) + cast_op = P.Cast() + mul_op = P.Mul() + grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad))) + return grad return grad -@reduce_opt.register("Bool", "Tensor") -def _tensors_allreduce(allreduce_filter, grad): +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor") +def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): """ - Apply allreduce on gradient. + Apply allgather on gradient instead of allreduce for sparse feature. + Allgather is a communication operation used for distributed deep learning. Args: - allreduce_filter (bool): When it is true, allreduce would apply. - grad (Tensor): The gradient tensor before operation. + degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. + allreduce_filter (bool): When it is true, allgather would apply. + grad (tuple): The indices, gradient tensor and tensor_shape before operation. Returns: - Tensor, the gradient tensor after operation. + RowTensor, the gradient after operation. """ if allreduce_filter: - return _all_reduce(grad) + indices = allgather(grad.indices) + dout = allgather(grad.values) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad.values)) + cast_op = P.Cast() + mul_op = P.Mul() + dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) + grad = RowTensor(indices, dout, grad.dense_shape) return grad -@reduce_opt.register("Bool", "Tuple") -def _tensors_allreduce_with_sparse(allreduce_filter, grad): +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") +def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): """ - Apply mean and allgather on gradient instead of allreduce for sparse feature. + Apply allgather on gradient instead of allreduce for sparse feature. Allgather is a communication operation used for distributed deep learning. Args: + degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. allreduce_filter (bool): When it is true, allgather would apply. - grad (Tuple): The indices, gradient tensor and tensor_shape before operation. + grad (tuple): The indices, gradient tensor and tensor_shape before operation. + ps_parameter (bool): Use parameter server or not. Returns: - Tuple, include indices, the gradient tensor and tensor_shape after operation. + RowTensor, the gradient after operation. """ + if ps_parameter: + return grad + if allreduce_filter: - indices = _all_gather(grad[0]) - dout = _all_gather(grad[1]) - grad = (indices, dout, grad[2]) + indices = allgather(grad.indices) + dout = allgather(grad.values) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad.values)) + cast_op = P.Cast() + mul_op = P.Mul() + dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) + grad = RowTensor(indices, dout, grad.dense_shape) return grad @@ -135,18 +182,18 @@ def _tensors_get_datatype(grad): return F.dtype(grad) -@_get_datatype.register("Tuple") +@_get_datatype.register("RowTensor") def _tensors_get_datatype_with_sparse(grad): """ Acquire gradient datatype. Args: - grad (Tuple): The gradient tensor before operation. + grad (RowTensor): The gradient before operation. Returns: mstype, the datatype of gradient. """ - return F.dtype(grad[1]) + return F.dtype(grad.values) _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") @@ -167,20 +214,20 @@ def _tensors_cast_datatype(datatype, grad): return F.cast(grad, datatype) -@_cast_datatype.register("TypeType", "Tuple") +@_cast_datatype.register("TypeType", "RowTensor") def _tensors_cast_datatype_with_sparse(datatype, grad): """ Cast gradient to datatype. Args: datatype (mstype): the destination datatype of gradient. - grad (Tuple): The gradient tensor before operation. + grad (RowTensor): The gradient before operation. Returns: - Tuple, the gradient tuple after operation. + RowTensor, the gradient after operation. """ - dout = F.cast(grad[1], datatype) - return (grad[0], dout, grad[2]) + dout = F.cast(grad.values, datatype) + return RowTensor(grad.indices, dout, grad.dense_shape) class DistributedGradReducer(Cell): @@ -259,7 +306,6 @@ class DistributedGradReducer(Cell): def __init__(self, parameters, mean=True, degree=None): super(DistributedGradReducer, self).__init__(auto_prefix=False) self.map_ = C.Map() - self.mul = P.Mul() if degree is None: self.degree = get_group_size() else: @@ -268,7 +314,18 @@ class DistributedGradReducer(Cell): self.degree = degree self.mean = mean self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) - _init_optimizer_communication() + is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") + split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() + if is_parallel_optimizer and split_indices: + self.split_fusion = True + self.op_list = _init_allreduce_operators(len(parameters), split_indices) + else: + self.split_fusion = False + self.allreduce = AllReduce().add_prim_attr('fusion', 1) + self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) + ps_filter = lambda x: x.is_param_ps + self.ps_parameters = tuple(ps_filter(x) for x in parameters) + self.enable_parameter_server = any(self.ps_parameters) def construct(self, grads): """ @@ -284,11 +341,19 @@ class DistributedGradReducer(Cell): """ datatypes = self.map_(F.partial(_get_datatype), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) - - if self.mean: - new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) + if self.split_fusion: + if self.enable_parameter_server: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), + self.op_list, self.allreduce_filter, grads, self.ps_parameters) + else: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), + self.op_list, self.allreduce_filter, grads) else: - new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads) - + if self.enable_parameter_server: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, + self.allreduce), self.allreduce_filter, grads, self.ps_parameters) + else: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, + self.allreduce), self.allreduce_filter, grads) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) return new_grad diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index a9aa4d781b..08ff30b4b4 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.train.parallel_utils import ParallelMode from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from ..cell import Cell -from ...common import Tensor +from ...common import Tensor, RowTensor from ...common.parameter import Parameter from ...ops import functional as F from ...ops import composite as C @@ -35,6 +35,12 @@ reciprocal = P.Reciprocal() def tensor_grad_scale(scale, grad): return grad * F.cast(reciprocal(scale), F.dtype(grad)) +@_grad_scale.register("Tensor", "RowTensor") +def tensor_grad_scale_row_tensor(scale, grad): + return RowTensor(grad.indices, + grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), + grad.dense_shape) + _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") grad_overflow = P.FloatStatus() @@ -200,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell): def __init__(self, network, optimizer, scale_update_cell=None): super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network + self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = optimizer.parameters self.optimizer = optimizer diff --git a/mindspore/ops/__init__.py b/mindspore/ops/__init__.py index 7265b3c98b..aa4c5662e3 100644 --- a/mindspore/ops/__init__.py +++ b/mindspore/ops/__init__.py @@ -24,8 +24,8 @@ Examples: Note: - The Primitive operators in operations need to be used after instantiation. - - The composite operators are pre-defined combination of operator. - - The functional operators are the pre-instantiated Primitive operators, which can be used directly like a function. + - The composite operators are the pre-defined combination of operators. + - The functional operators are the pre-instantiated Primitive operators, which can be used directly as a function. - For functional operators usage, please refer to https://gitee.com/mindspore/mindspore/blob/master/mindspore/ops/functional.py """ diff --git a/mindspore/ops/_grad/__init__.py b/mindspore/ops/_grad/__init__.py index de9e3ae8d0..c9db4094ee 100644 --- a/mindspore/ops/_grad/__init__.py +++ b/mindspore/ops/_grad/__init__.py @@ -15,7 +15,7 @@ """grad impl.""" from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ - grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops + grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse from .grad_base import get_bprop_fn __all__ = ['get_bprop_fn'] diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index b1a3e1d98b..d4079f38e2 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -15,6 +15,8 @@ """array_ops""" +import mindspore as ms +from mindspore.ops import composite as C from .. import operations as P from ..operations import _grad_ops as G from ..operations import _inner_ops as inner @@ -25,6 +27,7 @@ from .grad_base import bprop_getters from ..primitive import constexpr from ... import context from ...common import dtype as mstype +from ...common.tensor import RowTensor reduce_sum = P.ReduceSum() unsorted_segment_sum = P.UnsortedSegmentSum() @@ -34,6 +37,7 @@ reshape = P.Reshape() size_op = P.Size() invert_permutation = P.InvertPermutation() logical_and = P.LogicalAnd() +is_sub_class = P.IsSubClass() @bprop_getters.register(P.Fill) @@ -56,6 +60,29 @@ def get_bprop_dtype(self): return bprop +dout_cast = C.MultitypeFuncGraph("dout_cast") +@dout_cast.register("Tensor", "Tensor") +def dout_cast_tensor(dout, x): + cast = P.Cast() + get_dtype = P.DType() + dx = cast(dout, get_dtype(x)) + return dx + +@dout_cast.register("Number", "Number") +def dout_cast_number(dout, x): + cast = P.Cast() + get_dtype = P.DType() + dx = cast(dout, get_dtype(x)) + return dx + +@dout_cast.register("RowTensor", "Tensor") +def dout_cast_row_tensor(dout, x): + cast = P.Cast() + get_dtype = P.DType() + values = cast(dout.values, get_dtype(x)) + return RowTensor(dout.indices, values, dout.dense_shape) + + @bprop_getters.register(P.Cast) def get_bprop_cast(self): """Generate bprop for Cast""" @@ -66,6 +93,13 @@ def get_bprop_cast(self): dx = cast(dout, get_dtype(x)) return dx, zeros_like(t) + def bprop_sparse(x, t, out, dout): + dx = dout_cast(dout, x) + return dx, zeros_like(t) + + if context.get_context('enable_sparse'): + return bprop_sparse + return bprop @@ -206,29 +240,10 @@ def get_bprop_embedding_lookup(self): actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail # Reshape the 'actual_dout' on device actual_dout = reshape_op(dout, actual_dout_shape_changed) - return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) + return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return bprop_sparse -@bprop_getters.register(P.EmbeddingLookup) -def get_bprop_embedding_look_up(self): - """Generate bprop for EmbeddingLookup""" - sub_op = P.Sub() - reshape_op = P.Reshape() - def bprop(x, indices, offset, out, dout): - x_shp = shape_op(x) - new_indices = sub_op(indices, offset) - # Reshape the 'new_indices' - new_indices_shape_changed = (size_op(new_indices),) - new_indices = reshape_op(new_indices, new_indices_shape_changed) - actual_dout_shape_changed = new_indices_shape_changed - if len(x_shp) > 1: - actual_dout_shape_changed += x_shp[1:] - actual_dout = reshape_op(dout, actual_dout_shape_changed) - return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) - return bprop - - @bprop_getters.register(P.Transpose) def get_bprop_transpose(self): """Generate bprop for Transpose""" @@ -354,7 +369,7 @@ def get_bprop_sparse_gather_v2(self): values_shape = indices_size + x_tail_shp values = reshape(dout, values_shape) indices = reshape(indices, indices_size) - return (indices, values, x_shp), zeros_like(indices), zeros_like(axis) + return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis) if F.rank(dout) == 0: dout = P.ExpandDims()(dout, -1) if F.rank(indices) == 0: @@ -390,11 +405,28 @@ def get_bprop_pack(self): def bprop(x, out, dout): pack_grad = P.Unpack(axis) out = pack_grad(dout) + if is_sub_class(F.typeof(x), ms.list_): + ret = [] + for item in out: + ret.append(item) + return (ret,) return (out,) return bprop +@bprop_getters.register(P.ReverseV2) +def get_bprop_reverse_v2(self): + """Generate bprop for ReverseV2""" + axis = self.axis + + def bprop(x, out, dout): + reverse_grad = P.ReverseV2(axis) + dx = reverse_grad(dout) + return (dx,) + + return bprop + @bprop_getters.register(P.Unpack) def get_bprop_unpack(self): """Generate bprop for Unpack""" @@ -513,6 +545,16 @@ def get_bprop_scatter_nd_update(self): return bprop +@bprop_getters.register(P.ScatterNonAliasingAdd) +def get_bprop_scatter_non_aliasing_add_update(self): + """Generate bprop for ScatterNonAliasingAdd""" + op = P.GatherNd() + + def bprop(x, indices, update, out, dout): + return dout, zeros_like(indices), op(dout, indices) + + return bprop + @bprop_getters.register(P.TensorScatterUpdate) def get_bprop_tensor_scatter_update(self): """Generate bprop for TensorScatterUpdate""" @@ -527,6 +569,7 @@ def get_bprop_tensor_scatter_update(self): return bprop + @bprop_getters.register(P.ScatterMax) def get_bprop_scatter_max(self): """Generate bprop for ScatterMax""" @@ -630,6 +673,16 @@ def _GatherDropNegatives(params, return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) +@bprop_getters.register(P.UnsortedSegmentSum) +def get_bprop_unsorted_segment_sum(self): + """Generate bprop for UnsortedSegmentSum""" + + def bprop(x, segment_ids, num_segments, out, dout): + return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) + + return bprop + + @bprop_getters.register(P.UnsortedSegmentMin) def get_bprop_unsorted_segment_min(self): """Generate bprop for UnsortedSegmentMin""" @@ -764,3 +817,13 @@ def get_bprop_trans_shape(self): dx = op(dout, shape_op(x)) return (dx, zeros_like(shape)) return bprop + + +@bprop_getters.register(P.Unique) +def get_bprop_unique(self): + """Generate bprop for Unique""" + op = G.UniqueGrad() + def bprop(x, out, dout): + dx = op(dout, out) + return (dx,) + return bprop diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 34df18beba..fb54eab172 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -17,6 +17,7 @@ import mindspore.common.dtype as mstype from mindspore.ops import functional as F from .. import operations as P +from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, ReduceOp, @@ -46,9 +47,9 @@ def get_bprop_all_reduce(self): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) else: - indices = all_gather(dout[0]) - grad = all_gather(dout[1]) - dx = (indices, grad, dout[2]) + indices = all_gather(dout.indices) + grad = all_gather(dout.values) + dx = RowTensor(indices, grad, dout.dense_shape) return (dx,) else: @@ -59,12 +60,12 @@ def get_bprop_all_reduce(self): z = cast(z, dtype(dx)) dx = mul(dx, z) else: - indices = all_gather(dout[0]) - grad = all_gather(dout[1]) + indices = all_gather(dout.indices) + grad = all_gather(dout.values) z = equal(x, out) z = cast(z, dtype(grad)) grad = mul(grad, z) - dx = (indices, grad, dout[2]) + dx = RowTensor(indices, grad, dout.dense_shape) return (dx,) return bprop @@ -194,19 +195,19 @@ def get_bprop_mirror_operator(self): num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) else: - indices = all_gather(dout[0]) - grad = all_gather(dout[1]) + indices = all_gather(dout.indices) + grad = all_gather(dout.values) float_one = F.scalar_cast(1.0, F.dtype(grad)) num = F.scalar_cast(dev_num, F.dtype(grad)) grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) - dx = (indices, grad, dout[2]) + dx = RowTensor(indices, grad, dout.dense_shape) else: if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce(dout) else: - indices = all_gather(dout[0]) - grad = all_gather(dout[1]) - dx = (indices, grad, dout[2]) + indices = all_gather(dout.indices) + grad = all_gather(dout.values) + dx = RowTensor(indices, grad, dout.dense_shape) return (dx,) return bprop diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index 87566b1110..0ada82c847 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -117,6 +117,12 @@ def bprop_tuple_getitem(data, idx, out, dout): return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) +@bprops.register("list_getitem") +def bprop_list_getitem(data, idx, out, dout): + """Backpropagator for primitive `list_getitem`.""" + return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) + + @bprops.register("identity") def bprop_identity(x, out, dout): """Backpropagator for primitive `identity`.""" diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 975e918817..aee67c0f9b 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -17,6 +17,7 @@ from functools import reduce import numpy as np +import mindspore as ms from mindspore.ops import _selected_grad_ops as SG from .. import functional as F from .. import operations as P @@ -33,6 +34,7 @@ shape_op = P.Shape() reduce_sum = P.ReduceSum() reshape = P.Reshape() tile = P.Tile() +is_sub_class = P.IsSubClass() def binop_grad_common(x, y, dx, dy): @@ -250,6 +252,21 @@ def get_bprop_div_no_nan(self): return bprop +@bprop_getters.register(P.Xdivy) +def get_bprop_xdivy(self): + """Grad definition for `Xdivy` operation.""" + div_op = P.Xdivy() + + def bprop(x, y, out, dout): + x_dtype = F.dtype(x) + not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype) + bc_x = div_op(not_zero_x, y) * dout + bc_y = div_op(-x, F.square(y)) * dout + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + + @bprop_getters.register(P.Floor) def get_bprop_floor(self): """Grad definition for `floor` operation.""" @@ -282,14 +299,9 @@ def get_bprop_ceil(self): @bprop_getters.register(P.FloorDiv) def get_bprop_floordiv(self): """Grad definition for `FloorDiv` operation.""" - div_op = P.FloorDiv() - neg = P.Neg() - mul_op = P.Mul() def bprop(x, y, out, dout): - bc_x = div_op(dout, y) - bc_y = neg(mul_op(bc_x, out)) - return binop_grad_common(x, y, bc_x, bc_y) + return zeros_like(x), zeros_like(y) return bprop @@ -306,6 +318,29 @@ def get_bprop_floormod(self): return bprop +@bprop_getters.register(P.TruncateDiv) +def get_bprop_truncate_div(self): + """Grad definition for `TruncateDiv` operation.""" + + def bprop(x, y, out, dout): + return zeros_like(x), zeros_like(y) + + return bprop + + +@bprop_getters.register(P.TruncateMod) +def get_bprop_truncate_mod(self): + """Grad definition for `TruncateMod` operation.""" + div_op = P.TruncateDiv() + + def bprop(x, y, out, dout): + bc_x = dout + bc_y = -dout * div_op(x, y) + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + + @bprop_getters.register(P.Mod) def get_bprop_mod(self): """Grad definition for `Mod` operation.""" @@ -333,18 +368,59 @@ def get_bprop_square(self): return bprop -@bprop_getters.register(P.Sqrt) -def get_bprop_sqrt(self): - """Grad definition for `Sqrt` operation.""" +@bprop_getters.register(P.SquaredDifference) +def get_bprop_squared_difference(self): + """Grad definition for `SquaredDifference` operation.""" + neg = P.Neg() + + def bprop(x, y, out, dout): + x_grad = 2 * dout * (x - y) + bc_x = x_grad + bc_y = neg(x_grad) + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + + +@bprop_getters.register(P.Xlogy) +def get_bprop_xlogy(self): + """Grad definition for `Xlogy` operation.""" + log_op = P.Xlogy() + div_op = P.Xdivy() + + def bprop(x, y, out, dout): + x_dtype = F.dtype(x) + not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype) + bc_x = log_op(not_zero_x, y) * dout + bc_y = div_op(x, y) * dout + return binop_grad_common(x, y, bc_x, bc_y) + + return bprop + +@bprop_getters.register(P.SquareSumAll) +def get_bprop_square_sum_all(self): + """Grad definition for `Square` operation.""" mul_func = P.Mul() fill_func = P.Fill() - div_op = P.RealDiv() - sqrt = P.Sqrt() dtype = P.DType() + def bprop(x, y, out, dout): + temp_x = mul_func(dout[0], x) + temp_y = mul_func(dout[1], y) + dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x) + dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y) + return (dx, dy) + + return bprop + + +@bprop_getters.register(P.Sqrt) +def get_bprop_sqrt(self): + """Grad definition for `Sqrt` operation.""" + sqrt_grad = G.SqrtGrad() + def bprop(x, out, dout): - temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x)) - dx = mul_func(dout, temp) + dx = sqrt_grad(out, dout) return (dx,) return bprop @@ -353,10 +429,10 @@ def get_bprop_sqrt(self): @bprop_getters.register(P.Rsqrt) def get_bprop_rsqrt(self): """Grad definition for `Rsqrt` operation.""" + rsqrt_grad = G.RsqrtGrad() def bprop(x, out, dout): - grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x) - dx = dout * grad + dx = rsqrt_grad(out, dout) return (dx,) return bprop @@ -365,15 +441,22 @@ def get_bprop_rsqrt(self): @bprop_getters.register(P.Reciprocal) def get_bprop_reciprocal(self): """Grad definition for `Reciprocal` operation.""" - neg = P.Neg() - mul = P.Mul() - square = P.Square() - reciprocal = P.Reciprocal() + if self.target == "GPU": + neg = P.Neg() + mul = P.Mul() + square = P.Square() + reciprocal = P.Reciprocal() + + def bprop(x, out, dout): + g = neg(reciprocal(square(x))) + dx = mul(dout, g) + return (dx,) + else: + reciprocal_grad = G.ReciprocalGrad() - def bprop(x, out, dout): - g = neg(reciprocal(square(x))) - dx = mul(dout, g) - return (dx,) + def bprop(x, out, dout): + dx = reciprocal_grad(out, dout) + return (dx,) return bprop @@ -611,6 +694,16 @@ def get_bprop_reduceall(self): return bprop +@bprop_getters.register(P.ReduceAny) +def get_bprop_reduceany(self): + """Grad definition for `ReduceAny` operation.""" + + def bprop(x, axis, out, dout): + return zeros_like(x), zeros_like(axis) + + return bprop + + @bprop_getters.register(P.ReduceMax) def get_bprop_reducemax(self): """Grad definition for `Max` operation.""" @@ -952,7 +1045,8 @@ def get_bprop_scalar_accumulatenv2(self): dx = () for _ in range(len(x)): dx = dx + (dout,) - return dx + return (dx,) + return bprop @@ -961,10 +1055,16 @@ def get_bprop_scalar_addn(self): """Generate bprop for AddN""" def bprop(x, out, dout): + if is_sub_class(F.typeof(x), ms.list_): + dx = [] + for _ in range(len(x)): + dx.append(dout) + return (dx,) + dx = () for _ in range(len(x)): dx = dx + (dout,) - return dx + return (dx,) return bprop @@ -1027,6 +1127,22 @@ def get_bprop_atan(self): return bprop +@bprop_getters.register(P.Tan) +def get_bprop_tan(self): + """Grad definition for `Tan` operation.""" + reciprocal = P.Reciprocal() + square = P.Square() + cos = P.Cos() + + def bprop(x, out, dout): + cosx = cos(x) + secx2 = square(reciprocal(cosx)) + dx = secx2 * dout + return (dx,) + + return bprop + + @bprop_getters.register(P.BesselI1e) def get_bprop_bessel_i1e(self): """Generate bprop for BesselI1e""" diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 61c7e40960..8f4cf8496d 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -14,7 +14,10 @@ # ============================================================================ """Define the grad rules of neural network related operations.""" +import numpy as np from mindspore.ops import _selected_grad_ops as SG +from mindspore.ops.primitive import constexpr +from mindspore.common.tensor import Tensor from .grad_base import bprop_getters from .. import functional as F from .. import operations as P @@ -24,7 +27,6 @@ from ..operations import _inner_ops as inner from ... import context - @bprop_getters.register(P.BiasAdd) def get_bprop_bias_add(self): """Grad definition for `BiasAdd` operation.""" @@ -195,33 +197,133 @@ def get_bprop_max_pool_grad(self): return bprop +def _windowed_output_size(input_size, ksize, stride, padding): + """ + helper func for AvgPoolGrad + """ + + tmp_output = 0 + tmp_pad_need = 0 + tmp_pad_before = 0 + tmp_pad_after = 0 + if padding == 'VALID': + tmp_output = (input_size - ksize + stride) // stride + tmp_pad_before = 0 + tmp_pad_after = 0 + elif padding == 'SAME': + tmp_output = (input_size + stride - 1) // stride + tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size) + tmp_pad_before = tmp_pad_need // 2 + tmp_pad_after = tmp_pad_need - tmp_pad_before + return tmp_output, tmp_pad_before, tmp_pad_after + + +@constexpr +def _get_mean_matrix(x_shape, ksize, stride, padding, x_dtype): + """ + helper func for AvgPoolGrad. + + `assist_input_matrix` is a 2d matrix with input_shape after padding, + the value of element which is padded is 0, else are 1. + For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, + w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the + number of input that assosiate with output element. + """ + + n_input, c_input, h_input, w_input = x_shape + h_ksize, w_ksize = ksize[2], ksize[3] + h_stride, w_stride = stride[2], stride[3] + n_output = n_input + c_output = c_input + h_output, w_output = 0, 0 + pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 + h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize, + h_stride, padding) + w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize, + w_stride, padding) + + output_size = n_output * c_output * h_output * w_output + output_shape = (n_output, c_output, h_output, w_output) + output = np.array([0.0] * output_size) + output = np.reshape(output, output_shape) + + in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right) + assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32) + if pad_top > 0: + assist_input_matrix[:pad_top, :] = 0 + if pad_bottom > 0: + assist_input_matrix[-pad_bottom:, :] = 0 + if pad_left > 0: + assist_input_matrix[:, :pad_left] = 0 + if pad_right > 0: + assist_input_matrix[:, -pad_right:] = 0 + + for h in range(h_output): + for w in range(w_output): + curr_input = assist_input_matrix[h*h_stride : h*h_stride + h_ksize, w*w_stride : w*w_stride + w_ksize] + curr_sum = np.sum(curr_input) + if curr_sum > 0: + output[:, :, h, w] = 1. / curr_sum + return Tensor(output, x_dtype) + + +@constexpr +def _get_kernel_matrix(kernel_matrix_shape, x_dtype): + kernel_matrix = np.ones(kernel_matrix_shape) + return Tensor(kernel_matrix, x_dtype) + + @bprop_getters.register(P.AvgPool) def get_bprop_avg_pool_grad(self): """Grad definition for `AvgPool` operation.""" - avgpool_grad = G.AvgPoolGrad( - ksize=self.ksize, - strides=self.strides, - padding=self.padding) - shape_op = P.Shape() - - avgpool_grad_gpu = G.AvgPoolGradGpu( - ksize=self.ksize, - strides=self.strides, - padding=self.padding) - - def bprop(x, out, dout): - dx = avgpool_grad(shape_op(x), dout) - return (dx,) - - def bprop_gpu(x, out, dout): - dx = avgpool_grad_gpu(x, out, dout) - return (dx,) # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same if self.target == "GPU": + avgpool_grad_gpu = G.AvgPoolGradGpu( + ksize=self.ksize, + strides=self.strides, + padding=self.padding) + + def bprop_gpu(x, out, dout): + dx = avgpool_grad_gpu(x, out, dout) + return (dx,) + bprop_fn = bprop_gpu + + elif self.target == "GE": + avgpool_grad_ge = G.AvgPoolGrad( + ksize=self.ksize, + strides=self.strides, + padding=self.padding) + shape_op = P.Shape() + + def bprop_ge(x, out, dout): + dx = avgpool_grad_ge(shape_op(x), dout) + return (dx,) + + bprop_fn = bprop_ge + else: - bprop_fn = bprop + avgpool_grad_vm = G.AvgPoolGradVm( + ksize=self.ksize, + strides=self.strides, + padding=self.padding) + k_size_nchw = avgpool_grad_vm.ksize + stride_nchw = avgpool_grad_vm.strides + padding = self.padding + + def bprop_vm(x, out, dout): + x_shape_nchw = F.shape(x) + x_dtype = F.dtype(x) + kernel_matrix_shape = (1, x_shape_nchw[1], + k_size_nchw[2], + k_size_nchw[3]) + mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, padding, x_dtype) + kernel_matrix = _get_kernel_matrix(kernel_matrix_shape, x_dtype) + dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix) + return (dx,) + + bprop_fn = bprop_vm return bprop_fn @@ -339,6 +441,7 @@ def get_bprop_softmax(self): sub = P.Sub() mul = P.Mul() axis = self.axis + def bprop(x, out, dout): dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) return (dx,) @@ -526,19 +629,59 @@ def get_bprop_onehot(self): return bprop +@constexpr +def _range_op(start, limit, delta, dtype): + """helper function for Grad TopK""" + output_tensor = Tensor(list(range(start, limit, delta)), dtype) + return output_tensor + +@constexpr +def _get_1d_shape(in_shape): + """helper function for Grad TopK""" + out_shape = 1 + for i in in_shape: + out_shape *= i + return (out_shape,) + @bprop_getters.register(P.TopK) def get_bprop_top_kv2(self): """Grad definition for `TopK` operation.""" scatter = P.ScatterNd() expand_dims = P.ExpandDims() shape_op = P.Shape() + reshape_op = P.Reshape() + dtype = P.DType() def bprop(input_x, k, out, dout): + + # (n1, n2, ...., n_p), in_lastdim = n_p + in_shape = shape_op(input_x) + in_lastdim = in_shape[-1] + + # (n_1, ... n_(p-1), k), ind_lastdim = k indices = out[1] - indices = expand_dims(indices, -1) - updates = dout[0] - shapes = shape_op(input_x) - return scatter(indices, updates, shapes), zeros_like(k) + ind_shape = shape_op(indices) + ind_lastdim = ind_shape[-1] + + # (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1) + ind_2d = reshape_op(indices, (-1, ind_lastdim)) + outerdim = shape_op(ind_2d)[0] + + # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] + indices_dtype = dtype(indices) + range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype) + + # expand_dims to (k, 1), then broadcast + ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) + in_shape_1d = _get_1d_shape(in_shape) + + out_grad = reshape_op( + scatter( + expand_dims(ind, -1), + reshape_op(dout[0], (-1,)), + in_shape_1d), + in_shape) + return out_grad, zeros_like(k) return bprop @@ -732,6 +875,17 @@ def get_bprop_binary_cross_entropy(self): return bprop +@bprop_getters.register(P.KLDivLoss) +def get_bprop_kl_div_loss(self): + """Grad definition for `KLDivLoss` operation.""" + grad = G.KLDivLossGrad(self.reduction) + + def bprop(x, y, out, dout): + dx, dy = grad(x, y, dout) + return dx, dy + + return bprop + @bprop_getters.register(P.Dropout) def get_bprop_dropout(self): diff --git a/mindspore/ops/_grad/grad_sparse.py b/mindspore/ops/_grad/grad_sparse.py new file mode 100644 index 0000000000..03b68cc40a --- /dev/null +++ b/mindspore/ops/_grad/grad_sparse.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""bprop primitives""" +from .. import functional as F +from .. import operations as P +from ..composite.multitype_ops.zeros_like_impl import zeros_like +from .grad_base import bprops, bprop_getters + +# Unused parameters are placeholders. + + +@bprops.register("MakeSparseTensor") +def bprop_make_sparse_tensor(indices, values, dense_shape, out, dout): + """Backpropagator for primitive `MakeSparseTensor`.""" + return zeros_like(indices), F.sparse_tensor_get_values(dout), () + + +@bprops.register("SparseTensorGetIndices") +def bprop_sparse_tensor_get_indices(sparse_tensor, out, dout): + """Backpropagator for primitive `SparseTensorGetIndices`.""" + return (zeros_like(sparse_tensor),) + + +@bprops.register("SparseTensorGetValues") +def bprop_sparse_tensor_get_values(sparse_tensor, out, dout): + """Backpropagator for primitive `SparseTensorGetValues`.""" + return F.make_sparse_tensor(F.sparse_tensor_get_indices(sparse_tensor), + dout, + F.sparse_tensor_get_dense_shape(sparse_tensor)) + + +@bprops.register("SparseTensorGetDenseShape") +def bprop_sparse_tensor_get_dense_shape(sparse_tensor, out, dout): + """Backpropagator for primitive `SparseTensorGetDenseShape`.""" + return (zeros_like(sparse_tensor),) + + +@bprop_getters.register(P.SparseToDense) +def get_bprop_sparse_to_dense(self): + """Generate bprop for SparseToDense""" + + def bprop(indices, values, dense_shape, out, dout): + return zeros_like(indices), dout, zeros_like(dense_shape) + + return bprop diff --git a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py index c512989906..0421de2dab 100644 --- a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py @@ -14,10 +14,11 @@ # ============================================================================ """batch_matmul_impl""" -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from te import tik from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ .fusion_type("OPAQUE") \ .async_flag(False) \ @@ -114,7 +115,8 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr ((1, 64, 64), (1, 64, 64), "float32", False, True), ((1, 128, 128), (1, 128, 128), "float32", False, True), ((4, 128, 128), (4, 128, 128), "float32", False, True), - ((2, 128, 128), (2, 128, 128), "float32", False, True)] + ((2, 128, 128), (2, 128, 128), "float32", False, True), + ((32, 128, 128), (32, 128, 128), 'float32', False, True)] if input_shape not in support_shape: raise RuntimeError("input_shape %s is not supported" % str(input_shape)) @@ -232,7 +234,8 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr ((2, 128, 128), (2, 128, 128), "float32", False, True), ((4, 128, 128), (4, 128, 128), "float32", False, True), ((8, 128, 128), (8, 128, 128), "float32", False, True), - ((16, 128, 128), (16, 128, 128), "float32", False, True) + ((16, 128, 128), (16, 128, 128), "float32", False, True), + ((32, 128, 128), (32, 128, 128), 'float32', False, True) ] if input_shape in input_shape_list: block_num = 32 diff --git a/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py index f4b8d44063..23fdd1a2a7 100644 --- a/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================ """CusFusedAbsMax1""" -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from te import tik from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \ .fusion_type("OPAQUE") \ .async_flag(False) \ @@ -36,152 +37,86 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m """CusFusedAbsMax1""" input_x_shape = input_x.get("shape") output_shape = output.get("shape") + dtype = input_x.get("dtype") if util.get_product_version() == util.VERSION_MINI: tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) else: tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) - if len(input_x_shape) > 2: - if (input_x_shape[0] == 1 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 4 and input_x_shape[1] == 16) or (input_x_shape[0] == 16 and input_x_shape[1] == 4): - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val + support_shape = [((1, 128, 128), "float32"), + ((2, 128, 128), "float32"), + ((4, 128, 128), "float32"), + ((8, 128, 128), "float32"), + ((16, 128, 128), "float32"), + ((5, 128, 128), "float32"), + ((9, 128, 128), "float32"), + ((18, 128, 128), "float32"), + ((36, 128, 128), "float32"), + ((32, 128, 128), "float32"), + ((1, 64, 64), "float32"), + ((32, 64), "float32") + ] + ori_shape = tuple(origin_shape) + input_info = (tuple(input_x_shape), dtype) + if input_info not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_info)) + if input_info == ((1, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((2, 128, 128), "float32"): + if ori_shape == (147, 147): + phase_1 = 16384 + phase_2 = 1216 blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 2 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 16 and input_x_shape[1] == 8): - if origin_shape[0] == 147 and ( - input_x_shape[0] == 2 and input_x_shape[1] == 128 and input_x_shape[2] == 128): - assert origin_shape[0] == 147 - assert origin_shape[1] == 147 - phase_1 = 16384 - phase_2 = 1216 - blocks = 32 - each_block_element = phase_1 // blocks + 64 - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[512 * block_index], 0, 1, 512 // 8, 0, 0) - line_id = block_index % 19 - tik_instance.data_move(input_x_ub[512], input_x[16384 + 128 * line_id], 0, 1, 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(19, input_x_ub, input_x_ub, input_x_ub[512], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, - 1, 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - else: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, - 1, 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 4 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 8 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 8): + each_block_element = phase_1 // blocks + 64 input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", scope=tik.scope_ubuf) broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) + tik_instance.data_move(input_x_ub, input_x[512 * block_index], 0, 1, 512 // 8, 0, 0) + line_id = block_index % 19 + tik_instance.data_move(input_x_ub[512], input_x[16384 + 128 * line_id], 0, 1, 8, 0, 0) repeat_time = each_block_element // 64 tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(19, input_x_ub, input_x_ub, input_x_ub[512], 1, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) @@ -189,169 +124,20 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m data_temp = tik_instance.Scalar("float32") data_temp.set_as(input_x_ub[cc0]) tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( - input_x_shape[0] == 16 and input_x_shape[1] == 32): - if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ - 0] == 1000: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - blocks = 32 - each_block_element = 7 * 128 * 128 // 32 + 4 * 128 - phase_1 = 7 * 128 * 128 // 32 - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) - tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, - 0) - move_idx = block_index % 8 - tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, - 128 // 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - vmask = 1000 - 7 * 128 - 64 - with tik_instance.for_range(0, 4) as loop_idx: - tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], - input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - - with tik_instance.for_range(0, 4) as loop_idx: - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, - 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, - 1, 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - - elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ - 0] == 1001: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - blocks = 32 - each_block_element = 7 * 128 * 128 // 32 + 4 * 128 - phase_1 = 7 * 128 * 128 // 32 - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) - tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, - 0) - tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, - 0) - move_idx = block_index % 9 - tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, - 128 // 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - vmask = 1001 - 7 * 128 - 64 - with tik_instance.for_range(0, 4) as loop_idx: - tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], - input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 4) as loop_idx: - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, - 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, - 1, 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - else: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, - 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, - 1, 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( - input_x_shape[0] == 64 and input_x_shape[1] == 16): + elif ori_shape in ((256, 256), None, (-1, -1)): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) total_elements = 1 @@ -368,9 +154,6 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m each_block_element // 8, 0, 0) repeat_time = each_block_element // 64 tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) @@ -379,108 +162,163 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m data_temp = tik_instance.Scalar("float32") data_temp.set_as(input_x_ub[cc0]) tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 5 and input_x_shape[1] == 128 and input_x_shape[2] == 128 and origin_shape[0] == 576: + else: + raise RuntimeError("origin shape %s is not supported" % str(ori_shape)) + elif input_info == ((4, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((8, 128, 128), "float32"): + if ori_shape == (1000, 1000): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 69632 blocks = 32 - each_block_element = total_elements // blocks - phase_1 = 2048 - phase_2 = 128 + each_block_element = 7 * 128 * 128 // 32 + 4 * 128 + phase_1 = 7 * 128 * 128 // 32 with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", scope=tik.scope_ubuf) broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", scope=tik.scope_ubuf) tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) - tik_instance.data_move(input_x_ub[phase_1], input_x[65536 + phase_2 * block_index * 2], 0, 1, 8, 0, 0) - tik_instance.data_move(input_x_ub[phase_1 + 64], input_x[65536 + 128 + phase_2 * block_index * 2], 0, 1, - 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + move_idx = block_index % 8 + tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, + 128 // 8, 0, 0) repeat_time = each_block_element // 64 tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub[2048], input_x_ub[2048], input_x_ub[2048 + 64], 1, 1, 1, 1, 8, 8, 8) + vmask = 1000 - 7 * 128 - 64 + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], + input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, + 8, 8) with tik_instance.for_range(0, 64) as cc0: data_temp = tik_instance.Scalar("float32") data_temp.set_as(input_x_ub[cc0]) tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 9 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 72 and input_x_shape[1] == 8): + elif ori_shape == (1001, 1001): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val blocks = 32 - each_block_element = total_elements // blocks + each_block_element = 7 * 128 * 128 // 32 + 4 * 128 + phase_1 = 7 * 128 * 128 // 32 with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", scope=tik.scope_ubuf) broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + move_idx = block_index % 9 + tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, + 128 // 8, 0, 0) repeat_time = each_block_element // 64 tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + vmask = 1001 - 7 * 128 - 64 + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], + input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, + 8, 8) with tik_instance.for_range(0, 64) as cc0: data_temp = tik_instance.Scalar("float32") data_temp.set_as(input_x_ub[cc0]) tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 18 and input_x_shape[1] == 128 and input_x_shape[2] == 128: + elif ori_shape in ((1024, 1024), None, (-1, -1)): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) total_elements = 1 @@ -497,390 +335,6 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m each_block_element // 8, 0, 0) repeat_time = each_block_element // 64 tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 36 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( - input_x_shape[0] == 144 and input_x_shape[1] == 16): - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time_1 = 255 - repeat_time_2 = each_block_element // 64 - 255 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, - 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 1024], 16, 1, 1, 1, 8, 8, - 8) - tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 512], 8, 1, 1, 1, 8, 8, - 8) - tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 256], 4, 1, 1, 1, 8, 8, - 8) - tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 128], 2, 1, 1, 1, 8, 8, - 8) - tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 128 and input_x_shape[1] == 63: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time_1 = 255 - repeat_time_2 = each_block_element // 64 - 255 * 3 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, - 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], - repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 3 * 64], input_x_ub[repeat_time_1 * 3 * 64], - repeat_time_2, 1, 1, 8, 8) - loop_size = each_block_element // 16384 - with tik_instance.for_range(0, loop_size) as loop_idx: - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, loop_size - 1) as loop_idx: - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * (loop_idx + 1)], 1, 1, 1, 1, 8, 8, - 8) - tail_element = each_block_element - 16384 * loop_size - repeats = tail_element // 64 - with tik_instance.for_range(0, repeats) as i: - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * loop_size + i * 64], 1, 1, 1, 1, 8, - 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, input_x_ub[64 + cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[2048 + 64], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[1024 + 64], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[512 + 64], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[256 + 64], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[128 + 64], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[64 + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.data_move(res[block_index, 0], input_x_ub[64], 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 32 and input_x_shape[1] == 128) or ( - input_x_shape[0] == 128 and input_x_shape[1] == 32): - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time_1 = 255 - repeat_time_2 = each_block_element // 64 - 255 * 2 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, - 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], - repeat_time_2, 1, 1, 8, 8) - loop_size = each_block_element // 16384 - with tik_instance.for_range(0, loop_size) as loop_idx: - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], - input_x_ub[16384 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, loop_size - 1) as loop_idx: - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * (loop_idx + 1)], 1, 1, 1, 1, 8, 8, - 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 288 and input_x_shape[1] == 32: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - assist_ub = tik_instance.Tensor("float32", (64,), name="assist_ub", scope=tik.scope_ubuf) - zero = tik_instance.Scalar("float32") - zero.set_as(0) - tik_instance.vector_dup(64, assist_ub, zero, 1, 1, 8) - input_x_ub = tik_instance.Tensor("float32", (32768,), name="input_x_ub", scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - repeat_time_1 = 255 - repeat_time_2 = 32768 // 64 - 255 * 2 - - tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 0], 0, 1, 4096, 0, 0) - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, - 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], - repeat_time_2, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) - - tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 32768], 0, 1, 4096, 0, - 0) - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, - 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], - repeat_time_2, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) - tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 65536], 0, 1, 1024, 0, - 0) - tik_instance.vabs(64, input_x_ub, input_x_ub, 128, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) - - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(assist_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 64 and input_x_shape[1] == 128: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - assist_ub = tik_instance.Tensor("float32", (64,), name="assist_ub", scope=tik.scope_ubuf) - zero = tik_instance.Scalar("float32") - zero.set_as(0) - tik_instance.vector_dup(64, assist_ub, zero, 1, 1, 8) - input_x_ub = tik_instance.Tensor("float32", (32768,), name="input_x_ub", scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - repeat_time_1 = 255 - repeat_time_2 = 32768 // 64 - 255 * 2 - - tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 0], 0, 1, 4096, 0, 0) - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, - 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], - repeat_time_2, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) - - tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 32768], 0, 1, 4096, 0, - 0) - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, - 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], - repeat_time_2, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) - - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(assist_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif (input_x_shape[0] == 64 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 64): - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time_1 = 255 - repeat_time_2 = each_block_element // 64 - 255 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) - tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, - 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) @@ -891,159 +345,131 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m data_temp = tik_instance.Scalar("float32") data_temp.set_as(input_x_ub[cc0]) tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 36 and input_x_shape[1] == 4: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 4 and input_x_shape[1] == 4: + else: + raise RuntimeError("origin shape %s is not supported" % str(ori_shape)) + elif input_info == ((16, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((32, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, 255, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[255 * 64], input_x_ub[255 * 64], 1, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((5, 128, 128), "float32"): + if ori_shape == (576, 576): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val + total_elements = 69632 blocks = 32 each_block_element = total_elements // blocks + phase_1 = 2048 + phase_2 = 128 with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", scope=tik.scope_ubuf) broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[65536 + phase_2 * block_index * 2], 0, 1, 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1 + 64], input_x[65536 + 128 + phase_2 * block_index * 2], 0, 1, + 8, 0, 0) repeat_time = each_block_element // 64 tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 49 and input_x_shape[1] == 4: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, 24, 1, 1, 8, 8) - tik_instance.vabs(32, input_x_ub[1536], input_x_ub[1536], 1, 1, 1, 8, 8) - tik_instance.vmax(32, input_x_ub[1504], input_x_ub[1504], input_x_ub[1536], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[2048], input_x_ub[2048], input_x_ub[2048 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 256], 4, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 128], 2, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 64], 1, 1, 1, 1, 8, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 1, 1, 1, 1, 8, 8, 8) - with tik_instance.for_range(0, 64) as cc0: - data_temp = tik_instance.Scalar("float32") - data_temp.set_as(input_x_ub[cc0]) - tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, - 1, 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, - 8, 8, 8) - tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, - 8, 8, 8) - tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - elif input_x_shape[0] == 1 and input_x_shape[1] == 64 and input_x_shape[2] == 64: - input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) - res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) - total_elements = 1 - for val in input_x_shape: - total_elements *= val - blocks = 32 - each_block_element = total_elements // blocks - with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: - input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", - scope=tik.scope_ubuf) - broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", - scope=tik.scope_ubuf) - tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, - each_block_element // 8, 0, 0) - repeat_time = each_block_element // 64 - tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) - tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 1, 1, 1, 1, 8, 8, 8) with tik_instance.for_range(0, 64) as cc0: data_temp = tik_instance.Scalar("float32") data_temp.set_as(input_x_ub[cc0]) @@ -1061,10 +487,189 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, 8, 8, 8) tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) - else: - raise RuntimeError("UnSupportedShape") - elif len(input_x_shape) == 2 and (input_x_shape[0] == 32 and input_x_shape[1] == 64): + raise RuntimeError("origin shape %s is not supported" % str(ori_shape)) + elif input_info == ((9, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((18, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((36, 128, 128), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, + 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 1024], 16, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 512], 8, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 256], 4, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 128], 2, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((1, 64, 64), "float32"): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_info == ((32, 64), "float32"): input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) input_x_ub = tik_instance.Tensor("float32", (32 * 64,), name="input_x_ub", scope=tik.scope_ubuf) diff --git a/mindspore/ops/_op_impl/akg/ascend/__init__.py b/mindspore/ops/_op_impl/akg/ascend/__init__.py index a4d7aec7d0..d1029e0bcb 100644 --- a/mindspore/ops/_op_impl/akg/ascend/__init__.py +++ b/mindspore/ops/_op_impl/akg/ascend/__init__.py @@ -14,17 +14,34 @@ """__init__""" +from .abs import _abs_akg from .add import _add_akg +from .add_n import _addn_akg from .batchmatmul import _batchmatmul_akg from .cast import _cast_akg +from .equal import _equal_akg +from .exp import _exp_akg from .expand_dims import _expand_dims_akg from .greater import _greater_akg +from .greater_equal import _greater_equal_akg from .inplace_assign import _inplace_assign_akg +from .less import _less_akg +from .less_equal import _less_equal_akg +from .log import _log_akg from .maximum import _maximum_akg from .minimum import _minimum_akg from .mul import _mul_akg +from .neg import _neg_akg +from .pow import _power_akg from .real_div import _real_div_akg +from .reciprocal import _reciprocal_akg +from .reduce_max import _reduce_max_akg +from .reduce_min import _reduce_min_akg +from .reduce_sum import _reduce_sum_akg from .rsqrt import _rsqrt_akg from .select import _select_akg from .sqrt import _sqrt_akg +from .square import _square_akg from .sub import _sub_akg + +# Please insert op register in lexicographical order of the filename. diff --git a/mindspore/ops/_op_impl/akg/ascend/abs.py b/mindspore/ops/_op_impl/akg/ascend/abs.py new file mode 100644 index 0000000000..4128d8bc24 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/abs.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Abs op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Abs") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _abs_akg(): + """Abs Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/add_n.py b/mindspore/ops/_op_impl/akg/ascend/add_n.py new file mode 100644 index 0000000000..efd1fa474b --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/add_n.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AddN op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("AddN") \ + .fusion_type("ELEMWISE") \ + .input(0, "inputs", "dynamic") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracZ, DT.F16_FracZ) \ + .dtype_format(DT.F32_FracZ, DT.F32_FracZ) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _addn_akg(): + """AddN Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/equal.py b/mindspore/ops/_op_impl/akg/ascend/equal.py new file mode 100644 index 0000000000..bda0c700dc --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/equal.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Equal op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Equal") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.BOOL_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.BOOL_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.BOOL_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _equal_akg(): + """Equal Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/exp.py b/mindspore/ops/_op_impl/akg/ascend/exp.py new file mode 100644 index 0000000000..bcea4b16da --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/exp.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Exp op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Exp") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _exp_akg(): + """Exp Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/greater_equal.py b/mindspore/ops/_op_impl/akg/ascend/greater_equal.py new file mode 100644 index 0000000000..a5d63ba660 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/greater_equal.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""GreaterEqual op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("GreaterEqual") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.BOOL_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.BOOL_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.BOOL_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _greater_equal_akg(): + """Equal Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/less.py b/mindspore/ops/_op_impl/akg/ascend/less.py new file mode 100644 index 0000000000..33699d07af --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/less.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Less op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Less") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _less_akg(): + """Less Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/less_equal.py b/mindspore/ops/_op_impl/akg/ascend/less_equal.py new file mode 100644 index 0000000000..75e1b9a42a --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/less_equal.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LessEqual op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("LessEqual") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.BOOL_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.BOOL_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.BOOL_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _less_equal_akg(): + """Equal Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/log.py b/mindspore/ops/_op_impl/akg/ascend/log.py new file mode 100644 index 0000000000..0a6b04df87 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/log.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Log op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Log") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + +@op_info_register(op_info) +def _log_akg(): + """Log Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/neg.py b/mindspore/ops/_op_impl/akg/ascend/neg.py new file mode 100644 index 0000000000..dc1a7c7fcc --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/neg.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Neg op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Neg") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ) \ + .get_op_info() + +@op_info_register(op_info) +def _neg_akg(): + """Neg Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/pow.py b/mindspore/ops/_op_impl/akg/ascend/pow.py new file mode 100644 index 0000000000..b6c1fdc0a6 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/pow.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Pow op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Pow") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "power") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _power_akg(): + """Pow Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/reciprocal.py b/mindspore/ops/_op_impl/akg/ascend/reciprocal.py new file mode 100644 index 0000000000..0d3823148a --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/reciprocal.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Reciprocal op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Reciprocal") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _reciprocal_akg(): + """Reciprocal Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/reduce_max.py b/mindspore/ops/_op_impl/akg/ascend/reduce_max.py new file mode 100644 index 0000000000..a21fd33494 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/reduce_max.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReduceMax op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("ReduceMax") \ + .fusion_type("COMMREDUCE") \ + .input(0, "x") \ + .output(0, "output") \ + .attr("axis", "required", "listInt") \ + .attr("keep_dims", "required", "bool") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _reduce_max_akg(): + """ReduceMax Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/reduce_min.py b/mindspore/ops/_op_impl/akg/ascend/reduce_min.py new file mode 100644 index 0000000000..cc42c7bd26 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/reduce_min.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReduceMin op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("ReduceMin") \ + .fusion_type("COMMREDUCE") \ + .input(0, "x") \ + .output(0, "output") \ + .attr("axis", "required", "listInt") \ + .attr("keep_dims", "required", "bool") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .get_op_info() + +@op_info_register(op_info) +def _reduce_min_akg(): + """ReduceMin Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/reduce_sum.py b/mindspore/ops/_op_impl/akg/ascend/reduce_sum.py new file mode 100644 index 0000000000..cf5ab6eaba --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/reduce_sum.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReduceSum op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("ReduceSum") \ + .fusion_type("COMMREDUCE") \ + .input(0, "x") \ + .output(0, "output") \ + .attr("axis", "required", "listInt") \ + .attr("keep_dims", "required", "bool") \ + .attr("atomic_add", "optional", "str") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + +@op_info_register(op_info) +def _reduce_sum_akg(): + """ReduceSum Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/square.py b/mindspore/ops/_op_impl/akg/ascend/square.py new file mode 100644 index 0000000000..60853bc0ee --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/square.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Square op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Square") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _square_akg(): + """Square Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/cast.py b/mindspore/ops/_op_impl/akg/gpu/cast.py index c8aef249cd..68c280f348 100644 --- a/mindspore/ops/_op_impl/akg/gpu/cast.py +++ b/mindspore/ops/_op_impl/akg/gpu/cast.py @@ -21,10 +21,42 @@ cast_op_info = AkgGpuRegOp("Cast") \ .output(0, "output") \ .attr("dst_type", "required", "str") \ .dtype_format(DataType.F16_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.F16_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.F64_Default) \ + .dtype_format(DataType.I8_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.F16_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I16_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default) \ .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.F16_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I8_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I16_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.F32_Default) \ + .dtype_format(DataType.U8_Default, DataType.F16_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default) \ + .dtype_format(DataType.I16_Default, DataType.F64_Default) \ + .dtype_format(DataType.I16_Default, DataType.F32_Default) \ + .dtype_format(DataType.I16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I16_Default, DataType.F32_Default) \ + .dtype_format(DataType.I16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.F16_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/akg/gpu/equal.py b/mindspore/ops/_op_impl/akg/gpu/equal.py index 40a3590f61..c63988f202 100644 --- a/mindspore/ops/_op_impl/akg/gpu/equal.py +++ b/mindspore/ops/_op_impl/akg/gpu/equal.py @@ -22,6 +22,7 @@ equal_op_info = AkgGpuRegOp("Equal") \ .output(0, "output") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 317509b5a9..4146089ffa 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -81,6 +81,9 @@ from .sub import _sub_tbe from .reduce_mean_d import _reduce_mean_d_tbe from .scatter_nd import _scatter_nd_tbe from .scatter_nd_d import _scatter_nd_d_tbe +from .scatter_nd_add import _scatter_nd_add_tbe +from .scatter_nd_sub import _scatter_nd_sub_tbe +from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe from .reduce_mean import _reduce_mean_tbe from .tile import _tile_tbe from .atomic_addr_clean import _atomic_addr_clean_tbe @@ -92,7 +95,10 @@ from .bn_training_update import _bn_training_update_tbe from .bn_training_update_grad import _bn_training_update_grad_tbe from .bn_infer import _bn_infer_tbe from .bn_infer_grad import _bn_infer_grad_tbe +from .bn_inference import _bn_inference_tbe from .reciprocal import _reciprocal_tbe +from .reverse_v2_d import _reverse_v2_d_tbe +from .rint import _rint_tbe from .strided_slice_d import _strided_slice_d_tbe from .strided_slice_grad_d import _strided_slice_grad_d_tbe from .split_d import _split_d_tbe @@ -102,6 +108,8 @@ from .elu import _elu_tbe from .elu_grad import _elu_grad_tbe from .div import _div_tbe from .log import _log_tbe +from .xdivy import _xdivy_tbe +from .xlogy import _xlogy_tbe from .floor_div import _floor_div_tbe from .zeros_like import _zeros_like_tbe from .neg import _neg_tbe @@ -127,11 +135,14 @@ from .softplus import _softplus_tbe from .softplus_grad import _softplus_grad_tbe from .softmax_grad_ext import _softmax_grad_ext_tbe from .square import _square_tbe +from .squared_difference import _squared_difference_tbe from .sqrt import _sqrt_tbe from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad from .apply_proximal_adagrad import _apply_proximal_adagrad from .transpose_d import _transpose_d_tbe +from .truncate_div import _truncate_div_tbe +from .truncate_mod import _truncate_mod_tbe from .unsorted_segment_sum import _unsorted_segment_sum_tbe from .unsorted_segment_prod import _unsorted_segment_prod_tbe from .logsoftmax_grad import _logsoftmax_grad_tbe @@ -188,6 +199,7 @@ from .floor_mod import _floor_mod_tbe from .scatter_nd_update import _scatter_nd_update_tbe from .avg_pool import _avg_pool_tbe from .avg_pool_grad import _avg_pool_grad_tbe +from .avg_pool_grad_vm import _avg_pool_grad_vm_tbe from .ones_like import _ones_like_tbe from .batch_to_space import _batch_to_space_tbe from .space_to_batch import _space_to_batch_tbe @@ -222,10 +234,14 @@ from .binary_cross_entropy import _binary_cross_entropy_tbe from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe from .sin import _sin_tbe from .cos import _cos_tbe +from .tan import _tan_tbe from .cum_sum import _cum_sum_tbe from .apply_rms_prop import _apply_rms_prop_tbe from .cumprod import _cumprop_tbe from .reduce_prod import _reduce_prod_tbe +from .reciprocal_grad import _reciprocal_grad_tbe +from .sqrt_grad import _sqrt_grad_tbe +from .rsqrt_grad import _rsqrt_grad_tbe from .flatten_grad import _flatten_grad_tbe from .scatter_add import _scatter_add_tbe from .atan2 import _atan2_tbe @@ -237,6 +253,7 @@ from .bitwise_and import _bitwise_and_tbe from .bitwise_or import _bitwise_or_tbe from .bitwise_xor import _bitwise_xor_tbe from .reduce_all import _reduce_all_tbe +from .reduce_any import _reduce_any_tbe from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe from .unsorted_segment_min import _unsorted_segment_min_tbe from .asin import _asin_tbe diff --git a/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py b/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py index 8499614324..f372583a85 100644 --- a/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +++ b/mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py @@ -13,15 +13,15 @@ # limitations under the License. # ============================================================================ -"""ApplyCenteredRMSProp op""" +"""ApplyCenteredRMSPropD op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("apply_centered_rms_prop.so") \ + .binfile_name("apply_centered_rms_prop_d.so") \ .compute_cost(10) \ - .kernel_name("apply_centered_rms_prop") \ + .kernel_name("apply_centered_rms_prop_d") \ .partial_flag(True) \ .input(0, "var", False, "required", "all") \ .input(1, "mg", False, "required", "all") \ @@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \ .input(7, "epsilon", False, "required", "all") \ .input(8, "grad", False, "required", "all") \ .output(0, "var", False, "required", "all") \ + .output(1, "mg", False, "required", "all") \ + .output(2, "ms", False, "required", "all") \ + .output(3, "mom", False, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_5HD, DataType.F16_5HD) \ + DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_FracZ, DataType.F16_FracZ) \ + DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, + DataType.F16_FracZ) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, + DataType.F16_C1HWNCoC0) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default) \ + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, + DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ .get_op_info() @op_info_register(apply_centered_rms_prop_op_info) def _apply_centered_rms_prop_tbe(): - """ApplyCenteredRMSProp TBE register""" + """ApplyCenteredRMSPropD TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py b/mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py new file mode 100644 index 0000000000..d15613224a --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AvgPoolGradVm op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +avg_pool_grad_vm_op_info = TBERegOp("AvgPoolGradVm") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("avg_pool_grad_d.so") \ + .compute_cost(10) \ + .kernel_name("avg_pool_grad_d") \ + .partial_flag(True) \ + .attr("x_origin", "required", "listInt", "all") \ + .attr("ksize", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("padding", "required", "str", "all") \ + .attr("data_format", "optional", "str", "all") \ + .input(0, "input_grad", False, "required", "all") \ + .input(1, "mean_matrix", False, "optional", "all") \ + .input(2, "kernel_matrix", False, "optional", "all") \ + .output(0, "out_grad", True, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(avg_pool_grad_vm_op_info) +def _avg_pool_grad_vm_tbe(): + """AvgPoolGradVm TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/bn_inference.py b/mindspore/ops/_op_impl/tbe/bn_inference.py new file mode 100644 index 0000000000..5d4a4a9120 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bn_inference.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BNInference op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bn_inference_op_info = TBERegOp("BNInference") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("bninference_d.so") \ + .compute_cost(10) \ + .kernel_name("bninference_d") \ + .partial_flag(True) \ + .attr("momentum", "optional", "float", "all", "0.999") \ + .attr("epsilon", "optional", "float", "all", "0.00001") \ + .attr("use_global_stats", "optional", "bool", "true,false", "true") \ + .attr("mode", "optional", "int", "all", "1") \ + .input(0, "x", False, "required", "all") \ + .input(1, "mean", False, "required", "all") \ + .input(2, "variance", False, "required", "all") \ + .input(3, "scale", False, "optional", "all") \ + .input(4, "offset", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(bn_inference_op_info) +def _bn_inference_tbe(): + """BNInference TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/matmul.py b/mindspore/ops/_op_impl/tbe/matmul.py index 7784d5e222..0f68fa4c9d 100644 --- a/mindspore/ops/_op_impl/tbe/matmul.py +++ b/mindspore/ops/_op_impl/tbe/matmul.py @@ -17,7 +17,7 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType matmul_op_info = TBERegOp("MatMul") \ - .fusion_type("DYNAMIC") \ + .fusion_type("ELEMWISE") \ .async_flag(False) \ .binfile_name("matmul.so") \ .compute_cost(10) \ diff --git a/mindspore/ops/_op_impl/tbe/reciprocal.py b/mindspore/ops/_op_impl/tbe/reciprocal.py index c620fb17a6..eacfdd6bce 100644 --- a/mindspore/ops/_op_impl/tbe/reciprocal.py +++ b/mindspore/ops/_op_impl/tbe/reciprocal.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""Add op""" +"""Reciprocal op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType reciprocal_op_info = TBERegOp("Reciprocal") \ @@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ @op_info_register(reciprocal_op_info) def _reciprocal_tbe(): - """Add TBE register""" + """Reciprocal TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/reciprocal_grad.py b/mindspore/ops/_op_impl/tbe/reciprocal_grad.py new file mode 100644 index 0000000000..48c7169861 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/reciprocal_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReciprocalGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("reciprocal_grad.so") \ + .compute_cost(10) \ + .kernel_name("reciprocal_grad") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(reciprocal_grad_op_info) +def _reciprocal_grad_tbe(): + """ReciprocalGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/reduce_any.py b/mindspore/ops/_op_impl/tbe/reduce_any.py new file mode 100644 index 0000000000..101a5e0506 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/reduce_any.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReduceAny op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +reduce_any_op_info = TBERegOp("ReduceAny") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("reduce_any_d.so") \ + .compute_cost(10) \ + .kernel_name("reduce_any_d") \ + .partial_flag(True) \ + .attr("axis", "required", "listInt", "all") \ + .attr("keep_dims", "optional", "bool", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("reduce") \ + .dtype_format(DataType.BOOL_None, DataType.BOOL_None) \ + .get_op_info() + + +@op_info_register(reduce_any_op_info) +def _reduce_any_tbe(): + """ReduceAny TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/reverse_v2_d.py b/mindspore/ops/_op_impl/tbe/reverse_v2_d.py new file mode 100644 index 0000000000..035af55f20 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/reverse_v2_d.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReverseV2D op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +reverse_v2_d_op_info = TBERegOp("ReverseV2") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("reverse_v2_d.so") \ + .compute_cost(10) \ + .kernel_name("reverse_v2_d") \ + .partial_flag(True) \ + .op_pattern("dynamicFormat") \ + .attr("axis", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.None_None, DataType.None_None) \ + .get_op_info() + + +@op_info_register(reverse_v2_d_op_info) +def _reverse_v2_d_tbe(): + """ReverseV2D TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/rint.py b/mindspore/ops/_op_impl/tbe/rint.py new file mode 100644 index 0000000000..eabeea1c94 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/rint.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Rint op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +rint_op_info = TBERegOp("Rint") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("rint.so") \ + .compute_cost(10) \ + .kernel_name("rint") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(rint_op_info) +def _rint_tbe(): + """Rint TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/rsqrt_grad.py b/mindspore/ops/_op_impl/tbe/rsqrt_grad.py new file mode 100644 index 0000000000..914c7eca8b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/rsqrt_grad.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RsqrtGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("rsqrt_grad.so") \ + .compute_cost(10) \ + .kernel_name("rsqrt_grad") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(rsqrt_grad_op_info) +def _rsqrt_grad_tbe(): + """RsqrtGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_nd_add.py b/mindspore/ops/_op_impl/tbe/scatter_nd_add.py new file mode 100644 index 0000000000..abe609a91c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_nd_add.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterNdAdd op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_nd_add_op_info = TBERegOp("ScatterNdAdd") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_nd_add.so") \ + .compute_cost(10) \ + .kernel_name("scatter_nd_add") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(scatter_nd_add_op_info) +def _scatter_nd_add_tbe(): + """ScatterNdAdd TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_nd_sub.py b/mindspore/ops/_op_impl/tbe/scatter_nd_sub.py new file mode 100644 index 0000000000..f985c412c5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_nd_sub.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterNdSub op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_nd_sub_op_info = TBERegOp("ScatterNdSub") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_nd_sub.so") \ + .compute_cost(10) \ + .kernel_name("scatter_nd_sub") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(scatter_nd_sub_op_info) +def _scatter_nd_sub_tbe(): + """ScatterNdSub TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py b/mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py new file mode 100644 index 0000000000..d1e278b102 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterNonAliasingAdd op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_non_aliasing_add_op_info = TBERegOp("ScatterNonAliasingAdd") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_non_aliasing_add.so") \ + .compute_cost(10) \ + .kernel_name("scatter_non_aliasing_add") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(scatter_non_aliasing_add_op_info) +def _scatter_non_aliasing_add_tbe(): + """ScatterNonAliasingAdd TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sqrt_grad.py b/mindspore/ops/_op_impl/tbe/sqrt_grad.py new file mode 100644 index 0000000000..a951bb0f8a --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sqrt_grad.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SqrtGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sqrt_grad_op_info = TBERegOp("SqrtGrad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("sqrt_grad.so") \ + .compute_cost(10) \ + .kernel_name("sqrt_grad") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(sqrt_grad_op_info) +def _sqrt_grad_tbe(): + """SqrtGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/squared_difference.py b/mindspore/ops/_op_impl/tbe/squared_difference.py new file mode 100644 index 0000000000..f567b91964 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/squared_difference.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SquaredDifference op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +squared_difference_op_info = TBERegOp("SquaredDifference") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("squared_difference.so") \ + .compute_cost(10) \ + .kernel_name("squared_difference") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(squared_difference_op_info) +def _squared_difference_tbe(): + """SquaredDifference TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/tan.py b/mindspore/ops/_op_impl/tbe/tan.py new file mode 100644 index 0000000000..2287e4bc07 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/tan.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tan op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +tan_op_info = TBERegOp("Tan") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("tan.so") \ + .compute_cost(10) \ + .kernel_name("tan") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ + .get_op_info() + + +@op_info_register(tan_op_info) +def _tan_tbe(): + """Tan TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/truncate_div.py b/mindspore/ops/_op_impl/tbe/truncate_div.py new file mode 100644 index 0000000000..583d96b7f3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/truncate_div.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TruncateDiv op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +truncate_div_op_info = TBERegOp("TruncateDiv") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("truncate_div.so") \ + .compute_cost(10) \ + .kernel_name("truncate_div") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ + .get_op_info() + + +@op_info_register(truncate_div_op_info) +def _truncate_div_tbe(): + """TruncateDiv TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/truncate_mod.py b/mindspore/ops/_op_impl/tbe/truncate_mod.py new file mode 100644 index 0000000000..b8cfa991e2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/truncate_mod.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TruncateMod op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +truncate_mod_op_info = TBERegOp("TruncateMod") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("truncate_mod.so") \ + .compute_cost(10) \ + .kernel_name("truncate_mod") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ + .get_op_info() + + +@op_info_register(truncate_mod_op_info) +def _truncate_mod_tbe(): + """TruncateMod TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/xdivy.py b/mindspore/ops/_op_impl/tbe/xdivy.py new file mode 100644 index 0000000000..1624576c2e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/xdivy.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Xdivy op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +xdivy_op_info = TBERegOp("Xdivy") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("xdivy.so") \ + .compute_cost(10) \ + .kernel_name("xdivy") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(xdivy_op_info) +def _xdivy_tbe(): + """Xdivy TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/xlogy.py b/mindspore/ops/_op_impl/tbe/xlogy.py new file mode 100644 index 0000000000..7a997f216b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/xlogy.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Xlogy op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +xlogy_op_info = TBERegOp("Xlogy") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("xlogy.so") \ + .compute_cost(10) \ + .kernel_name("xlogy") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(xlogy_op_info) +def _xlogy_tbe(): + """Xlogy TBE register""" + return diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index 0e6850dcb1..9ee599e6ef 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -82,5 +82,8 @@ def get_concat_offset(x_shp, x_type, axis, prim_name): if j != axis and v[j] != x_shp[0][j]: raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") offset.append(all_shp) - all_shp += v[axis] + if all_shp == -1 or v[axis] == -1: + all_shp = -1 + else: + all_shp += v[axis] return offset, all_shp, axis diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index ab35dd65fb..6656dafdb4 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .clip_ops import clip_by_value from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like -from .random_ops import set_seed, normal, uniform +from .random_ops import set_seed, normal, multinomial, uniform __all__ = [ @@ -51,4 +51,5 @@ __all__ = [ 'set_seed', 'uniform', 'normal', + 'multinomial', 'clip_by_value',] diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 0f28d9572f..766bedc5d0 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -129,14 +129,14 @@ class GradOperation(GradOperation_): output = fn(*args) _pynative_exec.end_graph(fn, output, *args) else: - if fn.is_run and not fn.requires_grad: + if fn.already_run and not fn.requires_grad: raise ValueError("obj must set_grad.") - if not fn.is_run: + if not fn.already_run: self.need_forward = True - print("already has forward run before grad by user") if self.need_forward: fn.set_grad() fn(*args) + fn.already_run = False def __call__(self, fn, weights=None): grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) diff --git a/mindspore/ops/composite/multitype_ops/ones_like_impl.py b/mindspore/ops/composite/multitype_ops/ones_like_impl.py index d88193f845..840571c8b1 100644 --- a/mindspore/ops/composite/multitype_ops/ones_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/ones_like_impl.py @@ -42,6 +42,16 @@ def _ones_like_tensor(x): return P.Fill()(P.DType()(x), P.Shape()(x), 1.0) +@ones_like_leaf.register("SparseTensor") +def _ones_like_sparse_tensor(x): + """Returns a tensor with the same shape and dtype as x and all elements are 1.""" + values_ = F.sparse_tensor_get_values(x) + values = P.Fill()(P.DType()(values_), + P.Shape()(values_), + 1.0) + return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x)) + + ones_like = base.HyperMap(ones_like_leaf) """ `ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype. diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index 88037aefb7..70313e72e0 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -20,17 +20,46 @@ from .. import functional as F from ..primitive import constexpr from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype +from ...common.tensor import Tensor +from ..._checkparam import Validator as validator +from ..._checkparam import check_int_positive +from ..._checkparam import Rel # set graph-level RNG seed _GRAPH_SEED = 0 @constexpr def set_seed(seed): + """ + Set the graph-level seed. + Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. + If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a + random seed. + + Args: + seed(Int): the graph-level seed value that to be set. + + Examples: + >>> C.set_seed(10) + """ + check_int_positive(seed) global _GRAPH_SEED _GRAPH_SEED = seed @constexpr def get_seed(): + """ + Get the graph-level seed. + Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. + If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a + random seed. + + Returns: + Interger. The current graph-level seed. + + Examples: + >>> C.get_seed(10) + """ return _GRAPH_SEED @@ -69,6 +98,54 @@ def normal(shape, mean, stddev, seed=0): value = rnd * stddev + mean return value + +def multinomial(inputs, num_sample=None, replacement=True, seed=0): + r""" + Returns a tensor sampled from the multinomial probability distribution located in the corresponding + row of tensor input. + + Note: + The rows of input do not need to sum to one (in which case we use the values as weights), + but must be non-negative, finite and have a non-zero sum. + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. + - **num_samples** (int) - number of samples to draw, default None. + - **replacement** (bool, optional) - whether to draw with replacement or not, default True. + + Outputs: + Tensor. have the same rows with input, each row has num_samples sampled indices. + + Examples: + >>> input = Tensor([0, 9, 4, 0], mstype.float32) + >>> output = C.multinomial(input, 2, True) + """ + shape = P.Shape() + reshape = P.Reshape() + validator.check_value_type('replacement', replacement, (bool,), None) + validator.check_value_type('num_sample', num_sample, (int,), None) + validator.check_integer("num_sample", num_sample, 0, Rel.GT, None) + if inputs.dim() != 1 and inputs.dim() != 2: + raise ValueError("inputs dim must be 1d or 2d") + if not replacement: + if shape(inputs)[-1] < num_sample: + raise ValueError("num_sample must be less than shape(input)[-1] without replacement") + n_dist = 1 + if len(shape(inputs)) > 1: + n_dist = shape(inputs)[-2] + a = Tensor(0.0, mstype.float32) + b = Tensor(1.0, mstype.float32) + uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b) + if n_dist != 1: + uniform = reshape(uniform, (n_dist, num_sample)) + vals = P.RealDiv()(P.Log()(uniform), inputs + 1e-6) + _, indices = P.TopK()(vals, num_sample) + return indices + return P.Multinomial(seed=seed)(inputs, num_sample) + def uniform(shape, a, b, seed=0, dtype=mstype.float32): """ Generates random numbers according to the Uniform (or Gaussian) random number distribution. diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 2be011cb77..226a570a09 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -28,7 +28,7 @@ hastype = Primitive('hastype') cast = P.Cast() dtype = P.DType() isconstant = Primitive('is_constant') - +isconstant.set_is_const_value(True) issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() @@ -66,6 +66,7 @@ assign_sub = P.AssignSub() assign = P.Assign() square = P.Square() sqrt = P.Sqrt() + scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() tuple_to_array = P.TupleToArray() @@ -82,7 +83,6 @@ partial = P.Partial() # depend: mount a node to another node depend = P.Depend() - tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') list_getitem = Primitive('list_getitem') @@ -101,7 +101,6 @@ tuple_equal = Primitive("tuple_equal") list_equal = Primitive("list_equal") make_ref = Primitive("make_ref") - scalar_add = Primitive('scalar_add') scalar_mul = Primitive('scalar_mul') scalar_sub = Primitive('scalar_sub') @@ -153,12 +152,15 @@ shape_mul = Primitive("shape_mul") # a primitive to compare between tuple. stop_gradient = Primitive("stop_gradient") +make_row_tensor = Primitive('MakeRowTensor') +row_tensor_get_values = Primitive('RowTensorGetValues') +row_tensor_get_indices = Primitive('RowTensorGetIndices') +row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') -make_indexed_slices = Primitive('MakeIndexedSlices') -indexed_slices_get_values = Primitive('IndexedSlicesGetValues') -indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') -indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') - +make_sparse_tensor = Primitive('MakeSparseTensor') +sparse_tensor_get_values = Primitive('SparseTensorGetValues') +sparse_tensor_get_indices = Primitive('SparseTensorGetIndices') +sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape') tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__sub__', tensor_sub) @@ -167,7 +169,9 @@ tensor_operator_registry.register('__truediv__', tensor_div) tensor_operator_registry.register('__mod__', tensor_mod) tensor_operator_registry.register('__pow__', tensor_pow) tensor_operator_registry.register('__floordiv__', tensor_floordiv) -#ms cannot support Tensor(True) compare +tensor_operator_registry.register('all', P.ReduceAll) +tensor_operator_registry.register('any', P.ReduceAny) +# ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) tensor_operator_registry.register('__neg__', neg_tensor) @@ -176,5 +180,6 @@ tensor_operator_registry.register('__le__', tensor_le) tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__ge__', tensor_ge) tensor_operator_registry.register('shape', shape) -#support GE backend for no compare operators +# support GE backend for no compare operators tensor_operator_registry.register('vm_compare', BP.vm_compare) +tensor_operator_registry.register('cast', cast) diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 6ab915e369..65d6d2cdb8 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -28,13 +28,14 @@ BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op" def op_info_register(op_info): """ - A decorator used as register of operator implementation. + A decorator which is used to register an operator. Note: - 'op_info' must be a str of json format represent the op info, the op info will be added into oplib. + 'op_info' should represent the operator information by string with json format. + The 'op_info' will be added into oplib. Args: - op_info (str or dict): op info of json format. + op_info (str or dict): operator information in json format. Returns: Function, returns a decorator for op info register. @@ -220,18 +221,19 @@ class AkgRegOp(RegOp): self.imply_type = "AKG" self.processor = processor - def input(self, index=None, name=None, **kwargs): + def input(self, index=None, name=None, param_type=None, **kwargs): """ Register Akg op input information. Args: index (int): Order of the input. Default: None. name (str): Name of the input. Default: None. - kwargs (dict): Other information for the input. + param_type (str): Param type of the input. Default: None. + kwargs (dict): Other information of the input. """ - param_list = [index, name] - key_list = ["index", "name"] - fn_list = [self._is_int, self._is_string] + param_list = [index, name, param_type] + key_list = ["index", "name", "param_type"] + fn_list = [self._is_int, self._is_string, self._is_string] input_dict = self._check_param(param_list, key_list, fn_list, kwargs) self.inputs.append(input_dict) return self @@ -243,7 +245,7 @@ class AkgRegOp(RegOp): Args: index (int): Order of the output. Default: None. name (str): Name of the output. Default: None. - kwargs (dict): Other information for the output. + kwargs (dict): Other information of the output. """ param_list = [index, name] key_list = ["index", "name"] @@ -260,7 +262,7 @@ class AkgRegOp(RegOp): name (str): Name of the attribute. Default: None. param_type (str): Param type of the attribute. Default: None. value_type (str): Value type of the attribute. Default: None. - kwargs (dict): Other information for the attribute. + kwargs (dict): Other information of the attribute. """ param_list = [name, param_type, value_type] key_list = ["name", "param_type", "type"] @@ -295,7 +297,7 @@ class AiCPURegOp(RegOp): index (int): Order of the input. Default: None. name (str): Name of the input. Default: None. param_type (str): Param type of the input. Default: None. - kwargs (dict): Other information for the input. + kwargs (dict): Other information of the input. """ param_list = [index, name, param_type] key_list = ["index", "name", "param_type"] @@ -312,7 +314,7 @@ class AiCPURegOp(RegOp): index (int): Order of the output. Default: None. name (str): Name of the output. Default: None. param_type (str): Param type of the output. Default: None. - kwargs (dict): Other information for the output. + kwargs (dict): Other information of the output. """ param_list = [index, name, param_type] key_list = ["index", "name", "param_type"] @@ -328,8 +330,8 @@ class AiCPURegOp(RegOp): Args: name (str): Name of the attribute. Default: None. value_type (str): Value type of the attribute. Default: None. - value (str): Value type of the attribute. Default: None. - kwargs (dict): Other information for the attribute. + value (str): Value of the attribute. Default: None. + kwargs (dict): Other information of the attribute. """ param_list = [name, value_type, value] key_list = ["name", "type", "value"] @@ -356,7 +358,7 @@ class TBERegOp(RegOp): def async_flag(self, async_flag): """ - Define the calculation efficiency of operator, whether to support asynchronous calculation. + Define the calculation efficiency of the operator, whether the asynchronous calculation is supported. Args: async_flag (bool): Value of async flag. Default: false. @@ -367,10 +369,10 @@ class TBERegOp(RegOp): def binfile_name(self, binfile_name): """ - Binary file name of operator. The option is optional. + Set the binary file name of the operator, it is optional. Args: - binfile_name (str): File name of operator binary. + binfile_name (str): The binary file name of the operator. """ self._is_string(binfile_name) self.binfile_name_ = binfile_name @@ -378,7 +380,7 @@ class TBERegOp(RegOp): def compute_cost(self, compute_cost): """ - Define the calculation efficiency of operator, which refers to cost model value of the tiling module. + Define the calculation efficiency of operator, which refers to the value of the cost model in the tiling module. Args: compute_cost (int): Value of compute cost. Default: 10. @@ -400,7 +402,7 @@ class TBERegOp(RegOp): def partial_flag(self, partial_flag): """ - Define the calculation efficiency of operator, whether to support partial calculation. + Define the calculation efficiency of operator, whether the partial calculation is supported. Args: partial_flag (bool): Value of partial flag. Default: true. @@ -422,7 +424,7 @@ class TBERegOp(RegOp): def dynamic_format(self, dynamic_format): """ - Whether the operator supports dynamic selection of format and dtype. + Whether the operator supports dynamic selection of format and dtype or not. Args: dynamic_format (bool): Value of dynamic format. Default: false. @@ -452,7 +454,7 @@ class TBERegOp(RegOp): value_type (str): Type of the attribute. Default: None. value (str): Value of the attribute. Default: None. default_value (str): Default value of attribute. Default: None. - kwargs (dict): Other information for the attribute. + kwargs (dict): Other information of the attribute. """ param_list = [name, param_type, value_type, value, default_value] key_list = ["name", "param_type", "type", "value", "default_value"] @@ -468,10 +470,10 @@ class TBERegOp(RegOp): Args: index (int): Order of the input. Default: None. name (str): Name of the input. Default: None. - need_compile (bool): The input need compile whether or not. Default: None. + need_compile (bool): Whether the input needs to be compiled or not. Default: None. param_type (str): Type of the input. Default: None. shape (str): Shape of the input. Default: None. - kwargs (dict): Other information for the input. + kwargs (dict): Other information of the input. """ param_list = [index, name, need_compile, param_type, shape] key_list = ["index", "name", "need_compile", "param_type", "shape"] @@ -487,10 +489,10 @@ class TBERegOp(RegOp): Args: index (int): Order of the output. Default: None. name (str): Name of the output. Default: None. - need_compile (bool): The output need compile whether or not. Default: None. + need_compile (bool): Whether the output needs to be compiled or not. Default: None. param_type (str): Type of the output. Default: None. shape (str): Shape of the output. Default: None. - kwargs (dict): Other information for the output. + kwargs (dict): Other information of the output. """ param_list = [index, name, need_compile, param_type, shape] key_list = ["index", "name", "need_compile", "param_type", "shape"] @@ -504,7 +506,7 @@ class DataType: """ Various combinations of dtype and format. - The current list below maybe not completed. If necessary, please add it. + The current list below may be incomplete. Please add it if necessary. """ None_None = ("", "") diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 9def0186c9..ca03ad2edf 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -28,10 +28,12 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, + ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, - SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup) + SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, + Unique) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, @@ -43,19 +45,19 @@ from .inner_ops import ScalarCast from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, - ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, + ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny, Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod, LogicalNot, LogicalOr, MatMul, Maximum, Minimum, Mul, Neg, NMSWithMask, NotEqual, NPUAllocFloatStatus, NPUClearFloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, - Reciprocal, CumSum, HistogramFixedWidth, - Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, - Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) + Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, + Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, + Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, - RandomCategorical, Laplace) + RandomCategorical, Laplace, Multinomial) from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -72,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, - TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, + TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2, FusedSparseFtrl, FusedSparseProximalAdagrad, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, @@ -82,7 +84,11 @@ from . import _quant_ops from ._quant_ops import * from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull) -from .thor_ops import * +from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, + CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, + CusMatMulCubeDenseRight, + CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky) +from .sparse_ops import SparseToDense __all__ = [ 'ReverseSequence', @@ -105,6 +111,9 @@ __all__ = [ 'Rsqrt', 'Sqrt', 'Square', + 'SquaredDifference', + 'Xdivy', + 'Xlogy', 'Conv2D', 'Flatten', 'MaxPoolWithArgmax', @@ -176,6 +185,7 @@ __all__ = [ 'Tanh', 'RandomChoiceWithMask', 'StandardNormal', + 'Multinomial', 'Gamma', 'Poisson', 'UniformInt', @@ -215,6 +225,7 @@ __all__ = [ 'CTCLoss', 'RNNTLoss', 'ReduceAll', + 'ReduceAny', 'ScalarToArray', 'ScalarToTensor', 'TupleToArray', @@ -234,6 +245,11 @@ __all__ = [ 'ScatterNd', 'ScatterMax', 'ScatterMin', + 'ScatterNdAdd', + 'ScatterNdSub', + 'ScatterNonAliasingAdd', + 'ReverseV2', + 'Rint', 'ResizeNearestNeighbor', 'HistogramFixedWidth', 'Pad', @@ -279,6 +295,8 @@ __all__ = [ 'SigmoidCrossEntropyWithLogits', 'FloorDiv', 'FloorMod', + 'TruncateDiv', + 'TruncateMod', 'Ceil', 'Acosh', 'Asinh', @@ -298,6 +316,7 @@ __all__ = [ "LSTM", "Abs", "BinaryCrossEntropy", + "KLDivLoss", "SparseApplyAdagrad", "SparseApplyAdagradV2", "SpaceToDepth", @@ -337,6 +356,7 @@ __all__ = [ "BesselI1e", "Atan", "Atanh", + "Tan", "BasicLSTMCell", "BroadcastTo", "DataFormatDimMap", @@ -348,7 +368,8 @@ __all__ = [ "PopulationCount", "ParallelConcat", "Push", - "Pull" + "Pull", + 'SparseToDense', ] __all__.sort() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 5e5e56f708..7940662f48 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -23,7 +23,6 @@ from .._utils import get_concat_offset from ...common import dtype as mstype from .. import functional as F - class AbsGrad(PrimitiveWithInfer): """Computes gradients for abs operation.""" @@ -116,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer): return x +class ReciprocalGrad(PrimitiveWithInfer): + """Performs grad of Reciprocal operation.""" + + @prim_attr_register + def __init__(self): + """init ReciprocalGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return x_dtype + + +class RsqrtGrad(PrimitiveWithInfer): + """Performs grad of Rsqrt operation.""" + + @prim_attr_register + def __init__(self): + """init RsqrtGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) + return x_dtype + + +class SoftmaxGrad(PrimitiveWithInfer): + """Performs grad of Softmax operation.""" + + @prim_attr_register + def __init__(self): + """init SoftmaxGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return x_dtype + + +class SqrtGrad(PrimitiveWithInfer): + """Performs grad of Sqrt operation.""" + + @prim_attr_register + def __init__(self): + """init SqrtGrad""" + + def infer_shape(self, x_shape, dout_shape): + validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, dout_dtype): + args = {"x": x_dtype, "dout": dout_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return x_dtype + + class BatchNormGrad(PrimitiveWithInfer): """Performs grad of BatchNorm operation.""" @@ -145,6 +212,23 @@ class BiasAddGrad(Primitive): raise NotImplementedError +class KLDivLossGrad(PrimitiveWithInfer): + """Computes gradients for `KLDivLoss` operation.""" + + @prim_attr_register + def __init__(self, reduction='mean'): + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) + + def infer_shape(self, x_shape, y_shape, doutput_shape): + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) + return x_shape, y_shape + + def infer_dtype(self, x_type, y_type, doutput_type): + args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + return x_type, y_type + + class BinaryCrossEntropyGrad(PrimitiveWithInfer): """Computes gradients for `BinaryCrossEntropy` operation.""" @@ -406,6 +490,18 @@ class FusedBatchNormGrad(Primitive): def __call__(self, dy, x, scale, save_mean, save_inv_variance): raise NotImplementedError + +class UniqueGrad(Primitive): + """Gradients of Unique operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx']) + + def __call__(self, dy, x, scale, save_mean, save_inv_variance): + raise NotImplementedError + + class BNTrainingReduceGrad(PrimitiveWithInfer): """Gradients of FusedBatchNorm operation.""" @@ -420,6 +516,7 @@ class BNTrainingReduceGrad(PrimitiveWithInfer): def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance): return grads + class BNTrainingUpdateGrad(PrimitiveWithInfer): """Gradients of FusedBatchNorm operation.""" @@ -434,6 +531,7 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer): def infer_dtype(self, grads, x, batch_mean, batch_variance): return (batch_mean, batch_variance) + class GeluGrad(PrimitiveWithInfer): """Gradients of Gelu operation.""" @@ -492,7 +590,7 @@ class _PoolGrad(PrimitiveWithInfer): class AvgPoolGrad(_PoolGrad): - """Gradients of the avg pool operation.""" + """Gradients of the avg pool operation for ge.""" @prim_attr_register def __init__(self, ksize=1, strides=1, padding="VALID"): @@ -508,6 +606,24 @@ class AvgPoolGrad(_PoolGrad): return out +class AvgPoolGradVm(_PoolGrad): + """Gradients of the avg pool operation for vm.""" + + @prim_attr_register + def __init__(self, ksize=1, strides=1, padding="VALID"): + super(AvgPoolGradVm, self).__init__(ksize, strides, padding) + self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output']) + + def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix): + out = { + 'value': None, + 'shape': tuple(origin_input['value']), + 'dtype': dout['dtype'], + } + + return out + + class AvgPoolGradGpu(_PoolGrad): """Gradients of the avg pool operation for gpu.""" @@ -1319,6 +1435,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer): This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host. """ + @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output']) @@ -1519,6 +1636,7 @@ class InvGrad(PrimitiveWithInfer): class LRNGrad(PrimitiveWithInfer): """Computes gradients for LRN operation.""" + @prim_attr_register def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z']) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index e70e5b32d5..a691725719 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -288,7 +288,7 @@ class Range(PrimitiveWithInfer): return x_dtype -class AscendQuant(PrimitiveWithInfer): +class Quant(PrimitiveWithInfer): r""" Returns the quantized value of input_x. @@ -320,7 +320,7 @@ class AscendQuant(PrimitiveWithInfer): Examples: >>> input_x = Tensor([100.0, 150.0], mstype.float32) - >>> quant = P.AscendQuant(80.0, 0.0, False, "Round") + >>> quant = P.Quant(80.0, 0.0, False, "Round") >>> y = quant(input_x) """ @@ -341,7 +341,7 @@ class AscendQuant(PrimitiveWithInfer): return mstype.int8 -class AscendDequant(PrimitiveWithInfer): +class Dequant(PrimitiveWithInfer): r""" Returns the dequantized value of input_x. This operation will do ReLU to the dequantized value if `relu_flag` is True. @@ -373,13 +373,14 @@ class AscendDequant(PrimitiveWithInfer): Examples: >>> input_x = Tensor([100.0, 150.0], mstype.float32) - >>> dequant = P.AscendDequant(False, False) + >>> dequant = P.Dequant(False, False) >>> y = dequant(input_x) """ @prim_attr_register def __init__(self, sqrt_mode=False, relu_flag=False): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) + self.add_prim_attr("dtype", mstype.float16) def infer_shape(self, x_shape, deq_scale_shape): return x_shape diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 1f4de03d3c..d34e322bb5 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -386,6 +386,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer): raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") if not self.is_ascend: validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + if len(x_shape) == 1: + self.channel_axis = 0 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check_integer( "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) diff --git a/mindspore/ops/operations/_thor_ops.py b/mindspore/ops/operations/_thor_ops.py new file mode 100644 index 0000000000..a4f2335c9b --- /dev/null +++ b/mindspore/ops/operations/_thor_ops.py @@ -0,0 +1,640 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""thor_ops""" +import math + +from ..primitive import prim_attr_register, PrimitiveWithInfer +from ...common import dtype as mstype +from ..._checkparam import Validator as validator +from ..._checkparam import Rel + +__all__ = ["CusBatchMatMul", + "CusCholeskyTrsm", + "CusFusedAbsMax1", + "CusImg2Col", + "CusMatMulCubeDenseLeft", + "CusMatMulCubeFraczRightMul", + "CusMatMulCube", + "CusMatrixCombine", + "CusTranspose02314", + "CusMatMulCubeDenseRight", + "CusMatMulCubeFraczLeftCast", + ] + + +def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): + """ + Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. + """ + + def _raise_message(): + raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two " + f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}") + + def _get_return_value(): + if isinstance(arg_value, int): + ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value) + elif len(arg_value) == 2: + ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value + elif len(arg_value) == 4: + if not allow_four: + _raise_message() + ret = arg_value if ret_four else (arg_value[2], arg_value[3]) + else: + _raise_message() + return ret + + validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) + ret_value = _get_return_value() + for item in ret_value: + if isinstance(item, int) and item > 0: + continue + _raise_message() + return ret_value + + +class CusBatchMatMul(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b` in batch. + + The rank of input tensors must be `3`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. + - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If + `transpose_b` is True. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, D, D)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> cus_batch_matmul = P.CusBatchMatMul() + >>> output = cus_batch_matmul(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CusBatchMatMul""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.batch_matmul_impl import CusBatchMatMul + + def infer_shape(self, data1_shape, data2_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return data1_dtype + + +class CusCholeskyTrsm(PrimitiveWithInfer): + """ + L * LT = A. + LT * (LT)^-1 = I. + return (LT)^-1. + Only compute the res of the diag part of input matrix with dim 128. + The rank of input tensors must be `2`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N // Split_dim, Split_dim, Split_dim)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32) + >>> cus_choleskytrsm = P.CusCholeskyTrsm() + >>> output = matmul(input_x) + """ + + @prim_attr_register + def __init__(self): + """init CusCholeskyTrsm""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import CusCholeskyTrsm + + def infer_shape(self, data1_shape): + ll = [] + m, _ = data1_shape + if m >= 128: + ll = [m // 128, 128, 128] + else: + ll = [1, 64, 64] + return ll + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusFusedAbsMax1(PrimitiveWithInfer): + """ + Compute the abs max of Tensor input. + + The rank of input tensors must be `4` or `2`. + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)` + or math:`(32, 64)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32) + >>> cus_fused_abs_max1 = P.CusFusedAbsMax1() + >>> output = cus_fused_abs_max1(input_x) + """ + + @prim_attr_register + def __init__(self, origin_shape=[-1, -1]): + """init CusFusedAbsMax1""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.origin_shape = origin_shape + from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1 + + def infer_shape(self, data1_shape): + ll = [] + if len(data1_shape) == 2: + ll = [1,] + else: + ll = [32, 64] + return ll + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusImg2Col(PrimitiveWithInfer): + """ + Img2col the feature map and the result in reorganized in NC1HWC0. + + Args: + - **strides** (listInt) - the stride of the ops. + - **ksizes** (listInt) - the kernel size of the ops. + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`. + Examples: + >>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16) + >>> cusimg2col = P.CusImg2Col() + >>> output = cusimg2col(input_x) + """ + + @prim_attr_register + def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"): + """init CusImg2Col""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.ksizes = ksizes + self.strides = strides + self.dilates = dilates + self.mode = mode + from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col + + def infer_shape(self, data1_shape): + bs, c, h, w = data1_shape + _, stride_h, stride_w, _ = self.strides + _, k_w, k_h, _ = self.ksizes + # assert m == n + c0 = 16 + c1 = c // 16 + if c1 == 0: + c1 = 1 + shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0] + return shape + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusMatMulCubeDenseLeft(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 must be `4`, the fractal format of the normal matrix. + The rank of input_x2 must be `2`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. + The shape of the tensor is :math:`(N0, M0, N1, M1)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N, C)`. + Examples: + >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> matmulcubedenseleft = P.CusMatMulCubeDenseLeft() + >>> output = matmulcubedenseleft(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeDenseLeft""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import CusMatMulCubeDenseLeft + + def infer_shape(self, data1_shape, data2_shape): + return data2_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return mstype.float16 + + +class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`. + + The rank of input_x1 tensors must be `2`. + The rank of input_x2 tensors must be `4`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. + The shape of the tensor is :math:`(C1, M1, C0, M0)`. + - **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + Examples: + >>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16) + >>> cusmatmulfraczrightmul = P.CusMatMulCubeFraczRightMul() + >>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3) + """ + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeFraczRightMul""" + self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul + + def infer_shape(self, data1_shape, data2_shape, data3_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): + return mstype.float32 + + +class CusMatMulCube(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input tensors must be `2`. + + Args: + transpose_a (bool): If True, `a` is transposed before multiplication. Default: False. + transpose_b (bool): If True, `b` is transposed before multiplication. Default: False. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If + `transpose_a` is True, its shape should be :math:`(N, C)` after transposing. + - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If + `transpose_b` is True, its shape should be :math:`(C, M)` after transpose. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> cusmatmulcube = P.CusMatMulCube() + >>> output = matmul(input_x, input_y) + """ + + @prim_attr_register + def __init__(self, transpose_a=False, transpose_b=False): + """init CusMatMulCube""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + self.transpose_a = transpose_a + self.transpose_b = transpose_b + from mindspore.ops._op_impl._custom_op.matmul_cube_impl import CusMatMulCube + + def infer_shape(self, data1_shape, data2_shape): + # shape = [1, data1_shape[1], data2_shape[2], 16, 16] + # return shape + if self.transpose_a: + k1, m = data1_shape + else: + m, k1 = data1_shape + if self.transpose_b: + n, k2 = data2_shape + else: + k2, n = data2_shape + assert k1 == k2 + shape = [m, n] + return shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return mstype.float32 + + +class CusMatrixCombine(PrimitiveWithInfer): + """ + move the batch matrix to result matrix diag part. + The rank of input tensors must be `3`. + + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N * D, N * D)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> cusmatrixcombine = P.CusMatrixCombine() + >>> output = cusmatrixcombine(input_x) + """ + + @prim_attr_register + def __init__(self): + """init CusMatrixCombine""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matrix_combine_impl import CusMatrixCombine + + def infer_shape(self, data_shape): + a, b, c = data_shape + shape = [a * b, a * c] + + return shape + + def infer_dtype(self, data_dtype): + return data_dtype + + +class CusTranspose02314(PrimitiveWithInfer): + """ + Permute input tensor with perm (0, 2, 3, 1, 4) + + The rank of input tensors must be `5` with format NC1HWC0. + + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16) + >>> custranspose02314 = P.CusTranspose02314() + >>> output = custranspose02314(input_x) + """ + + @prim_attr_register + def __init__(self): + """init CusTranspose02314""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314 + + def get_bprop(self): + def bprop(x, out, dout): + return (C.zeros_like(x),) + + return bprop + + def infer_shape(self, data1_shape): + assert len(data1_shape) == 4 + n, c, h, w = data1_shape + c0 = 16 + c1 = c // 16 + shape = (n * h * w, c1 * c0) + return shape + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusMatMulCubeDenseRight(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 tensor must be `2`. + The rank of input_x2 tensor must be `4`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. + - **input_y** (Tensor) - The second tensor to be multiplied. + The shape of the tensor is :math:`(C1, M1, M0, C0)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> cusmatmulcubedenseright = P.CusMatMulCubeDenseRight() + >>> output = cusmatmulcubedenseright(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeDenseRight""" + self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import CusMatMulCubeDenseRight + + def infer_shape(self, data1_shape, data2_shape, data3_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): + return mstype.float32 + + +class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 tensor must be `4`. + The rank of input_x2 tensors must be `2`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. + The shape of the tensor is :math:`(C1, N1, N0, C0)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> cusmatmulcubefraczleftcast = P.CusMatMulCubeFraczLeftCast() + >>> output = cusmatmulcubefraczleftcast(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeFraczLeftCast""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast + + def infer_shape(self, data1_shape, data2_shape): + return data2_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return mstype.float16 + + +class Im2Col(PrimitiveWithInfer): + """ + extract image pathes from image. + + The rank of input_x1 must be `4`, data_format is "NCHW". + + Inputs: + - **input_x1** (Tensor) - The feature map. + The shape of the tensor is :math:`(N, C, H, W)`. + Outputs: + Tensor. + Examples: + >>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16)) + >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) + >>> output = img2col(input_x) + """ + @prim_attr_register + def __init__(self, + kernel_size, + pad_mode="valid", + pad=0, + stride=1, + dilation=1): + """init Im2Col""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.add_prim_attr('kernel_size', self.kernel_size) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + self.add_prim_attr('stride', self.stride) + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) + self.add_prim_attr('dilation', self.dilation) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + if self.pad_mode == 'pad': + validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) + self.add_prim_attr('data_format', "NCHW") + + def infer_shape(self, x_shape): + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + kernel_size_h = self.kernel_size[0] + kernel_size_w = self.kernel_size[1] + stride_h = self.stride[2] + stride_w = self.stride[3] + dilation_h = self.dilation[2] + dilation_w = self.dilation[3] + if self.pad_mode == "valid": + h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) + w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) + pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 + elif self.pad_mode == "same": + h_out = math.ceil(x_shape[2] / stride_h) + w_out = math.ceil(x_shape[3] / stride_w) + pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) + pad_top = math.floor(pad_needed_h / 2) + pad_bottom = pad_needed_h - pad_top + pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]) + pad_left = math.floor(pad_needed_w / 2) + pad_right = pad_needed_w - pad_left + elif self.pad_mode == 'pad': + pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad + h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h + w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w + h_out = math.floor(h_out) + w_out = math.floor(w_out) + self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] + self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) + batch_size = x_shape[0] + channel = x_shape[1] + k_h = kernel_size_h + k_w = kernel_size_w + out_shape = [channel, k_h, k_w, batch_size, h_out, w_out] + return out_shape + + def infer_dtype(self, x_dtype): + args = {'x': x_dtype} + valid_types = [mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) + return x_dtype + + +class UpdateThorGradient(PrimitiveWithInfer): + """ + Update Thor Gradient with Approximate Fisher info matrix(for GPU backend). + + The rank of input_x1 must be `3`, which indicates the A matrix. + The rank of input_x2 must be `2`, which indicates the 1st-order gradient. + The rank of input_x3 must be `4`, which indicates the G matrix. + + Inputs: + - **input_x1** (Tensor) - The first input is the diag part of the cov matrix of feature map. + Supported dtype [float32]. + - **input_x2** (Tensor) - The second input is the corresponding 1st-order grad. Supported dtype [float32]. + - **input_x3** (Tensor) - The third input is the diag part of the cov matrix of dout. Supported dtype [float32]. + + Outputs: + Tensor, the shape is the same as the shape of input_x2, it will be used to update the weights. + + Examples: + >>> input_x1 = Tensor(np.random.rand(16, 128, 128).astype(np.float32)) + >>> input_x2 = Tensor(np.random.rand(2048, 1024).astype(np.float32)) + >>> temp_x3 = np.random.rand(8, 128, 128).astype(np.float32) + >>> input_x3 = np.zeros(16,8,128,128).astype(np.float32) + >>> for i in range(16): + >>> input_x3[i,:,:,:] = temp_x3 + >>> input_x3 = Tensor(input_x3) + >>> update_thor_gradient = P.UpdateThorGradient(split_dim=128) + >>> output = update_thor_gradient(input_x1, input_x2, input_x3) + """ + + @prim_attr_register + def __init__(self, split_dim=0): + """init UpdateThorGradient""" + self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) + self.split_dim = split_dim + self.add_prim_attr('split_dim', self.split_dim) + + def infer_shape(self, x1_shape, x2_shape, x3_shape): + return x2_shape + + def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): + validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, + [mstype.float32], self.name) + return x2_dtype + +class Cholesky(PrimitiveWithInfer): + """ + Inner API for resnet50 THOR GPU backend + """ + @prim_attr_register + def __init__(self, split_dim=0): + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.split_dim = split_dim + self.add_prim_attr('split_dim', self.split_dim) + + def infer_shape(self, x1_shape): + if self.split_dim != 0: + assert len(x1_shape) == 2 + height = x1_shape[0] + width = x1_shape[1] + assert height == width + if height <= self.split_dim: + out_shape = [1, height, width] + else: + batch = height // self.split_dim + if height != batch * self.split_dim: + batch += 1 + out_shape = [batch, self.split_dim, self.split_dim] + else: + out_shape = x1_shape + return out_shape + + def infer_dtype(self, x1_dtype): + validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) + return x1_dtype diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 918ad6e0e6..e76ab49ee7 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -22,20 +22,21 @@ import copy import functools import itertools import numbers + import numpy as np -from ..._checkparam import Validator as validator -from ..._checkparam import Rel -from ...common import dtype as mstype -from ...common.tensor import Tensor -from ...common.parameter import Parameter -from ..operations.math_ops import _infer_shape_reduce from .._utils import get_concat_offset +from ..operations.math_ops import _infer_shape_reduce from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_dtype as sig_dtype +from ..._c_expression import signature_kind as sig_kind +from ..._c_expression import signature_rw as sig_rw from ..._c_expression import typing +from ..._checkparam import Rel +from ..._checkparam import Validator as validator +from ...common import dtype as mstype +from ...common.parameter import Parameter +from ...common.tensor import Tensor class _ScatterOp(PrimitiveWithInfer): @@ -47,10 +48,10 @@ class _ScatterOp(PrimitiveWithInfer): ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), ('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) ) - @staticmethod - def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): + + def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): if updates_shape and updates_shape != indices_shape + x_shape[1:]: - raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " + raise ValueError(f"For '{prim_name}', " f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") @@ -61,7 +62,7 @@ class _ScatterOp(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) def infer_shape(self, x_shape, indices_shape, updates_shape): - _ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) + self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): @@ -71,6 +72,19 @@ class _ScatterOp(PrimitiveWithInfer): return x_dtype +class _ScatterNdOp(_ScatterOp): + """ + Define _ScatterNd operators + """ + def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): + validator.check('the dimension of x', len(x_shape), + 'the dimension of indices', indices_shape[-1], Rel.GE) + if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape: + raise ValueError(f"For '{prim_name}', updates_shape = " + f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, " + f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") + + def _check_infer_attr_reduce(axis, keep_dims, prim_name): validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) validator.check_value_type('axis', axis, [int, tuple], prim_name) @@ -402,7 +416,7 @@ class Reshape(PrimitiveWithInfer): return out -class Shape(Primitive): +class Shape(PrimitiveWithInfer): """ Returns the shape of input tensor. @@ -423,6 +437,13 @@ class Shape(Primitive): def __init__(self): """init Shape""" + def __infer__(self, x): + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) + out = {'shape': (), + 'dtype': mstype.tuple_, + 'value': tuple(x['shape'])} + return out + class Squeeze(PrimitiveWithInfer): """ @@ -435,7 +456,7 @@ class Squeeze(PrimitiveWithInfer): ValueError: If the corresponding dimension of the specified axis does not equal to 1. Args: - axis (int): Specifies the dimension indexes of shape to be removed, which will remove + axis (Union[int, tuple(int)]): Specifies the dimension indexes of shape to be removed, which will remove all the dimensions that are equal to 1. If specified, it must be int32 or int64. Default: (), an empty tuple. @@ -535,6 +556,28 @@ class Transpose(PrimitiveWithInfer): return out +class Unique(Primitive): + """ + Returns the unique elements of input tensor and also return a tensor containing the index of each value of input + tensor corresponding to the output unique tensor. + + Inputs: + - **x** (Tensor) - The input tensor. + + Outputs: + Tuple, containing tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor + containing indices of elements in the input coressponding to the output tensor. + + Examples: + >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.float32) + >>> out = P.Unique()(x) + (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.float32)) + """ + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + class GatherV2(PrimitiveWithInfer): """ Returns a slice of input tensor based on the specified indices and axis. @@ -1046,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init InvertPermutation""" - self.const_value = True + self.set_is_const_value(True) def __infer__(self, x): x_shp = x['shape'] @@ -1212,7 +1255,8 @@ class ArgMaxWithValue(PrimitiveWithInfer): """init ArgMaxWithValue""" self.axis = axis self.keep_dims = keep_dims - _check_infer_attr_reduce(axis, keep_dims, self.name) + validator.check_value_type('keep_dims', keep_dims, [bool], self.name) + validator.check_value_type('axis', axis, [int], self.name) def infer_shape(self, x_shape): axis = self.axis @@ -1259,7 +1303,8 @@ class ArgMinWithValue(PrimitiveWithInfer): """init ArgMinWithValue""" self.axis = axis self.keep_dims = keep_dims - _check_infer_attr_reduce(axis, keep_dims, self.name) + validator.check_value_type('keep_dims', keep_dims, [bool], self.name) + validator.check_value_type('axis', axis, [int], self.name) def infer_shape(self, x_shape): axis = self.axis @@ -1467,7 +1512,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer): Inputs: - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. With float16, float32 or int32 data type. - - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32. + - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value should be >= 0. + Data type must be int32. - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`, should be greater than 0. @@ -1799,6 +1845,75 @@ class Slice(PrimitiveWithInfer): 'value': None} +class ReverseV2(PrimitiveWithInfer): + """ + Reverse specific dimensions of a tensor. + + Args: + axis (Union[tuple(int), list(int)): The indices of the dimensions to reverse. + + Inputs: + - **input_x** (Tensor) - The target tensor. + + Outputs: + Tensor, has the same shape and type as `input_x`. + + Examples: + >>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32) + >>> op = P.ReverseV2(axis=[1]) + >>> output = op(input_x) + [[4, 3, 2, 1], [8, 7, 6, 5]] + """ + + @prim_attr_register + def __init__(self, axis): + validator.check_value_type('axis', axis, [list, tuple], self.name) + for i, each in enumerate(axis): + validator.check_value_type(f'axis[{i}]', each, [int], self.name) + self.axis = axis + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + dim = len(x_shape) + for i, each in enumerate(self.axis): + validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name) + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name) + return x_dtype + + +class Rint(PrimitiveWithInfer): + """ + Return element-wise integer closest to x. + + Inputs: + - **input_x** (Tensor) - The target tensor, which must be one of the following types: + float16, float32. + + Outputs: + Tensor, has the same shape and type as `input_x`. + + Examples: + >>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32) + >>> op = P.Rint() + >>> output = op(input_x) + [-2., 0., 2., 2.] + """ + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) + return x_dtype + + class Select(PrimitiveWithInfer): r""" @@ -1914,7 +2029,7 @@ def _compute_slicing_length(begin, end, stride, x_shape, i): if begin >= x_dim: # When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element. begin = -1 - if 0 < end < x_dim: + if 0 <= end < x_dim: end += -x_dim if end < -x_dim - 1: # When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means @@ -2296,10 +2411,10 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): and output tensors are aligned. Default: False. Inputs: - - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. + - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. Outputs: - Tensor, the shape of the output tensor is :math:`(N, NEW\_C, NEW\_H, W)`. + Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`. Examples: >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) @@ -2318,10 +2433,12 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) def infer_shape(self, x): - validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name) + validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) return tuple(x)[:-2] + tuple(self.size) def infer_dtype(self, x): + validator.check_subclass("x", x, mstype.tensor, self.name) + validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name) return x @@ -2444,7 +2561,7 @@ class ScatterUpdate(_ScatterOp): return x_dtype -class ScatterNdUpdate(PrimitiveWithInfer): +class ScatterNdUpdate(_ScatterNdOp): """ Update tensor value by using input indices and value. @@ -2469,11 +2586,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): >>> op = P.ScatterNdUpdate() >>> output = op(input_x, indices, update) """ - __mindspore_signature__ = ( - ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) - ) + @prim_attr_register def __init__(self, use_locking=True): @@ -2481,13 +2594,6 @@ class ScatterNdUpdate(PrimitiveWithInfer): validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) - def infer_shape(self, x_shape, indices_shape, value_shape): - validator.check('the dimension of x', len(x_shape), - 'the dimension of indices', indices_shape[-1], Rel.GE) - if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: - raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.") - return x_shape - def infer_dtype(self, x_dtype, indices_dtype, value_dtype): validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) args = {"x": x_dtype, "value": value_dtype} @@ -2675,6 +2781,101 @@ class ScatterDiv(_ScatterOp): """ +class ScatterNdAdd(_ScatterNdOp): + """ + Applies sparse addition to individual values or slices in a Tensor. + + Using given values to update tensor value through the add operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the add operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32) + >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32) + >>> scatter_nd_add = P.ScatterNdAdd() + >>> output = scatter_nd_add(input_x, indices, updates) + [1, 10, 9, 4, 12, 6, 7, 17] + """ + + +class ScatterNdSub(_ScatterNdOp): + """ + Applies sparse subtraction to individual values or slices in a Tensor. + + Using given values to update tensor value through the sub operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the sub operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32) + >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32) + >>> scatter_nd_sub = P.ScatterNdSub() + >>> output = scatter_nd_sub(input_x, indices, updates) + [1, -6, -3, 4, -2, 6, 7, -1] + """ + + +class ScatterNonAliasingAdd(_ScatterNdOp): + """ + Applies sparse addition to input using individual values or slices. + + Using given values to update tensor value through the add operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the add operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32) + >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32) + >>> scatter_non_aliasing_add = P.ScatterNonAliasingAdd() + >>> output = scatter_non_aliasing_add(input_x, indices, updates) + [1, 10, 9, 4, 12, 6, 7, 17] + """ + + @prim_attr_register + def __init__(self): + """Init ScatterNonAliasingAdd""" + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + + def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + args = {"x": x_dtype, "updates": updates_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name) + return x_dtype + + class SpaceToDepth(PrimitiveWithInfer): r""" Rearrange blocks of spatial data into depth. @@ -3296,7 +3497,7 @@ class EmbeddingLookup(PrimitiveWithInfer): Inputs: - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - The Tensor slice, instead of the entire Tensor. + This represents a Tensor slice, instead of the entire Tensor. Currently, the dimension is restricted to be 2. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, and the exceeding part will be filled with 0 in the output. diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index f8b47a28c3..2b84741fea 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -100,11 +100,6 @@ class AllReduce(PrimitiveWithInfer): self.add_prim_attr('fusion', 0) self.add_prim_attr('index', 0) - def vm_impl(self, x): - """Implement by vm mode.""" - x = x.asnumpy() - return Tensor(x) - def infer_shape(self, x_shape): return x_shape diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 066791d4df..03d9adfe8b 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -251,7 +251,8 @@ class InsertGradientOf(PrimitiveWithInfer): class HookBackward(PrimitiveWithInfer): """ - Used as tag to hook gradient in intermediate variables. + Used as tag to hook gradient in intermediate variables. Note that this function + is only supported in Pynative Mode. Note: The hook function should be defined like `hook_fn(grad) -> Tensor or None`, diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index b4a684d2f7..9bfa078560 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -405,6 +405,42 @@ class ReduceAll(_Reduce): return self.do_infer(input_x, axis, (mstype.bool_,)) +class ReduceAny(_Reduce): + """ + Reduce a dimension of a tensor by the "logical or" of all elements in the dimension. + + The dtype of the tensor to be reduced is bool. + + Args: + keep_dims (bool): If True, keep these reduced dimensions and the length is 1. + If False, don't keep these dimensions. + Default : False, don't keep these reduced dimensions. + + Inputs: + - **input_x** (Tensor[bool]) - The input tensor. + - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. + Only constant value is allowed. + + Outputs: + Tensor, the dtype is bool. + + - If axis is (), and keep_dims is false, + the output is a 0-D tensor representing the "logical or" of of all elements in the input tensor. + - If axis is int, set as 2, and keep_dims is false, + and keep_dims is false, the shape of output is :math:`(x_1, x_3, ..., x_R)`. + - If axis is tuple(int), set as (2, 3), and keep_dims is false, + the shape of output is :math:`(x_1, x_4, ..., x_R)`. + + Examples: + >>> input_x = Tensor(np.array([[True, False], [True, True]])) + >>> op = P.ReduceAny(keep_dims=True) + >>> output = op(input_x, 1) + """ + + def __infer__(self, input_x, axis): + return self.do_infer(input_x, axis, (mstype.bool_,)) + + class ReduceMax(_Reduce): """ Reduce a dimension of a tensor by the maximum value in this dimension. @@ -596,7 +632,7 @@ class MatMul(PrimitiveWithInfer): raise ValueError('MatMul input x, y should be the same dimension size and should be ' + f'equal to 2, while x size = {len(x)}, y size= {len(y)}') - def infer_shape(self, x, y): + def infer_shape(self, x, y, bias=None): self.check_shape_size(x, y) cls_name = self.name # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two @@ -621,7 +657,7 @@ class MatMul(PrimitiveWithInfer): ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]] return ret_dims - def infer_dtype(self, x, y): + def infer_dtype(self, x, y, bias=None): args = {"x": x, "y": y} validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) if x.element_type() == mstype.int8: @@ -778,9 +814,13 @@ class AddN(PrimitiveWithInfer): validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) args = {} + contains_undetermined = False for i, dtype in enumerate(inputs): args[f"inputs[{i}]"] = dtype - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) + if dtype == mstype.undetermined: + contains_undetermined = True + if not contains_undetermined: + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) return inputs[0] def infer_value(self, inputs): @@ -1075,6 +1115,7 @@ class Mul(_MathBinaryOp): >>> mul(input_x, input_y) [4, 10, 18] """ + def infer_value(self, x, y): if x is not None and y is not None: x = x.asnumpy() @@ -1085,6 +1126,40 @@ class Mul(_MathBinaryOp): return None +class SquaredDifference(_MathBinaryOp): + """ + Subtracts the second input tensor from the first input tensor element-wise and returns square of it. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is float16, float32, int32 or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is + float16, float32, int32 or bool. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) + >>> input_y = Tensor(np.array([2.0, 4.0, 6.0]), mindspore.float32) + >>> squared_difference = P.SquaredDifference() + >>> squared_difference(input_x, input_y) + [1.0, 4.0, 9.0] + """ + + def infer_dtype(self, x_dtype, y_dtype): + valid_type = [mstype.float16, mstype.float32, mstype.int32] + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, valid_type, self.name) + + class Square(PrimitiveWithInfer): """ Returns square of a tensor element-wise. @@ -1219,6 +1294,10 @@ class Reciprocal(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init Reciprocal""" + if context.get_context("device_target") == "GPU": + self.target = "GPU" + else: + self.target = "OTHER" self.init_prim_io_names(inputs=['x'], outputs=['y']) def infer_shape(self, x): @@ -1744,6 +1823,65 @@ class FloorDiv(_MathBinaryOp): """ +class TruncateDiv(_MathBinaryOp): + """ + Divide the first input tensor by the second input tensor element-wise for integer types, negative numbers will + round fractional quantities towards zero. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is number or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is number or bool. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) + >>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32) + >>> truncate_div = P.TruncateDiv() + >>> truncate_div(input_x, input_y) + [0, 1, 0] + """ + + +class TruncateMod(_MathBinaryOp): + """ + Returns element-wise remainder of division. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is number or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is number or bool. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) + >>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32) + >>> truncate_mod = P.TruncateMod() + >>> truncate_mod(input_x, input_y) + [2, 1, -1] + """ + + class Mod(_MathBinaryOp): """ Computes the remainder of dividing the first input tensor by the second input tensor element-wise. @@ -1867,6 +2005,72 @@ class Ceil(PrimitiveWithInfer): return x_dtype +class Xdivy(_MathBinaryOp): + """ + Divide the first input tensor by the second input tensor element-wise. Returns zero when `x` is zero. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is float16, float32 or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.float32) + >>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32) + >>> xdivy = P.Xdivy() + >>> xdivy(input_x, input_y) + [1.0, 2.0, -0.5] + """ + + def infer_dtype(self, x_dtype, y_dtype): + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name) + + +class Xlogy(_MathBinaryOp): + """ + Computes first input tensor multiplied by the logarithm of second input tensor element-wise. + Returns zero when `x` is zero. + + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, + both dtypes cannot be bool, and the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, + the scalar only could be a constant. + + Inputs: + - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or + a bool or a tensor whose data type is float16, float32 or bool. + - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or + a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. + The value must be positive. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, + and the data type is the one with high precision or high digits among the two inputs. + + Examples: + >>> input_x = Tensor(np.array([-5, 0, 4]), mindspore.float32) + >>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32) + >>> xlogy = P.Xlogy() + >>> Xlogy(input_x, input_y) + [-3.465736, 0.0, 2.7725887] + """ + + def infer_dtype(self, x_dtype, y_dtype): + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name) + + class Acosh(PrimitiveWithInfer): """ Compute inverse hyperbolic cosine of x element-wise. @@ -2870,6 +3074,36 @@ class Round(PrimitiveWithInfer): return x_type +class Tan(PrimitiveWithInfer): + """ + Computes tangent of `input_x` element-wise. + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. Data type should be + float16, float32 or int32. + + Outputs: + Tensor, has the same shape as `input_x`. + + Examples: + >>> tan = P.Tan() + >>> input_x = Tensor(np.array([-1.0, 0.0, 1.0]), mindspore.float32) + >>> output = tan(input_x) + """ + + @prim_attr_register + def __init__(self): + """init Tan""" + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + valid_types = [mstype.float16, mstype.float32, mstype.int32] + validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) + return x_type + + class Atan(PrimitiveWithInfer): """ Computes the trignometric inverse tangent of x element-wise. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1c8b24a112..1ebba14064 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -498,7 +498,7 @@ class HSigmoid(PrimitiveWithInfer): Hard sigmoid is defined as: .. math:: - \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})), + \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{x_{i} + 3}{6})), where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. @@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive): self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) + self._update_parameter = True class BNTrainingReduce(PrimitiveWithInfer): @@ -842,7 +843,7 @@ class Conv2D(PrimitiveWithInfer): self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('offset_a', 0) - def infer_shape(self, x_shape, w_shape): + def infer_shape(self, x_shape, w_shape, b_shape=None): validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) @@ -887,7 +888,7 @@ class Conv2D(PrimitiveWithInfer): out_shape = [x_shape[0], out_channel, h_out, w_out] return out_shape - def infer_dtype(self, x_dtype, w_dtype): + def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] validator.check_tensor_type_same(args, valid_types, self.name) @@ -968,7 +969,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) self.add_prim_attr('offset_a', 0) - def infer_shape(self, x_shape, w_shape): + def infer_shape(self, x_shape, w_shape, b_shape=None): validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) @@ -1011,7 +1012,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): out_shape = [x_shape[0], out_channel, h_out, w_out] return out_shape - def infer_dtype(self, x_dtype, w_dtype): + def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) if x_dtype.element_type() == mstype.int8: @@ -1180,6 +1181,7 @@ class MaxPoolWithArgmax(_Pool): def __init__(self, ksize=1, strides=1, padding="valid"): super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding) self.is_tbe = context.get_context("device_target") == "Ascend" + self.is_gpu = context.get_context("device_target") == "GPU" def infer_shape(self, x_shape): out_shape = _Pool.infer_shape(self, x_shape) @@ -1206,6 +1208,8 @@ class MaxPoolWithArgmax(_Pool): out_dtype = x_dtype validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.uint16 + if self.is_gpu: + argmax_dtype = mstype.int32 return out_dtype, argmax_dtype @@ -1276,6 +1280,8 @@ class AvgPool(_Pool): def __init__(self, ksize=1, strides=1, padding="valid"): if context.get_context("device_target") == "GPU": self.target = "GPU" + elif context.get_context("enable_ge"): + self.target = "GE" else: self.target = "OTHER" super(AvgPool, self).__init__(ksize, strides, padding) @@ -1332,10 +1338,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer): validator.check_value_type('pad', pad, (int, tuple), self.name) if isinstance(pad, int): pad = (pad,) * 4 - self.padding = pad else: validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) - + self.padding = pad self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) if pad_mode != 'pad' and pad != (0, 0, 0, 0): raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") @@ -1691,7 +1696,9 @@ class L2Loss(PrimitiveWithInfer): Set `input_x` as x and output as loss. .. math:: - loss = sum(x ** 2) / 2 + loss = sum(x ** 2) / nelement(x) + + :math:`nelement(x)` represents the number of `input_x`. Inputs: - **input_x** (Tensor) - A input Tensor. @@ -2042,6 +2049,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.is_ascend = context.get_context("device_target") == "Ascend" def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): @@ -2049,6 +2057,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) + if self.is_ascend: + return var_shape, mean_gradient_shape, mean_square_shape, moment_shape return var_shape def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, @@ -2062,6 +2072,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): validator.check_type_same(args_rho, valid_types, self.name) args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) + if self.is_ascend: + return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype return var_dtype @@ -2749,12 +2761,17 @@ class MirrorPad(PrimitiveWithInfer): paddings_value = paddings['value'].asnumpy() paddings_size = paddings_value.size validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) - if not np.all(paddings_size >= 0): + if not np.all(paddings_value >= 0): raise ValueError('All elements of paddings must be >= 0.') + adjust = 0 + if self.mode == 'SYMMETRIC': + adjust = 1 + for i in range(0, int(paddings_size / 2)): + if (paddings_value[i, 0] >= x_shape[i] + adjust) or (paddings_value[i, 1] >= x_shape[i] + adjust): + raise ValueError('At least one dim has too high a padding value for this input and mode') y_shape = () for i in range(0, int(paddings_size / 2)): y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),) - return {'shape': y_shape, 'dtype': input_x['dtype'], 'value': None} @@ -3193,11 +3210,11 @@ class FusedSparseFtrl(PrimitiveWithInfer): use_locking (bool): Use locks for update operation if True . Default: False. Inputs: - - **var** (Parameter): The variable to be updated. The data type must be float32. - - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. - - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. - - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape + - **var** (Parameter) - The variable to be updated. The data type must be float32. + - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. Outputs: @@ -3288,9 +3305,9 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): Inputs: - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - - **lr** (Tensor): The learning rate value. The data type must be float32. - - **l1** (Tensor): l1 regularization strength. The data type must be float32. - - **l2** (Tensor): l2 regularization strength. The data type must be float32. + - **lr** (Tensor) - The learning rate value. The data type must be float32. + - **l1** (Tensor) - l1 regularization strength. The data type must be float32. + - **l2** (Tensor) - l2 regularization strength. The data type must be float32. - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type must be int32. @@ -3356,6 +3373,78 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): return var_dtype, accum_dtype +class KLDivLoss(PrimitiveWithInfer): + r""" + Computes the Kullback-Leibler divergence between the target and the output. + + Note: + Sets input as :math:`x`, input label as :math:`y`, output as :math:`\ell(x, y)`. + Let, + + .. math:: + L = \{l_1,\dots,l_N\}^\top, \quad + l_n = y_n \cdot (\log y_n - x_n) + + Then, + + .. math:: + \ell(x, y) = \begin{cases} + L, & \text{if reduction} = \text{'none';}\\ + \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \end{cases} + + Args: + reduction (str): Specifies the reduction to apply to the output. + Its value should be one of 'none', 'mean', 'sum'. Default: 'mean'. + + Inputs: + - **input_x** (Tensor) - The input Tensor. The data type must be float32. + - **input_y** (Tensor) - The label Tensor which has same shape as `input_x`. The data type must be float32. + + Outputs: + Tensor or Scalar, if `reduction` is 'none', then output is a tensor and same shape as `input_x`. + Otherwise it is a scalar. + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.kldiv_loss = P.KLDivLoss() + >>> def construct(self, x, y): + >>> result = self.kldiv_loss(x, y) + >>> return result + >>> + >>> net = Net() + >>> input_x = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32) + >>> input_y = Tensor(np.array([0., 1., 0.]), mindspore.float32) + >>> result = net(input_x, input_y) + """ + + @prim_attr_register + def __init__(self, reduction='mean'): + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) + + def infer_shape(self, x_shape, y_shape): + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) + if self.reduction in ('mean', 'sum'): + shape = [] + else: + shape = x_shape + return shape + + def infer_dtype(self, x_type, y_type): + args = {'x': x_type, 'y': y_type} + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) + return x_type + + class BinaryCrossEntropy(PrimitiveWithInfer): r""" Computes the Binary Cross Entropy between the target and the output. @@ -3755,12 +3844,12 @@ class ApplyAdagradV2(PrimitiveWithInfer): update_slots (bool): If `True`, `accum` will be updated. Default: True. Inputs: - - **var** (Parameter) - Variable to be updated. With float32 or float16 data type. + - **var** (Parameter) - Variable to be updated. With float32 data type. - **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`. - With float32 or float16 data type. - - **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 or float16 data type. + With float32 data type. + - **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 data type. - **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`. - With float32 or float16 data type. + With float32 data type. Outputs: Tuple of 2 Tensor, the updated parameters. @@ -3812,9 +3901,8 @@ class ApplyAdagradV2(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name) + validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float32], self.name) return var_dtype, accum_dtype @@ -4278,6 +4366,7 @@ class ApplyPowerSign(PrimitiveWithInfer): Inputs: - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type. + If data type of `var` is float16, all inputs must have the same data type as `var`. - **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar. With float32 or float16 data type. @@ -4320,11 +4409,11 @@ class ApplyPowerSign(PrimitiveWithInfer): __mindspore_signature__ = ( ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), + ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T3), - ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), + sig_dtype.T), + ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) ) @@ -4588,16 +4677,16 @@ class ApplyFtrl(PrimitiveWithInfer): use_locking (bool): Use locks for update operation if True . Default: False. Inputs: - - **var** (Tensor): The variable to be updated. - - **accum** (Tensor): The accum to be updated, must be same type and shape as `var`. - - **linear** (Tensor): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): Gradient. - - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. Default: 0.001. - - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. + - **var** (Tensor) - The variable to be updated. + - **accum** (Tensor) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Tensor) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - Gradient. + - **lr** (Union[Number, Tensor]) - The learning rate value, must be positive. Default: 0.001. + - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be greater than or equal to zero. Default: 0.0. - - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. + - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be greater than or equal to zero. Default: 0.0. - - **lr_power** (Union[Number, Tensor]): Learning rate power controls how the learning rate decreases + - **lr_power** (Union[Number, Tensor]) - Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5. @@ -4678,17 +4767,17 @@ class SparseApplyFtrl(PrimitiveWithInfer): use_locking (bool): Use locks for update operation if True . Default: False. Inputs: - - **var** (Parameter): The variable to be updated. The data type must be float32. - - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. - - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. - - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. + - **var** (Parameter) - The variable to be updated. The data type must be float32. + - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. Outputs: - - **var** (Tensor): Tensor, has the same shape and type as `var`. - - **accum** (Tensor): Tensor, has the same shape and type as `accum`. - - **linear** (Tensor): Tensor, has the same shape and type as `linear`. + - **var** (Tensor) - Tensor, has the same shape and type as `var`. + - **accum** (Tensor) - Tensor, has the same shape and type as `accum`. + - **linear** (Tensor) - Tensor, has the same shape and type as `linear`. Examples: >>> import mindspore @@ -4776,9 +4865,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): Outputs: Tuple of 3 Tensor, the updated parameters. - - **var** (Tensor): Tensor, has the same shape and type as `var`. - - **accum** (Tensor): Tensor, has the same shape and type as `accum`. - - **linear** (Tensor): Tensor, has the same shape and type as `linear`. + - **var** (Tensor) - Tensor, has the same shape and type as `var`. + - **accum** (Tensor) - Tensor, has the same shape and type as `accum`. + - **linear** (Tensor) - Tensor, has the same shape and type as `linear`. Examples: >>> import mindspore diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index a58403f883..ef3751f3e2 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -93,14 +93,20 @@ class BoundingBoxEncode(PrimitiveWithInfer): @prim_attr_register def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): - validator.check_value_type('means', means, [tuple], self.name) - validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_value_type('means', means, [tuple, list], self.name) + validator.check_value_type('stds', stds, [tuple, list], self.name) + for i, value in enumerate(means): + validator.check_value_type("means[%d]" % i, value, [float], self.name) + for i, value in enumerate(stds): + validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) def infer_shape(self, anchor_box, groundtruth_box): validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, self.name) + validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) + validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name) validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name) return anchor_box @@ -141,8 +147,12 @@ class BoundingBoxDecode(PrimitiveWithInfer): @prim_attr_register def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): - validator.check_value_type('means', means, [tuple], self.name) - validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_value_type('means', means, [tuple, list], self.name) + validator.check_value_type('stds', stds, [tuple, list], self.name) + for i, value in enumerate(means): + validator.check_value_type("means[%d]" % i, value, [float], self.name) + for i, value in enumerate(stds): + validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) @@ -152,6 +162,8 @@ class BoundingBoxDecode(PrimitiveWithInfer): def infer_shape(self, anchor_box, deltas): validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) + validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) + validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name) validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name) return anchor_box @@ -262,8 +274,6 @@ class IOU(PrimitiveWithInfer): return iou def infer_dtype(self, anchor_boxes, gt_boxes): - args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes} - validator.check_tensor_type_same(args, (mstype.float16,), self.name) return anchor_boxes @@ -379,8 +389,8 @@ class CheckBprop(PrimitiveWithInfer): validator.check_value_type('grads', xshapes, (tuple,), tips) validator.check_value_type('params', yshapes, (tuple,), tips) if len(xshapes) < len(yshapes): - raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," - f" but got {len(xshapes)}.") + raise ValueError(f"{tips}, the size of output should be {len(yshapes)}," + f" but got {len(xshapes)}.") checking_range = len(yshapes) for i in range(checking_range): xshape = xshapes[i] @@ -388,8 +398,8 @@ class CheckBprop(PrimitiveWithInfer): if not xshape or not yshape: continue if xshape != yshape: - raise TypeError(f"{tips}, the shape of {i}th output should be {yshape}," - f" but got {xshape}.") + raise ValueError(f"{tips}, the shape of {i}th output should be {yshape}," + f" but got {xshape}.") return xshapes def infer_dtype(self, xdtypes, ydtypes): @@ -397,8 +407,8 @@ class CheckBprop(PrimitiveWithInfer): validator.check_value_type('grads', xdtypes, (tuple,), tips) validator.check_value_type('params', ydtypes, (tuple,), tips) if len(xdtypes) < len(ydtypes): - raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," - f" but got {len(xdtypes)}.") + raise ValueError(f"{tips}, the size of output should be {len(ydtypes)}," + f" but got {len(xdtypes)}.") checking_range = len(ydtypes) for i in range(checking_range): xdtype = xdtypes[i] @@ -511,6 +521,7 @@ class Push(PrimitiveWithInfer): @prim_attr_register def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None): """init Push""" + self.add_prim_attr("primitive_target", "CPU") self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key']) def infer_shape(self, inputs, shapes): @@ -534,6 +545,7 @@ class Pull(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init Pull""" + self.add_prim_attr("primitive_target", "CPU") self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output']) def infer_shape(self, key_shape, weight_shape): diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 59b28cf09d..a303f58dc7 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -337,13 +337,13 @@ class RandomChoiceWithMask(PrimitiveWithInfer): seed2 (int): Random seed2. Default: 0. Inputs: - - **input_x** (Tensor[bool]) - The input tensor. + - **input_x** (Tensor[bool]) - The input tensor. The input tensor rank should be >= 1 and <= 5. Outputs: Two tensors, the first one is the index tensor and the other one is the mask tensor. - - **index** (Tensor) - The output has shape between 2-D and 5-D. - - **mask** (Tensor) - The output has shape 1-D. + - **index** (Tensor) - The output shape is 2-D. + - **mask** (Tensor) - The output shape is 1-D. Examples: >>> rnd_choice_mask = P.RandomChoiceWithMask() @@ -361,6 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): @@ -379,7 +380,7 @@ class RandomCategorical(PrimitiveWithInfer): Inputs: - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes]. - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed. - - **seed** (int) - Random seed. Default: 0. + - **seed** (int) - Random seed. Default: 0. Only constant values is allowed. Outputs: - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples]. @@ -397,6 +398,7 @@ class RandomCategorical(PrimitiveWithInfer): >>> net = Net(8) >>> output = net(Tensor(x)) """ + @prim_attr_register def __init__(self, dtype=mstype.int64): """Init RandomCategorical""" @@ -424,3 +426,54 @@ class RandomCategorical(PrimitiveWithInfer): return {'shape': (x_shape), 'dtype': (self.dtype), 'value': None} + + +class Multinomial(PrimitiveWithInfer): + r""" + Returns a tensor sampled from the multinomial probability distribution located in the corresponding + row of tensor input. + + Note: + The rows of input do not need to sum to one (in which case we use the values as weights), + but must be non-negative, finite and have a non-zero sum. + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims. + - **num_samples** (int) - number of samples to draw. + + Outputs: + Tensor. have the same rows with input, each row has num_samples sampled indices. + + Examples: + >>> input = Tensor([0., 9., 4., 0.], mstype.float32) + >>> multinomial = P.Multinomial(seed=10) + >>> output = multinomial(input, 2) + """ + + @prim_attr_register + def __init__(self, seed=0): + """init""" + validator.check_value_type("seed", seed, [int], self.name) + self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) + + def __infer__(self, inputs, num_samples): + input_shape = inputs["shape"] + if len(input_shape) != 1 and len(input_shape) != 2: + raise ValueError("input dim must be 1 or 2") + validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) + num_samples_value = num_samples["value"] + if num_samples_value is None: + raise ValueError(f"For {self.name}, shape nust be const") + validator.check_value_type("num_samples", num_samples_value, [int], self.name) + validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None) + y_shape = (num_samples_value,) + if len(input_shape) == 2: + y_shape = (input_shape[0], num_samples_value) + out = { + "shape": y_shape, + "dtype": mstype.int32, + "value": None} + return out diff --git a/mindspore/ops/operations/sparse_ops.py b/mindspore/ops/operations/sparse_ops.py new file mode 100644 index 0000000000..6aaa7f271f --- /dev/null +++ b/mindspore/ops/operations/sparse_ops.py @@ -0,0 +1,55 @@ +# coding: utf-8 + +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. +# ============================================================================ + +"""Operators for sparse operators.""" + +from ..._checkparam import Validator as validator +from ...common import dtype as mstype +from ..primitive import PrimitiveWithInfer, prim_attr_register + +class SparseToDense(PrimitiveWithInfer): + """ + Convert a sparse representation into a dense tensor. + + Inputs: + - **indices** (Tensor) - The indices of sparse representation. + - **values** (Tensor) - Values corresponding to each row of indices. + - **dense_shape** (tuple) - A int tuple which specifies the shape of dense tensor. + + Returns: + Tensor, the shape of tensor is dense_shape. + + Examples: + >>> indices = Tensor([[0, 1], [1, 2]]) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> dense_shape = (3, 4) + >>> out = P.SparseToDense()(indices, values, dense_shape) + """ + + @prim_attr_register + def __init__(self): + """init index_select""" + self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output']) + + def __infer__(self, indices, values, dense_shape): + validator.check_subclass("indices", indices['dtype'], mstype.tensor, self.name) + validator.check_subclass("values", values['dtype'], mstype.tensor, self.name) + out = {'shape': dense_shape['value'], + 'dtype': values['dtype'], + 'value': None} + return out diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py deleted file mode 100644 index d2de0190a6..0000000000 --- a/mindspore/ops/operations/thor_ops.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""thor_ops""" -from ..primitive import prim_attr_register, PrimitiveWithInfer -from ...common import dtype as mstype - - -__all__ = ["CusBatchMatMul", - "CusCholeskyTrsm", - "CusFusedAbsMax1", - "CusImg2Col", - "CusMatMulCubeDenseLeft", - "CusMatMulCubeFraczRightMul", - "CusMatMulCube", - "CusMatrixCombine", - "CusTranspose02314", - "CusMatMulCubeDenseRight", - "CusMatMulCubeFraczLeftCast", - ] - - -class CusBatchMatMul(PrimitiveWithInfer): - """ - Multiplies matrix `a` by matrix `b` in batch. - - The rank of input tensors must be `3`. - - Inputs: - - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. - - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If - `transpose_b` is True. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N, D, D)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) - >>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) - >>> cus_batch_matmul = P.CusBatchMatMul() - >>> output = cus_batch_matmul(input_x, input_y) - """ - - @prim_attr_register - def __init__(self): - """init CusBatchMatMul""" - self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.batch_matmul_impl import CusBatchMatMul - - def infer_shape(self, data1_shape, data2_shape): - return data1_shape - - def infer_dtype(self, data1_dtype, data2_dtype): - return data1_dtype - - -class CusCholeskyTrsm(PrimitiveWithInfer): - """ - L * LT = A. - LT * (LT)^-1 = I. - return (LT)^-1. - Only compute the res of the diag part of input matrix with dim 128. - The rank of input tensors must be `2`. - - Inputs: - - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N // Split_dim, Split_dim, Split_dim)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32) - >>> cus_choleskytrsm = P.CusCholeskyTrsm() - >>> output = matmul(input_x) - """ - - @prim_attr_register - def __init__(self): - """init CusCholeskyTrsm""" - self.init_prim_io_names(inputs=['x1'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import CusCholeskyTrsm - - def infer_shape(self, data1_shape): - ll = [] - m, _ = data1_shape - if m >= 128: - ll = [m // 128, 128, 128] - else: - ll = [1, 64, 64] - return ll - - def infer_dtype(self, data1_dtype): - return data1_dtype - - -class CusFusedAbsMax1(PrimitiveWithInfer): - """ - Compute the abs max of Tensor input. - - The rank of input tensors must be `4` or `2`. - Inputs: - - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)` - or math:`(32, 64)`. - Outputs: - Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32) - >>> cus_fused_abs_max1 = P.CusFusedAbsMax1() - >>> output = cus_fused_abs_max1(input_x) - """ - - @prim_attr_register - def __init__(self, origin_shape=[-1, -1]): - """init CusFusedAbsMax1""" - self.init_prim_io_names(inputs=['x1'], outputs=['y']) - self.origin_shape = origin_shape - from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1 - - def infer_shape(self, data1_shape): - ll = [] - if len(data1_shape) == 2: - ll = [1,] - else: - ll = [32, 64] - return ll - - def infer_dtype(self, data1_dtype): - return data1_dtype - - -class CusImg2Col(PrimitiveWithInfer): - """ - Img2col the feature map and the result in reorganized in NC1HWC0. - - Args: - - **strides** (listInt) - the stride of the ops. - - **ksizes** (listInt) - the kernel size of the ops. - Inputs: - - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`. - Outputs: - Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`. - Examples: - >>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16) - >>> cusimg2col = P.CusImg2Col() - >>> output = cusimg2col(input_x) - """ - - @prim_attr_register - def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"): - """init CusImg2Col""" - self.init_prim_io_names(inputs=['x1'], outputs=['y']) - self.ksizes = ksizes - self.strides = strides - self.dilates = dilates - self.mode = mode - from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col - - def infer_shape(self, data1_shape): - bs, c, h, w = data1_shape - _, stride_h, stride_w, _ = self.strides - _, k_w, k_h, _ = self.ksizes - # assert m == n - c0 = 16 - c1 = c // 16 - if c1 == 0: - c1 = 1 - shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0] - return shape - - def infer_dtype(self, data1_dtype): - return data1_dtype - - -class CusMatMulCubeDenseLeft(PrimitiveWithInfer): - """ - Multiplies matrix `a` by matrix `b`. - - The rank of input_x1 must be `4`, the fractal format of the normal matrix. - The rank of input_x2 must be `2`. - - Inputs: - - **input_x1** (Tensor) - The first tensor to be multiplied. - The shape of the tensor is :math:`(N0, M0, N1, M1)`. - - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`. - Outputs: - Tensor, the shape of the output tensor is :math:`(N, C)`. - Examples: - >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) - >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) - >>> matmulcubedenseleft = P.CusMatMulCubeDenseLeft() - >>> output = matmulcubedenseleft(input_x, input_y) - """ - - @prim_attr_register - def __init__(self): - """init CusMatMulCubeDenseLeft""" - self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import CusMatMulCubeDenseLeft - - def infer_shape(self, data1_shape, data2_shape): - return data2_shape - - def infer_dtype(self, data1_dtype, data2_dtype): - return mstype.float16 - - -class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): - """ - Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`. - - The rank of input_x1 tensors must be `2`. - The rank of input_x2 tensors must be `4`. - - Inputs: - - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. - - **input_x2** (Tensor) - The second tensor to be multiplied. - The shape of the tensor is :math:`(C1, M1, C0, M0)`. - - **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`. - Outputs: - Tensor, the shape of the output tensor is :math:`(N, M)`. - Examples: - >>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16) - >>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) - >>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16) - >>> cusmatmulfraczrightmul = P.CusMatMulCubeFraczRightMul() - >>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3) - """ - - @prim_attr_register - def __init__(self): - """init CusMatMulCubeFraczRightMul""" - self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul - - def infer_shape(self, data1_shape, data2_shape, data3_shape): - return data1_shape - - def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): - return mstype.float32 - - -class CusMatMulCube(PrimitiveWithInfer): - """ - Multiplies matrix `a` by matrix `b`. - - The rank of input tensors must be `2`. - - Args: - transpose_a (bool): If True, `a` is transposed before multiplication. Default: False. - transpose_b (bool): If True, `b` is transposed before multiplication. Default: False. - - Inputs: - - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If - `transpose_a` is True, its shape should be :math:`(N, C)` after transposing. - - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If - `transpose_b` is True, its shape should be :math:`(C, M)` after transpose. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N, M)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) - >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) - >>> cusmatmulcube = P.CusMatMulCube() - >>> output = matmul(input_x, input_y) - """ - - @prim_attr_register - def __init__(self, transpose_a=False, transpose_b=False): - """init CusMatMulCube""" - self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - self.transpose_a = transpose_a - self.transpose_b = transpose_b - from mindspore.ops._op_impl._custom_op.matmul_cube_impl import CusMatMulCube - - def infer_shape(self, data1_shape, data2_shape): - # shape = [1, data1_shape[1], data2_shape[2], 16, 16] - # return shape - if self.transpose_a: - k1, m = data1_shape - else: - m, k1 = data1_shape - if self.transpose_b: - n, k2 = data2_shape - else: - k2, n = data2_shape - assert k1 == k2 - shape = [m, n] - return shape - - def infer_dtype(self, data1_dtype, data2_dtype): - return mstype.float32 - - -class CusMatrixCombine(PrimitiveWithInfer): - """ - move the batch matrix to result matrix diag part. - The rank of input tensors must be `3`. - - Inputs: - - **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N * D, N * D)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) - >>> cusmatrixcombine = P.CusMatrixCombine() - >>> output = cusmatrixcombine(input_x) - """ - - @prim_attr_register - def __init__(self): - """init CusMatrixCombine""" - self.init_prim_io_names(inputs=['x'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.matrix_combine_impl import CusMatrixCombine - - def infer_shape(self, data_shape): - a, b, c = data_shape - shape = [a * b, a * c] - - return shape - - def infer_dtype(self, data_dtype): - return data_dtype - - -class CusTranspose02314(PrimitiveWithInfer): - """ - Permute input tensor with perm (0, 2, 3, 1, 4) - - The rank of input tensors must be `5` with format NC1HWC0. - - Inputs: - - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16) - >>> custranspose02314 = P.CusTranspose02314() - >>> output = custranspose02314(input_x) - """ - - @prim_attr_register - def __init__(self): - """init CusTranspose02314""" - self.init_prim_io_names(inputs=['x1'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314 - def get_bprop(self): - def bprop(x, out, dout): - return (C.zeros_like(x),) - - return bprop - - def infer_shape(self, data1_shape): - assert len(data1_shape) == 4 - n, c, h, w = data1_shape - c0 = 16 - c1 = c // 16 - shape = (n * h * w, c1 * c0) - return shape - - def infer_dtype(self, data1_dtype): - return data1_dtype - - -class CusMatMulCubeDenseRight(PrimitiveWithInfer): - """ - Multiplies matrix `a` by matrix `b`. - - The rank of input_x1 tensor must be `2`. - The rank of input_x2 tensor must be `4`. - - Inputs: - - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. - - **input_y** (Tensor) - The second tensor to be multiplied. - The shape of the tensor is :math:`(C1, M1, M0, C0)`. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N, M)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) - >>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) - >>> cusmatmulcubedenseright = P.CusMatMulCubeDenseRight() - >>> output = cusmatmulcubedenseright(input_x, input_y) - """ - - @prim_attr_register - def __init__(self): - """init CusMatMulCubeDenseRight""" - self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import CusMatMulCubeDenseRight - - def infer_shape(self, data1_shape, data2_shape, data3_shape): - return data1_shape - - def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): - return mstype.float32 - - -class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): - """ - Multiplies matrix `a` by matrix `b`. - - The rank of input_x1 tensor must be `4`. - The rank of input_x2 tensors must be `2`. - - Inputs: - - **input_x1** (Tensor) - The first tensor to be multiplied. - The shape of the tensor is :math:`(C1, N1, N0, C0)`. - - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. - - Outputs: - Tensor, the shape of the output tensor is :math:`(N, M)`. - - Examples: - >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) - >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) - >>> cusmatmulcubefraczleftcast = P.CusMatMulCubeFraczLeftCast() - >>> output = cusmatmulcubefraczleftcast(input_x, input_y) - """ - - @prim_attr_register - def __init__(self): - """init CusMatMulCubeFraczLeftCast""" - self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast - - def infer_shape(self, data1_shape, data2_shape): - return data2_shape - - def infer_dtype(self, data1_dtype, data2_dtype): - return mstype.float16 diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index cb34e9ff24..21924bb5a3 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -18,6 +18,9 @@ import inspect import copy from mindspore.common.api import _wrap_func +from mindspore.common import Parameter +from mindspore.common._register_for_tensor import tensor_operator_registry +from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import signature_rw as sig_rw from .._c_expression import signature_kind as sig_kind @@ -26,10 +29,10 @@ from .._c_expression import signature_dtype as sig_dtype class Primitive(Primitive_): """ - Primitive is base class for primitives in python. + Primitive is the base class of primitives in python. Args: - name (str): Name for current Primitive. + name (str): Name for the current Primitive. Examples: >>> add = Primitive('add') @@ -49,6 +52,7 @@ class Primitive(Primitive_): self.name = name self.attrs = {} self.init_attrs = {"name": name} + self._update_parameter = False Primitive_.__init__(self, name, self) if hasattr(self.__class__, '__mindspore_signature__'): sig = self._fill_signature(self.__class__.__mindspore_signature__) @@ -108,11 +112,11 @@ class Primitive(Primitive_): def set_strategy(self, strategy): """ - Adds strategy to primitive attribute. + Add strategies to primitive attribute. Note: - Valid only in semi auto parallel or auto parallel mode. - In other parallel modes, strategies will be ignored if set. + It is valid only in semi auto parallel or auto parallel mode. + In other parallel modes, strategies set here will be ignored. Args: strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. @@ -122,10 +126,10 @@ class Primitive(Primitive_): def set_prim_instance_name(self, instance_name): """ - Sets instance name to primitive operator. + Set instance name to primitive operator. Note: - Will be called by default when user defines primitive operator. + It will be called by default when user defines primitive operator. Args: instance_name (str): Instance name of primitive operator set by user. @@ -135,6 +139,8 @@ class Primitive(Primitive_): return self def __getattr__(self, item): + if item == 'infer_dynamic_shape': + return None if item in super().get_attr_dict(): return super().get_attr_dict()[item] if item in self.attrs: @@ -143,14 +149,14 @@ class Primitive(Primitive_): def check_elim(self, *args): """ - Check whether or not certain inputs should go into backend. Subclass in need should override this method. + Check if certain inputs should go to the backend. Subclass in need should override this method. Args: *args(Primitive args): Same as arguments of current Primitive. Returns: - A tuple of two elements, first element indicates whether or not we should filter out current arguments; - seconde element is the output in case where we should filter out the arguments. + A tuple consisting of two elements. The first element indicates whether we should filter out current + arguments; the seconde element is the output if we need to filter out the arguments. """ return (False, None) @@ -178,7 +184,7 @@ class Primitive(Primitive_): def init_prim_io_names(self, inputs, outputs): """ - Initializes inputs and outpus name of Tensor or attributes. + Initializes the name of inputs and outpus of Tensor or attributes. Args: inputs (list[str]): list of inputs names. @@ -189,18 +195,23 @@ class Primitive(Primitive_): # for checking output number with kernel implementation self.add_prim_attr("output_names", outputs) + @property + def update_parameter(self): + """ Whether the primitive will update the value of parameter.""" + return self._update_parameter + class PrimitiveWithInfer(Primitive): """ - PrimitiveWithInfer is base class for primitives in python and defines functions for infer of tracks in python. + PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python. There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority - to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape - and type infer logic. The infer_value() is used for constant propagation. + to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer + logic of the shape and type. The infer_value() is used for constant propagation. Args: - name (str): Name for current Primitive. + name (str): Name of the current Primitive. Examples: >>> # init a Primitive class with infer @@ -268,27 +279,63 @@ class PrimitiveWithInfer(Primitive): args (Any): value of inputs. Return: - Value of outputs. Return `None` for, cat not infer the value at compile time. + Value of outputs. Return `None`, the value can not be inferred at compile time in this case. """ return None def __infer__(self, *args): """Infer shape, type, and value at the same time by using dictionary as arguments.""" + is_graph_mode = context.get_context("mode") == context.GRAPH_MODE + fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None) + if is_graph_mode and fn_infer_dynamic_shape is not None: + out = fn_infer_dynamic_shape(*args) + tracks = ['dtype', 'value'] + for track in tracks: + fn = getattr(self, 'infer_' + track) + # fn may return None + out[track] = fn(*(x[track] for x in args)) + return out + tracks = ['dtype', 'shape', 'value'] out = {} for track in tracks: fn = getattr(self, 'infer_' + track) # fn may return None out[track] = fn(*(x[track] for x in args)) - return out + + # in non-graph_mode, it is not necessary to infer min/max shape + if not is_graph_mode: + return out + + def get_specified_shape(elems, attr): + has_specified_shape = False + ret_vals = [] + for elem in elems: + if attr in elem: + has_specified_shape = True + ret_vals.append(elem[attr]) + else: + ret_vals.append(elem['shape']) + return has_specified_shape, tuple(ret_vals) + + has_min_shape, min_shapes = get_specified_shape(args, 'min_shape') + has_max_shape, max_shapes = get_specified_shape(args, 'max_shape') + if not (has_min_shape or has_max_shape): + return out + if has_min_shape and has_max_shape: + fn_infer_shape = getattr(self, 'infer_shape') + out['min_shape'] = fn_infer_shape(*min_shapes) + out['max_shape'] = fn_infer_shape(*max_shapes) + return out + raise ValueError('Input args has invalid dynamic shape, args info: {args}') def prim_attr_register(fn): """ Primitive attributes register. - Registering the decorator of the built-in operator primitive __init__ - function will add all the parameters of __init__ as operator attributes. + Register the decorator of the built-in operator primitive '__init__'. + The function will add all the parameters of '__init__' as operator attributes. Args: fn (function): __init__ function of primitive. @@ -317,17 +364,17 @@ def prim_attr_register(fn): def constexpr(fn=None, get_instance=True, name=None): """ - Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function - to compute between constant variable and used in constructß. + Make a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function to + compute constant value using the constants in the constructor. Args: fn (function): A `fn` use as the infer_value of the output operator. - get_instance (bool): If true, returns the instance of operator, else returns the operator class. + get_instance (bool): If true, return the instance of operator, otherwise return the operator class. name (str): Defines the operator name. If `name` is None, use the function name as op name. Examples: >>> a = (1, 2) - >>> # make a operator to calculate tuple len + >>> # make an operator to calculate tuple len >>> @constexpr >>> def tuple_len(x): >>> return len(x) @@ -344,7 +391,7 @@ def constexpr(fn=None, get_instance=True, name=None): def __init__(self): op_name = name if name else fn.__name__ PrimitiveWithInfer.__init__(self, op_name) - self.const_value = True + self.set_is_const_value(True) def infer_value(self, *args): return fn(*args) @@ -359,7 +406,20 @@ def constexpr(fn=None, get_instance=True, name=None): @_wrap_func def _run_op(obj, op_name, args): """Single op execution function supported by ge in PyNative mode.""" - output = real_run_op(obj, op_name, args) + cast = tensor_operator_registry.get("cast") + if op_name == "Cast" or obj.update_parameter: + cast_args = args + else: + cast_args = list() + for arg in args: + if isinstance(arg, Parameter): + if arg.cast_type: + cast_args.append(cast(arg, arg.cast_type)) + else: + cast_args.append(arg) + else: + cast_args.append(arg) + output = real_run_op(obj, op_name, tuple(cast_args)) if not output: raise RuntimeError("Pynative run op %s failed!" % op_name) if len(output) == 1: diff --git a/mindspore/ops/vm_impl_registry.py b/mindspore/ops/vm_impl_registry.py index 6217616483..3265d539b2 100644 --- a/mindspore/ops/vm_impl_registry.py +++ b/mindspore/ops/vm_impl_registry.py @@ -37,10 +37,10 @@ Examples: def get_vm_impl_fn(prim): """ - Gets vm function by primitive obj or primitive name for c++ + Get the virtual implementation function by a primitive object or primitive name. Args: - prim (Union[Primitive, str]): primitive obj or primitive name for operator register by name. + prim (Union[Primitive, str]): primitive object or name for operator register. Returns: function, vm function diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 93fe233855..7fd366c2f0 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext from mindspore._checkparam import args_type_check _MAX_GROUP_NAME_LEN = 127 +_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" +_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" class _AutoParallelContext: @@ -267,7 +269,7 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_parameter_broadcast_is_set() - def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): + def set_all_reduce_fusion_split_indices(self, indices, group=""): """ Set allreduce fusion strategy by parameters indices. @@ -294,11 +296,17 @@ class _AutoParallelContext: else: raise TypeError('Group must be a python str') + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME + self._context_handle.set_all_reduce_fusion_split_indices(indices, group) if context.get_context("device_target") == "Ascend": _set_fusion_strategy_by_idx(indices) - def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): + def get_all_reduce_fusion_split_indices(self, group=""): """ Get allreduce fusion split indices. @@ -318,9 +326,15 @@ class _AutoParallelContext: raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') else: raise TypeError('Group must be a python str') + + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME return self._context_handle.get_all_reduce_fusion_split_indices(group) - def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): + def set_all_reduce_fusion_split_sizes(self, sizes, group=""): """ Set allreduce fusion strategy by parameters data sizes. @@ -347,11 +361,17 @@ class _AutoParallelContext: else: raise TypeError('Group must be a python str') + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME + self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) if context.get_context("device_target") == "Ascend": _set_fusion_strategy_by_size(sizes) - def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): + def get_all_reduce_fusion_split_sizes(self, group=""): """ Get allreduce fusion split sizes. @@ -371,6 +391,12 @@ class _AutoParallelContext: raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') else: raise TypeError('Group must be a python str') + + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME return self._context_handle.get_all_reduce_fusion_split_sizes(group) def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): @@ -513,7 +539,7 @@ def _set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. - enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. + enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/parallel/_dp_allreduce_fusion.py b/mindspore/parallel/_dp_allreduce_fusion.py index 3c7039dbd6..ad78595d13 100644 --- a/mindspore/parallel/_dp_allreduce_fusion.py +++ b/mindspore/parallel/_dp_allreduce_fusion.py @@ -16,8 +16,6 @@ import ctypes -from mindspore import log as logger - _MAX_GROUP_NAME_LEN = 127 _HCCL_LIB = 'libhccl.so' @@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so' def _load_lib(): try: hccl_lib = ctypes.CDLL(_HCCL_LIB) - except RuntimeError: - logger.error('Get hccl lib error') + except Exception: + raise RuntimeError('Get hccl lib error') return hccl_lib @@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): try: lib_ctype = _load_lib() except RuntimeError: - logger.error('Load HCCL lib failed') - + import hccl_test.manage.api as hccl + hccl.set_fusion_strategy_by_idx() + return if isinstance(group, (str)): group_len = len(group) if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): @@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): try: lib_ctype = _load_lib() except RuntimeError: - logger.error('Load HCCL lib failed') + import hccl_test.manage.api as hccl + hccl.set_fusion_strategy_by_size() + return if isinstance(group, (str)): group_len = len(group) if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: diff --git a/mindspore/parallel/_tensor.py b/mindspore/parallel/_tensor.py index fca8b88920..598046f66a 100644 --- a/mindspore/parallel/_tensor.py +++ b/mindspore/parallel/_tensor.py @@ -229,8 +229,8 @@ def _load_tensor_by_layout(tensor, layout): """ if not isinstance(layout, list): raise TypeError("The layout should be list! layout is {}".format(layout)) - if len(layout) != 3: - raise ValueError("The length of layout must be 3! layout is {}".format(layout)) + if len(layout) < 3: + raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) dev_mat = layout[0] tensor_map = layout[1] if tensor.size() == 1: @@ -290,3 +290,37 @@ def _reshape_param_data(param_data, dev_mat, tensor_map): tensor_slices_new = tensor_slices_new_inner return Tensor(tensor_slices_new[0]) + +def _reshape_param_data_with_weight(param_data, dev_mat, field_size): + """ + Combine param slice by the device matrix, used in model parallel scenario. + + Args: + param_data (Tensor): The tensor to be reshaped and rearrangement, + generated from all the device from AllGatherParamNet. + dev_mat (list): The device matrix of devices. + Returns: + Tensor, the combined tensor which with the whole data value. + + Examples: + >>> param_data = _allgather_param_net(param_data) + >>> dev_mat = [2, 2] + >>> field_size = [39] + >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size) + """ + device_count = 1 + for dim in dev_mat: + device_count *= dim + + tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) + tensor_slices_col = [] + for i in range(len(tensor_slices[0][0])): + tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1) + for j in range(1, device_count): + tensor_slices_new = np.concatenate((tensor_slices_new,\ + np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1) + tensor_slices_col.append(tensor_slices_new) + new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) + for i in range(1, len(tensor_slices_col)): + new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1) + return Tensor(new_tensor) diff --git a/mindspore/profiler/__init__.py b/mindspore/profiler/__init__.py new file mode 100644 index 0000000000..a77d94f3c8 --- /dev/null +++ b/mindspore/profiler/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Profiler Module Introduction. + +This module provides Python APIs to enable the profiling of MindSpore neural networks. +Users can import the mindspore.profiler.Profiler, initialize the Profiler object to start profiling, +and use Profiler.analyse() to stop profiling and analyse the results. +To visualize the profiling results, users can open mindspore Web, find the corresponding run +and click the profile link. +Now, Profiler supports the AICore operator analysis. +""" +from mindspore.profiler.profiling import Profiler + +__all__ = ["Profiler"] diff --git a/model_zoo/deeplabv3/src/utils/__init__.py b/mindspore/profiler/common/__init__.py similarity index 100% rename from model_zoo/deeplabv3/src/utils/__init__.py rename to mindspore/profiler/common/__init__.py diff --git a/mindspore/profiler/common/exceptions/__init__.py b/mindspore/profiler/common/exceptions/__init__.py new file mode 100644 index 0000000000..e30774307c --- /dev/null +++ b/mindspore/profiler/common/exceptions/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindspore/profiler/common/exceptions/error_code.py b/mindspore/profiler/common/exceptions/error_code.py new file mode 100644 index 0000000000..0514f52dab --- /dev/null +++ b/mindspore/profiler/common/exceptions/error_code.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Profiler error code and messages.""" +from enum import unique, Enum + + +_GENERAL_MASK = 0b00001 << 7 +_PARSER_MASK = 0b00010 << 7 +_ANALYSER_MASK = 0b00011 << 7 + + +class ProfilerMgrErrors(Enum): + """Enum definition for profiler errors""" + +@unique +class ProfilerErrors(ProfilerMgrErrors): + """Profiler error codes.""" + # general error code + PARAM_VALUE_ERROR = 0 | _GENERAL_MASK + PATH_ERROR = 1 | _GENERAL_MASK + PARAM_TYPE_ERROR = 2 | _GENERAL_MASK + DIR_NOT_FOUND_ERROR = 3 | _GENERAL_MASK + FILE_NOT_FOUND_ERROR = 4 | _GENERAL_MASK + IO_ERROR = 5 | _GENERAL_MASK + + # parser error code + DEVICE_ID_MISMATCH_ERROR = 0 | _PARSER_MASK + RAW_FILE_ERROR = 1 | _PARSER_MASK + STEP_NUM_NOT_SUPPORTED_ERROR = 2 | _PARSER_MASK + JOB_ID_MISMATCH_ERROR = 3 | _PARSER_MASK + + # analyser error code + COLUMN_NOT_EXIST_ERROR = 0 | _ANALYSER_MASK + ANALYSER_NOT_EXIST_ERROR = 1 | _ANALYSER_MASK + DEVICE_ID_ERROR = 2 | _ANALYSER_MASK + OP_TYPE_ERROR = 3 | _ANALYSER_MASK + GROUP_CONDITION_ERROR = 4 | _ANALYSER_MASK + SORT_CONDITION_ERROR = 5 | _ANALYSER_MASK + FILTER_CONDITION_ERROR = 6 | _ANALYSER_MASK + COLUMN_NOT_SUPPORT_SORT_ERROR = 7 | _ANALYSER_MASK + PIPELINE_OP_NOT_EXIST_ERROR = 8 | _ANALYSER_MASK + + + + +@unique +class ProfilerErrorMsg(Enum): + """Profiler error messages.""" + # general error msg + PARAM_VALUE_ERROR = 'Param value error. {}' + PATH_ERROR = 'Path error. {}' + PARAM_TYPE_ERROR = 'Param type error. {}' + DIR_NOT_FOUND_ERROR = 'The dir <{}> not found.' + FILE_NOT_FOUND_ERROR = 'The file <{}> not found.' + IO_ERROR = 'Read or write file fail.' + + # parser error msg + DEVICE_ID_MISMATCH_ERROR = 'The device ID mismatch.' + RAW_FILE_ERROR = 'Raw file error. {}' + STEP_NUM_NOT_SUPPORTED_ERROR = 'The step num must be in {}' + JOB_ID_MISMATCH_ERROR = 'The job id in the parameter is not the same as ' \ + 'in the training trace file. ' + + # analyser error msg + COLUMN_NOT_EXIST_ERROR = 'The column {} does not exist.' + ANALYSER_NOT_EXIST_ERROR = 'The analyser {} does not exist.' + DEIVICE_ID_ERROR = 'The device_id in search_condition error, {}' + FILTER_CONDITION_ERROR = 'The filter_condition in search_condition error, {}' + OP_TYPE_ERROR = 'The op_type in search_condition error, {}' + GROUP_CONDITION_ERROR = 'The group_condition in search_condition error, {}' + SORT_CONDITION_ERROR = 'The sort_condition in search_condition error, {}' + COLUMN_NOT_SUPPORT_SORT_ERROR = 'The column {} does not support to sort.' + PIPELINE_OP_NOT_EXIST_ERROR = 'The minddata pipeline operator {} does not exist.' diff --git a/mindspore/profiler/common/exceptions/exceptions.py b/mindspore/profiler/common/exceptions/exceptions.py new file mode 100644 index 0000000000..d5821d5954 --- /dev/null +++ b/mindspore/profiler/common/exceptions/exceptions.py @@ -0,0 +1,287 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Definition of error code and relative messages in profiler module.""" +from mindspore.profiler.common.exceptions.error_code import ProfilerErrors, \ + ProfilerErrorMsg + + +class ProfilerException(Exception): + """ + Base class for Profilier exception. + + Examples: + >>> raise ProfilerException(GeneralErrors.PATH_NOT_EXISTS_ERROR, 'path not exists') + """ + + RUNTIME = 1 + TYPE = 1 + LEVEL = 0 + SYSID = 42 + + def __init__(self, error, message, http_code=500): + """ + Initialization of ProfilerException. + + Args: + error (Enum): Error value for specified case. + message (str): Description for exception. + http_code (int): Http code for exception. Default is 500. + """ + if isinstance(message, str): + message = ' '.join(message.split()) + super(ProfilerException, self).__init__(message) + self.error = error + self.message = message + self.http_code = http_code + + + @property + def error_code(self): + """ + Transform exception no to Profiler error code. + + code compose(4bytes): + runtime 2bits, type 2bits, level 3bits, sysid 8bits, modid 5bits, value 12bits. + + num = ((0xFF & runtime) << 30) \ + | ((0xFF & type) << 28) \ + | ((0xFF & level) << 25) \ + | ((0xFF & sysid) << 17) \ + | ((0xFF & modid) << 12) \ + | (0x0FFF & value) + + Returns: + str, Hex string representing the composed Profiler error code. + """ + num = (((0xFF & self.RUNTIME) << 30) + | ((0xFF & self.TYPE) << 28) + | ((0xFF & self.LEVEL) << 25) + | ((0xFF & self.SYSID) << 17) + | ((0xFF & 6) << 12) + | (0x0FFF & self.error.value)) + + return hex(num)[2:].zfill(8).upper() + + def __str__(self): + return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code, self.message) + + +class ProfilerParamValueErrorException(ProfilerException): + """The parameter value error in profiler module.""" + + def __init__(self, msg): + super(ProfilerParamValueErrorException, self).__init__( + error=ProfilerErrors.PARAM_VALUE_ERROR, + message=ProfilerErrorMsg.PARAM_VALUE_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerPathErrorException(ProfilerException): + """The path error in profiler module.""" + + def __init__(self, msg): + super(ProfilerPathErrorException, self).__init__( + error=ProfilerErrors.PATH_ERROR, + message=ProfilerErrorMsg.PATH_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerParamTypeErrorException(ProfilerException): + """The parameter type error in profiler module.""" + + def __init__(self, msg): + super(ProfilerParamTypeErrorException, self).__init__( + error=ProfilerErrors.PARAM_TYPE_ERROR, + message=ProfilerErrorMsg.PARAM_TYPE_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerDirNotFoundException(ProfilerException): + """The dir not found exception in profiler module.""" + + def __init__(self, msg): + super(ProfilerDirNotFoundException, self).__init__( + error=ProfilerErrors.DIR_NOT_FOUND_ERROR, + message=ProfilerErrorMsg.DIR_NOT_FOUND_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerFileNotFoundException(ProfilerException): + """The file not found exception in profiler module.""" + + def __init__(self, msg): + super(ProfilerFileNotFoundException, self).__init__( + error=ProfilerErrors.FILE_NOT_FOUND_ERROR, + message=ProfilerErrorMsg.FILE_NOT_FOUND_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerIOException(ProfilerException): + """The IO exception in profiler module.""" + + def __init__(self): + super(ProfilerIOException, self).__init__( + error=ProfilerErrors.IO_ERROR, + message=ProfilerErrorMsg.IO_ERROR.value, + http_code=400 + ) + + +class ProfilerDeviceIdMismatchException(ProfilerException): + """The device id mismatch exception in profiler module.""" + + def __init__(self): + super(ProfilerDeviceIdMismatchException, self).__init__( + error=ProfilerErrors.DEVICE_ID_MISMATCH_ERROR, + message=ProfilerErrorMsg.DEVICE_ID_MISMATCH_ERROR.value, + http_code=400 + ) + + +class ProfilerRawFileException(ProfilerException): + """The raw file exception in profiler module.""" + + def __init__(self, msg): + super(ProfilerRawFileException, self).__init__( + error=ProfilerErrors.RAW_FILE_ERROR, + message=ProfilerErrorMsg.RAW_FILE_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerColumnNotExistException(ProfilerException): + """The column does not exist exception in profiler module.""" + + def __init__(self, msg): + super(ProfilerColumnNotExistException, self).__init__( + error=ProfilerErrors.COLUMN_NOT_EXIST_ERROR, + message=ProfilerErrorMsg.COLUMN_NOT_EXIST_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerAnalyserNotExistException(ProfilerException): + """The analyser in profiler module.""" + + def __init__(self, msg): + super(ProfilerAnalyserNotExistException, self).__init__( + error=ProfilerErrors.ANALYSER_NOT_EXIST_ERROR, + message=ProfilerErrorMsg.ANALYSER_NOT_EXIST_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerDeviceIdException(ProfilerException): + """The parameter device_id error in profiler module.""" + + def __init__(self, msg): + super(ProfilerDeviceIdException, self).__init__( + error=ProfilerErrors.DEVICE_ID_ERROR, + message=ProfilerErrorMsg.DEIVICE_ID_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerOpTypeException(ProfilerException): + """The parameter op_type error in profiler module.""" + + def __init__(self, msg): + super(ProfilerOpTypeException, self).__init__( + error=ProfilerErrors.OP_TYPE_ERROR, + message=ProfilerErrorMsg.OP_TYPE_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerSortConditionException(ProfilerException): + """The parameter sort_condition error in profiler module.""" + + def __init__(self, msg): + super(ProfilerSortConditionException, self).__init__( + error=ProfilerErrors.SORT_CONDITION_ERROR, + message=ProfilerErrorMsg.SORT_CONDITION_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerFilterConditionException(ProfilerException): + """The parameter filer_condition error in profiler module.""" + + def __init__(self, msg): + super(ProfilerFilterConditionException, self).__init__( + error=ProfilerErrors.FILTER_CONDITION_ERROR, + message=ProfilerErrorMsg.FILTER_CONDITION_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerGroupConditionException(ProfilerException): + """The parameter group_condition error in profiler module.""" + + def __init__(self, msg): + super(ProfilerGroupConditionException, self).__init__( + error=ProfilerErrors.GROUP_CONDITION_ERROR, + message=ProfilerErrorMsg.GROUP_CONDITION_ERROR.value.format(msg), + http_code=400 + ) + + +class ProfilerColumnNotSupportSortException(ProfilerException): + """The column does not support to sort error in profiler module.""" + + def __init__(self, msg): + super(ProfilerColumnNotSupportSortException, self).__init__( + error=ProfilerErrors.COLUMN_NOT_SUPPORT_SORT_ERROR, + message=ProfilerErrorMsg.COLUMN_NOT_SUPPORT_SORT_ERROR.value.format(msg), + http_code=400 + ) + + +class StepNumNotSupportedException(ProfilerException): + """The step number error in profiler module.""" + + def __init__(self, msg): + super(StepNumNotSupportedException, self).__init__( + error=ProfilerErrors.STEP_NUM_NOT_SUPPORTED_ERROR, + message=ProfilerErrorMsg.STEP_NUM_NOT_SUPPORTED_ERROR.value.format(msg), + http_code=400 + ) + + +class JobIdMismatchException(ProfilerException): + """The Job ID mismatch error in profiler module.""" + + def __init__(self): + super(JobIdMismatchException, self).__init__( + error=ProfilerErrors.JOB_ID_MISMATCH_ERROR, + message=ProfilerErrorMsg.JOB_ID_MISMATCH_ERROR.value, + http_code=400 + ) + + +class ProfilerPipelineOpNotExistException(ProfilerException): + """The minddata pipeline operator does not exist error in profiler module.""" + + def __init__(self, msg): + super(ProfilerPipelineOpNotExistException, self).__init__( + error=ProfilerErrors.PIPELINE_OP_NOT_EXIST_ERROR, + message=ProfilerErrorMsg.PIPELINE_OP_NOT_EXIST_ERROR.value.format(msg), + http_code=400 + ) diff --git a/mindspore/profiler/common/util.py b/mindspore/profiler/common/util.py new file mode 100644 index 0000000000..180d163ff2 --- /dev/null +++ b/mindspore/profiler/common/util.py @@ -0,0 +1,295 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Profiler util. + +This module provides the utils. +""" +import os + + +# one sys count takes 10 ns, 1 ms has 100000 system count +import re + +PER_MS_SYSCNT = 100000 + + +def to_int(param, param_name): + """ + Transfer param to int type. + + Args: + param (Any): A param transformed. + param_name (str): Param name. + + Returns: + int, value after transformed. + + """ + try: + param = int(param) + except ValueError: + raise TypeError('Must be Integer: ' + param_name) + return param + + +def fwrite_format(output_data_path, data_source=None, is_print=False, is_start=False): + """ + Write data to the output file. + + Args: + output_data_path (str): The output file path of the data. + data_source (str, list, tuple): The data to write. + is_print (bool): whether to print the data to stdout. + is_start (bool): Whether is the first line of the output file, will remove the old file if True." + """ + + if is_start is True and os.path.exists(output_data_path): + os.remove(output_data_path) + + if isinstance(data_source, str) and data_source.startswith("title:"): + title_label = '=' * 20 + data_source = title_label + data_source[6:] + title_label + + with open(output_data_path, 'a+') as f: + if isinstance(data_source, (list, tuple)): + for raw_data in data_source: + if isinstance(raw_data, (list, tuple)): + raw_data = map(str, raw_data) + raw_data = " ".join(raw_data) + f.write(raw_data) + f.write("\n") + else: + f.write(data_source) + f.write("\n") + + if is_print: + if isinstance(data_source, (list, tuple)): + for raw_data in data_source: + if isinstance(raw_data, (list, tuple)): + raw_data = map(str, raw_data) + raw_data = " ".join(raw_data) + print(raw_data) + else: + print(data_source) + + +def get_log_slice_id(file_name): + pattern = re.compile(r'(?<=slice_)\d+') + slice_list = pattern.findall(file_name) + index = re.findall(r'\d+', slice_list[0]) + return int(index[0]) + + +def get_file_join_name(input_path, file_name): + """ + Search files under the special path, and will join all the files to one file. + + Args: + input_path (str): The source path, will search files under it. + file_name (str): The target of the filename, such as 'hwts.log.data.45.dev'. + + Returns: + str, the join file name. + """ + name_list = [] + file_join_name = '' + input_path = os.path.realpath(input_path) + if os.path.exists(input_path): + files = os.listdir(input_path) + for f in files: + if file_name in f and not f.endswith('.done') and not f.endswith('.join') \ + and not f.endswith('.zip'): + name_list.append(f) + + # resort name_list + name_list.sort(key=get_log_slice_id) + + if len(name_list) == 1: + file_join_name = os.path.join(input_path, name_list[0]) + elif len(name_list) > 1: + file_join_name = os.path.join(input_path, '%s.join' % file_name) + if os.path.exists(file_join_name): + os.remove(file_join_name) + with open(file_join_name, 'ab') as bin_data: + for i in name_list: + file = input_path + os.sep + i + with open(file, 'rb') as txt: + bin_data.write(txt.read()) + return file_join_name + +def get_file_names(input_path, file_name): + """ + Search files under the special path. + + Args: + input_path (str): The source path, will search files under it. + file_name (str): The target of the filename, such as 'host_start_log'. + + Returns: + list, file name list. + """ + + input_path = os.path.realpath(input_path) + name_list = [] + if os.path.exists(input_path): + files = os.listdir(input_path) + for f in files: + if file_name in f and not f.endswith('.done') \ + and not f.endswith('.zip'): + name_list.append(f) + break + + return name_list + + +def analyse_device_list_from_profiler_dir(profiler_dir): + """ + Analyse device list from profiler dir. + + Args: + profiler_dir (str): The profiler data dir. + + Returns: + list, the device_id list. + """ + profiler_file_prefix = ["timeline_display", "output_op_compute_time"] + + device_id_list = set() + for _, _, filenames in os.walk(profiler_dir): + for filename in filenames: + if filename.startswith("step_trace_raw"): + items = filename.split("_") + device_num = "" + if len(items) > 3: + device_num = items[3] + else: + items = filename.split("_") + device_num = items[-1].split(".")[0] if items[-1].split(".") else "" + + if device_num.isdigit() and '_'.join(items[:-1]) in profiler_file_prefix: + device_id_list.add(device_num) + + return sorted(list(device_id_list)) + + +def query_latest_trace_time_file(profiler_dir, device_id=0): + """ + Query the latest trace time file. + + Args: + profiler_dir (str): The profiler directory. + device_id (int): The id of device. + + Returns: + str, the latest trace time file path. + """ + files = os.listdir(profiler_dir) + target_file = f'step_trace_raw_{device_id}_detail_time.csv' + try: + latest_file = max( + filter( + lambda file: file == target_file, + files + ), + key=lambda file: os.stat(os.path.join(profiler_dir, file)).st_mtime + ) + except ValueError: + return None + return os.path.join(profiler_dir, latest_file) + + +def query_step_trace_file(profiler_dir): + """ + Query for all step trace file. + + Args: + profiler_dir (str): The directory that contains all step trace files. + + Returns: + str, the file path of step trace time. + """ + files = os.listdir(profiler_dir) + training_trace_file = list( + filter( + lambda file: file.startswith('training_trace') and not file.endswith('.done'), + files + ) + ) + if training_trace_file: + return os.path.join(profiler_dir, training_trace_file[0]) + return None + + +def get_summary_for_step_trace(average_info, header): + """The property of summary info.""" + if not average_info or not header: + return {} + total_time = get_field_value(average_info, 'total', header) + iteration_interval = get_field_value(average_info, 'iteration_interval', + header) + fp_and_bp = get_field_value(average_info, 'fp_and_bp', header) + tail = get_field_value(average_info, 'tail', header) + summary = { + 'total_time': total_time, + 'iteration_interval': iteration_interval, + 'iteration_interval_percent': calculate_percent(iteration_interval, total_time), + 'fp_and_bp': fp_and_bp, + 'fp_and_bp_percent': calculate_percent(fp_and_bp, total_time), + 'tail': tail, + 'tail_percent': calculate_percent(tail, total_time) + } + return summary + + +def calculate_percent(partial, total): + """Calculate percent value.""" + if total: + percent = round(partial / total * 100, 2) + else: + percent = 0 + return f'{percent}%' + + +def to_millisecond(sys_count, limit=4): + """Translate system count to millisecond.""" + return round(sys_count / PER_MS_SYSCNT, limit) + + +def get_field_value(row_info, field_name, header, time_type='realtime'): + """ + Extract basic info through row_info. + + Args: + row_info (list): The list of data info in one row. + field_name (str): The name in header. + header (list[str]): The list of field names. + time_type (str): The type of value, `realtime` or `systime`. Default: `realtime`. + + Returns: + dict, step trace info in dict format. + """ + field_index = header.index(field_name) + value = row_info[field_index] + value = to_int(value, field_name) + if time_type == 'realtime': + value = to_millisecond(value) + + return value + +def get_options(options): + if options is None: + options = {} + return options diff --git a/mindspore/profiler/common/validator/__init__.py b/mindspore/profiler/common/validator/__init__.py new file mode 100644 index 0000000000..e30774307c --- /dev/null +++ b/mindspore/profiler/common/validator/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindspore/profiler/common/validator/checkparam.py b/mindspore/profiler/common/validator/checkparam.py new file mode 100644 index 0000000000..ebe8cc1673 --- /dev/null +++ b/mindspore/profiler/common/validator/checkparam.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Profiler check parameters.""" +def check_bool(input_param, param_name): + """Bool type judgment.""" + if isinstance(input_param, bool): + return input_param + raise TypeError("Parameter {}: input type must be bool!".format(param_name)) + +def check_subgraph(subgraph): + """Check subgraph.""" + if subgraph in ("all", "Default", "Gradients"): + return subgraph + raise ValueError("subgraph must be all or Default or Gradients, but got {}.".format(subgraph)) diff --git a/mindspore/profiler/common/validator/validate.py b/mindspore/profiler/common/validator/validate.py new file mode 100644 index 0000000000..f883b027af --- /dev/null +++ b/mindspore/profiler/common/validator/validate.py @@ -0,0 +1,307 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Validate the profiler parameters.""" +import os +import sys + + +from mindspore.profiler.common.exceptions.exceptions import ProfilerParamTypeErrorException, \ + ProfilerDeviceIdException, ProfilerOpTypeException, \ + ProfilerSortConditionException, ProfilerFilterConditionException, \ + ProfilerGroupConditionException, ProfilerParamValueErrorException +from mindspore import log +from mindspore.profiler.common.util import to_int + +AICORE_TYPE_COL = ["op_type", "execution_time", "execution_frequency", "precent"] +AICORE_DETAIL_COL = ["op_name", "op_type", "avg_execution_time", "subgraph", "full_op_name"] +AICPU_COL = ["serial_number", "op_type", "total_time", "dispatch_time", "run_start", + "run_end"] +MINDDATA_PIPELINE_COL = [ + 'op_id', 'op_type', 'num_workers', 'output_queue_average_size', + 'output_queue_length', 'output_queue_usage_rate', 'sample_interval', + 'parent_id' +] + + +def validate_condition(search_condition): + """ + Verify the param in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + + Raises: + ProfilerParamTypeErrorException: If the type of the param in search_condition is invalid. + ProfilerDeviceIdException: If the device_id param in search_condition is invalid. + ProfilerOpTypeException: If the op_type param in search_condition is invalid. + ProfilerGroupConditionException: If the group_condition param in search_condition is invalid. + ProfilerSortConditionException: If the sort_condition param in search_condition is invalid. + ProfilerFilterConditionException: If the filter_condition param in search_condition is invalid. + """ + if not isinstance(search_condition, dict): + log.error("Invalid search_condition type, it should be dict.") + raise ProfilerParamTypeErrorException( + "Invalid search_condition type, it should be dict.") + + if "device_id" in search_condition: + device_id = search_condition.get("device_id") + if not isinstance(device_id, str): + raise ProfilerDeviceIdException("Invalid device_id type, it should be str.") + + if "op_type" in search_condition: + op_type = search_condition.get("op_type") + if op_type == "aicpu": + search_scope = AICPU_COL + elif op_type == "aicore_type": + search_scope = AICORE_TYPE_COL + elif op_type == "aicore_detail": + search_scope = AICORE_DETAIL_COL + else: + raise ProfilerOpTypeException("The op_type must in ['aicpu', 'aicore_type', 'aicore_detail']") + else: + raise ProfilerOpTypeException("The op_type must in ['aicpu', 'aicore_type', 'aicore_detail']") + + if "group_condition" in search_condition: + validate_group_condition(search_condition) + + if "sort_condition" in search_condition: + validate_sort_condition(search_condition, search_scope) + + if "filter_condition" in search_condition: + validate_filter_condition(search_condition) + + +def validate_group_condition(search_condition): + """ + Verify the group_condition in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + + Raises: + ProfilerGroupConditionException: If the group_condition param in search_condition is invalid. + """ + group_condition = search_condition.get("group_condition") + if not isinstance(group_condition, dict): + raise ProfilerGroupConditionException("The group condition must be dict.") + if "limit" in group_condition: + limit = group_condition.get("limit", 10) + if isinstance(limit, bool) \ + or not isinstance(group_condition.get("limit"), int): + log.error("The limit must be int.") + raise ProfilerGroupConditionException("The limit must be int.") + if limit < 1 or limit > 100: + raise ProfilerGroupConditionException("The limit must in [1, 100].") + + if "offset" in group_condition: + offset = group_condition.get("offset", 0) + if isinstance(offset, bool) \ + or not isinstance(group_condition.get("offset"), int): + log.error("The offset must be int.") + raise ProfilerGroupConditionException("The offset must be int.") + if offset < 0: + raise ProfilerGroupConditionException("The offset must ge 0.") + + if offset > 1000000: + raise ProfilerGroupConditionException("The offset must le 1000000.") + + +def validate_sort_condition(search_condition, search_scope): + """ + Verify the sort_condition in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + search_scope (list): The search scope. + + Raises: + ProfilerSortConditionException: If the sort_condition param in search_condition is invalid. + """ + sort_condition = search_condition.get("sort_condition") + if not isinstance(sort_condition, dict): + raise ProfilerSortConditionException("The sort condition must be dict.") + if "name" in sort_condition: + sorted_name = sort_condition.get("name", "") + err_msg = "The sorted_name must be in {}".format(search_scope) + if not isinstance(sorted_name, str): + log.error("Wrong sorted name type.") + raise ProfilerSortConditionException("Wrong sorted name type.") + if sorted_name not in search_scope: + log.error(err_msg) + raise ProfilerSortConditionException(err_msg) + + if "type" in sort_condition: + sorted_type_param = ['ascending', 'descending'] + sorted_type = sort_condition.get("type") + if sorted_type and sorted_type not in sorted_type_param: + err_msg = "The sorted type must be ascending or descending." + log.error(err_msg) + raise ProfilerSortConditionException(err_msg) + + +def validate_op_filter_condition(op_condition, value_type=str, value_type_msg='str'): + """ + Verify the op_condition in filter_condition is valid or not. + + Args: + op_condition (dict): The op_condition in search_condition. + value_type (type): The value type. Default: str. + value_type_msg (str): The value type message. Default: 'str'. + + Raises: + ProfilerFilterConditionException: If the filter_condition param in search_condition is invalid. + """ + filter_key = ["in", "not_in", "partial_match_str_in"] + if not isinstance(op_condition, dict): + raise ProfilerFilterConditionException("The filter condition value must be dict.") + for key, value in op_condition.items(): + if not isinstance(key, str): + raise ProfilerFilterConditionException("The filter key must be str") + if not isinstance(value, list): + raise ProfilerFilterConditionException("The filter value must be list") + if key not in filter_key: + raise ProfilerFilterConditionException("The filter key must in {}.".format(filter_key)) + for item in value: + if not isinstance(item, value_type): + raise ProfilerFilterConditionException( + "The item in filter value must be {}.".format(value_type_msg) + ) + + +def validate_filter_condition(search_condition): + """ + Verify the filter_condition in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + + Raises: + ProfilerFilterConditionException: If the filter_condition param in search_condition is invalid. + """ + filter_condition = search_condition.get("filter_condition") + if not isinstance(filter_condition, dict): + raise ProfilerFilterConditionException("The filter condition must be dict.") + if filter_condition: + if "op_type" in filter_condition: + op_type_condition = filter_condition.get("op_type") + validate_op_filter_condition(op_type_condition) + if "op_name" in filter_condition: + op_name_condition = filter_condition.get("op_name") + validate_op_filter_condition(op_name_condition) + if "op_type" not in filter_condition and "op_name" not in filter_condition: + raise ProfilerFilterConditionException("The key of filter_condition is not support") + + +def validate_and_set_job_id_env(job_id_env): + """ + Validate the job id and set it in environment. + + Args: + job_id_env (str): The id that to be set in environment parameter `JOB_ID`. + + Returns: + int, the valid job id env. + """ + if job_id_env is None: + return job_id_env + # get job_id_env in int type + valid_id = to_int(job_id_env, 'job_id_env') + # check the range of valid_id + if valid_id and 255 < valid_id < sys.maxsize: + os.environ['JOB_ID'] = job_id_env + else: + log.warning("Invalid job_id_env %s. The value should be int and between 255 and %s. Use" + "default job id env instead.", + job_id_env, sys.maxsize) + return valid_id + + +def validate_ui_proc(proc_name): + """ + Validate proc name in restful request. + + Args: + proc_name (str): The proc name to query. Acceptable value is in + [`iteration_interval`, `fp_and_bp`, `tail`]. + + Raises: + ProfilerParamValueErrorException: If the proc_name is invalid. + """ + accept_names = ['iteration_interval', 'fp_and_bp', 'tail'] + if proc_name not in accept_names: + log.error("Invalid proc_name. The proc_name for restful api is in %s", accept_names) + raise ProfilerParamValueErrorException(f'proc_name should be in {accept_names}.') + + +def validate_minddata_pipeline_condition(condition): + """ + Verify the minddata pipeline search condition is valid or not. + + Args: + condition (dict): The minddata pipeline search condition. + + Raises: + ProfilerParamTypeErrorException: If the type of the search condition is + invalid. + ProfilerDeviceIdException: If the device_id param in the search + condition is invalid. + ProfilerGroupConditionException: If the group_condition param in the + search condition is invalid. + ProfilerSortConditionException: If the sort_condition param in the + search condition is invalid. + ProfilerFilterConditionException: If the filter_condition param in the + search condition is invalid. + """ + if not isinstance(condition, dict): + log.error("Invalid condition type, it should be dict.") + raise ProfilerParamTypeErrorException( + "Invalid condition type, it should be dict." + ) + + if "device_id" in condition: + device_id = condition.get("device_id") + if not isinstance(device_id, str): + raise ProfilerDeviceIdException( + "Invalid device_id type, it should be str." + ) + + if "group_condition" in condition: + validate_group_condition(condition) + + if "sort_condition" in condition: + validate_sort_condition(condition, MINDDATA_PIPELINE_COL) + + if "filter_condition" in condition: + filter_condition = condition.get('filter_condition') + if not isinstance(filter_condition, dict): + raise ProfilerFilterConditionException( + "The filter condition must be dict." + ) + for key, value in filter_condition.items(): + if key == 'op_id': + validate_op_filter_condition( + value, value_type=int, value_type_msg='int' + ) + elif key == 'op_type': + validate_op_filter_condition(value) + elif key == 'is_display_op_detail': + if not isinstance(value, bool): + raise ProfilerFilterConditionException( + "The condition must be bool." + ) + else: + raise ProfilerFilterConditionException( + "The key {} of filter_condition is not support.".format(key) + ) diff --git a/mindspore/profiler/common/validator/validate_path.py b/mindspore/profiler/common/validator/validate_path.py new file mode 100644 index 0000000000..95d0049203 --- /dev/null +++ b/mindspore/profiler/common/validator/validate_path.py @@ -0,0 +1,60 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Validate the input path.""" +import os + + +def validate_and_normalize_path( + path, + check_absolute_path=False, + allow_parent_dir=False, +): + """ + Validates path and returns its normalized form. + + If path has a valid scheme, treat path as url, otherwise consider path a + unix local path. + + Note: + File scheme (rfc8089) is currently not supported. + + Args: + path (str): Path to be normalized. + check_absolute_path (bool): Whether check path scheme is supported. + allow_parent_dir (bool): Whether allow parent dir in path. + + Returns: + str, normalized path. + """ + if not path: + raise RuntimeError("The path is invalid!") + + path_str = str(path) + if not allow_parent_dir: + path_components = path_str.split("/") + if ".." in path_components: + raise RuntimeError("The path is invalid!") + + # path does not have valid schema, treat it as unix local path. + if check_absolute_path: + if not path_str.startswith("/"): + raise RuntimeError("The path is invalid!") + try: + # most unix systems allow + normalized_path = os.path.realpath(path) + except ValueError: + raise RuntimeError("The path is invalid!") + + return normalized_path diff --git a/mindspore/profiler/parser/__init__.py b/mindspore/profiler/parser/__init__.py new file mode 100644 index 0000000000..e30774307c --- /dev/null +++ b/mindspore/profiler/parser/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindspore/profiler/parser/aicpu_data_parser.py b/mindspore/profiler/parser/aicpu_data_parser.py new file mode 100644 index 0000000000..32304edfc3 --- /dev/null +++ b/mindspore/profiler/parser/aicpu_data_parser.py @@ -0,0 +1,175 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +The parser for AI CPU preprocess data. +""" +import os + +from mindspore.profiler.common.util import fwrite_format, get_file_join_name +from mindspore import log as logger + + +class DataPreProcessParser: + """ + The Parser for AI CPU preprocess data. + + Args: + input_path(str): The profiling job path. + output_filename(str): The output data path and name. + + """ + + _source_file_target = 'DATA_PREPROCESS.dev.AICPU.' + _dst_file_title = 'title:DATA_PREPROCESS AICPU' + _dst_file_column_title = ['serial_number', 'node_type_name', 'total_time(ms)', + 'dispatch_time(ms)', 'run_start', 'run_end'] + _ms_unit = 1000 + + def __init__(self, input_path, output_filename): + self._input_path = input_path + self._output_filename = output_filename + self._source_file_name = self._get_source_file() + self._ms_kernel_flag = 3 + self._other_kernel_flag = 6 + self._thread_flag = 7 + self._ms_kernel_run_end_index = 2 + self._other_kernel_run_end_index = 5 + self._result_list = [] + self._min_cycle_counter = float('inf') + + def _get_source_file(self): + """Get log file name, which was created by ada service.""" + file_name = get_file_join_name(self._input_path, self._source_file_target) + if not file_name: + data_path = os.path.join(self._input_path, "data") + file_name = get_file_join_name(data_path, self._source_file_target) + return file_name + + def _get_kernel_result(self, number, node_list, thread_list): + """Get the profiling data form different aicpu kernel""" + try: + if len(node_list) == self._ms_kernel_flag and len(thread_list) == self._thread_flag: + node_type_name = node_list[0].split(':')[-1] + run_end_index = self._ms_kernel_run_end_index + elif len(node_list) == self._other_kernel_flag and len(thread_list) == self._thread_flag: + node_type_name = node_list[0].split(':')[-1].split('/')[-1].split('-')[0] + run_end_index = self._other_kernel_run_end_index + else: + logger.warning("the data format can't support 'node_list':%s", str(node_list)) + return None + + run_start = node_list[1].split(':')[-1].split(' ')[0] + run_end = node_list[run_end_index].split(':')[-1].split(' ')[0] + total_time = float(thread_list[-1].split('=')[-1].split()[0]) / self._ms_unit + dispatch_time = float(thread_list[-2].split('=')[-1].split()[0]) / self._ms_unit + + return [number, node_type_name, total_time, dispatch_time, + run_start, run_end] + except IndexError as e: + logger.error(e) + return None + + def execute(self): + """Execute the parser, get result data, and write it to the output file.""" + + if not os.path.exists(self._source_file_name): + logger.info("Did not find the aicpu profiling source file") + return + + with open(self._source_file_name, 'rb') as ai_cpu_data: + ai_cpu_str = str(ai_cpu_data.read().replace(b'\n\x00', b' ___ ') + .replace(b'\x00', b' ___ '))[2:-1] + ai_cpu_lines = ai_cpu_str.split(" ___ ") + + result_list = list() + ai_cpu_total_time_summary = 0 + # Node serial number. + serial_number = 1 + for i in range(len(ai_cpu_lines) - 1): + node_line = ai_cpu_lines[i] + thread_line = ai_cpu_lines[i + 1] + if "Node" in node_line and "Thread" in thread_line: + # Get the node data from node_line + node_list = node_line.split(',') + thread_list = thread_line.split(',') + result = self._get_kernel_result(serial_number, node_list, thread_list) + + if result is None: + continue + + result_list.append(result) + # Calculate the total time. + total_time = result[2] + ai_cpu_total_time_summary += total_time + # Increase node serial number. + serial_number += 1 + elif "Node" in node_line and "Thread" not in thread_line: + node_type_name = node_line.split(',')[0].split(':')[-1] + logger.warning("The node type:%s cannot find thread data", node_type_name) + + if result_list: + ai_cpu_total_time = format(ai_cpu_total_time_summary, '.6f') + result_list.append(["AI CPU Total Time(ms):", ai_cpu_total_time]) + fwrite_format(self._output_filename, " ".join(self._dst_file_column_title), is_start=True, is_print=True) + fwrite_format(self._output_filename, result_list, is_print=True) + + # For timeline display. + self._result_list = result_list + + def query_aicpu_data(self): + """ + Get execution time of AI CPU operator. + + Returns: + a dict, the metadata of AI CPU operator execution time. + """ + stream_id = 0 # Default stream id for AI CPU. + pid = 9000 # Default pid for AI CPU. + factor = 1000 # Convert time unit from 1us to 1ms + total_time = 0 + min_cycle_counter = float('inf') + aicpu_info = [] + op_count_list = [] + for aicpu_item in self._result_list: + if "AI CPU Total Time(ms):" in aicpu_item: + total_time = aicpu_item[-1] + continue + + op_name = aicpu_item[1] + start_time = float(aicpu_item[4]) / factor + min_cycle_counter = min(min_cycle_counter, start_time) + end_time = float(aicpu_item[5]) / factor + duration = end_time - start_time + aicpu_info.append([op_name, stream_id, start_time, duration, pid]) + + # Record the number of operator types. + if op_name not in op_count_list: + op_count_list.append(op_name) + + self._min_cycle_counter = min_cycle_counter + aicpu_dict = { + 'info': aicpu_info, + 'total_time': float(total_time), + 'op_exe_times': len(aicpu_info), + 'num_of_ops': len(op_count_list), + 'num_of_streams': 1 + } + + return aicpu_dict + + @property + def min_cycle_counter(self): + """Get minimum cycle counter in AI CPU.""" + return self._min_cycle_counter diff --git a/mindspore/profiler/parser/container.py b/mindspore/profiler/parser/container.py new file mode 100644 index 0000000000..62f054ea7b --- /dev/null +++ b/mindspore/profiler/parser/container.py @@ -0,0 +1,113 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The container of metadata used in profiler parser.""" + + +class HWTSContainer: + """ + HWTS output container. + + Args: + split_list (list): The split list of metadata in HWTS output file. + """ + def __init__(self, split_list): + self._op_name = '' + self._duration = None + self._status = split_list[0] + self._task_id = split_list[6] + self._cycle_counter = float(split_list[7]) + self._stream_id = split_list[8] + + @property + def status(self): + """Get the status of the operator, i.e. Start or End.""" + return self._status + + @property + def task_id(self): + """Get the task id of the operator.""" + return self._task_id + + @property + def cycle_counter(self): + """Get the cycle counter.""" + return self._cycle_counter + + @property + def stream_id(self): + """Get the stream id of the operator.""" + return self._stream_id + + @property + def op_name(self): + """Get the name of the operator.""" + return self._op_name + + @op_name.setter + def op_name(self, name): + """Set the name of the operator.""" + self._op_name = name + + @property + def duration(self): + """Get the duration of the operator execution.""" + return self._duration + + @duration.setter + def duration(self, value): + """Set the duration of the operator execution.""" + self._duration = value + + +class TimelineContainer: + """ + A container of operator computation metadata. + + Args: + split_list (list): The split list of metadata in op_compute output file. + """ + def __init__(self, split_list): + self._op_name = split_list[0] + self._stream_id = int(split_list[1]) + self._start_time = float(split_list[2]) + self._duration = float(split_list[3]) + self._pid = None + if len(split_list) == 5: + self._pid = int(split_list[4]) + + @property + def op_name(self): + """Get the name of the operator.""" + return self._op_name + + @property + def stream_id(self): + """Get the stream id of the operator.""" + return self._stream_id + + @property + def start_time(self): + """Get the execution start time of the operator.""" + return self._start_time + + @property + def duration(self): + """Get the duration of the operator execution.""" + return self._duration + + @property + def pid(self): + """Get the pid of the operator execution.""" + return self._pid diff --git a/mindspore/profiler/parser/framework_parser.py b/mindspore/profiler/parser/framework_parser.py new file mode 100644 index 0000000000..8299f8f6fa --- /dev/null +++ b/mindspore/profiler/parser/framework_parser.py @@ -0,0 +1,595 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Thr parser for parsing framework files.""" +import csv +import enum +import json +import os +import re + +from mindspore.profiler.common.exceptions.exceptions import \ + ProfilerPathErrorException, ProfilerDirNotFoundException, \ + ProfilerFileNotFoundException, ProfilerDeviceIdMismatchException, \ + ProfilerRawFileException, ProfilerParamValueErrorException +from mindspore.profiler.common.validator.validate_path import \ + validate_and_normalize_path + + +class VmDataType(enum.IntEnum): + """Definition of vm data type.""" + NUMBER_TYPE_BEGIN = 26 + NUMBER_TYPE_BOOL = 27 + NUMBER_TYPE_INT = 28 + NUMBER_TYPE_INT8 = 29 + NUMBER_TYPE_INT16 = 30 + NUMBER_TYPE_INT32 = 31 + NUMBER_TYPE_INT64 = 32 + NUMBER_TYPE_UINT = 33 + NUMBER_TYPE_UINT8 = 34 + NUMBER_TYPE_UINT16 = 35 + NUMBER_TYPE_UINT32 = 36 + NUMBER_TYPE_UINT64 = 37 + NUMBER_TYPE_FLOAT = 38 + NUMBER_TYPE_FLOAT16 = 39 + NUMBER_TYPE_FLOAT32 = 40 + NUMBER_TYPE_FLOAT64 = 41 + NUMBER_TYPE_END = 42 + + @classmethod + def get_data_type_name(cls, num): + """ + Get the name of data type by enum number. + + Args: + num (int): Enum number. + + Returns: + str, the name of data type. + """ + data_type = cls._value2member_map_.get(num) + return 'UNKNOWN' if data_type is None else data_type.name + + +class GeDataType(enum.IntEnum): + """Definition of ge data type.""" + DT_FLOAT = 0 + DT_FLOAT16 = 1 + DT_INT8 = 2 + DT_INT16 = 6 + DT_UINT16 = 7 + DT_UINT8 = 4 + DT_INT32 = 3 + DT_INT64 = 9 + DT_UINT32 = 8 + DT_UINT64 = 10 + DT_BOOL = 12 + DT_DOUBLE = 11 + DT_STRING = 13 + DT_DUAL_SUB_INT8 = 14 + DT_DUAL_SUB_UINT8 = 15 + DT_COMPLEX64 = 16 + DT_COMPLEX128 = 17 + DT_QINT8 = 18 + DT_QINT16 = 19 + DT_QINT32 = 20 + DT_QUINT8 = 21 + DT_QUINT16 = 22 + DT_RESOURCE = 23 + DT_STRING_REF = 24 + DT_DUAL = 25 + DT_UNDEFINED = 26 + + @classmethod + def get_data_type_name(cls, num): + """ + Get the name of data type by enum number. + + Args: + num (int): Enum number. + + Returns: + str, the name of data type. + """ + data_type = cls._value2member_map_.get(num) + return 'UNKNOWN' if data_type is None else data_type.name + + +class GeFormat(enum.IntEnum): + """Definition of ge format type.""" + FORMAT_NCHW = 0 + FORMAT_NHWC = 1 + FORMAT_ND = 2 + FORMAT_NC1HWC0 = 3 + FORMAT_FRACTAL_Z = 4 + FORMAT_NC1C0HWPAD = 5 + FORMAT_NHWC1C0 = 6 + FORMAT_FSR_NCHW = 7 + FORMAT_FRACTAL_DECONV = 8 + FORMAT_C1HWNC0 = 9 + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10 + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11 + FORMAT_NC1HWC0_C04 = 12 + FORMAT_FRACTAL_Z_C04 = 13 + FORMAT_CHWN = 14 + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15 + FORMAT_HWCN = 16 + FORMAT_NC1KHKWHWC0 = 17 + FORMAT_BN_WEIGHT = 18 + FORMAT_FILTER_HWCK = 19 + FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20 + FORMAT_HASHTABLE_LOOKUP_KEYS = 21 + FORMAT_HASHTABLE_LOOKUP_VALUE = 22 + FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23 + FORMAT_HASHTABLE_LOOKUP_HITS = 24 + FORMAT_C1HWNCOC0 = 25 + FORMAT_MD = 26 + FORMAT_NDHWC = 27 + FORMAT_FRACTAL_ZZ = 28 + FORMAT_FRACTAL_NZ = 29 + FORMAT_NCDHW = 30 + FORMAT_DHWCN = 31 + FORMAT_NDC1HWC0 = 32 + FORMAT_FRACTAL_Z_3D = 33 + FORMAT_CN = 34 + FORMAT_NC = 35 + FORMAT_DHWNC = 36 + FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37 + FORMAT_RESERVED = 38 + FORMAT_ALL = 39 + + @classmethod + def get_format_name(cls, num): + """ + Get the name of format type by enum number. + + Args: + num (int): Enum number. + + Returns: + str, the name of format type. + """ + format_type = cls._value2member_map_.get(num) + return 'UNKNOWN' if format_type is None else format_type.name + + +class FrameworkParser: + """ + Thr parser for parsing framework files. + + Args: + profiling_id (str): The profiling ID. + device_id (str): The device ID. + output_path (str): The directory of the parsed file. Default: `./`. + """ + _raw_data_dir = '/var/log/npu/profiling' + _regex_framework = r'Framework\.host\.(?P.+)\.(?P\d).+' + _regex_framework_in_data = r'Framework\.host\.(?P.+)\.' \ + r'(?P\d)\.(?P[a-zA-Z0-9]+).+' + _col_names = [ + 'task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name', + 'op_type', 'subgraph', 'op_info' + ] + _graph_attr_name = [ + 'input_format', 'input_data_type', 'input_shape', 'output_format', + 'output_data_type', 'output_shape' + ] + + # if the task id is less than the task id threshold, The combination of + # task id and Stream id represents one operator, else the task id represents + # one operator + _task_id_threshold = 25000 + + def __init__(self, profiling_id, device_id, output_path='./'): + self._profiling_path = self._get_raw_profiling_path(profiling_id) + self._backend_type = None + self._framework_path = {'graph': [], 'task': [], 'point': []} + self._search_file(profiling_id, device_id) + self._device_id = device_id + self._save_path = self._get_save_path(device_id, output_path) + self._task_id_full_op_name_dict = {} + self._task_cache = {} + self._point_info = {} + self._parse_task_files() + self._parse_point_files() + + @property + def save_path(self): + """ + The property of save path. + + Returns: + str, the save path. + """ + return self._save_path + + @property + def point_info(self): + """ + The property of the framework point information. + + Returns: + dict, the framework point information. + """ + return self._point_info + + def to_task_id_full_op_name_dict(self): + """ + Get the task id and full operator name dict. + + Returns: + dict, the task id and full operator name dict. + """ + return self._task_id_full_op_name_dict + + def parse(self): + """Parse the framework files.""" + self._parse_graph_files_and_save(self._task_cache) + del self._task_cache + + def check_op_name(self, op_name, is_prefix=True): + """ + Check whether the operator name exists. + + Args: + op_name (str): The operator name or operator name prefix. + is_prefix (bool): `True` if the op_name is prefix, else `False`. + Default: True. + + Returns: + bool, `True` if the operator name does exist in framework file, else + `False`. + """ + if not op_name: + raise ProfilerParamValueErrorException('The op_name should exist.') + for full_op_name in self._task_id_full_op_name_dict.values(): + if full_op_name: + if is_prefix and full_op_name.startswith(op_name): + return True + if not is_prefix and op_name == full_op_name: + return True + return False + + def _get_raw_profiling_path(self, profiling_id): + """ + Get raw profiling path. + + Args: + profiling_id (str): The profiling ID. + + Returns: + str, the raw profiling path. + + Raises: + ProfilerPathErrorException: If the profiling path is invalid. + ProfilerDirNotFoundException: If the profiling dir is not found. + """ + profiling_path = os.path.join(self._raw_data_dir, profiling_id) + try: + profiling_path = validate_and_normalize_path(profiling_path) + except RuntimeError: + raise ProfilerPathErrorException('Profiling path is invalid.') + if not os.path.isdir(profiling_path): + raise ProfilerDirNotFoundException(profiling_path) + return profiling_path + + def _search_file(self, profiling_id, device_id): + """ + Search all framework files in raw profiling path. + + Args: + profiling_id (str): The profiling ID. + device_id (str): The device ID. + + Raises: + ProfilerFileNotFoundException: If the framework files are not found. + """ + # first search in the JOB dir, and if not, search in the sub directory + # in the JOB + self._search_file_from_job_path(device_id, search_in_sub_path=False) + if self._backend_type is None: + self._search_file_from_job_path(device_id, search_in_sub_path=True) + self._search_file_from_data_path(profiling_id, device_id) + + if self._backend_type is None: + raise ProfilerFileNotFoundException('Framework') + self._framework_path['graph'].sort() + self._framework_path['task'].sort() + + def _search_file_from_job_path(self, device_id, search_in_sub_path=False): + """ + Search framework files from job path. + + Args: + device_id (str): The device ID. + search_in_sub_path (bool): `True` if search file in profiling dir, + else search in profiling sub dir. Default: False. + + Raises: + ProfilerRawFileException: If the framework file type is inconsistent. + ProfilerDeviceIdMismatchException: If the device id is mismatch + with framework in the raw dir. + """ + profiling_dir = os.path.join(self._profiling_path, 'data') \ + if search_in_sub_path else self._profiling_path + if not os.path.isdir(profiling_dir): + return + + files = os.listdir(profiling_dir) + for file in files: + pattern = re.search(self._regex_framework, file) + if not pattern or file.endswith('.done'): + continue + attrs = pattern.groupdict() + + device_id_in_path = attrs.get('device_id') + if device_id_in_path != device_id: + raise ProfilerDeviceIdMismatchException() + + data_type = attrs.get('data_type') + if data_type.startswith('vm.'): + if self._backend_type and self._backend_type != 'vm': + raise ProfilerRawFileException('Backend type is inconsistent.') + self._backend_type = 'vm' + data_type = data_type.split('.')[1] + else: + if self._backend_type and self._backend_type != 'ge': + raise ProfilerRawFileException('Backend type is inconsistent.') + self._backend_type = 'ge' + if data_type.startswith('graph_desc_info'): + self._framework_path['graph'].append( + os.path.join(profiling_dir, file) + ) + elif data_type.startswith('task_desc_info'): + self._framework_path['task'].append( + os.path.join(profiling_dir, file) + ) + elif data_type.startswith('point'): + self._framework_path['point'].append( + os.path.join(profiling_dir, file) + ) + + def _search_file_from_data_path(self, profiling_id, device_id): + """ + Search framework files from data path. + + Args: + profiling_id (str): The profiling ID. + device_id (str): The device ID. + + Raises: + ProfilerRawFileException: If the framework file type is inconsistent. + ProfilerDeviceIdMismatchException: If the device id is mismatch + with framework in the raw dir. + """ + profiling_data_path = os.path.join( + self._raw_data_dir, 'container', device_id, 'data' + ) + if not os.path.isdir(profiling_data_path): + return + + files = os.listdir(profiling_data_path) + for file in files: + pattern = re.search(self._regex_framework_in_data, file) + if not pattern or file.endswith('.done') or file.endswith('.zip'): + continue + attrs = pattern.groupdict() + + profiling_id_in_path = attrs.get('profiling_id') + if profiling_id_in_path != profiling_id: + continue + + device_id_in_path = attrs.get('device_id') + if device_id_in_path != device_id: + raise ProfilerDeviceIdMismatchException() + + data_type = attrs.get('data_type') + if data_type.startswith('vm.'): + if self._backend_type and self._backend_type != 'vm': + raise ProfilerRawFileException('Backend type is inconsistent.') + self._backend_type = 'vm' + data_type = data_type.split('.')[1] + else: + if self._backend_type and self._backend_type != 'ge': + raise ProfilerRawFileException('Backend type is inconsistent.') + self._backend_type = 'ge' + if data_type.startswith('graph_desc_info'): + self._framework_path['graph'].append( + os.path.join(profiling_data_path, file) + ) + elif data_type.startswith('task_desc_info'): + self._framework_path['task'].append( + os.path.join(profiling_data_path, file) + ) + elif data_type.startswith('point'): + self._framework_path['point'].append( + os.path.join(profiling_data_path, file) + ) + + def _get_save_path(self, device_id, output_path): + """ + Get the save path. + + Args: + device_id (str): The device ID. + output_path (str): The output dir. + + Returns: + str, the save path. + + Raises: + ProfilerPathErrorException: If the output path is invalid. + ProfilerDirNotFoundException: If the output dir is not found. + """ + try: + output_dir = validate_and_normalize_path(output_path) + except RuntimeError: + raise ProfilerPathErrorException('Output path is invalid.') + if not os.path.isdir(output_dir): + raise ProfilerDirNotFoundException(output_dir) + return os.path.join( + output_dir, '_'.join(['framework', 'raw', device_id]) + '.csv' + ) + + def _parse_task_files(self): + """Parse the framework task files.""" + for path in self._framework_path['task']: + with open(path, 'r') as file: + for task_info in file: + infos = task_info.strip('\n').split(' ') + infos = infos[1:] if len(infos) == 5 else infos + # key is op name, values is task id, stream id, block_dim + self._task_cache[infos[0]] = [infos[2], infos[3], infos[1]] + + # if the task id is less than the task id threshold, the + # stream id and task id correspond to an operator + task_id = infos[2] + if int(task_id) < self._task_id_threshold: + task_id = '_'.join([infos[3], task_id]) + self._task_id_full_op_name_dict[task_id] = infos[0] + + def _parse_graph_files_and_save(self, task_cache): + """ + Parse the framework graph files and save the framework information. + + Args: + task_cache (dict): The task information cache. + """ + with open(self._save_path, 'w') as save_file: + csv_writer = csv.writer(save_file) + csv_writer.writerow(self._col_names) + for path in self._framework_path['graph']: + with open(path, 'r') as graph_file: + for graph_info in graph_file: + result = self._parse_one_row_graph_info(graph_info) + task_info = task_cache.get(result[0]) + if task_info: + task_info.extend(result) + csv_writer.writerow(task_info) + del task_cache[result[0]] + else: + save_info = [None, None, None] + save_info.extend(result) + csv_writer.writerow(save_info) + + none_list = [None, None, None, None] + for key, value in task_cache.items(): + value.append(key) + value.extend(none_list) + csv_writer.writerow(value) + + def _parse_one_row_graph_info(self, row_info): + """ + Parse the graph information in one row. + + Args: + row_info (str): One row graph information. + + Returns: + list[str], the parsed graph information. + """ + full_op_name = None + op_name = None + subgraph_name = None + op_type = None + op_info = dict() + cur_op_info_key = None + + infos = row_info.strip('\n').split(' ') + for info in infos: + attr_name, attr_value = info.split(':', 1) + if attr_name == 'op_name': + full_op_name = attr_value + subgraph_name = self._get_subgraph_name(full_op_name) + op_name = self._get_op_name(full_op_name, subgraph_name) + elif attr_name == 'op_type': + op_type = attr_value + elif attr_name in ['input_id', 'output_id']: + cur_op_info_key = '{}_{}'.format( + attr_name.split('_')[0], attr_value + ) + op_info[cur_op_info_key] = dict() + elif attr_name in self._graph_attr_name: + op_attr = attr_name.split('_', 1)[1] + if op_attr == 'shape': + attr_value = attr_value.strip('"') + if self._backend_type == 'vm': + if op_attr == 'data_type': + attr_value = VmDataType.get_data_type_name( + int(attr_value) + ) + else: + if op_attr == 'data_type': + attr_value = GeDataType.get_data_type_name( + int(attr_value) + ) + elif op_attr == 'format': + attr_value = GeFormat.get_format_name(int(attr_value)) + + op_info[cur_op_info_key][op_attr] = attr_value + + # the list info are full_op_name, op_name, op_type, subgraph, op_info + return [full_op_name, op_name, op_type, subgraph_name, + json.dumps(op_info)] + + def _get_subgraph_name(self, full_op_name): + """ + Get subgraph name. + + Args: + full_op_name (str): The full operator name. + + Returns: + str, the subgraph name. + """ + subgraph_name = full_op_name.split('/', 1)[0] + if subgraph_name in ['Default', 'Gradients']: + return subgraph_name + return None + + def _get_op_name(self, full_op_name, subgraph_name): + """ + Get operator name. + + Args: + full_op_name (str): The full operator name. + subgraph_name (str): The subgraph name. + + Returns: + str, the operator name. + """ + if subgraph_name is None: + return full_op_name + + if self._backend_type == 'vm': + return full_op_name.split('/')[-1] + + strs = full_op_name.split(subgraph_name + '/') + op_name = None + for name_str in strs: + if not name_str: + continue + if op_name is None: + op_name = name_str.split('/')[-1] + else: + op_name = '+'.join([op_name, name_str.split('/')[-1]]) + return op_name + + def _parse_point_files(self): + """Parse the framework point files.""" + for path in self._framework_path['point']: + with open(path, 'r') as file: + for point_info in file: + infos = point_info.strip('\n').split(' ') + self._point_info[int(infos[0])] = infos[1] diff --git a/mindspore/profiler/parser/hwts_log_parser.py b/mindspore/profiler/parser/hwts_log_parser.py new file mode 100644 index 0000000000..29550b96c1 --- /dev/null +++ b/mindspore/profiler/parser/hwts_log_parser.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The parser for hwts log file.""" +import os +import struct +from mindspore.profiler.common.util import fwrite_format, get_file_join_name +from mindspore import log as logger + + +class HWTSLogParser: + """ + The Parser for hwts log files. + + Args: + input_path (str): The profiling job path. Such as: '/var/log/npu/profiling/JOBAIFGJEJFEDCBAEADIFJAAAAAAAAAA". + output_filename (str): The output data path and name. Such as: './output_format_data_hwts_0.txt'. + """ + + _source_file_target = 'hwts.log.data.45.dev.profiler_default_tag' + _dst_file_title = 'title:45 HWTS data' + _dst_file_column_title = 'Type cnt Core_ID Block_ID Task_ID Cycle_counter Stream_ID' + + def __init__(self, input_path, output_filename): + self._input_path = input_path + self._output_filename = output_filename + self._source_flie_name = self._get_source_file() + + def _get_source_file(self): + """Get hwts log file name, which was created by ada service.""" + + file_name = get_file_join_name(self._input_path, self._source_file_target) + if not file_name: + data_path = os.path.join(self._input_path, "data") + file_name = get_file_join_name(data_path, self._source_file_target) + if not file_name: + msg = "Fail to find hwts log file, under profiling directory" + raise RuntimeError(msg) + + return file_name + + def execute(self): + """ + Execute the parser, get result data, and write it to the output file. + + Returns: + bool, whether succeed to analyse hwts log. + """ + + content_format = ['QIIIIIIIIIIII', 'QIIQIIIIIIII', 'IIIIQIIIIIIII'] + log_type = ['Start of task', 'End of task', 'Start of block', 'End of block', 'Block PMU'] + + result_data = "" + + with open(self._source_flie_name, 'rb') as hwts_data: + while True: + line = hwts_data.read(64) + if line: + if not line.strip(): + continue + else: + break + byte_first_four = struct.unpack('BBHHH', line[0:8]) + byte_first = bin(byte_first_four[0]).replace('0b', '').zfill(8) + ms_type = byte_first[-3:] + is_warn_res0_ov = byte_first[4] + cnt = int(byte_first[0:4], 2) + core_id = byte_first_four[1] + blk_id, task_id = byte_first_four[3], byte_first_four[4] + if ms_type in ['000', '001', '010']: # log type 0,1,2 + result = struct.unpack(content_format[0], line[8:]) + syscnt = result[0] + stream_id = result[1] + elif ms_type == '011': # log type 3 + result = struct.unpack(content_format[1], line[8:]) + syscnt = result[0] + stream_id = result[1] + elif ms_type == '100': # log type 4 + result = struct.unpack(content_format[2], line[8:]) + stream_id = result[2] + if is_warn_res0_ov == '0': + syscnt = result[4] + else: + syscnt = None + else: + logger.info("Profiling: invalid hwts log record type %s", ms_type) + continue + + if int(task_id) < 25000: + task_id = str(stream_id) + "_" + str(task_id) + result_data += ("%-14s %-4s %-8s %-9s %-8s %-15s %s\n" %(log_type[int(ms_type, 2)], cnt, core_id, + blk_id, task_id, syscnt, stream_id)) + + fwrite_format(self._output_filename, data_source=self._dst_file_title, is_start=True) + fwrite_format(self._output_filename, data_source=self._dst_file_column_title) + fwrite_format(self._output_filename, data_source=result_data) + + return True diff --git a/mindspore/profiler/parser/integrator.py b/mindspore/profiler/parser/integrator.py new file mode 100644 index 0000000000..fa8e208586 --- /dev/null +++ b/mindspore/profiler/parser/integrator.py @@ -0,0 +1,720 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The integrator for integrating parsed profiling files.""" +import csv +import json +import os +from decimal import Decimal + +from mindspore import log as logger +from mindspore.profiler.common.exceptions.exceptions import ProfilerIOException, \ + ProfilerFileNotFoundException, ProfilerRawFileException +from mindspore.profiler.common.util import query_latest_trace_time_file, to_int, to_millisecond +from mindspore.profiler.common.validator.validate_path import validate_and_normalize_path +from mindspore.profiler.parser.container import TimelineContainer + +SIZE_LIMIT = 20 * 1024 * 1024 # 20MB + + +class Integrator: + """ + The integrator for integrating parsed profiling files. + + Args: + profiling_dir (str): The directory where the parsed profiling files are + located. + device_id (str): The device ID. + """ + _file_name_aicore_detail_time = 'output_op_compute_time_{}.txt' + _file_name_aicpu_time = 'output_data_preprocess_aicpu_{}.txt' + _file_name_framework = 'framework_raw_{}.csv' + _header_aicore_type = ['op_type', 'execution_time', 'execution_frequency', + 'percent'] + _header_aicore_detail = ['full_op_name', 'execution_time'] + _header_aicpu = ['serial_number', 'op_type', 'total_time', 'dispatch_time', + 'run_start', 'run_end'] + + _file_name_aicore_type_time = 'aicore_intermediate_{}_type.csv' + _file_name_aicore_detail_info = 'aicore_intermediate_{}_detail.csv' + _col_names_detail = ['op_name', 'op_type', 'avg_execution_time', 'subgraph', 'full_op_name', 'op_info'] + _none_filter_condition_key = ['is_display_detail', 'is_display_full_op_name'] + _none_sort_col_names = ['op_info'] + _aicore_data = [] + _aicore_detail_data = [] + _aicore_trace_data = [] + + def __init__(self, profiling_dir, device_id): + self._profiling_dir = profiling_dir + self._device_id = device_id + self._op_time_cache = {} + self._total_time = Decimal('0.0') + + def integrate(self): + """Integrate the parsed profiling files.""" + self._parse_aicore_detail_time() + self._parse_aicore_type_time() + self._parse_aicpu_time() + + def get_aicore_data(self): + self._aicore_data_load() + return self._aicore_data + + def get_aicore_detail_data(self): + self._aicore_detail_data_load() + return self._aicore_detail_data + + def get_aicore_trace_data(self): + self._aicore_trace_data_load() + return self._aicore_trace_data + + def query_for_all_reduce(self): + return self._query_for_all_reduce() + + def query_and_sort_by_op_type(self, filter_condition, op_type_order): + return self._query_and_sort_by_op_type(filter_condition, op_type_order) + + def _parse_aicore_type_time(self): + """Parse the parsed AICORE operator type file.""" + framework_file = os.path.join( + self._profiling_dir, + self._file_name_framework.format(self._device_id) + ) + if not os.path.isfile(framework_file): + return + + op_name_type_cache = {} + with open(framework_file, 'r') as src_file: + csv_reader = csv.reader(src_file) + _ = next(csv_reader) + + for row in csv_reader: + op_name_type_cache[row[3]] = row[5] + + op_type_time_cache = {} + for full_op_name, op_time in self._op_time_cache.items(): + op_type = op_name_type_cache.get(full_op_name) + if op_type_time_cache.get(op_type) is None: + op_type_time_cache[op_type] = [op_time, 1] + else: + op_type_time_cache[op_type][0] += op_time + op_type_time_cache[op_type][1] += 1 + + op_type_file_name = 'aicore_intermediate_' + self._device_id + '_type.csv' + op_type_file_path = os.path.join(self._profiling_dir, op_type_file_name) + with open(op_type_file_path, 'w') as type_file: + csv_writer = csv.writer(type_file) + csv_writer.writerow(self._header_aicore_type) + + for op_type, op_type_time_info in op_type_time_cache.items(): + type_info = [ + op_type, op_type_time_info[0], op_type_time_info[1], + round((op_type_time_info[0] / self._total_time) * 100, 2) + ] + csv_writer.writerow(type_info) + + def _parse_aicore_detail_time(self): + """Parse the parsed AICORE operator time file.""" + aicore_detail_file = os.path.join( + self._profiling_dir, + self._file_name_aicore_detail_time.format(self._device_id) + ) + if not os.path.isfile(aicore_detail_file): + return + + op_detail_file_name = 'aicore_intermediate_' + self._device_id + '_detail.csv' + op_detail_file_path = os.path.join( + self._profiling_dir, op_detail_file_name + ) + with open(aicore_detail_file, 'r') as src_file: + row = src_file.readline() + if row.startswith('op_name'): + _ = src_file.readline() + elif row.startswith('====='): + _ = src_file.readline() + _ = src_file.readline() + else: + return + + with open(op_detail_file_path, 'w') as detail_file: + csv_writer = csv.writer(detail_file) + csv_writer.writerow(self._header_aicore_detail) + + while True: + row = src_file.readline() + if not row: + break + + op_infos = row.split() + if op_infos[0] == 'total': + self._total_time = Decimal(op_infos[2]) + continue + self._op_time_cache[op_infos[0]] = Decimal(op_infos[1]) + csv_writer.writerow([op_infos[0], op_infos[1]]) + + def _parse_aicpu_time(self): + """Parse the parsed AICPU operator time file.""" + aicpu_file = os.path.join( + self._profiling_dir, + self._file_name_aicpu_time.format(self._device_id) + ) + if not os.path.isfile(aicpu_file): + return + + save_file_name = 'aicpu_intermediate_' + self._device_id + '.csv' + save_file_path = os.path.join(self._profiling_dir, save_file_name) + with open(aicpu_file, 'r') as src_file: + row = src_file.readline() + if not row.startswith('serial_number'): + return + with open(save_file_path, 'w') as save_file: + csv_writer = csv.writer(save_file) + csv_writer.writerow(self._header_aicpu) + + while True: + row = src_file.readline() + if not row: + break + infos = row.split() + if infos[0] == 'AI': + continue + csv_writer.writerow(infos) + + def _aicore_data_load(self): + """Load data according to the parsed AICORE operator types file.""" + op_type_file_path = os.path.join( + self._profiling_dir, + self._file_name_aicore_type_time.format(self._device_id) + ) + if not os.path.isfile(op_type_file_path): + logger.warning('The file <%s> does not exist.', op_type_file_path) + return + + with open(op_type_file_path, 'r') as file: + csv_reader = csv.reader(file) + _ = next(csv_reader) + for info in csv_reader: + self._aicore_data.append([info[0], float(info[1]), int(info[2]), float(info[3])]) + + def _aicore_detail_data_load(self): + """Load data according to the parsed AICORE operator file.""" + op_detail_file_path = os.path.join( + self._profiling_dir, + self._file_name_aicore_detail_info.format(self._device_id) + ) + framework_file_path = os.path.join( + self._profiling_dir, + self._file_name_framework.format(self._device_id) + ) + if not os.path.isfile(op_detail_file_path): + logger.warning('The file <%s> does not exist.', op_detail_file_path) + return + if not os.path.isfile(framework_file_path): + logger.warning('The file <%s> does not exist.', framework_file_path) + return + + framework_infos = dict() + with open(framework_file_path, 'r') as file: + csv_reader = csv.reader(file) + _ = next(csv_reader) + for info in csv_reader: + framework_infos[info[3]] = [ + info[3], info[4], info[5], info[6], json.loads(info[7]) if info[7] else None] + + with open(op_detail_file_path, 'r') as file: + csv_reader = csv.reader(file) + _ = next(csv_reader) + for info in csv_reader: + framework_info = framework_infos.get(info[0]) + self._aicore_detail_data.append( + [ + framework_info[1], framework_info[2], float(info[1]), + framework_info[3], framework_info[0], framework_info[4] + ] + ) + del framework_infos + + def _aicore_trace_data_load(self): + """Load data according to the parsed AICORE operator types file.""" + file_path = query_latest_trace_time_file(self._profiling_dir, int(self._device_id)) + if not file_path: + logger.error("Failed to find parsed trace time file.") + raise ProfilerFileNotFoundException('parsed step trace time file') + with open(file_path, 'r') as handle: + csv_reader = csv.reader(handle) + self.__column__ = next(csv_reader) + self._aicore_trace_data = list(csv_reader) + self._size = len(self._aicore_trace_data) - 1 + self._load_point_info() + + def _load_point_info(self): + """Load point info.""" + file_path = os.path.join(self._profiling_dir, 'step_trace_point_info.json') + if os.path.isfile(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + try: + self._point_info = json.load(file) + except (json.JSONDecodeError, TypeError) as err: + logger.warning(err) + raise ProfilerRawFileException('Fail to parse point info file.') + + def _query_for_all_reduce(self): + """ + Query for all reduce info. + + Returns: + list[dict], reduce information. Each item is the reduce info for one step. + The reduce info is format like: + {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}. + """ + self._aicore_trace_data_load() + reduce_infos = [] + for row_info in self._aicore_trace_data[:-1]: + row_info_dict = self._get_info_dict_from_row_data(row_info, 'systime') + reduce_info = self._sort_reduce_by_time(row_info_dict) + if reduce_info: + reduce_infos.extend(reduce_info) + + return reduce_infos + + def _get_info_dict_from_row_data(self, row_info, time_type): + """ + Get step info in dict format. + + Args: + row_info (list[str]): Step info, the value is corresponding to `__column__`. + time_type (str): The value type. `systime` keeps the original value. + `realtime` transforms the value in millisecond. Default: `realtime`. + + Returns: + dict, step trace information. The key is in `__column__`. + """ + row_info_dict = {} + for key, value in zip(self.__column__, row_info): + if key == 'step_num': + continue + value = to_int(value, key) + row_info_dict[key] = to_millisecond(value) if time_type == 'realtime' else value + return row_info_dict + + def _sort_reduce_by_time(self, row_info_dict): + """ + Sort reduce info by time. + + Args: + row_info_dict (dict): Step trace information. + + Returns: + list, including the all reduce info sorted by start time only. + [ + [reduce_field, stream_id, reduce_start, reduce_duration], + [...], + [...] + ] + """ + factor = 1e5 # convert time unit from 10ns to 1ms + reduce_pid = 10000 + reduce_info = [] + reduce_fields = [field_name for field_name in self.__column__ + if field_name.startswith('stream_') and not field_name.endswith('point')] + for reduce_field in reduce_fields: + reduce_start = row_info_dict.get(reduce_field + '_start_point') + reduce_start = reduce_start / factor \ + if reduce_start else 0 + reduce_duration = row_info_dict.get(reduce_field) + reduce_duration = reduce_duration / factor if reduce_duration else 0 + if not (reduce_start and reduce_duration): + logger.info("Reduce event missing value.") + continue + cur_stream_id = reduce_field.split('_', 2)[1] + reduce_meta = [reduce_field, int(cur_stream_id), reduce_start, + reduce_duration, reduce_pid] + reduce_info.append(reduce_meta) + + return reduce_info + + def _query_and_sort_by_op_type(self, filter_condition, op_type_order: list): + """ + Query the AICORE operator detail information by `filter_condition`, + and sort by `op_type_order` and execution time. + + Args: + filter_condition (dict): The filter condition. + op_type_order (list[str]): The name of the operator type in order. + + Returns: + dict, The results are filtered and sorted. + """ + self._aicore_detail_data_load() + if filter_condition is None: + filter_condition = {} + self._filter(filter_condition) + + type_detail_cache = {} + for detail_info in self._result: + op_type = detail_info[1] + if op_type not in op_type_order: + continue + infos = type_detail_cache.get(op_type) + if infos: + infos.append(detail_info) + else: + type_detail_cache[op_type] = [detail_info] + + result = [] + for op_type in op_type_order: + detail_infos = type_detail_cache.get(op_type) + if detail_infos is None: + continue + detail_infos.sort(key=lambda item: item[2], reverse=True) + result.extend(detail_infos) + + return { + 'col_name_detail': self._display_col_names_detail, + 'object': result + } + + def _filter(self, filter_condition): + """ + Filter the profiling data according to the filter condition. + + Args: + filter_condition (dict): The filter condition. + """ + def _inner_filter(item: list): + return self._default_filter(item, filter_condition) + + def _inner_map(item: list): + inner_item = item[0:4] + if is_display_full_op_name: + inner_item.append(item[4]) + if is_display_detail: + inner_item.append(item[5]) + return inner_item + + is_display_detail = filter_condition.get('is_display_detail', True) + is_display_full_op_name = filter_condition.get( + 'is_display_full_op_name', True + ) + self._set_display_col_name(is_display_detail, is_display_full_op_name) + if is_display_detail and is_display_full_op_name: + self._result = list(filter(_inner_filter, self._aicore_detail_data)) + else: + self._result = list( + map(_inner_map, filter(_inner_filter, self._aicore_detail_data)) + ) + + def _default_filter(self, item, condition): + """ + The default filter method. + + Args: + item (list[Union[str, float, int]]): A piece of data to be filtered. + condition (dict): The filter condition. + + Returns: + bool, `True` if the item is satisfied. + """ + for condition_key, condition_value in condition.items(): + if condition_key in self._none_filter_condition_key: + continue + if condition_key in self._col_names_detail: + index = self._col_names_detail.index(condition_key) + actual_value = item[index] + for exp_key, exp_value in condition_value.items(): + if not self._is_match_condition( + exp_key, exp_value, actual_value): + return False + return True + + def _is_match_condition(self, exp_key, exp_value, actual_value): + """ + Check whether the actual value meets the expect condition. + + Args: + exp_key (str): Expect key of the condition. + exp_value (str): Expect value. + actual_value (str): Actual value. + + Returns: + bool, `True` if the actual meets the expect condition, else `False`. + """ + if exp_key == 'in': + if actual_value not in exp_value: + return False + elif exp_key == 'not_in': + if actual_value in exp_value: + return False + elif exp_key == 'partial_match_str_in': + for partial_match_str in exp_value: + if partial_match_str in actual_value: + return True + return False + else: + return False + + return True + + def _set_display_col_name(self, is_display_detail, is_display_full_op_name): + """ + Set the display column name according to the filter condition. + + Args: + is_display_detail (bool): Whether to display the detailed operator + information. + is_display_full_op_name (bool): Whether to display the operator full + name. + """ + self._display_col_names_detail = self._col_names_detail[0:4] + if is_display_full_op_name: + self._display_col_names_detail.append(self._col_names_detail[4]) + if is_display_detail: + self._display_col_names_detail.append(self._col_names_detail[5]) + + +class TimelineAnalyser: + """ + Analyse timeline data from file. + """ + __col_names__ = ['op_name', 'stream_id', 'start_time', 'duration'] + _output_timeline_data_file_path = 'output_timeline_data_{}.txt' + _min_cycle_counter_file_path = 'min_cycle_counter_{}.txt' + _display_filename = 'timeline_display_{}.json' + _timeline_summary_filename = 'timeline_summary_{}.json' + _timeline_meta = [] + _timeline_summary = { + 'total_time': 0, + 'num_of_streams': 0, + 'num_of_ops': 0, + 'op_exe_times': 0 + } + + def __init__(self, profiling_dir, device_id): + self._profiling_dir = profiling_dir + self._device_id = device_id + + def write_timeline(self): + """Load data according to the parsed profiling files.""" + # Write timeline to file. + logger.info('Writing timeline file...') + self.write_timeline_to_json_by_limitation() + logger.info('Finished file writing!') + + def write_timeline_to_json_by_limitation(self): + """Write timeline to json by limitation.""" + display_filename = self._display_filename.format(self._device_id) + display_file_path = os.path.join( + self._profiling_dir, + display_filename + ) + display_file_path = validate_and_normalize_path(display_file_path) + + length = len(self._timeline_meta) + try: + with open(display_file_path, 'w') as json_file: + json_file.write('[') + for index, item in enumerate(self._timeline_meta): + json.dump(item, json_file) + file_size = os.path.getsize(display_file_path) + if file_size > SIZE_LIMIT: + break + if index == length - 1: + break + json_file.write(',') + json_file.write(']') + except (IOError, OSError) as err: + logger.error('Error occurred when write timeline display file: %s', err) + raise ProfilerIOException + + def write_timeline_summary(self): + """Write timeline summary to json.""" + timeline_summary_file_path = os.path.join( + self._profiling_dir, + self._timeline_summary_filename.format(self._device_id) + ) + + timeline_summary_file_path = validate_and_normalize_path(timeline_summary_file_path) + + try: + with open(timeline_summary_file_path, 'w') as json_file: + json.dump(self._timeline_summary, json_file) + except (IOError, OSError) as err: + logger.error('Error occurred when write timeline summary file: %s', err) + raise ProfilerIOException + + def _load_timeline_data(self): + """Load timeline data from file.""" + file_path = os.path.join( + self._profiling_dir, + self._output_timeline_data_file_path.format(self._device_id) + ) + file_path = validate_and_normalize_path(file_path) + if not os.path.exists(file_path): + logger.error("Failed to find parsed timeline file.") + raise ProfilerFileNotFoundException('parsed timeline file') + + timeline_list = [] + try: + with open(file_path, 'r') as f_obj: + for line in f_obj: + if not line.startswith('op_name'): + line_list = line.strip('\n').split(',') + timeline_list.append(line_list) + except (IOError, OSError) as err: + logger.error('Error occurred when read timeline intermediate file: %s', err) + raise ProfilerIOException + + return timeline_list + + def _parse_timeline_data(self, timeline, min_cycle_counter): + """Parse timeline data.""" + # factor to convert the time unit from 1ms to 1us for timeline display + factor = 1000 + op_meta = TimelineContainer(timeline) + timeline_dict = {} + timeline_dict['name'] = op_meta.op_name + timeline_dict['ph'] = 'X' + timeline_dict['tid'] = op_meta.stream_id + timeline_dict['ts'] = (op_meta.start_time - min_cycle_counter) * factor + dur = op_meta.duration * factor + timeline_dict['dur'] = dur + if op_meta.pid is None: + timeline_dict['pid'] = int(self._device_id) + # Update total time of operator execution. + self._timeline_summary['total_time'] += dur + else: # AllReduce and AI CPU pid + timeline_dict['pid'] = op_meta.pid + self._timeline_meta.append(timeline_dict) + + @staticmethod + def _update_num_of_streams(timeline, stream_count_dict): + """Update number of streams.""" + stream_id = timeline[1] + if stream_id not in stream_count_dict.keys(): + stream_count_dict[stream_id] = 1 + else: + stream_count_dict[stream_id] += 1 + + def get_min_cycle_counter(self): + """ + Get minimum cycle counter. + + Returns: + float, the minimum value of the cycle counter. + """ + file_path = os.path.join( + self._profiling_dir, + self._min_cycle_counter_file_path.format(self._device_id) + ) + + file_path = validate_and_normalize_path(file_path) + + if os.path.exists(file_path): + try: + with open(file_path, 'r') as f_obj: + min_cycle_counter = f_obj.read() + min_cycle_counter = float(min_cycle_counter) \ + if not min_cycle_counter == 'inf' else 0 + except (IOError, OSError) as err: + logger.error('Error occurred when read minimum cycle counter: %s', err) + raise ProfilerIOException + else: + min_cycle_counter = 0 + logger.info("No min cycle counter recorded.") + + return min_cycle_counter + + def init_timeline(self, all_reduce_info, framework_info, aicpu_info, min_cycle_counter): + """ + Init timeline metadata, adding all collected info. + + Args: + all_reduce_info (list[list]): The metadata of AllReduce operator. + framework_info (dict): The framework metadata. + aicpu_info (dict): The metadata of AI CPU operator. + min_cycle_counter (float): The minimum cycle counter of the timeline. + """ + if min_cycle_counter == float('inf'): + min_cycle_counter = 0 + + logger.info('Initiating timeline...') + timeline_list = self._load_timeline_data() + self._timeline_summary['op_exe_times'] = len(timeline_list) + + # Add AllReduce info to timeline temp list and sort by start time. + if all_reduce_info: + logger.debug('AllReduce info found. Start adding info into timeline...') + timeline_list.extend(all_reduce_info) + timeline_list.sort(key=lambda x: float(x[2])) + + # Add AI CPU data into timeline temp list and sort by start time. + aicpu_data = aicpu_info.get('info') + if aicpu_data: + timeline_list.extend(aicpu_data) + timeline_list.sort(key=lambda x: float(x[2])) + self._timeline_summary['op_exe_times'] += aicpu_info.get('op_exe_times', 0) + self._timeline_summary['num_of_streams'] += aicpu_info.get('num_of_streams', 0) + self._timeline_summary['num_of_ops'] += aicpu_info.get('num_of_ops', 0) + self._timeline_summary['total_time'] += aicpu_info.get('total_time', 0) + + # Init a dict for counting the num of streams. + stream_count_dict = {} + for timeline in timeline_list: + self._parse_timeline_data(timeline, min_cycle_counter) + # Updating the collection of streams. + if len(timeline) == 4: + self._update_num_of_streams(timeline, stream_count_dict) + + # Get framework metadata. + framework_obj_list = framework_info.get('object') + # The length of list is the number of operators. + self._timeline_summary['num_of_ops'] += len(framework_obj_list) + self._add_framework_info(framework_obj_list) + logger.info('Finished adding info into timeline...') + + # Update timeline summary info + self._timeline_summary['num_of_streams'] += len(stream_count_dict.keys()) + + def _add_framework_info(self, framework_obj_list): + """ + Add framework info into timeline metadata. + + Args: + framework_obj_list (list): The framework metadata. + """ + logger.debug('Start adding framework info into timeline...') + # Get the framework info that will be written into timeline. + framework_info_dict = {} + for framework_obj in framework_obj_list: + op_name = framework_obj[0] + op_type = framework_obj[1] + op_full_name = framework_obj[4] + op_info = framework_obj[5] + framework_info_dict[op_full_name] = { + 'name': op_name, + 'args': { + 'type': op_type, + 'fullname': op_full_name + } + } + framework_info_dict[op_full_name]['args'].update(op_info) + + # Insert framework info into timeline. + for timeline_item in self._timeline_meta: + op_full_name = timeline_item.get('name') + framework_item = framework_info_dict.get(op_full_name) + if framework_item: + timeline_item['name'] = framework_item.get('name') + timeline_item['args'] = framework_item.get('args') + logger.debug('Finished adding framework info into timeline...') diff --git a/mindspore/profiler/parser/minddata_parser.py b/mindspore/profiler/parser/minddata_parser.py new file mode 100644 index 0000000000..27ab95f705 --- /dev/null +++ b/mindspore/profiler/parser/minddata_parser.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Minddata aicpu parser.""" +import os + +from mindspore.profiler.common.util import get_file_join_name, fwrite_format +from mindspore import log as logger + + +class MinddataParser: + """Minddata Aicpu Parser.""" + @staticmethod + def parse_minddata_aicpu_data(minddata_aicpu_source_path): + """ + Parse minddata get_next info which contains queue size and execute time. + + Args: + minddata_aicpu_source_path (str): the source file path. + + Returns: + list[Union[str, float]], the converted data. + """ + result = list() + try: + with open(minddata_aicpu_source_path) as source_data_file: + source_data = source_data_file.read() + step_data = source_data.split("\x00") + for one_step in step_data: + if one_step: + node_info = one_step.split(", ") + node_name, node_start, node_end, queue_size = "", 0, 0, 0 + if node_info: + node_name = node_info[0].replace("Node:", "") + if len(node_info) > 2: + node_start = node_info[1].replace("Run start:", "") + if node_start.isdigit(): + node_start = int(node_start) + node_end = node_info[2].replace("Run end:", "") + if node_end.isdigit(): + node_end = int(node_end) + if len(node_info) > 3: + queue_size = node_info[3].replace("queue size:", "") + if queue_size.isdigit(): + queue_size = int(queue_size) + + one_step_list = [node_name, node_start, node_end, queue_size] + result.append(one_step_list) + except OSError: + logger.error("Open get_next profiling file error.") + + return result + + @staticmethod + def execute(source_path, output_path, device_id): + """ + Execute the parser. + + Args: + source_path (str): the source file path. + output_path (str): the output file path. + device_id (str): the device id. + """ + col_names = ["node_name", "start_time", "end_time", "queue_size"] + minddata_aicpu_source_path = get_file_join_name( + input_path=source_path, file_name='DATA_PREPROCESS.dev.AICPUMI') + if not minddata_aicpu_source_path: + minddata_aicpu_source_path = get_file_join_name( + input_path=os.path.join(source_path, "data"), file_name='DATA_PREPROCESS.dev.AICPUMI') + if not minddata_aicpu_source_path: + return + minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + device_id + ".txt") + + minddata_aicpu_data = MinddataParser.parse_minddata_aicpu_data(minddata_aicpu_source_path) + if minddata_aicpu_data: + fwrite_format(minddata_aicpu_output_path, " ".join(col_names), is_start=True) + fwrite_format(minddata_aicpu_output_path, minddata_aicpu_data, is_start=True) diff --git a/mindspore/profiler/parser/minddata_pipeline_parser.py b/mindspore/profiler/parser/minddata_pipeline_parser.py new file mode 100644 index 0000000000..ea0c9ae366 --- /dev/null +++ b/mindspore/profiler/parser/minddata_pipeline_parser.py @@ -0,0 +1,287 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Thr parser for parsing minddata pipeline files.""" +import csv +import json +import os +from queue import Queue + +from mindspore.profiler.common.exceptions.exceptions import \ + ProfilerPathErrorException, ProfilerFileNotFoundException, \ + ProfilerDirNotFoundException, ProfilerRawFileException +from mindspore import log as logger +from mindspore.profiler.common.validator.validate_path import \ + validate_and_normalize_path + + +class MinddataPipelineParser: + """ + Thr parser for parsing minddata pipeline files. + + Args: + source_dir (str): The minddata pipeline source dir. + device_id (str): The device ID. + output_path (str): The directory of the parsed file. Default: `./`. + + Raises: + ProfilerPathErrorException: If the minddata pipeline file path or + the output path is invalid. + ProfilerFileNotFoundException: If the minddata pipeline file or + the output dir does not exist. + """ + _raw_pipeline_file_name = 'pipeline_profiling_{}.json' + _parsed_pipeline_file_name = 'minddata_pipeline_raw_{}.csv' + _col_names = [ + 'op_id', 'op_type', 'num_workers', 'output_queue_size', + 'output_queue_average_size', 'output_queue_length', + 'output_queue_usage_rate', 'sample_interval', 'parent_id', 'children_id' + ] + + def __init__(self, source_dir, device_id, output_path='./'): + self._device_id = device_id + self._pipeline_path = self._get_pipeline_path(source_dir) + self._save_path = self._get_save_path(output_path) + + @property + def save_path(self): + """ + The property of save path. + + Returns: + str, the save path. + """ + return self._save_path + + def parse(self): + """ + Parse the minddata pipeline files. + + Raises: + ProfilerRawFileException: If fails to parse the raw file of + minddata pipeline or the file is empty. + """ + with open(self._pipeline_path, 'r') as file: + try: + pipeline_info = json.load(file) + except (json.JSONDecodeError, TypeError) as err: + logger.warning(err) + raise ProfilerRawFileException( + 'Fail to parse minddata pipeline file.' + ) + if not pipeline_info: + logger.warning('The minddata pipeline file is empty.') + raise ProfilerRawFileException( + 'The minddata pipeline file is empty.' + ) + + self._parse_and_save(pipeline_info) + + def _get_pipeline_path(self, source_dir): + """ + Get the minddata pipeline file path. + + Args: + source_dir (str): The minddata pipeline source dir. + + Returns: + str, the minddata pipeline file path. + """ + pipeline_path = os.path.join( + source_dir, + self._raw_pipeline_file_name.format(self._device_id) + ) + + try: + pipeline_path = validate_and_normalize_path(pipeline_path) + except RuntimeError: + logger.warning('Minddata pipeline file is invalid.') + raise ProfilerPathErrorException('Minddata pipeline file is invalid.') + if not os.path.isfile(pipeline_path): + logger.warning( + 'The minddata pipeline file <%s> not found.', pipeline_path + ) + raise ProfilerFileNotFoundException(pipeline_path) + + return pipeline_path + + def _get_save_path(self, output_path): + """ + Get the save path. + + Args: + output_path (str): The output dir. + + Returns: + str, the save path. + """ + try: + output_dir = validate_and_normalize_path(output_path) + except ValidationError: + logger.warning('Output path is invalid.') + raise ProfilerPathErrorException('Output path is invalid.') + if not os.path.isdir(output_dir): + logger.warning('The output dir <%s> not found.', output_dir) + raise ProfilerDirNotFoundException(output_dir) + return os.path.join( + output_dir, self._parsed_pipeline_file_name.format(self._device_id) + ) + + def _parse_and_save(self, pipeline_info): + """ + Parse and save the parsed minddata pipeline file. + + Args: + pipeline_info (dict): The pipeline info reads from the raw file of + the minddata pipeline. + + Raises: + ProfilerRawFileException: If the format of minddata pipeline raw + file is wrong. + """ + sample_interval = pipeline_info.get('sampling_interval') + op_info = pipeline_info.get('op_info') + if sample_interval is None or not op_info: + raise ProfilerRawFileException( + 'The format of minddata pipeline raw file is wrong.' + ) + + op_id_info_cache = {} + for item in op_info: + op_id_info_cache[item.get('op_id')] = item + + with open(self._save_path, 'w') as save_file: + csv_writer = csv.writer(save_file) + csv_writer.writerow(self._col_names) + self._parse_and_save_op_info( + csv_writer, op_id_info_cache, sample_interval + ) + + def _parse_and_save_op_info(self, csv_writer, op_id_info_cache, + sample_interval): + """ + Parse and save the minddata pipeline operator information. + + Args: + csv_writer (csv.writer): The csv writer. + op_id_info_cache (dict): The operator id and information cache. + sample_interval (int): The sample interval. + + Raises: + ProfilerRawFileException: If the operator that id is 0 does not exist. + """ + queue = Queue() + root_node = op_id_info_cache.get(0) + if not root_node: + raise ProfilerRawFileException( + 'The format of minddata pipeline raw file is wrong, ' + 'the operator that id is 0 does not exist.' + ) + root_node['parent_id'] = None + queue.put_nowait(root_node) + + while not queue.empty(): + node = queue.get_nowait() + self._update_child_node(node, op_id_info_cache) + csv_writer.writerow(self._get_op_info(node, sample_interval)) + + op_id = node.get('op_id') + children_ids = node.get('children') + if not children_ids: + continue + for child_op_id in children_ids: + sub_node = op_id_info_cache.get(child_op_id) + sub_node['parent_id'] = op_id + queue.put_nowait(sub_node) + + def _update_child_node(self, node, op_id_info_cache): + """ + Updates the child node information of the operator. + + Args: + node (dict): The node represents an operator. + op_id_info_cache (dict): The operator id and information cache. + """ + child_op_ids = node.get('children') + if not child_op_ids: + return + + queue = Queue() + self._cp_list_item_to_queue(child_op_ids, queue) + + new_child_op_ids = [] + while not queue.empty(): + child_op_id = queue.get_nowait() + child_node = op_id_info_cache.get(child_op_id) + if child_node is None: + continue + metrics = child_node.get('metrics') + if not metrics or not metrics.get('output_queue'): + op_ids = child_node.get('children') + if op_ids: + self._cp_list_item_to_queue(op_ids, queue) + else: + new_child_op_ids.append(child_op_id) + + node['children'] = new_child_op_ids + + def _get_op_info(self, op_node, sample_interval): + """ + Get the operator information. + + Args: + op_node (dict): The node represents an operator. + sample_interval (int): The sample interval. + + Returns: + list[str, int, float], the operator information. + """ + queue_size = None + queue_average_size = None + queue_length = None + queue_usage_rate = None + metrics = op_node.get('metrics') + if metrics: + output_queue = metrics.get('output_queue') + if output_queue: + queue_size = output_queue.get('size') + queue_average_size = sum(queue_size) / len(queue_size) + queue_length = output_queue.get('length') + queue_usage_rate = queue_average_size / queue_length + + children_id = op_node.get('children') + op_info = [ + op_node.get('op_id'), + op_node.get('op_type'), + op_node.get('num_workers'), + queue_size, + queue_average_size, + queue_length, + queue_usage_rate, + sample_interval, + op_node.get('parent_id'), + children_id if children_id else None + ] + return op_info + + def _cp_list_item_to_queue(self, inner_list, queue): + """ + Copy the contents of a list to a queue. + + Args: + inner_list (list): The list. + queue (Queue): The target queue. + """ + for item in inner_list: + queue.put_nowait(item) diff --git a/mindspore/profiler/parser/optime_parser.py b/mindspore/profiler/parser/optime_parser.py new file mode 100644 index 0000000000..842376fcf3 --- /dev/null +++ b/mindspore/profiler/parser/optime_parser.py @@ -0,0 +1,245 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Op compute time files parser.""" +import os +from mindspore.profiler.common.util import fwrite_format +from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \ + ProfilerIOException +from mindspore import log as logger +from mindspore.profiler.common.validator.validate_path import validate_and_normalize_path +from mindspore.profiler.parser.container import HWTSContainer + +TIMELINE_FILE_COLUMN_TITLE = 'op_name, stream_id, start_time(ms), duration(ms)' + +class OPComputeTimeParser: + """ + Join hwts info and framework info, get op time info, and output to the result file. + + Args: + hwts_output_file (str): The file path of hwts_output_file. Such as: './output_format_data_hwts_0.txt". + output_filename (str): The output data file path and name. Such as: './output_op_compute_time_0.txt'. + op_task_info (dict): The task and op relation info. The format: {task_id, [opname, stream_id, block dim]}. + """ + + _dst_file_title = 'title:op compute time' + _dst_file_column_title = 'op_name compute_time(ms) stream_id' + _dst_file_column_title += '\n------------ --------------- ---------' + + def __init__(self, hwts_output_file, output_filename, op_task_info, + output_path, device_id): + hwts_output_file = validate_and_normalize_path(hwts_output_file) + self._hwts_output_file = hwts_output_file + self._output_filename = output_filename + self._op_task_info = op_task_info + self._output_path = output_path + self._device_id = device_id + self._min_cycle_counter = float("inf") + + def _get_op_task_id_map(self): + """ + Read hwts data file, get the task time info. + + Returns: + list: all hwts task time info. + """ + + op_map_result = [] + hwts_list = [] + + if not os.path.exists(self._hwts_output_file): + logger.error('The hwts output file does not exist.') + raise ProfilerFileNotFoundException('hwts output file') + + with open(self._hwts_output_file, 'r') as data_file: + lines = data_file.readlines() + for line in lines: + if line.startswith("Start of task") or line.startswith("End of task"): + line_split = line.split() + container = HWTSContainer(line_split) + hwts_list.append(container) + + # hwts op map by taskId + for hwts in hwts_list: + if hwts.task_id in self._op_task_info.keys(): + hwts.op_name = self._op_task_info[hwts.task_id] + op_map_result.append(hwts) + + return op_map_result + + def execute(self): + """Execute the parser, compute all op, get op time, and write it to the output file.""" + # Calculate the execution time of operators, + # and update the minimum cycle counter. + tmp_result_data = self._calculate_op_execution_time() + + # Convert time units from nanoseconds to milliseconds. + # The unit of the cycle counter is 10 nanoseconds. + op_name_time_dict = {} + op_name_stream_dict = {} + op_name_count_dict = {} + op_name_task_dict = {} + op_name_start_time = {} + self._convert_op_time_unit( + tmp_result_data, op_name_time_dict, op_name_stream_dict, + op_name_count_dict, op_name_task_dict, op_name_start_time + ) + + result_data = "" + total_time = 0 + for op_name, time in op_name_time_dict.items(): + if op_name in op_name_stream_dict.keys(): + stream_id = op_name_stream_dict[op_name] + avg_time = time / op_name_count_dict[op_name] + total_time += avg_time + result_data += ("%s %s %s\n" %(op_name, str(avg_time), stream_id)) + result_data += ("total op %s 0" %(str(total_time))) + + timeline_data = [] + for op_name, time in op_name_time_dict.items(): + if op_name in op_name_stream_dict.keys(): + stream_id = op_name_stream_dict[op_name] + start_time_list = op_name_start_time.get(op_name) + for (start_time, duration) in start_time_list: + timeline_data.append([op_name, stream_id, start_time, duration]) + + # Write the metadata of operators into the file, + # including operator name, average time, and stream id. + self._write_op_time_into_file(result_data) + # Write the timeline data into file, + # including operator name, stream id, start time, and duration. + self._write_timeline_data_into_file(timeline_data) + + def _write_op_time_into_file(self, result_data): + """ + Write the metadata of operators into the file, including + op name, average time, and stream id. + + Args: + result_data (str): The metadata to be written into the file. + 'op_name_1', 'avg_time_1', 'stream_id_1', + 'op_name_2', 'avg_time_2', 'stream_id_2', + ... + """ + + fwrite_format(self._output_filename, data_source=self._dst_file_title, is_start=True) + fwrite_format(self._output_filename, data_source=self._dst_file_column_title) + fwrite_format(self._output_filename, data_source=result_data) + + def _write_timeline_data_into_file(self, timeline_data): + """ + Write the timeline information into the file, including + operator name, stream id, start time and duration. + + Args: + timeline_data (list): The metadata to be written into the file. + [ + ['op_name_1', 'stream_id_1', 'start_time_1', 'durarion_1'], + ['op_name_2', 'stream_id_2', 'start_time_2', 'durarion_2'], + [...] + ] + """ + # sorted by start times + timeline_data.sort(key=lambda x: float(x[2])) + filename = 'output_timeline_data_{}.txt'.format(self._device_id) + file_path = os.path.join(self._output_path, filename) + file_path = validate_and_normalize_path(file_path) + + # write to file + try: + with open(file_path, 'w') as f_obj: + f_obj.write(TIMELINE_FILE_COLUMN_TITLE + '\n') + for timeline in timeline_data: + timeline = [str(item) for item in timeline] + f_obj.write(','.join(timeline) + '\n') + except (IOError, OSError) as err: + logger.error('Error occurred when writing intermediate timeline file: %s', err) + raise ProfilerIOException + + def _calculate_op_execution_time(self): + """ + Calculate the execution time of each operator. + + Returns: + list, including the intermediate data of op execution time. + """ + tmp_result_data = [] + op_map_list = self._get_op_task_id_map() + + cur_index = 0 + length = len(op_map_list) + min_cycle_counter = float("inf") + while cur_index < length: + if cur_index + 1 == length: + break + + op_start = op_map_list[cur_index] + op_end = op_map_list[cur_index + 1] + if op_start.status == "Start" and op_end.status == "End" \ + and op_start.op_name == op_end.op_name: + op_start.duration = op_end.cycle_counter - op_start.cycle_counter + tmp_result_data.append(op_start) + cur_index += 2 + if not op_start.op_name.startswith("assign"): + min_cycle_counter = min(min_cycle_counter, op_start.cycle_counter) + else: + cur_index += 1 + + # Update the value of minimum cycle counter. + self._min_cycle_counter = min_cycle_counter / 1e5 # Convert the time unit from 10ns to 1ms + + return tmp_result_data + + def _convert_op_time_unit(self, op_data_list, op_name_time_dict, op_name_stream_dict, + op_name_count_dict, op_name_task_dict, op_name_start_time): + """ + Calculate the execution time of operator and convert it into millisecond. + + Args: + op_data_list (list): The list of operator metadata. + op_name_time_dict (dict): The mapping relation of operator name and its execution time. + op_name_stream_dict (dict): The mapping relation of operator name and its stream id. + op_name_count_dict (dict): The mapping relation of operator name and its count. + op_name_task_dict (dict): The mapping relation of operator name and its task id. + op_name_start_time (dict): The mapping relation of operator name and its start time. + """ + factor = 1e5 + for item in op_data_list: + op_name = item.op_name + # Unit conversion: converting the cycle counter into ms. + op_start_time_str = str(item.cycle_counter / factor) + op_duration = item.duration / factor + op_duration_str = str(item.duration / factor) + if op_name in op_name_time_dict.keys(): + op_name_time_dict[op_name] += op_duration + if item.task_id == op_name_task_dict[op_name]: + op_name_count_dict[op_name] += 1 + op_name_start_time[op_name].append( + (op_start_time_str, op_duration_str) + ) + + else: + op_name_time_dict[op_name] = op_duration + op_name_stream_dict[op_name] = item.stream_id + op_name_task_dict[op_name] = item.task_id + op_name_count_dict[op_name] = 1 + op_name_start_time[op_name] = [] + op_name_start_time[op_name].append( + (op_start_time_str, op_duration_str) + ) + + @property + def min_cycle_counter(self): + """Get minimum cycle counter.""" + return self._min_cycle_counter diff --git a/mindspore/profiler/parser/step_trace_parser.py b/mindspore/profiler/parser/step_trace_parser.py new file mode 100644 index 0000000000..b39820d4bc --- /dev/null +++ b/mindspore/profiler/parser/step_trace_parser.py @@ -0,0 +1,382 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The parser for step trace data.""" +import csv +import json +import os +import stat +import struct +from collections import namedtuple +from decimal import Decimal + +from mindspore.profiler.common.exceptions.exceptions import ProfilerPathErrorException, \ + JobIdMismatchException, ProfilerIOException +from mindspore import log +from mindspore.profiler.common.util import get_summary_for_step_trace + +StepTraceStruct = namedtuple( + 'TrainingTraceStruct', ['tag_id', 'task_id', 'stream_id', 'sys_count'] +) + + +class StepTraceParser: + """ + The parser for step trace data. + + Args: + input_dir (str): The directory that contains original step trace data. + output_file_path (str): The output file path. + job_id (int): The job id used to define the start of new step. Default: 0. + skip_first_step (bool): Whether skip the first step or not. + """ + _event_size = 20 + _fp_tag = 1 + _bp_tag = 2 + _end_tag = 255 + + def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False): + self._input_dir = input_dir + self._output_path = output_file_path + self._job_id = job_id + self._skip_first_step = skip_first_step + self._result = [] + self._header = [] + self._step_num = 0 + self._tag_map = {} + + @property + def output_file(self): + """The property of step trace header.""" + file_name = self._output_path.rsplit('/', 2) + return file_name[-1] if len(file_name) == 3 else '' + + def show(self): + """The property of step trace info.""" + summary_info = {} + if self._result: + summary_info = get_summary_for_step_trace(self._result[-1], self._header) + summary_info['total_steps'] = len(self._result) - 1 + print('\nStep trace summary info (unit: syscnt):') + print(summary_info) + print('\nThe step trace parse result saves under ${summary_dir}/profiler/%s' + % self.output_file) + + def parse_and_save(self): + """Parse step trace files and save the result.""" + try: + source_files = self._get_step_trace_files() + self._parse(source_files) + self._save() + except IOError as err: + log.warning(err) + raise ProfilerIOException() + else: + log.info("Finish to save intermediate result for step trace file.") + + def record_point_info(self, point_info, output_path): + """ + Record point info into json. + + Args: + point_info (dict): The point info about tag id and relative op name. + output_path (str): The output path for saving point info. + + Returns: + dict, parsed point info. + """ + points = { + 'fp_start': point_info.get(self._fp_tag, ''), + 'bp_end': point_info.get(self._bp_tag, '') + } + try: + with open(output_path, 'w') as json_file: + json.dump(points, json_file) + os.chmod(output_path, stat.S_IREAD) + except (IOError, OSError) as err: + log.warning('Failed to save point info. %s', err) + raise ProfilerIOException + return points + + def update_tag_op_type_map(self, point_info): + """ + update the map from tag id to op type. + + Args: + point_info (dict): The point info about tag id and relative op name. + """ + tag_map = {} + for tag, op_name in point_info.items(): + op_type = self._get_op_type(tag, op_name) + tag_map[tag] = op_type + log.info("Get tag types for step trace analysis: %s", tag_map) + self._tag_map = tag_map + + def _get_op_type(self, tag, name): + """ + Get op type from tag and name. + + Args: + tag (int): The tag id. + name (str): The op name. + + Returns: + str, the op type. + """ + tag_map = {self._fp_tag: 'fp', self._bp_tag: 'bp', self._end_tag: 'end'} + # get solid tag type + op_type = tag_map.get(tag, '') + if op_type: + return op_type + # check if the tag is step tag. + if tag > self._end_tag or tag == 0: + return 'start' + # analyze the reduce tag + op_type = name.rsplit('/', 1)[-1].split('-')[0] + if not op_type: + log.warning("Unexpected op name:%s", name) + + return op_type + + def _get_step_trace_files(self): + """Get step trace files.""" + # step trace files may under $profiler_dir or $profiler_dir/data + profiler_dir = self._input_dir + step_trace_files = self._search_file(profiler_dir) + if not step_trace_files: + # try to find step trace files under $profiler_dir/data + profiler_dir = os.path.join(profiler_dir, 'data') + step_trace_files = self._search_file(profiler_dir) + if not step_trace_files: + raise ProfilerPathErrorException('Training trace file does not exist.') + + return step_trace_files + + @staticmethod + def _search_file(input_dir): + """Search step trace file under specific input directory.""" + # validate input_dir + if not os.path.isdir(input_dir): + raise ProfilerPathErrorException( + '{} does not exist or is not a dir'.format(input_dir) + ) + # get step trace files + files = os.listdir(input_dir) + step_trace_files = list( + filter( + lambda file: file.startswith('training_trace') and not file.endswith('.done'), + files + ) + ) + # validate result + if len(step_trace_files) > 1: + # the format of file name is like + # `training_trace.46.dev.profiler_default_tag.$id.slice_$number` + # use the $number as the sorted key + try: + step_trace_files.sort(key=lambda path: int(path.rsplit('_', 1)[-1])) + except ValueError as err: + log.warning("Unable to parse file names: %s. %s", step_trace_files, err) + step_trace_files = [] + + file_paths = [os.path.join(input_dir, file) for file in step_trace_files] + log.info("Find %d step trace files.", len(file_paths)) + return file_paths + + def _parse(self, source_files): + """Parse source step trace files.""" + log.info("Start to parse step trace file.") + event_info = {} + for source_file in source_files: + with open(source_file, 'rb') as handler: + content = handler.read() + for step_trace in self._get_next_step_trace(content, event_info): + if self._skip_first_step: + self._skip_first_step = False + continue + self._record_trace_event(step_trace) + self._record_average_info() + log.info("Finish to parse step trace file.") + + def _get_next_step_trace(self, content, event_info): + """ + Get next step trace info. + + Args: + content (bytes): The input step trace info. + event_info (dict): The event info. + + Returns: + Generator, return the step trace one by one. + """ + for pos in range(0, len(content), 20): + next_event = self._get_trace_struct(content[pos:pos + self._event_size]) + self._construct_event_info(next_event, event_info) + if event_info.get('end'): + yield event_info + + def _get_trace_struct(self, bin_info): + """Translate event info to StepTraceStruct.""" + if len(bin_info) == self._event_size: + parsed_info = struct.unpack('=QHHQ', bin_info) + return StepTraceStruct(*parsed_info) + return None + + def _construct_event_info(self, next_event, event_info): + """Construct event info according to next_event.""" + min_job_id = 255 + step_flag: bool = lambda tag: tag > min_job_id or tag == 0 + end_flag: bool = lambda tag: tag == min_job_id + fp_flag: bool = lambda tag: tag == self._fp_tag + bp_flag: bool = lambda tag: tag == self._bp_tag + + def _on_step_event(): + """Handle step event.""" + self._validate_tag_id(tag_id) + start_time = event_info.get('end', '-') + event_info.clear() + event_info['start'] = start_time + event_info['reduce'] = {} + + def _on_reduce_event(reduce_tag_id): + """Handle reduce event.""" + stream_id = next_event.stream_id + if event_info['reduce'].get(stream_id): + event_info['reduce'][stream_id].append((reduce_tag_id, sys_count)) + else: + event_info['reduce'][stream_id] = [(reduce_tag_id, sys_count)] + + tag_id = next_event.tag_id + sys_count = next_event.sys_count + if end_flag(tag_id): + event_info['end'] = sys_count + elif step_flag(tag_id): + _on_step_event() + elif fp_flag(tag_id): + event_info['fp'] = sys_count + elif bp_flag(tag_id): + event_info['bp'] = sys_count + else: + _on_reduce_event(tag_id) + + def _validate_tag_id(self, job_id): + """Check the job id in source step trace file is same as user set.""" + if not self._job_id: + self._job_id = job_id + elif self._job_id != job_id: + raise JobIdMismatchException() + + def _record_trace_event(self, step_trace): + """Record trace event.""" + self._step_num += 1 + start_time = step_trace.get('start') + end_time = step_trace.get('end') + fp_time = step_trace.get('fp') + bp_time = step_trace.get('bp') + if not (start_time and end_time and fp_time and bp_time): + log.warning("The step %d lacks basic time.", self._step_num) + return + if start_time == '-': + start_time = fp_time + row_data = { + 'step_num': self._step_num, + 'start_point': start_time, + 'end_point': end_time, + 'total': end_time - start_time, + 'fp_point': fp_time, + 'bp_point': bp_time, + 'iteration_interval': fp_time - start_time, + 'fp_and_bp': bp_time - fp_time, + 'tail': end_time - bp_time + } + # update reduce info + self._update_reduce_info(step_trace, row_data) + # save the row data + if not self._header: + self._header = list(row_data.keys()) + row_data_list = [row_data.get(header_name, 0) for header_name in self._header] + self._result.append(row_data_list) + + def _update_reduce_info(self, step_trace, row_data): + """Extract reduce info.""" + reduce_time = step_trace.get('reduce', {}) + for stream_id, time_points in reduce_time.items(): + time_point_num = len(time_points) + if time_point_num % 2: + log.warning("Stream %d has %d reduce time points.", stream_id, time_point_num) + continue + for index, point_id in enumerate(range(0, time_point_num, 2)): + field_name = f'stream_{stream_id}_{index}' + reduce_info = self._get_single_reduce_event_info( + field_name, time_points[point_id], time_points[point_id + 1]) + row_data.update(reduce_info) + + def _get_single_reduce_event_info(self, field_name, start_point, end_point): + """ + Get single reduce info. + + Args: + field_name (str): The field name. + start_point (Tuple[int, int]): Start point time info, including (tag_id, sys_count). + end_point (Tuple[int, int]): End point time info, including (tag_id, sys_count). + + Returns: + dict, reduce info. + """ + reduce_info = {} + if end_point[0] - start_point[0] != 1 or end_point[0] % 2: + log.warning("Unmatched reduce event <%s, %s>.", start_point, end_point) + return reduce_info + op_type = self._tag_map.get(start_point[0]) + # append field name with op type. + if not op_type: + log.warning("Can't recognize the inner type for point tag: %d.", start_point[0]) + field_name += '_parallel' + else: + field_name += '_' + op_type + reduce_info[field_name] = end_point[1] - start_point[1] + reduce_info[field_name + '_start_point'] = start_point[1] + reduce_info[field_name + '_end_point'] = end_point[1] + + return reduce_info + + def _record_average_info(self): + """Calculate average info.""" + result_size = len(self._result) + # calculate average data for each column in result data + average_data = [0] * len(self._header) + if result_size >= 2: + for row_info in self._result[1:]: + average_data = [ + Decimal(i) + Decimal(j) for i, j in zip(row_info, average_data) + ] + average_data = [ + round((item / (result_size - 1))) for item in average_data + ] + # change step num info in average_data to None + step_num_index = self._header.index('step_num') + average_data[step_num_index] = '-' + self._result.append(average_data) + log.info("Finish add average info for step trace.") + + def _save(self): + log.info("Start to save step trace file.") + if not self._header: + return + with open(self._output_path, 'w') as file_handle: + csv_writer = csv.writer(file_handle) + csv_writer.writerow(self._header) + for row_data in self._result: + csv_writer.writerow(row_data) + os.chmod(self._output_path, stat.S_IREAD) diff --git a/mindspore/profiler/profiling.py b/mindspore/profiler/profiling.py new file mode 100644 index 0000000000..ae908edf24 --- /dev/null +++ b/mindspore/profiler/profiling.py @@ -0,0 +1,441 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Profiling api file.""" +import os +import time + +from mindspore import log as logger, context +from mindspore.communication.management import release +from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \ + ProfilerIOException, ProfilerException +from mindspore.profiler.common.util import get_file_names, fwrite_format +from mindspore.profiler.common.validator.checkparam import \ + check_bool, check_subgraph +from mindspore.profiler.common.validator.validate_path import \ + validate_and_normalize_path +from mindspore.profiler.parser.aicpu_data_parser import DataPreProcessParser +from mindspore.profiler.parser.framework_parser import FrameworkParser +from mindspore.profiler.parser.hwts_log_parser import HWTSLogParser +from mindspore.profiler.parser.integrator import Integrator +from mindspore.profiler.parser.integrator import TimelineAnalyser +from mindspore.profiler.parser.minddata_parser import MinddataParser +from mindspore.profiler.parser.minddata_pipeline_parser import \ + MinddataPipelineParser +from mindspore.profiler.parser.optime_parser import OPComputeTimeParser +from mindspore.profiler.parser.step_trace_parser import StepTraceParser +from mindspore.nn.cell import Cell + +PROFILING_LOG_BASE_PATH = "/var/log/npu/profiling" +INIT_OP_NAME = 'Default/InitDataSetQueue' + + +class Profiler: + """ + Performance profiling API. + + Enable MindSpore users to profile the performance of neural network. + + Args: + subgraph (str): Define which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'. + is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False. + is_show_op_path (bool): Whether to save the full path for each op instance. + output_path (str): Output data path. + optypes_to_deal (str): Op type names, the data of which optype should be collected and analysed, + will deal with all op if null; Different op types should be seperated by comma. + optypes_not_deal (str): Op type names, the data of which optype will not be collected and analysed; + Different op types should be seperated by comma. + + Examples: + >>> from mindspore.profiler import Profiler + >>> import mindspore.context + >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", + >>> device_id=int(os.environ["DEVICE_ID"])) + >>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data') + >>> model = Model() + >>> model.train() + >>> profiler.analyse() + """ + + _base_profiling_container_path = "/var/log/npu/profiling/container" + _hwts_output_filename_target = "output_format_data_hwts_" + _opcompute_output_filename_target = "output_op_compute_time_" + _aicpu_op_output_filename_target = "output_data_preprocess_aicpu_" + + def __init__(self, subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data', + optypes_to_deal='', optypes_not_deal='Variable', job_id=""): + # get device_id and device_target + self._get_devid_and_devtarget() + self._container_path = os.path.join(self._base_profiling_container_path, self._dev_id) + data_path = os.path.join(self._container_path, "data") + if not os.path.exists(data_path): + os.makedirs(data_path, exist_ok=True) + self._output_path = validate_and_normalize_path(output_path) + self._output_path = os.path.join(self._output_path, "profiler") + if not os.path.exists(self._output_path): + os.makedirs(self._output_path, exist_ok=True) + + os.environ['PROFILING_MODE'] = 'true' + os.environ['PROFILING_OPTIONS'] = 'training_trace:task_trace' + os.environ['MINDDATA_PROFILING_DIR'] = self._output_path + os.environ['DEVICE_ID'] = self._dev_id + os.environ['AICPU_PROFILING_MODE'] = 'true' + os.environ['PROFILING_DIR'] = str(self._container_path) + + # use context interface to open profiling, for the new mindspore version(after 2020.5.21) + context.set_context(enable_profiling=True, profiling_options="training_trace:task_trace") + + self._subgraph = check_subgraph(subgraph) + self._valid_optype_name = optypes_to_deal.split(",") if optypes_to_deal else [] + self._filt_optype_names = optypes_not_deal.split(",") if optypes_not_deal else [] + self._detail = check_bool(is_detail, 'is_detail') + self._withfullpath = check_bool(is_show_op_path, 'is_show_op_path') + self._profiling_job_id = job_id + # add job id env through user input later + self._job_id_env = 0 + self._start_time = int(time.time() * 10000000) + logger.info("Profiling: profiling start time: %d", self._start_time) + + def analyse(self): + """ + Collect and analyse performance data, called after training or during training. + + Examples: + >>> from mindspore.profiler import Profiler + >>> import mindspore.context + >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", + >>> device_id=int(os.environ["DEVICE_ID"])) + >>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data') + >>> model = Model() + >>> model.train() + >>> profiler.analyse() + """ + release() + + job_id = self._get_profiling_job_id() + logger.info("Profiling: job id is %s ", job_id) + + source_path = os.path.join(PROFILING_LOG_BASE_PATH, job_id) + # parse hwts.log.data.45.dev file, and get task profiling data + hwts_output_filename = self._hwts_output_filename_target + self._dev_id + ".txt" + hwts_output_filename = os.path.join(self._output_path, hwts_output_filename) + hwtslog_parser = HWTSLogParser(source_path, hwts_output_filename) + result = hwtslog_parser.execute() + if not result: + logger.error("Profiling: fail to parse hwts log file.") + return + + # parse Framework file, and get the relation of op and tasks + framework_parser = FrameworkParser(job_id, self._dev_id, self._output_path) + framework_parser.parse() + op_task_dict = framework_parser.to_task_id_full_op_name_dict() + if not op_task_dict: + logger.error("Profiling: fail to parse framework files.") + return + + # get op compute time from hwts data and framework data, write output_op_compute_time.txt + opcompute_output_filename = self._opcompute_output_filename_target + self._dev_id + ".txt" + opcompute_output_filename = os.path.join(self._output_path, opcompute_output_filename) + optime_parser = OPComputeTimeParser( + hwts_output_filename, opcompute_output_filename, + op_task_dict, self._output_path, self._dev_id + ) + optime_parser.execute() + + # parse DATA_PREPROCESS.dev.AICPU file, write output_data_preprocess_aicpu_x.txt + output_data_preprocess_aicpu = self._aicpu_op_output_filename_target + self._dev_id + ".txt" + output_data_preprocess_aicpu = os.path.join(self._output_path, output_data_preprocess_aicpu) + aicpu_data_parser = DataPreProcessParser(source_path, output_data_preprocess_aicpu) + aicpu_data_parser.execute() + + # Parsing minddata AICPU profiling + MinddataParser.execute(source_path, self._output_path, self._dev_id) + + # parse minddata pipeline operator and queue + try: + pipeline_parser = MinddataPipelineParser(self._output_path, self._dev_id, self._output_path) + pipeline_parser.parse() + except ProfilerException as err: + logger.warning(err.message) + + # analyse op compute time info + try: + self._analyser_op_info() + except ProfilerException as err: + logger.warning(err.message) + + # analyse step trace info + try: + self._analyse_step_trace(source_path, framework_parser) + except ProfilerException as err: + logger.warning(err.message) + + # analyse timeline info + try: + self._analyse_timeline(aicpu_data_parser, optime_parser) + except (ProfilerIOException, ProfilerFileNotFoundException, RuntimeError) as err: + logger.warning('Fail to write timeline data: %s', err) + + def _analyse_step_trace(self, source_path, framework_parser): + """ + Analyse step trace data and save the result. + + Args: + source_path (str): The directory that contains the step trace original data. + framework_parser (FrameworkParser): The framework parse instance. + """ + logger.info("Begin to parse step trace.") + # construct output path + step_trace_intermediate_file_path = os.path.join( + self._output_path, + f'step_trace_raw_{self._dev_id}_detail_time.csv' + ) + point_info_file_path = os.path.join( + self._output_path, + 'step_trace_point_info.json' + ) + # whether keep the first step + skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME) + point_info = framework_parser.point_info + # parser the step trace files and save the result to disk + parser = StepTraceParser(input_dir=source_path, + output_file_path=step_trace_intermediate_file_path, + job_id=self._job_id_env, + skip_first_step=skip_first_step_flag) + parser.update_tag_op_type_map(point_info) + parser.parse_and_save() + point_info = parser.record_point_info(point_info, point_info_file_path) + # print parser result + parser.show() + logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path) + logger.info("The point info is: %s", point_info) + + def _analyse_timeline(self, aicpu_parser, optime_parser): + """ + Analyse and parse timeline info. + + Args: + aicpu_parser (DataPreProcessParser): The parser instance for AI CPU operator + execution time calculation. + optime_parser (OPComputeTimeParserParser): The parser instance for AI Core + operator execution time calculation. + """ + timeline_analyser = TimelineAnalyser(self._output_path, self._dev_id) + # Get framework info + integrator = Integrator(self._output_path, self._dev_id) + aicore_detail_data = integrator.get_aicore_detail_data() + aicore_detail_data_size = len(aicore_detail_data) + col_names = ['op_name', 'op_type', 'avg_execution_time', 'subgraph', + 'full_op_name', 'op_info'] + framework_info = { + 'col_name': col_names, + 'object': aicore_detail_data, + 'size': aicore_detail_data_size + } + + all_reduce_info = integrator.query_for_all_reduce() + + # Get timeline info + logger.info('Start writing timeline info...') + logger.info('Warm Prompt: It could take a few minutes if you are training ' + 'with a complex network or more than 10 steps.') + # Add info into timeline, such as AI CPU, AllReduce, framework info. + aicpu_info = aicpu_parser.query_aicpu_data() + min_cycle_counter = min(aicpu_parser.min_cycle_counter, optime_parser.min_cycle_counter) + timeline_analyser.init_timeline(all_reduce_info, framework_info, aicpu_info, min_cycle_counter) + timeline_analyser.write_timeline() + timeline_analyser.write_timeline_summary() + + def __del__(self): + """Disable the profiling collection service, called after training.""" + os.environ['PROFILING_MODE'] = str("false") + context.set_context(enable_profiling=False) + + def _get_profiling_job_id(self): + """Get profiling job id, which was generated by ada service. + + Returns: + str: profiling jon id. + """ + + if self._profiling_job_id: + return self._profiling_job_id + + job_id = "" + cmd = "ls -t " + PROFILING_LOG_BASE_PATH + "|grep JOB|awk '{print $1}'" + r = os.popen(cmd) + profiling_job_dirs = r.readlines() + r.close() + for item in profiling_job_dirs: + path = os.path.join(PROFILING_LOG_BASE_PATH, item.strip()) + log_file = get_file_names(path, "host_start.log") + if not log_file: + logger.error("Profiling: job path %s, host_start.log not exist.", path) + continue + + log_file = os.path.join(path, log_file[0]) + item_dict = self._parse_host_start_log(log_file) + + if not item_dict: + logger.error("Profiling: job path %s, fail to get job start info.", path) + continue + if self._start_time > int(item_dict["start_time"]): + logger.info("Profiling: job path %s, start_time %s, training start_time %d.", + path, item_dict["start_time"], self._start_time) + break + + if self._dev_id != item_dict["device_id"]: + logger.info("Profiling: job path %s, dev id %s, training device id %s.", + path, item_dict["device_id"], self._dev_id) + continue + + job_id = item.strip() + break + + if not job_id: + msg = "Fail to get profiling job, please check whether job dir was generated" + raise RuntimeError(msg) + + return job_id + + def _parse_host_start_log(self, input_file): + """ + Parse host start log file, get the device id and start time of the job. + + Args: + input_file (str): The file path of the host start log file. + + Returns: + dict, job start time and device id. + """ + + item_dict = {} + for line in open(input_file): + if "Device" in line: + item_dict["device_id"] = line[7:len(line)-2] + elif "clock_realtime" in line: + item_dict["start_time"] = line[16:len(line)-3] + + return item_dict + + def _analyser_op_info(self): + """Analyse the operator information.""" + integrator = Integrator(self._output_path, self._dev_id) + integrator.integrate() + + aicore_type_result = self._query_op_type_info() + detail_file_path = os.path.join( + self._output_path, + 'output_op_compute_time_detail_{}.txt'.format(self._dev_id) + ) + fwrite_format(detail_file_path, data_source='title:op compute time') + display_names = [ + 'optype_name', 'compute_time(ms, per-step)', + 'called_times(per-step)', 'percent' + ] + fwrite_format(detail_file_path, data_source=" ".join(display_names), is_print=True) + fwrite_format(detail_file_path, data_source=aicore_type_result, is_print=True) + + if self._detail: + op_type_order = [item[0] for item in aicore_type_result] + aicore_detail_result = self._query_op_detail_info(op_type_order) + + fwrite_format(detail_file_path, data_source='', is_print=True) + fwrite_format(detail_file_path, data_source='Detail:', is_print=True) + fwrite_format(detail_file_path, data_source=" ".join(aicore_detail_result.get('col_name_detail')), + is_print=True) + fwrite_format(detail_file_path, data_source=aicore_detail_result.get('object'), is_print=True) + + def _query_op_type_info(self): + """ + Query AICORE operator type information. + + Returns: + list[list], the AICORE operator type and execution time information. + """ + integrator = Integrator(self._output_path, self._dev_id) + return integrator.get_aicore_data() + + def _query_op_detail_info(self, op_type_order): + """ + Query AICORE operator detail information. + + Args: + op_type_order(list): The name of the op type in order. + + Returns: + dict, the AICORE operator detail information. + """ + + op_type_condition = {} + if self._valid_optype_name: + op_type_condition['in'] = self._valid_optype_name + if self._filt_optype_names: + op_type_condition['not_in'] = self._filt_optype_names + + subgraph_condition = {} + if self._subgraph != 'all': + subgraph_condition['in'] = [self._subgraph] + + filter_condition = { + 'op_type': op_type_condition, + 'subgraph': subgraph_condition, + 'is_display_detail': False, + 'is_display_full_op_name': self._withfullpath + } + integrator = Integrator(self._output_path, self._dev_id) + return integrator.query_and_sort_by_op_type(filter_condition, op_type_order) + + def _get_devid_and_devtarget(self): + """Get device id and target of this training.""" + + device_target = "" + dev_id = "" + try: + dev_id = str(context.get_context("device_id")) + device_target = context.get_context("device_target") + except ValueError as err: + logger.error("Profiling: fail to get context, %s", err) + + if not dev_id or not dev_id.isdigit(): + dev_id = os.getenv('DEVICE_ID') + if not dev_id or not dev_id.isdigit(): + dev_id = "0" + logger.error("Fail to get DEVICE_ID, use 0 instead.") + + if device_target and device_target != "Davinci" \ + and device_target != "Ascend": + msg = "Profiling: unsupport backend: %s" % device_target + raise RuntimeError(msg) + + self._dev_id = dev_id + + @staticmethod + def trainable_parameters(network): + """ + Get the number of trainable parameters in the training network. + + Args: + network(Cell): The training network. + + Returns: + an integer,the network of trainable parameters. + """ + if not isinstance(network, Cell): + msg = "Profiling: The network should be an object of nn.Cell" + raise ValueError(msg) + + param_nums = len(network.parameters_dict()) + + return param_nums diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 85fd6fa189..5d0ae10081 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -57,7 +57,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): # transform data format dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) - exec_dataset = exec_dataset.device_que() + send_epoch_end = bool(dataset_size == -1) + exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end) _executor.init_dataset(exec_dataset.queue_name, dataset_size, @@ -126,7 +127,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1): def _to_tensor(elem, scaling_sens=None): - """Conver numpy to tensor, adapt to minddata feed solution.""" + """Convert numpy to tensor, adapt to feed the data from host solution.""" lst = [] if not isinstance(elem, (tuple, list)): elem = [elem] @@ -145,7 +146,8 @@ def _to_tensor(elem, scaling_sens=None): def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): - """Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution.""" + """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data + from host solution.""" lst = [] if not isinstance(elem, (tuple, list)): elem = [elem] diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index a47b16d0e0..e2da1618bf 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -102,7 +102,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): def construct(self, data, label): out = self._backbone(data) label = F.mixed_precision_cast(mstype.float32, label) - return self._loss_fn(F.cast(out, mstype.float32), label) + return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) if cast_model_type == mstype.float16: @@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. + Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect. loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If set, overwrite the level setting. """ diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index e0048ad713..b9e235aed5 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -15,10 +15,10 @@ """Checkpoint related classes and functions.""" import os -import shutil import stat import time +import threading import mindspore.context as context from mindspore import log as logger from mindspore._checkparam import check_bool, check_int_non_negative @@ -86,6 +86,7 @@ class CheckpointConfig: Can't be used with keep_checkpoint_max at the same time. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. + async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False Raises: ValueError: If the input_param is None or 0. @@ -100,21 +101,22 @@ class CheckpointConfig: save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, - integrated_save=True): + integrated_save=True, + async_save=False): - if not save_checkpoint_steps and not save_checkpoint_seconds and \ - not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: - raise ValueError("The input_param can't be all None or 0") - - if save_checkpoint_steps: + if save_checkpoint_steps is not None: save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) - if save_checkpoint_seconds: + if save_checkpoint_seconds is not None: save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) - if keep_checkpoint_max: + if keep_checkpoint_max is not None: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) - if keep_checkpoint_per_n_minutes: + if keep_checkpoint_per_n_minutes is not None: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) + if not save_checkpoint_steps and not save_checkpoint_seconds and \ + not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: + raise ValueError("The input_param can't be all None or 0") + self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: @@ -129,6 +131,7 @@ class CheckpointConfig: self._keep_checkpoint_max = 1 self._integrated_save = check_bool(integrated_save) + self._async_save = check_bool(async_save) @property def save_checkpoint_steps(self): @@ -155,6 +158,11 @@ class CheckpointConfig: """Get the value of _integrated_save.""" return self._integrated_save + @property + def async_save(self): + """Get the value of _async_save.""" + return self._async_save + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, @@ -165,7 +173,6 @@ class CheckpointConfig: return checkpoint_policy - class ModelCheckpoint(Callback): """ The checkpoint callback class. @@ -195,7 +202,7 @@ class ModelCheckpoint(Callback): raise ValueError("Prefix {} for checkpoint file name invalid, " "please check and correct it and then continue.".format(prefix)) - if directory: + if directory is not None: self._directory = _make_directory(directory) else: self._directory = _cur_dir @@ -238,6 +245,12 @@ class ModelCheckpoint(Callback): _to_save_last_ckpt = True self._save_ckpt(cb_params, _to_save_last_ckpt) + thread_list = threading.enumerate() + if len(thread_list) > 1: + for thread in thread_list: + if thread.getName() == "asyn_save_ckpt": + thread.join() + from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell() @@ -282,8 +295,6 @@ class ModelCheckpoint(Callback): global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) - tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt' - gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num @@ -291,10 +302,9 @@ class ModelCheckpoint(Callback): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) + _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, + self._config.async_save) - if os.path.exists(gen_file): - shutil.move(gen_file, cur_file) self._latest_ckpt_file_name = cur_file @property diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index ded0e9a650..6ac0883268 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -73,7 +73,8 @@ class SummaryCollector(Callback): summary_dir (str): The collected data will be persisted to this directory. If the directory does not exist, it will be created automatically. collect_freq (int): Set the frequency of data collection, it should be greater then zero, - and the unit is `step`. Default: 10. The first step will be recorded at any time. + and the unit is `step`. Default: 10. If a frequency is set, we will collect data + at (current steps % freq) == 0, and the first step will be collected at any time. It is important to note that if the data sink mode is used, the unit will become the `epoch`. It is not recommended to collect data too frequently, which can affect performance. collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None. @@ -108,6 +109,18 @@ class SummaryCollector(Callback): custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight lineage page. In the custom data, the key type support str, and the value type support str/int/float. Default: None, it means there is no custom data. + collect_tensor_freq (Optional[int]): Same semantic as the `collect_freq`, but controls TensorSummary only. + Because TensorSummary data is too large compared to other summary data, this parameter is used to reduce + its collection. By default, TensorSummary data will be collected at most 21 steps, but not more than how + many steps other summary data will be collected. + Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`, + when the total steps is 600, TensorSummary will be collected 21 steps, while other summary data 61 steps, + but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps. + Also note that when in parallel mode, the total steps will be splitted evenly, which will + affect how many steps TensorSummary will be collected. + max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. + Default: None, which means no limit. For example, to write not larger than 4GB, + specify `max_file_size=4 * 1024**3`. Raises: ValueError: If the parameter value is not expected. @@ -145,16 +158,28 @@ class SummaryCollector(Callback): 'histogram_regular': None } - def __init__(self, summary_dir, collect_freq=10, collect_specified_data=None, - keep_default_action=True, custom_lineage_data=None): + def __init__(self, + summary_dir, + collect_freq=10, + collect_specified_data=None, + keep_default_action=True, + custom_lineage_data=None, + collect_tensor_freq=None, + max_file_size=None): super(SummaryCollector, self).__init__() self._summary_dir = self._process_summary_dir(summary_dir) self._record = None - self._check_collect_freq(collect_freq) + self._check_positive('collect_freq', collect_freq) self._collect_freq = collect_freq + self._check_positive('collect_tensor_freq', collect_tensor_freq, allow_none=True) + self._collect_tensor_freq = collect_tensor_freq + + self._check_positive('max_file_size', max_file_size, allow_none=True) + self._max_file_size = max_file_size + self._check_action(keep_default_action) self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action) @@ -165,16 +190,15 @@ class SummaryCollector(Callback): self._custom_lineage_data = custom_lineage_data self._temp_optimizer = None - self._has_saved_train_network = False + self._has_saved_graph = False self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True self._dataset_sink_mode = True def __enter__(self): - self._first_step = True - self._dataset_sink_mode = True - self._record = SummaryRecord(log_dir=self._summary_dir) + self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size) + self._first_step, self._dataset_sink_mode = True, True return self def __exit__(self, *err): @@ -198,11 +222,13 @@ class SummaryCollector(Callback): return summary_dir @staticmethod - def _check_collect_freq(freq): - """Check collect freq type and value.""" - check_value_type('collect_freq', freq, int) - if freq <= 0: - raise ValueError(f'For `collect_freq` the value should be greater than 0, but got `{freq}`.') + def _check_positive(name, value, allow_none=False): + """Check if the value to be int type and positive.""" + if allow_none and value is None: + return + check_value_type(name, value, int) + if value <= 0: + raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') @staticmethod def _check_custom_lineage_data(custom_lineage_data): @@ -269,42 +295,49 @@ class SummaryCollector(Callback): 'but got `{cb_params.mode}` mode.') self._record.set_mode(cb_params.mode) + if cb_params.mode == ModeEnum.TRAIN.value: - # Note: if model.init is not executed then the computed graph will not be obtained here - # The purpose of recording the graph here was to collect_freq if it was set to a large size, - # but also want to see the graph as soon after compilation. - self._collect_graphs(cb_params) + if self._collect_tensor_freq is None: + default_tensor_summary_limit = 20 + total_step = cb_params.epoch_num * cb_params.batch_num + self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit) + def step_end(self, run_context): + cb_params = run_context.original_args() + if cb_params.mode != ModeEnum.TRAIN.value: + return + + if not self._has_saved_graph: + self._collect_graphs(cb_params) self._collect_dataset_graph(cb_params) + self._has_saved_graph = True + self._record.record(cb_params.cur_step_num) if self._custom_lineage_data and not self._has_saved_custom_data: packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data) self._has_saved_custom_data = True + self._record.record(cb_params.cur_step_num) - # There's nothing special about setting step to 0 here, just to satisfy the interface call - self._record.record(step=0) - - def step_end(self, run_context): - cb_params = run_context.original_args() if self._first_step: # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario - self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) - - if cb_params.mode == ModeEnum.TRAIN.value: - - if not self._is_collect_this_step(cb_params): - return - - if not self._has_saved_train_network: - self._collect_graphs(cb_params) - - self._collect_input_data(cb_params) - self._collect_metric(cb_params) - self._collect_histogram(cb_params) - - self._first_step = False - self._record.record(cb_params.cur_step_num) + self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num + self._collect_at_step_end(cb_params, plugin_filter=None) + self._first_step = False + else: + current = cb_params.cur_epoch_num if self._dataset_sink_mode else cb_params.cur_step_num + if current % self._collect_freq == 0 and current % self._collect_tensor_freq == 0: + self._collect_at_step_end(cb_params, plugin_filter=None) + elif current % self._collect_tensor_freq == 0: + self._collect_at_step_end(cb_params, lambda plugin: plugin == PluginEnum.TENSOR.value) + elif current % self._collect_freq == 0: + self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value) + + def _collect_at_step_end(self, cb_params, plugin_filter): + self._collect_input_data(cb_params) + self._collect_metric(cb_params) + self._collect_histogram(cb_params) + self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter) def end(self, run_context): cb_params = run_context.original_args() @@ -331,18 +364,6 @@ class SummaryCollector(Callback): raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," f"but expected only one {self.__class__.__name__} instance.") - def _is_collect_this_step(self, cb_params): - """Decide whether to collect data for the current step.""" - # Make sure the first step data is recorded - if not self._first_step: - if self._dataset_sink_mode: - if cb_params.cur_epoch_num % self._collect_freq: - return False - else: - if cb_params.cur_step_num % self._collect_freq: - return False - return True - @staticmethod def _package_custom_lineage_data(custom_lineage_data): """ @@ -411,7 +432,6 @@ class SummaryCollector(Callback): if graph_proto is None: return - self._has_saved_train_network = True self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) def _collect_metric(self, cb_params): @@ -582,7 +602,7 @@ class SummaryCollector(Callback): else: train_lineage[LineageMetadata.learning_rate] = None train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None - train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network) + train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__ loss_fn = self._get_loss_fn(cb_params) train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None @@ -739,30 +759,6 @@ class SummaryCollector(Callback): return ckpt_file_path - @staticmethod - def _get_backbone(network): - """ - Get the name of backbone network. - - Args: - network (Cell): The train network. - - Returns: - Union[str, None], If parse success, will return the name of the backbone network, else return None. - """ - backbone_name = None - backbone_key = '_backbone' - - for _, cell in network.cells_and_names(): - if hasattr(cell, backbone_key): - backbone_network = getattr(cell, backbone_key) - backbone_name = type(backbone_network).__name__ - - if backbone_name is None and network is not None: - backbone_name = type(network).__name__ - - return backbone_name - @staticmethod def _get_loss_fn(cb_params): """ diff --git a/mindspore/train/callback/_time_monitor.py b/mindspore/train/callback/_time_monitor.py index 9fbdf83aa8..f5a5815041 100644 --- a/mindspore/train/callback/_time_monitor.py +++ b/mindspore/train/callback/_time_monitor.py @@ -16,13 +16,19 @@ import time +from mindspore import log as logger from ._callback import Callback class TimeMonitor(Callback): - """Time Monitor.""" + """ + Monitor the time in training. - def __init__(self, data_size): + Args: + data_size (int): Dataset size. Default: None. + """ + + def __init__(self, data_size=None): super(TimeMonitor, self).__init__() self.data_size = data_size @@ -30,6 +36,17 @@ class TimeMonitor(Callback): self.epoch_time = time.time() def epoch_end(self, run_context): - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / self.data_size - print("Epoch time: {:5.3f}, per step time: {:5.3f}".format(epoch_mseconds, per_step_mseconds), flush=True) + epoch_seconds = (time.time() - self.epoch_time) * 1000 + step_size = self.data_size + cb_params = run_context.original_args() + if hasattr(cb_params, "batch_num"): + batch_num = cb_params.batch_num + if isinstance(batch_num, int) and batch_num > 0: + step_size = cb_params.batch_num + + if not isinstance(step_size, int) or step_size < 1: + logger.error("data_size must be positive int.") + return + + step_seconds = epoch_seconds / step_size + print("Epoch time: {:5.3f}, per step time: {:5.3f}".format(epoch_seconds, step_seconds), flush=True) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 75e1deabc4..077463ac9f 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -16,7 +16,7 @@ import math import os -from mindspore._checkparam import check_bool +from mindspore._checkparam import check_bool, check_int from .. import context from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ _construct_tensor_list, _to_full_shapes, _to_full_tensor @@ -24,35 +24,48 @@ from ..nn.wrap import GetNextSingleOp from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full -def _send_data(dataset): +def _send_data(dataset, epoch_num): """Engine dataset to write data to tdt queue.""" if not hasattr(dataset, '__has_sent__'): exec_dataset = dataset.__TRANSFER_DATASET__ - exec_dataset.send() + exec_dataset.send(epoch_num) dataset.__has_sent__ = True +def _send_data_no_flag(dataset, epoch_num): + """Engine dataset to write data to tdt queue directly.""" + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send(epoch_num) + class DatasetHelper: """ - Help function to use the Minddata dataset. + Help function to use the MindData dataset. - According to different context, change the iter of dataset, to use the same for loop in different context. + According to different contexts, change the iterations of dataset and use the same iteration for loop in different + contexts. Note: - The iter of DatasetHelper will give one epoch data. + The iteration of DatasetHelper will provide one epoch data. Args: - dataset (DataSet): The dataset. - dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. - Default: True. + dataset (DataSet): The training dataset iterator. + dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True. + sink_size (int): Control the amount of data in each sink. + If sink_size=-1, sink the complete dataset for each epoch. + If sink_size>0, sink sink_size data for each epoch. Default: -1. + epoch_num (int): Control the number of epoch data to send. Default: 1. Examples: >>> dataset_helper = DatasetHelper(dataset) >>> for inputs in dataset_helper: >>> outputs = network(*inputs) """ - def __init__(self, dataset, dataset_sink_mode=True): + + def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): check_bool(dataset_sink_mode) + check_int(sink_size) + if sink_size < -1 or sink_size == 0: + raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) if dataset_sink_mode: if context.get_context("enable_ge"): @@ -68,77 +81,110 @@ class DatasetHelper: iterclass = _DatasetIterMS elif context.get_context("device_target") == "CPU": raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") + self.iter = iterclass(dataset, sink_size, epoch_num) else: - iterclass = _DatasetIterFeed - self.iter = iterclass(dataset) + iterclass = _DatasetIterNormal + self.iter = iterclass(dataset) def __iter__(self): return self.iter.__iter__() # A temp solution for loop sink. Delete later def types_shapes(self): - """Get the types and shapes from dataset on current config.""" + """Get the types and shapes from dataset on the current configuration.""" return self.iter.types_shapes() - def loop_size(self): - """Get loop_size for every iteration.""" - return self.iter.loop_size + def sink_size(self): + """Get sink_size for each iteration.""" + return self.iter.get_sink_size() + + def stop_send(self): + """Free up resources about data sink.""" + self.iter.stop_send() class _DatasetIter: - """Base iter for dataset help""" - def __init__(self, dataset): - if not hasattr(dataset, '__loop_size__'): - self.loop_size = dataset.get_dataset_size() - else: - self.loop_size = dataset.__loop_size__ + """Base iter for dataset helper""" + def __init__(self, dataset, sink_size, epoch_num): + self.dataset = dataset + self.sink_size = sink_size + self.sink_count = 1 - if not hasattr(dataset, '__ME_INITED__'): - dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + if not hasattr(dataset, '__TRANSFER_DATASET__'): + if hasattr(dataset, '__loop_size__'): + self.sink_size = dataset.__loop_size__ + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name if not hasattr(dataset, '__no_send__'): - _send_data(dataset) + _send_data(dataset, epoch_num) else: - _send_data(dataset) + _send_data_no_flag(dataset, epoch_num) - self.ind = 0 - self.dataset = dataset - dataset_types, dataset_shapes = _get_types_and_shapes(dataset) - self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes + self.stop_send = dataset.__TRANSFER_DATASET__.stop_send + self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) def __iter__(self): - self.ind = 0 + self.index = 0 return self def __next__(self): - if self.ind >= self.loop_count: + if self.index >= self.sink_count: raise StopIteration() - self.ind += 1 + self.index += 1 return self.op() def types_shapes(self): return self.dataset_types, self.dataset_shapes - def get_loop_count(self, dataset): - loop_count = 1 + def get_sink_count(self, dataset): + sink_count = 1 if hasattr(dataset, '__loop_size__'): loop_size = dataset.__loop_size__ if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' - f'loop_size {loop_size} are not matched.') - loop_count = math.ceil(dataset.get_dataset_size() / loop_size) - return loop_count + f'sink_size {loop_size} are not matched.') + sink_count = math.ceil(dataset.get_dataset_size() / loop_size) + return sink_count + + def get_sink_size(self): + """get sink_size to device""" + sink_size = 1 + if hasattr(self.dataset, '__loop_size__'): + sink_size = self.dataset.__loop_size__ + else: + if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend": + if self.sink_size > 0: + sink_size = self.sink_size + else: + sink_size = self.dataset.get_dataset_size() + return sink_size + + +class _DatasetIterGE(_DatasetIter): + """Iter for GE.""" + def __init__(self, dataset, sink_size, epoch_num): + super().__init__(dataset, sink_size, epoch_num) + self.sink_count = self.get_sink_count(dataset) + batch_expand_num = 1 + if _need_to_full(): + batch_expand_num = _get_device_num() + tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) + + def op(): + return tensor_list_run + + self.op = op class _DatasetIterMSLoopSink(_DatasetIter): """Iter for context (device_target=Ascend)""" - def __init__(self, dataset): - super(_DatasetIterMSLoopSink, self).__init__(dataset) - self.loop_count = self.get_loop_count(dataset) + def __init__(self, dataset, sink_size, epoch_num): + super().__init__(dataset, sink_size, epoch_num) + self.sink_count = self.get_sink_count(dataset) ms_role = os.getenv("MS_ROLE") if ms_role in ("MS_PSERVER", "MS_SCHED"): - self.loop_count = 1 + self.sink_count = 1 # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. @@ -153,66 +199,42 @@ class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMS(_DatasetIter): - """Iter for context (device_target=GPU)""" - def __init__(self, dataset): - super(_DatasetIterMS, self).__init__(dataset) - self.loop_count = dataset.get_dataset_size() - self.loop_size = 1 + """Iter for MS(enable_loop_sink=False).""" + def __init__(self, dataset, sink_size, epoch_num): + super().__init__(dataset, sink_size, epoch_num) + if sink_size > 0: + self.sink_count = sink_size + else: + self.sink_count = dataset.get_dataset_size() + queue_name = dataset.__ME_INITED__ self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) class _DatasetIterPSLite(_DatasetIter): """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" - def __init__(self, dataset): - super(_DatasetIterPSLite, self).__init__(dataset) - self.loop_count = 1 - self.loop_size = 1 + def __init__(self, dataset, sink_size, epoch_num): + super().__init__(dataset, sink_size, epoch_num) + self.sink_count = 1 + self.sink_size = 1 self.op = None def op(): return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) self.op = op -class _DatasetIterGE(_DatasetIter): - """Iter for ge""" - def __init__(self, dataset): - super(_DatasetIterGE, self).__init__(dataset) - self.loop_count = self.get_loop_count(dataset) - batch_expand_num = 1 - if _need_to_full(): - batch_expand_num = _get_device_num() - tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) - - def op(): - return tensor_list_run - - self.op = op - - -class _DatasetIterFeed: +class _DatasetIterNormal: """Iter for normal(non sink) mode, feed the data from host.""" def __init__(self, dataset): self.dataset = dataset self.device_num = _get_device_num() self.global_rank = _get_global_rank() - self.repeat_count = dataset.get_repeat_count() - self.repeat_ind = 0 - self.loop_count = dataset.get_dataset_size() - self.ind = 0 + self.iter = self.dataset.create_tuple_iterator() def __iter__(self): - if self.repeat_ind % self.repeat_count == 0: - self.iter = self.dataset.__iter__() - - self.repeat_ind += 1 - self.ind = 0 return self def __next__(self): - if self.ind >= self.loop_count: - raise StopIteration() - self.ind += 1 data = self.iter.__next__() if _need_to_full(): return _to_full_tensor(data, self.device_num, self.global_rank) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 74fd668e82..04ee2fe40e 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -16,12 +16,13 @@ from collections.abc import Iterable import os +import math import numpy as np from mindspore import log as logger from ..common.tensor import Tensor from ..nn.metrics import get_metrics -from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool +from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ @@ -30,6 +31,8 @@ from ..nn.metrics import Loss from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from .parallel_utils import ParallelMode +from ._utils import _to_full_tensor +from ..parallel._utils import _need_to_full from ..common import dtype as mstype from .dataset_helper import DatasetHelper from . import amp @@ -42,20 +45,20 @@ class Model: `Model` groups layers into an object with training and inference features. Args: - network (Cell): The training or testing network. + network (Cell): A training or testing network. loss_fn (Cell): Objective function, if loss_fn is None, the network should contain the logic of loss and grads calculation, and the logic of parallel if needed. Default: None. optimizer (Cell): Optimizer for updating the weights. Default: None. - metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during + metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during training and testing. eg: {'accuracy', 'recall'}. Default: None. eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as `eval_network`. Default: None. - eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of + eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three - elements, representing the positions of loss value, predict value and label, the loss - value would be passed to `Loss` metric, predict value and label would be passed to other - metric. Default: None. + elements, including the positions of loss value, predicted value and label. The loss + value would be passed to the `Loss` metric, the predicted value and label would be passed + to other metric. Default: None. amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed precision training. Supports [O0, O2, O3]. Default: "O0". @@ -65,10 +68,11 @@ class Model: O2 is recommended on GPU, O3 is recommended on Ascend. - loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else - scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. + loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise, + scale the loss by LossScaleManager. It is a key argument. e.g. Use `loss_scale_manager=None` to set the value. - keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. + keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before + will be overwritten. Default: True. Examples: >>> class Net(nn.Cell): @@ -171,7 +175,7 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"]) self._eval_indexes = [0, 1, 2] if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): @@ -225,7 +229,7 @@ class Model: scaling_sens /= self._device_number return scaling_sens - def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode): + def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1): """Initializes dataset.""" need_wrap = False if dataset_sink_mode: @@ -237,7 +241,7 @@ class Model: if not is_train: dataset.__loop_size__ = 1 - dataset_helper = DatasetHelper(dataset, dataset_sink_mode) + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num) # remove later to deal with loop sink if need_wrap: @@ -252,16 +256,16 @@ class Model: def init(self, train_dataset=None, valid_dataset=None): """ - Initializes compute graphs and data graphs with sink mode. + Initialize compute graphs and data graphs with the sink mode. Note: Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. Args: - train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be + train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be initialized. Default: None. - valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will - be initialized, and `metrics` in `Model` can not be None. Default: None. + valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs + will be initialized, and `metrics` in `Model` can not be None. Default: None. Examples: >>> train_dataset = get_train_dataset() @@ -317,21 +321,23 @@ class Model: self._eval_network.compile(*inputs) break - def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): + def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): """ Training. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be + loss_fn, a tuple with multiple data (data1, data2, data3, ...) will be returned and passed to the network. Otherwise, a tuple (data, label) will - be returned, and the data and label are passed to the network and loss + be returned. The data and label would be passed to the network and loss function respectively. - callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. - dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + callbacks (list): List of callback objects which should be executed while training. Default: None. + dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel. + Default: True. Configure pynative mode, the training process will be performed with dataset not sink. + sink_size (int): Control the amount of data in each sink. Default: -1. """ epoch = check_int_positive(epoch) self._train_network.set_train() @@ -342,7 +348,10 @@ class Model: cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch - cb_params.batch_num = train_dataset.get_dataset_size() + if dataset_sink_mode and sink_size > 0: + cb_params.batch_num = sink_size + else: + cb_params.batch_num = train_dataset.get_dataset_size() cb_params.mode = "train" cb_params.loss_fn = self._loss_fn cb_params.optimizer = self._optimizer @@ -351,6 +360,7 @@ class Model: cb_params.train_dataset = train_dataset cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.train_dataset_element = None + cb_params.network = self._network ms_role = os.getenv("MS_ROLE") if ms_role in ("MS_PSERVER", "MS_SCHED"): epoch = 1 @@ -364,7 +374,7 @@ class Model: "So the training process will be performed with dataset not sink.") self._train_process(epoch, train_dataset, list_callback, cb_params) else: - self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) + self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) @staticmethod def _transform_callbacks(callbacks): @@ -377,30 +387,37 @@ class Model: return [callbacks] - def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): + def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): """ Training process. The data would be passed to network through dataset channel. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should - be returned, and the data and label are passed to the network and loss + be returned. The data and label would be passed to the network and loss function respectively. list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. + sink_size (int): Control the amount of data in each sink. Default: -1. """ + if sink_size == -1: + epoch_num = epoch + else: + epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) + dataset_helper, train_network = self._exec_preprocess(self._train_network, is_train=True, phase='train', dataset=train_dataset, - dataset_sink_mode=True) + dataset_sink_mode=True, + sink_size=sink_size, + epoch_num=epoch_num) self._train_network = train_network cb_params.train_network = self._train_network cb_params.cur_step_num = 0 - loop_size = dataset_helper.loop_size() run_context = RunContext(cb_params) list_callback.begin(run_context) @@ -412,9 +429,11 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: - cb_params.cur_step_num += loop_size + if _need_to_full() and context.get_context("device_target") == "GPU": + inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) outputs = self._train_network(*inputs) + cb_params.cur_step_num += dataset_helper.sink_size() cb_params.net_outputs = outputs list_callback.step_end(run_context) @@ -422,6 +441,7 @@ class Model: should_stop = should_stop or run_context.get_stop_requested() if should_stop: break + dataset_helper.stop_send() list_callback.end(run_context) @@ -432,9 +452,9 @@ class Model: Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should - be returned, and the data and label are passed to the network and loss + be returned. The data and label would be passed to the network and loss function respectively. list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. @@ -490,7 +510,7 @@ class Model: list_callback.end(run_context) - def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): + def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): """ Training API where the iteration is controlled by python front-end. @@ -500,22 +520,27 @@ class Model: CPU is not supported when dataset_sink_mode is true. If dataset_sink_mode is True, epoch of training should be equal to the count of repeat operation in dataset processing. Otherwise, errors could occur since the amount of data - is not the amount training requires. + is not equal to the required amount of training . If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. Args: - epoch (int): Total number of iterations on the data. + epoch (int): Generally, total number of iterations on the data per epoch. + When dataset_sink_mode is set to true and sink_size>0, each epoch sink sink_size + steps on the data instead of total number of iterations. train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should - be returned, and the data and label are passed to the network and loss + be returned. The data and label would be passed to the network and loss function respectively. - callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. + callbacks (list): List of callback objects which should be executed while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Configure pynative mode, the training process will be performed with dataset not sink. - + sink_size (int): Control the amount of data in each sink. + If sink_size=-1, sink the complete dataset for each epoch. + If sink_size>0, sink sink_size data for each epoch. + If dataset_sink_mode is False, set sink_size as invalid. Default: -1. Examples: >>> dataset = get_dataset() @@ -526,17 +551,19 @@ class Model: >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) >>> model.train(2, dataset) """ - repeat_count = train_dataset.get_repeat_count() - if epoch != repeat_count and dataset_sink_mode is True: - logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}") check_bool(dataset_sink_mode) + check_int(sink_size) + if sink_size < -1 or sink_size == 0: + raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) + _device_number_check(self._parallel_mode, self._device_number) _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) self._train(epoch, train_dataset, callbacks=callbacks, - dataset_sink_mode=dataset_sink_mode) + dataset_sink_mode=dataset_sink_mode, + sink_size=sink_size) def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): """ @@ -548,7 +575,7 @@ class Model: cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: - Dict, returns the loss value & metrics values for the model in test mode. + Dict, which returns the loss value and metrics values for the model in the test mode. """ run_context = RunContext(cb_params) @@ -587,7 +614,7 @@ class Model: cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: - Dict, returns the loss value & metrics values for the model in test mode. + Dict, which returns the loss value and metrics values for the model in the test mode. """ run_context = RunContext(cb_params) list_callback.begin(run_context) @@ -605,6 +632,8 @@ class Model: list_callback.step_end(run_context) self._update_metrics(outputs) + valid_dataset.reset() + metrics = self._get_metrics() cb_params.metrics = metrics list_callback.end(run_context) @@ -623,12 +652,11 @@ class Model: Args: valid_dataset (Dataset): Dataset to evaluate the model. - callbacks (list): List of callback object. Callbacks which should be excuted - while training. Default: None. + callbacks (list): List of callback objects which should be executed while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Returns: - Dict, returns the loss value & metrics values for the model in test mode. + Dict, which returns the loss value and metrics values for the model in the test mode. Examples: >>> dataset = get_dataset() @@ -649,6 +677,7 @@ class Model: cb_params.mode = "eval" cb_params.cur_step_num = 0 cb_params.list_callback = self._transform_callbacks(callbacks) + cb_params.network = self._network self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' @@ -662,9 +691,9 @@ class Model: def predict(self, *predict_data): """ - Generates output predictions for the input samples. + Generate output predictions for the input samples. - Data could be single tensor, or list of tensor, tuple of tensor. + Data could be a single tensor, a list of tensor, or a tuple of tensor. Note: Batch data should be put together in one tensor. diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index b553373f10..b947811030 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -29,14 +29,15 @@ from ...common import dtype as mstype from ...common.api import _executor from ...nn.layer import quant from ...ops import functional as F +from ...ops import operations as P from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, nn.ReLU6: quant.ActQuant, - nn.LeakyReLU: quant.ActQuant, nn.Sigmoid: quant.ActQuant, + nn.LeakyReLU: quant.LeakyReLUQuant, nn.HSigmoid: quant.HSigmoidQuant, nn.HSwish: quant.HSwishQuant} @@ -167,32 +168,60 @@ class ConvertToQuantNetwork: convert Conv2d cell to quant cell """ conv_inner = subcell.conv - if subcell.has_bn and self.bn_fold: - bn_inner = subcell.batchnorm - conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - quant_delay=self.weight_qdelay, - freeze_bn=self.freeze_bn, - per_channel=self.weight_channel, - num_bits=self.weight_bits, - fake=True, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) - # change original network BatchNormal OP parameters to quant network - conv_inner.gamma = subcell.batchnorm.gamma - conv_inner.beta = subcell.batchnorm.beta - conv_inner.moving_mean = subcell.batchnorm.moving_mean - conv_inner.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False + if subcell.has_bn: + if self.bn_fold: + bn_inner = subcell.batchnorm + conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + quant_delay=self.weight_qdelay, + freeze_bn=self.freeze_bn, + per_channel=self.weight_channel, + num_bits=self.weight_bits, + fake=True, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) + # change original network BatchNormal OP parameters to quant network + conv_inner.gamma = subcell.batchnorm.gamma + conv_inner.beta = subcell.batchnorm.beta + conv_inner.moving_mean = subcell.batchnorm.moving_mean + conv_inner.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False + else: + bn_inner = subcell.batchnorm + conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + has_bn=True, + quant_delay=self.weight_qdelay, + per_channel=self.weight_channel, + num_bits=self.weight_bits, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) + # change original network BatchNormal OP parameters to quant network + conv_inner.batchnorm.gamma = subcell.batchnorm.gamma + conv_inner.batchnorm.beta = subcell.batchnorm.beta + conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean + conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False else: conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, @@ -259,7 +288,7 @@ class ConvertToQuantNetwork: act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: raise ValueError("Unsupported activation in auto quant: ", act_class) - return _ACTIVATION_MAP[act_class](activation=act_class, + return _ACTIVATION_MAP[act_class](activation=activation, num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, @@ -278,7 +307,7 @@ class ExportToQuantInferNetwork: std_dev (int, float): Input data variance. Default: 127.5. Returns: - Cell, GEIR backend Infer network. + Cell, Infer network. """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] @@ -329,17 +358,15 @@ class ExportToQuantInferNetwork: return None # Build the `Quant` `Dequant` op. - # AscendQuant only support perlayer version. Need check here. - quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in)) + # Quant only support perlayer version. Need check here. + quant_op = inner.Quant(float(scale_a_in), float(zp_a_in)) sqrt_mode = False scale_deq = scale_a_out * scale_w if (scale_deq < 2 ** -14).all(): scale_deq = np.sqrt(scale_deq) sqrt_mode = True - dequant_op = inner.AscendDequant(sqrt_mode) + dequant_op = inner.Dequant(sqrt_mode) - # get op - op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv if isinstance(activation, _AddFakeQuantAfterSubCell): activation = activation.subcell elif hasattr(activation, "get_origin"): @@ -351,14 +378,21 @@ class ExportToQuantInferNetwork: if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): if cell_core.has_bias: bias = cell_core.bias.data.asnumpy() - elif isinstance(cell_core, quant.Conv2dBatchNormQuant): + elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant)): weight, bias = quant_utils.fold_batchnorm(weight, cell_core) # apply the quant - weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type) + weight = quant_utils.weight2int(weight, scale_w, zp_w) if bias is not None: bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) scale_deq = Tensor(scale_deq, mstype.float16) + # get op + if isinstance(cell_core, quant.DenseQuant): + op_core = P.MatMul() + weight = np.transpose(weight) + else: + op_core = cell_core.conv + weight = Tensor(weight, self.data_type) block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block @@ -411,11 +445,15 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' file_name (str): File name of model to export. mean (int): Input data mean. Default: 127.5. std_dev (int, float): Input data variance. Default: 127.5. - file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model. - - GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model. + file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'BINARY' format for exported + quantization aware model. Default: 'GEIR'. + + - GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of + Ascend model. + - BINARY: Binary format for model. An intermidiate representation format for models. """ - supported_device = ["Ascend"] - supported_formats = ['GEIR'] + supported_device = ["Ascend", "GPU"] + supported_formats = ['GEIR', 'BINARY'] mean = validator.check_type("mean", mean, (int, float)) std_dev = validator.check_type("std_dev", std_dev, (int, float)) @@ -428,10 +466,9 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' network.set_train(False) - if file_format == 'GEIR': - exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs) - deploy_net = exporter.run() - serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) + exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs) + deploy_net = exporter.run() + serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) def convert_quant_network(network, diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 69505970fd..1e2481ceaa 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -104,7 +104,7 @@ def weight2int(data, scale, zero_point): raise ValueError("`scale` and `zero_point` should have the same shape.") if scale.shape[0] < 0: raise ValueError("`scale` and `zero_point` shape should greater than zero.") - if len(scale.shape) > 1: + if len(scale.shape) >= 1 and scale.shape[0] > 1: # for perchannel if scale.shape[0] == data.shape[0]: # `Conv2d` or `Dense` op weight @@ -176,13 +176,13 @@ def scale_zp_from_data(op, minq, maxq, data_type): def fold_batchnorm(weight, cell_quant): r""" - Fold the batchnorm in `Conv2dBatchNormQuant` to weight. + Fold the batchnorm in `Conv2dBnFoldQuant` to weight. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. Args: weight (numpy.ndarray): Weight of `cell_quant`. - cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`. Returns: weight (numpy.ndarray): Folded weight. diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index bc74986321..e07cfa94c5 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -15,10 +15,11 @@ """Model and parameters serialization.""" import os import stat +import math +from threading import Thread, Lock import numpy as np import mindspore.nn as nn -import mindspore.context as context from mindspore import log as logger from mindspore.train.checkpoint_pb2 import Checkpoint from mindspore.train.print_pb2 import Print @@ -40,6 +41,9 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} +_ckpt_mutex = Lock() +SLICE_SIZE = 512 * 1024 * 1024 + def _special_process_par(par, new_par): """ @@ -101,7 +105,41 @@ def _update_param(param, new_param): param.set_parameter_data(type(param.data)(new_param.data)) -def save_checkpoint(parameter_list, ckpt_file_name): +def _exec_save(ckpt_file_name, data_list): + """Execute save checkpoint into file process.""" + + try: + with _ckpt_mutex: + if os.path.exists(ckpt_file_name): + os.remove(ckpt_file_name) + with open(ckpt_file_name, "ab") as f: + for name, value in data_list.items(): + data_size = value[2].nbytes + if data_size > SLICE_SIZE: + slice_count = math.ceil(data_size / SLICE_SIZE) + param_slice_list = np.array_split(value[2], slice_count) + else: + param_slice_list = [value[2]] + + for param_slice in param_slice_list: + checkpoint_list = Checkpoint() + param_value = checkpoint_list.value.add() + param_value.tag = name + param_tensor = param_value.tensor + param_tensor.dims.extend(value[0]) + param_tensor.tensor_type = value[1] + param_tensor.tensor_content = param_slice.tostring() + + f.write(checkpoint_list.SerializeToString()) + + os.chmod(ckpt_file_name, stat.S_IRUSR) + + except BaseException as e: + logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) + raise RuntimeError(e.__str__()) + + +def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): """ Saves checkpoint info to a specified file. @@ -109,37 +147,37 @@ def save_checkpoint(parameter_list, ckpt_file_name): parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. ckpt_file_name (str): Checkpoint file name. + async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False Raises: RuntimeError: Failed to save the Checkpoint file. """ logger.info("Execute save checkpoint process.") - checkpoint_list = Checkpoint() - try: + data_list = {} + with _ckpt_mutex: for param in parameter_list: - param_value = checkpoint_list.value.add() - param_value.tag = param["name"] - param_tensor = param_value.tensor + key = param["name"] + data_list[key] = [] if isinstance(param["data"], Parameter): param["data"].init_data() - param_data = param["data"].asnumpy().reshape(-1) - param_tensor.tensor_content = param_data.tostring() - param_tensor.tensor_type = str(param["data"].dtype) - + dims = [] if param['data'].shape == (): - param_tensor.dims.append(0) + dims.append(0) else: for dim in param['data'].shape: - param_tensor.dims.append(dim) - - with open(ckpt_file_name, "wb") as f: - f.write(checkpoint_list.SerializeToString()) - os.chmod(ckpt_file_name, stat.S_IRUSR) - - except BaseException as e: - logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) - raise RuntimeError(e.__str__()) + dims.append(dim) + data_list[key].append(dims) + tensor_type = str(param["data"].dtype) + data_list[key].append(tensor_type) + data = param["data"].asnumpy().reshape(-1) + data_list[key].append(data) + + if async_save: + thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt") + thr.start() + else: + _exec_save(ckpt_file_name, data_list) logger.info("Save checkpoint process finish.") @@ -182,28 +220,37 @@ def load_checkpoint(ckpt_file_name, net=None): parameter_dict = {} try: + element_id = 0 + param_data_list = [] for element in checkpoint_list.value: data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] - param_data = np.fromstring(data, np_type) - dims = element.tensor.dims - - if dims == [0]: - if 'Float' in data_type: - param_data = float(param_data[0]) - elif 'Int' in data_type: - param_data = int(param_data[0]) - parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) - elif dims == [1]: - parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) - else: - param_dim = [] - for dim in dims: - param_dim.append(dim) - param_value = param_data.reshape(param_dim) - parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) + element_data = np.frombuffer(data, np_type) + param_data_list.append(element_data) + if (element_id == len(checkpoint_list.value) - 1) or \ + (element.tag != checkpoint_list.value[element_id + 1].tag): + param_data = np.concatenate((param_data_list), axis=0) + param_data_list.clear() + dims = element.tensor.dims + + if dims == [0]: + if 'Float' in data_type: + param_data = float(param_data[0]) + elif 'Int' in data_type: + param_data = int(param_data[0]) + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) + elif dims == [1]: + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) + else: + param_dim = [] + for dim in dims: + param_dim.append(dim) + param_value = param_data.reshape(param_dim) + parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) + + element_id += 1 logger.info("Load checkpoint process finish.") @@ -211,7 +258,7 @@ def load_checkpoint(ckpt_file_name, net=None): logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) - if net: + if net is not None: load_param_into_net(net, parameter_dict) return parameter_dict @@ -248,7 +295,6 @@ def load_param_into_net(net, parameter_dict): logger.error("Failed to combine the net and the parameters.") msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) raise TypeError(msg) - param.init_data() _update_param(param, new_param) else: param_not_load.append(param.name) @@ -305,7 +351,7 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): +def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False): """ Saves checkpoint for 'ms' backend. @@ -313,16 +359,15 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): train_network (Network): The train network for training. ckpt_file_name (str): The name of checkpoint file. integrated_save (bool): Whether to integrated save in automatic model parallel scene. + async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False. """ - + train_network.init_parameters_data() param_dict = {} for _, param in train_network.parameters_and_names(): param_dict[param.name] = param - param_list = [] for (key, value) in param_dict.items(): each_param = {"name": key} - value.init_data() if isinstance(value.data, Tensor): param_data = value.data else: @@ -336,7 +381,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): each_param["data"] = param_data param_list.append(each_param) - save_checkpoint(param_list, ckpt_file_name) + save_checkpoint(param_list, ckpt_file_name, async_save) def _get_merged_param_data(net, param_name, param_data): @@ -359,14 +404,17 @@ def _get_merged_param_data(net, param_name, param_data): dev_mat = layout[0] tensor_map = layout[1] + field_size = layout[3] from mindspore.parallel._cell_wrapper import get_allgather_cell - from mindspore.parallel._tensor import _reshape_param_data + from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight # while any dim is not equal to -1, means param is splited and needs to be merged for dim in tensor_map: if dim != -1: allgather_net = get_allgather_cell() param_data = allgather_net(param_data) + if field_size[0]: + return _reshape_param_data_with_weight(param_data, dev_mat, field_size) return _reshape_param_data(param_data, dev_mat, tensor_map) return param_data @@ -405,18 +453,17 @@ def export(net, *inputs, file_name, file_format='GEIR'): net (Cell): MindSpore network. inputs (Tensor): Inputs of the `net`. file_name (str): File name of model to export. - file_format (str): MindSpore currently supports 'GEIR', 'ONNX' 'LITE' and 'BINARY' format for exported model. + file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'BINARY' format for exported model. - GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of Ascend model. - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. - - LITE: Huawei model format for mobile. A lite model only for the MindSpore Lite - BINARY: Binary format for model. An intermidiate representation format for models. """ logger.info("exporting model file:%s format:%s.", file_name, file_format) check_input_data(*inputs, data_class=Tensor) - supported_formats = ['GEIR', 'ONNX', 'LITE', 'BINARY'] + supported_formats = ['GEIR', 'ONNX', 'BINARY'] if file_format not in supported_formats: raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') # switch network mode to infer when it is training @@ -426,27 +473,25 @@ def export(net, *inputs, file_name, file_format='GEIR'): # export model net.init_parameters_data() if file_format == 'GEIR': - _executor.compile(net, *inputs, phase='export') - _executor.export(net, file_name, file_format) + phase_name = 'export.geir' + graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) + _executor.export(file_name, graph_id) elif file_format == 'ONNX': # file_format is 'ONNX' # NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline, # do not change it to other values. - phase_name = 'export_onnx' + phase_name = 'export.onnx' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(graph_id) with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) elif file_format == 'BINARY': # file_format is 'BINARY' - phase_name = 'export_binary' + phase_name = 'export.binary' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(graph_id, 'binary_ir') with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) - elif file_format == 'LITE': # file_format is 'LITE' - context.set_context(save_ms_model=True, save_ms_model_path=file_name) - net(*inputs) # restore network training mode if is_training: net.set_train(mode=True) diff --git a/mindspore/train/summary/_summary_writer.py b/mindspore/train/summary/_summary_writer.py index 36d020819a..2c288e16c6 100644 --- a/mindspore/train/summary/_summary_writer.py +++ b/mindspore/train/summary/_summary_writer.py @@ -15,6 +15,7 @@ """Writes events to disk in a logdir.""" import os import stat +from shutil import disk_usage from ..._c_expression import EventWriter_ from ._summary_adapter import package_init_event @@ -23,8 +24,8 @@ from ._summary_adapter import package_init_event class BaseWriter: """BaseWriter to be subclass.""" - def __init__(self, filepath) -> None: - self._filepath = filepath + def __init__(self, filepath, max_file_size=None) -> None: + self._filepath, self._max_file_size = filepath, max_file_size self._writer: EventWriter_ = None def init_writer(self): @@ -42,9 +43,23 @@ class BaseWriter: self.init_writer() return self._writer - def write(self, plugin, mode, data): + def write(self, plugin, data): """Write data to file.""" - raise NotImplementedError() + if self.writer and disk_usage(self._filepath).free < len(data) * 32: + raise RuntimeError(f"The disk space may be soon exhausted by the '{self._filepath}'.") + # 8: data length + # 4: crc32 of data length + # 4: crc32 of data + metadata_length = 8 + 4 + 4 + required_length = len(data) + metadata_length + if self._max_file_size is None: + self.writer.Write(data) + elif self._max_file_size >= required_length: + self._max_file_size -= required_length + self.writer.Write(data) + else: + raise RuntimeError(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " + f"but the '{self._filepath}' requires to write {required_length} bytes.") def flush(self): """Flush the writer.""" @@ -64,16 +79,16 @@ class SummaryWriter(BaseWriter): """Write some metadata etc.""" self.writer.Write(package_init_event().SerializeToString()) - def write(self, plugin, mode, data): + def write(self, plugin, data): """Write data to file.""" if plugin in ('summary', 'graph'): - self.writer.Write(data) + super().write(plugin, data) class LineageWriter(BaseWriter): """LineageWriter for write lineage.""" - def write(self, plugin, mode, data): + def write(self, plugin, data): """Write data to file.""" if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): - self.writer.Write(data) + super().write(plugin, data) diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index d9cdfd3c8c..888c5a90a4 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -18,6 +18,8 @@ import time from collections import deque from multiprocessing import Pool, Process, Queue, cpu_count +import mindspore.log as logger + from ._lineage_adapter import serialize_to_lineage_event from ._summary_adapter import package_graph_event, package_summary_event from ._summary_writer import LineageWriter, SummaryWriter @@ -25,20 +27,18 @@ from ._summary_writer import LineageWriter, SummaryWriter def _pack_data(datadict, wall_time): """Pack data according to which plugin.""" - result = [] - summaries, step, mode = [], None, None + result, summaries, step = [], [], None for plugin, datalist in datadict.items(): for data in datalist: if plugin == 'graph': - result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()]) + result.append([plugin, package_graph_event(data.get('value')).SerializeToString()]) elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'): - result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))]) + result.append([plugin, serialize_to_lineage_event(plugin, data.get('value'))]) elif plugin in ('scalar', 'tensor', 'histogram', 'image'): summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) step = data.get('step') - mode = data.get('mode') if summaries: - result.append(['summary', mode, package_summary_event(summaries, step, wall_time).SerializeToString()]) + result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()]) return result @@ -51,49 +51,68 @@ class WriterPool(Process): filelist (str): The mapping from short name to long filename. """ - def __init__(self, base_dir, **filedict) -> None: + def __init__(self, base_dir, max_file_size, **filedict) -> None: super().__init__() self._base_dir, self._filedict = base_dir, filedict - self._queue = Queue(cpu_count() * 2) + self._queue, self._writers_ = Queue(cpu_count() * 2), None + self._max_file_size = max_file_size self.start() def run(self): - writers = self._get_writers() - with Pool(min(cpu_count(), 32)) as pool: deq = deque() while True: while deq and deq[0].ready(): - for plugin, mode, data in deq.popleft().get(): - for writer in writers: - writer.write(plugin, mode, data) + for plugin, data in deq.popleft().get(): + self._write(plugin, data) if not self._queue.empty(): action, data = self._queue.get() if action == 'WRITE': deq.append(pool.apply_async(_pack_data, (data, time.time()))) elif action == 'FLUSH': - for writer in writers: - writer.flush() + self._flush() elif action == 'END': break for result in deq: - for plugin, mode, data in result.get(): - for writer in writers: - writer.write(plugin, mode, data) + for plugin, data in result.get(): + self._write(plugin, data) - for writer in writers: - writer.close() + self._close() - def _get_writers(self): - writers = [] + @property + def _writers(self): + """Get the writers in the subprocess.""" + if self._writers_ is not None: + return self._writers_ + self._writers_ = [] for plugin, filename in self._filedict.items(): filepath = os.path.join(self._base_dir, filename) if plugin == 'summary': - writers.append(SummaryWriter(filepath)) + self._writers_.append(SummaryWriter(filepath, self._max_file_size)) elif plugin == 'lineage': - writers.append(LineageWriter(filepath)) - return writers + self._writers_.append(LineageWriter(filepath, self._max_file_size)) + return self._writers_ + + def _write(self, plugin, data): + """Write the data in the subprocess.""" + for writer in self._writers[:]: + try: + writer.write(plugin, data) + except RuntimeError as e: + logger.warning(e.args[0]) + self._writers.remove(writer) + writer.close() + + def _flush(self): + """Flush the writers in the subprocess.""" + for writer in self._writers: + writer.flush() + + def _close(self): + """Close the writers in the subprocess.""" + for writer in self._writers: + writer.close() def write(self, data) -> None: """ diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 21c8c58d3b..18cecb2914 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -75,14 +75,17 @@ class SummaryRecord: Args: log_dir (str): The log_dir is a directory location to save the summary. - queue_max_size (int): The capacity of event queue.(reserved). Default: 0. - flush_time (int): Frequency to flush the summaries to disk, the unit is second. Default: 120. + queue_max_size (int): Deprecated. The capacity of event queue.(reserved). Default: 0. + flush_time (int): Deprecated. Frequency to flush the summaries to disk, the unit is second. Default: 120. file_prefix (str): The prefix of file. Default: "events". file_suffix (str): The suffix of file. Default: "_MS". network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. + max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. \ + Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. Raises: - TypeError: If `queue_max_size` and `flush_time` is not int, or `file_prefix` and `file_suffix` is not str. + TypeError: If `max_file_size`, `queue_max_size` or `flush_time` is not int, \ + or `file_prefix` and `file_suffix` is not str. RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname. Examples: @@ -103,7 +106,8 @@ class SummaryRecord: flush_time=120, file_prefix="events", file_suffix="_MS", - network=None): + network=None, + max_file_size=None): self._closed, self._event_writer = False, None self._mode, self._data_pool = 'train', _dictlist() @@ -113,11 +117,18 @@ class SummaryRecord: self.log_path = _make_directory(log_dir) + if not isinstance(max_file_size, (int, type(None))): + raise TypeError("The 'max_file_size' should be int type.") + if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): raise TypeError("`queue_max_size` and `flush_time` should be int") if not isinstance(file_prefix, str) or not isinstance(file_suffix, str): raise TypeError("`file_prefix` and `file_suffix` should be str.") + if max_file_size is not None and max_file_size < 0: + logger.warning("The 'max_file_size' should be greater than 0.") + max_file_size = None + self.queue_max_size = queue_max_size if queue_max_size < 0: # 0 is not limit @@ -142,6 +153,7 @@ class SummaryRecord: raise RuntimeError(ex) self._event_writer = WriterPool(log_dir, + max_file_size, summary=self.full_file_name, lineage=get_event_file_name('events', '_lineage')) atexit.register(self.close) @@ -152,7 +164,7 @@ class SummaryRecord: raise ValueError('SummaryRecord has been closed.') return self - def __exit__(self, extype, exvalue, traceback): + def __exit__(self, *err): """Exit the context manager.""" self.close() @@ -218,24 +230,26 @@ class SummaryRecord: if name in {item['tag'] for item in self._data_pool[plugin]}: entry = repr(f'{name}/{plugin}') logger.warning(f'{entry} has duplicate values. Only the newest one will be recorded.') - self._data_pool[plugin].append(dict(tag=name, mode=self._mode, value=np_value)) + self._data_pool[plugin].append(dict(tag=name, value=np_value)) elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): _check_lineage_value(plugin, value) - self._data_pool[plugin].append(dict(mode=self._mode, value=value.SerializeToString())) + self._data_pool[plugin].append(dict(value=value.SerializeToString())) elif plugin == 'graph': package_graph_event(value) - self._data_pool[plugin].append(dict(mode=self._mode, value=value)) + self._data_pool[plugin].append(dict(value=value)) else: raise ValueError(f'No such plugin of {repr(plugin)}') - def record(self, step, train_network=None): + def record(self, step, train_network=None, plugin_filter=None): """ Record the summary. Args: step (int): Represents training step number. train_network (Cell): The network that called the callback. + plugin_filter (Optional[Callable[[str], bool]]): The filter function, \ + which is used to filter out plugins from being written by return False. Returns: bool, whether the record process is successful or not. @@ -266,7 +280,14 @@ class SummaryRecord: if self._mode == 'train': self._add_summary_tensor_data() - self._event_writer.write(self._consume_data_pool(step)) + if not plugin_filter: + self._event_writer.write(self._consume_data_pool(step)) + else: + filtered = {} + for plugin, datalist in self._consume_data_pool(step).items(): + if plugin_filter(plugin): + filtered[plugin] = datalist + self._event_writer.write(filtered) return True def _add_summary_tensor_data(self): diff --git a/model_zoo/README.md b/model_zoo/README.md deleted file mode 100644 index 1e392445af..0000000000 --- a/model_zoo/README.md +++ /dev/null @@ -1,325 +0,0 @@ -![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) - - -# Welcome to the Model Zoo for MindSpore - -In order to facilitate developers to enjoy the benefits of MindSpore framework and Huawei chips, we will continue to add typical networks and models . If you have needs for the model zoo, you can file an issue on [gitee](https://gitee.com/mindspore/mindspore/issues) or [MindSpore](https://bbs.huaweicloud.com/forum/forum-1076-1.html), We will consider it in time. - -- SOTA models using the latest MindSpore APIs - -- The best benefits from MindSpore and Huawei chips - -- Officially maintained and supported - - - -# Table of Contents - -- [Models and Implementations](#models-and-implementations) - - [Computer Vision](#computer-vision) - - [Image Classification](#image-classification) - - [GoogleNet](#googlenet) - - [ResNet50[benchmark]](#resnet50) - - [ResNet101](#resnet101) - - [VGG16](#vgg16) - - [AlexNet](#alexnet) - - [LeNet](#lenet) - - [Object Detection and Segmentation](#object-detection-and-segmentation) - - [YoloV3](#yolov3) - - [MobileNetV2](#mobilenetv2) - - [MobileNetV3](#mobilenetv3) - - [SSD](#ssd) - - [Natural Language Processing](#natural-language-processing) - - [BERT](#bert) - - [MASS](#mass) - - [Transformer](#transformer) - - -# Announcements -| Date | News | -| ------------ | ------------------------------------------------------------ | -| May 31, 2020 | Support [MindSpore v0.3.0-alpha](https://www.mindspore.cn/news/newschildren?id=215) | - - -# Models and Implementations - -## Computer Vision - -### Image Classification - -#### [GoogleNet](#table-of-contents) -| Parameters | GoogleNet | -| -------------------------- | ------------------------------------------------------------ | -| Published Year | 2014 | -| Paper | [Going Deeper with Convolutions](https://arxiv.org/abs/1409.4842) | -| Resource | Ascend 910 | -| Features | • Mixed Precision • Multi-GPU training support with Ascend | -| MindSpore Version | 0.3.0-alpha | -| Dataset | CIFAR-10 | -| Training Parameters | epoch=125, batch_size = 128, lr=0.1 | -| Optimizer | Momentum | -| Loss Function | Softmax Cross Entropy | -| Accuracy | 1pc: 93.4%; 8pcs: 92.17% | -| Speed | 79 ms/Step | -| Loss | 0.0016 | -| Params (M) | 6.8 | -| Checkpoint for Fine tuning | 43.07M (.ckpt file) | -| Model for inference | 21.50M (.onnx file), 21.60M(.geir file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/googlenet | - -#### [ResNet50](#table-of-contents) - -| Parameters | ResNet50 | -| -------------------------- | -------- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Accuracy | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [ResNet101](#table-of-contents) - -| Parameters | ResNet101 | -| -------------------------- | --------- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Accuracy | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [VGG16](#table-of-contents) - -| Parameters | VGG16 | -| -------------------------- | ----- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Accuracy | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [AlexNet](#table-of-contents) - -| Parameters | AlexNet | -| -------------------------- | ------- | -| Published Year | 2012 | -| Paper | [ImageNet Classification with Deep Convolutional Neural Networks](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-) | -| Resource | Ascend 910 | -| Features | support with Ascend, GPU | -| MindSpore Version | 0.5.0-beta | -| Dataset | CIFAR10 | -| Training Parameters | epoch=30, batch_size=32 | -| Optimizer | Momentum | -| Loss Function | SoftmaxCrossEntropyWithLogits | -| Accuracy | 88.23% | -| Speed | 1481fps | -| Loss | 0.108 | -| Params (M) | 61.10 | -| Checkpoint for Fine tuning | 445MB(.ckpt file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/alexnet| - -#### [LeNet](#table-of-contents) - -| Parameters | LeNet | -| -------------------------- | ----- | -| Published Year | 1998 | -| Paper | [Gradient-Based Learning Applied to Document Recognition](https://ieeexplore.ieee.org/abstract/document/726791) | -| Resource | Ascend 910 | -| Features | support with Ascend, GPU, CPU | -| MindSpore Version | 0.5.0-beta | -| Dataset | MNIST | -| Training Parameters | epoch=10, batch_size=32 | -| Optimizer | Momentum | -| Loss Function | SoftmaxCrossEntropyWithLogits | -| Accuracy | 98.52% | -| Speed | 18680fps | -| Loss | 0.004 | -| Params (M) | 0.06 | -| Checkpoint for Fine tuning | 483KB(.ckpt file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/lenet| - -### Object Detection and Segmentation - -#### [YoloV3](#table-of-contents) - -| Parameters | YoLoV3 | -| -------------------------------- | ------ | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Mean Average Precision (mAP@0.5) | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [MobileNetV2](#table-of-contents) - -| Parameters | MobileNetV2 | -| -------------------------------- | ----------- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Mean Average Precision (mAP@0.5) | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [MobileNetV3](#table-of-contents) - -| Parameters | MobileNetV3 | -| -------------------------------- | ----------- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Mean Average Precision (mAP@0.5) | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [SSD](#table-of-contents) - -| Parameters | SSD | -| -------------------------------- | ---- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Mean Average Precision (mAP@0.5) | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -## Natural Language Processing - -#### [BERT](#table-of-contents) - -| Parameters | BERT | -| -------------------------- | ---- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| GLUE Score | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [MASS](#table-of-contents) - -| Parameters | MASS | -| -------------------------- | ---- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| ROUGE Score | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | - -#### [Transformer](#table-of-contents) - -| Parameters | Transformer | -| -------------------------- | -------------------------------------------------------------- | -| Published Year | 2017 | -| Paper | [Attention Is All You Need ](https://arxiv.org/abs/1706.03762) | -| Resource | Ascend 910 | -| Features | • Multi-GPU training support with Ascend | -| MindSpore Version | 0.5.0-beta | -| Dataset | WMT Englis-German | -| Training Parameters | epoch=52, batch_size=96 | -| Optimizer | Adam | -| Loss Function | Softmax Cross Entropy | -| BLEU Score | 28.7 | -| Speed | 410ms/step (8pcs) | -| Loss | 2.8 | -| Params (M) | 213.7 | -| Checkpoint for inference | 2.4G (.ckpt file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/Transformer | - -#### License - -[Apache License 2.0](https://github.com/mindspore-ai/mindspore/blob/master/LICENSE) diff --git a/model_zoo/Transformer/README.md b/model_zoo/Transformer/README.md deleted file mode 100644 index 7ba0c8eb3d..0000000000 --- a/model_zoo/Transformer/README.md +++ /dev/null @@ -1,176 +0,0 @@ -# Transformer Example -## Description -This example implements training and evaluation of Transformer Model, which is introduced in the following paper: -- Ashish Vaswani, Noam Shazeer, Niki Parmar, JakobUszkoreit, Llion Jones, Aidan N Gomez, Ł ukaszKaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS 2017, pages 5998–6008. - -## Requirements -- Install [MindSpore](https://www.mindspore.cn/install/en). -- Download and preprocess the WMT English-German dataset for training and evaluation. - -> Notes:If you are running an evaluation task, prepare the corresponding checkpoint file. - -## Example structure - -```shell -. -└─Transformer - ├─README.md - ├─scripts - ├─process_output.sh - ├─replace-quote.perl - ├─run_distribute_train.sh - └─run_standalone_train.sh - ├─src - ├─__init__.py - ├─beam_search.py - ├─config.py - ├─dataset.py - ├─eval_config.py - ├─lr_schedule.py - ├─process_output.py - ├─tokenization.py - ├─transformer_for_train.py - ├─transformer_model.py - └─weight_init.py - ├─create_data.py - ├─eval.py - └─train.py -``` - ---- - -## Prepare the dataset -- You may use this [shell script](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files: - - train.tok.clean.bpe.32000.en - - train.tok.clean.bpe.32000.de - - vocab.bpe.32000 - - newstest2014.tok.bpe.32000.en - - newstest2014.tok.bpe.32000.de - - newstest2014.tok.de - -- Convert the original data to mindrecord for training: - - ``` bash - paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all - python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 - ``` -- Convert the original data to mindrecord for evaluation: - - ``` bash - paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all - python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True - ``` - -## Running the example - -### Training -- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#mindspore) for more information about dataset. - -- Run `run_standalone_train.sh` for non-distributed training of Transformer model. - - ``` bash - sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH - ``` -- Run `run_distribute_train.sh` for distributed training of Transformer model. - - ``` bash - sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_PATH MINDSPORE_HCCL_CONFIG_PATH - ``` - -### Evaluation -- Set options in `eval_config.py`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path. - -- Run `eval.py` for evaluation of Transformer model. - - ```bash - python eval.py - ``` - -- Run `process_output.sh` to process the output token ids to get the real translation results. - - ```bash - sh scripts/process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE - ``` - You will get two files, REF_DATA.forbleu and EVAL_OUTPUT.forbleu, for BLEU score calculation. - -- Calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score. - - ```bash - perl multi-bleu.perl REF_DATA.forbleu < EVAL_OUTPUT.forbleu - ``` - ---- - -## Usage - -### Training -``` -usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] - [--enable_save_ckpt ENABLE_SAVE_CKPT] - [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] - [--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N] - [--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH] - [--data_path DATA_PATH] - -options: - --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" - --epoch_size epoch size: N, default is 52 - --device_num number of used devices: N, default is 1 - --device_id device id: N, default is 0 - --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" - --enable_lossscale enable lossscale: "true" | "false", default is "true" - --do_shuffle enable shuffle: "true" | "false", default is "true" - --enable_data_sink enable data sink: "true" | "false", default is "false" - --checkpoint_path path to load checkpoint files: PATH, default is "" - --save_checkpoint_steps steps for saving checkpoint files: N, default is 2500 - --save_checkpoint_num number for saving checkpoint files: N, default is 30 - --save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/" - --data_path path to dataset file: PATH, default is "" -``` - -## Options and Parameters -It contains of parameters of Transformer model and options for training and evaluation, which is set in file `config.py` and `evaluation_config.py` respectively. -### Options: -``` -config.py: - transformer_network version of Transformer model: base | large, default is large - init_loss_scale_value initial value of loss scale: N, default is 2^10 - scale_factor factor used to update loss scale: N, default is 2 - scale_window steps for once updatation of loss scale: N, default is 2000 - optimizer optimizer used in the network: Adam, default is "Adam" - -eval_config.py: - transformer_network version of Transformer model: base | large, default is large - data_file data file: PATH - model_file checkpoint file to be loaded: PATH - output_file output file of evaluation: PATH -``` - -### Parameters: -``` -Parameters for dataset and network (Training/Evaluation): - batch_size batch size of input dataset: N, default is 96 - seq_length length of input sequence: N, default is 128 - vocab_size size of each embedding vector: N, default is 36560 - hidden_size size of Transformer encoder layers: N, default is 1024 - num_hidden_layers number of hidden layers: N, default is 6 - num_attention_heads number of attention heads: N, default is 16 - intermediate_size size of intermediate layer: N, default is 4096 - hidden_act activation function used: ACTIVATION, default is "relu" - hidden_dropout_prob dropout probability for TransformerOutput: Q, default is 0.3 - attention_probs_dropout_prob dropout probability for TransformerAttention: Q, default is 0.3 - max_position_embeddings maximum length of sequences: N, default is 128 - initializer_range initialization value of TruncatedNormal: Q, default is 0.02 - label_smoothing label smoothing setting: Q, default is 0.1 - input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True - beam_width beam width setting: N, default is 4 - max_decode_length max decode length in evaluation: N, default is 80 - length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0 - compute_type compute type in Transformer: mstype.float16 | mstype.float32, default is mstype.float16 - -Parameters for learning rate: - learning_rate value of learning rate: Q - warmup_steps steps of the learning rate warm up: N - start_decay_step step of the learning rate to decay: N - min_lr minimal learning rate: Q -``` \ No newline at end of file diff --git a/model_zoo/Transformer/scripts/run_distribute_train.sh b/model_zoo/Transformer/scripts/run_distribute_train.sh deleted file mode 100644 index 772e690dc2..0000000000 --- a/model_zoo/Transformer/scripts/run_distribute_train.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH MINDSPORE_HCCL_CONFIG_PATH" -echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json" -echo "It is better to use absolute path." -echo "==============================================================================================================" - -rm -rf run_distribute_train -mkdir run_distribute_train -cd run_distribute_train || exit - -EPOCH_SIZE=$2 -DATA_PATH=$3 - -export MINDSPORE_HCCL_CONFIG_PATH=$4 -export RANK_TABLE_FILE=$4 -export RANK_SIZE=$1 -export HCCL_FLAG=1 -export DEPLOY_MODE=0 - -for((i=0;i env.log - python train.py \ - --distribute="true" \ - --epoch_size=$EPOCH_SIZE \ - --device_id=$DEVICE_ID \ - --device_num=$RANK_SIZE \ - --enable_save_ckpt="true" \ - --enable_lossscale="true" \ - --do_shuffle="true" \ - --enable_data_sink="false" \ - --checkpoint_path="" \ - --save_checkpoint_steps=2500 \ - --save_checkpoint_num=30 \ - --data_path=$DATA_PATH > log.txt 2>&1 & - cd ../ -done -cd .. \ No newline at end of file diff --git a/model_zoo/Transformer/src/dataset.py b/model_zoo/Transformer/src/dataset.py deleted file mode 100644 index 5b006046a5..0000000000 --- a/model_zoo/Transformer/src/dataset.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Data operations, will be used in train.py.""" - -import mindspore.common.dtype as mstype -import mindspore.dataset.engine.datasets as de -import mindspore.dataset.transforms.c_transforms as deC -from mindspore import log as logger -from .config import transformer_net_cfg - -def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true", - dataset_path=None): - """create dataset""" - repeat_count = epoch_count - ds = de.MindDataset(dataset_path, - columns_list=["source_eos_ids", "source_eos_mask", - "target_sos_ids", "target_sos_mask", - "target_eos_ids", "target_eos_mask"], - shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id) - - type_cast_op = deC.TypeCast(mstype.int32) - ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op) - ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op) - ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op) - ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op) - ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op) - ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op) - - # apply batch operations - ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) - ds = ds.repeat(repeat_count) - - ds.channel_name = 'transformer' - logger.info("data size: {}".format(ds.get_dataset_size())) - logger.info("repeatcount: {}".format(ds.get_repeat_count())) - return ds, repeat_count diff --git a/model_zoo/Transformer/src/transformer_model.py b/model_zoo/Transformer/src/transformer_model.py deleted file mode 100644 index 409f8965eb..0000000000 --- a/model_zoo/Transformer/src/transformer_model.py +++ /dev/null @@ -1,1158 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Transformer model.""" - -import math -import copy -import numpy as np -import mindspore.common.dtype as mstype -import mindspore.nn as nn -import mindspore.ops.functional as F -from mindspore.ops import operations as P -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter -from .beam_search import BeamSearchDecoder, TileBeam -from .weight_init import normal_weight, weight_variable - -class TransformerConfig: - """ - Configuration for `Transformer`. - - Args: - batch_size (int): Batch size of input dataset. - seq_length (int): Length of input sequence. Default: 128. - vocab_size (int): The shape of each embedding vector. Default: 36560. - hidden_size (int): Size of the layers. Default: 1024. - num_hidden_layers (int): Number of hidden layers in the Transformer encoder/decoder - cell. Default: 6. - num_attention_heads (int): Number of attention heads in the Transformer - encoder/decoder cell. Default: 16. - intermediate_size (int): Size of intermediate layer in the Transformer - encoder/decoder cell. Default: 4096. - hidden_act (str): Activation function used in the Transformer encoder/decoder - cell. Default: "relu". - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.3. - attention_probs_dropout_prob (float): The dropout probability for - MultiheadAttention. Default: 0.3. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 128. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - label_smoothing (float): label smoothing setting. Default: 0.1 - input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from - dataset. Default: True. - beam_width (int): beam width setting. Default: 4 - max_decode_length (int): max decode length in evaluation. Default: 80 - length_penalty_weight (float): normalize scores of translations according to their length. Default: 1.0 - dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. - compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. - """ - def __init__(self, - batch_size, - seq_length=128, - vocab_size=36560, - hidden_size=1024, - num_hidden_layers=6, - num_attention_heads=16, - intermediate_size=4096, - hidden_act="relu", - hidden_dropout_prob=0.3, - attention_probs_dropout_prob=0.3, - max_position_embeddings=128, - initializer_range=0.02, - label_smoothing=0.1, - input_mask_from_dataset=True, - beam_width=4, - max_decode_length=80, - length_penalty_weight=1.0, - dtype=mstype.float32, - compute_type=mstype.float32): - self.batch_size = batch_size - self.seq_length = seq_length - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.initializer_range = initializer_range - self.label_smoothing = label_smoothing - self.input_mask_from_dataset = input_mask_from_dataset - self.beam_width = beam_width - self.max_decode_length = max_decode_length - self.length_penalty_weight = length_penalty_weight - self.dtype = dtype - self.compute_type = compute_type - - -class EmbeddingLookup(nn.Cell): - """ - A embeddings lookup table with a fixed dictionary and size. - - Args: - vocab_size (int): Size of the dictionary of embeddings. - embedding_size (int): The size of each embedding vector. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - """ - def __init__(self, - vocab_size, - embedding_size, - use_one_hot_embeddings=False, - initializer_range=0.02): - super(EmbeddingLookup, self).__init__() - self.vocab_size = vocab_size - self.embedding_size = embedding_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size), - name='embedding_table') - self.expand = P.ExpandDims() - self.shape_flat = (-1,) - self.gather = P.GatherV2() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = P.Shape() - - def construct(self, input_ids): - input_shape = self.shape(input_ids) - - flat_ids = self.reshape(input_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) - output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) - else: - output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) - - out_shape = input_shape + (self.embedding_size,) - output = self.reshape(output_for_reshape, out_shape) - return output, self.embedding_table - - -def position_encoding(length, - depth, - min_timescale=1, - max_timescale=1e4): - """ - Create Tensor of sinusoids of different frequencies. - - Args: - length (int): Length of the Tensor to create, i.e. Number of steps. - depth (int): Hidden size. - min_timescale (float): Default: 1. - max_timescale (float): Default: 10000. - - Returns: - Tensor of shape (length, depth) - """ - depth = depth // 2 - positions = np.arange(length, dtype=np.float32) - log_timescale_increment = (np.log(max_timescale / min_timescale) / (depth - 1)) - inv_timescales = min_timescale * np.exp(np.arange(depth, dtype=np.float32) * -log_timescale_increment) - scaled_time = np.expand_dims(positions, 1) * np.expand_dims(inv_timescales, 0) - x = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) - return x - - -class EmbeddingPostprocessor(nn.Cell): - """ - Postprocessors apply positional embeddings to word embeddings. - - Args: - embedding_size (int): The size of each embedding vector. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 128. - dropout_prob (float): The dropout probability. Default: 0.1. - """ - def __init__(self, - embedding_size, - use_one_hot_embeddings=False, - initializer_range=0.02, - max_position_embeddings=128, - dropout_prob=0.1): - super(EmbeddingPostprocessor, self).__init__() - self.scores_mul = Tensor([math.sqrt(float(embedding_size))], dtype=mstype.float32) - self.multiply = P.Mul() - self.add = P.TensorAdd() - self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32) - self.use_dropout = dropout_prob > 0 - self.expand_dims = P.ExpandDims() - self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size), - mstype.float32) - self.shape = P.Shape() - - def construct(self, word_embeddings): - input_shape = self.shape(word_embeddings) - input_len = input_shape[1] - - output = self.multiply(word_embeddings, self.scores_mul) - - # add position embeddings - position_embeddings = self.position_embedding_table[0:input_len:1, ::] - position_embeddings = self.expand_dims(position_embeddings, 0) - output = self.add(output, position_embeddings) - - if self.use_dropout: - output = self.dropout(output) - return output - - -class CastWrapper(nn.Cell): - """ - Cast wrapper. - """ - def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): - super(CastWrapper, self).__init__() - self.cast = P.Cast() - self.dst_type = dst_type - - def construct(self, x): - return self.cast(x, self.dst_type) - - -class LayerPreprocess(nn.Cell): - """ - preprocess input of each layer. - """ - def __init__(self, - in_channels=None): - super(LayerPreprocess, self).__init__() - self.layernorm = nn.LayerNorm((in_channels,)) - self.cast = P.Cast() - self.get_dtype = P.DType() - - def construct(self, input_tensor): - output = self.cast(input_tensor, mstype.float32) - output = self.layernorm(output) - output = self.cast(output, self.get_dtype(input_tensor)) - return output - - -class LayerPostprocess(nn.Cell): - """ - postprocess ouput of each layer. - """ - def __init__(self, - dropout_prob=0.1): - super(LayerPostprocess, self).__init__() - self.add = P.TensorAdd() - self.dropout = nn.Dropout(1 - dropout_prob) - self.use_dropout = dropout_prob > 0 - - def construct(self, hidden_tensor, input_tensor): - output = hidden_tensor - if self.use_dropout: - output = self.dropout(output) - output = self.add(output, input_tensor) - return output - - -class MultiheadAttention(nn.Cell): - """ - Apply multi-headed attention from "from_tensor" to "to_tensor". - - Args: - batch_size (int): Batch size of input datasets. - from_tensor_width (int): Size of last dim of from_tensor. - to_tensor_width (int): Size of last dim of to_tensor. - from_seq_length (int): Length of from_tensor sequence. - to_seq_length (int): Length of to_tensor sequence. - num_attention_heads (int): Number of attention heads. Default: 1. - size_per_head (int): Size of each attention head. Default: 512. - query_act (str): Activation function for the query transform. Default: None. - key_act (str): Activation function for the key transform. Default: None. - value_act (str): Activation function for the value transform. Default: None. - has_attention_mask (bool): Specifies whether to use attention mask. Default: False. - attention_probs_dropout_prob (float): The dropout probability for - MultiheadAttention. Default: 0.0. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d - tensor. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - from_tensor_width, - to_tensor_width, - out_tensor_width, - from_seq_length, - to_seq_length, - num_attention_heads=1, - size_per_head=512, - query_act=None, - key_act=None, - value_act=None, - out_act=None, - has_attention_mask=True, - attention_probs_dropout_prob=0.0, - use_one_hot_embeddings=False, - initializer_range=0.02, - do_return_2d_tensor=True, - compute_type=mstype.float32): - super(MultiheadAttention, self).__init__() - self.batch_size = batch_size - self.from_seq_length = from_seq_length - self.to_seq_length = to_seq_length - self.num_attention_heads = num_attention_heads - self.size_per_head = size_per_head - self.has_attention_mask = has_attention_mask - assert has_attention_mask - - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) - self.reshape = P.Reshape() - self.shape_from_2d = (-1, from_tensor_width) - self.shape_to_2d = (-1, to_tensor_width) - units = num_attention_heads * size_per_head - self.query_layer = nn.Dense(from_tensor_width, - units, - activation=query_act, - has_bias=False, - weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type) - self.key_layer = nn.Dense(to_tensor_width, - units, - activation=key_act, - has_bias=False, - weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type) - self.value_layer = nn.Dense(to_tensor_width, - units, - activation=value_act, - has_bias=False, - weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type) - self.out_layer = nn.Dense(units, - out_tensor_width, - activation=out_act, - has_bias=False, - weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type) - - self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) - self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head) - - self.matmul_trans_b = P.BatchMatMul(transpose_b=True) - self.multiply = P.Mul() - self.transpose = P.Transpose() - self.trans_shape = (0, 2, 1, 3) - self.trans_shape_relative = (2, 0, 1, 3) - self.trans_shape_position = (1, 2, 0, 3) - self.multiply_data = Tensor([-10000.0,], dtype=compute_type) - self.batch_num = batch_size * num_attention_heads - self.matmul = P.BatchMatMul() - - self.softmax = nn.Softmax() - self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) - self.use_dropout = attention_probs_dropout_prob > 0 - - if self.has_attention_mask: - self.expand_dims = P.ExpandDims() - self.sub = P.Sub() - self.add = P.TensorAdd() - self.cast = P.Cast() - self.get_dtype = P.DType() - if do_return_2d_tensor: - self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) - if from_seq_length == -1: - self.shape_return = (-1, num_attention_heads * size_per_head) - else: - self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) - - self.cast_compute_type = CastWrapper(dst_type=compute_type) - self.softmax_cast = P.Cast() - - def construct(self, from_tensor, to_tensor, attention_mask=None): - # reshape 2d/3d input tensors to 2d - from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) - to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) - query_out = self.query_layer(from_tensor_2d) - key_out = self.key_layer(to_tensor_2d) - value_out = self.value_layer(to_tensor_2d) - - query_layer = self.reshape(query_out, self.shape_from) - query_layer = self.transpose(query_layer, self.trans_shape) - key_layer = self.reshape(key_out, self.shape_to) - key_layer = self.transpose(key_layer, self.trans_shape) - - attention_scores = self.matmul_trans_b(query_layer, key_layer) - attention_scores = self.multiply(attention_scores, self.scores_mul) - - if self.has_attention_mask: - attention_mask = self.expand_dims(attention_mask, 1) - multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), - self.cast(attention_mask, self.get_dtype(attention_scores))) - adder = self.multiply(multiply_out, self.multiply_data) - attention_scores = self.add(adder, attention_scores) - - attention_scores = self.softmax_cast(attention_scores, mstype.float32) - attention_probs = self.softmax(attention_scores) - attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key_layer)) - if self.use_dropout: - attention_probs = self.dropout(attention_probs) - - value_layer = self.reshape(value_out, self.shape_to) - value_layer = self.transpose(value_layer, self.trans_shape) - context_layer = self.matmul(attention_probs, value_layer) - - context_layer = self.transpose(context_layer, self.trans_shape) - context_layer = self.reshape(context_layer, self.shape_return) - context_layer = self.out_layer(context_layer) - return context_layer - - -class SelfAttention(nn.Cell): - """ - Apply self-attention. - - Args: - batch_size (int): Batch size of input dataset. - from_seq_length (int): Length of query sequence. - to_seq_length (int): Length of memory sequence. - hidden_size (int): Size of attention layers. - num_attention_heads (int): Number of attention heads. Default: 16. - attention_probs_dropout_prob (float): The dropout probability for - SelfAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. - has_attention_mask (bool): Specifies whether has attention mask. Default: True. - is_encdec_att (bool): Specifies whether query sequence and memory sequence are different. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - from_seq_length, - to_seq_length, - hidden_size, - num_attention_heads=16, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - has_attention_mask=True, - is_encdec_att=False, - compute_type=mstype.float32): - super(SelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError("The hidden size (%d) is not a multiple of the number " - "of attention heads (%d)" % (hidden_size, num_attention_heads)) - self.size_per_head = int(hidden_size / num_attention_heads) - self.is_encdec_att = is_encdec_att - - self.attention = MultiheadAttention( - batch_size=batch_size, - from_tensor_width=hidden_size, - to_tensor_width=hidden_size, - out_tensor_width=hidden_size, - from_seq_length=from_seq_length, - to_seq_length=to_seq_length, - num_attention_heads=num_attention_heads, - size_per_head=self.size_per_head, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - has_attention_mask=has_attention_mask, - do_return_2d_tensor=True, - compute_type=compute_type) - - self.preprocess = LayerPreprocess(in_channels=hidden_size) - self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) - - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - def construct(self, input_tensor, memory_tensor, attention_mask): - input_tensor = self.reshape(input_tensor, self.shape) - memory_tensor = self.reshape(memory_tensor, self.shape) - - output = self.preprocess(input_tensor) - - if not self.is_encdec_att: - memory_tensor = output - - attention_output = self.attention(output, memory_tensor, attention_mask) - output = self.postprocess(attention_output, input_tensor) - return output - - -class FeedForward(nn.Cell): - """ - Apply two-layer feed forward - - Args: - in_channels (int): Size of the input layer. - hidden_size (int): Size of the hidden layer. - out_channels (int): Size of the output layers. - hidden_act (str): name of the activation function. Default: relu - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. - compute_type (:class:`mindspore.dtype`): Compute type in FeedForward. Default: mstype.float32. - """ - def __init__(self, - in_channels, - hidden_size, - out_channels, - hidden_act="relu", - initializer_range=0.02, - hidden_dropout_prob=0.1, - compute_type=mstype.float32): - super(FeedForward, self).__init__() - - self.conv1 = nn.Dense(in_channels, - hidden_size, - activation=hidden_act, - weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type) - self.conv2 = nn.Dense(hidden_size, - out_channels, - weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type) - - self.preprocess = LayerPreprocess(in_channels=in_channels) - self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) - - self.reshape = P.Reshape() - self.shape = (-1, in_channels) - self.dropout = nn.Dropout(1 - hidden_dropout_prob) - self.use_dropout = hidden_dropout_prob > 0 - - def construct(self, input_tensor): - input_tensor = self.reshape(input_tensor, self.shape) - output = self.preprocess(input_tensor) - output = self.conv1(output) - if self.use_dropout: - output = self.dropout(output) - output = self.conv2(output) - output = self.postprocess(output, input_tensor) - return output - - -class EncoderCell(nn.Cell): - """ - Encoder cells used in Transformer. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the encoder layers. Default: 1024. - seq_length (int): Length of input sequence. Default: 128. - num_attention_heads (int): Number of attention heads. Default: 16. - intermediate_size (int): Size of intermediate layer. Default: 4096. - attention_probs_dropout_prob (float): The dropout probability for - SelfAttention. Default: 0.02. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.1. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. - hidden_act (str): Activation function. Default: "relu". - compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - hidden_size=1024, - seq_length=128, - num_attention_heads=16, - intermediate_size=4096, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - hidden_act="relu", - compute_type=mstype.float32): - super(EncoderCell, self).__init__() - self.attention = SelfAttention( - batch_size=batch_size, - hidden_size=hidden_size, - from_seq_length=seq_length, - to_seq_length=seq_length, - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - is_encdec_att=False, - compute_type=compute_type) - self.feedforward = FeedForward( - in_channels=hidden_size, - hidden_size=intermediate_size, - out_channels=hidden_size, - hidden_act=hidden_act, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - compute_type=compute_type) - - def construct(self, hidden_states, attention_mask): - # self-attention with ln, res - attention_output = self.attention(hidden_states, hidden_states, attention_mask) - # feed forward with ln, res - output = self.feedforward(attention_output) - return output - - -class TransformerEncoder(nn.Cell): - """ - Multi-layer transformer encoder. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the encoder layers. - seq_length (int): Length of input sequence. - num_hidden_layers (int): Number of hidden layers in encoder cells. - num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. - intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. - attention_probs_dropout_prob (float): The dropout probability for - SelfAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.. - hidden_act (str): Activation function used in the encoder cells. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. - """ - def __init__(self, - batch_size, - hidden_size, - seq_length, - num_hidden_layers, - num_attention_heads=16, - intermediate_size=4096, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - hidden_act="relu", - compute_type=mstype.float32): - super(TransformerEncoder, self).__init__() - self.num_hidden_layers = num_hidden_layers - - layers = [] - for _ in range(num_hidden_layers): - layer = EncoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - layers.append(layer) - self.layers = nn.CellList(layers) - - self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) - - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - self.out_shape = (batch_size, seq_length, hidden_size) - - def construct(self, input_tensor, attention_mask): - prev_output = self.reshape(input_tensor, self.shape) - - for layer_module in self.layers: - layer_output = layer_module(prev_output, attention_mask) - prev_output = layer_output - - prev_output = self.layer_preprocess(prev_output) - output = self.reshape(prev_output, self.out_shape) - return output - - -class DecoderCell(nn.Cell): - """ - decoder cells used in Transformer. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the Transformer decoder layers. Default: 1024. - seq_length (int): Length of input sequence. Default: 128. - enc_seq_length (int): Length of source sentences. Default:128 - num_attention_heads (int): Number of attention heads. Default: 12. - intermediate_size (int): Size of intermediate layer. Default: 4096. - attention_probs_dropout_prob (float): The dropout probability for - SelfAttention. Default: 0.02. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. - hidden_act (str): Activation function. Default: "relu". - compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - hidden_size=1024, - seq_length=128, - enc_seq_length=128, - num_attention_heads=12, - intermediate_size=4096, - attention_probs_dropout_prob=0.02, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - hidden_act="relu", - compute_type=mstype.float32): - super(DecoderCell, self).__init__() - self.self_attention = SelfAttention( - batch_size=batch_size, - hidden_size=hidden_size, - from_seq_length=seq_length, - to_seq_length=seq_length, - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - is_encdec_att=False, - hidden_dropout_prob=hidden_dropout_prob, - compute_type=compute_type) - self.cross_attention = SelfAttention( - batch_size=batch_size, - hidden_size=hidden_size, - from_seq_length=seq_length, - to_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - is_encdec_att=True, - hidden_dropout_prob=hidden_dropout_prob, - compute_type=compute_type) - self.feedforward = FeedForward( - in_channels=hidden_size, - hidden_size=intermediate_size, - out_channels=hidden_size, - hidden_act=hidden_act, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - compute_type=compute_type) - - def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask): - # self-attention with ln, res - attention_output = self.self_attention(hidden_states, hidden_states, attention_mask) - # cross-attention with ln, res - attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask) - # feed forward with ln, res - output = self.feedforward(attention_output) - return output - - -class TransformerDecoder(nn.Cell): - """ - Multi-layer transformer decoder. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the encoder layers. - seq_length (int): Length of input sequence. - enc_seq_length (int): Length of source sentences. - num_hidden_layers (int): Number of hidden layers in encoder cells. - num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. - intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. - attention_probs_dropout_prob (float): The dropout probability for - SelfAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. - hidden_act (str): Activation function used in the encoder cells. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. - """ - def __init__(self, - batch_size, - hidden_size, - seq_length, - enc_seq_length, - num_hidden_layers, - num_attention_heads=16, - intermediate_size=4096, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - hidden_act="relu", - compute_type=mstype.float32): - super(TransformerDecoder, self).__init__() - self.num_hidden_layers = num_hidden_layers - - layers = [] - for _ in range(num_hidden_layers): - layer = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - layers.append(layer) - self.layers = nn.CellList(layers) - - self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) - - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - self.out_shape = (batch_size, seq_length, hidden_size) - - def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): - prev_output = self.reshape(input_tensor, self.shape) - - for layer_module in self.layers: - layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) - prev_output = layer_output - - prev_output = self.layer_preprocess(prev_output) - output = self.reshape(prev_output, self.out_shape) - return output - - -class CreateAttentionMaskFromInputMask(nn.Cell): - """ - Create attention mask according to input mask. - - Args: - config (:class:`TransformerConfig`): Configuration for Transformer. - """ - def __init__(self): - super(CreateAttentionMaskFromInputMask, self).__init__() - self.cast = P.Cast() - self.reshape = P.Reshape() - self.shape = P.Shape() - self.batch_matmul = P.BatchMatMul() - - def construct(self, input_mask): - input_shape = self.shape(input_mask) - shape_right = (input_shape[0], 1, input_shape[1]) - shape_left = input_shape + (1,) - - input_mask = self.cast(input_mask, mstype.float32) - mask_left = self.reshape(input_mask, shape_left) - mask_right = self.reshape(input_mask, shape_right) - attention_mask = self.batch_matmul(mask_left, mask_right) - - return attention_mask - - -class PredLogProbs(nn.Cell): - """ - Get log probs. - - Args: - batch_size (int): Batch size. - seq_length (int): Length of input sequence. - width (int): Hidden size. - compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. - dtype (:class:`mindspore.dtype`): Compute type to compute log_softmax. Default: mstype.float32. - """ - def __init__(self, - batch_size, - seq_length, - width, - compute_type=mstype.float32, - dtype=mstype.float32): - super(PredLogProbs, self).__init__() - self.batch_size = batch_size - self.seq_length = seq_length - self.width = width - self.compute_type = compute_type - self.dtype = dtype - - self.reshape = P.Reshape() - self.matmul = P.MatMul(transpose_b=True) - self.log_softmax = nn.LogSoftmax(axis=-1) - self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width) - self.cast = P.Cast() - - def construct(self, - input_tensor, - output_weights): - input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) - input_tensor = self.cast(input_tensor, self.compute_type) - output_weights = self.cast(output_weights, self.compute_type) - - logits = self.matmul(input_tensor, output_weights) - logits = self.cast(logits, self.dtype) - - log_probs = self.log_softmax(logits) - return log_probs - - -class TransformerDecoderStep(nn.Cell): - """ - Multi-layer transformer decoder step. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the encoder layers. - max_decode_length (int): Max decode length. - enc_seq_length (int): Length of source sentences. - num_hidden_layers (int): Number of hidden layers in encoder cells. - num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. - intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. - attention_probs_dropout_prob (float): The dropout probability for - SelfAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. - hidden_act (str): Activation function used in the encoder cells. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. - embedding_lookup (:class:`EmbeddingLookup`): Embedding lookup module. - embedding_processor (:class:`EmbeddingPostprocessor`) Embedding postprocessor module. - projection (:class:`PredLogProbs`): PredLogProbs module - """ - def __init__(self, - batch_size, - hidden_size, - enc_seq_length, - max_decode_length, - num_hidden_layers, - num_attention_heads=16, - intermediate_size=4096, - attention_probs_dropout_prob=0.3, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.3, - hidden_act="relu", - compute_type=mstype.float32, - embedding_lookup=None, - embedding_processor=None, - projection=None): - super(TransformerDecoderStep, self).__init__(auto_prefix=False) - self.num_hidden_layers = num_hidden_layers - - self.tfm_embedding_lookup = embedding_lookup - self.tfm_embedding_processor = embedding_processor - self.projection = projection - - self.tfm_decoder = TransformerDecoder( - batch_size=batch_size, - hidden_size=hidden_size, - seq_length=-1, # -1 means length is not fixed - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - num_hidden_layers=num_hidden_layers, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - - self.ones_like = P.OnesLike() - self.shape = P.Shape() - - self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() - self.expand = P.ExpandDims() - self.multiply = P.Mul() - - ones = np.ones(shape=(max_decode_length, max_decode_length)) - self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) - - self.cast_compute_type = CastWrapper(dst_type=compute_type) - - def construct(self, input_ids, enc_states, enc_attention_mask): - # input_ids: [batch_size * beam_width] - # process embedding - input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids) - input_embedding = self.tfm_embedding_processor(input_embedding) - input_embedding = self.cast_compute_type(input_embedding) - - input_shape = self.shape(input_ids) - input_len = input_shape[1] - future_mask = self.future_mask[0:input_len:1, 0:input_len:1] - - input_mask = self.ones_like(input_ids) - input_mask = self._create_attention_mask_from_input_mask(input_mask) - input_mask = self.multiply(input_mask, self.expand(future_mask, 0)) - input_mask = self.cast_compute_type(input_mask) - - enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] - - # call TransformerDecoder - decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask) - - # take the last step - decoder_output = decoder_output[::, input_len-1:input_len:1, ::] - - # projection and log_prob - log_probs = self.projection(decoder_output, embedding_tables) - - return log_probs - - -class TransformerModel(nn.Cell): - """ - Transformer with encoder and decoder. - - Args: - config (Class): Configuration for Transformer. - is_training (bool): True for training mode. False for eval mode. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - """ - def __init__(self, - config, - is_training, - use_one_hot_embeddings=False): - super(TransformerModel, self).__init__() - config = copy.deepcopy(config) - self.is_training = is_training - if not is_training: - config.hidden_dropout_prob = 0.0 - config.attention_probs_dropout_prob = 0.0 - - self.input_mask_from_dataset = config.input_mask_from_dataset - self.batch_size = config.batch_size - self.seq_length = config.seq_length - self.hidden_size = config.hidden_size - self.num_hidden_layers = config.num_hidden_layers - self.embedding_size = config.hidden_size - - self.last_idx = self.num_hidden_layers - 1 - - self.tfm_embedding_lookup = EmbeddingLookup( - vocab_size=config.vocab_size, - embedding_size=self.embedding_size, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range) - self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor( - embedding_size=self.embedding_size, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=0.02, - max_position_embeddings=config.max_position_embeddings, - dropout_prob=config.hidden_dropout_prob) - self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor( - embedding_size=self.embedding_size, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=0.02, - max_position_embeddings=config.max_position_embeddings, - dropout_prob=config.hidden_dropout_prob) - self.tfm_encoder = TransformerEncoder( - batch_size=self.batch_size, - hidden_size=self.hidden_size, - seq_length=self.seq_length, - num_attention_heads=config.num_attention_heads, - num_hidden_layers=self.num_hidden_layers, - intermediate_size=config.intermediate_size, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - hidden_act=config.hidden_act, - compute_type=config.compute_type) - - if is_training: - self.projection = PredLogProbs( - batch_size=self.batch_size, - seq_length=self.seq_length, - width=self.hidden_size, - compute_type=config.compute_type, - dtype=config.dtype) - self.tfm_decoder = TransformerDecoder( - batch_size=self.batch_size, - hidden_size=self.hidden_size, - seq_length=self.seq_length, - enc_seq_length=self.seq_length, - num_attention_heads=config.num_attention_heads, - num_hidden_layers=self.num_hidden_layers, - intermediate_size=config.intermediate_size, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - hidden_act=config.hidden_act, - compute_type=config.compute_type) - else: - self.projection = PredLogProbs( - batch_size=self.batch_size * config.beam_width, - seq_length=1, - width=self.hidden_size, - compute_type=config.compute_type, - dtype=config.dtype) - self.tfm_decoder = TransformerDecoderStep( - batch_size=self.batch_size * config.beam_width, - hidden_size=self.hidden_size, - enc_seq_length=self.seq_length, - max_decode_length=config.max_decode_length, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - intermediate_size=config.intermediate_size, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - use_one_hot_embeddings=False, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - hidden_act=config.hidden_act, - compute_type=config.compute_type, - embedding_lookup=self.tfm_embedding_lookup, - embedding_processor=self.tfm_embedding_postprocessor_for_decoder, - projection=self.projection) - self.tfm_decoder = BeamSearchDecoder( - batch_size=config.batch_size, - seq_length=config.seq_length, - vocab_size=config.vocab_size, - decoder=self.tfm_decoder, - beam_width=config.beam_width, - length_penalty_weight=config.length_penalty_weight, - max_decode_length=config.max_decode_length) - self.tfm_decoder.add_flags(loop_can_unroll=True) - - self.cast = P.Cast() - self.dtype = config.dtype - self.cast_compute_type = CastWrapper(dst_type=config.compute_type) - self.expand = P.ExpandDims() - self.multiply = P.Mul() - - self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() - - if is_training: - ones = np.ones(shape=(self.seq_length, self.seq_length)) - self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) - else: - self.tile_beam = TileBeam(beam_width=config.beam_width) - ones = np.ones(shape=(config.batch_size, config.max_decode_length)) - self.encdec_mask = Tensor(ones, dtype=mstype.float32) - - def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): - # process source sentence - src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) - src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings) - # attention mask [batch_size, seq_length, seq_length] - enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) - # transformer encoder - encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output), - self.cast_compute_type(enc_attention_mask)) - - if self.is_training: - # process target sentence - tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) - tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings) - # attention mask [batch_size, seq_length, seq_length] - tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) - tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) - # transformer decoder - decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output), - self.cast_compute_type(tgt_attention_mask), - encoder_output, enc_attention_mask) - # calculate logits and log_probs - log_probs = self.projection(decoder_output, embedding_tables) - ret = log_probs - else: - beam_encoder_output = self.tile_beam(encoder_output) - - enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1)) - - beam_enc_attention_mask = self.tile_beam(enc_attention_mask) - beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) - predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask) - ret = predicted_ids - return ret diff --git a/model_zoo/Transformer/train.py b/model_zoo/Transformer/train.py deleted file mode 100644 index ffd6b8c714..0000000000 --- a/model_zoo/Transformer/train.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Transformer training script.""" - -import time -import argparse -import random -import numpy as np - -import mindspore.common.dtype as mstype -from mindspore.common.tensor import Tensor -from mindspore.nn.optim import Adam -from mindspore.train.model import Model -from mindspore.train.loss_scale_manager import DynamicLossScaleManager -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint -from mindspore.train.callback import Callback, TimeMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net -import mindspore.dataset.engine as de -import mindspore.communication.management as D -from mindspore.train.parallel_utils import ParallelMode -from mindspore import context - -from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \ - TransformerTrainOneStepWithLossScaleCell -from src.config import cfg, transformer_net_cfg -from src.dataset import create_transformer_dataset -from src.lr_schedule import create_dynamic_lr - -random_seed = 1 -random.seed(random_seed) -np.random.seed(random_seed) -de.config.set_seed(random_seed) - -def get_ms_timestamp(): - t = time.time() - return int(round(t * 1000)) -time_stamp_init = False -time_stamp_first = 0 - -class LossCallBack(Callback): - """ - Monitor the loss in training. - If the loss is NAN or INF terminating training. - Note: - If per_print_times is 0 do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0.") - self._per_print_times = per_print_times - global time_stamp_init, time_stamp_first - if not time_stamp_init: - time_stamp_first = get_ms_timestamp() - time_stamp_init = True - - def step_end(self, run_context): - global time_stamp_first - time_stamp_current = get_ms_timestamp() - cb_params = run_context.original_args() - print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, - cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) - with open("./loss.log", "a+") as f: - f.write("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, - cb_params.cur_epoch_num, - cb_params.cur_step_num, - str(cb_params.net_outputs))) - f.write('\n') - - -def argparse_init(): - """ - Argparse init. - """ - parser = argparse.ArgumentParser(description='transformer') - parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") - parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.") - parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") - parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.") - parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") - parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, " - "default is true.") - parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, " - "default is 2500.") - parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.") - parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, " - "default is ./checkpoint/") - parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path") - return parser - -def run_transformer_train(): - """ - Transformer training. - """ - parser = argparse_init() - args, _ = parser.parse_known_args() - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) - context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) - - if args.distribute == "true": - device_num = args.device_num - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, - parameter_broadcast=True, device_num=device_num) - D.init() - rank_id = args.device_id % device_num - else: - device_num = 1 - rank_id = 0 - dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num, - rank_id=rank_id, do_shuffle=args.do_shuffle, - enable_data_sink=args.enable_data_sink, - dataset_path=args.data_path) - - netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) - - if args.checkpoint_path: - parameter_dict = load_checkpoint(args.checkpoint_path) - load_param_into_net(netwithloss, parameter_dict) - - lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay", - training_steps=dataset.get_dataset_size()*args.epoch_size, - learning_rate=cfg.lr_schedule.learning_rate, - warmup_steps=cfg.lr_schedule.warmup_steps, - hidden_size=transformer_net_cfg.hidden_size, - start_decay_step=cfg.lr_schedule.start_decay_step, - min_lr=cfg.lr_schedule.min_lr), mstype.float32) - optimizer = Adam(netwithloss.trainable_params(), lr) - - callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] - if args.enable_save_ckpt == "true": - if device_num == 1 or (device_num > 1 and rank_id == 0): - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, - keep_checkpoint_max=args.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) - callbacks.append(ckpoint_cb) - - if args.enable_lossscale == "true": - scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value, - scale_factor=cfg.scale_factor, - scale_window=cfg.scale_window) - update_cell = scale_manager.get_update_cell() - netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, - scale_update_cell=update_cell) - else: - netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer) - - netwithgrads.set_train(True) - model = Model(netwithgrads) - model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) - -if __name__ == '__main__': - run_transformer_train() diff --git a/model_zoo/alexnet/src/alexnet.py b/model_zoo/alexnet/src/alexnet.py deleted file mode 100644 index c528ae39e9..0000000000 --- a/model_zoo/alexnet/src/alexnet.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Alexnet.""" -import mindspore.nn as nn -from mindspore.common.initializer import TruncatedNormal -from mindspore.ops import operations as P - -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode=pad_mode) - -def fc_with_initialize(input_channels, out_channels): - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - -def weight_variable(): - return TruncatedNormal(0.02) # 0.02 - - -class AlexNet(nn.Cell): - """ - Alexnet - """ - def __init__(self, num_classes=10, channel=3): - super(AlexNet, self).__init__() - self.conv1 = conv(channel, 96, 11, stride=4) - self.conv2 = conv(96, 256, 5, pad_mode="same") - self.conv3 = conv(256, 384, 3, pad_mode="same") - self.conv4 = conv(384, 384, 3, pad_mode="same") - self.conv5 = conv(384, 256, 3, pad_mode="same") - self.relu = nn.ReLU() - self.max_pool2d = P.MaxPool(ksize=3, strides=2) - self.flatten = nn.Flatten() - self.fc1 = fc_with_initialize(6*6*256, 4096) - self.fc2 = fc_with_initialize(4096, 4096) - self.fc3 = fc_with_initialize(4096, num_classes) - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv3(x) - x = self.relu(x) - x = self.conv4(x) - x = self.relu(x) - x = self.conv5(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x diff --git a/model_zoo/alexnet/train.py b/model_zoo/alexnet/train.py deleted file mode 100644 index df038d62a2..0000000000 --- a/model_zoo/alexnet/train.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -######################## train alexnet example ######################## -train alexnet and get network model files(.ckpt) : -python train.py --data_path /YourDataPath -""" - -import argparse -from src.config import alexnet_cfg as cfg -from src.dataset import create_dataset_cifar10 -from src.generator_lr import get_lr -from src.alexnet import AlexNet -import mindspore.nn as nn -from mindspore import context -from mindspore import Tensor -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], - help='device where the code will be implemented (default: Ascend)') - parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') - parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ - path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') - args = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - - ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, cfg.epoch_size) - network = AlexNet(cfg.num_classes) - loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) - opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) - model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) - time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck) - - print("============== Starting Training ==============") - model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) diff --git a/model_zoo/bert/README.md b/model_zoo/bert/README.md deleted file mode 100644 index 45928da4e3..0000000000 --- a/model_zoo/bert/README.md +++ /dev/null @@ -1,171 +0,0 @@ -# BERT Example -## Description -This example implements pre-training, fine-tuning and evaluation of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model) and [BERT-NEZHA](https://github.com/huawei-noah/Pretrained-Language-Model)(a Chinese pretrained language model developed by Huawei, which introduced a improvement of Functional Relative Positional Encoding as an effective positional encoding scheme). - -## Requirements -- Install [MindSpore](https://www.mindspore.cn/install/en). -- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. -- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. -> Notes: - If you are running a fine-tuning or evaluation task, prepare a checkpoint from pre-train. - -## Running the Example -### Pre-Training -- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. - -- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. - - ``` bash - sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR - ``` -- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. - - ``` bash - sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH - ``` - -### Fine-Tuning and Evaluation -- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`. - -- Set task related hyperparameters in scripts/run_XXX.sh. - -- Run `bash scripts/run_XXX.py` for fine-tuning of BERT-base and BERT-NEZHA model. - - ```bash - bash scripts/run_XXX.sh - ``` - -## Usage -### Pre-Training -``` -usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] - [--enable_save_ckpt ENABLE_SAVE_CKPT] - [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] - [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] - [--save_checkpoint_steps N] [--save_checkpoint_num N] - [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] - -options: - --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" - --epoch_size epoch size: N, default is 1 - --device_num number of used devices: N, default is 1 - --device_id device id: N, default is 0 - --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" - --enable_lossscale enable lossscale: "true" | "false", default is "true" - --do_shuffle enable shuffle: "true" | "false", default is "true" - --enable_data_sink enable data sink: "true" | "false", default is "true" - --data_sink_steps set data sink steps: N, default is 1 - --checkpoint_path path to save checkpoint files: PATH, default is "" - --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 - --save_checkpoint_num number for saving checkpoint files: N, default is 1 - --data_dir path to dataset directory: PATH, default is "" - --schema_dir path to schema.json file, PATH, default is "" -``` -## Options and Parameters -It contains of parameters of BERT model and options for training, which is set in file `config.py`, `finetune_config.py` and `evaluation_config.py` respectively. -### Options: -``` -config.py: - bert_network version of BERT model: base | nezha, default is base - loss_scale_value initial value of loss scale: N, default is 2^32 - scale_factor factor used to update loss scale: N, default is 2 - scale_window steps for once updatation of loss scale: N, default is 1000 - optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" - -scripts/run_ner.sh: - device_target targeted device to run task: Ascend | GPU - do_train whether to run training on training set: true | false - do_eval whether to run eval on dev set: true | false - assessment_method assessment method to do evaluation: f1 | clue_benchmark - use_crf whether to use crf to calculate loss: true | false - device_id device id to run task - epoch_num total number of training epochs to perform - num_class number of classes to do labeling - vocab_file_path the vocabulary file that the BERT model was trained on - label2id_file_path label to id json file - save_finetune_checkpoint_path path to save generated finetuning checkpoint - load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) - load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval - train_data_file_path ner tfrecord for training. E.g., train.tfrecord - eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result - schema_file_path path to datafile schema file - -scripts/run_squad.sh: - device_target targeted device to run task: Ascend | GPU - do_train whether to run training on training set: true | false - do_eval whether to run eval on dev set: true | false - device_id device id to run task - epoch_num total number of training epochs to perform - num_class number of classes to classify, usually 2 for squad task - vocab_file_path the vocabulary file that the BERT model was trained on - eval_json_path path to squad dev json file - save_finetune_checkpoint_path path to save generated finetuning checkpoint - load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) - load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval - train_data_file_path squad tfrecord for training. E.g., train1.1.tfrecord - eval_data_file_path squad tfrecord for predictions. E.g., dev1.1.tfrecord - schema_file_path path to datafile schema file - -scripts/run_classifier.sh - device_target targeted device to run task: Ascend | GPU - do_train whether to run training on training set: true | false - do_eval whether to run eval on dev set: true | false - assessment_method assessment method to do evaluation: accuracy | f1 | mcc | spearman_correlation - device_id device id to run task - epoch_num total number of training epochs to perform - num_class number of classes to do labeling - save_finetune_checkpoint_path path to save generated finetuning checkpoint - load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) - load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval - train_data_file_path tfrecord for training. E.g., train.tfrecord - eval_data_file_path tfrecord for predictions. E.g., dev.tfrecord - schema_file_path path to datafile schema file - - -``` - -### Parameters: -``` -Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): - batch_size batch size of input dataset: N, default is 16 - seq_length length of input sequence: N, default is 128 - vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 - hidden_size size of bert encoder layers: N, default is 768 - num_hidden_layers number of hidden layers: N, default is 12 - num_attention_heads number of attention heads: N, default is 12 - intermediate_size size of intermediate layer: N, default is 3072 - hidden_act activation function used: ACTIVATION, default is "gelu" - hidden_dropout_prob dropout probability for BertOutput: Q, default is 0.1 - attention_probs_dropout_prob dropout probability for BertAttention: Q, default is 0.1 - max_position_embeddings maximum length of sequences: N, default is 512 - type_vocab_size size of token type vocab: N, default is 16 - initializer_range initialization value of TruncatedNormal: Q, default is 0.02 - use_relative_positions use relative positions or not: True | False, default is False - input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True - token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True - dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 - compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 - -Parameters for optimizer: - AdamWeightDecayDynamicLR: - decay_steps steps of the learning rate decay: N - learning_rate value of learning rate: Q - end_learning_rate value of end learning rate: Q, must be positive - power power: Q - warmup_steps steps of the learning rate warm up: N - weight_decay weight decay: Q - eps term added to the denominator to improve numerical stability: Q - - Lamb: - decay_steps steps of the learning rate decay: N - learning_rate value of learning rate: Q - end_learning_rate value of end learning rate: Q - power power: Q - warmup_steps steps of the learning rate warm up: N - weight_decay weight decay: Q - - Momentum: - learning_rate value of learning rate: Q - momentum momentum for the moving average: Q -``` - diff --git a/model_zoo/bert/run_classifier.py b/model_zoo/bert/run_classifier.py deleted file mode 100644 index 4b2801f87c..0000000000 --- a/model_zoo/bert/run_classifier.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -''' -Bert finetune and evaluation script. -''' - -import os -import argparse -from src.bert_for_finetune import BertFinetuneCell, BertCLS -from src.finetune_eval_config import optimizer_cfg, bert_net_cfg -from src.dataset import create_classification_dataset -from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation -from src.utils import make_directory, LossCallBack, LoadNewestCkpt -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore import log as logger -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum -from mindspore.common.tensor import Tensor -from mindspore.train.model import Model -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -_cur_dir = os.getcwd() - -def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): - """ do train """ - if load_checkpoint_path == "": - raise ValueError("Pretrain model missed, finetune task must load pretrain model!") - steps_per_epoch = dataset.get_dataset_size() - epoch_num = dataset.get_repeat_count() - # optimizer - if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), - decay_steps=steps_per_epoch * epoch_num, - learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=optimizer_cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) - elif optimizer_cfg.optimizer == 'Lamb': - optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, - start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, - end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, - power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - decay_filter=optimizer_cfg.Lamb.decay_filter) - elif optimizer_cfg.optimizer == 'Momentum': - optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, - momentum=optimizer_cfg.Momentum.momentum) - else: - raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") - - # load checkpoint into network - ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) - ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) - param_dict = load_checkpoint(load_checkpoint_path) - load_param_into_net(network, param_dict) - - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) - netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) - model = Model(netwithgrads) - callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] - model.train(epoch_num, dataset, callbacks=callbacks) - -def eval_result_print(assessment_method="accuracy", callback=None): - """ print eval result """ - if assessment_method == "accuracy": - print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, - callback.acc_num / callback.total_num)) - elif assessment_method == "f1": - print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) - print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) - print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) - elif assessment_method == "mcc": - print("MCC {:.6f} ".format(callback.cal())) - elif assessment_method == "spearman_correlation": - print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) - else: - raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") - -def do_eval(dataset=None, network=None, num_class=2, assessment_method="accuracy", load_checkpoint_path=""): - """ do eval """ - if load_checkpoint_path == "": - raise ValueError("Finetune model missed, evaluation task must load finetune model!") - net_for_pretraining = network(bert_net_cfg, False, num_class) - net_for_pretraining.set_train(False) - param_dict = load_checkpoint(load_checkpoint_path) - load_param_into_net(net_for_pretraining, param_dict) - model = Model(net_for_pretraining) - - if assessment_method == "accuracy": - callback = Accuracy() - elif assessment_method == "f1": - callback = F1(False, num_class) - elif assessment_method == "mcc": - callback = MCC() - elif assessment_method == "spearman_correlation": - callback = Spearman_Correlation() - else: - raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") - - columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] - for data in dataset.create_dict_iterator(): - input_data = [] - for i in columns_list: - input_data.append(Tensor(data[i])) - input_ids, input_mask, token_type_id, label_ids = input_data - logits = model.predict(input_ids, input_mask, token_type_id, label_ids) - callback.update(logits, label_ids) - print("==============================================================") - eval_result_print(assessment_method, callback) - print("==============================================================") - -def run_classifier(): - """run classifier task""" - parser = argparse.ArgumentParser(description="run classifier") - parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") - parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " - "[MCC, Spearman_correlation, " - "Accuracy], default is accuracy") - parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") - parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") - parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") - parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") - parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--train_data_file_path", type=str, default="", - help="Data path, it is better to use absolute path") - parser.add_argument("--eval_data_file_path", type=str, default="", - help="Data path, it is better to use absolute path") - parser.add_argument("--schema_file_path", type=str, default="", - help="Schema path, it is better to use absolute path") - args_opt = parser.parse_args() - epoch_num = args_opt.epoch_num - assessment_method = args_opt.assessment_method.lower() - load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path - save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path - load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path - - if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": - raise ValueError("At least one of 'do_train' or 'do_eval' must be true") - if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": - raise ValueError("'train_data_file_path' must be set when do finetune task") - if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": - raise ValueError("'eval_data_file_path' must be set when do evaluation task") - - target = args_opt.device_target - if target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - elif target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - if bert_net_cfg.compute_type != mstype.float32: - logger.warning('GPU only support fp32 temporarily, run with fp32.') - bert_net_cfg.compute_type = mstype.float32 - else: - raise Exception("Target error, GPU or Ascend is supported.") - - netwithloss = BertCLS(bert_net_cfg, True, num_labels=args_opt.num_class, dropout_prob=0.1, - assessment_method=assessment_method) - - if args_opt.do_train.lower() == "true": - ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, - assessment_method=assessment_method, - data_file_path=args_opt.train_data_file_path, - schema_file_path=args_opt.schema_file_path) - do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) - - if args_opt.do_eval.lower() == "true": - if save_finetune_checkpoint_path == "": - load_finetune_checkpoint_dir = _cur_dir - else: - load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) - load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, - ds.get_dataset_size(), epoch_num, "classifier") - - if args_opt.do_eval.lower() == "true": - ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, - assessment_method=assessment_method, - data_file_path=args_opt.eval_data_file_path, - schema_file_path=args_opt.schema_file_path) - do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path) - -if __name__ == "__main__": - run_classifier() diff --git a/model_zoo/bert/run_ner.py b/model_zoo/bert/run_ner.py deleted file mode 100644 index a61c96066e..0000000000 --- a/model_zoo/bert/run_ner.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -''' -Bert finetune and evaluation script. -''' - -import os -import json -import argparse -from src.bert_for_finetune import BertFinetuneCell, BertNER -from src.finetune_eval_config import optimizer_cfg, bert_net_cfg -from src.dataset import create_ner_dataset -from src.utils import make_directory, LossCallBack, LoadNewestCkpt -from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore import log as logger -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum -from mindspore.common.tensor import Tensor -from mindspore.train.model import Model -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -_cur_dir = os.getcwd() - - -def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): - """ do train """ - if load_checkpoint_path == "": - raise ValueError("Pretrain model missed, finetune task must load pretrain model!") - steps_per_epoch = dataset.get_dataset_size() - epoch_num = dataset.get_repeat_count() - # optimizer - if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), - decay_steps=steps_per_epoch * epoch_num, - learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=optimizer_cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) - elif optimizer_cfg.optimizer == 'Lamb': - optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, - start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, - end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, - power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - decay_filter=optimizer_cfg.Lamb.decay_filter) - elif optimizer_cfg.optimizer == 'Momentum': - optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, - momentum=optimizer_cfg.Momentum.momentum) - else: - raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") - - # load checkpoint into network - ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) - ckpoint_cb = ModelCheckpoint(prefix="ner", directory=save_checkpoint_path, config=ckpt_config) - param_dict = load_checkpoint(load_checkpoint_path) - load_param_into_net(network, param_dict) - - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) - netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) - model = Model(netwithgrads) - callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] - model.train(epoch_num, dataset, callbacks=callbacks) - -def eval_result_print(assessment_method="accuracy", callback=None): - """print eval result""" - if assessment_method == "accuracy": - print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, - callback.acc_num / callback.total_num)) - elif assessment_method == "f1": - print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) - print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) - print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) - elif assessment_method == "mcc": - print("MCC {:.6f} ".format(callback.cal())) - elif assessment_method == "spearman_correlation": - print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) - else: - raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") - -def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="", - load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None): - """ do eval """ - if load_checkpoint_path == "": - raise ValueError("Finetune model missed, evaluation task must load finetune model!") - if assessment_method == "clue_benchmark": - bert_net_cfg.batch_size = 1 - net_for_pretraining = network(bert_net_cfg, False, num_class, use_crf=(use_crf.lower() == "true"), - tag_to_index=tag_to_index) - net_for_pretraining.set_train(False) - param_dict = load_checkpoint(load_checkpoint_path) - load_param_into_net(net_for_pretraining, param_dict) - model = Model(net_for_pretraining) - - if assessment_method == "clue_benchmark": - from src.cluener_evaluation import submit - submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file) - else: - if assessment_method == "accuracy": - callback = Accuracy() - elif assessment_method == "f1": - callback = F1((use_crf.lower() == "true"), num_class) - elif assessment_method == "mcc": - callback = MCC() - elif assessment_method == "spearman_correlation": - callback = Spearman_Correlation() - else: - raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") - - columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] - for data in dataset.create_dict_iterator(): - input_data = [] - for i in columns_list: - input_data.append(Tensor(data[i])) - input_ids, input_mask, token_type_id, label_ids = input_data - logits = model.predict(input_ids, input_mask, token_type_id, label_ids) - callback.update(logits, label_ids) - print("==============================================================") - eval_result_print(assessment_method, callback) - print("==============================================================") - -def run_ner(): - """run ner task""" - parser = argparse.ArgumentParser(description="run classifier") - parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") - parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " - "[F1, clue_benchmark], default is F1") - parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") - parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") - parser.add_argument("--use_crf", type=str, default="false", help="Use crf, default is false") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") - parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") - parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark") - parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark") - parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") - parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--train_data_file_path", type=str, default="", - help="Data path, it is better to use absolute path") - parser.add_argument("--eval_data_file_path", type=str, default="", - help="Data path, it is better to use absolute path") - parser.add_argument("--schema_file_path", type=str, default="", - help="Schema path, it is better to use absolute path") - args_opt = parser.parse_args() - epoch_num = args_opt.epoch_num - assessment_method = args_opt.assessment_method.lower() - load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path - save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path - load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path - - if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": - raise ValueError("At least one of 'do_train' or 'do_eval' must be true") - if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": - raise ValueError("'train_data_file_path' must be set when do finetune task") - if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": - raise ValueError("'eval_data_file_path' must be set when do evaluation task") - if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "": - raise ValueError("'vocab_file_path' must be set to do clue benchmark") - if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "": - raise ValueError("'label2id_file_path' must be set to use crf") - if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "": - raise ValueError("'label2id_file_path' must be set to do clue benchmark") - - target = args_opt.device_target - if target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - elif target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - if bert_net_cfg.compute_type != mstype.float32: - logger.warning('GPU only support fp32 temporarily, run with fp32.') - bert_net_cfg.compute_type = mstype.float32 - else: - raise Exception("Target error, GPU or Ascend is supported.") - - tag_to_index = None - if args_opt.use_crf.lower() == "true": - with open(args_opt.label2id_file_path) as json_file: - tag_to_index = json.load(json_file) - max_val = max(tag_to_index.values()) - tag_to_index[""] = max_val + 1 - tag_to_index[""] = max_val + 2 - number_labels = len(tag_to_index) - else: - number_labels = args_opt.num_class - netwithloss = BertNER(bert_net_cfg, True, num_labels=number_labels, - use_crf=(args_opt.use_crf.lower() == "true"), - tag_to_index=tag_to_index, dropout_prob=0.1) - if args_opt.do_train.lower() == "true": - ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, - assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, - schema_file_path=args_opt.schema_file_path) - do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) - - if args_opt.do_eval.lower() == "true": - if save_finetune_checkpoint_path == "": - load_finetune_checkpoint_dir = _cur_dir - else: - load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) - load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, - ds.get_dataset_size(), epoch_num, "ner") - - if args_opt.do_eval.lower() == "true": - ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, - assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, - schema_file_path=args_opt.schema_file_path) - do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, - load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index) - -if __name__ == "__main__": - run_ner() diff --git a/model_zoo/bert/run_pretrain.py b/model_zoo/bert/run_pretrain.py deleted file mode 100644 index 7123c942f3..0000000000 --- a/model_zoo/bert/run_pretrain.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -#################pre_train bert example on zh-wiki######################## -python run_pretrain.py -""" - -import os -import argparse -import numpy -import mindspore.communication.management as D -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore.train.model import Model -from mindspore.train.parallel_utils import ParallelMode -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR -from mindspore import log as logger -from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell -from src.dataset import create_bert_dataset -from src.config import cfg, bert_net_cfg -from src.utils import LossCallBack -_current_dir = os.path.dirname(os.path.realpath(__file__)) - - -def run_pretrain(): - """pre-train bert_clue""" - parser = argparse.ArgumentParser(description='bert pre_training') - parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") - parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") - parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") - parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") - parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") - parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") - parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") - parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " - "default is 1000.") - parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " - "meaning run all steps according to epoch number.") - parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") - parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") - parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") - - args_opt = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) - context.set_context(reserve_class_name_in_scope=False) - context.set_context(variable_memory_max_size="30GB") - ckpt_save_dir = args_opt.save_checkpoint_path - if args_opt.distribute == "true": - if args_opt.device_target == 'Ascend': - D.init('hccl') - device_num = args_opt.device_num - rank = args_opt.device_id % device_num - else: - D.init('nccl') - device_num = D.get_group_size() - rank = D.get_rank() - ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' - - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, - device_num=device_num) - from mindspore.parallel._auto_parallel_context import auto_parallel_context - if bert_net_cfg.num_hidden_layers == 12: - if bert_net_cfg.use_relative_positions: - auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) - else: - auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) - elif bert_net_cfg.num_hidden_layers == 24: - if bert_net_cfg.use_relative_positions: - auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) - else: - auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) - else: - rank = 0 - device_num = 1 - - if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: - logger.warning('Gpu only support fp32 temporarily, run with fp32.') - bert_net_cfg.compute_type = mstype.float32 - - - ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, - args_opt.enable_data_sink, args_opt.data_sink_steps, - args_opt.data_dir, args_opt.schema_dir) - if args_opt.train_steps > 0: - new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) - netwithloss = BertNetworkWithLoss(bert_net_cfg, True) - - if cfg.optimizer == 'Lamb': - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, - start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, - power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, - eps=cfg.Lamb.eps) - elif cfg.optimizer == 'Momentum': - optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, - momentum=cfg.Momentum.momentum) - elif cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), - decay_steps=ds.get_dataset_size() * new_repeat_count, - learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=cfg.AdamWeightDecayDynamicLR.power, - weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=cfg.AdamWeightDecayDynamicLR.eps, - warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps) - else: - raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". - format(cfg.optimizer)) - callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] - if args_opt.enable_save_ckpt == "true": - config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, - keep_checkpoint_max=args_opt.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) - callback.append(ckpoint_cb) - - if args_opt.load_checkpoint_path: - param_dict = load_checkpoint(args_opt.load_checkpoint_path) - load_param_into_net(netwithloss, param_dict) - - if args_opt.enable_lossscale == "true": - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, - scale_factor=cfg.scale_factor, - scale_window=cfg.scale_window) - netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, - scale_update_cell=update_cell) - else: - netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) - - model = Model(netwithgrads) - model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) -if __name__ == '__main__': - numpy.random.seed(0) - run_pretrain() diff --git a/model_zoo/bert/run_squad.py b/model_zoo/bert/run_squad.py deleted file mode 100644 index 083cedac1d..0000000000 --- a/model_zoo/bert/run_squad.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -''' -Bert finetune and evaluation script. -''' -import os -import argparse -import collections -from src.bert_for_finetune import BertSquadCell, BertSquad -from src.finetune_eval_config import optimizer_cfg, bert_net_cfg -from src.dataset import create_squad_dataset -from src import tokenization -from src.create_squad_data import read_squad_examples, convert_examples_to_features -from src.run_squad import write_predictions -from src.utils import make_directory, LossCallBack, LoadNewestCkpt -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore import log as logger -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum -from mindspore.common.tensor import Tensor -from mindspore.train.model import Model -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -_cur_dir = os.getcwd() - -def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): - """ do train """ - if load_checkpoint_path == "": - raise ValueError("Pretrain model missed, finetune task must load pretrain model!") - steps_per_epoch = dataset.get_dataset_size() - epoch_num = dataset.get_repeat_count() - # optimizer - if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), - decay_steps=steps_per_epoch * epoch_num, - learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=optimizer_cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) - elif optimizer_cfg.optimizer == 'Lamb': - optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, - start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, - end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, - power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - decay_filter=optimizer_cfg.Lamb.decay_filter) - elif optimizer_cfg.optimizer == 'Momentum': - optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, - momentum=optimizer_cfg.Momentum.momentum) - else: - raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") - - # load checkpoint into network - ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) - ckpoint_cb = ModelCheckpoint(prefix="squad", directory=save_checkpoint_path, config=ckpt_config) - param_dict = load_checkpoint(load_checkpoint_path) - load_param_into_net(network, param_dict) - - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) - netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell) - model = Model(netwithgrads) - callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] - model.train(epoch_num, dataset, callbacks=callbacks) - - -def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", seq_length=384): - """ do eval """ - if load_checkpoint_path == "": - raise ValueError("Finetune model missed, evaluation task must load finetune model!") - tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) - eval_examples = read_squad_examples(eval_json, False) - eval_features = convert_examples_to_features( - examples=eval_examples, - tokenizer=tokenizer, - max_seq_length=seq_length, - doc_stride=128, - max_query_length=64, - is_training=False, - output_fn=None, - verbose_logging=False) - - net = BertSquad(bert_net_cfg, False, 2) - net.set_train(False) - param_dict = load_checkpoint(load_checkpoint_path) - load_param_into_net(net, param_dict) - model = Model(net) - output = [] - RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) - columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] - for data in dataset.create_dict_iterator(): - input_data = [] - for i in columns_list: - input_data.append(Tensor(data[i])) - input_ids, input_mask, segment_ids, unique_ids = input_data - start_positions = Tensor([1], mstype.float32) - end_positions = Tensor([1], mstype.float32) - is_impossible = Tensor([1], mstype.float32) - logits = model.predict(input_ids, input_mask, segment_ids, start_positions, - end_positions, unique_ids, is_impossible) - ids = logits[0].asnumpy() - start = logits[1].asnumpy() - end = logits[2].asnumpy() - - for i in range(bert_net_cfg.batch_size): - unique_id = int(ids[i]) - start_logits = [float(x) for x in start[i].flat] - end_logits = [float(x) for x in end[i].flat] - output.append(RawResult( - unique_id=unique_id, - start_logits=start_logits, - end_logits=end_logits)) - write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", None, None) - -def run_squad(): - """run squad task""" - parser = argparse.ArgumentParser(description="run classifier") - parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") - parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") - parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") - parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") - parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path") - parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json") - parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") - parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") - parser.add_argument("--train_data_file_path", type=str, default="", - help="Data path, it is better to use absolute path") - parser.add_argument("--eval_data_file_path", type=str, default="", - help="Data path, it is better to use absolute path") - parser.add_argument("--schema_file_path", type=str, default="", - help="Schema path, it is better to use absolute path") - args_opt = parser.parse_args() - epoch_num = args_opt.epoch_num - load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path - save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path - load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path - - if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": - raise ValueError("At least one of 'do_train' or 'do_eval' must be true") - if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": - raise ValueError("'train_data_file_path' must be set when do finetune task") - if args_opt.do_eval.lower() == "true": - if args_opt.eval_data_file_path == "": - raise ValueError("'eval_data_file_path' must be set when do evaluation task") - if args_opt.vocab_file_path == "": - raise ValueError("'vocab_file_path' must be set when do evaluation task") - if args_opt.eval_json_path == "": - raise ValueError("'tokenization_file_path' must be set when do evaluation task") - - - target = args_opt.device_target - if target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - elif target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - if bert_net_cfg.compute_type != mstype.float32: - logger.warning('GPU only support fp32 temporarily, run with fp32.') - bert_net_cfg.compute_type = mstype.float32 - else: - raise Exception("Target error, GPU or Ascend is supported.") - - netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) - - if args_opt.do_train.lower() == "true": - ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, - data_file_path=args_opt.train_data_file_path, - schema_file_path=args_opt.schema_file_path) - do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) - if args_opt.do_eval.lower() == "true": - if save_finetune_checkpoint_path == "": - load_finetune_checkpoint_dir = _cur_dir - else: - load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) - load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, - ds.get_dataset_size(), epoch_num, "squad") - - if args_opt.do_eval.lower() == "true": - ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, - data_file_path=args_opt.eval_data_file_path, - schema_file_path=args_opt.schema_file_path, is_training=False) - do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path, - load_finetune_checkpoint_path, bert_net_cfg.seq_length) - -if __name__ == "__main__": - run_squad() diff --git a/model_zoo/bert/scripts/run_classifier.sh b/model_zoo/bert/scripts/run_classifier.sh deleted file mode 100644 index 275324b950..0000000000 --- a/model_zoo/bert/scripts/run_classifier.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash scripts/run_classifier.sh" -echo "for example: bash scripts/run_classifier.sh" -echo "assessment_method include: [MCC, Spearman_correlation ,Accuracy]" -echo "==============================================================================================================" - -mkdir -p ms_log -CUR_DIR=`pwd` -PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) -export GLOG_log_dir=${CUR_DIR}/ms_log -export GLOG_logtostderr=0 -python ${PROJECT_DIR}/../run_classifier.py \ - --device_target="Ascend" \ - --do_train="true" \ - --do_eval="false" \ - --assessment_method="Accuracy" \ - --device_id=0 \ - --epoch_num=1 \ - --num_class=2 \ - --save_finetune_checkpoint_path="" \ - --load_pretrain_checkpoint_path="" \ - --load_finetune_checkpoint_path="" \ - --train_data_file_path="" \ - --eval_data_file_path="" \ - --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/scripts/run_distribute_pretrain.sh b/model_zoo/bert/scripts/run_distribute_pretrain.sh deleted file mode 100644 index eb3a0979d1..0000000000 --- a/model_zoo/bert/scripts/run_distribute_pretrain.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH" -echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json" -echo "It is better to use absolute path." -echo "==============================================================================================================" - -EPOCH_SIZE=$2 -DATA_DIR=$3 -SCHEMA_DIR=$4 -PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) -export RANK_TABLE_FILE=$5 -export RANK_SIZE=$1 -cores=`cat /proc/cpuinfo|grep "processor" |wc -l` -echo "the number of logical core" $cores -avg_core_per_rank=`expr $cores \/ $RANK_SIZE` -core_gap=`expr $avg_core_per_rank \- 1` -echo "avg_core_per_rank" $avg_core_per_rank -echo "core_gap" $core_gap -for((i=0;i env.log - taskset -c $cmdopt python ${PROJECT_DIR}/../run_pretrain.py \ - --distribute="true" \ - --epoch_size=$EPOCH_SIZE \ - --device_id=$DEVICE_ID \ - --device_num=$RANK_SIZE \ - --enable_save_ckpt="true" \ - --enable_lossscale="true" \ - --do_shuffle="true" \ - --enable_data_sink="true" \ - --data_sink_steps=100 \ - --load_checkpoint_path="" \ - --save_checkpoint_steps=10000 \ - --save_checkpoint_num=1 \ - --data_dir=$DATA_DIR \ - --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & - cd ../ -done diff --git a/model_zoo/bert/scripts/run_ner.sh b/model_zoo/bert/scripts/run_ner.sh deleted file mode 100644 index ae401b2462..0000000000 --- a/model_zoo/bert/scripts/run_ner.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash scripts/run_ner.sh" -echo "for example: bash scripts/run_ner.sh" -echo "assessment_method include: [F1, clue_benchmark]" -echo "==============================================================================================================" - -mkdir -p ms_log -CUR_DIR=`pwd` -PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) -export GLOG_log_dir=${CUR_DIR}/ms_log -export GLOG_logtostderr=0 -python ${PROJECT_DIR}/../run_ner.py \ - --device_target="Ascend" \ - --do_train="true" \ - --do_eval="false" \ - --assessment_method="F1" \ - --use_crf="false" \ - --device_id=0 \ - --epoch_num=1 \ - --num_class=2 \ - --vocab_file_path="" \ - --label2id_file_path="" \ - --save_finetune_checkpoint_path="" \ - --load_pretrain_checkpoint_path="" \ - --load_finetune_checkpoint_path="" \ - --train_data_file_path="" \ - --eval_data_file_path="" \ - --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/scripts/run_squad.sh b/model_zoo/bert/scripts/run_squad.sh deleted file mode 100644 index a33950cadb..0000000000 --- a/model_zoo/bert/scripts/run_squad.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash scripts/run_squad.sh" -echo "for example: bash scripts/run_squad.sh" -echo "assessment_method include: [Accuracy]" -echo "==============================================================================================================" - -mkdir -p ms_log -CUR_DIR=`pwd` -PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) -export GLOG_log_dir=${CUR_DIR}/ms_log -export GLOG_logtostderr=0 -python ${PROJECT_DIR}/../run_squad.py \ - --device_target="Ascend" \ - --do_train="true" \ - --do_eval="false" \ - --device_id=0 \ - --epoch_num=1 \ - --num_class=2 \ - --vocab_file_path="" \ - --eval_json_path="" \ - --save_finetune_checkpoint_path="" \ - --load_pretrain_checkpoint_path="" \ - --load_finetune_checkpoint_path="" \ - --train_data_file_path="" \ - --eval_data_file_path="" \ - --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/src/assessment_method.py b/model_zoo/bert/src/assessment_method.py deleted file mode 100644 index ca6579cabf..0000000000 --- a/model_zoo/bert/src/assessment_method.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -''' -Bert evaluation assessment method script. -''' -import math -import numpy as np -from .CRF import postprocess - -class Accuracy(): - ''' - calculate accuracy - ''' - def __init__(self): - self.acc_num = 0 - self.total_num = 0 - def update(self, logits, labels): - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - logits = logits.asnumpy() - logit_id = np.argmax(logits, axis=-1) - self.acc_num += np.sum(labels == logit_id) - self.total_num += len(labels) - print("=========================accuracy is ", self.acc_num / self.total_num) - -class F1(): - ''' - calculate F1 score - ''' - def __init__(self, use_crf=False, num_labels=2): - self.TP = 0 - self.FP = 0 - self.FN = 0 - self.use_crf = use_crf - self.num_labels = num_labels - - def update(self, logits, labels): - ''' - update F1 score - ''' - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - if self.use_crf: - backpointers, best_tag_id = logits - best_path = postprocess(backpointers, best_tag_id) - logit_id = [] - for ele in best_path: - logit_id.extend(ele) - else: - logits = logits.asnumpy() - logit_id = np.argmax(logits, axis=-1) - logit_id = np.reshape(logit_id, -1) - pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) - pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) - self.TP += np.sum(pos_eva&pos_label) - self.FP += np.sum(pos_eva&(~pos_label)) - self.FN += np.sum((~pos_eva)&pos_label) - -class MCC(): - ''' - Calculate Matthews Correlation Coefficient - ''' - def __init__(self): - self.TP = 0 - self.FP = 0 - self.FN = 0 - self.TN = 0 - def update(self, logits, labels): - ''' - MCC update - ''' - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - labels = labels.astype(np.bool) - logits = logits.asnumpy() - logit_id = np.argmax(logits, axis=-1) - logit_id = np.reshape(logit_id, -1) - logit_id = logit_id.astype(np.bool) - ornot = logit_id ^ labels - - self.TP += (~ornot & labels).sum() - self.FP += (ornot & ~labels).sum() - self.FN += (ornot & labels).sum() - self.TN += (~ornot & ~labels).sum() - - def cal(self): - mcc = (self.TP*self.TN - self.FP*self.FN)/math.sqrt((self.TP+self.FP)*(self.TP+self.FN) * - (self.TN+self.FP)*(self.TN+self.FN)) - return mcc - -class Spearman_Correlation(): - ''' - Calculate Spearman Correlation Coefficient - ''' - def __init__(self): - self.label = [] - self.logit = [] - - def update(self, logits, labels): - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - logits = logits.asnumpy() - logits = np.reshape(logits, -1) - self.label.append(labels) - self.logit.append(logits) - - def cal(self): - ''' - Calculate Spearman Correlation - ''' - label = np.concatenate(self.label) - logit = np.concatenate(self.logit) - sort_label = label.argsort()[::-1] - sort_logit = logit.argsort()[::-1] - n = len(label) - d_acc = 0 - for i in range(n): - d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] - d_acc += d**2 - ps = 1 - 6*d_acc/n/(n**2-1) - return ps diff --git a/model_zoo/bert/src/bert_for_finetune.py b/model_zoo/bert/src/bert_for_finetune.py deleted file mode 100644 index 32ac0823b9..0000000000 --- a/model_zoo/bert/src/bert_for_finetune.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -''' -Bert for finetune script. -''' - -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.train.parallel_utils import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context -from .bert_for_pre_training import clip_grad -from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel -from .utils import CrossEntropyCalculation - - -GRADIENT_CLIP_TYPE = 1 -GRADIENT_CLIP_VALUE = 1.0 -grad_scale = C.MultitypeFuncGraph("grad_scale") -reciprocal = P.Reciprocal() -@grad_scale.register("Tensor", "Tensor") -def tensor_grad_scale(scale, grad): - return grad * reciprocal(scale) - -_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") -grad_overflow = P.FloatStatus() -@_grad_overflow.register("Tensor") -def _tensor_grad_overflow(grad): - return grad_overflow(grad) - -class BertFinetuneCell(nn.Cell): - """ - Especifically defined for finetuning where only four inputs tensor are needed. - """ - def __init__(self, network, optimizer, scale_update_cell=None): - - super(BertFinetuneCell, self).__init__(auto_prefix=False) - self.network = network - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', - get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("mirror_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") - - def construct(self, - input_ids, - input_mask, - token_type_id, - label_ids, - sens=None): - - - weights = self.weights - init = False - loss = self.network(input_ids, - input_mask, - token_type_id, - label_ids) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - - if not self.gpu_target: - init = self.alloc_status() - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - label_ids, - self.cast(scaling_sens, - mstype.float32)) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - grads = self.grad_reducer(grads) - if not self.gpu_target: - flag = self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond) - return F.depend(ret, succ) - -class BertSquadCell(nn.Cell): - """ - specifically defined for finetuning where only four inputs tensor are needed. - """ - def __init__(self, network, optimizer, scale_update_cell=None): - super(BertSquadCell, self).__init__(auto_prefix=False) - self.network = network - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("mirror_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") - def construct(self, - input_ids, - input_mask, - token_type_id, - start_position, - end_position, - unique_id, - is_impossible, - sens=None): - weights = self.weights - init = self.alloc_status() - loss = self.network(input_ids, - input_mask, - token_type_id, - start_position, - end_position, - unique_id, - is_impossible) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - start_position, - end_position, - unique_id, - is_impossible, - self.cast(scaling_sens, - mstype.float32)) - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - grads = self.grad_reducer(grads) - flag = self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond) - return F.depend(ret, succ) - -class BertCLS(nn.Cell): - """ - Train interface for classification finetuning task. - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, - assessment_method=""): - super(BertCLS, self).__init__() - self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings, - assessment_method) - self.loss = CrossEntropyCalculation(is_training) - self.num_labels = num_labels - self.assessment_method = assessment_method - self.is_training = is_training - def construct(self, input_ids, input_mask, token_type_id, label_ids): - logits = self.bert(input_ids, input_mask, token_type_id) - if self.assessment_method == "spearman_correlation": - if self.is_training: - loss = self.loss(logits, label_ids) - else: - loss = logits - else: - loss = self.loss(logits, label_ids, self.num_labels) - return loss - - -class BertNER(nn.Cell): - """ - Train interface for sequence labeling finetuning task. - """ - def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, - use_one_hot_embeddings=False): - super(BertNER, self).__init__() - self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) - if use_crf: - if not tag_to_index: - raise Exception("The dict for tag-index mapping should be provided for CRF.") - from src.CRF import CRF - self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) - else: - self.loss = CrossEntropyCalculation(is_training) - self.num_labels = num_labels - self.use_crf = use_crf - def construct(self, input_ids, input_mask, token_type_id, label_ids): - logits = self.bert(input_ids, input_mask, token_type_id) - if self.use_crf: - loss = self.loss(logits, label_ids) - else: - loss = self.loss(logits, label_ids, self.num_labels) - return loss - -class BertSquad(nn.Cell): - ''' - Train interface for SQuAD finetuning task. - ''' - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertSquad, self).__init__() - self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) - self.loss = CrossEntropyCalculation(is_training) - self.num_labels = num_labels - self.seq_length = config.seq_length - self.is_training = is_training - self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') - self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') - self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') - self.sum = P.ReduceSum() - self.equal = P.Equal() - self.argmax = P.ArgMaxWithValue(axis=1) - self.squeeze = P.Squeeze(axis=-1) - - def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): - logits = self.bert(input_ids, input_mask, token_type_id) - if self.is_training: - unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) - unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) - start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) - end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) - total_loss = (start_loss + end_loss) / 2.0 - else: - start_logits = self.squeeze(logits[:, :, 0:1]) - end_logits = self.squeeze(logits[:, :, 1:2]) - total_loss = (unique_id, start_logits, end_logits) - return total_loss diff --git a/model_zoo/bert/src/bert_for_pre_training.py b/model_zoo/bert/src/bert_for_pre_training.py deleted file mode 100644 index 802391ee86..0000000000 --- a/model_zoo/bert/src/bert_for_pre_training.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Bert for pretraining.""" -import numpy as np - -import mindspore.nn as nn -from mindspore.common.initializer import initializer, TruncatedNormal -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.train.parallel_utils import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context -from mindspore.ops import _selected_ops -from .bert_model import BertModel - -GRADIENT_CLIP_TYPE = 1 -GRADIENT_CLIP_VALUE = 1.0 - -clip_grad = C.MultitypeFuncGraph("clip_grad") - - -# pylint: disable=consider-using-in -@clip_grad.register("Number", "Number", "Tensor") -def _clip_grad(clip_type, clip_value, grad): - """ - Clip gradients. - - Inputs: - clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. - clip_value (float): Specifies how much to clip. - grad (tuple[Tensor]): Gradients. - - Outputs: - tuple[Tensor], clipped gradients. - """ - if clip_type != 0 and clip_type != 1: - return grad - dt = F.dtype(grad) - if clip_type == 0: - new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), - F.cast(F.tuple_to_array((clip_value,)), dt)) - else: - new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) - return new_grad - - -class GetMaskedLMOutput(nn.Cell): - """ - Get masked lm output. - - Args: - config (BertConfig): The config of BertModel. - - Returns: - Tensor, masked lm output. - """ - def __init__(self, config): - super(GetMaskedLMOutput, self).__init__() - self.width = config.hidden_size - self.reshape = P.Reshape() - self.gather = P.GatherV2() - - weight_init = TruncatedNormal(config.initializer_range) - self.dense = nn.Dense(self.width, - config.hidden_size, - weight_init=weight_init, - activation=config.hidden_act).to_float(config.compute_type) - self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) - self.output_bias = Parameter( - initializer( - 'zero', - config.vocab_size), - name='output_bias') - self.matmul = P.MatMul(transpose_b=True) - self.log_softmax = nn.LogSoftmax(axis=-1) - self.shape_flat_offsets = (-1, 1) - self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32)) - self.last_idx = (-1,) - self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) - self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32)) - self.cast = P.Cast() - self.compute_type = config.compute_type - self.dtype = config.dtype - - def construct(self, - input_tensor, - output_weights, - positions): - flat_offsets = self.reshape( - self.rng * self.seq_length_tensor, self.shape_flat_offsets) - flat_position = self.reshape(positions + flat_offsets, self.last_idx) - flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) - input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) - input_tensor = self.cast(input_tensor, self.compute_type) - output_weights = self.cast(output_weights, self.compute_type) - input_tensor = self.dense(input_tensor) - input_tensor = self.layernorm(input_tensor) - logits = self.matmul(input_tensor, output_weights) - logits = self.cast(logits, self.dtype) - logits = logits + self.output_bias - log_probs = self.log_softmax(logits) - return log_probs - - -class GetNextSentenceOutput(nn.Cell): - """ - Get next sentence output. - - Args: - config (BertConfig): The config of Bert. - - Returns: - Tensor, next sentence output. - """ - def __init__(self, config): - super(GetNextSentenceOutput, self).__init__() - self.log_softmax = _selected_ops.LogSoftmax() - weight_init = TruncatedNormal(config.initializer_range) - self.dense = nn.Dense(config.hidden_size, 2, - weight_init=weight_init, has_bias=True).to_float(config.compute_type) - self.dtype = config.dtype - self.cast = P.Cast() - - def construct(self, input_tensor): - logits = self.dense(input_tensor) - logits = self.cast(logits, self.dtype) - log_prob = self.log_softmax(logits) - return log_prob - - -class BertPreTraining(nn.Cell): - """ - Bert pretraining network. - - Args: - config (BertConfig): The config of BertModel. - is_training (bool): Specifies whether to use the training mode. - use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. - - Returns: - Tensor, prediction_scores, seq_relationship_score. - """ - def __init__(self, config, is_training, use_one_hot_embeddings): - super(BertPreTraining, self).__init__() - self.bert = BertModel(config, is_training, use_one_hot_embeddings) - self.cls1 = GetMaskedLMOutput(config) - self.cls2 = GetNextSentenceOutput(config) - - def construct(self, input_ids, input_mask, token_type_id, - masked_lm_positions): - sequence_output, pooled_output, embedding_table = \ - self.bert(input_ids, token_type_id, input_mask) - prediction_scores = self.cls1(sequence_output, - embedding_table, - masked_lm_positions) - seq_relationship_score = self.cls2(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPretrainingLoss(nn.Cell): - """ - Provide bert pre-training loss. - - Args: - config (BertConfig): The config of BertModel. - - Returns: - Tensor, total loss. - """ - def __init__(self, config): - super(BertPretrainingLoss, self).__init__() - self.vocab_size = config.vocab_size - self.onehot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.reshape = P.Reshape() - self.last_idx = (-1,) - self.neg = P.Neg() - self.cast = P.Cast() - - def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, - masked_lm_weights, next_sentence_labels): - """Defines the computation performed.""" - label_ids = self.reshape(masked_lm_ids, self.last_idx) - label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) - one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) - - per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) - numerator = self.reduce_sum(label_weights * per_example_loss, ()) - denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) - masked_lm_loss = numerator / denominator - - # next_sentence_loss - labels = self.reshape(next_sentence_labels, self.last_idx) - one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) - per_example_loss = self.neg(self.reduce_sum( - one_hot_labels * seq_relationship_score, self.last_idx)) - next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) - - # total_loss - total_loss = masked_lm_loss + next_sentence_loss - - return total_loss - - -class BertNetworkWithLoss(nn.Cell): - """ - Provide bert pre-training loss through network. - - Args: - config (BertConfig): The config of BertModel. - is_training (bool): Specifies whether to use the training mode. - use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. - - Returns: - Tensor, the loss of the network. - """ - def __init__(self, config, is_training, use_one_hot_embeddings=False): - super(BertNetworkWithLoss, self).__init__() - self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) - self.loss = BertPretrainingLoss(config) - self.cast = P.Cast() - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights): - prediction_scores, seq_relationship_score = \ - self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) - total_loss = self.loss(prediction_scores, seq_relationship_score, - masked_lm_ids, masked_lm_weights, next_sentence_labels) - return self.cast(total_loss, mstype.float32) - - -class BertTrainOneStepCell(nn.Cell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - sens (Number): The adjust parameter. Default: 1.0. - """ - def __init__(self, network, optimizer, sens=1.0): - super(BertTrainOneStepCell, self).__init__(auto_prefix=False) - self.network = network - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) - self.sens = sens - self.reducer_flag = False - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("mirror_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - - self.cast = P.Cast() - self.hyper_map = C.HyperMap() - - def set_sens(self, value): - self.sens = value - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights): - """Defines the computation performed.""" - weights = self.weights - - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(F.tuple_to_array((self.sens,)), - mstype.float32)) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) - succ = self.optimizer(grads) - return F.depend(loss, succ) - - -grad_scale = C.MultitypeFuncGraph("grad_scale") -reciprocal = P.Reciprocal() - - -@grad_scale.register("Tensor", "Tensor") -def tensor_grad_scale(scale, grad): - return grad * reciprocal(scale) - - -class BertTrainOneStepWithLossScaleCell(nn.Cell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - scale_update_cell (Cell): Cell to do the loss scale. Default: None. - """ - def __init__(self, network, optimizer, scale_update_cell=None): - super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', - get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") - - @C.add_flags(has_effect=True) - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - sens=None): - """Defines the computation performed.""" - weights = self.weights - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - self.clear_before_grad(init) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(scaling_sens, - mstype.float32)) - # apply grad reducer on grads - grads = self.grad_reducer(grads) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond, scaling_sens) - return F.depend(ret, succ) diff --git a/model_zoo/bert/src/bert_model.py b/model_zoo/bert/src/bert_model.py deleted file mode 100644 index 5cd90ab84b..0000000000 --- a/model_zoo/bert/src/bert_model.py +++ /dev/null @@ -1,948 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Bert model.""" - -import math -import copy -import numpy as np -import mindspore.common.dtype as mstype -import mindspore.nn as nn -import mindspore.ops.functional as F -from mindspore.common.initializer import TruncatedNormal, initializer -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter -from .fused_layer_norm import FusedLayerNorm - - -class BertConfig: - """ - Configuration for `BertModel`. - - Args: - batch_size (int): Batch size of input dataset. - seq_length (int): Length of input sequence. Default: 128. - vocab_size (int): The shape of each embedding vector. Default: 32000. - hidden_size (int): Size of the bert encoder layers. Default: 768. - num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder - cell. Default: 12. - num_attention_heads (int): Number of attention heads in the BertTransformer - encoder cell. Default: 12. - intermediate_size (int): Size of intermediate layer in the BertTransformer - encoder cell. Default: 3072. - hidden_act (str): Activation function used in the BertTransformer encoder - cell. Default: "gelu". - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 512. - type_vocab_size (int): Size of token type vocab. Default: 16. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from - dataset. Default: True. - token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded - from dataset. Default: True. - dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - """ - def __init__(self, - batch_size, - seq_length=128, - vocab_size=32000, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - initializer_range=0.02, - use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float32, - enable_fused_layernorm=False): - self.batch_size = batch_size - self.seq_length = seq_length - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.input_mask_from_dataset = input_mask_from_dataset - self.token_type_ids_from_dataset = token_type_ids_from_dataset - self.use_relative_positions = use_relative_positions - self.dtype = dtype - self.compute_type = compute_type - self.enable_fused_layernorm = enable_fused_layernorm - - -class EmbeddingLookup(nn.Cell): - """ - A embeddings lookup table with a fixed dictionary and size. - - Args: - vocab_size (int): Size of the dictionary of embeddings. - embedding_size (int): The size of each embedding vector. - embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of - each embedding vector. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - """ - def __init__(self, - vocab_size, - embedding_size, - embedding_shape, - use_one_hot_embeddings=False, - initializer_range=0.02): - super(EmbeddingLookup, self).__init__() - self.vocab_size = vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [vocab_size, embedding_size]), - name='embedding_table') - self.expand = P.ExpandDims() - self.shape_flat = (-1,) - self.gather = P.GatherV2() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = tuple(embedding_shape) - - def construct(self, input_ids): - extended_ids = self.expand(input_ids, -1) - flat_ids = self.reshape(extended_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) - output_for_reshape = self.array_mul( - one_hot_ids, self.embedding_table) - else: - output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) - output = self.reshape(output_for_reshape, self.shape) - return output, self.embedding_table - - -class EmbeddingPostprocessor(nn.Cell): - """ - Postprocessors apply positional and token type embeddings to word embeddings. - - Args: - embedding_size (int): The size of each embedding vector. - embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of - each embedding vector. - use_token_type (bool): Specifies whether to use token type embeddings. Default: False. - token_type_vocab_size (int): Size of token type vocab. Default: 16. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 512. - dropout_prob (float): The dropout probability. Default: 0.1. - """ - def __init__(self, - embedding_size, - embedding_shape, - use_relative_positions=False, - use_token_type=False, - token_type_vocab_size=16, - use_one_hot_embeddings=False, - initializer_range=0.02, - max_position_embeddings=512, - dropout_prob=0.1): - super(EmbeddingPostprocessor, self).__init__() - self.use_token_type = use_token_type - self.token_type_vocab_size = token_type_vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.max_position_embeddings = max_position_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [token_type_vocab_size, - embedding_size]), - name='embedding_table') - - self.shape_flat = (-1,) - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.1, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = tuple(embedding_shape) - self.layernorm = nn.LayerNorm((embedding_size,)) - self.dropout = nn.Dropout(1 - dropout_prob) - self.gather = P.GatherV2() - self.use_relative_positions = use_relative_positions - self.slice = P.StridedSlice() - self.full_position_embeddings = Parameter(initializer - (TruncatedNormal(initializer_range), - [max_position_embeddings, - embedding_size]), - name='full_position_embeddings') - - def construct(self, token_type_ids, word_embeddings): - output = word_embeddings - if self.use_token_type: - flat_ids = self.reshape(token_type_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, - self.token_type_vocab_size, self.on_value, self.off_value) - token_type_embeddings = self.array_mul(one_hot_ids, - self.embedding_table) - else: - token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) - token_type_embeddings = self.reshape(token_type_embeddings, self.shape) - output += token_type_embeddings - if not self.use_relative_positions: - _, seq, width = self.shape - position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) - position_embeddings = self.reshape(position_embeddings, (1, seq, width)) - output += position_embeddings - output = self.layernorm(output) - output = self.dropout(output) - return output - - -class BertOutput(nn.Cell): - """ - Apply a linear computation to hidden status and a residual computation to input. - - Args: - in_channels (int): Input channels. - out_channels (int): Output channels. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - dropout_prob (float): The dropout probability. Default: 0.1. - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - """ - def __init__(self, - in_channels, - out_channels, - initializer_range=0.02, - dropout_prob=0.1, - compute_type=mstype.float32, - enable_fused_layernorm=False): - super(BertOutput, self).__init__() - self.dense = nn.Dense(in_channels, out_channels, - weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) - self.dropout = nn.Dropout(1 - dropout_prob) - self.dropout_prob = dropout_prob - self.add = P.TensorAdd() - if compute_type == mstype.float16: - self.layernorm = FusedLayerNorm((out_channels,), - use_batch_norm=enable_fused_layernorm).to_float(compute_type) - else: - self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) - self.cast = P.Cast() - - def construct(self, hidden_status, input_tensor): - output = self.dense(hidden_status) - output = self.dropout(output) - output = self.add(input_tensor, output) - output = self.layernorm(output) - return output - - -class RelaPosMatrixGenerator(nn.Cell): - """ - Generates matrix of relative positions between inputs. - - Args: - length (int): Length of one dim for the matrix to be generated. - max_relative_position (int): Max value of relative position. - """ - def __init__(self, length, max_relative_position): - super(RelaPosMatrixGenerator, self).__init__() - self._length = length - self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) - self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) - self.range_length = -length + 1 - - self.tile = P.Tile() - self.range_mat = P.Reshape() - self.sub = P.Sub() - self.expanddims = P.ExpandDims() - self.cast = P.Cast() - - def construct(self): - range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) - range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) - tile_row_out = self.tile(range_vec_row_out, (self._length,)) - tile_col_out = self.tile(range_vec_col_out, (1, self._length)) - range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) - transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) - distance_mat = self.sub(range_mat_out, transpose_out) - - distance_mat_clipped = C.clip_by_value(distance_mat, - self._min_relative_position, - self._max_relative_position) - - # Shift values to be >=0. Each integer still uniquely identifies a - # relative position difference. - final_mat = distance_mat_clipped + self._max_relative_position - return final_mat - - -class RelaPosEmbeddingsGenerator(nn.Cell): - """ - Generates tensor of size [length, length, depth]. - - Args: - length (int): Length of one dim for the matrix to be generated. - depth (int): Size of each attention head. - max_relative_position (int): Maxmum value of relative position. - initializer_range (float): Initialization value of TruncatedNormal. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - """ - def __init__(self, - length, - depth, - max_relative_position, - initializer_range, - use_one_hot_embeddings=False): - super(RelaPosEmbeddingsGenerator, self).__init__() - self.depth = depth - self.vocab_size = max_relative_position * 2 + 1 - self.use_one_hot_embeddings = use_one_hot_embeddings - - self.embeddings_table = Parameter( - initializer(TruncatedNormal(initializer_range), - [self.vocab_size, self.depth]), - name='embeddings_for_position') - - self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, - max_relative_position=max_relative_position) - self.reshape = P.Reshape() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.shape = P.Shape() - self.gather = P.GatherV2() # index_select - self.matmul = P.BatchMatMul() - - def construct(self): - relative_positions_matrix_out = self.relative_positions_matrix() - - # Generate embedding for each relative position of dimension depth. - if self.use_one_hot_embeddings: - flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) - one_hot_relative_positions_matrix = self.one_hot( - flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) - embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) - my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) - embeddings = self.reshape(embeddings, my_shape) - else: - embeddings = self.gather(self.embeddings_table, - relative_positions_matrix_out, 0) - return embeddings - - -class SaturateCast(nn.Cell): - """ - Performs a safe saturating cast. This operation applies proper clamping before casting to prevent - the danger that the value will overflow or underflow. - - Args: - src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. - dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. - """ - def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): - super(SaturateCast, self).__init__() - np_type = mstype.dtype_to_nptype(dst_type) - min_type = np.finfo(np_type).min - max_type = np.finfo(np_type).max - - self.tensor_min_type = Tensor([min_type], dtype=src_type) - self.tensor_max_type = Tensor([max_type], dtype=src_type) - - self.min_op = P.Minimum() - self.max_op = P.Maximum() - self.cast = P.Cast() - self.dst_type = dst_type - - def construct(self, x): - out = self.max_op(x, self.tensor_min_type) - out = self.min_op(out, self.tensor_max_type) - return self.cast(out, self.dst_type) - - -class BertAttention(nn.Cell): - """ - Apply multi-headed attention from "from_tensor" to "to_tensor". - - Args: - batch_size (int): Batch size of input datasets. - from_tensor_width (int): Size of last dim of from_tensor. - to_tensor_width (int): Size of last dim of to_tensor. - from_seq_length (int): Length of from_tensor sequence. - to_seq_length (int): Length of to_tensor sequence. - num_attention_heads (int): Number of attention heads. Default: 1. - size_per_head (int): Size of each attention head. Default: 512. - query_act (str): Activation function for the query transform. Default: None. - key_act (str): Activation function for the key transform. Default: None. - value_act (str): Activation function for the value transform. Default: None. - has_attention_mask (bool): Specifies whether to use attention mask. Default: False. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.0. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d - tensor. Default: False. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - from_tensor_width, - to_tensor_width, - from_seq_length, - to_seq_length, - num_attention_heads=1, - size_per_head=512, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=False, - attention_probs_dropout_prob=0.0, - use_one_hot_embeddings=False, - initializer_range=0.02, - do_return_2d_tensor=False, - use_relative_positions=False, - compute_type=mstype.float32): - - super(BertAttention, self).__init__() - self.batch_size = batch_size - self.from_seq_length = from_seq_length - self.to_seq_length = to_seq_length - self.num_attention_heads = num_attention_heads - self.size_per_head = size_per_head - self.has_attention_mask = has_attention_mask - self.use_relative_positions = use_relative_positions - - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) - self.reshape = P.Reshape() - self.shape_from_2d = (-1, from_tensor_width) - self.shape_to_2d = (-1, to_tensor_width) - weight = TruncatedNormal(initializer_range) - units = num_attention_heads * size_per_head - self.query_layer = nn.Dense(from_tensor_width, - units, - activation=query_act, - weight_init=weight).to_float(compute_type) - self.key_layer = nn.Dense(to_tensor_width, - units, - activation=key_act, - weight_init=weight).to_float(compute_type) - self.value_layer = nn.Dense(to_tensor_width, - units, - activation=value_act, - weight_init=weight).to_float(compute_type) - - self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) - self.shape_to = ( - batch_size, to_seq_length, num_attention_heads, size_per_head) - - self.matmul_trans_b = P.BatchMatMul(transpose_b=True) - self.multiply = P.Mul() - self.transpose = P.Transpose() - self.trans_shape = (0, 2, 1, 3) - self.trans_shape_relative = (2, 0, 1, 3) - self.trans_shape_position = (1, 2, 0, 3) - self.multiply_data = Tensor([-10000.0,], dtype=compute_type) - self.batch_num = batch_size * num_attention_heads - self.matmul = P.BatchMatMul() - - self.softmax = nn.Softmax() - self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) - - if self.has_attention_mask: - self.expand_dims = P.ExpandDims() - self.sub = P.Sub() - self.add = P.TensorAdd() - self.cast = P.Cast() - self.get_dtype = P.DType() - if do_return_2d_tensor: - self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) - else: - self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) - - self.cast_compute_type = SaturateCast(dst_type=compute_type) - if self.use_relative_positions: - self._generate_relative_positions_embeddings = \ - RelaPosEmbeddingsGenerator(length=to_seq_length, - depth=size_per_head, - max_relative_position=16, - initializer_range=initializer_range, - use_one_hot_embeddings=use_one_hot_embeddings) - - def construct(self, from_tensor, to_tensor, attention_mask): - # reshape 2d/3d input tensors to 2d - from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) - to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) - query_out = self.query_layer(from_tensor_2d) - key_out = self.key_layer(to_tensor_2d) - value_out = self.value_layer(to_tensor_2d) - - query_layer = self.reshape(query_out, self.shape_from) - query_layer = self.transpose(query_layer, self.trans_shape) - key_layer = self.reshape(key_out, self.shape_to) - key_layer = self.transpose(key_layer, self.trans_shape) - - attention_scores = self.matmul_trans_b(query_layer, key_layer) - - # use_relative_position, supplementary logic - if self.use_relative_positions: - # 'relations_keys' = [F|T, F|T, H] - relations_keys = self._generate_relative_positions_embeddings() - relations_keys = self.cast_compute_type(relations_keys) - # query_layer_t is [F, B, N, H] - query_layer_t = self.transpose(query_layer, self.trans_shape_relative) - # query_layer_r is [F, B * N, H] - query_layer_r = self.reshape(query_layer_t, - (self.from_seq_length, - self.batch_num, - self.size_per_head)) - # key_position_scores is [F, B * N, F|T] - key_position_scores = self.matmul_trans_b(query_layer_r, - relations_keys) - # key_position_scores_r is [F, B, N, F|T] - key_position_scores_r = self.reshape(key_position_scores, - (self.from_seq_length, - self.batch_size, - self.num_attention_heads, - self.from_seq_length)) - # key_position_scores_r_t is [B, N, F, F|T] - key_position_scores_r_t = self.transpose(key_position_scores_r, - self.trans_shape_position) - attention_scores = attention_scores + key_position_scores_r_t - - attention_scores = self.multiply(self.scores_mul, attention_scores) - - if self.has_attention_mask: - attention_mask = self.expand_dims(attention_mask, 1) - multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), - self.cast(attention_mask, self.get_dtype(attention_scores))) - - adder = self.multiply(multiply_out, self.multiply_data) - attention_scores = self.add(adder, attention_scores) - - attention_probs = self.softmax(attention_scores) - attention_probs = self.dropout(attention_probs) - - value_layer = self.reshape(value_out, self.shape_to) - value_layer = self.transpose(value_layer, self.trans_shape) - context_layer = self.matmul(attention_probs, value_layer) - - # use_relative_position, supplementary logic - if self.use_relative_positions: - # 'relations_values' = [F|T, F|T, H] - relations_values = self._generate_relative_positions_embeddings() - relations_values = self.cast_compute_type(relations_values) - # attention_probs_t is [F, B, N, T] - attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) - # attention_probs_r is [F, B * N, T] - attention_probs_r = self.reshape( - attention_probs_t, - (self.from_seq_length, - self.batch_num, - self.to_seq_length)) - # value_position_scores is [F, B * N, H] - value_position_scores = self.matmul(attention_probs_r, - relations_values) - # value_position_scores_r is [F, B, N, H] - value_position_scores_r = self.reshape(value_position_scores, - (self.from_seq_length, - self.batch_size, - self.num_attention_heads, - self.size_per_head)) - # value_position_scores_r_t is [B, N, F, H] - value_position_scores_r_t = self.transpose(value_position_scores_r, - self.trans_shape_position) - context_layer = context_layer + value_position_scores_r_t - - context_layer = self.transpose(context_layer, self.trans_shape) - context_layer = self.reshape(context_layer, self.shape_return) - - return context_layer - - -class BertSelfAttention(nn.Cell): - """ - Apply self-attention. - - Args: - batch_size (int): Batch size of input dataset. - seq_length (int): Length of input sequence. - hidden_size (int): Size of the bert encoder layers. - num_attention_heads (int): Number of attention heads. Default: 12. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - seq_length, - hidden_size, - num_attention_heads=12, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - compute_type=mstype.float32, - enable_fused_layernorm=False): - super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError("The hidden size (%d) is not a multiple of the number " - "of attention heads (%d)" % (hidden_size, num_attention_heads)) - - self.size_per_head = int(hidden_size / num_attention_heads) - - self.attention = BertAttention( - batch_size=batch_size, - from_tensor_width=hidden_size, - to_tensor_width=hidden_size, - from_seq_length=seq_length, - to_seq_length=seq_length, - num_attention_heads=num_attention_heads, - size_per_head=self.size_per_head, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - use_relative_positions=use_relative_positions, - has_attention_mask=True, - do_return_2d_tensor=True, - compute_type=compute_type) - - self.output = BertOutput(in_channels=hidden_size, - out_channels=hidden_size, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - compute_type=compute_type, - enable_fused_layernorm=enable_fused_layernorm) - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - - def construct(self, input_tensor, attention_mask): - input_tensor = self.reshape(input_tensor, self.shape) - attention_output = self.attention(input_tensor, input_tensor, attention_mask) - output = self.output(attention_output, input_tensor) - return output - - -class BertEncoderCell(nn.Cell): - """ - Encoder cells used in BertTransformer. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the bert encoder layers. Default: 768. - seq_length (int): Length of input sequence. Default: 512. - num_attention_heads (int): Number of attention heads. Default: 12. - intermediate_size (int): Size of intermediate layer. Default: 3072. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.02. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - hidden_act (str): Activation function. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. - """ - def __init__(self, - batch_size, - hidden_size=768, - seq_length=512, - num_attention_heads=12, - intermediate_size=3072, - attention_probs_dropout_prob=0.02, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32, - enable_fused_layernorm=False): - super(BertEncoderCell, self).__init__() - self.attention = BertSelfAttention( - batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - use_relative_positions=use_relative_positions, - compute_type=compute_type, - enable_fused_layernorm=enable_fused_layernorm) - self.intermediate = nn.Dense(in_channels=hidden_size, - out_channels=intermediate_size, - activation=hidden_act, - weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) - self.output = BertOutput(in_channels=intermediate_size, - out_channels=hidden_size, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - compute_type=compute_type, - enable_fused_layernorm=enable_fused_layernorm) - - def construct(self, hidden_states, attention_mask): - # self-attention - attention_output = self.attention(hidden_states, attention_mask) - # feed construct - intermediate_output = self.intermediate(attention_output) - # add and normalize - output = self.output(intermediate_output, attention_output) - return output - - -class BertTransformer(nn.Cell): - """ - Multi-layer bert transformer. - - Args: - batch_size (int): Batch size of input dataset. - hidden_size (int): Size of the encoder layers. - seq_length (int): Length of input sequence. - num_hidden_layers (int): Number of hidden layers in encoder cells. - num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. - intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - hidden_act (str): Activation function used in the encoder cells. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - return_all_encoders (bool): Specifies whether to return all encoders. Default: False. - """ - def __init__(self, - batch_size, - hidden_size, - seq_length, - num_hidden_layers, - num_attention_heads=12, - intermediate_size=3072, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32, - return_all_encoders=False, - enable_fused_layernorm=False): - super(BertTransformer, self).__init__() - self.return_all_encoders = return_all_encoders - - layers = [] - for _ in range(num_hidden_layers): - layer = BertEncoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - use_relative_positions=use_relative_positions, - hidden_act=hidden_act, - compute_type=compute_type, - enable_fused_layernorm=enable_fused_layernorm) - layers.append(layer) - - self.layers = nn.CellList(layers) - - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - self.out_shape = (batch_size, seq_length, hidden_size) - - def construct(self, input_tensor, attention_mask): - prev_output = self.reshape(input_tensor, self.shape) - - all_encoder_layers = () - for layer_module in self.layers: - layer_output = layer_module(prev_output, attention_mask) - prev_output = layer_output - - if self.return_all_encoders: - layer_output = self.reshape(layer_output, self.out_shape) - all_encoder_layers = all_encoder_layers + (layer_output,) - - if not self.return_all_encoders: - prev_output = self.reshape(prev_output, self.out_shape) - all_encoder_layers = all_encoder_layers + (prev_output,) - return all_encoder_layers - - -class CreateAttentionMaskFromInputMask(nn.Cell): - """ - Create attention mask according to input mask. - - Args: - config (Class): Configuration for BertModel. - """ - def __init__(self, config): - super(CreateAttentionMaskFromInputMask, self).__init__() - self.input_mask_from_dataset = config.input_mask_from_dataset - self.input_mask = None - - if not self.input_mask_from_dataset: - self.input_mask = initializer( - "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() - - self.cast = P.Cast() - self.reshape = P.Reshape() - self.shape = (config.batch_size, 1, config.seq_length) - self.broadcast_ones = initializer( - "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() - self.batch_matmul = P.BatchMatMul() - - def construct(self, input_mask): - if not self.input_mask_from_dataset: - input_mask = self.input_mask - - attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) - return attention_mask - - -class BertModel(nn.Cell): - """ - Bidirectional Encoder Representations from Transformers. - - Args: - config (Class): Configuration for BertModel. - is_training (bool): True for training mode. False for eval mode. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - """ - def __init__(self, - config, - is_training, - use_one_hot_embeddings=False): - super(BertModel, self).__init__() - config = copy.deepcopy(config) - if not is_training: - config.hidden_dropout_prob = 0.0 - config.attention_probs_dropout_prob = 0.0 - - self.input_mask_from_dataset = config.input_mask_from_dataset - self.token_type_ids_from_dataset = config.token_type_ids_from_dataset - self.batch_size = config.batch_size - self.seq_length = config.seq_length - self.hidden_size = config.hidden_size - self.num_hidden_layers = config.num_hidden_layers - self.embedding_size = config.hidden_size - self.token_type_ids = None - - self.last_idx = self.num_hidden_layers - 1 - output_embedding_shape = [self.batch_size, self.seq_length, - self.embedding_size] - - if not self.token_type_ids_from_dataset: - self.token_type_ids = initializer( - "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() - - self.bert_embedding_lookup = EmbeddingLookup( - vocab_size=config.vocab_size, - embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range) - - self.bert_embedding_postprocessor = EmbeddingPostprocessor( - embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_relative_positions=config.use_relative_positions, - use_token_type=True, - token_type_vocab_size=config.type_vocab_size, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=0.02, - max_position_embeddings=config.max_position_embeddings, - dropout_prob=config.hidden_dropout_prob) - - self.bert_encoder = BertTransformer( - batch_size=self.batch_size, - hidden_size=self.hidden_size, - seq_length=self.seq_length, - num_attention_heads=config.num_attention_heads, - num_hidden_layers=self.num_hidden_layers, - intermediate_size=config.intermediate_size, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - use_relative_positions=config.use_relative_positions, - hidden_act=config.hidden_act, - compute_type=config.compute_type, - return_all_encoders=True, - enable_fused_layernorm=config.enable_fused_layernorm) - - self.cast = P.Cast() - self.dtype = config.dtype - self.cast_compute_type = SaturateCast(dst_type=config.compute_type) - self.slice = P.StridedSlice() - - self.squeeze_1 = P.Squeeze(axis=1) - self.dense = nn.Dense(self.hidden_size, self.hidden_size, - activation="tanh", - weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) - self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) - - def construct(self, input_ids, token_type_ids, input_mask): - - # embedding - if not self.token_type_ids_from_dataset: - token_type_ids = self.token_type_ids - word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) - embedding_output = self.bert_embedding_postprocessor(token_type_ids, - word_embeddings) - - # attention mask [batch_size, seq_length, seq_length] - attention_mask = self._create_attention_mask_from_input_mask(input_mask) - - # bert encoder - encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), - attention_mask) - - sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) - - # pooler - sequence_slice = self.slice(sequence_output, - (0, 0, 0), - (self.batch_size, 1, self.hidden_size), - (1, 1, 1)) - first_token = self.squeeze_1(sequence_slice) - pooled_output = self.dense(first_token) - pooled_output = self.cast(pooled_output, self.dtype) - - return sequence_output, pooled_output, embedding_tables diff --git a/model_zoo/bert/src/clue_classification_dataset_process.py b/model_zoo/bert/src/clue_classification_dataset_process.py deleted file mode 100755 index 1e27fe0352..0000000000 --- a/model_zoo/bert/src/clue_classification_dataset_process.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -sample script of processing CLUE classification dataset using mindspore.dataset.text for fine-tuning bert -""" - -import os -import numpy as np - -import mindspore.common.dtype as mstype -import mindspore.dataset as ds -import mindspore.dataset.text as text -import mindspore.dataset.transforms.c_transforms as ops - - -def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, - data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): - """Process TNEWS dataset""" - ### Loading TNEWS from CLUEDataset - assert data_usage in ['train', 'eval', 'test'] - if data_usage == 'train': - dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='TNEWS', - usage=data_usage, shuffle=shuffle_dataset) - elif data_usage == 'eval': - dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='TNEWS', - usage=data_usage, shuffle=shuffle_dataset) - else: - dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='TNEWS', - usage=data_usage, shuffle=shuffle_dataset) - ### Processing label - if data_usage == 'test': - dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], - columns_order=["id", "label_id", "sentence"], operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) - else: - label_vocab = text.Vocab.from_list(label_list) - label_lookup = text.Lookup(label_vocab) - dataset = dataset.map(input_columns="label_desc", output_columns="label_id", operations=label_lookup) - ### Processing sentence - vocab = text.Vocab.from_file(bert_vocab_path) - tokenizer = text.BertTokenizer(vocab, lower_case=True) - lookup = text.Lookup(vocab, unknown_token='[UNK]') - dataset = dataset.map(input_columns=["sentence"], operations=tokenizer) - dataset = dataset.map(input_columns=["sentence"], operations=ops.Slice(slice(0, max_seq_len))) - dataset = dataset.map(input_columns=["sentence"], - operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), - append=np.array(["[SEP]"], dtype='S'))) - dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) - dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) - dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], - columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) - dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], - columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) - dataset = dataset.batch(batch_size) - label = [] - text_ids = [] - mask_ids = [] - segment_ids = [] - for data in dataset: - label.append(data[0]) - text_ids.append(data[1]) - mask_ids.append(data[2]) - segment_ids.append(data[3]) - return label, text_ids, mask_ids, segment_ids - - -def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, - data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): - """Process CMNLI dataset""" - ### Loading CMNLI from CLUEDataset - assert data_usage in ['train', 'eval', 'test'] - if data_usage == 'train': - dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='CMNLI', - usage=data_usage, shuffle=shuffle_dataset) - elif data_usage == 'eval': - dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='CMNLI', - usage=data_usage, shuffle=shuffle_dataset) - else: - dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='CMNLI', - usage=data_usage, shuffle=shuffle_dataset) - ### Processing label - if data_usage == 'test': - dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], - columns_order=["id", "label_id", "sentence1", "sentence2"], operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) - else: - label_vocab = text.Vocab.from_list(label_list) - label_lookup = text.Lookup(label_vocab) - dataset = dataset.map(input_columns="label", output_columns="label_id", operations=label_lookup) - ### Processing sentence pairs - vocab = text.Vocab.from_file(bert_vocab_path) - tokenizer = text.BertTokenizer(vocab, lower_case=True) - lookup = text.Lookup(vocab, unknown_token='[UNK]') - ### Tokenizing sentences and truncate sequence pair - dataset = dataset.map(input_columns=["sentence1"], operations=tokenizer) - dataset = dataset.map(input_columns=["sentence2"], operations=tokenizer) - dataset = dataset.map(input_columns=["sentence1", "sentence2"], - operations=text.TruncateSequencePair(max_seq_len-3)) - ### Adding special tokens - dataset = dataset.map(input_columns=["sentence1"], - operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), - append=np.array(["[SEP]"], dtype='S'))) - dataset = dataset.map(input_columns=["sentence2"], - operations=ops.Concatenate(append=np.array(["[SEP]"], dtype='S'))) - ### Generating segment_ids - dataset = dataset.map(input_columns=["sentence1"], output_columns=["sentence1", "type_sentence1"], - columns_order=["sentence1", "type_sentence1", "sentence2", "label_id"], - operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["sentence2"], output_columns=["sentence2", "type_sentence2"], - columns_order=["sentence1", "type_sentence1", "sentence2", "type_sentence2", "label_id"], - operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["type_sentence1"], operations=[lookup, ops.Fill(0)]) - dataset = dataset.map(input_columns=["type_sentence2"], operations=[lookup, ops.Fill(1)]) - dataset = dataset.map(input_columns=["type_sentence1", "type_sentence2"], output_columns=["segment_ids"], - columns_order=["sentence1", "sentence2", "segment_ids", "label_id"], - operations=ops.Concatenate()) - dataset = dataset.map(input_columns=["segment_ids"], operations=ops.PadEnd([max_seq_len], 0)) - ### Generating text_ids - dataset = dataset.map(input_columns=["sentence1", "sentence2"], output_columns=["text_ids"], - columns_order=["text_ids", "segment_ids", "label_id"], - operations=ops.Concatenate()) - dataset = dataset.map(input_columns=["text_ids"], operations=lookup) - dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) - ### Generating mask_ids - dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], - columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) - dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) - dataset = dataset.batch(batch_size) - label = [] - text_ids = [] - mask_ids = [] - segment_ids = [] - for data in dataset: - label.append(data[0]) - text_ids.append(data[1]) - mask_ids.append(data[2]) - segment_ids.append(data[3]) - return label, text_ids, mask_ids, segment_ids diff --git a/model_zoo/bert/src/config.py b/model_zoo/bert/src/config.py deleted file mode 100644 index 812f0c2f18..0000000000 --- a/model_zoo/bert/src/config.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in dataset.py, run_pretrain.py -""" -from easydict import EasyDict as edict -import mindspore.common.dtype as mstype -from .bert_model import BertConfig -cfg = edict({ - 'bert_network': 'base', - 'loss_scale_value': 65536, - 'scale_factor': 2, - 'scale_window': 1000, - 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ - 'learning_rate': 3e-5, - 'end_learning_rate': 1e-10, - 'power': 5.0, - 'weight_decay': 1e-5, - 'eps': 1e-6, - 'warmup_steps': 10000, - }), - 'Lamb': edict({ - 'start_learning_rate': 3e-5, - 'end_learning_rate': 1e-10, - 'power': 10.0, - 'warmup_steps': 10000, - 'weight_decay': 0.01, - 'eps': 1e-6, - }), - 'Momentum': edict({ - 'learning_rate': 2e-5, - 'momentum': 0.9, - }), -}) - -''' -Including two kinds of network: \ -base: Goole BERT-base(the base version of BERT model). -large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ - Functional Relative Posetional Encoding as an effective positional encoding scheme). -''' -if cfg.bert_network == 'base': - bert_net_cfg = BertConfig( - batch_size=32, - seq_length=128, - vocab_size=21128, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16 - ) -if cfg.bert_network == 'nezha': - bert_net_cfg = BertConfig( - batch_size=32, - seq_length=128, - vocab_size=21128, - hidden_size=1024, - num_hidden_layers=24, - num_attention_heads=16, - intermediate_size=4096, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16 - ) -if cfg.bert_network == 'large': - bert_net_cfg = BertConfig( - batch_size=16, - seq_length=512, - vocab_size=30522, - hidden_size=1024, - num_hidden_layers=24, - num_attention_heads=16, - intermediate_size=4096, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16, - enable_fused_layernorm=True - ) diff --git a/model_zoo/bert/src/dataset.py b/model_zoo/bert/src/dataset.py deleted file mode 100644 index e530718d4f..0000000000 --- a/model_zoo/bert/src/dataset.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Data operations, will be used in run_pretrain.py -""" -import os -import mindspore.common.dtype as mstype -import mindspore.dataset.engine.datasets as de -import mindspore.dataset.transforms.c_transforms as C -from mindspore import log as logger -from .config import bert_net_cfg - - -def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true", - data_sink_steps=1, data_dir=None, schema_dir=None): - """create train dataset""" - # apply repeat operations - repeat_count = epoch_size - files = os.listdir(data_dir) - data_files = [] - for file_name in files: - if "tfrecord" in file_name: - data_files.append(os.path.join(data_dir, file_name)) - ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, - columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", - "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], - shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, - num_shards=device_num, shard_id=rank, shard_equal_rows=True) - ori_dataset_size = ds.get_dataset_size() - print('origin dataset size: ', ori_dataset_size) - new_size = ori_dataset_size - if enable_data_sink == "true": - new_size = data_sink_steps * bert_net_cfg.batch_size - ds.set_dataset_size(new_size) - new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) - type_cast_op = C.TypeCast(mstype.int32) - ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) - ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) - ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - # apply batch operations - ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) - ds = ds.repeat(max(new_repeat_count, repeat_count)) - logger.info("data size: {}".format(ds.get_dataset_size())) - logger.info("repeatcount: {}".format(ds.get_repeat_count())) - return ds, new_repeat_count - - -def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", - data_file_path=None, schema_file_path=None): - """create finetune or evaluation dataset""" - type_cast_op = C.TypeCast(mstype.int32) - ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, - columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) - if assessment_method == "Spearman_correlation": - type_cast_op_float = C.TypeCast(mstype.float32) - ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) - else: - ds = ds.map(input_columns="label_ids", operations=type_cast_op) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.repeat(repeat_count) - # apply shuffle operation - buffer_size = 960 - ds = ds.shuffle(buffer_size=buffer_size) - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - return ds - - -def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", - data_file_path=None, schema_file_path=None): - """create finetune or evaluation dataset""" - type_cast_op = C.TypeCast(mstype.int32) - ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, - columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) - if assessment_method == "Spearman_correlation": - type_cast_op_float = C.TypeCast(mstype.float32) - ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) - else: - ds = ds.map(input_columns="label_ids", operations=type_cast_op) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.repeat(repeat_count) - # apply shuffle operation - buffer_size = 960 - ds = ds.shuffle(buffer_size=buffer_size) - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - return ds - - -def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True): - """create finetune or evaluation dataset""" - type_cast_op = C.TypeCast(mstype.int32) - if is_training: - ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, - columns_list=["input_ids", "input_mask", "segment_ids", - "start_positions", "end_positions", - "unique_ids", "is_impossible"]) - ds = ds.map(input_columns="start_positions", operations=type_cast_op) - ds = ds.map(input_columns="end_positions", operations=type_cast_op) - else: - ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, - columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"]) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.repeat(repeat_count) - # apply shuffle operation - buffer_size = 960 - ds = ds.shuffle(buffer_size=buffer_size) - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - return ds diff --git a/model_zoo/bert/src/finetune_eval_config.py b/model_zoo/bert/src/finetune_eval_config.py deleted file mode 100644 index 4b8e121e09..0000000000 --- a/model_zoo/bert/src/finetune_eval_config.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -config settings, will be used in finetune.py -""" - -from easydict import EasyDict as edict -import mindspore.common.dtype as mstype -from .bert_model import BertConfig - -optimizer_cfg = edict({ - 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ - 'learning_rate': 2e-5, - 'end_learning_rate': 1e-7, - 'power': 1.0, - 'weight_decay': 1e-5, - 'eps': 1e-6, - }), - 'Lamb': edict({ - 'start_learning_rate': 2e-5, - 'end_learning_rate': 1e-7, - 'power': 1.0, - 'weight_decay': 0.01, - 'decay_filter': lambda x: False, - }), - 'Momentum': edict({ - 'learning_rate': 2e-5, - 'momentum': 0.9, - }), -}) - -bert_net_cfg = BertConfig( - batch_size=16, - seq_length=128, - vocab_size=21128, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16, -) diff --git a/model_zoo/bert/src/utils.py b/model_zoo/bert/src/utils.py deleted file mode 100644 index dfb6ffa5fe..0000000000 --- a/model_zoo/bert/src/utils.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -Functional Cells used in Bert finetune and evaluation. -""" - -import os -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.tensor import Tensor -from mindspore.common import dtype as mstype -from mindspore.train.callback import Callback - - -class CrossEntropyCalculation(nn.Cell): - """ - Cross Entropy loss - """ - def __init__(self, is_training=True): - super(CrossEntropyCalculation, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.reshape = P.Reshape() - self.last_idx = (-1,) - self.neg = P.Neg() - self.cast = P.Cast() - self.is_training = is_training - - def construct(self, logits, label_ids, num_labels): - if self.is_training: - label_ids = self.reshape(label_ids, self.last_idx) - one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) - per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) - loss = self.reduce_mean(per_example_loss, self.last_idx) - return_value = self.cast(loss, mstype.float32) - else: - return_value = logits * 1.0 - return return_value - - -def make_directory(path: str): - """Make directory.""" - if path is None or not isinstance(path, str) or path.strip() == "": - logger.error("The path(%r) is invalid type.", path) - raise TypeError("Input path is invaild type") - - # convert the relative paths - path = os.path.realpath(path) - logger.debug("The abs path is %r", path) - - # check the path is exist and write permissions? - if os.path.exists(path): - real_path = path - else: - # All exceptions need to be caught because create directory maybe have some limit(permissions) - logger.debug("The directory(%s) doesn't exist, will create it", path) - try: - os.makedirs(path, exist_ok=True) - real_path = path - except PermissionError as e: - logger.error("No write permission on the directory(%r), error = %r", path, e) - raise TypeError("No write permission on the directory.") - return real_path - -class LossCallBack(Callback): - """ - Monitor the loss in training. - If the loss in NAN or INF terminating training. - Note: - if per_print_times is 0 do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0") - self._per_print_times = per_print_times - def step_end(self, run_context): - cb_params = run_context.original_args() - print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) - -def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): - """ - Find the ckpt finetune generated and load it into eval network. - """ - files = os.listdir(load_finetune_checkpoint_dir) - pre_len = len(prefix) - max_num = 0 - for filename in files: - name_ext = os.path.splitext(filename) - if name_ext[-1] != ".ckpt": - continue - #steps_per_epoch = ds.get_dataset_size() - if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): - index = filename[pre_len:].find("-") - if index == 0 and max_num == 0: - load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) - elif index not in (0, -1): - name_split = name_ext[-2].split('_') - if (steps_per_epoch != int(name_split[len(name_split)-1])) \ - or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])): - continue - num = filename[pre_len + 1:pre_len + index] - if int(num) > max_num: - max_num = int(num) - load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) - return load_finetune_checkpoint_path diff --git a/model_zoo/community/README.md b/model_zoo/community/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/deepfm/README.md b/model_zoo/deepfm/README.md deleted file mode 100644 index 47809a54c0..0000000000 --- a/model_zoo/deepfm/README.md +++ /dev/null @@ -1,132 +0,0 @@ -# DeepFM Description - -This is an example of training DeepFM with Criteo dataset in MindSpore. - -[Paper](https://arxiv.org/pdf/1703.04247.pdf) Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He - - -# Model architecture - -The overall network architecture of DeepFM is show below: - -[Link](https://arxiv.org/pdf/1703.04247.pdf) - - -# Requirements -- Install [MindSpore](https://www.mindspore.cn/install/en). -- Download the criteo dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. -- For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - -# Script description - -## Script and sample code - -```python -├── deepfm - ├── README.md - ├── scripts - │ ├──run_train.sh - │ ├──run_eval.sh - ├── src - │ ├──config.py - │ ├──dataset.py - │ ├──callback.py - │ ├──deepfm.py - ├── train.py - ├── eval.py -``` - -## Training process - -### Usage - -- sh run_train.sh [DEVICE_NUM] [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PAHT] -- python train.py --dataset_path [DATASET_PATH] - -### Launch - -``` -# distribute training example - sh scripts/run_distribute_train.sh 8 /opt/dataset/criteo /opt/mindspore_hccl_file.json -# standalone training example - sh scripts/run_standalone_train.sh 0 /opt/dataset/criteo - or - python train.py --dataset_path /opt/dataset/criteo > output.log 2>&1 & -``` - -### Result - -Training result will be stored in the example path. -Checkpoints will be stored at `./checkpoint` by default, -and training log will be redirected to `./output.log` by default, -and loss log will be redirected to `./loss.log` by default, -and eval log will be redirected to `./auc.log` by default. - - -## Eval process - -### Usage - -- sh run_eval.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH] - -### Launch - -``` -# infer example - sh scripts/run_eval.sh 0 ~/criteo/eval/ ~/train/deepfm-15_41257.ckpt -``` - -> checkpoint can be produced in training process. - -### Result - -Inference result will be stored in the example path, you can find result like the followings in `auc.log`. - -``` -2020-05-27 20:51:35 AUC: 0.80577889065281, eval time: 35.55999s. -``` - -# Model description - -## Performance - -### Training Performance - -| Parameters | DeepFM | -| -------------------------- | ------------------------------------------------------| -| Model Version | | -| Resource | Ascend 910, cpu:2.60GHz 96cores, memory:1.5T | -| uploaded Date | 05/27/2020 | -| MindSpore Version | 0.2.0 | -| Dataset | Criteo | -| Training Parameters | src/config.py | -| Optimizer | Adam | -| Loss Function | SoftmaxCrossEntropyWithLogits | -| outputs | | -| Loss | 0.4234 | -| Accuracy | AUC[0.8055] | -| Total time | 91 min | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | - -#### Inference Performance - -| Parameters | | | -| -------------------------- | ----------------------------- | ------------------------- | -| Model Version | | | -| Resource | Ascend 910 | Ascend 310 | -| uploaded Date | 05/27/2020 | 05/27/2020 | -| MindSpore Version | 0.2.0 | 0.2.0 | -| Dataset | Criteo | | -| batch_size | 1000 | | -| outputs | | | -| Accuracy | AUC[0.8055] | | -| Speed | | | -| Total time | 35.559s | | -| Model for inference | | | - -# ModelZoo Homepage - [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/deepfm/scripts/run_distribute_train.sh b/model_zoo/deepfm/scripts/run_distribute_train.sh deleted file mode 100644 index fb2c3db17c..0000000000 --- a/model_zoo/deepfm/scripts/run_distribute_train.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -echo "Please run the script as: " -echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT" -echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json" -echo "After running the script, the network runs in the background, The log will be generated in logx/output.log" - - -export RANK_SIZE=$1 -DATA_URL=$2 -export MINDSPORE_HCCL_CONFIG_PATH=$3 - -for ((i=0; i env.log - python -u train.py \ - --dataset_path=$DATA_URL \ - --ckpt_path="checkpoint" \ - --eval_file_name='auc.log' \ - --loss_file_name='loss.log' \ - --do_eval=True > output.log 2>&1 & - cd ../ -done diff --git a/model_zoo/deepfm/train.py b/model_zoo/deepfm/train.py deleted file mode 100644 index 228d04c0d3..0000000000 --- a/model_zoo/deepfm/train.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train_criteo.""" -import os -import sys -import argparse - -from mindspore import context, ParallelMode -from mindspore.communication.management import init -from mindspore.train.model import Model -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor - -from src.deepfm import ModelBuilder, AUCMetric -from src.config import DataConfig, ModelConfig, TrainConfig -from src.dataset import create_dataset, DataType -from src.callback import EvalCallBack, LossCallBack - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -parser = argparse.ArgumentParser(description='CTR Prediction') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path') -parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') -parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') -parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') - -args_opt, _ = parser.parse_known_args() -device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) - - -if __name__ == '__main__': - data_config = DataConfig() - model_config = ModelConfig() - train_config = TrainConfig() - - rank_size = int(os.environ.get("RANK_SIZE", 1)) - if rank_size > 1: - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) - init() - rank_id = int(os.environ.get('RANK_ID')) - else: - rank_size = None - rank_id = None - - ds_train = create_dataset(args_opt.dataset_path, - train_mode=True, - epochs=train_config.train_epochs, - batch_size=train_config.batch_size, - data_type=DataType(data_config.data_format), - rank_size=rank_size, - rank_id=rank_id) - - model_builder = ModelBuilder(ModelConfig, TrainConfig) - train_net, eval_net = model_builder.get_train_eval_net() - auc_metric = AUCMetric() - model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) - - time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) - loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name) - callback_list = [time_callback, loss_callback] - - if train_config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, - keep_checkpoint_max=train_config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, - directory=args_opt.ckpt_path, - config=config_ck) - callback_list.append(ckpt_cb) - - if args_opt.do_eval: - ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, - epochs=train_config.train_epochs, - batch_size=train_config.batch_size, - data_type=DataType(data_config.data_format)) - eval_callback = EvalCallBack(model, ds_eval, auc_metric, - eval_file_path=args_opt.eval_file_name) - callback_list.append(eval_callback) - model.train(train_config.train_epochs, ds_train, callbacks=callback_list) diff --git a/model_zoo/deeplabv3/README.md b/model_zoo/deeplabv3/README.md deleted file mode 100644 index c8df3dab8d..0000000000 --- a/model_zoo/deeplabv3/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Deeplab-V3 Example - -## Description -This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpore. - -## Requirements -- Install [MindSpore](https://www.mindspore.cn/install/en). -- Download the VOC 2012 dataset for training. - -> Notes: - If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. - - -## Running the Example -### Training -- Set options in config.py. -- Run `run_standalone_train.sh` for non-distributed training. - ``` bash - sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH - ``` -- Run `run_distribute_train.sh` for distributed training. - ``` bash - sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH - ``` -### Evaluation -Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path. -- Run run_eval.sh for evaluation. - ``` bash - sh scripts/run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH - ``` - -## Options and Parameters -It contains of parameters of Deeplab-V3 model and options for training, which is set in file config.py. - -### Options: -``` -config.py: - learning_rate Learning rate, default is 0.0014. - weight_decay Weight decay, default is 5e-5. - momentum Momentum, default is 0.97. - crop_size Image crop size [height, width] during training, default is 513. - eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]. - output_stride The ratio of input to output spatial resolution, default is 16. - ignore_label Ignore label value, default is 255. - seg_num_classes Number of semantic classes, including the background class (if exists). - foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21. - fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False. - atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None. - decoder_output_stride The ratio of input to output spatial resolution when employing decoder - to refine segmentation results, default is None. - image_pyramid Input scales for multi-scale feature extraction, default is None. - epoch_size Epoch size, default is 6. - batch_size batch size of input dataset: N, default is 2. - enable_save_ckpt Enable save checkpoint, default is true. - save_checkpoint_steps Save checkpoint steps, default is 1000. - save_checkpoint_num Save checkpoint numbers, default is 1. -``` - - -### Parameters: -``` -Parameters for dataset and network: - distribute Run distribute, default is false. - data_url Train/Evaluation data url, required. - checkpoint_url Checkpoint path, default is None. -``` \ No newline at end of file diff --git a/model_zoo/deeplabv3/scripts/run_distribute_train.sh b/model_zoo/deeplabv3/scripts/run_distribute_train.sh deleted file mode 100644 index 4dcd8d9768..0000000000 --- a/model_zoo/deeplabv3/scripts/run_distribute_train.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "bash run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH" -echo "for example: bash run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH [PRETRAINED_CKPT_PATH](option)" -echo "It is better to use absolute path." -echo "==============================================================================================================" - -DATA_DIR=$2 - -export MINDSPORE_HCCL_CONFIG_PATH=$1 -export RANK_TABLE_FILE=$1 -export RANK_SIZE=8 -PATH_CHECKPOINT="" -if [ $# == 3 ] -then - PATH_CHECKPOINT=$3 -fi -cores=`cat /proc/cpuinfo|grep "processor" |wc -l` -echo "the number of logical core" $cores -avg_core_per_rank=`expr $cores \/ $RANK_SIZE` -core_gap=`expr $avg_core_per_rank \- 1` -echo "avg_core_per_rank" $avg_core_per_rank -echo "core_gap" $core_gap -for((i=0;i env.log - taskset -c $cmdopt python ../train.py \ - --distribute="true" \ - --device_id=$DEVICE_ID \ - --checkpoint_url=$PATH_CHECKPOINT \ - --data_url=$DATA_DIR > log.txt 2>&1 & - cd ../ -done \ No newline at end of file diff --git a/model_zoo/deeplabv3/src/deeplabv3.py b/model_zoo/deeplabv3/src/deeplabv3.py deleted file mode 100644 index 03bb03ad14..0000000000 --- a/model_zoo/deeplabv3/src/deeplabv3.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""DeepLabv3.""" - -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops.composite import add_flags -from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ - DepthwiseConv2dNative, SpaceToBatch, BatchToSpace - - -class ASPPSampleBlock(nn.Cell): - """ASPP sample block.""" - def __init__(self, feature_shape, scale_size, output_stride): - super(ASPPSampleBlock, self).__init__() - sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1 - sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1 - self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) - - def construct(self, x): - return self.sample(x) - - -class ASPP(nn.Cell): - """ - ASPP model for DeepLabv3. - - Args: - channel (int): Input channel. - depth (int): Output channel. - feature_shape (list): The shape of feature,[h,w]. - scale_sizes (list): Input scales for multi-scale feature extraction. - atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. - output_stride (int): 'The ratio of input to output spatial resolution.' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - - Returns: - Tensor, output tensor. - - Examples: - >>> ASPP(channel=2048,256,[14,14],[1],[6],16) - """ - def __init__(self, channel, depth, feature_shape, scale_sizes, - atrous_rates, output_stride, fine_tune_batch_norm=False): - super(ASPP, self).__init__() - self.aspp0 = _conv_bn_relu(channel, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.atrous_rates = [] - if atrous_rates is not None: - self.atrous_rates = atrous_rates - self.aspp_pointwise = _conv_bn_relu(channel, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel, - channel_multiplier=1, - kernel_size=3, - stride=1, - dilation=1, - pad_mode="valid") - self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm) - self.aspp_depth_relu = nn.ReLU() - self.aspp_depths = [] - self.aspp_depth_spacetobatchs = [] - self.aspp_depth_batchtospaces = [] - - for scale_size in scale_sizes: - aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16) - if atrous_rates is None: - break - for rate in atrous_rates: - padding = 0 - for j in range(100): - padded_size = rate * j - if padded_size >= aspp_scale_depth_size + 2 * rate: - padding = padded_size - aspp_scale_depth_size - 2 * rate - break - paddings = [[rate, rate + int(padding)], - [rate, rate + int(padding)]] - self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings) - self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch) - crops = [[0, int(padding)], [0, int(padding)]] - self.aspp_depth_batchtospace = BatchToSpace(rate, crops) - self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace) - self.aspp_depths = nn.CellList(self.aspp_depths) - self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs) - self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces) - - self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1]))) - self.global_poolings = [] - for scale_size in scale_sizes: - pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride) - pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride) - self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w)))) - self.global_poolings = nn.CellList(self.global_poolings) - self.conv_bn = _conv_bn_relu(channel, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.samples = [] - for scale_size in scale_sizes: - self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride)) - self.samples = nn.CellList(self.samples) - self.feature_shape = feature_shape - self.concat = P.Concat(axis=1) - - @add_flags(loop_can_unroll=True) - def construct(self, x, scale_index=0): - aspp0 = self.aspp0(x) - aspp1 = self.global_poolings[scale_index](x) - aspp1 = self.conv_bn(aspp1) - aspp1 = self.samples[scale_index](aspp1) - output = self.concat((aspp1, aspp0)) - - for i in range(len(self.atrous_rates)): - aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x) - aspp_i = self.aspp_depth_depthwiseconv(aspp_i) - aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i) - aspp_i = self.aspp_depth_bn(aspp_i) - aspp_i = self.aspp_depth_relu(aspp_i) - aspp_i = self.aspp_pointwise(aspp_i) - output = self.concat((output, aspp_i)) - return output - - -class DecoderSampleBlock(nn.Cell): - """Decoder sample block.""" - def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4): - super(DecoderSampleBlock, self).__init__() - sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1 - sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1 - self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) - - def construct(self, x): - return self.sample(x) - - -class Decoder(nn.Cell): - """ - Decode module for DeepLabv3. - Args: - low_level_channel (int): Low level input channel - channel (int): Input channel. - depth (int): Output channel. - feature_shape (list): 'Input image shape, [N,C,H,W].' - scale_sizes (list): 'Input scales for multi-scale feature extraction.' - decoder_output_stride (int): 'The ratio of input to output spatial resolution' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - Returns: - Tensor, output tensor. - Examples: - >>> Decoder(256, 100, [56,56]) - """ - def __init__(self, - low_level_channel, - channel, - depth, - feature_shape, - scale_sizes, - decoder_output_stride, - fine_tune_batch_norm): - super(Decoder, self).__init__() - self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1, - pad_mode="same", use_batch_statistics=fine_tune_batch_norm) - self.decoder_depth0 = _deep_conv_bn_relu(channel + 48, - channel_multiplier=1, - ksize=3, - stride=1, - pad_mode="same", - dilation=1, - use_batch_statistics=fine_tune_batch_norm) - self.decoder_pointwise0 = _conv_bn_relu(channel + 48, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.decoder_depth1 = _deep_conv_bn_relu(depth, - channel_multiplier=1, - ksize=3, - stride=1, - pad_mode="same", - dilation=1, - use_batch_statistics=fine_tune_batch_norm) - self.decoder_pointwise1 = _conv_bn_relu(depth, - depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.depth = depth - self.concat = P.Concat(axis=1) - self.samples = [] - for scale_size in scale_sizes: - self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride)) - self.samples = nn.CellList(self.samples) - - def construct(self, x, low_level_feature, scale_index): - low_level_feature = self.feature_projection(low_level_feature) - low_level_feature = self.samples[scale_index](low_level_feature) - x = self.samples[scale_index](x) - output = self.concat((x, low_level_feature)) - output = self.decoder_depth0(output) - output = self.decoder_pointwise0(output) - output = self.decoder_depth1(output) - output = self.decoder_pointwise1(output) - return output - - -class SingleDeepLabV3(nn.Cell): - """ - DeepLabv3 Network. - Args: - num_classes (int): Class number. - feature_shape (list): Input image shape, [N,C,H,W]. - backbone (Cell): Backbone Network. - channel (int): Resnet output channel. - depth (int): ASPP block depth. - scale_sizes (list): Input scales for multi-scale feature extraction. - atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. - decoder_output_stride (int): 'The ratio of input to output spatial resolution' - output_stride (int): 'The ratio of input to output spatial resolution.' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - Returns: - Tensor, output tensor. - Examples: - >>> SingleDeepLabV3(num_classes=10, - >>> feature_shape=[1,3,224,224], - >>> backbone=resnet50_dl(), - >>> channel=2048, - >>> depth=256) - >>> scale_sizes=[1.0]) - >>> atrous_rates=[6]) - >>> decoder_output_stride=4) - >>> output_stride=16) - """ - - def __init__(self, - num_classes, - feature_shape, - backbone, - channel, - depth, - scale_sizes, - atrous_rates, - decoder_output_stride, - output_stride, - fine_tune_batch_norm=False): - super(SingleDeepLabV3, self).__init__() - self.num_classes = num_classes - self.channel = channel - self.depth = depth - self.scale_sizes = [] - for scale_size in np.sort(scale_sizes): - self.scale_sizes.append(scale_size) - self.net = backbone - self.aspp = ASPP(channel=self.channel, - depth=self.depth, - feature_shape=[feature_shape[2], - feature_shape[3]], - scale_sizes=self.scale_sizes, - atrous_rates=atrous_rates, - output_stride=output_stride, - fine_tune_batch_norm=fine_tune_batch_norm) - - atrous_rates_len = 0 - if atrous_rates is not None: - atrous_rates_len = len(atrous_rates) - self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth, - ksize=1, - stride=1, - use_batch_statistics=fine_tune_batch_norm) - self.fc2 = nn.Conv2d(depth, - num_classes, - kernel_size=1, - stride=1, - has_bias=True) - self.upsample = P.ResizeBilinear((int(feature_shape[2]), - int(feature_shape[3])), - align_corners=True) - self.samples = [] - for scale_size in self.scale_sizes: - self.samples.append(SampleBlock(feature_shape, scale_size)) - self.samples = nn.CellList(self.samples) - self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]), - float(feature_shape[3])] - - self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1))) - self.dropout = nn.Dropout(keep_prob=0.9) - self.shape = P.Shape() - self.decoder_output_stride = decoder_output_stride - if decoder_output_stride is not None: - self.decoder = Decoder(low_level_channel=depth, - channel=depth, - depth=depth, - feature_shape=[feature_shape[2], - feature_shape[3]], - scale_sizes=self.scale_sizes, - decoder_output_stride=decoder_output_stride, - fine_tune_batch_norm=fine_tune_batch_norm) - - def construct(self, x, scale_index=0): - x = (2.0 / 255.0) * x - 1.0 - x = self.pad(x) - low_level_feature, feature_map = self.net(x) - for scale_size in self.scale_sizes: - if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2: - output = self.aspp(feature_map, scale_index) - output = self.fc1(output) - if self.decoder_output_stride is not None: - output = self.decoder(output, low_level_feature, scale_index) - output = self.fc2(output) - output = self.samples[scale_index](output) - return output - scale_index += 1 - return feature_map - - -class SampleBlock(nn.Cell): - """Sample block.""" - def __init__(self, - feature_shape, - scale_size=1.0): - super(SampleBlock, self).__init__() - sample_h = np.ceil(float(feature_shape[2]) * scale_size) - sample_w = np.ceil(float(feature_shape[3]) * scale_size) - self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) - - def construct(self, x): - return self.sample(x) - - -class DeepLabV3(nn.Cell): - """DeepLabV3 model.""" - def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates, - decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid): - super(DeepLabV3, self).__init__() - self.infer_scale_sizes = [] - if infer_scale_sizes is not None: - self.infer_scale_sizes = infer_scale_sizes - - self.infer_scale_sizes = infer_scale_sizes - if image_pyramid is None: - image_pyramid = [1.0] - - self.image_pyramid = image_pyramid - scale_sizes = [] - for pyramid in image_pyramid: - scale_sizes.append(pyramid) - for scale in infer_scale_sizes: - scale_sizes.append(scale) - self.samples = [] - for scale_size in scale_sizes: - self.samples.append(SampleBlock(feature_shape, scale_size)) - self.samples = nn.CellList(self.samples) - self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes, - feature_shape=feature_shape, - backbone=resnet50_dl(fine_tune_batch_norm), - channel=channel, - depth=depth, - scale_sizes=scale_sizes, - atrous_rates=atrous_rates, - decoder_output_stride=decoder_output_stride, - output_stride=output_stride, - fine_tune_batch_norm=fine_tune_batch_norm) - self.softmax = P.Softmax(axis=1) - self.concat = P.Concat(axis=2) - self.expand_dims = P.ExpandDims() - self.reduce_mean = P.ReduceMean() - self.sample_common = P.ResizeBilinear((int(feature_shape[2]), - int(feature_shape[3])), - align_corners=True) - - def construct(self, x): - logits = () - if self.training: - if len(self.image_pyramid) >= 1: - if self.image_pyramid[0] == 1: - logits = self.deeplabv3(x) - else: - x1 = self.samples[0](x) - logits = self.deeplabv3(x1) - logits = self.sample_common(logits) - logits = self.expand_dims(logits, 2) - for i in range(len(self.image_pyramid) - 1): - x_i = self.samples[i + 1](x) - logits_i = self.deeplabv3(x_i) - logits_i = self.sample_common(logits_i) - logits_i = self.expand_dims(logits_i, 2) - logits = self.concat((logits, logits_i)) - logits = self.reduce_mean(logits, 2) - return logits - if len(self.infer_scale_sizes) >= 1: - infer_index = len(self.image_pyramid) - x1 = self.samples[infer_index](x) - logits = self.deeplabv3(x1) - logits = self.sample_common(logits) - logits = self.softmax(logits) - logits = self.expand_dims(logits, 2) - for i in range(len(self.infer_scale_sizes) - 1): - x_i = self.samples[i + 1 + infer_index](x) - logits_i = self.deeplabv3(x_i) - logits_i = self.sample_common(logits_i) - logits_i = self.softmax(logits_i) - logits_i = self.expand_dims(logits_i, 2) - logits = self.concat((logits, logits_i)) - logits = self.reduce_mean(logits, 2) - return logits - - -def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid, - infer_scale_sizes, atrous_rates=None, decoder_output_stride=None, - output_stride=16, fine_tune_batch_norm=False): - """ - ResNet50 based DeepLabv3 network. - - Args: - num_classes (int): Class number. - feature_shape (list): Input image shape, [N,C,H,W]. - image_pyramid (list): Input scales for multi-scale feature extraction. - atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. - infer_scale_sizes (list): 'The scales to resize images for inference. - decoder_output_stride (int): 'The ratio of input to output spatial resolution' - output_stride (int): 'The ratio of input to output spatial resolution.' - fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' - - Returns: - Cell, cell instance of ResNet50 based DeepLabv3 neural network. - - Examples: - >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0]) - """ - return DeepLabV3(num_classes=num_classes, - feature_shape=feature_shape, - backbone=resnet50_dl(fine_tune_batch_norm), - channel=2048, - depth=256, - infer_scale_sizes=infer_scale_sizes, - atrous_rates=atrous_rates, - decoder_output_stride=decoder_output_stride, - output_stride=output_stride, - fine_tune_batch_norm=fine_tune_batch_norm, - image_pyramid=image_pyramid) diff --git a/model_zoo/deeplabv3/train.py b/model_zoo/deeplabv3/train.py deleted file mode 100644 index d096613977..0000000000 --- a/model_zoo/deeplabv3/train.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train.""" -import argparse -from mindspore import context -from mindspore.communication.management import init -from mindspore.nn.optim.momentum import Momentum -from mindspore import Model, ParallelMode -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor -from src.md_dataset import create_dataset -from src.losses import OhemLoss -from src.deeplabv3 import deeplabv3_resnet50 -from src.config import config - -parser = argparse.ArgumentParser(description="Deeplabv3 training") -parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") -parser.add_argument('--data_url', required=True, default=None, help='Train data url') -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') - -args_opt = parser.parse_args() -print(args_opt) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) -class LossCallBack(Callback): - """ - Monitor the loss in training. - Note: - if per_print_times is 0 do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0") - self._per_print_times = per_print_times - def step_end(self, run_context): - cb_params = run_context.original_args() - print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) -def model_fine_tune(flags, train_net, fix_weight_layer): - checkpoint_path = flags.checkpoint_url - if checkpoint_path is None: - return - param_dict = load_checkpoint(checkpoint_path) - load_param_into_net(train_net, param_dict) - for para in train_net.trainable_params(): - if fix_weight_layer in para.name: - para.requires_grad = False -if __name__ == "__main__": - if args_opt.distribute == "true": - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) - init() - args_opt.base_size = config.crop_size - args_opt.crop_size = config.crop_size - train_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="train") - dataset_size = train_dataset.get_dataset_size() - time_cb = TimeMonitor(data_size=dataset_size) - callback = [time_cb, LossCallBack()] - if config.enable_save_ckpt: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, - keep_checkpoint_max=config.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) - callback.append(ckpoint_cb) - net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], - infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, - decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, - fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) - net.set_train() - model_fine_tune(args_opt, net, 'layer') - loss = OhemLoss(config.seg_num_classes, config.ignore_label) - opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) - model = Model(net, loss, opt) - model.train(config.epoch_size, train_dataset, callback) - \ No newline at end of file diff --git a/model_zoo/faster_rcnn/README.md b/model_zoo/faster_rcnn/README.md deleted file mode 100644 index 24ababcfe4..0000000000 --- a/model_zoo/faster_rcnn/README.md +++ /dev/null @@ -1,142 +0,0 @@ -# FasterRcnn Example - -## Description - -FasterRcnn is a two-stage target detection network,This network uses a region proposal network (RPN), which can share the convolution features of the whole image with the detection network, so that the calculation of region proposal is almost cost free. The whole network further combines RPN and FastRcnn into a network by sharing the convolution features. - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Download the dataset COCO2017. - -- We use coco2017 as training dataset in this example by default, and you can also use your own datasets. - - 1. If coco dataset is used. **Select dataset to coco when run script.** - Install Cython and pycocotool, and you can also install mmcv to process data. - - ``` - pip install Cython - - pip install pycocotools - - pip install mmcv - ``` - And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: - - - ``` - . - └─cocodataset - ├─annotations - ├─instance_train2017.json - └─instance_val2017.json - ├─val2017 - └─train2017 - - ``` - - 2. If your own dataset is used. **Select dataset to other when run script.** - Organize the dataset infomation into a TXT file, each row in the file is as follows: - - ``` - train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 - ``` - - Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. - - -## Example structure - -```shell -. -└─FasterRcnn - ├─README.md - ├─scripts - ├─run_download_process_data.sh - ├─run_standalone_train.sh - ├─run_train.sh - └─run_eval.sh - ├─src - ├─FasterRcnn - ├─__init__.py - ├─anchor_generator.py - ├─bbox_assign_sample.py - ├─bbox_assign_sample_stage2.py - ├─faster_rcnn_r50.py - ├─fpn_neck.py - ├─proposal_generator.py - ├─rcnn.py - ├─resnet50.py - ├─roi_align.py - └─rpn.py - ├─config.py - ├─dataset.py - ├─lr_schedule.py - ├─network_define.py - └─util.py - ├─eval.py - └─train.py -``` - -## Running the example - -### Train - -#### Usage - -``` -# distributed training -sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [PRETRAINED_MODEL] - -# standalone training -sh run_standalone_train.sh [PRETRAINED_MODEL] -``` - -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). - -#### Result - -Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log. - - -``` -# distribute training result(8p) -epoch: 1 step: 7393, rpn_loss: 0.12054, rcnn_loss: 0.40601, rpn_cls_loss: 0.04025, rpn_reg_loss: 0.08032, rcnn_cls_loss: 0.25854, rcnn_reg_loss: 0.14746, total_loss: 0.52655 -epoch: 2 step: 7393, rpn_loss: 0.06561, rcnn_loss: 0.50293, rpn_cls_loss: 0.02587, rpn_reg_loss: 0.03967, rcnn_cls_loss: 0.35669, rcnn_reg_loss: 0.14624, total_loss: 0.56854 -epoch: 3 step: 7393, rpn_loss: 0.06940, rcnn_loss: 0.49658, rpn_cls_loss: 0.03769, rpn_reg_loss: 0.03165, rcnn_cls_loss: 0.36353, rcnn_reg_loss: 0.13318, total_loss: 0.56598 -... -epoch: 10 step: 7393, rpn_loss: 0.03555, rcnn_loss: 0.32666, rpn_cls_loss: 0.00697, rpn_reg_loss: 0.02859, rcnn_cls_loss: 0.16125, rcnn_reg_loss: 0.16541, total_loss: 0.36221 -epoch: 11 step: 7393, rpn_loss: 0.19849, rcnn_loss: 0.47827, rpn_cls_loss: 0.11639, rpn_reg_loss: 0.08209, rcnn_cls_loss: 0.29712, rcnn_reg_loss: 0.18115, total_loss: 0.67676 -epoch: 12 step: 7393, rpn_loss: 0.00691, rcnn_loss: 0.10168, rpn_cls_loss: 0.00529, rpn_reg_loss: 0.00162, rcnn_cls_loss: 0.05426, rcnn_reg_loss: 0.04745, total_loss: 0.10859 -``` - -### Infer - -#### Usage - -``` -# infer -sh run_infer.sh [VALIDATION_DATASET_PATH] [CHECKPOINT_PATH] -``` - -> checkpoint can be produced in training process. - -#### Result - -Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log. - -``` - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.360 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.586 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.385 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.229 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.402 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.441 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.299 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.487 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.515 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.346 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.562 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631 -``` \ No newline at end of file diff --git a/model_zoo/faster_rcnn/eval.py b/model_zoo/faster_rcnn/eval.py deleted file mode 100644 index d8dd2ed79a..0000000000 --- a/model_zoo/faster_rcnn/eval.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# less required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Evaluation for FasterRcnn""" -import os -import argparse -import time -import random -import numpy as np -from pycocotools.coco import COCO -from mindspore import context, Tensor -from mindspore.train.serialization import load_checkpoint, load_param_into_net -import mindspore.dataset.engine as de - -from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 -from src.config import config -from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset -from src.util import coco_eval, bbox2result_1image, results2json - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description="FasterRcnn evaluation") -parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") -parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.") -parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -args_opt = parser.parse_args() - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - -def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): - """FasterRcnn evaluation.""" - ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size, - repeat_num=1, is_training=False) - net = Faster_Rcnn_Resnet50(config) - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - net.set_train(False) - - eval_iter = 0 - total = ds.get_dataset_size() - outputs = [] - dataset_coco = COCO(ann_file) - - print("\n========================================\n") - print("total images num: ", total) - print("Processing, please wait a moment.") - max_num = 128 - for data in ds.create_dict_iterator(): - eval_iter = eval_iter + 1 - - img_data = data['image'] - img_metas = data['image_shape'] - gt_bboxes = data['box'] - gt_labels = data['label'] - gt_num = data['valid_num'] - - start = time.time() - # run net - output = net(Tensor(img_data), Tensor(img_metas), Tensor(gt_bboxes), Tensor(gt_labels), Tensor(gt_num)) - end = time.time() - print("Iter {} cost time {}".format(eval_iter, end - start)) - - # output - all_bbox = output[0] - all_label = output[1] - all_mask = output[2] - - for j in range(config.test_batch_size): - all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :]) - all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :]) - all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :]) - - all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] - all_labels_tmp_mask = all_label_squee[all_mask_squee] - - if all_bboxes_tmp_mask.shape[0] > max_num: - inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) - inds = inds[:max_num] - all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] - all_labels_tmp_mask = all_labels_tmp_mask[inds] - - outputs_tmp = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes) - - outputs.append(outputs_tmp) - - eval_types = ["bbox"] - result_files = results2json(dataset_coco, outputs, "./results.pkl") - - coco_eval(result_files, eval_types, dataset_coco, single_result=True) - - -if __name__ == '__main__': - prefix = "FasterRcnn_eval.mindrecord" - mindrecord_dir = config.mindrecord_dir - mindrecord_file = os.path.join(mindrecord_dir, prefix) - if not os.path.exists(mindrecord_file): - if not os.path.isdir(mindrecord_dir): - os.makedirs(mindrecord_dir) - if args_opt.dataset == "coco": - if os.path.isdir(config.coco_root): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("coco_root not exits.") - else: - if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("other", False, prefix, file_num=1) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("IMAGE_DIR or ANNO_PATH not exits.") - - print("Start Eval!") - FasterRcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file) diff --git a/model_zoo/faster_rcnn/scripts/run_distribute_train.sh b/model_zoo/faster_rcnn/scripts/run_distribute_train.sh deleted file mode 100755 index bc6ebd4a18..0000000000 --- a/model_zoo/faster_rcnn/scripts/run_distribute_train.sh +++ /dev/null @@ -1,69 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 2 ] -then - echo "Usage: sh run_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [PRETRAINED_PATH]" -exit 1 -fi - -get_real_path(){ - if [ "${1:0:1}" == "/" ]; then - echo "$1" - else - echo "$(realpath -m $PWD/$1)" - fi -} -PATH1=$(get_real_path $1) -PATH2=$(get_real_path $2) - -echo $PATH1 -echo $PATH2 - -if [ ! -f $PATH1 ] -then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file" -exit 1 -fi - -if [ ! -f $PATH2 ] -then - echo "error: PRETRAINED_PATH=$PATH2 is not a file" -exit 1 -fi - -ulimit -u unlimited -export DEVICE_NUM=8 -export RANK_SIZE=8 -export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 -export RANK_TABLE_FILE=$PATH1 - -for((i=0; i<${DEVICE_NUM}; i++)) -do - export DEVICE_ID=$i - export RANK_ID=$i - rm -rf ./train_parallel$i - mkdir ./train_parallel$i - cp ../*.py ./train_parallel$i - cp *.sh ./train_parallel$i - cp -r ../src ./train_parallel$i - cd ./train_parallel$i || exit - echo "start training for rank $RANK_ID, device $DEVICE_ID" - env > env.log - python train.py --do_train=True --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM \ - --pre_trained=$PATH2 &> log & - cd .. -done diff --git a/model_zoo/faster_rcnn/train.py b/model_zoo/faster_rcnn/train.py deleted file mode 100644 index 7d5f190bab..0000000000 --- a/model_zoo/faster_rcnn/train.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# less required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""train FasterRcnn and get checkpoint files.""" - -import os -import argparse -import random -import numpy as np - -import mindspore.common.dtype as mstype -from mindspore import context, Tensor -from mindspore.communication.management import init -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor -from mindspore.train import Model, ParallelMode -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.nn import SGD -import mindspore.dataset.engine as de - -from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 -from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet -from src.config import config -from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset -from src.lr_schedule import dynamic_lr - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description="FasterRcnn training") -parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " - "Mindrecord, default is false.") -parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") -parser.add_argument("--do_train", type=bool, default=True, help="Do train or not, default is true.") -parser.add_argument("--do_eval", type=bool, default=False, help="Do eval or not, default is false.") -parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") -parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.") -parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") -parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") -parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") -args_opt = parser.parse_args() - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - -if __name__ == '__main__': - if not args_opt.do_eval and args_opt.run_distribute: - rank = args_opt.rank_id - device_num = args_opt.device_num - context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True, parameter_broadcast=True) - init() - else: - rank = 0 - device_num = 1 - - print("Start create dataset!") - - # It will generate mindrecord file in args_opt.mindrecord_dir, - # and the file name is FasterRcnn.mindrecord0, 1, ... file_num. - prefix = "FasterRcnn.mindrecord" - mindrecord_dir = config.mindrecord_dir - mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") - if not os.path.exists(mindrecord_file): - if not os.path.isdir(mindrecord_dir): - os.makedirs(mindrecord_dir) - if args_opt.dataset == "coco": - if os.path.isdir(config.coco_root): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("coco", True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("coco_root not exits.") - else: - if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("other", True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("IMAGE_DIR or ANNO_PATH not exits.") - - if not args_opt.only_create_dataset: - loss_scale = float(config.loss_scale) - - # When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0. - dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=config.epoch_size, - batch_size=config.batch_size, device_num=device_num, rank_id=rank) - - dataset_size = dataset.get_dataset_size() - print("Create dataset done!") - - net = Faster_Rcnn_Resnet50(config=config) - net = net.set_train() - - load_path = args_opt.pre_trained - if load_path != "": - param_dict = load_checkpoint(load_path) - for item in list(param_dict.keys()): - if not item.startswith('backbone'): - param_dict.pop(item) - load_param_into_net(net, param_dict) - - loss = LossNet() - lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32) - - opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, - weight_decay=config.weight_decay, loss_scale=config.loss_scale) - net_with_loss = WithLossCell(net, loss) - if args_opt.run_distribute: - net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, - mean=True, degree=device_num) - else: - net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) - - time_cb = TimeMonitor(data_size=dataset_size) - loss_cb = LossCallBack() - cb = [time_cb, loss_cb] - if config.save_checkpoint: - ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=config.save_checkpoint_path, config=ckptconfig) - cb += [ckpoint_cb] - - model = Model(net) - model.train(config.epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/gat/src/gat.py b/model_zoo/gat/src/gat.py deleted file mode 100644 index 3cb3cc1106..0000000000 --- a/model_zoo/gat/src/gat.py +++ /dev/null @@ -1,496 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Aggregator.""" -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore._extends import cell_attr_register -from mindspore import Tensor, Parameter -from mindspore.common.initializer import initializer -from mindspore._checkparam import check_int_positive, check_bool -from mindspore.nn.layer.activation import get_activation - - -class GNNFeatureTransform(nn.Cell): - r""" - The GNN featuren transform layer for input. - - Applies linear transformation for the input feature. This layer implements the operation as: - - .. math:: - \text{outputs} = \text{inputs} * \text{kernel} + \text{bias}, - - where :math:`\text{activation}` is the activation function passed as the activation - argument (if passed in),:math:`\text{activation}` is a weight matrix with the same - data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector - with the same data type as the inputs created by the layer (only if has_bias is True). - - Args: - in_channels (int): The number of channels in the input space. - out_channels (int): The number of channels in the output space. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype - is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is - same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - - Raises: - ValueError: If weight_init or bias_init shape is incorrect. - - Inputs: - - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`, - where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the - size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`. - - Outputs: - Tensor, the shape of the output tensor is :math:`(*B, N, M)`. - - Examples: - >>> net = nn.Dense(3, 4) - >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) - >>> net(input) - [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] - [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] - """ - @cell_attr_register - def __init__(self, - in_channels, - out_channels, - weight_init='normal', - bias_init='zeros', - has_bias=True): - super(GNNFeatureTransform, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) - self.has_bias = check_bool(has_bias) - - if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ - weight_init.shape()[1] != in_channels: - raise ValueError("weight_init shape error") - - self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") - - if self.has_bias: - if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: - raise ValueError("bias_init shape error") - - self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") - - self.matmul = P.MatMul(transpose_b=True) - self.bias_add = P.BiasAdd() - - def construct(self, x): - tensor_shape = F.shape(x) - input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2])) - output = self.matmul(input_feature, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels)) - return output - - def extend_repr(self): - str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ - .format(self.in_channels, self.out_channels, self.weight, self.has_bias) - if self.has_bias: - str_info = str_info + ', bias={}'.format(self.bias) - - return str_info - - -class _BaseAggregator(nn.Cell): - """ - Base Aggregator of GNN - - Args: - feature_in_dim (int): Node or edge input feature dim. - feature_out_dim (int): Node or edge outpout feature dim. - use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True - weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype - is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is - same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. - activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. - - Examples: - >>> class MyAggregator(_BaseAggregator): - >>> def __init__(self): - >>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim) - >>> self.reduce_mean = P.ReduceSum() - >>> - >>> def construct(self, x): - >>> return self.reduce_mean(x, 1) - """ - def __init__(self, - feature_in_dim, - feature_out_dim, - use_fc=True, - weight_init="normal", - bias_init="zeros", - has_bias=True, - dropout_ratio=None, - activation=None): - super(_BaseAggregator, self).__init__() - self.in_dim = feature_in_dim - self.out_dim = feature_out_dim - self.use_fc = use_fc - if self.use_fc: - self.weight_init = weight_init - self.bias_init = bias_init - self.has_bias = has_bias - self.fc = GNNFeatureTransform(self.in_dim, - self.out_dim, - weight_init=self.weight_init, - bias_init=self.bias_init, - has_bias=self.has_bias) - self.dropout_ratio = dropout_ratio - if self.dropout_ratio is not None: - self.dropout = nn.Dropout(keep_prob=self.dropout_ratio) - self.dropout_flag = self.dropout_ratio is not None - self.activation = get_activation(activation) - self.activation_flag = self.activation is not None - - def construct(self, **kward): - """Must be overridden by all subclasses.""" - raise NotImplementedError - - -class MeanAggregator(_BaseAggregator): - """ - Mean Aggregator of GNN - - Args: - feature_in_dim (int): Node or edge input feature dim. - feature_out_dim (int): Node or edge outpout feature dim. - use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True - weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype - is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is - same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. - activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. - - Examples: - >>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5) - >>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32)) - >>> output = net(input_data) - """ - def __init__(self, - feature_in_dim, - feature_out_dim, - use_fc=True, - weight_init="normal", - bias_init="zeros", - has_bias=True, - dropout_ratio=None, - activation=None): - super(MeanAggregator, self).__init__( - feature_in_dim, - feature_out_dim, - use_fc, - weight_init, - bias_init, - has_bias, - dropout_ratio, - activation) - self.reduce_mean = P.ReduceMean(keep_dims=False) - - def construct(self, input_feature): - if self.use_fc: - input_feature = self.fc(input_feature) - if self.dropout_flag: - input_feature = self.dropout(input_feature) - if self.activation_flag: - input_feature = self.activation(input_feature) - output_feature = self.reduce_mean(input_feature, 1) - return output_feature - - -class AttentionHead(nn.Cell): - """ - Attention Head for Graph Attention Networks. - - Args: - in_channel (int): The number of input channel, input feature dim. - out_channel (int): The number of output channel, output feature dim. - in_drop_ratio (float): Input feature dropout ratio, default 0.0. - coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. - residual (bool): Whether to use residual connection, default False. - coef_activation (Cell): The attention coefficient activation function, - default nn.LeakyReLU(). - activation (Cell): The output activation function, default nn.ELU(). - - Inputs: - - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). - - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). - - Examples: - >>> head = AttentionHead(1433, - 8, - in_drop_ratio=0.6, - coef_drop_ratio=0.6, - residual=False) - >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32)) - >>> output = net(input_data) - """ - - def __init__(self, - in_channel, - out_channel, - in_drop_ratio=0.0, - coef_drop_ratio=0.0, - residual=False, - coef_activation=nn.LeakyReLU(), - activation=nn.ELU()): - super(AttentionHead, self).__init__() - self.in_channel = check_int_positive(in_channel) - self.out_channel = check_int_positive(out_channel) - self.in_drop_ratio = in_drop_ratio - self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) - self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) - self.feature_transform = GNNFeatureTransform( - in_channels=self.in_channel, - out_channels=self.out_channel, - has_bias=False, - weight_init='XavierUniform') - - self.f_1_transform = GNNFeatureTransform( - in_channels=self.out_channel, - out_channels=1, - weight_init='XavierUniform') - self.f_2_transform = GNNFeatureTransform( - in_channels=self.out_channel, - out_channels=1, - weight_init='XavierUniform') - self.softmax = nn.Softmax() - - self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio) - self.matmul = P.MatMul() - self.bias_add = P.BiasAdd() - self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') - self.residual = check_bool(residual) - if self.residual: - if in_channel != out_channel: - self.residual_transform_flag = True - self.residual_transform = GNNFeatureTransform( - in_channels=self.in_channel, - out_channels=self.out_channel) - else: - self.residual_transform = None - self.coef_activation = coef_activation - self.activation = activation - - def construct(self, input_feature, bias_mat, training=True): - if training is True: - input_feature = self.in_drop(input_feature) - - feature = self.feature_transform(input_feature) - # self attention - f_1 = self.f_1_transform(feature) - f_2 = self.f_2_transform(feature) - logits = f_1 + P.Transpose()(f_2, (0, 2, 1)) - logits = self.coef_activation(logits) + bias_mat - coefs = self.softmax(logits) - if training is True: - coefs = self.coef_drop(coefs) - feature = self.in_drop_2(feature) - - coefs = P.Squeeze(0)(coefs) - feature = P.Squeeze(0)(feature) - - ret = self.matmul(coefs, feature) - ret = self.bias_add(ret, self.bias) - ret = P.ExpandDims()(ret, 0) - # residual connection - if self.residual: - if self.residual_transform_flag: - res = self.residual_transform(input_feature) - ret = ret + res - else: - ret = ret + input_feature - # activation - if self.activation is not None: - ret = self.activation(ret) - return ret - - -class AttentionAggregator(nn.Cell): - """ - Attention Head for Graph Attention Networks,can be regarded as one - GAT layer. - - Args: - in_channel (int): Input channel. - out_channel (int): Output channel. - num_heads (int): Number of attention heads for this layer, default 1. - in_drop_ratio (float): Input feature dropout ratio, default 0.0. - coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. - activation (Cell): The output activation function, default nn.ELU(). - residual (bool): Whether to use residual connection, default False. - output_transform (str['concat', 'sum']): output transform for a layer, - default 'concat' - - Inputs: - - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). - - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). - - Examples: - >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) - >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) - >>> net = AttentionAggregator(1433, - 8, - 8) - >>> net(input_data, biases) - """ - def __init__(self, - in_channels, - out_channels, - num_heads=1, - in_drop=0.0, - coef_drop=0.0, - activation=nn.ELU(), - residual=False, - output_transform='concat'): - super(AttentionAggregator, self).__init__() - self.num_heads = num_heads - self.attns = [] - for _ in range(num_heads): - self.attns.append(AttentionHead(in_channels, - out_channels, - in_drop_ratio=in_drop, - coef_drop_ratio=coef_drop, - activation=activation, - residual=residual)) - self.attns = nn.layer.CellList(self.attns) - if output_transform == 'concat': - self.out_trans = P.Concat(-1) - elif output_transform == 'sum': - self.out_trans = P.AddN() - else: - raise ValueError("output_transform must be either 'concat' or 'sum'") - - def construct(self, input_data, bias_mat, training=True): - res = () - for i in range(self.num_heads): - res += (self.attns[i](input_data, bias_mat, training),) - return self.out_trans(res) - - -class GAT(nn.Cell): - """ - Graph Attention Network - - Args: - ftr_dims (int): Initial feature dimensions. - num_class (int): Num of class to identify. - num_nodes (int): Num of nodes in this graph. - hidden_units (list[int]): Num of hidden units at each layer. - num_heads (list[int]): Num of heads at each layer. - attn_drop (float): Drop out ratio of attention coefficient, - default 0.0. - ftr_drop (float): Drop out ratio of feature, default 0.0. - activation (Cell): Activation Function for output layer, default - nn.Elu(). - residual (bool): Whether to use residual connection between - intermediate layers, default False. - - Examples: - >>> ft_sizes = 1433 - >>> num_class = 7 - >>> num_nodes = 2708 - >>> hid_units = [8] - >>> n_heads = [8, 1] - >>> activation = nn.ELU() - >>> residual = False - >>> input_data = np.array(np.random.rand(1, 2708, 1433)) - >>> biases = np.array(np.random.rand(1, 2708, 2708)) - >>> net = GAT(ft_sizes, - num_class, - num_nodes, - hidden_units=hid_units, - num_heads=n_heads, - attn_drop=0.6, - ftr_drop=0.6, - activation=activation, - residual=residual) - >>> output = net(input_data, biases) - """ - - def __init__(self, - features, - biases, - ftr_dims, - num_class, - num_nodes, - hidden_units, - num_heads, - attn_drop=0.0, - ftr_drop=0.0, - activation=nn.ELU(), - residual=False): - super(GAT, self).__init__() - self.features = Tensor(features) - self.biases = Tensor(biases) - self.ftr_dims = check_int_positive(ftr_dims) - self.num_class = check_int_positive(num_class) - self.num_nodes = check_int_positive(num_nodes) - self.hidden_units = hidden_units - self.num_heads = num_heads - self.attn_drop = attn_drop - self.ftr_drop = ftr_drop - self.activation = activation - self.residual = check_bool(residual) - self.layers = [] - # first layer - self.layers.append(AttentionAggregator( - self.ftr_dims, - self.hidden_units[0], - self.num_heads[0], - self.ftr_drop, - self.attn_drop, - self.activation, - residual=False)) - # intermediate layer - for i in range(1, len(self.hidden_units)): - self.layers.append(AttentionAggregator( - self.hidden_units[i-1]*self.num_heads[i-1], - self.hidden_units[i], - self.num_heads[i], - self.ftr_drop, - self.attn_drop, - self.activation, - residual=self.residual)) - # output layer - self.layers.append(AttentionAggregator( - self.hidden_units[-1]*self.num_heads[-2], - self.num_class, - self.num_heads[-1], - self.ftr_drop, - self.attn_drop, - activation=None, - residual=False, - output_transform='sum')) - self.layers = nn.layer.CellList(self.layers) - - def construct(self, training=True): - input_data = self.features - bias_mat = self.biases - for cell in self.layers: - input_data = cell(input_data, bias_mat, training) - return input_data/self.num_heads[-1] diff --git a/model_zoo/googlenet/README.md b/model_zoo/googlenet/README.md deleted file mode 100644 index 92cdd8af43..0000000000 --- a/model_zoo/googlenet/README.md +++ /dev/null @@ -1,324 +0,0 @@ -# Contents - -- [GoogleNet Description](#googlenet-description) -- [Model Architecture](#model-architecture) -- [Dataset](#dataset) -- [Features](#features) - - [Mixed Precision](#mixed-precision) -- [Environment Requirements](#environment-requirements) -- [Quick Start](#quick-start) -- [Script Description](#script-description) - - [Script and Sample Code](#script-and-sample-code) - - [Script Parameters](#script-parameters) - - [Training Process](#training-process) - - [Training](#training) - - [Distributed Training](#distributed-training) - - [Evaluation Process](#evaluation-process) - - [Evaluation](#evaluation) -- [Model Description](#model-description) - - [Performance](#performance) - - [Evaluation Performance](#evaluation-performance) - - [Inference Performance](#evaluation-performance) - - [How to use](#how-to-use) - - [Inference](#inference) - - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) - - [Transfer Learning](#transfer-learning) -- [Description of Random Situation](#description-of-random-situation) -- [ModelZoo Homepage](#modelzoo-homepage) - - -# [GoogleNet Description](#contents) - -GoogleNet, a 22 layers deep network, was proposed in 2014 and won the first place in the ImageNet Large-Scale Visual Recognition Challenge 2014 (ILSVRC14). GoogleNet, also called Inception v1, has significant improvement over ZFNet (The winner in 2013) and AlexNet (The winner in 2012), and has relatively lower error rate compared to VGGNet. Typically deeper deep learning network means larger number of parameters, which makes it more prone to overfitting. Furthermore, the increased network size leads to increased use of computational resources. To tackle these issues, GoogleNet adopts 1*1 convolution middle of the network to reduce dimension, and thus further reduce the computation. Global average pooling is used at the end of the network, instead of using fully connected layers. Another technique, called inception module, is to have different sizes of convolutions for the same input and stacking all the outputs. - -[Paper](https://arxiv.org/abs/1409.4842): Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. "Going deeper with convolutions." *Proceedings of the IEEE conference on computer vision and pattern recognition*. 2015. - - -# [Model Architecture](#contents) - -The overall network architecture of GoogleNet is shown below: - -![](https://miro.medium.com/max/3780/1*ZFPOSAted10TPd3hBQU8iQ.png) - -Specifically, the GoogleNet contains numerous inception modules, which are connected together to go deeper. In general, an inception module with dimensionality reduction consists of **1×1 conv**, **3×3 conv**, **5×5 conv**, and **3×3 max pooling**, which are done altogether for the previous input, and stack together again at output. - -![](https://miro.medium.com/max/1108/1*sezFsYW1MyM9YOMa1q909A.png) - - - -# [Dataset](#contents) - -Dataset used: [CIFAR-10]() - -- Dataset size:175M,60,000 32*32 colorful images in 10 classes - - Train:146M,50,000 images - - Test:29.3M,10,000 images -- Data format:binary files - - Note:Data will be processed in dataset.py - - - -# [Features](#contents) - -## Mixed Precision - -The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. -For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. - - - -# [Environment Requirements](#contents) - -- Hardware(Ascend/GPU) - - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. -- Framework - - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) -- For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - - - -# [Quick Start](#contents) - -After installing MindSpore via the official website, you can start training and evaluation as follows: - -```python -# run training example -python train.py > train.log 2>&1 & - -# run distributed training example -sh scripts/run_train.sh rank_table.json - -# run evaluation example -python eval.py > eval.log 2>&1 & OR sh run_eval.sh -``` - - - -# [Script Description](#contents) - -## [Script and Sample Code](#contents) - -``` -├── model_zoo - ├── README.md // descriptions about all the models - ├── googlenet - ├── README.md // descriptions about googlenet - ├── scripts - │ ├──run_train.sh // shell script for distributed - │ ├──run_eval.sh // shell script for evaluation - ├── src - │ ├──dataset.py // creating dataset - │ ├──googlenet.py // googlenet architecture - │ ├──config.py // parameter configuration - ├── train.py // training script - ├── eval.py // evaluation script - ├── export.py // export checkpoint files into geir/onnx -``` - -## [Script Parameters](#contents) - -```python -Major parameters in train.py and config.py are: - ---data_path: The absolute full path to the train and evaluation datasets. ---epoch_size: Total training epochs. ---batch_size: Training batch size. ---lr_init: Initial learning rate. ---num_classes: The number of classes in the training set. ---weight_decay: Weight decay value. ---image_height: Image height used as input to the model. ---image_width: Image width used as input the model. ---pre_trained: Whether training from scratch or training based on the - pre-trained model.Optional values are True, False. ---device_target: Device where the code will be implemented. Optional values - are "Ascend", "GPU". ---device_id: Device ID used to train or evaluate the dataset. Ignore it - when you use run_train.sh for distributed training. ---checkpoint_path: The absolute full path to the checkpoint file saved - after training. ---onnx_filename: File name of the onnx model used in export.py. ---geir_filename: File name of the geir model used in export.py. -``` - - -## [Training Process](#contents) - -### Training - -``` -python train.py > train.log 2>&1 & -``` - -The python command above will run in the background, you can view the results through the file `train.log`. - -After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: - -``` -# grep "loss is " train.log -epoch: 1 step: 390, loss is 1.4842823 -epcoh: 2 step: 390, loss is 1.0897788 -... -``` - -The model checkpoint will be saved in the current directory. - -### Distributed Training - -``` -sh scripts/run_train.sh rank_table.json -``` - -The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows: - -``` -# grep "result: " train_parallel*/log -train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931 -train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874 -... -train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025 -train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336 -... -... -``` - - -## [Evaluation Process](#contents) - -### Evaluation - -Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt". - -``` -python eval.py > eval.log 2>&1 & -OR -sh scripts/run_eval.sh -``` - -The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: - -``` -# grep "accuracy: " eval.log -accuracy: {'acc': 0.934} -``` - -Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows: - -``` -# grep "accuracy: " dist.eval.log -accuracy: {'acc': 0.9217} -``` - - -# [Model Description](#contents) -## [Performance](#contents) - -### Evaluation Performance - -| Parameters | GoogleNet | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | Inception V1 | -| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | -| uploaded Date | 06/09/2020 (month/day/year) | -| MindSpore Version | 0.3.0-alpha | -| Dataset | CIFAR-10 | -| Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 | -| Optimizer | SGD | -| Loss Function | Softmax Cross Entropy | -| outputs | probability | -| Loss | 0.0016 | -| Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step | -| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins | -| Parameters (M) | 6.8 | -| Checkpoint for Fine tuning | 43.07M (.ckpt file) | -| Model for inference | 21.50M (.onnx file), 21.60M(.geir file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/googlenet | - - -### Inference Performance - -| Parameters | GoogleNet | -| ------------------- | --------------------------- | -| Model Version | Inception V1 | -| Resource | Ascend 910 | -| Uploaded Date | 06/09/2020 (month/day/year) | -| MindSpore Version | 0.3.0-alpha | -| Dataset | CIFAR-10, 10,000 images | -| batch_size | 128 | -| outputs | probability | -| Accuracy | 1pc: 93.4%; 8pcs: 92.17% | -| Model for inference | 21.50M (.onnx file) | - -## [How to use](#contents) -### Inference - -If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example: - -``` -# Load unseen dataset for inference -dataset = dataset.create_dataset(cfg.data_path, 1, False) - -# Define model -net = GoogleNet(num_classes=cfg.num_classes) -opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, - cfg.momentum, weight_decay=cfg.weight_decay) -loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', - is_grad=False) -model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - -# Load pre-trained model -param_dict = load_checkpoint(cfg.checkpoint_path) -load_param_into_net(net, param_dict) -net.set_train(False) - -# Make predictions on the unseen dataset -acc = model.eval(dataset) -print("accuracy: ", acc) -``` - -### Continue Training on the Pretrained Model - -``` -# Load dataset -dataset = create_dataset(cfg.data_path, cfg.epoch_size) -batch_num = dataset.get_dataset_size() - -# Define model -net = GoogleNet(num_classes=cfg.num_classes) -# Continue training if set pre_trained to be True -if cfg.pre_trained: - param_dict = load_checkpoint(cfg.checkpoint_path) - load_param_into_net(net, param_dict) -lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, - steps_per_epoch=batch_num) -opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), - Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) -loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) -model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) - -# Set callbacks -config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, - keep_checkpoint_max=cfg.keep_checkpoint_max) -time_cb = TimeMonitor(data_size=batch_num) -ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", - config=config_ck) -loss_cb = LossMonitor() - -# Start training -model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) -print("train success") -``` - -### Transfer Learning -To be added. - - -# [Description of Random Situation](#contents) - -In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. - - -# [ModelZoo Homepage](#contents) - Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/googlenet/eval.py b/model_zoo/googlenet/eval.py deleted file mode 100644 index fc469879e7..0000000000 --- a/model_zoo/googlenet/eval.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -##############test googlenet example on cifar10################# -python eval.py -""" -import mindspore.nn as nn -from mindspore import context -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from src.config import cifar_cfg as cfg -from src.dataset import create_dataset -from src.googlenet import GoogleNet - - -if __name__ == '__main__': - context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) - context.set_context(device_id=cfg.device_id) - - net = GoogleNet(num_classes=cfg.num_classes) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - - param_dict = load_checkpoint(cfg.checkpoint_path) - load_param_into_net(net, param_dict) - net.set_train(False) - dataset = create_dataset(cfg.data_path, 1, False) - acc = model.eval(dataset) - print("accuracy: ", acc) diff --git a/model_zoo/googlenet/scripts/run_train.sh b/model_zoo/googlenet/scripts/run_train.sh deleted file mode 100644 index e8c045c8b1..0000000000 --- a/model_zoo/googlenet/scripts/run_train.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 1 ] -then - echo "Usage: sh run_train.sh [MINDSPORE_HCCL_CONFIG_PATH]" -exit 1 -fi - -if [ ! -f $1 ] -then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" -exit 1 -fi - -ulimit -u unlimited -export DEVICE_NUM=8 -export RANK_SIZE=8 -MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1) -export MINDSPORE_HCCL_CONFIG_PATH -echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}" - -export SERVER_ID=0 -rank_start=$((DEVICE_NUM * SERVER_ID)) -for((i=0; i<${DEVICE_NUM}; i++)) -do - export DEVICE_ID=$i - export RANK_ID=$((rank_start + i)) - rm -rf ./train_parallel$i - mkdir ./train_parallel$i - cp -r ./src ./train_parallel$i - cp ./train.py ./train_parallel$i - echo "start training for rank $RANK_ID, device $DEVICE_ID" - cd ./train_parallel$i ||exit - env > env.log - python train.py --device_id=$i > log 2>&1 & - cd .. -done diff --git a/model_zoo/googlenet/src/dataset.py b/model_zoo/googlenet/src/dataset.py deleted file mode 100644 index a3f74a0617..0000000000 --- a/model_zoo/googlenet/src/dataset.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Data operations, will be used in train.py and eval.py -""" -import os - -import mindspore.common.dtype as mstype -import mindspore.dataset as ds -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as vision -from src.config import cifar_cfg as cfg - - -def create_dataset(data_home, repeat_num=1, training=True): - """Data operations.""" - ds.config.set_seed(1) - data_dir = os.path.join(data_home, "cifar-10-batches-bin") - if not training: - data_dir = os.path.join(data_home, "cifar-10-verify-bin") - - rank_size, rank_id = _get_rank_info() - data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) - - resize_height = cfg.image_height - resize_width = cfg.image_width - - # define map operations - random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT - random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR - normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - changeswap_op = vision.HWC2CHW() - type_cast_op = C.TypeCast(mstype.int32) - - c_trans = [] - if training: - c_trans = [random_crop_op, random_horizontal_op] - c_trans += [resize_op, normalize_op, changeswap_op] - - # apply map operations on images - data_set = data_set.map(input_columns="label", operations=type_cast_op) - data_set = data_set.map(input_columns="image", operations=c_trans) - - # apply repeat operations - data_set = data_set.repeat(repeat_num) - - # apply shuffle operations - data_set = data_set.shuffle(buffer_size=10) - - # apply batch operations - data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) - - return data_set - - -def _get_rank_info(): - """ - get rank size and rank id - """ - rank_size = int(os.environ.get("RANK_SIZE", 1)) - - if rank_size > 1: - from mindspore.communication.management import get_rank, get_group_size - rank_size = get_group_size() - rank_id = get_rank() - else: - rank_size = rank_id = None - - return rank_size, rank_id diff --git a/model_zoo/googlenet/train.py b/model_zoo/googlenet/train.py deleted file mode 100644 index 0129176510..0000000000 --- a/model_zoo/googlenet/train.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -#################train googlent example on cifar10######################## -python train.py -""" -import argparse -import os -import random - -import numpy as np - -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import context -from mindspore.communication.management import init -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.model import Model, ParallelMode -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from src.config import cifar_cfg as cfg -from src.dataset import create_dataset -from src.googlenet import GoogleNet - -random.seed(1) -np.random.seed(1) - - -def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): - """Set learning rate.""" - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] - for i in range(total_steps): - if i < decay_epoch_index[0]: - lr_each_step.append(lr_max) - elif i < decay_epoch_index[1]: - lr_each_step.append(lr_max * 0.1) - elif i < decay_epoch_index[2]: - lr_each_step.append(lr_max * 0.01) - else: - lr_each_step.append(lr_max * 0.001) - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] - - return learning_rate - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') - parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') - args_opt = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) - if args_opt.device_id is not None: - context.set_context(device_id=args_opt.device_id) - else: - context.set_context(device_id=cfg.device_id) - - device_num = int(os.environ.get("DEVICE_NUM", 1)) - if device_num > 1: - context.reset_auto_parallel_context() - context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - init() - - dataset = create_dataset(cfg.data_path, cfg.epoch_size) - batch_num = dataset.get_dataset_size() - - net = GoogleNet(num_classes=cfg.num_classes) - # Continue training if set pre_trained to be True - if cfg.pre_trained: - param_dict = load_checkpoint(cfg.checkpoint_path) - load_param_into_net(net, param_dict) - lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) - - config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) - time_cb = TimeMonitor(data_size=batch_num) - ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", config=config_ck) - loss_cb = LossMonitor() - model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) - print("train success") diff --git a/model_zoo/lenet/train.py b/model_zoo/lenet/train.py deleted file mode 100644 index 740b6e8ca3..0000000000 --- a/model_zoo/lenet/train.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -######################## train lenet example ######################## -train lenet and get network model files(.ckpt) : -python train.py --data_path /YourDataPath -""" - -import os -import argparse -from src.config import mnist_cfg as cfg -from src.dataset import create_dataset -from src.lenet import LeNet5 -import mindspore.nn as nn -from mindspore import context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='MindSpore Lenet Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], - help='device where the code will be implemented (default: Ascend)') - parser.add_argument('--data_path', type=str, default="./Data", - help='path where the dataset is saved') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') - - args = parser.parse_args() - - if args.device_target == "CPU": - args.dataset_sink_mode = False - - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset(os.path.join(args.data_path, "train"), - cfg.batch_size, - cfg.epoch_size) - - network = LeNet5(cfg.num_classes) - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - - print("============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) diff --git a/model_zoo/lenet_quant/src/loss_monitor.py b/model_zoo/lenet_quant/src/loss_monitor.py deleted file mode 100644 index 59c222d23d..0000000000 --- a/model_zoo/lenet_quant/src/loss_monitor.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""LossMonitor Callback class.""" - -import time -import numpy as np -from mindspore.common.tensor import Tensor -from mindspore.train.callback import Callback - - -class LossMonitor(Callback): - """ - Monitor the loss in training. - - If the loss is NAN or INF, it will terminate training. - - Note: - If per_print_times is 0 do not print loss. - - Args: - per_print_times (int): Print loss every times. Default: 1. - lr_init (numpy array): train learning rate. Default: None. - - Raises: - ValueError: If print_step is not int or less than zero. - - Examples: - >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) - """ - - def __init__(self, per_print_times=1, lr_init=None): - super(LossMonitor, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0.") - self._per_print_times = per_print_times - self.lr_init = lr_init - - def epoch_begin(self, run_context): - self.losses = [] - self.epoch_time = time.time() - - def epoch_end(self, run_context): - cb_params = run_context.original_args() - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / cb_params.batch_num - print("Epoch time: {:5.3f}, per step time: {:5.3f}, " - "avg loss: {:5.3f}".format(epoch_mseconds, - per_step_mseconds, - np.mean(self.losses))) - print("*" * 60) - - def step_begin(self, run_context): - self.step_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - step_mseconds = (time.time() - self.step_time) * 1000 - step_loss = cb_params.net_outputs - - if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): - step_loss = step_loss[0] - if isinstance(step_loss, Tensor): - step_loss = np.mean(step_loss.asnumpy()) - - self.losses.append(step_loss) - cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 - - if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): - raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " - "Invalid loss, terminating training.".format( - cb_params.cur_epoch_num - 1, cb_params.epoch_num, - cur_step_in_epoch, cb_params.batch_num)) - - if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: - print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " - "loss: [{:5.4f}], avg loss: [{:5.4f}], time: [{:5.4f}ms]".format( - cb_params.cur_epoch_num, cb_params.epoch_num, - cur_step_in_epoch, int(cb_params.batch_num), - step_loss, np.mean(self.losses), - step_mseconds), flush=True) diff --git a/model_zoo/lenet_quant/train.py b/model_zoo/lenet_quant/train.py deleted file mode 100644 index 03e9ff62bd..0000000000 --- a/model_zoo/lenet_quant/train.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -######################## train lenet example ######################## -train lenet and get network model files(.ckpt) : -python train.py --data_path /YourDataPath -""" - -import os -import argparse -import mindspore.nn as nn -from mindspore import context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy -from src.dataset import create_dataset -from src.config import mnist_cfg as cfg -from src.lenet_fusion import LeNet5 as LeNet5Fusion -from src.loss_monitor import LossMonitor - -parser = argparse.ArgumentParser(description='MindSpore MNIST Example') -parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU'], - help='device where the code will be implemented (default: Ascend)') -parser.add_argument('--data_path', type=str, default="./MNIST_Data", - help='path where the dataset is saved') -parser.add_argument('--ckpt_path', type=str, default="", - help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') -args = parser.parse_args() - -if __name__ == "__main__": - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) - step_size = ds_train.get_dataset_size() - - # define fusion network - network = LeNet5Fusion(cfg.num_classes) - # define network loss - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - # define network optimization - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - - # call back and monitor - config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) - - # define model - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - - print("============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) - print("============== End Training ==============") diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py deleted file mode 100644 index 3a87ccc70d..0000000000 --- a/model_zoo/lenet_quant/train_quant.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -######################## train lenet example ######################## -train lenet and get network model files(.ckpt) : -python train.py --data_path /YourDataPath -""" - -import os -import argparse -import mindspore.nn as nn -from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy -from mindspore.train.quant import quant -from src.dataset import create_dataset -from src.config import mnist_cfg as cfg -from src.lenet_fusion import LeNet5 as LeNet5Fusion -from src.loss_monitor import LossMonitor - -parser = argparse.ArgumentParser(description='MindSpore MNIST Example') -parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU'], - help='device where the code will be implemented (default: Ascend)') -parser.add_argument('--data_path', type=str, default="./MNIST_Data", - help='path where the dataset is saved') -parser.add_argument('--ckpt_path', type=str, default="", - help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') -args = parser.parse_args() - -if __name__ == "__main__": - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) - step_size = ds_train.get_dataset_size() - - # define fusion network - network = LeNet5Fusion(cfg.num_classes) - - # load quantization aware network checkpoint - param_dict = load_checkpoint(args.ckpt_path) - load_param_into_net(network, param_dict) - - # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) - - # define network loss - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - # define network optimization - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - - # call back and monitor - config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) - - # define model - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - - print("============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) - print("============== End Training ==============") diff --git a/model_zoo/lstm/train.py b/model_zoo/lstm/train.py deleted file mode 100644 index 51ae12c685..0000000000 --- a/model_zoo/lstm/train.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -#################train lstm example on aclImdb######################## -python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path -""" -import argparse -import os - -import numpy as np - -from src.config import lstm_cfg as cfg -from src.dataset import convert_to_mindrecord -from src.dataset import lstm_create_dataset -from src.lstm import SentimentNet -from mindspore import Tensor, nn, Model, context -from mindspore.nn import Accuracy -from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor -from mindspore.train.serialization import load_param_into_net, load_checkpoint - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='MindSpore LSTM Example') - parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], - help='whether to preprocess data.') - parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", - help='path where the dataset is stored.') - parser.add_argument('--glove_path', type=str, default="./glove", - help='path where the GloVe is stored.') - parser.add_argument('--preprocess_path', type=str, default="./preprocess", - help='path where the pre-process data is stored.') - parser.add_argument('--ckpt_path', type=str, default="./", - help='the path to save the checkpoint file.') - parser.add_argument('--pre_trained', type=str, default=None, - help='the pretrained checkpoint file path.') - parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], - help='the target device to run, support "GPU", "CPU". Default: "GPU".') - args = parser.parse_args() - - context.set_context( - mode=context.GRAPH_MODE, - save_graphs=False, - device_target=args.device_target) - - if args.preprocess == "true": - print("============== Starting Data Pre-processing ==============") - convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) - - embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) - network = SentimentNet(vocab_size=embedding_table.shape[0], - embed_size=cfg.embed_size, - num_hiddens=cfg.num_hiddens, - num_layers=cfg.num_layers, - bidirectional=cfg.bidirectional, - num_classes=cfg.num_classes, - weight=Tensor(embedding_table), - batch_size=cfg.batch_size) - # pre_trained - if args.pre_trained: - load_param_into_net(network, load_checkpoint(args.pre_trained)) - - loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) - loss_cb = LossMonitor() - - model = Model(network, loss, opt, {'acc': Accuracy()}) - - print("============== Starting Training ==============") - ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) - time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - if args.device_target == "CPU": - model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False) - else: - model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb]) - print("============== Training Success ==============") diff --git a/model_zoo/mass/README.md b/model_zoo/mass/README.md deleted file mode 100644 index d6b1c29186..0000000000 --- a/model_zoo/mass/README.md +++ /dev/null @@ -1,592 +0,0 @@ -![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) - - - -- [MASS: Masked Sequence to Sequence Pre-training for Language Generation Description](#googlenet-description) -- [Model architecture](#model-architecture) -- [Dataset](#dataset) -- [Features](#features) -- [Script description](#script-description) - - [Data Preparation](#Data-Preparation) - - [Tokenization](#Tokenization) - - [Byte Pair Encoding](#Byte-Pair-Encoding) - - [Build Vocabulary](#Build-Vocabulary) - - [Generate Dataset](#Generate-Dataset) - - [News Crawl Corpus](#News-Crawl-Corpus) - - [Gigaword Corpus](#Gigaword-Corpus) - - [Cornell Movie Dialog Corpus](#Cornell-Movie-Dialog-Corpus) - - [Configuration](#Configuration) - - [Training & Evaluation process](#Training-&-Evaluation-process) - - [Weights average](#Weights-average) - - [Learning rate scheduler](#Learning-rate-scheduler) -- [Model description](#model-description) - - [Performance](#performance) - - [Results](#results) - - [Training Performance](#training-performance) - - [Inference Performance](#inference-performance) -- [Environment Requirements](#environment-requirements) - - [Platform](#Platform) - - [Requirements](#Requirements) -- [Get started](#get-started) - - [Pre-training](#Pre-training) - - [Fine-tuning](#Fine-tuning) - - [Inference](#Inference) -- [Description of random situation](#description-of-random-situation) -- [others](#others) -- [ModelZoo Homepage](#modelzoo-homepage) - - - - -# MASS: Masked Sequence to Sequence Pre-training for Language Generation Description - -[MASS: Masked Sequence to Sequence Pre-training for Language Generation](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf) was released by MicroSoft in June 2019. - -BERT(Devlin et al., 2018) have achieved SOTA in natural language understanding area by pre-training the encoder part of Transformer(Vaswani et al., 2017) with masked rich-resource text. Likewise, GPT(Raddford et al., 2018) pre-trains the decoder part of Transformer with masked(encoder inputs are masked) rich-resource text. Both of them build a robust language model by pre-training with masked rich-resource text. - -Inspired by BERT, GPT and other language models, MicroSoft addressed [MASS: Masked Sequence to Sequence Pre-training for Language Generation](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf) which combines BERT's and GPT's idea. MASS has an important parameter k, which controls the masked fragment length. BERT and GPT are specicl case when k equals to 1 and sentence length. - -[Introducing MASS – A pre-training method that outperforms BERT and GPT in sequence to sequence language generation tasks](https://www.microsoft.com/en-us/research/blog/introducing-mass-a-pre-training-method-that-outperforms-bert-and-gpt-in-sequence-to-sequence-language-generation-tasks/) - -[Paper](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf): Song, Kaitao, Xu Tan, Tao Qin, Jianfeng Lu and Tie-Yan Liu. “MASS: Masked Sequence to Sequence Pre-training for Language Generation.” ICML (2019). - - -# Model architecture - -The overall network architecture of MASS is shown below, which is Transformer(Vaswani et al., 2017): - -MASS is consisted of 6-layer encoder and 6-layer decoder with 1024 embedding/hidden size, and 4096 intermediate size between feed forward network which has two full connection layers. - -![Transformer architecture](https://cdn.analyticsvidhya.com/wp-content/uploads/2019/06/Screenshot-from-2019-06-17-19-53-10.png) - - -# Dataset - -Dataset used: -- monolingual English data from News Crawl dataset(WMT 2019) for pre-training. -- Gigaword Corpus(Graff et al., 2003) for Text Summarization. -- Cornell movie dialog corpus(DanescuNiculescu-Mizil & Lee, 2011). - -Details about those dataset could be found in [MASS: Masked Sequence to Sequence Pre-training for Language Generation](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf). - - -# Features - -Mass is designed to jointly pre train encoder and decoder to complete the task of language generation. -First of all, through a sequence to sequence framework, mass only predicts the blocked token, which forces the encoder to understand the meaning of the unshielded token, and encourages the decoder to extract useful information from the encoder. -Secondly, by predicting the continuous token of the decoder, the decoder can build better language modeling ability than only predicting discrete token. -Third, by further shielding the input token of the decoder which is not shielded in the encoder, the decoder is encouraged to extract more useful information from the encoder side, rather than using the rich information in the previous token. - - -# Script description - -MASS script and code structure are as follow: - -```text -├── mass - ├── README.md // Introduction of MASS model. - ├── config - │ ├──config.py // Configuration instance definition. - │ ├──config.json // Configuration file. - ├── src - │ ├──dataset - │ ├──bi_data_loader.py // Dataset loader for fine-tune or inferring. - │ ├──mono_data_loader.py // Dataset loader for pre-training. - │ ├──language_model - │ ├──noise_channel_language_model.p // Noisy channel language model for dataset generation. - │ ├──mass_language_model.py // MASS language model according to MASS paper. - │ ├──loose_masked_language_model.py // MASS language model according to MASS released code. - │ ├──masked_language_model.py // Masked language model according to MASS paper. - │ ├──transformer - │ ├──create_attn_mask.py // Generate mask matrix to remove padding positions. - │ ├──transformer.py // Transformer model architecture. - │ ├──encoder.py // Transformer encoder component. - │ ├──decoder.py // Transformer decoder component. - │ ├──self_attention.py // Self-Attention block component. - │ ├──multi_head_attention.py // Multi-Head Self-Attention component. - │ ├──embedding.py // Embedding component. - │ ├──positional_embedding.py // Positional embedding component. - │ ├──feed_forward_network.py // Feed forward network. - │ ├──residual_conn.py // Residual block. - │ ├──beam_search.py // Beam search decoder for inferring. - │ ├──transformer_for_infer.py // Use Transformer to infer. - │ ├──transformer_for_train.py // Use Transformer to train. - │ ├──utils - │ ├──byte_pair_encoding.py // Apply BPE with subword-nmt. - │ ├──dictionary.py // Dictionary. - │ ├──loss_moniter.py // Callback of monitering loss during training step. - │ ├──lr_scheduler.py // Learning rate scheduler. - │ ├──ppl_score.py // Perplexity score based on N-gram. - │ ├──rouge_score.py // Calculate ROUGE score. - │ ├──load_weights.py // Load weights from a checkpoint or NPZ file. - │ ├──initializer.py // Parameters initializer. - ├── vocab - │ ├──all.bpe.codes // BPE codes table(this file should be generated by user). - │ ├──all_en.dict.bin // Learned vocabulary file(this file should be generated by user). - ├── scripts - │ ├──run.sh // Train & evaluate model script. - │ ├──learn_subword.sh // Learn BPE codes. - │ ├──stop_training.sh // Stop training. - ├── requirements.txt // Requirements of third party package. - ├── train.py // Train API entry. - ├── eval.py // Infer API entry. - ├── tokenize_corpus.py // Corpus tokenization. - ├── apply_bpe_encoding.py // Applying bpe encoding. - ├── weights_average.py // Average multi model checkpoints to NPZ format. - ├── news_crawl.py // Create News Crawl dataset for pre-training. - ├── gigaword.py // Create Gigaword Corpus. - ├── cornell_dialog.py // Create Cornell Movie Dialog dataset for conversation response. - -``` - - -## Data Preparation - -The data preparation of a natural language processing task contains data cleaning, tokenization, encoding and vocabulary generation steps. - -In our experiments, using [Byte Pair Encoding(BPE)](https://arxiv.org/abs/1508.07909) could reduce size of vocabulary, and relieve the OOV influence effectively. - -Vocabulary could be created using `src/utils/dictionary.py` with text dictionary which is learnt from BPE. -For more detail about BPE, please refer to [Subword-nmt lib](https://www.cnpython.com/pypi/subword-nmt) or [paper](https://arxiv.org/abs/1508.07909). - -In our experiments, vocabulary was learned based on 1.9M sentences from News Crawl Dataset, size of vocabulary is 45755. - -Here, we have a brief introduction of data preparation scripts. - - -### Tokenization -Using `tokenize_corpus.py` could tokenize corpus whose text files are in format of `.txt`. - -Major parameters in `tokenize_corpus.py`: - -```bash ---corpus_folder: Corpus folder path, if multi-folders are provided, use ',' split folders. ---output_folder: Output folder path. ---tokenizer: Tokenizer to be used, nltk or jieba, if nltk is not installed fully, use jieba instead. ---pool_size: Processes pool size. -``` - -Sample code: -```bash -python tokenize_corpus.py --corpus_folder /{path}/corpus --output_folder /{path}/tokenized_corpus --tokenizer {nltk|jieba} --pool_size 16 -``` - - -### Byte Pair Encoding -After tokenization, BPE is applied to tokenized corpus with provided `all.bpe.codes`. - -Apply BPE script can be found in `apply_bpe_encoding.py`. - -Major parameters in `apply_bpe_encoding.py`: - -```bash ---codes: BPE codes file. ---src_folder: Corpus folders. ---output_folder: Output files folder. ---prefix: Prefix of text file in `src_folder`. ---vocab_path: Generated vocabulary output path. ---threshold: Filter out words that frequency is lower than threshold. ---processes: Size of process pool (to accelerate). Default: 2. -``` - -Sample code: -```bash -python tokenize_corpus.py --codes /{path}/all.bpe.codes \ - --src_folder /{path}/tokenized_corpus \ - --output_folder /{path}/tokenized_corpus/bpe \ - --prefix tokenized \ - --vocab_path /{path}/vocab_en.dict.bin - --processes 32 -``` - - -### Build Vocabulary -Support that you want to create a new vocabulary, there are two options: -1. Learn BPE codes from scratch, and create vocabulary with multi vocabulary files from `subword-nmt`. -2. Create from an existing vocabulary file which lines in the format of `word frequency`. -3. *Optional*, Create a small vocabulary based on `vocab/all_en.dict.bin` with method of `shink` from `src/utils/dictionary.py`. -4. Persistent vocabulary to `vocab` folder with method `persistence()`. - -Major interface of `src/utils/dictionary.py` are as follow: - -1. `shrink(self, threshold=50)`: Shrink the size of vocabulary by filter out words frequency is lower than threshold. It returns a new vocabulary. -2. `load_from_text(cls, filepaths: List[str])`: Load existed text vocabulary which lines in the format of `word frequency`. -3. `load_from_persisted_dict(cls, filepath)`: Load from a persisted binary vocabulary which was saved by calling `persistence()` method. -4. `persistence(self, path)`: Save vocabulary object to binary file. - -Sample code: -```python -from src.utils import Dictionary - -vocabulary = Dictionary.load_from_persisted_dict("vocab/all_en.dict.bin") -tokens = [1, 2, 3, 4, 5] -# Convert ids to symbols. -print([vocabulary[t] for t in tokens]) - -sentence = ["Hello", "world"] -# Convert symbols to ids. -print([vocabulary.index[s] for s in sentence]) -``` - -For more detail, please refer to the source file. - - -### Generate Dataset -As mentioned above, three corpus are used in MASS mode, dataset generation scripts for them are provided. - -#### News Crawl Corpus -Script can be found in `news_crawl.py`. - -Major parameters in `news_crawl.py`: - -```bash -Note that please provide `--existed_vocab` or `--dict_folder` at least one. -A new vocabulary would be created in `output_folder` when pass `--dict_folder`. - ---src_folder: Corpus folders. ---existed_vocab: Optional, persisted vocabulary file. ---mask_ratio: Ratio of mask. ---output_folder: Output dataset files folder path. ---max_len: Maximum sentence length. If a sentence longer than `max_len`, then drop it. ---suffix: Optional, suffix of generated dataset files. ---processes: Optional, size of process pool (to accelerate). Default: 2. -``` - -Sample code: - -```bash -python news_crawl.py --src_folder /{path}/news_crawl \ - --existed_vocab /{path}/mass/vocab/all_en.dict.bin \ - --mask_ratio 0.5 \ - --output_folder /{path}/news_crawl_dataset \ - --max_len 32 \ - --processes 32 -``` - - -#### Gigaword Corpus -Script can be found in `gigaword.py`. - -Major parameters in `gigaword.py`: - -```bash ---train_src: Train source file path. ---train_ref: Train reference file path. ---test_src: Test source file path. ---test_ref: Test reference file path. ---existed_vocab: Persisted vocabulary file. ---output_folder: Output dataset files folder path. ---noise_prob: Optional, add noise prob. Default: 0. ---max_len: Optional, maximum sentence length. If a sentence longer than `max_len`, then drop it. Default: 64. ---format: Optional, dataset format, "mindrecord" or "tfrecord". Default: "tfrecord". -``` - -Sample code: - -```bash -python gigaword.py --train_src /{path}/gigaword/train_src.txt \ - --train_ref /{path}/gigaword/train_ref.txt \ - --test_src /{path}/gigaword/test_src.txt \ - --test_ref /{path}/gigaword/test_ref.txt \ - --existed_vocab /{path}/mass/vocab/all_en.dict.bin \ - --noise_prob 0.1 \ - --output_folder /{path}/gigaword_dataset \ - --max_len 64 -``` - - -#### Cornell Movie Dialog Corpus -Script can be found in `cornell_dialog.py`. - -Major parameters in `cornell_dialog.py`: - -```bash ---src_folder: Corpus folders. ---existed_vocab: Persisted vocabulary file. ---train_prefix: Train source and target file prefix. Default: train. ---test_prefix: Test source and target file prefix. Default: test. ---output_folder: Output dataset files folder path. ---max_len: Maximum sentence length. If a sentence longer than `max_len`, then drop it. ---valid_prefix: Optional, Valid source and target file prefix. Default: valid. -``` - -Sample code: - -```bash -python cornell_dialog.py --src_folder /{path}/cornell_dialog \ - --existed_vocab /{path}/mass/vocab/all_en.dict.bin \ - --train_prefix train \ - --test_prefix test \ - --noise_prob 0.1 \ - --output_folder /{path}/cornell_dialog_dataset \ - --max_len 64 -``` - - -## Configuration -Json file under the path `config/` is the template configuration file. -Almost all of the options and arguments needed could be assigned conveniently, including the training platform, configurations of dataset and model, arguments of optimizer etc. Optional features such as loss scale and checkpoint are also available by setting the options correspondingly. -For more detailed information about the attributes, refer to the file `config/config.py`. - -## Training & Evaluation process -For training a model, the shell script `run.sh` is all you need. In this scripts, the environment variable is set and the training script `train.py` under `mass` is executed. -You may start a task training with single device or multiple devices by assigning the options and run the command in bash: -```bash -sh run.sh [--options] -``` - -The usage is shown as bellow: -```text -Usage: run.sh [-h, --help] [-t, --task ] [-n, --device_num ] - [-i, --device_id ] [-j, --hccl_json ] - [-c, --config ] [-o, --output ] - [-v, --vocab ] - -options: - -h, --help show usage - -t, --task select task: CHAR, 't' for train and 'i' for inference". - -n, --device_num device number used for training: N, default is 1. - -i, --device_id device id used for training with single device: N, 0<=N<=7, default is 0. - -j, --hccl_json rank table file used for training with multiple devices: FILE. - -c, --config configuration file as shown in the path 'mass/config': FILE. - -o, --output assign output file of inference: FILE. - -v, --vocab set the vocabulary" -``` -Notes: Be sure to assign the hccl_json file while running a distributed-training. - -The command followed shows a example for training with 2 devices. -```bash -sh run.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json -``` -ps. Discontinuous device id is not supported in `run.sh` at present, device id in `rank_table.json` must start from 0. - - -If use a single chip, it would be like this: -```bash -sh run.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json -``` - - -## Weights average - -```python -python weights_average.py --input_files your_checkpoint_list --output_file model.npz -``` - -The input_files is a list of you checkpoints file. To use model.npz as the weights, add its path in config.json at "existed_ckpt". -```json -{ - ... - "checkpoint_options": { - "existed_ckpt": "/xxx/xxx/model.npz", - "save_ckpt_steps": 1000, - ... - }, - ... -} -``` - - -## Learning rate scheduler - -Two learning rate scheduler are provided in our model: - -1. [Polynomial decay scheduler](https://towardsdatascience.com/learning-rate-schedules-and-adaptive-learning-rate-methods-for-deep-learning-2c8f433990d1). -2. [Inverse square root scheduler](https://ece.uwaterloo.ca/~dwharder/aads/Algorithms/Inverse_square_root/). - -LR scheduler could be config in `config/config.json`. - -For Polynomial decay scheduler, config could be like: -```json -{ - ... - "learn_rate_config": { - "optimizer": "adam", - "lr": 1e-4, - "lr_scheduler": "poly", - "poly_lr_scheduler_power": 0.5, - "decay_steps": 10000, - "warmup_steps": 2000, - "min_lr": 1e-6 - }, - ... -} -``` - -For Inverse square root scheduler, config could be like: -```json -{ - ... - "learn_rate_config": { - "optimizer": "adam", - "lr": 1e-4, - "lr_scheduler": "isr", - "decay_start_step": 12000, - "warmup_steps": 2000, - "min_lr": 1e-6 - }, - ... -} -``` - -More detail about LR scheduler could be found in `src/utils/lr_scheduler.py`. - - -# Model description - -The MASS network is implemented by Transformer, which has multi-encoder layers and multi-decoder layers. -For pre-training, we use the Adam optimizer and loss-scale to get the pre-trained model. -During fine-turning, we fine-tune this pre-trained model with different dataset according to different tasks. -During testing, we use the fine-turned model to predict the result, and adopt a beam search algorithm to -get the most possible prediction results. - - -![MASS framework](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-Fig-2.png) - - -## Performance - -### Results - -#### Fine-Tuning on Text Summarization -The comparisons between MASS and two other pre-training methods in terms of ROUGE score on the text summarization task -with 3.8M training data are as follows: - -| Method | RG-1(F) | RG-2(F) | RG-L(F) | -|:---------------|:--------------|:-------------|:-------------| -| MASS | Ongoing | Ongoing | Ongoing | - -#### Fine-Tuning on Conversational ResponseGeneration -The comparisons between MASS and other baseline methods in terms of PPL on Cornell Movie Dialog corpus are as follows: - -| Method | Data = 10K | Data = 110K | -|--------------------|------------------|-----------------| -| MASS | Ongoing | Ongoing | - -#### Training Performance - -| Parameters | Masked Sequence to Sequence Pre-training for Language Generation | -|:---------------------------|:--------------------------------------------------------------------------| -| Model Version | v1 | -| Resource | Ascend 910, cpu 2.60GHz, 56cores;memory, 314G | -| uploaded Date | 05/24/2020 | -| MindSpore Version | 0.2.0 | -| Dataset | News Crawl 2007-2017 English monolingual corpus, Gigaword corpus, Cornell Movie Dialog corpus | -| Training Parameters | Epoch=50, steps=XXX, batch_size=192, lr=1e-4 | -| Optimizer | Adam | -| Loss Function | Label smoothed cross-entropy criterion | -| outputs | Sentence and probability | -| Loss | Lower than 2 | -| Accuracy | For conversation response, ppl=23.52, for text summarization, RG-1=29.79. | -| Speed | 611.45 sentences/s | -| Total time | --/-- | -| Params (M) | 44.6M | -| Checkpoint for Fine tuning | ---Mb, --, [A link]() | -| Model for inference | ---Mb, --, [A link]() | -| Scripts | [A link]() | - - -#### Inference Performance - -| Parameters | Masked Sequence to Sequence Pre-training for Language Generation | -|:---------------------------|:-----------------------------------------------------------| -| Model Version | V1 | -| Resource | Huawei 910 | -| uploaded Date | 05/24/2020 | -| MindSpore Version | 0.2.0 | -| Dataset | Gigaword corpus, Cornell Movie Dialog corpus | -| batch_size | --- | -| outputs | Sentence and probability | -| Accuracy | ppl=23.52 for conversation response, RG-1=29.79 for text summarization. | -| Speed | ---- sentences/s | -| Total time | --/-- | -| Model for inference | ---Mb, --, [A link]() | - - -# Environment Requirements - -## Platform - -- Hardware(Ascend) - - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial. -- Framework - - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) -- For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - -## Requirements - -```txt -nltk -numpy -subword-nmt -rouge -``` - -https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html - - -# Get started -MASS pre-trains a sequence to sequence model by predicting the masked fragments in an input sequence. After this, downstream tasks including text summarization and conversation response are candidated for fine-tuning the model and for inference. -Here we provide a practice example to demonstrate the basic usage of MASS for pre-training, fine-tuning a model, and the inference process. The overall process is as follows: -1. Download and process the dataset. -2. Modify the `config.json` to config the network. -3. Run a task for pre-training and fine-tuning. -4. Perform inference and validation. - -## Pre-training -For pre-training a model, config the options in `config.json` firstly: -- Assign the `pre_train_dataset` under `dataset_config` node to the dataset path. -- Choose the optimizer('momentum/adam/lamb' is available). -- Assign the 'ckpt_prefix' and 'ckpt_path' under `checkpoint_path` to save the model files. -- Set other arguments including dataset configurations and network configurations. -- If you have a trained model already, assign the `existed_ckpt` to the checkpoint file. - -Run the shell script `run.sh` as followed: - -```bash -sh run.sh -t t -n 1 -i 1 -c /mass/config/config.json -``` -Get the log and output files under the path `./run_mass_*/`, and the model file under the path assigned in the `config/config.json` file. - -## Fine-tuning -For fine-tuning a model, config the options in `config.json` firstly: -- Assign the `fine_tune_dataset` under `dataset_config` node to the dataset path. -- Assign the `existed_ckpt` under `checkpoint_path` node to the existed model file generated by pre-training. -- Choose the optimizer('momentum/adam/lamb' is available). -- Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. -- Set other arguments including dataset configurations and network configurations. - -Run the shell script `run.sh` as followed: -```bash -sh run.sh -t t -n 1 -i 1 -c config/config.json -``` -Get the log and output files under the path `./run_mass_*/`, and the model file under the path assigned in the `config/config.json` file. - -## Inference -If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). -For inference, config the options in `config.json` firstly: -- Assign the `test_dataset` under `dataset_config` node to the dataset path. -- Assign the `existed_ckpt` under `checkpoint_path` node to the model file produced by fine-tuning. -- Choose the optimizer('momentum/adam/lamb' is available). -- Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. -- Set other arguments including dataset configurations and network configurations. - -Run the shell script `run.sh` as followed: - -```bash -sh run.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile} -``` - -# Description of random situation - -MASS model contains dropout operations, if you want to disable dropout, please set related dropout_rate to 0 in `config/config.json`. - - -# others -The model has been validated on Ascend environment, not validated on CPU and GPU. - - -# ModelZoo Homepage - [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/mass/scripts/run.sh b/model_zoo/mass/scripts/run.sh deleted file mode 100644 index 132e38dae2..0000000000 --- a/model_zoo/mass/scripts/run.sh +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -export DEVICE_ID=0 -export RANK_ID=0 -export RANK_SIZE=1 - -options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"` -eval set -- "$options" -echo $options - -echo_help() -{ - echo "Usage:" - echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]" - echo "options:" - echo " -h --help show usage" - echo " -t --task select task, 't' for training and 'i' for inference" - echo " -n --device_num training with N devices" - echo " -i --device_id training with device i" - echo " -j --hccl_json set the rank table file" - echo " -c --config set the configuration file" - echo " -o --output set the output file of inference" - echo " -v --vocab set the vocabulary" - echo " -m --metric set the metric" -} - -set_hccl_json() -{ - while [ -n "$1" ] - do - if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] - then - export MINDSPORE_HCCL_CONFIG_PATH=$2 - export RANK_TABLE_FILE=$2 - break - fi - shift - done -} -set_device_id() -{ - while [ -n "$1" ] - do - if [[ "$1" == "-i" || "$1" == "--device_id" ]] - then - if [[ $2 -ge 0 && $2 -le 7 ]] - then - export DEVICE_ID=$2 - fi - break - fi - shift - done -} - -while [ -n "$1" ] -do - case "$1" in - -h|--help) - echo_help - shift - ;; - -t|--task) - echo "task:" - if [ "$2" == "t" ] - then - task=train - elif [ "$2" == "i" ] - then - task=infer - fi - shift 2 - ;; - -n|--device_num) - echo "device_num" - if [ $2 -eq 1 ] - then - set_device_id $options - elif [ $2 -gt 1 ] - then - export HCCL_FLAG=1 - export DEPLOY_MODE=0 - - export RANK_SIZE=$2 - set_hccl_json $options - fi - shift 2 - ;; - -i|--device_id) - echo "set device id" - export DEVICE_ID=$2 - shift 2 - ;; - -c|--config) - echo "config"; - configurations=$2 - shift 2 - ;; - -o|--output) - echo "output"; - output=$2 - shift 2 - ;; - -v|--vocab) - echo "vocab"; - vocab=$2 - shift 2 - ;; - -m|--metric) - echo "metric"; - metric=$2 - shift 2 - ;; - --) - shift - break - ;; - *) - shift - ;; -esac -done - -file_path=$(cd "$(dirname $0)" || exit; pwd) -for((i=0; i < $RANK_SIZE; i++)) -do - if [ $RANK_SIZE -gt 1 ] - then - echo $RANK_SIZE - export RANK_ID=$i - export DEVICE_ID=$[i] - fi - echo "Working on device $i" - - cd $file_path || exit - cd ../ || exit - - rm -rf ./run_mass_$DEVICE_ID - mkdir ./run_mass_$DEVICE_ID - - cp train.py ./run_mass_$DEVICE_ID - cp eval.py ./run_mass_$DEVICE_ID - cp $configurations ./run_mass_$DEVICE_ID - - if [ $vocab ] - then - cp $vocab ./run_mass_$DEVICE_ID - fi - - cd ./run_mass_$DEVICE_ID || exit - env > log.log - echo $task - if [ "$task" == "train" ] - then - python train.py --config ${configurations##*/} >>log.log 2>&1 & - elif [ "$task" == "infer" ] - then - python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & - fi - cd ../ -done diff --git a/model_zoo/mass/src/dataset/load_dataset.py b/model_zoo/mass/src/dataset/load_dataset.py deleted file mode 100644 index 53ad5c7491..0000000000 --- a/model_zoo/mass/src/dataset/load_dataset.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Dataset loader to feed into model.""" -import mindspore.common.dtype as mstype -import mindspore.dataset.engine as de -import mindspore.dataset.transforms.c_transforms as deC - - -def _load_dataset(input_files, batch_size, epoch_count=1, - sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True): - """ - Load dataset according to passed in params. - - Args: - input_files (list): Data files. - batch_size (int): Batch size. - epoch_count (int): Epoch count. - sink_mode (bool): Whether enable sink mode. - sink_step (int): Step to sink. - rank_size (int): Rank size. - rank_id (int): Rank id. - shuffle (bool): Whether shuffle dataset. - - Returns: - Dataset, dataset instance. - """ - if not input_files: - raise FileNotFoundError("Require at least one dataset.") - - if not isinstance(sink_mode, bool): - raise ValueError("`sink` must be type of bool.") - - for datafile in input_files: - print(f" | Loading {datafile}.") - - ds = de.TFRecordDataset( - input_files, - columns_list=[ - "src", "src_padding", - "prev_opt", "prev_padding", - "target", "tgt_padding" - ], - shuffle=shuffle, num_shards=rank_size, shard_id=rank_id, - shard_equal_rows=True, num_parallel_workers=8) - - ori_dataset_size = ds.get_dataset_size() - print(f" | Dataset size: {ori_dataset_size}.") - repeat_count = epoch_count - if sink_mode: - ds.set_dataset_size(sink_step * batch_size) - repeat_count = epoch_count * ori_dataset_size // ds.get_dataset_size() - - type_cast_op = deC.TypeCast(mstype.int32) - ds = ds.map(input_columns="src", operations=type_cast_op) - ds = ds.map(input_columns="src_padding", operations=type_cast_op) - ds = ds.map(input_columns="prev_opt", operations=type_cast_op) - ds = ds.map(input_columns="prev_padding", operations=type_cast_op) - ds = ds.map(input_columns="target", operations=type_cast_op) - ds = ds.map(input_columns="tgt_padding", operations=type_cast_op) - - ds = ds.rename( - input_columns=["src", - "src_padding", - "prev_opt", - "prev_padding", - "target", - "tgt_padding"], - output_columns=["source_eos_ids", - "source_eos_mask", - "target_sos_ids", - "target_sos_mask", - "target_eos_ids", - "target_eos_mask"] - ) - - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_count) - - ds.channel_name = 'transformer' - return ds - - -def load_dataset(data_files: list, batch_size: int, epoch_count: int, - sink_mode: bool, sink_step: int = 1, rank_size: int = 1, rank_id: int = 0, shuffle=True): - """ - Load dataset. - - Args: - data_files (list): Data files. - batch_size (int): Batch size. - epoch_count (int): Epoch count. - sink_mode (bool): Whether enable sink mode. - sink_step (int): Step to sink. - rank_size (int): Rank size. - rank_id (int): Rank id. - shuffle (bool): Whether shuffle dataset. - - Returns: - Dataset, dataset instance. - """ - return _load_dataset(data_files, batch_size, epoch_count, sink_mode, - sink_step, rank_size, rank_id, shuffle=shuffle) diff --git a/model_zoo/mass/src/transformer/transformer_for_infer.py b/model_zoo/mass/src/transformer/transformer_for_infer.py deleted file mode 100644 index 8b1a1c4667..0000000000 --- a/model_zoo/mass/src/transformer/transformer_for_infer.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Transformer for infer.""" -import math -import copy -import numpy as np -import mindspore.common.dtype as mstype -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.tensor import Tensor - -from .beam_search import BeamSearchDecoder, TileBeam -from .embedding import EmbeddingLookup -from .positional_embedding import PositionalEmbedding -from .components import SaturateCast -from .create_attn_mask import CreateAttentionMaskFromInputMask -from .decoder import TransformerDecoder -from .encoder import TransformerEncoder - - -class PredLogProbs(nn.Cell): - """ - Get log probs. - - Args: - batch_size (int): Batch size of input dataset. - seq_length (int): The length of sequences. - width (int): Number of parameters of a layer - compute_type (int): Type of input type. - dtype (int): Type of MindSpore output type. - """ - - def __init__(self, - batch_size, - seq_length, - width, - compute_type=mstype.float32, - dtype=mstype.float32): - super(PredLogProbs, self).__init__() - self.batch_size = batch_size - self.seq_length = seq_length - self.width = width - self.compute_type = compute_type - self.dtype = dtype - - self.reshape = P.Reshape() - self.matmul = P.MatMul(transpose_b=True) - self.log_softmax = nn.LogSoftmax(axis=-1) - self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width) - self.cast = P.Cast() - - def construct(self, input_tensor, output_weights): - """ - Calculate the log_softmax. - - Inputs: - input_tensor (Tensor): A batch of sentences with shape (N, T). - output_weights (Tensor): A batch of masks with shape (N, T). - - Returns: - Tensor, the prediction probability with shape (N, T'). - """ - input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) - input_tensor = self.cast(input_tensor, self.compute_type) - output_weights = self.cast(output_weights, self.compute_type) - - logits = self.matmul(input_tensor, output_weights) - logits = self.cast(logits, self.dtype) - - log_probs = self.log_softmax(logits) - return log_probs - - -class TransformerDecoderStep(nn.Cell): - """ - Multi-layer transformer decoder step. - - Args: - config (TransformerConfig): The config of Transformer. - num_hidden_layers (int): The numbers of hidden layers. - attn_embed_dim (int): Dimensions of attention weights. - num_attn_heads=12 (int): Heads number. - seq_length (int): The length of a sequence. - intermediate_size: Hidden size in FFN. - attn_dropout_prob (float): Dropout rate in attention. Default: 0.1. - initializer_range (float): Initial range. - hidden_dropout_prob (float): Dropout rate in FFN. - hidden_act (str): Activation function in FFN. - compute_type (mstype): Mindspore data type. Default: mstype.float32. - embedding_lookup (function): Embeddings lookup operation. Default: None. - positional_embedding (function): Position Embedding operation. Default: None. - projection (function): Function to get log probs. Default: None. - """ - - def __init__(self, - config, - num_hidden_layers, - attn_embed_dim, - num_attn_heads=12, - seq_length=64, - intermediate_size=3072, - attn_dropout_prob=0.1, - initializer_range=0.02, - hidden_dropout_prob=0.1, - hidden_act="relu", - compute_type=mstype.float32, - embedding_lookup=None, - positional_embedding=None, - projection=None): - super(TransformerDecoderStep, self).__init__(auto_prefix=False) - self.embedding_lookup = embedding_lookup - self.positional_embedding = positional_embedding - self.projection = projection - self.seq_length = seq_length - self.decoder = TransformerDecoder( - attn_embed_dim=attn_embed_dim, - num_attn_heads=num_attn_heads, - decoder_layers=num_hidden_layers, - intermediate_size=intermediate_size, - attn_dropout_prob=attn_dropout_prob, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - - self.ones_like = P.OnesLike() - self.shape = P.Shape() - - self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) - self.expand = P.ExpandDims() - self.multiply = P.Mul() - - ones = np.ones(shape=(seq_length, seq_length)) - self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) - - self.cast_compute_type = SaturateCast(dst_type=compute_type) - self.scale = Tensor([math.sqrt(float(attn_embed_dim))], dtype=mstype.float32) - - def construct(self, input_ids, enc_states, enc_attention_mask): - """ - Get log probs. - - Args: - input_ids: [batch_size * beam_width, m] - enc_states: [batch_size * beam_width, T, D] - enc_attention_mask: [batch_size * beam_width, T, D] - - Returns: - Tensor, the log_probs. [batch_size * beam_width, 1, Vocabulary_Dimension] - """ - - # process embedding. input_embedding: [batch_size * beam_width, m, D], embedding_tables: [V, D] - input_embedding, embedding_tables = self.embedding_lookup(input_ids) - input_embedding = self.multiply(input_embedding, self.scale) - input_embedding = self.positional_embedding(input_embedding) - input_embedding = self.cast_compute_type(input_embedding) - - input_shape = self.shape(input_ids) - input_len = input_shape[1] - # [m,m] - future_mask = self.future_mask[0:input_len:1, 0:input_len:1] - # [batch_size * beam_width, m] - input_mask = self.ones_like(input_ids) - # [batch_size * beam_width, m, m] - input_mask = self._create_attention_mask_from_input_mask(input_mask) - # [batch_size * beam_width, m, m] - input_mask = self.multiply(input_mask, self.expand(future_mask, 0)) - input_mask = self.cast_compute_type(input_mask) - - # [batch_size * beam_width, m, D] - enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] - - # call TransformerDecoder: [batch_size * beam_width, m, D] - decoder_output = self.decoder(input_embedding, input_mask, enc_states, enc_attention_mask) - - # take the last step, [batch_size * beam_width, 1, D] - decoder_output = decoder_output[::, input_len - 1:input_len:1, ::] - - # projection and log_prob - log_probs = self.projection(decoder_output, embedding_tables) - - # [batch_size * beam_width, 1, vocabulary_size] - return log_probs - - -class TransformerInferModel(nn.Cell): - """ - Transformer Infer. - - Args: - config (TransformerConfig): The config of Transformer. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - """ - - def __init__(self, - config, - use_one_hot_embeddings=False): - super(TransformerInferModel, self).__init__() - config = copy.deepcopy(config) - config.hidden_dropout_prob = 0.0 - config.attention_dropout_prob = 0.0 - - self.input_mask_from_dataset = config.input_mask_from_dataset - self.batch_size = config.batch_size - self.seq_length = config.seq_length - self.hidden_size = config.hidden_size - self.num_hidden_layers = config.num_hidden_layers - self.embedding_size = config.hidden_size - self.attn_embed_dim = config.hidden_size - self.num_layers = config.num_hidden_layers - self.last_idx = self.num_hidden_layers - 1 - - self.embedding_lookup = EmbeddingLookup( - vocab_size=config.vocab_size, - embed_dim=self.embedding_size, - use_one_hot_embeddings=use_one_hot_embeddings) - - self.positional_embedding = PositionalEmbedding( - embedding_size=self.embedding_size, - max_position_embeddings=config.max_position_embeddings) - # use for infer - self.projection = PredLogProbs( - batch_size=config.batch_size * config.beam_width, - seq_length=1, - width=self.hidden_size, - compute_type=config.compute_type) - - self.encoder = TransformerEncoder( - attn_embed_dim=self.attn_embed_dim, - encoder_layers=self.num_layers, - num_attn_heads=config.num_attention_heads, - intermediate_size=config.intermediate_size, - attention_dropout_prob=config.attention_dropout_prob, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - hidden_act=config.hidden_act, - compute_type=config.compute_type) - - decoder_cell = TransformerDecoderStep( - config=config, - num_hidden_layers=config.num_hidden_layers, - attn_embed_dim=self.attn_embed_dim, - seq_length=config.seq_length, - num_attn_heads=config.num_attention_heads, - intermediate_size=config.intermediate_size, - hidden_dropout_prob=config.hidden_dropout_prob, - compute_type=config.compute_type, - initializer_range=config.initializer_range, - hidden_act="relu", - embedding_lookup=self.embedding_lookup, - positional_embedding=self.positional_embedding, - attn_dropout_prob=config.attention_dropout_prob, - projection=self.projection - ) - - # link beam_search after decoder - self.decoder = BeamSearchDecoder( - batch_size=config.batch_size, - seq_length=config.seq_length, - vocab_size=config.vocab_size, - decoder=decoder_cell, - beam_width=config.beam_width, - length_penalty_weight=config.length_penalty_weight, - max_decode_length=config.max_decode_length) - - self.decoder.add_flags(loop_can_unroll=True) - - self.cast = P.Cast() - self.dtype = config.dtype - self.cast_compute_type = SaturateCast(dst_type=config.compute_type) - self.expand = P.ExpandDims() - self.multiply = P.Mul() - - self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) - - # use for infer - self.tile_beam = TileBeam(beam_width=config.beam_width) - ones = np.ones(shape=(config.batch_size, config.max_decode_length)) - self.encode_mask = Tensor(ones, dtype=mstype.float32) - - self.scale = Tensor([math.sqrt(float(self.embedding_size))], - dtype=mstype.float32) - self.reshape = P.Reshape() - - def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): - """ - Process source sentence - - Inputs: - source_ids (Tensor): Source sentences with shape (N, T). - source_mask (Tensor): Source sentences padding mask with shape (N, T), - where 0 indicates padding position. - - Returns: - Tensor, Predictions with shape (N, T'). - """ - # word_embeddings - src_embeddings, _ = self.embedding_lookup(source_ids) - src_embeddings = self.multiply(src_embeddings, self.scale) - # position_embeddings - src_embeddings = self.positional_embedding(src_embeddings) - # attention mask, [batch_size, seq_length, seq_length] - enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) - # encode - encoder_output = self.encoder(self.cast_compute_type(src_embeddings), - self.cast_compute_type(enc_attention_mask)) - - # bean search for encoder output - beam_encoder_output = self.tile_beam(encoder_output) - # [batch_size, T, D] - enc_attention_mask = self.multiply( - enc_attention_mask[::, 0:1:1, ::], - self.expand(self.encode_mask, -1)) - # [N*batch_size, T, D] - beam_enc_attention_mask = self.tile_beam(enc_attention_mask) - beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) - predicted_ids, predicted_probs = self.decoder(beam_encoder_output, beam_enc_attention_mask) - predicted_ids = self.reshape(predicted_ids, (self.batch_size, -1)) - return predicted_ids, predicted_probs diff --git a/model_zoo/mass/src/transformer/transformer_for_train.py b/model_zoo/mass/src/transformer/transformer_for_train.py deleted file mode 100644 index eb75e2d7b9..0000000000 --- a/model_zoo/mass/src/transformer/transformer_for_train.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Transformer for training.""" -from mindspore import nn -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.train.parallel_utils import ParallelMode -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean - -from .transformer import Transformer -from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients - - -class PredLogProbs(nn.Cell): - """ - Get log probs. - - Args: - config (TransformerConfig): The config of Transformer. - - Returns: - Tensor, masked lm output. - """ - - def __init__(self, config): - super(PredLogProbs, self).__init__() - self.width = config.hidden_size - self.reshape = P.Reshape() - - self.matmul = P.MatMul(transpose_b=True) - self.log_softmax = nn.LogSoftmax(axis=-1) - self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) - self.cast = P.Cast() - self.compute_type = config.compute_type - self.dtype = config.dtype - self.get_shape = P.Shape() - - def construct(self, input_tensor, output_weights): - """ - Construct network. - - Args: - input_tensor (Tensor): Tensor. - output_weights (Tensor): Tensor. - - Returns: - Tensor, masked lm output. - """ - shape = self.get_shape(input_tensor) - - input_tensor = self.reshape(input_tensor, (shape[0] * shape[1], shape[2])) - input_tensor = self.cast(input_tensor, self.compute_type) - output_weights = self.cast(output_weights, self.compute_type) - - logits = self.matmul(input_tensor, output_weights) - logits = self.cast(logits, self.dtype) - - log_probs = self.log_softmax(logits) - return log_probs - - -class TransformerTraining(nn.Cell): - """ - Transformer training network. - - Args: - config (TransformerConfig): The config of Transformer. - is_training (bool): Specifies whether to use the training mode. - use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. - - Returns: - Tensor, prediction_scores, seq_relationship_score. - """ - - def __init__(self, config, is_training, use_one_hot_embeddings): - super(TransformerTraining, self).__init__() - self.transformer = Transformer(config, is_training, use_one_hot_embeddings) - self.projection = PredLogProbs(config) - - def construct(self, source_ids, source_mask, target_ids, target_mask): - """ - Construct network. - - Args: - source_ids (Tensor): Source sentence. - source_mask (Tensor): Source padding mask. - target_ids (Tensor): Target sentence. - target_mask (Tensor): Target padding mask. - - Returns: - Tensor, prediction_scores, seq_relationship_score. - """ - _, decoder_outputs, embedding_table = \ - self.transformer(source_ids, source_mask, target_ids, target_mask) - prediction_scores = self.projection(decoder_outputs, - embedding_table) - return prediction_scores - - -class LabelSmoothedCrossEntropyCriterion(nn.Cell): - """ - Label Smoothed Cross-Entropy Criterion. - - Args: - config (TransformerConfig): The config of Transformer. - - Returns: - Tensor, final loss. - """ - - def __init__(self, config): - super(LabelSmoothedCrossEntropyCriterion, self).__init__() - self.vocab_size = config.vocab_size - self.onehot = P.OneHot() - self.on_value = Tensor(float(1 - config.label_smoothing), mstype.float32) - self.off_value = Tensor(config.label_smoothing / float(self.vocab_size - 1), mstype.float32) - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.reshape = P.Reshape() - self.last_idx = (-1,) - self.flatten = P.Flatten() - self.neg = P.Neg() - self.cast = P.Cast() - self.flat_shape = (config.batch_size * config.seq_length,) - self.get_shape = P.Shape() - - def construct(self, prediction_scores, label_ids, label_weights): - """ - Construct network to calculate loss. - - Args: - prediction_scores (Tensor): Prediction scores. - label_ids (Tensor): Labels. - label_weights (Tensor): Mask tensor. - - Returns: - Tensor, final loss. - """ - label_shape = self.get_shape(label_ids) - - label_ids = self.reshape(label_ids, (label_shape[0] * label_shape[1],)) - label_weights = self.cast( - self.reshape(label_weights, (label_shape[0] * label_shape[1],)), - mstype.float32 - ) - one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) - - per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) - numerator = self.reduce_sum(label_weights * per_example_loss, ()) - denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) - loss = numerator / denominator - - return loss - - -class TransformerNetworkWithLoss(nn.Cell): - """ - Provide transformer training loss through network. - - Args: - config (BertConfig): The config of Transformer. - is_training (bool): Specifies whether to use the training mode. - use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. - - Returns: - Tensor, the loss of the network. - """ - - def __init__(self, config, is_training, use_one_hot_embeddings=False): - super(TransformerNetworkWithLoss, self).__init__() - self.transformer = TransformerTraining(config, is_training, use_one_hot_embeddings) - self.loss = LabelSmoothedCrossEntropyCriterion(config) - self.cast = P.Cast() - - def construct(self, - source_ids, - source_mask, - target_ids, - target_mask, - label_ids, - label_weights): - prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask) - total_loss = self.loss(prediction_scores, label_ids, label_weights) - return self.cast(total_loss, mstype.float32) - - -grad_scale = C.MultitypeFuncGraph("grad_scale") -reciprocal = P.Reciprocal() - - -@grad_scale.register("Tensor", "Tensor") -def tensor_grad_scale(scale, grad): - return grad * F.cast(reciprocal(scale), F.dtype(grad)) - - -class TransformerTrainOneStepWithLossScaleCell(nn.Cell): - """ - Encapsulation class of Transformer network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - Args: - network: Cell. The training network. Note that loss function should have - been added. - optimizer: Optimizer. Optimizer for updating the weights. - - Returns: - Tuple[Tensor, Tensor, Tensor], loss, overflow, sen. - """ - - def __init__(self, network, optimizer, scale_update_cell=None): - - super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.add_flags(defer_inline=True) - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.all_reduce = P.AllReduce() - - self.parallel_mode = _get_parallel_mode() - if self.parallel_mode not in ParallelMode.MODE_LIST: - raise ValueError("Parallel mode does not support: ", self.parallel_mode) - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = _get_mirror_mean() - degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.clip_gradients = ClipGradients() - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") - self.add_flags(has_effect=True) - - def construct(self, - source_eos_ids, - source_eos_mask, - target_sos_ids, - target_sos_mask, - target_eos_ids, - target_eos_mask, - sens=None): - """ - Construct network. - - Args: - source_eos_ids (Tensor): Source sentence. - source_eos_mask (Tensor): Source padding mask. - target_sos_ids (Tensor): Target sentence. - target_sos_mask (Tensor): Target padding mask. - target_eos_ids (Tensor): Prediction sentence. - target_eos_mask (Tensor): Prediction padding mask. - sens (Tensor): Loss sen. - - Returns: - Tuple[Tensor, Tensor, Tensor], loss, overflow, sen. - """ - source_ids = source_eos_ids - source_mask = source_eos_mask - target_ids = target_sos_ids - target_mask = target_sos_mask - label_ids = target_eos_ids - label_weights = target_eos_mask - - weights = self.weights - loss = self.network(source_ids, - source_mask, - target_ids, - target_mask, - label_ids, - label_weights) - # Alloc status. - init = self.alloc_status() - # Clear overflow buffer. - self.clear_before_grad(init) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - grads = self.grad(self.network, weights)(source_ids, - source_mask, - target_ids, - target_mask, - label_ids, - label_weights, - self.cast(scaling_sens, - mstype.float32)) - - grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) - if self.reducer_flag: - # Apply grad reducer on grads. - grads = self.grad_reducer(grads) - self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - - if self.is_distributed: - # Sum overflow flag over devices. - flag_reduce = self.all_reduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - - ret = (loss, cond, scaling_sens) - return F.depend(ret, succ) diff --git a/model_zoo/mass/src/utils/lr_scheduler.py b/model_zoo/mass/src/utils/lr_scheduler.py deleted file mode 100644 index 44ef397fdd..0000000000 --- a/model_zoo/mass/src/utils/lr_scheduler.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Learning scheduler.""" -from math import ceil - -import numpy as np - - -def square_root_schedule(lr, update_num, decay_start_step, - warmup_steps=2000, - min_lr=1e-5): - """ - Decay the LR based on the ISR(inverse square root). - - During warm-up:: - lrs = np.linspace(0, lr, warmup_steps) - - After warm-up: - decay_factor = lr * sqrt(warmup_steps) - lr = decay_factor / sqrt(step) if step >= decay_start_step else lr - - Args: - lr (float): Init learning rate. - update_num (int): Total steps. - decay_start_step (int): Decay begins after `decay_start_step` steps. - warmup_steps (int): Warm up steps. - min_lr (float): Min learning rate. - - Returns: - np.ndarray, learning rate array. - """ - warmup_end_lr = lr - warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr - - # If warmup_init_lr > lr, then lr_step is negative. - # Otherwise, it's positive. - lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps - decay_factor = lr * warmup_steps ** 0.5 - - lrs = np.empty(shape=update_num, dtype=np.float32) - _start_step = 0 - if 0 < warmup_steps < update_num: - lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps) - _start_step = warmup_steps - - for step in range(_start_step, update_num): - if step < warmup_steps: - _lr = warmup_init_lr + step * lr_step - elif step < decay_start_step: - _lr = lr - else: - _lr = decay_factor * step ** -0.5 - if _lr < min_lr: - _lr = min_lr - lrs[step] = _lr - - return lrs - - -def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0): - """ - Implements of polynomial decay learning rate scheduler which cycles by default. - - Args: - lr (float): Initial learning rate. - warmup_steps (int): Warmup steps. - decay_steps (int): Decay steps. - total_update_num (int): Total update steps. - min_lr (float): Min learning. - power (float): Power factor. - - Returns: - np.ndarray, learning rate of each step. - """ - lrs = np.zeros(shape=total_update_num, dtype=np.float32) - - if decay_steps <= 0: - raise ValueError("`decay_steps` must larger than 1.") - - _start_step = 0 - if 0 < warmup_steps < total_update_num: - warmup_end_lr = lr - warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr - lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps) - _start_step = warmup_steps - - decay_steps = decay_steps - for step in range(_start_step, total_update_num): - _step = step - _start_step # 2999 - ratio = ceil(_step / decay_steps) # 3 - ratio = 1 if ratio < 1 else ratio - _decay_steps = decay_steps * ratio # 3000 - lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr - - return lrs diff --git a/model_zoo/mass/train.py b/model_zoo/mass/train.py deleted file mode 100644 index b58075ba4e..0000000000 --- a/model_zoo/mass/train.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Train api.""" -import os -import argparse -import pickle - -import numpy as np - -import mindspore.common.dtype as mstype -from mindspore.common.tensor import Tensor -from mindspore.nn import Momentum -from mindspore.nn.optim import Adam, Lamb -from mindspore.train.model import Model -from mindspore.train.loss_scale_manager import DynamicLossScaleManager -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint -from mindspore import context, ParallelMode, Parameter -from mindspore.communication import management as MultiAscend -from mindspore.train.serialization import load_checkpoint - -from config import TransformerConfig -from src.dataset import load_dataset -from src.transformer import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell -from src.transformer.infer_mass import infer -from src.utils import LossCallBack -from src.utils import one_weight, zero_weight, weight_variable -from src.utils import square_root_schedule -from src.utils.lr_scheduler import polynomial_decay_scheduler - -parser = argparse.ArgumentParser(description='MASS train entry point.') -parser.add_argument("--config", type=str, required=True, help="model config json file path.") - -device_id = os.getenv('DEVICE_ID', None) -if device_id is None: - raise RuntimeError("`DEVICE_ID` can not be None.") - -device_id = int(device_id) -context.set_context( - mode=context.GRAPH_MODE, - device_target="Ascend", - reserve_class_name_in_scope=False, - device_id=device_id) - - -def get_config(config): - config = TransformerConfig.from_json_file(config) - config.compute_type = mstype.float16 - config.dtype = mstype.float32 - return config - - -def _train(model, config: TransformerConfig, - pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None, - callbacks: list = None): - """ - Train model. - - Args: - model (Model): MindSpore model instance. - config (TransformerConfig): Config of mass model. - pre_training_dataset (Dataset): Pre-training dataset. - fine_tune_dataset (Dataset): Fine-tune dataset. - test_dataset (Dataset): Test dataset. - callbacks (list): A list of callbacks. - """ - callbacks = callbacks if callbacks else [] - - if pre_training_dataset is not None: - print(" | Start pre-training job.") - epoch_size = pre_training_dataset.get_repeat_count() - if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: - print(f" | Rank {MultiAscend.get_rank()} Call model train.") - model.train(epoch_size, pre_training_dataset, - callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode) - # Test the accuracy of the model. - if test_dataset is not None: - print(" | Start test job.") - result = infer(_config) - with open("validation_res_after_pre_training.bin", "wb") as f: - pickle.dump(result, f, 1) - - if fine_tune_dataset is not None: - print(" | Start fine-tuning job.") - epoch_size = fine_tune_dataset.get_repeat_count() - - model.train(epoch_size, fine_tune_dataset, - callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode) - - # Test the accuracy of the model. - if test_dataset is not None: - print(" | Start test job.") - result = infer(_config) - with open("validation_res_after_pre_training.bin", "wb") as f: - pickle.dump(result, f, 1) - - -def _build_training_pipeline(config: TransformerConfig, - pre_training_dataset=None, - fine_tune_dataset=None, - test_dataset=None): - """ - Build training pipeline. - - Args: - config (TransformerConfig): Config of mass model. - pre_training_dataset (Dataset): Pre-training dataset. - fine_tune_dataset (Dataset): Fine-tune dataset. - test_dataset (Dataset): Test dataset. - """ - net_with_loss = TransformerNetworkWithLoss(config, is_training=True) - net_with_loss.init_parameters_data() - - if config.existed_ckpt: - if config.existed_ckpt.endswith(".npz"): - weights = np.load(config.existed_ckpt) - else: - weights = load_checkpoint(config.existed_ckpt) - for param in net_with_loss.trainable_params(): - weights_name = param.name - if weights_name not in weights: - raise ValueError(f"Param {weights_name} is not found in ckpt file.") - - if isinstance(weights[weights_name], Parameter): - param.default_input = weights[weights_name].default_input - elif isinstance(weights[weights_name], Tensor): - param.default_input = Tensor(weights[weights_name].asnumpy(), config.dtype) - elif isinstance(weights[weights_name], np.ndarray): - param.default_input = Tensor(weights[weights_name], config.dtype) - else: - param.default_input = weights[weights_name] - else: - for param in net_with_loss.trainable_params(): - name = param.name - value = param.default_input - if isinstance(value, Tensor): - if name.endswith(".gamma"): - param.default_input = one_weight(value.asnumpy().shape) - elif name.endswith(".beta") or name.endswith(".bias"): - param.default_input = zero_weight(value.asnumpy().shape) - else: - param.default_input = weight_variable(value.asnumpy().shape) - - dataset = pre_training_dataset if pre_training_dataset is not None \ - else fine_tune_dataset - - if dataset is None: - raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.") - - update_steps = dataset.get_repeat_count() * dataset.get_dataset_size() - if config.lr_scheduler == "isr": - lr = Tensor(square_root_schedule(lr=config.lr, - update_num=update_steps, - decay_start_step=config.decay_start_step, - warmup_steps=config.warmup_steps, - min_lr=config.min_lr), dtype=mstype.float32) - elif config.lr_scheduler == "poly": - lr = Tensor(polynomial_decay_scheduler(lr=config.lr, - min_lr=config.min_lr, - decay_steps=config.decay_steps, - total_update_num=update_steps, - warmup_steps=config.warmup_steps, - power=config.poly_lr_scheduler_power), dtype=mstype.float32) - else: - lr = config.lr - - if config.optimizer.lower() == "adam": - optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98) - elif config.optimizer.lower() == "lamb": - optimizer = Lamb(net_with_loss.trainable_params(), decay_steps=12000, - start_learning_rate=config.lr, end_learning_rate=config.min_lr, - power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01, - eps=1e-6) - elif config.optimizer.lower() == "momentum": - optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9) - else: - raise ValueError(f"optimizer only support `adam` and `momentum` now.") - - # Dynamic loss scale. - scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale, - scale_factor=config.loss_scale_factor, - scale_window=config.scale_window) - net_with_grads = TransformerTrainOneStepWithLossScaleCell( - network=net_with_loss, optimizer=optimizer, - scale_update_cell=scale_manager.get_update_cell() - ) - net_with_grads.set_train(True) - model = Model(net_with_grads) - loss_monitor = LossCallBack(config) - ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps, - keep_checkpoint_max=config.keep_ckpt_max) - - rank_size = os.getenv('RANK_SIZE') - callbacks = [loss_monitor] - if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: - ckpt_callback = ModelCheckpoint( - prefix=config.ckpt_prefix, - directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), - config=ckpt_config) - callbacks.append(ckpt_callback) - - if rank_size is None or int(rank_size) == 1: - ckpt_callback = ModelCheckpoint( - prefix=config.ckpt_prefix, - directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), - config=ckpt_config) - callbacks.append(ckpt_callback) - - print(f" | ALL SET, PREPARE TO TRAIN.") - _train(model=model, config=config, - pre_training_dataset=pre_training_dataset, - fine_tune_dataset=fine_tune_dataset, - test_dataset=test_dataset, - callbacks=callbacks) - - -def _setup_parallel_env(): - context.reset_auto_parallel_context() - MultiAscend.init() - context.set_auto_parallel_context( - parallel_mode=ParallelMode.DATA_PARALLEL, - device_num=MultiAscend.get_group_size(), - parameter_broadcast=True, - mirror_mean=True - ) - - -def train_parallel(config: TransformerConfig): - """ - Train model with multi ascend chips. - - Args: - config (TransformerConfig): Config for MASS model. - """ - _setup_parallel_env() - - print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.") - - pre_train_dataset = load_dataset( - data_files=config.pre_train_dataset, - batch_size=config.batch_size, epoch_count=config.epochs, - sink_mode=config.dataset_sink_mode, - sink_step=config.dataset_sink_step, - rank_size=MultiAscend.get_group_size(), - rank_id=MultiAscend.get_rank() - ) if config.pre_train_dataset else None - fine_tune_dataset = load_dataset( - data_files=config.fine_tune_dataset, - batch_size=config.batch_size, epoch_count=config.epochs, - sink_mode=config.dataset_sink_mode, - sink_step=config.dataset_sink_step, - rank_size=MultiAscend.get_group_size(), - rank_id=MultiAscend.get_rank() - ) if config.fine_tune_dataset else None - test_dataset = load_dataset( - data_files=config.test_dataset, - batch_size=config.batch_size, epoch_count=config.epochs, - sink_mode=config.dataset_sink_mode, - sink_step=config.dataset_sink_step, - rank_size=MultiAscend.get_group_size(), - rank_id=MultiAscend.get_rank() - ) if config.test_dataset else None - - _build_training_pipeline(config=config, - pre_training_dataset=pre_train_dataset, - fine_tune_dataset=fine_tune_dataset, - test_dataset=test_dataset) - - -def train_single(config: TransformerConfig): - """ - Train model on single device. - - Args: - config (TransformerConfig): Config for model. - """ - print(" | Starting training on single device.") - pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, - batch_size=config.batch_size, - epoch_count=config.epochs, - sink_mode=config.dataset_sink_mode, - sink_step=config.dataset_sink_step) if config.pre_train_dataset else None - fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, - batch_size=config.batch_size, - epoch_count=config.epochs, - sink_mode=config.dataset_sink_mode, - sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None - test_dataset = load_dataset(data_files=config.test_dataset, - batch_size=config.batch_size, - epoch_count=config.epochs, - sink_mode=config.dataset_sink_mode, - sink_step=config.dataset_sink_step) if config.test_dataset else None - - _build_training_pipeline(config=config, - pre_training_dataset=pre_train_dataset, - fine_tune_dataset=fine_tune_dataset, - test_dataset=test_dataset) - - -def _check_args(config): - if not os.path.exists(config): - raise FileNotFoundError("`config` is not existed.") - if not isinstance(config, str): - raise ValueError("`config` must be type of str.") - - -if __name__ == '__main__': - _rank_size = os.getenv('RANK_SIZE') - - args, _ = parser.parse_known_args() - _check_args(args.config) - _config = get_config(args.config) - - np.random.seed(_config.random_seed) - context.set_context(save_graphs=_config.save_graphs) - - if _rank_size is not None and int(_rank_size) > 1: - train_parallel(_config) - else: - train_single(_config) diff --git a/model_zoo/mobilenetv2/Readme.md b/model_zoo/mobilenetv2/Readme.md deleted file mode 100644 index 1687d2cbdc..0000000000 --- a/model_zoo/mobilenetv2/Readme.md +++ /dev/null @@ -1,151 +0,0 @@ -# MobileNetV2 Description - - -MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. - -[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. - -# Model architecture - -The overall network architecture of MobileNetV2 is show below: - -[Link](https://arxiv.org/pdf/1905.02244) - -# Dataset - -Dataset used: [imagenet](http://www.image-net.org/) - -- Dataset size: ~125G, 1.2W colorful images in 1000 classes - - Train: 120G, 1.2W images - - Test: 5G, 50000 images -- Data format: RGB images. - - Note: Data will be processed in src/dataset.py - - -# Features - - -# Environment Requirements - -- Hardware(Ascend/GPU) - - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. -- Framework - - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) -- For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - - -# Script description - -## Script and sample code - -```python -├── MobileNetV2 - ├── Readme.md - ├── scripts - │ ├──run_train.sh - │ ├──run_eval.sh - ├── src - │ ├──config.py - │ ├──dataset.py - │ ├──luanch.py - │ ├──lr_generator.py - │ ├──mobilenetV2.py - ├── train.py - ├── eval.py -``` - -## Training process - -### Usage - -- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [CKPT_PATH] -- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] - -### Launch - -``` -# training example - Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt - GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ -``` - -### Result - -Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. - -``` -epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] -epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 -epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] -epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 -``` - -## Eval process - -### Usage - -- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] -- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] - -### Launch - -``` -# infer example - Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt - GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt -``` - -> checkpoint can be produced in training process. - -### Result - -Inference result will be stored in the example path, you can find result like the followings in `val.log`. - -``` -result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt -``` - -# Model description - -## Performance - -### Training Performance - -| Parameters | MobilenetV2 | | -| -------------------------- | ---------------------------------------------------------- | ------------------------- | -| Model Version | | large | -| Resource | Ascend 910, cpu:2.60GHz 56cores, memory:314G | NV SMX2 V100-32G | -| uploaded Date | 05/06/2020 | 05/06/2020 | -| MindSpore Version | 0.3.0 | 0.3.0 | -| Dataset | ImageNet | ImageNet | -| Training Parameters | src/config.py | src/config.py | -| Optimizer | Momentum | Momentum | -| Loss Function | SoftmaxCrossEntropy | SoftmaxCrossEntropy | -| outputs | | | -| Loss | | 1.913 | -| Accuracy | | ACC1[77.09%] ACC5[92.57%] | -| Total time | | | -| Params (M) | | | -| Checkpoint for Fine tuning | | | -| Model for inference | | | - -#### Inference Performance - -| Parameters | | | | -| -------------------------- | ----------------------------- | ------------------------- | -------------------- | -| Model Version | V1 | | | -| Resource | Huawei 910 | NV SMX2 V100-32G | Huawei 310 | -| uploaded Date | 05/06/2020 | 05/22/2020 | | -| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 | -| Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W | -| batch_size | | 130(8P) | | -| outputs | | | | -| Accuracy | | ACC1[72.07%] ACC5[90.90%] | | -| Speed | | | | -| Total time | | | | -| Model for inference | | | | - -# ModelZoo Homepage - [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) \ No newline at end of file diff --git a/model_zoo/mobilenetv2/scripts/run_train.sh b/model_zoo/mobilenetv2/scripts/run_train.sh deleted file mode 100644 index a6e2a79477..0000000000 --- a/model_zoo/mobilenetv2/scripts/run_train.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -run_ascend() -{ - if [ $2 -lt 1 ] && [ $2 -gt 8 ] - then - echo "error: DEVICE_NUM=$2 is not in (1-8)" - exit 1 - fi - - if [ ! -d $5 ] && [ ! -f $5 ] - then - echo "error: DATASET_PATH=$5 is not a directory or file" - exit 1 - fi - - BASEPATH=$(cd "`dirname $0`" || exit; pwd) - export PYTHONPATH=${BASEPATH}:$PYTHONPATH - export MINDSPORE_HCCL_CONFIG_PATH=$4 - export RANK_TABLE_FILE=$4 - if [ -d "../train" ]; - then - rm -rf ../train - fi - mkdir ../train - cd ../train || exit - python ${BASEPATH}/../src/launch.py \ - --nproc_per_node=$2 \ - --visible_devices=$3 \ - --training_script=${BASEPATH}/../train.py \ - --dataset_path=$5 \ - --pre_trained=$6 \ - --platform=$1 &> ../train.log & # dataset train folder -} - -run_gpu() -{ - if [ $2 -lt 1 ] && [ $2 -gt 8 ] - then - echo "error: DEVICE_NUM=$2 is not in (1-8)" - exit 1 - fi - - if [ ! -d $4 ] - then - echo "error: DATASET_PATH=$4 is not a directory" - exit 1 - fi - - BASEPATH=$(cd "`dirname $0`" || exit; pwd) - export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "../train" ]; - then - rm -rf ../train - fi - mkdir ../train - cd ../train || exit - - export CUDA_VISIBLE_DEVICES="$3" - mpirun -n $2 --allow-run-as-root \ - python ${BASEPATH}/../train.py \ - --dataset_path=$4 \ - --platform=$1 \ - &> ../train.log & # dataset train folder -} - -if [ $# -gt 6 ] || [ $# -lt 4 ] -then - echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [CKPT_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ - " -exit 1 -fi - -if [ $1 = "Ascend" ] ; then - run_ascend "$@" -elif [ $1 = "GPU" ] ; then - run_gpu "$@" -else - echo "not support platform" -fi; - diff --git a/model_zoo/mobilenetv2/src/mobilenetV2.py b/model_zoo/mobilenetv2/src/mobilenetV2.py deleted file mode 100644 index 5b1b4cc5ef..0000000000 --- a/model_zoo/mobilenetv2/src/mobilenetV2.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2 model define""" -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.ops.operations import TensorAdd -from mindspore import Parameter, Tensor -from mindspore.common.initializer import initializer - -__all__ = ['mobilenet_v2'] - - -def _make_divisible(v, divisor, min_value=None): - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - - def __init__(self): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=False) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - - def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) - else: - if platform == "Ascend": - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - elif platform == "GPU": - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, - group=in_planes, pad_mode='pad', padding=padding) - - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] - self.features = nn.SequentialCell(layers) - - def construct(self, x): - output = self.features(x) - return output - - -class InvertedResidual(nn.Cell): - """ - Mobilenetv2 residual block definition. - - Args: - inp (int): Input channel. - oup (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - expand_ratio (int): expand ration of input channel - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock(3, 256, 1, 1) - """ - - def __init__(self, platform, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(platform, hidden_dim, hidden_dim, - stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, kernel_size=1, - stride=1, has_bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.SequentialCell(layers) - self.add = TensorAdd() - self.cast = P.Cast() - - def construct(self, x): - identity = x - x = self.conv(x) - if self.use_res_connect: - return self.add(identity, x) - return x - - -class MobileNetV2(nn.Cell): - """ - MobileNetV2 architecture. - - Args: - class_num (Cell): number of classes. - width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. - has_dropout (bool): Is dropout used. Default is false - inverted_residual_setting (list): Inverted residual settings. Default is None - round_nearest (list): Channel round to . Default is 8 - Returns: - Tensor, output tensor. - - Examples: - >>> MobileNetV2(num_classes=1000) - """ - - def __init__(self, platform, num_classes=1000, width_mult=1., - has_dropout=False, inverted_residual_setting=None, round_nearest=8): - super(MobileNetV2, self).__init__() - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - # setting of inverted residual blocks - self.cfgs = inverted_residual_setting - if inverted_residual_setting is None: - self.cfgs = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(platform, 3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in self.cfgs: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) - # make it nn.CellList - self.features = nn.SequentialCell(features) - # mobilenet head - head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else - [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) - self.head = nn.SequentialCell(head) - - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.head(x) - return x - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d, DepthwiseConv)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data( - Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_parameter_data( - Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - -def mobilenet_v2(**kwargs): - """ - Constructs a MobileNet V2 model - """ - return MobileNetV2(**kwargs) diff --git a/model_zoo/mobilenetv2/train.py b/model_zoo/mobilenetv2/train.py deleted file mode 100644 index 4ae743f540..0000000000 --- a/model_zoo/mobilenetv2/train.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train_imagenet.""" -import os -import time -import argparse -import random -import numpy as np - -from mindspore import context -from mindspore import Tensor -from mindspore import nn -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.nn.optim.momentum import Momentum -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.nn.loss.loss import _Loss -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.common import dtype as mstype -from mindspore.train.model import Model, ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init, get_group_size, get_rank -import mindspore.dataset.engine as de - -from src.dataset import create_dataset -from src.lr_generator import get_lr -from src.config import config_gpu, config_ascend -from src.mobilenetV2 import mobilenet_v2 - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') -parser.add_argument('--platform', type=str, default=None, help='run platform') -args_opt = parser.parse_args() - -if args_opt.platform == "Ascend": - device_id = int(os.getenv('DEVICE_ID')) - rank_id = int(os.getenv('RANK_ID')) - rank_size = int(os.getenv('RANK_SIZE')) - run_distribute = rank_size > 1 - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - device_id=device_id, save_graphs=False) -elif args_opt.platform == "GPU": - context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", - save_graphs=False) - init("nccl") - context.set_auto_parallel_context(device_num=get_group_size(), - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) -else: - raise ValueError("Unsupported device target.") - - -class CrossEntropyWithLabelSmooth(_Loss): - """ - CrossEntropyWith LabelSmooth. - - Args: - smooth_factor (float): smooth factor, default=0. - num_classes (int): num classes - - Returns: - None. - - Examples: - >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) - """ - - def __init__(self, smooth_factor=0., num_classes=1000): - super(CrossEntropyWithLabelSmooth, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) - self.off_value = Tensor(1.0 * smooth_factor / - (num_classes - 1), mstype.float32) - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean(False) - self.cast = P.Cast() - - def construct(self, logit, label): - one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], - self.on_value, self.off_value) - out_loss = self.ce(logit, one_hot_label) - out_loss = self.mean(out_loss, 0) - return out_loss - - -class Monitor(Callback): - """ - Monitor loss and time. - - Args: - lr_init (numpy array): train lr - - Returns: - None - - Examples: - >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) - """ - - def __init__(self, lr_init=None): - super(Monitor, self).__init__() - self.lr_init = lr_init - self.lr_init_len = len(lr_init) - - def epoch_begin(self, run_context): - self.losses = [] - self.epoch_time = time.time() - - def epoch_end(self, run_context): - cb_params = run_context.original_args() - - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / cb_params.batch_num - print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, - per_step_mseconds, - np.mean(self.losses))) - - def step_begin(self, run_context): - self.step_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - step_mseconds = (time.time() - self.step_time) * 1000 - step_loss = cb_params.net_outputs - - if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): - step_loss = step_loss[0] - if isinstance(step_loss, Tensor): - step_loss = np.mean(step_loss.asnumpy()) - - self.losses.append(step_loss) - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num - - print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( - cb_params.cur_epoch_num - - 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, - np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) - - -if __name__ == '__main__': - if args_opt.platform == "GPU": - # train on gpu - print("train args: ", args_opt) - print("cfg: ", config_gpu) - - # define net - net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") - # define loss - if config_gpu.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits( - is_grad=False, sparse=True, reduction='mean') - # define dataset - epoch_size = config_gpu.epoch_size - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_gpu, - platform=args_opt.platform, - repeat_num=epoch_size, - batch_size=config_gpu.batch_size) - step_size = dataset.get_dataset_size() - # resume - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - # define optimizer - loss_scale = FixedLossScaleManager( - config_gpu.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_gpu.lr, - warmup_epochs=config_gpu.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, - config_gpu.weight_decay, config_gpu.loss_scale) - # define model - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = [Monitor(lr_init=lr.asnumpy())] - ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" - if config_gpu.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_gpu.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) - cb += [ckpt_cb] - # begin train - model.train(epoch_size, dataset, callbacks=cb) - elif args_opt.platform == "Ascend": - # train on ascend - print("train args: ", args_opt, "\ncfg: ", config_ascend, - "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) - - if run_distribute: - context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, mirror_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) - init() - - epoch_size = config_ascend.epoch_size - net = mobilenet_v2(num_classes=config_ascend.num_classes, platform="Ascend") - net.to_float(mstype.float16) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - if config_ascend.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits( - is_grad=False, sparse=True, reduction='mean') - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_ascend, - platform=args_opt.platform, - repeat_num=epoch_size, - batch_size=config_ascend.batch_size) - step_size = dataset.get_dataset_size() - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - - loss_scale = FixedLossScaleManager( - config_ascend.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_ascend.lr, - warmup_epochs=config_ascend.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, - config_ascend.weight_decay, config_ascend.loss_scale) - - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = None - if rank_id == 0: - cb = [Monitor(lr_init=lr.asnumpy())] - if config_ascend.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_ascend.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint( - prefix="mobilenetV2", directory=config_ascend.save_checkpoint_path, config=config_ck) - cb += [ckpt_cb] - model.train(epoch_size, dataset, callbacks=cb) - else: - raise ValueError("Unsupport platform.") diff --git a/model_zoo/mobilenetv2_quant/Readme.md b/model_zoo/mobilenetv2_quant/Readme.md deleted file mode 100644 index 81be5d519c..0000000000 --- a/model_zoo/mobilenetv2_quant/Readme.md +++ /dev/null @@ -1,142 +0,0 @@ -# MobileNetV2 Quantization Aware Training - -MobileNetV2 is a significant improvement over MobileNetV1 and pushes the state of the art for mobile visual recognition including classification, object detection and semantic segmentation. - -MobileNetV2 builds upon the ideas from MobileNetV1, using depthwise separable convolution as efficient building blocks. However, V2 introduces two new features to the architecture: 1) linear bottlenecks between the layers, and 2) shortcut connections between the bottlenecks1. - -Training MobileNetV2 with ImageNet dataset in MindSpore with quantization aware training. - -This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. - -In this readme tutorial, you will: - -1. Train a MindSpore fusion MobileNetV2 model for ImageNet from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. -2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. - -[Paper](https://arxiv.org/pdf/1801.04381) Sandler, Mark, et al. "Mobilenetv2: Inverted residuals and linear bottlenecks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018. - -# Dataset - -Dataset use: ImageNet - -- Dataset size: about 125G - - Train: 120G, 1281167 images: 1000 directories - - Test: 5G, 50000 images: images should be classified into 1000 directories firstly, just like train images -- Data format: RGB images. - - Note: Data will be processed in src/dataset.py - -# Environment Requirements - -- Hardware(Ascend) - - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. -- Framework - - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) -- For more information, please check the resources below: - - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - - -# Script description - -## Script and sample code - -```python -├── mobilenetv2_quant - ├── Readme.md - ├── scripts - │ ├──run_train.sh - │ ├──run_infer.sh - │ ├──run_train_quant.sh - │ ├──run_infer_quant.sh - ├── src - │ ├──config.py - │ ├──dataset.py - │ ├──luanch.py - │ ├──lr_generator.py - │ ├──mobilenetV2.py - ├── train.py - ├── eval.py -``` - -## Training process - -### Train MobileNetV2 model - -Train a MindSpore fusion MobileNetV2 model for ImageNet, like: - -- sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] - -You can just run this command instead. - -``` bash ->>> sh run_train.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt -``` - -Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. - -``` ->>> epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] ->>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 ->>> epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] ->>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 -``` - -### Evaluate MobileNetV2 model - -Evaluate a MindSpore fusion MobileNetV2 model for ImageNet, like: - -- sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] - -You can just run this command instead. - -``` bash ->>> sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt -``` - -Inference result will be stored in the example path, you can find result like the followings in `val.log`. - -``` ->>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt -``` - -### Fine-tune for quantization aware training - -Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. - -- sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] - -You can just run this command instead. - -``` bash ->>> sh run_train_quant.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt -``` - -Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. - -``` ->>> epoch: [ 0/60], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] ->>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 ->>> epoch: [ 1/60], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] ->>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 -``` - -### Evaluate quantization aware training model - -Evaluate a MindSpore fusion MobileNetV2 model for ImageNet by applying the quantization aware training, like: - -- sh run_infer_quant.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] - -You can just run this command instead. - -``` bash ->>> sh run_infer_quant.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_625.ckpt -``` - -Inference result will be stored in the example path, you can find result like the followings in `val.log`. - -``` ->>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-60_625.ckpt -``` - -# ModelZoo Homepage - [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/mobilenetv2_quant/scripts/run_train.sh b/model_zoo/mobilenetv2_quant/scripts/run_train.sh deleted file mode 100644 index 59b105f92e..0000000000 --- a/model_zoo/mobilenetv2_quant/scripts/run_train.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -run_ascend() -{ - if [ $2 -lt 1 ] && [ $2 -gt 8 ] - then - echo "error: DEVICE_NUM=$2 is not in (1-9)" - exit 1 - fi - - if [ ! -d $5 ] && [ ! -f $5 ] - then - echo "error: DATASET_PATH=$5 is not a directory or file" - exit 1 - fi - - BASEPATH=$(cd "`dirname $0`" || exit; pwd) - export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "../train" ]; - then - rm -rf ../train - fi - mkdir ../train - cd ../train || exit - python ${BASEPATH}/../src/launch.py \ - --nproc_per_node=$2 \ - --visible_devices=$4 \ - --server_id=$3 \ - --training_script=${BASEPATH}/../train.py \ - --dataset_path=$5 \ - --pre_trained=$6 \ - --device_target=$1 &> train.log & # dataset train folder -} - -if [ $# -gt 6 ] || [ $# -lt 4 ] -then - echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ - " -exit 1 -fi - -if [ $1 = "Ascend" ] ; then - run_ascend "$@" -else - echo "Unsupported device target." -fi; - diff --git a/model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh b/model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh deleted file mode 100644 index c82d1b0da5..0000000000 --- a/model_zoo/mobilenetv2_quant/scripts/run_train_quant.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -run_ascend() -{ - if [ $2 -lt 1 ] && [ $2 -gt 8 ] - then - echo "error: DEVICE_NUM=$2 is not in (1-9)" - exit 1 - fi - - if [ ! -d $5 ] && [ ! -f $5 ] - then - echo "error: DATASET_PATH=$5 is not a directory or file" - exit 1 - fi - - BASEPATH=$(cd "`dirname $0`" || exit; pwd) - export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "../train" ]; - then - rm -rf ../train - fi - mkdir ../train - cd ../train || exit - python ${BASEPATH}/../src/launch.py \ - --nproc_per_node=$2 \ - --visible_devices=$4 \ - --server_id=$3 \ - --training_script=${BASEPATH}/../train.py \ - --dataset_path=$5 \ - --pre_trained=$6 \ - --quantization_aware=True \ - --device_target=$1 &> train.log & # dataset train folder -} - -if [ $# -gt 6 ] || [ $# -lt 4 ] -then - echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ - " -exit 1 -fi - -if [ $1 = "Ascend" ] ; then - run_ascend "$@" -else - echo "Unsupported device target." -fi; - diff --git a/model_zoo/mobilenetv2_quant/src/config.py b/model_zoo/mobilenetv2_quant/src/config.py deleted file mode 100644 index 97fbc52e12..0000000000 --- a/model_zoo/mobilenetv2_quant/src/config.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in train.py and eval.py -""" -from easydict import EasyDict as ed - -config_ascend = ed({ - "num_classes": 1000, - "image_height": 224, - "image_width": 224, - "batch_size": 256, - "data_load_mode": "mindrecord", - "epoch_size": 200, - "start_epoch": 0, - "warmup_epochs": 4, - "lr": 0.4, - "momentum": 0.9, - "weight_decay": 4e-5, - "label_smooth": 0.1, - "loss_scale": 1024, - "save_checkpoint": True, - "save_checkpoint_epochs": 1, - "keep_checkpoint_max": 200, - "save_checkpoint_path": "./checkpoint", - "quantization_aware": False, -}) - -config_ascend_quant = ed({ - "num_classes": 1000, - "image_height": 224, - "image_width": 224, - "batch_size": 192, - "data_load_mode": "mindrecord", - "epoch_size": 60, - "start_epoch": 200, - "warmup_epochs": 1, - "lr": 0.3, - "momentum": 0.9, - "weight_decay": 4e-5, - "label_smooth": 0.1, - "loss_scale": 1024, - "save_checkpoint": True, - "save_checkpoint_epochs": 1, - "keep_checkpoint_max": 200, - "save_checkpoint_path": "./checkpoint", - "quantization_aware": True, -}) diff --git a/model_zoo/mobilenetv2_quant/src/launch.py b/model_zoo/mobilenetv2_quant/src/launch.py deleted file mode 100644 index 08477a363a..0000000000 --- a/model_zoo/mobilenetv2_quant/src/launch.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""launch train script""" -import os -import sys -import json -import subprocess -import shutil -import platform -from argparse import ArgumentParser - - -def parse_args(): - """ - parse args . - - Args: - - Returns: - args. - - Examples: - >>> parse_args() - """ - parser = ArgumentParser(description="mindspore distributed training launch " - "helper utilty that will spawn up " - "multiple distributed processes") - parser.add_argument("--nproc_per_node", type=int, default=1, - help="The number of processes to launch on each node, " - "for D training, this is recommended to be set " - "to the number of D in your system so that " - "each process can be bound to a single D.") - parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", - help="will use the visible devices sequentially") - parser.add_argument("--server_id", type=str, default="", - help="server ip") - parser.add_argument("--training_script", type=str, - help="The full path to the single D training " - "program/script to be launched in parallel, " - "followed by all the arguments for the " - "training script") - # rest from the training program - args, unknown = parser.parse_known_args() - args.training_script_args = unknown - return args - - -def main(): - print("start", __file__) - args = parse_args() - print(args) - visible_devices = args.visible_devices.split(',') - assert os.path.isfile(args.training_script) - assert len(visible_devices) >= args.nproc_per_node - print('visible_devices:{}'.format(visible_devices)) - if not args.server_id: - print('pleaser input server ip!!!') - exit(0) - print('server_id:{}'.format(args.server_id)) - - # construct hccn_table - hccn_configs = open('/etc/hccn.conf', 'r').readlines() - device_ips = {} - for hccn_item in hccn_configs: - hccn_item = hccn_item.strip() - if hccn_item.startswith('address_'): - device_id, device_ip = hccn_item.split('=') - device_id = device_id.split('_')[1] - device_ips[device_id] = device_ip - print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) - hccn_table = {} - arch = platform.processor() - hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch] - hccn_table['chip_info'] = '910' - hccn_table['deploy_mode'] = 'lab' - hccn_table['group_count'] = '1' - hccn_table['group_list'] = [] - instance_list = [] - usable_dev = '' - for instance_id in range(args.nproc_per_node): - instance = {} - instance['devices'] = [] - device_id = visible_devices[instance_id] - device_ip = device_ips[device_id] - usable_dev += str(device_id) - instance['devices'].append({ - 'device_id': device_id, - 'device_ip': device_ip, - }) - instance['rank_id'] = str(instance_id) - instance['server_id'] = args.server_id - instance_list.append(instance) - hccn_table['group_list'].append({ - 'device_num': str(args.nproc_per_node), - 'server_num': '1', - 'group_name': '', - 'instance_count': str(args.nproc_per_node), - 'instance_list': instance_list, - }) - hccn_table['para_plane_nic_location'] = 'device' - hccn_table['para_plane_nic_name'] = [] - for instance_id in range(args.nproc_per_node): - eth_id = visible_devices[instance_id] - hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) - hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) - hccn_table['status'] = 'completed' - - # save hccn_table to file - table_path = os.getcwd() - if not os.path.exists(table_path): - os.mkdir(table_path) - table_fn = os.path.join(table_path, - 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) - with open(table_fn, 'w') as table_fp: - json.dump(hccn_table, table_fp, indent=4) - sys.stdout.flush() - - # spawn the processes - processes = [] - cmds = [] - log_files = [] - env = os.environ.copy() - env['RANK_SIZE'] = str(args.nproc_per_node) - cur_path = os.getcwd() - for rank_id in range(0, args.nproc_per_node): - os.chdir(cur_path) - device_id = visible_devices[rank_id] - device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) - env['RANK_ID'] = str(rank_id) - env['DEVICE_ID'] = str(device_id) - if args.nproc_per_node > 1: - env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn - env['RANK_TABLE_FILE'] = table_fn - if os.path.exists(device_dir): - shutil.rmtree(device_dir) - os.mkdir(device_dir) - os.chdir(device_dir) - cmd = [sys.executable, '-u'] - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') - process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) - processes.append(process) - cmds.append(cmd) - log_files.append(log_file) - for process, cmd, log_file in zip(processes, cmds, log_files): - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process, cmd=cmd) - log_file.close() - - -if __name__ == "__main__": - main() diff --git a/model_zoo/mobilenetv2_quant/src/mobilenetV2.py b/model_zoo/mobilenetv2_quant/src/mobilenetV2.py deleted file mode 100644 index 25dccfed10..0000000000 --- a/model_zoo/mobilenetv2_quant/src/mobilenetV2.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2 Quant model define""" - -import numpy as np - -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore import Tensor - -__all__ = ['mobilenetV2'] - - -def _make_divisible(v, divisor, min_value=None): - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - - def __init__(self): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=False) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, - stride=stride, - pad_mode='pad', - padding=padding, - group=groups, - has_bn=True, - activation='relu') - - def construct(self, x): - x = self.conv(x) - return x - - -class InvertedResidual(nn.Cell): - """ - Mobilenetv2 residual block definition. - - Args: - inp (int): Input channel. - oup (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - expand_ratio (int): expand ration of input channel - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock(3, 256, 1, 1) - """ - - def __init__(self, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) - ]) - self.conv = nn.SequentialCell(layers) - self.add = P.TensorAdd() - - def construct(self, x): - out = self.conv(x) - if self.use_res_connect: - out = self.add(out, x) - return out - - -class mobilenetV2(nn.Cell): - """ - mobilenetV2 fusion architecture. - - Args: - class_num (Cell): number of classes. - width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. - has_dropout (bool): Is dropout used. Default is false - inverted_residual_setting (list): Inverted residual settings. Default is None - round_nearest (list): Channel round to . Default is 8 - Returns: - Tensor, output tensor. - - Examples: - >>> mobilenetV2(num_classes=1000) - """ - - def __init__(self, num_classes=1000, width_mult=1., - has_dropout=False, inverted_residual_setting=None, round_nearest=8): - super(mobilenetV2, self).__init__() - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - # setting of inverted residual blocks - self.cfgs = inverted_residual_setting - if inverted_residual_setting is None: - self.cfgs = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - - features = [ConvBNReLU(3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in self.cfgs: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) - # make it nn.CellList - self.features = nn.SequentialCell(features) - # mobilenet head - head = ([GlobalAvgPooling(), - nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) - ] if not has_dropout else - [GlobalAvgPooling(), - nn.Dropout(0.2), - nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) - ]) - self.head = nn.SequentialCell(head) - - # init weights - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.head(x) - return x - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32")) - m.weight.set_parameter_data(w) - if m.bias is not None: - m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) diff --git a/model_zoo/mobilenetv2_quant/train.py b/model_zoo/mobilenetv2_quant/train.py deleted file mode 100644 index 1302c3cf27..0000000000 --- a/model_zoo/mobilenetv2_quant/train.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Train mobilenetV2 on ImageNet""" - -import os -import argparse -import random -import numpy as np - -from mindspore import context -from mindspore import Tensor -from mindspore import nn -from mindspore.train.model import Model, ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init -from mindspore.train.quant import quant -import mindspore.dataset.engine as de - -from src.dataset import create_dataset -from src.lr_generator import get_lr -from src.utils import Monitor, CrossEntropyWithLabelSmooth -from src.config import config_ascend, config_ascend_quant -from src.mobilenetV2 import mobilenetV2 - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') -parser.add_argument('--device_target', type=str, default=None, help='Run device target') -parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training') -args_opt = parser.parse_args() - -if args_opt.device_target == "Ascend": - device_id = int(os.getenv('DEVICE_ID')) - rank_id = int(os.getenv('RANK_ID')) - rank_size = int(os.getenv('RANK_SIZE')) - run_distribute = rank_size > 1 - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - device_id=device_id, save_graphs=False) -else: - raise ValueError("Unsupported device target.") - -if __name__ == '__main__': - # train on ascend - config = config_ascend_quant if args_opt.quantization_aware else config_ascend - print("training args: {}".format(args_opt)) - print("training configure: {}".format(config)) - print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) - epoch_size = config.epoch_size - - # distribute init - if run_distribute: - context.set_auto_parallel_context(device_num=rank_size, - parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, - mirror_mean=True) - init() - - # define network - network = mobilenetV2(num_classes=config.num_classes) - # define loss - if config.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes) - else: - loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') - # define dataset - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config, - device_target=args_opt.device_target, - repeat_num=epoch_size, - batch_size=config.batch_size) - step_size = dataset.get_dataset_size() - # load pre trained ckpt - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(network, param_dict) - - # convert fusion network to quantization aware network - if config.quantization_aware: - network = quant.convert_quant_network(network, - bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) - - # get learning rate - lr = Tensor(get_lr(global_step=config.start_epoch * step_size, - lr_init=0, - lr_end=0, - lr_max=config.lr, - warmup_epochs=config.warmup_epochs, - total_epochs=epoch_size + config.start_epoch, - steps_per_epoch=step_size)) - - # define optimization - opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum, - config.weight_decay) - # define model - model = Model(network, loss_fn=loss, optimizer=opt) - - print("============== Starting Training ==============") - callback = None - if rank_id == 0: - callback = [Monitor(lr_init=lr.asnumpy())] - if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", - directory=config.save_checkpoint_path, - config=config_ck) - callback += [ckpt_cb] - model.train(epoch_size, dataset, callbacks=callback) - print("============== End Training ==============") diff --git a/model_zoo/mobilenetv3/src/launch.py b/model_zoo/mobilenetv3/src/launch.py deleted file mode 100644 index 48c8159664..0000000000 --- a/model_zoo/mobilenetv3/src/launch.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""launch train script""" -import os -import sys -import json -import subprocess -import shutil -from argparse import ArgumentParser - -def parse_args(): - """ - parse args . - - Args: - - Returns: - args. - - Examples: - >>> parse_args() - """ - parser = ArgumentParser(description="mindspore distributed training launch " - "helper utilty that will spawn up " - "multiple distributed processes") - parser.add_argument("--nproc_per_node", type=int, default=1, - help="The number of processes to launch on each node, " - "for D training, this is recommended to be set " - "to the number of D in your system so that " - "each process can be bound to a single D.") - parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", - help="will use the visible devices sequentially") - parser.add_argument("--server_id", type=str, default="", - help="server ip") - parser.add_argument("--training_script", type=str, - help="The full path to the single D training " - "program/script to be launched in parallel, " - "followed by all the arguments for the " - "training script") - # rest from the training program - args, unknown = parser.parse_known_args() - args.training_script_args = unknown - return args - - -def main(): - print("start", __file__) - args = parse_args() - print(args) - visible_devices = args.visible_devices.split(',') - assert os.path.isfile(args.training_script) - assert len(visible_devices) >= args.nproc_per_node - print('visible_devices:{}'.format(visible_devices)) - if not args.server_id: - print('pleaser input server ip!!!') - exit(0) - print('server_id:{}'.format(args.server_id)) - - # construct hccn_table - hccn_configs = open('/etc/hccn.conf', 'r').readlines() - device_ips = {} - for hccn_item in hccn_configs: - hccn_item = hccn_item.strip() - if hccn_item.startswith('address_'): - device_id, device_ip = hccn_item.split('=') - device_id = device_id.split('_')[1] - device_ips[device_id] = device_ip - print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) - hccn_table = {} - hccn_table['board_id'] = '0x0000' - hccn_table['chip_info'] = '910' - hccn_table['deploy_mode'] = 'lab' - hccn_table['group_count'] = '1' - hccn_table['group_list'] = [] - instance_list = [] - usable_dev = '' - for instance_id in range(args.nproc_per_node): - instance = {} - instance['devices'] = [] - device_id = visible_devices[instance_id] - device_ip = device_ips[device_id] - usable_dev += str(device_id) - instance['devices'].append({ - 'device_id': device_id, - 'device_ip': device_ip, - }) - instance['rank_id'] = str(instance_id) - instance['server_id'] = args.server_id - instance_list.append(instance) - hccn_table['group_list'].append({ - 'device_num': str(args.nproc_per_node), - 'server_num': '1', - 'group_name': '', - 'instance_count': str(args.nproc_per_node), - 'instance_list': instance_list, - }) - hccn_table['para_plane_nic_location'] = 'device' - hccn_table['para_plane_nic_name'] = [] - for instance_id in range(args.nproc_per_node): - eth_id = visible_devices[instance_id] - hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) - hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) - hccn_table['status'] = 'completed' - - # save hccn_table to file - table_path = os.getcwd() - if not os.path.exists(table_path): - os.mkdir(table_path) - table_fn = os.path.join(table_path, - 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) - with open(table_fn, 'w') as table_fp: - json.dump(hccn_table, table_fp, indent=4) - sys.stdout.flush() - - # spawn the processes - processes = [] - cmds = [] - log_files = [] - env = os.environ.copy() - env['RANK_SIZE'] = str(args.nproc_per_node) - cur_path = os.getcwd() - for rank_id in range(0, args.nproc_per_node): - os.chdir(cur_path) - device_id = visible_devices[rank_id] - device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) - env['RANK_ID'] = str(rank_id) - env['DEVICE_ID'] = str(device_id) - if args.nproc_per_node > 1: - env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn - env['RANK_TABLE_FILE'] = table_fn - if os.path.exists(device_dir): - shutil.rmtree(device_dir) - os.mkdir(device_dir) - os.chdir(device_dir) - cmd = [sys.executable, '-u'] - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') - process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) - processes.append(process) - cmds.append(cmd) - log_files.append(log_file) - for process, cmd, log_file in zip(processes, cmds, log_files): - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process, cmd=cmd) - log_file.close() - - -if __name__ == "__main__": - main() diff --git a/model_zoo/mobilenetv3/src/mobilenetV3.py b/model_zoo/mobilenetv3/src/mobilenetV3.py deleted file mode 100644 index 61b63f9ea1..0000000000 --- a/model_zoo/mobilenetv3/src/mobilenetV3.py +++ /dev/null @@ -1,390 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV3 model define""" -from functools import partial -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore import Tensor - - -__all__ = ['mobilenet_v3_large', - 'mobilenet_v3_small'] - - -def _make_divisible(x, divisor=8): - return int(np.ceil(x * 1. / divisor) * divisor) - - -class Activation(nn.Cell): - """ - Activation definition. - - Args: - act_func(string): activation name. - - Returns: - Tensor, output tensor. - """ - - def __init__(self, act_func): - super(Activation, self).__init__() - if act_func == 'relu': - self.act = nn.ReLU() - elif act_func == 'relu6': - self.act = nn.ReLU6() - elif act_func in ('hsigmoid', 'hard_sigmoid'): - self.act = nn.HSigmoid() - elif act_func in ('hswish', 'hard_swish'): - self.act = nn.HSwish() - else: - raise NotImplementedError - - def construct(self, x): - return self.act(x) - - -class GlobalAvgPooling(nn.Cell): - """ - Global avg pooling definition. - - Args: - - Returns: - Tensor, output tensor. - - Examples: - >>> GlobalAvgPooling() - """ - - def __init__(self, keep_dims=False): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(keep_dims=keep_dims) - - def construct(self, x): - x = self.mean(x, (2, 3)) - return x - - -class SE(nn.Cell): - """ - SE warpper definition. - - Args: - num_out (int): Output channel. - ratio (int): middle output ratio. - - Returns: - Tensor, output tensor. - - Examples: - >>> SE(4) - """ - - def __init__(self, num_out, ratio=4): - super(SE, self).__init__() - num_mid = _make_divisible(num_out // ratio) - self.pool = GlobalAvgPooling(keep_dims=True) - self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid, - kernel_size=1, has_bias=True, pad_mode='pad') - self.act1 = Activation('relu') - self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out, - kernel_size=1, has_bias=True, pad_mode='pad') - self.act2 = Activation('hsigmoid') - self.mul = P.Mul() - - def construct(self, x): - out = self.pool(x) - out = self.conv1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.act2(out) - out = self.mul(x, out) - return out - - -class Unit(nn.Cell): - """ - Unit warpper definition. - - Args: - num_in (int): Input channel. - num_out (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - padding (int): Padding number. - num_groups (int): Output num group. - use_act (bool): Used activation or not. - act_type (string): Activation type. - - Returns: - Tensor, output tensor. - - Examples: - >>> Unit(3, 3) - """ - - def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1, - use_act=True, act_type='relu'): - super(Unit, self).__init__() - self.conv = nn.Conv2d(in_channels=num_in, - out_channels=num_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - group=num_groups, - has_bias=False, - pad_mode='pad') - self.bn = nn.BatchNorm2d(num_out) - self.use_act = use_act - self.act = Activation(act_type) if use_act else None - - def construct(self, x): - out = self.conv(x) - out = self.bn(out) - if self.use_act: - out = self.act(out) - return out - - -class ResUnit(nn.Cell): - """ - ResUnit warpper definition. - - Args: - num_in (int): Input channel. - num_mid (int): Middle channel. - num_out (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - act_type (str): Activation type. - use_se (bool): Use SE warpper or not. - - Returns: - Tensor, output tensor. - - Examples: - >>> ResUnit(16, 3, 1, 1) - """ - def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): - super(ResUnit, self).__init__() - self.use_se = use_se - self.first_conv = (num_out != num_mid) - self.use_short_cut_conv = True - - if self.first_conv: - self.expand = Unit(num_in, num_mid, kernel_size=1, - stride=1, padding=0, act_type=act_type) - else: - self.expand = None - self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride, - padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid) - if use_se: - self.se = SE(num_mid) - self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1, - padding=0, act_type=act_type, use_act=False) - if num_in != num_out or stride != 1: - self.use_short_cut_conv = False - self.add = P.TensorAdd() if self.use_short_cut_conv else None - - def construct(self, x): - if self.first_conv: - out = self.expand(x) - else: - out = x - out = self.conv1(out) - if self.use_se: - out = self.se(out) - out = self.conv2(out) - if self.use_short_cut_conv: - out = self.add(x, out) - return out - - def _get_pad(self, kernel_size): - """set the padding number""" - pad = 0 - if kernel_size == 1: - pad = 0 - elif kernel_size == 3: - pad = 1 - elif kernel_size == 5: - pad = 2 - elif kernel_size == 7: - pad = 3 - else: - raise NotImplementedError - return pad - - -class MobileNetV3(nn.Cell): - """ - MobileNetV3 architecture. - - Args: - model_cfgs (Cell): number of classes. - num_classes (int): Output number classes. - multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1. - final_drop (float): Dropout number. - round_nearest (list): Channel round to . Default is 8. - Returns: - Tensor, output tensor. - - Examples: - >>> MobileNetV3(num_classes=1000) - """ - - def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): - super(MobileNetV3, self).__init__() - self.cfgs = model_cfgs['cfg'] - self.inplanes = 16 - self.features = [] - first_conv_in_channel = 3 - first_conv_out_channel = _make_divisible(multiplier * self.inplanes) - - self.features.append(nn.Conv2d(in_channels=first_conv_in_channel, - out_channels=first_conv_out_channel, - kernel_size=3, padding=1, stride=2, - has_bias=False, pad_mode='pad')) - self.features.append(nn.BatchNorm2d(first_conv_out_channel)) - self.features.append(Activation('hswish')) - for layer_cfg in self.cfgs: - self.features.append(self._make_layer(kernel_size=layer_cfg[0], - exp_ch=_make_divisible(multiplier * layer_cfg[1]), - out_channel=_make_divisible(multiplier * layer_cfg[2]), - use_se=layer_cfg[3], - act_func=layer_cfg[4], - stride=layer_cfg[5])) - output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"]) - self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]), - out_channels=output_channel, - kernel_size=1, padding=0, stride=1, - has_bias=False, pad_mode='pad')) - self.features.append(nn.BatchNorm2d(output_channel)) - self.features.append(Activation('hswish')) - self.features.append(GlobalAvgPooling(keep_dims=True)) - self.features.append(nn.Conv2d(in_channels=output_channel, - out_channels=model_cfgs['cls_ch_expand'], - kernel_size=1, padding=0, stride=1, - has_bias=False, pad_mode='pad')) - self.features.append(Activation('hswish')) - if final_drop > 0: - self.features.append((nn.Dropout(final_drop))) - - # make it nn.CellList - self.features = nn.SequentialCell(self.features) - self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], - out_channels=num_classes, - kernel_size=1, has_bias=True, pad_mode='pad') - self.squeeze = P.Squeeze(axis=(2, 3)) - - self._initialize_weights() - - def construct(self, x): - x = self.features(x) - x = self.output(x) - x = self.squeeze(x) - return x - - def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): - mid_planes = exp_ch - out_planes = out_channel - #num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): - layer = ResUnit(self.inplanes, mid_planes, out_planes, - kernel_size, stride=stride, act_type=act_func, use_se=use_se) - self.inplanes = out_planes - return layer - - def _initialize_weights(self): - """ - Initialize weights. - - Args: - - Returns: - None. - - Examples: - >>> _initialize_weights() - """ - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_parameter_data( - Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_parameter_data( - Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - m.weight.set_parameter_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_parameter_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - -def mobilenet_v3(model_name, **kwargs): - """ - Constructs a MobileNet V2 model - """ - model_cfgs = { - "large": { - "cfg": [ - # k, exp, c, se, nl, s, - [3, 16, 16, False, 'relu', 1], - [3, 64, 24, False, 'relu', 2], - [3, 72, 24, False, 'relu', 1], - [5, 72, 40, True, 'relu', 2], - [5, 120, 40, True, 'relu', 1], - [5, 120, 40, True, 'relu', 1], - [3, 240, 80, False, 'hswish', 2], - [3, 200, 80, False, 'hswish', 1], - [3, 184, 80, False, 'hswish', 1], - [3, 184, 80, False, 'hswish', 1], - [3, 480, 112, True, 'hswish', 1], - [3, 672, 112, True, 'hswish', 1], - [5, 672, 160, True, 'hswish', 2], - [5, 960, 160, True, 'hswish', 1], - [5, 960, 160, True, 'hswish', 1]], - "cls_ch_squeeze": 960, - "cls_ch_expand": 1280, - }, - "small": { - "cfg": [ - # k, exp, c, se, nl, s, - [3, 16, 16, True, 'relu', 2], - [3, 72, 24, False, 'relu', 2], - [3, 88, 24, False, 'relu', 1], - [5, 96, 40, True, 'hswish', 2], - [5, 240, 40, True, 'hswish', 1], - [5, 240, 40, True, 'hswish', 1], - [5, 120, 48, True, 'hswish', 1], - [5, 144, 48, True, 'hswish', 1], - [5, 288, 96, True, 'hswish', 2], - [5, 576, 96, True, 'hswish', 1], - [5, 576, 96, True, 'hswish', 1]], - "cls_ch_squeeze": 576, - "cls_ch_expand": 1280, - } - } - return MobileNetV3(model_cfgs[model_name], **kwargs) - - -mobilenet_v3_large = partial(mobilenet_v3, model_name="large") -mobilenet_v3_small = partial(mobilenet_v3, model_name="small") diff --git a/model_zoo/mobilenetv3/train.py b/model_zoo/mobilenetv3/train.py deleted file mode 100644 index 57199ec1a7..0000000000 --- a/model_zoo/mobilenetv3/train.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train_imagenet.""" -import os -import time -import argparse -import random -import numpy as np - -from mindspore import context -from mindspore import Tensor -from mindspore import nn -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.nn.optim.momentum import Momentum -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.nn.loss.loss import _Loss -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.common import dtype as mstype -from mindspore.train.model import Model, ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import load_checkpoint, load_param_into_net -import mindspore.dataset.engine as de -from mindspore.communication.management import init, get_group_size, get_rank - -from src.dataset import create_dataset -from src.lr_generator import get_lr -from src.config import config_gpu, config_ascend -from src.mobilenetV3 import mobilenet_v3_large - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') -parser.add_argument('--platform', type=str, default=None, help='run platform') -args_opt = parser.parse_args() - -if args_opt.platform == "Ascend": - device_id = int(os.getenv('DEVICE_ID')) - rank_id = int(os.getenv('RANK_ID')) - rank_size = int(os.getenv('RANK_SIZE')) - run_distribute = rank_size > 1 - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - device_id=device_id, - save_graphs=False) -elif args_opt.platform == "GPU": - context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", - save_graphs=False) - init("nccl") - context.set_auto_parallel_context(device_num=get_group_size(), - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) -else: - raise ValueError("Unsupport platform.") - - -class CrossEntropyWithLabelSmooth(_Loss): - """ - CrossEntropyWith LabelSmooth. - - Args: - smooth_factor (float): smooth factor, default=0. - num_classes (int): num classes - - Returns: - None. - - Examples: - >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) - """ - - def __init__(self, smooth_factor=0., num_classes=1000): - super(CrossEntropyWithLabelSmooth, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) - self.off_value = Tensor(1.0 * smooth_factor / - (num_classes - 1), mstype.float32) - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean(False) - self.cast = P.Cast() - - def construct(self, logit, label): - one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], - self.on_value, self.off_value) - out_loss = self.ce(logit, one_hot_label) - out_loss = self.mean(out_loss, 0) - return out_loss - - -class Monitor(Callback): - """ - Monitor loss and time. - - Args: - lr_init (numpy array): train lr - - Returns: - None - - Examples: - >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) - """ - - def __init__(self, lr_init=None): - super(Monitor, self).__init__() - self.lr_init = lr_init - self.lr_init_len = len(lr_init) - - def epoch_begin(self, run_context): - self.losses = [] - self.epoch_time = time.time() - - def epoch_end(self, run_context): - cb_params = run_context.original_args() - - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / cb_params.batch_num - print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, - per_step_mseconds, - np.mean(self.losses))) - - def step_begin(self, run_context): - self.step_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - step_mseconds = (time.time() - self.step_time) * 1000 - step_loss = cb_params.net_outputs - - if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): - step_loss = step_loss[0] - if isinstance(step_loss, Tensor): - step_loss = np.mean(step_loss.asnumpy()) - - self.losses.append(step_loss) - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num - - print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( - cb_params.cur_epoch_num - - 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, - np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) - - -if __name__ == '__main__': - if args_opt.platform == "GPU": - # train on gpu - print("train args: ", args_opt) - print("cfg: ", config_gpu) - - # define net - net = mobilenet_v3_large(num_classes=config_gpu.num_classes) - # define loss - if config_gpu.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits( - is_grad=False, sparse=True, reduction='mean') - # define dataset - epoch_size = config_gpu.epoch_size - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_gpu, - platform=args_opt.platform, - repeat_num=epoch_size, - batch_size=config_gpu.batch_size) - step_size = dataset.get_dataset_size() - # resume - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - # define optimizer - loss_scale = FixedLossScaleManager( - config_gpu.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_gpu.lr, - warmup_epochs=config_gpu.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, - config_gpu.weight_decay, config_gpu.loss_scale) - # define model - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = [Monitor(lr_init=lr.asnumpy())] - ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" - if config_gpu.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_gpu.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) - cb += [ckpt_cb] - # begine train - model.train(epoch_size, dataset, callbacks=cb) - elif args_opt.platform == "Ascend": - # train on ascend - print("train args: ", args_opt, "\ncfg: ", config_ascend, - "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) - - if run_distribute: - context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, mirror_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) - init() - - epoch_size = config_ascend.epoch_size - net = mobilenet_v3_large(num_classes=config_ascend.num_classes) - net.to_float(mstype.float16) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - if config_ascend.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_ascend.label_smooth, num_classes=config.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits( - is_grad=False, sparse=True, reduction='mean') - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - config=config_ascend, - platform=args_opt.platform, - repeat_num=epoch_size, - batch_size=config_ascend.batch_size) - step_size = dataset.get_dataset_size() - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - - loss_scale = FixedLossScaleManager( - config_ascend.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_ascend.lr, - warmup_epochs=config_ascend.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, - config_ascend.weight_decay, config_ascend.loss_scale) - - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = None - if rank_id == 0: - cb = [Monitor(lr_init=lr.asnumpy())] - if config_ascend.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_ascend.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint( - prefix="mobilenetV3", directory=config_ascend.save_checkpoint_path, config=config_ck) - cb += [ckpt_cb] - model.train(epoch_size, dataset, callbacks=cb) - else: - raise Exception diff --git a/model_zoo/official/README.md b/model_zoo/official/README.md new file mode 100644 index 0000000000..e921cfe3df --- /dev/null +++ b/model_zoo/official/README.md @@ -0,0 +1,73 @@ +![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) + + +# Welcome to the Model Zoo for MindSpore + +In order to facilitate developers to enjoy the benefits of MindSpore framework, we will continue to add typical networks and some of the related pre-trained models. If you have needs for the model zoo, you can file an issue on [gitee](https://gitee.com/mindspore/mindspore/issues) or [MindSpore](https://bbs.huaweicloud.com/forum/forum-1076-1.html), We will consider it in time. + +- SOTA models using the latest MindSpore APIs + +- The best benefits from MindSpore + +- Officially maintained and supported + + + +# Table of Contents + +- [Models](#models) + - [Computer Vision](#computer-vision) + - [Image Classification](#image-classification) + - [GoogleNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet/README.md) + - [ResNet50[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/README.md) + - [ResNet50_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/resnet_quant/README.md) + - [ResNet101](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet/README.md) + - [ResNext50](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnext50/README.md) + - [VGG16](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg16/README.md) + - [AlexNet](#https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet/README.md) + - [LeNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/README.md) + - [LeNet](#https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet_quant/README.md) + - [Object Detection and Segmentation](#object-detection-and-segmentation) + - [DeepLabV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3/README.md) + - [FasterRCNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/faster_rcnn/README.md) + - [YoloV3-DarkNet53](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53/README.md) + - [YoloV3-ResNet18](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_resnet18/README.md) + - [MobileNetV2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2/README.md) + - [MobileNetV2_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2_quant/README.md) + - [MobileNetV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv3/README.md) + - [SSD](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd/README.md) + - [Warp-CTC](#https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/warpctc/README.md) + + - [Natural Language Processing](#natural-language-processing) + - [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md) + - [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md) + - [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md) + - [Recommendation](#recommendation) + - [DeepFM](#https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md) + - [Wide&Deep[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md) + - [Graph Neural Networks](#gnn) + - [GAT](#https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gat/README.md) + - [GCN](#https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn//README.md) + + + +# Announcements +| Date | News | +| ------------ | ------------------------------------------------------------ | +| June 30, 2020 | Support [MindSpore v0.6.0-beta](https://www.mindspore.cn/news/newschildren?id=221) | + + + +# Disclaimers + +Mindspore only provides scripts that downloads and preprocesses public datasets. We do not own these datasets and are not responsible for their quality or maintenance. Please make sure you have permission to use the dataset under the dataset’s license. + +To dataset owners: we will remove or update all public content upon request if you don’t want your dataset included on Mindspore, or wish to update it in any way. Please contact us through a Github/Gitee issue. Your understanding and contribution to this community is greatly appreciated. + +MindSpore is Apache 2.0 licensed. Please see the LICENSE file. + + + +# License + +[Apache License 2.0](https://gitee.com/mindspore/mindspore/blob/master/LICENSE) diff --git a/model_zoo/official/audio/.gitkeep b/model_zoo/official/audio/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/alexnet/README.md b/model_zoo/official/cv/alexnet/README.md similarity index 100% rename from model_zoo/alexnet/README.md rename to model_zoo/official/cv/alexnet/README.md diff --git a/model_zoo/alexnet/eval.py b/model_zoo/official/cv/alexnet/eval.py similarity index 100% rename from model_zoo/alexnet/eval.py rename to model_zoo/official/cv/alexnet/eval.py diff --git a/model_zoo/official/cv/alexnet/src/__init__.py b/model_zoo/official/cv/alexnet/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/alexnet/src/alexnet.py b/model_zoo/official/cv/alexnet/src/alexnet.py new file mode 100644 index 0000000000..2e333a0ad5 --- /dev/null +++ b/model_zoo/official/cv/alexnet/src/alexnet.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Alexnet.""" +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode=pad_mode) + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + +def weight_variable(): + return TruncatedNormal(0.02) + + +class AlexNet(nn.Cell): + """ + Alexnet + """ + def __init__(self, num_classes=10, channel=3): + super(AlexNet, self).__init__() + self.conv1 = conv(channel, 96, 11, stride=4) + self.conv2 = conv(96, 256, 5, pad_mode="same") + self.conv3 = conv(256, 384, 3, pad_mode="same") + self.conv4 = conv(384, 384, 3, pad_mode="same") + self.conv5 = conv(384, 256, 3, pad_mode="same") + self.relu = nn.ReLU() + self.max_pool2d = P.MaxPool(ksize=3, strides=2) + self.flatten = nn.Flatten() + self.fc1 = fc_with_initialize(6*6*256, 4096) + self.fc2 = fc_with_initialize(4096, 4096) + self.fc3 = fc_with_initialize(4096, num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.relu(x) + x = self.conv5(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x diff --git a/model_zoo/alexnet/src/config.py b/model_zoo/official/cv/alexnet/src/config.py similarity index 100% rename from model_zoo/alexnet/src/config.py rename to model_zoo/official/cv/alexnet/src/config.py diff --git a/model_zoo/alexnet/src/dataset.py b/model_zoo/official/cv/alexnet/src/dataset.py similarity index 100% rename from model_zoo/alexnet/src/dataset.py rename to model_zoo/official/cv/alexnet/src/dataset.py diff --git a/model_zoo/alexnet/src/generator_lr.py b/model_zoo/official/cv/alexnet/src/generator_lr.py similarity index 100% rename from model_zoo/alexnet/src/generator_lr.py rename to model_zoo/official/cv/alexnet/src/generator_lr.py diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py new file mode 100644 index 0000000000..4512244b92 --- /dev/null +++ b/model_zoo/official/cv/alexnet/train.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## train alexnet example ######################## +train alexnet and get network model files(.ckpt) : +python train.py --data_path /YourDataPath +""" + +import argparse +from src.config import alexnet_cfg as cfg +from src.dataset import create_dataset_cifar10 +from src.generator_lr import get_lr +from src.alexnet import AlexNet +import mindspore.nn as nn +from mindspore import context +from mindspore import Tensor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') + parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ + path where the trained ckpt file') + parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + args = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + + ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1) + network = AlexNet(cfg.num_classes) + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) + opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) + model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck) + + print("============== Starting Training ==============") + model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], + dataset_sink_mode=args.dataset_sink_mode) diff --git a/model_zoo/official/cv/deeplabv3/README.md b/model_zoo/official/cv/deeplabv3/README.md new file mode 100644 index 0000000000..98c9e500fc --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/README.md @@ -0,0 +1,70 @@ +# DeeplabV3 Example + +## Description +This is an example of training DeepLabV3 with PASCAL VOC 2012 dataset in MindSpore. + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download the VOC 2012 dataset for training. + ``` bash + python remove_gt_colormap.py --original_gt_folder GT_FOLDER --output_dir OUTPUT_DIR + + ``` + +> Notes: + If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. + + +## Running the Example +### Training +- Set options in config.py. +- Run `run_standalone_train.sh` for non-distributed training. + ``` bash + sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH + ``` +- Run `run_distribute_train.sh` for distributed training. + ``` bash + sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH + ``` +### Evaluation +Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path. +- Run run_eval.sh for evaluation. + ``` bash + sh scripts/run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH + ``` + +## Options and Parameters +It contains of parameters of DeeplabV3 model and options for training, which is set in file config.py. + +### Options: +``` +config.py: + learning_rate Learning rate, default is 0.0014. + weight_decay Weight decay, default is 5e-5. + momentum Momentum, default is 0.97. + crop_size Image crop size [height, width] during training, default is 513. + eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]. + output_stride The ratio of input to output spatial resolution, default is 16. + ignore_label Ignore label value, default is 255. + seg_num_classes Number of semantic classes, including the background class (if exists). + foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21. + fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False. + atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None. + decoder_output_stride The ratio of input to output spatial resolution when employing decoder + to refine segmentation results, default is None. + image_pyramid Input scales for multi-scale feature extraction, default is None. + epoch_size Epoch size, default is 6. + batch_size batch size of input dataset: N, default is 2. + enable_save_ckpt Enable save checkpoint, default is true. + save_checkpoint_steps Save checkpoint steps, default is 1000. + save_checkpoint_num Save checkpoint numbers, default is 1. +``` + + +### Parameters: +``` +Parameters for dataset and network: + distribute Run distribute, default is false. + data_url Train/Evaluation data url, required. + checkpoint_url Checkpoint path, default is None. +``` \ No newline at end of file diff --git a/model_zoo/deeplabv3/eval.py b/model_zoo/official/cv/deeplabv3/eval.py similarity index 100% rename from model_zoo/deeplabv3/eval.py rename to model_zoo/official/cv/deeplabv3/eval.py diff --git a/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train.sh b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..2a4867548e --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/scripts/run_distribute_train.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH" +echo "for example: bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH [PRETRAINED_CKPT_PATH](option)" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +DATA_DIR=$2 + +export RANK_TABLE_FILE=$1 +export RANK_SIZE=8 +export DEVICE_NUM=8 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi +cores=`cat /proc/cpuinfo|grep "processor" |wc -l` +echo "the number of logical core" $cores +avg_core_per_rank=`expr $cores \/ $RANK_SIZE` +core_gap=`expr $avg_core_per_rank \- 1` +echo "avg_core_per_rank" $avg_core_per_rank +echo "core_gap" $core_gap +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0;i env.log + taskset -c $cmdopt python ../train.py \ + --distribute="true" \ + --device_id=$DEVICE_ID \ + --checkpoint_url=$PATH_CHECKPOINT \ + --data_url=$DATA_DIR > log.txt 2>&1 & + cd ../ +done \ No newline at end of file diff --git a/model_zoo/deeplabv3/scripts/run_eval.sh b/model_zoo/official/cv/deeplabv3/scripts/run_eval.sh similarity index 100% rename from model_zoo/deeplabv3/scripts/run_eval.sh rename to model_zoo/official/cv/deeplabv3/scripts/run_eval.sh diff --git a/model_zoo/deeplabv3/scripts/run_standalone_train.sh b/model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/deeplabv3/scripts/run_standalone_train.sh rename to model_zoo/official/cv/deeplabv3/scripts/run_standalone_train.sh diff --git a/model_zoo/deeplabv3/src/__init__.py b/model_zoo/official/cv/deeplabv3/src/__init__.py similarity index 100% rename from model_zoo/deeplabv3/src/__init__.py rename to model_zoo/official/cv/deeplabv3/src/__init__.py diff --git a/model_zoo/deeplabv3/src/backbone/__init__.py b/model_zoo/official/cv/deeplabv3/src/backbone/__init__.py similarity index 100% rename from model_zoo/deeplabv3/src/backbone/__init__.py rename to model_zoo/official/cv/deeplabv3/src/backbone/__init__.py diff --git a/model_zoo/deeplabv3/src/backbone/resnet_deeplab.py b/model_zoo/official/cv/deeplabv3/src/backbone/resnet_deeplab.py similarity index 100% rename from model_zoo/deeplabv3/src/backbone/resnet_deeplab.py rename to model_zoo/official/cv/deeplabv3/src/backbone/resnet_deeplab.py diff --git a/model_zoo/deeplabv3/src/config.py b/model_zoo/official/cv/deeplabv3/src/config.py similarity index 100% rename from model_zoo/deeplabv3/src/config.py rename to model_zoo/official/cv/deeplabv3/src/config.py diff --git a/model_zoo/official/cv/deeplabv3/src/deeplabv3.py b/model_zoo/official/cv/deeplabv3/src/deeplabv3.py new file mode 100644 index 0000000000..bbfc4dceb3 --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/deeplabv3.py @@ -0,0 +1,457 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DeepLabv3.""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ + DepthwiseConv2dNative, SpaceToBatch, BatchToSpace + + +class ASPPSampleBlock(nn.Cell): + """ASPP sample block.""" + def __init__(self, feature_shape, scale_size, output_stride): + super(ASPPSampleBlock, self).__init__() + sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1 + sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1 + self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) + + def construct(self, x): + return self.sample(x) + + +class ASPP(nn.Cell): + """ + ASPP model for DeepLabv3. + + Args: + channel (int): Input channel. + depth (int): Output channel. + feature_shape (list): The shape of feature,[h,w]. + scale_sizes (list): Input scales for multi-scale feature extraction. + atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. + output_stride (int): 'The ratio of input to output spatial resolution.' + fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' + + Returns: + Tensor, output tensor. + + Examples: + >>> ASPP(channel=2048,256,[14,14],[1],[6],16) + """ + def __init__(self, channel, depth, feature_shape, scale_sizes, + atrous_rates, output_stride, fine_tune_batch_norm=False): + super(ASPP, self).__init__() + self.aspp0 = _conv_bn_relu(channel, + depth, + ksize=1, + stride=1, + use_batch_statistics=fine_tune_batch_norm) + self.atrous_rates = [] + if atrous_rates is not None: + self.atrous_rates = atrous_rates + self.aspp_pointwise = _conv_bn_relu(channel, + depth, + ksize=1, + stride=1, + use_batch_statistics=fine_tune_batch_norm) + self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel, + channel_multiplier=1, + kernel_size=3, + stride=1, + dilation=1, + pad_mode="valid") + self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm) + self.aspp_depth_relu = nn.ReLU() + self.aspp_depths = [] + self.aspp_depth_spacetobatchs = [] + self.aspp_depth_batchtospaces = [] + + for scale_size in scale_sizes: + aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16) + if atrous_rates is None: + break + for rate in atrous_rates: + padding = 0 + for j in range(100): + padded_size = rate * j + if padded_size >= aspp_scale_depth_size + 2 * rate: + padding = padded_size - aspp_scale_depth_size - 2 * rate + break + paddings = [[rate, rate + int(padding)], + [rate, rate + int(padding)]] + self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings) + self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch) + crops = [[0, int(padding)], [0, int(padding)]] + self.aspp_depth_batchtospace = BatchToSpace(rate, crops) + self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace) + self.aspp_depths = nn.CellList(self.aspp_depths) + self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs) + self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces) + + self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1]))) + self.global_poolings = [] + for scale_size in scale_sizes: + pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride) + pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride) + self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w)))) + self.global_poolings = nn.CellList(self.global_poolings) + self.conv_bn = _conv_bn_relu(channel, + depth, + ksize=1, + stride=1, + use_batch_statistics=fine_tune_batch_norm) + self.samples = [] + for scale_size in scale_sizes: + self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride)) + self.samples = nn.CellList(self.samples) + self.feature_shape = feature_shape + self.concat = P.Concat(axis=1) + + def construct(self, x, scale_index=0): + aspp0 = self.aspp0(x) + aspp1 = self.global_poolings[scale_index](x) + aspp1 = self.conv_bn(aspp1) + aspp1 = self.samples[scale_index](aspp1) + output = self.concat((aspp1, aspp0)) + + for i in range(len(self.atrous_rates)): + aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x) + aspp_i = self.aspp_depth_depthwiseconv(aspp_i) + aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i) + aspp_i = self.aspp_depth_bn(aspp_i) + aspp_i = self.aspp_depth_relu(aspp_i) + aspp_i = self.aspp_pointwise(aspp_i) + output = self.concat((output, aspp_i)) + return output + + +class DecoderSampleBlock(nn.Cell): + """Decoder sample block.""" + def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4): + super(DecoderSampleBlock, self).__init__() + sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1 + sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1 + self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) + + def construct(self, x): + return self.sample(x) + + +class Decoder(nn.Cell): + """ + Decode module for DeepLabv3. + Args: + low_level_channel (int): Low level input channel + channel (int): Input channel. + depth (int): Output channel. + feature_shape (list): 'Input image shape, [N,C,H,W].' + scale_sizes (list): 'Input scales for multi-scale feature extraction.' + decoder_output_stride (int): 'The ratio of input to output spatial resolution' + fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' + Returns: + Tensor, output tensor. + Examples: + >>> Decoder(256, 100, [56,56]) + """ + def __init__(self, + low_level_channel, + channel, + depth, + feature_shape, + scale_sizes, + decoder_output_stride, + fine_tune_batch_norm): + super(Decoder, self).__init__() + self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1, + pad_mode="same", use_batch_statistics=fine_tune_batch_norm) + self.decoder_depth0 = _deep_conv_bn_relu(channel + 48, + channel_multiplier=1, + ksize=3, + stride=1, + pad_mode="same", + dilation=1, + use_batch_statistics=fine_tune_batch_norm) + self.decoder_pointwise0 = _conv_bn_relu(channel + 48, + depth, + ksize=1, + stride=1, + use_batch_statistics=fine_tune_batch_norm) + self.decoder_depth1 = _deep_conv_bn_relu(depth, + channel_multiplier=1, + ksize=3, + stride=1, + pad_mode="same", + dilation=1, + use_batch_statistics=fine_tune_batch_norm) + self.decoder_pointwise1 = _conv_bn_relu(depth, + depth, + ksize=1, + stride=1, + use_batch_statistics=fine_tune_batch_norm) + self.depth = depth + self.concat = P.Concat(axis=1) + self.samples = [] + for scale_size in scale_sizes: + self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride)) + self.samples = nn.CellList(self.samples) + + def construct(self, x, low_level_feature, scale_index): + low_level_feature = self.feature_projection(low_level_feature) + low_level_feature = self.samples[scale_index](low_level_feature) + x = self.samples[scale_index](x) + output = self.concat((x, low_level_feature)) + output = self.decoder_depth0(output) + output = self.decoder_pointwise0(output) + output = self.decoder_depth1(output) + output = self.decoder_pointwise1(output) + return output + + +class SingleDeepLabV3(nn.Cell): + """ + DeepLabv3 Network. + Args: + num_classes (int): Class number. + feature_shape (list): Input image shape, [N,C,H,W]. + backbone (Cell): Backbone Network. + channel (int): Resnet output channel. + depth (int): ASPP block depth. + scale_sizes (list): Input scales for multi-scale feature extraction. + atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. + decoder_output_stride (int): 'The ratio of input to output spatial resolution' + output_stride (int): 'The ratio of input to output spatial resolution.' + fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' + Returns: + Tensor, output tensor. + Examples: + >>> SingleDeepLabV3(num_classes=10, + >>> feature_shape=[1,3,224,224], + >>> backbone=resnet50_dl(), + >>> channel=2048, + >>> depth=256) + >>> scale_sizes=[1.0]) + >>> atrous_rates=[6]) + >>> decoder_output_stride=4) + >>> output_stride=16) + """ + + def __init__(self, + num_classes, + feature_shape, + backbone, + channel, + depth, + scale_sizes, + atrous_rates, + decoder_output_stride, + output_stride, + fine_tune_batch_norm=False): + super(SingleDeepLabV3, self).__init__() + self.num_classes = num_classes + self.channel = channel + self.depth = depth + self.scale_sizes = [] + for scale_size in np.sort(scale_sizes): + self.scale_sizes.append(scale_size) + self.net = backbone + self.aspp = ASPP(channel=self.channel, + depth=self.depth, + feature_shape=[feature_shape[2], + feature_shape[3]], + scale_sizes=self.scale_sizes, + atrous_rates=atrous_rates, + output_stride=output_stride, + fine_tune_batch_norm=fine_tune_batch_norm) + + atrous_rates_len = 0 + if atrous_rates is not None: + atrous_rates_len = len(atrous_rates) + self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth, + ksize=1, + stride=1, + use_batch_statistics=fine_tune_batch_norm) + self.fc2 = nn.Conv2d(depth, + num_classes, + kernel_size=1, + stride=1, + has_bias=True) + self.upsample = P.ResizeBilinear((int(feature_shape[2]), + int(feature_shape[3])), + align_corners=True) + self.samples = [] + for scale_size in self.scale_sizes: + self.samples.append(SampleBlock(feature_shape, scale_size)) + self.samples = nn.CellList(self.samples) + self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]), + float(feature_shape[3])] + + self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1))) + self.dropout = nn.Dropout(keep_prob=0.9) + self.shape = P.Shape() + self.decoder_output_stride = decoder_output_stride + if decoder_output_stride is not None: + self.decoder = Decoder(low_level_channel=depth, + channel=depth, + depth=depth, + feature_shape=[feature_shape[2], + feature_shape[3]], + scale_sizes=self.scale_sizes, + decoder_output_stride=decoder_output_stride, + fine_tune_batch_norm=fine_tune_batch_norm) + + def construct(self, x, scale_index=0): + x = (2.0 / 255.0) * x - 1.0 + x = self.pad(x) + low_level_feature, feature_map = self.net(x) + for scale_size in self.scale_sizes: + if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2: + output = self.aspp(feature_map, scale_index) + output = self.fc1(output) + if self.decoder_output_stride is not None: + output = self.decoder(output, low_level_feature, scale_index) + output = self.fc2(output) + output = self.samples[scale_index](output) + return output + scale_index += 1 + return feature_map + + +class SampleBlock(nn.Cell): + """Sample block.""" + def __init__(self, + feature_shape, + scale_size=1.0): + super(SampleBlock, self).__init__() + sample_h = np.ceil(float(feature_shape[2]) * scale_size) + sample_w = np.ceil(float(feature_shape[3]) * scale_size) + self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) + + def construct(self, x): + return self.sample(x) + + +class DeepLabV3(nn.Cell): + """DeepLabV3 model.""" + def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates, + decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid): + super(DeepLabV3, self).__init__() + self.infer_scale_sizes = [] + if infer_scale_sizes is not None: + self.infer_scale_sizes = infer_scale_sizes + + self.infer_scale_sizes = infer_scale_sizes + if image_pyramid is None: + image_pyramid = [1.0] + + self.image_pyramid = image_pyramid + scale_sizes = [] + for pyramid in image_pyramid: + scale_sizes.append(pyramid) + for scale in infer_scale_sizes: + scale_sizes.append(scale) + self.samples = [] + for scale_size in scale_sizes: + self.samples.append(SampleBlock(feature_shape, scale_size)) + self.samples = nn.CellList(self.samples) + self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes, + feature_shape=feature_shape, + backbone=resnet50_dl(fine_tune_batch_norm), + channel=channel, + depth=depth, + scale_sizes=scale_sizes, + atrous_rates=atrous_rates, + decoder_output_stride=decoder_output_stride, + output_stride=output_stride, + fine_tune_batch_norm=fine_tune_batch_norm) + self.softmax = P.Softmax(axis=1) + self.concat = P.Concat(axis=2) + self.expand_dims = P.ExpandDims() + self.reduce_mean = P.ReduceMean() + self.sample_common = P.ResizeBilinear((int(feature_shape[2]), + int(feature_shape[3])), + align_corners=True) + + def construct(self, x): + logits = () + if self.training: + if len(self.image_pyramid) >= 1: + if self.image_pyramid[0] == 1: + logits = self.deeplabv3(x) + else: + x1 = self.samples[0](x) + logits = self.deeplabv3(x1) + logits = self.sample_common(logits) + logits = self.expand_dims(logits, 2) + for i in range(len(self.image_pyramid) - 1): + x_i = self.samples[i + 1](x) + logits_i = self.deeplabv3(x_i) + logits_i = self.sample_common(logits_i) + logits_i = self.expand_dims(logits_i, 2) + logits = self.concat((logits, logits_i)) + logits = self.reduce_mean(logits, 2) + return logits + if len(self.infer_scale_sizes) >= 1: + infer_index = len(self.image_pyramid) + x1 = self.samples[infer_index](x) + logits = self.deeplabv3(x1) + logits = self.sample_common(logits) + logits = self.softmax(logits) + logits = self.expand_dims(logits, 2) + for i in range(len(self.infer_scale_sizes) - 1): + x_i = self.samples[i + 1 + infer_index](x) + logits_i = self.deeplabv3(x_i) + logits_i = self.sample_common(logits_i) + logits_i = self.softmax(logits_i) + logits_i = self.expand_dims(logits_i, 2) + logits = self.concat((logits, logits_i)) + logits = self.reduce_mean(logits, 2) + return logits + + +def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid, + infer_scale_sizes, atrous_rates=None, decoder_output_stride=None, + output_stride=16, fine_tune_batch_norm=False): + """ + ResNet50 based DeepLabv3 network. + + Args: + num_classes (int): Class number. + feature_shape (list): Input image shape, [N,C,H,W]. + image_pyramid (list): Input scales for multi-scale feature extraction. + atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. + infer_scale_sizes (list): 'The scales to resize images for inference. + decoder_output_stride (int): 'The ratio of input to output spatial resolution' + output_stride (int): 'The ratio of input to output spatial resolution.' + fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' + + Returns: + Cell, cell instance of ResNet50 based DeepLabv3 neural network. + + Examples: + >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0]) + """ + return DeepLabV3(num_classes=num_classes, + feature_shape=feature_shape, + backbone=resnet50_dl(fine_tune_batch_norm), + channel=2048, + depth=256, + infer_scale_sizes=infer_scale_sizes, + atrous_rates=atrous_rates, + decoder_output_stride=decoder_output_stride, + output_stride=output_stride, + fine_tune_batch_norm=fine_tune_batch_norm, + image_pyramid=image_pyramid) diff --git a/model_zoo/deeplabv3/src/ei_dataset.py b/model_zoo/official/cv/deeplabv3/src/ei_dataset.py similarity index 100% rename from model_zoo/deeplabv3/src/ei_dataset.py rename to model_zoo/official/cv/deeplabv3/src/ei_dataset.py diff --git a/model_zoo/deeplabv3/src/losses.py b/model_zoo/official/cv/deeplabv3/src/losses.py similarity index 100% rename from model_zoo/deeplabv3/src/losses.py rename to model_zoo/official/cv/deeplabv3/src/losses.py diff --git a/model_zoo/deeplabv3/src/md_dataset.py b/model_zoo/official/cv/deeplabv3/src/md_dataset.py similarity index 100% rename from model_zoo/deeplabv3/src/md_dataset.py rename to model_zoo/official/cv/deeplabv3/src/md_dataset.py diff --git a/model_zoo/deeplabv3/src/miou_precision.py b/model_zoo/official/cv/deeplabv3/src/miou_precision.py similarity index 100% rename from model_zoo/deeplabv3/src/miou_precision.py rename to model_zoo/official/cv/deeplabv3/src/miou_precision.py diff --git a/model_zoo/official/cv/deeplabv3/src/utils/__init__.py b/model_zoo/official/cv/deeplabv3/src/utils/__init__.py new file mode 100644 index 0000000000..e30774307c --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/src/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/model_zoo/deeplabv3/src/utils/adapter.py b/model_zoo/official/cv/deeplabv3/src/utils/adapter.py similarity index 100% rename from model_zoo/deeplabv3/src/utils/adapter.py rename to model_zoo/official/cv/deeplabv3/src/utils/adapter.py diff --git a/model_zoo/deeplabv3/src/utils/custom_transforms.py b/model_zoo/official/cv/deeplabv3/src/utils/custom_transforms.py similarity index 100% rename from model_zoo/deeplabv3/src/utils/custom_transforms.py rename to model_zoo/official/cv/deeplabv3/src/utils/custom_transforms.py diff --git a/model_zoo/deeplabv3/src/utils/file_io.py b/model_zoo/official/cv/deeplabv3/src/utils/file_io.py similarity index 100% rename from model_zoo/deeplabv3/src/utils/file_io.py rename to model_zoo/official/cv/deeplabv3/src/utils/file_io.py diff --git a/model_zoo/official/cv/deeplabv3/train.py b/model_zoo/official/cv/deeplabv3/train.py new file mode 100644 index 0000000000..56ef5b02bb --- /dev/null +++ b/model_zoo/official/cv/deeplabv3/train.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train.""" +import argparse +from mindspore import context +from mindspore.communication.management import init +from mindspore.nn.optim.momentum import Momentum +from mindspore import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor +from src.md_dataset import create_dataset +from src.losses import OhemLoss +from src.deeplabv3 import deeplabv3_resnet50 +from src.config import config + +parser = argparse.ArgumentParser(description="Deeplabv3 training") +parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") +parser.add_argument('--data_url', required=True, default=None, help='Train data url') +parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") +parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') + +args_opt = parser.parse_args() +print(args_opt) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) +class LossCallBack(Callback): + """ + Monitor the loss in training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self._per_print_times = per_print_times + def step_end(self, run_context): + cb_params = run_context.original_args() + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) +def model_fine_tune(flags, train_net, fix_weight_layer): + checkpoint_path = flags.checkpoint_url + if checkpoint_path is None: + return + param_dict = load_checkpoint(checkpoint_path) + load_param_into_net(train_net, param_dict) + for para in train_net.trainable_params(): + if fix_weight_layer in para.name: + para.requires_grad = False +if __name__ == "__main__": + if args_opt.distribute == "true": + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) + init() + args_opt.base_size = config.crop_size + args_opt.crop_size = config.crop_size + train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train") + dataset_size = train_dataset.get_dataset_size() + time_cb = TimeMonitor(data_size=dataset_size) + callback = [time_cb, LossCallBack()] + if config.enable_save_ckpt: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, + keep_checkpoint_max=config.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) + callback.append(ckpoint_cb) + net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], + infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, + decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, + fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) + net.set_train() + model_fine_tune(args_opt, net, 'layer') + loss = OhemLoss(config.seg_num_classes, config.ignore_label) + opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) + model = Model(net, loss, opt) + model.train(config.epoch_size, train_dataset, callback) diff --git a/model_zoo/official/cv/faster_rcnn/README.md b/model_zoo/official/cv/faster_rcnn/README.md new file mode 100644 index 0000000000..56be9bfa79 --- /dev/null +++ b/model_zoo/official/cv/faster_rcnn/README.md @@ -0,0 +1,143 @@ +# FasterRcnn Example + +## Description + +FasterRcnn is a two-stage target detection network,This network uses a region proposal network (RPN), which can share the convolution features of the whole image with the detection network, so that the calculation of region proposal is almost cost free. The whole network further combines RPN and FastRcnn into a network by sharing the convolution features. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset COCO2017. + +- We use coco2017 as training dataset in this example by default, and you can also use your own datasets. + + 1. If coco dataset is used. **Select dataset to coco when run script.** + Install Cython and pycocotool, and you can also install mmcv to process data. + + ``` + pip install Cython + + pip install pycocotools + + pip install mmcv + ``` + And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: + + + ``` + . + └─cocodataset + ├─annotations + ├─instance_train2017.json + └─instance_val2017.json + ├─val2017 + └─train2017 + + ``` + + 2. If your own dataset is used. **Select dataset to other when run script.** + Organize the dataset infomation into a TXT file, each row in the file is as follows: + + ``` + train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 + ``` + + Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. + + +## Example structure + +```shell +. +└─FasterRcnn + ├─README.md + ├─scripts + ├─run_download_process_data.sh + ├─run_standalone_train.sh + ├─run_train.sh + └─run_eval.sh + ├─src + ├─FasterRcnn + ├─__init__.py + ├─anchor_generator.py + ├─bbox_assign_sample.py + ├─bbox_assign_sample_stage2.py + ├─faster_rcnn_r50.py + ├─fpn_neck.py + ├─proposal_generator.py + ├─rcnn.py + ├─resnet50.py + ├─roi_align.py + └─rpn.py + ├─config.py + ├─dataset.py + ├─lr_schedule.py + ├─network_define.py + └─util.py + ├─eval.py + └─train.py +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +sh run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL] + +# standalone training +sh run_standalone_train.sh [PRETRAINED_MODEL] +``` + +> Rank_table.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). +> As for PRETRAINED_MODEL,it should be a ResNet50 checkpoint that trained over ImageNet2012. Ready-made pretrained_models are not available now. Stay tuned. + +#### Result + +Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log. + + +``` +# distribute training result(8p) +epoch: 1 step: 7393, rpn_loss: 0.12054, rcnn_loss: 0.40601, rpn_cls_loss: 0.04025, rpn_reg_loss: 0.08032, rcnn_cls_loss: 0.25854, rcnn_reg_loss: 0.14746, total_loss: 0.52655 +epoch: 2 step: 7393, rpn_loss: 0.06561, rcnn_loss: 0.50293, rpn_cls_loss: 0.02587, rpn_reg_loss: 0.03967, rcnn_cls_loss: 0.35669, rcnn_reg_loss: 0.14624, total_loss: 0.56854 +epoch: 3 step: 7393, rpn_loss: 0.06940, rcnn_loss: 0.49658, rpn_cls_loss: 0.03769, rpn_reg_loss: 0.03165, rcnn_cls_loss: 0.36353, rcnn_reg_loss: 0.13318, total_loss: 0.56598 +... +epoch: 10 step: 7393, rpn_loss: 0.03555, rcnn_loss: 0.32666, rpn_cls_loss: 0.00697, rpn_reg_loss: 0.02859, rcnn_cls_loss: 0.16125, rcnn_reg_loss: 0.16541, total_loss: 0.36221 +epoch: 11 step: 7393, rpn_loss: 0.19849, rcnn_loss: 0.47827, rpn_cls_loss: 0.11639, rpn_reg_loss: 0.08209, rcnn_cls_loss: 0.29712, rcnn_reg_loss: 0.18115, total_loss: 0.67676 +epoch: 12 step: 7393, rpn_loss: 0.00691, rcnn_loss: 0.10168, rpn_cls_loss: 0.00529, rpn_reg_loss: 0.00162, rcnn_cls_loss: 0.05426, rcnn_reg_loss: 0.04745, total_loss: 0.10859 +``` + +### Infer + +#### Usage + +``` +# infer +sh run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH] +``` + +> checkpoint can be produced in training process. + +#### Result + +Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. + +``` + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.360 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.586 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.385 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.229 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.402 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.441 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.299 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.487 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.515 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.346 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.562 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631 +``` diff --git a/model_zoo/official/cv/faster_rcnn/eval.py b/model_zoo/official/cv/faster_rcnn/eval.py new file mode 100644 index 0000000000..2049735046 --- /dev/null +++ b/model_zoo/official/cv/faster_rcnn/eval.py @@ -0,0 +1,133 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Evaluation for FasterRcnn""" +import os +import argparse +import time +import random +import numpy as np +from pycocotools.coco import COCO +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de + +from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 +from src.config import config +from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset +from src.util import coco_eval, bbox2result_1image, results2json + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="FasterRcnn evaluation") +parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") +parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.") +parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + +def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): + """FasterRcnn evaluation.""" + ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size, + repeat_num=1, is_training=False) + net = Faster_Rcnn_Resnet50(config) + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + eval_iter = 0 + total = ds.get_dataset_size() + outputs = [] + dataset_coco = COCO(ann_file) + + print("\n========================================\n") + print("total images num: ", total) + print("Processing, please wait a moment.") + max_num = 128 + for data in ds.create_dict_iterator(): + eval_iter = eval_iter + 1 + + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + + start = time.time() + # run net + output = net(Tensor(img_data), Tensor(img_metas), Tensor(gt_bboxes), Tensor(gt_labels), Tensor(gt_num)) + end = time.time() + print("Iter {} cost time {}".format(eval_iter, end - start)) + + # output + all_bbox = output[0] + all_label = output[1] + all_mask = output[2] + + for j in range(config.test_batch_size): + all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :]) + all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :]) + all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :]) + + all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] + all_labels_tmp_mask = all_label_squee[all_mask_squee] + + if all_bboxes_tmp_mask.shape[0] > max_num: + inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) + inds = inds[:max_num] + all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] + all_labels_tmp_mask = all_labels_tmp_mask[inds] + + outputs_tmp = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes) + + outputs.append(outputs_tmp) + + eval_types = ["bbox"] + result_files = results2json(dataset_coco, outputs, "./results.pkl") + + coco_eval(result_files, eval_types, dataset_coco, single_result=True) + + +if __name__ == '__main__': + prefix = "FasterRcnn_eval.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix) + print("CHECKING MINDRECORD FILES ...") + + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + else: + if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image("other", False, prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + print("CHECKING MINDRECORD FILES DONE!") + print("Start Eval!") + FasterRcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file) diff --git a/model_zoo/official/cv/faster_rcnn/scripts/run_distribute_train.sh b/model_zoo/official/cv/faster_rcnn/scripts/run_distribute_train.sh new file mode 100755 index 0000000000..6dbf66cbe9 --- /dev/null +++ b/model_zoo/official/cv/faster_rcnn/scripts/run_distribute_train.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 2 ] +then + echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [PRETRAINED_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 + +if [ ! -f $PATH1 ] +then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" +exit 1 +fi + +PATH2=$(get_real_path $2) +echo $PATH2 +if [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --pre_trained=$PATH2 &> log & + cd .. +done diff --git a/model_zoo/official/cv/faster_rcnn/scripts/run_eval.sh b/model_zoo/official/cv/faster_rcnn/scripts/run_eval.sh new file mode 100755 index 0000000000..95199a55bd --- /dev/null +++ b/model_zoo/official/cv/faster_rcnn/scripts/run_eval.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) +echo $PATH1 +echo $PATH2 + +if [ ! -f $PATH1 ] +then + echo "error: ANN_FILE=$PATH1 is not a file" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export RANK_SIZE=$DEVICE_NUM +export DEVICE_ID=0 +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start eval for device $DEVICE_ID" +python eval.py --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 &> log & +cd .. diff --git a/model_zoo/official/cv/faster_rcnn/scripts/run_standalone_train.sh b/model_zoo/official/cv/faster_rcnn/scripts/run_standalone_train.sh new file mode 100755 index 0000000000..92ea15c2cf --- /dev/null +++ b/model_zoo/official/cv/faster_rcnn/scripts/run_standalone_train.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 1 ] +then + echo "Usage: sh run_standalone_train.sh [PRETRAINED_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 + +if [ ! -f $PATH1 ] +then + echo "error: PRETRAINED_PATH=$PATH1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log +python train.py --device_id=$DEVICE_ID --pre_trained=$PATH1 &> log & +cd .. diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/__init__.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/__init__.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/__init__.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/__init__.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/anchor_generator.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/anchor_generator.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/anchor_generator.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/proposal_generator.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/rcnn.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/rcnn.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/resnet50.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/resnet50.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/roi_align.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/roi_align.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/roi_align.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/roi_align.py diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/rpn.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py similarity index 100% rename from model_zoo/faster_rcnn/src/FasterRcnn/rpn.py rename to model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py diff --git a/model_zoo/faster_rcnn/src/config.py b/model_zoo/official/cv/faster_rcnn/src/config.py similarity index 100% rename from model_zoo/faster_rcnn/src/config.py rename to model_zoo/official/cv/faster_rcnn/src/config.py diff --git a/model_zoo/faster_rcnn/src/dataset.py b/model_zoo/official/cv/faster_rcnn/src/dataset.py similarity index 100% rename from model_zoo/faster_rcnn/src/dataset.py rename to model_zoo/official/cv/faster_rcnn/src/dataset.py diff --git a/model_zoo/faster_rcnn/src/lr_schedule.py b/model_zoo/official/cv/faster_rcnn/src/lr_schedule.py similarity index 100% rename from model_zoo/faster_rcnn/src/lr_schedule.py rename to model_zoo/official/cv/faster_rcnn/src/lr_schedule.py diff --git a/model_zoo/faster_rcnn/src/network_define.py b/model_zoo/official/cv/faster_rcnn/src/network_define.py similarity index 100% rename from model_zoo/faster_rcnn/src/network_define.py rename to model_zoo/official/cv/faster_rcnn/src/network_define.py diff --git a/model_zoo/faster_rcnn/src/util.py b/model_zoo/official/cv/faster_rcnn/src/util.py similarity index 100% rename from model_zoo/faster_rcnn/src/util.py rename to model_zoo/official/cv/faster_rcnn/src/util.py diff --git a/model_zoo/official/cv/faster_rcnn/train.py b/model_zoo/official/cv/faster_rcnn/train.py new file mode 100644 index 0000000000..d48466f621 --- /dev/null +++ b/model_zoo/official/cv/faster_rcnn/train.py @@ -0,0 +1,139 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""train FasterRcnn and get checkpoint files.""" + +import os +import time +import argparse +import random +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import context, Tensor +from mindspore.communication.management import init +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn import SGD +import mindspore.dataset.engine as de + +from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 +from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet +from src.config import config +from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset +from src.lr_schedule import dynamic_lr + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="FasterRcnn training") +parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default: false.") +parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.") +parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") +parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") +parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.") +parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + +if __name__ == '__main__': + if args_opt.run_distribute: + rank = args_opt.rank_id + device_num = args_opt.device_num + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True, parameter_broadcast=True) + init() + else: + rank = 0 + device_num = 1 + + print("Start create dataset!") + + # It will generate mindrecord file in args_opt.mindrecord_dir, + # and the file name is FasterRcnn.mindrecord0, 1, ... file_num. + prefix = "FasterRcnn.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + print("CHECKING MINDRECORD FILES ...") + + if rank == 0 and not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image("coco", True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + else: + if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image("other", True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + while not os.path.exists(mindrecord_file + ".db"): + time.sleep(5) + + print("CHECKING MINDRECORD FILES DONE!") + + loss_scale = float(config.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0. + dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=1, + batch_size=config.batch_size, device_num=device_num, rank_id=rank) + + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + net = Faster_Rcnn_Resnet50(config=config) + net = net.set_train() + + load_path = args_opt.pre_trained + if load_path != "": + param_dict = load_checkpoint(load_path) + for item in list(param_dict.keys()): + if not item.startswith('backbone'): + param_dict.pop(item) + load_param_into_net(net, param_dict) + + loss = LossNet() + lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32) + + opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, + weight_decay=config.weight_decay, loss_scale=config.loss_scale) + net_with_loss = WithLossCell(net, loss) + if args_opt.run_distribute: + net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, + mean=True, degree=device_num) + else: + net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) + + time_cb = TimeMonitor(data_size=dataset_size) + loss_cb = LossCallBack() + cb = [time_cb, loss_cb] + if config.save_checkpoint: + ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=config.save_checkpoint_path, config=ckptconfig) + cb += [ckpoint_cb] + + model = Model(net) + model.train(config.epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/official/cv/googlenet/README.md b/model_zoo/official/cv/googlenet/README.md new file mode 100644 index 0000000000..67660b61d7 --- /dev/null +++ b/model_zoo/official/cv/googlenet/README.md @@ -0,0 +1,318 @@ +# Contents + +- [GoogleNet Description](#googlenet-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Distributed Training](#distributed-training) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) + - [How to use](#how-to-use) + - [Inference](#inference) + - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) + - [Transfer Learning](#transfer-learning) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + + +# [GoogleNet Description](#contents) + +GoogleNet, a 22 layers deep network, was proposed in 2014 and won the first place in the ImageNet Large-Scale Visual Recognition Challenge 2014 (ILSVRC14). GoogleNet, also called Inception v1, has significant improvement over ZFNet (The winner in 2013) and AlexNet (The winner in 2012), and has relatively lower error rate compared to VGGNet. Typically deeper deep learning network means larger number of parameters, which makes it more prone to overfitting. Furthermore, the increased network size leads to increased use of computational resources. To tackle these issues, GoogleNet adopts 1*1 convolution middle of the network to reduce dimension, and thus further reduce the computation. Global average pooling is used at the end of the network, instead of using fully connected layers. Another technique, called inception module, is to have different sizes of convolutions for the same input and stacking all the outputs. + +[Paper](https://arxiv.org/abs/1409.4842): Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. "Going deeper with convolutions." *Proceedings of the IEEE conference on computer vision and pattern recognition*. 2015. + + +# [Model Architecture](#contents) + +Specifically, the GoogleNet contains numerous inception modules, which are connected together to go deeper. In general, an inception module with dimensionality reduction consists of **1×1 conv**, **3×3 conv**, **5×5 conv**, and **3×3 max pooling**, which are done altogether for the previous input, and stack together again at output. + + + +# [Dataset](#contents) + +Dataset used: [CIFAR-10]() + +- Dataset size:175M,60,000 32*32 colorful images in 10 classes + - Train:146M,50,000 images + - Test:29.3M,10,000 images +- Data format:binary files + - Note:Data will be processed in dataset.py + + + +# [Features](#contents) + +## Mixed Precision + +The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. +For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. + + + +# [Environment Requirements](#contents) + +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + + +# [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +```python +# run training example +python train.py > train.log 2>&1 & + +# run distributed training example +sh scripts/run_train.sh rank_table.json + +# run evaluation example +python eval.py > eval.log 2>&1 & OR sh run_eval.sh +``` + + + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +``` +├── model_zoo + ├── README.md // descriptions about all the models + ├── googlenet + ├── README.md // descriptions about googlenet + ├── scripts + │ ├──run_train.sh // shell script for distributed + │ ├──run_eval.sh // shell script for evaluation + ├── src + │ ├──dataset.py // creating dataset + │ ├──googlenet.py // googlenet architecture + │ ├──config.py // parameter configuration + ├── train.py // training script + ├── eval.py // evaluation script + ├── export.py // export checkpoint files into geir/onnx +``` + +## [Script Parameters](#contents) + +```python +Major parameters in train.py and config.py are: + +--data_path: The absolute full path to the train and evaluation datasets. +--epoch_size: Total training epochs. +--batch_size: Training batch size. +--lr_init: Initial learning rate. +--num_classes: The number of classes in the training set. +--weight_decay: Weight decay value. +--image_height: Image height used as input to the model. +--image_width: Image width used as input the model. +--pre_trained: Whether training from scratch or training based on the + pre-trained model.Optional values are True, False. +--device_target: Device where the code will be implemented. Optional values + are "Ascend", "GPU". +--device_id: Device ID used to train or evaluate the dataset. Ignore it + when you use run_train.sh for distributed training. +--checkpoint_path: The absolute full path to the checkpoint file saved + after training. +--onnx_filename: File name of the onnx model used in export.py. +--geir_filename: File name of the geir model used in export.py. +``` + + +## [Training Process](#contents) + +### Training + +``` +python train.py > train.log 2>&1 & +``` + +The python command above will run in the background, you can view the results through the file `train.log`. + +After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: + +``` +# grep "loss is " train.log +epoch: 1 step: 390, loss is 1.4842823 +epcoh: 2 step: 390, loss is 1.0897788 +... +``` + +The model checkpoint will be saved in the current directory. + +### Distributed Training + +``` +sh scripts/run_train.sh rank_table.json +``` + +The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows: + +``` +# grep "result: " train_parallel*/log +train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931 +train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874 +... +train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025 +train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336 +... +... +``` + + +## [Evaluation Process](#contents) + +### Evaluation + +Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt". + +``` +python eval.py > eval.log 2>&1 & +OR +sh scripts/run_eval.sh +``` + +The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: + +``` +# grep "accuracy: " eval.log +accuracy: {'acc': 0.934} +``` + +Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows: + +``` +# grep "accuracy: " dist.eval.log +accuracy: {'acc': 0.9217} +``` + + +# [Model Description](#contents) +## [Performance](#contents) + +### Evaluation Performance + +| Parameters | GoogleNet | +| -------------------------- | ----------------------------------------------------------- | +| Model Version | Inception V1 | +| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | +| uploaded Date | 06/09/2020 (month/day/year) | +| MindSpore Version | 0.3.0-alpha | +| Dataset | CIFAR-10 | +| Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 | +| Optimizer | SGD | +| Loss Function | Softmax Cross Entropy | +| outputs | probability | +| Loss | 0.0016 | +| Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step | +| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins | +| Parameters (M) | 13.0 | +| Checkpoint for Fine tuning | 43.07M (.ckpt file) | +| Model for inference | 21.50M (.onnx file), 21.60M(.geir file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet | + + +### Inference Performance + +| Parameters | GoogleNet | +| ------------------- | --------------------------- | +| Model Version | Inception V1 | +| Resource | Ascend 910 | +| Uploaded Date | 06/09/2020 (month/day/year) | +| MindSpore Version | 0.2.0-alpha | +| Dataset | CIFAR-10, 10,000 images | +| batch_size | 128 | +| outputs | probability | +| Accuracy | 1pc: 93.4%; 8pcs: 92.17% | +| Model for inference | 21.50M (.onnx file) | + +## [How to use](#contents) +### Inference + +If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example: + +``` +# Load unseen dataset for inference +dataset = dataset.create_dataset(cfg.data_path, 1, False) + +# Define model +net = GoogleNet(num_classes=cfg.num_classes) +opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, + cfg.momentum, weight_decay=cfg.weight_decay) +loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', + is_grad=False) +model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + +# Load pre-trained model +param_dict = load_checkpoint(cfg.checkpoint_path) +load_param_into_net(net, param_dict) +net.set_train(False) + +# Make predictions on the unseen dataset +acc = model.eval(dataset) +print("accuracy: ", acc) +``` + +### Continue Training on the Pretrained Model + +``` +# Load dataset +dataset = create_dataset(cfg.data_path, cfg.epoch_size) +batch_num = dataset.get_dataset_size() + +# Define model +net = GoogleNet(num_classes=cfg.num_classes) +# Continue training if set pre_trained to be True +if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) +lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, + steps_per_epoch=batch_num) +opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) +loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) +model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + +# Set callbacks +config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, + keep_checkpoint_max=cfg.keep_checkpoint_max) +time_cb = TimeMonitor(data_size=batch_num) +ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", + config=config_ck) +loss_cb = LossMonitor() + +# Start training +model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) +print("train success") +``` + +### Transfer Learning +To be added. + + +# [Description of Random Situation](#contents) + +In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. + + +# [ModelZoo Homepage](#contents) + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/googlenet/eval.py b/model_zoo/official/cv/googlenet/eval.py new file mode 100644 index 0000000000..31646c9713 --- /dev/null +++ b/model_zoo/official/cv/googlenet/eval.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############test googlenet example on cifar10################# +python eval.py +""" +import argparse + +import mindspore.nn as nn +from mindspore import context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import cifar_cfg as cfg +from src.dataset import create_dataset +from src.googlenet import GoogleNet + +parser = argparse.ArgumentParser(description='googlenet') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +args_opt = parser.parse_args() + +if __name__ == '__main__': + device_target = cfg.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + if device_target == "Ascend": + context.set_context(device_id=cfg.device_id) + + net = GoogleNet(num_classes=cfg.num_classes) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, + weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + if device_target == "Ascend": + param_dict = load_checkpoint(cfg.checkpoint_path) + else: # GPU + param_dict = load_checkpoint(args_opt.checkpoint_path) + + load_param_into_net(net, param_dict) + net.set_train(False) + dataset = create_dataset(cfg.data_path, 1, False) + acc = model.eval(dataset) + print("accuracy: ", acc) diff --git a/model_zoo/googlenet/export.py b/model_zoo/official/cv/googlenet/export.py similarity index 100% rename from model_zoo/googlenet/export.py rename to model_zoo/official/cv/googlenet/export.py diff --git a/model_zoo/googlenet/scripts/run_eval.sh b/model_zoo/official/cv/googlenet/scripts/run_eval.sh similarity index 100% rename from model_zoo/googlenet/scripts/run_eval.sh rename to model_zoo/official/cv/googlenet/scripts/run_eval.sh diff --git a/model_zoo/official/cv/googlenet/scripts/run_eval_gpu.sh b/model_zoo/official/cv/googlenet/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000..b2e2a38737 --- /dev/null +++ b/model_zoo/official/cv/googlenet/scripts/run_eval_gpu.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +ulimit -u unlimited + +if [ $# != 1 ] +then + echo "GPU: sh run_eval_gpu.sh [CHECKPOINT_PATH]" +exit 1 +fi + +# check checkpoint file +if [ ! -f $1 ] +then + echo "error: CHECKPOINT_PATH=$1 is not a file" +exit 1 +fi + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +export DEVICE_ID=0 + +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +python3 ${BASEPATH}/../eval.py --checkpoint_path=$1 > ./eval.log 2>&1 & diff --git a/model_zoo/official/cv/googlenet/scripts/run_train.sh b/model_zoo/official/cv/googlenet/scripts/run_train.sh new file mode 100644 index 0000000000..ed8a0e5f2a --- /dev/null +++ b/model_zoo/official/cv/googlenet/scripts/run_train.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ] +then + echo "Usage: sh run_train.sh [RANK_TABLE_FILE]" +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +RANK_TABLE_FILE=$(realpath $1) +export RANK_TABLE_FILE +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" + +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp -r ./src ./train_parallel$i + cp ./train.py ./train_parallel$i + echo "start training for rank $RANK_ID, device $DEVICE_ID" + cd ./train_parallel$i ||exit + env > env.log + python train.py --device_id=$i > log 2>&1 & + cd .. +done diff --git a/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh b/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh new file mode 100644 index 0000000000..b30160238d --- /dev/null +++ b/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -lt 2 ] +then + echo "Usage:\n \ + sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\n \ + " +exit 1 +fi + +if [ $1 -lt 1 ] && [ $1 -gt 8 ] +then + echo "error: DEVICE_NUM=$1 is not in (1-8)" +exit 1 +fi + +export DEVICE_NUM=$1 +export RANK_SIZE=$1 + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +if [ -d "../train" ]; +then + rm -rf ../train +fi +mkdir ../train +cd ../train || exit + +export CUDA_VISIBLE_DEVICES="$2" + +if [ $1 -gt 1 ] +then + mpirun -n $1 --allow-run-as-root \ + python3 ${BASEPATH}/../train.py > train.log 2>&1 & +else + python3 ${BASEPATH}/../train.py > train.log 2>&1 & +fi diff --git a/model_zoo/googlenet/src/config.py b/model_zoo/official/cv/googlenet/src/config.py similarity index 100% rename from model_zoo/googlenet/src/config.py rename to model_zoo/official/cv/googlenet/src/config.py diff --git a/model_zoo/official/cv/googlenet/src/dataset.py b/model_zoo/official/cv/googlenet/src/dataset.py new file mode 100644 index 0000000000..cc33d2e594 --- /dev/null +++ b/model_zoo/official/cv/googlenet/src/dataset.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision +from src.config import cifar_cfg as cfg + + +def create_dataset(data_home, repeat_num=1, training=True): + """Data operations.""" + ds.config.set_seed(1) + data_dir = os.path.join(data_home, "cifar-10-batches-bin") + if not training: + data_dir = os.path.join(data_home, "cifar-10-verify-bin") + + rank_size, rank_id = _get_rank_info() + if training: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True) + else: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) + + resize_height = cfg.image_height + resize_width = cfg.image_width + + # define map operations + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + rescale_op = vision.Rescale(1.0/255.0, 0.0) + normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply batch operations + data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + return data_set + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + from mindspore.communication.management import get_rank, get_group_size + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = rank_id = None + + return rank_size, rank_id diff --git a/model_zoo/googlenet/src/googlenet.py b/model_zoo/official/cv/googlenet/src/googlenet.py similarity index 100% rename from model_zoo/googlenet/src/googlenet.py rename to model_zoo/official/cv/googlenet/src/googlenet.py diff --git a/model_zoo/official/cv/googlenet/train.py b/model_zoo/official/cv/googlenet/train.py new file mode 100644 index 0000000000..50e3ec7bc2 --- /dev/null +++ b/model_zoo/official/cv/googlenet/train.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train googlent example on cifar10######################## +python train.py +""" +import argparse +import os +import random + +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init, get_rank +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.model import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import cifar_cfg as cfg +from src.dataset import create_dataset +from src.googlenet import GoogleNet + +random.seed(1) +np.random.seed(1) + +def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): + """Set learning rate.""" + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr_each_step.append(lr_max) + elif i < decay_epoch_index[1]: + lr_each_step.append(lr_max * 0.1) + elif i < decay_epoch_index[2]: + lr_each_step.append(lr_max * 0.01) + else: + lr_each_step.append(lr_max * 0.001) + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Cifar10 classification') + parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') + args_opt = parser.parse_args() + + device_target = cfg.device_target + + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + device_num = int(os.environ.get("DEVICE_NUM", 1)) + + if device_target == "Ascend": + if args_opt.device_id is not None: + context.set_context(device_id=args_opt.device_id) + else: + context.set_context(device_id=cfg.device_id) + + if device_num > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + init() + elif device_target == "GPU": + init("nccl") + + if device_num > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + else: + raise ValueError("Unsupport platform.") + + dataset = create_dataset(cfg.data_path, 1) + batch_num = dataset.get_dataset_size() + + net = GoogleNet(num_classes=cfg.num_classes) + # Continue training if set pre_trained to be True + if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, + weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + + if device_target == "Ascend": + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + ckpt_save_dir = "./" + else: # GPU + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) + ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" + + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) + time_cb = TimeMonitor(data_size=batch_num) + ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck) + loss_cb = LossMonitor() + model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("train success") diff --git a/model_zoo/official/cv/googlenet_quant/.gitkeep b/model_zoo/official/cv/googlenet_quant/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/inceptionv3/README.md b/model_zoo/official/cv/inceptionv3/README.md new file mode 100644 index 0000000000..0d84497ac5 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/README.md @@ -0,0 +1,115 @@ +# Inception-v3 Example + +## Description + +This is an example of training Inception-v3 in MindSpore. + +## Requirements + +- Install [Mindspore](http://www.mindspore.cn/install/en). +- Downlaod the dataset. + +## Structure + +```shell +. +└─Inception-v3 + ├─README.md + ├─scripts + ├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p) + ├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p) + └─run_eval_for_gpu.sh # launch evaluating with gpu platform + ├─src + ├─config.py # parameter configuration + ├─dataset.py # data preprocessing + ├─inception_v3.py # network definition + ├─loss.py # Customized CrossEntropy loss function + ├─lr_generator.py # learning rate generator + ├─eval.py # eval net + ├─export.py # convert checkpoint + └─train.py # train net + +``` + +## Parameter Configuration + +Parameters for both training and evaluating can be set in config.py + +``` +'random_seed': 1, # fix random seed +'rank': 0, # local rank of distributed +'group_size': 1, # world size of distributed +'work_nums': 8, # number of workers to read the data +'decay_method': 'cosine', # learning rate scheduler mode +"loss_scale": 1, # loss scale +'batch_size': 128, # input batchsize +'epoch_size': 250, # total epoch numbers +'num_classes': 1000, # dataset class numbers +'smooth_factor': 0.1, # label smoothing factor +'aux_factor': 0.2, # loss factor of aux logit +'lr_init': 0.00004, # initiate learning rate +'lr_max': 0.4, # max bound of learning rate +'lr_end': 0.000004, # min bound of learning rate +'warmup_epochs': 1, # warmup epoch numbers +'weight_decay': 0.00004, # weight decay +'momentum': 0.9, # momentum +'opt_eps': 1.0, # epsilon +'keep_checkpoint_max': 100, # max numbers to keep checkpoints +'ckpt_path': './checkpoint/', # save checkpoint path +'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters +``` + + + +## Running the example + +### Train + +#### Usage + +``` +# distribute training example(8p) +sh run_distribute_train_for_gpu.sh DATA_DIR +# standalone training +sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR +``` + +#### Launch + +```bash +# distributed training example(8p) for GPU +sh scripts/run_distribute_train_for_gpu.sh /dataset/train +# standalone training example for GPU +sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train +``` + +#### Result + +You can find checkpoint file together with result in log. + +### Evaluation + +#### Usage + +``` +# Evaluation +sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT +``` + +#### Launch + +```bash +# Evaluation with checkpoint +sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/inceptionv3-rank3-247_1251.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. + +``` +acc=78.75%(TOP1) +acc=94.07%(TOP5) +``` \ No newline at end of file diff --git a/model_zoo/official/cv/inceptionv3/eval.py b/model_zoo/official/cv/inceptionv3/eval.py new file mode 100644 index 0000000000..a2f0ade1d3 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/eval.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate_imagenet""" +import argparse +import os + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import config_gpu as cfg +from src.dataset import create_dataset +from src.inception_v3 import InceptionV3 +from src.loss import CrossEntropy_Val + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification evaluation') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform) + net = InceptionV3(num_classes=cfg.num_classes, is_training=False) + ckpt = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, ckpt) + net.set_train(False) + dataset = create_dataset(args_opt.dataset_path, False, 0, 1) + loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + model = Model(net, loss, optimizer=None, metrics=eval_metrics) + metrics = model.eval(dataset) + print("metric: ", metrics) diff --git a/model_zoo/official/cv/inceptionv3/export.py b/model_zoo/official/cv/inceptionv3/export.py new file mode 100644 index 0000000000..302ff1302a --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/export.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into geir and onnx models################# +""" +import argparse +import numpy as np + +import mindspore as ms +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export + +from src.config import config_gpu as cfg +from src.inception_v3 import InceptionV3 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='checkpoint export') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)') + args_opt = parser.parse_args() + + net = InceptionV3(num_classes=cfg.num_classes, is_training=False) + param_dict = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32) + export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX") + export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR") diff --git a/model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 0000000000..305f1dcfff --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 +mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & diff --git a/model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh b/model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh new file mode 100644 index 0000000000..0ecd63a434 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT=$3 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 & diff --git a/model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh new file mode 100644 index 0000000000..7b856bbcf9 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & + diff --git a/model_zoo/official/cv/inceptionv3/src/config.py b/model_zoo/official/cv/inceptionv3/src/config.py new file mode 100644 index 0000000000..b465a7543a --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/config.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in main.py +""" +from easydict import EasyDict as edict + + +config_gpu = edict({ + 'random_seed': 1, + 'rank': 0, + 'group_size': 1, + 'work_nums': 8, + 'decay_method': 'cosine', + "loss_scale": 1, + 'batch_size': 128, + 'epoch_size': 250, + 'num_classes': 1000, + 'smooth_factor': 0.1, + 'aux_factor': 0.2, + 'lr_init': 0.00004, + 'lr_max': 0.4, + 'lr_end': 0.000004, + 'warmup_epochs': 1, + 'weight_decay': 0.00004, + 'momentum': 0.9, + 'opt_eps': 1.0, + 'keep_checkpoint_max': 100, + 'ckpt_path': './checkpoint/', + 'is_save_on_master': 0 +}) diff --git a/model_zoo/official/cv/inceptionv3/src/dataset.py b/model_zoo/official/cv/inceptionv3/src/dataset.py new file mode 100644 index 0000000000..73c84bc959 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/dataset.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.c_transforms as C +from src.config import config_gpu as cfg + + +def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided into (default=None). + repeat_num(int): the repeat times of dataset. Default: 1. + + Returns: + dataset + """ + if group_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True, + num_shards=group_size, shard_id=rank) + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(299, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + ] + else: + trans = [ + C.Decode(), + C.Resize(299), + C.CenterCrop(299) + ] + trans += [ + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + C.HWC2CHW() + ] + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums) + # apply batch operations + ds = ds.batch(cfg.batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/official/cv/inceptionv3/src/inception_v3.py b/model_zoo/official/cv/inceptionv3/src/inception_v3.py new file mode 100644 index 0000000000..f1339b1c88 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/inception_v3.py @@ -0,0 +1,257 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inception-v3 model definition""" +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.initializer import XavierUniform + + +class BasicConv2d(nn.Cell): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, pad_mode='same', padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, + pad_mode=pad_mode, padding=padding, weight_init=XavierUniform(), has_bias=True) + self.bn = nn.BatchNorm2d(out_channel, eps=0.001, momentum=0.9997) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Inception_A(nn.Cell): + def __init__(self, in_channels, pool_features): + super(Inception_A, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 64, kernel_size=1) + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, 48, kernel_size=1), + BasicConv2d(48, 64, kernel_size=5) + ]) + self.branch2 = nn.SequentialCell([ + BasicConv2d(in_channels, 64, kernel_size=1), + BasicConv2d(64, 96, kernel_size=3), + BasicConv2d(96, 96, kernel_size=3) + + ]) + self.branch_pool = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=3, pad_mode='same'), + BasicConv2d(in_channels, pool_features, kernel_size=1) + ]) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, x2, branch_pool)) + return out + + +class Inception_B(nn.Cell): + def __init__(self, in_channels): + super(Inception_B, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2, pad_mode='valid') + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, 64, kernel_size=1), + BasicConv2d(64, 96, kernel_size=3), + BasicConv2d(96, 96, kernel_size=3, stride=2, pad_mode='valid') + + ]) + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, branch_pool)) + return out + + +class Inception_C(nn.Cell): + def __init__(self, in_channels, channels_7x7): + super(Inception_C, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 192, kernel_size=1) + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, channels_7x7, kernel_size=1), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)), + BasicConv2d(channels_7x7, 192, kernel_size=(7, 1)) + ]) + self.branch2 = nn.SequentialCell([ + BasicConv2d(in_channels, channels_7x7, kernel_size=1), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)), + BasicConv2d(channels_7x7, 192, kernel_size=(1, 7)) + ]) + self.branch_pool = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=3, pad_mode='same'), + BasicConv2d(in_channels, 192, kernel_size=1) + ]) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, x2, branch_pool)) + return out + + +class Inception_D(nn.Cell): + def __init__(self, in_channels): + super(Inception_D, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = nn.SequentialCell([ + BasicConv2d(in_channels, 192, kernel_size=1), + BasicConv2d(192, 320, kernel_size=3, stride=2, pad_mode='valid') + ]) + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, 192, kernel_size=1), + BasicConv2d(192, 192, kernel_size=(1, 7)), # check + BasicConv2d(192, 192, kernel_size=(7, 1)), + BasicConv2d(192, 192, kernel_size=3, stride=2, pad_mode='valid') + ]) + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, branch_pool)) + return out + + +class Inception_E(nn.Cell): + def __init__(self, in_channels): + super(Inception_E, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 320, kernel_size=1) + self.branch1 = BasicConv2d(in_channels, 384, kernel_size=1) + self.branch1_a = BasicConv2d(384, 384, kernel_size=(1, 3)) + self.branch1_b = BasicConv2d(384, 384, kernel_size=(3, 1)) + self.branch2 = nn.SequentialCell([ + BasicConv2d(in_channels, 448, kernel_size=1), + BasicConv2d(448, 384, kernel_size=3) + ]) + self.branch2_a = BasicConv2d(384, 384, kernel_size=(1, 3)) + self.branch2_b = BasicConv2d(384, 384, kernel_size=(3, 1)) + self.branch_pool = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=3, pad_mode='same'), + BasicConv2d(in_channels, 192, kernel_size=1) + ]) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x1 = self.concat((self.branch1_a(x1), self.branch1_b(x1))) + x2 = self.branch2(x) + x2 = self.concat((self.branch2_a(x2), self.branch2_b(x2))) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, x2, branch_pool)) + return out + + +class Logits(nn.Cell): + def __init__(self, num_classes=10, dropout_keep_prob=0.8): + super(Logits, self).__init__() + self.avg_pool = nn.AvgPool2d(8, pad_mode='valid') + self.dropout = nn.Dropout(keep_prob=dropout_keep_prob) + self.flatten = P.Flatten() + self.fc = nn.Dense(2048, num_classes) + + def construct(self, x): + x = self.avg_pool(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class AuxLogits(nn.Cell): + def __init__(self, in_channels, num_classes=10): + super(AuxLogits, self).__init__() + self.avg_pool = nn.AvgPool2d(5, stride=3, pad_mode='valid') + self.conv2d_0 = nn.Conv2d(in_channels, 128, kernel_size=1) + self.conv2d_1 = nn.Conv2d(128, 768, kernel_size=5, pad_mode='valid') + self.flatten = P.Flatten() + self.fc = nn.Dense(in_channels, num_classes) + + def construct(self, x): + x = self.avg_pool(x) + x = self.conv2d_0(x) + x = self.conv2d_1(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class InceptionV3(nn.Cell): + def __init__(self, num_classes=10, is_training=True): + super(InceptionV3, self).__init__() + self.is_training = is_training + self.Conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2, pad_mode='valid') + self.Conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1, pad_mode='valid') + self.Conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a = BasicConv2d(80, 192, kernel_size=3, pad_mode='valid') + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = Inception_A(192, pool_features=32) + self.Mixed_5c = Inception_A(256, pool_features=64) + self.Mixed_5d = Inception_A(288, pool_features=64) + self.Mixed_6a = Inception_B(288) + self.Mixed_6b = Inception_C(768, channels_7x7=128) + self.Mixed_6c = Inception_C(768, channels_7x7=160) + self.Mixed_6d = Inception_C(768, channels_7x7=160) + self.Mixed_6e = Inception_C(768, channels_7x7=192) + self.Mixed_7a = Inception_D(768) + self.Mixed_7b = Inception_E(1280) + self.Mixed_7c = Inception_E(2048) + if is_training: + self.aux_logits = AuxLogits(768, num_classes) + self.logits = Logits(num_classes, dropout_keep_prob=0.5) + + def construct(self, x): + x = self.Conv2d_1a(x) + x = self.Conv2d_2a(x) + x = self.Conv2d_2b(x) + x = self.maxpool1(x) + x = self.Conv2d_3b(x) + x = self.Conv2d_4a(x) + x = self.maxpool2(x) + x = self.Mixed_5b(x) + x = self.Mixed_5c(x) + x = self.Mixed_5d(x) + x = self.Mixed_6a(x) + x = self.Mixed_6b(x) + x = self.Mixed_6c(x) + x = self.Mixed_6d(x) + x = self.Mixed_6e(x) + if self.is_training: + aux_logits = self.aux_logits(x) + else: + aux_logits = None + x = self.Mixed_7a(x) + x = self.Mixed_7b(x) + x = self.Mixed_7c(x) + logits = self.logits(x) + if self.is_training: + return logits, aux_logits + return logits diff --git a/model_zoo/official/cv/inceptionv3/src/loss.py b/model_zoo/official/cv/inceptionv3/src/loss.py new file mode 100644 index 0000000000..413e1f0f39 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/loss.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): + super(CrossEntropy, self).__init__() + self.factor = factor + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + logit, aux = logits + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss_logit = self.ce(logit, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) + loss_aux = self.ce(aux, one_hot_label_aux) + loss_aux = self.mean(loss_aux, 0) + return loss_logit + self.factor*loss_aux + + +class CrossEntropy_Val(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" + def __init__(self, smooth_factor=0, num_classes=1000): + super(CrossEntropy_Val, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/model_zoo/official/cv/inceptionv3/src/lr_generator.py b/model_zoo/official/cv/inceptionv3/src/lr_generator.py new file mode 100644 index 0000000000..7a057f7251 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/lr_generator.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'steps_decay': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + decay_nums = math.floor((float(i-warmup_steps)/steps_per_epoch) / 2) + decay_rate = pow(0.94, decay_nums) + lr = float(lr_max)*decay_rate + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps)) + lr = (lr_max-lr_end)*cosine_decay + lr_end + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + learning_rate = np.array(lr_each_step).astype(np.float32) + return learning_rate diff --git a/model_zoo/official/cv/inceptionv3/train.py b/model_zoo/official/cv/inceptionv3/train.py new file mode 100644 index 0000000000..bca02e6ee0 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/train.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import argparse +import os +import random +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore import ParallelMode +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.nn.optim.rmsprop import RMSProp +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import dataset as de + +from src.config import config_gpu as cfg +from src.dataset import create_dataset +from src.inception_v3 import InceptionV3 +from src.lr_generator import get_lr +from src.loss import CrossEntropy + +random.seed(cfg.random_seed) +np.random.seed(cfg.random_seed) +de.config.set_seed(cfg.random_seed) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification training') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') + parser.add_argument('--is_distributed', action='store_true', default=False, + help='distributed training') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + # init distributed + if args_opt.is_distributed: + if args_opt.platform == "Ascend": + init() + else: + init("nccl") + cfg.rank = get_rank() + cfg.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + cfg.rank = 0 + cfg.group_size = 1 + + # dataloader + dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size) + batches_per_epoch = dataset.get_dataset_size() + + # network + net = InceptionV3(num_classes=cfg.num_classes) + + # loss + loss = CrossEntropy(smooth_factor=cfg.smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor) + + # learning rate schedule + lr = get_lr(lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, + total_epochs=cfg.epoch_size, steps_per_epoch=batches_per_epoch, lr_decay_mode=cfg.decay_method) + lr = Tensor(lr) + + # optimizer + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + optimizer = RMSProp(group_params, lr, decay=0.9, weight_decay=cfg.weight_decay, + momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + + if args_opt.resume: + ckpt = load_checkpoint(args_opt.resume) + load_param_into_net(net, ckpt) + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}) + + print("============== Starting Training ==============") + loss_cb = LossMonitor(per_print_times=batches_per_epoch) + time_cb = TimeMonitor(data_size=batches_per_epoch) + callbacks = [loss_cb, time_cb] + config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck) + if args_opt.is_distributed & cfg.is_save_on_master: + if cfg.rank == 0: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + else: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + print("train success") diff --git a/model_zoo/lenet/README.md b/model_zoo/official/cv/lenet/README.md similarity index 100% rename from model_zoo/lenet/README.md rename to model_zoo/official/cv/lenet/README.md diff --git a/model_zoo/lenet/eval.py b/model_zoo/official/cv/lenet/eval.py similarity index 100% rename from model_zoo/lenet/eval.py rename to model_zoo/official/cv/lenet/eval.py diff --git a/model_zoo/official/cv/lenet/src/__init__.py b/model_zoo/official/cv/lenet/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/lenet/src/config.py b/model_zoo/official/cv/lenet/src/config.py similarity index 100% rename from model_zoo/lenet/src/config.py rename to model_zoo/official/cv/lenet/src/config.py diff --git a/model_zoo/lenet/src/dataset.py b/model_zoo/official/cv/lenet/src/dataset.py similarity index 100% rename from model_zoo/lenet/src/dataset.py rename to model_zoo/official/cv/lenet/src/dataset.py diff --git a/model_zoo/lenet/src/lenet.py b/model_zoo/official/cv/lenet/src/lenet.py similarity index 100% rename from model_zoo/lenet/src/lenet.py rename to model_zoo/official/cv/lenet/src/lenet.py diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py new file mode 100644 index 0000000000..2282842188 --- /dev/null +++ b/model_zoo/official/cv/lenet/train.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## train lenet example ######################## +train lenet and get network model files(.ckpt) : +python train.py --data_path /YourDataPath +""" + +import os +import argparse +from src.config import mnist_cfg as cfg +from src.dataset import create_dataset +from src.lenet import LeNet5 +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore Lenet Example') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--data_path', type=str, default="./Data", + help='path where the dataset is saved') + parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + + args = parser.parse_args() + + if args.device_target == "CPU": + args.dataset_sink_mode = False + + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + ds_train = create_dataset(os.path.join(args.data_path, "train"), + cfg.batch_size) + + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Training ==============") + model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], + dataset_sink_mode=args.dataset_sink_mode) diff --git a/model_zoo/lenet_quant/README.md b/model_zoo/official/cv/lenet_quant/Readme.md similarity index 100% rename from model_zoo/lenet_quant/README.md rename to model_zoo/official/cv/lenet_quant/Readme.md diff --git a/model_zoo/lenet_quant/eval.py b/model_zoo/official/cv/lenet_quant/eval.py similarity index 100% rename from model_zoo/lenet_quant/eval.py rename to model_zoo/official/cv/lenet_quant/eval.py diff --git a/model_zoo/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py similarity index 100% rename from model_zoo/lenet_quant/eval_quant.py rename to model_zoo/official/cv/lenet_quant/eval_quant.py diff --git a/model_zoo/lenet_quant/export.py b/model_zoo/official/cv/lenet_quant/export.py similarity index 100% rename from model_zoo/lenet_quant/export.py rename to model_zoo/official/cv/lenet_quant/export.py diff --git a/model_zoo/lenet_quant/src/config.py b/model_zoo/official/cv/lenet_quant/src/config.py similarity index 100% rename from model_zoo/lenet_quant/src/config.py rename to model_zoo/official/cv/lenet_quant/src/config.py diff --git a/model_zoo/lenet_quant/src/dataset.py b/model_zoo/official/cv/lenet_quant/src/dataset.py similarity index 100% rename from model_zoo/lenet_quant/src/dataset.py rename to model_zoo/official/cv/lenet_quant/src/dataset.py diff --git a/model_zoo/lenet_quant/src/lenet.py b/model_zoo/official/cv/lenet_quant/src/lenet.py similarity index 100% rename from model_zoo/lenet_quant/src/lenet.py rename to model_zoo/official/cv/lenet_quant/src/lenet.py diff --git a/model_zoo/lenet_quant/src/lenet_fusion.py b/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py similarity index 100% rename from model_zoo/lenet_quant/src/lenet_fusion.py rename to model_zoo/official/cv/lenet_quant/src/lenet_fusion.py diff --git a/model_zoo/official/cv/lenet_quant/src/loss_monitor.py b/model_zoo/official/cv/lenet_quant/src/loss_monitor.py new file mode 100644 index 0000000000..4bb8400ae2 --- /dev/null +++ b/model_zoo/official/cv/lenet_quant/src/loss_monitor.py @@ -0,0 +1,104 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LossMonitor Callback class.""" + +import time +import numpy as np +from mindspore.common.tensor import Tensor +from mindspore.train.callback import Callback + + +class LossMonitor(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, it will terminate training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + lr_init (numpy array): train learning rate. Default: None. + + Raises: + ValueError: If print_step is not int or less than zero. + + Examples: + >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, per_print_times=1, lr_init=None): + super(LossMonitor, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.lr_init = lr_init + + def epoch_begin(self, run_context): + """ + epoch begin + """ + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + """ + epoch end + """ + cb_params = run_context.original_args() + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("Epoch time: {:5.3f}, per step time: {:5.3f}, " + "avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + print("*" * 60) + + def step_begin(self, run_context): + """ + step begin + """ + self.step_time = time.time() + + def step_end(self, run_context): + """ + step end + """ + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 + + if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): + raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " + "Invalid loss, terminating training.".format( + cb_params.cur_epoch_num - 1, cb_params.epoch_num, + cur_step_in_epoch, cb_params.batch_num)) + + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " + "loss: [{:5.4f}], avg loss: [{:5.4f}], time: [{:5.4f}ms]".format( + cb_params.cur_epoch_num, cb_params.epoch_num, + cur_step_in_epoch, int(cb_params.batch_num), + step_loss, np.mean(self.losses), + step_mseconds), flush=True) diff --git a/model_zoo/official/cv/lenet_quant/train.py b/model_zoo/official/cv/lenet_quant/train.py new file mode 100644 index 0000000000..66546b15c0 --- /dev/null +++ b/model_zoo/official/cv/lenet_quant/train.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## train lenet example ######################## +train lenet and get network model files(.ckpt) : +python train.py --data_path /YourDataPath +""" + +import os +import argparse +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from src.dataset import create_dataset +from src.config import mnist_cfg as cfg +from src.lenet_fusion import LeNet5 as LeNet5Fusion +from src.loss_monitor import LossMonitor + +parser = argparse.ArgumentParser(description='MindSpore MNIST Example') +parser.add_argument('--device_target', type=str, default="Ascend", + choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--data_path', type=str, default="./MNIST_Data", + help='path where the dataset is saved') +parser.add_argument('--ckpt_path', type=str, default="", + help='if mode is test, must provide path where the trained ckpt file') +parser.add_argument('--dataset_sink_mode', type=bool, default=True, + help='dataset_sink_mode is False or True') +args = parser.parse_args() + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1) + step_size = ds_train.get_dataset_size() + + # define fusion network + network = LeNet5Fusion(cfg.num_classes) + # define network loss + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + # define network optimization + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + + # call back and monitor + config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) + + # define model + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Training ==============") + model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], + dataset_sink_mode=args.dataset_sink_mode) + print("============== End Training ==============") diff --git a/model_zoo/official/cv/lenet_quant/train_quant.py b/model_zoo/official/cv/lenet_quant/train_quant.py new file mode 100644 index 0000000000..33c322f4b5 --- /dev/null +++ b/model_zoo/official/cv/lenet_quant/train_quant.py @@ -0,0 +1,78 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## train lenet example ######################## +train lenet and get network model files(.ckpt) : +python train.py --data_path /YourDataPath +""" + +import os +import argparse +import mindspore.nn as nn +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.train.quant import quant +from src.dataset import create_dataset +from src.config import mnist_cfg as cfg +from src.lenet_fusion import LeNet5 as LeNet5Fusion +from src.loss_monitor import LossMonitor + +parser = argparse.ArgumentParser(description='MindSpore MNIST Example') +parser.add_argument('--device_target', type=str, default="Ascend", + choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--data_path', type=str, default="./MNIST_Data", + help='path where the dataset is saved') +parser.add_argument('--ckpt_path', type=str, default="", + help='if mode is test, must provide path where the trained ckpt file') +parser.add_argument('--dataset_sink_mode', type=bool, default=True, + help='dataset_sink_mode is False or True') +args = parser.parse_args() + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1) + step_size = ds_train.get_dataset_size() + + # define fusion network + network = LeNet5Fusion(cfg.num_classes) + + # load quantization aware network checkpoint + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(network, param_dict) + + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + + # define network loss + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + # define network optimization + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + + # call back and monitor + config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) + + # define model + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Training ==============") + model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], + dataset_sink_mode=args.dataset_sink_mode) + print("============== End Training ==============") diff --git a/model_zoo/official/cv/maskrcnn/README.md b/model_zoo/official/cv/maskrcnn/README.md new file mode 100644 index 0000000000..fadce5ad32 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/README.md @@ -0,0 +1,161 @@ +# MaskRcnn Example + +## Description + +MaskRcnn is a two-stage target detection network,This network uses a region proposal network (RPN), which can share the convolution features of the whole image with the detection network, so that the calculation of region proposal is almost cost free. The whole network further combines RPN and MaskRcnn into a network by sharing the convolution features. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset COCO2017. + +- We use coco2017 as training dataset in this example by default, and you can also use your own datasets. + + 1. If coco dataset is used. **Select dataset to coco when run script.** + Install Cython and pycocotool, and you can also install mmcv to process data. + + ``` + pip install Cython + + pip install pycocotools + + pip install mmcv + ``` + And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: + + + ``` + . + └─cocodataset + ├─annotations + ├─instance_train2017.json + └─instance_val2017.json + ├─val2017 + └─train2017 + + ``` + Notice that the coco2017 dataset will be converted to MindRecord which is a data format in MindSpore. The dataset conversion may take about 4 hours. + 2. If your own dataset is used. **Select dataset to other when run script.** + Organize the dataset infomation into a TXT file, each row in the file is as follows: + + ``` + train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 + ``` + + Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. + + +## Example structure + +```shell +. +└─MaskRcnn + ├─README.md + ├─scripts + ├─run_download_process_data.sh + ├─run_standalone_train.sh + ├─run_train.sh + └─run_eval.sh + ├─src + ├─MaskRcnn + ├─__init__.py + ├─anchor_generator.py + ├─bbox_assign_sample.py + ├─bbox_assign_sample_stage2.py + ├─mask_rcnn_r50.py + ├─fpn_neck.py + ├─proposal_generator.py + ├─rcnn_cls.py + ├─rcnn_mask.py + ├─resnet50.py + ├─roi_align.py + └─rpn.py + ├─config.py + ├─dataset.py + ├─lr_schedule.py + ├─network_define.py + └─util.py + ├─eval.py + └─train.py +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +sh run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL] + +# standalone training +sh run_standalone_train.sh [PRETRAINED_MODEL] +``` + +> hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). +> As for PRETRAINED_MODEL,if not set, the model will be trained from the very beginning.Ready-made pretrained_models are not available now. Stay tuned. + +#### Result + +Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log. + + +``` +# distribute training result(8p) +epoch: 1 step: 7393 ,rpn_loss: 0.10626, rcnn_loss: 0.81592, rpn_cls_loss: 0.05862, rpn_reg_loss: 0.04761, rcnn_cls_loss: 0.32642, rcnn_reg_loss: 0.15503, rcnn_mask_loss: 0.33447, total_loss: 0.92218 +epoch: 2 step: 7393 ,rpn_loss: 0.00911, rcnn_loss: 0.34082, rpn_cls_loss: 0.00341, rpn_reg_loss: 0.00571, rcnn_cls_loss: 0.07440, rcnn_reg_loss: 0.05872, rcnn_mask_loss: 0.20764, total_loss: 0.34993 +epoch: 3 step: 7393 ,rpn_loss: 0.02087, rcnn_loss: 0.98633, rpn_cls_loss: 0.00665, rpn_reg_loss: 0.01422, rcnn_cls_loss: 0.35913, rcnn_reg_loss: 0.21375, rcnn_mask_loss: 0.41382, total_loss: 1.00720 +... +epoch: 10 step: 7393 ,rpn_loss: 0.02122, rcnn_loss: 0.55176, rpn_cls_loss: 0.00620, rpn_reg_loss: 0.01503, rcnn_cls_loss: 0.12708, rcnn_reg_loss: 0.10254, rcnn_mask_loss: 0.32227, total_loss: 0.57298 +epoch: 11 step: 7393 ,rpn_loss: 0.03772, rcnn_loss: 0.60791, rpn_cls_loss: 0.03058, rpn_reg_loss: 0.00713, rcnn_cls_loss: 0.23987, rcnn_reg_loss: 0.11743, rcnn_mask_loss: 0.25049, total_loss: 0.64563 +epoch: 12 step: 7393 ,rpn_loss: 0.06482, rcnn_loss: 0.47681, rpn_cls_loss: 0.04770, rpn_reg_loss: 0.01709, rcnn_cls_loss: 0.16492, rcnn_reg_loss: 0.04990, rcnn_mask_loss: 0.26196, total_loss: 0.54163 +``` + +### Evaluation + +#### Usage + +``` +# infer +sh run_eval.sh [VALIDATION_ANN_FILE_JSON] [CHECKPOINT_PATH] +``` +> As for the COCO2017 dataset, VALIDATION_ANN_FILE_JSON is refer to the annotations/instances_val2017.json in the dataset directory. +> checkpoint can be produced and saved in training process, whose folder name begins with "train/checkpoint" or "train_parallel*/checkpoint". + +#### Result + +Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. + +``` +Evaluate annotation type *bbox* +Accumulating evaluation results... + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.366 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.591 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.393 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.241 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.405 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.454 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.304 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.492 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.521 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.372 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.560 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637 + +Evaluate annotation type *segm* +Accumulating evaluation results... + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.318 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.546 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.332 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.165 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.348 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.449 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.272 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.421 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.440 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.292 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.479 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558 +``` \ No newline at end of file diff --git a/model_zoo/official/cv/maskrcnn/eval.py b/model_zoo/official/cv/maskrcnn/eval.py new file mode 100644 index 0000000000..36289e6480 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/eval.py @@ -0,0 +1,135 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Evaluation for MaskRcnn""" +import os +import argparse +import time +import random +import numpy as np +from pycocotools.coco import COCO +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de + +from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 +from src.config import config +from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset +from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="MaskRcnn evaluation") +parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") +parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.") +parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) + +def MaskRcnn_eval(dataset_path, ckpt_path, ann_file): + """MaskRcnn evaluation.""" + ds = create_maskrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False) + + net = Mask_Rcnn_Resnet50(config) + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + eval_iter = 0 + total = ds.get_dataset_size() + outputs = [] + dataset_coco = COCO(ann_file) + + print("\n========================================\n") + print("total images num: ", total) + print("Processing, please wait a moment.") + max_num = 128 + for data in ds.create_dict_iterator(): + eval_iter = eval_iter + 1 + + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + gt_mask = data["mask"] + + start = time.time() + # run net + output = net(Tensor(img_data), Tensor(img_metas), Tensor(gt_bboxes), Tensor(gt_labels), Tensor(gt_num), + Tensor(gt_mask)) + end = time.time() + print("Iter {} cost time {}".format(eval_iter, end - start)) + + # output + all_bbox = output[0] + all_label = output[1] + all_mask = output[2] + all_mask_fb = output[3] + + for j in range(config.test_batch_size): + all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :]) + all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :]) + all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :]) + all_mask_fb_squee = np.squeeze(all_mask_fb.asnumpy()[j, :, :, :]) + + all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] + all_labels_tmp_mask = all_label_squee[all_mask_squee] + all_mask_fb_tmp_mask = all_mask_fb_squee[all_mask_squee, :, :] + + if all_bboxes_tmp_mask.shape[0] > max_num: + inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) + inds = inds[:max_num] + all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] + all_labels_tmp_mask = all_labels_tmp_mask[inds] + all_mask_fb_tmp_mask = all_mask_fb_tmp_mask[inds] + + bbox_results = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes) + segm_results = get_seg_masks(all_mask_fb_tmp_mask, all_bboxes_tmp_mask, all_labels_tmp_mask, img_metas[j], + True, config.num_classes) + outputs.append((bbox_results, segm_results)) + + eval_types = ["bbox", "segm"] + result_files = results2json(dataset_coco, outputs, "./results.pkl") + coco_eval(result_files, eval_types, dataset_coco, single_result=False) + +if __name__ == '__main__': + prefix = "MaskRcnn_eval.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix) + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + else: + if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", False, prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + print("Start Eval!") + MaskRcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file) diff --git a/model_zoo/official/cv/maskrcnn/scripts/run_distribute_train.sh b/model_zoo/official/cv/maskrcnn/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..ab4a172f6e --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/scripts/run_distribute_train.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [PRETRAINED_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +echo $PATH1 +echo $PATH2 + +if [ ! -f $PATH1 ] +then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +echo 3 > /proc/sys/vm/drop_caches + +cpus=`cat /proc/cpuinfo| grep "processor"| wc -l` +avg=`expr $cpus \/ $RANK_SIZE` +gap=`expr $avg \- 1` + +for((i=0; i<${DEVICE_NUM}; i++)) +do + start=`expr $i \* $avg` + end=`expr $start \+ $gap` + cmdopt=$start"-"$end + + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + taskset -c $cmdopt python train.py --do_train=True --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM \ + --pre_trained=$PATH2 &> log & + cd .. +done diff --git a/model_zoo/faster_rcnn/scripts/run_eval.sh b/model_zoo/official/cv/maskrcnn/scripts/run_eval.sh similarity index 100% rename from model_zoo/faster_rcnn/scripts/run_eval.sh rename to model_zoo/official/cv/maskrcnn/scripts/run_eval.sh diff --git a/model_zoo/faster_rcnn/scripts/run_standalone_train.sh b/model_zoo/official/cv/maskrcnn/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/faster_rcnn/scripts/run_standalone_train.sh rename to model_zoo/official/cv/maskrcnn/scripts/run_standalone_train.sh diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/__init__.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/__init__.py new file mode 100644 index 0000000000..619d2d8999 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn Init.""" + +from .resnet50 import ResNetFea, ResidualBlockUsing +from .bbox_assign_sample import BboxAssignSample +from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn +from .fpn_neck import FeatPyramidNeck +from .proposal_generator import Proposal +from .rcnn_cls import RcnnCls +from .rcnn_mask import RcnnMask +from .rpn import RPN +from .roi_align import SingleRoIExtractor +from .anchor_generator import AnchorGenerator + +__all__ = [ + "ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn", + "FeatPyramidNeck", "Proposal", "RcnnCls", "RcnnMask", + "RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing" +] diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/anchor_generator.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/anchor_generator.py new file mode 100644 index 0000000000..0430b6192c --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/anchor_generator.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn anchor generator.""" + +import numpy as np + +class AnchorGenerator(): + """Anchor generator for MasKRcnn.""" + def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): + """Anchor generator init method.""" + self.base_size = base_size + self.scales = np.array(scales) + self.ratios = np.array(ratios) + self.scale_major = scale_major + self.ctr = ctr + self.base_anchors = self.gen_base_anchors() + + def gen_base_anchors(self): + """Generate a single anchor.""" + w = self.base_size + h = self.base_size + if self.ctr is None: + x_ctr = 0.5 * (w - 1) + y_ctr = 0.5 * (h - 1) + else: + x_ctr, y_ctr = self.ctr + + h_ratios = np.sqrt(self.ratios) + w_ratios = 1 / h_ratios + if self.scale_major: + ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1) + hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1) + else: + ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1) + hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1) + + base_anchors = np.stack( + [ + x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) + ], + axis=-1).round() + + return base_anchors + + def _meshgrid(self, x, y, row_major=True): + """Generate grid.""" + xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1) + yy = np.repeat(y, len(x)) + if row_major: + return xx, yy + + return yy, xx + + def grid_anchors(self, featmap_size, stride=16): + """Generate anchor list.""" + base_anchors = self.base_anchors + + feat_h, feat_w = featmap_size + shift_x = np.arange(0, feat_w) * stride + shift_y = np.arange(0, feat_h) * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) + shifts = shifts.astype(base_anchors.dtype) + # first feat_w elements correspond to the first row of shifts + # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get + # shifted anchors (K, A, 4), reshape to (K*A, 4) + + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + all_anchors = all_anchors.reshape(-1, 4) + + return all_anchors diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/bbox_assign_sample.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/bbox_assign_sample.py new file mode 100644 index 0000000000..a282b732bb --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/bbox_assign_sample.py @@ -0,0 +1,164 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn positive and negative sample screening for RPN.""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype + + +class BboxAssignSample(nn.Cell): + """ + Bbox assigner and sampler defination. + + Args: + config (dict): Config. + batch_size (int): Batchsize. + num_bboxes (int): The anchor nums. + add_gt_as_proposals (bool): add gt bboxes as proposals flag. + + Returns: + Tensor, output tensor. + bbox_targets: bbox location, (batch_size, num_bboxes, 4) + bbox_weights: bbox weights, (batch_size, num_bboxes, 1) + labels: label for every bboxes, (batch_size, num_bboxes, 1) + label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) + + Examples: + BboxAssignSample(config, 2, 1024, True) + """ + + def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): + super(BboxAssignSample, self).__init__() + cfg = config + self.batch_size = batch_size + + self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16) + self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16) + self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16) + self.zero_thr = Tensor(0.0, mstype.float16) + + self.num_bboxes = num_bboxes + self.num_gts = cfg.num_gts + self.num_expected_pos = cfg.num_expected_pos + self.num_expected_neg = cfg.num_expected_neg + self.add_gt_as_proposals = add_gt_as_proposals + + if self.add_gt_as_proposals: + self.label_inds = Tensor(np.arange(1, self.num_gts + 1)) + + self.concat = P.Concat(axis=0) + self.max_gt = P.ArgMaxWithValue(axis=0) + self.max_anchor = P.ArgMaxWithValue(axis=1) + self.sum_inds = P.ReduceSum() + self.iou = P.IOU() + self.greaterequal = P.GreaterEqual() + self.greater = P.Greater() + self.select = P.Select() + self.gatherND = P.GatherNd() + self.squeeze = P.Squeeze() + self.cast = P.Cast() + self.logicaland = P.LogicalAnd() + self.less = P.Less() + self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) + self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) + self.reshape = P.Reshape() + self.equal = P.Equal() + self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) + self.scatterNdUpdate = P.ScatterNdUpdate() + self.scatterNd = P.ScatterNd() + self.logicalnot = P.LogicalNot() + self.tile = P.Tile() + self.zeros_like = P.ZerosLike() + + self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) + self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) + self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) + self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) + self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) + + self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) + self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) + + + def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): + gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ + (self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one) + bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ + (self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two) + + overlaps = self.iou(bboxes, gt_bboxes_i) + + max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) + _, max_overlaps_w_ac = self.max_anchor(overlaps) + + neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \ + self.less(max_overlaps_w_gt, self.neg_iou_thr)) + assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) + + pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr) + assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ + max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) + assigned_gt_inds4 = assigned_gt_inds3 + for j in range(self.num_gts): + max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] + overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::]) + + pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \ + self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j)) + + assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4) + + assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores) + + pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) + + pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) + pos_check_valid = self.sum_inds(pos_check_valid, -1) + valid_pos_index = self.less(self.range_pos_size, pos_check_valid) + pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) + + pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones + pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32) + pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1)) + + neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) + + num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16) + num_pos = self.sum_inds(num_pos, -1) + unvalid_pos_index = self.less(self.range_pos_size, num_pos) + valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) + + pos_bboxes_ = self.gatherND(bboxes, pos_index) + pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) + pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) + + pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) + + valid_pos_index = self.cast(valid_pos_index, mstype.int32) + valid_neg_index = self.cast(valid_neg_index, mstype.int32) + bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4)) + bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,)) + labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,)) + total_index = self.concat((pos_index, neg_index)) + total_valid_index = self.concat((valid_pos_index, valid_neg_index)) + label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,)) + + return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \ + labels_total, self.cast(label_weights_total, mstype.bool_) diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/bbox_assign_sample_stage2.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/bbox_assign_sample_stage2.py new file mode 100644 index 0000000000..0e797577e5 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/bbox_assign_sample_stage2.py @@ -0,0 +1,221 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn tpositive and negative sample screening for Rcnn.""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor + +class BboxAssignSampleForRcnn(nn.Cell): + """ + Bbox assigner and sampler defination. + + Args: + config (dict): Config. + batch_size (int): Batchsize. + num_bboxes (int): The anchor nums. + add_gt_as_proposals (bool): add gt bboxes as proposals flag. + + Returns: + Tensor, multiple output tensors. + + Examples: + BboxAssignSampleForRcnn(config, 2, 1024, True) + """ + + def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): + super(BboxAssignSampleForRcnn, self).__init__() + cfg = config + self.batch_size = batch_size + self.neg_iou_thr = cfg.neg_iou_thr_stage2 + self.pos_iou_thr = cfg.pos_iou_thr_stage2 + self.min_pos_iou = cfg.min_pos_iou_stage2 + self.num_gts = cfg.num_gts + self.num_bboxes = num_bboxes + self.num_expected_pos = cfg.num_expected_pos_stage2 + self.num_expected_neg = cfg.num_expected_neg_stage2 + self.num_expected_total = cfg.num_expected_total_stage2 + + self.add_gt_as_proposals = add_gt_as_proposals + self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32)) + self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts), + dtype=np.int32)) + + self.concat = P.Concat(axis=0) + self.max_gt = P.ArgMaxWithValue(axis=0) + self.max_anchor = P.ArgMaxWithValue(axis=1) + self.sum_inds = P.ReduceSum() + self.iou = P.IOU() + self.greaterequal = P.GreaterEqual() + self.greater = P.Greater() + self.select = P.Select() + self.gatherND = P.GatherNd() + self.squeeze = P.Squeeze() + self.cast = P.Cast() + self.logicaland = P.LogicalAnd() + self.less = P.Less() + self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) + self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) + self.reshape = P.Reshape() + self.equal = P.Equal() + self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(10.0, 10.0, 5.0, 5.0)) + self.concat_axis1 = P.Concat(axis=1) + self.logicalnot = P.LogicalNot() + self.tile = P.Tile() + + # Check + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) + + # Init tensor + self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) + self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) + self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) + self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) + self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) + + self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32)) + self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) + self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) + self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16)) + self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) + + self.reshape_shape_pos = (self.num_expected_pos, 1) + self.reshape_shape_neg = (self.num_expected_neg, 1) + + self.scalar_zero = Tensor(0.0, dtype=mstype.float16) + self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16) + self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16) + self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16) + + self.expand_dims = P.ExpandDims() + self.split = P.Split(axis=1, output_num=4) + self.concat_last_axis = P.Concat(axis=-1) + self.round = P.Round() + self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16) + self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2) + self.crop_and_resize = P.CropAndResize() + self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1]) + self.squeeze_mask_last = P.Squeeze(axis=-1) + def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i): + gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ + (self.num_gts, 1)), (1, 4)), mstype.bool_), \ + gt_bboxes_i, self.check_gt_one) + bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ + (self.num_bboxes, 1)), (1, 4)), mstype.bool_), \ + bboxes, self.check_anchor_two) + # 1 dim = gt, 2 dim = bbox + overlaps = self.iou(bboxes, gt_bboxes_i) + + max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) + _, max_overlaps_w_ac = self.max_anchor(overlaps) + + neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, + self.scalar_zero), + self.less(max_overlaps_w_gt, + self.scalar_neg_iou_thr)) + + assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) + + pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr) + assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ + max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) + + for j in range(self.num_gts): + max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] + overlaps_w_ac_j = overlaps[j:j+1:1, ::] + temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou) + temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j)) + pos_mask_j = self.logicaland(temp1, temp2) + assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3) + + assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores) + + bboxes = self.concat((gt_bboxes_i, bboxes)) + label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores) + label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid + assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5)) + + # Get pos index + pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) + + pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) + pos_check_valid = self.sum_inds(pos_check_valid, -1) + valid_pos_index = self.less(self.range_pos_size, pos_check_valid) + pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) + + num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1) + valid_pos_index = self.cast(valid_pos_index, mstype.int32) + pos_index = self.reshape(pos_index, self.reshape_shape_pos) + valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) + pos_index = pos_index * valid_pos_index + + pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones + pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos) + pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index + + pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) + + # Get neg index + neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) + + unvalid_pos_index = self.less(self.range_pos_size, num_pos) + valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) + neg_index = self.reshape(neg_index, self.reshape_shape_neg) + + valid_neg_index = self.cast(valid_neg_index, mstype.int32) + valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg) + neg_index = neg_index * valid_neg_index + + pos_bboxes_ = self.gatherND(bboxes, pos_index) + + neg_bboxes_ = self.gatherND(bboxes, neg_index) + pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos) + pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) + pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) + + # assign positive ROIs to gt masks + # Pick the right front and background mask for each ROI + roi_pos_masks_fb = self.gatherND(gt_masks_i, pos_assigned_gt_index) + pos_masks_fb = self.cast(roi_pos_masks_fb, mstype.float32) + # compute mask targets + x1, y1, x2, y2 = self.split(pos_bboxes_) + boxes = self.concat_last_axis((y1, x1, y2, x2)) + # normalized box coordinate + boxes = boxes / self.image_h_w + box_ids = self.range() + pos_masks_fb = self.expand_dims(pos_masks_fb, -1) + boxes = self.cast(boxes, mstype.float32) + pos_masks_fb = self.crop_and_resize(pos_masks_fb, boxes, box_ids, self.mask_shape) + + # Remove the extra dimension from masks. + pos_masks_fb = self.squeeze_mask_last(pos_masks_fb) + + # convert gt masks targets be 0 or 1 to use with binary cross entropy loss. + pos_masks_fb = self.round(pos_masks_fb) + + pos_masks_fb = self.cast(pos_masks_fb, mstype.float16) + total_bboxes = self.concat((pos_bboxes_, neg_bboxes_)) + total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask)) + total_labels = self.concat((pos_gt_labels, self.labels_neg_mask)) + + valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos) + valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg) + total_mask = self.concat((valid_pos_index, valid_neg_index)) + + return total_bboxes, total_deltas, total_labels, total_mask, pos_bboxes_, pos_masks_fb, \ + pos_gt_labels, valid_pos_index diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/fpn_neck.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/fpn_neck.py new file mode 100644 index 0000000000..8524648570 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/fpn_neck.py @@ -0,0 +1,112 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn feature pyramid network.""" + +import numpy as np +import mindspore.nn as nn +from mindspore import context +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + +def bias_init_zeros(shape): + """Bias init method.""" + return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16)) + +def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): + """Conv2D wrapper.""" + shape = (out_channels, in_channels, kernel_size, kernel_size) + weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() + shape_bias = (out_channels,) + biass = bias_init_zeros(shape_bias) + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass) + +class FeatPyramidNeck(nn.Cell): + """ + Feature pyramid network cell, usually uses as network neck. + + Applies the convolution on multiple, input feature maps + and output feature map with same channel size. if required num of + output larger then num of inputs, add extra maxpooling for further + downsampling; + + Args: + in_channels (tuple) - Channel size of input feature maps. + out_channels (int) - Channel size output. + num_outs (int) - Num of output features. + + Returns: + Tuple, with tensors of same channel size. + + Examples: + neck = FeatPyramidNeck([100,200,300], 50, 4) + input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)), + dtype=np.float32) \ + for i, c in enumerate(config.fpn_in_channels)) + x = neck(input_data) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs): + super(FeatPyramidNeck, self).__init__() + self.num_outs = num_outs + self.in_channels = in_channels + self.fpn_layer = len(self.in_channels) + + assert not self.num_outs < len(in_channels) + + self.lateral_convs_list_ = [] + self.fpn_convs_ = [] + + for _, channel in enumerate(in_channels): + l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid') + fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same') + self.lateral_convs_list_.append(l_conv) + self.fpn_convs_.append(fpn_conv) + self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) + self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) + self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) + self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) + self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) + self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same") + + def construct(self, inputs): + x = () + for i in range(self.fpn_layer): + x += (self.lateral_convs_list[i](inputs[i]),) + + y = (x[3],) + y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),) + y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),) + y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),) + + z = () + for i in range(self.fpn_layer - 1, -1, -1): + z = z + (y[i],) + + outs = () + for i in range(self.fpn_layer): + outs = outs + (self.fpn_convs_list[i](z[i]),) + + for i in range(self.num_outs - self.fpn_layer): + outs = outs + (self.maxpool(outs[3]),) + return outs diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/mask_rcnn_r50.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/mask_rcnn_r50.py new file mode 100644 index 0000000000..c7697c04d7 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/mask_rcnn_r50.py @@ -0,0 +1,569 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn based on ResNet50.""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import functional as F +from .resnet50 import ResNetFea, ResidualBlockUsing +from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn +from .fpn_neck import FeatPyramidNeck +from .proposal_generator import Proposal +from .rcnn_cls import RcnnCls +from .rcnn_mask import RcnnMask +from .rpn import RPN +from .roi_align import SingleRoIExtractor +from .anchor_generator import AnchorGenerator + +class Mask_Rcnn_Resnet50(nn.Cell): + """ + MaskRcnn Network. + + Note: + backbone = resnet50 + + Returns: + Tuple, tuple of output tensor. + rpn_loss: Scalar, Total loss of RPN subnet. + rcnn_loss: Scalar, Total loss of RCNN subnet. + rpn_cls_loss: Scalar, Classification loss of RPN subnet. + rpn_reg_loss: Scalar, Regression loss of RPN subnet. + rcnn_cls_loss: Scalar, Classification loss of RCNNcls subnet. + rcnn_reg_loss: Scalar, Regression loss of RCNNcls subnet. + rcnn_mask_loss: Scalar, mask loss of RCNNmask subnet. + + Examples: + net = Mask_Rcnn_Resnet50() + """ + def __init__(self, config): + super(Mask_Rcnn_Resnet50, self).__init__() + self.train_batch_size = config.batch_size + self.num_classes = config.num_classes + self.anchor_scales = config.anchor_scales + self.anchor_ratios = config.anchor_ratios + self.anchor_strides = config.anchor_strides + self.target_means = tuple(config.rcnn_target_means) + self.target_stds = tuple(config.rcnn_target_stds) + + # Anchor generator + anchor_base_sizes = None + self.anchor_base_sizes = list( + self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes + + self.anchor_generators = [] + for anchor_base in self.anchor_base_sizes: + self.anchor_generators.append( + AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios)) + + self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales) + + featmap_sizes = config.feature_shapes + assert len(featmap_sizes) == len(self.anchor_generators) + + self.anchor_list = self.get_anchors(featmap_sizes) + + # Backbone resnet50 + self.backbone = ResNetFea(ResidualBlockUsing, + config.resnet_block, + config.resnet_in_channels, + config.resnet_out_channels, + False) + + # Fpn + self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels, + config.fpn_out_channels, + config.fpn_num_outs) + + # Rpn and rpn loss + self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8)) + self.rpn_with_loss = RPN(config, + self.train_batch_size, + config.rpn_in_channels, + config.rpn_feat_channels, + config.num_anchors, + config.rpn_cls_out_channels) + + # Proposal + self.proposal_generator = Proposal(config, + self.train_batch_size, + config.activate_num_classes, + config.use_sigmoid_cls) + self.proposal_generator.set_train_local(config, True) + self.proposal_generator_test = Proposal(config, + config.test_batch_size, + config.activate_num_classes, + config.use_sigmoid_cls) + self.proposal_generator_test.set_train_local(config, False) + + # Assign and sampler stage two + self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size, + config.num_bboxes_stage2, True) + self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=self.target_means, \ + stds=self.target_stds) + + # Roi + self.roi_align = SingleRoIExtractor(config, + config.roi_layer, + config.roi_align_out_channels, + config.roi_align_featmap_strides, + self.train_batch_size, + config.roi_align_finest_scale, + mask=False) + self.roi_align.set_train_local(config, True) + + self.roi_align_mask = SingleRoIExtractor(config, + config.roi_layer, + config.roi_align_out_channels, + config.roi_align_featmap_strides, + self.train_batch_size, + config.roi_align_finest_scale, + mask=True) + self.roi_align_mask.set_train_local(config, True) + + self.roi_align_test = SingleRoIExtractor(config, + config.roi_layer, + config.roi_align_out_channels, + config.roi_align_featmap_strides, + 1, + config.roi_align_finest_scale, + mask=False) + self.roi_align_test.set_train_local(config, False) + + self.roi_align_mask_test = SingleRoIExtractor(config, + config.roi_layer, + config.roi_align_out_channels, + config.roi_align_featmap_strides, + 1, + config.roi_align_finest_scale, + mask=True) + self.roi_align_mask_test.set_train_local(config, False) + + # Rcnn + self.rcnn_cls = RcnnCls(config, self.train_batch_size, self.num_classes) + self.rcnn_mask = RcnnMask(config, self.train_batch_size, self.num_classes) + + # Op declare + self.squeeze = P.Squeeze() + self.cast = P.Cast() + + self.concat = P.Concat(axis=0) + self.concat_1 = P.Concat(axis=1) + self.concat_2 = P.Concat(axis=2) + self.reshape = P.Reshape() + self.select = P.Select() + self.greater = P.Greater() + self.transpose = P.Transpose() + + # Test mode + self.test_batch_size = config.test_batch_size + self.split = P.Split(axis=0, output_num=self.test_batch_size) + self.split_shape = P.Split(axis=0, output_num=4) + self.split_scores = P.Split(axis=1, output_num=self.num_classes) + self.split_fb_mask = P.Split(axis=1, output_num=self.num_classes) + self.split_cls = P.Split(axis=0, output_num=self.num_classes-1) + self.tile = P.Tile() + self.gather = P.GatherNd() + + self.rpn_max_num = config.rpn_max_num + + self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16)) + self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool) + self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool) + self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask, + self.ones_mask, self.zeros_mask), axis=1)) + self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask, + self.ones_mask, self.ones_mask, self.zeros_mask), axis=1)) + + self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr) + self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0) + self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1) + self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr) + self.test_max_per_img = config.test_max_per_img + self.nms_test = P.NMSWithMask(config.test_iou_thr) + self.softmax = P.Softmax(axis=1) + self.logicand = P.LogicalAnd() + self.oneslike = P.OnesLike() + self.test_topk = P.TopK(sorted=True) + self.test_num_proposal = self.test_batch_size * self.rpn_max_num + + # Improve speed + self.concat_start = min(self.num_classes - 2, 55) + self.concat_end = (self.num_classes - 1) + + # Init tensor + roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, + dtype=np.float16) for i in range(self.train_batch_size)] + + roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \ + for i in range(self.test_batch_size)] + + self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) + self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test)) + + roi_align_index_pos = [np.array(np.ones((config.num_expected_pos_stage2, 1)) * i, + dtype=np.float16) for i in range(self.train_batch_size)] + self.roi_align_index_tensor_pos = Tensor(np.concatenate(roi_align_index_pos)) + + self.rcnn_loss_cls_weight = Tensor(np.array(config.rcnn_loss_cls_weight).astype(np.float16)) + self.rcnn_loss_reg_weight = Tensor(np.array(config.rcnn_loss_reg_weight).astype(np.float16)) + self.rcnn_loss_mask_fb_weight = Tensor(np.array(config.rcnn_loss_mask_fb_weight).astype(np.float16)) + + self.argmax_with_value = P.ArgMaxWithValue(axis=1) + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.onehot = P.OneHot() + self.reducesum = P.ReduceSum() + self.sigmoid = P.Sigmoid() + self.expand_dims = P.ExpandDims() + self.test_mask_fb_zeros = Tensor(np.zeros((self.rpn_max_num, 28, 28)).astype(np.float16)) + self.value = Tensor(1.0, mstype.float16) + def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids, gt_masks): + x = self.backbone(img_data) + x = self.fpn_ncek(x) + + rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x, + img_metas, + self.anchor_list, + gt_bboxes, + self.gt_labels_stage1, + gt_valids) + + if self.training: + proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list) + else: + proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list) + + gt_labels = self.cast(gt_labels, mstype.int32) + gt_valids = self.cast(gt_valids, mstype.int32) + bboxes_tuple = () + deltas_tuple = () + labels_tuple = () + mask_tuple = () + + pos_bboxes_tuple = () + pos_mask_fb_tuple = () + pos_labels_tuple = () + pos_mask_tuple = () + + if self.training: + for i in range(self.train_batch_size): + gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) + + gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) + gt_labels_i = self.cast(gt_labels_i, mstype.uint8) + + gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) + gt_valids_i = self.cast(gt_valids_i, mstype.bool_) + + gt_masks_i = self.squeeze(gt_masks[i:i + 1:1, ::]) + gt_masks_i = self.cast(gt_masks_i, mstype.bool_) + + bboxes, deltas, labels, mask, pos_bboxes, pos_mask_fb, pos_labels, pos_mask = \ + self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i, + gt_labels_i, + proposal_mask[i], + proposal[i][::, 0:4:1], + gt_valids_i, + gt_masks_i) + bboxes_tuple += (bboxes,) + deltas_tuple += (deltas,) + labels_tuple += (labels,) + mask_tuple += (mask,) + + pos_bboxes_tuple += (pos_bboxes,) + pos_mask_fb_tuple += (pos_mask_fb,) + pos_labels_tuple += (pos_labels,) + pos_mask_tuple += (pos_mask,) + + bbox_targets = self.concat(deltas_tuple) + rcnn_labels = self.concat(labels_tuple) + bbox_targets = F.stop_gradient(bbox_targets) + rcnn_labels = F.stop_gradient(rcnn_labels) + rcnn_labels = self.cast(rcnn_labels, mstype.int32) + + rcnn_pos_masks_fb = self.concat(pos_mask_fb_tuple) + rcnn_pos_masks_fb = F.stop_gradient(rcnn_pos_masks_fb) + rcnn_pos_labels = self.concat(pos_labels_tuple) + rcnn_pos_labels = F.stop_gradient(rcnn_pos_labels) + rcnn_pos_labels = self.cast(rcnn_pos_labels, mstype.int32) + else: + mask_tuple += proposal_mask + bbox_targets = proposal_mask + rcnn_labels = proposal_mask + + rcnn_pos_masks_fb = proposal_mask + rcnn_pos_labels = proposal_mask + for p_i in proposal: + bboxes_tuple += (p_i[::, 0:4:1],) + + pos_rois = None + if self.training: + if self.train_batch_size > 1: + bboxes_all = self.concat(bboxes_tuple) + pos_bboxes_all = self.concat(pos_bboxes_tuple) + else: + bboxes_all = bboxes_tuple[0] + pos_bboxes_all = pos_bboxes_tuple[0] + rois = self.concat_1((self.roi_align_index_tensor, bboxes_all)) + pos_rois = self.concat_1((self.roi_align_index_tensor_pos, pos_bboxes_all)) + pos_rois = self.cast(pos_rois, mstype.float32) + pos_rois = F.stop_gradient(pos_rois) + else: + if self.test_batch_size > 1: + bboxes_all = self.concat(bboxes_tuple) + else: + bboxes_all = bboxes_tuple[0] + rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) + + rois = self.cast(rois, mstype.float32) + rois = F.stop_gradient(rois) + + if self.training: + roi_feats = self.roi_align(rois, + self.cast(x[0], mstype.float32), + self.cast(x[1], mstype.float32), + self.cast(x[2], mstype.float32), + self.cast(x[3], mstype.float32)) + else: + roi_feats = self.roi_align_test(rois, + self.cast(x[0], mstype.float32), + self.cast(x[1], mstype.float32), + self.cast(x[2], mstype.float32), + self.cast(x[3], mstype.float32)) + + + roi_feats = self.cast(roi_feats, mstype.float16) + rcnn_masks = self.concat(mask_tuple) + rcnn_masks = F.stop_gradient(rcnn_masks) + rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_)) + + rcnn_pos_masks = self.concat(pos_mask_tuple) + rcnn_pos_masks = F.stop_gradient(rcnn_pos_masks) + rcnn_pos_mask_squeeze = self.squeeze(self.cast(rcnn_pos_masks, mstype.bool_)) + + rcnn_cls_loss, rcnn_reg_loss = self.rcnn_cls(roi_feats, + bbox_targets, + rcnn_labels, + rcnn_mask_squeeze) + + output = () + if self.training: + roi_feats_mask = self.roi_align_mask(pos_rois, + self.cast(x[0], mstype.float32), + self.cast(x[1], mstype.float32), + self.cast(x[2], mstype.float32), + self.cast(x[3], mstype.float32)) + roi_feats_mask = self.cast(roi_feats_mask, mstype.float16) + rcnn_mask_fb_loss = self.rcnn_mask(roi_feats_mask, + rcnn_pos_labels, + rcnn_pos_mask_squeeze, + rcnn_pos_masks_fb) + + rcnn_loss = self.rcnn_loss_cls_weight * rcnn_cls_loss + self.rcnn_loss_reg_weight * rcnn_reg_loss + \ + self.rcnn_loss_mask_fb_weight * rcnn_mask_fb_loss + output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_fb_loss) + else: + mask_fb_pred_all = self.rcnn_mask_test(x, bboxes_all, rcnn_cls_loss, rcnn_reg_loss) + output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, + img_metas, mask_fb_pred_all) + + return output + + def get_det_bboxes(self, cls_logits, reg_logits, mask_logits, rois, img_metas, mask_fb_pred_all): + """Get the actual detection box.""" + scores = self.softmax(cls_logits / self.value) + mask_fb_logits = self.sigmoid(mask_fb_pred_all) + + boxes_all = () + for i in range(self.num_classes): + k = i * 4 + reg_logits_i = self.squeeze(reg_logits[::, k:k+4:1]) + out_boxes_i = self.decode(rois, reg_logits_i) + boxes_all += (out_boxes_i,) + + img_metas_all = self.split(img_metas) + scores_all = self.split(scores) + mask_all = self.split(self.cast(mask_logits, mstype.int32)) + mask_fb_all = self.split(mask_fb_logits) + + boxes_all_with_batchsize = () + for i in range(self.test_batch_size): + scale = self.split_shape(self.squeeze(img_metas_all[i])) + scale_h = scale[2] + scale_w = scale[3] + boxes_tuple = () + for j in range(self.num_classes): + boxes_tmp = self.split(boxes_all[j]) + out_boxes_h = boxes_tmp[i] / scale_h + out_boxes_w = boxes_tmp[i] / scale_w + boxes_tuple += (self.select(self.bbox_mask, out_boxes_w, out_boxes_h),) + boxes_all_with_batchsize += (boxes_tuple,) + + output = self.multiclass_nms(boxes_all_with_batchsize, scores_all, mask_all, mask_fb_all) + + return output + + def multiclass_nms(self, boxes_all, scores_all, mask_all, mask_fb_all): + """Multiscale postprocessing.""" + all_bboxes = () + all_labels = () + all_masks = () + all_masks_fb = () + + for i in range(self.test_batch_size): + bboxes = boxes_all[i] + scores = scores_all[i] + masks = self.cast(mask_all[i], mstype.bool_) + masks_fb = mask_fb_all[i] + _mask_fb_all = self.split_fb_mask(masks_fb) + + res_boxes_tuple = () + res_labels_tuple = () + res_masks_tuple = () + res_masks_fb_tuple = () + + for j in range(self.num_classes - 1): + k = j + 1 + _cls_scores = scores[::, k:k + 1:1] + _bboxes = self.squeeze(bboxes[k]) + _mask_o = self.reshape(masks, (self.rpn_max_num, 1)) + _masks_fb = self.squeeze(_mask_fb_all[k]) + + cls_mask = self.greater(_cls_scores, self.test_score_thresh) + _mask = self.logicand(_mask_o, cls_mask) + + _reg_mask = self.cast(self.tile(self.cast(_mask, mstype.int32), (1, 4)), mstype.bool_) + + _bboxes = self.select(_reg_mask, _bboxes, self.test_box_zeros) + _fb_mask = self.expand_dims(_mask, -1) + _mask_fb_mask = self.cast(self.tile(self.cast(_fb_mask, mstype.int32), (1, 28, 28)), mstype.bool_) + _masks_fb = self.select(_mask_fb_mask, _masks_fb, self.test_mask_fb_zeros) + _cls_scores = self.select(_mask, _cls_scores, self.test_score_zeros) + __cls_scores = self.squeeze(_cls_scores) + scores_sorted, topk_inds = self.test_topk(__cls_scores, self.rpn_max_num) + topk_inds = self.reshape(topk_inds, (self.rpn_max_num, 1)) + scores_sorted = self.reshape(scores_sorted, (self.rpn_max_num, 1)) + _bboxes_sorted = self.gather(_bboxes, topk_inds) + _mask_fb_sorted = self.gather(_masks_fb, topk_inds) + _mask_sorted = self.gather(_mask, topk_inds) + + scores_sorted = self.tile(scores_sorted, (1, 4)) + cls_dets = self.concat_1((_bboxes_sorted, scores_sorted)) + cls_dets = P.Slice()(cls_dets, (0, 0), (self.rpn_max_num, 5)) + + cls_dets, _index, _mask_nms = self.nms_test(cls_dets) + _index = self.reshape(_index, (self.rpn_max_num, 1)) + _mask_nms = self.reshape(_mask_nms, (self.rpn_max_num, 1)) + + _mask_n = self.gather(_mask_sorted, _index) + _mask_n = self.logicand(_mask_n, _mask_nms) + + _mask_fb = self.gather(_mask_fb_sorted, _index) + + cls_labels = self.oneslike(_index) * j + res_boxes_tuple += (cls_dets,) + res_labels_tuple += (cls_labels,) + res_masks_tuple += (_mask_n,) + res_masks_fb_tuple += (_mask_fb,) + + res_boxes_start = self.concat(res_boxes_tuple[:self.concat_start]) + res_labels_start = self.concat(res_labels_tuple[:self.concat_start]) + res_masks_start = self.concat(res_masks_tuple[:self.concat_start]) + res_masks_fb_start = self.concat(res_masks_fb_tuple[:self.concat_start]) + + res_boxes_end = self.concat(res_boxes_tuple[self.concat_start:self.concat_end]) + res_labels_end = self.concat(res_labels_tuple[self.concat_start:self.concat_end]) + res_masks_end = self.concat(res_masks_tuple[self.concat_start:self.concat_end]) + res_masks_fb_end = self.concat(res_masks_fb_tuple[self.concat_start:self.concat_end]) + + res_boxes = self.concat((res_boxes_start, res_boxes_end)) + res_labels = self.concat((res_labels_start, res_labels_end)) + res_masks = self.concat((res_masks_start, res_masks_end)) + res_masks_fb = self.concat((res_masks_fb_start, res_masks_fb_end)) + + reshape_size = (self.num_classes - 1) * self.rpn_max_num + res_boxes = self.reshape(res_boxes, (1, reshape_size, 5)) + res_labels = self.reshape(res_labels, (1, reshape_size, 1)) + res_masks = self.reshape(res_masks, (1, reshape_size, 1)) + res_masks_fb = self.reshape(res_masks_fb, (1, reshape_size, 28, 28)) + + all_bboxes += (res_boxes,) + all_labels += (res_labels,) + all_masks += (res_masks,) + all_masks_fb += (res_masks_fb,) + + all_bboxes = self.concat(all_bboxes) + all_labels = self.concat(all_labels) + all_masks = self.concat(all_masks) + all_masks_fb = self.concat(all_masks_fb) + return all_bboxes, all_labels, all_masks, all_masks_fb + + def get_anchors(self, featmap_sizes): + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + + Returns: + tuple: anchors of each image, valid flags of each image + """ + num_levels = len(featmap_sizes) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = () + for i in range(num_levels): + anchors = self.anchor_generators[i].grid_anchors( + featmap_sizes[i], self.anchor_strides[i]) + multi_level_anchors += (Tensor(anchors.astype(np.float16)),) + + return multi_level_anchors + + def rcnn_mask_test(self, x, rois, cls_pred, reg_pred): + """Prediction masks in an images by the bounding boxes + """ + cls_scores = self.softmax(cls_pred / self.value) + + cls_scores_all = self.split(cls_scores) + reg_pred = self.reshape(reg_pred, (-1, self.num_classes, 4)) + reg_pred_all = self.split(reg_pred) + rois_all = self.split(rois) + boxes_tuple = () + for i in range(self.test_batch_size): + cls_score_max_index, _ = self.argmax_with_value(cls_scores_all[i]) + cls_score_max_index = self.cast(self.onehot(cls_score_max_index, self.num_classes, + self.on_value, self.off_value), mstype.float16) + cls_score_max_index = self.expand_dims(cls_score_max_index, -1) + cls_score_max_index = self.tile(cls_score_max_index, (1, 1, 4)) + reg_pred_max = reg_pred_all[i] * cls_score_max_index + reg_pred_max = self.reducesum(reg_pred_max, 1) + out_boxes_i = self.decode(rois_all[i], reg_pred_max) + boxes_tuple += (out_boxes_i,) + + boxes_all = self.concat(boxes_tuple) + boxes_rois = self.concat_1((self.roi_align_index_test_tensor, boxes_all)) + boxes_rois = self.cast(boxes_rois, mstype.float32) + roi_feats_mask_test = self.roi_align_mask_test(boxes_rois, + self.cast(x[0], mstype.float32), + self.cast(x[1], mstype.float32), + self.cast(x[2], mstype.float32), + self.cast(x[3], mstype.float32)) + roi_feats_mask_test = self.cast(roi_feats_mask_test, mstype.float16) + mask_fb_pred_all = self.rcnn_mask(roi_feats_mask_test) + return mask_fb_pred_all diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/proposal_generator.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/proposal_generator.py new file mode 100644 index 0000000000..ae4c2624c1 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/proposal_generator.py @@ -0,0 +1,199 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn proposal generator.""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore import context + + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + + +class Proposal(nn.Cell): + """ + Proposal subnet. + + Args: + config (dict): Config. + batch_size (int): Batchsize. + num_classes (int) - Class number. + use_sigmoid_cls (bool) - Select sigmoid or softmax function. + target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0). + target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0). + + Returns: + Tuple, tuple of output tensor,(proposal, mask). + + Examples: + Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \ + target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)) + """ + def __init__(self, + config, + batch_size, + num_classes, + use_sigmoid_cls, + target_means=(.0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0) + ): + super(Proposal, self).__init__() + cfg = config + self.batch_size = batch_size + self.num_classes = num_classes + self.target_means = target_means + self.target_stds = target_stds + self.use_sigmoid_cls = use_sigmoid_cls + + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes - 1 + self.activation = P.Sigmoid() + self.reshape_shape = (-1, 1) + else: + self.cls_out_channels = num_classes + self.activation = P.Softmax(axis=1) + self.reshape_shape = (-1, 2) + + if self.cls_out_channels <= 0: + raise ValueError('num_classes={} is too small'.format(num_classes)) + + self.num_pre = cfg.rpn_proposal_nms_pre + self.min_box_size = cfg.rpn_proposal_min_bbox_size + self.nms_thr = cfg.rpn_proposal_nms_thr + self.nms_post = cfg.rpn_proposal_nms_post + self.nms_across_levels = cfg.rpn_proposal_nms_across_levels + self.max_num = cfg.rpn_proposal_max_num + self.num_levels = cfg.fpn_num_outs + + # Op Define + self.squeeze = P.Squeeze() + self.reshape = P.Reshape() + self.cast = P.Cast() + + self.feature_shapes = cfg.feature_shapes + + self.transpose_shape = (1, 2, 0) + + self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \ + means=self.target_means, \ + stds=self.target_stds) + + self.nms = P.NMSWithMask(self.nms_thr) + self.concat_axis0 = P.Concat(axis=0) + self.concat_axis1 = P.Concat(axis=1) + self.split = P.Split(axis=1, output_num=5) + self.min = P.Minimum() + self.gatherND = P.GatherNd() + self.slice = P.Slice() + self.select = P.Select() + self.greater = P.Greater() + self.transpose = P.Transpose() + self.tile = P.Tile() + self.set_train_local(config, training=True) + + self.multi_10 = Tensor(10.0, mstype.float16) + + def set_train_local(self, config, training=True): + """Set training flag.""" + self.training_local = training + + cfg = config + self.topK_stage1 = () + self.topK_shape = () + total_max_topk_input = 0 + if not self.training_local: + self.num_pre = cfg.rpn_nms_pre + self.min_box_size = cfg.rpn_min_bbox_min_size + self.nms_thr = cfg.rpn_nms_thr + self.nms_post = cfg.rpn_nms_post + self.nms_across_levels = cfg.rpn_nms_across_levels + self.max_num = cfg.rpn_max_num + + for shp in self.feature_shapes: + k_num = min(self.num_pre, (shp[0] * shp[1] * 3)) + total_max_topk_input += k_num + self.topK_stage1 += (k_num,) + self.topK_shape += ((k_num, 1),) + + self.topKv2 = P.TopK(sorted=True) + self.topK_shape_stage2 = (self.max_num, 1) + self.min_float_num = -65536.0 + self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) + + def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): + proposals_tuple = () + masks_tuple = () + for img_id in range(self.batch_size): + cls_score_list = () + bbox_pred_list = () + for i in range(self.num_levels): + rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::]) + rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::]) + + cls_score_list = cls_score_list + (rpn_cls_score_i,) + bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,) + + proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list) + proposals_tuple += (proposals,) + masks_tuple += (masks,) + return proposals_tuple, masks_tuple + + def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors): + """Get proposal boundingbox.""" + mlvl_proposals = () + mlvl_mask = () + for idx in range(self.num_levels): + rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape) + rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape) + anchors = mlvl_anchors[idx] + + rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) + rpn_cls_score = self.activation(rpn_cls_score) + rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16) + + rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16) + + scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx]) + + topk_inds = self.reshape(topk_inds, self.topK_shape[idx]) + + bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) + anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16) + + proposals_decode = self.decode(anchors_sorted, bboxes_sorted) + + proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx]))) + proposals, _, mask_valid = self.nms(proposals_decode) + + mlvl_proposals = mlvl_proposals + (proposals,) + mlvl_mask = mlvl_mask + (mask_valid,) + + proposals = self.concat_axis0(mlvl_proposals) + masks = self.concat_axis0(mlvl_mask) + + _, _, _, _, scores = self.split(proposals) + scores = self.squeeze(scores) + topk_mask = self.cast(self.topK_mask, mstype.float16) + scores_using = self.select(masks, scores, topk_mask) + + _, topk_inds = self.topKv2(scores_using, self.max_num) + + topk_inds = self.reshape(topk_inds, self.topK_shape_stage2) + proposals = self.gatherND(proposals, topk_inds) + masks = self.gatherND(masks, topk_inds) + return proposals, masks diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rcnn_cls.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rcnn_cls.py new file mode 100644 index 0000000000..ac38b76334 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rcnn_cls.py @@ -0,0 +1,178 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn Rcnn classification and box regression network.""" + +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +class DenseNoTranpose(nn.Cell): + """Dense method""" + def __init__(self, input_channels, output_channels, weight_init): + super(DenseNoTranpose, self).__init__() + self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16), + name="weight") + self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias") + self.matmul = P.MatMul(transpose_b=False) + self.bias_add = P.BiasAdd() + + def construct(self, x): + output = self.bias_add(self.matmul(x, self.weight), self.bias) + return output + +class FpnCls(nn.Cell): + """dense layer of classification and box head""" + def __init__(self, input_channels, output_channels, num_classes, pool_size): + super(FpnCls, self).__init__() + representation_size = input_channels * pool_size * pool_size + shape_0 = (output_channels, representation_size) + weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor() + shape_1 = (output_channels, output_channels) + weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor() + self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0) + self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1) + + cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1], + dtype=mstype.float16).to_tensor() + reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1], + dtype=mstype.float16).to_tensor() + self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight) + self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight) + + self.relu = P.ReLU() + self.flatten = P.Flatten() + + def construct(self, x): + # two share fc layer + x = self.flatten(x) + + x = self.relu(self.shared_fc_0(x)) + x = self.relu(self.shared_fc_1(x)) + + # classifier head + cls_scores = self.cls_scores(x) + # bbox head + reg_scores = self.reg_scores(x) + + return cls_scores, reg_scores + +class RcnnCls(nn.Cell): + """ + Rcnn for classification and box regression subnet. + + Args: + config (dict) - Config. + batch_size (int) - Batchsize. + num_classes (int) - Class number. + target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]). + target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2). + + Returns: + Tuple, tuple of output tensor. + + Examples: + RcnnCls(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \ + target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2)) + """ + def __init__(self, + config, + batch_size, + num_classes, + target_means=(0., 0., 0., 0.), + target_stds=(0.1, 0.1, 0.2, 0.2) + ): + super(RcnnCls, self).__init__() + cfg = config + self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16)) + self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16)) + self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels + self.target_means = target_means + self.target_stds = target_stds + self.num_classes = num_classes + self.in_channels = cfg.rcnn_in_channels + self.train_batch_size = batch_size + self.test_batch_size = cfg.test_batch_size + + self.fpn_cls = FpnCls(self.in_channels, self.rcnn_fc_out_channels, self.num_classes, cfg.roi_layer["out_size"]) + self.relu = P.ReLU() + self.logicaland = P.LogicalAnd() + self.loss_cls = P.SoftmaxCrossEntropyWithLogits() + self.loss_bbox = P.SmoothL1Loss(sigma=1.0) + self.loss_mask = P.SigmoidCrossEntropyWithLogits() + self.reshape = P.Reshape() + self.onehot = P.OneHot() + self.greater = P.Greater() + self.cast = P.Cast() + self.sum_loss = P.ReduceSum() + self.tile = P.Tile() + self.expandims = P.ExpandDims() + + self.gather = P.GatherNd() + self.argmax = P.ArgMaxWithValue(axis=1) + + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.value = Tensor(1.0, mstype.float16) + + self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size + + rmv_first = np.ones((self.num_bboxes, self.num_classes)) + rmv_first[:, 0] = np.zeros((self.num_bboxes,)) + self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16)) + + self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size + + def construct(self, featuremap, bbox_targets, labels, mask): + x_cls, x_reg = self.fpn_cls(featuremap) + + if self.training: + bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels + labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16) + bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) + + loss_cls, loss_reg = self.loss(x_cls, x_reg, + bbox_targets, bbox_weights, + labels, + mask) + out = (loss_cls, loss_reg) + else: + out = (x_cls, x_reg) + + return out + + def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights): + """Loss method.""" + # loss_cls + loss_cls, _ = self.loss_cls(cls_score, labels) + weights = self.cast(weights, mstype.float16) + loss_cls = loss_cls * weights + loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) + + # loss_reg + bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), + mstype.float16) + bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background + pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4)) + loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets) + loss_reg = self.sum_loss(loss_reg, (2,)) + loss_reg = loss_reg * bbox_weights + loss_reg = loss_reg / self.sum_loss(weights, (0,)) + loss_reg = self.sum_loss(loss_reg, (0, 1)) + + return loss_cls, loss_reg diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rcnn_mask.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rcnn_mask.py new file mode 100644 index 0000000000..7e46f28584 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rcnn_mask.py @@ -0,0 +1,168 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn Rcnn for mask network.""" + +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.initializer import initializer + +def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): + """Conv2D wrapper.""" + shape = (out_channels, in_channels, kernel_size, kernel_size) + weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() + shape_bias = (out_channels,) + bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16)) + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias) + +def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): + """ConvTranspose wrapper.""" + shape = (out_channels, in_channels, kernel_size, kernel_size) + weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor() + shape_bias = (out_channels,) + bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16)) + return nn.Conv2dTranspose(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias) + +class FpnMask(nn.Cell): + """conv layers of mask head""" + def __init__(self, input_channels, output_channels, num_classes): + super(FpnMask, self).__init__() + self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3, pad_mode="same") + self.mask_relu1 = P.ReLU() + + self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same") + self.mask_relu2 = P.ReLU() + + self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same") + self.mask_relu3 = P.ReLU() + + self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same") + self.mask_relu4 = P.ReLU() + + self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2, stride=2, pad_mode="valid") + self.mask_relu5 = P.ReLU() + self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1, pad_mode="valid") + + def construct(self, x): + x = self.mask_conv1(x) + x = self.mask_relu1(x) + + x = self.mask_conv2(x) + x = self.mask_relu2(x) + + x = self.mask_conv3(x) + x = self.mask_relu3(x) + + x = self.mask_conv4(x) + x = self.mask_relu4(x) + + x = self.mask_deconv5(x) + x = self.mask_relu5(x) + + x = self.mask_conv6(x) + + return x + +class RcnnMask(nn.Cell): + """ + Rcnn for mask subnet. + + Args: + config (dict) - Config. + batch_size (int) - Batchsize. + num_classes (int) - Class number. + target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]). + target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2). + + Returns: + Tuple, tuple of output tensor. + + Examples: + RcnnMask(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \ + target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2)) + """ + def __init__(self, + config, + batch_size, + num_classes, + target_means=(0., 0., 0., 0.), + target_stds=(0.1, 0.1, 0.2, 0.2) + ): + super(RcnnMask, self).__init__() + cfg = config + self.rcnn_loss_mask_fb_weight = Tensor(np.array(cfg.rcnn_loss_mask_fb_weight).astype(np.float16)) + self.rcnn_mask_out_channels = cfg.rcnn_mask_out_channels + self.target_means = target_means + self.target_stds = target_stds + self.num_classes = num_classes + self.in_channels = cfg.rcnn_in_channels + + self.fpn_mask = FpnMask(self.in_channels, self.rcnn_mask_out_channels, self.num_classes) + + self.logicaland = P.LogicalAnd() + self.loss_mask = P.SigmoidCrossEntropyWithLogits() + self.onehot = P.OneHot() + self.greater = P.Greater() + self.cast = P.Cast() + self.sum_loss = P.ReduceSum() + self.tile = P.Tile() + self.expandims = P.ExpandDims() + + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + + self.num_bboxes = cfg.num_expected_pos_stage2 * batch_size + rmv_first = np.ones((self.num_bboxes, self.num_classes)) + rmv_first[:, 0] = np.zeros((self.num_bboxes,)) + self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16)) + self.mean_loss = P.ReduceMean() + + def construct(self, mask_featuremap, labels=None, mask=None, mask_fb_targets=None): + x_mask_fb = self.fpn_mask(mask_featuremap) + + if self.training: + bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels + mask_fb_targets = self.tile(self.expandims(mask_fb_targets, 1), (1, self.num_classes, 1, 1)) + + loss_mask_fb = self.loss(x_mask_fb, bbox_weights, mask, mask_fb_targets) + out = loss_mask_fb + else: + out = x_mask_fb + + return out + + + def loss(self, masks_fb_pred, bbox_weights, weights, masks_fb_targets): + """Loss method.""" + weights = self.cast(weights, mstype.float16) + bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), + mstype.float16) + bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background + + # loss_mask_fb + masks_fb_targets = self.cast(masks_fb_targets, mstype.float16) + loss_mask_fb = self.loss_mask(masks_fb_pred, masks_fb_targets) + loss_mask_fb = self.mean_loss(loss_mask_fb, (2, 3)) + loss_mask_fb = loss_mask_fb * bbox_weights + loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,)) + loss_mask_fb = self.sum_loss(loss_mask_fb, (0, 1)) + + return loss_mask_fb diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/resnet50.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/resnet50.py new file mode 100644 index 0000000000..20d9ee1f34 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/resnet50.py @@ -0,0 +1,248 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Resnet50 backbone.""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.ops import functional as F +from mindspore import context + + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + + +def weight_init_ones(shape): + """Weight init.""" + return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16)) + + +def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'): + """Conv2D wrapper.""" + shape = (out_channels, in_channels, kernel_size, kernel_size) + weights = weight_init_ones(shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + pad_mode=pad_mode, weight_init=weights, has_bias=False) + + +def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True): + """Batchnorm2D wrapper.""" + gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) + beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) + moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16)) + moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16)) + + return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, + beta_init=beta_init, moving_mean_init=moving_mean_init, + moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) + + +class ResNetFea(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + weights_update (bool): Weight update flag. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> False) + """ + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + weights_update=False): + super(ResNetFea, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of " + "layer_num, inchannel, outchannel list must be 4!") + + bn_training = False + self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') + self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training) + self.relu = P.ReLU() + self.maxpool = P.MaxPool(ksize=3, strides=2, padding="SAME") + self.weights_update = weights_update + + if not self.weights_update: + self.conv1.weight.requires_grad = False + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=1, + training=bn_training, + weights_update=self.weights_update) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=2, + training=bn_training, + weights_update=True) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=2, + training=bn_training, + weights_update=True) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=2, + training=bn_training, + weights_update=True) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False): + """Make block layer.""" + layers = [] + down_sample = False + if stride != 1 or in_channel != out_channel: + down_sample = True + resblk = block(in_channel, + out_channel, + stride=stride, + down_sample=down_sample, + training=training, + weights_update=weights_update) + layers.append(resblk) + + for _ in range(1, layer_num): + resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update) + layers.append(resblk) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + identity = c2 + if not self.weights_update: + identity = F.stop_gradient(c2) + c3 = self.layer2(identity) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + return identity, c3, c4, c5 + + +class ResidualBlockUsing(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channels (int) - Input channel. + out_channels (int) - Output channel. + stride (int) - Stride size for the initial convolutional layer. Default: 1. + down_sample (bool) - If to do the downsample in block. Default: False. + momentum (float) - Momentum for batchnorm layer. Default: 0.1. + training (bool) - Training flag. Default: False. + weights_updata (bool) - Weights update flag. Default: False. + + Returns: + Tensor, output tensor. + + Examples: + ResidualBlock(3,256,stride=2,down_sample=True) + """ + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1, + down_sample=False, + momentum=0.1, + training=False, + weights_update=False): + super(ResidualBlockUsing, self).__init__() + + self.affine = weights_update + + out_chls = out_channels // self.expansion + self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0) + self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training) + + self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1) + self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training) + + self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0) + self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training) + + if training: + self.bn1 = self.bn1.set_train() + self.bn2 = self.bn2.set_train() + self.bn3 = self.bn3.set_train() + + if not weights_update: + self.conv1.weight.requires_grad = False + self.conv2.weight.requires_grad = False + self.conv3.weight.requires_grad = False + + self.relu = P.ReLU() + self.downsample = down_sample + if self.downsample: + self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0) + self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, + use_batch_statistics=training) + if training: + self.bn_down_sample = self.bn_down_sample.set_train() + if not weights_update: + self.conv_down_sample.weight.requires_grad = False + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample: + identity = self.conv_down_sample(identity) + identity = self.bn_down_sample(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/roi_align.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/roi_align.py new file mode 100644 index 0000000000..1f82e774bf --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/roi_align.py @@ -0,0 +1,186 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn ROIAlign module.""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.nn import layer as L +from mindspore.common.tensor import Tensor + +class ROIAlign(nn.Cell): + """ + Extract RoI features from mulitple feature map. + + Args: + out_size_h (int) - RoI height. + out_size_w (int) - RoI width. + spatial_scale (int) - RoI spatial scale. + sample_num (int) - RoI sample number. + roi_align_mode (int)- RoI align mode + """ + def __init__(self, + out_size_h, + out_size_w, + spatial_scale, + sample_num=0, + roi_align_mode=1): + super(ROIAlign, self).__init__() + + self.out_size = (out_size_h, out_size_w) + self.spatial_scale = float(spatial_scale) + self.sample_num = int(sample_num) + self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1], + self.spatial_scale, self.sample_num, roi_align_mode) + + def construct(self, features, rois): + return self.align_op(features, rois) + + def __repr__(self): + format_str = self.__class__.__name__ + format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format( + self.out_size, self.spatial_scale, self.sample_num) + return format_str + + +class SingleRoIExtractor(nn.Cell): + """ + Extract RoI features from a single level feature map. + + If there are mulitple input feature levels, each RoI is mapped to a level + according to its scale. + + Args: + config (dict): Config + roi_layer (dict): Specify RoI layer type and arguments. + out_channels (int): Output channels of RoI layers. + featmap_strides (int): Strides of input feature maps. + batch_size (int): Batchsize. + finest_scale (int): Scale threshold of mapping to level 0. + mask (bool): Specify ROIAlign for cls or mask branch + """ + + def __init__(self, + config, + roi_layer, + out_channels, + featmap_strides, + batch_size=1, + finest_scale=56, + mask=False): + super(SingleRoIExtractor, self).__init__() + cfg = config + self.train_batch_size = batch_size + self.out_channels = out_channels + self.featmap_strides = featmap_strides + self.num_levels = len(self.featmap_strides) + self.out_size = roi_layer['mask_out_size'] if mask else roi_layer['out_size'] + self.mask = mask + self.sample_num = roi_layer['sample_num'] + self.roi_layers = self.build_roi_layers(self.featmap_strides) + self.roi_layers = L.CellList(self.roi_layers) + + self.sqrt = P.Sqrt() + self.log = P.Log() + self.finest_scale_ = finest_scale + self.clamp = C.clip_by_value + + self.cast = P.Cast() + self.equal = P.Equal() + self.select = P.Select() + + _mode_16 = False + self.dtype = np.float16 if _mode_16 else np.float32 + self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32 + self.set_train_local(cfg, training=True) + + def set_train_local(self, config, training=True): + """Set training flag.""" + self.training_local = training + + cfg = config + # Init tensor + roi_sample_num = cfg.num_expected_pos_stage2 if self.mask else cfg.roi_sample_num + self.batch_size = roi_sample_num if self.training_local else cfg.rpn_max_num + self.batch_size = self.train_batch_size*self.batch_size \ + if self.training_local else cfg.test_batch_size*self.batch_size + self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)) + finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_ + self.finest_scale = Tensor(finest_scale) + self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6)) + self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32)) + self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1)) + self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2) + self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels, + self.out_size, self.out_size)), dtype=self.dtype)) + def num_inputs(self): + return len(self.featmap_strides) + + def init_weights(self): + pass + + def log2(self, value): + return self.log(value) / self.log(self.twos) + + def build_roi_layers(self, featmap_strides): + roi_layers = [] + for s in featmap_strides: + layer_cls = ROIAlign(self.out_size, self.out_size, + spatial_scale=1 / s, + sample_num=self.sample_num, + roi_align_mode=0) + roi_layers.append(layer_cls) + return roi_layers + + def _c_map_roi_levels(self, rois): + """Map rois to corresponding feature levels by scales. + + - scale < finest_scale * 2: level 0 + - finest_scale * 2 <= scale < finest_scale * 4: level 1 + - finest_scale * 4 <= scale < finest_scale * 8: level 2 + - scale >= finest_scale * 8: level 3 + + Args: + rois (Tensor): Input RoIs, shape (k, 5). + num_levels (int): Total level number. + + Returns: + Tensor: Level index (0-based) of each RoI, shape (k, ) + """ + scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \ + self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones) + + target_lvls = self.log2(scale / self.finest_scale + self.epslion) + target_lvls = P.Floor()(target_lvls) + target_lvls = self.cast(target_lvls, mstype.int32) + target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels) + + return target_lvls + + def construct(self, rois, feat1, feat2, feat3, feat4): + feats = (feat1, feat2, feat3, feat4) + res = self.res_ + target_lvls = self._c_map_roi_levels(rois) + for i in range(self.num_levels): + mask = self.equal(target_lvls, P.ScalarToArray()(i)) + mask = P.Reshape()(mask, (-1, 1, 1, 1)) + roi_feats_t = self.roi_layers[i](feats[i], rois) + mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, self.out_size, self.out_size)), + mstype.bool_) + res = self.select(mask, roi_feats_t, res) + + return res diff --git a/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rpn.py b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rpn.py new file mode 100644 index 0000000000..e06f8eb0df --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/MaskRcnn/rpn.py @@ -0,0 +1,311 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""RPN for MaskRCNN""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore.common.initializer import initializer +from .bbox_assign_sample import BboxAssignSample + + +class RpnRegClsBlock(nn.Cell): + """ + Rpn reg cls block for rpn layer + + Args: + in_channels (int) - Input channels of shared convolution. + feat_channels (int) - Output channels of shared convolution. + num_anchors (int) - The anchor number. + cls_out_channels (int) - Output channels of classification convolution. + weight_conv (Tensor) - weight init for rpn conv. + bias_conv (Tensor) - bias init for rpn conv. + weight_cls (Tensor) - weight init for rpn cls conv. + bias_cls (Tensor) - bias init for rpn cls conv. + weight_reg (Tensor) - weight init for rpn reg conv. + bias_reg (Tensor) - bias init for rpn reg conv. + + Returns: + Tensor, output tensor. + """ + def __init__(self, + in_channels, + feat_channels, + num_anchors, + cls_out_channels, + weight_conv, + bias_conv, + weight_cls, + bias_cls, + weight_reg, + bias_reg): + super(RpnRegClsBlock, self).__init__() + self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same', + has_bias=True, weight_init=weight_conv, bias_init=bias_conv) + self.relu = nn.ReLU() + + self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid', + has_bias=True, weight_init=weight_cls, bias_init=bias_cls) + self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid', + has_bias=True, weight_init=weight_reg, bias_init=bias_reg) + + def construct(self, x): + x = self.relu(self.rpn_conv(x)) + + x1 = self.rpn_cls(x) + x2 = self.rpn_reg(x) + + return x1, x2 + + +class RPN(nn.Cell): + """ + ROI proposal network.. + + Args: + config (dict) - Config. + batch_size (int) - Batchsize. + in_channels (int) - Input channels of shared convolution. + feat_channels (int) - Output channels of shared convolution. + num_anchors (int) - The anchor number. + cls_out_channels (int) - Output channels of classification convolution. + + Returns: + Tuple, tuple of output tensor. + + Examples: + RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024, + num_anchors=3, cls_out_channels=512) + """ + def __init__(self, + config, + batch_size, + in_channels, + feat_channels, + num_anchors, + cls_out_channels): + super(RPN, self).__init__() + cfg_rpn = config + self.num_bboxes = cfg_rpn.num_bboxes + self.slice_index = () + self.feature_anchor_shape = () + self.slice_index += (0,) + index = 0 + for shape in cfg_rpn.feature_shapes: + self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,) + self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,) + index += 1 + + self.num_anchors = num_anchors + self.batch_size = batch_size + self.test_batch_size = cfg_rpn.test_batch_size + self.num_layers = 5 + self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) + + self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels, + num_anchors, cls_out_channels)) + + self.transpose = P.Transpose() + self.reshape = P.Reshape() + self.concat = P.Concat(axis=0) + self.fill = P.Fill() + self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) + + self.trans_shape = (0, 2, 3, 1) + + self.reshape_shape_reg = (-1, 4) + self.reshape_shape_cls = (-1,) + self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) + self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) + self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) + self.num_bboxes = cfg_rpn.num_bboxes + self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) + self.CheckValid = P.CheckValid() + self.sum_loss = P.ReduceSum() + self.loss_cls = P.SigmoidCrossEntropyWithLogits() + self.loss_bbox = P.SmoothL1Loss(sigma=1.0/9.0) + self.squeeze = P.Squeeze() + self.cast = P.Cast() + self.tile = P.Tile() + self.zeros_like = P.ZerosLike() + self.loss = Tensor(np.zeros((1,)).astype(np.float16)) + self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) + self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) + + def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): + """ + make rpn layer for rpn proposal network + + Args: + num_layers (int) - layer num. + in_channels (int) - Input channels of shared convolution. + feat_channels (int) - Output channels of shared convolution. + num_anchors (int) - The anchor number. + cls_out_channels (int) - Output channels of classification convolution. + + Returns: + List, list of RpnRegClsBlock cells. + """ + rpn_layer = [] + + shp_weight_conv = (feat_channels, in_channels, 3, 3) + shp_bias_conv = (feat_channels,) + weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() + bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() + + shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) + shp_bias_cls = (num_anchors * cls_out_channels,) + weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() + bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() + + shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) + shp_bias_reg = (num_anchors * 4,) + weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() + bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() + + for i in range(num_layers): + rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ + weight_conv, bias_conv, weight_cls, \ + bias_cls, weight_reg, bias_reg)) + + for i in range(1, num_layers): + rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight + rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight + rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight + + rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias + rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias + rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias + + return rpn_layer + + def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids): + loss_print = () + rpn_cls_score = () + rpn_bbox_pred = () + rpn_cls_score_total = () + rpn_bbox_pred_total = () + + for i in range(self.num_layers): + x1, x2 = self.rpn_convs_list[i](inputs[i]) + + rpn_cls_score_total = rpn_cls_score_total + (x1,) + rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,) + + x1 = self.transpose(x1, self.trans_shape) + x1 = self.reshape(x1, self.reshape_shape_cls) + + x2 = self.transpose(x2, self.trans_shape) + x2 = self.reshape(x2, self.reshape_shape_reg) + + rpn_cls_score = rpn_cls_score + (x1,) + rpn_bbox_pred = rpn_bbox_pred + (x2,) + + loss = self.loss + clsloss = self.clsloss + regloss = self.regloss + bbox_targets = () + bbox_weights = () + labels = () + label_weights = () + + output = () + if self.training: + for i in range(self.batch_size): + multi_level_flags = () + anchor_list_tuple = () + + for j in range(self.num_layers): + res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])), + mstype.int32) + multi_level_flags = multi_level_flags + (res,) + anchor_list_tuple = anchor_list_tuple + (anchor_list[j],) + + valid_flag_list = self.concat(multi_level_flags) + anchor_using_list = self.concat(anchor_list_tuple) + + gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) + gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) + gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) + + bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i, + gt_labels_i, + self.cast(valid_flag_list, + mstype.bool_), + anchor_using_list, gt_valids_i) + + bbox_weight = self.cast(bbox_weight, mstype.float16) + label = self.cast(label, mstype.float16) + label_weight = self.cast(label_weight, mstype.float16) + + for j in range(self.num_layers): + begin = self.slice_index[j] + end = self.slice_index[j + 1] + stride = 1 + bbox_targets += (bbox_target[begin:end:stride, ::],) + bbox_weights += (bbox_weight[begin:end:stride],) + labels += (label[begin:end:stride],) + label_weights += (label_weight[begin:end:stride],) + + for i in range(self.num_layers): + bbox_target_using = () + bbox_weight_using = () + label_using = () + label_weight_using = () + + for j in range(self.batch_size): + bbox_target_using += (bbox_targets[i + (self.num_layers * j)],) + bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],) + label_using += (labels[i + (self.num_layers * j)],) + label_weight_using += (label_weights[i + (self.num_layers * j)],) + + bbox_target_with_batchsize = self.concat(bbox_target_using) + bbox_weight_with_batchsize = self.concat(bbox_weight_using) + label_with_batchsize = self.concat(label_using) + label_weight_with_batchsize = self.concat(label_weight_using) + + # stop + bbox_target_ = F.stop_gradient(bbox_target_with_batchsize) + bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize) + label_ = F.stop_gradient(label_with_batchsize) + label_weight_ = F.stop_gradient(label_weight_with_batchsize) + + cls_score_i = rpn_cls_score[i] + reg_score_i = rpn_bbox_pred[i] + + loss_cls = self.loss_cls(cls_score_i, label_) + loss_cls_item = loss_cls * label_weight_ + loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total + + loss_reg = self.loss_bbox(reg_score_i, bbox_target_) + bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4)) + loss_reg = loss_reg * bbox_weight_ + loss_reg_item = self.sum_loss(loss_reg, (1,)) + loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total + + loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item + + loss += loss_total + loss_print += (loss_total, loss_cls_item, loss_reg_item) + clsloss += loss_cls_item + regloss += loss_reg_item + + output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print) + else: + output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1) + + return output diff --git a/model_zoo/official/cv/maskrcnn/src/config.py b/model_zoo/official/cv/maskrcnn/src/config.py new file mode 100644 index 0000000000..97028c782e --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/config.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" :=========================================================================== +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config = ed({ + "img_width": 1280, + "img_height": 768, + "keep_ratio": False, + "flip_ratio": 0.5, + "photo_ratio": 0.5, + "expand_ratio": 1.0, + + "max_instance_count": 128, + "mask_shape": (28, 28), + + # anchor + "feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)], + "anchor_scales": [8], + "anchor_ratios": [0.5, 1.0, 2.0], + "anchor_strides": [4, 8, 16, 32, 64], + "num_anchors": 3, + + # resnet + "resnet_block": [3, 4, 6, 3], + "resnet_in_channels": [64, 256, 512, 1024], + "resnet_out_channels": [256, 512, 1024, 2048], + + # fpn + "fpn_in_channels": [256, 512, 1024, 2048], + "fpn_out_channels": 256, + "fpn_num_outs": 5, + + # rpn + "rpn_in_channels": 256, + "rpn_feat_channels": 256, + "rpn_loss_cls_weight": 1.0, + "rpn_loss_reg_weight": 1.0, + "rpn_cls_out_channels": 1, + "rpn_target_means": [0., 0., 0., 0.], + "rpn_target_stds": [1.0, 1.0, 1.0, 1.0], + + # bbox_assign_sampler + "neg_iou_thr": 0.3, + "pos_iou_thr": 0.7, + "min_pos_iou": 0.3, + "num_bboxes": 245520, + "num_gts": 128, + "num_expected_neg": 256, + "num_expected_pos": 128, + + # proposal + "activate_num_classes": 2, + "use_sigmoid_cls": True, + + # roi_align + "roi_layer": dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2), + "roi_align_out_channels": 256, + "roi_align_featmap_strides": [4, 8, 16, 32], + "roi_align_finest_scale": 56, + "roi_sample_num": 640, + + # bbox_assign_sampler_stage2 + "neg_iou_thr_stage2": 0.5, + "pos_iou_thr_stage2": 0.5, + "min_pos_iou_stage2": 0.5, + "num_bboxes_stage2": 2000, + "num_expected_pos_stage2": 128, + "num_expected_neg_stage2": 512, + "num_expected_total_stage2": 512, + + # rcnn + "rcnn_num_layers": 2, + "rcnn_in_channels": 256, + "rcnn_fc_out_channels": 1024, + "rcnn_mask_out_channels": 256, + "rcnn_loss_cls_weight": 1, + "rcnn_loss_reg_weight": 1, + "rcnn_loss_mask_fb_weight": 1, + "rcnn_target_means": [0., 0., 0., 0.], + "rcnn_target_stds": [0.1, 0.1, 0.2, 0.2], + + # train proposal + "rpn_proposal_nms_across_levels": False, + "rpn_proposal_nms_pre": 2000, + "rpn_proposal_nms_post": 2000, + "rpn_proposal_max_num": 2000, + "rpn_proposal_nms_thr": 0.7, + "rpn_proposal_min_bbox_size": 0, + + # test proposal + "rpn_nms_across_levels": False, + "rpn_nms_pre": 1000, + "rpn_nms_post": 1000, + "rpn_max_num": 1000, + "rpn_nms_thr": 0.7, + "rpn_min_bbox_min_size": 0, + "test_score_thr": 0.05, + "test_iou_thr": 0.5, + "test_max_per_img": 100, + "test_batch_size": 2, + + "rpn_head_loss_type": "CrossEntropyLoss", + "rpn_head_use_sigmoid": True, + "rpn_head_weight": 1.0, + "mask_thr_binary": 0.5, + + # LR + "base_lr": 0.02, + "base_step": 58633, + "total_epoch": 13, + "warmup_step": 500, + "warmup_mode": "linear", + "warmup_ratio": 1/3.0, + "sgd_step": [8, 11], + "sgd_momentum": 0.9, + + # train + "batch_size": 2, + "loss_scale": 1, + "momentum": 0.91, + "weight_decay": 1e-4, + "pretrain_epoch_size": 0, + "epoch_size": 12, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./checkpoint", + + "mindrecord_dir": "/home/mxw/mask_rcnn/scripts/MindRecord_COCO2017_Train", + "coco_root": "/home/mxw/coco2017/", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instance_set": "annotations/instances_{}.json", + "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81 +}) diff --git a/model_zoo/official/cv/maskrcnn/src/dataset.py b/model_zoo/official/cv/maskrcnn/src/dataset.py new file mode 100644 index 0000000000..e0bbbba70f --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/dataset.py @@ -0,0 +1,522 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MaskRcnn dataset""" +from __future__ import division + +import os +import numpy as np +from numpy import random + +import mmcv +import mindspore.dataset as de +import mindspore.dataset.transforms.vision.c_transforms as C +from mindspore.mindrecord import FileWriter +from src.config import config +import cv2 + +def bbox_overlaps(bboxes1, bboxes2, mode='iou'): + """Calculate the ious between each bbox of bboxes1 and bboxes2. + + Args: + bboxes1(ndarray): shape (n, 4) + bboxes2(ndarray): shape (k, 4) + mode(str): iou (intersection over union) or iof (intersection + over foreground) + + Returns: + ious(ndarray): shape (n, k) + """ + + assert mode in ['iou', 'iof'] + + bboxes1 = bboxes1.astype(np.float32) + bboxes2 = bboxes2.astype(np.float32) + rows = bboxes1.shape[0] + cols = bboxes2.shape[0] + ious = np.zeros((rows, cols), dtype=np.float32) + if rows * cols == 0: + return ious + exchange = False + if bboxes1.shape[0] > bboxes2.shape[0]: + bboxes1, bboxes2 = bboxes2, bboxes1 + ious = np.zeros((cols, rows), dtype=np.float32) + exchange = True + area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1) + area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1) + for i in range(bboxes1.shape[0]): + x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) + y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) + x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) + y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) + overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( + y_end - y_start + 1, 0) + if mode == 'iou': + union = area1[i] + area2 - overlap + else: + union = area1[i] if not exchange else area2 + ious[i, :] = overlap / union + if exchange: + ious = ious.T + return ious + +class PhotoMetricDistortion: + """Photo Metric Distortion""" + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def __call__(self, img, boxes, labels): + # random brightness + img = img.astype('float32') + + if random.randint(2): + delta = random.uniform(-self.brightness_delta, + self.brightness_delta) + img += delta + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + + # convert color from BGR to HSV + img = mmcv.bgr2hsv(img) + + # random saturation + if random.randint(2): + img[..., 1] *= random.uniform(self.saturation_lower, + self.saturation_upper) + + # random hue + if random.randint(2): + img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + + # convert color from HSV to BGR + img = mmcv.hsv2bgr(img) + + # random contrast + if mode == 0: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + + # randomly swap channels + if random.randint(2): + img = img[..., random.permutation(3)] + + return img, boxes, labels + +class Expand: + """expand image""" + def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): + if to_rgb: + self.mean = mean[::-1] + else: + self.mean = mean + self.min_ratio, self.max_ratio = ratio_range + + def __call__(self, img, boxes, labels, mask): + if random.randint(2): + return img, boxes, labels, mask + + h, w, c = img.shape + ratio = random.uniform(self.min_ratio, self.max_ratio) + expand_img = np.full((int(h * ratio), int(w * ratio), c), + self.mean).astype(img.dtype) + left = int(random.uniform(0, w * ratio - w)) + top = int(random.uniform(0, h * ratio - h)) + expand_img[top:top + h, left:left + w] = img + img = expand_img + boxes += np.tile((left, top), 2) + + mask_count, mask_h, mask_w = mask.shape + expand_mask = np.zeros((mask_count, int(mask_h * ratio), int(mask_w * ratio))).astype(mask.dtype) + expand_mask[:, top:top + h, left:left + w] = mask + mask = expand_mask + + return img, boxes, labels, mask + +def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """rescale operation for image""" + img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) + if img_data.shape[0] > config.img_height: + img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True) + scale_factor = scale_factor*scale_factor2 + + gt_bboxes = gt_bboxes * scale_factor + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_data.shape[1] - 1) + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_data.shape[0] - 1) + + gt_mask_data = np.array([ + mmcv.imrescale(mask, scale_factor, interpolation='nearest') + for mask in gt_mask + ]) + + pad_h = config.img_height - img_data.shape[0] + pad_w = config.img_width - img_data.shape[1] + assert ((pad_h >= 0) and (pad_w >= 0)) + + pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype) + pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data + + mask_count, mask_h, mask_w = gt_mask_data.shape + pad_mask = np.zeros((mask_count, config.img_height, config.img_width)).astype(gt_mask_data.dtype) + pad_mask[:, 0:mask_h, 0:mask_w] = gt_mask_data + + img_shape = (config.img_height, config.img_width, 1.0) + img_shape = np.asarray(img_shape, dtype=np.float32) + + return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, pad_mask) + +def rescale_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """rescale operation for image of eval""" + img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) + if img_data.shape[0] > config.img_height: + img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True) + scale_factor = scale_factor*scale_factor2 + + pad_h = config.img_height - img_data.shape[0] + pad_w = config.img_width - img_data.shape[1] + assert ((pad_h >= 0) and (pad_w >= 0)) + + pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype) + pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data + + img_shape = np.append(img_shape, (scale_factor, scale_factor)) + img_shape = np.asarray(img_shape, dtype=np.float32) + + return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + +def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """resize operation for image""" + img_data = img + img_data, w_scale, h_scale = mmcv.imresize( + img_data, (config.img_width, config.img_height), return_scale=True) + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + img_shape = (config.img_height, config.img_width, 1.0) + img_shape = np.asarray(img_shape, dtype=np.float32) + + gt_bboxes = gt_bboxes * scale_factor + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) # x1, x2 [0, W-1] + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) # y1, y2 [0, H-1] + + gt_mask_data = np.array([ + mmcv.imresize(mask, (config.img_width, config.img_height), interpolation='nearest') + for mask in gt_mask + ]) + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask_data) + +def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """resize operation for image of eval""" + img_data = img + img_data, w_scale, h_scale = mmcv.imresize( + img_data, (config.img_width, config.img_height), return_scale=True) + img_shape = np.append(img_shape, (h_scale, w_scale)) + img_shape = np.asarray(img_shape, dtype=np.float32) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + +def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """impad operation for image""" + img_data = mmcv.impad(img, (config.img_height, config.img_width)) + img_data = img_data.astype(np.float32) + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + +def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """imnormalize operation for image""" + img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True) + img_data = img_data.astype(np.float32) + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + +def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """flip operation for image""" + img_data = img + img_data = mmcv.imflip(img_data) + flipped = gt_bboxes.copy() + _, w, _ = img_data.shape + + flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 # x1 = W-x2-1 + flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 # x2 = W-x1-1 + + gt_mask_data = np.array([mask[:, ::-1] for mask in gt_mask]) + + return (img_data, img_shape, flipped, gt_label, gt_num, gt_mask_data) + +def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """transpose operation for image""" + img_data = img.transpose(2, 0, 1).copy() + img_data = img_data.astype(np.float16) + img_shape = img_shape.astype(np.float16) + gt_bboxes = gt_bboxes.astype(np.float16) + gt_label = gt_label.astype(np.int32) + gt_num = gt_num.astype(np.bool) + gt_mask_data = gt_mask.astype(np.bool) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask_data) + +def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """photo crop operation for image""" + random_photo = PhotoMetricDistortion() + img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + +def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """expand operation for image""" + expand = Expand() + img, gt_bboxes, gt_label, gt_mask = expand(img, gt_bboxes, gt_label, gt_mask) + + return (img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + +def pad_to_max(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask, instance_count): + pad_max_number = config.max_instance_count + gt_box_new = np.pad(gt_bboxes, ((0, pad_max_number - instance_count), (0, 0)), mode="constant", constant_values=0) + gt_label_new = np.pad(gt_label, ((0, pad_max_number - instance_count)), mode="constant", constant_values=-1) + gt_iscrowd_new = np.pad(gt_num, ((0, pad_max_number - instance_count)), mode="constant", constant_values=1) + gt_iscrowd_new_revert = ~(gt_iscrowd_new.astype(np.bool)) + gt_mask_new = np.pad(gt_mask, ((0, pad_max_number - instance_count), (0, 0), (0, 0)), mode="constant", + constant_values=0) + + return img, img_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, gt_mask_new + +def preprocess_fn(image, box, mask, mask_shape, is_training): + """Preprocess function for dataset.""" + def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, + gt_mask_new, instance_count): + image_shape = image_shape[:2] + input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, gt_mask_new + + if config.keep_ratio: + input_data = rescale_column_test(*input_data) + else: + input_data = resize_column_test(*input_data) + input_data = imnormalize_column(*input_data) + + input_data = pad_to_max(*input_data, instance_count) + output_data = transpose_column(*input_data) + return output_data + + def _data_aug(image, box, mask, mask_shape, is_training): + """Data augmentation function.""" + image_bgr = image.copy() + image_bgr[:, :, 0] = image[:, :, 2] + image_bgr[:, :, 1] = image[:, :, 1] + image_bgr[:, :, 2] = image[:, :, 0] + image_shape = image_bgr.shape[:2] + instance_count = box.shape[0] + gt_box = box[:, :4] + gt_label = box[:, 4] + gt_iscrowd = box[:, 5] + gt_mask = mask.copy() + n, h, w = mask_shape + gt_mask = gt_mask.reshape(n, h, w) + assert n == box.shape[0] + + if not is_training: + return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_iscrowd, gt_mask, instance_count) + + flip = (np.random.rand() < config.flip_ratio) + expand = (np.random.rand() < config.expand_ratio) + + input_data = image_bgr, image_shape, gt_box, gt_label, gt_iscrowd, gt_mask + + if expand: + input_data = expand_column(*input_data) + if config.keep_ratio: + input_data = rescale_column(*input_data) + else: + input_data = resize_column(*input_data) + + input_data = imnormalize_column(*input_data) + if flip: + input_data = flip_column(*input_data) + + input_data = pad_to_max(*input_data, instance_count) + output_data = transpose_column(*input_data) + return output_data + + return _data_aug(image, box, mask, mask_shape, is_training) + +def annToMask(ann, height, width): + """Convert annotation to RLE and then to binary mask.""" + from pycocotools import mask as maskHelper + segm = ann['segmentation'] + if isinstance(segm, list): + rles = maskHelper.frPyObjects(segm, height, width) + rle = maskHelper.merge(rles) + elif isinstance(segm['counts'], list): + rle = maskHelper.frPyObjects(segm, height, width) + else: + rle = ann['segmentation'] + m = maskHelper.decode(rle) + return m + +def create_coco_label(is_training): + """Get image path and annotation from COCO.""" + from pycocotools.coco import COCO + + coco_root = config.coco_root + data_type = config.val_data_type + if is_training: + data_type = config.train_data_type + + #Classes need to train or test. + train_cls = config.coco_classes + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + + anno_json = os.path.join(coco_root, config.instance_set.format(data_type)) + + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + + image_ids = coco.getImgIds() + image_files = [] + image_anno_dict = {} + masks = {} + masks_shape = {} + for img_id in image_ids: + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + image_path = os.path.join(coco_root, data_type, file_name) + annos = [] + instance_masks = [] + image_height = coco.imgs[img_id]["height"] + image_width = coco.imgs[img_id]["width"] + print("image file name: ", file_name) + if not is_training: + image_files.append(image_path) + image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) + masks[image_path] = np.zeros([1, 1, 1], dtype=np.bool).tobytes() + masks_shape[image_path] = np.array([1, 1, 1], dtype=np.int32) + else: + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + if class_name in train_cls: + # get coco mask + m = annToMask(label, image_height, image_width) + if m.max() < 1: + print("all black mask!!!!") + continue + # Resize mask for the crowd + if label['iscrowd'] and (m.shape[0] != image_height or m.shape[1] != image_width): + m = np.ones([image_height, image_width], dtype=np.bool) + instance_masks.append(m) + + # get coco bbox + x1, x2 = bbox[0], bbox[0] + bbox[2] + y1, y2 = bbox[1], bbox[1] + bbox[3] + annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])]) + else: + print("not in classes: ", class_name) + + image_files.append(image_path) + if annos: + image_anno_dict[image_path] = np.array(annos) + instance_masks = np.stack(instance_masks, axis=0).astype(np.bool) + masks[image_path] = np.array(instance_masks).tobytes() + masks_shape[image_path] = np.array(instance_masks.shape, dtype=np.int32) + else: + print("no annotations for image ", file_name) + image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) + masks[image_path] = np.zeros([1, image_height, image_width], dtype=np.bool).tobytes() + masks_shape[image_path] = np.array([1, image_height, image_width], dtype=np.int32) + + return image_files, image_anno_dict, masks, masks_shape + +def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="maskrcnn.mindrecord", file_num=8): + """Create MindRecord file.""" + mindrecord_dir = config.mindrecord_dir + mindrecord_path = os.path.join(mindrecord_dir, prefix) + + writer = FileWriter(mindrecord_path, file_num) + if dataset == "coco": + image_files, image_anno_dict, masks, masks_shape = create_coco_label(is_training) + else: + print("Error unsupport other dataset") + return + + maskrcnn_json = { + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 6]}, + "mask": {"type": "bytes"}, + "mask_shape": {"type": "int32", "shape": [-1]}, + } + writer.add_schema(maskrcnn_json, "maskrcnn_json") + + for image_name in image_files: + with open(image_name, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[image_name], dtype=np.int32) + mask = masks[image_name] + mask_shape = masks_shape[image_name] + row = {"image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape} + writer.write_raw_data([row]) + writer.commit() + +def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0, + is_training=True, num_parallel_workers=8): + """Create MaskRcnn dataset with MindDataset.""" + cv2.setNumThreads(0) + de.config.set_prefetch_size(8) + ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation", "mask", "mask_shape"], + num_shards=device_num, shard_id=rank_id, + num_parallel_workers=4, shuffle=is_training) + + decode = C.Decode() + ds = ds.map(input_columns=["image"], operations=decode) + compose_map_func = (lambda image, annotation, mask, mask_shape: + preprocess_fn(image, annotation, mask, mask_shape, is_training)) + + if is_training: + ds = ds.map(input_columns=["image", "annotation", "mask", "mask_shape"], + output_columns=["image", "image_shape", "box", "label", "valid_num", "mask"], + columns_order=["image", "image_shape", "box", "label", "valid_num", "mask"], + operations=compose_map_func, + python_multiprocessing=False, + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + + else: + ds = ds.map(input_columns=["image", "annotation", "mask", "mask_shape"], + output_columns=["image", "image_shape", "box", "label", "valid_num", "mask"], + columns_order=["image", "image_shape", "box", "label", "valid_num", "mask"], + operations=compose_map_func, + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + + return ds diff --git a/model_zoo/official/cv/maskrcnn/src/lr_schedule.py b/model_zoo/official/cv/maskrcnn/src/lr_schedule.py new file mode 100644 index 0000000000..dae7e2ea82 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/lr_schedule.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr generator for maskrcnn""" +import math + +def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + learning_rate = float(init_lr) + lr_inc * current_step + return learning_rate + +def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): + base = float(current_step - warmup_steps) / float(decay_steps) + learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr + return learning_rate + +def dynamic_lr(config, rank_size=1, start_steps=0): + """dynamic learning rate generator""" + base_lr = config.base_lr + + base_step = (config.base_step // rank_size) + rank_size + total_steps = int(base_step * config.total_epoch) + warmup_steps = int(config.warmup_step) + lr = [] + for i in range(total_steps): + if i < warmup_steps: + lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) + else: + lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) + learning_rate = lr[start_steps:] + return learning_rate diff --git a/model_zoo/official/cv/maskrcnn/src/network_define.py b/model_zoo/official/cv/maskrcnn/src/network_define.py new file mode 100644 index 0000000000..dc18da3956 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/network_define.py @@ -0,0 +1,193 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MaskRcnn training network wrapper.""" + +import time +import numpy as np +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore import ParameterTuple +from mindspore.train.callback import Callback +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer + +time_stamp_init = False +time_stamp_first = 0 + +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF terminating training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.count = 0 + self.rpn_loss_sum = 0 + self.rcnn_loss_sum = 0 + self.rpn_cls_loss_sum = 0 + self.rpn_reg_loss_sum = 0 + self.rcnn_cls_loss_sum = 0 + self.rcnn_reg_loss_sum = 0 + self.rcnn_mask_loss_sum = 0 + + global time_stamp_init, time_stamp_first + if not time_stamp_init: + time_stamp_first = time.time() + time_stamp_init = True + + def step_end(self, run_context): + cb_params = run_context.original_args() + rpn_loss = cb_params.net_outputs[0].asnumpy() + rcnn_loss = cb_params.net_outputs[1].asnumpy() + rpn_cls_loss = cb_params.net_outputs[2].asnumpy() + + rpn_reg_loss = cb_params.net_outputs[3].asnumpy() + rcnn_cls_loss = cb_params.net_outputs[4].asnumpy() + rcnn_reg_loss = cb_params.net_outputs[5].asnumpy() + rcnn_mask_loss = cb_params.net_outputs[6].asnumpy() + + self.count += 1 + self.rpn_loss_sum += float(rpn_loss) + self.rcnn_loss_sum += float(rcnn_loss) + self.rpn_cls_loss_sum += float(rpn_cls_loss) + self.rpn_reg_loss_sum += float(rpn_reg_loss) + self.rcnn_cls_loss_sum += float(rcnn_cls_loss) + self.rcnn_reg_loss_sum += float(rcnn_reg_loss) + self.rcnn_mask_loss_sum += float(rcnn_mask_loss) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if self.count >= 1: + global time_stamp_first + time_stamp_current = time.time() + + rpn_loss = self.rpn_loss_sum/self.count + rcnn_loss = self.rcnn_loss_sum/self.count + rpn_cls_loss = self.rpn_cls_loss_sum/self.count + + rpn_reg_loss = self.rpn_reg_loss_sum/self.count + rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count + rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count + rcnn_mask_loss = self.rcnn_mask_loss_sum/self.count + + total_loss = rpn_loss + rcnn_loss + + loss_file = open("./loss.log", "a+") + loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, " + "rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, rcnn_mask_loss: %.5f, " + "total_loss: %.5f" % + (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, + rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, + rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_loss, total_loss)) + loss_file.write("\n") + loss_file.close() + + self.count = 0 + self.rpn_loss_sum = 0 + self.rcnn_loss_sum = 0 + self.rpn_cls_loss_sum = 0 + self.rpn_reg_loss_sum = 0 + self.rcnn_cls_loss_sum = 0 + self.rcnn_reg_loss_sum = 0 + self.rcnn_mask_loss_sum = 0 + +class LossNet(nn.Cell): + """MaskRcnn loss method""" + def __init__(self): + super(LossNet, self).__init__() + def construct(self, x1, x2, x3, x4, x5, x6, x7): + return x1 + x2 + +class WithLossCell(nn.Cell): + """ + Wrap the network with loss function to compute loss. + + Args: + backbone (Cell): The target network to wrap. + loss_fn (Cell): The loss function used to compute loss. + """ + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask): + loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self._backbone(x, img_shape, gt_bboxe, gt_label, + gt_num, gt_mask) + return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6, loss7) + + @property + def backbone_network(self): + """ + Get the backbone network. + + Returns: + Cell, return backbone network. + """ + return self._backbone + + +class TrainOneStepCell(nn.Cell): + """ + Network training package class. + + Append an optimizer to the training network after that the construct function + can be called to create the backward graph. + + Args: + network (Cell): The training network. + network_backbone (Cell): The forward network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default value is 1.0. + reduce_flag (bool): The reduce flag. Default value is False. + mean (bool): Allreduce method. Default value is False. + degree (int): Device number. Default value is None. + """ + def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.backbone = network_backbone + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) + self.reduce_flag = reduce_flag + self.hyper_map = C.HyperMap() + if reduce_flag: + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask): + weights = self.weights + loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self.backbone(x, img_shape, gt_bboxe, gt_label, + gt_num, gt_mask) + grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens) + if self.reduce_flag: + grads = self.grad_reducer(grads) + + return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6, loss7 diff --git a/model_zoo/official/cv/maskrcnn/src/util.py b/model_zoo/official/cv/maskrcnn/src/util.py new file mode 100644 index 0000000000..f7e8a30770 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/src/util.py @@ -0,0 +1,269 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""coco eval for maskrcnn""" +import json +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from pycocotools import mask as maskUtils +import mmcv + +from src.config import config + +_init_value = np.array(0.0) +summary_init = { + 'Precision/mAP': _init_value, + 'Precision/mAP@.50IOU': _init_value, + 'Precision/mAP@.75IOU': _init_value, + 'Precision/mAP (small)': _init_value, + 'Precision/mAP (medium)': _init_value, + 'Precision/mAP (large)': _init_value, + 'Recall/AR@1': _init_value, + 'Recall/AR@10': _init_value, + 'Recall/AR@100': _init_value, + 'Recall/AR@100 (small)': _init_value, + 'Recall/AR@100 (medium)': _init_value, + 'Recall/AR@100 (large)': _init_value, +} + + +def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): + """coco eval for maskrcnn""" + anns = json.load(open(result_files['bbox'])) + if not anns: + return summary_init + if mmcv.is_str(coco): + coco = COCO(coco) + assert isinstance(coco, COCO) + + for res_type in result_types: + result_file = result_files[res_type] + assert result_file.endswith('.json') + + coco_dets = coco.loadRes(result_file) + gt_img_ids = coco.getImgIds() + det_img_ids = coco_dets.getImgIds() + iou_type = 'bbox' if res_type == 'proposal' else res_type + cocoEval = COCOeval(coco, coco_dets, iou_type) + if res_type == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.params.maxDets = list(max_dets) + + tgt_ids = gt_img_ids if not single_result else det_img_ids + + if single_result: + res_dict = dict() + for id_i in tgt_ids: + cocoEval = COCOeval(coco, coco_dets, iou_type) + if res_type == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.params.maxDets = list(max_dets) + + cocoEval.params.imgIds = [id_i] + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]}) + + cocoEval = COCOeval(coco, coco_dets, iou_type) + if res_type == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.params.maxDets = list(max_dets) + + cocoEval.params.imgIds = tgt_ids + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + summary_metrics = { + 'Precision/mAP': cocoEval.stats[0], + 'Precision/mAP@.50IOU': cocoEval.stats[1], + 'Precision/mAP@.75IOU': cocoEval.stats[2], + 'Precision/mAP (small)': cocoEval.stats[3], + 'Precision/mAP (medium)': cocoEval.stats[4], + 'Precision/mAP (large)': cocoEval.stats[5], + 'Recall/AR@1': cocoEval.stats[6], + 'Recall/AR@10': cocoEval.stats[7], + 'Recall/AR@100': cocoEval.stats[8], + 'Recall/AR@100 (small)': cocoEval.stats[9], + 'Recall/AR@100 (medium)': cocoEval.stats[10], + 'Recall/AR@100 (large)': cocoEval.stats[11], + } + + return summary_metrics + + +def xyxy2xywh(bbox): + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0] + 1, + _bbox[3] - _bbox[1] + 1, + ] + +def bbox2result_1image(bboxes, labels, num_classes): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (Tensor): shape (n, 5) + labels (Tensor): shape (n, ) + num_classes (int): class number, including background class + + Returns: + list(ndarray): bbox results of each class + """ + if bboxes.shape[0] == 0: + result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)] + else: + result = [bboxes[labels == i, :] for i in range(num_classes - 1)] + + return result + +def proposal2json(dataset, results): + """convert proposal to json mode""" + img_ids = dataset.getImgIds() + json_results = [] + dataset_len = dataset.get_dataset_size()*2 + for idx in range(dataset_len): + img_id = img_ids[idx] + bboxes = results[idx] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = 1 + json_results.append(data) + return json_results + +def det2json(dataset, results): + """convert det to json mode""" + cat_ids = dataset.getCatIds() + img_ids = dataset.getImgIds() + json_results = [] + dataset_len = len(img_ids) + for idx in range(dataset_len): + img_id = img_ids[idx] + if idx == len(results): break + result = results[idx] + for label, result_label in enumerate(result): + bboxes = result_label + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = cat_ids[label] + json_results.append(data) + return json_results + +def segm2json(dataset, results): + """convert segm to json mode""" + cat_ids = dataset.getCatIds() + img_ids = dataset.getImgIds() + bbox_json_results = [] + segm_json_results = [] + + dataset_len = len(img_ids) + assert dataset_len == len(results) + for idx in range(dataset_len): + img_id = img_ids[idx] + if idx == len(results): break + det, seg = results[idx] + for label, det_label in enumerate(det): + bboxes = det_label + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = cat_ids[label] + bbox_json_results.append(data) + + if len(seg) == 2: + segms = seg[0][label] + mask_score = seg[1][label] + else: + segms = seg[label] + mask_score = [bbox[4] for bbox in bboxes] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['score'] = float(mask_score[i]) + data['category_id'] = cat_ids[label] + segms[i]['counts'] = segms[i]['counts'].decode() + data['segmentation'] = segms[i] + segm_json_results.append(data) + return bbox_json_results, segm_json_results + +def results2json(dataset, results, out_file): + """convert result convert to json mode""" + result_files = dict() + if isinstance(results[0], list): + json_results = det2json(dataset, results) + result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') + result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') + mmcv.dump(json_results, result_files['bbox']) + elif isinstance(results[0], tuple): + json_results = segm2json(dataset, results) + result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') + result_files['segm'] = '{}.{}.json'.format(out_file, 'segm') + mmcv.dump(json_results[0], result_files['bbox']) + mmcv.dump(json_results[1], result_files['segm']) + elif isinstance(results[0], np.ndarray): + json_results = proposal2json(dataset, results) + result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal') + mmcv.dump(json_results, result_files['proposal']) + else: + raise TypeError('invalid type of results') + return result_files + +def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_classes): + """Get segmentation masks from mask_pred and bboxes""" + mask_pred = mask_pred.astype(np.float32) + + cls_segms = [[] for _ in range(num_classes - 1)] + bboxes = det_bboxes[:, :4] + labels = det_labels + 1 + + ori_shape = img_meta[:2].astype(np.int32) + scale_factor = img_meta[2:].astype(np.int32) + + if rescale: + img_h, img_w = ori_shape[:2] + else: + img_h = np.round(ori_shape[0] * scale_factor[0]).astype(np.int32) + img_w = np.round(ori_shape[1] * scale_factor[1]).astype(np.int32) + scale_factor = 1.0 + + for i in range(bboxes.shape[0]): + bbox = (bboxes[i, :] / 1.0).astype(np.int32) + label = labels[i] + w = max(bbox[2] - bbox[0] + 1, 1) + h = max(bbox[3] - bbox[1] + 1, 1) + w = min(w, img_w - bbox[0]) + h = min(h, img_h - bbox[1]) + mask_pred_ = mask_pred[i, :, :] + im_mask = np.zeros((img_h, img_w), dtype=np.uint8) + bbox_mask = mmcv.imresize(mask_pred_, (w, h)) + bbox_mask = (bbox_mask > config.mask_thr_binary).astype(np.uint8) + im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask + + rle = maskUtils.encode( + np.array(im_mask[:, :, np.newaxis], order='F'))[0] + cls_segms[label - 1].append(rle) + + return cls_segms diff --git a/model_zoo/official/cv/maskrcnn/train.py b/model_zoo/official/cv/maskrcnn/train.py new file mode 100644 index 0000000000..8df2fab00e --- /dev/null +++ b/model_zoo/official/cv/maskrcnn/train.py @@ -0,0 +1,140 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""train MaskRcnn and get checkpoint files.""" + +import os +import argparse +import random +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import context, Tensor +from mindspore.communication.management import init +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn import SGD +import mindspore.dataset.engine as de + +from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 +from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet +from src.config import config +from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset +from src.lr_schedule import dynamic_lr + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="MaskRcnn training") +parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " + "Mindrecord, default is false.") +parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") +parser.add_argument("--do_train", type=bool, default=True, help="Do train or not, default is true.") +parser.add_argument("--do_eval", type=bool, default=False, help="Do eval or not, default is false.") +parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") +parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.") +parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") +parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") +parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) + +if __name__ == '__main__': + print("Start train for maskrcnn!") + if not args_opt.do_eval and args_opt.run_distribute: + rank = args_opt.rank_id + device_num = args_opt.device_num + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True, parameter_broadcast=True) + init() + else: + rank = 0 + device_num = 1 + + print("Start create dataset!") + + # It will generate mindrecord file in args_opt.mindrecord_dir, + # and the file name is MaskRcnn.mindrecord0, 1, ... file_num. + prefix = "MaskRcnn.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + else: + if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + if not args_opt.only_create_dataset: + loss_scale = float(config.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as MaskRcnn.mindrecord0. + dataset = create_maskrcnn_dataset(mindrecord_file, batch_size=config.batch_size, + device_num=device_num, rank_id=rank) + + dataset_size = dataset.get_dataset_size() + print("total images num: ", dataset_size) + print("Create dataset done!") + + net = Mask_Rcnn_Resnet50(config=config) + net = net.set_train() + + load_path = args_opt.pre_trained + if load_path != "": + param_dict = load_checkpoint(load_path) + if config.pretrain_epoch_size == 0: + for item in list(param_dict.keys()): + if not (item.startswith('backbone') or item.startswith('rcnn_mask')): + param_dict.pop(item) + load_param_into_net(net, param_dict) + + loss = LossNet() + lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size), + mstype.float32) + opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, + weight_decay=config.weight_decay, loss_scale=config.loss_scale) + + net_with_loss = WithLossCell(net, loss) + if args_opt.run_distribute: + net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, + mean=True, degree=device_num) + else: + net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) + + time_cb = TimeMonitor(data_size=dataset_size) + loss_cb = LossCallBack() + cb = [time_cb, loss_cb] + if config.save_checkpoint: + ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=config.save_checkpoint_path, config=ckptconfig) + cb += [ckpoint_cb] + + model = Model(net) + model.train(config.epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/official/cv/mobilenetv2/Readme.md b/model_zoo/official/cv/mobilenetv2/Readme.md new file mode 100644 index 0000000000..b39013b07e --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/Readme.md @@ -0,0 +1,151 @@ +# MobileNetV2 Description + + +MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. + +[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. + +# Model architecture + +The overall network architecture of MobileNetV2 is show below: + +[Link](https://arxiv.org/pdf/1905.02244) + +# Dataset + +Dataset used: [imagenet](http://www.image-net.org/) + +- Dataset size: ~125G, 1.2W colorful images in 1000 classes + - Train: 120G, 1.2W images + - Test: 5G, 50000 images +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + + +# Features + + +# Environment Requirements + +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + +# Script description + +## Script and sample code + +```python +├── MobileNetV2 + ├── Readme.md + ├── scripts + │ ├──run_train.sh + │ ├──run_eval.sh + ├── src + │ ├──config.py + │ ├──dataset.py + │ ├──luanch.py + │ ├──lr_generator.py + │ ├──mobilenetV2.py + ├── train.py + ├── eval.py +``` + +## Training process + +### Usage + +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] +- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] + +### Launch + +``` +# training example + Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt + GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ +``` + +### Result + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +## Eval process + +### Usage + +- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] + +### Launch + +``` +# infer example + Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt + GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt +``` + +> checkpoint can be produced in training process. + +### Result + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt +``` + +# Model description + +## Performance + +### Training Performance + +| Parameters | MobilenetV2 | | +| -------------------------- | ---------------------------------------------------------- | ------------------------- | +| Model Version | | large | +| Resource | Ascend 910, cpu:2.60GHz 56cores, memory:314G | NV SMX2 V100-32G | +| uploaded Date | 05/06/2020 | 05/06/2020 | +| MindSpore Version | 0.3.0 | 0.3.0 | +| Dataset | ImageNet | ImageNet | +| Training Parameters | src/config.py | src/config.py | +| Optimizer | Momentum | Momentum | +| Loss Function | SoftmaxCrossEntropy | SoftmaxCrossEntropy | +| outputs | | | +| Loss | | 1.913 | +| Accuracy | | ACC1[77.09%] ACC5[92.57%] | +| Total time | | | +| Params (M) | | | +| Checkpoint for Fine tuning | | | +| Model for inference | | | + +#### Inference Performance + +| Parameters | | | | +| -------------------------- | ----------------------------- | ------------------------- | -------------------- | +| Model Version | V1 | | | +| Resource | Huawei 910 | NV SMX2 V100-32G | Huawei 310 | +| uploaded Date | 05/06/2020 | 05/22/2020 | | +| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 | +| Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W | +| batch_size | | 130(8P) | | +| outputs | | | | +| Accuracy | | ACC1[72.07%] ACC5[90.90%] | | +| Speed | | | | +| Total time | | | | +| Model for inference | | | | + +# ModelZoo Homepage + [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) \ No newline at end of file diff --git a/model_zoo/mobilenetv2/eval.py b/model_zoo/official/cv/mobilenetv2/eval.py similarity index 100% rename from model_zoo/mobilenetv2/eval.py rename to model_zoo/official/cv/mobilenetv2/eval.py diff --git a/model_zoo/mobilenetv2/scripts/run_infer.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh similarity index 100% rename from model_zoo/mobilenetv2/scripts/run_infer.sh rename to model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh new file mode 100644 index 0000000000..c260aac787 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +run_ascend() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-8)" + exit 1 + fi + + if [ ! -d $5 ] && [ ! -f $5 ] + then + echo "error: DATASET_PATH=$5 is not a directory or file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + export RANK_TABLE_FILE=$4 + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + python ${BASEPATH}/../src/launch.py \ + --nproc_per_node=$2 \ + --visible_devices=$3 \ + --training_script=${BASEPATH}/../train.py \ + --dataset_path=$5 \ + --pre_trained=$6 \ + --platform=$1 &> ../train.log & # dataset train folder +} + +run_gpu() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-8)" + exit 1 + fi + + if [ ! -d $4 ] + then + echo "error: DATASET_PATH=$4 is not a directory" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + + export CUDA_VISIBLE_DEVICES="$3" + mpirun -n $2 --allow-run-as-root \ + python ${BASEPATH}/../train.py \ + --dataset_path=$4 \ + --platform=$1 \ + &> ../train.log & # dataset train folder +} + +if [ $# -gt 6 ] || [ $# -lt 4 ] +then + echo "Usage:\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + " +exit 1 +fi + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +elif [ $1 = "GPU" ] ; then + run_gpu "$@" +else + echo "Unsupported platform." +fi; + diff --git a/model_zoo/mobilenetv2/src/config.py b/model_zoo/official/cv/mobilenetv2/src/config.py similarity index 100% rename from model_zoo/mobilenetv2/src/config.py rename to model_zoo/official/cv/mobilenetv2/src/config.py diff --git a/model_zoo/mobilenetv2/src/dataset.py b/model_zoo/official/cv/mobilenetv2/src/dataset.py similarity index 100% rename from model_zoo/mobilenetv2/src/dataset.py rename to model_zoo/official/cv/mobilenetv2/src/dataset.py diff --git a/model_zoo/mobilenetv2/src/launch.py b/model_zoo/official/cv/mobilenetv2/src/launch.py similarity index 100% rename from model_zoo/mobilenetv2/src/launch.py rename to model_zoo/official/cv/mobilenetv2/src/launch.py diff --git a/model_zoo/mobilenetv2/src/lr_generator.py b/model_zoo/official/cv/mobilenetv2/src/lr_generator.py similarity index 100% rename from model_zoo/mobilenetv2/src/lr_generator.py rename to model_zoo/official/cv/mobilenetv2/src/lr_generator.py diff --git a/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py new file mode 100644 index 0000000000..76fa21acdf --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py @@ -0,0 +1,292 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 model define""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops.operations import TensorAdd +from mindspore import Parameter, Tensor +from mindspore.common.initializer import initializer + +__all__ = ['mobilenet_v2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class DepthwiseConv(nn.Cell): + """ + Depthwise Convolution warpper definition. + + Args: + in_planes (int): Input channel. + kernel_size (int): Input kernel size. + stride (int): Stride size. + pad_mode (str): pad mode in (pad, same, valid) + channel_multiplier (int): Output channel multiplier + has_bias (bool): has bias or not + + Returns: + Tensor, output tensor. + + Examples: + >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) + """ + + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): + super(DepthwiseConv, self).__init__() + self.has_bias = has_bias + self.in_channels = in_planes + self.channel_multiplier = channel_multiplier + self.out_channels = in_planes * channel_multiplier + self.kernel_size = (kernel_size, kernel_size) + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, + kernel_size=self.kernel_size, + stride=stride, pad_mode=pad_mode, pad=pad) + self.bias_add = P.BiasAdd() + weight_shape = [channel_multiplier, in_planes, *self.kernel_size] + self.weight = Parameter(initializer('ones', weight_shape), name='weight') + + if has_bias: + bias_shape = [channel_multiplier * in_planes] + self.bias = Parameter(initializer('zeros', bias_shape), name='bias') + else: + self.bias = None + + def construct(self, x): + output = self.depthwise_conv(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + return output + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) + else: + if platform == "Ascend": + conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) + elif platform == "GPU": + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, + group=in_planes, pad_mode='pad', padding=padding) + + layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, platform, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(platform, hidden_dim, hidden_dim, + stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, kernel_size=1, + stride=1, has_bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.SequentialCell(layers) + self.add = TensorAdd() + self.cast = P.Cast() + + def construct(self, x): + identity = x + x = self.conv(x) + if self.use_res_connect: + return self.add(identity, x) + return x + + +class MobileNetV2(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + + def __init__(self, platform, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(platform, 3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else + [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) + self.head = nn.SequentialCell(head) + + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, (nn.Conv2d, DepthwiseConv)): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + +def mobilenet_v2(**kwargs): + """ + Constructs a MobileNet V2 model + """ + return MobileNetV2(**kwargs) diff --git a/model_zoo/official/cv/mobilenetv2/src/mobilenetV2_fusion.py b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2_fusion.py new file mode 100644 index 0000000000..715231d8fc --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2_fusion.py @@ -0,0 +1,239 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """MobileNetV2 Quant model define""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor + +__all__ = ['mobilenetV2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10 %. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + group=groups, + has_bn=True, + activation='relu') + + def construct(self, x): + x = self.conv(x) + return x + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) + ]) + self.conv = nn.SequentialCell(layers) + self.add = P.TensorAdd() + + def construct(self, x): + out = self.conv(x) + if self.use_res_connect: + out = self.add(out, x) + return out + + +class mobilenetV2(nn.Cell): + """ + mobilenetV2 fusion architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> mobilenetV2(num_classes=1000) + """ + + def __init__(self, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(mobilenetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ] if not has_dropout else + [GlobalAvgPooling(), + nn.Dropout(0.2), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ]) + self.head = nn.SequentialCell(head) + + # init weights + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32")) + m.weight.set_parameter_data(w) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.Conv2dBnAct): + n = m.conv.kernel_size[0] * m.conv.kernel_size[1] * m.conv.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.conv.weight.data.shape).astype("float32")) + m.conv.weight.set_parameter_data(w) + if m.conv.bias is not None: + m.conv.bias.set_parameter_data(Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.DenseBnAct): + m.dense.weight.set_parameter_data( + Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32"))) + if m.dense.bias is not None: + m.dense.bias.set_parameter_data(Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32"))) diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py new file mode 100644 index 0000000000..8862937a8e --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -0,0 +1,279 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train mobilenetV2 on ImageNet.""" + +import os +import time +import argparse +import random +import numpy as np + +from mindspore import context +from mindspore import Tensor +from mindspore import nn +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init, get_group_size, get_rank +import mindspore.dataset.engine as de + +from src.dataset import create_dataset +from src.lr_generator import get_lr +from src.config import config_gpu, config_ascend +from src.mobilenetV2 import mobilenet_v2 + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') +parser.add_argument('--platform', type=str, default=None, help='run platform') +args_opt = parser.parse_args() + +if args_opt.platform == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) + run_distribute = rank_size > 1 + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + device_id=device_id, save_graphs=False) +elif args_opt.platform == "GPU": + context.set_context(mode=context.GRAPH_MODE, + device_target="GPU", + save_graphs=False) + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) +else: + raise ValueError("Unsupported device target.") + + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + + +if __name__ == '__main__': + if args_opt.platform == "GPU": + # train on gpu + print("train args: ", args_opt) + print("cfg: ", config_gpu) + + # define network + net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") + # define loss + if config_gpu.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth, + num_classes=config_gpu.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + # define dataset + epoch_size = config_gpu.epoch_size + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config_gpu, + platform=args_opt.platform, + repeat_num=1, + batch_size=config_gpu.batch_size) + step_size = dataset.get_dataset_size() + # resume + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + + # get learning rate + loss_scale = FixedLossScaleManager( + config_gpu.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=0, + lr_max=config_gpu.lr, + warmup_epochs=config_gpu.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + + # define optimization + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, + config_gpu.weight_decay, config_gpu.loss_scale) + # define model + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) + + print("============== Starting Training ==============") + cb = [Monitor(lr_init=lr.asnumpy())] + ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" + if config_gpu.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config_gpu.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + # begin train + model.train(epoch_size, dataset, callbacks=cb) + print("============== End Training ==============") + elif args_opt.platform == "Ascend": + # train on ascend + print("train args: ", args_opt, "\ncfg: ", config_ascend, + "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + init() + + epoch_size = config_ascend.epoch_size + net = mobilenet_v2(num_classes=config_ascend.num_classes, platform="Ascend") + net.to_float(mstype.float16) + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.to_float(mstype.float32) + if config_ascend.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction='mean') + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config_ascend, + platform=args_opt.platform, + repeat_num=1, + batch_size=config_ascend.batch_size) + step_size = dataset.get_dataset_size() + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + + loss_scale = FixedLossScaleManager( + config_ascend.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=0, + lr_max=config_ascend.lr, + warmup_epochs=config_ascend.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, + config_ascend.weight_decay, config_ascend.loss_scale) + + model = Model(net, loss_fn=loss, optimizer=opt, + loss_scale_manager=loss_scale) + + cb = None + if rank_id == 0: + cb = [Monitor(lr_init=lr.asnumpy())] + if config_ascend.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config_ascend.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint( + prefix="mobilenetV2", directory=config_ascend.save_checkpoint_path, config=config_ck) + cb += [ckpt_cb] + model.train(epoch_size, dataset, callbacks=cb) + else: + raise ValueError("Unsupport platform.") diff --git a/model_zoo/official/cv/mobilenetv2_quant/Readme.md b/model_zoo/official/cv/mobilenetv2_quant/Readme.md new file mode 100644 index 0000000000..03b755ee3d --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/Readme.md @@ -0,0 +1,144 @@ +# MobileNetV2 Quantization Aware Training + +MobileNetV2 is a significant improvement over MobileNetV1 and pushes the state of the art for mobile visual recognition including classification, object detection and semantic segmentation. + +MobileNetV2 builds upon the ideas from MobileNetV1, using depthwise separable convolution as efficient building blocks. However, V2 introduces two new features to the architecture: 1) linear bottlenecks between the layers, and 2) shortcut connections between the bottlenecks1. + +Training MobileNetV2 with ImageNet dataset in MindSpore with quantization aware training. + +This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. + +In this readme tutorial, you will: + +1. Train a MindSpore fusion MobileNetV2 model for ImageNet from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. +2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. + +[Paper](https://arxiv.org/pdf/1801.04381) Sandler, Mark, et al. "Mobilenetv2: Inverted residuals and linear bottlenecks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018. + +# Dataset + +Dataset use: ImageNet + +- Dataset size: about 125G + - Train: 120G, 1281167 images: 1000 directories + - Test: 5G, 50000 images: images should be classified into 1000 directories firstly, just like train images +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + +# Environment Requirements + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + +# Script description + +## Script and sample code + +```python +├── mobilenetv2_quant + ├── Readme.md + ├── scripts + │ ├──run_train.sh + │ ├──run_infer.sh + │ ├──run_train_quant.sh + │ ├──run_infer_quant.sh + ├── src + │ ├──config.py + │ ├──dataset.py + │ ├──luanch.py + │ ├──lr_generator.py + │ ├──mobilenetV2.py + ├── train.py + ├── eval.py +``` + +## Training process + +### Train MobileNetV2 model + +Train a MindSpore fusion MobileNetV2 model for ImageNet, like: + +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] +- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] + +You can just run this command instead. + +``` bash +>>> Ascend: sh run_train.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt +>>> GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ +``` + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +>>> epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +>>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +>>> epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +>>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +### Evaluate MobileNetV2 model + +Evaluate a MindSpore fusion MobileNetV2 model for ImageNet, like: + +- sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt +``` + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +>>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt +``` + +### Fine-tune for quantization aware training + +Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. + +- sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_train_quant.sh Ascend 4 192.168.0.1 0,1,2,3 ~/imagenet/train/ ~/mobilenet.ckpt +``` + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +>>> epoch: [ 0/60], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +>>> epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +>>> epoch: [ 1/60], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +>>> epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +### Evaluate quantization aware training model + +Evaluate a MindSpore fusion MobileNetV2 model for ImageNet by applying the quantization aware training, like: + +- sh run_infer_quant.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] + +You can just run this command instead. + +``` bash +>>> sh run_infer_quant.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_625.ckpt +``` + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +>>> result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-60_625.ckpt +``` + +# ModelZoo Homepage + [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/mobilenetv2_quant/eval.py b/model_zoo/official/cv/mobilenetv2_quant/eval.py similarity index 100% rename from model_zoo/mobilenetv2_quant/eval.py rename to model_zoo/official/cv/mobilenetv2_quant/eval.py diff --git a/model_zoo/mobilenetv2_quant/export.py b/model_zoo/official/cv/mobilenetv2_quant/export.py similarity index 100% rename from model_zoo/mobilenetv2_quant/export.py rename to model_zoo/official/cv/mobilenetv2_quant/export.py diff --git a/model_zoo/mobilenetv2_quant/scripts/run_infer.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer.sh similarity index 100% rename from model_zoo/mobilenetv2_quant/scripts/run_infer.sh rename to model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer.sh diff --git a/model_zoo/mobilenetv2_quant/scripts/run_infer_quant.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer_quant.sh similarity index 100% rename from model_zoo/mobilenetv2_quant/scripts/run_infer_quant.sh rename to model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer_quant.sh diff --git a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train.sh new file mode 100644 index 0000000000..ebbcf9109d --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +run_ascend() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-9)" + exit 1 + fi + + if [ ! -d $5 ] && [ ! -f $5 ] + then + echo "error: DATASET_PATH=$5 is not a directory or file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + python ${BASEPATH}/../src/launch.py \ + --nproc_per_node=$2 \ + --visible_devices=$4 \ + --server_id=$3 \ + --training_script=${BASEPATH}/../train.py \ + --dataset_path=$5 \ + --pre_trained=$6 \ + --device_target=$1 &> train.log & # dataset train folder +} + +run_gpu() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-8)" + exit 1 + fi + + if [ ! -d $4 ] + then + echo "error: DATASET_PATH=$4 is not a directory" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + + export CUDA_VISIBLE_DEVICES="$3" + mpirun -n $2 --allow-run-as-root \ + python ${BASEPATH}/../train.py \ + --dataset_path=$4 \ + --device_target=$1 \ + &> ../train.log & # dataset train folder +} + +if [ $# -gt 6 ] || [ $# -lt 4 ] +then + echo "Usage:\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + " +exit 1 +fi + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +elif [ $1 = "GPU" ] ; then + run_gpu "$@" +else + echo "Unsupported device target." +fi; + diff --git a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh new file mode 100644 index 0000000000..e4d41ac9a2 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +run_ascend() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-9)" + exit 1 + fi + + if [ ! -d $5 ] && [ ! -f $5 ] + then + echo "error: DATASET_PATH=$5 is not a directory or file" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + python ${BASEPATH}/../src/launch.py \ + --nproc_per_node=$2 \ + --visible_devices=$4 \ + --server_id=$3 \ + --training_script=${BASEPATH}/../train.py \ + --dataset_path=$5 \ + --pre_trained=$6 \ + --quantization_aware=True \ + --device_target=$1 &> train.log & # dataset train folder +} + +run_gpu() +{ + if [ $2 -lt 1 ] && [ $2 -gt 8 ] + then + echo "error: DEVICE_NUM=$2 is not in (1-8)" + exit 1 + fi + + if [ ! -d $4 ] + then + echo "error: DATASET_PATH=$4 is not a directory" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + + export CUDA_VISIBLE_DEVICES="$3" + mpirun -n $2 --allow-run-as-root \ + python ${BASEPATH}/../train.py \ + --dataset_path=$4 \ + --device_target=$1 \ + --quantization_aware=True \ + &> ../train.log & # dataset train folder +} + +if [ $# -gt 6 ] || [ $# -lt 4 ] +then + echo "Usage:\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + " +exit 1 +fi + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +elif [ $1 = "GPU" ] ; then + run_gpu "$@" +else + echo "Unsupported device target." +fi; + diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/config.py b/model_zoo/official/cv/mobilenetv2_quant/src/config.py new file mode 100644 index 0000000000..5b526f816f --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/src/config.py @@ -0,0 +1,98 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config_ascend = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 256, + "data_load_mode": "mindrecord", + "epoch_size": 200, + "start_epoch": 0, + "warmup_epochs": 4, + "lr": 0.4, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", + "quantization_aware": False, +}) + +config_ascend_quant = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 192, + "data_load_mode": "mindrecord", + "epoch_size": 60, + "start_epoch": 200, + "warmup_epochs": 1, + "lr": 0.3, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", + "quantization_aware": True, +}) + +config_gpu = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 150, + "epoch_size": 200, + "warmup_epochs": 4, + "lr": 0.8, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", +}) + +config_gpu_quant = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 134, + "epoch_size": 60, + "start_epoch": 200, + "warmup_epochs": 1, + "lr": 0.3, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", + "quantization_aware": True, +}) diff --git a/model_zoo/mobilenetv2_quant/src/dataset.py b/model_zoo/official/cv/mobilenetv2_quant/src/dataset.py similarity index 100% rename from model_zoo/mobilenetv2_quant/src/dataset.py rename to model_zoo/official/cv/mobilenetv2_quant/src/dataset.py diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/launch.py b/model_zoo/official/cv/mobilenetv2_quant/src/launch.py new file mode 100644 index 0000000000..0d05ee9ad7 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/src/launch.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import json +import subprocess +import shutil +import platform +from argparse import ArgumentParser + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_id", type=str, default="", + help="server ip") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + if not args.server_id: + print('pleaser input server ip!!!') + exit(0) + print('server_id:{}'.format(args.server_id)) + + # construct hccn_table + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {} + arch = platform.processor() + hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch] + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + usable_dev = '' + for instance_id in range(args.nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = args.server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(args.nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(args.nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(args.nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + if not os.path.exists(table_path): + os.mkdir(table_path) + table_fn = os.path.join(table_path, + 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if args.nproc_per_node > 1: + env['RANK_TABLE_FILE'] = table_fn + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/model_zoo/mobilenetv2_quant/src/lr_generator.py b/model_zoo/official/cv/mobilenetv2_quant/src/lr_generator.py similarity index 100% rename from model_zoo/mobilenetv2_quant/src/lr_generator.py rename to model_zoo/official/cv/mobilenetv2_quant/src/lr_generator.py diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/mobilenetV2.py b/model_zoo/official/cv/mobilenetv2_quant/src/mobilenetV2.py new file mode 100644 index 0000000000..1b8f029171 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/src/mobilenetV2.py @@ -0,0 +1,244 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 Quant model define""" + +import numpy as np + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor + +__all__ = ['mobilenetV2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + group=groups, + has_bn=True, + activation='relu') + + def construct(self, x): + x = self.conv(x) + return x + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) + ]) + self.conv = nn.SequentialCell(layers) + self.add = P.TensorAdd() + + def construct(self, x): + out = self.conv(x) + if self.use_res_connect: + out = self.add(out, x) + return out + + +class mobilenetV2(nn.Cell): + """ + mobilenetV2 fusion architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> mobilenetV2(num_classes=1000) + """ + + def __init__(self, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(mobilenetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ] if not has_dropout else + [GlobalAvgPooling(), + nn.Dropout(0.2), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ]) + self.head = nn.SequentialCell(head) + + # init weights + self.init_parameters_data() + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32")) + m.weight.set_parameter_data(w) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.Conv2dBnAct): + n = m.conv.kernel_size[0] * m.conv.kernel_size[1] * m.conv.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.conv.weight.data.shape).astype("float32")) + m.conv.weight.set_parameter_data(w) + if m.conv.bias is not None: + m.conv.bias.set_parameter_data(Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.DenseBnAct): + m.dense.weight.set_parameter_data( + Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32"))) + if m.dense.bias is not None: + m.dense.bias.set_parameter_data(Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32"))) diff --git a/model_zoo/mobilenetv2_quant/src/utils.py b/model_zoo/official/cv/mobilenetv2_quant/src/utils.py similarity index 100% rename from model_zoo/mobilenetv2_quant/src/utils.py rename to model_zoo/official/cv/mobilenetv2_quant/src/utils.py diff --git a/model_zoo/official/cv/mobilenetv2_quant/train.py b/model_zoo/official/cv/mobilenetv2_quant/train.py new file mode 100644 index 0000000000..2253fbb5a0 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/train.py @@ -0,0 +1,209 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train mobilenetV2 on ImageNet""" + +import os +import argparse +import random +import numpy as np + +from mindspore import context +from mindspore import Tensor +from mindspore import nn +from mindspore.train.model import Model, ParallelMode +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init, get_group_size, get_rank +from mindspore.train.quant import quant +import mindspore.dataset.engine as de + +from src.dataset import create_dataset +from src.lr_generator import get_lr +from src.utils import Monitor, CrossEntropyWithLabelSmooth +from src.config import config_ascend_quant, config_ascend, config_gpu_quant, config_gpu +from src.mobilenetV2 import mobilenetV2 + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') +parser.add_argument('--device_target', type=str, default=None, help='Run device target') +parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training') +args_opt = parser.parse_args() + +if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) + run_distribute = rank_size > 1 + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + device_id=device_id, save_graphs=False) +elif args_opt.platform == "GPU": + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + context.set_context(mode=context.GRAPH_MODE, + device_target="GPU", + save_graphs=False) +else: + raise ValueError("Unsupported device target.") + + +def train_on_ascend(): + config = config_ascend_quant if args_opt.quantization_aware else config_ascend + print("training args: {}".format(args_opt)) + print("training configure: {}".format(config)) + print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + epoch_size = config.epoch_size + + # distribute init + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, + parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, + mirror_mean=True) + init() + + # define network + network = mobilenetV2(num_classes=config.num_classes) + # define loss + if config.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes) + else: + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + # define dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config, + device_target=args_opt.device_target, + repeat_num=1, + batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + # load pre trained ckpt + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(network, param_dict) + + # convert fusion network to quantization aware network + if config.quantization_aware: + network = quant.convert_quant_network(network, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + + # get learning rate + lr = Tensor(get_lr(global_step=config.start_epoch * step_size, + lr_init=0, + lr_end=0, + lr_max=config.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=epoch_size + config.start_epoch, + steps_per_epoch=step_size)) + + # define optimization + opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum, + config.weight_decay) + # define model + model = Model(network, loss_fn=loss, optimizer=opt) + + print("============== Starting Training ==============") + callback = None + if rank_id == 0: + callback = [Monitor(lr_init=lr.asnumpy())] + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", + directory=config.save_checkpoint_path, + config=config_ck) + callback += [ckpt_cb] + model.train(epoch_size, dataset, callbacks=callback) + print("============== End Training ==============") + + +def train_on_gpu(): + config = config_gpu_quant if args_opt.quantization_aware else config_gpu + print("training args: {}".format(args_opt)) + print("training configure: {}".format(config)) + + # define network + network = mobilenetV2(num_classes=config.num_classes) + # define loss + if config.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, + num_classes=config.num_classes) + else: + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + # define dataset + epoch_size = config.epoch_size + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config, + device_target=args_opt.device_target, + repeat_num=1, + batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + # resume + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(network, param_dict) + + # convert fusion network to quantization aware network + if config.quantization_aware: + network = quant.convert_quant_network(network, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, True]) + + # get learning rate + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=config.start_epoch * step_size, + lr_init=0, + lr_end=0, + lr_max=config.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=epoch_size + config.start_epoch, + steps_per_epoch=step_size)) + + # define optimization + opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum, + config.weight_decay, config.loss_scale) + # define model + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) + + print("============== Starting Training ==============") + callback = [Monitor(lr_init=lr.asnumpy())] + ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) + callback += [ckpt_cb] + model.train(epoch_size, dataset, callbacks=callback) + print("============== End Training ==============") + + +if __name__ == '__main__': + if args_opt.device_target == "Ascend": + train_on_ascend() + elif args_opt.platform == "GPU": + train_on_gpu() diff --git a/model_zoo/mobilenetv3/Readme.md b/model_zoo/official/cv/mobilenetv3/Readme.md similarity index 100% rename from model_zoo/mobilenetv3/Readme.md rename to model_zoo/official/cv/mobilenetv3/Readme.md diff --git a/model_zoo/mobilenetv3/eval.py b/model_zoo/official/cv/mobilenetv3/eval.py similarity index 100% rename from model_zoo/mobilenetv3/eval.py rename to model_zoo/official/cv/mobilenetv3/eval.py diff --git a/model_zoo/mobilenetv3/scripts/run_infer.sh b/model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh similarity index 100% rename from model_zoo/mobilenetv3/scripts/run_infer.sh rename to model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh diff --git a/model_zoo/mobilenetv3/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv3/scripts/run_train.sh similarity index 100% rename from model_zoo/mobilenetv3/scripts/run_train.sh rename to model_zoo/official/cv/mobilenetv3/scripts/run_train.sh diff --git a/model_zoo/mobilenetv3/src/config.py b/model_zoo/official/cv/mobilenetv3/src/config.py similarity index 100% rename from model_zoo/mobilenetv3/src/config.py rename to model_zoo/official/cv/mobilenetv3/src/config.py diff --git a/model_zoo/mobilenetv3/src/dataset.py b/model_zoo/official/cv/mobilenetv3/src/dataset.py similarity index 100% rename from model_zoo/mobilenetv3/src/dataset.py rename to model_zoo/official/cv/mobilenetv3/src/dataset.py diff --git a/model_zoo/official/cv/mobilenetv3/src/launch.py b/model_zoo/official/cv/mobilenetv3/src/launch.py new file mode 100644 index 0000000000..df5f4e65f0 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv3/src/launch.py @@ -0,0 +1,162 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import json +import subprocess +import shutil +from argparse import ArgumentParser + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_id", type=str, default="", + help="server ip") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + if not args.server_id: + print('pleaser input server ip!!!') + exit(0) + print('server_id:{}'.format(args.server_id)) + + # construct hccn_table + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {} + hccn_table['board_id'] = '0x0000' + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + usable_dev = '' + for instance_id in range(args.nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = args.server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(args.nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(args.nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(args.nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + if not os.path.exists(table_path): + os.mkdir(table_path) + table_fn = os.path.join(table_path, + 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if args.nproc_per_node > 1: + env['RANK_TABLE_FILE'] = table_fn + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/model_zoo/mobilenetv3/src/lr_generator.py b/model_zoo/official/cv/mobilenetv3/src/lr_generator.py similarity index 100% rename from model_zoo/mobilenetv3/src/lr_generator.py rename to model_zoo/official/cv/mobilenetv3/src/lr_generator.py diff --git a/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py b/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py new file mode 100644 index 0000000000..b84cbeb83c --- /dev/null +++ b/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py @@ -0,0 +1,392 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV3 model define""" +from functools import partial +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor + + +__all__ = ['mobilenet_v3_large', + 'mobilenet_v3_small'] + + +def _make_divisible(x, divisor=8): + return int(np.ceil(x * 1. / divisor) * divisor) + + +class Activation(nn.Cell): + """ + Activation definition. + + Args: + act_func(string): activation name. + + Returns: + Tensor, output tensor. + """ + + def __init__(self, act_func): + super(Activation, self).__init__() + if act_func == 'relu': + self.act = nn.ReLU() + elif act_func == 'relu6': + self.act = nn.ReLU6() + elif act_func in ('hsigmoid', 'hard_sigmoid'): + self.act = nn.HSigmoid() + elif act_func in ('hswish', 'hard_swish'): + self.act = nn.HSwish() + else: + raise NotImplementedError + + def construct(self, x): + return self.act(x) + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self, keep_dims=False): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=keep_dims) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class SE(nn.Cell): + """ + SE warpper definition. + + Args: + num_out (int): Output channel. + ratio (int): middle output ratio. + + Returns: + Tensor, output tensor. + + Examples: + >>> SE(4) + """ + + def __init__(self, num_out, ratio=4): + super(SE, self).__init__() + num_mid = _make_divisible(num_out // ratio) + self.pool = GlobalAvgPooling(keep_dims=True) + self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid, + kernel_size=1, has_bias=True, pad_mode='pad') + self.act1 = Activation('relu') + self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out, + kernel_size=1, has_bias=True, pad_mode='pad') + self.act2 = Activation('hsigmoid') + self.mul = P.Mul() + + def construct(self, x): + out = self.pool(x) + out = self.conv1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.act2(out) + out = self.mul(x, out) + return out + + +class Unit(nn.Cell): + """ + Unit warpper definition. + + Args: + num_in (int): Input channel. + num_out (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size. + padding (int): Padding number. + num_groups (int): Output num group. + use_act (bool): Used activation or not. + act_type (string): Activation type. + + Returns: + Tensor, output tensor. + + Examples: + >>> Unit(3, 3) + """ + + def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1, + use_act=True, act_type='relu'): + super(Unit, self).__init__() + self.conv = nn.Conv2d(in_channels=num_in, + out_channels=num_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + group=num_groups, + has_bias=False, + pad_mode='pad') + self.bn = nn.BatchNorm2d(num_out) + self.use_act = use_act + self.act = Activation(act_type) if use_act else None + + def construct(self, x): + out = self.conv(x) + out = self.bn(out) + if self.use_act: + out = self.act(out) + return out + + +class ResUnit(nn.Cell): + """ + ResUnit warpper definition. + + Args: + num_in (int): Input channel. + num_mid (int): Middle channel. + num_out (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size. + act_type (str): Activation type. + use_se (bool): Use SE warpper or not. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResUnit(16, 3, 1, 1) + """ + def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): + super(ResUnit, self).__init__() + self.use_se = use_se + self.first_conv = (num_out != num_mid) + self.use_short_cut_conv = True + + if self.first_conv: + self.expand = Unit(num_in, num_mid, kernel_size=1, + stride=1, padding=0, act_type=act_type) + else: + self.expand = None + self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride, + padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid) + if use_se: + self.se = SE(num_mid) + self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1, + padding=0, act_type=act_type, use_act=False) + if num_in != num_out or stride != 1: + self.use_short_cut_conv = False + self.add = P.TensorAdd() if self.use_short_cut_conv else None + + def construct(self, x): + """construct""" + if self.first_conv: + out = self.expand(x) + else: + out = x + out = self.conv1(out) + if self.use_se: + out = self.se(out) + out = self.conv2(out) + if self.use_short_cut_conv: + out = self.add(x, out) + return out + + def _get_pad(self, kernel_size): + """set the padding number""" + pad = 0 + if kernel_size == 1: + pad = 0 + elif kernel_size == 3: + pad = 1 + elif kernel_size == 5: + pad = 2 + elif kernel_size == 7: + pad = 3 + else: + raise NotImplementedError + return pad + + +class MobileNetV3(nn.Cell): + """ + MobileNetV3 architecture. + + Args: + model_cfgs (Cell): number of classes. + num_classes (int): Output number classes. + multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1. + final_drop (float): Dropout number. + round_nearest (list): Channel round to . Default is 8. + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV3(num_classes=1000) + """ + + def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): + super(MobileNetV3, self).__init__() + self.cfgs = model_cfgs['cfg'] + self.inplanes = 16 + self.features = [] + first_conv_in_channel = 3 + first_conv_out_channel = _make_divisible(multiplier * self.inplanes) + + self.features.append(nn.Conv2d(in_channels=first_conv_in_channel, + out_channels=first_conv_out_channel, + kernel_size=3, padding=1, stride=2, + has_bias=False, pad_mode='pad')) + self.features.append(nn.BatchNorm2d(first_conv_out_channel)) + self.features.append(Activation('hswish')) + for layer_cfg in self.cfgs: + self.features.append(self._make_layer(kernel_size=layer_cfg[0], + exp_ch=_make_divisible(multiplier * layer_cfg[1]), + out_channel=_make_divisible(multiplier * layer_cfg[2]), + use_se=layer_cfg[3], + act_func=layer_cfg[4], + stride=layer_cfg[5])) + output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"]) + self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]), + out_channels=output_channel, + kernel_size=1, padding=0, stride=1, + has_bias=False, pad_mode='pad')) + self.features.append(nn.BatchNorm2d(output_channel)) + self.features.append(Activation('hswish')) + self.features.append(GlobalAvgPooling(keep_dims=True)) + self.features.append(nn.Conv2d(in_channels=output_channel, + out_channels=model_cfgs['cls_ch_expand'], + kernel_size=1, padding=0, stride=1, + has_bias=False, pad_mode='pad')) + self.features.append(Activation('hswish')) + if final_drop > 0: + self.features.append((nn.Dropout(final_drop))) + + # make it nn.CellList + self.features = nn.SequentialCell(self.features) + self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], + out_channels=num_classes, + kernel_size=1, has_bias=True, pad_mode='pad') + self.squeeze = P.Squeeze(axis=(2, 3)) + + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.output(x) + x = self.squeeze(x) + return x + + def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): + mid_planes = exp_ch + out_planes = out_channel + #num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): + layer = ResUnit(self.inplanes, mid_planes, out_planes, + kernel_size, stride=stride, act_type=act_func, use_se=use_se) + self.inplanes = out_planes + return layer + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, (nn.Conv2d)): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + +def mobilenet_v3(model_name, **kwargs): + """ + Constructs a MobileNet V2 model + """ + model_cfgs = { + "large": { + "cfg": [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hswish', 2], + [3, 200, 80, False, 'hswish', 1], + [3, 184, 80, False, 'hswish', 1], + [3, 184, 80, False, 'hswish', 1], + [3, 480, 112, True, 'hswish', 1], + [3, 672, 112, True, 'hswish', 1], + [5, 672, 160, True, 'hswish', 2], + [5, 960, 160, True, 'hswish', 1], + [5, 960, 160, True, 'hswish', 1]], + "cls_ch_squeeze": 960, + "cls_ch_expand": 1280, + }, + "small": { + "cfg": [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hswish', 2], + [5, 240, 40, True, 'hswish', 1], + [5, 240, 40, True, 'hswish', 1], + [5, 120, 48, True, 'hswish', 1], + [5, 144, 48, True, 'hswish', 1], + [5, 288, 96, True, 'hswish', 2], + [5, 576, 96, True, 'hswish', 1], + [5, 576, 96, True, 'hswish', 1]], + "cls_ch_squeeze": 576, + "cls_ch_expand": 1280, + } + } + return MobileNetV3(model_cfgs[model_name], **kwargs) + + +mobilenet_v3_large = partial(mobilenet_v3, model_name="large") +mobilenet_v3_small = partial(mobilenet_v3, model_name="small") diff --git a/model_zoo/official/cv/mobilenetv3/train.py b/model_zoo/official/cv/mobilenetv3/train.py new file mode 100644 index 0000000000..5f2a3502ac --- /dev/null +++ b/model_zoo/official/cv/mobilenetv3/train.py @@ -0,0 +1,276 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import os +import time +import argparse +import random +import numpy as np + +from mindspore import context +from mindspore import Tensor +from mindspore import nn +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de +from mindspore.communication.management import init, get_group_size, get_rank + +from src.dataset import create_dataset +from src.lr_generator import get_lr +from src.config import config_gpu, config_ascend +from src.mobilenetV3 import mobilenet_v3_large + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') +parser.add_argument('--platform', type=str, default=None, help='run platform') +args_opt = parser.parse_args() + +if args_opt.platform == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) + run_distribute = rank_size > 1 + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + device_id=device_id, + save_graphs=False) +elif args_opt.platform == "GPU": + context.set_context(mode=context.GRAPH_MODE, + device_target="GPU", + save_graphs=False) + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) +else: + raise ValueError("Unsupport platform.") + + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + + +if __name__ == '__main__': + if args_opt.platform == "GPU": + # train on gpu + print("train args: ", args_opt) + print("cfg: ", config_gpu) + + # define net + net = mobilenet_v3_large(num_classes=config_gpu.num_classes) + # define loss + if config_gpu.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction='mean') + # define dataset + epoch_size = config_gpu.epoch_size + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config_gpu, + platform=args_opt.platform, + repeat_num=1, + batch_size=config_gpu.batch_size) + step_size = dataset.get_dataset_size() + # resume + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + # define optimizer + loss_scale = FixedLossScaleManager( + config_gpu.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=0, + lr_max=config_gpu.lr, + warmup_epochs=config_gpu.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, + config_gpu.weight_decay, config_gpu.loss_scale) + # define model + model = Model(net, loss_fn=loss, optimizer=opt, + loss_scale_manager=loss_scale) + + cb = [Monitor(lr_init=lr.asnumpy())] + ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" + if config_gpu.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config_gpu.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + # begine train + model.train(epoch_size, dataset, callbacks=cb) + elif args_opt.platform == "Ascend": + # train on ascend + print("train args: ", args_opt, "\ncfg: ", config_ascend, + "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + init() + + epoch_size = config_ascend.epoch_size + net = mobilenet_v3_large(num_classes=config_ascend.num_classes) + net.to_float(mstype.float16) + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.to_float(mstype.float32) + if config_ascend.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config_ascend.label_smooth, num_classes=config.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction='mean') + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + config=config_ascend, + platform=args_opt.platform, + repeat_num=1, + batch_size=config_ascend.batch_size) + step_size = dataset.get_dataset_size() + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + + loss_scale = FixedLossScaleManager( + config_ascend.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=0, + lr_max=config_ascend.lr, + warmup_epochs=config_ascend.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum, + config_ascend.weight_decay, config_ascend.loss_scale) + + model = Model(net, loss_fn=loss, optimizer=opt, + loss_scale_manager=loss_scale) + + cb = None + if rank_id == 0: + cb = [Monitor(lr_init=lr.asnumpy())] + if config_ascend.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config_ascend.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint( + prefix="mobilenetV3", directory=config_ascend.save_checkpoint_path, config=config_ck) + cb += [ckpt_cb] + model.train(epoch_size, dataset, callbacks=cb) + else: + raise Exception diff --git a/model_zoo/official/cv/resnet/README.md b/model_zoo/official/cv/resnet/README.md new file mode 100644 index 0000000000..a22df320e7 --- /dev/null +++ b/model_zoo/official/cv/resnet/README.md @@ -0,0 +1,254 @@ +# ResNet Example + +## Description + +These are examples of training ResNet-50/ResNet-101 with CIFAR-10/ImageNet2012 dataset in MindSpore. +(Training ResNet-101 with dataset CIFAR-10 is unsupported now.) + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset CIFAR-10 or ImageNet2012 + +CIFAR-10 + +> Unzip the CIFAR-10 dataset to any path you want and the folder structure should include train and eval dataset as follows: +> ``` +> . +> └─dataset +> ├─ cifar-10-batches-bin # train dataset +> └─ cifar-10-verify-bin # evaluate dataset +> ``` + +ImageNet2012 + +> Unzip the ImageNet2012 dataset to any path you want and the folder should include train and eval dataset as follows: +> +> ``` +> . +> └─dataset +> ├─ilsvrc # train dataset +> └─validation_preprocess # evaluate dataset +> ``` + + + +## Structure + +```shell +. +└──resnet + ├── README.md + ├── script + ├── run_distribute_train.sh # launch distributed training(8 pcs) + ├── run_eval.sh # launch evaluation + └── run_standalone_train.sh # launch standalone training(1 pcs) + ├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs) + ├── run_eval_gpu.sh # launch gpu evaluation + └── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) + ├── src + ├── config.py # parameter configuration + ├── dataset.py # data preprocessing + ├── crossentropy.py # loss definition for ImageNet2012 dataset + ├── lr_generator.py # generate learning rate for each step + └── resnet.py # resnet backbone, including resnet50 and resnet101 + ├── eval.py # eval net + └── train.py # train net +``` + + +## Parameter configuration + +Parameters for both training and evaluation can be set in config.py. + +- config for ResNet-50, CIFAR-10 dataset + +``` +"class_num": 10, # dataset class num +"batch_size": 32, # batch size of input tensor +"loss_scale": 1024, # loss scale +"momentum": 0.9, # momentum +"weight_decay": 1e-4, # weight decay +"epoch_size": 90, # only valid for taining, which is always 1 for inference +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_steps": 195, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step +"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./", # path to save checkpoint +"warmup_epochs": 5, # number of warmup epoch +"lr_decay_mode": "poly" # decay mode can be selected in steps, ploy and default +"lr_init": 0.01, # initial learning rate +"lr_end": 0.00001, # final learning rate +"lr_max": 0.1, # maximum learning rate +``` + +- config for ResNet-50, ImageNet2012 dataset + +``` +"class_num": 1001, # dataset class number +"batch_size": 32, # batch size of input tensor +"loss_scale": 1024, # loss scale +"momentum": 0.9, # momentum optimizer +"weight_decay": 1e-4, # weight decay +"epoch_size": 90, # only valid for taining, which is always 1 for inference +"pretrained_epoch_size": 1, # epoch size that model has been trained before load pretrained checkpoint +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch +"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path +"warmup_epochs": 0, # number of warmup epoch +"lr_decay_mode": "cosine", # decay mode for generating learning rate +"label_smooth": True, # label smooth +"label_smooth_factor": 0.1, # label smooth factor +"lr_init": 0, # initial learning rate +"lr_max": 0.1, # maximum learning rate +``` + +- config for ResNet-101, ImageNet2012 dataset + +``` +"class_num": 1001, # dataset class number +"batch_size": 32, # batch size of input tensor +"loss_scale": 1024, # loss scale +"momentum": 0.9, # momentum optimizer +"weight_decay": 1e-4, # weight decay +"epoch_size": 120, # epoch sizes for training +"pretrain_epoch_size": 0, # epoch size of pretrain checkpoint +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch +"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path +"warmup_epochs": 0, # number of warmup epoch +"lr_decay_mode": "cosine" # decay mode for generating learning rate +"label_smooth": 1, # label_smooth +"label_smooth_factor": 0.1, # label_smooth_factor +"lr": 0.1 # base learning rate +``` + + + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] + [PRETRAINED_CKPT_PATH](optional) + +# standalone training +Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] + [PRETRAINED_CKPT_PATH](optional) +``` + + +#### Launch + +``` +# distribute training example +sh run_distribute_train.sh resnet50 cifar10 rank_table.json ~/cifar-10-batches-bin + +# standalone training example +sh run_standalone_train.sh resnet50 cifar10 ~/cifar-10-batches-bin +``` + +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + +#### Result + +Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. + +- training ResNet-50 with CIFAR-10 dataset + +``` +# distribute training result(8 pcs) +epoch: 1 step: 195, loss is 1.9601055 +epoch: 2 step: 195, loss is 1.8555021 +epoch: 3 step: 195, loss is 1.6707983 +epoch: 4 step: 195, loss is 1.8162166 +epoch: 5 step: 195, loss is 1.393667 +... +``` + +- training ResNet-50 with ImageNet2012 dataset + +``` +# distribute training result(8 pcs) +epoch: 1 step: 5004, loss is 4.8995576 +epoch: 2 step: 5004, loss is 3.9235563 +epoch: 3 step: 5004, loss is 3.833077 +epoch: 4 step: 5004, loss is 3.2795618 +epoch: 5 step: 5004, loss is 3.1978393 +... +``` + +- training ResNet-101 with ImageNet2012 dataset + +``` +# distribute training result(8p) +epoch: 1 step: 5004, loss is 4.805483 +epoch: 2 step: 5004, loss is 3.2121816 +epoch: 3 step: 5004, loss is 3.429647 +epoch: 4 step: 5004, loss is 3.3667371 +epoch: 5 step: 5004, loss is 3.1718972 +... +epoch: 67 step: 5004, loss is 2.2768745 +epoch: 68 step: 5004, loss is 1.7223864 +epoch: 69 step: 5004, loss is 2.0665488 +epoch: 70 step: 5004, loss is 1.8717369 +... +``` + +### Evaluation + +#### Usage + +``` +# evaluation +Usage: sh run_eval.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] +``` + +#### Launch + +``` +# evaluation example +sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. + +- evaluating ResNet-50 with CIFAR-10 dataset + +``` +result: {'acc': 0.91446314102564111} ckpt=~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt +``` + +- evaluating ResNet-50 with ImageNet2012 dataset + +``` +result: {'acc': 0.7671054737516005} ckpt=train_parallel0/resnet-90_5004.ckpt +``` + +- evaluating ResNet-101 with ImageNet2012 dataset + +``` +result: {'top_5_accuracy': 0.9429417413572343, 'top_1_accuracy': 0.7853513124199744} ckpt=train_parallel0/resnet-120_5004.ckpt +``` + +### Running on GPU +``` +# distributed training example +sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) + +# standalone training example +sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) + +# infer example +sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] +``` diff --git a/model_zoo/official/cv/resnet/eval.py b/model_zoo/official/cv/resnet/eval.py new file mode 100755 index 0000000000..7ad67289fe --- /dev/null +++ b/model_zoo/official/cv/resnet/eval.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train resnet.""" +import os +import random +import argparse +import numpy as np +from mindspore import context +from mindspore import dataset as de +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') +parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012') + +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +args_opt = parser.parse_args() + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +if args_opt.net == "resnet50": + from src.resnet import resnet50 as resnet + + if args_opt.dataset == "cifar10": + from src.config import config1 as config + from src.dataset import create_dataset1 as create_dataset + else: + from src.config import config2 as config + from src.dataset import create_dataset2 as create_dataset +else: + from src.resnet import resnet101 as resnet + from src.config import config3 as config + from src.dataset import create_dataset3 as create_dataset + +if __name__ == '__main__': + target = args_opt.device_target + + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + if target != "GPU": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, + target=target) + step_size = dataset.get_dataset_size() + + # define net + net = resnet(class_num=config.class_num) + + # load checkpoint + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + # define loss, model + if args_opt.dataset == "imagenet2012": + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + # define model + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + # eval model + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh new file mode 100755 index 0000000000..58e2cb1c6c --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 4 ] && [ $# != 5 ] +then + echo "Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" +exit 1 +fi + +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +then + echo "error: the selected net is neither resnet50 nor resnet101" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" +exit 1 +fi + +if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] +then + echo "error: training resnet101 with cifar10 dataset is unsupported now!" +exit 1 +fi + + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) +PATH2=$(get_real_path $4) + +if [ $# == 5 ] +then + PATH3=$(get_real_path $5) +fi + +if [ ! -f $PATH1 ] +then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" +exit 1 +fi + +if [ ! -d $PATH2 ] +then + echo "error: DATASET_PATH=$PATH2 is not a directory" +exit 1 +fi + +if [ $# == 5 ] && [ ! -f $PATH3 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + if [ $# == 4 ] + then + python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log & + fi + + if [ $# == 5 ] + then + python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log & + fi + + cd .. +done diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train_gpu.sh new file mode 100755 index 0000000000..95e0d7df06 --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,93 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] && [ $# != 4 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" +exit 1 +fi + +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +then + echo "error: the selected net is neither resnet50 nor resnet101" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" +exit 1 +fi + +if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] +then + echo "error: training resnet101 with cifar10 dataset is unsupported now!" +exit 1 +fi + + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) + +if [ $# == 4 ] +then + PATH2=$(get_real_path $4) +fi + + +if [ ! -d $PATH2 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ $# == 5 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 + +rm -rf ./train_parallel +mkdir ./train_parallel +cp ../*.py ./train_parallel +cp *.sh ./train_parallel +cp -r ../src ./train_parallel +cd ./train_parallel || exit + +if [ $# == 3 ] +then + mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py --net=$1 --dataset=$2 --run_distribute=True \ + --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & +fi + +if [ $# == 4 ] +then + mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py --net=$1 --dataset=$2 --run_distribute=True \ + --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & +fi diff --git a/model_zoo/resnet/scripts/run_eval.sh b/model_zoo/official/cv/resnet/scripts/run_eval.sh similarity index 100% rename from model_zoo/resnet/scripts/run_eval.sh rename to model_zoo/official/cv/resnet/scripts/run_eval.sh diff --git a/model_zoo/official/cv/resnet/scripts/run_eval_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_eval_gpu.sh new file mode 100755 index 0000000000..fc93602f5a --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_eval_gpu.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 4 ] +then + echo "Usage: sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +then + echo "error: the selected net is neither resnet50 nor resnet101" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" +exit 1 +fi + +if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] +then + echo "error: evaluating resnet101 with cifar10 dataset is unsupported now!" +exit 1 +fi + + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) +PATH2=$(get_real_path $4) + + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start evaluation for device $DEVICE_ID" +python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log & +cd .. diff --git a/model_zoo/resnet/scripts/run_standalone_train.sh b/model_zoo/official/cv/resnet/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/resnet/scripts/run_standalone_train.sh rename to model_zoo/official/cv/resnet/scripts/run_standalone_train.sh diff --git a/model_zoo/official/cv/resnet/scripts/run_standalone_train_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_standalone_train_gpu.sh new file mode 100755 index 0000000000..076bd4b332 --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] && [ $# != 4 ] +then + echo "Usage: sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" +exit 1 +fi + +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +then + echo "error: the selected net is neither resnet50 nor resnet101" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" +exit 1 +fi + +if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] +then + echo "error: training resnet101 with cifar10 dataset is unsupported now!" +exit 1 +fi + + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) + +if [ $# == 4 ] +then + PATH2=$(get_real_path $4) +fi + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ $# == 4 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log +if [ $# == 3 ] +then + python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 &> log & +fi + +if [ $# == 4 ] +then + python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & +fi +cd .. diff --git a/model_zoo/resnet/src/config.py b/model_zoo/official/cv/resnet/src/config.py similarity index 100% rename from model_zoo/resnet/src/config.py rename to model_zoo/official/cv/resnet/src/config.py diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py new file mode 100755 index 0000000000..d4a8969ed1 --- /dev/null +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -0,0 +1,208 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +from mindspore.communication.management import init, get_rank, get_group_size + + +def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or evaluate cifar10 dataset for resnet50 + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() + + if device_num == 1: + ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + # define map operations + trans = [] + if do_train: + trans += [ + C.RandomCrop((32, 32), (4, 4, 4, 4)), + C.RandomHorizontalFlip(prob=0.5) + ] + + trans += [ + C.Resize((224, 224)), + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() + + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or eval imagenet2012 dataset for resnet101 + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + + Returns: + dataset + """ + device_num, rank_id = _get_rank_info() + + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + image_size = 224 + mean = [0.475 * 255, 0.451 * 255, 0.392 * 255] + std = [0.275 * 255, 0.267 * 255, 0.278 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(rank_id/ (rank_id +1)), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id diff --git a/model_zoo/resnet/src/lr_generator.py b/model_zoo/official/cv/resnet/src/lr_generator.py similarity index 100% rename from model_zoo/resnet/src/lr_generator.py rename to model_zoo/official/cv/resnet/src/lr_generator.py diff --git a/model_zoo/resnet/src/resnet.py b/model_zoo/official/cv/resnet/src/resnet.py similarity index 100% rename from model_zoo/resnet/src/resnet.py rename to model_zoo/official/cv/resnet/src/resnet.py diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py new file mode 100755 index 0000000000..414fa4c7de --- /dev/null +++ b/model_zoo/official/cv/resnet/train.py @@ -0,0 +1,184 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train resnet.""" +import os +import random +import argparse +import numpy as np +from mindspore import context +from mindspore import Tensor +from mindspore import dataset as de +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init, get_rank, get_group_size +import mindspore.nn as nn +import mindspore.common.initializer as weight_init +from src.lr_generator import get_lr, warmup_cosine_annealing_lr + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') +parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012') +parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') + +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') +args_opt = parser.parse_args() + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +if args_opt.net == "resnet50": + from src.resnet import resnet50 as resnet + + if args_opt.dataset == "cifar10": + from src.config import config1 as config + from src.dataset import create_dataset1 as create_dataset + else: + from src.config import config2 as config + from src.dataset import create_dataset2 as create_dataset +else: + from src.resnet import resnet101 as resnet + from src.config import config3 as config + from src.dataset import create_dataset3 as create_dataset + +if __name__ == '__main__': + target = args_opt.device_target + ckpt_save_dir = config.save_checkpoint_path + + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + if args_opt.run_distribute: + if target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id, enable_auto_mixed_precision=True) + context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + if args_opt.net == "resnet50": + auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) + init() + # GPU target + else: + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" + + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, + batch_size=config.batch_size, target=target) + step_size = dataset.get_dataset_size() + + # define net + net = resnet(class_num=config.class_num) + + # init weight + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype) + if isinstance(cell, nn.Dense): + cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype) + + # init lr + if args_opt.net == "resnet50": + if args_opt.dataset == "cifar10": + lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size, + lr_decay_mode='poly') + else: + lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') + else: + lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size, + config.pretrain_epoch_size * step_size) + lr = Tensor(lr) + + # define opt + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) + # define loss, model + if target == "Ascend": + if args_opt.dataset == "imagenet2012": + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False) + else: + # GPU target + if args_opt.dataset == "imagenet2012": + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, + num_classes=config.class_num) + + if args_opt.net == "resnet101": + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, + config.loss_scale) + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + # Mixed precision + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=True) + else: + ## fp32 training + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + + # train model + model.train(config.epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/resnet50_quant/Readme.md b/model_zoo/official/cv/resnet50_quant/Readme.md similarity index 100% rename from model_zoo/resnet50_quant/Readme.md rename to model_zoo/official/cv/resnet50_quant/Readme.md diff --git a/model_zoo/resnet50_quant/eval.py b/model_zoo/official/cv/resnet50_quant/eval.py similarity index 100% rename from model_zoo/resnet50_quant/eval.py rename to model_zoo/official/cv/resnet50_quant/eval.py diff --git a/model_zoo/resnet50_quant/models/resnet_quant.py b/model_zoo/official/cv/resnet50_quant/models/resnet_quant.py similarity index 100% rename from model_zoo/resnet50_quant/models/resnet_quant.py rename to model_zoo/official/cv/resnet50_quant/models/resnet_quant.py diff --git a/model_zoo/resnet50_quant/scripts/run_infer.sh b/model_zoo/official/cv/resnet50_quant/scripts/run_infer.sh similarity index 100% rename from model_zoo/resnet50_quant/scripts/run_infer.sh rename to model_zoo/official/cv/resnet50_quant/scripts/run_infer.sh diff --git a/model_zoo/resnet50_quant/scripts/run_train.sh b/model_zoo/official/cv/resnet50_quant/scripts/run_train.sh similarity index 100% rename from model_zoo/resnet50_quant/scripts/run_train.sh rename to model_zoo/official/cv/resnet50_quant/scripts/run_train.sh diff --git a/model_zoo/resnet50_quant/src/config.py b/model_zoo/official/cv/resnet50_quant/src/config.py similarity index 100% rename from model_zoo/resnet50_quant/src/config.py rename to model_zoo/official/cv/resnet50_quant/src/config.py diff --git a/model_zoo/resnet50_quant/src/crossentropy.py b/model_zoo/official/cv/resnet50_quant/src/crossentropy.py similarity index 100% rename from model_zoo/resnet50_quant/src/crossentropy.py rename to model_zoo/official/cv/resnet50_quant/src/crossentropy.py diff --git a/model_zoo/resnet50_quant/src/dataset.py b/model_zoo/official/cv/resnet50_quant/src/dataset.py similarity index 100% rename from model_zoo/resnet50_quant/src/dataset.py rename to model_zoo/official/cv/resnet50_quant/src/dataset.py diff --git a/model_zoo/official/cv/resnet50_quant/src/launch.py b/model_zoo/official/cv/resnet50_quant/src/launch.py new file mode 100644 index 0000000000..c30869504a --- /dev/null +++ b/model_zoo/official/cv/resnet50_quant/src/launch.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import json +import subprocess +import shutil +import platform +from argparse import ArgumentParser + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_id", type=str, default="", + help="server ip") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + if not args.server_id: + print('pleaser input server ip!!!') + exit(0) + print('server_id:{}'.format(args.server_id)) + + # construct hccn_table + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {} + arch = platform.processor() + hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch] + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + usable_dev = '' + for instance_id in range(args.nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = args.server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(args.nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(args.nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(args.nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + if not os.path.exists(table_path): + os.mkdir(table_path) + table_fn = os.path.join(table_path, + 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if args.nproc_per_node > 1: + env['RANK_TABLE_FILE'] = table_fn + env['RANK_TABLE_FILE'] = table_fn + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/model_zoo/resnet50_quant/src/lr_generator.py b/model_zoo/official/cv/resnet50_quant/src/lr_generator.py similarity index 100% rename from model_zoo/resnet50_quant/src/lr_generator.py rename to model_zoo/official/cv/resnet50_quant/src/lr_generator.py diff --git a/model_zoo/resnet50_quant/src/utils.py b/model_zoo/official/cv/resnet50_quant/src/utils.py similarity index 100% rename from model_zoo/resnet50_quant/src/utils.py rename to model_zoo/official/cv/resnet50_quant/src/utils.py diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py new file mode 100755 index 0000000000..5e1f075615 --- /dev/null +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -0,0 +1,153 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train Resnet50 on ImageNet""" + +import os +import argparse + +from mindspore import context +from mindspore import Tensor +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint +from mindspore.train.quant import quant +from mindspore.communication.management import init +import mindspore.nn as nn +import mindspore.common.initializer as weight_init + +from models.resnet_quant import resnet50_quant +from src.dataset import create_dataset +from src.lr_generator import get_lr +from src.config import quant_set, config_quant, config_noquant +from src.crossentropy import CrossEntropy +from src.utils import _load_param_into_net + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') +args_opt = parser.parse_args() +config = config_quant if quant_set.quantization_aware else config_noquant + +if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) + run_distribute = rank_size > 1 + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + save_graphs=False, + device_id=device_id, + enable_auto_mixed_precision=True) +else: + raise ValueError("Unsupported device target.") + +if __name__ == '__main__': + # train on ascend + print("training args: {}".format(args_opt)) + print("training configure: {}".format(config)) + print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + epoch_size = config.epoch_size + + # distribute init + if run_distribute: + context.set_auto_parallel_context(device_num=rank_size, + parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, + mirror_mean=True) + init() + context.set_auto_parallel_context(device_num=args_opt.device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) + + # define network + net = resnet50_quant(class_num=config.class_num) + net.set_train(True) + + # weight init and load checkpoint file + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + _load_param_into_net(net, param_dict) + epoch_size = config.epoch_size - config.pretrained_epoch_size + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype) + if isinstance(cell, nn.Dense): + cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype) + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + # define dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=True, + repeat_num=1, + batch_size=config.batch_size, + target=args_opt.device_target) + step_size = dataset.get_dataset_size() + + if quant_set.quantization_aware: + # convert fusion network to quantization aware network + net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + + # get learning rate + lr = get_lr(lr_init=config.lr_init, + lr_end=0.0, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, + steps_per_epoch=step_size, + lr_decay_mode='cosine') + if args_opt.pre_trained: + lr = lr[config.pretrained_epoch_size * step_size:] + lr = Tensor(lr) + + # define optimization + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + config.weight_decay, config.loss_scale) + + # define model + if quant_set.quantization_aware: + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) + else: + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2") + + print("============== Starting Training ==============") + time_callback = TimeMonitor(data_size=step_size) + loss_callback = LossMonitor() + callbacks = [time_callback, loss_callback] + if rank_id == 0: + if config.save_checkpoint: + config_ckpt = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_callback = ModelCheckpoint(prefix="ResNet50", + directory=config.save_checkpoint_path, + config=config_ckpt) + callbacks += [ckpt_callback] + model.train(epoch_size, dataset, callbacks=callbacks) + print("============== End Training ==============") diff --git a/model_zoo/official/cv/resnet_thor/README.md b/model_zoo/official/cv/resnet_thor/README.md new file mode 100644 index 0000000000..cecd934575 --- /dev/null +++ b/model_zoo/official/cv/resnet_thor/README.md @@ -0,0 +1,128 @@ +# ResNet-50-THOR Example + +## Description + +This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by second-order optimizer THOR. THOR is a novel approximate seond-order optimization method in MindSpore. With fewer iterations, THOR can finish ResNet-50 V1.5 training in 72 minutes to top-1 accuracy of 75.9% using 8 Ascend 910, which is much faster than SGD with Momentum. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset ImageNet2012 + +> Unzip the ImageNet2012 dataset to any path you want and the folder structure should include train and eval dataset as follows: +> ``` +> . +> ├── ilsvrc # train dataset +> └── ilsvrc_eval # infer dataset +> ``` + + +## Example structure + +```shell +. +├── resnet_thor + ├── README.md + ├── src + ├── crossentropy.py # CrossEntropy loss function + ├── config.py # parameter configuration + ├── resnet50.py # resnet50 backbone + ├── dataset_helper.py # dataset help for minddata dataset + ├── grad_reducer_thor.py # grad reducer for thor + ├── model_thor.py # model + ├── resnet_thor.py # resnet50_thor backone + ├── thor.py # thor + ├── thor_layer.py # thor layer + └── dataset_imagenet.py # data preprocessing + ├── scripts + ├── run_distribute_train.sh # launch distributed training(8 pcs) + └── run_eval.sh # launch infering + ├── eval.py # infer script + └── train.py # train script +``` + + +## Parameter configuration + +Parameters for both training and inference can be set in config.py. + +``` +"class_num": 1000, # dataset class number +"batch_size": 32, # batch size of input tensor +"loss_scale": 128, # loss scale +"momentum": 0.9, # momentum of THOR optimizer +"weight_decay": 5e-4, # weight decay +"epoch_size": 45, # only valid for taining, which is always 1 for inference +"buffer_size": 1000, # number of queue size in data preprocessing +"image_height": 224, # image height +"image_width": 224, # image width +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch +"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path +"label_smooth": True, # label smooth +"label_smooth_factor": 0.1, # label smooth factor +"frequency": 834, # the step interval to update second-order information matrix +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM] +``` + + +#### Launch + +```bash +# distributed training example(8 pcs) +sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc +``` + +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + +#### Result + +Training result will be stored in the example path, whose folder name begins with "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. + +``` +# distribute training result(8 pcs) +epoch: 1 step: 5004, loss is 4.4182425 +epoch: 2 step: 5004, loss is 3.740064 +epoch: 3 step: 5004, loss is 4.0546017 +epoch: 4 step: 5004, loss is 3.7598825 +epoch: 5 step: 5004, loss is 3.3744206 +...... +``` + +### Infer + +#### Usage + +``` +# infer +Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` + +#### Launch + +```bash +# infer with checkpoint +sh run_eval.sh dataset/ilsvrc_eval train_parallel0/resnet-42_5004.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log. + +``` +result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt +``` diff --git a/model_zoo/resnet_thor/eval.py b/model_zoo/official/cv/resnet_thor/eval.py similarity index 100% rename from model_zoo/resnet_thor/eval.py rename to model_zoo/official/cv/resnet_thor/eval.py diff --git a/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..63d192bfa1 --- /dev/null +++ b/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM]" +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + +if [ ! -d $2 ] +then + echo "error: DATASET_PATH=$2 is not a directory" +exit 1 +fi + +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) +cd $BASE_PATH/../ || exit + +ulimit -u unlimited +export DEVICE_NUM=$3 +export RANK_SIZE=$3 +export RANK_TABLE_FILE=$1 + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp *.py ./train_parallel$i + cp -r ./src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + + env > env.log + python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & + cd .. +done diff --git a/model_zoo/resnet_thor/scripts/run_eval.sh b/model_zoo/official/cv/resnet_thor/scripts/run_eval.sh similarity index 100% rename from model_zoo/resnet_thor/scripts/run_eval.sh rename to model_zoo/official/cv/resnet_thor/scripts/run_eval.sh diff --git a/model_zoo/resnet_thor/src/config.py b/model_zoo/official/cv/resnet_thor/src/config.py similarity index 100% rename from model_zoo/resnet_thor/src/config.py rename to model_zoo/official/cv/resnet_thor/src/config.py diff --git a/model_zoo/resnet_thor/src/crossentropy.py b/model_zoo/official/cv/resnet_thor/src/crossentropy.py similarity index 100% rename from model_zoo/resnet_thor/src/crossentropy.py rename to model_zoo/official/cv/resnet_thor/src/crossentropy.py diff --git a/model_zoo/resnet_thor/src/dataset_helper.py b/model_zoo/official/cv/resnet_thor/src/dataset_helper.py similarity index 100% rename from model_zoo/resnet_thor/src/dataset_helper.py rename to model_zoo/official/cv/resnet_thor/src/dataset_helper.py diff --git a/model_zoo/resnet_thor/src/dataset_imagenet.py b/model_zoo/official/cv/resnet_thor/src/dataset_imagenet.py similarity index 100% rename from model_zoo/resnet_thor/src/dataset_imagenet.py rename to model_zoo/official/cv/resnet_thor/src/dataset_imagenet.py diff --git a/model_zoo/resnet_thor/src/grad_reducer_thor.py b/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py similarity index 100% rename from model_zoo/resnet_thor/src/grad_reducer_thor.py rename to model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py diff --git a/model_zoo/resnet_thor/src/model_thor.py b/model_zoo/official/cv/resnet_thor/src/model_thor.py similarity index 100% rename from model_zoo/resnet_thor/src/model_thor.py rename to model_zoo/official/cv/resnet_thor/src/model_thor.py diff --git a/model_zoo/resnet_thor/src/resnet50.py b/model_zoo/official/cv/resnet_thor/src/resnet50.py similarity index 100% rename from model_zoo/resnet_thor/src/resnet50.py rename to model_zoo/official/cv/resnet_thor/src/resnet50.py diff --git a/model_zoo/resnet_thor/src/resnet_thor.py b/model_zoo/official/cv/resnet_thor/src/resnet_thor.py similarity index 100% rename from model_zoo/resnet_thor/src/resnet_thor.py rename to model_zoo/official/cv/resnet_thor/src/resnet_thor.py diff --git a/model_zoo/resnet_thor/src/thor.py b/model_zoo/official/cv/resnet_thor/src/thor.py similarity index 100% rename from model_zoo/resnet_thor/src/thor.py rename to model_zoo/official/cv/resnet_thor/src/thor.py diff --git a/model_zoo/resnet_thor/src/thor_layer.py b/model_zoo/official/cv/resnet_thor/src/thor_layer.py similarity index 100% rename from model_zoo/resnet_thor/src/thor_layer.py rename to model_zoo/official/cv/resnet_thor/src/thor_layer.py diff --git a/model_zoo/official/cv/resnet_thor/train.py b/model_zoo/official/cv/resnet_thor/train.py new file mode 100644 index 0000000000..b6a84fe136 --- /dev/null +++ b/model_zoo/official/cv/resnet_thor/train.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import argparse +import os +import random + +import numpy as np + +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.model import ParallelMode +from src.model_thor import Model +from src.resnet_thor import resnet50 +from src.thor import THOR +from src.config import config +from src.crossentropy import CrossEntropy +from src.dataset_imagenet import create_dataset + +random.seed(1) +np.random.seed(1) + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.') +parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') + +args_opt = parser.parse_args() +device_id = int(os.getenv('DEVICE_ID')) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) + + +def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): + """get_model_lr""" + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + for i in range(total_steps): + epoch = (i + 1) / steps_per_epoch + base = (1.0 - float(epoch) / total_epochs) ** decay + lr_local = lr_init * base + if epoch >= 39: + lr_local = lr_local * 0.5 + if epoch >= 40: + lr_local = lr_local * 0.5 + lr_each_step.append(lr_local) + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + return learning_rate + + +def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): + """get_model_damping""" + damping_each_step = [] + total_steps = steps_per_epoch * total_epochs + for step in range(total_steps): + epoch = (step + 1) / steps_per_epoch + damping_here = damping_init * (decay_rate ** (epoch / 10)) + damping_each_step.append(damping_here) + + current_step = global_step + damping_each_step = np.array(damping_each_step).astype(np.float32) + damping_now = damping_each_step[current_step:] + return damping_now + + +if __name__ == '__main__': + if not args_opt.do_eval and args_opt.run_distribute: + context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True, parameter_broadcast=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") + + init() + + epoch_size = config.epoch_size + damping = get_model_damping(0, 0.03, 0.87, 50, 5004) + net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, + frequency=config.frequency) + + if not config.label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + if args_opt.do_train: + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, + batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) + opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), + filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), + filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), + filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), + config.weight_decay, config.loss_scale) + + model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, + keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) + + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) + cb += [ckpt_cb] + + model.train(epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/official/cv/resnext50/README.md b/model_zoo/official/cv/resnext50/README.md new file mode 100644 index 0000000000..6119fd7913 --- /dev/null +++ b/model_zoo/official/cv/resnext50/README.md @@ -0,0 +1,134 @@ +# ResNext50 Example + +## Description + +This is an example of training ResNext50 in MindSpore. + +## Requirements + +- Install [Mindspore](http://www.mindspore.cn/install/en). +- Downlaod the dataset. + +## Structure + +```shell +. +└─resnext50 + ├─README.md + ├─scripts + ├─run_standalone_train.sh # launch standalone training(1p) + ├─run_distribute_train.sh # launch distributed training(8p) + └─run_eval.sh # launch evaluating + ├─src + ├─backbone + ├─_init_.py # initalize + ├─resnet.py # resnext50 backbone + ├─utils + ├─_init_.py # initalize + ├─cunstom_op.py # network operation + ├─logging.py # print log + ├─optimizers_init_.py # get parameters + ├─sampler.py # distributed sampler + ├─var_init_.py # calculate gain value + ├─_init_.py # initalize + ├─config.py # parameter configuration + ├─crossentropy.py # CrossEntropy loss function + ├─dataset.py # data preprocessing + ├─head.py # commom head + ├─image_classification.py # get resnet + ├─linear_warmup.py # linear warmup learning rate + ├─warmup_cosine_annealing.py # learning rate each step + ├─warmup_step_lr.py # warmup step learning rate + ├─eval.py # eval net + └─train.py # train net + +``` + +## Parameter Configuration + +Parameters for both training and evaluating can be set in config.py + +``` +"image_height": '224,224' # image size +"num_classes": 1000, # dataset class number +"per_batch_size": 128, # batch size of input tensor +"lr": 0.05, # base learning rate +"lr_scheduler": 'cosine_annealing', # learning rate mode +"lr_epochs": '30,60,90,120', # epoch of lr changing +"lr_gamma": 0.1, # decrease lr by a factor of exponential lr_scheduler +"eta_min": 0, # eta_min in cosine_annealing scheduler +"T_max": 150, # T-max in cosine_annealing scheduler +"max_epoch": 150, # max epoch num to train the model +"backbone": 'resnext50', # backbone metwork +"warmup_epochs" : 1, # warmup epoch +"weight_decay": 0.0001, # weight decay +"momentum": 0.9, # momentum +"is_dynamic_loss_scale": 0, # dynamic loss scale +"loss_scale": 1024, # loss scale +"label_smooth": 1, # label_smooth +"label_smooth_factor": 0.1, # label_smooth_factor +"ckpt_interval": 2000, # ckpt_interval +"ckpt_path": 'outputs/', # checkpoint save location +"is_save_on_master": 1, +"rank": 0, # local rank of distributed +"group_size": 1 # world size of distributed +``` + +## Running the example + +### Train + +#### Usage + +``` +# distribute training example(8p) +sh run_distribute_train.sh RANK_TABLE_FILE DATA_PATH +# standalone training +sh run_standalone_train.sh DEVICE_ID DATA_PATH +``` + +#### Launch + +```bash +# distributed training example(8p) for Ascend +sh scripts/run_distribute_train.sh RANK_TABLE_FILE /dataset/train +# standalone training example for Ascend +sh scripts/run_standalone_train.sh 0 /dataset/train + +# distributed training example(8p) for GPU +sh scripts/run_distribute_train_for_gpu.sh /dataset/train +# standalone training example for GPU +sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train +``` + +#### Result + +You can find checkpoint file together with result in log. + +### Evaluation + +#### Usage + +``` +# Evaluation +sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM +``` +PLATFORM is Ascend or GPU, default is Ascend. + +#### Launch + +```bash +# Evaluation with checkpoint +sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt Ascend +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. + +``` +acc=78.16%(TOP1) +acc=93.88%(TOP5) +``` \ No newline at end of file diff --git a/model_zoo/official/cv/resnext50/eval.py b/model_zoo/official/cv/resnext50/eval.py new file mode 100644 index 0000000000..4dc2aa485a --- /dev/null +++ b/model_zoo/official/cv/resnext50/eval.py @@ -0,0 +1,254 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Eval""" +import os +import time +import argparse +import datetime +import glob +import numpy as np +import mindspore.nn as nn + +from mindspore import Tensor, context +from mindspore.communication.management import init, get_rank, get_group_size, release +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + +from src.utils.logging import get_logger +from src.utils.auto_mixed_precision import auto_mixed_precision +from src.image_classification import get_network +from src.dataset import classification_dataset +from src.config import config + + +class ParameterReduce(nn.Cell): + """ParameterReduce""" + def __init__(self): + super(ParameterReduce, self).__init__() + self.cast = P.Cast() + self.reduce = P.AllReduce() + + def construct(self, x): + one = self.cast(F.scalar_to_array(1.0), mstype.float32) + out = x * one + ret = self.reduce(out) + return ret + + +def parse_args(cloud_args=None): + """parse_args""" + parser = argparse.ArgumentParser('mindspore classification test') + parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') + + # dataset related + parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir') + parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') + # network related + parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') + parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. ' + 'If it is a direction, it will test all ckpt') + + # logging related + parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') + parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') + + # roma obs + parser.add_argument('--train_url', type=str, default="", help='train url') + + args, _ = parser.parse_known_args() + args = merge_args(args, cloud_args) + args.image_size = config.image_size + args.num_classes = config.num_classes + args.backbone = config.backbone + args.rank = config.rank + args.group_size = config.group_size + + args.image_size = list(map(int, args.image_size.split(','))) + + return args + + +def get_top5_acc(top5_arg, gt_class): + sub_count = 0 + for top5, gt in zip(top5_arg, gt_class): + if gt in top5: + sub_count += 1 + return sub_count + +def merge_args(args, cloud_args): + """merge_args""" + args_dict = vars(args) + if isinstance(cloud_args, dict): + for key in cloud_args.keys(): + val = cloud_args[key] + if key in args_dict and val: + arg_type = type(args_dict[key]) + if arg_type is not type(None): + val = arg_type(val) + args_dict[key] = val + return args + +def test(cloud_args=None): + """test""" + args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + # init distributed + if args.is_distributed: + if args.platform == "Ascend": + init() + elif args.platform == "GPU": + init("nccl") + args.rank = get_rank() + args.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + args.rank = 0 + args.group_size = 1 + + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + # network + args.logger.important_info('start create network') + if os.path.isdir(args.pretrained): + models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt'))) + print(models) + if args.graph_ckpt: + f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]) + else: + f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1]) + args.models = sorted(models, key=f) + else: + args.models = [args.pretrained,] + + for model in args.models: + de_dataset = classification_dataset(args.data_dir, image_size=args.image_size, + per_batch_size=args.per_batch_size, + max_epoch=1, rank=args.rank, group_size=args.group_size, + mode='eval') + eval_dataloader = de_dataset.create_tuple_iterator() + network = get_network(args.backbone, args.num_classes, platform=args.platform) + if network is None: + raise NotImplementedError('not implement {}'.format(args.backbone)) + + param_dict = load_checkpoint(model) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.'): + continue + elif key.startswith('network.'): + param_dict_new[key[8:]] = values + else: + param_dict_new[key] = values + + load_param_into_net(network, param_dict_new) + args.logger.info('load model {} success'.format(model)) + + img_tot = 0 + top1_correct = 0 + top5_correct = 0 + if args.platform == "Ascend": + network.to_float(mstype.float16) + else: + auto_mixed_precision(network) + network.set_train(False) + t_end = time.time() + it = 0 + for data, gt_classes in eval_dataloader: + output = network(Tensor(data, mstype.float32)) + output = output.asnumpy() + + top1_output = np.argmax(output, (-1)) + top5_output = np.argsort(output)[:, -5:] + + t1_correct = np.equal(top1_output, gt_classes).sum() + top1_correct += t1_correct + top5_correct += get_top5_acc(top5_output, gt_classes) + img_tot += args.per_batch_size + + if args.rank == 0 and it == 0: + t_end = time.time() + it = 1 + if args.rank == 0: + time_used = time.time() - t_end + fps = (img_tot - args.per_batch_size) * args.group_size / time_used + args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps)) + results = [[top1_correct], [top5_correct], [img_tot]] + args.logger.info('before results={}'.format(results)) + if args.is_distributed: + model_md5 = model.replace('/', '') + tmp_dir = '/cache' + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5) + top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5) + img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5) + np.save(top1_correct_npy, top1_correct) + np.save(top5_correct_npy, top5_correct) + np.save(img_tot_npy, img_tot) + while True: + rank_ok = True + for other_rank in range(args.group_size): + top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5) + top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5) + img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5) + if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \ + not os.path.exists(img_tot_npy): + rank_ok = False + if rank_ok: + break + + top1_correct_all = 0 + top5_correct_all = 0 + img_tot_all = 0 + for other_rank in range(args.group_size): + top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5) + top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5) + img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5) + top1_correct_all += np.load(top1_correct_npy) + top5_correct_all += np.load(top5_correct_npy) + img_tot_all += np.load(img_tot_npy) + results = [[top1_correct_all], [top5_correct_all], [img_tot_all]] + results = np.array(results) + else: + results = np.array(results) + + args.logger.info('after results={}'.format(results)) + top1_correct = results[0, 0] + top5_correct = results[1, 0] + img_tot = results[2, 0] + acc1 = 100.0 * top1_correct / img_tot + acc5 = 100.0 * top5_correct / img_tot + args.logger.info('after allreduce eval: top1_correct={}, tot={},' + 'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1)) + args.logger.info('after allreduce eval: top5_correct={}, tot={},' + 'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5)) + if args.is_distributed: + release() + + +if __name__ == "__main__": + test() diff --git a/model_zoo/resnext50/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnext50/scripts/run_distribute_train.sh similarity index 100% rename from model_zoo/resnext50/scripts/run_distribute_train.sh rename to model_zoo/official/cv/resnext50/scripts/run_distribute_train.sh diff --git a/model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 0000000000..6ab980a0fa --- /dev/null +++ b/model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +DATA_DIR=$1 +export RANK_SIZE=8 +PATH_CHECKPOINT="" +if [ $# == 2 ] +then + PATH_CHECKPOINT=$2 +fi + +mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py \ + --is_distribute=1 \ + --platform="GPU" \ + --pretrained=$PATH_CHECKPOINT \ + --data_dir=$DATA_DIR > log.txt 2>&1 & diff --git a/model_zoo/official/cv/resnext50/scripts/run_eval.sh b/model_zoo/official/cv/resnext50/scripts/run_eval.sh new file mode 100644 index 0000000000..c884180950 --- /dev/null +++ b/model_zoo/official/cv/resnext50/scripts/run_eval.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT=$3 +PLATFORM=Ascend +if [ $# == 4 ] +then + PLATFORM=$4 +fi + +python eval.py \ + --pretrained=$PATH_CHECKPOINT \ + --platform=$PLATFORM \ + --data_dir=$DATA_DIR > log.txt 2>&1 & diff --git a/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh b/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh new file mode 100644 index 0000000000..f10d7a2f57 --- /dev/null +++ b/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi + +python train.py \ + --is_distribute=0 \ + --device_id=$DEVICE_ID \ + --pretrained=$PATH_CHECKPOINT \ + --data_dir=$DATA_DIR > log.txt 2>&1 & + diff --git a/model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh new file mode 100644 index 0000000000..1d1d82fb88 --- /dev/null +++ b/model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi + +python train.py \ + --is_distribute=0 \ + --pretrained=$PATH_CHECKPOINT \ + --platform="GPU" \ + --data_dir=$DATA_DIR > log.txt 2>&1 & + diff --git a/model_zoo/official/cv/resnext50/src/__init__.py b/model_zoo/official/cv/resnext50/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/resnext50/src/backbone/__init__.py b/model_zoo/official/cv/resnext50/src/backbone/__init__.py similarity index 100% rename from model_zoo/resnext50/src/backbone/__init__.py rename to model_zoo/official/cv/resnext50/src/backbone/__init__.py diff --git a/model_zoo/official/cv/resnext50/src/backbone/resnet.py b/model_zoo/official/cv/resnext50/src/backbone/resnet.py new file mode 100644 index 0000000000..9c880154ea --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/backbone/resnet.py @@ -0,0 +1,279 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ResNet based ResNext +""" +import mindspore.nn as nn +from mindspore.ops.operations import TensorAdd, Split, Concat +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal + +from src.utils.cunstom_op import SEBlock, GroupConv + + +__all__ = ['ResNet', 'resnext50'] + + +def weight_variable(shape, factor=0.1): + return TruncatedNormal(0.02) + + +def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False, groups=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias, + padding=padding, pad_mode="pad", group=groups) + + +def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False, groups=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias, + padding=padding, pad_mode="pad", group=groups) + + +def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False, groups=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias, + padding=padding, pad_mode="pad", group=groups) + + +class _DownSample(nn.Cell): + """ + Downsample for ResNext-ResNet. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + stride (int): Stride size for the 1*1 convolutional layer. + + Returns: + Tensor, output tensor. + + Examples: + >>>DownSample(32, 64, 2) + """ + def __init__(self, in_channels, out_channels, stride): + super(_DownSample, self).__init__() + self.conv = conv1x1(in_channels, out_channels, stride=stride, padding=0) + self.bn = nn.BatchNorm2d(out_channels) + + def construct(self, x): + out = self.conv(x) + out = self.bn(out) + return out + +class BasicBlock(nn.Cell): + """ + ResNet basic block definition. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>>BasicBlock(32, 256, stride=2) + """ + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, + platform="Ascend", **kwargs): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_channels, out_channels, stride=stride) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = P.ReLU() + self.conv2 = conv3x3(out_channels, out_channels, stride=1) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.use_se = use_se + if self.use_se: + self.se = SEBlock(out_channels) + + self.down_sample_flag = False + if down_sample is not None: + self.down_sample = down_sample + self.down_sample_flag = True + + self.add = TensorAdd() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.use_se: + out = self.se(out) + + if self.down_sample_flag: + identity = self.down_sample(x) + + out = self.add(out, identity) + out = self.relu(out) + return out + +class Bottleneck(nn.Cell): + """ + ResNet Bottleneck block definition. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + stride (int): Stride size for the initial convolutional layer. Default: 1. + + Returns: + Tensor, the ResNet unit's output. + + Examples: + >>>Bottleneck(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, in_channels, out_channels, stride=1, down_sample=None, + base_width=64, groups=1, use_se=False, platform="Ascend", **kwargs): + super(Bottleneck, self).__init__() + + width = int(out_channels * (base_width / 64.0)) * groups + self.groups = groups + self.conv1 = conv1x1(in_channels, width, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.relu = P.ReLU() + + self.conv3x3s = nn.CellList() + + if platform == "GPU": + self.conv2 = nn.Conv2d(width, width, 3, stride, pad_mode='pad', padding=1, group=groups) + else: + self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups) + + self.op_split = Split(axis=1, output_num=self.groups) + self.op_concat = Concat(axis=1) + + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = conv1x1(width, out_channels * self.expansion, stride=1) + self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) + + self.use_se = use_se + if self.use_se: + self.se = SEBlock(out_channels * self.expansion) + + self.down_sample_flag = False + if down_sample is not None: + self.down_sample = down_sample + self.down_sample_flag = True + + self.cast = P.Cast() + self.add = TensorAdd() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.use_se: + out = self.se(out) + + if self.down_sample_flag: + identity = self.down_sample(x) + + out = self.add(out, identity) + out = self.relu(out) + return out + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (cell): Block for network. + layers (list): Numbers of block in different layers. + width_per_group (int): Width of every group. + groups (int): Groups number. + + Returns: + Tuple, output tensor tuple. + + Examples: + >>>ResNet() + """ + def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False, platform="Ascend"): + super(ResNet, self).__init__() + self.in_channels = 64 + self.groups = groups + self.base_width = width_per_group + + self.conv = conv7x7(3, self.in_channels, stride=2, padding=3) + self.bn = nn.BatchNorm2d(self.in_channels) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + + self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se, platform=platform) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se, platform=platform) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se, platform=platform) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se, platform=platform) + + self.out_channels = 512 * block.expansion + self.cast = P.Cast() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False, platform="Ascend"): + """_make_layer""" + down_sample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + down_sample = _DownSample(self.in_channels, + out_channels * block.expansion, + stride=stride) + + layers = [] + layers.append(block(self.in_channels, + out_channels, + stride=stride, + down_sample=down_sample, + base_width=self.base_width, + groups=self.groups, + use_se=use_se, + platform=platform)) + self.in_channels = out_channels * block.expansion + for _ in range(1, blocks_num): + layers.append(block(self.in_channels, out_channels, base_width=self.base_width, + groups=self.groups, use_se=use_se, platform=platform)) + + return nn.SequentialCell(layers) + + def get_out_channels(self): + return self.out_channels + + +def resnext50(platform="Ascend"): + return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32, platform=platform) diff --git a/model_zoo/official/cv/resnext50/src/config.py b/model_zoo/official/cv/resnext50/src/config.py new file mode 100644 index 0000000000..0acff08342 --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/config.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""config""" +from easydict import EasyDict as ed + +config = ed({ + "image_size": '224,224', + "num_classes": 1000, + + "lr": 0.4, + "lr_scheduler": 'cosine_annealing', + "lr_epochs": '30,60,90,120', + "lr_gamma": 0.1, + "eta_min": 0, + "T_max": 150, + "max_epoch": 150, + "backbone": 'resnext50', + "warmup_epochs": 1, + + "weight_decay": 0.0001, + "momentum": 0.9, + "is_dynamic_loss_scale": 0, + "loss_scale": 1024, + "label_smooth": 1, + "label_smooth_factor": 0.1, + + "ckpt_interval": 5, + "ckpt_save_max": 5, + "ckpt_path": 'outputs/', + "is_save_on_master": 1, + + "rank": 0, + "group_size": 1 +}) diff --git a/model_zoo/resnext50/src/crossentropy.py b/model_zoo/official/cv/resnext50/src/crossentropy.py similarity index 100% rename from model_zoo/resnext50/src/crossentropy.py rename to model_zoo/official/cv/resnext50/src/crossentropy.py diff --git a/model_zoo/official/cv/resnext50/src/dataset.py b/model_zoo/official/cv/resnext50/src/dataset.py new file mode 100644 index 0000000000..66fc653c47 --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/dataset.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +dataset processing. +""" +import os +from mindspore.common import dtype as mstype +import mindspore.dataset as de +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as V_C +from PIL import Image, ImageFile +from src.utils.sampler import DistributedSampler + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class TxtDataset(): + """ + create txt dataset. + + Args: + Returns: + de_dataset. + """ + def __init__(self, root, txt_name): + super(TxtDataset, self).__init__() + self.imgs = [] + self.labels = [] + fin = open(txt_name, "r") + for line in fin: + img_name, label = line.strip().split(' ') + self.imgs.append(os.path.join(root, img_name)) + self.labels.append(int(label)) + fin.close() + + def __getitem__(self, index): + img = Image.open(self.imgs[index]).convert('RGB') + return img, self.labels[index] + + def __len__(self): + return len(self.imgs) + + +def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size, + mode='train', + input_mode='folder', + root='', + num_parallel_workers=None, + shuffle=None, + sampler=None, + class_indexing=None, + drop_remainder=True, + transform=None, + target_transform=None): + """ + A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt". + If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images + are written into a textfile. + + Args: + data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"". + Or path of the textfile that contains every image's path of the dataset. + image_size (str): Size of the input images. + per_batch_size (int): the batch size of evey step during training. + max_epoch (int): the number of epochs. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided + into (default=None). + mode (str): "train" or others. Default: " train". + input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder". + root (str): the images path for "input_mode="txt"". Default: " ". + num_parallel_workers (int): Number of workers to read the data. Default: None. + shuffle (bool): Whether or not to perform shuffle on the dataset + (default=None, performs shuffle). + sampler (Sampler): Object used to choose samples from the dataset. Default: None. + class_indexing (dict): A str-to-int mapping from folder name to index + (default=None, the folder names will be sorted + alphabetically and each class will be given a + unique index starting from 0). + + Examples: + >>> from mindvision.common.datasets.classification import classification_dataset + >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images + >>> dataset_dir = "/path/to/imagefolder_directory" + >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], + >>> per_batch_size=64, max_epoch=100, + >>> rank=0, group_size=4) + >>> # Path of the textfile that contains every image's path of the dataset. + >>> dataset_dir = "/path/to/dataset/images/train.txt" + >>> images_dir = "/path/to/dataset/images" + >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], + >>> per_batch_size=64, max_epoch=100, + >>> rank=0, group_size=4, + >>> input_mode="txt", root=images_dir) + """ + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + if transform is None: + if mode == 'train': + transform_img = [ + V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + V_C.RandomHorizontalFlip(prob=0.5), + V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4), + V_C.Normalize(mean=mean, std=std), + V_C.HWC2CHW() + ] + else: + transform_img = [ + V_C.Decode(), + V_C.Resize((256, 256)), + V_C.CenterCrop(image_size), + V_C.Normalize(mean=mean, std=std), + V_C.HWC2CHW() + ] + else: + transform_img = transform + + if target_transform is None: + transform_label = [C.TypeCast(mstype.int32)] + else: + transform_label = target_transform + + if input_mode == 'folder': + de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers, + shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, + num_shards=group_size, shard_id=rank) + else: + dataset = TxtDataset(root, data_dir) + sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) + de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) + de_dataset.set_dataset_size(len(sampler)) + + de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers, + operations=transform_img) + de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers, + operations=transform_label) + + columns_to_project = ["image", "label"] + de_dataset = de_dataset.project(columns=columns_to_project) + + de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder) + de_dataset = de_dataset.repeat(max_epoch) + + return de_dataset diff --git a/model_zoo/resnext50/src/head.py b/model_zoo/official/cv/resnext50/src/head.py similarity index 100% rename from model_zoo/resnext50/src/head.py rename to model_zoo/official/cv/resnext50/src/head.py diff --git a/model_zoo/official/cv/resnext50/src/image_classification.py b/model_zoo/official/cv/resnext50/src/image_classification.py new file mode 100644 index 0000000000..6b12ede367 --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/image_classification.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Image classifiation. +""" +import math +import mindspore.nn as nn +from mindspore.common import initializer as init +import src.backbone as backbones +import src.head as heads +from src.utils.var_init import default_recurisive_init, KaimingNormal + + +class ImageClassificationNetwork(nn.Cell): + """ + architecture of image classification network. + + Args: + Returns: + Tensor, output tensor. + """ + def __init__(self, backbone, head): + super(ImageClassificationNetwork, self).__init__() + self.backbone = backbone + self.head = head + + def construct(self, x): + x = self.backbone(x) + x = self.head(x) + return x + +class Resnet(ImageClassificationNetwork): + """ + Resnet architecture. + Args: + backbone_name (string): backbone. + num_classes (int): number of classes. + Returns: + Resnet. + """ + def __init__(self, backbone_name, num_classes, platform="Ascend"): + self.backbone_name = backbone_name + backbone = backbones.__dict__[self.backbone_name](platform=platform) + out_channels = backbone.get_out_channels() + head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) + super(Resnet, self).__init__(backbone, head) + + default_recurisive_init(self) + + for cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer( + KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.default_input = init.initializer('ones', cell.gamma.shape) + cell.beta.default_input = init.initializer('zeros', cell.beta.shape) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + for cell in self.cells_and_names(): + if isinstance(cell, backbones.resnet.Bottleneck): + cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.shape) + elif isinstance(cell, backbones.resnet.BasicBlock): + cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.shape) + + + +def get_network(backbone_name, num_classes, platform="Ascend"): + if backbone_name in ['resnext50']: + return Resnet(backbone_name, num_classes, platform) + return None diff --git a/model_zoo/resnext50/src/linear_warmup.py b/model_zoo/official/cv/resnext50/src/linear_warmup.py similarity index 100% rename from model_zoo/resnext50/src/linear_warmup.py rename to model_zoo/official/cv/resnext50/src/linear_warmup.py diff --git a/model_zoo/official/cv/resnext50/src/utils/__init__.py b/model_zoo/official/cv/resnext50/src/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py new file mode 100644 index 0000000000..f8e27f5b52 --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Auto mixed precision.""" +import mindspore.nn as nn +from mindspore.ops import functional as F +from mindspore._checkparam import Validator as validator +from mindspore.common import dtype as mstype + + +class OutputTo(nn.Cell): + "Cast cell output back to float16 or float32" + + def __init__(self, op, to_type=mstype.float16): + super(OutputTo, self).__init__(auto_prefix=False) + self._op = op + validator.check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None) + self.to_type = to_type + + def construct(self, x): + return F.cast(self._op(x), self.to_type) + + +def auto_mixed_precision(network): + """Do keep batchnorm fp32.""" + cells = network.name_cells() + change = False + network.to_float(mstype.float16) + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif name == 'fc': + network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32)) + change = True + elif name == 'conv2': + subcell.to_float(mstype.float32) + change = True + elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): + network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16)) + change = True + else: + auto_mixed_precision(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) diff --git a/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py b/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py new file mode 100644 index 0000000000..f4062821ef --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py @@ -0,0 +1,104 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network operations +""" +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + + +class GlobalAvgPooling(nn.Cell): + """ + global average pooling feature map. + + Args: + mean (tuple): means for each channel. + """ + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class SEBlock(nn.Cell): + """ + squeeze and excitation block. + + Args: + channel (int): number of feature maps. + reduction (int): weight. + """ + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + + self.avg_pool = GlobalAvgPooling() + self.fc1 = nn.Dense(channel, channel // reduction) + self.relu = P.ReLU() + self.fc2 = nn.Dense(channel // reduction, channel) + self.sigmoid = P.Sigmoid() + self.reshape = P.Reshape() + self.shape = P.Shape() + self.sum = P.Sum() + self.cast = P.Cast() + + def construct(self, x): + b, c = self.shape(x) + y = self.avg_pool(x) + + y = self.reshape(y, (b, c)) + y = self.fc1(y) + y = self.relu(y) + y = self.fc2(y) + y = self.sigmoid(y) + y = self.reshape(y, (b, c, 1, 1)) + return x * y + +class GroupConv(nn.Cell): + """ + group convolution operation. + + Args: + in_channels (int): Input channels of feature map. + out_channels (int): Output channels of feature map. + kernel_size (int): Size of convolution kernel. + stride (int): Stride size for the group convolution layer. + + Returns: + tensor, output tensor. + """ + def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False): + super(GroupConv, self).__init__() + assert in_channels % groups == 0 and out_channels % groups == 0 + self.groups = groups + self.convs = nn.CellList() + self.op_split = P.Split(axis=1, output_num=self.groups) + self.op_concat = P.Concat(axis=1) + self.cast = P.Cast() + for _ in range(groups): + self.convs.append(nn.Conv2d(in_channels//groups, out_channels//groups, + kernel_size=kernel_size, stride=stride, has_bias=has_bias, + padding=pad, pad_mode=pad_mode, group=1)) + + def construct(self, x): + features = self.op_split(x) + outputs = () + for i in range(self.groups): + outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),) + out = self.op_concat(outputs) + return out diff --git a/model_zoo/resnext50/src/utils/logging.py b/model_zoo/official/cv/resnext50/src/utils/logging.py similarity index 100% rename from model_zoo/resnext50/src/utils/logging.py rename to model_zoo/official/cv/resnext50/src/utils/logging.py diff --git a/model_zoo/resnext50/src/utils/optimizers__init__.py b/model_zoo/official/cv/resnext50/src/utils/optimizers__init__.py similarity index 100% rename from model_zoo/resnext50/src/utils/optimizers__init__.py rename to model_zoo/official/cv/resnext50/src/utils/optimizers__init__.py diff --git a/model_zoo/resnext50/src/utils/sampler.py b/model_zoo/official/cv/resnext50/src/utils/sampler.py similarity index 100% rename from model_zoo/resnext50/src/utils/sampler.py rename to model_zoo/official/cv/resnext50/src/utils/sampler.py diff --git a/model_zoo/official/cv/resnext50/src/utils/var_init.py b/model_zoo/official/cv/resnext50/src/utils/var_init.py new file mode 100644 index 0000000000..185072d441 --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/utils/var_init.py @@ -0,0 +1,214 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Initialize. +""" +import math +from functools import reduce +import numpy as np +import mindspore.nn as nn +from mindspore.common import initializer as init + +def _calculate_gain(nonlinearity, param=None): + r""" + Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function + param: optional parameter for the non-linear function + + Examples: + >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + +def _assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def _calculate_in_and_out(arr): + """ + Calculate n_in and n_out. + + Args: + arr (Array): Input array. + + Returns: + Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. + """ + dim = len(arr.shape) + if dim < 2: + raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") + + n_in = arr.shape[1] + n_out = arr.shape[0] + + if dim > 2: + counter = reduce(lambda x, y: x * y, arr.shape[2:]) + n_in *= counter + n_out *= counter + return n_in, n_out + +def _select_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_in_and_out(array) + return fan_in if mode == 'fan_in' else fan_out + +class KaimingInit(init.Initializer): + r""" + Base Class. Initialize the array with He kaiming algorithm. + + Args: + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function, recommended to use only with + ``'relu'`` or ``'leaky_relu'`` (default). + """ + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingInit, self).__init__() + self.mode = mode + self.gain = _calculate_gain(nonlinearity, a) + def _initialize(self, arr): + pass + + +class KaimingUniform(KaimingInit): + r""" + Initialize the array with He kaiming uniform algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.uniform(-bound, bound, arr.shape) + + _assignment(arr, data) + + +class KaimingNormal(KaimingInit): + r""" + Initialize the array with He kaiming normal algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + std = self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +def default_recurisive_init(custom_cell): + """default_recurisive_init""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass diff --git a/model_zoo/resnext50/src/warmup_cosine_annealing_lr.py b/model_zoo/official/cv/resnext50/src/warmup_cosine_annealing_lr.py similarity index 100% rename from model_zoo/resnext50/src/warmup_cosine_annealing_lr.py rename to model_zoo/official/cv/resnext50/src/warmup_cosine_annealing_lr.py diff --git a/model_zoo/resnext50/src/warmup_step_lr.py b/model_zoo/official/cv/resnext50/src/warmup_step_lr.py similarity index 100% rename from model_zoo/resnext50/src/warmup_step_lr.py rename to model_zoo/official/cv/resnext50/src/warmup_step_lr.py diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py new file mode 100644 index 0000000000..6b0eaae03b --- /dev/null +++ b/model_zoo/official/cv/resnext50/train.py @@ -0,0 +1,294 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train ImageNet.""" +import os +import time +import argparse +import datetime + +import mindspore.nn as nn +from mindspore import Tensor, context +from mindspore import ParallelMode +from mindspore.nn.optim import Momentum +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import ModelCheckpoint +from mindspore.train.callback import CheckpointConfig, Callback +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager + +from src.dataset import classification_dataset +from src.crossentropy import CrossEntropy +from src.warmup_step_lr import warmup_step_lr +from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr +from src.utils.logging import get_logger +from src.utils.optimizers__init__ import get_param_groups +from src.image_classification import get_network +from src.utils.auto_mixed_precision import auto_mixed_precision +from src.config import config + + +class BuildTrainNetwork(nn.Cell): + """build training network""" + def __init__(self, network, criterion): + super(BuildTrainNetwork, self).__init__() + self.network = network + self.criterion = criterion + + def construct(self, input_data, label): + output = self.network(input_data) + loss = self.criterion(output, label) + return loss + +class ProgressMonitor(Callback): + """monitor loss and time""" + def __init__(self, args): + super(ProgressMonitor, self).__init__() + self.me_epoch_start_time = 0 + self.me_epoch_start_step_num = 0 + self.args = args + self.ckpt_history = [] + + def begin(self, run_context): + self.args.logger.info('start network train...') + + def epoch_begin(self, run_context): + pass + + def epoch_end(self, run_context, *me_args): + cb_params = run_context.original_args() + me_step = cb_params.cur_step_num - 1 + + real_epoch = me_step // self.args.steps_per_epoch + time_used = time.time() - self.me_epoch_start_time + fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used + self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}' + 'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean)) + + if self.args.rank_save_ckpt_flag: + import glob + ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt')) + for ckpt in ckpts: + ckpt_fn = os.path.basename(ckpt) + if not ckpt_fn.startswith('{}-'.format(self.args.rank)): + continue + if ckpt in self.ckpt_history: + continue + self.ckpt_history.append(ckpt) + self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},' + 'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn)) + + + self.me_epoch_start_step_num = me_step + self.me_epoch_start_time = time.time() + + def step_begin(self, run_context): + pass + + def step_end(self, run_context, *me_args): + pass + + def end(self, run_context): + self.args.logger.info('end network train...') + + +def parse_args(cloud_args=None): + """parameters""" + parser = argparse.ArgumentParser('mindspore classification training') + parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') + + # dataset related + parser.add_argument('--data_dir', type=str, default='', help='train data dir') + parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu') + # network related + parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') + + # distributed related + parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') + # roma obs + parser.add_argument('--train_url', type=str, default="", help='train url') + + args, _ = parser.parse_known_args() + args = merge_args(args, cloud_args) + args.image_size = config.image_size + args.num_classes = config.num_classes + args.lr = config.lr + args.lr_scheduler = config.lr_scheduler + args.lr_epochs = config.lr_epochs + args.lr_gamma = config.lr_gamma + args.eta_min = config.eta_min + args.T_max = config.T_max + args.max_epoch = config.max_epoch + args.backbone = config.backbone + args.warmup_epochs = config.warmup_epochs + args.weight_decay = config.weight_decay + args.momentum = config.momentum + args.is_dynamic_loss_scale = config.is_dynamic_loss_scale + args.loss_scale = config.loss_scale + args.label_smooth = config.label_smooth + args.label_smooth_factor = config.label_smooth_factor + args.ckpt_interval = config.ckpt_interval + args.ckpt_save_max = config.ckpt_save_max + args.ckpt_path = config.ckpt_path + args.is_save_on_master = config.is_save_on_master + args.rank = config.rank + args.group_size = config.group_size + args.lr_epochs = list(map(int, args.lr_epochs.split(','))) + args.image_size = list(map(int, args.image_size.split(','))) + + return args + +def merge_args(args, cloud_args): + """dictionary""" + args_dict = vars(args) + if isinstance(cloud_args, dict): + for key in cloud_args.keys(): + val = cloud_args[key] + if key in args_dict and val: + arg_type = type(args_dict[key]) + if arg_type is not type(None): + val = arg_type(val) + args_dict[key] = val + return args + +def train(cloud_args=None): + """training process""" + args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + # init distributed + if args.is_distributed: + if args.platform == "Ascend": + init() + else: + init("nccl") + args.rank = get_rank() + args.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + args.rank = 0 + args.group_size = 1 + + if args.is_dynamic_loss_scale == 1: + args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt + + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + + # dataloader + de_dataset = classification_dataset(args.data_dir, args.image_size, + args.per_batch_size, 1, + args.rank, args.group_size, num_parallel_workers=8) + de_dataset.map_model = 4 # !!!important + args.steps_per_epoch = de_dataset.get_dataset_size() + + args.logger.save_args(args) + + # network + args.logger.important_info('start create network') + # get network and init + network = get_network(args.backbone, args.num_classes, platform=args.platform) + if network is None: + raise NotImplementedError('not implement {}'.format(args.backbone)) + + # load pretrain model + if os.path.isfile(args.pretrained): + param_dict = load_checkpoint(args.pretrained) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.'): + continue + elif key.startswith('network.'): + param_dict_new[key[8:]] = values + else: + param_dict_new[key] = values + load_param_into_net(network, param_dict_new) + args.logger.info('load model {} success'.format(args.pretrained)) + + # lr scheduler + if args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + else: + raise NotImplementedError(args.lr_scheduler) + + # optimizer + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + + # loss + if not args.label_smooth: + args.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) + + if args.is_dynamic_loss_scale == 1: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) + + if args.platform == "Ascend": + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, + metrics={'acc'}, amp_level="O3") + else: + auto_mixed_precision(network) + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'}) + + # checkpoint save + progress_cb = ProgressMonitor(args) + callbacks = [progress_cb,] + if args.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, + keep_checkpoint_max=args.ckpt_save_max) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=args.outputs_dir, + prefix='{}'.format(args.rank)) + callbacks.append(ckpt_cb) + + model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True) + + +if __name__ == "__main__": + train() diff --git a/model_zoo/ssd/README.md b/model_zoo/official/cv/ssd/README.md similarity index 100% rename from model_zoo/ssd/README.md rename to model_zoo/official/cv/ssd/README.md diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py new file mode 100644 index 0000000000..37b5092206 --- /dev/null +++ b/model_zoo/official/cv/ssd/eval.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Evaluation for SSD""" + +import os +import argparse +import time +import numpy as np +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.ssd import SSD300, ssd_mobilenet_v2 +from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord +from src.config import config +from src.coco_eval import metrics + +def ssd_eval(dataset_path, ckpt_path): + """SSD evaluation.""" + batch_size = 1 + ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) + net = SSD300(ssd_mobilenet_v2(), config, is_training=False) + print("Load Checkpoint!") + param_dict = load_checkpoint(ckpt_path) + net.init_parameters_data() + load_param_into_net(net, param_dict) + + net.set_train(False) + i = batch_size + total = ds.get_dataset_size() * batch_size + start = time.time() + pred_data = [] + print("\n========================================\n") + print("total images num: ", total) + print("Processing, please wait a moment.") + for data in ds.create_dict_iterator(): + img_id = data['img_id'] + img_np = data['image'] + image_shape = data['image_shape'] + + output = net(Tensor(img_np)) + for batch_idx in range(img_np.shape[0]): + pred_data.append({"boxes": output[0].asnumpy()[batch_idx], + "box_scores": output[1].asnumpy()[batch_idx], + "img_id": int(np.squeeze(img_id[batch_idx])), + "image_shape": image_shape[batch_idx]}) + percent = round(i / total * 100., 2) + + print(f' {str(percent)} [{i}/{total}]', end='\r') + i += batch_size + cost_time = int((time.time() - start) * 1000) + print(f' 100% [{total}/{total}] cost {cost_time} ms') + mAP = metrics(pred_data) + print("\n========================================\n") + print(f"mAP: {mAP}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SSD evaluation') + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") + parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + + prefix = "ssd_eval.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if args_opt.dataset == "voc": + config.coco_root = config.voc_root + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + elif args_opt.dataset == "voc": + if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_root or voc_dir not exits.") + else: + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + print("Start Eval!") + ssd_eval(mindrecord_file, args_opt.checkpoint_path) diff --git a/model_zoo/official/cv/ssd/scripts/run_distribute_train.sh b/model_zoo/official/cv/ssd/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..7175d22988 --- /dev/null +++ b/model_zoo/official/cv/ssd/scripts/run_distribute_train.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" +echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" +echo "It is better to use absolute path." +echo "=================================================================================================================" + +if [ $# != 5 ] && [ $# != 7 ] +then + echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ +[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" + exit 1 +fi + +# Before start distribute train, first create mindrecord files. +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) +cd $BASE_PATH/../ || exit +python train.py --only_create_dataset=1 + +echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" + +export RANK_SIZE=$1 +EPOCH_SIZE=$2 +LR=$3 +DATASET=$4 +PRE_TRAINED=$6 +PRE_TRAINED_EPOCH_SIZE=$7 +export RANK_TABLE_FILE=$5 + +for((i=0;i env.log + if [ $# == 5 ] + then + python train.py \ + --distribute=1 \ + --lr=$LR \ + --dataset=$DATASET \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + fi + + if [ $# == 7 ] + then + python train.py \ + --distribute=1 \ + --lr=$LR \ + --dataset=$DATASET \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --pre_trained=$PRE_TRAINED \ + --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ + --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + fi + + cd ../ +done diff --git a/model_zoo/official/cv/ssd/src/__init__.py b/model_zoo/official/cv/ssd/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/ssd/src/box_utils.py b/model_zoo/official/cv/ssd/src/box_utils.py similarity index 100% rename from model_zoo/ssd/src/box_utils.py rename to model_zoo/official/cv/ssd/src/box_utils.py diff --git a/model_zoo/ssd/src/coco_eval.py b/model_zoo/official/cv/ssd/src/coco_eval.py similarity index 100% rename from model_zoo/ssd/src/coco_eval.py rename to model_zoo/official/cv/ssd/src/coco_eval.py diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py new file mode 100644 index 0000000000..ff0cd21963 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/config.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" ============================================================================ + +"""Config parameters for SSD models.""" + +from easydict import EasyDict as ed + +config = ed({ + "img_shape": [300, 300], + "num_ssd_boxes": 1917, + "neg_pre_positive": 3, + "match_thershold": 0.5, + "nms_thershold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + + # learing rate settings + "global_step": 0, + "lr_init": 0.001, + "lr_end_rate": 0.001, + "warmup_epochs": 2, + "momentum": 0.9, + "weight_decay": 1.5e-4, + + # network + "num_default": [3, 6, 6, 6, 6, 6], + "extras_in_channels": [256, 576, 1280, 512, 256, 256], + "extras_out_channels": [576, 1280, 512, 256, 256, 128], + "extras_srides": [1, 1, 2, 2, 2, 2], + "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], + "feature_size": [19, 10, 5, 3, 2, 1], + "min_scale": 0.2, + "max_scale": 0.95, + "aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], + "steps": (16, 32, 64, 100, 150, 300), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.75, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "mindrecord_dir": "/data/MindRecord_COCO", + "coco_root": "/data/coco2017", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "annotations/instances_{}.json", + "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + # The annotation.json position of voc validation dataset. + "voc_root": "", + # voc original dataset. + "voc_dir": "", + # if coco or voc used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "", +}) diff --git a/model_zoo/official/cv/ssd/src/dataset.py b/model_zoo/official/cv/ssd/src/dataset.py new file mode 100644 index 0000000000..d842aef709 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/dataset.py @@ -0,0 +1,412 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SSD dataset""" + +from __future__ import division + +import os +import json +import xml.etree.ElementTree as et +import cv2 +import numpy as np + +import mindspore.dataset as de +import mindspore.dataset.transforms.vision.c_transforms as C +from mindspore.mindrecord import FileWriter +from .config import config +from .box_utils import jaccard_numpy, ssd_bboxes_encode + + +def _rand(a=0., b=1.): + """Generate random.""" + return np.random.rand() * (b - a) + a + +def get_imageId_from_fileName(filename): + """Get imageID from fileName""" + try: + filename = os.path.splitext(filename)[0] + return int(filename) + except: + raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename)) + +def random_sample_crop(image, boxes): + """Random Crop the image and boxes""" + height, width, _ = image.shape + min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) + + if min_iou is None: + return image, boxes + + # max trails (50) + for _ in range(50): + image_t = image + + w = _rand(0.3, 1.0) * width + h = _rand(0.3, 1.0) * height + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = _rand() * (width - w) + top = _rand() * (height - h) + + rect = np.array([int(top), int(left), int(top+h), int(left+w)]) + overlap = jaccard_numpy(boxes, rect) + + # dropout some boxes + drop_mask = overlap > 0 + if not drop_mask.any(): + continue + + if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2): + continue + + image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] + + centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 + + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 * drop_mask + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + boxes_t = boxes[mask, :].copy() + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) + boxes_t[:, :2] -= rect[:2] + boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) + boxes_t[:, 2:4] -= rect[:2] + + return image_t, boxes_t + return image, boxes + + +def preprocess_fn(img_id, image, box, is_training): + """Preprocess function for dataset.""" + def _infer_data(image, input_shape): + img_h, img_w, _ = image.shape + input_h, input_w = input_shape + + image = cv2.resize(image, (input_w, input_h)) + + #When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + return img_id, image, np.array((img_h, img_w), np.float32) + + def _data_aug(image, box, is_training, image_size=(300, 300)): + """Data augmentation function.""" + ih, iw, _ = image.shape + w, h = image_size + + if not is_training: + return _infer_data(image, image_size) + + # Random crop + box = box.astype(np.float32) + image, box = random_sample_crop(image, box) + ih, iw, _ = image.shape + + # Resize image + image = cv2.resize(image, (w, h)) + + # Flip image or not + flip = _rand() < .5 + if flip: + image = cv2.flip(image, 1, dst=None) + + # When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + box[:, [0, 2]] = box[:, [0, 2]] / ih + box[:, [1, 3]] = box[:, [1, 3]] / iw + + if flip: + box[:, [1, 3]] = 1 - box[:, [3, 1]] + + box, label, num_match = ssd_bboxes_encode(box) + return image, box, label, num_match + return _data_aug(image, box, is_training, image_size=config.img_shape) + + +def create_voc_label(is_training): + """Get image path and annotation from VOC.""" + voc_dir = config.voc_dir + cls_map = {name: i for i, name in enumerate(config.coco_classes)} + sub_dir = 'train' if is_training else 'eval' + #sub_dir = 'train' + voc_dir = os.path.join(voc_dir, sub_dir) + if not os.path.isdir(voc_dir): + raise ValueError(f'Cannot find {sub_dir} dataset path.') + + image_dir = anno_dir = voc_dir + if os.path.isdir(os.path.join(voc_dir, 'Images')): + image_dir = os.path.join(voc_dir, 'Images') + if os.path.isdir(os.path.join(voc_dir, 'Annotations')): + anno_dir = os.path.join(voc_dir, 'Annotations') + + if not is_training: + data_dir = config.voc_root + json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) + file_dir = os.path.split(json_file)[0] + if not os.path.isdir(file_dir): + os.makedirs(file_dir) + json_dict = {"images": [], "type": "instances", "annotations": [], + "categories": []} + bnd_id = 1 + + image_files_dict = {} + image_anno_dict = {} + images = [] + for anno_file in os.listdir(anno_dir): + print(anno_file) + if not anno_file.endswith('xml'): + continue + tree = et.parse(os.path.join(anno_dir, anno_file)) + root_node = tree.getroot() + file_name = root_node.find('filename').text + img_id = get_imageId_from_fileName(file_name) + image_path = os.path.join(image_dir, file_name) + print(image_path) + if not os.path.isfile(image_path): + print(f'Cannot find image {file_name} according to annotations.') + continue + + labels = [] + for obj in root_node.iter('object'): + cls_name = obj.find('name').text + if cls_name not in cls_map: + print(f'Label "{cls_name}" not in "{config.coco_classes}"') + continue + bnd_box = obj.find('bndbox') + x_min = int(bnd_box.find('xmin').text) - 1 + y_min = int(bnd_box.find('ymin').text) - 1 + x_max = int(bnd_box.find('xmax').text) - 1 + y_max = int(bnd_box.find('ymax').text) - 1 + labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]]) + + if not is_training: + o_width = abs(x_max - x_min) + o_height = abs(y_max - y_min) + ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \ + img_id, 'bbox': [x_min, y_min, o_width, o_height], \ + 'category_id': cls_map[cls_name], 'id': bnd_id, \ + 'ignore': 0, \ + 'segmentation': []} + json_dict['annotations'].append(ann) + bnd_id = bnd_id + 1 + + if labels: + images.append(img_id) + image_files_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(labels) + + if not is_training: + size = root_node.find("size") + width = int(size.find('width').text) + height = int(size.find('height').text) + image = {'file_name': file_name, 'height': height, 'width': width, + 'id': img_id} + json_dict['images'].append(image) + + for cls_name, cid in cls_map.items(): + cat = {'supercategory': 'none', 'id': cid, 'name': cls_name} + json_dict['categories'].append(cat) + json_fp = open(json_file, 'w') + json_str = json.dumps(json_dict) + json_fp.write(json_str) + json_fp.close() + + return images, image_files_dict, image_anno_dict + +def create_coco_label(is_training): + """Get image path and annotation from COCO.""" + from pycocotools.coco import COCO + + coco_root = config.coco_root + data_type = config.val_data_type + if is_training: + data_type = config.train_data_type + + #Classes need to train or test. + train_cls = config.coco_classes + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + + anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) + + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + + image_ids = coco.getImgIds() + images = [] + image_path_dict = {} + image_anno_dict = {} + + for img_id in image_ids: + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + image_path = os.path.join(coco_root, data_type, file_name) + annos = [] + iscrowd = False + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + iscrowd = iscrowd or label["iscrowd"] + if class_name in train_cls: + x_min, x_max = bbox[0], bbox[0] + bbox[2] + y_min, y_max = bbox[1], bbox[1] + bbox[3] + annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) + + if not is_training and iscrowd: + continue + if len(annos) >= 1: + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(annos) + + return images, image_path_dict, image_anno_dict + + +def anno_parser(annos_str): + """Parse annotation from string to list.""" + annos = [] + for anno_str in annos_str: + anno = list(map(int, anno_str.strip().split(','))) + annos.append(anno) + return annos + + +def filter_valid_data(image_dir, anno_path): + """Filter valid image file, which both in image_dir and anno_path.""" + images = [] + image_path_dict = {} + image_anno_dict = {} + if not os.path.isdir(image_dir): + raise RuntimeError("Path given is not valid.") + if not os.path.isfile(anno_path): + raise RuntimeError("Annotation file is not valid.") + + with open(anno_path, "rb") as f: + lines = f.readlines() + for img_id, line in enumerate(lines): + line_str = line.decode("utf-8").strip() + line_split = str(line_str).split(' ') + file_name = line_split[0] + image_path = os.path.join(image_dir, file_name) + if os.path.isfile(image_path): + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = anno_parser(line_split[1:]) + + return images, image_path_dict, image_anno_dict + + +def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="ssd.mindrecord", file_num=8): + """Create MindRecord file by image_dir and anno_path.""" + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + images, image_path_dict, image_anno_dict = create_voc_label(is_training) + + ssd_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(ssd_json, "ssd_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): + """Create MindRecord file.""" + mindrecord_dir = config.mindrecord_dir + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + if dataset == "coco": + images, image_path_dict, image_anno_dict = create_coco_label(is_training) + else: + images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path) + + ssd_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(ssd_json, "ssd_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, + is_training=True, num_parallel_workers=4): + """Creatr SSD dataset with MindDataset.""" + ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, + shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) + decode = C.Decode() + ds = ds.map(input_columns=["image"], operations=decode) + change_swap_op = C.HWC2CHW() + normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) + color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) + if is_training: + output_columns = ["image", "box", "label", "num_match"] + trans = [color_adjust_op, normalize_op, change_swap_op] + else: + output_columns = ["img_id", "image", "image_shape"] + trans = [normalize_op, change_swap_op] + ds = ds.map(input_columns=["img_id", "image", "annotation"], + output_columns=output_columns, columns_order=output_columns, + operations=compose_map_func, python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/official/cv/ssd/src/init_params.py b/model_zoo/official/cv/ssd/src/init_params.py new file mode 100644 index 0000000000..b71ee2c4dc --- /dev/null +++ b/model_zoo/official/cv/ssd/src/init_params.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parameters utils""" + +from mindspore.common.initializer import initializer, TruncatedNormal +import numpy as np + +def init_net_param(network, initialize_mode='TruncatedNormal'): + """Init the parameters in net.""" + params = network.trainable_params() + for p in params: + if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: + np.random.seed(seed=1) + if initialize_mode == 'TruncatedNormal': + p.set_parameter_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype)) + else: + p.set_parameter_data(initialize_mode, p.data.shape, p.data.dtype) + + +def load_backbone_params(network, param_dict): + """Init the parameters from pre-train model, default is mobilenetv2.""" + for _, param in net.parameters_and_names(): + param_name = param.name.replace('network.backbone.', '') + name_split = param_name.split('.') + if 'features_1' in param_name: + param_name = param_name.replace('features_1', 'features') + if 'features_2' in param_name: + param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:]) + if param_name in param_dict: + param.set_parameter_data(param_dict[param_name].data) + +def filter_checkpoint_parameter(param_dict): + """remove useless parameters""" + for key in list(param_dict.keys()): + if 'multi_loc_layers' in key or 'multi_cls_layers' in key: + del param_dict[key] diff --git a/model_zoo/ssd/src/lr_schedule.py b/model_zoo/official/cv/ssd/src/lr_schedule.py similarity index 100% rename from model_zoo/ssd/src/lr_schedule.py rename to model_zoo/official/cv/ssd/src/lr_schedule.py diff --git a/model_zoo/ssd/src/ssd.py b/model_zoo/official/cv/ssd/src/ssd.py similarity index 100% rename from model_zoo/ssd/src/ssd.py rename to model_zoo/official/cv/ssd/src/ssd.py diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py new file mode 100644 index 0000000000..c38026103d --- /dev/null +++ b/model_zoo/official/cv/ssd/train.py @@ -0,0 +1,145 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Train SSD and get checkpoint files.""" + +import os +import argparse +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.communication.management import init +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor +from mindspore.train import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 +from src.config import config +from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord +from src.lr_schedule import get_lr +from src.init_params import init_net_param, filter_checkpoint_parameter + + +def main(): + parser = argparse.ArgumentParser(description="SSD training") + parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " + "Mindrecord, default is False.") + parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.") + parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") + parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") + parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") + parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") + parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") + parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") + parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") + parser.add_argument("--filter_weight", type=bool, default=False, help="Filter weight parameters, default is False.") + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + + if args_opt.distribute: + device_num = args_opt.device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=device_num) + init() + rank = args_opt.device_id % device_num + else: + rank = 0 + device_num = 1 + + print("Start create dataset!") + + # It will generate mindrecord file in args_opt.mindrecord_dir, + # and the file name is ssd.mindrecord0, 1, ... file_num. + + prefix = "ssd.mindrecord" + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if args_opt.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + elif args_opt.dataset == "voc": + if os.path.isdir(config.voc_dir): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_dir not exits.") + else: + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("image_dir or anno_path not exits.") + + if not args_opt.only_create_dataset: + loss_scale = float(args_opt.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. + dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, + batch_size=args_opt.batch_size, device_num=device_num, rank=rank) + + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + backbone = ssd_mobilenet_v2() + ssd = SSD300(backbone=backbone, config=config) + net = SSDWithLossCell(ssd, config) + init_net_param(net) + + # checkpoint + ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) + ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) + + if args_opt.pre_trained: + if args_opt.pre_trained_epoch_size <= 0: + raise KeyError("pre_trained_epoch_size must be greater than 0.") + param_dict = load_checkpoint(args_opt.pre_trained) + if args_opt.filter_weight: + filter_checkpoint_parameter(param_dict) + load_param_into_net(net, param_dict) + + lr = Tensor(get_lr(global_step=config.global_step, + lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=args_opt.epoch_size, + steps_per_epoch=dataset_size)) + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, + config.momentum, config.weight_decay, loss_scale) + net = TrainingWrapper(net, opt, loss_scale) + + callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] + + model = Model(net) + dataset_sink_mode = False + if args_opt.mode == "sink": + print("In sink mode, one epoch return a loss.") + dataset_sink_mode = True + print("Start train SSD, the first epoch will be slower because of the graph compilation.") + model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) + +if __name__ == '__main__': + main() diff --git a/model_zoo/official/cv/vgg16/README.md b/model_zoo/official/cv/vgg16/README.md new file mode 100644 index 0000000000..346738a8fa --- /dev/null +++ b/model_zoo/official/cv/vgg16/README.md @@ -0,0 +1,225 @@ +# VGG16 Example + +## Description + +This example is for VGG16 model training and evaluation. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset CIFAR-10 or ImageNet2012. + +CIFAR-10 + +> Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows: +> ``` +> . +> ├── cifar-10-batches-bin # train dataset +> └── cifar-10-verify-bin # infer dataset +> ``` + +ImageNet2012 + +> Unzip the ImageNet2012 dataset to any path you want and the folder should include train and eval dataset as follows: +> +> ``` +> . +> └─dataset +> ├─ilsvrc # train dataset +> └─validation_preprocess # evaluate dataset +> ``` + +## Parameter configuration + +Parameters for both training and evaluation can be set in config.py. + +- config for vgg16, CIFAR-10 dataset + +``` +"num_classes": 10, # dataset class num +"lr": 0.01, # learning rate +"lr_init": 0.01, # initial learning rate +"lr_max": 0.1, # max learning rate +"lr_epochs": '30,60,90,120', # lr changing based epochs +"lr_scheduler": "step", # learning rate mode +"warmup_epochs": 5, # number of warmup epoch +"batch_size": 64, # batch size of input tensor +"max_epoch": 70, # only valid for taining, which is always 1 for inference +"momentum": 0.9, # momentum +"weight_decay": 5e-4, # weight decay +"loss_scale": 1.0, # loss scale +"label_smooth": 0, # label smooth +"label_smooth_factor": 0, # label smooth factor +"buffer_size": 10, # shuffle buffer size +"image_size": '224,224', # image size +"pad_mode": 'same', # pad mode for conv2d +"padding": 0, # padding value for conv2d +"has_bias": False, # whether has bias in conv2d +"batch_norm": True, # wether has batch_norm in conv2d +"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint +"initialize_mode": "XavierUniform", # conv2d init mode +"has_dropout": True # wether using Dropout layer +``` + +- config for vgg16, ImageNet2012 dataset + +``` +"num_classes": 1000, # dataset class num +"lr": 0.01, # learning rate +"lr_init": 0.01, # initial learning rate +"lr_max": 0.1, # max learning rate +"lr_epochs": '30,60,90,120', # lr changing based epochs +"lr_scheduler": "cosine_annealing", # learning rate mode +"warmup_epochs": 0, # number of warmup epoch +"batch_size": 32, # batch size of input tensor +"max_epoch": 150, # only valid for taining, which is always 1 for inference +"momentum": 0.9, # momentum +"weight_decay": 1e-4, # weight decay +"loss_scale": 1024, # loss scale +"label_smooth": 1, # label smooth +"label_smooth_factor": 0.1, # label smooth factor +"buffer_size": 10, # shuffle buffer size +"image_size": '224,224', # image size +"pad_mode": 'pad', # pad mode for conv2d +"padding": 1, # padding value for conv2d +"has_bias": True, # whether has bias in conv2d +"batch_norm": False, # wether has batch_norm in conv2d +"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint +"initialize_mode": "KaimingNormal", # conv2d init mode +"has_dropout": True # wether using Dropout layer +``` + +## Running the Example + +### Training +**Run vgg16, using CIFAR-10 dataset** + +- Training using single device(1p) +``` +python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 & +``` +The python command above will run in the background, you can view the results through the file `out.train.log`. + +After training, you'll get some checkpoint files in specified ckpt_path, default in ./output directory. + +You will get the loss value as following: +``` +# grep "loss is " out.train.log +epoch: 1 step: 781, loss is 2.093086 +epcoh: 2 step: 781, loss is 1.827582 +... +``` + +- Distribute Training +``` +sh run_distribute_train.sh rank_table.json your_data_path +``` +The above shell script will run distribute training in the background, you can view the results through the file `train_parallel[X]/log`. + +You will get the loss value as following: +``` +# grep "result: " train_parallel*/log +train_parallel0/log:epoch: 1 step: 97, loss is 1.9060308 +train_parallel0/log:epcoh: 2 step: 97, loss is 1.6003821 +... +train_parallel1/log:epoch: 1 step: 97, loss is 1.7095519 +train_parallel1/log:epcoh: 2 step: 97, loss is 1.7133579 +... +... +``` +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + + +**Run vgg16, using imagenet2012 dataset** + +- Training using single device(1p) +``` +python train.py --device_target="GPU" --dataset="imagenet2012" --is_distributed=0 --data_path=$DATA_PATH > output.train.log 2>&1 & +``` + +- Distribute Training +``` +# distributed training(8p) +bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train" +``` + + +### Evaluation + +- Do eval as follows, need to specify dataset type as "cifar10" or "imagenet2012" +``` +# when using cifar10 dataset +python eval.py --data_path=your_data_path --dataset="cifar10" --device_target="Ascend" --pre_trained=./*-70-781.ckpt > out.eval.log 2>&1 & + +# when using imagenet2012 dataset +python eval.py --data_path=your_data_path --dataset="imagenet2012" --device_target="GPU" --pre_trained=./*-150-5004.ckpt > out.eval.log 2>&1 & +``` +- If the using dataset is +The above python command will run in the background, you can view the results through the file `out.eval.log`. + +You will get the accuracy as following: +``` +# when using cifar10 dataset +# grep "result: " out.eval.log +result: {'acc': 0.92} + +# when using the imagenet2012 dataset +after allreduce eval: top1_correct=36636, tot=50000, acc=73.27% +after allreduce eval: top5_correct=45582, tot=50000, acc=91.16% +``` + +## Usage: + +### Training +``` +usage: train.py [--device_target TARGET][--data_path DATA_PATH] + [--dataset DATASET_TYPE][--is_distributed VALUE] + [--device_id DEVICE_ID][--pre_trained PRE_TRAINED] + [--ckpt_path CHECKPOINT_PATH][--ckpt_interval INTERVAL_STEP] + +parameters/options: + --device_target the training backend type, Ascend or GPU, default is Ascend. + --dataset the dataset type, cifar10 or imagenet2012. + --is_distributed the way of traing, whether do distribute traing, value can be 0 or 1. + --data_path the storage path of dataset + --device_id the device which used to train model. + --pre_trained the pretrained checkpoint file path. + --ckpt_path the path to save checkpoint. + --ckpt_interval the epoch interval for saving checkpoint. + +``` + +### Evaluation + +``` +usage: eval.py [--device_target TARGET][--data_path DATA_PATH] + [--dataset DATASET_TYPE][--pre_trained PRE_TRAINED] + [--device_id DEVICE_ID] + +parameters/options: + --device_target the evaluation backend type, Ascend or GPU, default is Ascend. + --dataset the dataset type, cifar10 or imagenet2012. + --data_path the storage path of dataset. + --device_id the device which used to evaluate model. + --pre_trained the checkpoint file path used to evaluate model. +``` + +### Distribute Training +- Train on Ascend. + +``` +Usage: sh script/run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] + +parameters/options: + RANK_TABLE_FILE HCCL configuration file path. + DATA_PATH the storage path of dataset. +``` + +- Train on GPU. +``` +Usage: bash run_distribute_train_gpu.sh [DATA_PATH] + +parameters/options: + DATA_PATH the storage path of dataset. +``` \ No newline at end of file diff --git a/model_zoo/official/cv/vgg16/eval.py b/model_zoo/official/cv/vgg16/eval.py new file mode 100644 index 0000000000..86ce02187b --- /dev/null +++ b/model_zoo/official/cv/vgg16/eval.py @@ -0,0 +1,212 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Eval""" +import os +import time +import argparse +import datetime +import glob +import numpy as np +import mindspore.nn as nn + +from mindspore import Tensor, context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + +from src.utils.logging import get_logger +from src.vgg import vgg16 +from src.dataset import vgg_create_dataset +from src.dataset import classification_dataset + + +class ParameterReduce(nn.Cell): + """ParameterReduce""" + def __init__(self): + super(ParameterReduce, self).__init__() + self.cast = P.Cast() + self.reduce = P.AllReduce() + + def construct(self, x): + one = self.cast(F.scalar_to_array(1.0), mstype.float32) + out = x * one + ret = self.reduce(out) + return ret + + +def parse_args(cloud_args=None): + """parse_args""" + parser = argparse.ArgumentParser('mindspore classification test') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + # dataset related + parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10") + parser.add_argument('--data_path', type=str, default='', help='eval data dir') + parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') + # network related + parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') + parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. ' + 'If it is a direction, it will test all ckpt') + + # logging related + parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + + args_opt = parser.parse_args() + args_opt = merge_args(args_opt, cloud_args) + + if args_opt.dataset == "cifar10": + from src.config import cifar_cfg as cfg + else: + from src.config import imagenet_cfg as cfg + + args_opt.image_size = cfg.image_size + args_opt.num_classes = cfg.num_classes + args_opt.per_batch_size = cfg.batch_size + args_opt.momentum = cfg.momentum + args_opt.weight_decay = cfg.weight_decay + args_opt.buffer_size = cfg.buffer_size + args_opt.pad_mode = cfg.pad_mode + args_opt.padding = cfg.padding + args_opt.has_bias = cfg.has_bias + args_opt.batch_norm = cfg.batch_norm + args_opt.initialize_mode = cfg.initialize_mode + args_opt.has_dropout = cfg.has_dropout + + args_opt.image_size = list(map(int, args_opt.image_size.split(','))) + + return args_opt + + +def get_top5_acc(top5_arg, gt_class): + sub_count = 0 + for top5, gt in zip(top5_arg, gt_class): + if gt in top5: + sub_count += 1 + return sub_count + + +def merge_args(args, cloud_args): + """merge_args""" + args_dict = vars(args) + if isinstance(cloud_args, dict): + for key in cloud_args.keys(): + val = cloud_args[key] + if key in args_dict and val: + arg_type = type(args_dict[key]) + if arg_type is not type(None): + val = arg_type(val) + args_dict[key] = val + return args + + +def test(cloud_args=None): + """test""" + args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + if args.dataset == "cifar10": + net = vgg16(num_classes=args.num_classes, args=args) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum, + weight_decay=args.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + param_dict = load_checkpoint(args.pre_trained) + load_param_into_net(net, param_dict) + net.set_train(False) + dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False) + res = model.eval(dataset) + print("result: ", res) + else: + # network + args.logger.important_info('start create network') + if os.path.isdir(args.pre_trained): + models = list(glob.glob(os.path.join(args.pre_trained, '*.ckpt'))) + print(models) + if args.graph_ckpt: + f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]) + else: + f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1]) + args.models = sorted(models, key=f) + else: + args.models = [args.pre_trained,] + + for model in args.models: + dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size) + eval_dataloader = dataset.create_tuple_iterator() + network = vgg16(args.num_classes, args, phase="test") + + # pre_trained + load_param_into_net(network, load_checkpoint(model)) + network.add_flags_recursive(fp16=True) + + img_tot = 0 + top1_correct = 0 + top5_correct = 0 + + network.set_train(False) + t_end = time.time() + it = 0 + for data, gt_classes in eval_dataloader: + output = network(Tensor(data, mstype.float32)) + output = output.asnumpy() + + top1_output = np.argmax(output, (-1)) + top5_output = np.argsort(output)[:, -5:] + + t1_correct = np.equal(top1_output, gt_classes).sum() + top1_correct += t1_correct + top5_correct += get_top5_acc(top5_output, gt_classes) + img_tot += args.per_batch_size + + if args.rank == 0 and it == 0: + t_end = time.time() + it = 1 + if args.rank == 0: + time_used = time.time() - t_end + fps = (img_tot - args.per_batch_size) * args.group_size / time_used + args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps)) + results = [[top1_correct], [top5_correct], [img_tot]] + args.logger.info('before results={}'.format(results)) + results = np.array(results) + + args.logger.info('after results={}'.format(results)) + top1_correct = results[0, 0] + top5_correct = results[1, 0] + img_tot = results[2, 0] + acc1 = 100.0 * top1_correct / img_tot + acc5 = 100.0 * top5_correct / img_tot + args.logger.info('after allreduce eval: top1_correct={}, tot={},' + 'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1)) + args.logger.info('after allreduce eval: top5_correct={}, tot={},' + 'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5)) + + +if __name__ == "__main__": + test() diff --git a/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh b/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh new file mode 100755 index 0000000000..1a9e022fd2 --- /dev/null +++ b/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]" +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILEH=$1 is not a file" +exit 1 +fi + +if [ ! -d $2 ] +then + echo "error: DATA_PATH=$2 is not a directory" +exit 1 +fi + +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$1 + +for((i=0;i env.log + python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log & + cd .. +done \ No newline at end of file diff --git a/model_zoo/official/cv/vgg16/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/vgg16/scripts/run_distribute_train_gpu.sh new file mode 100644 index 0000000000..51be33a53a --- /dev/null +++ b/model_zoo/official/cv/vgg16/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_distribute_train_gpu.sh DATA_PATH" +echo "for example: bash run_distribute_train_gpu.sh /path/ImageNet2012/train" +echo "==============================================================================================================" + +DATA_PATH=$1 + +mpirun -n 8 python train.py \ + --device_target="GPU" \ + --dataset="imagenet2012" \ + --is_distributed=1 \ + --data_path=$DATA_PATH > output.train.log 2>&1 & diff --git a/model_zoo/deepfm/__init__.py b/model_zoo/official/cv/vgg16/src/__init__.py similarity index 100% rename from model_zoo/deepfm/__init__.py rename to model_zoo/official/cv/vgg16/src/__init__.py diff --git a/model_zoo/official/cv/vgg16/src/config.py b/model_zoo/official/cv/vgg16/src/config.py new file mode 100755 index 0000000000..0861897ed2 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/config.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as edict + +# config for vgg16, cifar10 +cifar_cfg = edict({ + "num_classes": 10, + "lr": 0.01, + "lr_init": 0.01, + "lr_max": 0.1, + "lr_epochs": '30,60,90,120', + "lr_scheduler": "step", + "warmup_epochs": 5, + "batch_size": 64, + "max_epoch": 70, + "momentum": 0.9, + "weight_decay": 5e-4, + "loss_scale": 1.0, + "label_smooth": 0, + "label_smooth_factor": 0, + "buffer_size": 10, + "image_size": '224,224', + "pad_mode": 'same', + "padding": 0, + "has_bias": False, + "batch_norm": True, + "keep_checkpoint_max": 10, + "initialize_mode": "XavierUniform", + "has_dropout": False +}) + +# config for vgg16, imagenet2012 +imagenet_cfg = edict({ + "num_classes": 1000, + "lr": 0.01, + "lr_init": 0.01, + "lr_max": 0.1, + "lr_epochs": '30,60,90,120', + "lr_scheduler": 'cosine_annealing', + "warmup_epochs": 0, + "batch_size": 32, + "max_epoch": 150, + "momentum": 0.9, + "weight_decay": 1e-4, + "loss_scale": 1024, + "label_smooth": 1, + "label_smooth_factor": 0.1, + "buffer_size": 10, + "image_size": '224,224', + "pad_mode": 'pad', + "padding": 1, + "has_bias": True, + "batch_norm": False, + "keep_checkpoint_max": 10, + "initialize_mode": "KaimingNormal", + "has_dropout": True +}) diff --git a/model_zoo/resnet/src/crossentropy.py b/model_zoo/official/cv/vgg16/src/crossentropy.py similarity index 100% rename from model_zoo/resnet/src/crossentropy.py rename to model_zoo/official/cv/vgg16/src/crossentropy.py diff --git a/model_zoo/official/cv/vgg16/src/dataset.py b/model_zoo/official/cv/vgg16/src/dataset.py new file mode 100644 index 0000000000..c510b49497 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/dataset.py @@ -0,0 +1,197 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +dataset processing. +""" +import os +from mindspore.common import dtype as mstype +import mindspore.dataset as de +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision +from PIL import Image, ImageFile +from src.utils.sampler import DistributedSampler + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True): + """Data operations.""" + de.config.set_seed(1) + data_dir = os.path.join(data_home, "cifar-10-batches-bin") + if not training: + data_dir = os.path.join(data_home, "cifar-10-verify-bin") + + data_set = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) + + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + resize_op = vision.Resize(image_size) # interpolation default BILINEAR + rescale_op = vision.Rescale(rescale, shift) + normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, + changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + # apply shuffle operations + data_set = data_set.shuffle(buffer_size=10) + + # apply batch operations + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + + return data_set + + +def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_size=1, + mode='train', + input_mode='folder', + root='', + num_parallel_workers=None, + shuffle=None, + sampler=None, + repeat_num=1, + class_indexing=None, + drop_remainder=True, + transform=None, + target_transform=None): + """ + A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt". + If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images + are written into a textfile. + + Args: + data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"". + Or path of the textfile that contains every image's path of the dataset. + image_size (str): Size of the input images. + per_batch_size (int): the batch size of evey step during training. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided + into (default=None). + mode (str): "train" or others. Default: " train". + input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder". + root (str): the images path for "input_mode="txt"". Default: " ". + num_parallel_workers (int): Number of workers to read the data. Default: None. + shuffle (bool): Whether or not to perform shuffle on the dataset + (default=None, performs shuffle). + sampler (Sampler): Object used to choose samples from the dataset. Default: None. + repeat_num (int): the num of repeat dataset. + class_indexing (dict): A str-to-int mapping from folder name to index + (default=None, the folder names will be sorted + alphabetically and each class will be given a + unique index starting from 0). + + Examples: + >>> from mindvision.common.datasets.classification import classification_dataset + >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images + >>> dataset_dir = "/path/to/imagefolder_directory" + >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], + >>> per_batch_size=64, rank=0, group_size=4) + >>> # Path of the textfile that contains every image's path of the dataset. + >>> dataset_dir = "/path/to/dataset/images/train.txt" + >>> images_dir = "/path/to/dataset/images" + >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], + >>> per_batch_size=64, rank=0, group_size=4, + >>> input_mode="txt", root=images_dir) + """ + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + if transform is None: + if mode == 'train': + transform_img = [ + vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0)), + vision.RandomHorizontalFlip(prob=0.5), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = [ + vision.Decode(), + vision.Resize((256, 256)), + vision.CenterCrop(image_size), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = transform + + if target_transform is None: + transform_label = [C.TypeCast(mstype.int32)] + else: + transform_label = target_transform + + if input_mode == 'folder': + de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers, + shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, + num_shards=group_size, shard_id=rank) + else: + dataset = TxtDataset(root, data_dir) + sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) + de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) + de_dataset.set_dataset_size(len(sampler)) + + de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) + de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) + + columns_to_project = ["image", "label"] + de_dataset = de_dataset.project(columns=columns_to_project) + + de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder) + de_dataset = de_dataset.repeat(repeat_num) + + return de_dataset + + +class TxtDataset: + """ + create txt dataset. + + Args: + Returns: + de_dataset. + """ + def __init__(self, root, txt_name): + super(TxtDataset, self).__init__() + self.imgs = [] + self.labels = [] + fin = open(txt_name, "r") + for line in fin: + img_name, label = line.strip().split(' ') + self.imgs.append(os.path.join(root, img_name)) + self.labels.append(int(label)) + fin.close() + + def __getitem__(self, index): + img = Image.open(self.imgs[index]).convert('RGB') + return img, self.labels[index] + + def __len__(self): + return len(self.imgs) diff --git a/model_zoo/official/cv/vgg16/src/linear_warmup.py b/model_zoo/official/cv/vgg16/src/linear_warmup.py new file mode 100644 index 0000000000..dc926e5ce1 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/linear_warmup.py @@ -0,0 +1,23 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +linear warm up learning rate. +""" + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr diff --git a/model_zoo/official/cv/vgg16/src/utils/logging.py b/model_zoo/official/cv/vgg16/src/utils/logging.py new file mode 100644 index 0000000000..ac37bec4ec --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/logging.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +get logger. +""" +import logging +import os +import sys +from datetime import datetime + +class LOGGER(logging.Logger): + """ + set up logging file. + + Args: + logger_name (string): logger name. + log_dir (string): path of logger. + + Returns: + string, logger path + """ + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """set up log file""" + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, rank): + logger = LOGGER("mindversion", rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/vgg16/src/utils/sampler.py b/model_zoo/official/cv/vgg16/src/utils/sampler.py new file mode 100644 index 0000000000..5b68f8325e --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/sampler.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +choose samples from the dataset +""" +import math +import numpy as np + +class DistributedSampler(): + """ + sampling the dataset. + + Args: + Returns: + num_samples, number of samples. + """ + def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): + self.dataset = dataset + self.rank = rank + self.group_size = group_size + self.dataset_length = len(self.dataset) + self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size)) + self.total_size = self.num_samples * self.group_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + self.seed = (self.seed + 1) & 0xffffffff + np.random.seed(self.seed) + indices = np.random.permutation(self.dataset_length).tolist() + else: + indices = list(range(len(self.dataset_length))) + + indices += indices[:(self.total_size - len(indices))] + indices = indices[self.rank::self.group_size] + return iter(indices) + + def __len__(self): + return self.num_samples + \ No newline at end of file diff --git a/model_zoo/official/cv/vgg16/src/utils/util.py b/model_zoo/official/cv/vgg16/src/utils/util.py new file mode 100644 index 0000000000..6f84045a89 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/util.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function.""" + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] diff --git a/model_zoo/official/cv/vgg16/src/utils/var_init.py b/model_zoo/official/cv/vgg16/src/utils/var_init.py new file mode 100644 index 0000000000..185072d441 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/utils/var_init.py @@ -0,0 +1,214 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Initialize. +""" +import math +from functools import reduce +import numpy as np +import mindspore.nn as nn +from mindspore.common import initializer as init + +def _calculate_gain(nonlinearity, param=None): + r""" + Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function + param: optional parameter for the non-linear function + + Examples: + >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + +def _assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def _calculate_in_and_out(arr): + """ + Calculate n_in and n_out. + + Args: + arr (Array): Input array. + + Returns: + Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. + """ + dim = len(arr.shape) + if dim < 2: + raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") + + n_in = arr.shape[1] + n_out = arr.shape[0] + + if dim > 2: + counter = reduce(lambda x, y: x * y, arr.shape[2:]) + n_in *= counter + n_out *= counter + return n_in, n_out + +def _select_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_in_and_out(array) + return fan_in if mode == 'fan_in' else fan_out + +class KaimingInit(init.Initializer): + r""" + Base Class. Initialize the array with He kaiming algorithm. + + Args: + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function, recommended to use only with + ``'relu'`` or ``'leaky_relu'`` (default). + """ + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingInit, self).__init__() + self.mode = mode + self.gain = _calculate_gain(nonlinearity, a) + def _initialize(self, arr): + pass + + +class KaimingUniform(KaimingInit): + r""" + Initialize the array with He kaiming uniform algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.uniform(-bound, bound, arr.shape) + + _assignment(arr, data) + + +class KaimingNormal(KaimingInit): + r""" + Initialize the array with He kaiming normal algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + std = self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +def default_recurisive_init(custom_cell): + """default_recurisive_init""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass diff --git a/model_zoo/official/cv/vgg16/src/vgg.py b/model_zoo/official/cv/vgg16/src/vgg.py new file mode 100644 index 0000000000..bd873e4d5c --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/vgg.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Image classifiation. +""" +import math +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.common import initializer as init +from mindspore.common.initializer import initializer +from .utils.var_init import default_recurisive_init, KaimingNormal + + +def _make_layer(base, args, batch_norm): + """Make stage network of VGG.""" + layers = [] + in_channels = 3 + for v in base: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + weight_shape = (v, in_channels, 3, 3) + weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() + if args.initialize_mode == "KaimingNormal": + weight = 'normal' + conv2d = nn.Conv2d(in_channels=in_channels, + out_channels=v, + kernel_size=3, + padding=args.padding, + pad_mode=args.pad_mode, + has_bias=args.has_bias, + weight_init=weight) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] + else: + layers += [conv2d, nn.ReLU()] + in_channels = v + return nn.SequentialCell(layers) + + +class Vgg(nn.Cell): + """ + VGG network definition. + + Args: + base (list): Configuration for different layers, mainly the channel number of Conv layer. + num_classes (int): Class numbers. Default: 1000. + batch_norm (bool): Whether to do the batchnorm. Default: False. + batch_size (int): Batch size. Default: 1. + + Returns: + Tensor, infer output tensor. + + Examples: + >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + >>> num_classes=1000, batch_norm=False, batch_size=1) + """ + + def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): + super(Vgg, self).__init__() + _ = batch_size + self.layers = _make_layer(base, args, batch_norm=batch_norm) + self.flatten = nn.Flatten() + dropout_ratio = 0.5 + if not args.has_dropout or phase == "test": + dropout_ratio = 1.0 + self.classifier = nn.SequentialCell([ + nn.Dense(512 * 7 * 7, 4096), + nn.ReLU(), + nn.Dropout(dropout_ratio), + nn.Dense(4096, 4096), + nn.ReLU(), + nn.Dropout(dropout_ratio), + nn.Dense(4096, num_classes)]) + if args.initialize_mode == "KaimingNormal": + default_recurisive_init(self) + self.custom_init_weight() + + def construct(self, x): + x = self.layers(x) + x = self.flatten(x) + x = self.classifier(x) + return x + + def custom_init_weight(self): + """ + Init the weight of Conv2d and Dense in the net. + """ + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer( + KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype) + if cell.bias is not None: + cell.bias.default_input = init.initializer( + 'zeros', cell.bias.shape, cell.bias.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer( + init.Normal(0.01), cell.weight.shape, cell.weight.dtype) + if cell.bias is not None: + cell.bias.default_input = init.initializer( + 'zeros', cell.bias.shape, cell.bias.dtype) + + +cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg16(num_classes=1000, args=None, phase="train"): + """ + Get Vgg16 neural network with batch normalization. + + Args: + num_classes (int): Class numbers. Default: 1000. + args(namespace): param for net init. + phase(str): train or test mode. + + Returns: + Cell, cell instance of Vgg16 neural network with batch normalization. + + Examples: + >>> vgg16(num_classes=1000, args=args) + """ + + net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) + return net diff --git a/model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py b/model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py new file mode 100644 index 0000000000..5d9fce9af4 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +warm up cosine annealing learning rate. +""" +import math +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """warm up cosine annealing learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/official/cv/vgg16/src/warmup_step_lr.py b/model_zoo/official/cv/vgg16/src/warmup_step_lr.py new file mode 100644 index 0000000000..2ffaa493e9 --- /dev/null +++ b/model_zoo/official/cv/vgg16/src/warmup_step_lr.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +warm up step learning rate. +""" +from collections import Counter +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """Set learning rate.""" + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr_value = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr_value = float(lr_max) * base * base + if lr_value < 0.0: + lr_value = 0.0 + lr_each_step.append(lr_value) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """warmup_step_lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) diff --git a/model_zoo/official/cv/vgg16/train.py b/model_zoo/official/cv/vgg16/train.py new file mode 100644 index 0000000000..2ddf89e977 --- /dev/null +++ b/model_zoo/official/cv/vgg16/train.py @@ -0,0 +1,293 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train vgg16 example on cifar10######################## +python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID +""" +import argparse +import datetime +import time +import os +import random + +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig +from mindspore.train.model import Model, ParallelMode +from mindspore.train.serialization import load_param_into_net, load_checkpoint +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from src.dataset import vgg_create_dataset +from src.dataset import classification_dataset + +from src.crossentropy import CrossEntropy +from src.warmup_step_lr import warmup_step_lr +from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr +from src.warmup_step_lr import lr_steps +from src.utils.logging import get_logger +from src.utils.util import get_param_groups +from src.vgg import vgg16 + + +random.seed(1) +np.random.seed(1) + + +class ProgressMonitor(Callback): + """monitor loss and time""" + def __init__(self, args_param): + super(ProgressMonitor, self).__init__() + self.me_epoch_start_time = 0 + self.me_epoch_start_step_num = 0 + self.args = args_param + self.ckpt_history = [] + + def begin(self, run_context): + self.args.logger.info('start network train...') + + def epoch_begin(self, run_context): + pass + + def epoch_end(self, run_context): + """ + Called after each epoch finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + cb_params = run_context.original_args() + me_step = cb_params.cur_step_num - 1 + + real_epoch = me_step // self.args.steps_per_epoch + time_used = time.time() - self.me_epoch_start_time + fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used + self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}' + 'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean)) + + if self.args.rank_save_ckpt_flag: + import glob + ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt')) + for ckpt in ckpts: + ckpt_fn = os.path.basename(ckpt) + if not ckpt_fn.startswith('{}-'.format(self.args.rank)): + continue + if ckpt in self.ckpt_history: + continue + self.ckpt_history.append(ckpt) + self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},' + 'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn)) + + self.me_epoch_start_step_num = me_step + self.me_epoch_start_time = time.time() + + def step_begin(self, run_context): + pass + + def step_end(self, run_context, *me_args): + pass + + def end(self, run_context): + self.args.logger.info('end network train...') + + +def parse_args(cloud_args=None): + """parameters""" + parser = argparse.ArgumentParser('mindspore classification training') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') + + # dataset related + parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10") + parser.add_argument('--data_path', type=str, default='', help='train data dir') + + # network related + parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load') + parser.add_argument('--lr_gamma', type=float, default=0.1, + help='decrease lr by a factor of exponential lr_scheduler') + parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') + parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') + + # logging and checkpoint related + parser.add_argument('--log_interval', type=int, default=100, help='logging interval') + parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') + parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval') + parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') + + # distributed related + parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + args_opt = parser.parse_args() + args_opt = merge_args(args_opt, cloud_args) + + if args_opt.dataset == "cifar10": + from src.config import cifar_cfg as cfg + else: + from src.config import imagenet_cfg as cfg + + args_opt.label_smooth = cfg.label_smooth + args_opt.label_smooth_factor = cfg.label_smooth_factor + args_opt.lr_scheduler = cfg.lr_scheduler + args_opt.loss_scale = cfg.loss_scale + args_opt.max_epoch = cfg.max_epoch + args_opt.warmup_epochs = cfg.warmup_epochs + args_opt.lr = cfg.lr + args_opt.lr_init = cfg.lr_init + args_opt.lr_max = cfg.lr_max + args_opt.momentum = cfg.momentum + args_opt.weight_decay = cfg.weight_decay + args_opt.per_batch_size = cfg.batch_size + args_opt.num_classes = cfg.num_classes + args_opt.buffer_size = cfg.buffer_size + args_opt.ckpt_save_max = cfg.keep_checkpoint_max + args_opt.pad_mode = cfg.pad_mode + args_opt.padding = cfg.padding + args_opt.has_bias = cfg.has_bias + args_opt.batch_norm = cfg.batch_norm + args_opt.initialize_mode = cfg.initialize_mode + args_opt.has_dropout = cfg.has_dropout + + args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(','))) + args_opt.image_size = list(map(int, cfg.image_size.split(','))) + + return args_opt + + +def merge_args(args_opt, cloud_args): + """dictionary""" + args_dict = vars(args_opt) + if isinstance(cloud_args, dict): + for key_arg in cloud_args.keys(): + val = cloud_args[key_arg] + if key_arg in args_dict and val: + arg_type = type(args_dict[key_arg]) + if arg_type is not None: + val = arg_type(val) + args_dict[key_arg] = val + return args_opt + + +if __name__ == '__main__': + args = parse_args() + + device_num = int(os.environ.get("DEVICE_NUM", 1)) + if args.is_distributed: + if args.device_target == "Ascend": + init() + context.set_context(device_id=args.device_id) + elif args.device_target == "GPU": + init("nccl") + + args.rank = get_rank() + args.group_size = get_group_size() + device_num = args.group_size + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + else: + context.set_context(device_id=args.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + + if args.dataset == "cifar10": + dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size) + else: + dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, + args.rank, args.group_size) + + batch_num = dataset.get_dataset_size() + args.steps_per_epoch = dataset.get_dataset_size() + args.logger.save_args(args) + + # network + args.logger.important_info('start create network') + + # get network and init + network = vgg16(args.num_classes, args) + + # pre_trained + if args.pre_trained: + load_param_into_net(network, load_checkpoint(args.pre_trained)) + + # lr scheduler + if args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'step': + lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs, + total_epochs=args.max_epoch, steps_per_epoch=batch_num) + else: + raise NotImplementedError(args.lr_scheduler) + + # optimizer + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + if args.dataset == "cifar10": + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + else: + if not args.label_smooth: + args.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) + + loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") + + # checkpoint save + progress_cb = ProgressMonitor(args) + callbacks = [progress_cb,] + if args.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, + keep_checkpoint_max=args.ckpt_save_max) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=args.outputs_dir, + prefix='{}'.format(args.rank)) + callbacks.append(ckpt_cb) + + model.train(args.max_epoch, dataset, callbacks=callbacks) diff --git a/model_zoo/official/cv/warpctc/README.md b/model_zoo/official/cv/warpctc/README.md new file mode 100644 index 0000000000..9f59e89f60 --- /dev/null +++ b/model_zoo/official/cv/warpctc/README.md @@ -0,0 +1,150 @@ +# Warpctc Example + +## Description + +These is an example of training Warpctc with self-generated captcha image dataset in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Generate captcha images. + +> The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately. +> ``` +> $ cd scripts +> $ sh run_process_data.sh +> +> # after execution, you will find the dataset like the follows: +> . +> └─warpctc +> └─data +> ├─ train # train dataset +> └─ test # evaluate dataset +> ... + + +## Structure + +```shell +. +└──warpctc + ├── README.md + ├── script + ├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs) + ├── run_distribute_train_for_gpu.sh # launch distributed training in GPU + ├── run_eval.sh # launch evaluation + ├── run_process_data.sh # launch dataset generation + └── run_standalone_train.sh # launch standalone training(1 pcs) + ├── src + ├── config.py # parameter configuration + ├── dataset.py # data preprocessing + ├── loss.py # ctcloss definition + ├── lr_generator.py # generate learning rate for each step + ├── metric.py # accuracy metric for warpctc network + ├── warpctc.py # warpctc network definition + └── warpctc_for_train.py # warp network with grad, loss and gradient clip + ├── eval.py # eval net + ├── process_data.py # dataset generation script + └── train.py # train net +``` + + +## Parameter configuration + +Parameters for both training and evaluation can be set in config.py. + +``` +"max_captcha_digits": 4, # max number of digits in each +"captcha_width": 160, # width of captcha images +"captcha_height": 64, # height of capthca images +"batch_size": 64, # batch size of input tensor +"epoch_size": 30, # only valid for taining, which is always 1 for inference +"hidden_size": 512, # hidden size in LSTM layers +"learning_rate": 0.01, # initial learning rate +"momentum": 0.9 # momentum of SGD optimizer +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_steps": 97, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step +"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./checkpoint", # path to save checkpoint +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training in Ascend +Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] + +# distributed training in GPU +Usage: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH] + +# standalone training +Usage: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM] +``` + + +#### Launch + +``` +# distribute training example in Ascend +bash run_distribute_train.sh rank_table.json ../data/train + +# distribute training example in GPU +bash run_distribute_train_for_gpu.sh 8 ../data/train + +# standalone training example in Ascend +bash run_standalone_train.sh ../data/train Ascend + +# standalone training example in GPU +bash run_standalone_train.sh ../data/train GPU +``` + +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + +#### Result + +Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. + +``` +# distribute training result(8 pcs) +Epoch: [ 1/ 30], step: [ 97/ 97], loss: [0.5853/0.5853], time: [376813.7944] +Epoch: [ 2/ 30], step: [ 97/ 97], loss: [0.4007/0.4007], time: [75882.0951] +Epoch: [ 3/ 30], step: [ 97/ 97], loss: [0.0921/0.0921], time: [75150.9385] +Epoch: [ 4/ 30], step: [ 97/ 97], loss: [0.1472/0.1472], time: [75135.0193] +Epoch: [ 5/ 30], step: [ 97/ 97], loss: [0.0186/0.0186], time: [75199.5809] +... +``` + + +### Evaluation + +#### Usage + +``` +# evaluation +Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM] +``` + +#### Launch + +``` +# evaluation example in Ascend +bash run_eval.sh ../data/test warpctc-30-97.ckpt Ascend + +# evaluation example in GPU +bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. + +``` +result: {'WarpCTCAccuracy': 0.9901472929936306} +``` diff --git a/model_zoo/official/cv/warpctc/eval.py b/model_zoo/official/cv/warpctc/eval.py new file mode 100755 index 0000000000..bf8e4e9552 --- /dev/null +++ b/model_zoo/official/cv/warpctc/eval.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Warpctc evaluation""" +import os +import math as m +import random +import argparse +import numpy as np +from mindspore import context +from mindspore import dataset as de +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.loss import CTCLoss, CTCLossV2 +from src.config import config as cf +from src.dataset import create_dataset +from src.warpctc import StackedRNN, StackedRNNForGPU +from src.metric import WarpCTCAccuracy + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="Warpctc training") +parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") +parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") +parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='Running platform, choose from Ascend, GPU, and default is Ascend.') +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) +if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + +if __name__ == '__main__': + max_captcha_digits = cf.max_captcha_digits + input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, + batch_size=cf.batch_size, + device_target=args_opt.platform) + step_size = dataset.get_dataset_size() + if args_opt.platform == 'Ascend': + loss = CTCLoss(max_sequence_length=cf.captcha_width, + max_label_length=max_captcha_digits, + batch_size=cf.batch_size) + net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + else: + loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size) + net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + + # load checkpoint + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + # define model + model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)}) + # start evaluation + res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend') + print("result:", res, flush=True) diff --git a/model_zoo/warpctc/process_data.py b/model_zoo/official/cv/warpctc/process_data.py similarity index 100% rename from model_zoo/warpctc/process_data.py rename to model_zoo/official/cv/warpctc/process_data.py diff --git a/model_zoo/official/cv/warpctc/scripts/run_distribute_train.sh b/model_zoo/official/cv/warpctc/scripts/run_distribute_train.sh new file mode 100755 index 0000000000..bf607d4fc2 --- /dev/null +++ b/model_zoo/official/cv/warpctc/scripts/run_distribute_train.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ ! -f $PATH1 ]; then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" + exit 1 +fi + +if [ ! -d $PATH2 ]; then + echo "error: DATASET_PATH=$PATH2 is not a directory" + exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +for ((i = 0; i < ${DEVICE_NUM}; i++)); do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env >env.log + python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 & + cd .. +done diff --git a/model_zoo/official/cv/warpctc/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/warpctc/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 0000000000..86d951f663 --- /dev/null +++ b/model_zoo/official/cv/warpctc/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +RANK_SIZE=$1 +DATASET_PATH=$(get_real_path $2) + +if [ ! -d $DATASET_PATH ]; then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" + exit 1 +fi + +if [ -d "distribute_train" ]; then + rm -rf ./distribute_train +fi + +mkdir ./distribute_train +cp ../*.py ./distribute_train +cp -r ../src ./distribute_train +cd ./distribute_train || exit + +mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py \ + --dataset_path=$DATASET_PATH \ + --platform=GPU \ + --run_distribute > log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/warpctc/scripts/run_eval.sh b/model_zoo/official/cv/warpctc/scripts/run_eval.sh new file mode 100755 index 0000000000..cc0e3ce252 --- /dev/null +++ b/model_zoo/official/cv/warpctc/scripts/run_eval.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ]; then + echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) +PLATFORM=$3 + +if [ ! -d $PATH1 ]; then + echo "error: DATASET_PATH=$PATH1 is not a directory" + exit 1 +fi + +if [ ! -f $PATH2 ]; then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" + exit 1 +fi + +run_ascend() { + ulimit -u unlimited + export DEVICE_NUM=1 + export DEVICE_ID=0 + export RANK_SIZE=$DEVICE_NUM + export RANK_ID=0 + + if [ -d "eval" ]; then + rm -rf ./eval + fi + mkdir ./eval + cp ../*.py ./eval + cp -r ../src ./eval + cd ./eval || exit + env >env.log + echo "start evaluation for device $DEVICE_ID" + python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 & + cd .. +} + +run_gpu() { + if [ -d "eval" ]; then + rm -rf ./eval + fi + mkdir ./eval + cp ../*.py ./eval + cp -r ../src ./eval + cd ./eval || exit + env >env.log + python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 & + cd .. +} + +if [ "Ascend" == $PLATFORM ]; then + run_ascend $PATH1 $PATH2 +elif [ "GPU" == $PLATFORM ]; then + run_gpu $PATH1 $PATH2 +else + echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." +fi + diff --git a/model_zoo/official/cv/warpctc/scripts/run_process_data.sh b/model_zoo/official/cv/warpctc/scripts/run_process_data.sh new file mode 100755 index 0000000000..177b63512e --- /dev/null +++ b/model_zoo/official/cv/warpctc/scripts/run_process_data.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +CUR_PATH=$(dirname $PWD/$0) +cd $CUR_PATH/../ && + python process_data.py && + cd - &> /dev/null || exit diff --git a/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh b/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh new file mode 100755 index 0000000000..863683dd00 --- /dev/null +++ b/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PLATFORM=$2 + +if [ ! -d $PATH1 ]; then + echo "error: DATASET_PATH=$PATH1 is not a directory" + exit 1 +fi + +run_ascend() { + ulimit -u unlimited + export DEVICE_NUM=1 + export DEVICE_ID=0 + export RANK_ID=0 + export RANK_SIZE=1 + + echo "start training for device $DEVICE_ID" + env >env.log + python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 & + cd .. +} + +run_gpu() { + env >env.log + python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 & + cd .. +} + +if [ -d "train" ]; then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cd ./train || exit + +if [ "Ascend" == $PLATFORM ]; then + run_ascend $PATH1 +elif [ "GPU" == $PLATFORM ]; then + run_gpu $PATH1 +else + echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." +fi \ No newline at end of file diff --git a/model_zoo/official/cv/warpctc/src/config.py b/model_zoo/official/cv/warpctc/src/config.py new file mode 100755 index 0000000000..6c937a47b4 --- /dev/null +++ b/model_zoo/official/cv/warpctc/src/config.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Network parameters.""" +from easydict import EasyDict + +config = EasyDict({ + "max_captcha_digits": 4, + "captcha_width": 160, + "captcha_height": 64, + "batch_size": 64, + "epoch_size": 30, + "hidden_size": 512, + "learning_rate": 0.01, + "momentum": 0.9, + "save_checkpoint": True, + "save_checkpoint_steps": 97, + "keep_checkpoint_max": 30, + "save_checkpoint_path": "./checkpoint", +}) diff --git a/model_zoo/official/cv/warpctc/src/dataset.py b/model_zoo/official/cv/warpctc/src/dataset.py new file mode 100755 index 0000000000..784bb4d84d --- /dev/null +++ b/model_zoo/official/cv/warpctc/src/dataset.py @@ -0,0 +1,98 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dataset preprocessing.""" +import os +import math as m +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as c +import mindspore.dataset.transforms.vision.c_transforms as vc +from PIL import Image +from src.config import config as cf + + +class _CaptchaDataset: + """ + create train or evaluation dataset for warpctc + + Args: + img_root_dir(str): root path of images + max_captcha_digits(int): max number of digits in images. + device_target(str): platform of training, support Ascend and GPU. + """ + + def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'): + if not os.path.exists(img_root_dir): + raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) + self.img_root_dir = img_root_dir + self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] + self.max_captcha_digits = max_captcha_digits + self.target = device_target + self.blank = 10 if self.target == 'Ascend' else 0 + self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names] + + def __len__(self): + return len(self.img_names) + + def __getitem__(self, item): + img_name = self.img_names[item] + im = Image.open(os.path.join(self.img_root_dir, img_name)) + r, g, b = im.split() + im = Image.merge("RGB", (b, g, r)) + image = np.array(im) + label_str = os.path.splitext(img_name)[0] + label_str = label_str[label_str.find('-') + 1:] + if self.target == 'Ascend': + label = [int(i) for i in label_str] + label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) + else: + label = [int(i) + 1 for i in label_str] + length = len(label) + label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) + label.append(length) + label = np.array(label) + return image, label + + +def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): + """ + create train or evaluation dataset for warpctc + + Args: + dataset_path(int): dataset path + batch_size(int): batch size of generated dataset, default is 1 + num_shards(int): number of devices + shard_id(int): rank id + device_target(str): platform of training, support Ascend and GPU + """ + + dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target) + ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id) + ds.set_dataset_size(m.ceil(len(dataset) / num_shards)) + image_trans = [ + vc.Rescale(1.0 / 255.0, 0.0), + vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), + vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), + vc.HWC2CHW() + ] + label_trans = [ + c.TypeCast(mstype.int32) + ] + ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) + ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) + + ds = ds.batch(batch_size, drop_remainder=True) + return ds diff --git a/model_zoo/official/cv/warpctc/src/loss.py b/model_zoo/official/cv/warpctc/src/loss.py new file mode 100755 index 0000000000..4bbb37eaa6 --- /dev/null +++ b/model_zoo/official/cv/warpctc/src/loss.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CTC Loss.""" +import numpy as np +from mindspore.nn.loss.loss import _Loss +from mindspore import Tensor, Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +class CTCLoss(_Loss): + """ + CTCLoss definition + + Args: + max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image + width + max_label_length(int): max number of label length for each input. + batch_size(int): batch size of input logits + """ + + def __init__(self, max_sequence_length, max_label_length, batch_size): + super(CTCLoss, self).__init__() + self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32), + name="sequence_length") + labels_indices = [] + for i in range(batch_size): + for j in range(max_label_length): + labels_indices.append([i, j]) + self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices") + self.reshape = P.Reshape() + self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True) + + def construct(self, logit, label): + labels_values = self.reshape(label, (-1,)) + loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) + return loss + + +class CTCLossV2(_Loss): + """ + CTCLoss definition + + Args: + max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width + batch_size(int): batch size of input logits + """ + + def __init__(self, max_sequence_length, batch_size): + super(CTCLossV2, self).__init__() + self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32) + self.reshape = P.Reshape() + self.ctc_loss = P.CTCLossV2() + + def construct(self, logit, label): + labels_values = label[:, :-1] + labels_length = label[:, -1] + loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length) + return loss diff --git a/model_zoo/warpctc/src/lr_schedule.py b/model_zoo/official/cv/warpctc/src/lr_schedule.py similarity index 100% rename from model_zoo/warpctc/src/lr_schedule.py rename to model_zoo/official/cv/warpctc/src/lr_schedule.py diff --git a/model_zoo/official/cv/warpctc/src/metric.py b/model_zoo/official/cv/warpctc/src/metric.py new file mode 100755 index 0000000000..c16c199ca5 --- /dev/null +++ b/model_zoo/official/cv/warpctc/src/metric.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Metric for accuracy evaluation.""" +from mindspore import nn + + +class WarpCTCAccuracy(nn.Metric): + """ + Define accuracy metric for warpctc network. + """ + + def __init__(self, device_target='Ascend'): + super(WarpCTCAccuracy).__init__() + self._correct_num = 0 + self._total_num = 0 + self._count = 0 + self.device_target = device_target + self.blank = 10 if device_target == 'Ascend' else 0 + + def clear(self): + self._correct_num = 0 + self._total_num = 0 + + def update(self, *inputs): + if len(inputs) != 2: + raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) + + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + if self.device_target == 'GPU': + y = y[:, :-1] + + self._count += 1 + + pred_lbls = self._get_prediction(y_pred) + + for b_idx, target in enumerate(y): + if self._is_eq(pred_lbls[b_idx], target): + self._correct_num += 1 + self._total_num += 1 + + def eval(self): + if self._total_num == 0: + raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') + return self._correct_num / self._total_num + + def _is_eq(self, pred_lbl, target): + """ + check whether predict label is equal to target label + """ + target = target.tolist() + pred_diff = len(target) - len(pred_lbl) + if pred_diff > 0: + # padding by BLANK_LABLE + pred_lbl.extend([self.blank] * pred_diff) + return pred_lbl == target + + def _get_prediction(self, y_pred): + """ + parse predict result to labels + """ + seq_len, batch_size, _ = y_pred.shape + indices = y_pred.argmax(axis=2) + + lens = [seq_len] * batch_size + pred_lbls = [] + for i in range(batch_size): + idx = indices[:, i] + last_idx = self.blank + pred_lbl = [] + for j in range(lens[i]): + cur_idx = idx[j] + if cur_idx not in [last_idx, self.blank]: + pred_lbl.append(cur_idx) + last_idx = cur_idx + pred_lbls.append(pred_lbl) + return pred_lbls diff --git a/model_zoo/official/cv/warpctc/src/warpctc.py b/model_zoo/official/cv/warpctc/src/warpctc.py new file mode 100755 index 0000000000..5ac3c24352 --- /dev/null +++ b/model_zoo/official/cv/warpctc/src/warpctc.py @@ -0,0 +1,139 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Warpctc network definition.""" + +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F + + +class StackedRNN(nn.Cell): + """ + Define a stacked RNN network which contains two LSTM layers and one full-connect layer. + + Args: + input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for + captcha images. + batch_size(int): batch size of input data, default is 64 + hidden_size(int): the hidden size in LSTM layers, default is 512 + """ + def __init__(self, input_size, batch_size=64, hidden_size=512): + super(StackedRNN, self).__init__() + self.batch_size = batch_size + self.input_size = input_size + self.num_classes = 11 + self.reshape = P.Reshape() + self.cast = P.Cast() + k = (1 / hidden_size) ** 0.5 + self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1)) + .astype(np.float16), name="w1") + self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1)) + .astype(np.float16), name="w2") + self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1") + self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2") + + self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + + self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh") + + self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) + self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) + + self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), + bias_init=Tensor(self.fc_bias)) + + self.fc.to_float(mstype.float32) + self.expand_dims = P.ExpandDims() + self.concat = P.Concat() + self.transpose = P.Transpose() + + def construct(self, x): + x = self.cast(x, mstype.float16) + x = self.transpose(x, (3, 0, 2, 1)) + x = self.reshape(x, (-1, self.batch_size, self.input_size)) + h1 = self.h1 + c1 = self.c1 + h2 = self.h2 + c2 = self.c2 + + c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1) + c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) + + h2_after_fc = self.fc(h2) + output = self.expand_dims(h2_after_fc, 0) + for i in range(1, F.shape(x)[0]): + c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1) + c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) + + h2_after_fc = self.fc(h2) + h2_after_fc = self.expand_dims(h2_after_fc, 0) + output = self.concat((output, h2_after_fc)) + + return output + + +class StackedRNNForGPU(nn.Cell): + """ + Define a stacked RNN network which contains two LSTM layers and one full-connect layer. + + Args: + input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for + captcha images. + batch_size(int): batch size of input data, default is 64 + hidden_size(int): the hidden size in LSTM layers, default is 512 + num_layer(int): the number of layer of LSTM. + """ + def __init__(self, input_size, batch_size=64, hidden_size=512, num_layer=2): + super(StackedRNNForGPU, self).__init__() + self.batch_size = batch_size + self.input_size = input_size + self.num_classes = 11 + self.reshape = P.Reshape() + self.cast = P.Cast() + k = (1 / hidden_size) ** 0.5 + weight_shape = 4 * hidden_size * (input_size + 3 * hidden_size + 4) + self.weight = Parameter(np.random.uniform(-k, k, (weight_shape, 1, 1)).astype(np.float32), name='weight') + self.h = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32)) + self.c = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32)) + + self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2) + self.lstm.weight = self.weight + + self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) + self.fc_bias = np.random.random(self.num_classes).astype(np.float32) + + self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), + bias_init=Tensor(self.fc_bias)) + + self.fc.to_float(mstype.float32) + self.expand_dims = P.ExpandDims() + self.concat = P.Concat() + self.transpose = P.Transpose() + + def construct(self, x): + x = self.transpose(x, (3, 0, 2, 1)) + x = self.reshape(x, (-1, self.batch_size, self.input_size)) + output, _ = self.lstm(x, (self.h, self.c)) + res = () + for i in range(F.shape(x)[0]): + res += (self.expand_dims(self.fc(output[i]), 0),) + res = self.concat(res) + return res diff --git a/model_zoo/official/cv/warpctc/src/warpctc_for_train.py b/model_zoo/official/cv/warpctc/src/warpctc_for_train.py new file mode 100755 index 0000000000..8cbf2f986a --- /dev/null +++ b/model_zoo/official/cv/warpctc/src/warpctc_for_train.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Automatic differentiation with grad clip.""" +from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, + _get_parallel_mode) +from mindspore.train.parallel_utils import ParallelMode +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.nn.cell import Cell +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +import numpy as np + +compute_norm = C.MultitypeFuncGraph("compute_norm") + + +@compute_norm.register("Tensor") +def _compute_norm(grad): + norm = nn.Norm() + norm = norm(F.cast(grad, mstype.float32)) + ret = F.expand_dims(F.cast(norm, mstype.float32), 0) + return ret + + +grad_div = C.MultitypeFuncGraph("grad_div") + + +@grad_div.register("Tensor", "Tensor") +def _grad_div(val, grad): + div = P.RealDiv() + mul = P.Mul() + grad = mul(grad, 10.0) + ret = div(grad, val) + return ret + + +class TrainOneStepCellWithGradClip(Cell): + """ + Network training package class. + + Wraps the network with an optimizer. The resulting Cell be trained with input data and label. + Backward graph with grad clip will be created in the construct function to do parameter updating. + Different parallel modes are available to run the training. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + + Inputs: + - data (Tensor) - Tensor of shape :(N, ...). + - label (Tensor) - Tensor of shape :(N, ...). + + Outputs: + Tensor, a scalar Tensor with shape :math:`()`. + """ + + def __init__(self, network, optimizer, sens=1.0): + super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.hyper_map = C.HyperMap() + self.greater = P.Greater() + self.select = P.Select() + self.norm = nn.Norm(keep_dims=True) + self.dtype = P.DType() + self.cast = P.Cast() + self.concat = P.Concat(axis=0) + self.ten = Tensor(np.array([10.0]).astype(np.float32)) + parallel_mode = _get_parallel_mode() + if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + self.reducer_flag = True + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, data, label): + weights = self.weights + loss = self.network(data, label) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(data, label, sens) + norm = self.hyper_map(F.partial(compute_norm), grads) + norm = self.concat(norm) + norm = self.norm(norm) + cond = self.greater(norm, self.cast(self.ten, self.dtype(norm))) + clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm))) + grads = self.hyper_map(F.partial(grad_div, clip_val), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/official/cv/warpctc/train.py b/model_zoo/official/cv/warpctc/train.py new file mode 100755 index 0000000000..380308653f --- /dev/null +++ b/model_zoo/official/cv/warpctc/train.py @@ -0,0 +1,106 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Warpctc training""" +import os +import math as m +import random +import argparse +import numpy as np +import mindspore.nn as nn +from mindspore import context +from mindspore import dataset as de +from mindspore.train.model import Model, ParallelMode +from mindspore.nn.wrap import WithLossCell +from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint +from mindspore.communication.management import init, get_group_size, get_rank + +from src.loss import CTCLoss, CTCLossV2 +from src.config import config as cf +from src.dataset import create_dataset +from src.warpctc import StackedRNN, StackedRNNForGPU +from src.warpctc_for_train import TrainOneStepCellWithGradClip +from src.lr_schedule import get_lr + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="Warpctc training") +parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') +parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='Running platform, choose from Ascend, GPU, and default is Ascend.') +parser.set_defaults(run_distribute=False) +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) +if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + +if __name__ == '__main__': + lr_scale = 1 + if args_opt.run_distribute: + if args_opt.platform == 'Ascend': + init() + lr_scale = 1 + device_num = int(os.environ.get("RANK_SIZE")) + rank = int(os.environ.get("RANK_ID")) + else: + init('nccl') + lr_scale = 0.5 + device_num = get_group_size() + rank = get_rank() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + else: + device_num = 1 + rank = 0 + + max_captcha_digits = cf.max_captcha_digits + input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size, + num_shards=device_num, shard_id=rank, device_target=args_opt.platform) + step_size = dataset.get_dataset_size() + # define lr + lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale + lr = get_lr(cf.epoch_size, step_size, lr_init) + if args_opt.platform == 'Ascend': + loss = CTCLoss(max_sequence_length=cf.captcha_width, + max_label_length=max_captcha_digits, + batch_size=cf.batch_size) + net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) + else: + loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size) + net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) + + net = WithLossCell(net, loss) + net = TrainOneStepCellWithGradClip(net, opt).set_train() + # define model + model = Model(net) + # define callbacks + callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] + if cf.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, + keep_checkpoint_max=cf.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path + str(rank), config=config_ck) + callbacks.append(ckpt_cb) + model.train(cf.epoch_size, dataset, callbacks=callbacks) diff --git a/model_zoo/official/cv/yolov3_darknet53/README.md b/model_zoo/official/cv/yolov3_darknet53/README.md new file mode 100644 index 0000000000..fcb3230e9d --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/README.md @@ -0,0 +1,133 @@ +# YOLOV3-DarkNet53 Example + +## Description + +This is an example of training YOLOV3-DarkNet53 with COCO2014 dataset in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset COCO2014. + +> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows: + +``` +. +└─dataset + ├─train2014 + ├─val2014 + └─annotations +``` + +## Structure + +```shell +. +└─yolov3_darknet53 + ├─README.md + ├─scripts + ├─run_standalone_train.sh # launch standalone training(1p) + ├─run_distribute_train.sh # launch distributed training(8p) + └─run_eval.sh # launch evaluating + ├─src + ├─__init__.py # python init file + ├─config.py # parameter configuration + ├─darknet.py # backbone of network + ├─distributed_sampler.py # iterator of dataset + ├─initializer.py # initializer of parameters + ├─logger.py # log function + ├─loss.py # loss function + ├─lr_scheduler.py # generate learning rate + ├─transforms.py # Preprocess data + ├─util.py # util function + ├─yolo.py # yolov3 network + ├─yolo_dataset.py # create dataset for YOLOV3 + ├─eval.py # eval net + └─train.py # train net +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE] + +# standalone training +sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] +``` + +#### Launch + +```bash +# distributed training example(8p) +sh run_distribute_train.sh dataset/coco2014 backbone/backbone.ckpt rank_table_8p.json + +# standalone training example(1p) +sh run_standalone_train.sh dataset/coco2014 backbone/backbone.ckpt +``` + +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + +#### Result + +Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt. + +``` +# distribute training result(8p) +epoch[0], iter[0], loss:14623.384766, 1.23 imgs/sec, lr:7.812499825377017e-05 +epoch[0], iter[100], loss:1486.253051, 15.01 imgs/sec, lr:0.007890624925494194 +epoch[0], iter[200], loss:288.579535, 490.41 imgs/sec, lr:0.015703124925494194 +epoch[0], iter[300], loss:153.136754, 531.99 imgs/sec, lr:0.023515624925494194 +epoch[1], iter[400], loss:106.429322, 405.14 imgs/sec, lr:0.03132812678813934 +... +epoch[318], iter[102000], loss:34.135306, 431.06 imgs/sec, lr:9.63797629083274e-06 +epoch[319], iter[102100], loss:35.652469, 449.52 imgs/sec, lr:2.409552052995423e-06 +epoch[319], iter[102200], loss:34.652273, 384.02 imgs/sec, lr:2.409552052995423e-06 +epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e-06 +... +``` + +### Infer + +#### Usage + +``` +# infer +sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` + +#### Launch + +```bash +# infer with checkpoint +sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt + +``` + +> checkpoint can be produced in training process. + + +#### Result + +Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt. + +``` +=============coco eval reulst========= + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.528 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.127 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.323 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.428 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.259 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.398 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.423 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.224 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.442 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551 +``` diff --git a/model_zoo/official/cv/yolov3_darknet53/eval.py b/model_zoo/official/cv/yolov3_darknet53/eval.py new file mode 100644 index 0000000000..f04ed2447c --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/eval.py @@ -0,0 +1,328 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""YoloV3 eval.""" +import os +import argparse +import datetime +import time +import sys +from collections import defaultdict + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from mindspore import Tensor +from mindspore.train import ParallelMode +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore as ms + +from src.yolo import YOLOV3DarkNet53 +from src.logger import get_logger +from src.yolo_dataset import create_yolo_dataset +from src.config import ConfigYOLOV3DarkNet53 + +devid = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid) + + +class Redirct: + def __init__(self): + self.content = "" + + def write(self, content): + self.content += content + + def flush(self): + self.content = "" + + +class DetectionEngine: + """Detection engine.""" + def __init__(self, args): + self.ignore_threshold = args.ignore_threshold + self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', + 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', + 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', + 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + self.num_classes = len(self.labels) + self.results = {} + self.file_path = '' + self.save_prefix = args.outputs_dir + self.annFile = args.annFile + self._coco = COCO(self.annFile) + self._img_ids = list(sorted(self._coco.imgs.keys())) + self.det_boxes = [] + self.nms_thresh = args.nms_thresh + self.coco_catIds = self._coco.getCatIds() + + def do_nms_for_results(self): + """Get result boxes.""" + for img_id in self.results: + for clsi in self.results[img_id]: + dets = self.results[img_id][clsi] + dets = np.array(dets) + keep_index = self._nms(dets, self.nms_thresh) + + keep_box = [{'image_id': int(img_id), + 'category_id': int(clsi), + 'bbox': list(dets[i][:4].astype(float)), + 'score': dets[i][4].astype(float)} + for i in keep_index] + self.det_boxes.extend(keep_box) + + def _nms(self, predicts, threshold): + """Calculate NMS.""" + # conver xywh -> xmin ymin xmax ymax + x1 = predicts[:, 0] + y1 = predicts[:, 1] + x2 = x1 + predicts[:, 2] + y2 = y1 + predicts[:, 3] + scores = predicts[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + reserved_boxes = [] + while order.size > 0: + i = order[0] + reserved_boxes.append(i) + max_x1 = np.maximum(x1[i], x1[order[1:]]) + max_y1 = np.maximum(y1[i], y1[order[1:]]) + min_x2 = np.minimum(x2[i], x2[order[1:]]) + min_y2 = np.minimum(y2[i], y2[order[1:]]) + + intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1) + intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1) + intersect_area = intersect_w * intersect_h + ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area) + + indexs = np.where(ovr <= threshold)[0] + order = order[indexs + 1] + return reserved_boxes + + def write_result(self): + """Save result to file.""" + import json + t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') + try: + self.file_path = self.save_prefix + '/predict' + t + '.json' + f = open(self.file_path, 'w') + json.dump(self.det_boxes, f) + except IOError as e: + raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) + else: + f.close() + return self.file_path + + def get_eval_result(self): + """Get eval result.""" + cocoGt = COCO(self.annFile) + cocoDt = cocoGt.loadRes(self.file_path) + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.evaluate() + cocoEval.accumulate() + rdct = Redirct() + stdout = sys.stdout + sys.stdout = rdct + cocoEval.summarize() + sys.stdout = stdout + return rdct.content + + def detect(self, outputs, batch, image_shape, image_id): + """Detect boxes.""" + outputs_num = len(outputs) + # output [|32, 52, 52, 3, 85| ] + for batch_id in range(batch): + for out_id in range(outputs_num): + # 32, 52, 52, 3, 85 + out_item = outputs[out_id] + # 52, 52, 3, 85 + out_item_single = out_item[batch_id, :] + # get number of items in one head, [B, gx, gy, anchors, 5+80] + dimensions = out_item_single.shape[:-1] + out_num = 1 + for d in dimensions: + out_num *= d + ori_w, ori_h = image_shape[batch_id] + img_id = int(image_id[batch_id]) + x = out_item_single[..., 0] * ori_w + y = out_item_single[..., 1] * ori_h + w = out_item_single[..., 2] * ori_w + h = out_item_single[..., 3] * ori_h + + conf = out_item_single[..., 4:5] + cls_emb = out_item_single[..., 5:] + + cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) + x = x.reshape(-1) + y = y.reshape(-1) + w = w.reshape(-1) + h = h.reshape(-1) + cls_emb = cls_emb.reshape(-1, 80) + conf = conf.reshape(-1) + cls_argmax = cls_argmax.reshape(-1) + + x_top_left = x - w / 2. + y_top_left = y - h / 2. + # creat all False + flag = np.random.random(cls_emb.shape) > sys.maxsize + for i in range(flag.shape[0]): + c = cls_argmax[i] + flag[i, c] = True + confidence = cls_emb[flag] * conf + for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax): + if confi < self.ignore_threshold: + continue + if img_id not in self.results: + self.results[img_id] = defaultdict(list) + x_lefti = max(0, x_lefti) + y_lefti = max(0, y_lefti) + wi = min(wi, ori_w) + hi = min(hi, ori_h) + # transform catId to match coco + coco_clsi = self.coco_catIds[clsi] + self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) + + +def parse_args(): + """Parse arguments.""" + parser = argparse.ArgumentParser('mindspore coco testing') + + # dataset related + parser.add_argument('--data_dir', type=str, default='', help='train data dir') + parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') + + # network related + parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') + + # logging related + parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location') + + # detect_related + parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS') + parser.add_argument('--annFile', type=str, default='', help='path to annotation') + parser.add_argument('--testing_shape', type=str, default='', help='shape for test ') + parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes') + + args, _ = parser.parse_known_args() + + args.data_root = os.path.join(args.data_dir, 'val2014') + args.annFile = os.path.join(args.data_dir, 'annotations/instances_val2014.json') + + return args + + +def conver_testing_shape(args): + """Convert testing shape to list.""" + testing_shape = [int(args.testing_shape), int(args.testing_shape)] + return testing_shape + + +def test(): + """The function of eval.""" + start_time = time.time() + args = parse_args() + + # logger + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + rank_id = int(os.environ.get('RANK_ID')) + args.logger = get_logger(args.outputs_dir, rank_id) + + context.reset_auto_parallel_context() + parallel_mode = ParallelMode.STAND_ALONE + context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1) + + args.logger.info('Creating Network....') + network = YOLOV3DarkNet53(is_training=False) + + args.logger.info(args.pretrained) + if os.path.isfile(args.pretrained): + param_dict = load_checkpoint(args.pretrained) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.'): + continue + elif key.startswith('yolo_network.'): + param_dict_new[key[13:]] = values + else: + param_dict_new[key] = values + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(args.pretrained)) + else: + args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained)) + assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained)) + exit(1) + + data_root = args.data_root + ann_file = args.annFile + + config = ConfigYOLOV3DarkNet53() + if args.testing_shape: + config.test_img_shape = conver_testing_shape(args) + + ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size, + max_epoch=1, device_num=1, rank=rank_id, shuffle=False, + config=config) + + args.logger.info('testing shape : {}'.format(config.test_img_shape)) + args.logger.info('totol {} images to eval'.format(data_size)) + + network.set_train(False) + + # init detection engine + detection = DetectionEngine(args) + + input_shape = Tensor(tuple(config.test_img_shape), ms.float32) + args.logger.info('Start inference....') + for i, data in enumerate(ds.create_dict_iterator()): + image = Tensor(data["image"]) + + image_shape = Tensor(data["image_shape"]) + image_id = Tensor(data["img_id"]) + + prediction = network(image, input_shape) + output_big, output_me, output_small = prediction + output_big = output_big.asnumpy() + output_me = output_me.asnumpy() + output_small = output_small.asnumpy() + image_id = image_id.asnumpy() + image_shape = image_shape.asnumpy() + + detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id) + if i % 1000 == 0: + args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100)) + + args.logger.info('Calculating mAP...') + detection.do_nms_for_results() + result_file_path = detection.write_result() + args.logger.info('result file path: {}'.format(result_file_path)) + eval_result = detection.get_eval_result() + + cost_time = time.time() - start_time + args.logger.info('\n=============coco eval reulst=========\n' + eval_result) + args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) + + +if __name__ == "__main__": + test() diff --git a/model_zoo/official/cv/yolov3_darknet53/scripts/run_distribute_train.sh b/model_zoo/official/cv/yolov3_darknet53/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..fcc0ef2867 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/scripts/run_distribute_train.sh @@ -0,0 +1,81 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +PRETRAINED_BACKBONE=$(get_real_path $2) +RANK_TABLE_FILE=$(get_real_path $3) +echo $DATASET_PATH +echo $PRETRAINED_BACKBONE +echo $RANK_TABLE_FILE + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $PRETRAINED_BACKBONE ] +then + echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" +exit 1 +fi + +if [ ! -f $RANK_TABLE_FILE ] +then + echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file" +exit 1 +fi + +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILEH=$RANK_TABLE_FILE + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py \ + --data_dir=$DATASET_PATH \ + --pretrained_backbone=$PRETRAINED_BACKBONE \ + --is_distributed=1 \ + --lr=0.1 \ + --T_max=320 \ + --max_epoch=320 \ + --warmup_epochs=4 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & + cd .. +done diff --git a/model_zoo/yolov3_darknet53/scripts/run_eval.sh b/model_zoo/official/cv/yolov3_darknet53/scripts/run_eval.sh similarity index 100% rename from model_zoo/yolov3_darknet53/scripts/run_eval.sh rename to model_zoo/official/cv/yolov3_darknet53/scripts/run_eval.sh diff --git a/model_zoo/yolov3_darknet53/scripts/run_standalone_train.sh b/model_zoo/official/cv/yolov3_darknet53/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/yolov3_darknet53/scripts/run_standalone_train.sh rename to model_zoo/official/cv/yolov3_darknet53/scripts/run_standalone_train.sh diff --git a/model_zoo/official/cv/yolov3_darknet53/src/__init__.py b/model_zoo/official/cv/yolov3_darknet53/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/yolov3_darknet53/src/config.py b/model_zoo/official/cv/yolov3_darknet53/src/config.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/config.py rename to model_zoo/official/cv/yolov3_darknet53/src/config.py diff --git a/model_zoo/yolov3_darknet53/src/darknet.py b/model_zoo/official/cv/yolov3_darknet53/src/darknet.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/darknet.py rename to model_zoo/official/cv/yolov3_darknet53/src/darknet.py diff --git a/model_zoo/yolov3_darknet53/src/distributed_sampler.py b/model_zoo/official/cv/yolov3_darknet53/src/distributed_sampler.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/distributed_sampler.py rename to model_zoo/official/cv/yolov3_darknet53/src/distributed_sampler.py diff --git a/model_zoo/official/cv/yolov3_darknet53/src/initializer.py b/model_zoo/official/cv/yolov3_darknet53/src/initializer.py new file mode 100644 index 0000000000..c66cc74acf --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/src/initializer.py @@ -0,0 +1,181 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parameter init.""" +import math +from functools import reduce +import numpy as np +from mindspore.common import initializer as init +from mindspore.common.initializer import Initializer as MeInitializer +import mindspore.nn as nn + + +np.random.seed(5) + + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def _assignment(arr, num): + """Assign the value of 'num' and 'arr'.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + + +def _calculate_correct_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(array) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Fills the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + + Examples: + >>> w = np.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + """ + fan = _calculate_correct_fan(arr, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, arr.shape) + + +def _calculate_fan_in_and_fan_out(arr): + """Calculate fan in and fan out.""" + dimensions = len(arr.shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") + + num_input_fmaps = arr.shape[1] + num_output_fmaps = arr.shape[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:]) + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +class KaimingUniform(MeInitializer): + """Kaiming uniform initializer.""" + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingUniform, self).__init__() + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + + def _initialize(self, arr): + tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) + _assignment(arr, tmp) + + +def default_recurisive_init(custom_cell): + """Initialize parameter.""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass diff --git a/model_zoo/yolov3_darknet53/src/logger.py b/model_zoo/official/cv/yolov3_darknet53/src/logger.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/logger.py rename to model_zoo/official/cv/yolov3_darknet53/src/logger.py diff --git a/model_zoo/yolov3_darknet53/src/loss.py b/model_zoo/official/cv/yolov3_darknet53/src/loss.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/loss.py rename to model_zoo/official/cv/yolov3_darknet53/src/loss.py diff --git a/model_zoo/yolov3_darknet53/src/lr_scheduler.py b/model_zoo/official/cv/yolov3_darknet53/src/lr_scheduler.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/lr_scheduler.py rename to model_zoo/official/cv/yolov3_darknet53/src/lr_scheduler.py diff --git a/model_zoo/official/cv/yolov3_darknet53/src/transforms.py b/model_zoo/official/cv/yolov3_darknet53/src/transforms.py new file mode 100644 index 0000000000..4756a141f0 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/src/transforms.py @@ -0,0 +1,570 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Preprocess dataset.""" +import random +import threading +import copy + +import numpy as np +from PIL import Image +import cv2 + + +def _rand(a=0., b=1.): + return np.random.rand() * (b - a) + a + + +def bbox_iou(bbox_a, bbox_b, offset=0): + """Calculate Intersection-Over-Union(IOU) of two bounding boxes. + + Parameters + ---------- + bbox_a : numpy.ndarray + An ndarray with shape :math:`(N, 4)`. + bbox_b : numpy.ndarray + An ndarray with shape :math:`(M, 4)`. + offset : float or int, default is 0 + The ``offset`` is used to control the whether the width(or height) is computed as + (right - left + ``offset``). + Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. + + Returns + ------- + numpy.ndarray + An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of + bounding boxes in `bbox_a` and `bbox_b`. + + """ + if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: + raise IndexError("Bounding boxes axis 1 must have at least length 4") + + tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) + br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) + + area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) + area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) + area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) + return area_i / (area_a[:, None] + area_b - area_i) + + +def statistic_normalize_img(img, statistic_norm): + """Statistic normalize images.""" + # img: RGB + if isinstance(img, Image.Image): + img = np.array(img) + img = img/255. + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + if statistic_norm: + img = (img - mean) / std + return img + + +def get_interp_method(interp, sizes=()): + """ + Get the interpolation method for resize functions. + The major purpose of this function is to wrap a random interp method selection + and a auto-estimation method. + + Note: + When shrinking an image, it will generally look best with AREA-based + interpolation, whereas, when enlarging an image, it will generally look best + with Bicubic or Bilinear. + + Args: + interp (int): Interpolation method for all resizing operations. + + - 0: Nearest Neighbors Interpolation. + - 1: Bilinear interpolation. + - 2: Bicubic interpolation over 4x4 pixel neighborhood. + - 3: Nearest Neighbors. Originally it should be Area-based, as we cannot find Area-based, + so we use NN instead. Area-based (resampling using pixel area relation). + It may be a preferred method for image decimation, as it gives moire-free results. + But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). + - 4: Lanczos interpolation over 8x8 pixel neighborhood. + - 9: Cubic for enlarge, area for shrink, bilinear for others. + - 10: Random select from interpolation method mentioned above. + + sizes (tuple): Format should like (old_height, old_width, new_height, new_width), + if None provided, auto(9) will return Area(2) anyway. Default: () + + Returns: + int, interp method from 0 to 4. + """ + if interp == 9: + if sizes: + assert len(sizes) == 4 + oh, ow, nh, nw = sizes + if nh > oh and nw > ow: + return 2 + if nh < oh and nw < ow: + return 0 + return 1 + return 2 + if interp == 10: + return random.randint(0, 4) + if interp not in (0, 1, 2, 3, 4): + raise ValueError('Unknown interp method %d' % interp) + return interp + + +def pil_image_reshape(interp): + """Reshape pil image.""" + reshape_type = { + 0: Image.NEAREST, + 1: Image.BILINEAR, + 2: Image.BICUBIC, + 3: Image.NEAREST, + 4: Image.LANCZOS, + } + return reshape_type[interp] + + +def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, + max_boxes, label_smooth, label_smooth_factor=0.1): + """Preprocess annotation boxes.""" + anchors = np.array(anchors) + num_layers = anchors.shape[0] // 3 + anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + true_boxes = np.array(true_boxes, dtype='float32') + input_shape = np.array(in_shape, dtype='int32') + boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. + # trans to box center point + boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] + # input_shape is [h, w] + true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] + true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] + # true_boxes = [xywh] + + grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] + # grid_shape [h, w] + y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), + 5 + num_classes), dtype='float32') for l in range(num_layers)] + # y_true [gridy, gridx] + anchors = np.expand_dims(anchors, 0) + anchors_max = anchors / 2. + anchors_min = -anchors_max + valid_mask = boxes_wh[..., 0] > 0 + + wh = boxes_wh[valid_mask] + if wh.size > 0: + wh = np.expand_dims(wh, -2) + boxes_max = wh / 2. + boxes_min = -boxes_max + + intersect_min = np.maximum(boxes_min, anchors_min) + intersect_max = np.minimum(boxes_max, anchors_max) + intersect_wh = np.maximum(intersect_max - intersect_min, 0.) + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + box_area = wh[..., 0] * wh[..., 1] + anchor_area = anchors[..., 0] * anchors[..., 1] + iou = intersect_area / (box_area + anchor_area - intersect_area) + + best_anchor = np.argmax(iou, axis=-1) + for t, n in enumerate(best_anchor): + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x + + k = anchor_mask[l].index(n) + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + + # lable-smooth + if label_smooth: + sigma = label_smooth_factor/(num_classes-1) + y_true[l][j, i, k, 5:] = sigma + y_true[l][j, i, k, 5+c] = 1-label_smooth_factor + else: + y_true[l][j, i, k, 5 + c] = 1. + + # pad_gt_boxes for avoiding dynamic shape + pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) + + mask0 = np.reshape(y_true[0][..., 4:5], [-1]) + gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) + # gt_box [boxes, [x,y,w,h]] + gt_box0 = gt_box0[mask0 == 1] + # gt_box0: get all boxes which have object + pad_gt_box0[:gt_box0.shape[0]] = gt_box0 + # gt_box0.shape[0]: total number of boxes in gt_box0 + # top N of pad_gt_box0 is real box, and after are pad by zero + + mask1 = np.reshape(y_true[1][..., 4:5], [-1]) + gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) + gt_box1 = gt_box1[mask1 == 1] + pad_gt_box1[:gt_box1.shape[0]] = gt_box1 + + mask2 = np.reshape(y_true[2][..., 4:5], [-1]) + gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) + + gt_box2 = gt_box2[mask2 == 1] + pad_gt_box2[:gt_box2.shape[0]] = gt_box2 + return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 + + +def _reshape_data(image, image_size): + """Reshape image.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + ori_w, ori_h = image.size + ori_image_shape = np.array([ori_w, ori_h], np.int32) + # original image shape fir:H sec:W + h, w = image_size + interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) + image = image.resize((w, h), pil_image_reshape(interp)) + image_data = statistic_normalize_img(image, statistic_norm=True) + if len(image_data.shape) == 2: + image_data = np.expand_dims(image_data, axis=-1) + image_data = np.concatenate([image_data, image_data, image_data], axis=-1) + image_data = image_data.astype(np.float32) + return image_data, ori_image_shape + + +def color_distortion(img, hue, sat, val, device_num): + """Color distortion.""" + hue = _rand(-hue, hue) + sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) + val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) + if device_num != 1: + cv2.setNumThreads(1) + x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) + x = x / 255. + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + x = x * 255. + x = x.astype(np.uint8) + image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) + return image_data + + +def filp_pil_image(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def convert_gray_to_color(img): + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + return img + + +def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): + iou = bbox_iou(box, crop_box) + return min_iou <= iou.min() and max_iou >= iou.max() + + +def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): + """Choose candidate by constraints.""" + if use_constraints: + constraints = ( + (0.1, None), + (0.3, None), + (0.5, None), + (0.7, None), + (0.9, None), + (None, 1), + ) + else: + constraints = ( + (None, None), + ) + # add default candidate + candidates = [(0, 0, input_w, input_h)] + for constraint in constraints: + min_iou, max_iou = constraint + min_iou = -np.inf if min_iou is None else min_iou + max_iou = np.inf if max_iou is None else max_iou + + for _ in range(max_trial): + # box_data should have at least one box + new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) + scale = _rand(0.25, 2) + + if new_ar < 1: + nh = int(scale * input_h) + nw = int(nh * new_ar) + else: + nw = int(scale * input_w) + nh = int(nw / new_ar) + + dx = int(_rand(0, input_w - nw)) + dy = int(_rand(0, input_h - nh)) + + if box.size > 0: + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + + crop_box = np.array((0, 0, input_w, input_h)) + if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): + continue + else: + candidates.append((dx, dy, nw, nh)) + else: + raise Exception("!!! annotation box is less than 1") + return candidates + + +def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, + image_h, flip, box, box_data, allow_outside_center): + """Calculate correct boxes.""" + while candidates: + if len(candidates) > 1: + # ignore default candidate which do not crop + candidate = candidates.pop(np.random.randint(1, len(candidates))) + else: + candidate = candidates.pop(np.random.randint(0, len(candidates))) + dx, dy, nw, nh = candidate + t_box = copy.deepcopy(box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy + if flip: + t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] + + if allow_outside_center: + pass + else: + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)] + t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, + (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] + + # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero + t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 + # recorrect w,h not higher than input size + t_box[:, 2][t_box[:, 2] > input_w] = input_w + t_box[:, 3][t_box[:, 3] > input_h] = input_h + box_w = t_box[:, 2] - t_box[:, 0] + box_h = t_box[:, 3] - t_box[:, 1] + # discard invalid box: w or h smaller than 1 pixel + t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] + + if t_box.shape[0] > 0: + # break if number of find t_box + box_data[: len(t_box)] = t_box + return box_data, candidate + raise Exception('all candidates can not satisfied re-correct bbox') + + +def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, + anchors, num_classes, max_trial=10, device_num=1): + """Crop an image randomly with bounding box constraints. + + This data augmentation is used in training of + Single Shot Multibox Detector [#]_. More details can be found in + data augmentation section of the original paper. + .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, + Scott Reed, Cheng-Yang Fu, Alexander C. Berg. + SSD: Single Shot MultiBox Detector. ECCV 2016.""" + + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + image_w, image_h = image.size + input_h, input_w = image_input_size + + np.random.shuffle(box) + if len(box) > max_boxes: + box = box[:max_boxes] + flip = _rand() < .5 + box_data = np.zeros((max_boxes, 5)) + + candidates = _choose_candidate_by_constraints(use_constraints=False, + max_trial=max_trial, + input_w=input_w, + input_h=input_h, + image_w=image_w, + image_h=image_h, + jitter=jitter, + box=box) + box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, + input_w=input_w, + input_h=input_h, + image_w=image_w, + image_h=image_h, + flip=flip, + box=box, + box_data=box_data, + allow_outside_center=True) + dx, dy, nw, nh = candidate + interp = get_interp_method(interp=10) + image = image.resize((nw, nh), pil_image_reshape(interp)) + # place image, gray color as back graoud + new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) + new_image.paste(image, (dx, dy)) + image = new_image + + if flip: + image = filp_pil_image(image) + + image = np.array(image) + + image = convert_gray_to_color(image) + + image_data = color_distortion(image, hue, sat, val, device_num) + image_data = statistic_normalize_img(image_data, statistic_norm=True) + + image_data = image_data.astype(np.float32) + + return image_data, box_data + + +def preprocess_fn(image, box, config, input_size, device_num): + """Preprocess data function.""" + config_anchors = config.anchor_scales + anchors = np.array([list(x) for x in config_anchors]) + max_boxes = config.max_box + num_classes = config.num_classes + jitter = config.jitter + hue = config.hue + sat = config.saturation + val = config.value + image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, + image_input_size=input_size, max_boxes=max_boxes, + num_classes=num_classes, anchors=anchors, device_num=device_num) + return image, anno + + +def reshape_fn(image, img_id, config): + input_size = config.test_img_shape + image, ori_image_shape = _reshape_data(image, image_size=input_size) + return image, ori_image_shape, img_id + + +class MultiScaleTrans: + """Multi scale transform.""" + def __init__(self, config, device_num): + self.config = config + self.seed = 0 + self.size_list = [] + self.resize_rate = config.resize_rate + self.dataset_size = config.dataset_size + self.size_dict = {} + self.seed_num = int(1e6) + self.seed_list = self.generate_seed_list(seed_num=self.seed_num) + self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) + self.device_num = device_num + + def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): + seed_list = [] + random.seed(init_seed) + for _ in range(seed_num): + seed = random.randint(seed_range[0], seed_range[1]) + seed_list.append(seed) + return seed_list + + def __call__(self, imgs, annos, batchInfo): + epoch_num = batchInfo.get_epoch_num() + size_idx = int(batchInfo.get_batch_num() / self.resize_rate) + seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] + ret_imgs = [] + ret_annos = [] + + if self.size_dict.get(seed_key, None) is None: + random.seed(seed_key) + new_size = random.choice(self.config.multi_scale) + self.size_dict[seed_key] = new_size + seed = seed_key + + input_size = self.size_dict[seed] + for img, anno in zip(imgs, annos): + img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) + ret_imgs.append(img.transpose(2, 0, 1).copy()) + ret_annos.append(anno) + return np.array(ret_imgs), np.array(ret_annos) + + +def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): + """Preprocess true box for multi-thread.""" + i = 0 + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1[result_index + i] = bbox_true_1 + batch_bbox_true_2[result_index + i] = bbox_true_2 + batch_bbox_true_3[result_index + i] = bbox_true_3 + batch_gt_box1[result_index + i] = gt_box1 + batch_gt_box2[result_index + i] = gt_box2 + batch_gt_box3[result_index + i] = gt_box3 + i = i + 1 + + +def batch_preprocess_true_box(annos, config, input_shape): + """Preprocess true box with multi-thread.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + threads = [] + + step = 4 + for index in range(0, len(annos), step): + for _ in range(step): + batch_bbox_true_1.append(None) + batch_bbox_true_2.append(None) + batch_bbox_true_3.append(None) + batch_gt_box1.append(None) + batch_gt_box2.append(None) + batch_gt_box3.append(None) + step_anno = annos[index: index + step] + t = threading.Thread(target=thread_batch_preprocess_true_box, + args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, + batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) + + +def batch_preprocess_true_box_single(annos, config, input_shape): + """Preprocess true boxes.""" + batch_bbox_true_1 = [] + batch_bbox_true_2 = [] + batch_bbox_true_3 = [] + batch_gt_box1 = [] + batch_gt_box2 = [] + batch_gt_box3 = [] + for anno in annos: + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, + num_classes=config.num_classes, max_boxes=config.max_box, + label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) + batch_bbox_true_1.append(bbox_true_1) + batch_bbox_true_2.append(bbox_true_2) + batch_bbox_true_3.append(bbox_true_3) + batch_gt_box1.append(gt_box1) + batch_gt_box2.append(gt_box2) + batch_gt_box3.append(gt_box3) + + return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ + np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) diff --git a/model_zoo/official/cv/yolov3_darknet53/src/util.py b/model_zoo/official/cv/yolov3_darknet53/src/util.py new file mode 100644 index 0000000000..f97bdd2548 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/src/util.py @@ -0,0 +1,177 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function.""" +from mindspore.train.serialization import load_checkpoint +import mindspore.nn as nn + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', tb_writer=None): + self.name = name + self.fmt = fmt + self.reset() + self.tb_writer = tb_writer + self.cur_step = 1 + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + if self.tb_writer is not None: + self.tb_writer.add_scalar(self.name, self.val, self.cur_step) + self.cur_step += 1 + + def __str__(self): + fmtstr = '{name}:{avg' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + + +def load_backbone(net, ckpt_path, args): + """Load darknet53 backbone checkpoint.""" + param_dict = load_checkpoint(ckpt_path) + yolo_backbone_prefix = 'feature_map.backbone' + darknet_backbone_prefix = 'network.backbone' + find_param = [] + not_found_param = [] + net.init_parameters_data() + for name, cell in net.cells_and_names(): + if name.startswith(yolo_backbone_prefix): + name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) + if isinstance(cell, (nn.Conv2d, nn.Dense)): + darknet_weight = '{}.weight'.format(name) + darknet_bias = '{}.bias'.format(name) + if darknet_weight in param_dict: + cell.weight.default_input = param_dict[darknet_weight].data + find_param.append(darknet_weight) + else: + not_found_param.append(darknet_weight) + if darknet_bias in param_dict: + cell.bias.default_input = param_dict[darknet_bias].data + find_param.append(darknet_bias) + else: + not_found_param.append(darknet_bias) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + darknet_moving_mean = '{}.moving_mean'.format(name) + darknet_moving_variance = '{}.moving_variance'.format(name) + darknet_gamma = '{}.gamma'.format(name) + darknet_beta = '{}.beta'.format(name) + if darknet_moving_mean in param_dict: + cell.moving_mean.default_input = param_dict[darknet_moving_mean].data + find_param.append(darknet_moving_mean) + else: + not_found_param.append(darknet_moving_mean) + if darknet_moving_variance in param_dict: + cell.moving_variance.default_input = param_dict[darknet_moving_variance].data + find_param.append(darknet_moving_variance) + else: + not_found_param.append(darknet_moving_variance) + if darknet_gamma in param_dict: + cell.gamma.default_input = param_dict[darknet_gamma].data + find_param.append(darknet_gamma) + else: + not_found_param.append(darknet_gamma) + if darknet_beta in param_dict: + cell.beta.default_input = param_dict[darknet_beta].data + find_param.append(darknet_beta) + else: + not_found_param.append(darknet_beta) + + args.logger.info('================found_param {}========='.format(len(find_param))) + args.logger.info(find_param) + args.logger.info('================not_found_param {}========='.format(len(not_found_param))) + args.logger.info(not_found_param) + args.logger.info('=====load {} successfully ====='.format(ckpt_path)) + + return net + + +def default_wd_filter(x): + """default weight decay filter.""" + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + return False + if parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + return False + if parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + return False + + return True + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + +class ShapeRecord: + """Log image shape.""" + def __init__(self): + self.shape_record = { + 320: 0, + 352: 0, + 384: 0, + 416: 0, + 448: 0, + 480: 0, + 512: 0, + 544: 0, + 576: 0, + 608: 0, + 'total': 0 + } + + def set(self, shape): + if len(shape) > 1: + shape = shape[0] + shape = int(shape) + self.shape_record[shape] += 1 + self.shape_record['total'] += 1 + + def show(self, logger): + for key in self.shape_record: + rate = self.shape_record[key] / float(self.shape_record['total']) + logger.info('shape {}: {:.2f}%'.format(key, rate*100)) diff --git a/model_zoo/yolov3_darknet53/src/yolo.py b/model_zoo/official/cv/yolov3_darknet53/src/yolo.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/yolo.py rename to model_zoo/official/cv/yolov3_darknet53/src/yolo.py diff --git a/model_zoo/yolov3_darknet53/src/yolo_dataset.py b/model_zoo/official/cv/yolov3_darknet53/src/yolo_dataset.py similarity index 100% rename from model_zoo/yolov3_darknet53/src/yolo_dataset.py rename to model_zoo/official/cv/yolov3_darknet53/src/yolo_dataset.py diff --git a/model_zoo/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py similarity index 100% rename from model_zoo/yolov3_darknet53/train.py rename to model_zoo/official/cv/yolov3_darknet53/train.py diff --git a/model_zoo/yolov3_resnet18/README.md b/model_zoo/official/cv/yolov3_resnet18/README.md similarity index 100% rename from model_zoo/yolov3_resnet18/README.md rename to model_zoo/official/cv/yolov3_resnet18/README.md diff --git a/model_zoo/yolov3_resnet18/eval.py b/model_zoo/official/cv/yolov3_resnet18/eval.py similarity index 100% rename from model_zoo/yolov3_resnet18/eval.py rename to model_zoo/official/cv/yolov3_resnet18/eval.py diff --git a/model_zoo/official/cv/yolov3_resnet18/scripts/run_distribute_train.sh b/model_zoo/official/cv/yolov3_resnet18/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..457598c9d9 --- /dev/null +++ b/model_zoo/official/cv/yolov3_resnet18/scripts/run_distribute_train.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "=======================================================================================================================================================" +echo "Please run the scipt as: " +echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" +echo "For example: sh run_distribute_train.sh 8 150 /data/Mindrecord_train /data /data/train.txt /data/hccl.json /opt/yolov3-150.ckpt(optional) 100(optional)" +echo "It is better to use absolute path." +echo "The learning rate is 0.005 as default, if you want other lr, please change the value in this script." +echo "=======================================================================================================================================================" + +if [ $# != 6 ] && [ $# != 8 ] +then + echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [MINDRECORD_DIR] [IMAGE_DIR] [ANNO_PATH] [RANK_TABLE_FILE] \ +[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" + exit 1 +fi + +EPOCH_SIZE=$2 +MINDRECORD_DIR=$3 +IMAGE_DIR=$4 +ANNO_PATH=$5 +PRE_TRAINED=$7 +PRE_TRAINED_EPOCH_SIZE=$8 + +# Before start distribute train, first create mindrecord files. +python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \ +--anno_path=$ANNO_PATH + +echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" + +export RANK_TABLE_FILE=$6 +export RANK_SIZE=$1 + +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) +cd $BASE_PATH/../ || exit + +for((i=0;i env.log + + if [ $# == 6 ] + then + taskset -c $cmdopt python train.py \ + --distribute=1 \ + --lr=0.005 \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --mindrecord_dir=$MINDRECORD_DIR \ + --image_dir=$IMAGE_DIR \ + --epoch_size=$EPOCH_SIZE \ + --anno_path=$ANNO_PATH > log.txt 2>&1 & + fi + + if [ $# == 8 ] + then + taskset -c $cmdopt python train.py \ + --distribute=1 \ + --lr=0.005 \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --mindrecord_dir=$MINDRECORD_DIR \ + --image_dir=$IMAGE_DIR \ + --epoch_size=$EPOCH_SIZE \ + --pre_trained=$PRE_TRAINED \ + --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ + --anno_path=$ANNO_PATH > log.txt 2>&1 & + fi + + cd ../ +done diff --git a/model_zoo/yolov3_resnet18/scripts/run_eval.sh b/model_zoo/official/cv/yolov3_resnet18/scripts/run_eval.sh similarity index 100% rename from model_zoo/yolov3_resnet18/scripts/run_eval.sh rename to model_zoo/official/cv/yolov3_resnet18/scripts/run_eval.sh diff --git a/model_zoo/yolov3_resnet18/scripts/run_standalone_train.sh b/model_zoo/official/cv/yolov3_resnet18/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/yolov3_resnet18/scripts/run_standalone_train.sh rename to model_zoo/official/cv/yolov3_resnet18/scripts/run_standalone_train.sh diff --git a/model_zoo/yolov3_resnet18/src/config.py b/model_zoo/official/cv/yolov3_resnet18/src/config.py similarity index 100% rename from model_zoo/yolov3_resnet18/src/config.py rename to model_zoo/official/cv/yolov3_resnet18/src/config.py diff --git a/model_zoo/official/cv/yolov3_resnet18/src/dataset.py b/model_zoo/official/cv/yolov3_resnet18/src/dataset.py new file mode 100644 index 0000000000..7c5177a3fe --- /dev/null +++ b/model_zoo/official/cv/yolov3_resnet18/src/dataset.py @@ -0,0 +1,316 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""YOLOv3 dataset""" +from __future__ import division + +import os +import numpy as np +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb +from PIL import Image +import mindspore.dataset as de +from mindspore.mindrecord import FileWriter +import mindspore.dataset.transforms.vision.c_transforms as C +from src.config import ConfigYOLOV3ResNet18 + +iter_cnt = 0 +_NUM_BOXES = 50 + +def preprocess_fn(image, box, is_training): + """Preprocess function for dataset.""" + config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326] + anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2) + do_hsv = False + max_boxes = 20 + num_classes = ConfigYOLOV3ResNet18.num_classes + + def _rand(a=0., b=1.): + return np.random.rand() * (b - a) + a + + def _preprocess_true_boxes(true_boxes, anchors, in_shape=None): + """Get true boxes.""" + num_layers = anchors.shape[0] // 3 + anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + true_boxes = np.array(true_boxes, dtype='float32') + # input_shape = np.array([in_shape, in_shape], dtype='int32') + input_shape = np.array(in_shape, dtype='int32') + boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. + boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] + true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] + true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] + + grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] + y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), + 5 + num_classes), dtype='float32') for l in range(num_layers)] + + anchors = np.expand_dims(anchors, 0) + anchors_max = anchors / 2. + anchors_min = -anchors_max + + valid_mask = boxes_wh[..., 0] >= 1 + + wh = boxes_wh[valid_mask] + + + if len(wh) >= 1: + wh = np.expand_dims(wh, -2) + boxes_max = wh / 2. + boxes_min = -boxes_max + + intersect_min = np.maximum(boxes_min, anchors_min) + intersect_max = np.minimum(boxes_max, anchors_max) + intersect_wh = np.maximum(intersect_max - intersect_min, 0.) + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + box_area = wh[..., 0] * wh[..., 1] + anchor_area = anchors[..., 0] * anchors[..., 1] + iou = intersect_area / (box_area + anchor_area - intersect_area) + + best_anchor = np.argmax(iou, axis=-1) + for t, n in enumerate(best_anchor): + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') + k = anchor_mask[l].index(n) + + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + y_true[l][j, i, k, 5 + c] = 1. + + pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32) + pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32) + pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32) + + mask0 = np.reshape(y_true[0][..., 4:5], [-1]) + gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) + gt_box0 = gt_box0[mask0 == 1] + pad_gt_box0[:gt_box0.shape[0]] = gt_box0 + + mask1 = np.reshape(y_true[1][..., 4:5], [-1]) + gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) + gt_box1 = gt_box1[mask1 == 1] + pad_gt_box1[:gt_box1.shape[0]] = gt_box1 + + mask2 = np.reshape(y_true[2][..., 4:5], [-1]) + gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) + gt_box2 = gt_box2[mask2 == 1] + pad_gt_box2[:gt_box2.shape[0]] = gt_box2 + + return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 + + def _infer_data(img_data, input_shape, box): + w, h = img_data.size + input_h, input_w = input_shape + scale = min(float(input_w) / float(w), float(input_h) / float(h)) + nw = int(w * scale) + nh = int(h * scale) + img_data = img_data.resize((nw, nh), Image.BICUBIC) + + new_image = np.zeros((input_h, input_w, 3), np.float32) + new_image.fill(128) + img_data = np.array(img_data) + if len(img_data.shape) == 2: + img_data = np.expand_dims(img_data, axis=-1) + img_data = np.concatenate([img_data, img_data, img_data], axis=-1) + + dh = int((input_h - nh) / 2) + dw = int((input_w - nw) / 2) + new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data + new_image /= 255. + new_image = np.transpose(new_image, (2, 0, 1)) + new_image = np.expand_dims(new_image, 0) + return new_image, np.array([h, w], np.float32), box + + def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)): + """Data augmentation function.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + iw, ih = image.size + ori_image_shape = np.array([ih, iw], np.int32) + h, w = image_size + + if not is_training: + return _infer_data(image, image_size, box) + + flip = _rand() < .5 + # correct boxes + box_data = np.zeros((max_boxes, 5)) + while True: + # Prevent the situation that all boxes are eliminated + new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \ + _rand(1 - jitter, 1 + jitter) + scale = _rand(0.25, 2) + + if new_ar < 1: + nh = int(scale * h) + nw = int(nh * new_ar) + else: + nw = int(scale * w) + nh = int(nw / new_ar) + + dx = int(_rand(0, w - nw)) + dy = int(_rand(0, h - nh)) + + if len(box) >= 1: + t_box = box.copy() + np.random.shuffle(t_box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy + if flip: + t_box[:, [0, 2]] = w - t_box[:, [2, 0]] + t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 + t_box[:, 2][t_box[:, 2] > w] = w + t_box[:, 3][t_box[:, 3] > h] = h + box_w = t_box[:, 2] - t_box[:, 0] + box_h = t_box[:, 3] - t_box[:, 1] + t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box + + if len(t_box) >= 1: + box = t_box + break + + box_data[:len(box)] = box + # resize image + image = image.resize((nw, nh), Image.BICUBIC) + # place image + new_image = Image.new('RGB', (w, h), (128, 128, 128)) + new_image.paste(image, (dx, dy)) + image = new_image + + # flip image or not + if flip: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + + # convert image to gray or not + gray = _rand() < .25 + if gray: + image = image.convert('L').convert('RGB') + + # when the channels of image is 1 + image = np.array(image) + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + # distort image + hue = _rand(-hue, hue) + sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) + val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) + image_data = image / 255. + if do_hsv: + x = rgb_to_hsv(image_data) + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + image_data = hsv_to_rgb(x) # numpy array, 0 to 1 + image_data = image_data.astype(np.float32) + + # preprocess bounding boxes + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(box_data, anchors, image_size) + + return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ + ori_image_shape, gt_box1, gt_box2, gt_box3 + + if is_training: + images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training) + return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3 + + images, shape, anno = _data_aug(image, box, is_training) + return images, shape, anno + + +def anno_parser(annos_str): + """Parse annotation from string to list.""" + annos = [] + for anno_str in annos_str: + anno = list(map(int, anno_str.strip().split(','))) + annos.append(anno) + return annos + + +def filter_valid_data(image_dir, anno_path): + """Filter valid image file, which both in image_dir and anno_path.""" + image_files = [] + image_anno_dict = {} + if not os.path.isdir(image_dir): + raise RuntimeError("Path given is not valid.") + if not os.path.isfile(anno_path): + raise RuntimeError("Annotation file is not valid.") + + with open(anno_path, "rb") as f: + lines = f.readlines() + for line in lines: + line_str = line.decode("utf-8").strip() + line_split = str(line_str).split(' ') + file_name = line_split[0] + if os.path.isfile(os.path.join(image_dir, file_name)): + image_anno_dict[file_name] = anno_parser(line_split[1:]) + image_files.append(file_name) + return image_files, image_anno_dict + + +def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix, file_num): + """Create MindRecord file by image_dir and anno_path.""" + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + image_files, image_anno_dict = filter_valid_data(image_dir, anno_path) + + yolo_json = { + "image": {"type": "bytes"}, + "annotation": {"type": "int64", "shape": [-1, 5]}, + } + writer.add_schema(yolo_json, "yolo_json") + + for image_name in image_files: + image_path = os.path.join(image_dir, image_name) + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[image_name]) + row = {"image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=1, rank=0, + is_training=True, num_parallel_workers=8): + """Creatr YOLOv3 dataset with MindDataset.""" + ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, + num_parallel_workers=num_parallel_workers, shuffle=is_training) + decode = C.Decode() + ds = ds.map(input_columns=["image"], operations=decode) + compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) + + if is_training: + hwc_to_chw = C.HWC2CHW() + ds = ds.map(input_columns=["image", "annotation"], + output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], + columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], + operations=compose_map_func, num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + else: + ds = ds.map(input_columns=["image", "annotation"], + output_columns=["image", "image_shape", "annotation"], + columns_order=["image", "image_shape", "annotation"], + operations=compose_map_func, num_parallel_workers=num_parallel_workers) + return ds diff --git a/model_zoo/yolov3_resnet18/src/utils.py b/model_zoo/official/cv/yolov3_resnet18/src/utils.py similarity index 100% rename from model_zoo/yolov3_resnet18/src/utils.py rename to model_zoo/official/cv/yolov3_resnet18/src/utils.py diff --git a/model_zoo/yolov3_resnet18/src/yolov3.py b/model_zoo/official/cv/yolov3_resnet18/src/yolov3.py similarity index 100% rename from model_zoo/yolov3_resnet18/src/yolov3.py rename to model_zoo/official/cv/yolov3_resnet18/src/yolov3.py diff --git a/model_zoo/official/cv/yolov3_resnet18/train.py b/model_zoo/official/cv/yolov3_resnet18/train.py new file mode 100644 index 0000000000..761a94c454 --- /dev/null +++ b/model_zoo/official/cv/yolov3_resnet18/train.py @@ -0,0 +1,162 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +######################## train YOLOv3 example ######################## +train YOLOv3 and get network model files(.ckpt) : +python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train + +If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path. +Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path. +""" + +import os +import argparse +import numpy as np +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.communication.management import init +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor +from mindspore.train import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common.initializer import initializer + +from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper +from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image +from src.config import ConfigYOLOV3ResNet18 + + +def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): + """Set learning rate.""" + lr_each_step = [] + for i in range(global_step): + if steps: + lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step))) + else: + lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step))) + lr_each_step = np.array(lr_each_step).astype(np.float32) + lr_each_step = lr_each_step[start_step:] + return lr_each_step + + +def init_net_param(network, init_value='ones'): + """Init:wq the parameters in network.""" + params = network.trainable_params() + for p in params: + if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: + p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype)) + + +def main(): + parser = argparse.ArgumentParser(description="YOLOv3 train") + parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " + "Mindrecord, default is false.") + parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--lr", type=float, default=0.001, help="Learning rate, default is 0.001.") + parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") + parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") + parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path") + parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size") + parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") + parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") + parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", + help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by" + "image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir " + "rather than image_dir and anno_path. Default is ./Mindrecord_train") + parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, " + "the absolute image path is joined by the image_dir " + "and the relative path in anno_path") + parser.add_argument("--anno_path", type=str, default="", help="Annotation path.") + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + if args_opt.distribute: + device_num = args_opt.device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=device_num) + init() + rank = args_opt.device_id % device_num + else: + rank = 0 + device_num = 1 + + print("Start create dataset!") + + # It will generate mindrecord file in args_opt.mindrecord_dir, + # and the file name is yolo.mindrecord0, 1, ... file_num. + if not os.path.isdir(args_opt.mindrecord_dir): + os.makedirs(args_opt.mindrecord_dir) + + prefix = "yolo.mindrecord" + mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0") + if not os.path.exists(mindrecord_file): + if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path): + print("Create Mindrecord.") + data_to_mindrecord_byte_image(args_opt.image_dir, + args_opt.anno_path, + args_opt.mindrecord_dir, + prefix, + 8) + print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir)) + else: + print("image_dir or anno_path not exits.") + + if not args_opt.only_create_dataset: + loss_scale = float(args_opt.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. + dataset = create_yolo_dataset(mindrecord_file, + batch_size=args_opt.batch_size, device_num=device_num, rank=rank) + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + net = yolov3_resnet18(ConfigYOLOV3ResNet18()) + net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) + init_net_param(net, "XavierUniform") + + # checkpoint + ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) + ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) + + if args_opt.pre_trained: + if args_opt.pre_trained_epoch_size <= 0: + raise KeyError("pre_trained_epoch_size must be greater than 0.") + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + total_epoch_size = 60 + if args_opt.distribute: + total_epoch_size = 160 + lr = Tensor(get_lr(learning_rate=args_opt.lr, start_step=args_opt.pre_trained_epoch_size * dataset_size, + global_step=total_epoch_size * dataset_size, + decay_step=1000, decay_rate=0.95, steps=True)) + opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) + net = TrainingWrapper(net, opt, loss_scale) + + callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] + + model = Model(net) + dataset_sink_mode = False + if args_opt.mode == "sink": + print("In sink mode, one epoch return a loss.") + dataset_sink_mode = True + print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") + model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) + +if __name__ == '__main__': + main() diff --git a/model_zoo/gat/README.md b/model_zoo/official/gnn/gat/README.md similarity index 100% rename from model_zoo/gat/README.md rename to model_zoo/official/gnn/gat/README.md diff --git a/model_zoo/gat/scripts/run_process_data.sh b/model_zoo/official/gnn/gat/scripts/run_process_data.sh similarity index 100% rename from model_zoo/gat/scripts/run_process_data.sh rename to model_zoo/official/gnn/gat/scripts/run_process_data.sh diff --git a/model_zoo/gat/scripts/run_train.sh b/model_zoo/official/gnn/gat/scripts/run_train.sh similarity index 100% rename from model_zoo/gat/scripts/run_train.sh rename to model_zoo/official/gnn/gat/scripts/run_train.sh diff --git a/model_zoo/official/gnn/gat/src/__init__.py b/model_zoo/official/gnn/gat/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/gat/src/config.py b/model_zoo/official/gnn/gat/src/config.py similarity index 100% rename from model_zoo/gat/src/config.py rename to model_zoo/official/gnn/gat/src/config.py diff --git a/model_zoo/gat/src/dataset.py b/model_zoo/official/gnn/gat/src/dataset.py similarity index 100% rename from model_zoo/gat/src/dataset.py rename to model_zoo/official/gnn/gat/src/dataset.py diff --git a/model_zoo/official/gnn/gat/src/gat.py b/model_zoo/official/gnn/gat/src/gat.py new file mode 100644 index 0000000000..ff0c964e9b --- /dev/null +++ b/model_zoo/official/gnn/gat/src/gat.py @@ -0,0 +1,496 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Aggregator.""" +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore._extends import cell_attr_register +from mindspore import Tensor, Parameter +from mindspore.common.initializer import initializer +from mindspore._checkparam import check_int_positive, check_bool +from mindspore.nn.layer.activation import get_activation + + +class GNNFeatureTransform(nn.Cell): + r""" + The GNN featuren transform layer for input. + + Applies linear transformation for the input feature. This layer implements the operation as: + + .. math:: + \text{outputs} = \text{inputs} * \text{kernel} + \text{bias}, + + where :math:`\text{activation}` is the activation function passed as the activation + argument (if passed in),:math:`\text{activation}` is a weight matrix with the same + data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector + with the same data type as the inputs created by the layer (only if has_bias is True). + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + + Raises: + ValueError: If weight_init or bias_init shape is incorrect. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`, + where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the + size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(*B, N, M)`. + + Examples: + >>> net = nn.Dense(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net(input) + [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] + [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] + """ + @cell_attr_register + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True): + super(GNNFeatureTransform, self).__init__() + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = check_bool(has_bias) + + if isinstance(weight_init, Tensor): + if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + weight_init.shape[1] != in_channels: + raise ValueError("weight_init shape error") + + self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + + if self.has_bias: + if isinstance(bias_init, Tensor): + if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + raise ValueError("bias_init shape error") + + self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + + def construct(self, x): + tensor_shape = F.shape(x) + input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2])) + output = self.matmul(input_feature, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels)) + return output + + def extend_repr(self): + str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ + .format(self.in_channels, self.out_channels, self.weight, self.has_bias) + if self.has_bias: + str_info = str_info + ', bias={}'.format(self.bias) + + return str_info + + +class _BaseAggregator(nn.Cell): + """ + Base Aggregator of GNN + + Args: + feature_in_dim (int): Node or edge input feature dim. + feature_out_dim (int): Node or edge outpout feature dim. + use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + + Examples: + >>> class MyAggregator(_BaseAggregator): + >>> def __init__(self): + >>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim) + >>> self.reduce_mean = P.ReduceSum() + >>> + >>> def construct(self, x): + >>> return self.reduce_mean(x, 1) + """ + def __init__(self, + feature_in_dim, + feature_out_dim, + use_fc=True, + weight_init="normal", + bias_init="zeros", + has_bias=True, + dropout_ratio=None, + activation=None): + super(_BaseAggregator, self).__init__() + self.in_dim = feature_in_dim + self.out_dim = feature_out_dim + self.use_fc = use_fc + if self.use_fc: + self.weight_init = weight_init + self.bias_init = bias_init + self.has_bias = has_bias + self.fc = GNNFeatureTransform(self.in_dim, + self.out_dim, + weight_init=self.weight_init, + bias_init=self.bias_init, + has_bias=self.has_bias) + self.dropout_ratio = dropout_ratio + if self.dropout_ratio is not None: + self.dropout = nn.Dropout(keep_prob=self.dropout_ratio) + self.dropout_flag = self.dropout_ratio is not None + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + + def construct(self, **kward): + """Must be overridden by all subclasses.""" + raise NotImplementedError + + +class MeanAggregator(_BaseAggregator): + """ + Mean Aggregator of GNN + + Args: + feature_in_dim (int): Node or edge input feature dim. + feature_out_dim (int): Node or edge outpout feature dim. + use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + + Examples: + >>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5) + >>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32)) + >>> output = net(input_data) + """ + def __init__(self, + feature_in_dim, + feature_out_dim, + use_fc=True, + weight_init="normal", + bias_init="zeros", + has_bias=True, + dropout_ratio=None, + activation=None): + super(MeanAggregator, self).__init__( + feature_in_dim, + feature_out_dim, + use_fc, + weight_init, + bias_init, + has_bias, + dropout_ratio, + activation) + self.reduce_mean = P.ReduceMean(keep_dims=False) + + def construct(self, input_feature): + if self.use_fc: + input_feature = self.fc(input_feature) + if self.dropout_flag: + input_feature = self.dropout(input_feature) + if self.activation_flag: + input_feature = self.activation(input_feature) + output_feature = self.reduce_mean(input_feature, 1) + return output_feature + + +class AttentionHead(nn.Cell): + """ + Attention Head for Graph Attention Networks. + + Args: + in_channel (int): The number of input channel, input feature dim. + out_channel (int): The number of output channel, output feature dim. + in_drop_ratio (float): Input feature dropout ratio, default 0.0. + coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. + residual (bool): Whether to use residual connection, default False. + coef_activation (Cell): The attention coefficient activation function, + default nn.LeakyReLU(). + activation (Cell): The output activation function, default nn.ELU(). + + Inputs: + - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). + - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). + + Examples: + >>> head = AttentionHead(1433, + 8, + in_drop_ratio=0.6, + coef_drop_ratio=0.6, + residual=False) + >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32)) + >>> output = net(input_data) + """ + + def __init__(self, + in_channel, + out_channel, + in_drop_ratio=0.0, + coef_drop_ratio=0.0, + residual=False, + coef_activation=nn.LeakyReLU(), + activation=nn.ELU()): + super(AttentionHead, self).__init__() + self.in_channel = check_int_positive(in_channel) + self.out_channel = check_int_positive(out_channel) + self.in_drop_ratio = in_drop_ratio + self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) + self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) + self.feature_transform = GNNFeatureTransform( + in_channels=self.in_channel, + out_channels=self.out_channel, + has_bias=False, + weight_init='XavierUniform') + + self.f_1_transform = GNNFeatureTransform( + in_channels=self.out_channel, + out_channels=1, + weight_init='XavierUniform') + self.f_2_transform = GNNFeatureTransform( + in_channels=self.out_channel, + out_channels=1, + weight_init='XavierUniform') + self.softmax = nn.Softmax() + + self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio) + self.matmul = P.MatMul() + self.bias_add = P.BiasAdd() + self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') + self.residual = check_bool(residual) + if self.residual: + if in_channel != out_channel: + self.residual_transform_flag = True + self.residual_transform = GNNFeatureTransform( + in_channels=self.in_channel, + out_channels=self.out_channel) + else: + self.residual_transform = None + self.coef_activation = coef_activation + self.activation = activation + + def construct(self, input_feature, bias_mat, training=True): + if training is True: + input_feature = self.in_drop(input_feature) + + feature = self.feature_transform(input_feature) + # self attention + f_1 = self.f_1_transform(feature) + f_2 = self.f_2_transform(feature) + logits = f_1 + P.Transpose()(f_2, (0, 2, 1)) + logits = self.coef_activation(logits) + bias_mat + coefs = self.softmax(logits) + if training is True: + coefs = self.coef_drop(coefs) + feature = self.in_drop_2(feature) + + coefs = P.Squeeze(0)(coefs) + feature = P.Squeeze(0)(feature) + + ret = self.matmul(coefs, feature) + ret = self.bias_add(ret, self.bias) + ret = P.ExpandDims()(ret, 0) + # residual connection + if self.residual: + if self.residual_transform_flag: + res = self.residual_transform(input_feature) + ret = ret + res + else: + ret = ret + input_feature + # activation + if self.activation is not None: + ret = self.activation(ret) + return ret + + +class AttentionAggregator(nn.Cell): + """ + Attention Head for Graph Attention Networks,can be regarded as one + GAT layer. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + num_heads (int): Number of attention heads for this layer, default 1. + in_drop_ratio (float): Input feature dropout ratio, default 0.0. + coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. + activation (Cell): The output activation function, default nn.ELU(). + residual (bool): Whether to use residual connection, default False. + output_transform (str['concat', 'sum']): output transform for a layer, + default 'concat' + + Inputs: + - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). + - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). + + Examples: + >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) + >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) + >>> net = AttentionAggregator(1433, + 8, + 8) + >>> net(input_data, biases) + """ + def __init__(self, + in_channels, + out_channels, + num_heads=1, + in_drop=0.0, + coef_drop=0.0, + activation=nn.ELU(), + residual=False, + output_transform='concat'): + super(AttentionAggregator, self).__init__() + self.num_heads = num_heads + self.attns = [] + for _ in range(num_heads): + self.attns.append(AttentionHead(in_channels, + out_channels, + in_drop_ratio=in_drop, + coef_drop_ratio=coef_drop, + activation=activation, + residual=residual)) + self.attns = nn.layer.CellList(self.attns) + if output_transform == 'concat': + self.out_trans = P.Concat(-1) + elif output_transform == 'sum': + self.out_trans = P.AddN() + else: + raise ValueError("output_transform must be either 'concat' or 'sum'") + + def construct(self, input_data, bias_mat, training=True): + res = () + for i in range(self.num_heads): + res += (self.attns[i](input_data, bias_mat, training),) + return self.out_trans(res) + + +class GAT(nn.Cell): + """ + Graph Attention Network + + Args: + ftr_dims (int): Initial feature dimensions. + num_class (int): Num of class to identify. + num_nodes (int): Num of nodes in this graph. + hidden_units (list[int]): Num of hidden units at each layer. + num_heads (list[int]): Num of heads at each layer. + attn_drop (float): Drop out ratio of attention coefficient, + default 0.0. + ftr_drop (float): Drop out ratio of feature, default 0.0. + activation (Cell): Activation Function for output layer, default + nn.Elu(). + residual (bool): Whether to use residual connection between + intermediate layers, default False. + + Examples: + >>> ft_sizes = 1433 + >>> num_class = 7 + >>> num_nodes = 2708 + >>> hid_units = [8] + >>> n_heads = [8, 1] + >>> activation = nn.ELU() + >>> residual = False + >>> input_data = np.array(np.random.rand(1, 2708, 1433)) + >>> biases = np.array(np.random.rand(1, 2708, 2708)) + >>> net = GAT(ft_sizes, + num_class, + num_nodes, + hidden_units=hid_units, + num_heads=n_heads, + attn_drop=0.6, + ftr_drop=0.6, + activation=activation, + residual=residual) + >>> output = net(input_data, biases) + """ + + def __init__(self, + features, + biases, + ftr_dims, + num_class, + num_nodes, + hidden_units, + num_heads, + attn_drop=0.0, + ftr_drop=0.0, + activation=nn.ELU(), + residual=False): + super(GAT, self).__init__() + self.features = Tensor(features) + self.biases = Tensor(biases) + self.ftr_dims = check_int_positive(ftr_dims) + self.num_class = check_int_positive(num_class) + self.num_nodes = check_int_positive(num_nodes) + self.hidden_units = hidden_units + self.num_heads = num_heads + self.attn_drop = attn_drop + self.ftr_drop = ftr_drop + self.activation = activation + self.residual = check_bool(residual) + self.layers = [] + # first layer + self.layers.append(AttentionAggregator( + self.ftr_dims, + self.hidden_units[0], + self.num_heads[0], + self.ftr_drop, + self.attn_drop, + self.activation, + residual=False)) + # intermediate layer + for i in range(1, len(self.hidden_units)): + self.layers.append(AttentionAggregator( + self.hidden_units[i-1]*self.num_heads[i-1], + self.hidden_units[i], + self.num_heads[i], + self.ftr_drop, + self.attn_drop, + self.activation, + residual=self.residual)) + # output layer + self.layers.append(AttentionAggregator( + self.hidden_units[-1]*self.num_heads[-2], + self.num_class, + self.num_heads[-1], + self.ftr_drop, + self.attn_drop, + activation=None, + residual=False, + output_transform='sum')) + self.layers = nn.layer.CellList(self.layers) + + def construct(self, training=True): + input_data = self.features + bias_mat = self.biases + for cell in self.layers: + input_data = cell(input_data, bias_mat, training) + return input_data/self.num_heads[-1] diff --git a/model_zoo/gat/src/utils.py b/model_zoo/official/gnn/gat/src/utils.py similarity index 100% rename from model_zoo/gat/src/utils.py rename to model_zoo/official/gnn/gat/src/utils.py diff --git a/model_zoo/gat/train.py b/model_zoo/official/gnn/gat/train.py similarity index 100% rename from model_zoo/gat/train.py rename to model_zoo/official/gnn/gat/train.py diff --git a/model_zoo/gcn/README.md b/model_zoo/official/gnn/gcn/README.md similarity index 100% rename from model_zoo/gcn/README.md rename to model_zoo/official/gnn/gcn/README.md diff --git a/model_zoo/gcn/scripts/run_process_data.sh b/model_zoo/official/gnn/gcn/scripts/run_process_data.sh similarity index 100% rename from model_zoo/gcn/scripts/run_process_data.sh rename to model_zoo/official/gnn/gcn/scripts/run_process_data.sh diff --git a/model_zoo/gcn/scripts/run_train.sh b/model_zoo/official/gnn/gcn/scripts/run_train.sh similarity index 100% rename from model_zoo/gcn/scripts/run_train.sh rename to model_zoo/official/gnn/gcn/scripts/run_train.sh diff --git a/model_zoo/gcn/src/config.py b/model_zoo/official/gnn/gcn/src/config.py similarity index 100% rename from model_zoo/gcn/src/config.py rename to model_zoo/official/gnn/gcn/src/config.py diff --git a/model_zoo/gcn/src/dataset.py b/model_zoo/official/gnn/gcn/src/dataset.py similarity index 100% rename from model_zoo/gcn/src/dataset.py rename to model_zoo/official/gnn/gcn/src/dataset.py diff --git a/model_zoo/gcn/src/gcn.py b/model_zoo/official/gnn/gcn/src/gcn.py similarity index 100% rename from model_zoo/gcn/src/gcn.py rename to model_zoo/official/gnn/gcn/src/gcn.py diff --git a/model_zoo/gcn/src/metrics.py b/model_zoo/official/gnn/gcn/src/metrics.py similarity index 100% rename from model_zoo/gcn/src/metrics.py rename to model_zoo/official/gnn/gcn/src/metrics.py diff --git a/model_zoo/gcn/t-SNE_visualization_on_Cora.gif b/model_zoo/official/gnn/gcn/t-SNE_visualization_on_Cora.gif similarity index 100% rename from model_zoo/gcn/t-SNE_visualization_on_Cora.gif rename to model_zoo/official/gnn/gcn/t-SNE_visualization_on_Cora.gif diff --git a/model_zoo/gcn/train.py b/model_zoo/official/gnn/gcn/train.py similarity index 100% rename from model_zoo/gcn/train.py rename to model_zoo/official/gnn/gcn/train.py diff --git a/model_zoo/official/lite/.gitkeep b/model_zoo/official/lite/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/.gitkeep b/model_zoo/official/nlp/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md new file mode 100644 index 0000000000..95aed5e96a --- /dev/null +++ b/model_zoo/official/nlp/bert/README.md @@ -0,0 +1,214 @@ +# BERT Example +## Description +This example implements pre-training, fine-tuning and evaluation of [BERT-base](https://github.com/google-research/bert) and [BERT-NEZHA](https://github.com/huawei-noah/Pretrained-Language-Model). + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. +- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. +- Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository. +> Notes: + If you are running a fine-tuning or evaluation task, prepare a checkpoint from pre-train. + +## Running the Example +### Pre-Training +- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. + +- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. + + ``` bash + sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR + ``` +- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. + + ``` bash + sh scripts/run_distribute_pretrain.sh DATA_DIR RANK_TABLE_FILE + ``` + +### Fine-Tuning and Evaluation +- Including three kinds of task: Classification, NER(Named Entity Recognition) and SQuAD(Stanford Question Answering Dataset) + +- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`. + +- Classification task: Set task related hyperparameters in scripts/run_classifier.sh. +- Run `bash scripts/run_classifier.py` for fine-tuning of BERT-base and BERT-NEZHA model. + + ```bash + bash scripts/run_classifier.sh + ``` + +- NER task: Set task related hyperparameters in scripts/run_ner.sh. +- Run `bash scripts/run_ner.py` for fine-tuning of BERT-base and BERT-NEZHA model. + + ```bash + bash scripts/run_ner.sh + ``` + +- SQuAD task: Set task related hyperparameters in scripts/run_squad.sh. +- Run `bash scripts/run_squad.py` for fine-tuning of BERT-base and BERT-NEZHA model. + + ```bash + bash scripts/run_squad.sh + ``` + +## Usage +### Pre-Training +``` +usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] + [--enable_save_ckpt ENABLE_SAVE_CKPT] + [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] + [--save_checkpoint_steps N] [--save_checkpoint_num N] + [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] + +options: + --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" + --epoch_size epoch size: N, default is 1 + --device_num number of used devices: N, default is 1 + --device_id device id: N, default is 0 + --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" + --enable_lossscale enable lossscale: "true" | "false", default is "true" + --do_shuffle enable shuffle: "true" | "false", default is "true" + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --checkpoint_path path to save checkpoint files: PATH, default is "" + --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 + --save_checkpoint_num number for saving checkpoint files: N, default is 1 + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" +``` +### Fine-Tuning and Evaluation +``` +usage: run_ner.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] + [--assessment_method ASSESSMENT_METHOD] [--use_crf USE_CRF] + [--device_id N] [--epoch_num N] [--vocab_file_path VOCAB_FILE_PATH] + [--label2id_file_path LABEL2ID_FILE_PATH] + [--save_finetune_checkpoint_path SAVE_FINETUNE_CHECKPOINT_PATH] + [--load_pretrain_checkpoint_path LOAD_PRETRAIN_CHECKPOINT_PATH] + [--train_data_file_path TRAIN_DATA_FILE_PATH] + [--eval_data_file_path EVAL_DATA_FILE_PATH] + [--schema_file_path SCHEMA_FILE_PATH] +options: + --device_target targeted device to run task: Ascend | GPU + --do_train whether to run training on training set: true | false + --do_eval whether to run eval on dev set: true | false + --assessment_method assessment method to do evaluation: f1 | clue_benchmark + --use_crf whether to use crf to calculate loss: true | false + --device_id device id to run task + --epoch_num total number of training epochs to perform + --num_class number of classes to do labeling + --vocab_file_path the vocabulary file that the BERT model was trained on + --label2id_file_path label to id json file + --save_finetune_checkpoint_path path to save generated finetuning checkpoint + --load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) + --load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval + --train_data_file_path ner tfrecord for training. E.g., train.tfrecord + --eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result + --schema_file_path path to datafile schema file + +usage: run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] + [--device_id N] [--epoch_num N] [--num_class N] + [--vocab_file_path VOCAB_FILE_PATH] + [--eval_json_path EVAL_JSON_PATH] + [--save_finetune_checkpoint_path SAVE_FINETUNE_CHECKPOINT_PATH] + [--load_pretrain_checkpoint_path LOAD_PRETRAIN_CHECKPOINT_PATH] + [--load_finetune_checkpoint_path LOAD_FINETUNE_CHECKPOINT_PATH] + [--train_data_file_path TRAIN_DATA_FILE_PATH] + [--eval_data_file_path EVAL_DATA_FILE_PATH] + [--schema_file_path SCHEMA_FILE_PATH] +options: + --device_target targeted device to run task: Ascend | GPU + --do_train whether to run training on training set: true | false + --do_eval whether to run eval on dev set: true | false + --device_id device id to run task + --epoch_num total number of training epochs to perform + --num_class number of classes to classify, usually 2 for squad task + --vocab_file_path the vocabulary file that the BERT model was trained on + --eval_json_path path to squad dev json file + --save_finetune_checkpoint_path path to save generated finetuning checkpoint + --load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) + --load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval + --train_data_file_path squad tfrecord for training. E.g., train1.1.tfrecord + --eval_data_file_path squad tfrecord for predictions. E.g., dev1.1.tfrecord + --schema_file_path path to datafile schema file + +usage: run_classifier.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] + [--assessment_method ASSESSMENT_METHOD] [--device_id N] [--epoch_num N] [--num_class N] + [--save_finetune_checkpoint_path SAVE_FINETUNE_CHECKPOINT_PATH] + [--load_pretrain_checkpoint_path LOAD_PRETRAIN_CHECKPOINT_PATH] + [--load_finetune_checkpoint_path LOAD_FINETUNE_CHECKPOINT_PATH] + [--train_data_file_path TRAIN_DATA_FILE_PATH] + [--eval_data_file_path EVAL_DATA_FILE_PATH] + [--schema_file_path SCHEMA_FILE_PATH] +options: + --device_target targeted device to run task: Ascend | GPU + --do_train whether to run training on training set: true | false + --do_eval whether to run eval on dev set: true | false + --assessment_method assessment method to do evaluation: accuracy | f1 | mcc | spearman_correlation + --device_id device id to run task + --epoch_num total number of training epochs to perform + --num_class number of classes to do labeling + --save_finetune_checkpoint_path path to save generated finetuning checkpoint + --load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) + --load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval + --train_data_file_path tfrecord for training. E.g., train.tfrecord + --eval_data_file_path tfrecord for predictions. E.g., dev.tfrecord + --schema_file_path path to datafile schema file +``` +## Options and Parameters +It contains of parameters of BERT model and options for training, which is set in file `config.py` and `finetune_eval_config.py` respectively. +### Options: +``` +config.py: + bert_network version of BERT model: base | nezha, default is base + loss_scale_value initial value of loss scale: N, default is 2^32 + scale_factor factor used to update loss scale: N, default is 2 + scale_window steps for once updatation of loss scale: N, default is 1000 + optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" +``` + +### Parameters: +``` +Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): + batch_size batch size of input dataset: N, default is 16 + seq_length length of input sequence: N, default is 128 + vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 + hidden_size size of bert encoder layers: N, default is 768 + num_hidden_layers number of hidden layers: N, default is 12 + num_attention_heads number of attention heads: N, default is 12 + intermediate_size size of intermediate layer: N, default is 3072 + hidden_act activation function used: ACTIVATION, default is "gelu" + hidden_dropout_prob dropout probability for BertOutput: Q, default is 0.1 + attention_probs_dropout_prob dropout probability for BertAttention: Q, default is 0.1 + max_position_embeddings maximum length of sequences: N, default is 512 + type_vocab_size size of token type vocab: N, default is 16 + initializer_range initialization value of TruncatedNormal: Q, default is 0.02 + use_relative_positions use relative positions or not: True | False, default is False + input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True + token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True + dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 + compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 + +Parameters for optimizer: + AdamWeightDecay: + decay_steps steps of the learning rate decay: N + learning_rate value of learning rate: Q + end_learning_rate value of end learning rate: Q, must be positive + power power: Q + warmup_steps steps of the learning rate warm up: N + weight_decay weight decay: Q + eps term added to the denominator to improve numerical stability: Q + + Lamb: + decay_steps steps of the learning rate decay: N + learning_rate value of learning rate: Q + end_learning_rate value of end learning rate: Q + power power: Q + warmup_steps steps of the learning rate warm up: N + weight_decay weight decay: Q + + Momentum: + learning_rate value of learning rate: Q + momentum momentum for the moving average: Q +``` + diff --git a/model_zoo/bert/pretrain_eval.py b/model_zoo/official/nlp/bert/pretrain_eval.py similarity index 100% rename from model_zoo/bert/pretrain_eval.py rename to model_zoo/official/nlp/bert/pretrain_eval.py diff --git a/model_zoo/official/nlp/bert/run_classifier.py b/model_zoo/official/nlp/bert/run_classifier.py new file mode 100644 index 0000000000..d2278bbc3c --- /dev/null +++ b/model_zoo/official/nlp/bert/run_classifier.py @@ -0,0 +1,209 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation script. +''' + +import os +import argparse +from src.bert_for_finetune import BertFinetuneCell, BertCLS +from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.dataset import create_classification_dataset +from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation +from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import log as logger +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_cur_dir = os.getcwd() + +def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): + """ do train """ + if load_checkpoint_path == "": + raise ValueError("Pretrain model missed, finetune task must load pretrain model!") + steps_per_epoch = dataset.get_dataset_size() + # optimizer + if optimizer_cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.AdamWeightDecay.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + + optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + elif optimizer_cfg.optimizer == 'Lamb': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.Lamb.power) + optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) + elif optimizer_cfg.optimizer == 'Momentum': + optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, + momentum=optimizer_cfg.Momentum.momentum) + else: + raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") + + # load checkpoint into network + ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(network, param_dict) + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) + netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] + model.train(epoch_num, dataset, callbacks=callbacks) + +def eval_result_print(assessment_method="accuracy", callback=None): + """ print eval result """ + if assessment_method == "accuracy": + print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + elif assessment_method == "f1": + print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) + print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) + print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) + elif assessment_method == "mcc": + print("MCC {:.6f} ".format(callback.cal())) + elif assessment_method == "spearman_correlation": + print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + +def do_eval(dataset=None, network=None, num_class=2, assessment_method="accuracy", load_checkpoint_path=""): + """ do eval """ + if load_checkpoint_path == "": + raise ValueError("Finetune model missed, evaluation task must load finetune model!") + net_for_pretraining = network(bert_net_cfg, False, num_class) + net_for_pretraining.set_train(False) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(net_for_pretraining, param_dict) + model = Model(net_for_pretraining) + + if assessment_method == "accuracy": + callback = Accuracy() + elif assessment_method == "f1": + callback = F1(False, num_class) + elif assessment_method == "mcc": + callback = MCC() + elif assessment_method == "spearman_correlation": + callback = Spearman_Correlation() + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + logits = model.predict(input_ids, input_mask, token_type_id, label_ids) + callback.update(logits, label_ids) + print("==============================================================") + eval_result_print(assessment_method, callback) + print("==============================================================") + +def run_classifier(): + """run classifier task""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") + parser.add_argument("--assessment_method", type=str, default="accuracy", + help="assessment_method including [MCC, Spearman_correlation, Accuracy], default is accuracy") + parser.add_argument("--do_train", type=str, default="false", help="Enable train, default is false") + parser.add_argument("--do_eval", type=str, default="false", help="Enable eval, default is false") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") + parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--train_data_shuffle", type=str, default="true", + help="Enable train data shuffle, default is true") + parser.add_argument("--eval_data_shuffle", type=str, default="false", + help="Enable eval data shuffle, default is false") + parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--schema_file_path", type=str, default="", + help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + epoch_num = args_opt.epoch_num + assessment_method = args_opt.assessment_method.lower() + load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path + save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path + load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path + + if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": + raise ValueError("At least one of 'do_train' or 'do_eval' must be true") + if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": + raise ValueError("'train_data_file_path' must be set when do finetune task") + if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do evaluation task") + + target = args_opt.device_target + if target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") + + netwithloss = BertCLS(bert_net_cfg, True, num_labels=args_opt.num_class, dropout_prob=0.1, + assessment_method=assessment_method) + + if args_opt.do_train.lower() == "true": + ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + assessment_method=assessment_method, + data_file_path=args_opt.train_data_file_path, + schema_file_path=args_opt.schema_file_path, + do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) + do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) + + if args_opt.do_eval.lower() == "true": + if save_finetune_checkpoint_path == "": + load_finetune_checkpoint_dir = _cur_dir + else: + load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) + load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, + ds.get_dataset_size(), epoch_num, "classifier") + + if args_opt.do_eval.lower() == "true": + ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + assessment_method=assessment_method, + data_file_path=args_opt.eval_data_file_path, + schema_file_path=args_opt.schema_file_path, + do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) + do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path) + +if __name__ == "__main__": + run_classifier() diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py new file mode 100644 index 0000000000..b311950315 --- /dev/null +++ b/model_zoo/official/nlp/bert/run_ner.py @@ -0,0 +1,236 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation script. +''' + +import os +import json +import argparse +from src.bert_for_finetune import BertFinetuneCell, BertNER +from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.dataset import create_ner_dataset +from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate +from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import log as logger +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_cur_dir = os.getcwd() + + +def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): + """ do train """ + if load_checkpoint_path == "": + raise ValueError("Pretrain model missed, finetune task must load pretrain model!") + steps_per_epoch = dataset.get_dataset_size() + # optimizer + if optimizer_cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.AdamWeightDecay.power) + params = network.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + elif optimizer_cfg.optimizer == 'Lamb': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.Lamb.power) + optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) + elif optimizer_cfg.optimizer == 'Momentum': + optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, + momentum=optimizer_cfg.Momentum.momentum) + else: + raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") + + # load checkpoint into network + ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="ner", directory=save_checkpoint_path, config=ckpt_config) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(network, param_dict) + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) + netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] + model.train(epoch_num, dataset, callbacks=callbacks) + +def eval_result_print(assessment_method="accuracy", callback=None): + """print eval result""" + if assessment_method == "accuracy": + print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + elif assessment_method == "f1": + print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) + print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) + print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) + elif assessment_method == "mcc": + print("MCC {:.6f} ".format(callback.cal())) + elif assessment_method == "spearman_correlation": + print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + +def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="", + load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None): + """ do eval """ + if load_checkpoint_path == "": + raise ValueError("Finetune model missed, evaluation task must load finetune model!") + if assessment_method == "clue_benchmark": + bert_net_cfg.batch_size = 1 + net_for_pretraining = network(bert_net_cfg, False, num_class, use_crf=(use_crf.lower() == "true"), + tag_to_index=tag_to_index) + net_for_pretraining.set_train(False) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(net_for_pretraining, param_dict) + model = Model(net_for_pretraining) + + if assessment_method == "clue_benchmark": + from src.cluener_evaluation import submit + submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file) + else: + if assessment_method == "accuracy": + callback = Accuracy() + elif assessment_method == "f1": + callback = F1((use_crf.lower() == "true"), num_class) + elif assessment_method == "mcc": + callback = MCC() + elif assessment_method == "spearman_correlation": + callback = Spearman_Correlation() + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + logits = model.predict(input_ids, input_mask, token_type_id, label_ids) + callback.update(logits, label_ids) + print("==============================================================") + eval_result_print(assessment_method, callback) + print("==============================================================") + +def run_ner(): + """run ner task""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") + parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " + "[F1, clue_benchmark], default is F1") + parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") + parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") + parser.add_argument("--use_crf", type=str, default="false", help="Use crf, default is false") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") + parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--train_data_shuffle", type=str, default="true", + help="Enable train data shuffle, default is true") + parser.add_argument("--eval_data_shuffle", type=str, default="false", + help="Enable eval data shuffle, default is false") + parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark") + parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark") + parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--schema_file_path", type=str, default="", + help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + epoch_num = args_opt.epoch_num + assessment_method = args_opt.assessment_method.lower() + load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path + save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path + load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path + + if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": + raise ValueError("At least one of 'do_train' or 'do_eval' must be true") + if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": + raise ValueError("'train_data_file_path' must be set when do finetune task") + if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do evaluation task") + if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "": + raise ValueError("'vocab_file_path' must be set to do clue benchmark") + if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "": + raise ValueError("'label2id_file_path' must be set to use crf") + if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "": + raise ValueError("'label2id_file_path' must be set to do clue benchmark") + + target = args_opt.device_target + if target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") + + tag_to_index = None + if args_opt.use_crf.lower() == "true": + with open(args_opt.label2id_file_path) as json_file: + tag_to_index = json.load(json_file) + max_val = max(tag_to_index.values()) + tag_to_index[""] = max_val + 1 + tag_to_index[""] = max_val + 2 + number_labels = len(tag_to_index) + else: + number_labels = args_opt.num_class + netwithloss = BertNER(bert_net_cfg, True, num_labels=number_labels, + use_crf=(args_opt.use_crf.lower() == "true"), + tag_to_index=tag_to_index, dropout_prob=0.1) + if args_opt.do_train.lower() == "true": + ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, + schema_file_path=args_opt.schema_file_path, + do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) + do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) + + if args_opt.do_eval.lower() == "true": + if save_finetune_checkpoint_path == "": + load_finetune_checkpoint_dir = _cur_dir + else: + load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) + load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, + ds.get_dataset_size(), epoch_num, "ner") + + if args_opt.do_eval.lower() == "true": + ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, + schema_file_path=args_opt.schema_file_path, + do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) + do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, + load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index) + +if __name__ == "__main__": + run_ner() diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py new file mode 100644 index 0000000000..6b4cb1548a --- /dev/null +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################pre_train bert example on zh-wiki######################## +python run_pretrain.py +""" + +import os +import argparse +import numpy +import mindspore.communication.management as D +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.parallel_utils import ParallelMode +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay +from mindspore import log as logger +from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from src.dataset import create_bert_dataset +from src.config import cfg, bert_net_cfg +from src.utils import LossCallBack, BertLearningRate +_current_dir = os.path.dirname(os.path.realpath(__file__)) + + +def run_pretrain(): + """pre-train bert_clue""" + parser = argparse.ArgumentParser(description='bert pre_training') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") + parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") + parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") + parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " + "default is 1000.") + parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " + "meaning run all steps according to epoch number.") + parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") + parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + + args_opt = parser.parse_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(reserve_class_name_in_scope=False) + ckpt_save_dir = args_opt.save_checkpoint_path + if args_opt.distribute == "true": + if args_opt.device_target == 'Ascend': + D.init('hccl') + device_num = args_opt.device_num + rank = args_opt.device_id % device_num + else: + D.init('nccl') + device_num = D.get_group_size() + rank = D.get_rank() + ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' + + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=device_num) + from mindspore.parallel._auto_parallel_context import auto_parallel_context + if bert_net_cfg.num_hidden_layers == 12: + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) + elif bert_net_cfg.num_hidden_layers == 24: + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) + else: + rank = 0 + device_num = 1 + + if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: + logger.warning('Gpu only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + + ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) + net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) + + new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps + if args_opt.train_steps > 0: + new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) + else: + args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() + logger.info("train steps: {}".format(args_opt.train_steps)) + + if cfg.optimizer == 'Lamb': + lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, + end_learning_rate=cfg.Lamb.end_learning_rate, + warmup_steps=cfg.Lamb.warmup_steps, + decay_steps=args_opt.train_steps, + power=cfg.Lamb.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(cfg.Lamb.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, + {'params': other_params}, + {'order_params': params}] + optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) + elif cfg.optimizer == 'Momentum': + optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, + momentum=cfg.Momentum.momentum) + elif cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, + end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=cfg.AdamWeightDecay.warmup_steps, + decay_steps=args_opt.train_steps, + power=cfg.AdamWeightDecay.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) + else: + raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". + format(cfg.optimizer)) + callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] + if args_opt.enable_save_ckpt == "true": + config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, + keep_checkpoint_max=args_opt.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) + callback.append(ckpoint_cb) + + if args_opt.load_checkpoint_path: + param_dict = load_checkpoint(args_opt.load_checkpoint_path) + load_param_into_net(net_with_loss, param_dict) + + if args_opt.enable_lossscale == "true": + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) + else: + net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) + + model = Model(net_with_grads) + model.train(new_repeat_count, ds, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps) + + +if __name__ == '__main__': + numpy.random.seed(0) + run_pretrain() diff --git a/model_zoo/official/nlp/bert/run_squad.py b/model_zoo/official/nlp/bert/run_squad.py new file mode 100644 index 0000000000..a026408e7c --- /dev/null +++ b/model_zoo/official/nlp/bert/run_squad.py @@ -0,0 +1,213 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation script. +''' +import os +import argparse +import collections +from src.bert_for_finetune import BertSquadCell, BertSquad +from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.dataset import create_squad_dataset +from src import tokenization +from src.create_squad_data import read_squad_examples, convert_examples_to_features +from src.run_squad import write_predictions +from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import log as logger +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_cur_dir = os.getcwd() + +def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): + """ do train """ + if load_checkpoint_path == "": + raise ValueError("Pretrain model missed, finetune task must load pretrain model!") + steps_per_epoch = dataset.get_dataset_size() + # optimizer + if optimizer_cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.AdamWeightDecay.power) + params = network.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + + optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + elif optimizer_cfg.optimizer == 'Lamb': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.Lamb.power) + optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) + elif optimizer_cfg.optimizer == 'Momentum': + optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, + momentum=optimizer_cfg.Momentum.momentum) + else: + raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") + + # load checkpoint into network + ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="squad", directory=save_checkpoint_path, config=ckpt_config) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(network, param_dict) + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) + netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] + model.train(epoch_num, dataset, callbacks=callbacks) + + +def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", seq_length=384): + """ do eval """ + if load_checkpoint_path == "": + raise ValueError("Finetune model missed, evaluation task must load finetune model!") + tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) + eval_examples = read_squad_examples(eval_json, False) + eval_features = convert_examples_to_features( + examples=eval_examples, + tokenizer=tokenizer, + max_seq_length=seq_length, + doc_stride=128, + max_query_length=64, + is_training=False, + output_fn=None, + verbose_logging=False) + + net = BertSquad(bert_net_cfg, False, 2) + net.set_train(False) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(net, param_dict) + model = Model(net) + output = [] + RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) + columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] + for data in dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, segment_ids, unique_ids = input_data + start_positions = Tensor([1], mstype.float32) + end_positions = Tensor([1], mstype.float32) + is_impossible = Tensor([1], mstype.float32) + logits = model.predict(input_ids, input_mask, segment_ids, start_positions, + end_positions, unique_ids, is_impossible) + ids = logits[0].asnumpy() + start = logits[1].asnumpy() + end = logits[2].asnumpy() + + for i in range(bert_net_cfg.batch_size): + unique_id = int(ids[i]) + start_logits = [float(x) for x in start[i].flat] + end_logits = [float(x) for x in end[i].flat] + output.append(RawResult( + unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) + write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", None, None) + +def run_squad(): + """run squad task""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") + parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") + parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") + parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--train_data_shuffle", type=str, default="true", + help="Enable train data shuffle, default is true") + parser.add_argument("--eval_data_shuffle", type=str, default="false", + help="Enable eval data shuffle, default is false") + parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path") + parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json") + parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--schema_file_path", type=str, default="", + help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + epoch_num = args_opt.epoch_num + load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path + save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path + load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path + + if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": + raise ValueError("At least one of 'do_train' or 'do_eval' must be true") + if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": + raise ValueError("'train_data_file_path' must be set when do finetune task") + if args_opt.do_eval.lower() == "true": + if args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do evaluation task") + if args_opt.vocab_file_path == "": + raise ValueError("'vocab_file_path' must be set when do evaluation task") + if args_opt.eval_json_path == "": + raise ValueError("'tokenization_file_path' must be set when do evaluation task") + + + target = args_opt.device_target + if target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") + + netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) + + if args_opt.do_train.lower() == "true": + ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + data_file_path=args_opt.train_data_file_path, + schema_file_path=args_opt.schema_file_path, + do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) + do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) + if args_opt.do_eval.lower() == "true": + if save_finetune_checkpoint_path == "": + load_finetune_checkpoint_dir = _cur_dir + else: + load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) + load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, + ds.get_dataset_size(), epoch_num, "squad") + + if args_opt.do_eval.lower() == "true": + ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + data_file_path=args_opt.eval_data_file_path, + schema_file_path=args_opt.schema_file_path, is_training=False, + do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) + do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path, + load_finetune_checkpoint_path, bert_net_cfg.seq_length) + +if __name__ == "__main__": + run_squad() diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/README.md b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/README.md new file mode 100644 index 0000000000..b492c4c309 --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/README.md @@ -0,0 +1,48 @@ +# Run distribute pretrain + +## description +The number of D chips can be automatically allocated based on the device_num set in hccl config file, You don not need to specify that. + + +## how to use +For example, if we want to run the distributed training of Bert model on D chip, we can in `/bert/` dir: +``` +python ./scripts/ascend_distributed_launcher/run_distribute_pretrain.py --run_script_dir ./run_pretrain.py --hyper_parameter_config_dir ./scripts/ascend_distributed_launcher/hyper_parameter_config.ini --data_dir /path/dataset/ --hccl_config_dir model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json +``` + +output: + +``` +hccl_config_dir: model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json +the number of logical core: 192 +avg_core_per_rank: 96 +rank_size: 2 + +start training for rank 0, device 5: +rank_id: 0 +device_id: 5 +core nums: 0-95 +epoch_size: 8 +data_dir: /data/small_512/ +schema_dir: +log file dir: ./LOG5/log.txt + +start training for rank 1, device 6: +rank_id: 1 +device_id: 6 +core nums: 96-191 +epoch_size: 8 +data_dir: /data/small_512/ +schema_dir: +log file dir: ./LOG6/log.txt +``` + +## Note + +1. Note that `hccl_2p_56_x.x.x.x.json` can use [hccl_tools.py](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate. + +2. For hyper parameter, please note that you should customize the scripts `hyper_parameter_config.ini`. Please note that these two hyper parameters are not allowed to be configured here: + device_id + device_num + +3. For Other Model, please note that you should customize the option `run_script` and Corresponding `hyper_parameter_config.ini`. diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/__init__.py b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini new file mode 100644 index 0000000000..2298f83509 --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini @@ -0,0 +1,11 @@ +[config] +distribute=true +epoch_size=40 +enable_save_ckpt=true +enable_lossscale=true +do_shuffle=true +enable_data_sink=true +data_sink_steps=100 +save_checkpoint_path=./checkpoint/ +save_checkpoint_steps=10000 +save_checkpoint_num=1 \ No newline at end of file diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/run_distribute_pretrain.py b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/run_distribute_pretrain.py new file mode 100644 index 0000000000..32c3bb8038 --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/run_distribute_pretrain.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""distribute pretrain script""" +import os +import json +import configparser +import multiprocessing +from argparse import ArgumentParser + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training") + + parser.add_argument("--run_script_dir", type=str, default="", + help="Run script path, it is better to use absolute path") + parser.add_argument("--hyper_parameter_config_dir", type=str, default="", + help="Hyper Parameter config path, it is better to use absolute path") + parser.add_argument("--data_dir", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--hccl_config_dir", type=str, default="", + help="Hccl config path, it is better to use absolute path") + + args = parser.parse_args() + return args + + +def distribute_pretrain(): + """ + distribute pretrain scripts. The number of D chips can be automatically allocated + based on the device_num set in hccl config file, You don not need to specify that. + """ + print("start", __file__) + args = parse_args() + + run_script = args.run_script_dir + data_dir = args.data_dir + cf = configparser.ConfigParser() + cf.read(args.hyper_parameter_config_dir) + cfg = dict(cf.items("config")) + + print("hccl_config_dir:", args.hccl_config_dir) + os.environ['RANK_TABLE_FILE'] = args.hccl_config_dir + + cores = multiprocessing.cpu_count() + print("the number of logical core:", cores) + + # get device_ips + device_ips = {} + with open('/etc/hccn.conf', 'r') as fin: + for hccn_item in fin.readlines(): + if hccn_item.strip().startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip.strip() + + with open(args.hccl_config_dir, "r", encoding="utf-8") as fin: + hccl_config = json.loads(fin.read()) + rank_size = 0 + for server in hccl_config["server_list"]: + rank_size += len(server["device"]) + if server["device"][0]["device_ip"] in device_ips.values(): + this_server = server + + os.environ['RANK_SIZE'] = str(rank_size) + print("total rank size:", rank_size) + print("this server rank size:", len(this_server["device"])) + avg_core_per_rank = int(int(cores) / len(this_server["device"])) + core_gap = avg_core_per_rank - 1 + print("avg_core_per_rank:", avg_core_per_rank) + + count = 0 + for instance in this_server["device"]: + device_id = instance["device_id"] + rank_id = instance["rank_id"] + print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":") + print("rank_id:", rank_id) + print("device_id:", device_id) + + start = count * int(avg_core_per_rank) + count += 1 + end = start + core_gap + cmdopt = str(start) + "-" + str(end) + + os.environ["DEVICE_ID"] = device_id + os.environ["RANK_ID"] = rank_id + os.environ["DEPLOY_MODE"] = "0" + os.environ["GE_USE_STATIC_MEMORY"] = "1" + + os.system("rm -rf LOG" + str(device_id)) + os.system("mkdir ./LOG" + str(device_id)) + os.system("cp *.py ./LOG" + str(device_id)) + os.system("mkdir -p ./LOG" + str(device_id) + "/ms_log") + os.system("env > ./LOG" + str(device_id) + "/env.log") + + cur_dir = os.getcwd() + os.environ["GLOG_log_dir"] = cur_dir + "/LOG" + str(device_id) + "/ms_log" + os.environ["GLOG_logtostderr"] = "0" + + print("core_nums:", cmdopt) + print("epoch_size:", str(cfg['epoch_size'])) + print("data_dir:", data_dir) + print("log_file_dir: " + cur_dir + "/LOG" + str(device_id) + "/log.txt") + + os.chdir(cur_dir + "/LOG" + str(device_id)) + cmd = 'taskset -c ' + cmdopt + ' python ' + run_script + " " + opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()]) + if ('device_id' in opt) or ('device_num' in opt) or ('data_dir' in opt): + raise ValueError("hyper_parameter_config.ini can not setting 'device_id'," + " 'device_num' or 'data_dir'! ") + cmd += opt + cmd += " --data_dir=" + data_dir + cmd += ' --device_id=' + str(device_id) + ' --device_num=' \ + + str(rank_size) + ' >./log.txt 2>&1 &' + + os.system(cmd) + os.chdir(cur_dir) + +if __name__ == "__main__": + distribute_pretrain() diff --git a/model_zoo/official/nlp/bert/scripts/run_classifier.sh b/model_zoo/official/nlp/bert/scripts/run_classifier.sh new file mode 100644 index 0000000000..39516fa419 --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/run_classifier.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_classifier.sh" +echo "for example: bash scripts/run_classifier.sh" +echo "assessment_method include: [MCC, Spearman_correlation ,Accuracy]" +echo "==============================================================================================================" + +mkdir -p ms_log +CUR_DIR=`pwd` +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_classifier.py \ + --device_target="Ascend" \ + --do_train="true" \ + --do_eval="false" \ + --assessment_method="Accuracy" \ + --device_id=0 \ + --epoch_num=1 \ + --num_class=2 \ + --train_data_shuffle="true" \ + --eval_data_shuffle="false" \ + --save_finetune_checkpoint_path="" \ + --load_pretrain_checkpoint_path="" \ + --load_finetune_checkpoint_path="" \ + --train_data_file_path="" \ + --eval_data_file_path="" \ + --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/bert/scripts/run_distribute_pretrain.sh b/model_zoo/official/nlp/bert/scripts/run_distribute_pretrain.sh new file mode 100644 index 0000000000..be910fb844 --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/run_distribute_pretrain.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_distribute_pretrain.sh DATA_DIR RANK_TABLE_FILE" +echo "for example: bash run_distribute_pretrain.sh /path/dataset /path/hccl.json" +echo "It is better to use absolute path." +echo "For hyper parameter, please note that you should customize the scripts: + '{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' " +echo "==============================================================================================================" +CUR_DIR=`pwd` + +python ${CUR_DIR}/scripts/ascend_distributed_launcher/run_distribute_pretrain.py \ + --run_script_dir=${CUR_DIR}/run_pretrain.py \ + --hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \ + --data_dir=$1 \ + --hccl_config_dir=$2 diff --git a/model_zoo/bert/scripts/run_distribute_pretrain_for_gpu.sh b/model_zoo/official/nlp/bert/scripts/run_distribute_pretrain_for_gpu.sh similarity index 100% rename from model_zoo/bert/scripts/run_distribute_pretrain_for_gpu.sh rename to model_zoo/official/nlp/bert/scripts/run_distribute_pretrain_for_gpu.sh diff --git a/model_zoo/official/nlp/bert/scripts/run_ner.sh b/model_zoo/official/nlp/bert/scripts/run_ner.sh new file mode 100644 index 0000000000..45c37be653 --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/run_ner.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_ner.sh" +echo "for example: bash scripts/run_ner.sh" +echo "assessment_method include: [F1, clue_benchmark]" +echo "==============================================================================================================" + +mkdir -p ms_log +CUR_DIR=`pwd` +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_ner.py \ + --device_target="Ascend" \ + --do_train="true" \ + --do_eval="false" \ + --assessment_method="F1" \ + --use_crf="false" \ + --device_id=0 \ + --epoch_num=1 \ + --num_class=2 \ + --train_data_shuffle="true" \ + --eval_data_shuffle="false" \ + --vocab_file_path="" \ + --label2id_file_path="" \ + --save_finetune_checkpoint_path="" \ + --load_pretrain_checkpoint_path="" \ + --load_finetune_checkpoint_path="" \ + --train_data_file_path="" \ + --eval_data_file_path="" \ + --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/bert/scripts/run_squad.sh b/model_zoo/official/nlp/bert/scripts/run_squad.sh new file mode 100644 index 0000000000..efca61db1d --- /dev/null +++ b/model_zoo/official/nlp/bert/scripts/run_squad.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_squad.sh" +echo "for example: bash scripts/run_squad.sh" +echo "assessment_method include: [Accuracy]" +echo "==============================================================================================================" + +mkdir -p ms_log +CUR_DIR=`pwd` +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_squad.py \ + --device_target="Ascend" \ + --do_train="true" \ + --do_eval="false" \ + --device_id=0 \ + --epoch_num=1 \ + --num_class=2 \ + --train_data_shuffle="true" \ + --eval_data_shuffle="false" \ + --vocab_file_path="" \ + --eval_json_path="" \ + --save_finetune_checkpoint_path="" \ + --load_pretrain_checkpoint_path="" \ + --load_finetune_checkpoint_path="" \ + --train_data_file_path="" \ + --eval_data_file_path="" \ + --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/scripts/run_standalone_pretrain.sh b/model_zoo/official/nlp/bert/scripts/run_standalone_pretrain.sh similarity index 100% rename from model_zoo/bert/scripts/run_standalone_pretrain.sh rename to model_zoo/official/nlp/bert/scripts/run_standalone_pretrain.sh diff --git a/model_zoo/bert/scripts/run_standalone_pretrain_for_gpu.sh b/model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_for_gpu.sh similarity index 100% rename from model_zoo/bert/scripts/run_standalone_pretrain_for_gpu.sh rename to model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_for_gpu.sh diff --git a/model_zoo/bert/src/CRF.py b/model_zoo/official/nlp/bert/src/CRF.py similarity index 100% rename from model_zoo/bert/src/CRF.py rename to model_zoo/official/nlp/bert/src/CRF.py diff --git a/model_zoo/bert/src/__init__.py b/model_zoo/official/nlp/bert/src/__init__.py similarity index 100% rename from model_zoo/bert/src/__init__.py rename to model_zoo/official/nlp/bert/src/__init__.py diff --git a/model_zoo/official/nlp/bert/src/assessment_method.py b/model_zoo/official/nlp/bert/src/assessment_method.py new file mode 100644 index 0000000000..dae4894129 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/assessment_method.py @@ -0,0 +1,133 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert evaluation assessment method script. +''' +import math +import numpy as np +from .CRF import postprocess + +class Accuracy(): + ''' + calculate accuracy + ''' + def __init__(self): + self.acc_num = 0 + self.total_num = 0 + def update(self, logits, labels): + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + self.acc_num += np.sum(labels == logit_id) + self.total_num += len(labels) + +class F1(): + ''' + calculate F1 score + ''' + def __init__(self, use_crf=False, num_labels=2): + self.TP = 0 + self.FP = 0 + self.FN = 0 + self.use_crf = use_crf + self.num_labels = num_labels + + def update(self, logits, labels): + ''' + update F1 score + ''' + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + if self.use_crf: + backpointers, best_tag_id = logits + best_path = postprocess(backpointers, best_tag_id) + logit_id = [] + for ele in best_path: + logit_id.extend(ele) + else: + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) + pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) + self.TP += np.sum(pos_eva&pos_label) + self.FP += np.sum(pos_eva&(~pos_label)) + self.FN += np.sum((~pos_eva)&pos_label) + +class MCC(): + ''' + Calculate Matthews Correlation Coefficient + ''' + def __init__(self): + self.TP = 0 + self.FP = 0 + self.FN = 0 + self.TN = 0 + def update(self, logits, labels): + ''' + MCC update + ''' + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + labels = labels.astype(np.bool) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + logit_id = logit_id.astype(np.bool) + ornot = logit_id ^ labels + + self.TP += (~ornot & labels).sum() + self.FP += (ornot & ~labels).sum() + self.FN += (ornot & labels).sum() + self.TN += (~ornot & ~labels).sum() + + def cal(self): + mcc = (self.TP*self.TN - self.FP*self.FN)/math.sqrt((self.TP+self.FP)*(self.TP+self.FN) * + (self.TN+self.FP)*(self.TN+self.FN)) + return mcc + +class Spearman_Correlation(): + ''' + Calculate Spearman Correlation Coefficient + ''' + def __init__(self): + self.label = [] + self.logit = [] + + def update(self, logits, labels): + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logits = np.reshape(logits, -1) + self.label.append(labels) + self.logit.append(logits) + + def cal(self): + ''' + Calculate Spearman Correlation + ''' + label = np.concatenate(self.label) + logit = np.concatenate(self.logit) + sort_label = label.argsort()[::-1] + sort_logit = logit.argsort()[::-1] + n = len(label) + d_acc = 0 + for i in range(n): + d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] + d_acc += d**2 + ps = 1 - 6*d_acc/n/(n**2-1) + return ps diff --git a/model_zoo/official/nlp/bert/src/bert_for_finetune.py b/model_zoo/official/nlp/bert/src/bert_for_finetune.py new file mode 100644 index 0000000000..5fbf1d81b9 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/bert_for_finetune.py @@ -0,0 +1,327 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert for finetune script. +''' + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.communication.management import get_group_size +from mindspore import context +from .bert_for_pre_training import clip_grad +from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel +from .utils import CrossEntropyCalculation + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + +class BertFinetuneCell(nn.Cell): + """ + Especifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + + super(BertFinetuneCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.gpu_target = False + if context.get_context("device_target") == "GPU": + self.gpu_target = True + self.float_status = P.FloatStatus() + self.addn = P.AddN() + self.reshape = P.Reshape() + else: + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids, + sens=None): + + + weights = self.weights + init = False + loss = self.network(input_ids, + input_mask, + token_type_id, + label_ids) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + if not self.gpu_target: + init = self.alloc_status() + clear_before_grad = self.clear_before_grad(init) + F.control_depend(loss, init) + self.depend_parameter_use(clear_before_grad, scaling_sens) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + label_ids, + self.cast(scaling_sens, + mstype.float32)) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + grads = self.grad_reducer(grads) + if not self.gpu_target: + flag = self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + F.control_depend(grads, flag) + F.control_depend(flag, flag_sum) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) + flag_sum = self.addn(flag_sum) + flag_sum = self.reshape(flag_sum, (())) + if self.is_distributed: + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond) + return F.depend(ret, succ) + +class BertSquadCell(nn.Cell): + """ + specifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertSquadCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + def construct(self, + input_ids, + input_mask, + token_type_id, + start_position, + end_position, + unique_id, + is_impossible, + sens=None): + weights = self.weights + init = self.alloc_status() + loss = self.network(input_ids, + input_mask, + token_type_id, + start_position, + end_position, + unique_id, + is_impossible) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + start_position, + end_position, + unique_id, + is_impossible, + self.cast(scaling_sens, + mstype.float32)) + clear_before_grad = self.clear_before_grad(init) + F.control_depend(loss, init) + self.depend_parameter_use(clear_before_grad, scaling_sens) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + grads = self.grad_reducer(grads) + flag = self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + F.control_depend(grads, flag) + F.control_depend(flag, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond) + return F.depend(ret, succ) + +class BertCLS(nn.Cell): + """ + Train interface for classification finetuning task. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, + assessment_method=""): + super(BertCLS, self).__init__() + self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings, + assessment_method) + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.assessment_method = assessment_method + self.is_training = is_training + def construct(self, input_ids, input_mask, token_type_id, label_ids): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.assessment_method == "spearman_correlation": + if self.is_training: + loss = self.loss(logits, label_ids) + else: + loss = logits + else: + loss = self.loss(logits, label_ids, self.num_labels) + return loss + + +class BertNER(nn.Cell): + """ + Train interface for sequence labeling finetuning task. + """ + def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, + use_one_hot_embeddings=False): + super(BertNER, self).__init__() + self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) + if use_crf: + if not tag_to_index: + raise Exception("The dict for tag-index mapping should be provided for CRF.") + from src.CRF import CRF + self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) + else: + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.use_crf = use_crf + def construct(self, input_ids, input_mask, token_type_id, label_ids): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.use_crf: + loss = self.loss(logits, label_ids) + else: + loss = self.loss(logits, label_ids, self.num_labels) + return loss + +class BertSquad(nn.Cell): + ''' + Train interface for SQuAD finetuning task. + ''' + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertSquad, self).__init__() + self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.seq_length = config.seq_length + self.is_training = is_training + self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') + self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') + self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') + self.sum = P.ReduceSum() + self.equal = P.Equal() + self.argmax = P.ArgMaxWithValue(axis=1) + self.squeeze = P.Squeeze(axis=-1) + + def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.is_training: + unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) + unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) + start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) + end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) + total_loss = (start_loss + end_loss) / 2.0 + else: + start_logits = self.squeeze(logits[:, :, 0:1]) + end_logits = self.squeeze(logits[:, :, 1:2]) + total_loss = (unique_id, start_logits, end_logits) + return total_loss diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py new file mode 100644 index 0000000000..1d12ddaf06 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -0,0 +1,438 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert for pretraining.""" +import numpy as np + +import mindspore.nn as nn +from mindspore.common.initializer import initializer, TruncatedNormal +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.communication.management import get_group_size +from mindspore import context +from mindspore.ops import _selected_ops +from .bert_model import BertModel + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +# pylint: disable=consider-using-in +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + + +class GetMaskedLMOutput(nn.Cell): + """ + Get masked lm output. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, masked lm output. + """ + def __init__(self, config): + super(GetMaskedLMOutput, self).__init__() + self.width = config.hidden_size + self.reshape = P.Reshape() + self.gather = P.GatherV2() + + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(self.width, + config.hidden_size, + weight_init=weight_init, + activation=config.hidden_act).to_float(config.compute_type) + self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) + self.output_bias = Parameter( + initializer( + 'zero', + config.vocab_size), + name='output_bias') + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_offsets = (-1, 1) + self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32)) + self.last_idx = (-1,) + self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) + self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32)) + self.cast = P.Cast() + self.compute_type = config.compute_type + self.dtype = config.dtype + + def construct(self, + input_tensor, + output_weights, + positions): + flat_offsets = self.reshape( + self.rng * self.seq_length_tensor, self.shape_flat_offsets) + flat_position = self.reshape(positions + flat_offsets, self.last_idx) + flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + input_tensor = self.dense(input_tensor) + input_tensor = self.layernorm(input_tensor) + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + logits = logits + self.output_bias + log_probs = self.log_softmax(logits) + return log_probs + + +class GetNextSentenceOutput(nn.Cell): + """ + Get next sentence output. + + Args: + config (BertConfig): The config of Bert. + + Returns: + Tensor, next sentence output. + """ + def __init__(self, config): + super(GetNextSentenceOutput, self).__init__() + self.log_softmax = _selected_ops.LogSoftmax() + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(config.hidden_size, 2, + weight_init=weight_init, has_bias=True).to_float(config.compute_type) + self.dtype = config.dtype + self.cast = P.Cast() + + def construct(self, input_tensor): + logits = self.dense(input_tensor) + logits = self.cast(logits, self.dtype) + log_prob = self.log_softmax(logits) + return log_prob + + +class BertPreTraining(nn.Cell): + """ + Bert pretraining network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + def __init__(self, config, is_training, use_one_hot_embeddings): + super(BertPreTraining, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cls1 = GetMaskedLMOutput(config) + self.cls2 = GetNextSentenceOutput(config) + + def construct(self, input_ids, input_mask, token_type_id, + masked_lm_positions): + sequence_output, pooled_output, embedding_table = \ + self.bert(input_ids, token_type_id, input_mask) + prediction_scores = self.cls1(sequence_output, + embedding_table, + masked_lm_positions) + seq_relationship_score = self.cls2(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPretrainingLoss(nn.Cell): + """ + Provide bert pre-training loss. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, total loss. + """ + def __init__(self, config): + super(BertPretrainingLoss, self).__init__() + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + + def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, + masked_lm_weights, next_sentence_labels): + """Defines the computation performed.""" + label_ids = self.reshape(masked_lm_ids, self.last_idx) + label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + masked_lm_loss = numerator / denominator + + # next_sentence_loss + labels = self.reshape(next_sentence_labels, self.last_idx) + one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum( + one_hot_labels * seq_relationship_score, self.last_idx)) + next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) + + # total_loss + total_loss = masked_lm_loss + next_sentence_loss + + return total_loss + + +class BertNetworkWithLoss(nn.Cell): + """ + Provide bert pre-training loss through network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(BertNetworkWithLoss, self).__init__() + self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) + self.loss = BertPretrainingLoss(config) + self.cast = P.Cast() + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + prediction_scores, seq_relationship_score = \ + self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) + total_loss = self.loss(prediction_scores, seq_relationship_score, + masked_lm_ids, masked_lm_weights, next_sentence_labels) + return self.cast(total_loss, mstype.float32) + + +class BertTrainOneStepCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + def __init__(self, network, optimizer, sens=1.0): + super(BertTrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + + def set_sens(self, value): + self.sens = value + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Defines the computation performed.""" + weights = self.weights + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + succ = self.optimizer(grads) + return F.depend(loss, succ) + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + + +class BertTrainOneStepWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/bert/src/bert_model.py b/model_zoo/official/nlp/bert/src/bert_model.py new file mode 100644 index 0000000000..8f972f8cec --- /dev/null +++ b/model_zoo/official/nlp/bert/src/bert_model.py @@ -0,0 +1,944 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" + +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from .fused_layer_norm import FusedLayerNorm + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded + from dataset. Default: True. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32, + enable_fused_layernorm=False): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.input_mask_from_dataset = input_mask_from_dataset + self.token_type_ids_from_dataset = token_type_ids_from_dataset + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + self.enable_fused_layernorm = enable_fused_layernorm + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + def __init__(self, + embedding_size, + embedding_shape, + use_relative_positions=False, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [token_type_vocab_size, + embedding_size]), + name='embedding_table') + + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.dropout = nn.Dropout(1 - dropout_prob) + self.gather = P.GatherV2() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + self.full_position_embeddings = Parameter(initializer + (TruncatedNormal(initializer_range), + [max_position_embeddings, + embedding_size]), + name='full_position_embeddings') + + def construct(self, token_type_ids, word_embeddings): + output = word_embeddings + if self.use_token_type: + flat_ids = self.reshape(token_type_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, + self.token_type_vocab_size, self.on_value, self.off_value) + token_type_embeddings = self.array_mul(one_hot_ids, + self.embedding_table) + else: + token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) + token_type_embeddings = self.reshape(token_type_embeddings, self.shape) + output += token_type_embeddings + if not self.use_relative_positions: + _, seq, width = self.shape + position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) + position_embeddings = self.reshape(position_embeddings, (1, seq, width)) + output += position_embeddings + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertOutput, self).__init__() + self.dense = nn.Dense(in_channels, out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.dropout_prob = dropout_prob + self.add = P.TensorAdd() + if compute_type == mstype.float16: + self.layernorm = FusedLayerNorm((out_channels,), + use_batch_norm=enable_fused_layernorm).to_float(compute_type) + else: + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(input_tensor, output) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + def __init__(self, length, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._length = length + self._max_relative_position = max_relative_position + self._min_relative_position = -max_relative_position + self.range_length = -length + 1 + + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self): + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) + tile_row_out = self.tile(range_vec_row_out, (self._length,)) + tile_col_out = self.tile(range_vec_col_out, (1, self._length)) + range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) + transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) + distance_mat = self.sub(range_mat_out, transpose_out) + + distance_mat_clipped = C.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + length, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth]), + name='embeddings_for_position') + + self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, + max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = nn.OneHot(depth=self.vocab_size) + self.shape = P.Shape() + self.gather = P.GatherV2() # index_select + self.matmul = P.BatchMatMul() + + def construct(self): + relative_positions_matrix_out = self.relative_positions_matrix() + + # Generate embedding for each relative position of dimension depth. + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + + self.tensor_min_type = float(np.finfo(np_type).min) + self.tensor_max_type = float(np.finfo(np_type).max) + + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=False, + use_relative_positions=False, + compute_type=mstype.float32): + + super(BertAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + weight_init=weight).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + weight_init=weight).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + weight_init=weight).to_float(compute_type) + + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = ( + batch_size, to_seq_length, num_attention_heads, size_per_head) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = -10000.0 + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(length=to_seq_length, + depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_keys' = [F|T, F|T, H] + relations_keys = self._generate_relative_positions_embeddings() + relations_keys = self.cast_compute_type(relations_keys) + # query_layer_t is [F, B, N, H] + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + # query_layer_r is [F, B * N, H] + query_layer_r = self.reshape(query_layer_t, + (self.from_seq_length, + self.batch_num, + self.size_per_head)) + # key_position_scores is [F, B * N, F|T] + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + # key_position_scores_r is [F, B, N, F|T] + key_position_scores_r = self.reshape(key_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.from_seq_length)) + # key_position_scores_r_t is [B, N, F, F|T] + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + + attention_scores = self.multiply(self.scores_mul, attention_scores) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_values' = [F|T, F|T, H] + relations_values = self._generate_relative_positions_embeddings() + relations_values = self.cast_compute_type(relations_values) + # attention_probs_t is [F, B, N, T] + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + # attention_probs_r is [F, B * N, T] + attention_probs_r = self.reshape( + attention_probs_t, + (self.from_seq_length, + self.batch_num, + self.to_seq_length)) + # value_position_scores is [F, B * N, H] + value_position_scores = self.matmul(attention_probs_r, + relations_values) + # value_position_scores_r is [F, B, N, H] + value_position_scores_r = self.reshape(value_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.size_per_head)) + # value_position_scores_r_t is [B, N, F, H] + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + + return context_layer + + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + + self.size_per_head = int(hidden_size / num_attention_heads) + + self.attention = BertAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + do_return_2d_tensor=True, + compute_type=compute_type) + + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + input_tensor = self.reshape(input_tensor, self.shape) + attention_output = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the bert encoder layers. Default: 768. + seq_length (int): Length of input sequence. Default: 512. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=768, + seq_length=512, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.intermediate = nn.Dense(in_channels=hidden_size, + out_channels=intermediate_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + + def construct(self, hidden_states, attention_mask): + # self-attention + attention_output = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False, + enable_fused_layernorm=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + layers.append(layer) + + self.layers = nn.CellList(layers) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask): + prev_output = self.reshape(input_tensor, self.shape) + + all_encoder_layers = () + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + if self.return_all_encoders: + layer_output = self.reshape(layer_output, self.out_shape) + all_encoder_layers = all_encoder_layers + (layer_output,) + + if not self.return_all_encoders: + prev_output = self.reshape(prev_output, self.out_shape) + all_encoder_layers = all_encoder_layers + (prev_output,) + return all_encoder_layers + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask_from_dataset = config.input_mask_from_dataset + self.input_mask = None + + if not self.input_mask_from_dataset: + self.input_mask = initializer( + "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() + + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = (config.batch_size, 1, config.seq_length) + self.broadcast_ones = initializer( + "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() + self.batch_matmul = P.BatchMatMul() + + def construct(self, input_mask): + if not self.input_mask_from_dataset: + input_mask = self.input_mask + + attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) + return attention_mask + + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + + self.bert_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_relative_positions=config.use_relative_positions, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + + self.bert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, + word_embeddings) + + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + + # bert encoder + encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + + return sequence_output, pooled_output, embedding_tables diff --git a/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py b/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py new file mode 100755 index 0000000000..042bc0a9c6 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py @@ -0,0 +1,135 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +sample script of processing CLUE classification dataset using mindspore.dataset.text for fine-tuning bert +""" + +import os +import numpy as np + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.text as text +import mindspore.dataset.transforms.c_transforms as ops + + +def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, + max_seq_len=128, batch_size=64, drop_remainder=True): + """Process TNEWS dataset""" + ### Loading TNEWS from CLUEDataset + assert data_usage in ['train', 'eval', 'test'] + if data_usage == 'train': + dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='TNEWS', + usage=data_usage, shuffle=shuffle_dataset) + elif data_usage == 'eval': + dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='TNEWS', + usage=data_usage, shuffle=shuffle_dataset) + else: + dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='TNEWS', + usage=data_usage, shuffle=shuffle_dataset) + ### Processing label + if data_usage == 'test': + dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], + columns_order=["id", "label_id", "sentence"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) + else: + label_vocab = text.Vocab.from_list(label_list) + label_lookup = text.Lookup(label_vocab) + dataset = dataset.map(input_columns="label_desc", output_columns="label_id", operations=label_lookup) + ### Processing sentence + vocab = text.Vocab.from_file(bert_vocab_path) + tokenizer = text.BertTokenizer(vocab, lower_case=True) + lookup = text.Lookup(vocab, unknown_token='[UNK]') + dataset = dataset.map(input_columns=["sentence"], operations=tokenizer) + dataset = dataset.map(input_columns=["sentence"], operations=ops.Slice(slice(0, max_seq_len))) + dataset = dataset.map(input_columns=["sentence"], + operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), + append=np.array(["[SEP]"], dtype='S'))) + dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) + dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) + dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], + columns_order=["text_ids", "mask_ids", "label_id"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) + dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], + columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + return dataset + + +def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, + max_seq_len=128, batch_size=64, drop_remainder=True): + """Process CMNLI dataset""" + ### Loading CMNLI from CLUEDataset + assert data_usage in ['train', 'eval', 'test'] + if data_usage == 'train': + dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='CMNLI', + usage=data_usage, shuffle=shuffle_dataset) + elif data_usage == 'eval': + dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='CMNLI', + usage=data_usage, shuffle=shuffle_dataset) + else: + dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='CMNLI', + usage=data_usage, shuffle=shuffle_dataset) + ### Processing label + if data_usage == 'test': + dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], + columns_order=["id", "label_id", "sentence1", "sentence2"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) + else: + label_vocab = text.Vocab.from_list(label_list) + label_lookup = text.Lookup(label_vocab) + dataset = dataset.map(input_columns="label", output_columns="label_id", operations=label_lookup) + ### Processing sentence pairs + vocab = text.Vocab.from_file(bert_vocab_path) + tokenizer = text.BertTokenizer(vocab, lower_case=True) + lookup = text.Lookup(vocab, unknown_token='[UNK]') + ### Tokenizing sentences and truncate sequence pair + dataset = dataset.map(input_columns=["sentence1"], operations=tokenizer) + dataset = dataset.map(input_columns=["sentence2"], operations=tokenizer) + dataset = dataset.map(input_columns=["sentence1", "sentence2"], + operations=text.TruncateSequencePair(max_seq_len-3)) + ### Adding special tokens + dataset = dataset.map(input_columns=["sentence1"], + operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), + append=np.array(["[SEP]"], dtype='S'))) + dataset = dataset.map(input_columns=["sentence2"], + operations=ops.Concatenate(append=np.array(["[SEP]"], dtype='S'))) + ### Generating segment_ids + dataset = dataset.map(input_columns=["sentence1"], output_columns=["sentence1", "type_sentence1"], + columns_order=["sentence1", "type_sentence1", "sentence2", "label_id"], + operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["sentence2"], output_columns=["sentence2", "type_sentence2"], + columns_order=["sentence1", "type_sentence1", "sentence2", "type_sentence2", "label_id"], + operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["type_sentence1"], operations=[lookup, ops.Fill(0)]) + dataset = dataset.map(input_columns=["type_sentence2"], operations=[lookup, ops.Fill(1)]) + dataset = dataset.map(input_columns=["type_sentence1", "type_sentence2"], output_columns=["segment_ids"], + columns_order=["sentence1", "sentence2", "segment_ids", "label_id"], + operations=ops.Concatenate()) + dataset = dataset.map(input_columns=["segment_ids"], operations=ops.PadEnd([max_seq_len], 0)) + ### Generating text_ids + dataset = dataset.map(input_columns=["sentence1", "sentence2"], output_columns=["text_ids"], + columns_order=["text_ids", "segment_ids", "label_id"], + operations=ops.Concatenate()) + dataset = dataset.map(input_columns=["text_ids"], operations=lookup) + dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) + ### Generating mask_ids + dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], + columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + return dataset diff --git a/model_zoo/bert/src/cluener_evaluation.py b/model_zoo/official/nlp/bert/src/cluener_evaluation.py similarity index 100% rename from model_zoo/bert/src/cluener_evaluation.py rename to model_zoo/official/nlp/bert/src/cluener_evaluation.py diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py new file mode 100644 index 0000000000..c8692d1216 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/config.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in dataset.py, run_pretrain.py +""" +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .bert_model import BertConfig +cfg = edict({ + 'bert_network': 'base', + 'loss_scale_value': 65536, + 'scale_factor': 2, + 'scale_window': 1000, + 'optimizer': 'Lamb', + 'AdamWeightDecay': edict({ + 'learning_rate': 3e-5, + 'end_learning_rate': 0.0, + 'power': 5.0, + 'weight_decay': 1e-5, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + 'eps': 1e-6, + 'warmup_steps': 10000, + }), + 'Lamb': edict({ + 'learning_rate': 3e-5, + 'end_learning_rate': 0.0, + 'power': 10.0, + 'warmup_steps': 10000, + 'weight_decay': 0.01, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + 'eps': 1e-6, + }), + 'Momentum': edict({ + 'learning_rate': 2e-5, + 'momentum': 0.9, + }), +}) + +''' +Including two kinds of network: \ +base: Goole BERT-base(the base version of BERT model). +large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ + Functional Relative Posetional Encoding as an effective positional encoding scheme). +''' +if cfg.bert_network == 'base': + bert_net_cfg = BertConfig( + batch_size=64, + seq_length=128, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'nezha': + bert_net_cfg = BertConfig( + batch_size=96, + seq_length=128, + vocab_size=21128, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'large': + bert_net_cfg = BertConfig( + batch_size=24, + seq_length=512, + vocab_size=30522, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=True + ) diff --git a/model_zoo/official/nlp/bert/src/dataset.py b/model_zoo/official/nlp/bert/src/dataset.py new file mode 100644 index 0000000000..8193ef83fa --- /dev/null +++ b/model_zoo/official/nlp/bert/src/dataset.py @@ -0,0 +1,127 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in run_pretrain.py +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import log as logger +from .config import bert_net_cfg + + +def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None): + """create train dataset""" + # apply repeat operations + files = os.listdir(data_dir) + data_files = [] + for file_name in files: + if "tfrecord" in file_name: + data_files.append(os.path.join(data_dir, file_name)) + ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", + "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], + shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, + num_shards=device_num, shard_id=rank, shard_equal_rows=True) + ori_dataset_size = ds.get_dataset_size() + print('origin dataset size: ', ori_dataset_size) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) + ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + # apply batch operations + ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) + logger.info("data size: {}".format(ds.get_dataset_size())) + logger.info("repeat count: {}".format(ds.get_repeat_count())) + return ds + + +def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", + data_file_path=None, schema_file_path=None, do_shuffle=True): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], shuffle=do_shuffle) + if assessment_method == "Spearman_correlation": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", + data_file_path=None, schema_file_path=None, do_shuffle=True): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], shuffle=do_shuffle) + if assessment_method == "Spearman_correlation": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, + is_training=True, do_shuffle=True): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + if is_training: + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "start_positions", + "end_positions", "unique_ids", "is_impossible"], + shuffle=do_shuffle) + ds = ds.map(input_columns="start_positions", operations=type_cast_op) + ds = ds.map(input_columns="end_positions", operations=type_cast_op) + else: + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"]) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds diff --git a/model_zoo/official/nlp/bert/src/finetune_eval_config.py b/model_zoo/official/nlp/bert/src/finetune_eval_config.py new file mode 100644 index 0000000000..4a9f05a3fc --- /dev/null +++ b/model_zoo/official/nlp/bert/src/finetune_eval_config.py @@ -0,0 +1,66 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +config settings, will be used in finetune.py +""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .bert_model import BertConfig + +optimizer_cfg = edict({ + 'optimizer': 'Lamb', + 'AdamWeightDecay': edict({ + 'learning_rate': 2e-5, + 'end_learning_rate': 1e-7, + 'power': 1.0, + 'weight_decay': 1e-5, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + 'eps': 1e-6, + }), + 'Lamb': edict({ + 'learning_rate': 2e-5, + 'end_learning_rate': 1e-7, + 'power': 1.0, + 'weight_decay': 0.01, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), + 'Momentum': edict({ + 'learning_rate': 2e-5, + 'momentum': 0.9, + }), +}) + +bert_net_cfg = BertConfig( + batch_size=16, + seq_length=128, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, +) diff --git a/model_zoo/bert/src/finetune_eval_model.py b/model_zoo/official/nlp/bert/src/finetune_eval_model.py similarity index 100% rename from model_zoo/bert/src/finetune_eval_model.py rename to model_zoo/official/nlp/bert/src/finetune_eval_model.py diff --git a/model_zoo/bert/src/fused_layer_norm.py b/model_zoo/official/nlp/bert/src/fused_layer_norm.py similarity index 100% rename from model_zoo/bert/src/fused_layer_norm.py rename to model_zoo/official/nlp/bert/src/fused_layer_norm.py diff --git a/model_zoo/bert/src/sample_process.py b/model_zoo/official/nlp/bert/src/sample_process.py similarity index 100% rename from model_zoo/bert/src/sample_process.py rename to model_zoo/official/nlp/bert/src/sample_process.py diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py new file mode 100644 index 0000000000..6463464734 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -0,0 +1,156 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Functional Cells used in Bert finetune and evaluation. +""" + +import os +import numpy as np +import mindspore.nn as nn +from mindspore import log as logger +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.train.callback import Callback +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR + + +class CrossEntropyCalculation(nn.Cell): + """ + Cross Entropy loss + """ + def __init__(self, is_training=True): + super(CrossEntropyCalculation, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + self.is_training = is_training + + def construct(self, logits, label_ids, num_labels): + if self.is_training: + label_ids = self.reshape(label_ids, self.last_idx) + one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) + loss = self.reduce_mean(per_example_loss, self.last_idx) + return_value = self.cast(loss, mstype.float32) + else: + return_value = logits * 1.0 + return return_value + + +def make_directory(path: str): + """Make directory.""" + if path is None or not isinstance(path, str) or path.strip() == "": + logger.error("The path(%r) is invalid type.", path) + raise TypeError("Input path is invaild type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The abs path is %r", path) + + # check the path is exist and write permissions? + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path, exist_ok=True) + real_path = path + except PermissionError as e: + logger.error("No write permission on the directory(%r), error = %r", path, e) + raise TypeError("No write permission on the directory.") + return real_path + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self._per_print_times = per_print_times + def step_end(self, run_context): + cb_params = run_context.original_args() + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) + +def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): + """ + Find the ckpt finetune generated and load it into eval network. + """ + files = os.listdir(load_finetune_checkpoint_dir) + pre_len = len(prefix) + max_num = 0 + for filename in files: + name_ext = os.path.splitext(filename) + if name_ext[-1] != ".ckpt": + continue + #steps_per_epoch = ds.get_dataset_size() + if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): + index = filename[pre_len:].find("-") + if index == 0 and max_num == 0: + load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) + elif index not in (0, -1): + name_split = name_ext[-2].split('_') + if (steps_per_epoch != int(name_split[len(name_split)-1])) \ + or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])): + continue + num = filename[pre_len + 1:pre_len + index] + if int(num) > max_num: + max_num = int(num) + load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) + return load_finetune_checkpoint_path + + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + decay_lr = self.decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr diff --git a/model_zoo/official/nlp/bert_thor/README.md b/model_zoo/official/nlp/bert_thor/README.md new file mode 100644 index 0000000000..a3df8b73bb --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/README.md @@ -0,0 +1,93 @@ +# BERT Example +## Description +This is an example of training bert by second-order optimizer THOR. THOR is a novel approximate seond-order optimization method in MindSpore. + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. +- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. +> Notes: + If you are running a fine-tuning or evaluation task, prepare a checkpoint from pre-train. + +## Running the Example +### Pre-Training +- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. + +- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. + + ``` bash + sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR + ``` +- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. + + ``` bash + sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE + ``` + +## Usage +### Pre-Training +``` +usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] + [--enable_save_ckpt ENABLE_SAVE_CKPT] + [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] + [--save_checkpoint_steps N] [--save_checkpoint_num N] + [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] + +options: + --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" + --epoch_size epoch size: N, default is 1 + --device_num number of used devices: N, default is 1 + --device_id device id: N, default is 0 + --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" + --enable_lossscale enable lossscale: "true" | "false", default is "true" + --do_shuffle enable shuffle: "true" | "false", default is "true" + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --checkpoint_path path to save checkpoint files: PATH, default is "" + --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 + --save_checkpoint_num number for saving checkpoint files: N, default is 1 + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" +``` +## Options and Parameters +It contains of parameters of BERT model and options for training, which is set in file `config.py`, `bert_net_config.py` and `evaluation_config.py` respectively. +### Options: +``` +config.py: + bert_network version of BERT model: base | nezha, default is base + optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor" + +``` + +### Parameters: +``` +Parameters for dataset and network (Pre-Training/Evaluation): + batch_size batch size of input dataset: N, default is 8 + seq_length length of input sequence: N, default is 128 + vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 + hidden_size size of bert encoder layers: N, default is 768 + num_hidden_layers number of hidden layers: N, default is 12 + num_attention_heads number of attention heads: N, default is 12 + intermediate_size size of intermediate layer: N, default is 3072 + hidden_act activation function used: ACTIVATION, default is "gelu" + hidden_dropout_prob dropout probability for BertOutput: Q, default is 0.1 + attention_probs_dropout_prob dropout probability for BertAttention: Q, default is 0.1 + max_position_embeddings maximum length of sequences: N, default is 512 + type_vocab_size size of token type vocab: N, default is 16 + initializer_range initialization value of TruncatedNormal: Q, default is 0.02 + use_relative_positions use relative positions or not: True | False, default is False + input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True + token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True + dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 + compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 + +Parameters for optimizer: + Thor: + momentum momentum for the moving average: Q + weight_decay weight decay: Q + loss_scale loss scale: N + frequency the step interval to update second-order information matrix: N, default is 10 + batch_size batch size of input dataset: N, default is 8 +``` + diff --git a/model_zoo/official/nlp/bert_thor/pretrain_eval.py b/model_zoo/official/nlp/bert_thor/pretrain_eval.py new file mode 100644 index 0000000000..04ad30bb3d --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/pretrain_eval.py @@ -0,0 +1,164 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Bert evaluation script. +""" + +import os + +from src import BertModel, GetMaskedLMOutput +from src.evaluation_config import cfg, bert_net_cfg + +import mindspore.common.dtype as mstype +import mindspore.dataset as de +import mindspore.dataset.transforms.c_transforms as C +import mindspore.nn as nn +from mindspore import context +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.nn.metrics import Metric +from mindspore.ops import operations as P +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + + +class myMetric(Metric): + ''' + Self-defined Metric as a callback. + ''' + + def __init__(self): + super(myMetric, self).__init__() + self.clear() + + def clear(self): + self.total_num = 0 + self.acc_num = 0 + + def update(self, *inputs): + total_num = self._convert_data(inputs[0]) + acc_num = self._convert_data(inputs[1]) + self.total_num = total_num + self.acc_num = acc_num + + def eval(self): + return self.acc_num / self.total_num + + +class GetLogProbs(nn.Cell): + ''' + Get MaskedLM prediction scores + ''' + + def __init__(self, config): + super(GetLogProbs, self).__init__() + self.bert = BertModel(config, False) + self.cls1 = GetMaskedLMOutput(config) + + def construct(self, input_ids, input_mask, token_type_id, masked_pos): + sequence_output, _, embedding_table = self.bert(input_ids, token_type_id, input_mask) + prediction_scores = self.cls1(sequence_output, embedding_table, masked_pos) + return prediction_scores + + +class BertPretrainEva(nn.Cell): + ''' + Evaluate MaskedLM prediction scores + ''' + + def __init__(self, config): + super(BertPretrainEva, self).__init__() + self.bert = GetLogProbs(config) + self.argmax = P.Argmax(axis=-1, output_type=mstype.int32) + self.equal = P.Equal() + self.mean = P.ReduceMean() + self.sum = P.ReduceSum() + self.total = Parameter(Tensor([0], mstype.float32), name='total') + self.acc = Parameter(Tensor([0], mstype.float32), name='acc') + self.reshape = P.Reshape() + self.shape = P.Shape() + self.cast = P.Cast() + + def construct(self, input_ids, input_mask, token_type_id, masked_pos, masked_ids, masked_weights, nsp_label): + """construct of BertPretrainEva""" + bs, _ = self.shape(input_ids) + probs = self.bert(input_ids, input_mask, token_type_id, masked_pos) + index = self.argmax(probs) + index = self.reshape(index, (bs, -1)) + eval_acc = self.equal(index, masked_ids) + eval_acc1 = self.cast(eval_acc, mstype.float32) + real_acc = eval_acc1 * masked_weights + acc = self.sum(real_acc) + total = self.sum(masked_weights) + self.total += total + self.acc += acc + return acc, self.total, self.acc + + +def get_enwiki_512_dataset(batch_size=1, repeat_count=1, distribute_file=''): + ''' + Get enwiki seq_length=512 dataset + ''' + ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", + "masked_lm_positions", "masked_lm_ids", + "masked_lm_weights", + "next_sentence_labels"]) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) + ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) + ds = ds.repeat(repeat_count) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def bert_predict(): + ''' + Predict function + ''' + devid = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) + dataset = get_enwiki_512_dataset(bert_net_cfg.batch_size, 1) + net_for_pretraining = BertPretrainEva(bert_net_cfg) + net_for_pretraining.set_train(False) + param_dict = load_checkpoint(cfg.finetune_ckpt) + load_param_into_net(net_for_pretraining, param_dict) + model = Model(net_for_pretraining) + return model, dataset, net_for_pretraining + + +def MLM_eval(): + ''' + Evaluate function + ''' + _, dataset, net_for_pretraining = bert_predict() + net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2], + metrics={'name': myMetric()}) + res = net.eval(dataset, dataset_sink_mode=False) + print("==============================================================") + for _, v in res.items(): + print("Accuracy is: ") + print(v) + print("==============================================================") + + +if __name__ == "__main__": + MLM_eval() diff --git a/model_zoo/official/nlp/bert_thor/run_pretrain.py b/model_zoo/official/nlp/bert_thor/run_pretrain.py new file mode 100644 index 0000000000..0ec84545db --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/run_pretrain.py @@ -0,0 +1,202 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################pre_train bert example on zh-wiki######################## +python run_pretrain.py +""" + +import argparse +import os + +import numpy +from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from src.bert_net_config import bert_net_cfg +from src.config import cfg +from src.dataset import create_bert_dataset +from src.lr_generator import get_bert_lr, get_bert_damping +from src.model_thor import Model +# from src.thor_for_bert import THOR +from src.thor_for_bert_arg import THOR +from src.utils import LossCallBack, BertLearningRate + +import mindspore.common.dtype as mstype +import mindspore.communication.management as D +from mindspore import context +from mindspore import log as logger +from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train.parallel_utils import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_current_dir = os.path.dirname(os.path.realpath(__file__)) + + +def run_pretrain(): + """pre-train bert_clue""" + parser = argparse.ArgumentParser(description='bert pre_training') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") + parser.add_argument("--device_id", type=int, default=4, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") + parser.add_argument("--enable_lossscale", type=str, default="false", help="Use lossscale or not, default is not.") + parser.add_argument("--do_shuffle", type=str, default="false", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") + parser.add_argument("--data_sink_steps", type=int, default="100", help="Sink steps for each epoch, default is 1.") + parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " + "default is 1000.") + parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " + "meaning run all steps according to epoch number.") + parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") + parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + + args_opt = parser.parse_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id, + save_graphs=True) + context.set_context(reserve_class_name_in_scope=False) + context.set_context(variable_memory_max_size="30GB") + ckpt_save_dir = args_opt.save_checkpoint_path + if args_opt.distribute == "true": + if args_opt.device_target == 'Ascend': + D.init('hccl') + device_num = args_opt.device_num + rank = args_opt.device_id % device_num + else: + D.init('nccl') + device_num = D.get_group_size() + rank = D.get_rank() + ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' + + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=device_num) + from mindspore.parallel._auto_parallel_context import auto_parallel_context + if bert_net_cfg.num_hidden_layers == 12: + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217], + "hccl_world_groupsum3") + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205], + "hccl_world_groupsum3") + elif bert_net_cfg.num_hidden_layers == 24: + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421], + "hccl_world_groupsum3") + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397], + "hccl_world_groupsum3") + else: + rank = 0 + device_num = 1 + + if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: + logger.warning('Gpu only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + + ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) + net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) + + new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps + if args_opt.train_steps > 0: + new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) + else: + args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() + logger.info("train steps: {}".format(args_opt.train_steps)) + + if cfg.optimizer == 'Lamb': + lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, + end_learning_rate=cfg.Lamb.end_learning_rate, + warmup_steps=cfg.Lamb.warmup_steps, + decay_steps=args_opt.train_steps, + power=cfg.Lamb.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(cfg.Lamb.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, + {'params': other_params}, + {'order_params': params}] + optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) + elif cfg.optimizer == 'Momentum': + optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, + momentum=cfg.Momentum.momentum) + elif cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, + end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=cfg.AdamWeightDecay.warmup_steps, + decay_steps=args_opt.train_steps, + power=cfg.AdamWeightDecay.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) + elif cfg.optimizer == "Thor": + lr = get_bert_lr() + damping = get_bert_damping() + optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum, + filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()), + filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()), + filter(lambda x: 'A_inv_max' in x.name, net_with_loss.get_parameters()), + filter(lambda x: 'G_inv_max' in x.name, net_with_loss.get_parameters()), + cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, + bert_net_cfg.batch_size, damping) + else: + raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". + format(cfg.optimizer)) + callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] + if args_opt.enable_save_ckpt == "true": + config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, + keep_checkpoint_max=args_opt.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) + callback.append(ckpoint_cb) + + if args_opt.load_checkpoint_path: + param_dict = load_checkpoint(args_opt.load_checkpoint_path) + load_param_into_net(net_with_loss, param_dict) + + if args_opt.enable_lossscale == "true": + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) + else: + net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) + + model = Model(net_with_grads, frequency=cfg.Thor.frequency) + model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), + sink_size=args_opt.data_sink_steps) + + +if __name__ == '__main__': + numpy.random.seed(0) + run_pretrain() diff --git a/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh b/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh new file mode 100644 index 0000000000..f82151bea0 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/scripts/run_distribute_pretrain.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE" +echo "for example: bash run_distribute_pretrain.sh 8 1 /path/zh-wiki/ /path/Schema.json /path/hccl.json" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +EPOCH_SIZE=$2 +DATA_DIR=$3 +SCHEMA_DIR=$4 + +ulimit -u unlimited +export RANK_TABLE_FILE=$5 +export RANK_SIZE=$1 +export HCCL_CONNECT_TIMEOUT=300 + +for((i=0;i env.log + python ../run_pretrain.py \ + --distribute="true" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --device_num=$RANK_SIZE \ + --enable_save_ckpt="true" \ + --enable_lossscale="false" \ + --do_shuffle="true" \ + --enable_data_sink="true" \ + --data_sink_steps=1000 \ + --load_checkpoint_path="" \ + --save_checkpoint_steps=5000 \ + --save_checkpoint_num=30 \ + --data_dir=$DATA_DIR \ + --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & + cd ../ +done \ No newline at end of file diff --git a/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh b/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh new file mode 100644 index 0000000000..f59eb69601 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/scripts/run_standalone_pretrain.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" +echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" +echo "==============================================================================================================" + +DEVICE_ID=$1 +EPOCH_SIZE=$2 +DATA_DIR=$3 +SCHEMA_DIR=$4 + +mkdir -p ms_log +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_pretrain.py \ + --distribute="false" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --enable_save_ckpt="true" \ + --enable_lossscale="true" \ + --do_shuffle="true" \ + --enable_data_sink="true" \ + --data_sink_steps=1 \ + --load_checkpoint_path="" \ + --save_checkpoint_steps=10000 \ + --save_checkpoint_num=1 \ + --data_dir=$DATA_DIR \ + --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/bert_thor/src/__init__.py b/model_zoo/official/nlp/bert_thor/src/__init__.py new file mode 100644 index 0000000000..4f4584a4b4 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert Init.""" +from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ + BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ + BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ + BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ + EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ + SaturateCast, CreateAttentionMaskFromInputMask + +__all__ = [ + "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", + "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell", + "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", + "BertSelfAttention", "BertTransformer", "EmbeddingLookup", + "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", + "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask" +] diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py new file mode 100644 index 0000000000..fb2db14743 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py @@ -0,0 +1,458 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert for pretraining.""" +import numpy as np + +import mindspore.nn as nn +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer, TruncatedNormal +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.communication.management import get_group_size +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.ops import _selected_ops +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.train.parallel_utils import ParallelMode +from .bert_model import BertModel +from .config import cfg +from .lr_generator import get_bert_damping +from .thor_layer import Dense_Thor + +damping = get_bert_damping() +loss_scale = cfg.Thor.loss_scale +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +# pylint: disable=consider-using-in +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + + +class GetMaskedLMOutput(nn.Cell): + """ + Get masked lm output. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, masked lm output. + """ + + def __init__(self, config): + super(GetMaskedLMOutput, self).__init__() + self.width = config.hidden_size + self.reshape = P.Reshape() + self.gather = P.GatherV2() + + weight_init = TruncatedNormal(config.initializer_range) + self.dense = Dense_Thor(in_channels=self.width, + out_channels=config.hidden_size, + weight_init=weight_init, + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation=config.hidden_act, + batch_size=config.batch_size).to_float(config.compute_type) + self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) + self.output_bias = Parameter( + initializer( + 'zero', + config.vocab_size), + name='output_bias') + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_offsets = (-1, 1) + self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32)) + self.last_idx = (-1,) + self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) + self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32)) + self.cast = P.Cast() + self.compute_type = config.compute_type + self.dtype = config.dtype + + def construct(self, + input_tensor, + output_weights, + positions): + """construct of GetMaskedLMOutput""" + flat_offsets = self.reshape( + self.rng * self.seq_length_tensor, self.shape_flat_offsets) + flat_position = self.reshape(positions + flat_offsets, self.last_idx) + flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + input_tensor = self.dense(input_tensor) + input_tensor = self.layernorm(input_tensor) + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + logits = logits + self.output_bias + log_probs = self.log_softmax(logits) + return log_probs + + +class GetNextSentenceOutput(nn.Cell): + """ + Get next sentence output. + + Args: + config (BertConfig): The config of Bert. + + Returns: + Tensor, next sentence output. + """ + + def __init__(self, config): + super(GetNextSentenceOutput, self).__init__() + self.log_softmax = _selected_ops.LogSoftmax() + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(config.hidden_size, 2, + weight_init=weight_init, has_bias=True).to_float(config.compute_type) + self.dtype = config.dtype + self.cast = P.Cast() + + def construct(self, input_tensor): + logits = self.dense(input_tensor) + logits = self.cast(logits, self.dtype) + log_prob = self.log_softmax(logits) + return log_prob + + +class BertPreTraining(nn.Cell): + """ + Bert pretraining network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings): + super(BertPreTraining, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cls1 = GetMaskedLMOutput(config) + self.cls2 = GetNextSentenceOutput(config) + + def construct(self, input_ids, input_mask, token_type_id, + masked_lm_positions): + sequence_output, pooled_output, embedding_table = \ + self.bert(input_ids, token_type_id, input_mask) + prediction_scores = self.cls1(sequence_output, + embedding_table, + masked_lm_positions) + seq_relationship_score = self.cls2(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPretrainingLoss(nn.Cell): + """ + Provide bert pre-training loss. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, total loss. + """ + + def __init__(self, config): + super(BertPretrainingLoss, self).__init__() + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + + def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, + masked_lm_weights, next_sentence_labels): + """Defines the computation performed.""" + label_ids = self.reshape(masked_lm_ids, self.last_idx) + label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + masked_lm_loss = numerator / denominator + + # next_sentence_loss + labels = self.reshape(next_sentence_labels, self.last_idx) + one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum( + one_hot_labels * seq_relationship_score, self.last_idx)) + next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) + + # total_loss + total_loss = masked_lm_loss + next_sentence_loss + + return total_loss + + +class BertNetworkWithLoss(nn.Cell): + """ + Provide bert pre-training loss through network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(BertNetworkWithLoss, self).__init__() + self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) + self.loss = BertPretrainingLoss(config) + self.cast = P.Cast() + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """construct of BertNetworkWithLoss""" + prediction_scores, seq_relationship_score = \ + self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) + total_loss = self.loss(prediction_scores, seq_relationship_score, + masked_lm_ids, masked_lm_weights, next_sentence_labels) + return self.cast(total_loss, mstype.float32) + + +class BertTrainOneStepCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + + def __init__(self, network, optimizer, sens=1.0): + super(BertTrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + + def set_sens(self, value): + self.sens = value + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Defines the computation performed.""" + weights = self.weights + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + succ = self.optimizer(grads) + return F.depend(loss, succ) + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + + +class BertTrainOneStepWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/bert_thor/src/bert_model.py b/model_zoo/official/nlp/bert_thor/src/bert_model.py new file mode 100644 index 0000000000..93b5a9169a --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/bert_model.py @@ -0,0 +1,1027 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" + +import copy +import math + +import numpy as np + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from .config import cfg +from .fused_layer_norm import FusedLayerNorm +from .lr_generator import get_bert_damping +from .thor_layer import Dense_Thor, Embedding_Thor + +damping = get_bert_damping() +loss_scale = cfg.Thor.loss_scale +batch_size = cfg.Thor.batch_size + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded + from dataset. Default: True. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + + def __init__(self, + batch_size, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32, + enable_fused_layernorm=False): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.input_mask_from_dataset = input_mask_from_dataset + self.token_type_ids_from_dataset = token_type_ids_from_dataset + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + self.enable_fused_layernorm = enable_fused_layernorm + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + """construct of EmbeddingLookup""" + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + + def __init__(self, + embedding_size, + embedding_shape, + use_relative_positions=False, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.token_type_embedding = Embedding_Thor( + vocab_size=token_type_vocab_size, + embedding_size=embedding_size, + embedding_shape=embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + name='embedding_table', + is_expand=False, + batch_size=batch_size, + damping=damping, + loss_scale=loss_scale, + frequency=1) + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.dropout = nn.Dropout(1 - dropout_prob) + self.gather = P.GatherV2() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + _, seq, width = self.shape + position_embedding_shape = [1, seq, width] + self.full_position_embedding = Embedding_Thor( + vocab_size=max_position_embeddings, + embedding_size=embedding_size, + embedding_shape=position_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + name='full_position_embeddings', + is_expand=False, + batch_size=batch_size, + damping=damping, + loss_scale=loss_scale, + frequency=1) + self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) + self.layernorm = nn.LayerNorm((embedding_size,)) + + def construct(self, token_type_ids, word_embeddings): + """construct of EmbeddingPostprocessor""" + output = word_embeddings + if self.use_token_type: + token_type_embeddings, _ = self.token_type_embedding(token_type_ids) + output += token_type_embeddings + if not self.use_relative_positions: + position_embeddings, _ = self.full_position_embedding(self.position_ids) + output += position_embeddings + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertOutput, self).__init__() + self.dense = Dense_Thor(in_channels=in_channels, + out_channels=out_channels, + weight_init=TruncatedNormal(initializer_range), + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation=None, + batch_size=batch_size).to_float(compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.dropout_prob = dropout_prob + self.add = P.TensorAdd() + if compute_type == mstype.float16: + self.layernorm = FusedLayerNorm((out_channels,), + use_batch_norm=enable_fused_layernorm).to_float(compute_type) + else: + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + """construct of BertOutput""" + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(input_tensor, output) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + + def __init__(self, length, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._length = length + self._max_relative_position = max_relative_position + self._min_relative_position = -max_relative_position + self.range_length = -length + 1 + + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self): + """construct of RelaPosMatrixGenerator""" + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) + tile_row_out = self.tile(range_vec_row_out, (self._length,)) + tile_col_out = self.tile(range_vec_col_out, (1, self._length)) + range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) + transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) + distance_mat = self.sub(range_mat_out, transpose_out) + + distance_mat_clipped = C.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + length, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth]), + name='embeddings_for_position') + + self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, + max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = nn.OneHot(depth=self.vocab_size) + self.shape = P.Shape() + self.gather = P.GatherV2() # index_select + self.matmul = P.BatchMatMul() + + def construct(self): + """construct of RelaPosEmbeddingsGenerator""" + relative_positions_matrix_out = self.relative_positions_matrix() + + # Generate embedding for each relative position of dimension depth. + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + + self.tensor_min_type = float(np.finfo(np_type).min) + self.tensor_max_type = float(np.finfo(np_type).max) + + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + """construct of SaturateCast""" + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=False, + use_relative_positions=False, + compute_type=mstype.float32): + + super(BertAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = Dense_Thor(in_channels=from_tensor_width, + out_channels=units, + weight_init=weight, + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation=query_act, + batch_size=batch_size).to_float(compute_type) + self.key_layer = Dense_Thor(in_channels=to_tensor_width, + out_channels=units, + weight_init=weight, + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation=key_act, + batch_size=batch_size).to_float(compute_type) + self.value_layer = Dense_Thor(in_channels=to_tensor_width, + out_channels=units, + weight_init=weight, + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation=value_act, + batch_size=batch_size).to_float(compute_type) + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = ( + batch_size, to_seq_length, num_attention_heads, size_per_head) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = -10000.0 + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(length=to_seq_length, + depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + """construct of BertAttention""" + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_keys' = [F|T, F|T, H] + relations_keys = self._generate_relative_positions_embeddings() + relations_keys = self.cast_compute_type(relations_keys) + # query_layer_t is [F, B, N, H] + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + # query_layer_r is [F, B * N, H] + query_layer_r = self.reshape(query_layer_t, + (self.from_seq_length, + self.batch_num, + self.size_per_head)) + # key_position_scores is [F, B * N, F|T] + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + # key_position_scores_r is [F, B, N, F|T] + key_position_scores_r = self.reshape(key_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.from_seq_length)) + # key_position_scores_r_t is [B, N, F, F|T] + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + + attention_scores = self.multiply(self.scores_mul, attention_scores) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_values' = [F|T, F|T, H] + relations_values = self._generate_relative_positions_embeddings() + relations_values = self.cast_compute_type(relations_values) + # attention_probs_t is [F, B, N, T] + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + # attention_probs_r is [F, B * N, T] + attention_probs_r = self.reshape( + attention_probs_t, + (self.from_seq_length, + self.batch_num, + self.to_seq_length)) + # value_position_scores is [F, B * N, H] + value_position_scores = self.matmul(attention_probs_r, + relations_values) + # value_position_scores_r is [F, B, N, H] + value_position_scores_r = self.reshape(value_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.size_per_head)) + # value_position_scores_r_t is [B, N, F, H] + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + + return context_layer + + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + + def __init__(self, + batch_size, + seq_length, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + + self.size_per_head = int(hidden_size / num_attention_heads) + + self.attention = BertAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + do_return_2d_tensor=True, + compute_type=compute_type) + + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + """construct of BertSelfAttention""" + input_tensor = self.reshape(input_tensor, self.shape) + attention_output = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the bert encoder layers. Default: 768. + seq_length (int): Length of input sequence. Default: 512. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + + def __init__(self, + batch_size, + hidden_size=768, + seq_length=512, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.intermediate = Dense_Thor(in_channels=hidden_size, + out_channels=intermediate_size, + weight_init=TruncatedNormal(initializer_range), + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation=hidden_act, + batch_size=batch_size).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + + def construct(self, hidden_states, attention_mask): + """construct of BertEncoderCell""" + # self-attention + attention_output = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False, + enable_fused_layernorm=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + layers.append(layer) + + self.layers = nn.CellList(layers) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask): + """construct of BertTransformer""" + prev_output = self.reshape(input_tensor, self.shape) + + all_encoder_layers = () + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + if self.return_all_encoders: + layer_output = self.reshape(layer_output, self.out_shape) + all_encoder_layers = all_encoder_layers + (layer_output,) + + if not self.return_all_encoders: + prev_output = self.reshape(prev_output, self.out_shape) + all_encoder_layers = all_encoder_layers + (prev_output,) + return all_encoder_layers + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask_from_dataset = config.input_mask_from_dataset + self.input_mask = None + + if not self.input_mask_from_dataset: + self.input_mask = initializer( + "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() + + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = (config.batch_size, 1, config.seq_length) + self.broadcast_ones = initializer( + "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() + self.batch_matmul = P.BatchMatMul() + + def construct(self, input_mask): + """construct of CreateAttentionMaskFromInputMask""" + if not self.input_mask_from_dataset: + input_mask = self.input_mask + + attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) + return attention_mask + + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + + self.bert_embedding_lookup = Embedding_Thor( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + name='embedding_table', + is_expand=True, + batch_size=batch_size, + damping=damping, + loss_scale=loss_scale, + frequency=1) + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_relative_positions=config.use_relative_positions, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + + self.bert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = Dense_Thor(in_channels=self.hidden_size, + out_channels=self.hidden_size, + weight_init=TruncatedNormal(config.initializer_range), + has_bias=True, + bias_init='zeros', + damping=damping, + loss_scale=loss_scale, + frequency=1, + activation="tanh", + batch_size=batch_size).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """construct of BertModel""" + + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, + word_embeddings) + + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + + # bert encoder + encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + + return sequence_output, pooled_output, embedding_tables diff --git a/model_zoo/official/nlp/bert_thor/src/bert_net_config.py b/model_zoo/official/nlp/bert_thor/src/bert_net_config.py new file mode 100644 index 0000000000..043705f387 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/bert_net_config.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in dataset.py, run_pretrain.py +Including two kinds of network: \ +base: Goole BERT-base(the base version of BERT model). +large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ + Functional Relative Posetional Encoding as an effective positional encoding scheme). +""" +import mindspore.common.dtype as mstype +from .bert_model import BertConfig +from .config import cfg + +if cfg.bert_network == 'base': + bert_net_cfg = BertConfig( + batch_size=cfg.Thor.batch_size, + seq_length=128, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'nezha': + bert_net_cfg = BertConfig( + batch_size=cfg.Thor.batch_size, + seq_length=128, + vocab_size=21128, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'large': + bert_net_cfg = BertConfig( + batch_size=cfg.Thor.batch_size, + seq_length=512, + vocab_size=30522, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=True + ) diff --git a/model_zoo/official/nlp/bert_thor/src/config.py b/model_zoo/official/nlp/bert_thor/src/config.py new file mode 100644 index 0000000000..9c1d5bf725 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/config.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in dataset.py, run_pretrain.py +""" +from easydict import EasyDict as edict + +cfg = edict({ + 'bert_network': 'large', + 'loss_scale_value': 65536, + 'scale_factor': 2, + 'scale_window': 1000, + 'optimizer': 'Thor', + 'AdamWeightDecay': edict({ + 'learning_rate': 3e-5, + 'end_learning_rate': 1e-10, + 'power': 5.0, + 'weight_decay': 1e-5, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + 'eps': 1e-6, + 'warmup_steps': 10000, + }), + 'Lamb': edict({ + 'learning_rate': 3e-5, + 'end_learning_rate': 1e-10, + 'power': 10.0, + 'warmup_steps': 10000, + 'weight_decay': 0.01, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + 'eps': 1e-6, + }), + 'Momentum': edict({ + 'learning_rate': 2e-5, + 'momentum': 0.9, + }), + 'Thor': edict({ + 'momentum': 0.9, + 'weight_decay': 5e-4, + 'loss_scale': 1, + 'frequency': 10, + 'batch_size': 8, + }), +}) diff --git a/model_zoo/official/nlp/bert_thor/src/dataset.py b/model_zoo/official/nlp/bert_thor/src/dataset.py new file mode 100644 index 0000000000..889e27694a --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/dataset.py @@ -0,0 +1,128 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in run_pretrain.py +""" +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import log as logger +from .bert_net_config import bert_net_cfg + + +def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None): + """create train dataset""" + # apply repeat operations + files = os.listdir(data_dir) + data_files = [] + for file_name in files: + if "tfrecord" in file_name: + data_files.append(os.path.join(data_dir, file_name)) + data_files = sorted(data_files) + ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", + "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], + shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, + num_shards=device_num, shard_id=rank, shard_equal_rows=True) + ori_dataset_size = ds.get_dataset_size() + print('origin dataset size: ', ori_dataset_size) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) + ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + # apply batch operations + ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) + logger.info("data size: {}".format(ds.get_dataset_size())) + logger.info("repeat count: {}".format(ds.get_repeat_count())) + return ds + + +def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", + data_file_path=None, schema_file_path=None): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) + if assessment_method == "Spearman_correlation": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", + data_file_path=None, schema_file_path=None): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) + if assessment_method == "Spearman_correlation": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + if is_training: + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", + "start_positions", "end_positions", + "unique_ids", "is_impossible"]) + ds = ds.map(input_columns="start_positions", operations=type_cast_op) + ds = ds.map(input_columns="end_positions", operations=type_cast_op) + else: + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"]) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds diff --git a/model_zoo/official/nlp/bert_thor/src/dataset_helper.py b/model_zoo/official/nlp/bert_thor/src/dataset_helper.py new file mode 100644 index 0000000000..ff90d27e85 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/dataset_helper.py @@ -0,0 +1,177 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dataset help for minddata dataset""" +import os + +from mindspore import context +from mindspore._checkparam import check_bool, check_int +from mindspore.parallel._utils import _get_device_num, _need_to_full +from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes + + +def _send_data(dataset, epoch_num): + """Engine dataset to write data to tdt queue.""" + if not hasattr(dataset, '__has_sent__'): + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send(epoch_num) + dataset.__has_sent__ = True + + +def _send_data_no_flag(dataset, epoch_num): + """Engine dataset to write data to tdt queue directly.""" + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send(epoch_num) + + +class DatasetHelper: + """ + Help function to use the Minddata dataset. + + According to different context, change the iter of dataset, to use the same for loop in different context. + + Note: + The iter of DatasetHelper will give one epoch data. + + Args: + dataset (DataSet): The training dataset iterator. + dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True. + sink_size (int): Control the amount of data each sink. + If sink_size=-1, sink the complete dataset each epoch. + If sink_size>0, sink sink_size data each epoch. Default: -1. + + Examples: + >>> dataset_helper = DatasetHelper(dataset) + >>> for inputs in dataset_helper: + >>> outputs = network(*inputs) + """ + + def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0): + check_bool(dataset_sink_mode) + check_int(sink_size) + if sink_size < -1 or sink_size == 0: + raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) + + if dataset_sink_mode: + if context.get_context("enable_ge"): + iterclass = _DatasetIterGE + else: + if context.get_context("device_target") == "Ascend": + iterclass = _DatasetIterMSLoopSink + elif context.get_context("device_target") == "GPU": + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + iterclass = _DatasetIterPSLite + else: + iterclass = _DatasetIterMS + elif context.get_context("device_target") == "CPU": + raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") + self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order) + else: + iterclass = _DatasetIterNormal + self.iter = iterclass(dataset) + + def __iter__(self): + return self.iter.__iter__() + + # A temp solution for loop sink. Delete later + def types_shapes(self): + """Get the types and shapes from dataset on current config.""" + return self.iter.types_shapes() + + def sink_size(self): + """Get sink_size for every iteration.""" + return self.iter.get_sink_size() + + def stop_send(self): + """Free up resources about data sink.""" + self.iter.stop_send() + + +class _DatasetIter: + """Base iter for dataset helper""" + + def __init__(self, dataset, sink_size, epoch_num): + self.dataset = dataset + self.sink_size = sink_size + self.sink_count = 1 + + if not hasattr(dataset, '__TRANSFER_DATASET__'): + if hasattr(dataset, '__loop_size__'): + self.sink_size = dataset.__loop_size__ + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) + dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name + + if not hasattr(dataset, '__no_send__'): + _send_data(dataset, epoch_num) + else: + _send_data_no_flag(dataset, epoch_num) + + self.stop_send = dataset.__TRANSFER_DATASET__.stop_send + self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) + + def __iter__(self): + self.index = 0 + return self + + def __next__(self): + if self.index >= self.sink_count: + raise StopIteration() + self.index += 1 + return self.op() + + def types_shapes(self): + return self.dataset_types, self.dataset_shapes + + def get_sink_count(self, dataset, sink_size, iter_first_order): + sink_count = 1 + if hasattr(dataset, '__loop_size__'): + loop_size = dataset.__loop_size__ + iter_first_order + sink_count = int(sink_size / loop_size) * 2 + return sink_count + + def get_sink_size(self): + """get sink_size to device""" + sink_size = 1 + if hasattr(self.dataset, '__loop_size__'): + sink_size = self.dataset.__loop_size__ + else: + if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend": + if self.sink_size > 0: + sink_size = self.sink_size + else: + sink_size = self.dataset.get_dataset_size() + return sink_size + + +class _DatasetIterMSLoopSink(_DatasetIter): + """Iter for context (device_target=Ascend)""" + + def __init__(self, dataset, sink_size, epoch_num, iter_first_order): + super().__init__(dataset, sink_size, epoch_num) + self.sink_count = self.get_sink_count(dataset, sink_size, iter_first_order) + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + self.sink_count = 1 + # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, + # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for + # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. + if _need_to_full(): + device_num = _get_device_num() + self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) + + def op(): + return tuple() + + self.op = op diff --git a/model_zoo/official/nlp/bert_thor/src/evaluation_config.py b/model_zoo/official/nlp/bert_thor/src/evaluation_config.py new file mode 100644 index 0000000000..8aab4acf40 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/evaluation_config.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +config settings, will be used in finetune.py +""" + +from easydict import EasyDict as edict + +import mindspore.common.dtype as mstype +from .bert_model import BertConfig + +cfg = edict({ + 'task': 'NER', + 'num_labels': 41, + 'data_file': '', + 'schema_file': None, + 'finetune_ckpt': '', + 'use_crf': False, + 'clue_benchmark': False, +}) + +bert_net_cfg = BertConfig( + batch_size=8 if not cfg.clue_benchmark else 1, + seq_length=512, + vocab_size=30522, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, +) diff --git a/model_zoo/official/nlp/bert_thor/src/fused_layer_norm.py b/model_zoo/official/nlp/bert_thor/src/fused_layer_norm.py new file mode 100644 index 0000000000..882c4c6978 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/fused_layer_norm.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fused layernorm""" +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.nn.cell import Cell +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops.primitive import constexpr + +__all__ = ['FusedLayerNorm'] + + +@constexpr +def get_shape_for_norm(x_shape, begin_norm_axis): + print("input_shape: ", x_shape) + norm_shape = x_shape[begin_norm_axis:] + output_shape = (1, -1, 1, int(np.prod(norm_shape))) + print("output_shape: ", output_shape) + return output_shape + + +class FusedLayerNorm(Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + + Layer normalization is widely used in recurrent neural networks. It applies + normalization over a mini-batch of inputs for each single training case as described + in the paper `Layer Normalization `_. Unlike batch + normalization, layer normalization performs exactly the same computation at training and + testing times. It can be described using the following formula. It is applied across all channels + and pixel but only one batch size. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis + `begin_norm_axis ... R - 1`. + begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions + `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. + begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters + will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with + the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + use_batch_nrom (bool): Whether use batchnorm to preocess. + + Inputs: + - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, + and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. + + Outputs: + Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. + + Examples: + >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) + >>> shape1 = x.shape[1:] + >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) + >>> m(x) + """ + + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + use_batch_norm=False): + super(FusedLayerNorm, self).__init__() + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." + .format(normalized_shape, type(normalized_shape))) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.gamma = Parameter(initializer( + gamma_init, normalized_shape), name="gamma") + self.beta = Parameter(initializer( + beta_init, normalized_shape), name="beta") + self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) + + self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) + self.use_batch_norm = use_batch_norm + + def construct(self, input_x): + """construct of FusedLayerNorm""" + if self.use_batch_norm and self.training: + ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) + zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) + shape_x = F.shape(input_x) + norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) + input_x = F.reshape(input_x, norm_shape) + output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) + output = F.reshape(output, shape_x) + y = output * self.gamma + self.beta + else: + y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) + return y + + def extend_repr(self): + """Display instance object as string.""" + s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( + self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) + return s diff --git a/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor1.py b/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor1.py new file mode 100644 index 0000000000..709b0b73df --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/grad_reducer_thor1.py @@ -0,0 +1,184 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""grad_reducer_thor""" +import mindspore.common.dtype as mstype +from mindspore.communication.management import GlobalComm, get_group_size +from mindspore.nn.cell import Cell +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp + +reduce_opt = C.MultitypeFuncGraph("reduce_opt") + +_all_reduce_G = AllReduce() + + +def _init_optimizer_allreduce(group): + global _all_reduce_G + _all_reduce_G = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) + _all_reduce_G.add_prim_attr('fusion', group) + + +@reduce_opt.register("Function", "Number", "Tensor") +def _tensors_allreduce_mean(mul, degree, grad): + degree = F.scalar_cast(degree, F.dtype(grad)) + grad = _all_reduce_G(grad) + cast_op = P.Cast() + return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) + + +@reduce_opt.register("Bool", "Tensor") +def _tensors_allreduce(allreduce_filter, grad): + if allreduce_filter: + return _all_reduce_G(grad) + return grad + + +_get_datatype = C.MultitypeFuncGraph("_get_datatype") + + +@_get_datatype.register("Tensor") +def _tensors_get_datatype(grad): + """ + Acquire gradient datatype. + + Args: + grad (Tensor): The gradient tensor before operation. + + Returns: + mstype, the datatype of gradient. + """ + return F.dtype(grad) + + +_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") + + +@_cast_datatype.register("TypeType", "Tensor") +def _tensors_cast_datatype(datatype, grad): + """ + Cast gradient to datatype. + + Args: + datatype (mstype): the destination datatype of gradient. + grad (Tensor): The gradient tensor before operation. + + Returns: + Tensor, the gradient tensor after operation. + """ + return F.cast(grad, datatype) + + +class DistributedGradReducerThor1(Cell): + """ + A distributed optimizer. + + Constructs a gradient reducer Cell, which applies communication and average operations on + single-process gradient values. + + Args: + parameters (list): the parameters to be updated. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. + degree (int): The mean coefficient. Usually it equals to device number. Default: None. + + Raises: + ValueError: If degree is not a int or less than 0. + + Examples: + >>> from mindspore.communication import init, get_group_size + >>> from mindspore.ops import composite as C + >>> from mindspore.ops import operations as P + >>> from mindspore.ops import functional as F + >>> from mindspore import context + >>> from mindspore import nn + >>> from mindspore import ParallelMode, ParameterTuple + >>> + >>> device_id = int(os.environ["DEVICE_ID"]) + >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, + >>> device_id=int(device_id), enable_hccl=True) + >>> init() + >>> context.reset_auto_parallel_context() + >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) + >>> + >>> + >>> class TrainingWrapper(nn.Cell): + >>> def __init__(self, network, optimizer, sens=1.0): + >>> super(TrainingWrapper, self).__init__(auto_prefix=False) + >>> self.network = network + >>> self.network.add_flags(defer_inline=True) + >>> self.weights = ParameterTuple(network.trainable_params()) + >>> self.optimizer = optimizer + >>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + >>> self.sens = sens + >>> self.reducer_flag = False + >>> self.grad_reducer = None + >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, + >>> ParallelMode.HYBRID_PARALLEL]: + >>> self.reducer_flag = True + >>> if self.reducer_flag: + >>> mean = context.get_auto_parallel_context("mirror_mean") + >>> if mean.get_device_num_is_set(): + >>> degree = context.get_auto_parallel_context("device_num") + >>> else: + >>> degree = get_group_size() + >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + >>> + >>> def construct(self, *args): + >>> weights = self.weights + >>> loss = self.network(*args) + >>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + >>> grads = self.grad(self.network, weights)(*args, sens) + >>> if self.reducer_flag: + >>> # apply grad reducer on grads + >>> grads = self.grad_reducer(grads) + >>> return F.depend(loss, self.optimizer(grads)) + >>> + >>> network = Net() + >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> train_cell = TrainingWrapper(network, optimizer) + >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) + >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) + >>> grads = train_cell(inputs, label) + """ + + def __init__(self, parameters, group, mean=True, degree=None): + super(DistributedGradReducerThor1, self).__init__(auto_prefix=False) + self.hyper_map = C.HyperMap() + self.mul = P.Mul() + if degree is None: + self.degree = get_group_size() + else: + if not isinstance(degree, int) or degree <= 0: + raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") + self.degree = degree + self.mean = mean + self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) + _init_optimizer_allreduce(group) + + def construct(self, grads): + """construct of DistributedGradReducerThor1""" + # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the + # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, + # and cast back after the operation. + datatypes = self.hyper_map(F.partial(_get_datatype), grads) + grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) + + if self.mean: + new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) + else: + new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) + + new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) + return new_grad diff --git a/model_zoo/official/nlp/bert_thor/src/lr_generator.py b/model_zoo/official/nlp/bert_thor/src/lr_generator.py new file mode 100644 index 0000000000..c2416e9b81 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/lr_generator.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import numpy as np + +from mindspore.common.tensor import Tensor + + +def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_steps(int): number of warmup epochs + total_steps(int): total epoch of training + poly_power(int): poly learning rate power + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max - lr_end) * (base ** poly_power) + lr = lr + lr_end + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + current_step = global_step + learning_rate = learning_rate[current_step:] + return learning_rate + + +# bert kfac hyperparam setting +def get_bert_lr(): + learning_rate = Tensor( + get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=4e-4, warmup_steps=0, total_steps=30000, + poly_power=1)) + return learning_rate + + +def get_bert_damping(): + damping = Tensor( + get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000, + poly_power=1)) + return damping diff --git a/model_zoo/official/nlp/bert_thor/src/model_thor.py b/model_zoo/official/nlp/bert_thor/src/model_thor.py new file mode 100644 index 0000000000..6a687f2791 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/model_thor.py @@ -0,0 +1,784 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Model.""" +import math +import os +from collections.abc import Iterable + +import numpy as np +from mindspore._c_expression import init_exec_dataset + +from mindspore import context +from mindspore import log as logger +from mindspore import nn +from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int +from mindspore.common import dtype as mstype +from mindspore.common.dtype import pytype_to_dtype +from mindspore.common.tensor import Tensor +from mindspore.nn.metrics import Loss +from mindspore.nn.metrics import get_metrics +from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell +from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ + _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check +from mindspore.parallel._utils import _need_to_full +from mindspore.train import amp +from mindspore.train._utils import _to_full_tensor +from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager +from mindspore.train.parallel_utils import ParallelMode +from .dataset_helper import DatasetHelper + + +def _convert_type(types): + """ + Convert from numpy type to tensor type. + + Args: + types (list): Numpy type list of element in dataset. + + Returns: + list, list of element in dataset. + """ + ms_types = [] + for np_type in types: + ms_type = pytype_to_dtype(np_type) + ms_types.append(ms_type) + return ms_types + + +def _get_types_and_shapes(dataset): + """Get dataset types and shapes.""" + dataset_types = _convert_type(dataset.output_types()) + dataset_shapes = dataset.output_shapes() + return dataset_types, dataset_shapes + + +def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): + """Initialize and execute the dataset graph.""" + batch_size = exec_dataset.get_batch_size() + input_indexs = exec_dataset.input_indexs + + # transform data format + dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) + init_exec_dataset(exec_dataset.__ME_INITED__, + dataset_size, + batch_size, + dataset_types, + dataset_shapes, + input_indexs, + phase=phase, + need_run=False) + + +class Model: + """ + High-Level API for Training or Testing. + + `Model` groups layers into an object with training and inference features. + + Args: + network (Cell): The training or testing network. + loss_fn (Cell): Objective function, if loss_fn is None, the + network should contain the logic of loss and grads calculation, and the logic + of parallel if needed. Default: None. + optimizer (Cell): Optimizer for updating the weights. Default: None. + metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during + training and testing. eg: {'accuracy', 'recall'}. Default: None. + eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as + `eval_network`. Default: None. + eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of + `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three + elements, representing the positions of loss value, predict value and label, the loss + value would be passed to `Loss` metric, predict value and label would be passed to other + metric. Default: None. + amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed + precision training. Supports [O0, O2, O3]. Default: "O0". + + - O0: Do not change. + - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. + - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. + + O2 is recommended on GPU, O3 is recommended on Ascend. + + loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else + scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. + e.g. Use `loss_scale_manager=None` to set the value. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + >>> self.bn = nn.BatchNorm2d(64) + >>> self.relu = nn.ReLU() + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 + >>> + >>> def construct(self, x): + >>> x = self.conv(x) + >>> x = self.bn(x) + >>> x = self.relu(x) + >>> x = self.flatten(x) + >>> out = self.fc(x) + >>> return out + >>> + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> dataset = get_dataset() + >>> model.train(2, dataset) + """ + + def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, + eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs): + self._network = network + self._loss_fn = loss_fn + self._optimizer = optimizer + self._loss_scale_manager = None + self._loss_scale_manager_set = False + self._keep_bn_fp32 = True + self._check_kwargs(kwargs) + self._amp_level = amp_level + self._process_amp_args(kwargs) + self._parallel_mode = _get_parallel_mode() + self._device_number = _get_device_num() + self._global_rank = _get_global_rank() + self._parameter_broadcast = _get_parameter_broadcast() + self._frequency = frequency + self._stop_epoch = stop_epoch + + self._train_network = self._build_train_network() + self._build_eval_network(metrics, eval_network, eval_indexes) + self._build_predict_network() + + def _process_amp_args(self, kwargs): + if self._amp_level in ["O0", "O3"]: + self._keep_bn_fp32 = False + if 'keep_batchnorm_fp32' in kwargs: + self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] + if 'loss_scale_manager' in kwargs: + self._loss_scale_manager = kwargs['loss_scale_manager'] + self._loss_scale_manager_set = True + + def _check_kwargs(self, kwargs): + for arg in kwargs: + if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: + raise ValueError(f"Unsupport arg '{arg}'") + + def _build_train_network(self): + """Build train network""" + network = self._network + if self._optimizer: + if self._loss_scale_manager_set: + network = amp.build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) + else: + network = amp.build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) + elif self._loss_fn: + network = nn.WithLossCell(network, self._loss_fn) + # If need to check if loss_fn is not None, but optimizer is None + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + return network + + def _build_eval_network(self, metrics, eval_network, eval_indexes): + """Build the network for evaluation.""" + self._metric_fns = get_metrics(metrics) + if not self._metric_fns: + return + + if eval_network is not None: + if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3): + raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \ + must be three. But got {}".format(eval_indexes)) + + self._eval_network = eval_network + self._eval_indexes = eval_indexes + else: + if self._loss_fn is None: + raise ValueError("loss_fn can not be None.") + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") + self._eval_indexes = [0, 1, 2] + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + if self._optimizer: + self._eval_network = _VirtualDatasetCell(self._eval_network) + self._eval_network.set_auto_parallel() + + def _build_predict_network(self): + """Build the network for prediction.""" + self._predict_network = self._network + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._predict_network = _VirtualDatasetCell(self._network) + self._predict_network.set_auto_parallel() + + def _clear_metrics(self): + """Clear metrics local values.""" + for metric in self._metric_fns.values(): + metric.clear() + + def _update_metrics(self, outputs): + """Update metrics local values.""" + if not isinstance(outputs, tuple): + raise ValueError("The `outputs` is not tuple.") + + if self._eval_indexes is not None and len(outputs) < 3: + raise ValueError("The length of `outputs` must be greater than or equal to 3, \ + but got {}".format(len(outputs))) + + for metric in self._metric_fns.values(): + if self._eval_indexes is None: + metric.update(*outputs) + else: + if isinstance(metric, Loss): + metric.update(outputs[self._eval_indexes[0]]) + else: + metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]]) + + def _get_metrics(self): + """Get metrics local values.""" + metrics = dict() + for key, value in self._metric_fns.items(): + metrics[key] = value.eval() + return metrics + + def _get_scaling_sens(self): + """get the scaling sens""" + scaling_sens = 1 + if self._loss_scale_manager is not None: + scaling_sens = self._loss_scale_manager.get_loss_scale() + if self._parallel_mode == ParallelMode.DATA_PARALLEL: + scaling_sens /= self._device_number + return scaling_sens + + def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, + iter_first_order=9): + """Initializes dataset.""" + need_wrap = False + if dataset_sink_mode: + # remove later to deal with loop sink + if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ + and not context.get_context("enable_ge"): + need_wrap = True + + if not is_train: + dataset.__loop_size__ = 1 + + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) + + # remove later to deal with loop sink + if need_wrap: + network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) + network.set_train(is_train) + network.phase = phase + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + + return dataset_helper, network + + def init(self, train_dataset=None, valid_dataset=None): + """ + Initializes compute graphs and data graphs with sink mode. + + Note: + Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. + + Args: + train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be + initialized. Default: None. + valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will + be initialized, and `metrics` in `Model` can not be None. Default: None. + + Examples: + >>> train_dataset = get_train_dataset() + >>> valid_dataset = get_valid_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) + >>> model.init(train_dataset, valid_dataset) + >>> model.train(2, train_dataset) + >>> model.eval(valid_dataset) + """ + if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": + raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') + + if not train_dataset and not valid_dataset: + raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') + + _device_number_check(self._parallel_mode, self._device_number) + + if train_dataset: + _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) + self._train_network.set_train() + self._train_network.phase = 'train' + + if self._parameter_broadcast: + self._train_network.set_broadcast_flag() + train_dataset.__no_send__ = True + train_dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True) + self._train_network = train_network + for inputs in train_dataset_helper: + self._train_network.compile(*inputs) + break + + if valid_dataset: + if not self._metric_fns: + raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') + + self._eval_network.set_train(False) + self._eval_network.phase = 'eval' + valid_dataset.__no_send__ = True + valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + for inputs in valid_dataset_helper: + self._eval_network.compile(*inputs) + break + + def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): + """ + Training. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be + returned and passed to the network. Otherwise, a tuple (data, label) will + be returned, and the data and label are passed to the network and loss + function respectively. + callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + Configure pynative mode, the training process will be performed with + dataset not sink. + sink_size (int): Control the amount of data each sink. Default: -1. + """ + epoch = check_int_positive(epoch) + self._train_network.set_train() + + if self._parameter_broadcast: + self._train_network.set_broadcast_flag() + + cb_params = _InternalCallbackParam() + cb_params.train_network = self._train_network + cb_params.epoch_num = epoch + if dataset_sink_mode and sink_size > 0: + cb_params.batch_num = sink_size + else: + cb_params.batch_num = train_dataset.get_dataset_size() + cb_params.mode = "train" + cb_params.loss_fn = self._loss_fn + cb_params.optimizer = self._optimizer + cb_params.parallel_mode = self._parallel_mode + cb_params.device_number = self._device_number + cb_params.train_dataset = train_dataset + cb_params.list_callback = self._transform_callbacks(callbacks) + cb_params.train_dataset_element = None + cb_params.network = self._network + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + epoch = 1 + + # build callback list + with _CallbackManager(callbacks) as list_callback: + if not dataset_sink_mode: + self._train_process(epoch, train_dataset, list_callback, cb_params) + elif context.get_context("mode") == context.PYNATIVE_MODE: + logger.warning("The pynative mode cannot support dataset sink mode currently." + "So the training process will be performed with dataset not sink.") + self._train_process(epoch, train_dataset, list_callback, cb_params) + else: + self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) + + @staticmethod + def _transform_callbacks(callbacks): + """Transform callback to a list.""" + if callbacks is None: + return [] + + if isinstance(callbacks, Iterable): + return list(callbacks) + + return [callbacks] + + def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): + """ + Training process. The data would be passed to network through dataset channel. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned, and the data and label are passed to the network and loss + function respectively. + list_callback (Callback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + sink_size (int): Control the amount of data each sink. Default: -1. + """ + if sink_size == -1: + epoch_num = epoch + else: + epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) + + iter_first_order = self._frequency - 1 + iter_second_order = 1 + train_dataset.__loop_size__ = iter_second_order + dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True, + sink_size=sink_size, + epoch_num=epoch_num, + iter_first_order=iter_first_order) + self._train_network = train_network + cb_params.train_network = self._train_network + cb_params.cur_step_num = 0 + + run_context = RunContext(cb_params) + list_callback.begin(run_context) + + # used to stop training for early stop, such as stopAtTIme or stopATStep + should_stop = False + has_do_dataset_init = False + switch_branch_one = True + train_network_init_flag = True + for i in range(epoch): + cb_params.cur_epoch_num = i + 1 + list_callback.epoch_begin(run_context) + + # for data sink dataset_helper only iter once, other wise iter epoch_size times. + for inputs in dataset_helper: + if _need_to_full(): + inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) + list_callback.step_begin(run_context) + if switch_branch_one: + cb_params.cur_step_num += dataset_helper.sink_size() + if train_network_init_flag: + self._train_network.add_flags_recursive(thor=True) + self._train_network.phase = 'train0' + else: + cb_params.cur_step_num += iter_first_order + if train_network_init_flag: + self._train_network.add_flags_recursive(thor=False) + train_network_init_flag = False + self._train_network.phase = 'train1' + if not has_do_dataset_init: + _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') + has_do_dataset_init = True + switch_branch_one = not switch_branch_one + outputs = self._train_network(*inputs) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + + list_callback.epoch_end(run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + dataset_helper.stop_send() + + list_callback.end(run_context) + + def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None): + """ + Training process. The data would be passed to network directly. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned, and the data and label are passed to the network and loss + function respectively. + list_callback (Callback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + """ + dataset_helper, _ = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=False) + cb_params.cur_step_num = 0 + run_context = RunContext(cb_params) + list_callback.begin(run_context) + # used to stop training for early stop, such as stopAtTIme or stopATStep + should_stop = False + + for i in range(epoch): + cb_params.cur_epoch_num = i + 1 + + list_callback.epoch_begin(run_context) + + for next_element in dataset_helper: + len_element = len(next_element) + if self._loss_fn and len_element != 2: + raise ValueError("when loss_fn is not None, train_dataset should" + "return two elements, but got {}".format(len_element)) + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + + overflow = False + if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): + scaling_sens = self._get_scaling_sens() + next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) + + cb_params.train_dataset_element = next_element + outputs = self._train_network(*next_element) + cb_params.net_outputs = outputs + if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): + _, overflow, _ = outputs + overflow = np.all(overflow.asnumpy()) + self._loss_scale_manager.update_loss_scale(overflow) + + list_callback.step_end(run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + train_dataset.reset() + + list_callback.epoch_end(run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + list_callback.end(run_context) + + def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): + """ + Training API where the iteration is controlled by python front-end. + + When setting pynative mode, the training process will be performed with dataset not sink. + + Note: + CPU is not supported when dataset_sink_mode is true. + If dataset_sink_mode is True, epoch of training should be equal to the count of repeat + operation in dataset processing. Otherwise, errors could occur since the amount of data + is not the amount training requires. + If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned, and the data and label are passed to the network and loss + function respectively. + callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + Configure pynative mode, the training process will be performed with + dataset not sink. + sink_size (int): Control the amount of data each sink. + If sink_size=-1, sink the complete dataset each epoch. + If sink_size>0, sink sink_size data each epoch. + If dataset_sink_mode is False, set sink_size invalid. Default: -1. + + Examples: + >>> dataset = get_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> loss_scale_manager = FixedLossScaleManager() + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) + >>> model.train(2, dataset) + """ + check_bool(dataset_sink_mode) + check_int(sink_size) + if sink_size < -1 or sink_size == 0: + raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) + + _device_number_check(self._parallel_mode, self._device_number) + _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) + + self._train(epoch, + train_dataset, + callbacks=callbacks, + dataset_sink_mode=dataset_sink_mode, + sink_size=sink_size) + + def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): + """ + Evaluation. The data would be passed to network through dataset channel. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + list_callback (Callback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + + Returns: + Dict, returns the loss value & metrics values for the model in test mode. + """ + run_context = RunContext(cb_params) + + dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + cb_params.eval_network = self._eval_network + list_callback.begin(run_context) + + for inputs in dataset_helper: + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + + outputs = self._eval_network(*inputs) + + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + self._update_metrics(outputs) + + metrics = self._get_metrics() + cb_params.metrics = metrics + list_callback.end(run_context) + + return metrics + + def _eval_process(self, valid_dataset, list_callback=None, cb_params=None): + """ + Evaluation. The data would be passed to network directly. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + list_callback (Callback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + + Returns: + Dict, returns the loss value & metrics values for the model in test mode. + """ + run_context = RunContext(cb_params) + list_callback.begin(run_context) + + dataset_helper, _ = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=False) + for next_element in dataset_helper: + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + outputs = self._eval_network(*next_element) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + self._update_metrics(outputs) + + valid_dataset.reset() + + metrics = self._get_metrics() + cb_params.metrics = metrics + list_callback.end(run_context) + return metrics + + def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True): + """ + Evaluation API where the iteration is controlled by python front-end. + + Configure to pynative mode, the evaluation will be performed with dataset non-sink mode. + + Note: + CPU is not supported when dataset_sink_mode is true. + If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + callbacks (list): List of callback object. Callbacks which should be excuted + while training. Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + + Returns: + Dict, returns the loss value & metrics values for the model in test mode. + + Examples: + >>> dataset = get_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) + >>> model.eval(dataset) + """ + check_bool(dataset_sink_mode) + _device_number_check(self._parallel_mode, self._device_number) + if not self._metric_fns: + raise ValueError("metric fn can not be None or empty.") + + cb_params = _InternalCallbackParam() + cb_params.eval_network = self._eval_network + cb_params.valid_dataset = valid_dataset + cb_params.batch_num = valid_dataset.get_dataset_size() + cb_params.mode = "eval" + cb_params.cur_step_num = 0 + cb_params.list_callback = self._transform_callbacks(callbacks) + cb_params.network = self._network + + self._eval_network.set_train(mode=False) + self._eval_network.phase = 'eval' + + self._clear_metrics() + + with _CallbackManager(callbacks) as list_callback: + if dataset_sink_mode: + return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) + return self._eval_process(valid_dataset, list_callback, cb_params) + + def predict(self, *predict_data): + """ + Generates output predictions for the input samples. + + Data could be single tensor, or list of tensor, tuple of tensor. + + Note: + Batch data should be put together in one tensor. + + Args: + predict_data (Tensor): Tensor of predict data. can be array, list or tuple. + + Returns: + Tensor, array(s) of predictions. + + Examples: + >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) + >>> model = Model(Net()) + >>> model.predict(input_data) + """ + self._predict_network.set_train(False) + check_input_data(*predict_data, data_class=Tensor) + result = self._predict_network(*predict_data) + + check_output_data(result) + return result + + +__all__ = ["Model"] diff --git a/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py b/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py new file mode 100644 index 0000000000..b9e9c46ab4 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py @@ -0,0 +1,422 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""momentum""" +import mindspore.common.dtype as mstype +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.common.parameter import ParameterTuple +from mindspore.common.tensor import Tensor +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.ops import functional as F, composite as C, operations as P + +momentum_opt = C.MultitypeFuncGraph("momentum_opt") + + +@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): + """Apply momentum optimizer to the weight parameter using Tensor.""" + success = True + success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) + return success + + +op_add = P.AddN() +apply_decay = C.MultitypeFuncGraph("apply_decay") + + +@apply_decay.register("Number", "Bool", "Tensor", "Tensor") +def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): + """Get grad with weight_decay.""" + if if_apply: + return op_add((weight * weight_decay, gradient)) + return gradient + + +class THOR(Optimizer): + """THOR""" + + def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, + loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, + decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): + super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) + if isinstance(momentum, float) and momentum < 0.0: + raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") + self.params = self.parameters + self.moments = self.params.clone(prefix="moments", init='zeros') + self.hyper_map = C.HyperMap() + self.opt = P.ApplyMomentum() + self.matrix_A = ParameterTuple(matrix_A) + self.matrix_G = ParameterTuple(matrix_G) + self.A_inv_max = ParameterTuple(A_inv_max) + self.G_inv_max = ParameterTuple(G_inv_max) + self.matmul = P.MatMul() + self.transpose = P.Transpose() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.mul = P.Mul() + self.gather = P.GatherV2() + self.matrix_A_inv = () + self.matrix_G_inv = () + self.matrix_max_inv = () + self.num_hidden_layers = num_hidden_layers + fc_layer_num = num_hidden_layers * 6 + 5 + for i in range(fc_layer_num): + self.matrix_max_inv = self.matrix_max_inv + ( + Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) + self.log = P.Log() + self.exp = P.Exp() + self.sqrt = P.Sqrt() + self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) + self.assign = P.Assign() + self.cast = P.Cast() + self.thor = True + self.weight_decay = weight_decay * loss_scale + self.decay_flags = tuple(decay_filter(x) for x in self.parameters) + self.expand = P.ExpandDims() + self.square = P.Square() + self.inv = P.Inv() + self.batch_size = batch_size + self.damping = damping + self.freq = Tensor(frequency, mstype.int32) + self.one = Tensor(1, mstype.int32) + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + + def construct(self, gradients): + """construct of THOR""" + params = self.params + moments = self.moments + encoder_layers_num = 16 + if self.thor: + new_grads = () + # process embedding layer + for em_idx in range(3): + g = gradients[em_idx] + matrix_idx = em_idx + temp_a_ori = self.matrix_A[matrix_idx] + temp_a = self.expand(temp_a_ori, 1) + temp_g = self.matrix_G[matrix_idx] + G_max = self.G_inv_max[matrix_idx] + temp_g = self.cast(temp_g, mstype.float32) + matrix_G_inv_max = self.log(G_max) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + g = self.mul(temp_a, g) + g = self.cast(g, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.matmul(g, temp_g) + g = self.cast(g, mstype.float32) + g = self.mul(g, G_max) + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (g,) + # process bert_embedding_postprocessor.layernorm + grad_idx = 3 + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.one + damping = self.sqrt(damping_step) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + + for i in range(self.num_hidden_layers): + encoder_begin_idx = encoder_layers_num * i + 5 + for j in range(0, encoder_layers_num, 2): + grad_idx = encoder_begin_idx + j + if j in (8, 14): + # process layernorm layer + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + else: + g = gradients[grad_idx] + offset_idx = 0 + if j in (0, 2, 4, 6): + offset_idx = j // 2 + elif j in (10, 12): + offset_idx = j // 2 - 1 + matrix_idx = 6 * i + offset_idx + 3 + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, temp_max) + + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (g,) + new_grads = new_grads + (gradients[grad_idx + 1],) + + # process pooler layer + pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5 + matrix_idx = self.num_hidden_layers * 6 + 3 + g = gradients[pooler_layer_idx] + pooler_bias = gradients[pooler_layer_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, temp_max) + + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (g, pooler_bias) + + # for cls1 fc layer: mlm + mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8 + matrix_idx = self.num_hidden_layers * 6 + 4 + g = gradients[mlm_fc_idx] + mlm_bias = gradients[mlm_fc_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, temp_max) + + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (gradients[mlm_fc_idx - 1],) + new_grads = new_grads + (g, mlm_bias) + # add bert.cls1.layernorm grad + begin_idx = mlm_fc_idx + 2 + end_idx = mlm_fc_idx + 4 + new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) + new_grads = new_grads + gradients[lenth - 2: lenth] + gradients = new_grads + else: + new_grads = () + # process embedding layer + for em_idx in range(3): + g = gradients[em_idx] + matrix_idx = em_idx + temp_a = self.matrix_A[matrix_idx] + temp_a = self.expand(temp_a, 1) + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + g = self.mul(temp_a, g) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + g = self.matmul(g, temp_g) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + new_grads = new_grads + (g,) + # process bert_embedding_postprocessor.layernorm + grad_idx = 3 + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.one + damping = self.sqrt(damping_step) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + + for i in range(self.num_hidden_layers): + encoder_begin_idx = encoder_layers_num * i + 5 + for j in range(0, encoder_layers_num, 2): + grad_idx = encoder_begin_idx + j + if j in (8, 14): + # process layernorm layer + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + else: + g = gradients[grad_idx] + offset_idx = 0 + if j in (0, 2, 4, 6): + offset_idx = j // 2 + elif j in (10, 12): + offset_idx = j // 2 - 1 + matrix_idx = 6 * i + offset_idx + 3 + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + new_grads = new_grads + (g,) + new_grads = new_grads + (gradients[grad_idx + 1],) + + # process pooler layer + pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5 + matrix_idx = self.num_hidden_layers * 6 + 3 + g = gradients[pooler_layer_idx] + pooler_bias = gradients[pooler_layer_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + new_grads = new_grads + (g, pooler_bias) + + # for cls1 fc layer: mlm + mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8 + matrix_idx = self.num_hidden_layers * 6 + 4 + g = gradients[mlm_fc_idx] + mlm_bias = gradients[mlm_fc_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + # add bert.cls1.output_bias grad + new_grads = new_grads + (gradients[mlm_fc_idx - 1],) + new_grads = new_grads + (g, mlm_bias) + # add bert.cls1.layernorm grad + begin_idx = mlm_fc_idx + 2 + end_idx = mlm_fc_idx + 4 + new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) + new_grads = new_grads + gradients[lenth - 2: lenth] + gradients = new_grads + + if self.weight_decay > 0: + gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, + params, gradients) + gradients = self.scale_grad(gradients) + lr = self.get_lr() + success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) + return success diff --git a/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py b/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py new file mode 100644 index 0000000000..8cb56d1f70 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py @@ -0,0 +1,429 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""momentum""" +import mindspore.common.dtype as mstype +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.common.parameter import ParameterTuple +from mindspore.common.tensor import Tensor +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.parallel._utils import _get_device_num, _get_mirror_mean +from .grad_reducer_thor1 import DistributedGradReducerThor1 + +momentum_opt = C.MultitypeFuncGraph("momentum_opt") + + +@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): + """Apply momentum optimizer to the weight parameter using Tensor.""" + success = True + success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) + return success + + +op_add = P.AddN() +apply_decay = C.MultitypeFuncGraph("apply_decay") + + +@apply_decay.register("Number", "Bool", "Tensor", "Tensor") +def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): + """Get grad with weight_decay.""" + if if_apply: + return op_add((weight * weight_decay, gradient)) + return gradient + + +class THOR(Optimizer): + """THOR""" + + def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, + loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, + decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): + super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) + if isinstance(momentum, float) and momentum < 0.0: + raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") + self.params = self.parameters + self.moments = self.params.clone(prefix="moments", init='zeros') + self.hyper_map = C.HyperMap() + self.opt = P.ApplyMomentum() + self.matrix_A = ParameterTuple(matrix_A) + self.matrix_G = ParameterTuple(matrix_G) + self.A_inv_max = ParameterTuple(A_inv_max) + self.G_inv_max = ParameterTuple(G_inv_max) + self.matmul = P.MatMul() + self.transpose = P.Transpose() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.mul = P.Mul() + self.gather = P.GatherV2() + self.matrix_A_inv = () + self.matrix_G_inv = () + self.matrix_max_inv = () + self.num_hidden_layers = num_hidden_layers + fc_layer_num = num_hidden_layers * 6 + 5 + for i in range(fc_layer_num): + self.matrix_max_inv = self.matrix_max_inv + ( + Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) + self.log = P.Log() + self.exp = P.Exp() + self.sqrt = P.Sqrt() + self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) + self.assign = P.Assign() + self.cast = P.Cast() + self.thor = True + self.weight_decay = weight_decay * loss_scale + self.decay_flags = tuple(decay_filter(x) for x in self.parameters) + self.expand = P.ExpandDims() + self.square = P.Square() + self.inv = P.Inv() + self.batch_size = batch_size + self.damping = damping + self.freq = Tensor(frequency, mstype.int32) + self.one = Tensor(1, mstype.int32) + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer_g = DistributedGradReducerThor1(self.parameters, 3, mean, degree) + + def construct(self, gradients): + """construct of THOR""" + params = self.params + moments = self.moments + encoder_layers_num = 16 + if self.thor: + new_grads = () + # process embedding layer + for em_idx in range(3): + g = gradients[em_idx] + matrix_idx = em_idx + temp_a_ori = self.matrix_A[matrix_idx] + temp_a = self.expand(temp_a_ori, 1) + temp_g = self.matrix_G[matrix_idx] + G_max = self.G_inv_max[matrix_idx] + temp_g = self.cast(temp_g, mstype.float32) + matrix_G_inv_max = self.log(G_max) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + g = self.mul(temp_a, g) + g = self.cast(g, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.matmul(g, temp_g) + g = self.cast(g, mstype.float32) + g = self.mul(g, G_max) + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (g,) + # process bert_embedding_postprocessor.layernorm + grad_idx = 3 + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.one + damping = self.sqrt(damping_step) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + + for i in range(self.num_hidden_layers): + encoder_begin_idx = encoder_layers_num * i + 5 + for j in range(0, encoder_layers_num, 2): + grad_idx = encoder_begin_idx + j + if j in (8, 14): + # process layernorm layer + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + else: + g = gradients[grad_idx] + offset_idx = 0 + if j in (0, 2, 4, 6): + offset_idx = j // 2 + elif j in (10, 12): + offset_idx = j // 2 - 1 + matrix_idx = 6 * i + offset_idx + 3 + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, temp_max) + + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (g,) + new_grads = new_grads + (gradients[grad_idx + 1],) + + # process pooler layer + pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5 + matrix_idx = self.num_hidden_layers * 6 + 3 + g = gradients[pooler_layer_idx] + pooler_bias = gradients[pooler_layer_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, temp_max) + + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (g, pooler_bias) + + # for cls1 fc layer: mlm + mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8 + matrix_idx = self.num_hidden_layers * 6 + 4 + g = gradients[mlm_fc_idx] + mlm_bias = gradients[mlm_fc_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, temp_max) + + fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) + fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) + fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + new_grads = new_grads + (gradients[mlm_fc_idx - 1],) + new_grads = new_grads + (g, mlm_bias) + # add bert.cls1.layernorm grad + begin_idx = mlm_fc_idx + 2 + end_idx = mlm_fc_idx + 4 + new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) + new_grads = new_grads + gradients[lenth - 2: lenth] + gradients = new_grads + gradients = self.grad_reducer_g(gradients) + else: + new_grads = () + # process embedding layer + for em_idx in range(3): + g = gradients[em_idx] + matrix_idx = em_idx + temp_a = self.matrix_A[matrix_idx] + temp_a = self.expand(temp_a, 1) + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + g = self.mul(temp_a, g) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + g = self.matmul(g, temp_g) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + new_grads = new_grads + (g,) + # process bert_embedding_postprocessor.layernorm + grad_idx = 3 + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.one + damping = self.sqrt(damping_step) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + + for i in range(self.num_hidden_layers): + encoder_begin_idx = encoder_layers_num * i + 5 + for j in range(0, encoder_layers_num, 2): + grad_idx = encoder_begin_idx + j + if j in (8, 14): + # process layernorm layer + beta_grad = gradients[grad_idx] + gamma_grad = gradients[grad_idx + 1] + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + beta = self.square(beta_grad) + beta_cov = self.mul(beta, 1.0 / normalizer) + beta_cov = beta_cov + damping + beta_inv = self.inv(beta_cov) + gamma = self.square(gamma_grad) + gamma_cov = self.mul(gamma, 1.0 / normalizer) + gamma_cov = gamma_cov + damping + gamma_inv = self.inv(gamma_cov) + beta = self.mul(beta_inv, beta_grad) + gamma = self.mul(gamma_inv, gamma_grad) + new_grads = new_grads + (beta, gamma) + else: + g = gradients[grad_idx] + offset_idx = 0 + if j in (0, 2, 4, 6): + offset_idx = j // 2 + elif j in (10, 12): + offset_idx = j // 2 - 1 + matrix_idx = 6 * i + offset_idx + 3 + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + new_grads = new_grads + (g,) + new_grads = new_grads + (gradients[grad_idx + 1],) + + # process pooler layer + pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5 + matrix_idx = self.num_hidden_layers * 6 + 3 + g = gradients[pooler_layer_idx] + pooler_bias = gradients[pooler_layer_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + new_grads = new_grads + (g, pooler_bias) + + # for cls1 fc layer: mlm + mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8 + matrix_idx = self.num_hidden_layers * 6 + 4 + g = gradients[mlm_fc_idx] + mlm_bias = gradients[mlm_fc_idx + 1] + temp_a = self.matrix_A[matrix_idx] + temp_g = self.matrix_G[matrix_idx] + matrix_max = self.matrix_max_inv[matrix_idx] + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + g = self.cast(g, mstype.float16) + + g = self.matmul(temp_g, g) + g = self.matmul(g, temp_a) + g = self.cast(g, mstype.float32) + g = self.mul(g, matrix_max) + # add bert.cls1.output_bias grad + new_grads = new_grads + (gradients[mlm_fc_idx - 1],) + new_grads = new_grads + (g, mlm_bias) + # add bert.cls1.layernorm grad + begin_idx = mlm_fc_idx + 2 + end_idx = mlm_fc_idx + 4 + new_grads = new_grads + gradients[begin_idx: end_idx] + lenth = len(gradients) + new_grads = new_grads + gradients[lenth - 2: lenth] + gradients = new_grads + gradients = self.grad_reducer_g(gradients) + + if self.weight_decay > 0: + gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, + params, gradients) + gradients = self.scale_grad(gradients) + lr = self.get_lr() + success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) + return success diff --git a/model_zoo/official/nlp/bert_thor/src/thor_layer.py b/model_zoo/official/nlp/bert_thor/src/thor_layer.py new file mode 100644 index 0000000000..8f9e0c0759 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/thor_layer.py @@ -0,0 +1,304 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""thor_layer""" +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore._checkparam import check_bool, check_int_positive +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.nn.cell import Cell +from mindspore.nn.layer.activation import get_activation +from mindspore.ops import operations as P + + +class Embedding_Thor(Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02, + name='embedding_table', + is_expand=False, + batch_size=12, + damping=0.03, + loss_scale=1, + frequency=10, + ): + super(Embedding_Thor, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name=name) + self.thor = True + self.is_expand = is_expand + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.em_shape = tuple(embedding_shape) + self.shape = P.Shape() + self.loss_scale = Tensor(1 / loss_scale, mstype.float16) + self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_A_inv', + requires_grad=False) + self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)), + name="matrix_G_inv", requires_grad=False) + self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) + self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) + self.fused_abs_max = P.CusFusedAbsMax1() + self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)) + self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32)) + self.dampingG = Tensor(np.identity(embedding_size), mstype.float32) + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + self.freq = Tensor(frequency, mstype.int32) + self.axis = 0 + self.damping = damping + self.gather = P.GatherV2() + self.sqrt = P.Sqrt() + self.mul = P.Mul() + self.cast = P.Cast() + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.vector_matmul = P.CusBatchMatMul() + self.cholesky = P.CusCholeskyTrsm() + self.matrix_combine = P.CusMatrixCombine() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.inv = P.Inv() + self.getG = P.InsertGradientOf(self.save_gradient) + self.batch_size = batch_size + + def save_gradient(self, dout): + """save_gradient""" + bs = self.batch_size + bs = self.cast(bs, mstype.float32) + out = dout + dout = self.mul(dout, self.loss_scale) + dout = self.mul(dout, bs) + shape = self.shape(dout) + normalizer = self.cast(shape[0], mstype.float32) + matrix_G = self.cube_matmul(dout, dout) + matrix_G = self.mul(matrix_G, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.freq + damping = self.sqrt(damping_step) + dampingG = self.cast(self.dampingG, mstype.float32) + matrix_G = matrix_G + damping * dampingG + matrix_G_inv = self.cholesky(matrix_G) + matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) + self.G_inv_max = matrix_G_inv_max + matrix_G_inv = self.matrix_combine(matrix_G_inv) + matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) + self.matrix_G_inv = matrix_G_inv + return out + + def construct(self, input_ids): + """construct of Embedding_Thor""" + if self.is_expand: + input_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(input_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) + else: + if self.thor: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + matrix_A = self.reduce_sum(one_hot_ids, 0) + normalizer = self.batch_size + normalizer = self.cast(normalizer, mstype.float32) + matrix_A = self.mul(matrix_A, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, self.axis) + damping_step = self.cast(damping_step, mstype.float32) + damping = self.sqrt(damping_step) + dampingA = self.cast(self.dampingA, mstype.float32) + matrix_A = matrix_A + damping * dampingA + matrix_A_inv = self.inv(matrix_A) + self.matrix_A_inv = matrix_A_inv + self.matrix_G_inv = self.fake_G + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output_for_reshape = self.getG(output_for_reshape) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + + output = self.reshape(output_for_reshape, self.em_shape) + return output, self.embedding_table + + +class Dense_Thor(Cell): + """Dense_Thor""" + + # @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + damping=0.03, + loss_scale=1, + frequency=10, + has_bias=False, + activation=None, + batch_size=12): + super(Dense_Thor, self).__init__() + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = check_bool(has_bias) + self.thor = True + if isinstance(weight_init, Tensor): + if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ + weight_init.shape()[1] != in_channels: + raise ValueError("weight_init shape error") + + self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + + if self.has_bias: + if isinstance(bias_init, Tensor): + if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: + raise ValueError("bias_init shape error") + + self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + self.matrix_A_inv = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float16)), + name='matrix_A_inv', requires_grad=False) + self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)), + name="matrix_G_inv", requires_grad=False) + self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) + self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) + self.fused_abs_max = P.CusFusedAbsMax1() + self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)) + + self.matmul = P.MatMul(transpose_b=True) + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + self.mul = P.Mul() + self.cast = P.Cast() + self.damping = damping + self.loss_scale = Tensor(1 / loss_scale, mstype.float16) + self.vector_matmul = P.CusBatchMatMul() + self.gather = P.GatherV2() + self.assignadd = P.AssignAdd() + self.freq = Tensor(frequency, mstype.int32) + self.axis = 0 + self.abs = P.Abs() + self.reduce_max = P.ReduceMax(keep_dims=False) + self.log = P.Log() + self.exp = P.Exp() + self.dampingA = Tensor(np.identity(in_channels), mstype.float32) + self.dampingG = Tensor(np.identity(out_channels), mstype.float32) + self.sqrt = P.Sqrt() + self.getG = P.InsertGradientOf(self.save_gradient) + self.batch_size = batch_size + + def save_gradient(self, dout): + """save_gradient""" + bs = self.cast(self.batch_size, mstype.float32) + out = dout + dout = self.mul(dout, self.loss_scale) + dout = self.mul(dout, bs) + shape = self.shape(dout) + normalizer = self.cast(shape[0], mstype.float32) + matrix_G = self.cube_matmul(dout, dout) + matrix_G = self.mul(matrix_G, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.freq + damping = self.sqrt(damping_step) + dampingG = self.cast(self.dampingG, mstype.float32) + matrix_G = matrix_G + damping * dampingG + matrix_G_inv = self.cholesky(matrix_G) + matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) + self.G_inv_max = matrix_G_inv_max + matrix_G_inv = self.matrix_combine(matrix_G_inv) + matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) + self.matrix_G_inv = matrix_G_inv + return out + + def construct(self, x): + """construct""" + if self.thor: + inputs = self.cube_matmul(x, x) + shape = self.shape(x) + normalizer = self.cast(shape[0], mstype.float32) + matrix_A = self.mul(inputs, 1.0 / normalizer) + + damping_step = self.gather(self.damping, self.cov_step, self.axis) + damping_step = self.cast(damping_step, mstype.float32) + damping = self.sqrt(damping_step) + dampingA = self.cast(self.dampingA, mstype.float32) + matrix_A = matrix_A + damping * dampingA + matrix_A_inv = self.cholesky(matrix_A) + matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) + matrix_A_inv_max = self.fused_abs_max(matrix_A_inv) + matrix_A_inv_max = self.fused_abs_max(matrix_A_inv_max) + self.A_inv_max = matrix_A_inv_max + matrix_A_inv = self.matrix_combine(matrix_A_inv) + matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) + self.matrix_A_inv = matrix_A_inv + self.matrix_G_inv = self.fake_G + output = self.matmul(x, self.weight) + output = self.getG(output) + else: + output = self.matmul(x, self.weight) + + if self.has_bias: + output = self.bias_add(output, self.bias) + if self.activation_flag: + return self.activation(output) + return output + + def extend_repr(self): + """extend_repr""" + str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ + .format(self.in_channels, self.out_channels, self.weight, self.has_bias) + if self.has_bias: + str_info = str_info + ', bias={}'.format(self.bias) + + if self.activation_flag: + str_info = str_info + ', activation={}'.format(self.activation) + + return str_info diff --git a/model_zoo/official/nlp/bert_thor/src/utils.py b/model_zoo/official/nlp/bert_thor/src/utils.py new file mode 100644 index 0000000000..9366f71246 --- /dev/null +++ b/model_zoo/official/nlp/bert_thor/src/utils.py @@ -0,0 +1,169 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Functional Cells used in Bert finetune and evaluation. +""" + +import os +import time + +import numpy as np +from src.config import cfg + +import mindspore.nn as nn +from mindspore.common import dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR +from mindspore.ops import operations as P +from mindspore.train.callback import Callback + + +class CrossEntropyCalculation(nn.Cell): + """ + Cross Entropy loss + """ + + def __init__(self, is_training=True): + super(CrossEntropyCalculation, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + self.is_training = is_training + + def construct(self, logits, label_ids, num_labels): + if self.is_training: + label_ids = self.reshape(label_ids, self.last_idx) + one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) + loss = self.reduce_mean(per_example_loss, self.last_idx) + return_value = self.cast(loss, mstype.float32) + else: + return_value = logits * 1.0 + return return_value + + +def make_directory(path: str): + """Make directory.""" + if path is None or not isinstance(path, str) or path.strip() == "": + logger.error("The path(%r) is invalid type.", path) + raise TypeError("Input path is invaild type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The abs path is %r", path) + + # check the path is exist and write permissions? + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path, exist_ok=True) + real_path = path + except PermissionError as e: + logger.error("No write permission on the directory(%r), error = %r", path, e) + raise TypeError("No write permission on the directory.") + return real_path + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self._per_print_times = per_print_times + self.step_start_time = time.time() + + def step_begin(self, run_context): + self.step_start_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_time_span = time.time() - self.step_start_time + total_time_span = step_time_span + cur_step_num = cb_params.cur_step_num + if cur_step_num % cfg.Thor.frequency == 0: + step_time_span = step_time_span / (cfg.Thor.frequency - 1) + print("epoch: {}, step: {}, outputs are {}, total_time_span is {}, step_time_span is {}".format( + cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs), total_time_span, step_time_span)) + + +def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): + """ + Find the ckpt finetune generated and load it into eval network. + """ + files = os.listdir(load_finetune_checkpoint_dir) + pre_len = len(prefix) + max_num = 0 + for filename in files: + name_ext = os.path.splitext(filename) + if name_ext[-1] != ".ckpt": + continue + # steps_per_epoch = ds.get_dataset_size() + if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): + index = filename[pre_len:].find("-") + if index == 0 and max_num == 0: + load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) + elif index not in (0, -1): + name_split = name_ext[-2].split('_') + if (steps_per_epoch != int(name_split[len(name_split) - 1])) \ + or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])): + continue + num = filename[pre_len + 1:pre_len + index] + if int(num) > max_num: + max_num = int(num) + load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) + return load_finetune_checkpoint_path + + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr diff --git a/model_zoo/lstm/README.md b/model_zoo/official/nlp/lstm/README.md similarity index 100% rename from model_zoo/lstm/README.md rename to model_zoo/official/nlp/lstm/README.md diff --git a/model_zoo/lstm/eval.py b/model_zoo/official/nlp/lstm/eval.py similarity index 100% rename from model_zoo/lstm/eval.py rename to model_zoo/official/nlp/lstm/eval.py diff --git a/model_zoo/deepfm/src/__init__.py b/model_zoo/official/nlp/lstm/src/__init__.py similarity index 100% rename from model_zoo/deepfm/src/__init__.py rename to model_zoo/official/nlp/lstm/src/__init__.py diff --git a/model_zoo/lstm/src/config.py b/model_zoo/official/nlp/lstm/src/config.py similarity index 100% rename from model_zoo/lstm/src/config.py rename to model_zoo/official/nlp/lstm/src/config.py diff --git a/model_zoo/lstm/src/dataset.py b/model_zoo/official/nlp/lstm/src/dataset.py similarity index 100% rename from model_zoo/lstm/src/dataset.py rename to model_zoo/official/nlp/lstm/src/dataset.py diff --git a/model_zoo/lstm/src/imdb.py b/model_zoo/official/nlp/lstm/src/imdb.py similarity index 100% rename from model_zoo/lstm/src/imdb.py rename to model_zoo/official/nlp/lstm/src/imdb.py diff --git a/model_zoo/lstm/src/lstm.py b/model_zoo/official/nlp/lstm/src/lstm.py similarity index 100% rename from model_zoo/lstm/src/lstm.py rename to model_zoo/official/nlp/lstm/src/lstm.py diff --git a/model_zoo/official/nlp/lstm/train.py b/model_zoo/official/nlp/lstm/train.py new file mode 100644 index 0000000000..53c3a89a6a --- /dev/null +++ b/model_zoo/official/nlp/lstm/train.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train lstm example on aclImdb######################## +python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path +""" +import argparse +import os + +import numpy as np + +from src.config import lstm_cfg as cfg +from src.dataset import convert_to_mindrecord +from src.dataset import lstm_create_dataset +from src.lstm import SentimentNet +from mindspore import Tensor, nn, Model, context +from mindspore.nn import Accuracy +from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_param_into_net, load_checkpoint + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='MindSpore LSTM Example') + parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], + help='whether to preprocess data.') + parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", + help='path where the dataset is stored.') + parser.add_argument('--glove_path', type=str, default="./glove", + help='path where the GloVe is stored.') + parser.add_argument('--preprocess_path', type=str, default="./preprocess", + help='path where the pre-process data is stored.') + parser.add_argument('--ckpt_path', type=str, default="./", + help='the path to save the checkpoint file.') + parser.add_argument('--pre_trained', type=str, default=None, + help='the pretrained checkpoint file path.') + parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], + help='the target device to run, support "GPU", "CPU". Default: "GPU".') + args = parser.parse_args() + + context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target=args.device_target) + + if args.preprocess == "true": + print("============== Starting Data Pre-processing ==============") + convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) + + embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) + network = SentimentNet(vocab_size=embedding_table.shape[0], + embed_size=cfg.embed_size, + num_hiddens=cfg.num_hiddens, + num_layers=cfg.num_layers, + bidirectional=cfg.bidirectional, + num_classes=cfg.num_classes, + weight=Tensor(embedding_table), + batch_size=cfg.batch_size) + # pre_trained + if args.pre_trained: + load_param_into_net(network, load_checkpoint(args.pre_trained)) + + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) + loss_cb = LossMonitor() + + model = Model(network, loss, opt, {'acc': Accuracy()}) + + print("============== Starting Training ==============") + ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1) + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + if args.device_target == "CPU": + model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False) + else: + model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("============== Training Success ==============") diff --git a/model_zoo/official/nlp/mass/README.md b/model_zoo/official/nlp/mass/README.md new file mode 100644 index 0000000000..cb1a47dc44 --- /dev/null +++ b/model_zoo/official/nlp/mass/README.md @@ -0,0 +1,592 @@ +![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) + + + +- [MASS: Masked Sequence to Sequence Pre-training for Language Generation Description](#googlenet-description) +- [Model architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) +- [Script description](#script-description) + - [Data Preparation](#Data-Preparation) + - [Tokenization](#Tokenization) + - [Byte Pair Encoding](#Byte-Pair-Encoding) + - [Build Vocabulary](#Build-Vocabulary) + - [Generate Dataset](#Generate-Dataset) + - [News Crawl Corpus](#News-Crawl-Corpus) + - [Gigaword Corpus](#Gigaword-Corpus) + - [Cornell Movie Dialog Corpus](#Cornell-Movie-Dialog-Corpus) + - [Configuration](#Configuration) + - [Training & Evaluation process](#Training-&-Evaluation-process) + - [Weights average](#Weights-average) + - [Learning rate scheduler](#Learning-rate-scheduler) +- [Model description](#model-description) + - [Performance](#performance) + - [Results](#results) + - [Training Performance](#training-performance) + - [Inference Performance](#inference-performance) +- [Environment Requirements](#environment-requirements) + - [Platform](#Platform) + - [Requirements](#Requirements) +- [Get started](#get-started) + - [Pre-training](#Pre-training) + - [Fine-tuning](#Fine-tuning) + - [Inference](#Inference) +- [Description of random situation](#description-of-random-situation) +- [others](#others) +- [ModelZoo Homepage](#modelzoo-homepage) + + + + +# MASS: Masked Sequence to Sequence Pre-training for Language Generation Description + +[MASS: Masked Sequence to Sequence Pre-training for Language Generation](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf) was released by MicroSoft in June 2019. + +BERT(Devlin et al., 2018) have achieved SOTA in natural language understanding area by pre-training the encoder part of Transformer(Vaswani et al., 2017) with masked rich-resource text. Likewise, GPT(Raddford et al., 2018) pre-trains the decoder part of Transformer with masked(encoder inputs are masked) rich-resource text. Both of them build a robust language model by pre-training with masked rich-resource text. + +Inspired by BERT, GPT and other language models, MicroSoft addressed [MASS: Masked Sequence to Sequence Pre-training for Language Generation](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf) which combines BERT's and GPT's idea. MASS has an important parameter k, which controls the masked fragment length. BERT and GPT are specicl case when k equals to 1 and sentence length. + +[Introducing MASS – A pre-training method that outperforms BERT and GPT in sequence to sequence language generation tasks](https://www.microsoft.com/en-us/research/blog/introducing-mass-a-pre-training-method-that-outperforms-bert-and-gpt-in-sequence-to-sequence-language-generation-tasks/) + +[Paper](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf): Song, Kaitao, Xu Tan, Tao Qin, Jianfeng Lu and Tie-Yan Liu. “MASS: Masked Sequence to Sequence Pre-training for Language Generation.” ICML (2019). + + +# Model architecture + +The overall network architecture of MASS is shown below, which is Transformer(Vaswani et al., 2017): + +MASS is consisted of 6-layer encoder and 6-layer decoder with 1024 embedding/hidden size, and 4096 intermediate size between feed forward network which has two full connection layers. + +![Transformer architecture](https://cdn.analyticsvidhya.com/wp-content/uploads/2019/06/Screenshot-from-2019-06-17-19-53-10.png) + + +# Dataset + +Dataset used: +- monolingual English data from News Crawl dataset(WMT 2019) for pre-training. +- Gigaword Corpus(Graff et al., 2003) for Text Summarization. +- Cornell movie dialog corpus(DanescuNiculescu-Mizil & Lee, 2011). + +Details about those dataset could be found in [MASS: Masked Sequence to Sequence Pre-training for Language Generation](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf). + + +# Features + +Mass is designed to jointly pre train encoder and decoder to complete the task of language generation. +First of all, through a sequence to sequence framework, mass only predicts the blocked token, which forces the encoder to understand the meaning of the unshielded token, and encourages the decoder to extract useful information from the encoder. +Secondly, by predicting the continuous token of the decoder, the decoder can build better language modeling ability than only predicting discrete token. +Third, by further shielding the input token of the decoder which is not shielded in the encoder, the decoder is encouraged to extract more useful information from the encoder side, rather than using the rich information in the previous token. + + +# Script description + +MASS script and code structure are as follow: + +```text +├── mass + ├── README.md // Introduction of MASS model. + ├── config + │ ├──config.py // Configuration instance definition. + │ ├──config.json // Configuration file. + ├── src + │ ├──dataset + │ ├──bi_data_loader.py // Dataset loader for fine-tune or inferring. + │ ├──mono_data_loader.py // Dataset loader for pre-training. + │ ├──language_model + │ ├──noise_channel_language_model.p // Noisy channel language model for dataset generation. + │ ├──mass_language_model.py // MASS language model according to MASS paper. + │ ├──loose_masked_language_model.py // MASS language model according to MASS released code. + │ ├──masked_language_model.py // Masked language model according to MASS paper. + │ ├──transformer + │ ├──create_attn_mask.py // Generate mask matrix to remove padding positions. + │ ├──transformer.py // Transformer model architecture. + │ ├──encoder.py // Transformer encoder component. + │ ├──decoder.py // Transformer decoder component. + │ ├──self_attention.py // Self-Attention block component. + │ ├──multi_head_attention.py // Multi-Head Self-Attention component. + │ ├──embedding.py // Embedding component. + │ ├──positional_embedding.py // Positional embedding component. + │ ├──feed_forward_network.py // Feed forward network. + │ ├──residual_conn.py // Residual block. + │ ├──beam_search.py // Beam search decoder for inferring. + │ ├──transformer_for_infer.py // Use Transformer to infer. + │ ├──transformer_for_train.py // Use Transformer to train. + │ ├──utils + │ ├──byte_pair_encoding.py // Apply BPE with subword-nmt. + │ ├──dictionary.py // Dictionary. + │ ├──loss_moniter.py // Callback of monitering loss during training step. + │ ├──lr_scheduler.py // Learning rate scheduler. + │ ├──ppl_score.py // Perplexity score based on N-gram. + │ ├──rouge_score.py // Calculate ROUGE score. + │ ├──load_weights.py // Load weights from a checkpoint or NPZ file. + │ ├──initializer.py // Parameters initializer. + ├── vocab + │ ├──all.bpe.codes // BPE codes table(this file should be generated by user). + │ ├──all_en.dict.bin // Learned vocabulary file(this file should be generated by user). + ├── scripts + │ ├──run.sh // Train & evaluate model script. + │ ├──learn_subword.sh // Learn BPE codes. + │ ├──stop_training.sh // Stop training. + ├── requirements.txt // Requirements of third party package. + ├── train.py // Train API entry. + ├── eval.py // Infer API entry. + ├── tokenize_corpus.py // Corpus tokenization. + ├── apply_bpe_encoding.py // Applying bpe encoding. + ├── weights_average.py // Average multi model checkpoints to NPZ format. + ├── news_crawl.py // Create News Crawl dataset for pre-training. + ├── gigaword.py // Create Gigaword Corpus. + ├── cornell_dialog.py // Create Cornell Movie Dialog dataset for conversation response. + +``` + + +## Data Preparation + +The data preparation of a natural language processing task contains data cleaning, tokenization, encoding and vocabulary generation steps. + +In our experiments, using [Byte Pair Encoding(BPE)](https://arxiv.org/abs/1508.07909) could reduce size of vocabulary, and relieve the OOV influence effectively. + +Vocabulary could be created using `src/utils/dictionary.py` with text dictionary which is learnt from BPE. +For more detail about BPE, please refer to [Subword-nmt lib](https://www.cnpython.com/pypi/subword-nmt) or [paper](https://arxiv.org/abs/1508.07909). + +In our experiments, vocabulary was learned based on 1.9M sentences from News Crawl Dataset, size of vocabulary is 45755. + +Here, we have a brief introduction of data preparation scripts. + + +### Tokenization +Using `tokenize_corpus.py` could tokenize corpus whose text files are in format of `.txt`. + +Major parameters in `tokenize_corpus.py`: + +```bash +--corpus_folder: Corpus folder path, if multi-folders are provided, use ',' split folders. +--output_folder: Output folder path. +--tokenizer: Tokenizer to be used, nltk or jieba, if nltk is not installed fully, use jieba instead. +--pool_size: Processes pool size. +``` + +Sample code: +```bash +python tokenize_corpus.py --corpus_folder /{path}/corpus --output_folder /{path}/tokenized_corpus --tokenizer {nltk|jieba} --pool_size 16 +``` + + +### Byte Pair Encoding +After tokenization, BPE is applied to tokenized corpus with provided `all.bpe.codes`. + +Apply BPE script can be found in `apply_bpe_encoding.py`. + +Major parameters in `apply_bpe_encoding.py`: + +```bash +--codes: BPE codes file. +--src_folder: Corpus folders. +--output_folder: Output files folder. +--prefix: Prefix of text file in `src_folder`. +--vocab_path: Generated vocabulary output path. +--threshold: Filter out words that frequency is lower than threshold. +--processes: Size of process pool (to accelerate). Default: 2. +``` + +Sample code: +```bash +python tokenize_corpus.py --codes /{path}/all.bpe.codes \ + --src_folder /{path}/tokenized_corpus \ + --output_folder /{path}/tokenized_corpus/bpe \ + --prefix tokenized \ + --vocab_path /{path}/vocab_en.dict.bin + --processes 32 +``` + + +### Build Vocabulary +Support that you want to create a new vocabulary, there are two options: +1. Learn BPE codes from scratch, and create vocabulary with multi vocabulary files from `subword-nmt`. +2. Create from an existing vocabulary file which lines in the format of `word frequency`. +3. *Optional*, Create a small vocabulary based on `vocab/all_en.dict.bin` with method of `shink` from `src/utils/dictionary.py`. +4. Persistent vocabulary to `vocab` folder with method `persistence()`. + +Major interface of `src/utils/dictionary.py` are as follow: + +1. `shrink(self, threshold=50)`: Shrink the size of vocabulary by filter out words frequency is lower than threshold. It returns a new vocabulary. +2. `load_from_text(cls, filepaths: List[str])`: Load existed text vocabulary which lines in the format of `word frequency`. +3. `load_from_persisted_dict(cls, filepath)`: Load from a persisted binary vocabulary which was saved by calling `persistence()` method. +4. `persistence(self, path)`: Save vocabulary object to binary file. + +Sample code: +```python +from src.utils import Dictionary + +vocabulary = Dictionary.load_from_persisted_dict("vocab/all_en.dict.bin") +tokens = [1, 2, 3, 4, 5] +# Convert ids to symbols. +print([vocabulary[t] for t in tokens]) + +sentence = ["Hello", "world"] +# Convert symbols to ids. +print([vocabulary.index[s] for s in sentence]) +``` + +For more detail, please refer to the source file. + + +### Generate Dataset +As mentioned above, three corpus are used in MASS mode, dataset generation scripts for them are provided. + +#### News Crawl Corpus +Script can be found in `news_crawl.py`. + +Major parameters in `news_crawl.py`: + +```bash +Note that please provide `--existed_vocab` or `--dict_folder` at least one. +A new vocabulary would be created in `output_folder` when pass `--dict_folder`. + +--src_folder: Corpus folders. +--existed_vocab: Optional, persisted vocabulary file. +--mask_ratio: Ratio of mask. +--output_folder: Output dataset files folder path. +--max_len: Maximum sentence length. If a sentence longer than `max_len`, then drop it. +--suffix: Optional, suffix of generated dataset files. +--processes: Optional, size of process pool (to accelerate). Default: 2. +``` + +Sample code: + +```bash +python news_crawl.py --src_folder /{path}/news_crawl \ + --existed_vocab /{path}/mass/vocab/all_en.dict.bin \ + --mask_ratio 0.5 \ + --output_folder /{path}/news_crawl_dataset \ + --max_len 32 \ + --processes 32 +``` + + +#### Gigaword Corpus +Script can be found in `gigaword.py`. + +Major parameters in `gigaword.py`: + +```bash +--train_src: Train source file path. +--train_ref: Train reference file path. +--test_src: Test source file path. +--test_ref: Test reference file path. +--existed_vocab: Persisted vocabulary file. +--output_folder: Output dataset files folder path. +--noise_prob: Optional, add noise prob. Default: 0. +--max_len: Optional, maximum sentence length. If a sentence longer than `max_len`, then drop it. Default: 64. +--format: Optional, dataset format, "mindrecord" or "tfrecord". Default: "tfrecord". +``` + +Sample code: + +```bash +python gigaword.py --train_src /{path}/gigaword/train_src.txt \ + --train_ref /{path}/gigaword/train_ref.txt \ + --test_src /{path}/gigaword/test_src.txt \ + --test_ref /{path}/gigaword/test_ref.txt \ + --existed_vocab /{path}/mass/vocab/all_en.dict.bin \ + --noise_prob 0.1 \ + --output_folder /{path}/gigaword_dataset \ + --max_len 64 +``` + + +#### Cornell Movie Dialog Corpus +Script can be found in `cornell_dialog.py`. + +Major parameters in `cornell_dialog.py`: + +```bash +--src_folder: Corpus folders. +--existed_vocab: Persisted vocabulary file. +--train_prefix: Train source and target file prefix. Default: train. +--test_prefix: Test source and target file prefix. Default: test. +--output_folder: Output dataset files folder path. +--max_len: Maximum sentence length. If a sentence longer than `max_len`, then drop it. +--valid_prefix: Optional, Valid source and target file prefix. Default: valid. +``` + +Sample code: + +```bash +python cornell_dialog.py --src_folder /{path}/cornell_dialog \ + --existed_vocab /{path}/mass/vocab/all_en.dict.bin \ + --train_prefix train \ + --test_prefix test \ + --noise_prob 0.1 \ + --output_folder /{path}/cornell_dialog_dataset \ + --max_len 64 +``` + + +## Configuration +Json file under the path `config/` is the template configuration file. +Almost all of the options and arguments needed could be assigned conveniently, including the training platform, configurations of dataset and model, arguments of optimizer etc. Optional features such as loss scale and checkpoint are also available by setting the options correspondingly. +For more detailed information about the attributes, refer to the file `config/config.py`. + +## Training & Evaluation process +For training a model, the shell script `run.sh` is all you need. In this scripts, the environment variable is set and the training script `train.py` under `mass` is executed. +You may start a task training with single device or multiple devices by assigning the options and run the command in bash: +```bash +sh run.sh [--options] +``` + +The usage is shown as bellow: +```text +Usage: run.sh [-h, --help] [-t, --task ] [-n, --device_num ] + [-i, --device_id ] [-j, --hccl_json ] + [-c, --config ] [-o, --output ] + [-v, --vocab ] + +options: + -h, --help show usage + -t, --task select task: CHAR, 't' for train and 'i' for inference". + -n, --device_num device number used for training: N, default is 1. + -i, --device_id device id used for training with single device: N, 0<=N<=7, default is 0. + -j, --hccl_json rank table file used for training with multiple devices: FILE. + -c, --config configuration file as shown in the path 'mass/config': FILE. + -o, --output assign output file of inference: FILE. + -v, --vocab set the vocabulary" +``` +Notes: Be sure to assign the hccl_json file while running a distributed-training. + +The command followed shows a example for training with 2 devices. +```bash +sh run.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json +``` +ps. Discontinuous device id is not supported in `run.sh` at present, device id in `rank_table.json` must start from 0. + + +If use a single chip, it would be like this: +```bash +sh run.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json +``` + + +## Weights average + +```python +python weights_average.py --input_files your_checkpoint_list --output_file model.npz +``` + +The input_files is a list of you checkpoints file. To use model.npz as the weights, add its path in config.json at "existed_ckpt". +```json +{ + ... + "checkpoint_options": { + "existed_ckpt": "/xxx/xxx/model.npz", + "save_ckpt_steps": 1000, + ... + }, + ... +} +``` + + +## Learning rate scheduler + +Two learning rate scheduler are provided in our model: + +1. [Polynomial decay scheduler](https://towardsdatascience.com/learning-rate-schedules-and-adaptive-learning-rate-methods-for-deep-learning-2c8f433990d1). +2. [Inverse square root scheduler](https://ece.uwaterloo.ca/~dwharder/aads/Algorithms/Inverse_square_root/). + +LR scheduler could be config in `config/config.json`. + +For Polynomial decay scheduler, config could be like: +```json +{ + ... + "learn_rate_config": { + "optimizer": "adam", + "lr": 1e-4, + "lr_scheduler": "poly", + "poly_lr_scheduler_power": 0.5, + "decay_steps": 10000, + "warmup_steps": 2000, + "min_lr": 1e-6 + }, + ... +} +``` + +For Inverse square root scheduler, config could be like: +```json +{ + ... + "learn_rate_config": { + "optimizer": "adam", + "lr": 1e-4, + "lr_scheduler": "isr", + "decay_start_step": 12000, + "warmup_steps": 2000, + "min_lr": 1e-6 + }, + ... +} +``` + +More detail about LR scheduler could be found in `src/utils/lr_scheduler.py`. + + +# Model description + +The MASS network is implemented by Transformer, which has multi-encoder layers and multi-decoder layers. +For pre-training, we use the Adam optimizer and loss-scale to get the pre-trained model. +During fine-turning, we fine-tune this pre-trained model with different dataset according to different tasks. +During testing, we use the fine-turned model to predict the result, and adopt a beam search algorithm to +get the most possible prediction results. + + +![MASS framework](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-Fig-2.png) + + +## Performance + +### Results + +#### Fine-Tuning on Text Summarization +The comparisons between MASS and two other pre-training methods in terms of ROUGE score on the text summarization task +with 3.8M training data are as follows: + +| Method | RG-1(F) | RG-2(F) | RG-L(F) | +|:---------------|:--------------|:-------------|:-------------| +| MASS | Ongoing | Ongoing | Ongoing | + +#### Fine-Tuning on Conversational ResponseGeneration +The comparisons between MASS and other baseline methods in terms of PPL on Cornell Movie Dialog corpus are as follows: + +| Method | Data = 10K | Data = 110K | +|--------------------|------------------|-----------------| +| MASS | Ongoing | Ongoing | + +#### Training Performance + +| Parameters | Masked Sequence to Sequence Pre-training for Language Generation | +|:---------------------------|:--------------------------------------------------------------------------| +| Model Version | v1 | +| Resource | Ascend 910, cpu 2.60GHz, 56cores;memory, 314G | +| uploaded Date | 05/24/2020 | +| MindSpore Version | 0.2.0 | +| Dataset | News Crawl 2007-2017 English monolingual corpus, Gigaword corpus, Cornell Movie Dialog corpus | +| Training Parameters | Epoch=50, steps=XXX, batch_size=192, lr=1e-4 | +| Optimizer | Adam | +| Loss Function | Label smoothed cross-entropy criterion | +| outputs | Sentence and probability | +| Loss | Lower than 2 | +| Accuracy | For conversation response, ppl=23.52, for text summarization, RG-1=29.79. | +| Speed | 611.45 sentences/s | +| Total time | --/-- | +| Params (M) | 44.6M | +| Checkpoint for Fine tuning | ---Mb, --, [A link]() | +| Model for inference | ---Mb, --, [A link]() | +| Scripts | [A link]() | + + +#### Inference Performance + +| Parameters | Masked Sequence to Sequence Pre-training for Language Generation | +|:---------------------------|:-----------------------------------------------------------| +| Model Version | V1 | +| Resource | Huawei 910 | +| uploaded Date | 05/24/2020 | +| MindSpore Version | 0.2.0 | +| Dataset | Gigaword corpus, Cornell Movie Dialog corpus | +| batch_size | --- | +| outputs | Sentence and probability | +| Accuracy | ppl=23.52 for conversation response, RG-1=29.79 for text summarization. | +| Speed | ---- sentences/s | +| Total time | --/-- | +| Model for inference | ---Mb, --, [A link]() | + + +# Environment Requirements + +## Platform + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + +## Requirements + +```txt +nltk +numpy +subword-nmt +rouge +``` + +https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html + + +# Get started +MASS pre-trains a sequence to sequence model by predicting the masked fragments in an input sequence. After this, downstream tasks including text summarization and conversation response are candidated for fine-tuning the model and for inference. +Here we provide a practice example to demonstrate the basic usage of MASS for pre-training, fine-tuning a model, and the inference process. The overall process is as follows: +1. Download and process the dataset. +2. Modify the `config.json` to config the network. +3. Run a task for pre-training and fine-tuning. +4. Perform inference and validation. + +## Pre-training +For pre-training a model, config the options in `config.json` firstly: +- Assign the `pre_train_dataset` under `dataset_config` node to the dataset path. +- Choose the optimizer('momentum/adam/lamb' is available). +- Assign the 'ckpt_prefix' and 'ckpt_path' under `checkpoint_path` to save the model files. +- Set other arguments including dataset configurations and network configurations. +- If you have a trained model already, assign the `existed_ckpt` to the checkpoint file. + +Run the shell script `run.sh` as followed: + +```bash +sh run.sh -t t -n 1 -i 1 -c /mass/config/config.json +``` +Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file. + +## Fine-tuning +For fine-tuning a model, config the options in `config.json` firstly: +- Assign the `fine_tune_dataset` under `dataset_config` node to the dataset path. +- Assign the `existed_ckpt` under `checkpoint_path` node to the existed model file generated by pre-training. +- Choose the optimizer('momentum/adam/lamb' is available). +- Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. +- Set other arguments including dataset configurations and network configurations. + +Run the shell script `run.sh` as followed: +```bash +sh run.sh -t t -n 1 -i 1 -c config/config.json +``` +Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file. + +## Inference +If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). +For inference, config the options in `config.json` firstly: +- Assign the `test_dataset` under `dataset_config` node to the dataset path. +- Assign the `existed_ckpt` under `checkpoint_path` node to the model file produced by fine-tuning. +- Choose the optimizer('momentum/adam/lamb' is available). +- Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. +- Set other arguments including dataset configurations and network configurations. + +Run the shell script `run.sh` as followed: + +```bash +sh run.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile} +``` + +# Description of random situation + +MASS model contains dropout operations, if you want to disable dropout, please set related dropout_rate to 0 in `config/config.json`. + + +# others +The model has been validated on Ascend environment, not validated on CPU and GPU. + + +# ModelZoo Homepage + [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/official/nlp/mass/__init__.py b/model_zoo/official/nlp/mass/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/mass/apply_bpe_encoding.py b/model_zoo/official/nlp/mass/apply_bpe_encoding.py similarity index 100% rename from model_zoo/mass/apply_bpe_encoding.py rename to model_zoo/official/nlp/mass/apply_bpe_encoding.py diff --git a/model_zoo/mass/config/__init__.py b/model_zoo/official/nlp/mass/config/__init__.py similarity index 100% rename from model_zoo/mass/config/__init__.py rename to model_zoo/official/nlp/mass/config/__init__.py diff --git a/model_zoo/mass/config/config.json b/model_zoo/official/nlp/mass/config/config.json similarity index 100% rename from model_zoo/mass/config/config.json rename to model_zoo/official/nlp/mass/config/config.json diff --git a/model_zoo/mass/config/config.py b/model_zoo/official/nlp/mass/config/config.py similarity index 100% rename from model_zoo/mass/config/config.py rename to model_zoo/official/nlp/mass/config/config.py diff --git a/model_zoo/mass/cornell_dialog.py b/model_zoo/official/nlp/mass/cornell_dialog.py similarity index 100% rename from model_zoo/mass/cornell_dialog.py rename to model_zoo/official/nlp/mass/cornell_dialog.py diff --git a/model_zoo/mass/eval.py b/model_zoo/official/nlp/mass/eval.py similarity index 100% rename from model_zoo/mass/eval.py rename to model_zoo/official/nlp/mass/eval.py diff --git a/model_zoo/mass/gigaword.py b/model_zoo/official/nlp/mass/gigaword.py similarity index 100% rename from model_zoo/mass/gigaword.py rename to model_zoo/official/nlp/mass/gigaword.py diff --git a/model_zoo/mass/news_crawl.py b/model_zoo/official/nlp/mass/news_crawl.py similarity index 100% rename from model_zoo/mass/news_crawl.py rename to model_zoo/official/nlp/mass/news_crawl.py diff --git a/model_zoo/mass/requirements.txt b/model_zoo/official/nlp/mass/requirements.txt similarity index 100% rename from model_zoo/mass/requirements.txt rename to model_zoo/official/nlp/mass/requirements.txt diff --git a/model_zoo/official/nlp/mass/scripts/__init__.py b/model_zoo/official/nlp/mass/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/mass/scripts/learn_subword.sh b/model_zoo/official/nlp/mass/scripts/learn_subword.sh similarity index 100% rename from model_zoo/mass/scripts/learn_subword.sh rename to model_zoo/official/nlp/mass/scripts/learn_subword.sh diff --git a/model_zoo/official/nlp/mass/scripts/run.sh b/model_zoo/official/nlp/mass/scripts/run.sh new file mode 100644 index 0000000000..6e33550ee8 --- /dev/null +++ b/model_zoo/official/nlp/mass/scripts/run.sh @@ -0,0 +1,174 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"` +eval set -- "$options" +echo $options + +echo_help() +{ + echo "Usage:" + echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]" + echo "options:" + echo " -h --help show usage" + echo " -t --task select task, 't' for training and 'i' for inference" + echo " -n --device_num training with N devices" + echo " -i --device_id training with device i" + echo " -j --hccl_json set the rank table file" + echo " -c --config set the configuration file" + echo " -o --output set the output file of inference" + echo " -v --vocab set the vocabulary" + echo " -m --metric set the metric" +} + +set_hccl_json() +{ + while [ -n "$1" ] + do + if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] + then + export RANK_TABLE_FILE=$2 + break + fi + shift + done +} +set_device_id() +{ + while [ -n "$1" ] + do + if [[ "$1" == "-i" || "$1" == "--device_id" ]] + then + if [[ $2 -ge 0 && $2 -le 7 ]] + then + export DEVICE_ID=$2 + fi + break + fi + shift + done +} + +while [ -n "$1" ] +do + case "$1" in + -h|--help) + echo_help + shift + ;; + -t|--task) + echo "task:" + if [ "$2" == "t" ] + then + task=train + elif [ "$2" == "i" ] + then + task=infer + fi + shift 2 + ;; + -n|--device_num) + echo "device_num" + if [ $2 -eq 1 ] + then + set_device_id $options + elif [ $2 -gt 1 ] + then + export HCCL_FLAG=1 + export DEPLOY_MODE=0 + + export RANK_SIZE=$2 + set_hccl_json $options + fi + shift 2 + ;; + -i|--device_id) + echo "set device id" + export DEVICE_ID=$2 + shift 2 + ;; + -c|--config) + echo "config"; + configurations=$2 + shift 2 + ;; + -o|--output) + echo "output"; + output=$2 + shift 2 + ;; + -v|--vocab) + echo "vocab"; + vocab=$2 + shift 2 + ;; + -m|--metric) + echo "metric"; + metric=$2 + shift 2 + ;; + --) + shift + break + ;; + *) + shift + ;; +esac +done + +file_path=$(cd "$(dirname $0)" || exit; pwd) +for((i=0; i < $RANK_SIZE; i++)) +do + if [ $RANK_SIZE -gt 1 ] + then + echo $RANK_SIZE + export RANK_ID=$i + export DEVICE_ID=$[i] + fi + echo "Working on device $i" + + cd $file_path || exit + cd ../ || exit + + rm -rf ./${task}_mass_$DEVICE_ID + mkdir ./${task}_mass_$DEVICE_ID + + cp train.py ./${task}_mass_$DEVICE_ID + cp eval.py ./${task}_mass_$DEVICE_ID + cp $configurations ./${task}_mass_$DEVICE_ID + + if [ $vocab ] + then + cp $vocab ./${task}_mass_$DEVICE_ID + fi + + cd ./${task}_mass_$DEVICE_ID || exit + env > log.log + echo $task + if [ "$task" == "train" ] + then + python train.py --config ${configurations##*/} >>log.log 2>&1 & + elif [ "$task" == "infer" ] + then + python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & + fi + cd ../ +done diff --git a/model_zoo/mass/src/__init__.py b/model_zoo/official/nlp/mass/src/__init__.py similarity index 100% rename from model_zoo/mass/src/__init__.py rename to model_zoo/official/nlp/mass/src/__init__.py diff --git a/model_zoo/mass/src/dataset/__init__.py b/model_zoo/official/nlp/mass/src/dataset/__init__.py similarity index 100% rename from model_zoo/mass/src/dataset/__init__.py rename to model_zoo/official/nlp/mass/src/dataset/__init__.py diff --git a/model_zoo/mass/src/dataset/base.py b/model_zoo/official/nlp/mass/src/dataset/base.py similarity index 100% rename from model_zoo/mass/src/dataset/base.py rename to model_zoo/official/nlp/mass/src/dataset/base.py diff --git a/model_zoo/mass/src/dataset/bi_data_loader.py b/model_zoo/official/nlp/mass/src/dataset/bi_data_loader.py similarity index 100% rename from model_zoo/mass/src/dataset/bi_data_loader.py rename to model_zoo/official/nlp/mass/src/dataset/bi_data_loader.py diff --git a/model_zoo/official/nlp/mass/src/dataset/load_dataset.py b/model_zoo/official/nlp/mass/src/dataset/load_dataset.py new file mode 100644 index 0000000000..be59941374 --- /dev/null +++ b/model_zoo/official/nlp/mass/src/dataset/load_dataset.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dataset loader to feed into model.""" +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as deC + + +def _load_dataset(input_files, batch_size, epoch_count=1, + sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True): + """ + Load dataset according to passed in params. + + Args: + input_files (list): Data files. + batch_size (int): Batch size. + epoch_count (int): Epoch count. + sink_mode (bool): Whether enable sink mode. + sink_step (int): Step to sink. + rank_size (int): Rank size. + rank_id (int): Rank id. + shuffle (bool): Whether shuffle dataset. + + Returns: + Dataset, dataset instance. + """ + if not input_files: + raise FileNotFoundError("Require at least one dataset.") + + if not isinstance(sink_mode, bool): + raise ValueError("`sink` must be type of bool.") + + for datafile in input_files: + print(f" | Loading {datafile}.") + + ds = de.TFRecordDataset( + input_files, + columns_list=[ + "src", "src_padding", + "prev_opt", "prev_padding", + "target", "tgt_padding" + ], + shuffle=shuffle, num_shards=rank_size, shard_id=rank_id, + shard_equal_rows=True, num_parallel_workers=8) + + ori_dataset_size = ds.get_dataset_size() + print(f" | Dataset size: {ori_dataset_size}.") + repeat_count = epoch_count + + type_cast_op = deC.TypeCast(mstype.int32) + ds = ds.map(input_columns="src", operations=type_cast_op) + ds = ds.map(input_columns="src_padding", operations=type_cast_op) + ds = ds.map(input_columns="prev_opt", operations=type_cast_op) + ds = ds.map(input_columns="prev_padding", operations=type_cast_op) + ds = ds.map(input_columns="target", operations=type_cast_op) + ds = ds.map(input_columns="tgt_padding", operations=type_cast_op) + + ds = ds.rename( + input_columns=["src", + "src_padding", + "prev_opt", + "prev_padding", + "target", + "tgt_padding"], + output_columns=["source_eos_ids", + "source_eos_mask", + "target_sos_ids", + "target_sos_mask", + "target_eos_ids", + "target_eos_mask"] + ) + + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_count) + + ds.channel_name = 'transformer' + return ds + + +def load_dataset(data_files: list, batch_size: int, epoch_count: int, + sink_mode: bool, sink_step: int = 1, rank_size: int = 1, rank_id: int = 0, shuffle=True): + """ + Load dataset. + + Args: + data_files (list): Data files. + batch_size (int): Batch size. + epoch_count (int): Epoch count. + sink_mode (bool): Whether enable sink mode. + sink_step (int): Step to sink. + rank_size (int): Rank size. + rank_id (int): Rank id. + shuffle (bool): Whether shuffle dataset. + + Returns: + Dataset, dataset instance. + """ + return _load_dataset(data_files, batch_size, epoch_count, sink_mode, + sink_step, rank_size, rank_id, shuffle=shuffle) diff --git a/model_zoo/mass/src/dataset/mono_data_loader.py b/model_zoo/official/nlp/mass/src/dataset/mono_data_loader.py similarity index 100% rename from model_zoo/mass/src/dataset/mono_data_loader.py rename to model_zoo/official/nlp/mass/src/dataset/mono_data_loader.py diff --git a/model_zoo/mass/src/dataset/schema.py b/model_zoo/official/nlp/mass/src/dataset/schema.py similarity index 100% rename from model_zoo/mass/src/dataset/schema.py rename to model_zoo/official/nlp/mass/src/dataset/schema.py diff --git a/model_zoo/mass/src/language_model/__init__.py b/model_zoo/official/nlp/mass/src/language_model/__init__.py similarity index 100% rename from model_zoo/mass/src/language_model/__init__.py rename to model_zoo/official/nlp/mass/src/language_model/__init__.py diff --git a/model_zoo/mass/src/language_model/base.py b/model_zoo/official/nlp/mass/src/language_model/base.py similarity index 100% rename from model_zoo/mass/src/language_model/base.py rename to model_zoo/official/nlp/mass/src/language_model/base.py diff --git a/model_zoo/mass/src/language_model/loose_masked_language_model.py b/model_zoo/official/nlp/mass/src/language_model/loose_masked_language_model.py similarity index 100% rename from model_zoo/mass/src/language_model/loose_masked_language_model.py rename to model_zoo/official/nlp/mass/src/language_model/loose_masked_language_model.py diff --git a/model_zoo/mass/src/language_model/masked_language_model.py b/model_zoo/official/nlp/mass/src/language_model/masked_language_model.py similarity index 100% rename from model_zoo/mass/src/language_model/masked_language_model.py rename to model_zoo/official/nlp/mass/src/language_model/masked_language_model.py diff --git a/model_zoo/mass/src/language_model/mass_language_model.py b/model_zoo/official/nlp/mass/src/language_model/mass_language_model.py similarity index 100% rename from model_zoo/mass/src/language_model/mass_language_model.py rename to model_zoo/official/nlp/mass/src/language_model/mass_language_model.py diff --git a/model_zoo/mass/src/language_model/noise_channel_language_model.py b/model_zoo/official/nlp/mass/src/language_model/noise_channel_language_model.py similarity index 100% rename from model_zoo/mass/src/language_model/noise_channel_language_model.py rename to model_zoo/official/nlp/mass/src/language_model/noise_channel_language_model.py diff --git a/model_zoo/mass/src/transformer/__init__.py b/model_zoo/official/nlp/mass/src/transformer/__init__.py similarity index 100% rename from model_zoo/mass/src/transformer/__init__.py rename to model_zoo/official/nlp/mass/src/transformer/__init__.py diff --git a/model_zoo/mass/src/transformer/beam_search.py b/model_zoo/official/nlp/mass/src/transformer/beam_search.py similarity index 100% rename from model_zoo/mass/src/transformer/beam_search.py rename to model_zoo/official/nlp/mass/src/transformer/beam_search.py diff --git a/model_zoo/mass/src/transformer/components.py b/model_zoo/official/nlp/mass/src/transformer/components.py similarity index 100% rename from model_zoo/mass/src/transformer/components.py rename to model_zoo/official/nlp/mass/src/transformer/components.py diff --git a/model_zoo/mass/src/transformer/create_attn_mask.py b/model_zoo/official/nlp/mass/src/transformer/create_attn_mask.py similarity index 100% rename from model_zoo/mass/src/transformer/create_attn_mask.py rename to model_zoo/official/nlp/mass/src/transformer/create_attn_mask.py diff --git a/model_zoo/mass/src/transformer/decoder.py b/model_zoo/official/nlp/mass/src/transformer/decoder.py similarity index 100% rename from model_zoo/mass/src/transformer/decoder.py rename to model_zoo/official/nlp/mass/src/transformer/decoder.py diff --git a/model_zoo/mass/src/transformer/embedding.py b/model_zoo/official/nlp/mass/src/transformer/embedding.py similarity index 100% rename from model_zoo/mass/src/transformer/embedding.py rename to model_zoo/official/nlp/mass/src/transformer/embedding.py diff --git a/model_zoo/mass/src/transformer/encoder.py b/model_zoo/official/nlp/mass/src/transformer/encoder.py similarity index 100% rename from model_zoo/mass/src/transformer/encoder.py rename to model_zoo/official/nlp/mass/src/transformer/encoder.py diff --git a/model_zoo/mass/src/transformer/feed_forward_network.py b/model_zoo/official/nlp/mass/src/transformer/feed_forward_network.py similarity index 100% rename from model_zoo/mass/src/transformer/feed_forward_network.py rename to model_zoo/official/nlp/mass/src/transformer/feed_forward_network.py diff --git a/model_zoo/mass/src/transformer/grad_clip.py b/model_zoo/official/nlp/mass/src/transformer/grad_clip.py similarity index 100% rename from model_zoo/mass/src/transformer/grad_clip.py rename to model_zoo/official/nlp/mass/src/transformer/grad_clip.py diff --git a/model_zoo/mass/src/transformer/infer_mass.py b/model_zoo/official/nlp/mass/src/transformer/infer_mass.py similarity index 100% rename from model_zoo/mass/src/transformer/infer_mass.py rename to model_zoo/official/nlp/mass/src/transformer/infer_mass.py diff --git a/model_zoo/mass/src/transformer/multi_head_attention.py b/model_zoo/official/nlp/mass/src/transformer/multi_head_attention.py similarity index 100% rename from model_zoo/mass/src/transformer/multi_head_attention.py rename to model_zoo/official/nlp/mass/src/transformer/multi_head_attention.py diff --git a/model_zoo/mass/src/transformer/positional_embedding.py b/model_zoo/official/nlp/mass/src/transformer/positional_embedding.py similarity index 100% rename from model_zoo/mass/src/transformer/positional_embedding.py rename to model_zoo/official/nlp/mass/src/transformer/positional_embedding.py diff --git a/model_zoo/mass/src/transformer/residual_conn.py b/model_zoo/official/nlp/mass/src/transformer/residual_conn.py similarity index 100% rename from model_zoo/mass/src/transformer/residual_conn.py rename to model_zoo/official/nlp/mass/src/transformer/residual_conn.py diff --git a/model_zoo/mass/src/transformer/self_attention.py b/model_zoo/official/nlp/mass/src/transformer/self_attention.py similarity index 100% rename from model_zoo/mass/src/transformer/self_attention.py rename to model_zoo/official/nlp/mass/src/transformer/self_attention.py diff --git a/model_zoo/mass/src/transformer/transformer.py b/model_zoo/official/nlp/mass/src/transformer/transformer.py similarity index 100% rename from model_zoo/mass/src/transformer/transformer.py rename to model_zoo/official/nlp/mass/src/transformer/transformer.py diff --git a/model_zoo/official/nlp/mass/src/transformer/transformer_for_infer.py b/model_zoo/official/nlp/mass/src/transformer/transformer_for_infer.py new file mode 100644 index 0000000000..99d56ba3a1 --- /dev/null +++ b/model_zoo/official/nlp/mass/src/transformer/transformer_for_infer.py @@ -0,0 +1,329 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer for infer.""" +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor + +from .beam_search import BeamSearchDecoder, TileBeam +from .embedding import EmbeddingLookup +from .positional_embedding import PositionalEmbedding +from .components import SaturateCast +from .create_attn_mask import CreateAttentionMaskFromInputMask +from .decoder import TransformerDecoder +from .encoder import TransformerEncoder + + +class PredLogProbs(nn.Cell): + """ + Get log probs. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): The length of sequences. + width (int): Number of parameters of a layer + compute_type (int): Type of input type. + dtype (int): Type of MindSpore output type. + """ + + def __init__(self, + batch_size, + seq_length, + width, + compute_type=mstype.float32, + dtype=mstype.float32): + super(PredLogProbs, self).__init__() + self.batch_size = batch_size + self.seq_length = seq_length + self.width = width + self.compute_type = compute_type + self.dtype = dtype + + self.reshape = P.Reshape() + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width) + self.cast = P.Cast() + + def construct(self, input_tensor, output_weights): + """ + Calculate the log_softmax. + + Inputs: + input_tensor (Tensor): A batch of sentences with shape (N, T). + output_weights (Tensor): A batch of masks with shape (N, T). + + Returns: + Tensor, the prediction probability with shape (N, T'). + """ + input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + + log_probs = self.log_softmax(logits) + return log_probs + + +class TransformerDecoderStep(nn.Cell): + """ + Multi-layer transformer decoder step. + + Args: + config (TransformerConfig): The config of Transformer. + num_hidden_layers (int): The numbers of hidden layers. + attn_embed_dim (int): Dimensions of attention weights. + num_attn_heads=12 (int): Heads number. + seq_length (int): The length of a sequence. + intermediate_size: Hidden size in FFN. + attn_dropout_prob (float): Dropout rate in attention. Default: 0.1. + initializer_range (float): Initial range. + hidden_dropout_prob (float): Dropout rate in FFN. + hidden_act (str): Activation function in FFN. + compute_type (mstype): Mindspore data type. Default: mstype.float32. + embedding_lookup (function): Embeddings lookup operation. Default: None. + positional_embedding (function): Position Embedding operation. Default: None. + projection (function): Function to get log probs. Default: None. + """ + + def __init__(self, + config, + num_hidden_layers, + attn_embed_dim, + num_attn_heads=12, + seq_length=64, + intermediate_size=3072, + attn_dropout_prob=0.1, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32, + embedding_lookup=None, + positional_embedding=None, + projection=None): + super(TransformerDecoderStep, self).__init__(auto_prefix=False) + self.embedding_lookup = embedding_lookup + self.positional_embedding = positional_embedding + self.projection = projection + self.seq_length = seq_length + self.decoder = TransformerDecoder( + attn_embed_dim=attn_embed_dim, + num_attn_heads=num_attn_heads, + decoder_layers=num_hidden_layers, + intermediate_size=intermediate_size, + attn_dropout_prob=attn_dropout_prob, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + + self.ones_like = P.OnesLike() + self.shape = P.Shape() + + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + self.expand = P.ExpandDims() + self.multiply = P.Mul() + + ones = np.ones(shape=(seq_length, seq_length)) + self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) + + self.cast_compute_type = SaturateCast(dst_type=compute_type) + self.scale = Tensor([math.sqrt(float(attn_embed_dim))], dtype=mstype.float32) + + def construct(self, input_ids, enc_states, enc_attention_mask): + """ + Get log probs. + + Args: + input_ids: [batch_size * beam_width, m] + enc_states: [batch_size * beam_width, T, D] + enc_attention_mask: [batch_size * beam_width, T, D] + + Returns: + Tensor, the log_probs. [batch_size * beam_width, 1, Vocabulary_Dimension] + """ + + # process embedding. input_embedding: [batch_size * beam_width, m, D], embedding_tables: [V, D] + input_embedding, embedding_tables = self.embedding_lookup(input_ids) + input_embedding = self.multiply(input_embedding, self.scale) + input_embedding = self.positional_embedding(input_embedding) + input_embedding = self.cast_compute_type(input_embedding) + + input_shape = self.shape(input_ids) + input_len = input_shape[1] + # [m,m] + future_mask = self.future_mask[0:input_len:1, 0:input_len:1] + # [batch_size * beam_width, m] + input_mask = self.ones_like(input_ids) + # [batch_size * beam_width, m, m] + input_mask = self._create_attention_mask_from_input_mask(input_mask) + # [batch_size * beam_width, m, m] + input_mask = self.multiply(input_mask, self.expand(future_mask, 0)) + input_mask = self.cast_compute_type(input_mask) + + # [batch_size * beam_width, m, D] + enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] + + # call TransformerDecoder: [batch_size * beam_width, m, D] + decoder_output = self.decoder(input_embedding, input_mask, enc_states, enc_attention_mask) + + # take the last step, [batch_size * beam_width, 1, D] + decoder_output = decoder_output[::, input_len - 1:input_len:1, ::] + + # projection and log_prob + log_probs = self.projection(decoder_output, embedding_tables) + + # [batch_size * beam_width, 1, vocabulary_size] + return log_probs + + +class TransformerInferModel(nn.Cell): + """ + Transformer Infer. + + Args: + config (TransformerConfig): The config of Transformer. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + config, + use_one_hot_embeddings=False): + super(TransformerInferModel, self).__init__() + config = copy.deepcopy(config) + config.hidden_dropout_prob = 0.0 + config.attention_dropout_prob = 0.0 + + self.input_mask_from_dataset = config.input_mask_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.attn_embed_dim = config.hidden_size + self.num_layers = config.num_hidden_layers + self.last_idx = self.num_hidden_layers - 1 + + self.embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embed_dim=self.embedding_size, + use_one_hot_embeddings=use_one_hot_embeddings) + + self.positional_embedding = PositionalEmbedding( + embedding_size=self.embedding_size, + max_position_embeddings=config.max_position_embeddings) + # use for infer + self.projection = PredLogProbs( + batch_size=config.batch_size * config.beam_width, + seq_length=1, + width=self.hidden_size, + compute_type=config.compute_type) + + self.encoder = TransformerEncoder( + attn_embed_dim=self.attn_embed_dim, + encoder_layers=self.num_layers, + num_attn_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + attention_dropout_prob=config.attention_dropout_prob, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type) + + decoder_cell = TransformerDecoderStep( + config=config, + num_hidden_layers=config.num_hidden_layers, + attn_embed_dim=self.attn_embed_dim, + seq_length=config.seq_length, + num_attn_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + hidden_dropout_prob=config.hidden_dropout_prob, + compute_type=config.compute_type, + initializer_range=config.initializer_range, + hidden_act="relu", + embedding_lookup=self.embedding_lookup, + positional_embedding=self.positional_embedding, + attn_dropout_prob=config.attention_dropout_prob, + projection=self.projection + ) + + # link beam_search after decoder + self.decoder = BeamSearchDecoder( + batch_size=config.batch_size, + seq_length=config.seq_length, + vocab_size=config.vocab_size, + decoder=decoder_cell, + beam_width=config.beam_width, + length_penalty_weight=config.length_penalty_weight, + max_decode_length=config.max_decode_length) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.expand = P.ExpandDims() + self.multiply = P.Mul() + + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + # use for infer + self.tile_beam = TileBeam(beam_width=config.beam_width) + ones = np.ones(shape=(config.batch_size, config.max_decode_length)) + self.encode_mask = Tensor(ones, dtype=mstype.float32) + + self.scale = Tensor([math.sqrt(float(self.embedding_size))], + dtype=mstype.float32) + self.reshape = P.Reshape() + + def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): + """ + Process source sentence + + Inputs: + source_ids (Tensor): Source sentences with shape (N, T). + source_mask (Tensor): Source sentences padding mask with shape (N, T), + where 0 indicates padding position. + + Returns: + Tensor, Predictions with shape (N, T'). + """ + # word_embeddings + src_embeddings, _ = self.embedding_lookup(source_ids) + src_embeddings = self.multiply(src_embeddings, self.scale) + # position_embeddings + src_embeddings = self.positional_embedding(src_embeddings) + # attention mask, [batch_size, seq_length, seq_length] + enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) + # encode + encoder_output = self.encoder(self.cast_compute_type(src_embeddings), + self.cast_compute_type(enc_attention_mask)) + + # bean search for encoder output + beam_encoder_output = self.tile_beam(encoder_output) + # [batch_size, T, D] + enc_attention_mask = self.multiply( + enc_attention_mask[::, 0:1:1, ::], + self.expand(self.encode_mask, -1)) + # [N*batch_size, T, D] + beam_enc_attention_mask = self.tile_beam(enc_attention_mask) + beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) + predicted_ids, predicted_probs = self.decoder(beam_encoder_output, beam_enc_attention_mask) + predicted_ids = self.reshape(predicted_ids, (self.batch_size, -1)) + return predicted_ids, predicted_probs diff --git a/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py b/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py new file mode 100644 index 0000000000..656b9e6f40 --- /dev/null +++ b/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py @@ -0,0 +1,348 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer for training.""" +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean + +from .transformer import Transformer +from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients + + +class PredLogProbs(nn.Cell): + """ + Get log probs. + + Args: + config (TransformerConfig): The config of Transformer. + + Returns: + Tensor, masked lm output. + """ + + def __init__(self, config): + super(PredLogProbs, self).__init__() + self.width = config.hidden_size + self.reshape = P.Reshape() + + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) + self.cast = P.Cast() + self.compute_type = config.compute_type + self.dtype = config.dtype + self.get_shape = P.Shape() + + def construct(self, input_tensor, output_weights): + """ + Construct network. + + Args: + input_tensor (Tensor): Tensor. + output_weights (Tensor): Tensor. + + Returns: + Tensor, masked lm output. + """ + shape = self.get_shape(input_tensor) + + input_tensor = self.reshape(input_tensor, (shape[0] * shape[1], shape[2])) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + + log_probs = self.log_softmax(logits) + return log_probs + + +class TransformerTraining(nn.Cell): + """ + Transformer training network. + + Args: + config (TransformerConfig): The config of Transformer. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings): + super(TransformerTraining, self).__init__() + self.transformer = Transformer(config, is_training, use_one_hot_embeddings) + self.projection = PredLogProbs(config) + + def construct(self, source_ids, source_mask, target_ids, target_mask): + """ + Construct network. + + Args: + source_ids (Tensor): Source sentence. + source_mask (Tensor): Source padding mask. + target_ids (Tensor): Target sentence. + target_mask (Tensor): Target padding mask. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + _, decoder_outputs, embedding_table = \ + self.transformer(source_ids, source_mask, target_ids, target_mask) + prediction_scores = self.projection(decoder_outputs, + embedding_table) + return prediction_scores + + +class LabelSmoothedCrossEntropyCriterion(nn.Cell): + """ + Label Smoothed Cross-Entropy Criterion. + + Args: + config (TransformerConfig): The config of Transformer. + + Returns: + Tensor, final loss. + """ + + def __init__(self, config): + super(LabelSmoothedCrossEntropyCriterion, self).__init__() + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(float(1 - config.label_smoothing), mstype.float32) + self.off_value = Tensor(config.label_smoothing / float(self.vocab_size - 1), mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.flatten = P.Flatten() + self.neg = P.Neg() + self.cast = P.Cast() + self.flat_shape = (config.batch_size * config.seq_length,) + self.get_shape = P.Shape() + + def construct(self, prediction_scores, label_ids, label_weights): + """ + Construct network to calculate loss. + + Args: + prediction_scores (Tensor): Prediction scores. + label_ids (Tensor): Labels. + label_weights (Tensor): Mask tensor. + + Returns: + Tensor, final loss. + """ + label_shape = self.get_shape(label_ids) + + label_ids = self.reshape(label_ids, (label_shape[0] * label_shape[1],)) + label_weights = self.cast( + self.reshape(label_weights, (label_shape[0] * label_shape[1],)), + mstype.float32 + ) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + loss = numerator / denominator + + return loss + + +class TransformerNetworkWithLoss(nn.Cell): + """ + Provide transformer training loss through network. + + Args: + config (BertConfig): The config of Transformer. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(TransformerNetworkWithLoss, self).__init__() + self.transformer = TransformerTraining(config, is_training, use_one_hot_embeddings) + self.loss = LabelSmoothedCrossEntropyCriterion(config) + self.cast = P.Cast() + + def construct(self, + source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights): + prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask) + total_loss = self.loss(prediction_scores, label_ids, label_weights) + return self.cast(total_loss, mstype.float32) + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * F.cast(reciprocal(scale), F.dtype(grad)) + + +class TransformerTrainOneStepWithLossScaleCell(nn.Cell): + """ + Encapsulation class of Transformer network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network: Cell. The training network. Note that loss function should have + been added. + optimizer: Optimizer. Optimizer for updating the weights. + + Returns: + Tuple[Tensor, Tensor, Tensor], loss, overflow, sen. + """ + + def __init__(self, network, optimizer, scale_update_cell=None): + + super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.all_reduce = P.AllReduce() + + self.parallel_mode = _get_parallel_mode() + if self.parallel_mode not in ParallelMode.MODE_LIST: + raise ValueError("Parallel mode does not support: ", self.parallel_mode) + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.clip_gradients = ClipGradients() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + self.add_flags(has_effect=True) + + def construct(self, + source_eos_ids, + source_eos_mask, + target_sos_ids, + target_sos_mask, + target_eos_ids, + target_eos_mask, + sens=None): + """ + Construct network. + + Args: + source_eos_ids (Tensor): Source sentence. + source_eos_mask (Tensor): Source padding mask. + target_sos_ids (Tensor): Target sentence. + target_sos_mask (Tensor): Target padding mask. + target_eos_ids (Tensor): Prediction sentence. + target_eos_mask (Tensor): Prediction padding mask. + sens (Tensor): Loss sen. + + Returns: + Tuple[Tensor, Tensor, Tensor], loss, overflow, sen. + """ + source_ids = source_eos_ids + source_mask = source_eos_mask + target_ids = target_sos_ids + target_mask = target_sos_mask + label_ids = target_eos_ids + label_weights = target_eos_mask + + weights = self.weights + loss = self.network(source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights) + # Alloc status. + init = self.alloc_status() + # Clear overflow buffer. + self.clear_before_grad(init) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + grads = self.grad(self.network, weights)(source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights, + self.cast(scaling_sens, + mstype.float32)) + + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) + if self.reducer_flag: + # Apply grad reducer on grads. + grads = self.grad_reducer(grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + + if self.is_distributed: + # Sum overflow flag over devices. + flag_reduce = self.all_reduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/mass/src/utils/__init__.py b/model_zoo/official/nlp/mass/src/utils/__init__.py similarity index 100% rename from model_zoo/mass/src/utils/__init__.py rename to model_zoo/official/nlp/mass/src/utils/__init__.py diff --git a/model_zoo/mass/src/utils/byte_pair_encoding.py b/model_zoo/official/nlp/mass/src/utils/byte_pair_encoding.py similarity index 100% rename from model_zoo/mass/src/utils/byte_pair_encoding.py rename to model_zoo/official/nlp/mass/src/utils/byte_pair_encoding.py diff --git a/model_zoo/mass/src/utils/dictionary.py b/model_zoo/official/nlp/mass/src/utils/dictionary.py similarity index 100% rename from model_zoo/mass/src/utils/dictionary.py rename to model_zoo/official/nlp/mass/src/utils/dictionary.py diff --git a/model_zoo/mass/src/utils/eval_score.py b/model_zoo/official/nlp/mass/src/utils/eval_score.py similarity index 100% rename from model_zoo/mass/src/utils/eval_score.py rename to model_zoo/official/nlp/mass/src/utils/eval_score.py diff --git a/model_zoo/mass/src/utils/initializer.py b/model_zoo/official/nlp/mass/src/utils/initializer.py similarity index 100% rename from model_zoo/mass/src/utils/initializer.py rename to model_zoo/official/nlp/mass/src/utils/initializer.py diff --git a/model_zoo/mass/src/utils/load_weights.py b/model_zoo/official/nlp/mass/src/utils/load_weights.py similarity index 100% rename from model_zoo/mass/src/utils/load_weights.py rename to model_zoo/official/nlp/mass/src/utils/load_weights.py diff --git a/model_zoo/mass/src/utils/loss_monitor.py b/model_zoo/official/nlp/mass/src/utils/loss_monitor.py similarity index 100% rename from model_zoo/mass/src/utils/loss_monitor.py rename to model_zoo/official/nlp/mass/src/utils/loss_monitor.py diff --git a/model_zoo/official/nlp/mass/src/utils/lr_scheduler.py b/model_zoo/official/nlp/mass/src/utils/lr_scheduler.py new file mode 100644 index 0000000000..16607678e5 --- /dev/null +++ b/model_zoo/official/nlp/mass/src/utils/lr_scheduler.py @@ -0,0 +1,140 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning scheduler.""" +from math import ceil +import numpy as np + +import mindspore.nn.learning_rate_schedule as lr_schedules + + +def square_root_schedule(lr, update_num, decay_start_step, + warmup_steps=2000, + min_lr=1e-5): + """ + Decay the LR based on the ISR(inverse square root). + + During warm-up:: + lrs = np.linspace(0, lr, warmup_steps) + + After warm-up: + decay_factor = lr * sqrt(warmup_steps) + lr = decay_factor / sqrt(step) if step >= decay_start_step else lr + + Args: + lr (float): Init learning rate. + update_num (int): Total steps. + decay_start_step (int): Decay begins after `decay_start_step` steps. + warmup_steps (int): Warm up steps. + min_lr (float): Min learning rate. + + Returns: + np.ndarray, learning rate array. + """ + warmup_end_lr = lr + warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr + + # If warmup_init_lr > lr, then lr_step is negative. + # Otherwise, it's positive. + lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps + decay_factor = lr * warmup_steps ** 0.5 + + lrs = np.empty(shape=update_num, dtype=np.float32) + _start_step = 0 + if 0 < warmup_steps < update_num: + lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps) + _start_step = warmup_steps + + for step in range(_start_step, update_num): + if step < warmup_steps: + _lr = warmup_init_lr + step * lr_step + elif step < decay_start_step: + _lr = lr + else: + _lr = decay_factor * step ** -0.5 + if _lr < min_lr: + _lr = min_lr + lrs[step] = _lr + + return lrs + + +def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0): + """ + Implements of polynomial decay learning rate scheduler which cycles by default. + + Args: + lr (float): Initial learning rate. + warmup_steps (int): Warmup steps. + decay_steps (int): Decay steps. + total_update_num (int): Total update steps. + min_lr (float): Min learning. + power (float): Power factor. + + Returns: + np.ndarray, learning rate of each step. + """ + lrs = np.zeros(shape=total_update_num, dtype=np.float32) + + if decay_steps <= 0: + raise ValueError("`decay_steps` must larger than 1.") + + _start_step = 0 + if 0 < warmup_steps < total_update_num: + warmup_end_lr = lr + warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr + lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps) + _start_step = warmup_steps + + decay_steps = decay_steps + for step in range(_start_step, total_update_num): + _step = step - _start_step # 2999 + ratio = ceil(_step / decay_steps) # 3 + ratio = 1 if ratio < 1 else ratio + _decay_steps = decay_steps * ratio # 3000 + lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr + + return lrs + + +class BertLearningRate(lr_schedules.LearningRateSchedule): + """ + Implements of warmup-polydecay learning rate scheduler. + + Args: + learning_rate (float): The initial value of learning rate. + end_learning_rate (float): The end value of learning rate. + warmup_steps (int): The warm up steps of learning rate. + decay_steps (int): A value used to calculate decayed learning rate. + power (float): A value used to calculate decayed learning rate. + + Returns: + Tensor. The learning rate value for the current step. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr diff --git a/model_zoo/mass/src/utils/ppl_score.py b/model_zoo/official/nlp/mass/src/utils/ppl_score.py similarity index 100% rename from model_zoo/mass/src/utils/ppl_score.py rename to model_zoo/official/nlp/mass/src/utils/ppl_score.py diff --git a/model_zoo/mass/src/utils/preprocess.py b/model_zoo/official/nlp/mass/src/utils/preprocess.py similarity index 100% rename from model_zoo/mass/src/utils/preprocess.py rename to model_zoo/official/nlp/mass/src/utils/preprocess.py diff --git a/model_zoo/mass/src/utils/rouge_score.py b/model_zoo/official/nlp/mass/src/utils/rouge_score.py similarity index 100% rename from model_zoo/mass/src/utils/rouge_score.py rename to model_zoo/official/nlp/mass/src/utils/rouge_score.py diff --git a/model_zoo/mass/tokenize_corpus.py b/model_zoo/official/nlp/mass/tokenize_corpus.py similarity index 100% rename from model_zoo/mass/tokenize_corpus.py rename to model_zoo/official/nlp/mass/tokenize_corpus.py diff --git a/model_zoo/official/nlp/mass/train.py b/model_zoo/official/nlp/mass/train.py new file mode 100644 index 0000000000..07e4469bd5 --- /dev/null +++ b/model_zoo/official/nlp/mass/train.py @@ -0,0 +1,342 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train api.""" +import os +import argparse +import pickle + +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.nn import Momentum +from mindspore.nn.optim import Adam, Lamb +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore import context, ParallelMode, Parameter +from mindspore.communication import management as MultiAscend +from mindspore.train.serialization import load_checkpoint + +from config import TransformerConfig +from src.dataset import load_dataset +from src.transformer import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell +from src.transformer.infer_mass import infer +from src.utils import LossCallBack +from src.utils import one_weight, zero_weight, weight_variable +from src.utils import square_root_schedule +from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate + +parser = argparse.ArgumentParser(description='MASS train entry point.') +parser.add_argument("--config", type=str, required=True, help="model config json file path.") + +device_id = os.getenv('DEVICE_ID', None) +if device_id is None: + raise RuntimeError("`DEVICE_ID` can not be None.") + +device_id = int(device_id) +context.set_context( + mode=context.GRAPH_MODE, + device_target="Ascend", + reserve_class_name_in_scope=False, + device_id=device_id) + + +def get_config(config): + config = TransformerConfig.from_json_file(config) + config.compute_type = mstype.float16 + config.dtype = mstype.float32 + return config + + +def _train(model, config: TransformerConfig, + pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None, + callbacks: list = None): + """ + Train model. + + Args: + model (Model): MindSpore model instance. + config (TransformerConfig): Config of mass model. + pre_training_dataset (Dataset): Pre-training dataset. + fine_tune_dataset (Dataset): Fine-tune dataset. + test_dataset (Dataset): Test dataset. + callbacks (list): A list of callbacks. + """ + callbacks = callbacks if callbacks else [] + + if pre_training_dataset is not None: + print(" | Start pre-training job.") + epoch_size = config.epochs * pre_training_dataset.get_dataset_size() // config.dataset_sink_step + + if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: + print(f" | Rank {MultiAscend.get_rank()} Call model train.") + + model.train(epoch_size, pre_training_dataset, + callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, + sink_size=config.dataset_sink_step) + + # Test the accuracy of the model. + if test_dataset is not None: + print(" | Start test job.") + result = infer(_config) + with open("validation_res_after_pre_training.bin", "wb") as f: + pickle.dump(result, f, 1) + + if fine_tune_dataset is not None: + print(" | Start fine-tuning job.") + epoch_size = config.epochs * fine_tune_dataset.get_dataset_size() // config.dataset_sink_step + + model.train(epoch_size, fine_tune_dataset, + callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, + sink_size=config.dataset_sink_step) + + # Test the accuracy of the model. + if test_dataset is not None: + print(" | Start test job.") + result = infer(_config) + with open("validation_res_after_pre_training.bin", "wb") as f: + pickle.dump(result, f, 1) + + +def _build_training_pipeline(config: TransformerConfig, + pre_training_dataset=None, + fine_tune_dataset=None, + test_dataset=None): + """ + Build training pipeline. + + Args: + config (TransformerConfig): Config of mass model. + pre_training_dataset (Dataset): Pre-training dataset. + fine_tune_dataset (Dataset): Fine-tune dataset. + test_dataset (Dataset): Test dataset. + """ + net_with_loss = TransformerNetworkWithLoss(config, is_training=True) + net_with_loss.init_parameters_data() + + if config.existed_ckpt: + if config.existed_ckpt.endswith(".npz"): + weights = np.load(config.existed_ckpt) + else: + weights = load_checkpoint(config.existed_ckpt) + for param in net_with_loss.trainable_params(): + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"Param {weights_name} is not found in ckpt file.") + + if isinstance(weights[weights_name], Parameter): + param.default_input = weights[weights_name].default_input + elif isinstance(weights[weights_name], Tensor): + param.default_input = Tensor(weights[weights_name].asnumpy(), config.dtype) + elif isinstance(weights[weights_name], np.ndarray): + param.default_input = Tensor(weights[weights_name], config.dtype) + else: + param.default_input = weights[weights_name] + else: + for param in net_with_loss.trainable_params(): + name = param.name + value = param.default_input + if isinstance(value, Tensor): + if name.endswith(".gamma"): + param.default_input = one_weight(value.asnumpy().shape) + elif name.endswith(".beta") or name.endswith(".bias"): + param.default_input = zero_weight(value.asnumpy().shape) + else: + param.default_input = weight_variable(value.asnumpy().shape) + + dataset = pre_training_dataset if pre_training_dataset is not None \ + else fine_tune_dataset + + if dataset is None: + raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.") + + update_steps = dataset.get_repeat_count() * dataset.get_dataset_size() + if config.lr_scheduler == "isr": + lr = Tensor(square_root_schedule(lr=config.lr, + update_num=update_steps, + decay_start_step=config.decay_start_step, + warmup_steps=config.warmup_steps, + min_lr=config.min_lr), dtype=mstype.float32) + elif config.lr_scheduler == "poly": + lr = Tensor(polynomial_decay_scheduler(lr=config.lr, + min_lr=config.min_lr, + decay_steps=config.decay_steps, + total_update_num=update_steps, + warmup_steps=config.warmup_steps, + power=config.poly_lr_scheduler_power), dtype=mstype.float32) + else: + lr = config.lr + + if config.optimizer.lower() == "adam": + optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98) + elif config.optimizer.lower() == "lamb": + lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr, + power=10.0, warmup_steps=config.warmup_steps) + decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + net_with_loss.trainable_params())) + other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(), + net_with_loss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}] + + optimizer = Lamb(group_params, lr, eps=1e-6) + elif config.optimizer.lower() == "momentum": + optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9) + else: + raise ValueError(f"optimizer only support `adam` and `momentum` now.") + + # Dynamic loss scale. + scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale, + scale_factor=config.loss_scale_factor, + scale_window=config.scale_window) + net_with_grads = TransformerTrainOneStepWithLossScaleCell( + network=net_with_loss, optimizer=optimizer, + scale_update_cell=scale_manager.get_update_cell() + ) + net_with_grads.set_train(True) + model = Model(net_with_grads) + loss_monitor = LossCallBack(config) + ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps, + keep_checkpoint_max=config.keep_ckpt_max) + + rank_size = os.getenv('RANK_SIZE') + callbacks = [loss_monitor] + if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: + ckpt_callback = ModelCheckpoint( + prefix=config.ckpt_prefix, + directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), + config=ckpt_config) + callbacks.append(ckpt_callback) + + if rank_size is None or int(rank_size) == 1: + ckpt_callback = ModelCheckpoint( + prefix=config.ckpt_prefix, + directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), + config=ckpt_config) + callbacks.append(ckpt_callback) + + print(f" | ALL SET, PREPARE TO TRAIN.") + _train(model=model, config=config, + pre_training_dataset=pre_training_dataset, + fine_tune_dataset=fine_tune_dataset, + test_dataset=test_dataset, + callbacks=callbacks) + + +def _setup_parallel_env(): + context.reset_auto_parallel_context() + MultiAscend.init() + context.set_auto_parallel_context( + parallel_mode=ParallelMode.DATA_PARALLEL, + device_num=MultiAscend.get_group_size(), + parameter_broadcast=True, + mirror_mean=True + ) + + +def train_parallel(config: TransformerConfig): + """ + Train model with multi ascend chips. + + Args: + config (TransformerConfig): Config for MASS model. + """ + _setup_parallel_env() + + print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.") + + pre_train_dataset = load_dataset( + data_files=config.pre_train_dataset, + batch_size=config.batch_size, epoch_count=1, + sink_mode=config.dataset_sink_mode, + sink_step=config.dataset_sink_step, + rank_size=MultiAscend.get_group_size(), + rank_id=MultiAscend.get_rank() + ) if config.pre_train_dataset else None + fine_tune_dataset = load_dataset( + data_files=config.fine_tune_dataset, + batch_size=config.batch_size, epoch_count=1, + sink_mode=config.dataset_sink_mode, + sink_step=config.dataset_sink_step, + rank_size=MultiAscend.get_group_size(), + rank_id=MultiAscend.get_rank() + ) if config.fine_tune_dataset else None + test_dataset = load_dataset( + data_files=config.test_dataset, + batch_size=config.batch_size, epoch_count=1, + sink_mode=config.dataset_sink_mode, + sink_step=config.dataset_sink_step, + rank_size=MultiAscend.get_group_size(), + rank_id=MultiAscend.get_rank() + ) if config.test_dataset else None + + _build_training_pipeline(config=config, + pre_training_dataset=pre_train_dataset, + fine_tune_dataset=fine_tune_dataset, + test_dataset=test_dataset) + + +def train_single(config: TransformerConfig): + """ + Train model on single device. + + Args: + config (TransformerConfig): Config for model. + """ + print(" | Starting training on single device.") + pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, + batch_size=config.batch_size, + epoch_count=1, + sink_mode=config.dataset_sink_mode, + sink_step=config.dataset_sink_step) if config.pre_train_dataset else None + fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, + batch_size=config.batch_size, + epoch_count=1, + sink_mode=config.dataset_sink_mode, + sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None + test_dataset = load_dataset(data_files=config.test_dataset, + batch_size=config.batch_size, + epoch_count=1, + sink_mode=config.dataset_sink_mode, + sink_step=config.dataset_sink_step) if config.test_dataset else None + + _build_training_pipeline(config=config, + pre_training_dataset=pre_train_dataset, + fine_tune_dataset=fine_tune_dataset, + test_dataset=test_dataset) + + +def _check_args(config): + if not os.path.exists(config): + raise FileNotFoundError("`config` is not existed.") + if not isinstance(config, str): + raise ValueError("`config` must be type of str.") + + +if __name__ == '__main__': + _rank_size = os.getenv('RANK_SIZE') + + args, _ = parser.parse_known_args() + _check_args(args.config) + _config = get_config(args.config) + + np.random.seed(_config.random_seed) + context.set_context(save_graphs=_config.save_graphs) + + if _rank_size is not None and int(_rank_size) > 1: + train_parallel(_config) + else: + train_single(_config) diff --git a/model_zoo/mass/weights_average.py b/model_zoo/official/nlp/mass/weights_average.py similarity index 100% rename from model_zoo/mass/weights_average.py rename to model_zoo/official/nlp/mass/weights_average.py diff --git a/model_zoo/official/nlp/tinybert/README.md b/model_zoo/official/nlp/tinybert/README.md new file mode 100644 index 0000000000..3d1e990223 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/README.md @@ -0,0 +1,129 @@ +# TinyBERT Example +## Description +[TinyBERT](https://github.com/huawei-noah/Pretrained-Model/tree/master/TinyBERT) is 7.5x smalller and 9.4x faster on inference than [BERT-base](https://github.com/google-research/bert) (the base version of BERT model) and achieves competitive performances in the tasks of natural language understanding. It performs a novel transformer distillation at both the pre-training and task-specific learning stages. + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download dataset for general distill and task distill such as GLUE. +- Prepare a pre-trained bert model and a fine-tuned bert model for specific task such as GLUE. + +## Running the Example +### General Distill +- Set options in `src/gd_config.py`, including lossscale, optimizer and network. + +- Set options in `scripts/run_standalone_gd.sh`, including device target, data sink config, checkpoint config and dataset. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. + +- Run `run_standalone_gd.sh` for non-distributed general distill of BERT-base model. + + ``` bash + bash scripts/run_standalone_gd.sh + ``` +- Run `run_distribute_gd.sh` for distributed general distill of BERT-base model. + + ``` bash + bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE RANK_TABLE_FILE + ``` + +### Task Distill +Task distill has two phases, pre-distill and task distill. +- Set options in `src/td_config.py`, including lossscale, optimizer config of phase 1 and 2, as well as network config. + +- Run `run_standalone_td.py` for task distill of BERT-base model. + + ```bash + bash scripts/run_standalone_td.sh + ``` + +## Usage +### General Distill +``` +usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET] + [--epoch_size N] [--device_id N] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] + [--save_checkpoint_steps N] [--max_ckpt_num N] + [--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH] + [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] + +options: + --distribute whether to run distributely: "true" | "false" + --device_target target device to run, currently only support "Ascend" + --epoch_size epoch size: N, default is 1 + --device_id device id: N, default is 0 + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is "" + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" + +usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET] + [--epoch_size N] [--device_id N] [--device_num N] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] + [--save_ckpt_steps N] [--max_ckpt_num N] + [--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH] + [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] + +options: + --distribute whether to run distributely: "true" | "false" + --device_target target device to run, currently only support "Ascend" + --epoch_size epoch size: N, default is 1 + --device_id device id: N, default is 0 + --device_num device id to run task + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is "" + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" + +``` + +## Options and Parameters +`gd_config.py` and `td_config.py` Contain parameters of BERT model and options for optimizer and lossscale. +### Options: +``` +Parameters for lossscale: + loss_scale_value initial value of loss scale: N, default is 2^8 + scale_factor factor used to update loss scale: N, default is 2 + scale_window steps for once updatation of loss scale: N, default is 50 + +Parameters for task-specific config: + load_teacher_ckpt_path teacher checkpoint to load + load_student_ckpt_path student checkpoint to load + data_dir training data dir + eval_data_dir evaluation data dir + schema_dir data schema path +``` + +### Parameters: +``` +Parameters for bert network: + batch_size batch size of input dataset: N, default is 16 + seq_length length of input sequence: N, default is 128 + vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522 + hidden_size size of bert encoder layers: N + num_hidden_layers number of hidden layers: N + num_attention_heads number of attention heads: N, default is 12 + intermediate_size size of intermediate layer: N + hidden_act activation function used: ACTIVATION, default is "gelu" + hidden_dropout_prob dropout probability for BertOutput: Q + attention_probs_dropout_prob dropout probability for BertAttention: Q + max_position_embeddings maximum length of sequences: N, default is 512 + save_ckpt_step number for saving checkponit: N, default is 100 + max_ckpt_num maximum number for saving checkpoint: N, default is 1 + type_vocab_size size of token type vocab: N, default is 2 + initializer_range initialization value of TruncatedNormal: Q, default is 0.02 + use_relative_positions use relative positions or not: True | False, default is False + input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True + token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True + dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 + compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 + enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False + +Parameters for optimizer: + optimizer optimizer used in the network: AdamWeightDecay + learning_rate value of learning rate: Q + end_learning_rate value of end learning rate: Q, must be positive + power power: Q + weight_decay weight decay: Q + eps term added to the denominator to improve numerical stability: Q +``` + diff --git a/model_zoo/official/nlp/tinybert/__init__.py b/model_zoo/official/nlp/tinybert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py new file mode 100644 index 0000000000..50e586f0af --- /dev/null +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -0,0 +1,126 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""general distill script""" + +import os +import argparse +import datetime +import numpy +import mindspore.communication.management as D +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.callback import TimeMonitor +from mindspore.train.parallel_utils import ParallelMode +from mindspore.nn.optim import AdamWeightDecay +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from src.dataset import create_tinybert_dataset +from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate +from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg +from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd + +def run_general_distill(): + """ + run general distill + """ + parser = argparse.ArgumentParser(description='tinybert general distill') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") + parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") + parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") + parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(reserve_class_name_in_scope=False) + context.set_context(variable_memory_max_size="30GB") + + save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + if not os.path.exists(save_ckpt_dir): + os.makedirs(save_ckpt_dir) + + if args_opt.distribute == "true": + D.init('hccl') + device_num = args_opt.device_num + rank = args_opt.device_id % device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=device_num) + else: + rank = 0 + device_num = 1 + + netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, + teacher_ckpt=args_opt.load_teacher_ckpt_path, + student_config=bert_student_net_cfg, + is_training=True, use_one_hot_embeddings=False) + + dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, + args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) + + dataset_size = dataset.get_dataset_size() + print('dataset size: ', dataset_size) + if args_opt.enable_data_sink == "true": + repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps + time_monitor_steps = args_opt.data_sink_steps + else: + repeat_count = args_opt.epoch_size + time_monitor_steps = dataset_size + + lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(dataset_size * args_opt.epoch_size / 10), + decay_steps=int(dataset_size * args_opt.epoch_size), + power=common_cfg.AdamWeightDecay.power) + params = netwithloss.trainable_params() + decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps) + + callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + save_ckpt_dir)] + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, + scale_factor=common_cfg.scale_factor, + scale_window=common_cfg.scale_window) + + netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + model.train(repeat_count, dataset, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == "true"), + sink_size=args_opt.data_sink_steps) + +if __name__ == '__main__': + numpy.random.seed(0) + run_general_distill() diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py new file mode 100644 index 0000000000..9469c475d2 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -0,0 +1,252 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""task distill script""" + +import os +import re +import argparse +from mindspore import Tensor +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.callback import TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecay +from src.dataset import create_tinybert_dataset +from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate +from src.assessment_method import Accuracy +from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg +from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td +from src.tinybert_model import BertModelCLS + +_cur_dir = os.getcwd() +td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') +td_phase2_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase2_save_ckpt') +if not os.path.exists(td_phase1_save_ckpt_dir): + os.makedirs(td_phase1_save_ckpt_dir) +if not os.path.exists(td_phase2_save_ckpt_dir): + os.makedirs(td_phase2_save_ckpt_dir) + +def parse_args(): + """ + parse args + """ + parser = argparse.ArgumentParser(description='tinybert task distill') + parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.") + parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") + parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") + parser.add_argument("--td_phase1_epoch_size", type=int, default=10, + help="Epoch size for td phase 1, default is 10.") + parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") + parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") + parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") + parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") + parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_gd_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_td1_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + + args = parser.parse_args() + return args + +args_opt = parse_args() +def run_predistill(): + """ + run predistill + """ + cfg = phase1_cfg + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(reserve_class_name_in_scope=False) + load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path + load_student_checkpoint_path = args_opt.load_gd_ckpt_path + netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, + student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, + is_training=True, task_type='classification', + num_labels=args_opt.num_labels, is_predistill=True) + + rank = 0 + device_num = 1 + dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, + device_num, rank, args_opt.do_shuffle, + args_opt.train_data_dir, args_opt.schema_dir) + + dataset_size = dataset.get_dataset_size() + print('td1 dataset size: ', dataset_size) + if args_opt.enable_data_sink == 'true': + repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps + time_monitor_steps = args_opt.data_sink_steps + else: + repeat_count = args_opt.td_phase1_epoch_size + time_monitor_steps = dataset_size + + optimizer_cfg = cfg.optimizer_cfg + + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(dataset_size / 10), + decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size), + power=optimizer_cfg.AdamWeightDecay.power) + params = netwithloss.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + td_phase1_save_ckpt_dir)] + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + model.train(repeat_count, dataset, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == 'true'), + sink_size=args_opt.data_sink_steps) + +def run_task_distill(ckpt_file): + """ + run task distill + """ + if ckpt_file == '': + raise ValueError("Student ckpt file should not be None") + cfg = phase2_cfg + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path + load_student_checkpoint_path = ckpt_file + netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, + student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, + is_training=True, task_type='classification', + num_labels=args_opt.num_labels, is_predistill=False) + + rank = 0 + device_num = 1 + train_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, + device_num, rank, args_opt.do_shuffle, + args_opt.train_data_dir, args_opt.schema_dir) + + dataset_size = train_dataset.get_dataset_size() + print('td2 train dataset size: ', dataset_size) + if args_opt.enable_data_sink == 'true': + repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps + time_monitor_steps = args_opt.data_sink_steps + else: + repeat_count = args_opt.td_phase2_epoch_size + time_monitor_steps = dataset_size + + optimizer_cfg = cfg.optimizer_cfg + + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10), + decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size), + power=optimizer_cfg.AdamWeightDecay.power) + params = netwithloss.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + + eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, + device_num, rank, args_opt.do_shuffle, + args_opt.eval_data_dir, args_opt.schema_dir) + + if args_opt.do_eval.lower() == "true": + callback = [TimeMonitor(time_monitor_steps), LossCallBack(), + EvalCallBack(netwithloss.bert, eval_dataset)] + else: + callback = [TimeMonitor(time_monitor_steps), LossCallBack(), + ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + td_phase2_save_ckpt_dir)] + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + + netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + model.train(repeat_count, train_dataset, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == 'true'), + sink_size=args_opt.data_sink_steps) + +def do_eval_standalone(): + """ + do eval standalone + """ + ckpt_file = args_opt.load_td1_ckpt_path + if ckpt_file == '': + raise ValueError("Student ckpt file should not be None") + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student") + param_dict = load_checkpoint(ckpt_file) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('tinybert_', 'bert_', key) + new_key = re.sub('^bert.', '', new_key) + new_param_dict[new_key] = value + load_param_into_net(eval_model, new_param_dict) + eval_model.set_train(False) + + eval_dataset = create_tinybert_dataset('td', batch_size=1, + device_num=1, rank=0, do_shuffle="false", + data_dir=args_opt.eval_data_dir, + schema_dir=args_opt.schema_dir) + callback = Accuracy() + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in eval_dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + logits = eval_model(input_ids, token_type_id, input_mask) + callback.update(logits[3], label_ids) + acc = callback.acc_num / callback.total_num + print("======================================") + print("============== acc is {}".format(acc)) + print("======================================") + +if __name__ == '__main__': + if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": + raise ValueError("do_train or do eval must have one be true, please confirm your config") + if args_opt.do_train == "true": + # run predistill + run_predistill() + lists = os.listdir(td_phase1_save_ckpt_dir) + if lists: + lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir+'/'+fn)) + name_ext = os.path.splitext(lists[-1]) + if name_ext[-1] != ".ckpt": + raise ValueError("Invalid file, checkpoint file should be .ckpt file") + newest_ckpt_file = os.path.join(td_phase1_save_ckpt_dir, lists[-1]) + # run task distill + run_task_distill(newest_ckpt_file) + else: + raise ValueError("Checkpoint file not exists, please make sure ckpt file has been saved") + else: + do_eval_standalone() diff --git a/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd.sh b/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd.sh new file mode 100644 index 0000000000..e4c15ebf92 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE RANK_TABLE_FILE" +echo "for example: bash scripts/run_distribute_gd.sh 8 40 /path/hccl.json" +echo "It is better to use absolute path." +echo "running....... please see details by LOG{}/log.txt" +echo "==============================================================================================================" + +EPOCH_SIZE=$2 + +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export RANK_TABLE_FILE=$3 +export RANK_SIZE=$1 +cores=`cat /proc/cpuinfo|grep "processor" |wc -l` +echo "the number of logical core" $cores +avg_core_per_rank=`expr $cores \/ $RANK_SIZE` +core_gap=`expr $avg_core_per_rank \- 1` +echo "avg_core_per_rank" $avg_core_per_rank +echo "core_gap" $core_gap +for((i=0;i env.log + taskset -c $cmdopt python ${PROJECT_DIR}/../run_general_distill.py \ + --distribute="true" \ + --device_target="Ascend" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --device_num=$RANK_SIZE \ + --enable_data_sink="true" \ + --data_sink_steps=100 \ + --save_ckpt_step=10000 \ + --max_ckpt_num=1 \ + --load_teacher_ckpt_path="" \ + --data_dir="" \ + --schema_dir="" > log.txt 2>&1 & + cd ../ +done diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd.sh new file mode 100644 index 0000000000..343d1ed7ca --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_standalone_gd.sh" +echo "for example: bash scripts/run_standalone_gd.sh" +echo "running....... please see details by log.txt" +echo "==============================================================================================================" + + +mkdir -p ms_log +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_general_distill.py \ + --distribute="false" \ + --device_target="Ascend" \ + --epoch_size=3 \ + --device_id=0 \ + --enable_data_sink="true" \ + --data_sink_steps=100 \ + --save_ckpt_step=100 \ + --max_ckpt_num=1 \ + --save_ckpt_path="" \ + --load_teacher_ckpt_path="" \ + --data_dir="" \ + --schema_dir="" > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh new file mode 100644 index 0000000000..dcc01163db --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scipts/run_standalone_td.sh" +echo "for example: bash scipts/run_standalone_td.sh" +echo "==============================================================================================================" + +mkdir -p ms_log +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_task_distill.py \ + --device_target="Ascend" \ + --device_id=0 \ + --do_train="true" \ + --do_eval="true" \ + --td_phase1_epoch_size=10 \ + --td_phase2_epoch_size=3 \ + --num_labels=2 \ + --do_shuffle="true" \ + --enable_data_sink="true" \ + --data_sink_steps=100 \ + --save_ckpt_step=100 \ + --max_ckpt_num=1 \ + --load_teacher_ckpt_path="" \ + --load_gd_ckpt_path="" \ + --load_td1_ckpt_path="" \ + --train_data_dir="" \ + --eval_data_dir="" \ + --schema_dir="" > log.txt 2>&1 & + diff --git a/model_zoo/official/nlp/tinybert/src/__init__.py b/model_zoo/official/nlp/tinybert/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/tinybert/src/assessment_method.py b/model_zoo/official/nlp/tinybert/src/assessment_method.py new file mode 100644 index 0000000000..748666e3ce --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/assessment_method.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""assessment methods""" + +import numpy as np + +class Accuracy(): + """Accuracy""" + def __init__(self): + self.acc_num = 0 + self.total_num = 0 + + def update(self, logits, labels): + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + self.acc_num += np.sum(labels == logit_id) + self.total_num += len(labels) + +class F1(): + """F1""" + def __init__(self): + self.TP = 0 + self.FP = 0 + self.FN = 0 + + def update(self, logits, labels): + """Update F1 score""" + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7]) + pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7]) + self.TP += np.sum(pos_eva & pos_label) + self.FP += np.sum(pos_eva & (~pos_label)) + self.FN += np.sum((~pos_eva) & pos_label) + print("-----------------precision is ", self.TP / (self.TP + self.FP)) + print("-----------------recall is ", self.TP / (self.TP + self.FN)) diff --git a/model_zoo/official/nlp/tinybert/src/dataset.py b/model_zoo/official/nlp/tinybert/src/dataset.py new file mode 100644 index 0000000000..fdc0dfe21e --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/dataset.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""create tinybert dataset""" + +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import log as logger + +def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, + do_shuffle="true", data_dir=None, schema_dir=None): + """create tinybert dataset""" + files = os.listdir(data_dir) + data_files = [] + for file_name in files: + if "record" in file_name: + data_files.append(os.path.join(data_dir, file_name)) + if task == "td": + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + else: + columns_list = ["input_ids", "input_mask", "segment_ids"] + + ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, + shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, + shard_equal_rows=True) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + if task == "td": + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + logger.info("data size: {}".format(ds.get_dataset_size())) + logger.info("repeatcount: {}".format(ds.get_repeat_count())) + + return ds diff --git a/model_zoo/official/nlp/tinybert/src/fused_layer_norm.py b/model_zoo/official/nlp/tinybert/src/fused_layer_norm.py new file mode 100644 index 0000000000..d290842c58 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/fused_layer_norm.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fused layernorm""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.ops.primitive import constexpr +import mindspore.common.dtype as mstype +from mindspore.nn.cell import Cell + +import numpy as np + + +__all__ = ['FusedLayerNorm'] + +@constexpr +def get_shape_for_norm(x_shape, begin_norm_axis): + print("input_shape: ", x_shape) + norm_shape = x_shape[begin_norm_axis:] + output_shape = (1, -1, 1, int(np.prod(norm_shape))) + print("output_shape: ", output_shape) + return output_shape + +class FusedLayerNorm(Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + + Layer normalization is widely used in recurrent neural networks. It applies + normalization over a mini-batch of inputs for each single training case as described + in the paper `Layer Normalization `_. Unlike batch + normalization, layer normalization performs exactly the same computation at training and + testing times. It can be described using the following formula. It is applied across all channels + and pixel but only one batch size. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis + `begin_norm_axis ... R - 1`. + begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions + `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. + begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters + will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with + the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + use_batch_nrom (bool): Whether use batchnorm to preocess. + + Inputs: + - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, + and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. + + Outputs: + Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. + + Examples: + >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) + >>> shape1 = x.shape[1:] + >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) + >>> m(x) + """ + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + use_batch_norm=False): + super(FusedLayerNorm, self).__init__() + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." + .format(normalized_shape, type(normalized_shape))) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.gamma = Parameter(initializer( + gamma_init, normalized_shape), name="gamma") + self.beta = Parameter(initializer( + beta_init, normalized_shape), name="beta") + self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) + + self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) + self.use_batch_norm = use_batch_norm + + def construct(self, input_x): + """fusedlayernorm""" + if self.use_batch_norm and self.training: + ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) + zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) + shape_x = F.shape(input_x) + norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) + input_x = F.reshape(input_x, norm_shape) + output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) + output = F.reshape(output, shape_x) + y = output * self.gamma + self.beta + else: + y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) + return y + + def extend_repr(self): + """Display instance object as string.""" + s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( + self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) + return s diff --git a/model_zoo/official/nlp/tinybert/src/gd_config.py b/model_zoo/official/nlp/tinybert/src/gd_config.py new file mode 100644 index 0000000000..d2dc09d8fa --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/gd_config.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in dataset.py, run_general_distill.py and run_task_distill.py +""" +import mindspore.common.dtype as mstype +from easydict import EasyDict as edict +from .tinybert_model import BertConfig + +common_cfg = edict({ + 'loss_scale_value': 2 ** 16, + 'scale_factor': 2, + 'scale_window': 1000, + 'AdamWeightDecay': edict({ + 'learning_rate': 5e-5, + 'end_learning_rate': 1e-14, + 'power': 1.0, + 'weight_decay': 1e-4, + 'eps': 1e-6, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), +}) +''' +Including two kinds of network: \ +teacher network: The BERT-base network. +student network: The network which is inherited from teacher network. +''' +bert_teacher_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) +bert_student_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=384, + num_hidden_layers=4, + num_attention_heads=12, + intermediate_size=1536, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) diff --git a/model_zoo/official/nlp/tinybert/src/td_config.py b/model_zoo/official/nlp/tinybert/src/td_config.py new file mode 100644 index 0000000000..2a9046587e --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/td_config.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""config script for task distill""" + +import mindspore.common.dtype as mstype +from easydict import EasyDict as edict +from .tinybert_model import BertConfig + +phase1_cfg = edict({ + 'loss_scale_value': 2 ** 8, + 'scale_factor': 2, + 'scale_window': 50, + 'optimizer_cfg': edict({ + 'AdamWeightDecay': edict({ + 'learning_rate': 5e-5, + 'end_learning_rate': 1e-14, + 'power': 1.0, + 'weight_decay': 1e-4, + 'eps': 1e-6, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), + }), +}) + +phase2_cfg = edict({ + 'loss_scale_value': 2 ** 16, + 'scale_factor': 2, + 'scale_window': 50, + 'optimizer_cfg': edict({ + 'AdamWeightDecay': edict({ + 'learning_rate': 2e-5, + 'end_learning_rate': 1e-14, + 'power': 1.0, + 'weight_decay': 1e-4, + 'eps': 1e-6, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), + }), +}) + +''' +Including two kinds of network: \ +teacher network: The BERT-base network with finetune. +student network: The model which is producted by GD phase. +''' +td_teacher_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) +td_student_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=384, + num_hidden_layers=4, + num_attention_heads=12, + intermediate_size=1536, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py new file mode 100644 index 0000000000..55da0f3db9 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -0,0 +1,498 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tinybert model""" + +import re +import mindspore.nn as nn +from mindspore import context +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from mindspore.communication.management import get_group_size +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from .tinybert_model import BertModel, TinyBertModel, BertModelCLS + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") +# pylint: disable=consider-using-in +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +class ClipGradients(nn.Cell): + """ + Clip gradients. + + Args: + grads (list): List of gradient tuples. + clip_type (Tensor): The way to clip, 'value' or 'norm'. + clip_value (Tensor): Specifies how much to clip. + + Returns: + List, a list of clipped_grad tuples. + """ + def __init__(self): + super(ClipGradients, self).__init__() + self.clip_by_norm = nn.ClipByNorm() + self.cast = P.Cast() + self.dtype = P.DType() + + def construct(self, + grads, + clip_type, + clip_value): + """clip gradients""" + if clip_type != 0 and clip_type != 1: + return grads + new_grads = () + for grad in grads: + dt = self.dtype(grad) + if clip_type == 0: + t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt), + self.cast(F.tuple_to_array((clip_value,)), dt)) + else: + t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt)) + new_grads = new_grads + (t,) + return new_grads + +class SoftCrossEntropy(nn.Cell): + """SoftCrossEntropy loss""" + def __init__(self): + super(SoftCrossEntropy, self).__init__() + self.log_softmax = P.LogSoftmax(axis=-1) + self.softmax = P.Softmax(axis=-1) + self.reduce_mean = P.ReduceMean() + self.cast = P.Cast() + + def construct(self, predicts, targets): + likelihood = self.log_softmax(predicts) + target_prob = self.softmax(targets) + loss = self.reduce_mean(-target_prob * likelihood) + + return self.cast(loss, mstype.float32) + +class BertNetworkWithLoss_gd(nn.Cell): + """ + Provide bert pre-training loss through network. + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + Returns: + Tensor, the loss of the network. + """ + def __init__(self, teacher_config, teacher_ckpt, student_config, is_training, use_one_hot_embeddings=False, + is_att_fit=True, is_rep_fit=True): + super(BertNetworkWithLoss_gd, self).__init__() + # load teacher model + self.teacher = BertModel(teacher_config, False, use_one_hot_embeddings) + param_dict = load_checkpoint(teacher_ckpt) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('^bert.bert.', 'teacher.', key) + new_param_dict[new_key] = value + load_param_into_net(self.teacher, new_param_dict) + # no_grad + self.teacher.set_train(False) + params = self.teacher.trainable_params() + for param in params: + param.requires_grad = False + # student model + self.bert = TinyBertModel(student_config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.fit_dense = nn.Dense(student_config.hidden_size, + teacher_config.hidden_size).to_float(teacher_config.compute_type) + self.teacher_layers_num = teacher_config.num_hidden_layers + self.student_layers_num = student_config.num_hidden_layers + self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num) + self.is_att_fit = is_att_fit + self.is_rep_fit = is_rep_fit + self.loss_mse = nn.MSELoss() + self.select = P.Select() + self.zeroslike = P.ZerosLike() + self.dtype = teacher_config.dtype + + def construct(self, + input_ids, + input_mask, + token_type_id): + """general distill network with loss""" + # teacher model + _, _, _, teacher_seq_output, teacher_att_output = self.teacher(input_ids, token_type_id, input_mask) + # student model + _, _, _, student_seq_output, student_att_output = self.bert(input_ids, token_type_id, input_mask) + total_loss = 0 + if self.is_att_fit: + selected_teacher_att_output = () + selected_student_att_output = () + for i in range(self.student_layers_num): + selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],) + selected_student_att_output += (student_att_output[i],) + att_loss = 0 + for i in range(self.student_layers_num): + student_att = selected_student_att_output[i] + teacher_att = selected_teacher_att_output[i] + student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), self.zeroslike(student_att), + student_att) + teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), self.zeroslike(teacher_att), + teacher_att) + att_loss += self.loss_mse(student_att, teacher_att) + total_loss += att_loss + if self.is_rep_fit: + selected_teacher_seq_output = () + selected_student_seq_output = () + for i in range(self.student_layers_num + 1): + selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],) + fit_dense_out = self.fit_dense(student_seq_output[i]) + fit_dense_out = self.cast(fit_dense_out, self.dtype) + selected_student_seq_output += (fit_dense_out,) + rep_loss = 0 + for i in range(self.student_layers_num + 1): + teacher_rep = selected_teacher_seq_output[i] + student_rep = selected_student_seq_output[i] + rep_loss += self.loss_mse(student_rep, teacher_rep) + total_loss += rep_loss + return self.cast(total_loss, mstype.float32) + +class BertTrainWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) + +class BertNetworkWithLoss_td(nn.Cell): + """ + Provide bert pre-training loss through network. + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + Returns: + Tensor, the loss of the network. + """ + def __init__(self, teacher_config, teacher_ckpt, student_config, student_ckpt, + is_training, task_type, num_labels, use_one_hot_embeddings=False, + is_predistill=True, is_att_fit=True, is_rep_fit=True, + temperature=1.0, dropout_prob=0.1): + super(BertNetworkWithLoss_td, self).__init__() + # load teacher model + self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, + use_one_hot_embeddings, "teacher") + param_dict = load_checkpoint(teacher_ckpt) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('^bert.', 'teacher.', key) + new_param_dict[new_key] = value + load_param_into_net(self.teacher, new_param_dict) + + # no_grad + self.teacher.set_train(False) + params = self.teacher.trainable_params() + for param in params: + param.requires_grad = False + # load student model + self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, + use_one_hot_embeddings, "student") + param_dict = load_checkpoint(student_ckpt) + if is_predistill: + new_param_dict = {} + for key, value in param_dict.items(): + # new_key = re.sub('tinybert_', 'bert_', key) + new_key = re.sub('tinybert_', 'bert_', 'bert.' + key) + new_param_dict[new_key] = value + load_param_into_net(self.bert, new_param_dict) + else: + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('tinybert_', 'bert_', key) + # new_key = re.sub('tinybert_', 'bert_', 'bert.'+ key) + new_param_dict[new_key] = value + load_param_into_net(self.bert, new_param_dict) + self.cast = P.Cast() + self.fit_dense = nn.Dense(student_config.hidden_size, + teacher_config.hidden_size).to_float(teacher_config.compute_type) + self.teacher_layers_num = teacher_config.num_hidden_layers + self.student_layers_num = student_config.num_hidden_layers + self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num) + self.is_predistill = is_predistill + self.is_att_fit = is_att_fit + self.is_rep_fit = is_rep_fit + self.task_type = task_type + self.temperature = temperature + self.loss_mse = nn.MSELoss() + self.select = P.Select() + self.zeroslike = P.ZerosLike() + self.dtype = student_config.dtype + self.num_labels = num_labels + self.dtype = teacher_config.dtype + self.soft_cross_entropy = SoftCrossEntropy() + + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids): + """task distill network with loss""" + # teacher model + teacher_seq_output, teacher_att_output, teacher_logits, _ = self.teacher(input_ids, token_type_id, input_mask) + # student model + student_seq_output, student_att_output, student_logits, _ = self.bert(input_ids, token_type_id, input_mask) + total_loss = 0 + if self.is_predistill: + if self.is_att_fit: + selected_teacher_att_output = () + selected_student_att_output = () + for i in range(self.student_layers_num): + selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],) + selected_student_att_output += (student_att_output[i],) + att_loss = 0 + for i in range(self.student_layers_num): + student_att = selected_student_att_output[i] + teacher_att = selected_teacher_att_output[i] + student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), + self.zeroslike(student_att), + student_att) + teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), + self.zeroslike(teacher_att), + teacher_att) + att_loss += self.loss_mse(student_att, teacher_att) + total_loss += att_loss + if self.is_rep_fit: + selected_teacher_seq_output = () + selected_student_seq_output = () + for i in range(self.student_layers_num + 1): + selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],) + fit_dense_out = self.fit_dense(student_seq_output[i]) + fit_dense_out = self.cast(fit_dense_out, self.dtype) + selected_student_seq_output += (fit_dense_out,) + rep_loss = 0 + for i in range(self.student_layers_num + 1): + teacher_rep = selected_teacher_seq_output[i] + student_rep = selected_student_seq_output[i] + rep_loss += self.loss_mse(student_rep, teacher_rep) + total_loss += rep_loss + else: + if self.task_type == "classification": + cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) + else: + cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) + total_loss += cls_loss + return self.cast(total_loss, mstype.float32) + +class BertEvaluationCell(nn.Cell): + """ + Especifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertEvaluationCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + label_ids) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + label_ids, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py new file mode 100644 index 0000000000..cc5477bc4f --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -0,0 +1,1054 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from .fused_layer_norm import FusedLayerNorm + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded + from dataset. Default: True. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32, + enable_fused_layernorm=False): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.input_mask_from_dataset = input_mask_from_dataset + self.token_type_ids_from_dataset = token_type_ids_from_dataset + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + self.enable_fused_layernorm = enable_fused_layernorm + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + """embedding lookup""" + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + def __init__(self, + use_relative_positions, + embedding_size, + embedding_shape, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [token_type_vocab_size, + embedding_size]), + name='embedding_table') + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.dropout = nn.Dropout(1 - dropout_prob) + self.gather = P.GatherV2() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + self.full_position_embeddings = Parameter(initializer + (TruncatedNormal(initializer_range), + [max_position_embeddings, + embedding_size]), + name='full_position_embeddings') + + def construct(self, token_type_ids, word_embeddings): + """embedding postprocessor""" + output = word_embeddings + if self.use_token_type: + flat_ids = self.reshape(token_type_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, + self.token_type_vocab_size, self.on_value, self.off_value) + token_type_embeddings = self.array_mul(one_hot_ids, + self.embedding_table) + else: + token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) + token_type_embeddings = self.reshape(token_type_embeddings, self.shape) + output += token_type_embeddings + if not self.use_relative_positions: + _, seq, width = self.shape + position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) + position_embeddings = self.reshape(position_embeddings, (1, seq, width)) + output += position_embeddings + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertOutput, self).__init__() + self.dense = nn.Dense(in_channels, out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.add = P.TensorAdd() + if compute_type == mstype.float16: + self.layernorm = FusedLayerNorm((out_channels,), + use_batch_norm=enable_fused_layernorm).to_float(compute_type) + else: + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + """bert output""" + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(input_tensor, output) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + def __init__(self, length, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._length = length + self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) + self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) + self.range_length = -length + 1 + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self): + """position matrix generator""" + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) + tile_row_out = self.tile(range_vec_row_out, (self._length,)) + tile_col_out = self.tile(range_vec_col_out, (1, self._length)) + range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) + transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) + distance_mat = self.sub(range_mat_out, transpose_out) + distance_mat_clipped = C.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + length, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth]), + name='embeddings_for_position') + self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, + max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.shape = P.Shape() + self.gather = P.GatherV2() # index_select + self.matmul = P.BatchMatMul() + + def construct(self): + """position embedding generation""" + relative_positions_matrix_out = self.relative_positions_matrix() + # Generate embedding for each relative position of dimension depth. + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + min_type = np.finfo(np_type).min + max_type = np.finfo(np_type).max + self.tensor_min_type = Tensor([min_type], dtype=src_type) + self.tensor_max_type = Tensor([max_type], dtype=src_type) + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + """saturate cast""" + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=False, + use_relative_positions=False, + compute_type=mstype.float32): + super(BertAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + weight_init=weight).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + weight_init=weight).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + weight_init=weight).to_float(compute_type) + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = ( + batch_size, to_seq_length, num_attention_heads, size_per_head) + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(length=to_seq_length, + depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + """bert attention""" + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + attention_scores = self.matmul_trans_b(query_layer, key_layer) + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_keys' = [F|T, F|T, H] + relations_keys = self._generate_relative_positions_embeddings() + relations_keys = self.cast_compute_type(relations_keys) + # query_layer_t is [F, B, N, H] + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + # query_layer_r is [F, B * N, H] + query_layer_r = self.reshape(query_layer_t, + (self.from_seq_length, + self.batch_num, + self.size_per_head)) + # key_position_scores is [F, B * N, F|T] + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + # key_position_scores_r is [F, B, N, F|T] + key_position_scores_r = self.reshape(key_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.from_seq_length)) + # key_position_scores_r_t is [B, N, F, F|T] + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + attention_scores = self.multiply(self.scores_mul, attention_scores) + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_values' = [F|T, F|T, H] + relations_values = self._generate_relative_positions_embeddings() + relations_values = self.cast_compute_type(relations_values) + # attention_probs_t is [F, B, N, T] + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + # attention_probs_r is [F, B * N, T] + attention_probs_r = self.reshape( + attention_probs_t, + (self.from_seq_length, + self.batch_num, + self.to_seq_length)) + # value_position_scores is [F, B * N, H] + value_position_scores = self.matmul(attention_probs_r, + relations_values) + # value_position_scores_r is [F, B, N, H] + value_position_scores_r = self.reshape(value_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.size_per_head)) + # value_position_scores_r_t is [B, N, F, H] + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + return context_layer, attention_scores + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + self.size_per_head = int(hidden_size / num_attention_heads) + self.attention = BertAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + do_return_2d_tensor=True, + compute_type=compute_type) + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + """bert self attention""" + input_tensor = self.reshape(input_tensor, self.shape) + attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output, attention_scores + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the bert encoder layers. Default: 768. + seq_length (int): Length of input sequence. Default: 512. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=768, + seq_length=512, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.intermediate = nn.Dense(in_channels=hidden_size, + out_channels=intermediate_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + def construct(self, hidden_states, attention_mask): + """bert encoder cell""" + # self-attention + attention_output, attention_scores = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output, attention_scores + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False, + enable_fused_layernorm=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + layers.append(layer) + self.layers = nn.CellList(layers) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + def construct(self, input_tensor, attention_mask): + """bert transformer""" + prev_output = self.reshape(input_tensor, self.shape) + all_encoder_layers = () + all_encoder_atts = () + all_encoder_outputs = () + all_encoder_outputs += (prev_output,) + for layer_module in self.layers: + layer_output, encoder_att = layer_module(prev_output, attention_mask) + prev_output = layer_output + if self.return_all_encoders: + all_encoder_outputs += (layer_output,) + layer_output = self.reshape(layer_output, self.out_shape) + all_encoder_layers += (layer_output,) + all_encoder_atts += (encoder_att,) + if not self.return_all_encoders: + prev_output = self.reshape(prev_output, self.out_shape) + all_encoder_layers += (prev_output,) + return all_encoder_layers, all_encoder_outputs, all_encoder_atts + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask_from_dataset = config.input_mask_from_dataset + self.input_mask = None + if not self.input_mask_from_dataset: + self.input_mask = initializer( + "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = (config.batch_size, 1, config.seq_length) + self.broadcast_ones = initializer( + "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() + self.batch_matmul = P.BatchMatMul() + def construct(self, input_mask): + if not self.input_mask_from_dataset: + input_mask = self.input_mask + input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) + attention_mask = self.batch_matmul(self.broadcast_ones, input_mask) + return attention_mask + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + self.bert_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + use_relative_positions=config.use_relative_positions, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.bert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """bert model""" + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + # bert encoder + encoder_output, encoder_layers, layer_atts = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + encoder_outputs = () + for output in encoder_layers: + encoder_outputs += (self.cast(output, self.dtype),) + attention_outputs = () + for output in layer_atts: + attention_outputs += (self.cast(output, self.dtype),) + return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs + + +class TinyBertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(TinyBertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + self.tinybert_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + self.tinybert_embedding_postprocessor = EmbeddingPostprocessor( + use_relative_positions=config.use_relative_positions, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.tinybert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """tiny bert model""" + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids) + embedding_output = self.tinybert_embedding_postprocessor(token_type_ids, + word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + # bert encoder + encoder_output, encoder_layers, layer_atts = self.tinybert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + encoder_outputs = () + for output in encoder_layers: + encoder_outputs += (self.cast(output, self.dtype),) + attention_outputs = () + for output in layer_atts: + attention_outputs += (self.cast(output, self.dtype),) + return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs + + +class BertModelCLS(nn.Cell): + """ + This class is responsible for classification task evaluation, + i.e. XNLI(num_labels=3), LCQMC(num_labels=2), Chnsenti(num_labels=2). + The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, + use_one_hot_embeddings=False, phase_type="teacher"): + super(BertModelCLS, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.phase_type = phase_type + if self.phase_type == "teacher": + self.dense = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + else: + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.ReLU() + + def construct(self, input_ids, token_type_id, input_mask): + """classification bert model""" + _, pooled_output, _, seq_output, att_output = self.bert(input_ids, token_type_id, input_mask) + cls = self.cast(pooled_output, self.dtype) + cls = self.dropout(cls) + if self.phase_type == "teacher": + logits = self.dense(cls) + else: + logits = self.dense_1(cls) + logits = self.cast(logits, self.dtype) + log_probs = self.log_softmax(logits) + return seq_output, att_output, logits, log_probs diff --git a/model_zoo/official/nlp/tinybert/src/utils.py b/model_zoo/official/nlp/tinybert/src/utils.py new file mode 100644 index 0000000000..d10fb8642e --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/utils.py @@ -0,0 +1,140 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""tinybert utils""" + +import os +import numpy as np +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.train.callback import Callback +from mindspore.train.serialization import _exec_save_checkpoint +from mindspore.ops import operations as P +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR +from .assessment_method import Accuracy + +class ModelSaveCkpt(Callback): + """ + Saves checkpoint. + If the loss in NAN or INF terminating training. + Args: + network (Network): The train network for training. + save_ckpt_num (int): The number to save checkpoint, default is 1000. + max_ckpt_num (int): The max checkpoint number, default is 3. + """ + def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir): + super(ModelSaveCkpt, self).__init__() + self.count = 0 + self.network = network + self.save_ckpt_step = save_ckpt_step + self.max_ckpt_num = max_ckpt_num + self.output_dir = output_dir + + def step_end(self, run_context): + """step end and save ckpt""" + cb_params = run_context.original_args() + if cb_params.cur_step_num % self.save_ckpt_step == 0: + saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step + if saved_ckpt_num > self.max_ckpt_num: + oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num + path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index), + self.save_ckpt_step)) + if os.path.exists(path): + os.remove(path) + _exec_save_checkpoint(self.network, os.path.join(self.output_dir, + "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), + self.save_ckpt_step))) + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self._per_print_times = per_print_times + + def step_end(self, run_context): + """step end and print loss""" + cb_params = run_context.original_args() + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, + cb_params.cur_step_num, + str(cb_params.net_outputs))) + +class EvalCallBack(Callback): + """Evaluation callback""" + def __init__(self, network, dataset): + super(EvalCallBack, self).__init__() + self.network = network + self.global_acc = 0.0 + self.dataset = dataset + + def step_end(self, run_context): + """step end and do evaluation""" + cb_params = run_context.original_args() + if cb_params.cur_step_num % 100 == 0: + callback = Accuracy() + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in self.dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + self.network.set_train(False) + logits = self.network(input_ids, token_type_id, input_mask) + callback.update(logits[3], label_ids) + acc = callback.acc_num / callback.total_num + with open("./eval.log", "a+") as f: + f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + f.write('\n') + + if acc > self.global_acc: + self.global_acc = acc + print("The best acc is {}".format(acc)) + _exec_save_checkpoint(self.network, "eval_model.ckpt") + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + decay_lr = self.decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr diff --git a/model_zoo/official/nlp/transformer/README.md b/model_zoo/official/nlp/transformer/README.md new file mode 100644 index 0000000000..6a969f3eb7 --- /dev/null +++ b/model_zoo/official/nlp/transformer/README.md @@ -0,0 +1,176 @@ +# Transformer Example +## Description +This example implements training and evaluation of Transformer Model, which is introduced in the following paper: +- Ashish Vaswani, Noam Shazeer, Niki Parmar, JakobUszkoreit, Llion Jones, Aidan N Gomez, Ł ukaszKaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS 2017, pages 5998–6008. + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download and preprocess the WMT English-German dataset for training and evaluation. + +> Notes:If you are running an evaluation task, prepare the corresponding checkpoint file. + +## Example structure + +```shell +. +└─Transformer + ├─README.md + ├─scripts + ├─process_output.sh + ├─replace-quote.perl + ├─run_distribute_train.sh + └─run_standalone_train.sh + ├─src + ├─__init__.py + ├─beam_search.py + ├─config.py + ├─dataset.py + ├─eval_config.py + ├─lr_schedule.py + ├─process_output.py + ├─tokenization.py + ├─transformer_for_train.py + ├─transformer_model.py + └─weight_init.py + ├─create_data.py + ├─eval.py + └─train.py +``` + +--- + +## Prepare the dataset +- You may use this [shell script](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files: + - train.tok.clean.bpe.32000.en + - train.tok.clean.bpe.32000.de + - vocab.bpe.32000 + - newstest2014.tok.bpe.32000.en + - newstest2014.tok.bpe.32000.de + - newstest2014.tok.de + +- Convert the original data to mindrecord for training: + + ``` bash + paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all + python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 + ``` +- Convert the original data to mindrecord for evaluation: + + ``` bash + paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all + python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True + ``` + +## Running the example + +### Training +- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#mindspore) for more information about dataset. + +- Run `run_standalone_train.sh` for non-distributed training of Transformer model. + + ``` bash + sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH + ``` +- Run `run_distribute_train.sh` for distributed training of Transformer model. + + ``` bash + sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE + ``` + +### Evaluation +- Set options in `eval_config.py`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path. + +- Run `eval.py` for evaluation of Transformer model. + + ```bash + python eval.py + ``` + +- Run `process_output.sh` to process the output token ids to get the real translation results. + + ```bash + sh scripts/process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE + ``` + You will get two files, REF_DATA.forbleu and EVAL_OUTPUT.forbleu, for BLEU score calculation. + +- Calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score. + + ```bash + perl multi-bleu.perl REF_DATA.forbleu < EVAL_OUTPUT.forbleu + ``` + +--- + +## Usage + +### Training +``` +usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] + [--enable_save_ckpt ENABLE_SAVE_CKPT] + [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] + [--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N] + [--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH] + [--data_path DATA_PATH] + +options: + --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" + --epoch_size epoch size: N, default is 52 + --device_num number of used devices: N, default is 1 + --device_id device id: N, default is 0 + --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" + --enable_lossscale enable lossscale: "true" | "false", default is "true" + --do_shuffle enable shuffle: "true" | "false", default is "true" + --enable_data_sink enable data sink: "true" | "false", default is "false" + --checkpoint_path path to load checkpoint files: PATH, default is "" + --save_checkpoint_steps steps for saving checkpoint files: N, default is 2500 + --save_checkpoint_num number for saving checkpoint files: N, default is 30 + --save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/" + --data_path path to dataset file: PATH, default is "" +``` + +## Options and Parameters +It contains of parameters of Transformer model and options for training and evaluation, which is set in file `config.py` and `evaluation_config.py` respectively. +### Options: +``` +config.py: + transformer_network version of Transformer model: base | large, default is large + init_loss_scale_value initial value of loss scale: N, default is 2^10 + scale_factor factor used to update loss scale: N, default is 2 + scale_window steps for once updatation of loss scale: N, default is 2000 + optimizer optimizer used in the network: Adam, default is "Adam" + +eval_config.py: + transformer_network version of Transformer model: base | large, default is large + data_file data file: PATH + model_file checkpoint file to be loaded: PATH + output_file output file of evaluation: PATH +``` + +### Parameters: +``` +Parameters for dataset and network (Training/Evaluation): + batch_size batch size of input dataset: N, default is 96 + seq_length length of input sequence: N, default is 128 + vocab_size size of each embedding vector: N, default is 36560 + hidden_size size of Transformer encoder layers: N, default is 1024 + num_hidden_layers number of hidden layers: N, default is 6 + num_attention_heads number of attention heads: N, default is 16 + intermediate_size size of intermediate layer: N, default is 4096 + hidden_act activation function used: ACTIVATION, default is "relu" + hidden_dropout_prob dropout probability for TransformerOutput: Q, default is 0.3 + attention_probs_dropout_prob dropout probability for TransformerAttention: Q, default is 0.3 + max_position_embeddings maximum length of sequences: N, default is 128 + initializer_range initialization value of TruncatedNormal: Q, default is 0.02 + label_smoothing label smoothing setting: Q, default is 0.1 + input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True + beam_width beam width setting: N, default is 4 + max_decode_length max decode length in evaluation: N, default is 80 + length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0 + compute_type compute type in Transformer: mstype.float16 | mstype.float32, default is mstype.float16 + +Parameters for learning rate: + learning_rate value of learning rate: Q + warmup_steps steps of the learning rate warm up: N + start_decay_step step of the learning rate to decay: N + min_lr minimal learning rate: Q +``` \ No newline at end of file diff --git a/model_zoo/Transformer/create_data.py b/model_zoo/official/nlp/transformer/create_data.py similarity index 100% rename from model_zoo/Transformer/create_data.py rename to model_zoo/official/nlp/transformer/create_data.py diff --git a/model_zoo/Transformer/eval.py b/model_zoo/official/nlp/transformer/eval.py similarity index 100% rename from model_zoo/Transformer/eval.py rename to model_zoo/official/nlp/transformer/eval.py diff --git a/model_zoo/Transformer/scripts/process_output.sh b/model_zoo/official/nlp/transformer/scripts/process_output.sh similarity index 100% rename from model_zoo/Transformer/scripts/process_output.sh rename to model_zoo/official/nlp/transformer/scripts/process_output.sh diff --git a/model_zoo/Transformer/scripts/replace-quote.perl b/model_zoo/official/nlp/transformer/scripts/replace-quote.perl similarity index 100% rename from model_zoo/Transformer/scripts/replace-quote.perl rename to model_zoo/official/nlp/transformer/scripts/replace-quote.perl diff --git a/model_zoo/official/nlp/transformer/scripts/run_distribute_train.sh b/model_zoo/official/nlp/transformer/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..ea6ea614dc --- /dev/null +++ b/model_zoo/official/nlp/transformer/scripts/run_distribute_train.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE" +echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +rm -rf run_distribute_train +mkdir run_distribute_train +cd run_distribute_train || exit + +EPOCH_SIZE=$2 +DATA_PATH=$3 + +export RANK_TABLE_FILE=$4 +export RANK_SIZE=$1 +export HCCL_FLAG=1 +export DEPLOY_MODE=0 + +for((i=0;i env.log + python train.py \ + --distribute="true" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --device_num=$RANK_SIZE \ + --enable_save_ckpt="true" \ + --enable_lossscale="true" \ + --do_shuffle="true" \ + --enable_data_sink="false" \ + --checkpoint_path="" \ + --save_checkpoint_steps=2500 \ + --save_checkpoint_num=30 \ + --data_path=$DATA_PATH > log.txt 2>&1 & + cd ../ +done +cd .. \ No newline at end of file diff --git a/model_zoo/Transformer/scripts/run_standalone_train.sh b/model_zoo/official/nlp/transformer/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/Transformer/scripts/run_standalone_train.sh rename to model_zoo/official/nlp/transformer/scripts/run_standalone_train.sh diff --git a/model_zoo/official/nlp/transformer/src/__init__.py b/model_zoo/official/nlp/transformer/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/Transformer/src/beam_search.py b/model_zoo/official/nlp/transformer/src/beam_search.py similarity index 100% rename from model_zoo/Transformer/src/beam_search.py rename to model_zoo/official/nlp/transformer/src/beam_search.py diff --git a/model_zoo/Transformer/src/config.py b/model_zoo/official/nlp/transformer/src/config.py similarity index 100% rename from model_zoo/Transformer/src/config.py rename to model_zoo/official/nlp/transformer/src/config.py diff --git a/model_zoo/official/nlp/transformer/src/dataset.py b/model_zoo/official/nlp/transformer/src/dataset.py new file mode 100644 index 0000000000..84dc5427b2 --- /dev/null +++ b/model_zoo/official/nlp/transformer/src/dataset.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data operations, will be used in train.py.""" + +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as deC +from .config import transformer_net_cfg + +def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true", + dataset_path=None): + """create dataset""" + repeat_count = epoch_count + ds = de.MindDataset(dataset_path, + columns_list=["source_eos_ids", "source_eos_mask", + "target_sos_ids", "target_sos_mask", + "target_eos_ids", "target_eos_mask"], + shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id) + + type_cast_op = deC.TypeCast(mstype.int32) + ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op) + ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op) + ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op) + ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op) + ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op) + ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op) + + # apply batch operations + ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) + ds = ds.repeat(repeat_count) + + return ds diff --git a/model_zoo/Transformer/src/eval_config.py b/model_zoo/official/nlp/transformer/src/eval_config.py similarity index 100% rename from model_zoo/Transformer/src/eval_config.py rename to model_zoo/official/nlp/transformer/src/eval_config.py diff --git a/model_zoo/Transformer/src/lr_schedule.py b/model_zoo/official/nlp/transformer/src/lr_schedule.py similarity index 100% rename from model_zoo/Transformer/src/lr_schedule.py rename to model_zoo/official/nlp/transformer/src/lr_schedule.py diff --git a/model_zoo/Transformer/src/process_output.py b/model_zoo/official/nlp/transformer/src/process_output.py similarity index 100% rename from model_zoo/Transformer/src/process_output.py rename to model_zoo/official/nlp/transformer/src/process_output.py diff --git a/model_zoo/Transformer/src/tokenization.py b/model_zoo/official/nlp/transformer/src/tokenization.py similarity index 100% rename from model_zoo/Transformer/src/tokenization.py rename to model_zoo/official/nlp/transformer/src/tokenization.py diff --git a/model_zoo/Transformer/src/transformer_for_train.py b/model_zoo/official/nlp/transformer/src/transformer_for_train.py similarity index 100% rename from model_zoo/Transformer/src/transformer_for_train.py rename to model_zoo/official/nlp/transformer/src/transformer_for_train.py diff --git a/model_zoo/official/nlp/transformer/src/transformer_model.py b/model_zoo/official/nlp/transformer/src/transformer_model.py new file mode 100644 index 0000000000..fb33f526da --- /dev/null +++ b/model_zoo/official/nlp/transformer/src/transformer_model.py @@ -0,0 +1,1157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer model.""" + +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from .beam_search import BeamSearchDecoder, TileBeam +from .weight_init import normal_weight, weight_variable + +class TransformerConfig: + """ + Configuration for `Transformer`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 36560. + hidden_size (int): Size of the layers. Default: 1024. + num_hidden_layers (int): Number of hidden layers in the Transformer encoder/decoder + cell. Default: 6. + num_attention_heads (int): Number of attention heads in the Transformer + encoder/decoder cell. Default: 16. + intermediate_size (int): Size of intermediate layer in the Transformer + encoder/decoder cell. Default: 4096. + hidden_act (str): Activation function used in the Transformer encoder/decoder + cell. Default: "relu". + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.3. + attention_probs_dropout_prob (float): The dropout probability for + MultiheadAttention. Default: 0.3. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 128. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + label_smoothing (float): label smoothing setting. Default: 0.1 + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + beam_width (int): beam width setting. Default: 4 + max_decode_length (int): max decode length in evaluation. Default: 80 + length_penalty_weight (float): normalize scores of translations according to their length. Default: 1.0 + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length=128, + vocab_size=36560, + hidden_size=1024, + num_hidden_layers=6, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="relu", + hidden_dropout_prob=0.3, + attention_probs_dropout_prob=0.3, + max_position_embeddings=128, + initializer_range=0.02, + label_smoothing=0.1, + input_mask_from_dataset=True, + beam_width=4, + max_decode_length=80, + length_penalty_weight=1.0, + dtype=mstype.float32, + compute_type=mstype.float32): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.label_smoothing = label_smoothing + self.input_mask_from_dataset = input_mask_from_dataset + self.beam_width = beam_width + self.max_decode_length = max_decode_length + self.length_penalty_weight = length_penalty_weight + self.dtype = dtype + self.compute_type = compute_type + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = P.Shape() + + def construct(self, input_ids): + input_shape = self.shape(input_ids) + + flat_ids = self.reshape(input_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + + out_shape = input_shape + (self.embedding_size,) + output = self.reshape(output_for_reshape, out_shape) + return output, self.embedding_table + + +def position_encoding(length, + depth, + min_timescale=1, + max_timescale=1e4): + """ + Create Tensor of sinusoids of different frequencies. + + Args: + length (int): Length of the Tensor to create, i.e. Number of steps. + depth (int): Hidden size. + min_timescale (float): Default: 1. + max_timescale (float): Default: 10000. + + Returns: + Tensor of shape (length, depth) + """ + depth = depth // 2 + positions = np.arange(length, dtype=np.float32) + log_timescale_increment = (np.log(max_timescale / min_timescale) / (depth - 1)) + inv_timescales = min_timescale * np.exp(np.arange(depth, dtype=np.float32) * -log_timescale_increment) + scaled_time = np.expand_dims(positions, 1) * np.expand_dims(inv_timescales, 0) + x = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) + return x + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 128. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + def __init__(self, + embedding_size, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=128, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.scores_mul = Tensor([math.sqrt(float(embedding_size))], dtype=mstype.float32) + self.multiply = P.Mul() + self.add = P.TensorAdd() + self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32) + self.use_dropout = dropout_prob > 0 + self.expand_dims = P.ExpandDims() + self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size), + mstype.float32) + self.shape = P.Shape() + + def construct(self, word_embeddings): + input_shape = self.shape(word_embeddings) + input_len = input_shape[1] + + output = self.multiply(word_embeddings, self.scores_mul) + + # add position embeddings + position_embeddings = self.position_embedding_table[0:input_len:1, ::] + position_embeddings = self.expand_dims(position_embeddings, 0) + output = self.add(output, position_embeddings) + + if self.use_dropout: + output = self.dropout(output) + return output + + +class CastWrapper(nn.Cell): + """ + Cast wrapper. + """ + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(CastWrapper, self).__init__() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + return self.cast(x, self.dst_type) + + +class LayerPreprocess(nn.Cell): + """ + preprocess input of each layer. + """ + def __init__(self, + in_channels=None): + super(LayerPreprocess, self).__init__() + self.layernorm = nn.LayerNorm((in_channels,)) + self.cast = P.Cast() + self.get_dtype = P.DType() + + def construct(self, input_tensor): + output = self.cast(input_tensor, mstype.float32) + output = self.layernorm(output) + output = self.cast(output, self.get_dtype(input_tensor)) + return output + + +class LayerPostprocess(nn.Cell): + """ + postprocess ouput of each layer. + """ + def __init__(self, + dropout_prob=0.1): + super(LayerPostprocess, self).__init__() + self.add = P.TensorAdd() + self.dropout = nn.Dropout(1 - dropout_prob) + self.use_dropout = dropout_prob > 0 + + def construct(self, hidden_tensor, input_tensor): + output = hidden_tensor + if self.use_dropout: + output = self.dropout(output) + output = self.add(output, input_tensor) + return output + + +class MultiheadAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + MultiheadAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + out_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + out_act=None, + has_attention_mask=True, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=True, + compute_type=mstype.float32): + super(MultiheadAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + assert has_attention_mask + + self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + has_bias=False, + weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + has_bias=False, + weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + has_bias=False, + weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type) + self.out_layer = nn.Dense(units, + out_tensor_width, + activation=out_act, + has_bias=False, + weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type) + + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + self.use_dropout = attention_probs_dropout_prob > 0 + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + if from_seq_length == -1: + self.shape_return = (-1, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + + self.cast_compute_type = CastWrapper(dst_type=compute_type) + self.softmax_cast = P.Cast() + + def construct(self, from_tensor, to_tensor, attention_mask=None): + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + attention_scores = self.multiply(attention_scores, self.scores_mul) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_scores = self.softmax_cast(attention_scores, mstype.float32) + attention_probs = self.softmax(attention_scores) + attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key_layer)) + if self.use_dropout: + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + context_layer = self.out_layer(context_layer) + return context_layer + + +class SelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + from_seq_length (int): Length of query sequence. + to_seq_length (int): Length of memory sequence. + hidden_size (int): Size of attention layers. + num_attention_heads (int): Number of attention heads. Default: 16. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + has_attention_mask (bool): Specifies whether has attention mask. Default: True. + is_encdec_att (bool): Specifies whether query sequence and memory sequence are different. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_seq_length, + to_seq_length, + hidden_size, + num_attention_heads=16, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + has_attention_mask=True, + is_encdec_att=False, + compute_type=mstype.float32): + super(SelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + self.size_per_head = int(hidden_size / num_attention_heads) + self.is_encdec_att = is_encdec_att + + self.attention = MultiheadAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + out_tensor_width=hidden_size, + from_seq_length=from_seq_length, + to_seq_length=to_seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + has_attention_mask=has_attention_mask, + do_return_2d_tensor=True, + compute_type=compute_type) + + self.preprocess = LayerPreprocess(in_channels=hidden_size) + self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + def construct(self, input_tensor, memory_tensor, attention_mask): + input_tensor = self.reshape(input_tensor, self.shape) + memory_tensor = self.reshape(memory_tensor, self.shape) + + output = self.preprocess(input_tensor) + + if not self.is_encdec_att: + memory_tensor = output + + attention_output = self.attention(output, memory_tensor, attention_mask) + output = self.postprocess(attention_output, input_tensor) + return output + + +class FeedForward(nn.Cell): + """ + Apply two-layer feed forward + + Args: + in_channels (int): Size of the input layer. + hidden_size (int): Size of the hidden layer. + out_channels (int): Size of the output layers. + hidden_act (str): name of the activation function. Default: relu + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in FeedForward. Default: mstype.float32. + """ + def __init__(self, + in_channels, + hidden_size, + out_channels, + hidden_act="relu", + initializer_range=0.02, + hidden_dropout_prob=0.1, + compute_type=mstype.float32): + super(FeedForward, self).__init__() + + self.conv1 = nn.Dense(in_channels, + hidden_size, + activation=hidden_act, + weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type) + self.conv2 = nn.Dense(hidden_size, + out_channels, + weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type) + + self.preprocess = LayerPreprocess(in_channels=in_channels) + self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) + + self.reshape = P.Reshape() + self.shape = (-1, in_channels) + self.dropout = nn.Dropout(1 - hidden_dropout_prob) + self.use_dropout = hidden_dropout_prob > 0 + + def construct(self, input_tensor): + input_tensor = self.reshape(input_tensor, self.shape) + output = self.preprocess(input_tensor) + output = self.conv1(output) + if self.use_dropout: + output = self.dropout(output) + output = self.conv2(output) + output = self.postprocess(output, input_tensor) + return output + + +class EncoderCell(nn.Cell): + """ + Encoder cells used in Transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. Default: 1024. + seq_length (int): Length of input sequence. Default: 128. + num_attention_heads (int): Number of attention heads. Default: 16. + intermediate_size (int): Size of intermediate layer. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.1. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function. Default: "relu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=1024, + seq_length=128, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(EncoderCell, self).__init__() + self.attention = SelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + is_encdec_att=False, + compute_type=compute_type) + self.feedforward = FeedForward( + in_channels=hidden_size, + hidden_size=intermediate_size, + out_channels=hidden_size, + hidden_act=hidden_act, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + + def construct(self, hidden_states, attention_mask): + # self-attention with ln, res + attention_output = self.attention(hidden_states, hidden_states, attention_mask) + # feed forward with ln, res + output = self.feedforward(attention_output) + return output + + +class TransformerEncoder(nn.Cell): + """ + Multi-layer transformer encoder. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(TransformerEncoder, self).__init__() + self.num_hidden_layers = num_hidden_layers + + layers = [] + for _ in range(num_hidden_layers): + layer = EncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + layers.append(layer) + self.layers = nn.CellList(layers) + + self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask): + prev_output = self.reshape(input_tensor, self.shape) + + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + prev_output = self.layer_preprocess(prev_output) + output = self.reshape(prev_output, self.out_shape) + return output + + +class DecoderCell(nn.Cell): + """ + decoder cells used in Transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the Transformer decoder layers. Default: 1024. + seq_length (int): Length of input sequence. Default: 128. + enc_seq_length (int): Length of source sentences. Default:128 + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function. Default: "relu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=1024, + seq_length=128, + enc_seq_length=128, + num_attention_heads=12, + intermediate_size=4096, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(DecoderCell, self).__init__() + self.self_attention = SelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + is_encdec_att=False, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + self.cross_attention = SelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + from_seq_length=seq_length, + to_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + is_encdec_att=True, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + self.feedforward = FeedForward( + in_channels=hidden_size, + hidden_size=intermediate_size, + out_channels=hidden_size, + hidden_act=hidden_act, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + + def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask): + # self-attention with ln, res + attention_output = self.self_attention(hidden_states, hidden_states, attention_mask) + # cross-attention with ln, res + attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask) + # feed forward with ln, res + output = self.feedforward(attention_output) + return output + + +class TransformerDecoder(nn.Cell): + """ + Multi-layer transformer decoder. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + enc_seq_length (int): Length of source sentences. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + enc_seq_length, + num_hidden_layers, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(TransformerDecoder, self).__init__() + self.num_hidden_layers = num_hidden_layers + + layers = [] + for _ in range(num_hidden_layers): + layer = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + layers.append(layer) + self.layers = nn.CellList(layers) + + self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): + prev_output = self.reshape(input_tensor, self.shape) + + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = layer_output + + prev_output = self.layer_preprocess(prev_output) + output = self.reshape(prev_output, self.out_shape) + return output + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (:class:`TransformerConfig`): Configuration for Transformer. + """ + def __init__(self): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = P.Shape() + self.batch_matmul = P.BatchMatMul() + + def construct(self, input_mask): + input_shape = self.shape(input_mask) + shape_right = (input_shape[0], 1, input_shape[1]) + shape_left = input_shape + (1,) + + input_mask = self.cast(input_mask, mstype.float32) + mask_left = self.reshape(input_mask, shape_left) + mask_right = self.reshape(input_mask, shape_right) + attention_mask = self.batch_matmul(mask_left, mask_right) + + return attention_mask + + +class PredLogProbs(nn.Cell): + """ + Get log probs. + + Args: + batch_size (int): Batch size. + seq_length (int): Length of input sequence. + width (int): Hidden size. + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + dtype (:class:`mindspore.dtype`): Compute type to compute log_softmax. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + width, + compute_type=mstype.float32, + dtype=mstype.float32): + super(PredLogProbs, self).__init__() + self.batch_size = batch_size + self.seq_length = seq_length + self.width = width + self.compute_type = compute_type + self.dtype = dtype + + self.reshape = P.Reshape() + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width) + self.cast = P.Cast() + + def construct(self, + input_tensor, + output_weights): + input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + + log_probs = self.log_softmax(logits) + return log_probs + + +class TransformerDecoderStep(nn.Cell): + """ + Multi-layer transformer decoder step. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + max_decode_length (int): Max decode length. + enc_seq_length (int): Length of source sentences. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + embedding_lookup (:class:`EmbeddingLookup`): Embedding lookup module. + embedding_processor (:class:`EmbeddingPostprocessor`) Embedding postprocessor module. + projection (:class:`PredLogProbs`): PredLogProbs module + """ + def __init__(self, + batch_size, + hidden_size, + enc_seq_length, + max_decode_length, + num_hidden_layers, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.3, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.3, + hidden_act="relu", + compute_type=mstype.float32, + embedding_lookup=None, + embedding_processor=None, + projection=None): + super(TransformerDecoderStep, self).__init__(auto_prefix=False) + self.num_hidden_layers = num_hidden_layers + + self.tfm_embedding_lookup = embedding_lookup + self.tfm_embedding_processor = embedding_processor + self.projection = projection + + self.tfm_decoder = TransformerDecoder( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=-1, # -1 means length is not fixed + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + + self.ones_like = P.OnesLike() + self.shape = P.Shape() + + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() + self.expand = P.ExpandDims() + self.multiply = P.Mul() + + ones = np.ones(shape=(max_decode_length, max_decode_length)) + self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) + + self.cast_compute_type = CastWrapper(dst_type=compute_type) + + def construct(self, input_ids, enc_states, enc_attention_mask): + # input_ids: [batch_size * beam_width] + # process embedding + input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids) + input_embedding = self.tfm_embedding_processor(input_embedding) + input_embedding = self.cast_compute_type(input_embedding) + + input_shape = self.shape(input_ids) + input_len = input_shape[1] + future_mask = self.future_mask[0:input_len:1, 0:input_len:1] + + input_mask = self.ones_like(input_ids) + input_mask = self._create_attention_mask_from_input_mask(input_mask) + input_mask = self.multiply(input_mask, self.expand(future_mask, 0)) + input_mask = self.cast_compute_type(input_mask) + + enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] + + # call TransformerDecoder + decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask) + + # take the last step + decoder_output = decoder_output[::, input_len-1:input_len:1, ::] + + # projection and log_prob + log_probs = self.projection(decoder_output, embedding_tables) + + return log_probs + + +class TransformerModel(nn.Cell): + """ + Transformer with encoder and decoder. + + Args: + config (Class): Configuration for Transformer. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(TransformerModel, self).__init__() + config = copy.deepcopy(config) + self.is_training = is_training + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.input_mask_from_dataset = config.input_mask_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + + self.last_idx = self.num_hidden_layers - 1 + + self.tfm_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.tfm_encoder = TransformerEncoder( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type) + + if is_training: + self.projection = PredLogProbs( + batch_size=self.batch_size, + seq_length=self.seq_length, + width=self.hidden_size, + compute_type=config.compute_type, + dtype=config.dtype) + self.tfm_decoder = TransformerDecoder( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + enc_seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type) + else: + self.projection = PredLogProbs( + batch_size=self.batch_size * config.beam_width, + seq_length=1, + width=self.hidden_size, + compute_type=config.compute_type, + dtype=config.dtype) + self.tfm_decoder = TransformerDecoderStep( + batch_size=self.batch_size * config.beam_width, + hidden_size=self.hidden_size, + enc_seq_length=self.seq_length, + max_decode_length=config.max_decode_length, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=False, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + embedding_lookup=self.tfm_embedding_lookup, + embedding_processor=self.tfm_embedding_postprocessor_for_decoder, + projection=self.projection) + self.tfm_decoder = BeamSearchDecoder( + batch_size=config.batch_size, + seq_length=config.seq_length, + vocab_size=config.vocab_size, + decoder=self.tfm_decoder, + beam_width=config.beam_width, + length_penalty_weight=config.length_penalty_weight, + max_decode_length=config.max_decode_length) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = CastWrapper(dst_type=config.compute_type) + self.expand = P.ExpandDims() + self.multiply = P.Mul() + + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() + + if is_training: + ones = np.ones(shape=(self.seq_length, self.seq_length)) + self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) + else: + self.tile_beam = TileBeam(beam_width=config.beam_width) + ones = np.ones(shape=(config.batch_size, config.max_decode_length)) + self.encdec_mask = Tensor(ones, dtype=mstype.float32) + + def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): + # process source sentence + src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) + src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) + # transformer encoder + encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output), + self.cast_compute_type(enc_attention_mask)) + + if self.is_training: + # process target sentence + tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) + tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) + tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) + # transformer decoder + decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output), + self.cast_compute_type(tgt_attention_mask), + encoder_output, enc_attention_mask) + # calculate logits and log_probs + log_probs = self.projection(decoder_output, embedding_tables) + ret = log_probs + else: + beam_encoder_output = self.tile_beam(encoder_output) + + enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1)) + + beam_enc_attention_mask = self.tile_beam(enc_attention_mask) + beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) + predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask) + ret = predicted_ids + return ret diff --git a/model_zoo/Transformer/src/weight_init.py b/model_zoo/official/nlp/transformer/src/weight_init.py similarity index 100% rename from model_zoo/Transformer/src/weight_init.py rename to model_zoo/official/nlp/transformer/src/weight_init.py diff --git a/model_zoo/official/nlp/transformer/train.py b/model_zoo/official/nlp/transformer/train.py new file mode 100644 index 0000000000..f84c4214e3 --- /dev/null +++ b/model_zoo/official/nlp/transformer/train.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer training script.""" + +import time +import argparse +import random +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.nn.optim import Adam +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore.train.callback import Callback, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de +import mindspore.communication.management as D +from mindspore.train.parallel_utils import ParallelMode +from mindspore import context + +from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \ + TransformerTrainOneStepWithLossScaleCell +from src.config import cfg, transformer_net_cfg +from src.dataset import create_transformer_dataset +from src.lr_schedule import create_dynamic_lr + +random_seed = 1 +random.seed(random_seed) +np.random.seed(random_seed) +de.config.set_seed(random_seed) + +def get_ms_timestamp(): + t = time.time() + return int(round(t * 1000)) +time_stamp_init = False +time_stamp_first = 0 + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss is NAN or INF terminating training. + Note: + If per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + global time_stamp_init, time_stamp_first + if not time_stamp_init: + time_stamp_first = get_ms_timestamp() + time_stamp_init = True + + def step_end(self, run_context): + global time_stamp_first + time_stamp_current = get_ms_timestamp() + cb_params = run_context.original_args() + print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, + cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) + with open("./loss.log", "a+") as f: + f.write("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, + cb_params.cur_epoch_num, + cb_params.cur_step_num, + str(cb_params.net_outputs))) + f.write('\n') + + +def argparse_init(): + """ + Argparse init. + """ + parser = argparse.ArgumentParser(description='transformer') + parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.") + parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") + parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, " + "default is true.") + parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, " + "default is 2500.") + parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.") + parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, " + "default is ./checkpoint/") + parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path") + return parser + +def run_transformer_train(): + """ + Transformer training. + """ + parser = argparse_init() + args, _ = parser.parse_known_args() + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) + context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) + + if args.distribute == "true": + device_num = args.device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + parameter_broadcast=True, device_num=device_num) + D.init() + rank_id = args.device_id % device_num + else: + device_num = 1 + rank_id = 0 + dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num, + rank_id=rank_id, do_shuffle=args.do_shuffle, + enable_data_sink=args.enable_data_sink, + dataset_path=args.data_path) + + netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) + + if args.checkpoint_path: + parameter_dict = load_checkpoint(args.checkpoint_path) + load_param_into_net(netwithloss, parameter_dict) + + lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay", + training_steps=dataset.get_dataset_size()*args.epoch_size, + learning_rate=cfg.lr_schedule.learning_rate, + warmup_steps=cfg.lr_schedule.warmup_steps, + hidden_size=transformer_net_cfg.hidden_size, + start_decay_step=cfg.lr_schedule.start_decay_step, + min_lr=cfg.lr_schedule.min_lr), mstype.float32) + optimizer = Adam(netwithloss.trainable_params(), lr) + + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] + if args.enable_save_ckpt == "true": + if device_num == 1 or (device_num > 1 and rank_id == 0): + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, + keep_checkpoint_max=args.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) + callbacks.append(ckpoint_cb) + + if args.enable_lossscale == "true": + scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + update_cell = scale_manager.get_update_cell() + netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, + scale_update_cell=update_cell) + else: + netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer) + + netwithgrads.set_train(True) + model = Model(netwithgrads) + model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) + +if __name__ == '__main__': + run_transformer_train() diff --git a/model_zoo/official/recommend/.gitkeep b/model_zoo/official/recommend/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/recommend/deepfm/README.md b/model_zoo/official/recommend/deepfm/README.md new file mode 100644 index 0000000000..81cb023b97 --- /dev/null +++ b/model_zoo/official/recommend/deepfm/README.md @@ -0,0 +1,134 @@ +# DeepFM Description + +This is an example of training DeepFM with Criteo dataset in MindSpore. + +[Paper](https://arxiv.org/pdf/1703.04247.pdf) Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He + + +# Model architecture + +The overall network architecture of DeepFM is show below: + +[Link](https://arxiv.org/pdf/1703.04247.pdf) + + +# Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download the criteo dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + +# Script description + +## Script and sample code + +```shell +├── deepfm + ├── README.md + ├── scripts + │ ├──run_distribute_train.sh + │ ├──run_standalone_train.sh + │ ├──run_eval.sh + ├── src + │ ├──__init__.py + │ ├──config.py + │ ├──dataset.py + │ ├──callback.py + │ ├──deepfm.py + ├── train.py + ├── eval.py +``` + +## Training process + +### Usage + +- sh run_train.sh [DEVICE_NUM] [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PAHT] +- python train.py --dataset_path [DATASET_PATH] + +### Launch + +``` +# distribute training example + sh scripts/run_distribute_train.sh 8 /opt/dataset/criteo /opt/mindspore_hccl_file.json +# standalone training example + sh scripts/run_standalone_train.sh 0 /opt/dataset/criteo + or + python train.py --dataset_path /opt/dataset/criteo > output.log 2>&1 & +``` + +### Result + +Training result will be stored in the example path. +Checkpoints will be stored at `./checkpoint` by default, +and training log will be redirected to `./output.log` by default, +and loss log will be redirected to `./loss.log` by default, +and eval log will be redirected to `./auc.log` by default. + + +## Eval process + +### Usage + +- sh run_eval.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH] + +### Launch + +``` +# infer example + sh scripts/run_eval.sh 0 ~/criteo/eval/ ~/train/deepfm-15_41257.ckpt +``` + +> checkpoint can be produced in training process. + +### Result + +Inference result will be stored in the example path, you can find result like the followings in `auc.log`. + +``` +2020-05-27 20:51:35 AUC: 0.80577889065281, eval time: 35.55999s. +``` + +# Model description + +## Performance + +### Training Performance + +| Parameters | DeepFM | +| -------------------------- | ------------------------------------------------------| +| Model Version | | +| Resource | Ascend 910, cpu:2.60GHz 96cores, memory:1.5T | +| uploaded Date | 05/27/2020 | +| MindSpore Version | 0.2.0 | +| Dataset | Criteo | +| Training Parameters | src/config.py | +| Optimizer | Adam | +| Loss Function | SoftmaxCrossEntropyWithLogits | +| outputs | | +| Loss | 0.4234 | +| Accuracy | AUC[0.8055] | +| Total time | 91 min | +| Params (M) | | +| Checkpoint for Fine tuning | | +| Model for inference | | + +#### Inference Performance + +| Parameters | | | +| -------------------------- | ----------------------------- | ------------------------- | +| Model Version | | | +| Resource | Ascend 910 | Ascend 310 | +| uploaded Date | 05/27/2020 | 05/27/2020 | +| MindSpore Version | 0.2.0 | 0.2.0 | +| Dataset | Criteo | | +| batch_size | 1000 | | +| outputs | | | +| Accuracy | AUC[0.8055] | | +| Speed | | | +| Total time | 35.559s | | +| Model for inference | | | + +# ModelZoo Homepage + [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/model_zoo/deepfm/eval.py b/model_zoo/official/recommend/deepfm/eval.py similarity index 100% rename from model_zoo/deepfm/eval.py rename to model_zoo/official/recommend/deepfm/eval.py diff --git a/model_zoo/official/recommend/deepfm/scripts/run_distribute_train.sh b/model_zoo/official/recommend/deepfm/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..6da44819d7 --- /dev/null +++ b/model_zoo/official/recommend/deepfm/scripts/run_distribute_train.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "Please run the script as: " +echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT" +echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json" +echo "After running the script, the network runs in the background, The log will be generated in logx/output.log" + + +export RANK_SIZE=$1 +DATA_URL=$2 +export RANK_TABLE_FILE=$3 + +for ((i=0; i env.log + python -u train.py \ + --dataset_path=$DATA_URL \ + --ckpt_path="checkpoint" \ + --eval_file_name='auc.log' \ + --loss_file_name='loss.log' \ + --do_eval=True > output.log 2>&1 & + cd ../ +done diff --git a/model_zoo/deepfm/scripts/run_eval.sh b/model_zoo/official/recommend/deepfm/scripts/run_eval.sh similarity index 100% rename from model_zoo/deepfm/scripts/run_eval.sh rename to model_zoo/official/recommend/deepfm/scripts/run_eval.sh diff --git a/model_zoo/deepfm/scripts/run_standalone_train.sh b/model_zoo/official/recommend/deepfm/scripts/run_standalone_train.sh similarity index 100% rename from model_zoo/deepfm/scripts/run_standalone_train.sh rename to model_zoo/official/recommend/deepfm/scripts/run_standalone_train.sh diff --git a/model_zoo/lstm/src/__init__.py b/model_zoo/official/recommend/deepfm/src/__init__.py similarity index 100% rename from model_zoo/lstm/src/__init__.py rename to model_zoo/official/recommend/deepfm/src/__init__.py diff --git a/model_zoo/deepfm/src/callback.py b/model_zoo/official/recommend/deepfm/src/callback.py similarity index 100% rename from model_zoo/deepfm/src/callback.py rename to model_zoo/official/recommend/deepfm/src/callback.py diff --git a/model_zoo/deepfm/src/config.py b/model_zoo/official/recommend/deepfm/src/config.py similarity index 100% rename from model_zoo/deepfm/src/config.py rename to model_zoo/official/recommend/deepfm/src/config.py diff --git a/model_zoo/deepfm/src/dataset.py b/model_zoo/official/recommend/deepfm/src/dataset.py similarity index 100% rename from model_zoo/deepfm/src/dataset.py rename to model_zoo/official/recommend/deepfm/src/dataset.py diff --git a/model_zoo/deepfm/src/deepfm.py b/model_zoo/official/recommend/deepfm/src/deepfm.py similarity index 100% rename from model_zoo/deepfm/src/deepfm.py rename to model_zoo/official/recommend/deepfm/src/deepfm.py diff --git a/model_zoo/official/recommend/deepfm/train.py b/model_zoo/official/recommend/deepfm/train.py new file mode 100644 index 0000000000..ff110cd5ab --- /dev/null +++ b/model_zoo/official/recommend/deepfm/train.py @@ -0,0 +1,91 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_criteo.""" +import os +import sys +import argparse + +from mindspore import context, ParallelMode +from mindspore.communication.management import init +from mindspore.train.model import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor + +from src.deepfm import ModelBuilder, AUCMetric +from src.config import DataConfig, ModelConfig, TrainConfig +from src.dataset import create_dataset, DataType +from src.callback import EvalCallBack, LossCallBack + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parser = argparse.ArgumentParser(description='CTR Prediction') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path') +parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') +parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') +parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') + +args_opt, _ = parser.parse_known_args() +device_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) + + +if __name__ == '__main__': + data_config = DataConfig() + model_config = ModelConfig() + train_config = TrainConfig() + + rank_size = int(os.environ.get("RANK_SIZE", 1)) + if rank_size > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) + init() + rank_id = int(os.environ.get('RANK_ID')) + else: + rank_size = None + rank_id = None + + ds_train = create_dataset(args_opt.dataset_path, + train_mode=True, + epochs=1, + batch_size=train_config.batch_size, + data_type=DataType(data_config.data_format), + rank_size=rank_size, + rank_id=rank_id) + + model_builder = ModelBuilder(ModelConfig, TrainConfig) + train_net, eval_net = model_builder.get_train_eval_net() + auc_metric = AUCMetric() + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) + loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name) + callback_list = [time_callback, loss_callback] + + if train_config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, + keep_checkpoint_max=train_config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, + directory=args_opt.ckpt_path, + config=config_ck) + callback_list.append(ckpt_cb) + + if args_opt.do_eval: + ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, + epochs=1, + batch_size=train_config.batch_size, + data_type=DataType(data_config.data_format)) + eval_callback = EvalCallBack(model, ds_eval, auc_metric, + eval_file_path=args_opt.eval_file_name) + callback_list.append(eval_callback) + model.train(train_config.train_epochs, ds_train, callbacks=callback_list) diff --git a/model_zoo/official/recommend/wide_and_deep/README.md b/model_zoo/official/recommend/wide_and_deep/README.md new file mode 100644 index 0000000000..837b856dab --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/README.md @@ -0,0 +1,179 @@ +Recommendation Model +## Overview +This is an implementation of WideDeep as described in the [Wide & Deep Learning for Recommender System](https://arxiv.org/pdf/1606.07792.pdf) paper. + +WideDeep model jointly trained wide linear models and deep neural network, which combined the benefits of memorization and generalization for recommender systems. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the dataset and convert the dataset to mindrecord, command as follows: +``` +python src/preprocess_data.py --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0 + +``` +Arguments: + * `--data_type` {criteo,synthetic}: Currently we support criteo dataset and synthetic dataset.(Default: ./criteo_data/). + * `--data_path` : The path of the data file. + * `--dense_dim` : The number of your continues fields. + * `--slot_dim` : The number of your sparse fields, it can also be called category features. + * `--threshold` : Word frequency below this value will be regarded as OOV. It aims to reduce the vocab size. + * `--train_line_count`: The number of examples in your dataset. + * `--skip_id_convert`: 0 or 1. If set 1, the code will skip the id convert, regarding the original id as the final id. + + +## Dataset +The common used benchmark datasets are used for model training and evaluation. + + +### Generate the synthetic Data + +The following command will generate 40 million lines of click data, in the format of "label\tdense_feature[0]\tdense_feature[1]...\tsparse_feature[0]\tsparse_feature[1]...". +``` +mkdir -p syn_data/origin_data +python src/generate_synthetic_data.py --output_file=syn_data/origin_data/train.txt --number_examples=40000000 --dense_dim=13 --slot_dim=51 --vocabulary_size=2000000000 --random_slot_values=0 +``` +Arguments: + * `--output_file`: The output path of the generated file + * `--label_dim` : The label category + * `--number_examples`: The row numbers of the generated file + * `--dense_dim` : The number of the continue feature. + * `--slot_dim`: The number of the category features + * `--vocabulary_size`: The vocabulary size of the total dataset + * `--random_slot_values`: 0 or 1. If 1, the id is generated by the random. If 0, the id is set by the row_index mod part_size, where + part_size is the vocab size for each slot + +Preprocess the generated data +``` +python src/preprocess_data.py --data_path=./syn_data/ --data_type=synthetic --dense_dim=13 --slot_dim=51 --threshold=0 --train_line_count=40000000 --skip_id_convert=1 +``` + + + +## Running Code + +### Code Structure +The entire code structure is as following: +``` +|--- wide_and_deep/ + train_and_eval.py "Entrance of Wide&Deep model training and evaluation" + eval.py "Entrance of Wide&Deep model evaluation" + train.py "Entrance of Wide&Deep model training" + train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" + train_and_eval_auto_parallel.py + train_and_eval_parameter_server.py "Entrance of Wide&Deep model parameter server training and evaluation" + |--- src/ "Entrance of training and evaluation" + config.py "Parameters configuration" + dataset.py "Dataset loader class" + process_data.py "Process dataset" + preprocess_data.py "Pre_process dataset" + wide_and_deep.py "Model structure" + callbacks.py "Callback class for training and evaluation" + generate_synthetic_data.py "Generate the synthetic data for benchmark" + metrics.py "Metric class" + |--- script/ "Run shell dir" + run_multinpu_train.sh "Run data parallel" + run_auto_parallel_train.sh "Run auto parallel" + run_parameter_server_train.sh "Run parameter server" +``` + + +### Train and evaluate model +To train and evaluate the model, command as follows: +``` +python train_and_eval.py +``` +Arguments: + * `--device_target`: Device where the code will be implemented (Default: Ascend). + * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. + * `--epochs`: Total train epochs. + * `--batch_size`: Training batch size. + * `--eval_batch_size`: Eval batch size. + * `--field_size`: The number of features. + * `--vocab_size`: The total features of dataset. + * `--emb_dim`: The dense embedding dimension of sparse feature. + * `--deep_layers_dim`: The dimension of all deep layers. + * `--deep_layers_act`: The activation of all deep layers. + * `--dropout_flag`: Whether do dropout. + * `--keep_prob`: The rate to keep in dropout layer. + * `--ckpt_path`:The location of the checkpoint file. + * `--eval_file_name` : Eval output file. + * `--loss_file_name` : Loss output file. + * `--dataset_type` : tfrecord/mindrecord/hd5. + +To train the model in one device, command as follows: +``` +python train.py +``` +Arguments: + * `--device_target`: Device where the code will be implemented (Default: Ascend). + * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. + * `--epochs`: Total train epochs. + * `--batch_size`: Training batch size. + * `--eval_batch_size`: Eval batch size. + * `--field_size`: The number of features. + * `--vocab_size`: The total features of dataset. + * `--emb_dim`: The dense embedding dimension of sparse feature. + * `--deep_layers_dim`: The dimension of all deep layers. + * `--deep_layers_act`: The activation of all deep layers. + * `--dropout_flag`: Whether do dropout. + * `--keep_prob`: The rate to keep in dropout layer. + * `--ckpt_path`:The location of the checkpoint file. + * `--eval_file_name` : Eval output file. + * `--loss_file_name` : Loss output file. + * `--dataset_type` : tfrecord/mindrecord/hd5. + +To train the model in distributed, command as follows: +``` +# configure environment path before training +bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE +``` +``` +# configure environment path before training +bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE +``` + +To train the model in clusters, command as follows:''' +``` +# deploy wide&deep script in clusters +# CLUSTER_CONFIG is a json file, the sample is in script/. +# EXECUTE_PATH is the scripts path after the deploy. +bash deploy_cluster.sh CLUSTER_CONFIG_PATH EXECUTE_PATH + +# enter EXECUTE_PATH, and execute start_cluster.sh as follows. +# MODE: "host_device_mix" +bash start_cluster.sh CLUSTER_CONFIG_PATH EPOCH_SIZE VOCAB_SIZE EMB_DIM + DATASET ENV_SH RANK_TABLE_FILE MODE +``` + +To train and evaluate the model in parameter server mode, command as follows:''' +``` +# SERVER_NUM is the number of parameter servers for this task. +# SCHED_HOST is the IP address of scheduler. +# SCHED_PORT is the port of scheduler. +# The number of workers is the same as RANK_SIZE. +bash run_parameter_server_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE SERVER_NUM SCHED_HOST SCHED_PORT +``` + +To evaluate the model, command as follows: +``` +python eval.py +``` +Arguments: + * `--device_target`: Device where the code will be implemented (Default: Ascend). + * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. + * `--epochs`: Total train epochs. + * `--batch_size`: Training batch size. + * `--eval_batch_size`: Eval batch size. + * `--field_size`: The number of features. + * `--vocab_size`: The total features of dataset. + * `--emb_dim`: The dense embedding dimension of sparse feature. + * `--deep_layers_dim`: The dimension of all deep layers. + * `--deep_layers_act`: The activation of all deep layers. + * `--keep_prob`: The rate to keep in dropout layer. + * `--ckpt_path`:The location of the checkpoint file. + * `--eval_file_name` : Eval output file. + * `--loss_file_name` : Loss output file. + +There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions. diff --git a/model_zoo/official/recommend/wide_and_deep/eval.py b/model_zoo/official/recommend/wide_and_deep/eval.py new file mode 100644 index 0000000000..7f664f8aba --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/eval.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" test_training """ + +import os + +from mindspore import Model, context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset, DataType +from src.metrics import AUCMetric +from src.config import WideDeepConfig + + +def get_WideDeep_net(config): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(config) + + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + Wide and deep model builder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def test_eval(config): + """ + test evaluate + """ + data_path = config.data_path + batch_size = config.batch_size + if config.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif config.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size, data_type=dataset_type) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + train_net, eval_net = net_builder.get_net(config) + + param_dict = load_checkpoint(config.ckpt_path) + load_param_into_net(eval_net, param_dict) + + auc_metric = AUCMetric() + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + model.eval(ds_eval, callbacks=eval_callback) + + +if __name__ == "__main__": + widedeep_config = WideDeepConfig() + widedeep_config.argparse_init() + + context.set_context(mode=context.GRAPH_MODE, device_target=widedeep_config.device_target) + test_eval(widedeep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/script/cluster_32p.json b/model_zoo/official/recommend/wide_and_deep/script/cluster_32p.json new file mode 100644 index 0000000000..3f6dfb57fb --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/cluster_32p.json @@ -0,0 +1,21 @@ +{ + "rank_size": 32, + "cluster": { + "xx.xx.xx.xx": { + "user": "", + "passwd": "" + }, + "xx.xx.xx.xx": { + "user": "", + "passwd": "" + }, + "xx.xx.xx.xx": { + "user": "", + "passwd": "" + }, + "xx.xx.xx.xx": { + "user": "", + "passwd": "" + } + } +} \ No newline at end of file diff --git a/model_zoo/official/recommend/wide_and_deep/script/common.sh b/model_zoo/official/recommend/wide_and_deep/script/common.sh new file mode 100644 index 0000000000..06164ce42c --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/common.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +SSH="ssh -o StrictHostKeyChecking=no" +SCP="scp -o StrictHostKeyChecking=no" + +error_msg() +{ + local msg="$*" + echo "[ERROR]: $msg" 1>&2 + exit 1 +} + +ssh_pass() +{ + local node="$1" + local user="$2" + local passwd="$3" + shift 3 + local cmd="$*" + sshpass -p "${passwd}" ${SSH} "${user}"@"${node}" ${cmd} +} + +scp_pass() +{ + local node="$1" + local user="$2" + local passwd="$3" + local src="$4" + local target="$5" + sshpass -p "${passwd}" ${SCP} -r "${src}" "${user}"@"${node}":"${target}" +} + +rscp_pass() +{ + local node="$1" + local user="$2" + local passwd="$3" + local src="$4" + local target="$5" + sshpass -p "${passwd}" ${SCP} -r "${user}"@"${node}":"${src}" "${target}" +} + +get_rank_size() +{ + local cluster_config=$1 + cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["rank_size"])' +} + +get_train_dataset() +{ + local cluster_config=$1 + cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["train_dataset"])' +} + +get_cluster_list() +{ + local cluster_config=$1 + cat ${cluster_config} | python3 -c 'import sys,json;[print(node) for node in json.load(sys.stdin)["cluster"].keys()]' | sort +} + +get_node_user() +{ + local cluster_config=$1 + local node=$2 + cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["user"])' +} + +get_node_passwd() +{ + local cluster_config=$1 + local node=$2 + cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["passwd"])' +} + +rsync_sshpass() +{ + local node=$1 + local user="$2" + local passwd="$3" + scp_pass "${node}" "${user}" "${passwd}" /usr/local/bin/sshpass /usr/local/bin/sshpass +} diff --git a/model_zoo/official/recommend/wide_and_deep/script/deploy_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/deploy_cluster.sh new file mode 100644 index 0000000000..291181eb1a --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/deploy_cluster.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +SCRIPTPATH="$( cd "$(dirname "$0")" || exit ; pwd -P )" +# shellcheck source=/dev/null +source $SCRIPTPATH/common.sh +cluster_config_path=$1 +execute_path=$2 +RANK_SIZE=$(get_rank_size ${cluster_config_path}) +RANK_START=0 +node_list=$(get_cluster_list ${cluster_config_path}) + +for node in ${node_list} +do + user=$(get_node_user ${cluster_config_path} ${node}) + passwd=$(get_node_passwd ${cluster_config_path} ${node}) + echo "------------------${user}@${node}---------------------" + ssh_pass ${node} ${user} ${passwd} "rm -rf ${execute_path}" + scp_pass ${node} ${user} ${passwd} $SCRIPTPATH/../../wide_and_deep ${execute_path} + RANK_START=$[RANK_START+8] + if [[ $RANK_START -ge $RANK_SIZE ]]; then + break; + fi +done \ No newline at end of file diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train.sh new file mode 100644 index 0000000000..4f8091fc03 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# bash run_multinpu_train.sh +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export RANK_TABLE_FILE=$4 + +for((i=0;i<$RANK_SIZE;i++)); +do + rm -rf ${execute_path}/device_$i/ + mkdir ${execute_path}/device_$i/ + cd ${execute_path}/device_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & +done diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh new file mode 100644 index 0000000000..f3482a4205 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +echo ${execute_path} +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +echo ${self_path} + +export RANK_SIZE=$1 +RANK_START=$2 +EPOCH_SIZE=$3 +VOCAB_SIZE=$4 +EMB_DIM=$5 +DATASET=$6 +ENV_SH=$7 +MODE=$8 +export MINDSPORE_HCCL_CONFIG=$9 +export RANK_TABLE_FILE=$9 +DEVICE_START=0 +# shellcheck source=/dev/null +source $ENV_SH +for((i=0;i<=7;i++)); +do + export RANK_ID=$[i+RANK_START] + export DEVICE_ID=$[i+DEVICE_START] + rm -rf ${execute_path}/device_$RANK_ID + mkdir ${execute_path}/device_$RANK_ID + cd ${execute_path}/device_$RANK_ID || exit + if [ $MODE == "host_device_mix" ]; then + python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & + else + python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & + fi +done \ No newline at end of file diff --git a/model_zoo/wide_and_deep/script/run_multigpu_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_multigpu_train.sh similarity index 100% rename from model_zoo/wide_and_deep/script/run_multigpu_train.sh rename to model_zoo/official/recommend/wide_and_deep/script/run_multigpu_train.sh diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_multinpu_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_multinpu_train.sh new file mode 100644 index 0000000000..a922d6e7c5 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_multinpu_train.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# bash run_multinpu_train.sh +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export RANK_TABLE_FILE=$4 + +for((i=0;i<$RANK_SIZE;i++)); +do + rm -rf ${execute_path}/device_$i/ + mkdir ${execute_path}/device_$i/ + cd ${execute_path}/device_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_distribute.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & +done diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train.sh new file mode 100644 index 0000000000..d7f8d41a52 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export RANK_TABLE_FILE=$4 + +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +export MS_WORKER_NUM=$RANK_SIZE +export MS_SERVER_NUM=$5 +export MS_SCHED_HOST=$6 +export MS_SCHED_PORT=$7 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >sched_$i.log 2>&1 & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >server_$i.log 2>&1 & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >worker_$i.log 2>&1 & +done diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh new file mode 100644 index 0000000000..d2d885d420 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") + +#bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE +# LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM +# SCHED_HOST SCHED_PORT ROLE +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export RANK_TABLE_FILE=$4 + +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +export MS_WORKER_NUM=$RANK_SIZE +export LOCAL_WORKER_NUM=$5 +export LOCAL_SERVER_NUM=$6 +export MS_SERVER_NUM=$7 +export MS_SCHED_HOST=$8 +export MS_SCHED_PORT=$9 +export MS_ROLE=${10} +echo "=====Role is $MS_ROLE======" + + +if [ "$MS_ROLE" == "MS_SCHED" ];then +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >sched_$i.log 2>&1 & +done +fi + +if [ "$MS_ROLE" == "MS_PSERVER" ];then +for((i=0;i<$LOCAL_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >server_$i.log 2>&1 & +done +fi + +if [ "$MS_ROLE" == "MS_WORKER" ];then +for((i=0;i<$LOCAL_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >worker_$i.log 2>&1 & +done +fi diff --git a/model_zoo/wide_and_deep/script/run_standalone_train_for_gpu.sh b/model_zoo/official/recommend/wide_and_deep/script/run_standalone_train_for_gpu.sh similarity index 100% rename from model_zoo/wide_and_deep/script/run_standalone_train_for_gpu.sh rename to model_zoo/official/recommend/wide_and_deep/script/run_standalone_train_for_gpu.sh diff --git a/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh new file mode 100644 index 0000000000..06c9e4dfb5 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +echo ${execute_path} +script_self=$(readlink -f "$0") +SCRIPTPATH=$(dirname "${script_self}") +echo ${SCRIPTPATH} +# shellcheck source=/dev/null +source $SCRIPTPATH/common.sh +cluster_config_path=$1 +RANK_SIZE=$(get_rank_size ${cluster_config_path}) +RANK_START=0 +node_list=$(get_cluster_list ${cluster_config_path}) +EPOCH_SIZE=$2 +VOCAB_SIZE=$3 +EMB_DIM=$4 +DATASET=$5 +RANK_TABLE_FILE=$6 +ENV_SH=$7 +MODE=$8 + +for node in ${node_list} +do + user=$(get_node_user ${cluster_config_path} ${node}) + passwd=$(get_node_passwd ${cluster_config_path} ${node}) + echo "------------------${user}@${node}---------------------" + if [ $MODE == "host_device_mix" ]; then + ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}" + else + echo "[ERROR] mode is wrong" + exit 1 + fi + RANK_START=$[RANK_START+8] + if [[ $RANK_START -ge $RANK_SIZE ]]; then + break; + fi +done \ No newline at end of file diff --git a/model_zoo/official/recommend/wide_and_deep/src/__init__.py b/model_zoo/official/recommend/wide_and_deep/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py new file mode 100644 index 0000000000..9325292705 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py @@ -0,0 +1,110 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +callbacks +""" +import time +from mindspore.train.callback import Callback +from mindspore import context +from mindspore.train import ParallelMode + +def add_write(file_path, out_str): + """ + add lines to the file + """ + with open(file_path, 'a+', encoding="utf-8") as file_out: + file_out.write(out_str + "\n") + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, terminate the training. + + Note: + If per_print_times is 0, do NOT print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, config=None, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("per_print_times must be in and >= 0.") + self._per_print_times = per_print_times + self.config = config + + def step_end(self, run_context): + cb_params = run_context.original_args() + wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + cur_num = cb_params.cur_step_num + print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True) + + # raise ValueError + if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: + loss_file = open(self.config.loss_file_name, "a+") + loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % + (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) + loss_file.write("\n") + loss_file.close() + print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % + (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) + + +class EvalCallBack(Callback): + """ + Monitor the loss in evaluating. + + If the loss is NAN or INF, terminate evaluating. + + Note: + If per_print_times is 0, do NOT print loss. + + Args: + print_per_step (int): Print loss every times. Default: 1. + """ + def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False): + super(EvalCallBack, self).__init__() + if not isinstance(print_per_step, int) or print_per_step < 0: + raise ValueError("print_per_step must be int and >= 0.") + self.print_per_step = print_per_step + self.model = model + self.eval_dataset = eval_dataset + self.aucMetric = auc_metric + self.aucMetric.clear() + self.eval_file_name = config.eval_file_name + self.eval_values = [] + self.host_device_mix = host_device_mix + + def epoch_end(self, run_context): + """ + epoch end + """ + self.aucMetric.clear() + parallel_mode = context.get_auto_parallel_context("parallel_mode") + if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + context.set_auto_parallel_context(strategy_ckpt_save_file="", + strategy_ckpt_load_file="./strategy_train.ckpt") + start_time = time.time() + out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix)) + end_time = time.time() + eval_time = int(end_time - start_time) + + time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) + out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) + print(out_str) + self.eval_values = out.values() + add_write(self.eval_file_name, out_str) diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py new file mode 100644 index 0000000000..54d83e97b9 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" config. """ +import argparse + + +def argparse_init(): + """ + argparse_init + """ + parser = argparse.ArgumentParser(description='WideDeep') + parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], + help="device where the code will be implemented. (Default: Ascend)") + parser.add_argument("--data_path", type=str, default="./test_raw_data/") + parser.add_argument("--epochs", type=int, default=15) + parser.add_argument("--full_batch", type=bool, default=False) + parser.add_argument("--batch_size", type=int, default=16000) + parser.add_argument("--eval_batch_size", type=int, default=16000) + parser.add_argument("--field_size", type=int, default=39) + parser.add_argument("--vocab_size", type=int, default=200000) + parser.add_argument("--emb_dim", type=int, default=80) + parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) + parser.add_argument("--deep_layer_act", type=str, default='relu') + parser.add_argument("--keep_prob", type=float, default=1.0) + parser.add_argument("--dropout_flag", type=int, default=0) + parser.add_argument("--output_path", type=str, default="./output/") + parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") + parser.add_argument("--eval_file_name", type=str, default="eval.log") + parser.add_argument("--loss_file_name", type=str, default="loss.log") + parser.add_argument("--host_device_mix", type=int, default=0) + parser.add_argument("--dataset_type", type=str, default="tfrecord") + parser.add_argument("--parameter_server", type=int, default=0) + return parser + + +class WideDeepConfig(): + """ + WideDeepConfig + """ + def __init__(self): + self.device_target = "Ascend" + self.data_path = "./test_raw_data/" + self.full_batch = False + self.epochs = 15 + self.batch_size = 16000 + self.eval_batch_size = 16000 + self.field_size = 39 + self.vocab_size = 200000 + self.emb_dim = 80 + self.deep_layer_dim = [1024, 512, 256, 128] + self.deep_layer_act = 'relu' + self.weight_bias_init = ['normal', 'normal'] + self.emb_init = 'normal' + self.init_args = [-0.01, 0.01] + self.dropout_flag = False + self.keep_prob = 1.0 + self.l2_coef = 8e-5 + + self.output_path = "./output" + self.eval_file_name = "eval.log" + self.loss_file_name = "loss.log" + self.ckpt_path = "./checkpoints/" + self.host_device_mix = 0 + self.dataset_type = "tfrecord" + self.parameter_server = 0 + + def argparse_init(self): + """ + argparse_init + """ + parser = argparse_init() + args, _ = parser.parse_known_args() + self.device_target = args.device_target + self.data_path = args.data_path + self.epochs = args.epochs + self.full_batch = args.full_batch + self.batch_size = args.batch_size + self.eval_batch_size = args.eval_batch_size + self.field_size = args.field_size + self.vocab_size = args.vocab_size + self.emb_dim = args.emb_dim + self.deep_layer_dim = args.deep_layer_dim + self.deep_layer_act = args.deep_layer_act + self.keep_prob = args.keep_prob + self.weight_bias_init = ['normal', 'normal'] + self.emb_init = 'normal' + self.init_args = [-0.01, 0.01] + self.dropout_flag = bool(args.dropout_flag) + self.l2_coef = 8e-5 + + self.output_path = args.output_path + self.eval_file_name = args.eval_file_name + self.loss_file_name = args.loss_file_name + self.ckpt_path = args.ckpt_path + self.host_device_mix = args.host_device_mix + self.dataset_type = args.dataset_type + self.parameter_server = args.parameter_server diff --git a/model_zoo/wide_and_deep/src/datasets.py b/model_zoo/official/recommend/wide_and_deep/src/datasets.py similarity index 100% rename from model_zoo/wide_and_deep/src/datasets.py rename to model_zoo/official/recommend/wide_and_deep/src/datasets.py diff --git a/model_zoo/official/recommend/wide_and_deep/src/generate_synthetic_data.py b/model_zoo/official/recommend/wide_and_deep/src/generate_synthetic_data.py new file mode 100644 index 0000000000..0a90a6449c --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/src/generate_synthetic_data.py @@ -0,0 +1,95 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Generate the synthetic data for wide&deep model training""" +import time +import argparse +import numpy as np + + +def generate_data(output_path, label_dim, number_examples, dense_dim, slot_dim, vocabulary_size, random_slot_values): + """ + This function generates the synthetic data of the web clicking data. Each row in the output file is as follows + 'label\tdense_feature[0] dense_feature[1] ... sparse_feature[0]...sparse_feature[1]...' + Each value is dilimited by '\t'. + Args: + output_path: string. The output file path of the synthetic data + label_dim: int. The category of the label. For 0-1 clicking problem, it's value is 2 + number_examples: int. The row numbers of the synthetic dataset + dense_dim: int. The number of continue features. + slot_dim: int. The number of the category features + vocabulary_size: int. The value of vocabulary size + random_slot_values: bool. If true, the id is geneted by the random. If false, the id is set by the row_index + mod part_size, where part_size the the vocab size for each slot + """ + + part_size = (vocabulary_size - dense_dim) // slot_dim + + if random_slot_values is True: + print('Each field size is supposed to be {}, so number of examples should be no less than this value'.format( + part_size)) + + start = time.time() + + buffer_data = [] + + with open(output_path, 'w') as fp: + for i in range(number_examples): + example = [] + label = i % label_dim + example.append(label) + + dense_feature = ["{:.3f}".format(j + 0.01 * i % 10) for j in range(dense_dim)] + example.extend(dense_feature) + + if random_slot_values is True: + for j in range(slot_dim): + example.append(dense_dim + np.random.randint(j * part_size, min((j + 1) * part_size, + vocabulary_size - dense_dim - 1))) + else: + sp = i % part_size + example.extend( + [dense_dim + min(sp + j * part_size, vocabulary_size - dense_dim - 1) for j in range(slot_dim)]) + + buffer_data.append("\t".join([str(item) for item in example])) + + if (i + 1) % 10000 == 0: + end = time.time() + speed = 10000 / (end - start) + start = time.time() + print("Processed {} examples with speed {:.2f} examples/s".format(i + 1, speed)) + fp.write('\n'.join(buffer_data) + '\n') + buffer_data = [] + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate Synthetic Data') + + parser.add_argument("--output_file", type=str, default="./train.txt", help='The output path of the generated file') + parser.add_argument("--label_dim", type=int, default=2, help='The label category') + parser.add_argument("--number_examples", type=int, default=4000000, help='The row numbers of the generated file') + parser.add_argument("--dense_dim", type=int, default=13, help='The number of the continue feature.') + parser.add_argument("--slot_dim", type=int, default=26, help="The number of the category features") + parser.add_argument("--vocabulary_size", type=int, default=400000000, + help="The vocabulary size of the total dataset") + parser.add_argument("--random_slot_values", type=int, default=0, + help="If 1, the id is geneted by the random. If false, the id is set by " + "the row_index mod part_size, where part_size the the vocab size for each slot") + args = parser.parse_args() + args.random_slot_values = bool(args.random_slot_values) + + generate_data(output_path=args.output_file, label_dim=args.label_dim, number_examples=args.number_examples, + dense_dim=args.dense_dim, slot_dim=args.slot_dim, vocabulary_size=args.vocabulary_size, + random_slot_values=args.random_slot_values) diff --git a/model_zoo/wide_and_deep/src/metrics.py b/model_zoo/official/recommend/wide_and_deep/src/metrics.py similarity index 100% rename from model_zoo/wide_and_deep/src/metrics.py rename to model_zoo/official/recommend/wide_and_deep/src/metrics.py diff --git a/model_zoo/official/recommend/wide_and_deep/src/preprocess_data.py b/model_zoo/official/recommend/wide_and_deep/src/preprocess_data.py new file mode 100644 index 0000000000..439d16c807 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/src/preprocess_data.py @@ -0,0 +1,309 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Download raw data and preprocessed data.""" +import os +import pickle +import collections +import argparse +import urllib.request +import tarfile +import numpy as np +from mindspore.mindrecord import FileWriter + + +class StatsDict(): + """preprocessed data""" + + def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert): + self.field_size = field_size + self.dense_dim = dense_dim + self.slot_dim = slot_dim + self.skip_id_convert = bool(skip_id_convert) + + self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)] + self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)] + + self.val_min_dict = {col: 0 for col in self.val_cols} + self.val_max_dict = {col: 0 for col in self.val_cols} + + self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} + + self.oov_prefix = "OOV" + + self.cat2id_dict = {} + self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) + self.cat2id_dict.update( + {self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) + + def stats_vals(self, val_list): + """Handling weights column""" + assert len(val_list) == len(self.val_cols) + + def map_max_min(i, val): + key = self.val_cols[i] + if val != "": + if float(val) > self.val_max_dict[key]: + self.val_max_dict[key] = float(val) + if float(val) < self.val_min_dict[key]: + self.val_min_dict[key] = float(val) + + for i, val in enumerate(val_list): + map_max_min(i, val) + + def stats_cats(self, cat_list): + """Handling cats column""" + + assert len(cat_list) == len(self.cat_cols) + + def map_cat_count(i, cat): + key = self.cat_cols[i] + self.cat_count_dict[key][cat] += 1 + + for i, cat in enumerate(cat_list): + map_cat_count(i, cat) + + def save_dict(self, dict_path, prefix=""): + with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.val_max_dict, file_wrt) + with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.val_min_dict, file_wrt) + with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.cat_count_dict, file_wrt) + + def load_dict(self, dict_path, prefix=""): + with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: + self.val_max_dict = pickle.load(file_wrt) + with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: + self.val_min_dict = pickle.load(file_wrt) + with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: + self.cat_count_dict = pickle.load(file_wrt) + print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items()))) + print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items()))) + + def get_cat2id(self, threshold=100): + for key, cat_count_d in self.cat_count_dict.items(): + new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) + for cat_str, _ in new_cat_count_d.items(): + self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) + print("cat2id_dict.size:{}".format(len(self.cat2id_dict))) + print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50])) + + def map_cat2id(self, values, cats): + """Cat to id""" + + def minmax_scale_value(i, val): + max_v = float(self.val_max_dict["val_{}".format(i + 1)]) + return float(val) * 1.0 / max_v + + id_list = [] + weight_list = [] + for i, val in enumerate(values): + if val == "": + id_list.append(i) + weight_list.append(0) + else: + key = "val_{}".format(i + 1) + id_list.append(self.cat2id_dict[key]) + weight_list.append(minmax_scale_value(i, float(val))) + + for i, cat_str in enumerate(cats): + key = "cat_{}".format(i + 1) + "_" + cat_str + if key in self.cat2id_dict: + if self.skip_id_convert is True: + # For the synthetic data, if the generated id is between [0, max_vcoab], but the num examples is l + # ess than vocab_size/ slot_nums the id will still be converted to [0, real_vocab], where real_vocab + # the actually the vocab size, rather than the max_vocab. So a simple way to alleviate this + # problem is skip the id convert, regarding the synthetic data id as the final id. + id_list.append(cat_str) + else: + id_list.append(self.cat2id_dict[key]) + else: + id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) + weight_list.append(1.0) + return id_list, weight_list + + +def mkdir_path(file_path): + if not os.path.exists(file_path): + os.makedirs(file_path) + + +def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot_dim=26): + """Preprocess data and save data""" + with open(file_path, encoding="utf-8") as file_in: + errorline_list = [] + count = 0 + for line in file_in: + count += 1 + line = line.strip("\n") + items = line.split("\t") + if len(items) != (dense_dim + slot_dim + 1): + errorline_list.append(count) + print("Found line length: {}, suppose to be {}, the line is {}".format(len(items), + dense_dim + slot_dim + 1, line)) + continue + if count % 1000000 == 0: + print("Have handled {}w lines.".format(count // 10000)) + values = items[1: dense_dim + 1] + cats = items[dense_dim + 1:] + + assert len(values) == dense_dim, "values.size: {}".format(len(values)) + assert len(cats) == slot_dim, "cats.size: {}".format(len(cats)) + criteo_stats_dict.stats_vals(values) + criteo_stats_dict.stats_cats(cats) + criteo_stats_dict.save_dict(dict_output_path) + + +def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000, + line_per_sample=1000, train_line_count=None, + test_size=0.1, seed=2020, dense_dim=13, slot_dim=26): + """Random split data and save mindrecord""" + if train_line_count is None: + raise ValueError("Please provide training file line count") + test_size = int(train_line_count * test_size) + all_indices = [i for i in range(train_line_count)] + np.random.seed(seed) + np.random.shuffle(all_indices) + print("all_indices.size:{}".format(len(all_indices))) + test_indices_set = set(all_indices[:test_size]) + print("test_indices_set.size:{}".format(len(test_indices_set))) + print("-----------------------" * 10 + "\n" * 2) + + train_data_list = [] + test_data_list = [] + ids_list = [] + wts_list = [] + label_list = [] + + writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21) + writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3) + + schema = {"label": {"type": "float32", "shape": [-1]}, "feat_vals": {"type": "float32", "shape": [-1]}, + "feat_ids": {"type": "int32", "shape": [-1]}} + writer_train.add_schema(schema, "CRITEO_TRAIN") + writer_test.add_schema(schema, "CRITEO_TEST") + + with open(input_file_path, encoding="utf-8") as file_in: + items_error_size_lineCount = [] + count = 0 + train_part_number = 0 + test_part_number = 0 + for i, line in enumerate(file_in): + count += 1 + if count % 1000000 == 0: + print("Have handle {}w lines.".format(count // 10000)) + line = line.strip("\n") + items = line.split("\t") + if len(items) != (1 + dense_dim + slot_dim): + items_error_size_lineCount.append(i) + continue + label = float(items[0]) + values = items[1:1 + dense_dim] + cats = items[1 + dense_dim:] + + assert len(values) == dense_dim, "values.size: {}".format(len(values)) + assert len(cats) == slot_dim, "cats.size: {}".format(len(cats)) + + ids, wts = criteo_stats_dict.map_cat2id(values, cats) + + ids_list.extend(ids) + wts_list.extend(wts) + label_list.append(label) + + if count % line_per_sample == 0: + if i not in test_indices_set: + train_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32), + "feat_vals": np.array(wts_list, dtype=np.float32), + "label": np.array(label_list, dtype=np.float32) + }) + else: + test_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32), + "feat_vals": np.array(wts_list, dtype=np.float32), + "label": np.array(label_list, dtype=np.float32) + }) + if train_data_list and len(train_data_list) % part_rows == 0: + writer_train.write_raw_data(train_data_list) + train_data_list.clear() + train_part_number += 1 + + if test_data_list and len(test_data_list) % part_rows == 0: + writer_test.write_raw_data(test_data_list) + test_data_list.clear() + test_part_number += 1 + + ids_list.clear() + wts_list.clear() + label_list.clear() + + if train_data_list: + writer_train.write_raw_data(train_data_list) + if test_data_list: + writer_test.write_raw_data(test_data_list) + writer_train.commit() + writer_test.commit() + + print("-------------" * 10) + print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount))) + print("-------------" * 10) + np.save("items_error_size_lineCount.npy", items_error_size_lineCount) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="criteo data") + parser.add_argument("--data_type", type=str, default='criteo', choices=['criteo', 'synthetic'], + help='Currently we support criteo dataset and synthetic dataset') + parser.add_argument("--data_path", type=str, default="./criteo_data/", help='The path of the data file') + parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields') + parser.add_argument("--slot_dim", type=int, default=26, + help='The number of your sparse fields, it can also be called catelogy features.') + parser.add_argument("--threshold", type=int, default=100, + help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size') + parser.add_argument("--train_line_count", type=int, help='The number of examples in your dataset') + parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1], + help='Skip the id convert, regarding the original id as the final id.') + + args, _ = parser.parse_known_args() + data_path = args.data_path + + if args.data_type == 'criteo': + download_data_path = data_path + "origin_data/" + mkdir_path(download_data_path) + + url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz" + file_name = download_data_path + '/' + url.split('/')[-1] + urllib.request.urlretrieve(url, filename=file_name) + + tar = tarfile.open(file_name) + names = tar.getnames() + for name in names: + tar.extract(name, path=download_data_path) + tar.close() + target_field_size = args.dense_dim + args.slot_dim + stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim, + skip_id_convert=args.skip_id_convert) + data_file_path = data_path + "origin_data/train.txt" + stats_output_path = data_path + "stats_dict/" + mkdir_path(stats_output_path) + statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim) + + stats.load_dict(dict_path=stats_output_path, prefix="") + stats.get_cat2id(threshold=args.threshold) + + in_file_path = data_path + "origin_data/train.txt" + output_path = data_path + "mindrecord/" + mkdir_path(output_path) + random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000, + train_line_count=args.train_line_count, line_per_sample=1000, + test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim) diff --git a/model_zoo/wide_and_deep/src/process_data.py b/model_zoo/official/recommend/wide_and_deep/src/process_data.py similarity index 100% rename from model_zoo/wide_and_deep/src/process_data.py rename to model_zoo/official/recommend/wide_and_deep/src/process_data.py diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py new file mode 100644 index 0000000000..d66cb0772c --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -0,0 +1,400 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""wide and deep model""" +from mindspore import nn +from mindspore import Parameter, ParameterTuple +import mindspore.common.dtype as mstype +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.nn import Dropout +from mindspore.nn.optim import Adam, FTRL, LazyAdam +# from mindspore.nn.metrics import Metric +from mindspore.common.initializer import Uniform, initializer +# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean +from mindspore.train.parallel_utils import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.communication.management import get_group_size +import numpy as np + +np_type = np.float32 +ms_type = mstype.float32 + + +def init_method(method, shape, name, max_val=1.0): + ''' + parameter init method + ''' + if method in ['uniform']: + params = Parameter(initializer( + Uniform(max_val), shape, ms_type), name=name) + elif method == "one": + params = Parameter(initializer("ones", shape, ms_type), name=name) + elif method == 'zero': + params = Parameter(initializer("zeros", shape, ms_type), name=name) + elif method == "normal": + params = Parameter(initializer("normal", shape, ms_type), name=name) + return params + + +def init_var_dict(init_args, in_vars): + ''' + var init function + ''' + var_map = {} + _, _max_val = init_args + for _, iterm in enumerate(in_vars): + key, shape, method = iterm + if key not in var_map.keys(): + if method in ['random', 'uniform']: + var_map[key] = Parameter(initializer( + Uniform(_max_val), shape, ms_type), name=key) + elif method == "one": + var_map[key] = Parameter(initializer( + "ones", shape, ms_type), name=key) + elif method == "zero": + var_map[key] = Parameter(initializer( + "zeros", shape, ms_type), name=key) + elif method == 'normal': + var_map[key] = Parameter(initializer( + "normal", shape, ms_type), name=key) + return var_map + + +class DenseLayer(nn.Cell): + """ + Dense Layer for Deep Layer of WideDeep Model; + Containing: activation, matmul, bias_add; + Args: + """ + + def __init__(self, input_dim, output_dim, weight_bias_init, act_str, + keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False): + super(DenseLayer, self).__init__() + weight_init, bias_init = weight_bias_init + self.weight = init_method( + weight_init, [input_dim, output_dim], name="weight") + self.bias = init_method(bias_init, [output_dim], name="bias") + self.act_func = self._init_activation(act_str) + self.matmul = P.MatMul(transpose_b=False) + self.bias_add = P.BiasAdd() + self.cast = P.Cast() + self.dropout = Dropout(keep_prob=keep_prob) + self.use_activation = use_activation + self.convert_dtype = convert_dtype + self.drop_out = drop_out + + def _init_activation(self, act_str): + act_str = act_str.lower() + if act_str == "relu": + act_func = P.ReLU() + elif act_str == "sigmoid": + act_func = P.Sigmoid() + elif act_str == "tanh": + act_func = P.Tanh() + return act_func + + def construct(self, x): + ''' + Construct Dense layer + ''' + if self.training and self.drop_out: + x = self.dropout(x) + if self.convert_dtype: + x = self.cast(x, mstype.float16) + weight = self.cast(self.weight, mstype.float16) + bias = self.cast(self.bias, mstype.float16) + wx = self.matmul(x, weight) + wx = self.bias_add(wx, bias) + if self.use_activation: + wx = self.act_func(wx) + wx = self.cast(wx, mstype.float32) + else: + wx = self.matmul(x, self.weight) + wx = self.bias_add(wx, self.bias) + if self.use_activation: + wx = self.act_func(wx) + return wx + + +class WideDeepModel(nn.Cell): + """ + From paper: " Wide & Deep Learning for Recommender Systems" + Args: + config (Class): The default config of Wide&Deep + """ + + def __init__(self, config): + super(WideDeepModel, self).__init__() + self.batch_size = config.batch_size + host_device_mix = bool(config.host_device_mix) + parameter_server = bool(config.parameter_server) + parallel_mode = _get_parallel_mode() + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + if is_auto_parallel: + self.batch_size = self.batch_size * get_group_size() + self.field_size = config.field_size + self.vocab_size = config.vocab_size + self.emb_dim = config.emb_dim + self.deep_layer_dims_list = config.deep_layer_dim + self.deep_layer_act = config.deep_layer_act + self.init_args = config.init_args + self.weight_init, self.bias_init = config.weight_bias_init + self.weight_bias_init = config.weight_bias_init + self.emb_init = config.emb_init + self.drop_out = config.dropout_flag + self.keep_prob = config.keep_prob + self.deep_input_dims = self.field_size * self.emb_dim + self.layer_dims = self.deep_layer_dims_list + [1] + self.all_dim_list = [self.deep_input_dims] + self.layer_dims + + init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init), + ('V_l2', [self.vocab_size, self.emb_dim], self.emb_init), + ('Wide_b', [1], self.emb_init)] + var_map = init_var_dict(self.init_args, init_acts) + self.wide_w = var_map["Wide_w"] + self.wide_b = var_map["Wide_b"] + self.embedding_table = var_map["V_l2"] + if parameter_server: + self.wide_w.set_param_ps() + self.embedding_table.set_param_ps() + self.dense_layer_1 = DenseLayer(self.all_dim_list[0], + self.all_dim_list[1], + self.weight_bias_init, + self.deep_layer_act, + convert_dtype=True, drop_out=config.dropout_flag) + self.dense_layer_2 = DenseLayer(self.all_dim_list[1], + self.all_dim_list[2], + self.weight_bias_init, + self.deep_layer_act, + convert_dtype=True, drop_out=config.dropout_flag) + self.dense_layer_3 = DenseLayer(self.all_dim_list[2], + self.all_dim_list[3], + self.weight_bias_init, + self.deep_layer_act, + convert_dtype=True, drop_out=config.dropout_flag) + self.dense_layer_4 = DenseLayer(self.all_dim_list[3], + self.all_dim_list[4], + self.weight_bias_init, + self.deep_layer_act, + convert_dtype=True, drop_out=config.dropout_flag) + self.dense_layer_5 = DenseLayer(self.all_dim_list[4], + self.all_dim_list[5], + self.weight_bias_init, + self.deep_layer_act, + use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) + self.wide_mul = P.Mul() + self.deep_mul = P.Mul() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.reshape = P.Reshape() + self.deep_reshape = P.Reshape() + self.square = P.Square() + self.shape = P.Shape() + self.tile = P.Tile() + self.concat = P.Concat(axis=1) + self.cast = P.Cast() + if is_auto_parallel and host_device_mix: + self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) + self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) + self.deep_embeddinglookup = nn.EmbeddingLookup() + self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) + self.wide_embeddinglookup = nn.EmbeddingLookup() + self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) + self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) + self.deep_reshape.add_prim_attr("skip_redistribution", True) + self.reduce_sum.add_prim_attr("cross_batch", True) + elif parameter_server: + self.deep_embeddinglookup = nn.EmbeddingLookup() + self.wide_embeddinglookup = nn.EmbeddingLookup() + else: + self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') + self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') + + def construct(self, id_hldr, wt_hldr): + """ + Args: + id_hldr: batch ids; + wt_hldr: batch weights; + """ + mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) + # Wide layer + wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr) + wx = self.wide_mul(wide_id_weight, mask) + wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) + # Deep layer + deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr) + vx = self.deep_mul(deep_id_embs, mask) + deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) + deep_in = self.dense_layer_1(deep_in) + deep_in = self.dense_layer_2(deep_in) + deep_in = self.dense_layer_3(deep_in) + deep_in = self.dense_layer_4(deep_in) + deep_out = self.dense_layer_5(deep_in) + out = wide_out + deep_out + return out, self.embedding_table + + +class NetWithLossClass(nn.Cell): + + """" + Provide WideDeep training loss through network. + Args: + network (Cell): The training network + config (Class): WideDeep config + """ + + def __init__(self, network, config): + super(NetWithLossClass, self).__init__(auto_prefix=False) + host_device_mix = bool(config.host_device_mix) + parameter_server = bool(config.parameter_server) + parallel_mode = _get_parallel_mode() + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server) + self.network = network + self.l2_coef = config.l2_coef + self.loss = P.SigmoidCrossEntropyWithLogits() + self.square = P.Square() + self.reduceMean_false = P.ReduceMean(keep_dims=False) + if is_auto_parallel: + self.reduceMean_false.add_prim_attr("cross_batch", True) + self.reduceSum_false = P.ReduceSum(keep_dims=False) + + def construct(self, batch_ids, batch_wts, label): + ''' + Construct NetWithLossClass + ''' + predict, embedding_table = self.network(batch_ids, batch_wts) + log_loss = self.loss(predict, label) + wide_loss = self.reduceMean_false(log_loss) + if self.no_l2loss: + deep_loss = wide_loss + else: + l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 + deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v + + return wide_loss, deep_loss + + +class IthOutputCell(nn.Cell): + def __init__(self, network, output_index): + super(IthOutputCell, self).__init__() + self.network = network + self.output_index = output_index + + def construct(self, x1, x2, x3): + predict = self.network(x1, x2, x3)[self.output_index] + return predict + + +class TrainStepWrap(nn.Cell): + """ + Encapsulation class of WideDeep network training. + Append Adam and FTRL optimizers to the training network after that construct + function can be called to create the backward graph. + Args: + network (Cell): The training network. Note that loss function should have been added. + sens (Number): The adjust parameter. Default: 1024.0 + host_device_mix (Bool): Whether run in host and device mix mode. Default: False + parameter_server (Bool): Whether run in parameter server mode. Default: False + """ + + def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False): + super(TrainStepWrap, self).__init__() + parallel_mode = _get_parallel_mode() + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + self.network = network + self.network.set_train() + self.trainable_params = network.trainable_params() + weights_w = [] + weights_d = [] + for params in self.trainable_params: + if 'wide' in params.name: + weights_w.append(params) + else: + weights_d.append(params) + self.weights_w = ParameterTuple(weights_w) + self.weights_d = ParameterTuple(weights_d) + + if host_device_mix and is_auto_parallel: + self.optimizer_d = LazyAdam( + self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) + self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) + self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU") + self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU") + elif parameter_server: + self.optimizer_d = Adam( + self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) + self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) + self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU") + self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU") + else: + self.optimizer_d = Adam( + self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) + self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) + self.hyper_map = C.HyperMap() + self.grad_w = C.GradOperation('grad_w', get_by_list=True, + sens_param=True) + self.grad_d = C.GradOperation('grad_d', get_by_list=True, + sens_param=True) + self.sens = sens + self.loss_net_w = IthOutputCell(network, output_index=0) + self.loss_net_d = IthOutputCell(network, output_index=1) + + self.reducer_flag = False + self.grad_reducer_w = None + self.grad_reducer_d = None + parallel_mode = _get_parallel_mode() + self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL, + ParallelMode.HYBRID_PARALLEL) + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree) + self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree) + + def construct(self, batch_ids, batch_wts, label): + ''' + Construct wide and deep model + ''' + weights_w = self.weights_w + weights_d = self.weights_d + loss_w, loss_d = self.network(batch_ids, batch_wts, label) + sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) + sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) + grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts, + label, sens_w) + grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts, + label, sens_d) + if self.reducer_flag: + grads_w = self.grad_reducer_w(grads_w) + grads_d = self.grad_reducer_d(grads_d) + return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, + self.optimizer_d(grads_d)) + + +class PredictWithSigmoid(nn.Cell): + def __init__(self, network): + super(PredictWithSigmoid, self).__init__() + self.network = network + self.sigmoid = P.Sigmoid() + + def construct(self, batch_ids, batch_wts, labels): + logits, _, _, = self.network(batch_ids, batch_wts) + pred_probs = self.sigmoid(logits) + return logits, pred_probs, labels diff --git a/model_zoo/official/recommend/wide_and_deep/train.py b/model_zoo/official/recommend/wide_and_deep/train.py new file mode 100644 index 0000000000..4c4e384b6e --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/train.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" test_training """ +import os +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack +from src.datasets import create_dataset, DataType +from src.config import WideDeepConfig + + +def get_WideDeep_net(configure): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(configure) + + loss_net = NetWithLossClass(WideDeep_net, configure) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + Build the model. + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, configure): + return get_WideDeep_net(configure) + + +def test_train(configure): + """ + test_train + """ + data_path = configure.data_path + batch_size = configure.batch_size + epochs = configure.epochs + if configure.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif configure.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size, data_type=dataset_type) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + + net_builder = ModelBuilder() + train_net, _ = net_builder.get_net(configure) + train_net.set_train() + + model = Model(train_net) + callback = LossCallBack(config=configure) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), + keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) + model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb]) + + +if __name__ == "__main__": + config = WideDeepConfig() + config.argparse_init() + + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + test_train(config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval.py new file mode 100644 index 0000000000..1a255ce9e5 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" test_training """ +import os + +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset, DataType +from src.metrics import AUCMetric +from src.config import WideDeepConfig + + +def get_WideDeep_net(config): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(config) + + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def test_train_eval(config): + """ + test_train_eval + """ + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + if config.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif config.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size, data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size, data_type=dataset_type) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + callback = LossCallBack(config=config) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) + + out = model.eval(ds_eval) + print("=====" * 5 + "model.eval() initialized: {}".format(out)) + model.train(epochs, ds_train, + callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) + + +if __name__ == "__main__": + wide_deep_config = WideDeepConfig() + wide_deep_config.argparse_init() + + context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) + test_train_eval(wide_deep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py new file mode 100644 index 0000000000..a168b84d79 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -0,0 +1,138 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_multinpu.""" + + +import os +import sys +import mindspore.dataset.engine as de +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train import ParallelMode +from mindspore.communication.management import get_rank, get_group_size, init +from mindspore.parallel import _cost_model_context as cost_model_context +from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset, DataType +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) +context.set_context(variable_memory_max_size="24GB") +context.set_context(enable_sparse=True) +cost_model_context.set_cost_model_context(multi_subgraphs=True) +init() + + + +def get_WideDeep_net(config): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(config) + loss_net = NetWithLossClass(WideDeep_net, config) + loss_net = VirtualDatasetCellTriple(loss_net) + train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix)) + eval_net = PredictWithSigmoid(WideDeep_net) + eval_net = VirtualDatasetCellTriple(eval_net) + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def train_and_eval(config): + """ + test_train_eval + """ + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + if config.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif config.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + host_device_mix = bool(config.host_device_mix) + print("epochs is {}".format(epochs)) + if config.full_batch: + context.set_auto_parallel_context(full_batch=True) + de.config.set_seed(1) + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type) + else: + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) + + callback = LossCallBack(config=config) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path, config=ckptconfig) + context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") + callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] + if not host_device_mix: + callback_list.append(ckpoint_cb) + model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix)) + + +if __name__ == "__main__": + wide_deep_config = WideDeepConfig() + wide_deep_config.argparse_init() + if wide_deep_config.host_device_mix == 1: + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) + else: + context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) + train_and_eval(wide_deep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py new file mode 100644 index 0000000000..5a7cf8c718 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py @@ -0,0 +1,129 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_multinpu.""" + + +import os +import sys +import numpy as np +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train import ParallelMode +from mindspore.communication.management import get_rank, get_group_size, init + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset, DataType +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def get_WideDeep_net(config): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(config) + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net) + eval_net = PredictWithSigmoid(WideDeep_net) + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def train_and_eval(config): + """ + test_train_eval + """ + np.random.seed(1000) + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + if config.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif config.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + print("epochs is {}".format(epochs)) + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + callback = LossCallBack(config=config) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + if config.device_target == "Ascend": + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path, config=ckptconfig) + elif config.device_target == "GPU": + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), + directory=config.ckpt_path, config=ckptconfig) + out = model.eval(ds_eval) + print("=====" * 5 + "model.eval() initialized: {}".format(out)) + model.train(epochs, ds_train, + callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], + sink_size=ds_train.get_dataset_size()) + + +if __name__ == "__main__": + wide_deep_config = WideDeepConfig() + wide_deep_config.argparse_init() + + context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) + if wide_deep_config.device_target == "Ascend": + init("hccl") + elif wide_deep_config.device_target == "GPU": + init("nccl") + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=get_group_size()) + + train_and_eval(wide_deep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py new file mode 100644 index 0000000000..bab19acdc4 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py @@ -0,0 +1,129 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_multinpu.""" + + +import os +import sys +import numpy as np +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train import ParallelMode +from mindspore.communication.management import get_rank, get_group_size, init + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset, DataType +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +context.set_context(enable_sparse=True) + + +def get_WideDeep_net(config): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(config) + loss_net = NetWithLossClass(WideDeep_net, config) + train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server)) + eval_net = PredictWithSigmoid(WideDeep_net) + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def train_and_eval(config): + """ + test_train_eval + """ + np.random.seed(1000) + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + if config.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif config.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + parameter_server = bool(config.parameter_server) + print("epochs is {}".format(epochs)) + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) + + callback = LossCallBack(config=config) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + if config.device_target == "Ascend": + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path, config=ckptconfig) + elif config.device_target == "GPU": + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), + directory=config.ckpt_path, config=ckptconfig) + model.train(epochs, ds_train, + callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], + dataset_sink_mode=(not parameter_server)) + + +if __name__ == "__main__": + wide_deep_config = WideDeepConfig() + wide_deep_config.argparse_init() + + context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) + if wide_deep_config.device_target == "Ascend": + init("hccl") + elif wide_deep_config.device_target == "GPU": + init("nccl") + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=get_group_size()) + + train_and_eval(wide_deep_config) diff --git a/model_zoo/official/utils/.gitkeep b/model_zoo/official/utils/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/research/.gitkeep b/model_zoo/research/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/resnet/README.md b/model_zoo/resnet/README.md deleted file mode 100644 index ad93453602..0000000000 --- a/model_zoo/resnet/README.md +++ /dev/null @@ -1,251 +0,0 @@ -# ResNet Example - -## Description - -These are examples of training ResNet-50/ResNet-101 with CIFAR-10/ImageNet2012 dataset in MindSpore. -(Training ResNet-101 with dataset CIFAR-10 is unsupported now.) - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Download the dataset CIFAR-10 or ImageNet2012 - -CIFAR-10 - -> Unzip the CIFAR-10 dataset to any path you want and the folder structure should include train and eval dataset as follows: -> ``` -> . -> └─dataset -> ├─ cifar-10-batches-bin # train dataset -> └─ cifar-10-verify-bin # evaluate dataset -> ``` - -ImageNet2012 - -> Unzip the ImageNet2012 dataset to any path you want and the folder should include train and eval dataset as follows: -> -> ``` -> . -> └─dataset -> ├─ilsvrc # train dataset -> └─validation_preprocess # evaluate dataset -> ``` - - - -## Structure - -```shell -. -└──resnet - ├── README.md - ├── script - ├── run_distribute_train.sh # launch distributed training(8 pcs) - ├── run_eval.sh # launch evaluation - └── run_standalone_train.sh # launch standalone training(1 pcs) - ├── src - ├── config.py # parameter configuration - ├── dataset.py # data preprocessing - ├── crossentropy.py # loss definition for ImageNet2012 dataset - ├── lr_generator.py # generate learning rate for each step - └── resnet.py # resnet backbone, including resnet50 and resnet101 - ├── eval.py # eval net - └── train.py # train net -``` - - -## Parameter configuration - -Parameters for both training and evaluation can be set in config.py. - -- config for ResNet-50, CIFAR-10 dataset - -``` -"class_num": 10, # dataset class num -"batch_size": 32, # batch size of input tensor -"loss_scale": 1024, # loss scale -"momentum": 0.9, # momentum -"weight_decay": 1e-4, # weight decay -"epoch_size": 90, # only valid for taining, which is always 1 for inference -"save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_steps": 195, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step -"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint -"save_checkpoint_path": "./", # path to save checkpoint -"warmup_epochs": 5, # number of warmup epoch -"lr_decay_mode": "poly" # decay mode can be selected in steps, ploy and default -"lr_init": 0.01, # initial learning rate -"lr_end": 0.00001, # final learning rate -"lr_max": 0.1, # maximum learning rate -``` - -- config for ResNet-50, ImageNet2012 dataset - -``` -"class_num": 1001, # dataset class number -"batch_size": 32, # batch size of input tensor -"loss_scale": 1024, # loss scale -"momentum": 0.9, # momentum optimizer -"weight_decay": 1e-4, # weight decay -"epoch_size": 90, # only valid for taining, which is always 1 for inference -"pretrained_epoch_size": 1, # epoch size that model has been trained before load pretrained checkpoint -"save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch -"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint -"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path -"warmup_epochs": 0, # number of warmup epoch -"lr_decay_mode": "cosine", # decay mode for generating learning rate -"label_smooth": True, # label smooth -"label_smooth_factor": 0.1, # label smooth factor -"lr_init": 0, # initial learning rate -"lr_max": 0.1, # maximum learning rate -``` - -- config for ResNet-101, ImageNet2012 dataset - -``` -"class_num": 1001, # dataset class number -"batch_size": 32, # batch size of input tensor -"loss_scale": 1024, # loss scale -"momentum": 0.9, # momentum optimizer -"weight_decay": 1e-4, # weight decay -"epoch_size": 120, # epoch sizes for training -"pretrain_epoch_size": 0, # epoch size of pretrain checkpoint -"save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch -"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint -"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path -"warmup_epochs": 0, # number of warmup epoch -"lr_decay_mode": "cosine" # decay mode for generating learning rate -"label_smooth": 1, # label_smooth -"label_smooth_factor": 0.1, # label_smooth_factor -"lr": 0.1 # base learning rate -``` - - - -## Running the example - -### Train - -#### Usage - -``` -# distributed training -Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] - [PRETRAINED_CKPT_PATH](optional) - -# standalone training -Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] - [PRETRAINED_CKPT_PATH](optional) -``` - - -#### Launch - -``` -# distribute training example -sh run_distribute_train.sh resnet50 cifar10 rank_table.json ~/cifar-10-batches-bin - -# standalone training example -sh run_standalone_train.sh resnet50 cifar10 ~/cifar-10-batches-bin -``` - -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). - -#### Result - -Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. - -- training ResNet-50 with CIFAR-10 dataset - -``` -# distribute training result(8 pcs) -epoch: 1 step: 195, loss is 1.9601055 -epoch: 2 step: 195, loss is 1.8555021 -epoch: 3 step: 195, loss is 1.6707983 -epoch: 4 step: 195, loss is 1.8162166 -epoch: 5 step: 195, loss is 1.393667 -... -``` - -- training ResNet-50 with ImageNet2012 dataset - -``` -# distribute training result(8 pcs) -epoch: 1 step: 5004, loss is 4.8995576 -epoch: 2 step: 5004, loss is 3.9235563 -epoch: 3 step: 5004, loss is 3.833077 -epoch: 4 step: 5004, loss is 3.2795618 -epoch: 5 step: 5004, loss is 3.1978393 -... -``` - -- training ResNet-101 with ImageNet2012 dataset - -``` -# distribute training result(8p) -epoch: 1 step: 5004, loss is 4.805483 -epoch: 2 step: 5004, loss is 3.2121816 -epoch: 3 step: 5004, loss is 3.429647 -epoch: 4 step: 5004, loss is 3.3667371 -epoch: 5 step: 5004, loss is 3.1718972 -... -epoch: 67 step: 5004, loss is 2.2768745 -epoch: 68 step: 5004, loss is 1.7223864 -epoch: 69 step: 5004, loss is 2.0665488 -epoch: 70 step: 5004, loss is 1.8717369 -... -``` - -### Evaluation - -#### Usage - -``` -# evaluation -Usage: sh run_eval.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] -``` - -#### Launch - -``` -# evaluation example -sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt -``` - -> checkpoint can be produced in training process. - -#### Result - -Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. - -- evaluating ResNet-50 with CIFAR-10 dataset - -``` -result: {'acc': 0.91446314102564111} ckpt=~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt -``` - -- evaluating ResNet-50 with ImageNet2012 dataset - -``` -result: {'acc': 0.7671054737516005} ckpt=train_parallel0/resnet-90_5004.ckpt -``` - -- evaluating ResNet-101 with ImageNet2012 dataset - -``` -result: {'top_5_accuracy': 0.9429417413572343, 'top_1_accuracy': 0.7853513124199744} ckpt=train_parallel0/resnet-120_5004.ckpt -``` - -### Running on GPU -``` -# distributed training example -mpirun -n 8 python train.py ---net=resnet50 --dataset=cifar10 -dataset_path=~/cifar-10-batches-bin --device_target="GPU" --run_distribute=True - -# standalone training example -python train.py --net=resnet50 --dataset=cifar10 --dataset_path=~/cifar-10-batches-bin --device_target="GPU" - -# infer example -python eval.py --net=resnet50 --dataset=cifar10 --dataset_path=~/cifar10-10-verify-bin --device_target="GPU" --checkpoint_path=resnet-90_195.ckpt -``` diff --git a/model_zoo/resnet/eval.py b/model_zoo/resnet/eval.py deleted file mode 100755 index 426b8c9f3d..0000000000 --- a/model_zoo/resnet/eval.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train resnet.""" -import os -import random -import argparse -import numpy as np -from mindspore import context -from mindspore import dataset as de -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.crossentropy import CrossEntropy - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') -parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012') - -parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') -args_opt = parser.parse_args() - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -if args_opt.net == "resnet50": - from src.resnet import resnet50 as resnet - - if args_opt.dataset == "cifar10": - from src.config import config1 as config - from src.dataset import create_dataset1 as create_dataset - else: - from src.config import config2 as config - from src.dataset import create_dataset2 as create_dataset -else: - from src.resnet import resnet101 as resnet - from src.config import config3 as config - from src.dataset import create_dataset3 as create_dataset - -if __name__ == '__main__': - target = args_opt.device_target - - # init context - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, device_id=device_id) - - # create dataset - if args_opt.net == "resnet50": - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, - target=target) - else: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) - step_size = dataset.get_dataset_size() - - # define net - net = resnet(class_num=config.class_num) - - # load checkpoint - param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(net, param_dict) - net.set_train(False) - - # define loss, model - if args_opt.dataset == "imagenet2012": - if not config.use_label_smooth: - config.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - else: - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - - # define model - model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) - - # eval model - res = model.eval(dataset) - print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/model_zoo/resnet/scripts/run_distribute_train.sh b/model_zoo/resnet/scripts/run_distribute_train.sh deleted file mode 100755 index efcb620cd8..0000000000 --- a/model_zoo/resnet/scripts/run_distribute_train.sh +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 4 ] && [ $# != 5 ] -then - echo "Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" -exit 1 -fi - -if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] -then - echo "error: the selected net is neither resnet50 nor resnet101" -exit 1 -fi - -if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] -then - echo "error: the selected dataset is neither cifar10 nor imagenet2012" -exit 1 -fi - -if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] -then - echo "error: training resnet101 with cifar10 dataset is unsupported now!" -exit 1 -fi - - -get_real_path(){ - if [ "${1:0:1}" == "/" ]; then - echo "$1" - else - echo "$(realpath -m $PWD/$1)" - fi -} - -PATH1=$(get_real_path $3) -PATH2=$(get_real_path $4) - -if [ $# == 5 ] -then - PATH3=$(get_real_path $5) -fi - -if [ ! -f $PATH1 ] -then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file" -exit 1 -fi - -if [ ! -d $PATH2 ] -then - echo "error: DATASET_PATH=$PATH2 is not a directory" -exit 1 -fi - -if [ $# == 5 ] && [ ! -f $PATH3 ] -then - echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" -exit 1 -fi - -ulimit -u unlimited -export DEVICE_NUM=8 -export RANK_SIZE=8 -export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 -export RANK_TABLE_FILE=$PATH1 - -for((i=0; i<${DEVICE_NUM}; i++)) -do - export DEVICE_ID=$i - export RANK_ID=$i - rm -rf ./train_parallel$i - mkdir ./train_parallel$i - cp ../*.py ./train_parallel$i - cp *.sh ./train_parallel$i - cp -r ../src ./train_parallel$i - cd ./train_parallel$i || exit - echo "start training for rank $RANK_ID, device $DEVICE_ID" - env > env.log - if [ $# == 4 ] - then - python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log & - fi - - if [ $# == 5 ] - then - python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log & - fi - - cd .. -done diff --git a/model_zoo/resnet/src/dataset.py b/model_zoo/resnet/src/dataset.py deleted file mode 100755 index ac0adc4bc9..0000000000 --- a/model_zoo/resnet/src/dataset.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -create train or eval dataset. -""" -import os -import mindspore.common.dtype as mstype -import mindspore.dataset.engine as de -import mindspore.dataset.transforms.vision.c_transforms as C -import mindspore.dataset.transforms.c_transforms as C2 -from mindspore.communication.management import init, get_rank, get_group_size - - -def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): - """ - create a train or evaluate cifar10 dataset for resnet50 - Args: - dataset_path(string): the path of dataset. - do_train(bool): whether dataset is used for train or eval. - repeat_num(int): the repeat times of dataset. Default: 1 - batch_size(int): the batch size of dataset. Default: 32 - target(str): the device target. Default: Ascend - - Returns: - dataset - """ - if target == "Ascend": - device_num = int(os.getenv("DEVICE_NUM")) - rank_id = int(os.getenv("RANK_ID")) - else: - init("nccl") - rank_id = get_rank() - device_num = get_group_size() - - if device_num == 1: - ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True) - else: - ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True, - num_shards=device_num, shard_id=rank_id) - - # define map operations - trans = [] - if do_train: - trans += [ - C.RandomCrop((32, 32), (4, 4, 4, 4)), - C.RandomHorizontalFlip(prob=0.5) - ] - - trans += [ - C.Resize((224, 224)), - C.Rescale(1.0 / 255.0, 0.0), - C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), - C.HWC2CHW() - ] - - type_cast_op = C2.TypeCast(mstype.int32) - - ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) - ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) - - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - # apply dataset repeat operation - ds = ds.repeat(repeat_num) - - return ds - - -def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): - """ - create a train or eval imagenet2012 dataset for resnet50 - - Args: - dataset_path(string): the path of dataset. - do_train(bool): whether dataset is used for train or eval. - repeat_num(int): the repeat times of dataset. Default: 1 - batch_size(int): the batch size of dataset. Default: 32 - target(str): the device target. Default: Ascend - - Returns: - dataset - """ - if target == "Ascend": - device_num = int(os.getenv("DEVICE_NUM")) - rank_id = int(os.getenv("RANK_ID")) - else: - init("nccl") - rank_id = get_rank() - device_num = get_group_size() - - if device_num == 1: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) - else: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, - num_shards=device_num, shard_id=rank_id) - - image_size = 224 - mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] - std = [0.229 * 255, 0.224 * 255, 0.225 * 255] - - # define map operations - if do_train: - trans = [ - C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), - C.RandomHorizontalFlip(prob=0.5), - C.Normalize(mean=mean, std=std), - C.HWC2CHW() - ] - else: - trans = [ - C.Decode(), - C.Resize((256, 256)), - C.CenterCrop(image_size), - C.Normalize(mean=mean, std=std), - C.HWC2CHW() - ] - - type_cast_op = C2.TypeCast(mstype.int32) - - ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) - ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) - - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - - # apply dataset repeat operation - ds = ds.repeat(repeat_num) - - return ds - - -def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): - """ - create a train or eval imagenet2012 dataset for resnet101 - Args: - dataset_path(string): the path of dataset. - do_train(bool): whether dataset is used for train or eval. - repeat_num(int): the repeat times of dataset. Default: 1 - batch_size(int): the batch size of dataset. Default: 32 - - Returns: - dataset - """ - device_num = int(os.getenv("RANK_SIZE")) - rank_id = int(os.getenv("RANK_ID")) - - if device_num == 1: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) - else: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, - num_shards=device_num, shard_id=rank_id) - resize_height = 224 - rescale = 1.0 / 255.0 - shift = 0.0 - - # define map operations - decode_op = C.Decode() - - random_resize_crop_op = C.RandomResizedCrop(resize_height, (0.08, 1.0), (0.75, 1.33), max_attempts=100) - horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1)) - resize_op_256 = C.Resize((256, 256)) - center_crop = C.CenterCrop(224) - rescale_op = C.Rescale(rescale, shift) - normalize_op = C.Normalize((0.475, 0.451, 0.392), (0.275, 0.267, 0.278)) - changeswap_op = C.HWC2CHW() - - if do_train: - trans = [decode_op, - random_resize_crop_op, - horizontal_flip_op, - rescale_op, - normalize_op, - changeswap_op] - - else: - trans = [decode_op, - resize_op_256, - center_crop, - rescale_op, - normalize_op, - changeswap_op] - - type_cast_op = C2.TypeCast(mstype.int32) - - ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) - ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) - - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - # apply dataset repeat operation - ds = ds.repeat(repeat_num) - - return ds diff --git a/model_zoo/resnet/train.py b/model_zoo/resnet/train.py deleted file mode 100755 index 89ce62d733..0000000000 --- a/model_zoo/resnet/train.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train resnet.""" -import os -import random -import argparse -import numpy as np -from mindspore import context -from mindspore import Tensor -from mindspore import dataset as de -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.model import Model, ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init, get_rank, get_group_size -import mindspore.nn as nn -import mindspore.common.initializer as weight_init -from src.lr_generator import get_lr, warmup_cosine_annealing_lr -from src.crossentropy import CrossEntropy - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') -parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012') -parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') -parser.add_argument('--device_num', type=int, default=1, help='Device num.') - -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') -parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') -args_opt = parser.parse_args() - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -if args_opt.net == "resnet50": - from src.resnet import resnet50 as resnet - - if args_opt.dataset == "cifar10": - from src.config import config1 as config - from src.dataset import create_dataset1 as create_dataset - else: - from src.config import config2 as config - from src.dataset import create_dataset2 as create_dataset -else: - from src.resnet import resnet101 as resnet - from src.config import config3 as config - from src.dataset import create_dataset3 as create_dataset - -if __name__ == '__main__': - target = args_opt.device_target - ckpt_save_dir = config.save_checkpoint_path - - # init context - context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) - if args_opt.run_distribute: - if target == "Ascend": - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(device_id=device_id, enable_auto_mixed_precision=True) - context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - if args_opt.net == "resnet50": - auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) - else: - auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) - init() - # GPU target - else: - init("nccl") - context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" - - # create dataset - if args_opt.net == "resnet50": - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size, - batch_size=config.batch_size, target=target) - else: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size, - batch_size=config.batch_size) - step_size = dataset.get_dataset_size() - - # define net - net = resnet(class_num=config.class_num) - - # init weight - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - else: - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Conv2d): - cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if isinstance(cell, nn.Dense): - cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - - # init lr - if args_opt.net == "resnet50": - if args_opt.dataset == "cifar10": - lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, - warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size, - lr_decay_mode='poly') - else: - lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, - total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') - else: - lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, 120, - config.pretrain_epoch_size * step_size) - lr = Tensor(lr) - - # define opt - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, - config.weight_decay, config.loss_scale) - - # define loss, model - if target == "Ascend": - if args_opt.dataset == "imagenet2012": - if not config.use_label_smooth: - config.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - else: - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False) - else: - # GPU target - loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean') - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - - # define callbacks - time_cb = TimeMonitor(data_size=step_size) - loss_cb = LossMonitor() - cb = [time_cb, loss_cb] - if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) - cb += [ckpt_cb] - - # train model - model.train(config.epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/resnet50_quant/src/launch.py b/model_zoo/resnet50_quant/src/launch.py deleted file mode 100644 index abba92a540..0000000000 --- a/model_zoo/resnet50_quant/src/launch.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""launch train script""" -import os -import sys -import json -import subprocess -import shutil -import platform -from argparse import ArgumentParser - -def parse_args(): - """ - parse args . - - Args: - - Returns: - args. - - Examples: - >>> parse_args() - """ - parser = ArgumentParser(description="mindspore distributed training launch " - "helper utilty that will spawn up " - "multiple distributed processes") - parser.add_argument("--nproc_per_node", type=int, default=1, - help="The number of processes to launch on each node, " - "for D training, this is recommended to be set " - "to the number of D in your system so that " - "each process can be bound to a single D.") - parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", - help="will use the visible devices sequentially") - parser.add_argument("--server_id", type=str, default="", - help="server ip") - parser.add_argument("--training_script", type=str, - help="The full path to the single D training " - "program/script to be launched in parallel, " - "followed by all the arguments for the " - "training script") - # rest from the training program - args, unknown = parser.parse_known_args() - args.training_script_args = unknown - return args - - -def main(): - print("start", __file__) - args = parse_args() - print(args) - visible_devices = args.visible_devices.split(',') - assert os.path.isfile(args.training_script) - assert len(visible_devices) >= args.nproc_per_node - print('visible_devices:{}'.format(visible_devices)) - if not args.server_id: - print('pleaser input server ip!!!') - exit(0) - print('server_id:{}'.format(args.server_id)) - - # construct hccn_table - hccn_configs = open('/etc/hccn.conf', 'r').readlines() - device_ips = {} - for hccn_item in hccn_configs: - hccn_item = hccn_item.strip() - if hccn_item.startswith('address_'): - device_id, device_ip = hccn_item.split('=') - device_id = device_id.split('_')[1] - device_ips[device_id] = device_ip - print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) - hccn_table = {} - arch = platform.processor() - hccn_table['board_id'] = {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch] - hccn_table['chip_info'] = '910' - hccn_table['deploy_mode'] = 'lab' - hccn_table['group_count'] = '1' - hccn_table['group_list'] = [] - instance_list = [] - usable_dev = '' - for instance_id in range(args.nproc_per_node): - instance = {} - instance['devices'] = [] - device_id = visible_devices[instance_id] - device_ip = device_ips[device_id] - usable_dev += str(device_id) - instance['devices'].append({ - 'device_id': device_id, - 'device_ip': device_ip, - }) - instance['rank_id'] = str(instance_id) - instance['server_id'] = args.server_id - instance_list.append(instance) - hccn_table['group_list'].append({ - 'device_num': str(args.nproc_per_node), - 'server_num': '1', - 'group_name': '', - 'instance_count': str(args.nproc_per_node), - 'instance_list': instance_list, - }) - hccn_table['para_plane_nic_location'] = 'device' - hccn_table['para_plane_nic_name'] = [] - for instance_id in range(args.nproc_per_node): - eth_id = visible_devices[instance_id] - hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) - hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) - hccn_table['status'] = 'completed' - - # save hccn_table to file - table_path = os.getcwd() - if not os.path.exists(table_path): - os.mkdir(table_path) - table_fn = os.path.join(table_path, - 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) - with open(table_fn, 'w') as table_fp: - json.dump(hccn_table, table_fp, indent=4) - sys.stdout.flush() - - # spawn the processes - processes = [] - cmds = [] - log_files = [] - env = os.environ.copy() - env['RANK_SIZE'] = str(args.nproc_per_node) - cur_path = os.getcwd() - for rank_id in range(0, args.nproc_per_node): - os.chdir(cur_path) - device_id = visible_devices[rank_id] - device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) - env['RANK_ID'] = str(rank_id) - env['DEVICE_ID'] = str(device_id) - if args.nproc_per_node > 1: - env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn - env['RANK_TABLE_FILE'] = table_fn - if os.path.exists(device_dir): - shutil.rmtree(device_dir) - os.mkdir(device_dir) - os.chdir(device_dir) - cmd = [sys.executable, '-u'] - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') - process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) - processes.append(process) - cmds.append(cmd) - log_files.append(log_file) - for process, cmd, log_file in zip(processes, cmds, log_files): - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process, cmd=cmd) - log_file.close() - - -if __name__ == "__main__": - main() diff --git a/model_zoo/resnet50_quant/train.py b/model_zoo/resnet50_quant/train.py deleted file mode 100755 index b026f97278..0000000000 --- a/model_zoo/resnet50_quant/train.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Train Resnet50 on ImageNet""" - -import os -import argparse - -from mindspore import context -from mindspore import Tensor -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.model import Model, ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.serialization import load_checkpoint -from mindspore.train.quant import quant -from mindspore.communication.management import init -import mindspore.nn as nn -import mindspore.common.initializer as weight_init - -from models.resnet_quant import resnet50_quant -from src.dataset import create_dataset -from src.lr_generator import get_lr -from src.config import quant_set, config_quant, config_noquant -from src.crossentropy import CrossEntropy -from src.utils import _load_param_into_net - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') -parser.add_argument('--device_num', type=int, default=1, help='Device num.') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') -parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') -parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') -args_opt = parser.parse_args() -config = config_quant if quant_set.quantization_aware else config_noquant - -if args_opt.device_target == "Ascend": - device_id = int(os.getenv('DEVICE_ID')) - rank_id = int(os.getenv('RANK_ID')) - rank_size = int(os.getenv('RANK_SIZE')) - run_distribute = rank_size > 1 - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - save_graphs=False, - device_id=device_id, - enable_auto_mixed_precision=True) -else: - raise ValueError("Unsupported device target.") - -if __name__ == '__main__': - # train on ascend - print("training args: {}".format(args_opt)) - print("training configure: {}".format(config)) - print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) - epoch_size = config.epoch_size - - # distribute init - if run_distribute: - context.set_auto_parallel_context(device_num=rank_size, - parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, - mirror_mean=True) - init() - context.set_auto_parallel_context(device_num=args_opt.device_num, - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) - - # define network - net = resnet50_quant(class_num=config.class_num) - net.set_train(True) - - # weight init and load checkpoint file - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - _load_param_into_net(net, param_dict) - epoch_size = config.epoch_size - config.pretrained_epoch_size - else: - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Conv2d): - cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if isinstance(cell, nn.Dense): - cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if not config.use_label_smooth: - config.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - - # define dataset - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=True, - repeat_num=epoch_size, - batch_size=config.batch_size, - target=args_opt.device_target) - step_size = dataset.get_dataset_size() - - if quant_set.quantization_aware: - # convert fusion network to quantization aware network - net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) - - # get learning rate - lr = get_lr(lr_init=config.lr_init, - lr_end=0.0, - lr_max=config.lr_max, - warmup_epochs=config.warmup_epochs, - total_epochs=config.epoch_size, - steps_per_epoch=step_size, - lr_decay_mode='cosine') - if args_opt.pre_trained: - lr = lr[config.pretrained_epoch_size * step_size:] - lr = Tensor(lr) - - # define optimization - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, - config.weight_decay, config.loss_scale) - - # define model - if quant_set.quantization_aware: - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) - else: - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, - amp_level="O2") - - print("============== Starting Training ==============") - time_callback = TimeMonitor(data_size=step_size) - loss_callback = LossMonitor() - callbacks = [time_callback, loss_callback] - if rank_id == 0: - if config.save_checkpoint: - config_ckpt = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_callback = ModelCheckpoint(prefix="ResNet50", - directory=config.save_checkpoint_path, - config=config_ckpt) - callbacks += [ckpt_callback] - model.train(epoch_size, dataset, callbacks=callbacks) - print("============== End Training ==============") diff --git a/model_zoo/resnet_thor/README.md b/model_zoo/resnet_thor/README.md deleted file mode 100644 index 5fb17007ae..0000000000 --- a/model_zoo/resnet_thor/README.md +++ /dev/null @@ -1,128 +0,0 @@ -# ResNet-50-THOR Example - -## Description - -This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by second-order optimizer THOR. THOR is a novel approximate seond-order optimization method in MindSpore. With fewer iterations, THOR can finish ResNet-50 V1.5 training in 72 minutes to top-1 accuracy of 75.9% using 8 Ascend 910, which is much faster than SGD with Momentum. - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Download the dataset ImageNet2012 - -> Unzip the ImageNet2012 dataset to any path you want and the folder structure should include train and eval dataset as follows: -> ``` -> . -> ├── ilsvrc # train dataset -> └── ilsvrc_eval # infer dataset -> ``` - - -## Example structure - -```shell -. -├── resnet_thor - ├── README.md - ├── src - ├── crossentropy.py # CrossEntropy loss function - ├── config.py # parameter configuration - ├── resnet50.py # resnet50 backbone - ├── dataset_helper.py # dataset help for minddata dataset - ├── grad_reducer_thor.py # grad reducer for thor - ├── model_thor.py # model - ├── resnet_thor.py # resnet50_thor backone - ├── thor.py # thor - ├── thor_layer.py # thor layer - └── dataset_imagenet.py # data preprocessing - ├── scripts - ├── run_distribute_train.sh # launch distributed training(8 pcs) - └── run_eval.sh # launch infering - ├── eval.py # infer script - └── train.py # train script -``` - - -## Parameter configuration - -Parameters for both training and inference can be set in config.py. - -``` -"class_num": 1000, # dataset class number -"batch_size": 32, # batch size of input tensor -"loss_scale": 128, # loss scale -"momentum": 0.9, # momentum of THOR optimizer -"weight_decay": 5e-4, # weight decay -"epoch_size": 45, # only valid for taining, which is always 1 for inference -"buffer_size": 1000, # number of queue size in data preprocessing -"image_height": 224, # image height -"image_width": 224, # image width -"save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch -"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint -"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path -"label_smooth": True, # label smooth -"label_smooth_factor": 0.1, # label smooth factor -"frequency": 834, # the step interval to update second-order information matrix -``` - -## Running the example - -### Train - -#### Usage - -``` -# distributed training -Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [DEVICE_NUM] -``` - - -#### Launch - -```bash -# distributed training example(8 pcs) -sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc -``` - -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). - -#### Result - -Training result will be stored in the example path, whose folder name begins with "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. - -``` -# distribute training result(8 pcs) -epoch: 1 step: 5004, loss is 4.4182425 -epoch: 2 step: 5004, loss is 3.740064 -epoch: 3 step: 5004, loss is 4.0546017 -epoch: 4 step: 5004, loss is 3.7598825 -epoch: 5 step: 5004, loss is 3.3744206 -...... -``` - -### Infer - -#### Usage - -``` -# infer -Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] -``` - -#### Launch - -```bash -# infer with checkpoint -sh run_eval.sh dataset/ilsvrc_eval train_parallel0/resnet-42_5004.ckpt -``` - -> checkpoint can be produced in training process. - -#### Result - -Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log. - -``` -result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt -``` diff --git a/model_zoo/resnet_thor/scripts/run_distribute_train.sh b/model_zoo/resnet_thor/scripts/run_distribute_train.sh deleted file mode 100644 index 6fa7457227..0000000000 --- a/model_zoo/resnet_thor/scripts/run_distribute_train.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 3 ] -then - echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [DEVICE_NUM]" -exit 1 -fi - -if [ ! -f $1 ] -then - echo "error: DMINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" -exit 1 -fi - -if [ ! -d $2 ] -then - echo "error: DATASET_PATH=$2 is not a directory" -exit 1 -fi - -BASE_PATH=$(cd "`dirname $0`" || exit; pwd) -cd $BASE_PATH/../ || exit - -ulimit -u unlimited -export DEVICE_NUM=$3 -export RANK_SIZE=$3 -export MINDSPORE_HCCL_CONFIG_PATH=$1 - -for((i=0; i<${DEVICE_NUM}; i++)) -do - export DEVICE_ID=$i - export RANK_ID=$i - rm -rf ./train_parallel$i - mkdir ./train_parallel$i - cp *.py ./train_parallel$i - cp -r ./src ./train_parallel$i - cd ./train_parallel$i || exit - echo "start training for rank $RANK_ID, device $DEVICE_ID" - - env > env.log - python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & - cd .. -done diff --git a/model_zoo/resnet_thor/train.py b/model_zoo/resnet_thor/train.py deleted file mode 100644 index 47f56a0676..0000000000 --- a/model_zoo/resnet_thor/train.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train_imagenet.""" -import argparse -import os -import random - -import numpy as np - -from mindspore import Tensor -from mindspore import context -from mindspore.communication.management import init -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.train.model import ParallelMode -from src.model_thor import Model -from src.resnet_thor import resnet50 -from src.thor import THOR -from src.config import config -from src.crossentropy import CrossEntropy -from src.dataset_imagenet import create_dataset - -random.seed(1) -np.random.seed(1) - -parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') -parser.add_argument('--device_num', type=int, default=1, help='Device num.') -parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.') -parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') - -args_opt = parser.parse_args() -device_id = int(os.getenv('DEVICE_ID')) - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) - - -def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): - """get_model_lr""" - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - for i in range(total_steps): - epoch = (i + 1) / steps_per_epoch - base = (1.0 - float(epoch) / total_epochs) ** decay - lr_local = lr_init * base - if epoch >= 39: - lr_local = lr_local * 0.5 - if epoch >= 40: - lr_local = lr_local * 0.5 - lr_each_step.append(lr_local) - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] - return learning_rate - - -def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): - """get_model_damping""" - damping_each_step = [] - total_steps = steps_per_epoch * total_epochs - for step in range(total_steps): - epoch = (step + 1) / steps_per_epoch - damping_here = damping_init * (decay_rate ** (epoch / 10)) - damping_each_step.append(damping_here) - - current_step = global_step - damping_each_step = np.array(damping_each_step).astype(np.float32) - damping_now = damping_each_step[current_step:] - return damping_now - - -if __name__ == '__main__': - if not args_opt.do_eval and args_opt.run_distribute: - context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") - - init() - - epoch_size = config.epoch_size - damping = get_model_damping(0, 0.03, 0.87, 50, 5004) - net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, - frequency=config.frequency) - - if not config.label_smooth: - config.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - if args_opt.do_train: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, - repeat_num=epoch_size, batch_size=config.batch_size) - step_size = dataset.get_dataset_size() - - loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) - opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, - filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), - filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), - filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), - filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), - config.weight_decay, config.loss_scale) - - model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, - keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) - - time_cb = TimeMonitor(data_size=step_size) - loss_cb = LossMonitor() - cb = [time_cb, loss_cb] - if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) - cb += [ckpt_cb] - - model.train(epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/resnext50/README.md b/model_zoo/resnext50/README.md deleted file mode 100644 index c44844eecc..0000000000 --- a/model_zoo/resnext50/README.md +++ /dev/null @@ -1,128 +0,0 @@ -# ResNext50 Example - -## Description - -This is an example of training ResNext50 with ImageNet dataset in Mindspore. - -## Requirements - -- Install [Mindspore](http://www.mindspore.cn/install/en). -- Downlaod the dataset ImageNet2012. - -## Structure - -```shell -. -└─resnext50 - ├─README.md - ├─scripts - ├─run_standalone_train.sh # launch standalone training(1p) - ├─run_distribute_train.sh # launch distributed training(8p) - └─run_eval.sh # launch evaluating - ├─src - ├─backbone - ├─_init_.py # initalize - ├─resnet.py # resnext50 backbone - ├─utils - ├─_init_.py # initalize - ├─cunstom_op.py # network operation - ├─logging.py # print log - ├─optimizers_init_.py # get parameters - ├─sampler.py # distributed sampler - ├─var_init_.py # calculate gain value - ├─_init_.py # initalize - ├─config.py # parameter configuration - ├─crossentropy.py # CrossEntropy loss function - ├─dataset.py # data preprocessing - ├─head.py # commom head - ├─image_classification.py # get resnet - ├─linear_warmup.py # linear warmup learning rate - ├─warmup_cosine_annealing.py # learning rate each step - ├─warmup_step_lr.py # warmup step learning rate - ├─eval.py # eval net - └─train.py # train net - -``` - -## Parameter Configuration - -Parameters for both training and evaluating can be set in config.py - -``` -"image_height": '224,224' # image size -"num_classes": 1000, # dataset class number -"per_batch_size": 128, # batch size of input tensor -"lr": 0.05, # base learning rate -"lr_scheduler": 'cosine_annealing', # learning rate mode -"lr_epochs": '30,60,90,120', # epoch of lr changing -"lr_gamma": 0.1, # decrease lr by a factor of exponential lr_scheduler -"eta_min": 0, # eta_min in cosine_annealing scheduler -"T_max": 150, # T-max in cosine_annealing scheduler -"max_epoch": 150, # max epoch num to train the model -"backbone": 'resnext50', # backbone metwork -"warmup_epochs" : 1, # warmup epoch -"weight_decay": 0.0001, # weight decay -"momentum": 0.9, # momentum -"is_dynamic_loss_scale": 0, # dynamic loss scale -"loss_scale": 1024, # loss scale -"label_smooth": 1, # label_smooth -"label_smooth_factor": 0.1, # label_smooth_factor -"ckpt_interval": 2000, # ckpt_interval -"ckpt_path": 'outputs/', # checkpoint save location -"is_save_on_master": 1, -"rank": 0, # local rank of distributed -"group_size": 1 # world size of distributed -``` - -## Running the example - -### Train - -#### Usage - -``` -# distribute training example(8p) -sh run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH -# standalone training -sh run_standalone_train.sh DEVICE_ID DATA_PATH -``` - -#### Launch - -```bash -# distributed training example(8p) -sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /ImageNet/train -# standalone training example -sh scripts/run_standalone_train.sh 0 /ImageNet_Original/train -``` - -#### Result - -You can find checkpoint file together with result in log. - -### Evaluation - -#### Usage - -``` -# Evaluation -sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH -``` - -#### Launch - -```bash -# Evaluation with checkpoint -sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt -``` - -> checkpoint can be produced in training process. - -#### Result - -Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. - -``` -acc=78,16%(TOP1) -acc=93.88%(TOP5) -``` \ No newline at end of file diff --git a/model_zoo/resnext50/eval.py b/model_zoo/resnext50/eval.py deleted file mode 100644 index ff5c83843e..0000000000 --- a/model_zoo/resnext50/eval.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Eval""" -import os -import time -import argparse -import datetime -import glob -import numpy as np -import mindspore.nn as nn - -from mindspore import Tensor, context -from mindspore.communication.management import init, get_rank, get_group_size, release -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.common import dtype as mstype - -from src.utils.logging import get_logger -from src.image_classification import get_network -from src.dataset import classification_dataset -from src.config import config - -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target="Ascend", save_graphs=False, device_id=devid) - - - -class ParameterReduce(nn.Cell): - """ParameterReduce""" - def __init__(self): - super(ParameterReduce, self).__init__() - self.cast = P.Cast() - self.reduce = P.AllReduce() - - def construct(self, x): - one = self.cast(F.scalar_to_array(1.0), mstype.float32) - out = x * one - ret = self.reduce(out) - return ret - - -def parse_args(cloud_args=None): - """parse_args""" - parser = argparse.ArgumentParser('mindspore classification test') - - # dataset related - parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir') - parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') - # network related - parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') - parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. ' - 'If it is a direction, it will test all ckpt') - - # logging related - parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') - parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') - - # roma obs - parser.add_argument('--train_url', type=str, default="", help='train url') - - args, _ = parser.parse_known_args() - args = merge_args(args, cloud_args) - args.image_size = config.image_size - args.num_classes = config.num_classes - args.backbone = config.backbone - args.rank = config.rank - args.group_size = config.group_size - - args.image_size = list(map(int, args.image_size.split(','))) - - return args - - -def get_top5_acc(top5_arg, gt_class): - sub_count = 0 - for top5, gt in zip(top5_arg, gt_class): - if gt in top5: - sub_count += 1 - return sub_count - -def merge_args(args, cloud_args): - """merge_args""" - args_dict = vars(args) - if isinstance(cloud_args, dict): - for key in cloud_args.keys(): - val = cloud_args[key] - if key in args_dict and val: - arg_type = type(args_dict[key]) - if arg_type is not type(None): - val = arg_type(val) - args_dict[key] = val - return args - -def test(cloud_args=None): - """test""" - args = parse_args(cloud_args) - - # init distributed - if args.is_distributed: - init() - args.rank = get_rank() - args.group_size = get_group_size() - - args.outputs_dir = os.path.join(args.log_path, - datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - - args.logger = get_logger(args.outputs_dir, args.rank) - args.logger.save_args(args) - - # network - args.logger.important_info('start create network') - if os.path.isdir(args.pretrained): - models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt'))) - print(models) - if args.graph_ckpt: - f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]) - else: - f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1]) - args.models = sorted(models, key=f) - else: - args.models = [args.pretrained,] - - for model in args.models: - de_dataset = classification_dataset(args.data_dir, image_size=args.image_size, - per_batch_size=args.per_batch_size, - max_epoch=1, rank=args.rank, group_size=args.group_size, - mode='eval') - eval_dataloader = de_dataset.create_tuple_iterator() - network = get_network(args.backbone, args.num_classes) - if network is None: - raise NotImplementedError('not implement {}'.format(args.backbone)) - - param_dict = load_checkpoint(model) - param_dict_new = {} - for key, values in param_dict.items(): - if key.startswith('moments.'): - continue - elif key.startswith('network.'): - param_dict_new[key[8:]] = values - else: - param_dict_new[key] = values - - load_param_into_net(network, param_dict_new) - args.logger.info('load model {} success'.format(model)) - - # must add - network.add_flags_recursive(fp16=True) - - img_tot = 0 - top1_correct = 0 - top5_correct = 0 - network.set_train(False) - t_end = time.time() - it = 0 - for data, gt_classes in eval_dataloader: - output = network(Tensor(data, mstype.float32)) - output = output.asnumpy() - - top1_output = np.argmax(output, (-1)) - top5_output = np.argsort(output)[:, -5:] - - t1_correct = np.equal(top1_output, gt_classes).sum() - top1_correct += t1_correct - top5_correct += get_top5_acc(top5_output, gt_classes) - img_tot += args.per_batch_size - - if args.rank == 0 and it == 0: - t_end = time.time() - it = 1 - if args.rank == 0: - time_used = time.time() - t_end - fps = (img_tot - args.per_batch_size) * args.group_size / time_used - args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps)) - results = [[top1_correct], [top5_correct], [img_tot]] - args.logger.info('before results={}'.format(results)) - if args.is_distributed: - model_md5 = model.replace('/', '') - tmp_dir = '/cache' - if not os.path.exists(tmp_dir): - os.mkdir(tmp_dir) - top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5) - top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5) - img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5) - np.save(top1_correct_npy, top1_correct) - np.save(top5_correct_npy, top5_correct) - np.save(img_tot_npy, img_tot) - while True: - rank_ok = True - for other_rank in range(args.group_size): - top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5) - top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5) - img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5) - if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \ - not os.path.exists(img_tot_npy): - rank_ok = False - if rank_ok: - break - - top1_correct_all = 0 - top5_correct_all = 0 - img_tot_all = 0 - for other_rank in range(args.group_size): - top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5) - top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5) - img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5) - top1_correct_all += np.load(top1_correct_npy) - top5_correct_all += np.load(top5_correct_npy) - img_tot_all += np.load(img_tot_npy) - results = [[top1_correct_all], [top5_correct_all], [img_tot_all]] - results = np.array(results) - else: - results = np.array(results) - - args.logger.info('after results={}'.format(results)) - top1_correct = results[0, 0] - top5_correct = results[1, 0] - img_tot = results[2, 0] - acc1 = 100.0 * top1_correct / img_tot - acc5 = 100.0 * top5_correct / img_tot - args.logger.info('after allreduce eval: top1_correct={}, tot={},' - 'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1)) - args.logger.info('after allreduce eval: top5_correct={}, tot={},' - 'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5)) - if args.is_distributed: - release() - - -if __name__ == "__main__": - test() diff --git a/model_zoo/resnext50/scripts/run_eval.sh b/model_zoo/resnext50/scripts/run_eval.sh deleted file mode 100644 index 610faa874e..0000000000 --- a/model_zoo/resnext50/scripts/run_eval.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -DEVICE_ID=$1 -DATA_DIR=$2 -PATH_CHECKPOINT=$3 - -python eval.py \ - --device_id=$DEVICE_ID \ - --pretrained=$PATH_CHECKPOINT \ - --data_dir=$DATA_DIR > log.txt 2>&1 & diff --git a/model_zoo/resnext50/scripts/run_standalone_train.sh b/model_zoo/resnext50/scripts/run_standalone_train.sh deleted file mode 100644 index ca5d8206f3..0000000000 --- a/model_zoo/resnext50/scripts/run_standalone_train.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -DEVICE_ID=$1 -DATA_DIR=$2 -PATH_CHECKPOINT="" -if [ $# == 3 ] -then - PATH_CHECKPOINT=$3 -fi - -python train.py \ - --is_distribute=0 \ - --device_id=$DEVICE_ID \ - --pretrained=$PATH_CHECKPOINT \ - --data_dir=$DATA_DIR > log.txt 2>&1 & - diff --git a/model_zoo/resnext50/src/backbone/resnet.py b/model_zoo/resnext50/src/backbone/resnet.py deleted file mode 100644 index 5b69f9e1f5..0000000000 --- a/model_zoo/resnext50/src/backbone/resnet.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -ResNet based ResNext -""" -import mindspore.nn as nn -from mindspore.ops.operations import TensorAdd, Split, Concat -from mindspore.ops import operations as P -from mindspore.common.initializer import TruncatedNormal - -from src.utils.cunstom_op import SEBlock, GroupConv - - -__all__ = ['ResNet', 'resnext50'] - - -def weight_variable(shape, factor=0.1): - return TruncatedNormal(0.02) - - -def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False, groups=1): - return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias, - padding=padding, pad_mode="pad", group=groups) - - -def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False, groups=1): - return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias, - padding=padding, pad_mode="pad", group=groups) - - -def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False, groups=1): - return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias, - padding=padding, pad_mode="pad", group=groups) - - -class _DownSample(nn.Cell): - """ - Downsample for ResNext-ResNet. - - Args: - in_channels (int): Input channels. - out_channels (int): Output channels. - stride (int): Stride size for the 1*1 convolutional layer. - - Returns: - Tensor, output tensor. - - Examples: - >>>DownSample(32, 64, 2) - """ - def __init__(self, in_channels, out_channels, stride): - super(_DownSample, self).__init__() - self.conv = conv1x1(in_channels, out_channels, stride=stride, padding=0) - self.bn = nn.BatchNorm2d(out_channels) - - def construct(self, x): - out = self.conv(x) - out = self.bn(out) - return out - -class BasicBlock(nn.Cell): - """ - ResNet basic block definition. - - Args: - in_channels (int): Input channels. - out_channels (int): Output channels. - stride (int): Stride size for the first convolutional layer. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>>BasicBlock(32, 256, stride=2) - """ - expansion = 1 - - def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(in_channels, out_channels, stride=stride) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = P.ReLU() - self.conv2 = conv3x3(out_channels, out_channels, stride=1) - self.bn2 = nn.BatchNorm2d(out_channels) - - self.use_se = use_se - if self.use_se: - self.se = SEBlock(out_channels) - - self.down_sample_flag = False - if down_sample is not None: - self.down_sample = down_sample - self.down_sample_flag = True - - self.add = TensorAdd() - - def construct(self, x): - identity = x - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - - if self.use_se: - out = self.se(out) - - if self.down_sample_flag: - identity = self.down_sample(x) - - out = self.add(out, identity) - out = self.relu(out) - return out - -class Bottleneck(nn.Cell): - """ - ResNet Bottleneck block definition. - - Args: - in_channels (int): Input channels. - out_channels (int): Output channels. - stride (int): Stride size for the initial convolutional layer. Default: 1. - - Returns: - Tensor, the ResNet unit's output. - - Examples: - >>>Bottleneck(3, 256, stride=2) - """ - expansion = 4 - - def __init__(self, in_channels, out_channels, stride=1, down_sample=None, - base_width=64, groups=1, use_se=False, **kwargs): - super(Bottleneck, self).__init__() - - width = int(out_channels * (base_width / 64.0)) * groups - self.groups = groups - self.conv1 = conv1x1(in_channels, width, stride=1) - self.bn1 = nn.BatchNorm2d(width) - self.relu = P.ReLU() - - self.conv3x3s = nn.CellList() - - self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups) - self.op_split = Split(axis=1, output_num=self.groups) - self.op_concat = Concat(axis=1) - - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = conv1x1(width, out_channels * self.expansion, stride=1) - self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) - - self.use_se = use_se - if self.use_se: - self.se = SEBlock(out_channels * self.expansion) - - self.down_sample_flag = False - if down_sample is not None: - self.down_sample = down_sample - self.down_sample_flag = True - - self.cast = P.Cast() - self.add = TensorAdd() - - def construct(self, x): - identity = x - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - out = self.conv3(out) - out = self.bn3(out) - - if self.use_se: - out = self.se(out) - - if self.down_sample_flag: - identity = self.down_sample(x) - - out = self.add(out, identity) - out = self.relu(out) - return out - -class ResNet(nn.Cell): - """ - ResNet architecture. - - Args: - block (cell): Block for network. - layers (list): Numbers of block in different layers. - width_per_group (int): Width of every group. - groups (int): Groups number. - - Returns: - Tuple, output tensor tuple. - - Examples: - >>>ResNet() - """ - def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False): - super(ResNet, self).__init__() - self.in_channels = 64 - self.groups = groups - self.base_width = width_per_group - - self.conv = conv7x7(3, self.in_channels, stride=2, padding=3) - self.bn = nn.BatchNorm2d(self.in_channels) - self.relu = P.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') - - self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se) - - self.out_channels = 512 * block.expansion - self.cast = P.Cast() - - def construct(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.relu(x) - x = self.maxpool(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - return x - - def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False): - """_make_layer""" - down_sample = None - if stride != 1 or self.in_channels != out_channels * block.expansion: - down_sample = _DownSample(self.in_channels, - out_channels * block.expansion, - stride=stride) - - layers = [] - layers.append(block(self.in_channels, - out_channels, - stride=stride, - down_sample=down_sample, - base_width=self.base_width, - groups=self.groups, - use_se=use_se)) - self.in_channels = out_channels * block.expansion - for _ in range(1, blocks_num): - layers.append(block(self.in_channels, out_channels, - base_width=self.base_width, groups=self.groups, use_se=use_se)) - - return nn.SequentialCell(layers) - - def get_out_channels(self): - return self.out_channels - - -def resnext50(): - return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32) diff --git a/model_zoo/resnext50/src/config.py b/model_zoo/resnext50/src/config.py deleted file mode 100644 index c1a12aa14e..0000000000 --- a/model_zoo/resnext50/src/config.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""config""" -from easydict import EasyDict as ed - -config = ed({ - "image_size": '224,224', - "num_classes": 1000, - - "lr": 0.4, - "lr_scheduler": 'cosine_annealing', - "lr_epochs": '30,60,90,120', - "lr_gamma": 0.1, - "eta_min": 0, - "T_max": 150, - "max_epoch": 150, - "backbone": 'resnext50', - "warmup_epochs": 1, - - "weight_decay": 0.0001, - "momentum": 0.9, - "is_dynamic_loss_scale": 0, - "loss_scale": 1024, - "label_smooth": 1, - "label_smooth_factor": 0.1, - - "ckpt_interval": 1250, - "ckpt_path": 'outputs/', - "is_save_on_master": 1, - - "rank": 0, - "group_size": 1 -}) diff --git a/model_zoo/resnext50/src/dataset.py b/model_zoo/resnext50/src/dataset.py deleted file mode 100644 index 9608e3c790..0000000000 --- a/model_zoo/resnext50/src/dataset.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -dataset processing. -""" -import os -from mindspore.common import dtype as mstype -import mindspore.dataset as de -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as V_C -from PIL import Image, ImageFile -from src.utils.sampler import DistributedSampler - -ImageFile.LOAD_TRUNCATED_IMAGES = True - -class TxtDataset(): - """ - create txt dataset. - - Args: - Returns: - de_dataset. - """ - def __init__(self, root, txt_name): - super(TxtDataset, self).__init__() - self.imgs = [] - self.labels = [] - fin = open(txt_name, "r") - for line in fin: - img_name, label = line.strip().split(' ') - self.imgs.append(os.path.join(root, img_name)) - self.labels.append(int(label)) - fin.close() - - def __getitem__(self, index): - img = Image.open(self.imgs[index]).convert('RGB') - return img, self.labels[index] - - def __len__(self): - return len(self.imgs) - - -def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size, - mode='train', - input_mode='folder', - root='', - num_parallel_workers=None, - shuffle=None, - sampler=None, - class_indexing=None, - drop_remainder=True, - transform=None, - target_transform=None): - """ - A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt". - If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images - are written into a textfile. - - Args: - data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"". - Or path of the textfile that contains every image's path of the dataset. - image_size (str): Size of the input images. - per_batch_size (int): the batch size of evey step during training. - max_epoch (int): the number of epochs. - rank (int): The shard ID within num_shards (default=None). - group_size (int): Number of shards that the dataset should be divided - into (default=None). - mode (str): "train" or others. Default: " train". - input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder". - root (str): the images path for "input_mode="txt"". Default: " ". - num_parallel_workers (int): Number of workers to read the data. Default: None. - shuffle (bool): Whether or not to perform shuffle on the dataset - (default=None, performs shuffle). - sampler (Sampler): Object used to choose samples from the dataset. Default: None. - class_indexing (dict): A str-to-int mapping from folder name to index - (default=None, the folder names will be sorted - alphabetically and each class will be given a - unique index starting from 0). - - Examples: - >>> from mindvision.common.datasets.classification import classification_dataset - >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images - >>> dataset_dir = "/path/to/imagefolder_directory" - >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], - >>> per_batch_size=64, max_epoch=100, - >>> rank=0, group_size=4) - >>> # Path of the textfile that contains every image's path of the dataset. - >>> dataset_dir = "/path/to/dataset/images/train.txt" - >>> images_dir = "/path/to/dataset/images" - >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], - >>> per_batch_size=64, max_epoch=100, - >>> rank=0, group_size=4, - >>> input_mode="txt", root=images_dir) - """ - - mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] - std = [0.229 * 255, 0.224 * 255, 0.225 * 255] - - if transform is None: - if mode == 'train': - transform_img = [ - V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), - V_C.RandomHorizontalFlip(prob=0.5), - V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4), - V_C.Normalize(mean=mean, std=std), - V_C.HWC2CHW() - ] - else: - transform_img = [ - V_C.Decode(), - V_C.Resize((256, 256)), - V_C.CenterCrop(image_size), - V_C.Normalize(mean=mean, std=std), - V_C.HWC2CHW() - ] - else: - transform_img = transform - - if target_transform is None: - transform_label = [C.TypeCast(mstype.int32)] - else: - transform_label = target_transform - - if input_mode == 'folder': - de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers, - shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, - num_shards=group_size, shard_id=rank) - else: - dataset = TxtDataset(root, data_dir) - sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) - de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) - de_dataset.set_dataset_size(len(sampler)) - - de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) - de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) - - columns_to_project = ["image", "label"] - de_dataset = de_dataset.project(columns=columns_to_project) - - de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder) - de_dataset = de_dataset.repeat(max_epoch) - - return de_dataset diff --git a/model_zoo/resnext50/src/image_classification.py b/model_zoo/resnext50/src/image_classification.py deleted file mode 100644 index d8003ad200..0000000000 --- a/model_zoo/resnext50/src/image_classification.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Image classifiation. -""" -import math -import mindspore.nn as nn -from mindspore.common import initializer as init -import src.backbone as backbones -import src.head as heads -from src.utils.var_init import default_recurisive_init, KaimingNormal - - -class ImageClassificationNetwork(nn.Cell): - """ - architecture of image classification network. - - Args: - Returns: - Tensor, output tensor. - """ - def __init__(self, backbone, head): - super(ImageClassificationNetwork, self).__init__() - self.backbone = backbone - self.head = head - - def construct(self, x): - x = self.backbone(x) - x = self.head(x) - return x - -class Resnet(ImageClassificationNetwork): - """ - Resnet architecture. - Args: - backbone_name (string): backbone. - num_classes (int): number of classes. - Returns: - Resnet. - """ - def __init__(self, backbone_name, num_classes): - self.backbone_name = backbone_name - backbone = backbones.__dict__[self.backbone_name]() - out_channels = backbone.get_out_channels() - head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) - super(Resnet, self).__init__(backbone, head) - - default_recurisive_init(self) - - for cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): - cell.weight.default_input = init.initializer( - KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), - cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor() - cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor() - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - for cell in self.cells_and_names(): - if isinstance(cell, backbones.resnet.Bottleneck): - cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.default_input.shape).to_tensor() - elif isinstance(cell, backbones.resnet.BasicBlock): - cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.default_input.shape).to_tensor() - - - -def get_network(backbone_name, num_classes): - if backbone_name in ['resnext50']: - return Resnet(backbone_name, num_classes) - return None diff --git a/model_zoo/resnext50/src/utils/cunstom_op.py b/model_zoo/resnext50/src/utils/cunstom_op.py deleted file mode 100644 index cbe89a1610..0000000000 --- a/model_zoo/resnext50/src/utils/cunstom_op.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network operations -""" -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common import dtype as mstype - - -class GlobalAvgPooling(nn.Cell): - """ - global average pooling feature map. - - Args: - mean (tuple): means for each channel. - """ - def __init__(self): - super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(True) - self.shape = P.Shape() - self.reshape = P.Reshape() - - def construct(self, x): - x = self.mean(x, (2, 3)) - b, c, _, _ = self.shape(x) - x = self.reshape(x, (b, c)) - return x - - -class SEBlock(nn.Cell): - """ - squeeze and excitation block. - - Args: - channel (int): number of feature maps. - reduction (int): weight. - """ - def __init__(self, channel, reduction=16): - super(SEBlock, self).__init__() - - self.avg_pool = GlobalAvgPooling() - self.fc1 = nn.Dense(channel, channel // reduction) - self.relu = P.ReLU() - self.fc2 = nn.Dense(channel // reduction, channel) - self.sigmoid = P.Sigmoid() - self.reshape = P.Reshape() - self.shape = P.Shape() - self.sum = P.Sum() - self.cast = P.Cast() - - def construct(self, x): - b, c = self.shape(x) - y = self.avg_pool(x) - - y = self.reshape(y, (b, c)) - y = self.fc1(y) - y = self.relu(y) - y = self.fc2(y) - y = self.sigmoid(y) - y = self.reshape(y, (b, c, 1, 1)) - return x * y - -class GroupConv(nn.Cell): - """ - group convolution operation. - - Args: - in_channels (int): Input channels of feature map. - out_channels (int): Output channels of feature map. - kernel_size (int): Size of convolution kernel. - stride (int): Stride size for the group convolution layer. - - Returns: - tensor, output tensor. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False): - super(GroupConv, self).__init__() - assert in_channels % groups == 0 and out_channels % groups == 0 - self.groups = groups - self.convs = nn.CellList() - self.op_split = P.Split(axis=1, output_num=self.groups) - self.op_concat = P.Concat(axis=1) - self.cast = P.Cast() - for _ in range(groups): - self.convs.append(nn.Conv2d(in_channels//groups, out_channels//groups, - kernel_size=kernel_size, stride=stride, has_bias=has_bias, - padding=pad, pad_mode=pad_mode, group=1)) - - def construct(self, x): - features = self.op_split(x) - outputs = () - for i in range(self.groups): - outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),) - out = self.op_concat(outputs) - return out diff --git a/model_zoo/resnext50/src/utils/var_init.py b/model_zoo/resnext50/src/utils/var_init.py deleted file mode 100644 index 51fc109990..0000000000 --- a/model_zoo/resnext50/src/utils/var_init.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Initialize. -""" -import math -from functools import reduce -import numpy as np -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common import initializer as init - -def _calculate_gain(nonlinearity, param=None): - r""" - Return the recommended gain value for the given nonlinearity function. - - The values are as follows: - ================= ==================================================== - nonlinearity gain - ================= ==================================================== - Linear / Identity :math:`1` - Conv{1,2,3}D :math:`1` - Sigmoid :math:`1` - Tanh :math:`\frac{5}{3}` - ReLU :math:`\sqrt{2}` - Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` - ================= ==================================================== - - Args: - nonlinearity: the non-linear function - param: optional parameter for the non-linear function - - Examples: - >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 - """ - linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] - if nonlinearity in linear_fns or nonlinearity == 'sigmoid': - return 1 - if nonlinearity == 'tanh': - return 5.0 / 3 - if nonlinearity == 'relu': - return math.sqrt(2.0) - if nonlinearity == 'leaky_relu': - if param is None: - negative_slope = 0.01 - elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): - negative_slope = param - else: - raise ValueError("negative_slope {} not a valid number".format(param)) - return math.sqrt(2.0 / (1 + negative_slope ** 2)) - - raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) - -def _assignment(arr, num): - """Assign the value of `num` to `arr`.""" - if arr.shape == (): - arr = arr.reshape((1)) - arr[:] = num - arr = arr.reshape(()) - else: - if isinstance(num, np.ndarray): - arr[:] = num[:] - else: - arr[:] = num - return arr - -def _calculate_in_and_out(arr): - """ - Calculate n_in and n_out. - - Args: - arr (Array): Input array. - - Returns: - Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. - """ - dim = len(arr.shape) - if dim < 2: - raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") - - n_in = arr.shape[1] - n_out = arr.shape[0] - - if dim > 2: - counter = reduce(lambda x, y: x * y, arr.shape[2:]) - n_in *= counter - n_out *= counter - return n_in, n_out - -def _select_fan(array, mode): - mode = mode.lower() - valid_modes = ['fan_in', 'fan_out'] - if mode not in valid_modes: - raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) - - fan_in, fan_out = _calculate_in_and_out(array) - return fan_in if mode == 'fan_in' else fan_out - -class KaimingInit(init.Initializer): - r""" - Base Class. Initialize the array with He kaiming algorithm. - - Args: - a: the negative slope of the rectifier used after this layer (only - used with ``'leaky_relu'``) - mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` - preserves the magnitude of the variance of the weights in the - forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the - backwards pass. - nonlinearity: the non-linear function, recommended to use only with - ``'relu'`` or ``'leaky_relu'`` (default). - """ - def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): - super(KaimingInit, self).__init__() - self.mode = mode - self.gain = _calculate_gain(nonlinearity, a) - def _initialize(self, arr): - pass - - -class KaimingUniform(KaimingInit): - r""" - Initialize the array with He kaiming uniform algorithm. The resulting tensor will - have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where - - .. math:: - \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} - - Input: - arr (Array): The array to be assigned. - - Returns: - Array, assigned array. - - Examples: - >>> w = np.empty(3, 5) - >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') - """ - - def _initialize(self, arr): - fan = _select_fan(arr, self.mode) - bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) - np.random.seed(0) - data = np.random.uniform(-bound, bound, arr.shape) - - _assignment(arr, data) - - -class KaimingNormal(KaimingInit): - r""" - Initialize the array with He kaiming normal algorithm. The resulting tensor will - have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where - - .. math:: - \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} - - Input: - arr (Array): The array to be assigned. - - Returns: - Array, assigned array. - - Examples: - >>> w = np.empty(3, 5) - >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') - """ - - def _initialize(self, arr): - fan = _select_fan(arr, self.mode) - std = self.gain / math.sqrt(fan) - np.random.seed(0) - data = np.random.normal(0, std, arr.shape) - - _assignment(arr, data) - - -def default_recurisive_init(custom_cell): - """default_recurisive_init""" - for _, cell in custom_cell.cells_and_names(): - if isinstance(cell, nn.Conv2d): - cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if cell.bias is not None: - fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) - bound = 1 / math.sqrt(fan_in) - np.random.seed(0) - cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), - cell.bias.default_input.dtype) - elif isinstance(cell, nn.Dense): - cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if cell.bias is not None: - fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) - bound = 1 / math.sqrt(fan_in) - np.random.seed(0) - cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), - cell.bias.default_input.dtype) - elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): - pass diff --git a/model_zoo/resnext50/train.py b/model_zoo/resnext50/train.py deleted file mode 100644 index 29ccd9b00c..0000000000 --- a/model_zoo/resnext50/train.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train ImageNet.""" -import os -import time -import argparse -import datetime - -import mindspore.nn as nn -from mindspore import Tensor, context -from mindspore import ParallelMode -from mindspore.nn.optim import Momentum -from mindspore.communication.management import init, get_rank, get_group_size -from mindspore.train.callback import ModelCheckpoint -from mindspore.train.callback import CheckpointConfig, Callback -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.model import Model -from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager - -from src.dataset import classification_dataset -from src.crossentropy import CrossEntropy -from src.warmup_step_lr import warmup_step_lr -from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr -from src.utils.logging import get_logger -from src.utils.optimizers__init__ import get_param_groups -from src.image_classification import get_network -from src.config import config - -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target="Ascend", save_graphs=False, device_id=devid) - -class BuildTrainNetwork(nn.Cell): - """build training network""" - def __init__(self, network, criterion): - super(BuildTrainNetwork, self).__init__() - self.network = network - self.criterion = criterion - - def construct(self, input_data, label): - output = self.network(input_data) - loss = self.criterion(output, label) - return loss - -class ProgressMonitor(Callback): - """monitor loss and time""" - def __init__(self, args): - super(ProgressMonitor, self).__init__() - self.me_epoch_start_time = 0 - self.me_epoch_start_step_num = 0 - self.args = args - self.ckpt_history = [] - - def begin(self, run_context): - self.args.logger.info('start network train...') - - def epoch_begin(self, run_context): - pass - - def epoch_end(self, run_context, *me_args): - cb_params = run_context.original_args() - me_step = cb_params.cur_step_num - 1 - - real_epoch = me_step // self.args.steps_per_epoch - time_used = time.time() - self.me_epoch_start_time - fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used - self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}' - 'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean)) - - if self.args.rank_save_ckpt_flag: - import glob - ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt')) - for ckpt in ckpts: - ckpt_fn = os.path.basename(ckpt) - if not ckpt_fn.startswith('{}-'.format(self.args.rank)): - continue - if ckpt in self.ckpt_history: - continue - self.ckpt_history.append(ckpt) - self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},' - 'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn)) - - - self.me_epoch_start_step_num = me_step - self.me_epoch_start_time = time.time() - - def step_begin(self, run_context): - pass - - def step_end(self, run_context, *me_args): - pass - - def end(self, run_context): - self.args.logger.info('end network train...') - - -def parse_args(cloud_args=None): - """parameters""" - parser = argparse.ArgumentParser('mindspore classification training') - - # dataset related - parser.add_argument('--data_dir', type=str, default='', help='train data dir') - parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu') - # network related - parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') - - # distributed related - parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') - # roma obs - parser.add_argument('--train_url', type=str, default="", help='train url') - - args, _ = parser.parse_known_args() - args = merge_args(args, cloud_args) - args.image_size = config.image_size - args.num_classes = config.num_classes - args.lr = config.lr - args.lr_scheduler = config.lr_scheduler - args.lr_epochs = config.lr_epochs - args.lr_gamma = config.lr_gamma - args.eta_min = config.eta_min - args.T_max = config.T_max - args.max_epoch = config.max_epoch - args.backbone = config.backbone - args.warmup_epochs = config.warmup_epochs - args.weight_decay = config.weight_decay - args.momentum = config.momentum - args.is_dynamic_loss_scale = config.is_dynamic_loss_scale - args.loss_scale = config.loss_scale - args.label_smooth = config.label_smooth - args.label_smooth_factor = config.label_smooth_factor - args.ckpt_interval = config.ckpt_interval - args.ckpt_path = config.ckpt_path - args.is_save_on_master = config.is_save_on_master - args.rank = config.rank - args.group_size = config.group_size - args.lr_epochs = list(map(int, args.lr_epochs.split(','))) - args.image_size = list(map(int, args.image_size.split(','))) - - return args - -def merge_args(args, cloud_args): - """dictionary""" - args_dict = vars(args) - if isinstance(cloud_args, dict): - for key in cloud_args.keys(): - val = cloud_args[key] - if key in args_dict and val: - arg_type = type(args_dict[key]) - if arg_type is not type(None): - val = arg_type(val) - args_dict[key] = val - return args - -def train(cloud_args=None): - """training process""" - args = parse_args(cloud_args) - - # init distributed - if args.is_distributed: - init() - args.rank = get_rank() - args.group_size = get_group_size() - - if args.is_dynamic_loss_scale == 1: - args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt - - # select for master rank save ckpt or all rank save, compatiable for model parallel - args.rank_save_ckpt_flag = 0 - if args.is_save_on_master: - if args.rank == 0: - args.rank_save_ckpt_flag = 1 - else: - args.rank_save_ckpt_flag = 1 - - # logger - args.outputs_dir = os.path.join(args.ckpt_path, - datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - args.logger = get_logger(args.outputs_dir, args.rank) - - # dataloader - de_dataset = classification_dataset(args.data_dir, args.image_size, - args.per_batch_size, args.max_epoch, - args.rank, args.group_size) - de_dataset.map_model = 4 # !!!important - args.steps_per_epoch = de_dataset.get_dataset_size() - - args.logger.save_args(args) - - # network - args.logger.important_info('start create network') - # get network and init - network = get_network(args.backbone, args.num_classes) - if network is None: - raise NotImplementedError('not implement {}'.format(args.backbone)) - network.add_flags_recursive(fp16=True) - # loss - if not args.label_smooth: - args.label_smooth_factor = 0.0 - criterion = CrossEntropy(smooth_factor=args.label_smooth_factor, - num_classes=args.num_classes) - - # load pretrain model - if os.path.isfile(args.pretrained): - param_dict = load_checkpoint(args.pretrained) - param_dict_new = {} - for key, values in param_dict.items(): - if key.startswith('moments.'): - continue - elif key.startswith('network.'): - param_dict_new[key[8:]] = values - else: - param_dict_new[key] = values - load_param_into_net(network, param_dict_new) - args.logger.info('load model {} success'.format(args.pretrained)) - - # lr scheduler - if args.lr_scheduler == 'exponential': - lr = warmup_step_lr(args.lr, - args.lr_epochs, - args.steps_per_epoch, - args.warmup_epochs, - args.max_epoch, - gamma=args.lr_gamma, - ) - elif args.lr_scheduler == 'cosine_annealing': - lr = warmup_cosine_annealing_lr(args.lr, - args.steps_per_epoch, - args.warmup_epochs, - args.max_epoch, - args.T_max, - args.eta_min) - else: - raise NotImplementedError(args.lr_scheduler) - - # optimizer - opt = Momentum(params=get_param_groups(network), - learning_rate=Tensor(lr), - momentum=args.momentum, - weight_decay=args.weight_decay, - loss_scale=args.loss_scale) - - - criterion.add_flags_recursive(fp32=True) - - # package training process, adjust lr + forward + backward + optimizer - train_net = BuildTrainNetwork(network, criterion) - if args.is_distributed: - parallel_mode = ParallelMode.DATA_PARALLEL - else: - parallel_mode = ParallelMode.STAND_ALONE - if args.is_dynamic_loss_scale == 1: - loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) - else: - loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) - - # Model api changed since TR5_branch 2020/03/09 - context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, - parameter_broadcast=True, mirror_mean=True) - model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager) - - # checkpoint save - progress_cb = ProgressMonitor(args) - callbacks = [progress_cb,] - if args.rank_save_ckpt_flag: - ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, - keep_checkpoint_max=ckpt_max_num) - ckpt_cb = ModelCheckpoint(config=ckpt_config, - directory=args.outputs_dir, - prefix='{}'.format(args.rank)) - callbacks.append(ckpt_cb) - - model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True) - - -if __name__ == "__main__": - train() diff --git a/model_zoo/ssd/eval.py b/model_zoo/ssd/eval.py deleted file mode 100644 index 9054bf6f24..0000000000 --- a/model_zoo/ssd/eval.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# less required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Evaluation for SSD""" - -import os -import argparse -import time -import numpy as np -from mindspore import context, Tensor -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.ssd import SSD300, ssd_mobilenet_v2 -from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image -from src.config import config -from src.coco_eval import metrics - -def ssd_eval(dataset_path, ckpt_path): - """SSD evaluation.""" - batch_size = 1 - ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) - net = SSD300(ssd_mobilenet_v2(), config, is_training=False) - print("Load Checkpoint!") - param_dict = load_checkpoint(ckpt_path) - net.init_parameters_data() - load_param_into_net(net, param_dict) - - net.set_train(False) - i = batch_size - total = ds.get_dataset_size() * batch_size - start = time.time() - pred_data = [] - print("\n========================================\n") - print("total images num: ", total) - print("Processing, please wait a moment.") - for data in ds.create_dict_iterator(): - img_id = data['img_id'] - img_np = data['image'] - image_shape = data['image_shape'] - - output = net(Tensor(img_np)) - for batch_idx in range(img_np.shape[0]): - pred_data.append({"boxes": output[0].asnumpy()[batch_idx], - "box_scores": output[1].asnumpy()[batch_idx], - "img_id": int(np.squeeze(img_id[batch_idx])), - "image_shape": image_shape[batch_idx]}) - percent = round(i / total * 100., 2) - - print(f' {str(percent)} [{i}/{total}]', end='\r') - i += batch_size - cost_time = int((time.time() - start) * 1000) - print(f' 100% [{total}/{total}] cost {cost_time} ms') - mAP = metrics(pred_data) - print("\n========================================\n") - print(f"mAP: {mAP}") - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='SSD evaluation') - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") - parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") - args_opt = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - - prefix = "ssd_eval.mindrecord" - mindrecord_dir = config.mindrecord_dir - mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") - if not os.path.exists(mindrecord_file): - if not os.path.isdir(mindrecord_dir): - os.makedirs(mindrecord_dir) - if args_opt.dataset == "coco": - if os.path.isdir(config.coco_root): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("coco", False, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("coco_root not exits.") - else: - if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("other", False, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("IMAGE_DIR or ANNO_PATH not exits.") - - print("Start Eval!") - ssd_eval(mindrecord_file, args_opt.checkpoint_path) diff --git a/model_zoo/ssd/scripts/run_distribute_train.sh b/model_zoo/ssd/scripts/run_distribute_train.sh deleted file mode 100644 index 60eccf2c40..0000000000 --- a/model_zoo/ssd/scripts/run_distribute_train.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "==============================================================================================================" -echo "Please run the scipt as: " -echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" -echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" -echo "It is better to use absolute path." -echo "=================================================================================================================" - -if [ $# != 5 ] && [ $# != 7 ] -then - echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ -[MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" - exit 1 -fi - -# Before start distribute train, first create mindrecord files. -BASE_PATH=$(cd "`dirname $0`" || exit; pwd) -cd $BASE_PATH/../ || exit -python train.py --only_create_dataset=1 - -echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" - -export RANK_SIZE=$1 -EPOCH_SIZE=$2 -LR=$3 -DATASET=$4 -PRE_TRAINED=$6 -PRE_TRAINED_EPOCH_SIZE=$7 -export MINDSPORE_HCCL_CONFIG_PATH=$5 - -for((i=0;i env.log - if [ $# == 5 ] - then - python train.py \ - --distribute=1 \ - --lr=$LR \ - --dataset=$DATASET \ - --device_num=$RANK_SIZE \ - --device_id=$DEVICE_ID \ - --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & - fi - - if [ $# == 7 ] - then - python train.py \ - --distribute=1 \ - --lr=$LR \ - --dataset=$DATASET \ - --device_num=$RANK_SIZE \ - --device_id=$DEVICE_ID \ - --pre_trained=$PRE_TRAINED \ - --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ - --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & - fi - - cd ../ -done diff --git a/model_zoo/ssd/src/config.py b/model_zoo/ssd/src/config.py deleted file mode 100644 index 683b8de31f..0000000000 --- a/model_zoo/ssd/src/config.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#" ============================================================================ - -"""Config parameters for SSD models.""" - -from easydict import EasyDict as ed - -config = ed({ - "img_shape": [300, 300], - "num_ssd_boxes": 1917, - "neg_pre_positive": 3, - "match_thershold": 0.5, - "nms_thershold": 0.6, - "min_score": 0.1, - "max_boxes": 100, - - # learing rate settings - "global_step": 0, - "lr_init": 0.001, - "lr_end_rate": 0.001, - "warmup_epochs": 2, - "momentum": 0.9, - "weight_decay": 1.5e-4, - - # network - "num_default": [3, 6, 6, 6, 6, 6], - "extras_in_channels": [256, 576, 1280, 512, 256, 256], - "extras_out_channels": [576, 1280, 512, 256, 256, 128], - "extras_srides": [1, 1, 2, 2, 2, 2], - "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], - "feature_size": [19, 10, 5, 3, 2, 1], - "min_scale": 0.2, - "max_scale": 0.95, - "aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], - "steps": (16, 32, 64, 100, 150, 300), - "prior_scaling": (0.1, 0.2), - "gamma": 2.0, - "alpha": 0.75, - - # `mindrecord_dir` and `coco_root` are better to use absolute path. - "mindrecord_dir": "/data/MindRecord_COCO", - "coco_root": "/data/coco2017", - "train_data_type": "train2017", - "val_data_type": "val2017", - "instances_set": "annotations/instances_{}.json", - "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', - 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', - 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', - 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', - 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', - 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', - 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', - 'teddy bear', 'hair drier', 'toothbrush'), - "num_classes": 81, - - # if coco used, `image_dir` and `anno_path` are useless. - "image_dir": "", - "anno_path": "", -}) diff --git a/model_zoo/ssd/src/dataset.py b/model_zoo/ssd/src/dataset.py deleted file mode 100644 index 19c66fc598..0000000000 --- a/model_zoo/ssd/src/dataset.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SSD dataset""" - -from __future__ import division - -import os -import cv2 -import numpy as np - -import mindspore.dataset as de -import mindspore.dataset.transforms.vision.c_transforms as C -from mindspore.mindrecord import FileWriter -from .config import config -from .box_utils import jaccard_numpy, ssd_bboxes_encode - - -def _rand(a=0., b=1.): - """Generate random.""" - return np.random.rand() * (b - a) + a - - -def random_sample_crop(image, boxes): - """Random Crop the image and boxes""" - height, width, _ = image.shape - min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) - - if min_iou is None: - return image, boxes - - # max trails (50) - for _ in range(50): - image_t = image - - w = _rand(0.3, 1.0) * width - h = _rand(0.3, 1.0) * height - - # aspect ratio constraint b/t .5 & 2 - if h / w < 0.5 or h / w > 2: - continue - - left = _rand() * (width - w) - top = _rand() * (height - h) - - rect = np.array([int(top), int(left), int(top+h), int(left+w)]) - overlap = jaccard_numpy(boxes, rect) - - # dropout some boxes - drop_mask = overlap > 0 - if not drop_mask.any(): - continue - - if overlap[drop_mask].min() < min_iou: - continue - - image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] - - centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 - - m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) - m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) - - # mask in that both m1 and m2 are true - mask = m1 * m2 * drop_mask - - # have any valid boxes? try again if not - if not mask.any(): - continue - - # take only matching gt boxes - boxes_t = boxes[mask, :].copy() - - boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) - boxes_t[:, :2] -= rect[:2] - boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) - boxes_t[:, 2:4] -= rect[:2] - - return image_t, boxes_t - return image, boxes - - -def preprocess_fn(img_id, image, box, is_training): - """Preprocess function for dataset.""" - def _infer_data(image, input_shape): - img_h, img_w, _ = image.shape - input_h, input_w = input_shape - - image = cv2.resize(image, (input_w, input_h)) - - #When the channels of image is 1 - if len(image.shape) == 2: - image = np.expand_dims(image, axis=-1) - image = np.concatenate([image, image, image], axis=-1) - - return img_id, image, np.array((img_h, img_w), np.float32) - - def _data_aug(image, box, is_training, image_size=(300, 300)): - """Data augmentation function.""" - ih, iw, _ = image.shape - w, h = image_size - - if not is_training: - return _infer_data(image, image_size) - - # Random crop - box = box.astype(np.float32) - image, box = random_sample_crop(image, box) - ih, iw, _ = image.shape - - # Resize image - image = cv2.resize(image, (w, h)) - - # Flip image or not - flip = _rand() < .5 - if flip: - image = cv2.flip(image, 1, dst=None) - - # When the channels of image is 1 - if len(image.shape) == 2: - image = np.expand_dims(image, axis=-1) - image = np.concatenate([image, image, image], axis=-1) - - box[:, [0, 2]] = box[:, [0, 2]] / ih - box[:, [1, 3]] = box[:, [1, 3]] / iw - - if flip: - box[:, [1, 3]] = 1 - box[:, [3, 1]] - - box, label, num_match = ssd_bboxes_encode(box) - return image, box, label, num_match - return _data_aug(image, box, is_training, image_size=config.img_shape) - - -def create_coco_label(is_training): - """Get image path and annotation from COCO.""" - from pycocotools.coco import COCO - - coco_root = config.coco_root - data_type = config.val_data_type - if is_training: - data_type = config.train_data_type - - #Classes need to train or test. - train_cls = config.coco_classes - train_cls_dict = {} - for i, cls in enumerate(train_cls): - train_cls_dict[cls] = i - - anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) - - coco = COCO(anno_json) - classs_dict = {} - cat_ids = coco.loadCats(coco.getCatIds()) - for cat in cat_ids: - classs_dict[cat["id"]] = cat["name"] - - image_ids = coco.getImgIds() - images = [] - image_path_dict = {} - image_anno_dict = {} - - for img_id in image_ids: - image_info = coco.loadImgs(img_id) - file_name = image_info[0]["file_name"] - anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) - anno = coco.loadAnns(anno_ids) - image_path = os.path.join(coco_root, data_type, file_name) - annos = [] - iscrowd = False - for label in anno: - bbox = label["bbox"] - class_name = classs_dict[label["category_id"]] - iscrowd = iscrowd or label["iscrowd"] - if class_name in train_cls: - x_min, x_max = bbox[0], bbox[0] + bbox[2] - y_min, y_max = bbox[1], bbox[1] + bbox[3] - annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) - - if not is_training and iscrowd: - continue - if len(annos) >= 1: - images.append(img_id) - image_path_dict[img_id] = image_path - image_anno_dict[img_id] = np.array(annos) - - return images, image_path_dict, image_anno_dict - - -def anno_parser(annos_str): - """Parse annotation from string to list.""" - annos = [] - for anno_str in annos_str: - anno = list(map(int, anno_str.strip().split(','))) - annos.append(anno) - return annos - - -def filter_valid_data(image_dir, anno_path): - """Filter valid image file, which both in image_dir and anno_path.""" - images = [] - image_path_dict = {} - image_anno_dict = {} - if not os.path.isdir(image_dir): - raise RuntimeError("Path given is not valid.") - if not os.path.isfile(anno_path): - raise RuntimeError("Annotation file is not valid.") - - with open(anno_path, "rb") as f: - lines = f.readlines() - for img_id, line in enumerate(lines): - line_str = line.decode("utf-8").strip() - line_split = str(line_str).split(' ') - file_name = line_split[0] - image_path = os.path.join(image_dir, file_name) - if os.path.isfile(image_path): - images.append(img_id) - image_path_dict[img_id] = image_path - image_anno_dict[img_id] = anno_parser(line_split[1:]) - - return images, image_path_dict, image_anno_dict - - -def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): - """Create MindRecord file.""" - mindrecord_dir = config.mindrecord_dir - mindrecord_path = os.path.join(mindrecord_dir, prefix) - writer = FileWriter(mindrecord_path, file_num) - if dataset == "coco": - images, image_path_dict, image_anno_dict = create_coco_label(is_training) - else: - images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path) - - ssd_json = { - "img_id": {"type": "int32", "shape": [1]}, - "image": {"type": "bytes"}, - "annotation": {"type": "int32", "shape": [-1, 5]}, - } - writer.add_schema(ssd_json, "ssd_json") - - for img_id in images: - image_path = image_path_dict[img_id] - with open(image_path, 'rb') as f: - img = f.read() - annos = np.array(image_anno_dict[img_id], dtype=np.int32) - img_id = np.array([img_id], dtype=np.int32) - row = {"img_id": img_id, "image": img, "annotation": annos} - writer.write_raw_data([row]) - writer.commit() - - -def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, - is_training=True, num_parallel_workers=4): - """Creatr SSD dataset with MindDataset.""" - ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, - shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) - decode = C.Decode() - ds = ds.map(input_columns=["image"], operations=decode) - change_swap_op = C.HWC2CHW() - normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) - color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) - compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) - if is_training: - output_columns = ["image", "box", "label", "num_match"] - trans = [color_adjust_op, normalize_op, change_swap_op] - else: - output_columns = ["img_id", "image", "image_shape"] - trans = [normalize_op, change_swap_op] - ds = ds.map(input_columns=["img_id", "image", "annotation"], - output_columns=output_columns, columns_order=output_columns, - operations=compose_map_func, python_multiprocessing=is_training, - num_parallel_workers=num_parallel_workers) - ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training, - num_parallel_workers=num_parallel_workers) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) - return ds diff --git a/model_zoo/ssd/src/init_params.py b/model_zoo/ssd/src/init_params.py deleted file mode 100644 index 6e1f8869b3..0000000000 --- a/model_zoo/ssd/src/init_params.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Parameters utils""" - -from mindspore import Tensor -from mindspore.common.initializer import initializer, TruncatedNormal - -def init_net_param(network, initialize_mode='TruncatedNormal'): - """Init the parameters in net.""" - params = network.trainable_params() - for p in params: - if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: - if initialize_mode == 'TruncatedNormal': - p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape, p.data.dtype)) - else: - p.set_parameter_data(initialize_mode, p.data.shape, p.data.dtype) - - -def load_backbone_params(network, param_dict): - """Init the parameters from pre-train model, default is mobilenetv2.""" - for _, param in net.parameters_and_names(): - param_name = param.name.replace('network.backbone.', '') - name_split = param_name.split('.') - if 'features_1' in param_name: - param_name = param_name.replace('features_1', 'features') - if 'features_2' in param_name: - param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:]) - if param_name in param_dict: - param.set_parameter_data(param_dict[param_name].data) diff --git a/model_zoo/ssd/train.py b/model_zoo/ssd/train.py deleted file mode 100644 index 27f0e7ad0f..0000000000 --- a/model_zoo/ssd/train.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# less required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Train SSD and get checkpoint files.""" - -import os -import argparse -import mindspore.nn as nn -from mindspore import context, Tensor -from mindspore.communication.management import init -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor -from mindspore.train import Model, ParallelMode -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 -from src.config import config -from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image -from src.lr_schedule import get_lr -from src.init_params import init_net_param - - -def main(): - parser = argparse.ArgumentParser(description="SSD training") - parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " - "Mindrecord, default is False.") - parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.") - parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") - parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") - parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.") - parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") - parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") - parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") - parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") - parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") - args_opt = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - - if args_opt.distribute: - device_num = args_opt.device_num - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, - device_num=device_num) - init() - rank = args_opt.device_id % device_num - else: - rank = 0 - device_num = 1 - - print("Start create dataset!") - - # It will generate mindrecord file in args_opt.mindrecord_dir, - # and the file name is ssd.mindrecord0, 1, ... file_num. - - prefix = "ssd.mindrecord" - mindrecord_dir = config.mindrecord_dir - mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") - if not os.path.exists(mindrecord_file): - if not os.path.isdir(mindrecord_dir): - os.makedirs(mindrecord_dir) - if args_opt.dataset == "coco": - if os.path.isdir(config.coco_root): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("coco", True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("coco_root not exits.") - else: - if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("other", True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("image_dir or anno_path not exits.") - - if not args_opt.only_create_dataset: - loss_scale = float(args_opt.loss_scale) - - # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. - dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, - batch_size=args_opt.batch_size, device_num=device_num, rank=rank) - - dataset_size = dataset.get_dataset_size() - print("Create dataset done!") - - backbone = ssd_mobilenet_v2() - ssd = SSD300(backbone=backbone, config=config) - net = SSDWithLossCell(ssd, config) - init_net_param(net) - - # checkpoint - ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) - ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) - - if args_opt.pre_trained: - if args_opt.pre_trained_epoch_size <= 0: - raise KeyError("pre_trained_epoch_size must be greater than 0.") - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - - lr = Tensor(get_lr(global_step=config.global_step, - lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, - warmup_epochs=config.warmup_epochs, - total_epochs=args_opt.epoch_size, - steps_per_epoch=dataset_size)) - opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, - config.momentum, config.weight_decay, loss_scale) - net = TrainingWrapper(net, opt, loss_scale) - - callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] - - model = Model(net) - dataset_sink_mode = False - if args_opt.mode == "sink": - print("In sink mode, one epoch return a loss.") - dataset_sink_mode = True - print("Start train SSD, the first epoch will be slower because of the graph compilation.") - model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) - -if __name__ == '__main__': - main() diff --git a/model_zoo/utils/ascend_distributed_launcher/README.md b/model_zoo/utils/ascend_distributed_launcher/README.md new file mode 100644 index 0000000000..cefdaee3e8 --- /dev/null +++ b/model_zoo/utils/ascend_distributed_launcher/README.md @@ -0,0 +1,48 @@ +# Run distribute pretrain + +## description +The number of D chips can be automatically allocated based on the device_num set in hccl config file, You don not need to specify that. + + +## how to use +For example, if we want to run the distributed training of Bert model on D chip, we can in `/bert/` dir: +``` +python model_zoo/utils/ascend_distributed_launcher/run_distribute_pretrain.py --run_script_dir ./run_pretrain.py --hyper_parameter_config_dir model_zoo/utils/ascend_distributed_launcher/hyper_parameter_config.ini --data_dir /path/dataset/ --hccl_config_dir model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json +``` + +output: + +``` +hccl_config_dir: model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json +the number of logical core: 192 +avg_core_per_rank: 96 +rank_size: 2 + +start training for rank 0, device 5: +rank_id: 0 +device_id: 5 +core nums: 0-95 +epoch_size: 8 +data_dir: /data/small_512/ +schema_dir: +log file dir: ./LOG5/log.txt + +start training for rank 1, device 6: +rank_id: 1 +device_id: 6 +core nums: 96-191 +epoch_size: 8 +data_dir: /data/small_512/ +schema_dir: +log file dir: ./LOG6/log.txt +``` + +## Note + +1. Note that `hccl_2p_56_x.x.x.x.json` can use [hccl_tools.py](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate. + +2. For hyper parameter, please note that you should customize the scripts `hyper_parameter_config.ini`. Please note that these two hyper parameters are not allowed to be configured here: + device_id + device_num + +3. For Other Model, please note that you should customize the option `run_script` and Corresponding `hyper_parameter_config.ini`. diff --git a/model_zoo/utils/ascend_distributed_launcher/__init__.py b/model_zoo/utils/ascend_distributed_launcher/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/utils/ascend_distributed_launcher/hyper_parameter_config.ini b/model_zoo/utils/ascend_distributed_launcher/hyper_parameter_config.ini new file mode 100644 index 0000000000..2298f83509 --- /dev/null +++ b/model_zoo/utils/ascend_distributed_launcher/hyper_parameter_config.ini @@ -0,0 +1,11 @@ +[config] +distribute=true +epoch_size=40 +enable_save_ckpt=true +enable_lossscale=true +do_shuffle=true +enable_data_sink=true +data_sink_steps=100 +save_checkpoint_path=./checkpoint/ +save_checkpoint_steps=10000 +save_checkpoint_num=1 \ No newline at end of file diff --git a/model_zoo/utils/ascend_distributed_launcher/run_distribute_pretrain.py b/model_zoo/utils/ascend_distributed_launcher/run_distribute_pretrain.py new file mode 100644 index 0000000000..efc97e0fbe --- /dev/null +++ b/model_zoo/utils/ascend_distributed_launcher/run_distribute_pretrain.py @@ -0,0 +1,141 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""distribute pretrain script""" +import os +import json +import configparser +import multiprocessing +from argparse import ArgumentParser + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training") + + parser.add_argument("--run_script_dir", type=str, default="", + help="Run script path, it is better to use absolute path") + parser.add_argument("--hyper_parameter_config_dir", type=str, default="", + help="Hyper Parameter config path, it is better to use absolute path") + parser.add_argument("--data_dir", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--hccl_config_dir", type=str, default="", + help="Hccl config path, it is better to use absolute path") + + args = parser.parse_args() + return args + + +def distribute_pretrain(): + """ + distribute pretrain scripts. The number of D chips can be automatically allocated + based on the device_num set in hccl config file, You don not need to specify that. + """ + print("start", __file__) + args = parse_args() + + run_script = args.run_script_dir + data_dir = args.data_dir + cf = configparser.ConfigParser() + cf.read(args.hyper_parameter_config_dir) + cfg = dict(cf.items("config")) + + print("hccl_config_dir:", args.hccl_config_dir) + os.environ['RANK_TABLE_FILE'] = args.hccl_config_dir + + cores = multiprocessing.cpu_count() + print("the number of logical core:", cores) + + # get device_ips + device_ips = {} + with open('/etc/hccn.conf', 'r') as fin: + for hccn_item in fin.readlines(): + if hccn_item.strip().startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip.strip() + + with open(args.hccl_config_dir, "r", encoding="utf-8") as fin: + hccl_config = json.loads(fin.read()) + rank_size = 0 + for server in hccl_config["server_list"]: + rank_size += len(server["device"]) + if server["device"][0]["device_ip"] in device_ips.values(): + this_server = server + + os.environ['RANK_SIZE'] = str(rank_size) + print("total rank size:", rank_size) + print("this server rank size:", len(this_server["device"])) + avg_core_per_rank = int(int(cores) / len(this_server["device"])) + core_gap = avg_core_per_rank - 1 + print("avg_core_per_rank:", avg_core_per_rank) + + count = 0 + for instance in this_server["device"]: + device_id = instance["device_id"] + rank_id = instance["rank_id"] + print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":") + print("rank_id:", rank_id) + print("device_id:", device_id) + + start = count * int(avg_core_per_rank) + count += 1 + end = start + core_gap + cmdopt = str(start) + "-" + str(end) + + os.environ["DEVICE_ID"] = device_id + os.environ["RANK_ID"] = rank_id + os.environ["DEPLOY_MODE"] = "0" + os.environ["GE_USE_STATIC_MEMORY"] = "1" + + os.system("rm -rf LOG" + str(device_id)) + os.system("mkdir ./LOG" + str(device_id)) + os.system("cp *.py ./LOG" + str(device_id)) + os.system("mkdir -p ./LOG" + str(device_id) + "/ms_log") + os.system("env > ./LOG" + str(device_id) + "/env.log") + + cur_dir = os.getcwd() + os.environ["GLOG_log_dir"] = cur_dir + "/LOG" + str(device_id) + "/ms_log" + os.environ["GLOG_logtostderr"] = "0" + + print("core_nums:", cmdopt) + print("epoch_size:", str(cfg['epoch_size'])) + print("data_dir:", data_dir) + print("log_file_dir: ./LOG" + str(device_id) + "/log.txt") + + cmd = 'taskset -c ' + cmdopt + ' python ' + run_script + " " + opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()]) + if ('device_id' in opt) or ('device_num' in opt) or ('data_dir' in opt): + raise ValueError("hyper_parameter_config.ini can not setting 'device_id'," + " 'device_num' or 'data_dir'! ") + cmd += opt + cmd += " --data_dir=" + data_dir + cmd += ' --device_id=' + str(device_id) + ' --device_num=' \ + + str(rank_size) + ' >./LOG' + str(device_id) + '/log.txt 2>&1 &' + + os.system(cmd) + + +if __name__ == "__main__": + distribute_pretrain() diff --git a/model_zoo/utils/cv_to_mindrecord/ImageNet_Similar_Perf/README.md b/model_zoo/utils/cv_to_mindrecord/ImageNet_Similar_Perf/README.md index 8bdcb9e25d..d3bd5fdc18 100644 --- a/model_zoo/utils/cv_to_mindrecord/ImageNet_Similar_Perf/README.md +++ b/model_zoo/utils/cv_to_mindrecord/ImageNet_Similar_Perf/README.md @@ -28,12 +28,12 @@ This example provides an efficient way to generate MindRecord. Users only need t Store the downloaded ImageNet dataset in a folder. The folder contains all images and a mapping file that records labels of the images. - In the mapping file, there are three columns, which are separated by spaces. They indicate image classes, label IDs, and label names. The following is an example of the mapping file: + In the mapping file, there are three columns, which are separated by spaces. They indicate image classes and label IDs. The following is an example of the mapping file: ``` - n02119789 1 pen - n02100735 2 notbook - n02110185 3 mouse - n02096294 4 orange + n02119789 0 + n02100735 1 + n02110185 2 + n02096294 3 ``` 2. Edit run_imagenet.sh and modify the parameters diff --git a/model_zoo/utils/hccl_tools/hccl_tools.py b/model_zoo/utils/hccl_tools/hccl_tools.py index ac4114c0a8..5afcb23159 100644 --- a/model_zoo/utils/hccl_tools/hccl_tools.py +++ b/model_zoo/utils/hccl_tools/hccl_tools.py @@ -17,7 +17,6 @@ import os import sys import json import socket -import platform from argparse import ArgumentParser from typing import Dict, Any @@ -114,40 +113,25 @@ def main(): device_id = device_id.split('_')[1] device_ips[device_id] = device_ip.strip() - arch = platform.processor() - hccn_table = {'board_id': {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch], - 'chip_info': '910', - 'deploy_mode': 'lab', - 'group_count': '1', - 'group_list': []} - instance_list = [] + hccn_table = {'version': '1.0', + 'server_count': '1', + 'server_list': []} + device_list = [] rank_id = 0 for instance_id in device_num_list: - instance = {'devices': []} device_id = visible_devices[instance_id] device_ip = device_ips[device_id] - instance['devices'].append({ - 'device_id': device_id, - 'device_ip': device_ip, - }) + device = {'device_id': device_id, + 'device_ip': device_ip, + 'rank_id': str(rank_id)} print('rank_id:{}, device_id:{}, device_ip:{}'.format(rank_id, device_id, device_ip)) - instance['rank_id'] = str(rank_id) rank_id += 1 - instance['server_id'] = server_id - instance_list.append(instance) - hccn_table['group_list'].append({ - 'device_num': str(len(device_num_list)), - 'server_num': '1', - 'group_name': '', - 'instance_count': str(len(device_num_list)), - 'instance_list': instance_list, + device_list.append(device) + hccn_table['server_list'].append({ + 'server_id': server_id, + 'device': device_list, + 'host_nic_ip': 'reserve' }) - hccn_table['para_plane_nic_location'] = 'device' - hccn_table['para_plane_nic_name'] = [] - for instance_id in device_num_list: - eth_id = visible_devices[instance_id] - hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) - hccn_table['para_plane_nic_num'] = str(len(device_num_list)) hccn_table['status'] = 'completed' # save hccn_table to file diff --git a/model_zoo/vgg16/README.md b/model_zoo/vgg16/README.md deleted file mode 100644 index 53eb05f66d..0000000000 --- a/model_zoo/vgg16/README.md +++ /dev/null @@ -1,107 +0,0 @@ -# VGG16 Example - -## Description - -This example is for VGG16 model training and evaluation. - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Download the CIFAR-10 binary version dataset. - -> Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows: -> ``` -> . -> ├── cifar-10-batches-bin # train dataset -> └── cifar-10-verify-bin # infer dataset -> ``` - -## Running the Example - -### Training - -``` -python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 & -``` -The python command above will run in the background, you can view the results through the file `out.train.log`. - -After training, you'll get some checkpoint files under the script folder by default. - -You will get the loss value as following: -``` -# grep "loss is " out.train.log -epoch: 1 step: 781, loss is 2.093086 -epcoh: 2 step: 781, loss is 1.827582 -... -``` - -### Evaluation - -``` -python eval.py --data_path=your_data_path --device_id=6 --checkpoint_path=./train_vgg_cifar10-70-781.ckpt > out.eval.log 2>&1 & -``` -The above python command will run in the background, you can view the results through the file `out.eval.log`. - -You will get the accuracy as following: -``` -# grep "result: " out.eval.log -result: {'acc': 0.92} -``` - -### Distribute Training -``` -sh run_distribute_train.sh rank_table.json your_data_path -``` -The above shell script will run distribute training in the background, you can view the results through the file `train_parallel[X]/log`. - -You will get the loss value as following: -``` -# grep "result: " train_parallel*/log -train_parallel0/log:epoch: 1 step: 97, loss is 1.9060308 -train_parallel0/log:epcoh: 2 step: 97, loss is 1.6003821 -... -train_parallel1/log:epoch: 1 step: 97, loss is 1.7095519 -train_parallel1/log:epcoh: 2 step: 97, loss is 1.7133579 -... -... -``` -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). - -## Usage: - -### Training -``` -usage: train.py [--device_target TARGET][--data_path DATA_PATH] - [--device_id DEVICE_ID][--pre_trained PRE_TRAINED] - -parameters/options: - --device_target the training backend type, default is Ascend. - --data_path the storage path of dataset - --device_id the device which used to train model. - --pre_trained the pretrained checkpoint file path. - -``` - -### Evaluation - -``` -usage: eval.py [--device_target TARGET][--data_path DATA_PATH] - [--device_id DEVICE_ID][--checkpoint_path CKPT_PATH] - -parameters/options: - --device_target the evaluation backend type, default is Ascend. - --data_path the storage path of datasetd - --device_id the device which used to evaluate model. - --checkpoint_path the checkpoint file path used to evaluate model. -``` - -### Distribute Training - -``` -Usage: sh script/run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH] - -parameters/options: - MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path. - DATA_PATH the storage path of dataset. -``` diff --git a/model_zoo/vgg16/eval.py b/model_zoo/vgg16/eval.py deleted file mode 100644 index 8cdcc86031..0000000000 --- a/model_zoo/vgg16/eval.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -##############test vgg16 example on cifar10################# -python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID -""" -import argparse - -import mindspore.nn as nn -from mindspore import context -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.config import cifar_cfg as cfg -from src.dataset import vgg_create_dataset -from src.vgg import vgg16 - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') - parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved') - parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.') - parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') - args_opt = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(device_id=args_opt.device_id) - - net = vgg16(num_classes=cfg.num_classes) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - - param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(net, param_dict) - net.set_train(False) - dataset = vgg_create_dataset(args_opt.data_path, 1, False) - res = model.eval(dataset) - print("result: ", res) diff --git a/model_zoo/vgg16/scripts/run_distribute_train.sh b/model_zoo/vgg16/scripts/run_distribute_train.sh deleted file mode 100755 index ca4c993ded..0000000000 --- a/model_zoo/vgg16/scripts/run_distribute_train.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 2 ] -then - echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]" -exit 1 -fi - -if [ ! -f $1 ] -then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" -exit 1 -fi - -if [ ! -d $2 ] -then - echo "error: DATA_PATH=$2 is not a directory" -exit 1 -fi - -export DEVICE_NUM=8 -export RANK_SIZE=8 -export MINDSPORE_HCCL_CONFIG_PATH=$1 - -for((i=0;i env.log - python train.py --data_path=$2 --device_id=$i &> log & - cd .. -done \ No newline at end of file diff --git a/model_zoo/vgg16/src/__init__.py b/model_zoo/vgg16/src/__init__.py deleted file mode 100644 index 301ef9dcb7..0000000000 --- a/model_zoo/vgg16/src/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the License); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# httpwww.apache.orglicensesLICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an AS IS BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ diff --git a/model_zoo/vgg16/src/config.py b/model_zoo/vgg16/src/config.py deleted file mode 100644 index a34cf7a1d3..0000000000 --- a/model_zoo/vgg16/src/config.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in main.py -""" -from easydict import EasyDict as edict - -cifar_cfg = edict({ - 'num_classes': 10, - 'lr_init': 0.01, - 'lr_max': 0.1, - 'warmup_epochs': 5, - 'batch_size': 64, - 'epoch_size': 70, - 'momentum': 0.9, - 'weight_decay': 5e-4, - 'buffer_size': 10, - 'image_height': 224, - 'image_width': 224, - 'keep_checkpoint_max': 10 -}) diff --git a/model_zoo/vgg16/src/dataset.py b/model_zoo/vgg16/src/dataset.py deleted file mode 100644 index b08659fb5e..0000000000 --- a/model_zoo/vgg16/src/dataset.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Data operations, will be used in train.py and eval.py -""" -import os - -import mindspore.common.dtype as mstype -import mindspore.dataset as ds -import mindspore.dataset.transforms.c_transforms as C -import mindspore.dataset.transforms.vision.c_transforms as vision -from .config import cifar_cfg as cfg - - -def vgg_create_dataset(data_home, repeat_num=1, training=True): - """Data operations.""" - ds.config.set_seed(1) - data_dir = os.path.join(data_home, "cifar-10-batches-bin") - if not training: - data_dir = os.path.join(data_home, "cifar-10-verify-bin") - - rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None - rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None - data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) - - resize_height = cfg.image_height - resize_width = cfg.image_width - rescale = 1.0 / 255.0 - shift = 0.0 - - # define map operations - random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT - random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR - rescale_op = vision.Rescale(rescale, shift) - normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) - changeswap_op = vision.HWC2CHW() - type_cast_op = C.TypeCast(mstype.int32) - - c_trans = [] - if training: - c_trans = [random_crop_op, random_horizontal_op] - c_trans += [resize_op, rescale_op, normalize_op, - changeswap_op] - - # apply map operations on images - data_set = data_set.map(input_columns="label", operations=type_cast_op) - data_set = data_set.map(input_columns="image", operations=c_trans) - - # apply repeat operations - data_set = data_set.repeat(repeat_num) - - # apply shuffle operations - data_set = data_set.shuffle(buffer_size=10) - - # apply batch operations - data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) - - return data_set diff --git a/model_zoo/vgg16/src/vgg.py b/model_zoo/vgg16/src/vgg.py deleted file mode 100644 index 55130871cc..0000000000 --- a/model_zoo/vgg16/src/vgg.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""VGG.""" -import mindspore.nn as nn -from mindspore.common.initializer import initializer -import mindspore.common.dtype as mstype - -def _make_layer(base, batch_norm): - """Make stage network of VGG.""" - layers = [] - in_channels = 3 - for v in base: - if v == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - else: - weight_shape = (v, in_channels, 3, 3) - weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() - conv2d = nn.Conv2d(in_channels=in_channels, - out_channels=v, - kernel_size=3, - padding=0, - pad_mode='same', - weight_init=weight) - if batch_norm: - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] - else: - layers += [conv2d, nn.ReLU()] - in_channels = v - return nn.SequentialCell(layers) - - -class Vgg(nn.Cell): - """ - VGG network definition. - - Args: - base (list): Configuration for different layers, mainly the channel number of Conv layer. - num_classes (int): Class numbers. Default: 1000. - batch_norm (bool): Whether to do the batchnorm. Default: False. - batch_size (int): Batch size. Default: 1. - - Returns: - Tensor, infer output tensor. - - Examples: - >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - >>> num_classes=1000, batch_norm=False, batch_size=1) - """ - - def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): - super(Vgg, self).__init__() - _ = batch_size - self.layers = _make_layer(base, batch_norm=batch_norm) - self.flatten = nn.Flatten() - self.classifier = nn.SequentialCell([ - nn.Dense(512 * 7 * 7, 4096), - nn.ReLU(), - nn.Dense(4096, 4096), - nn.ReLU(), - nn.Dense(4096, num_classes)]) - - def construct(self, x): - x = self.layers(x) - x = self.flatten(x) - x = self.classifier(x) - return x - - -cfg = { - '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], -} - - -def vgg16(num_classes=1000): - """ - Get Vgg16 neural network with batch normalization. - - Args: - num_classes (int): Class numbers. Default: 1000. - - Returns: - Cell, cell instance of Vgg16 neural network with batch normalization. - - Examples: - >>> vgg16(num_classes=1000) - """ - - net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) - return net diff --git a/model_zoo/vgg16/train.py b/model_zoo/vgg16/train.py deleted file mode 100644 index 33a4f0310c..0000000000 --- a/model_zoo/vgg16/train.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -#################train vgg16 example on cifar10######################## -python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID -""" -import argparse -import os -import random - -import numpy as np - -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import context -from mindspore.communication.management import init -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.model import Model, ParallelMode -from mindspore.train.serialization import load_param_into_net, load_checkpoint -from src.config import cifar_cfg as cfg -from src.dataset import vgg_create_dataset -from src.vgg import vgg16 - -random.seed(1) -np.random.seed(1) - - -def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): - """Set learning rate.""" - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - warmup_steps = steps_per_epoch * warmup_epochs - if warmup_steps != 0: - inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) - else: - inc_each_step = 0 - for i in range(total_steps): - if i < warmup_steps: - lr_value = float(lr_init) + inc_each_step * float(i) - else: - base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) - lr_value = float(lr_max) * base * base - if lr_value < 0.0: - lr_value = 0.0 - lr_each_step.append(lr_value) - - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] - - return learning_rate - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') - parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved') - parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') - parser.add_argument('--pre_trained', type=str, default=None, help='the pretrained checkpoint file path.') - args_opt = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(device_id=args_opt.device_id) - - device_num = int(os.environ.get("DEVICE_NUM", 1)) - if device_num > 1: - context.reset_auto_parallel_context() - context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - init() - - dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size) - batch_num = dataset.get_dataset_size() - - net = vgg16(num_classes=cfg.num_classes) - # pre_trained - if args_opt.pre_trained: - load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) - - lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, - total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) - - config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) - time_cb = TimeMonitor(data_size=batch_num) - ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck) - loss_cb = LossMonitor() - model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) - print("train success") diff --git a/model_zoo/warpctc/README.md b/model_zoo/warpctc/README.md deleted file mode 100644 index cb941255bf..0000000000 --- a/model_zoo/warpctc/README.md +++ /dev/null @@ -1,137 +0,0 @@ -# Warpctc Example - -## Description - -These is an example of training Warpctc with self-generated captcha image dataset in MindSpore. - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Generate captcha images. - -> The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately. -> ``` -> $ cd scripts -> $ sh run_process_data.sh -> -> # after execution, you will find the dataset like the follows: -> . -> └─warpctc -> └─data -> ├─ train # train dataset -> └─ test # evaluate dataset -> ... - - -## Structure - -```shell -. -└──warpct - ├── README.md - ├── script - ├── run_distribute_train.sh # launch distributed training(8 pcs) - ├── run_eval.sh # launch evaluation - ├── run_process_data.sh # launch dataset generation - └── run_standalone_train.sh # launch standalone training(1 pcs) - ├── src - ├── config.py # parameter configuration - ├── dataset.py # data preprocessing - ├── loss.py # ctcloss definition - ├── lr_generator.py # generate learning rate for each step - ├── metric.py # accuracy metric for warpctc network - ├── warpctc.py # warpctc network definition - └── warpctc_for_train.py # warp network with grad, loss and gradient clip - ├── eval.py # eval net - ├── process_data.py # dataset generation script - └── train.py # train net -``` - - -## Parameter configuration - -Parameters for both training and evaluation can be set in config.py. - -``` -"max_captcha_digits": 4, # max number of digits in each -"captcha_width": 160, # width of captcha images -"captcha_height": 64, # height of capthca images -"batch_size": 64, # batch size of input tensor -"epoch_size": 30, # only valid for taining, which is always 1 for inference -"hidden_size": 512, # hidden size in LSTM layers -"learning_rate": 0.01, # initial learning rate -"momentum": 0.9 # momentum of SGD optimizer -"save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_steps": 98, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step -"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint -"save_checkpoint_path": "./", # path to save checkpoint -``` - -## Running the example - -### Train - -#### Usage - -``` -# distributed training -Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] - -# standalone training -Usage: sh run_standalone_train.sh [DATASET_PATH] -``` - - -#### Launch - -``` -# distribute training example -sh run_distribute_train.sh rank_table.json ../data/train - -# standalone training example -sh run_standalone_train.sh ../data/train -``` - -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). - -#### Result - -Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. - -``` -# distribute training result(8 pcs) -Epoch: [ 1/ 30], step: [ 98/ 98], loss: [0.5853/0.5853], time: [376813.7944] -Epoch: [ 2/ 30], step: [ 98/ 98], loss: [0.4007/0.4007], time: [75882.0951] -Epoch: [ 3/ 30], step: [ 98/ 98], loss: [0.0921/0.0921], time: [75150.9385] -Epoch: [ 4/ 30], step: [ 98/ 98], loss: [0.1472/0.1472], time: [75135.0193] -Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809] -... -``` - - -### Evaluation - -#### Usage - -``` -# evaluation -Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] -``` - -#### Launch - -``` -# evaluation example -sh run_eval.sh ../data/test warpctc-30-98.ckpt -``` - -> checkpoint can be produced in training process. - -#### Result - -Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. - -``` -result: {'WarpCTCAccuracy': 0.9901472929936306} -``` diff --git a/model_zoo/warpctc/eval.py b/model_zoo/warpctc/eval.py deleted file mode 100755 index df62c7c755..0000000000 --- a/model_zoo/warpctc/eval.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Warpctc evaluation""" -import os -import math as m -import random -import argparse -import numpy as np -from mindspore import context -from mindspore import dataset as de -from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from src.loss import CTCLoss -from src.config import config as cf -from src.dataset import create_dataset -from src.warpctc import StackedRNN -from src.metric import WarpCTCAccuracy - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description="Warpctc training") -parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") -parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") -args_opt = parser.parse_args() - -device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - save_graphs=False, - device_id=device_id) - -if __name__ == '__main__': - max_captcha_digits = cf.max_captcha_digits - input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 - # create dataset - dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) - step_size = dataset.get_dataset_size() - # define loss - loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) - # define net - net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) - # load checkpoint - param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(net, param_dict) - net.set_train(False) - # define model - model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()}) - # start evaluation - res = model.eval(dataset) - print("result:", res, flush=True) diff --git a/model_zoo/warpctc/scripts/run_distribute_train.sh b/model_zoo/warpctc/scripts/run_distribute_train.sh deleted file mode 100755 index 3cebf6d195..0000000000 --- a/model_zoo/warpctc/scripts/run_distribute_train.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 2 ]; then - echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]" - exit 1 -fi - -get_real_path() { - if [ "${1:0:1}" == "/" ]; then - echo "$1" - else - echo "$(realpath -m $PWD/$1)" - fi -} - -PATH1=$(get_real_path $1) -PATH2=$(get_real_path $2) - -if [ ! -f $PATH1 ]; then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file" - exit 1 -fi - -if [ ! -d $PATH2 ]; then - echo "error: DATASET_PATH=$PATH2 is not a directory" - exit 1 -fi - -ulimit -u unlimited -export DEVICE_NUM=8 -export RANK_SIZE=8 -export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 -export RANK_TABLE_FILE=$PATH1 - -for ((i = 0; i < ${DEVICE_NUM}; i++)); do - export DEVICE_ID=$i - export RANK_ID=$i - rm -rf ./train_parallel$i - mkdir ./train_parallel$i - cp ../*.py ./train_parallel$i - cp *.sh ./train_parallel$i - cp -r ../src ./train_parallel$i - cd ./train_parallel$i || exit - echo "start training for rank $RANK_ID, device $DEVICE_ID" - env >env.log - python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log & - cd .. -done diff --git a/model_zoo/warpctc/scripts/run_eval.sh b/model_zoo/warpctc/scripts/run_eval.sh deleted file mode 100755 index 659de6d72a..0000000000 --- a/model_zoo/warpctc/scripts/run_eval.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 2 ]; then - echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" - exit 1 -fi - -get_real_path() { - if [ "${1:0:1}" == "/" ]; then - echo "$1" - else - echo "$(realpath -m $PWD/$1)" - fi -} - -PATH1=$(get_real_path $1) -PATH2=$(get_real_path $2) - -if [ ! -d $PATH1 ]; then - echo "error: DATASET_PATH=$PATH1 is not a directory" - exit 1 -fi - -if [ ! -f $PATH2 ]; then - echo "error: CHECKPOINT_PATH=$PATH2 is not a file" - exit 1 -fi - -ulimit -u unlimited -export DEVICE_NUM=1 -export DEVICE_ID=0 -export RANK_SIZE=$DEVICE_NUM -export RANK_ID=0 - -if [ -d "eval" ]; then - rm -rf ./eval -fi -mkdir ./eval -cp ../*.py ./eval -cp *.sh ./eval -cp -r ../src ./eval -cd ./eval || exit -env >env.log -echo "start evaluation for device $DEVICE_ID" -python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log & -cd .. diff --git a/model_zoo/warpctc/scripts/run_process_data.sh b/model_zoo/warpctc/scripts/run_process_data.sh deleted file mode 100755 index 56b89f1a72..0000000000 --- a/model_zoo/warpctc/scripts/run_process_data.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -CUR_PATH=$(dirname $PWD/$0) -cd $CUR_PATH/../ && - python process_data.py && - cd - || exit \ No newline at end of file diff --git a/model_zoo/warpctc/scripts/run_standalone_train.sh b/model_zoo/warpctc/scripts/run_standalone_train.sh deleted file mode 100755 index 22a16ef4c8..0000000000 --- a/model_zoo/warpctc/scripts/run_standalone_train.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 1 ]; then - echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" - exit 1 -fi - -get_real_path() { - if [ "${1:0:1}" == "/" ]; then - echo "$1" - else - echo "$(realpath -m $PWD/$1)" - fi -} - -PATH1=$(get_real_path $1) - -if [ ! -d $PATH1 ]; then - echo "error: DATASET_PATH=$PATH1 is not a directory" - exit 1 -fi - -ulimit -u unlimited -export DEVICE_NUM=1 -export DEVICE_ID=0 -export RANK_ID=0 -export RANK_SIZE=1 - -if [ -d "train" ]; then - rm -rf ./train -fi -mkdir ./train -cp ../*.py ./train -cp *.sh ./train -cp -r ../src ./train -cd ./train || exit -echo "start training for device $DEVICE_ID" -env >env.log -python train.py --dataset=$PATH1 &>log & -cd .. diff --git a/model_zoo/warpctc/src/config.py b/model_zoo/warpctc/src/config.py deleted file mode 100755 index ed9c2968de..0000000000 --- a/model_zoo/warpctc/src/config.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Network parameters.""" -from easydict import EasyDict - -config = EasyDict({ - "max_captcha_digits": 4, - "captcha_width": 160, - "captcha_height": 64, - "batch_size": 64, - "epoch_size": 30, - "hidden_size": 512, - "learning_rate": 0.01, - "momentum": 0.9, - "save_checkpoint": True, - "save_checkpoint_steps": 98, - "keep_checkpoint_max": 30, - "save_checkpoint_path": "./", -}) diff --git a/model_zoo/warpctc/src/dataset.py b/model_zoo/warpctc/src/dataset.py deleted file mode 100755 index 76e592b906..0000000000 --- a/model_zoo/warpctc/src/dataset.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Dataset preprocessing.""" -import os -import math as m -import numpy as np -import mindspore.common.dtype as mstype -import mindspore.dataset.engine as de -import mindspore.dataset.transforms.c_transforms as c -import mindspore.dataset.transforms.vision.c_transforms as vc -from PIL import Image -from src.config import config as cf - - -class _CaptchaDataset(): - """ - create train or evaluation dataset for warpctc - - Args: - img_root_dir(str): root path of images - max_captcha_digits(int): max number of digits in images. - blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label - length is less than max_captcha_digits, the remaining labels are padding with blank. - """ - - def __init__(self, img_root_dir, max_captcha_digits, blank=10): - if not os.path.exists(img_root_dir): - raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) - self.img_root_dir = img_root_dir - self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] - self.max_captcha_digits = max_captcha_digits - self.blank = blank - - def __len__(self): - return len(self.img_names) - - def __getitem__(self, item): - img_name = self.img_names[item] - im = Image.open(os.path.join(self.img_root_dir, img_name)) - r, g, b = im.split() - im = Image.merge("RGB", (b, g, r)) - image = np.array(im) - label_str = os.path.splitext(img_name)[0] - label_str = label_str[label_str.find('-') + 1:] - label = [int(i) for i in label_str] - label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) - label = np.array(label) - return image, label - - -def create_dataset(dataset_path, repeat_num=1, batch_size=1): - """ - create train or evaluation dataset for warpctc - - Args: - dataset_path(int): dataset path - repeat_num(int): dataset repetition num, default is 1 - batch_size(int): batch size of generated dataset, default is 1 - """ - rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1 - rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0 - - dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits) - ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id) - ds.set_dataset_size(m.ceil(len(dataset) / rank_size)) - image_trans = [ - vc.Rescale(1.0 / 255.0, 0.0), - vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), - vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), - vc.HWC2CHW() - ] - label_trans = [ - c.TypeCast(mstype.int32) - ] - ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) - ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) - - ds = ds.batch(batch_size) - ds = ds.repeat(repeat_num) - return ds diff --git a/model_zoo/warpctc/src/loss.py b/model_zoo/warpctc/src/loss.py deleted file mode 100755 index 8ea4c20e94..0000000000 --- a/model_zoo/warpctc/src/loss.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CTC Loss.""" -import numpy as np -from mindspore.nn.loss.loss import _Loss -from mindspore import Tensor, Parameter -from mindspore.common import dtype as mstype -from mindspore.ops import operations as P - - -class CTCLoss(_Loss): - """ - CTCLoss definition - - Args: - max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image - width - max_label_length(int): max number of label length for each input. - batch_size(int): batch size of input logits - """ - - def __init__(self, max_sequence_length, max_label_length, batch_size): - super(CTCLoss, self).__init__() - self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32), - name="sequence_length") - labels_indices = [] - for i in range(batch_size): - for j in range(max_label_length): - labels_indices.append([i, j]) - self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices") - self.reshape = P.Reshape() - self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True) - - def construct(self, logit, label): - labels_values = self.reshape(label, (-1,)) - loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) - return loss diff --git a/model_zoo/warpctc/src/metric.py b/model_zoo/warpctc/src/metric.py deleted file mode 100755 index d1060d0781..0000000000 --- a/model_zoo/warpctc/src/metric.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Metric for accuracy evaluation.""" -from mindspore import nn - -BLANK_LABLE = 10 - - -class WarpCTCAccuracy(nn.Metric): - """ - Define accuracy metric for warpctc network. - """ - - def __init__(self): - super(WarpCTCAccuracy).__init__() - self._correct_num = 0 - self._total_num = 0 - self._count = 0 - - def clear(self): - self._correct_num = 0 - self._total_num = 0 - - def update(self, *inputs): - if len(inputs) != 2: - raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) - - y_pred = self._convert_data(inputs[0]) - y = self._convert_data(inputs[1]) - - self._count += 1 - - pred_lbls = self._get_prediction(y_pred) - - for b_idx, target in enumerate(y): - if self._is_eq(pred_lbls[b_idx], target): - self._correct_num += 1 - self._total_num += 1 - - def eval(self): - if self._total_num == 0: - raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') - return self._correct_num / self._total_num - - @staticmethod - def _is_eq(pred_lbl, target): - """ - check whether predict label is equal to target label - """ - target = target.tolist() - pred_diff = len(target) - len(pred_lbl) - if pred_diff > 0: - # padding by BLANK_LABLE - pred_lbl.extend([BLANK_LABLE] * pred_diff) - return pred_lbl == target - - @staticmethod - def _get_prediction(y_pred): - """ - parse predict result to labels - """ - seq_len, batch_size, _ = y_pred.shape - indices = y_pred.argmax(axis=2) - - lens = [seq_len] * batch_size - pred_lbls = [] - for i in range(batch_size): - idx = indices[:, i] - last_idx = BLANK_LABLE - pred_lbl = [] - for j in range(lens[i]): - cur_idx = idx[j] - if cur_idx not in [last_idx, BLANK_LABLE]: - pred_lbl.append(cur_idx) - last_idx = cur_idx - pred_lbls.append(pred_lbl) - return pred_lbls diff --git a/model_zoo/warpctc/src/warpctc.py b/model_zoo/warpctc/src/warpctc.py deleted file mode 100755 index 9669fc4bfd..0000000000 --- a/model_zoo/warpctc/src/warpctc.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Warpctc network definition.""" - -import numpy as np -import mindspore.nn as nn -from mindspore import Tensor, Parameter -from mindspore.common import dtype as mstype -from mindspore.ops import operations as P -from mindspore.ops import functional as F - - -class StackedRNN(nn.Cell): - """ - Define a stacked RNN network which contains two LSTM layers and one full-connect layer. - - Args: - input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for - captcha images. - batch_size(int): batch size of input data, default is 64 - hidden_size(int): the hidden size in LSTM layers, default is 512 - """ - def __init__(self, input_size, batch_size=64, hidden_size=512): - super(StackedRNN, self).__init__() - self.batch_size = batch_size - self.input_size = input_size - self.num_classes = 11 - self.reshape = P.Reshape() - self.cast = P.Cast() - k = (1 / hidden_size) ** 0.5 - self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) - self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) - self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1)) - .astype(np.float16), name="w1") - self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1)) - .astype(np.float16), name="w2") - self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1") - self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2") - - self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) - self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) - - self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh") - - self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) - self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) - - self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), - bias_init=Tensor(self.fc_bias)) - - self.fc.to_float(mstype.float32) - self.expand_dims = P.ExpandDims() - self.concat = P.Concat() - self.transpose = P.Transpose() - - def construct(self, x): - x = self.cast(x, mstype.float16) - x = self.transpose(x, (3, 0, 2, 1)) - x = self.reshape(x, (-1, self.batch_size, self.input_size)) - h1 = self.h1 - c1 = self.c1 - h2 = self.h2 - c2 = self.c2 - - c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1) - c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) - - h2_after_fc = self.fc(h2) - output = self.expand_dims(h2_after_fc, 0) - for i in range(1, F.shape(x)[0]): - c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1) - c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) - - h2_after_fc = self.fc(h2) - h2_after_fc = self.expand_dims(h2_after_fc, 0) - output = self.concat((output, h2_after_fc)) - - return output diff --git a/model_zoo/warpctc/src/warpctc_for_train.py b/model_zoo/warpctc/src/warpctc_for_train.py deleted file mode 100755 index d847f47c62..0000000000 --- a/model_zoo/warpctc/src/warpctc_for_train.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Automatic differentiation with grad clip.""" -from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, - _get_parallel_mode) -from mindspore.train.parallel_utils import ParallelMode -from mindspore.common import dtype as mstype -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.ops import operations as P -from mindspore.nn.cell import Cell -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -import mindspore.nn as nn -from mindspore.common.tensor import Tensor -import numpy as np - -compute_norm = C.MultitypeFuncGraph("compute_norm") - - -@compute_norm.register("Tensor") -def _compute_norm(grad): - norm = nn.Norm() - norm = norm(F.cast(grad, mstype.float32)) - ret = F.expand_dims(F.cast(norm, mstype.float32), 0) - return ret - - -grad_div = C.MultitypeFuncGraph("grad_div") - - -@grad_div.register("Tensor", "Tensor") -def _grad_div(val, grad): - div = P.Div() - mul = P.Mul() - grad = mul(grad, 10.0) - ret = div(grad, val) - return ret - - -class TrainOneStepCellWithGradClip(Cell): - """ - Network training package class. - - Wraps the network with an optimizer. The resulting Cell be trained with input data and label. - Backward graph with grad clip will be created in the construct function to do parameter updating. - Different parallel modes are available to run the training. - - Args: - network (Cell): The training network. - optimizer (Cell): Optimizer for updating the weights. - sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. - - Inputs: - - data (Tensor) - Tensor of shape :(N, ...). - - label (Tensor) - Tensor of shape :(N, ...). - - Outputs: - Tensor, a scalar Tensor with shape :math:`()`. - """ - - def __init__(self, network, optimizer, sens=1.0): - super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.network.add_flags(defer_inline=True) - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) - self.sens = sens - self.reducer_flag = False - self.grad_reducer = None - self.hyper_map = C.HyperMap() - self.greater = P.Greater() - self.select = P.Select() - self.norm = nn.Norm(keep_dims=True) - self.dtype = P.DType() - self.cast = P.Cast() - self.concat = P.Concat(axis=0) - self.ten = Tensor(np.array([10.0]).astype(np.float32)) - parallel_mode = _get_parallel_mode() - if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): - self.reducer_flag = True - if self.reducer_flag: - mean = _get_mirror_mean() - degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - - def construct(self, data, label): - weights = self.weights - loss = self.network(data, label) - sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - grads = self.grad(self.network, weights)(data, label, sens) - norm = self.hyper_map(F.partial(compute_norm), grads) - norm = self.concat(norm) - norm = self.norm(norm) - cond = self.greater(norm, self.cast(self.ten, self.dtype(norm))) - clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm))) - grads = self.hyper_map(F.partial(grad_div, clip_val), grads) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) - return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/warpctc/train.py b/model_zoo/warpctc/train.py deleted file mode 100755 index 651d2a73a4..0000000000 --- a/model_zoo/warpctc/train.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Warpctc training""" -import os -import math as m -import random -import argparse -import numpy as np -import mindspore.nn as nn -from mindspore import context -from mindspore import dataset as de -from mindspore.train.model import Model, ParallelMode -from mindspore.nn.wrap import WithLossCell -from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint -from mindspore.communication.management import init - -from src.loss import CTCLoss -from src.config import config as cf -from src.dataset import create_dataset -from src.warpctc import StackedRNN -from src.warpctc_for_train import TrainOneStepCellWithGradClip -from src.lr_schedule import get_lr - -random.seed(1) -np.random.seed(1) -de.config.set_seed(1) - -parser = argparse.ArgumentParser(description="Warpctc training") -parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") -parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.') -parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') -args_opt = parser.parse_args() - -device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - save_graphs=False, - device_id=device_id) - -if __name__ == '__main__': - if args_opt.run_distribute: - context.reset_auto_parallel_context() - context.set_auto_parallel_context(device_num=args_opt.device_num, - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - init() - max_captcha_digits = cf.max_captcha_digits - input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 - # create dataset - dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size) - step_size = dataset.get_dataset_size() - # define lr - lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num - lr = get_lr(cf.epoch_size, step_size, lr_init) - # define loss - loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) - # define net - net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) - # define opt - opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) - net = WithLossCell(net, loss) - net = TrainOneStepCellWithGradClip(net, opt).set_train() - # define model - model = Model(net) - # define callbacks - callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] - if cf.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, - keep_checkpoint_max=cf.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck) - callbacks.append(ckpt_cb) - model.train(cf.epoch_size, dataset, callbacks=callbacks) diff --git a/model_zoo/wide_and_deep/README.md b/model_zoo/wide_and_deep/README.md deleted file mode 100644 index 000e6a5335..0000000000 --- a/model_zoo/wide_and_deep/README.md +++ /dev/null @@ -1,107 +0,0 @@ -Recommendation Model -## Overview -This is an implementation of WideDeep as described in the [Wide & Deep Learning for Recommender System](https://arxiv.org/pdf/1606.07792.pdf) paper. - -WideDeep model jointly trained wide linear models and deep neural network, which combined the benefits of memorization and generalization for recommender systems. - -## Dataset -The Criteo datasets are used for model training and evaluation. - -## Running Code - -### Code Structure -The entire code structure is as following: -``` -|--- wide_and_deep/ - train_and_eval.py "Entrance of Wide&Deep model training and evaluation" - eval.py "Entrance of Wide&Deep model evaluation" - train.py "Entrance of Wide&Deep model training" - train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" - train_and_eval_auto_parallel.py - |--- src/ "Entrance of training and evaluation" - config.py "Parameters configuration" - dataset.py "Dataset loader class" - process_data.py "Process dataset" - preprocess_data.py "Pre_process dataset" - wide_and_deep.py "Model structure" - callbacks.py "Callback class for training and evaluation" - metrics.py "Metric class" - |--- script/ "Run shell dir" - run_multinpu_train.sh "Run data parallel" - run_auto_parallel_train.sh "Run auto parallel" -``` - -### Train and evaluate model -To train and evaluate the model, command as follows: -``` -python train_and_eval.py -``` -Arguments: - * `--device_target`: Device where the code will be implemented (Default: Ascend). - * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. - * `--epochs`: Total train epochs. - * `--batch_size`: Training batch size. - * `--eval_batch_size`: Eval batch size. - * `--field_size`: The number of features. - * `--vocab_size`: The total features of dataset. - * `--emb_dim`: The dense embedding dimension of sparse feature. - * `--deep_layers_dim`: The dimension of all deep layers. - * `--deep_layers_act`: The activation of all deep layers. - * `--dropout_flag`: Whether do dropout. - * `--keep_prob`: The rate to keep in dropout layer. - * `--ckpt_path`:The location of the checkpoint file. - * `--eval_file_name` : Eval output file. - * `--loss_file_name` : Loss output file. - -To train the model in one device, command as follows: -``` -python train.py -``` -Arguments: - * `--device_target`: Device where the code will be implemented (Default: Ascend). - * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. - * `--epochs`: Total train epochs. - * `--batch_size`: Training batch size. - * `--eval_batch_size`: Eval batch size. - * `--field_size`: The number of features. - * `--vocab_size`: The total features of dataset. - * `--emb_dim`: The dense embedding dimension of sparse feature. - * `--deep_layers_dim`: The dimension of all deep layers. - * `--deep_layers_act`: The activation of all deep layers. - * `--dropout_flag`: Whether do dropout. - * `--keep_prob`: The rate to keep in dropout layer. - * `--ckpt_path`:The location of the checkpoint file. - * `--eval_file_name` : Eval output file. - * `--loss_file_name` : Loss output file. - -To train the model in distributed, command as follows: -``` -# configure environment path before training -bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE -``` -``` -# configure environment path before training -bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE -``` - -To evaluate the model, command as follows: -``` -python eval.py -``` -Arguments: - * `--device_target`: Device where the code will be implemented (Default: Ascend). - * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. - * `--epochs`: Total train epochs. - * `--batch_size`: Training batch size. - * `--eval_batch_size`: Eval batch size. - * `--field_size`: The number of features. - * `--vocab_size`: The total features of dataset. - * `--emb_dim`: The dense embedding dimension of sparse feature. - * `--deep_layers_dim`: The dimension of all deep layers. - * `--deep_layers_act`: The activation of all deep layers. - * `--keep_prob`: The rate to keep in dropout layer. - * `--ckpt_path`:The location of the checkpoint file. - * `--eval_file_name` : Eval output file. - * `--loss_file_name` : Loss output file. - -There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions. diff --git a/model_zoo/wide_and_deep/eval.py b/model_zoo/wide_and_deep/eval.py deleted file mode 100644 index bc3846533f..0000000000 --- a/model_zoo/wide_and_deep/eval.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" test_training """ - -import os - -from mindspore import Model, context -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel -from src.callbacks import LossCallBack, EvalCallBack -from src.datasets import create_dataset -from src.metrics import AUCMetric -from src.config import WideDeepConfig - - -def get_WideDeep_net(config): - """ - Get network of wide&deep model. - """ - WideDeep_net = WideDeepModel(config) - - loss_net = NetWithLossClass(WideDeep_net, config) - train_net = TrainStepWrap(loss_net) - eval_net = PredictWithSigmoid(WideDeep_net) - - return train_net, eval_net - - -class ModelBuilder(): - """ - Wide and deep model builder - """ - def __init__(self): - pass - - def get_hook(self): - pass - - def get_train_hook(self): - hooks = [] - callback = LossCallBack() - hooks.append(callback) - - if int(os.getenv('DEVICE_ID')) == 0: - pass - return hooks - - def get_net(self, config): - return get_WideDeep_net(config) - - -def test_eval(config): - """ - test evaluate - """ - data_path = config.data_path - batch_size = config.batch_size - ds_eval = create_dataset(data_path, train_mode=False, epochs=2, - batch_size=batch_size) - print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) - - net_builder = ModelBuilder() - train_net, eval_net = net_builder.get_net(config) - - param_dict = load_checkpoint(config.ckpt_path) - load_param_into_net(eval_net, param_dict) - - auc_metric = AUCMetric() - model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) - - eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) - - model.eval(ds_eval, callbacks=eval_callback) - - -if __name__ == "__main__": - widedeep_config = WideDeepConfig() - widedeep_config.argparse_init() - - context.set_context(mode=context.GRAPH_MODE, device_target=widedeep_config.device_target) - test_eval(widedeep_config) diff --git a/model_zoo/wide_and_deep/script/run_auto_parallel_train.sh b/model_zoo/wide_and_deep/script/run_auto_parallel_train.sh deleted file mode 100644 index 9e9226a23a..0000000000 --- a/model_zoo/wide_and_deep/script/run_auto_parallel_train.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# bash run_multinpu_train.sh -execute_path=$(pwd) -script_self=$(readlink -f "$0") -self_path=$(dirname "${script_self}") -export RANK_SIZE=$1 -export EPOCH_SIZE=$2 -export DATASET=$3 -export RANK_TABLE_FILE=$4 -export MINDSPORE_HCCL_CONFIG_PATH=$4 - -for((i=0;i<$RANK_SIZE;i++)); -do - rm -rf ${execute_path}/device_$i/ - mkdir ${execute_path}/device_$i/ - cd ${execute_path}/device_$i/ || exit - export RANK_ID=$i - export DEVICE_ID=$i - python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & -done diff --git a/model_zoo/wide_and_deep/script/run_multinpu_train.sh b/model_zoo/wide_and_deep/script/run_multinpu_train.sh deleted file mode 100644 index 4b642bc196..0000000000 --- a/model_zoo/wide_and_deep/script/run_multinpu_train.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# bash run_multinpu_train.sh -execute_path=$(pwd) -script_self=$(readlink -f "$0") -self_path=$(dirname "${script_self}") -export RANK_SIZE=$1 -export EPOCH_SIZE=$2 -export DATASET=$3 -export RANK_TABLE_FILE=$4 -export MINDSPORE_HCCL_CONFIG_PATH=$4 - -for((i=0;i<$RANK_SIZE;i++)); -do - rm -rf ${execute_path}/device_$i/ - mkdir ${execute_path}/device_$i/ - cd ${execute_path}/device_$i/ || exit - export RANK_ID=$i - export DEVICE_ID=$i - python -s ${self_path}/../train_and_eval_distribute.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & -done diff --git a/model_zoo/wide_and_deep/src/callbacks.py b/model_zoo/wide_and_deep/src/callbacks.py deleted file mode 100644 index 4c2f9c700e..0000000000 --- a/model_zoo/wide_and_deep/src/callbacks.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -callbacks -""" -import time -from mindspore.train.callback import Callback -from mindspore import context -from mindspore.train import ParallelMode - -def add_write(file_path, out_str): - """ - add lines to the file - """ - with open(file_path, 'a+', encoding="utf-8") as file_out: - file_out.write(out_str + "\n") - - -class LossCallBack(Callback): - """ - Monitor the loss in training. - - If the loss is NAN or INF, terminate the training. - - Note: - If per_print_times is 0, do NOT print loss. - - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, config=None, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("per_print_times must be in and >= 0.") - self._per_print_times = per_print_times - self.config = config - - def step_end(self, run_context): - cb_params = run_context.original_args() - wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 - cur_num = cb_params.cur_step_num - print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss) - - # raise ValueError - if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: - loss_file = open(self.config.loss_file_name, "a+") - loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % - (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) - loss_file.write("\n") - loss_file.close() - print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % - (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) - - -class EvalCallBack(Callback): - """ - Monitor the loss in evaluating. - - If the loss is NAN or INF, terminate evaluating. - - Note: - If per_print_times is 0, do NOT print loss. - - Args: - print_per_step (int): Print loss every times. Default: 1. - """ - def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): - super(EvalCallBack, self).__init__() - if not isinstance(print_per_step, int) or print_per_step < 0: - raise ValueError("print_per_step must be int and >= 0.") - self.print_per_step = print_per_step - self.model = model - self.eval_dataset = eval_dataset - self.aucMetric = auc_metric - self.aucMetric.clear() - self.eval_file_name = config.eval_file_name - self.eval_values = [] - - def epoch_end(self, run_context): - """ - epoch end - """ - self.aucMetric.clear() - parallel_mode = context.get_auto_parallel_context("parallel_mode") - if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - context.set_auto_parallel_context(strategy_ckpt_save_file="", - strategy_ckpt_load_file="./strategy_train.ckpt") - start_time = time.time() - out = self.model.eval(self.eval_dataset) - end_time = time.time() - eval_time = int(end_time - start_time) - - time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) - out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) - print(out_str) - self.eval_values = out.values() - add_write(self.eval_file_name, out_str) diff --git a/model_zoo/wide_and_deep/src/config.py b/model_zoo/wide_and_deep/src/config.py deleted file mode 100644 index f8a2c84743..0000000000 --- a/model_zoo/wide_and_deep/src/config.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" config. """ -import argparse - - -def argparse_init(): - """ - argparse_init - """ - parser = argparse.ArgumentParser(description='WideDeep') - parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], - help="device where the code will be implemented. (Default: Ascend)") - parser.add_argument("--data_path", type=str, default="./test_raw_data/") - parser.add_argument("--epochs", type=int, default=15) - parser.add_argument("--full_batch", type=bool, default=False) - parser.add_argument("--batch_size", type=int, default=16000) - parser.add_argument("--eval_batch_size", type=int, default=16000) - parser.add_argument("--field_size", type=int, default=39) - parser.add_argument("--vocab_size", type=int, default=200000) - parser.add_argument("--emb_dim", type=int, default=80) - parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) - parser.add_argument("--deep_layer_act", type=str, default='relu') - parser.add_argument("--keep_prob", type=float, default=1.0) - parser.add_argument("--dropout_flag", type=int, default=0) - parser.add_argument("--output_path", type=str, default="./output/") - parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") - parser.add_argument("--eval_file_name", type=str, default="eval.log") - parser.add_argument("--loss_file_name", type=str, default="loss.log") - return parser - - -class WideDeepConfig(): - """ - WideDeepConfig - """ - def __init__(self): - self.device_target = "Ascend" - self.data_path = "./test_raw_data/" - self.full_batch = False - self.epochs = 15 - self.batch_size = 16000 - self.eval_batch_size = 16000 - self.field_size = 39 - self.vocab_size = 200000 - self.emb_dim = 80 - self.deep_layer_dim = [1024, 512, 256, 128] - self.deep_layer_act = 'relu' - self.weight_bias_init = ['normal', 'normal'] - self.emb_init = 'normal' - self.init_args = [-0.01, 0.01] - self.dropout_flag = False - self.keep_prob = 1.0 - self.l2_coef = 8e-5 - - self.output_path = "./output" - self.eval_file_name = "eval.log" - self.loss_file_name = "loss.log" - self.ckpt_path = "./checkpoints/" - - def argparse_init(self): - """ - argparse_init - """ - parser = argparse_init() - args, _ = parser.parse_known_args() - self.device_target = args.device_target - self.data_path = args.data_path - self.epochs = args.epochs - self.full_batch = args.full_batch - self.batch_size = args.batch_size - self.eval_batch_size = args.eval_batch_size - self.field_size = args.field_size - self.vocab_size = args.vocab_size - self.emb_dim = args.emb_dim - self.deep_layer_dim = args.deep_layer_dim - self.deep_layer_act = args.deep_layer_act - self.keep_prob = args.keep_prob - self.weight_bias_init = ['normal', 'normal'] - self.emb_init = 'normal' - self.init_args = [-0.01, 0.01] - self.dropout_flag = bool(args.dropout_flag) - self.l2_coef = 8e-5 - - self.output_path = args.output_path - self.eval_file_name = args.eval_file_name - self.loss_file_name = args.loss_file_name - self.ckpt_path = args.ckpt_path diff --git a/model_zoo/wide_and_deep/src/preprocess_data.py b/model_zoo/wide_and_deep/src/preprocess_data.py deleted file mode 100644 index 75562aa71b..0000000000 --- a/model_zoo/wide_and_deep/src/preprocess_data.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Download raw data and preprocessed data.""" -import os -import pickle -import collections -import argparse -import urllib.request -import tarfile -import numpy as np -from mindspore.mindrecord import FileWriter - -TRAIN_LINE_COUNT = 45840617 -TEST_LINE_COUNT = 6042135 - - -class CriteoStatsDict(): - """preprocessed data""" - - def __init__(self): - self.field_size = 39 - self.val_cols = ["val_{}".format(i + 1) for i in range(13)] - self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)] - - self.val_min_dict = {col: 0 for col in self.val_cols} - self.val_max_dict = {col: 0 for col in self.val_cols} - - self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} - - self.oov_prefix = "OOV" - - self.cat2id_dict = {} - self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) - self.cat2id_dict.update( - {self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) - - def stats_vals(self, val_list): - """Handling weights column""" - assert len(val_list) == len(self.val_cols) - - def map_max_min(i, val): - key = self.val_cols[i] - if val != "": - if float(val) > self.val_max_dict[key]: - self.val_max_dict[key] = float(val) - if float(val) < self.val_min_dict[key]: - self.val_min_dict[key] = float(val) - - for i, val in enumerate(val_list): - map_max_min(i, val) - - def stats_cats(self, cat_list): - """Handling cats column""" - - assert len(cat_list) == len(self.cat_cols) - - def map_cat_count(i, cat): - key = self.cat_cols[i] - self.cat_count_dict[key][cat] += 1 - - for i, cat in enumerate(cat_list): - map_cat_count(i, cat) - - def save_dict(self, dict_path, prefix=""): - with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: - pickle.dump(self.val_max_dict, file_wrt) - with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: - pickle.dump(self.val_min_dict, file_wrt) - with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: - pickle.dump(self.cat_count_dict, file_wrt) - - def load_dict(self, dict_path, prefix=""): - with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: - self.val_max_dict = pickle.load(file_wrt) - with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: - self.val_min_dict = pickle.load(file_wrt) - with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: - self.cat_count_dict = pickle.load(file_wrt) - print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items()))) - print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items()))) - - def get_cat2id(self, threshold=100): - for key, cat_count_d in self.cat_count_dict.items(): - new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) - for cat_str, _ in new_cat_count_d.items(): - self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) - print("cat2id_dict.size:{}".format(len(self.cat2id_dict))) - print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50])) - - def map_cat2id(self, values, cats): - """Cat to id""" - - def minmax_scale_value(i, val): - max_v = float(self.val_max_dict["val_{}".format(i + 1)]) - return float(val) * 1.0 / max_v - - id_list = [] - weight_list = [] - for i, val in enumerate(values): - if val == "": - id_list.append(i) - weight_list.append(0) - else: - key = "val_{}".format(i + 1) - id_list.append(self.cat2id_dict[key]) - weight_list.append(minmax_scale_value(i, float(val))) - - for i, cat_str in enumerate(cats): - key = "cat_{}".format(i + 1) + "_" + cat_str - if key in self.cat2id_dict: - id_list.append(self.cat2id_dict[key]) - else: - id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) - weight_list.append(1.0) - return id_list, weight_list - - -def mkdir_path(file_path): - if not os.path.exists(file_path): - os.makedirs(file_path) - - -def statsdata(file_path, dict_output_path, criteo_stats_dict): - """Preprocess data and save data""" - with open(file_path, encoding="utf-8") as file_in: - errorline_list = [] - count = 0 - for line in file_in: - count += 1 - line = line.strip("\n") - items = line.split("\t") - if len(items) != 40: - errorline_list.append(count) - print("line: {}".format(line)) - continue - if count % 1000000 == 0: - print("Have handled {}w lines.".format(count // 10000)) - values = items[1:14] - cats = items[14:] - - assert len(values) == 13, "values.size: {}".format(len(values)) - assert len(cats) == 26, "cats.size: {}".format(len(cats)) - criteo_stats_dict.stats_vals(values) - criteo_stats_dict.stats_cats(cats) - criteo_stats_dict.save_dict(dict_output_path) - - -def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000, - line_per_sample=1000, - test_size=0.1, seed=2020): - """Random split data and save mindrecord""" - test_size = int(TRAIN_LINE_COUNT * test_size) - all_indices = [i for i in range(TRAIN_LINE_COUNT)] - np.random.seed(seed) - np.random.shuffle(all_indices) - print("all_indices.size:{}".format(len(all_indices))) - test_indices_set = set(all_indices[:test_size]) - print("test_indices_set.size:{}".format(len(test_indices_set))) - print("-----------------------" * 10 + "\n" * 2) - - train_data_list = [] - test_data_list = [] - ids_list = [] - wts_list = [] - label_list = [] - - writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21) - writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3) - - schema = {"label": {"type": "float32", "shape": [-1]}, "feat_vals": {"type": "float32", "shape": [-1]}, - "feat_ids": {"type": "int32", "shape": [-1]}} - writer_train.add_schema(schema, "CRITEO_TRAIN") - writer_test.add_schema(schema, "CRITEO_TEST") - - with open(input_file_path, encoding="utf-8") as file_in: - items_error_size_lineCount = [] - count = 0 - train_part_number = 0 - test_part_number = 0 - for i, line in enumerate(file_in): - count += 1 - if count % 1000000 == 0: - print("Have handle {}w lines.".format(count // 10000)) - line = line.strip("\n") - items = line.split("\t") - if len(items) != 40: - items_error_size_lineCount.append(i) - continue - label = float(items[0]) - values = items[1:14] - cats = items[14:] - - assert len(values) == 13, "values.size: {}".format(len(values)) - assert len(cats) == 26, "cats.size: {}".format(len(cats)) - - ids, wts = criteo_stats_dict.map_cat2id(values, cats) - - ids_list.extend(ids) - wts_list.extend(wts) - label_list.append(label) - - if count % line_per_sample == 0: - if i not in test_indices_set: - train_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32), - "feat_vals": np.array(wts_list, dtype=np.float32), - "label": np.array(label_list, dtype=np.float32) - }) - else: - test_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32), - "feat_vals": np.array(wts_list, dtype=np.float32), - "label": np.array(label_list, dtype=np.float32) - }) - if train_data_list and len(train_data_list) % part_rows == 0: - writer_train.write_raw_data(train_data_list) - train_data_list.clear() - train_part_number += 1 - - if test_data_list and len(test_data_list) % part_rows == 0: - writer_test.write_raw_data(test_data_list) - test_data_list.clear() - test_part_number += 1 - - ids_list.clear() - wts_list.clear() - label_list.clear() - - if train_data_list: - writer_train.write_raw_data(train_data_list) - if test_data_list: - writer_test.write_raw_data(test_data_list) - writer_train.commit() - writer_test.commit() - - print("-------------" * 10) - print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount))) - print("-------------" * 10) - np.save("items_error_size_lineCount.npy", items_error_size_lineCount) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="criteo data") - parser.add_argument("--data_path", type=str, default="./criteo_data/") - - args, _ = parser.parse_known_args() - data_path = args.data_path - - download_data_path = data_path + "origin_data/" - mkdir_path(download_data_path) - - url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz" - file_name = download_data_path + '/' + url.split('/')[-1] - urllib.request.urlretrieve(url, filename=file_name) - - tar = tarfile.open(file_name) - names = tar.getnames() - for name in names: - tar.extract(name, path=download_data_path) - tar.close() - - criteo_stats = CriteoStatsDict() - data_file_path = data_path + "origin_data/train.txt" - stats_output_path = data_path + "stats_dict/" - mkdir_path(stats_output_path) - statsdata(data_file_path, stats_output_path, criteo_stats) - - criteo_stats.load_dict(dict_path=stats_output_path, prefix="") - criteo_stats.get_cat2id(threshold=100) - - in_file_path = data_path + "origin_data/train.txt" - output_path = data_path + "mindrecord/" - mkdir_path(output_path) - random_split_trans2mindrecord(in_file_path, output_path, criteo_stats, part_rows=2000000, line_per_sample=1000, - test_size=0.1, seed=2020) diff --git a/model_zoo/wide_and_deep/src/wide_and_deep.py b/model_zoo/wide_and_deep/src/wide_and_deep.py deleted file mode 100644 index 048bf3c66d..0000000000 --- a/model_zoo/wide_and_deep/src/wide_and_deep.py +++ /dev/null @@ -1,339 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""wide and deep model""" -from mindspore import nn -from mindspore import Parameter, ParameterTuple -import mindspore.common.dtype as mstype -from mindspore.ops import functional as F -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from mindspore.nn import Dropout -from mindspore.nn.optim import Adam, FTRL -# from mindspore.nn.metrics import Metric -from mindspore.common.initializer import Uniform, initializer -# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean -from mindspore.train.parallel_utils import ParallelMode -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.communication.management import get_group_size -import numpy as np - -np_type = np.float32 -ms_type = mstype.float32 - - -def init_method(method, shape, name, max_val=1.0): - ''' - parameter init method - ''' - if method in ['uniform']: - params = Parameter(initializer( - Uniform(max_val), shape, ms_type), name=name) - elif method == "one": - params = Parameter(initializer("ones", shape, ms_type), name=name) - elif method == 'zero': - params = Parameter(initializer("zeros", shape, ms_type), name=name) - elif method == "normal": - params = Parameter(initializer("normal", shape, ms_type), name=name) - return params - - -def init_var_dict(init_args, in_vars): - ''' - var init function - ''' - var_map = {} - _, _max_val = init_args - for _, iterm in enumerate(in_vars): - key, shape, method = iterm - if key not in var_map.keys(): - if method in ['random', 'uniform']: - var_map[key] = Parameter(initializer( - Uniform(_max_val), shape, ms_type), name=key) - elif method == "one": - var_map[key] = Parameter(initializer( - "ones", shape, ms_type), name=key) - elif method == "zero": - var_map[key] = Parameter(initializer( - "zeros", shape, ms_type), name=key) - elif method == 'normal': - var_map[key] = Parameter(initializer( - "normal", shape, ms_type), name=key) - return var_map - - -class DenseLayer(nn.Cell): - """ - Dense Layer for Deep Layer of WideDeep Model; - Containing: activation, matmul, bias_add; - Args: - """ - - def __init__(self, input_dim, output_dim, weight_bias_init, act_str, - keep_prob=0.7, use_activation=True, convert_dtype=True, drop_out=False): - super(DenseLayer, self).__init__() - weight_init, bias_init = weight_bias_init - self.weight = init_method( - weight_init, [input_dim, output_dim], name="weight") - self.bias = init_method(bias_init, [output_dim], name="bias") - self.act_func = self._init_activation(act_str) - self.matmul = P.MatMul(transpose_b=False) - self.bias_add = P.BiasAdd() - self.cast = P.Cast() - self.dropout = Dropout(keep_prob=keep_prob) - self.use_activation = use_activation - self.convert_dtype = convert_dtype - self.drop_out = drop_out - - def _init_activation(self, act_str): - act_str = act_str.lower() - if act_str == "relu": - act_func = P.ReLU() - elif act_str == "sigmoid": - act_func = P.Sigmoid() - elif act_str == "tanh": - act_func = P.Tanh() - return act_func - - def construct(self, x): - if self.training and self.drop_out: - x = self.dropout(x) - if self.convert_dtype: - x = self.cast(x, mstype.float16) - weight = self.cast(self.weight, mstype.float16) - bias = self.cast(self.bias, mstype.float16) - wx = self.matmul(x, weight) - wx = self.bias_add(wx, bias) - if self.use_activation: - wx = self.act_func(wx) - wx = self.cast(wx, mstype.float32) - else: - wx = self.matmul(x, self.weight) - wx = self.bias_add(wx, self.bias) - if self.use_activation: - wx = self.act_func(wx) - return wx - - -class WideDeepModel(nn.Cell): - """ - From paper: " Wide & Deep Learning for Recommender Systems" - Args: - config (Class): The default config of Wide&Deep - """ - - def __init__(self, config): - super(WideDeepModel, self).__init__() - self.batch_size = config.batch_size - parallel_mode = _get_parallel_mode() - if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - self.batch_size = self.batch_size * get_group_size() - self.field_size = config.field_size - self.vocab_size = config.vocab_size - self.emb_dim = config.emb_dim - self.deep_layer_dims_list = config.deep_layer_dim - self.deep_layer_act = config.deep_layer_act - self.init_args = config.init_args - self.weight_init, self.bias_init = config.weight_bias_init - self.weight_bias_init = config.weight_bias_init - self.emb_init = config.emb_init - self.drop_out = config.dropout_flag - self.keep_prob = config.keep_prob - self.deep_input_dims = self.field_size * self.emb_dim - self.layer_dims = self.deep_layer_dims_list + [1] - self.all_dim_list = [self.deep_input_dims] + self.layer_dims - - init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init), - ('V_l2', [self.vocab_size, self.emb_dim], self.emb_init), - ('Wide_b', [1], self.emb_init)] - var_map = init_var_dict(self.init_args, init_acts) - self.wide_w = var_map["Wide_w"] - self.wide_b = var_map["Wide_b"] - self.embedding_table = var_map["V_l2"] - self.dense_layer_1 = DenseLayer(self.all_dim_list[0], - self.all_dim_list[1], - self.weight_bias_init, - self.deep_layer_act, - convert_dtype=True, drop_out=config.dropout_flag) - self.dense_layer_2 = DenseLayer(self.all_dim_list[1], - self.all_dim_list[2], - self.weight_bias_init, - self.deep_layer_act, - convert_dtype=True, drop_out=config.dropout_flag) - self.dense_layer_3 = DenseLayer(self.all_dim_list[2], - self.all_dim_list[3], - self.weight_bias_init, - self.deep_layer_act, - convert_dtype=True, drop_out=config.dropout_flag) - self.dense_layer_4 = DenseLayer(self.all_dim_list[3], - self.all_dim_list[4], - self.weight_bias_init, - self.deep_layer_act, - convert_dtype=True, drop_out=config.dropout_flag) - self.dense_layer_5 = DenseLayer(self.all_dim_list[4], - self.all_dim_list[5], - self.weight_bias_init, - self.deep_layer_act, - use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) - - self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE') - self.mul = P.Mul() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.reshape = P.Reshape() - self.square = P.Square() - self.shape = P.Shape() - self.tile = P.Tile() - self.concat = P.Concat(axis=1) - self.cast = P.Cast() - - def construct(self, id_hldr, wt_hldr): - """ - Args: - id_hldr: batch ids; - wt_hldr: batch weights; - """ - mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) - # Wide layer - wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr) - wx = self.mul(wide_id_weight, mask) - wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) - # Deep layer - deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr) - vx = self.mul(deep_id_embs, mask) - deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) - deep_in = self.dense_layer_1(deep_in) - deep_in = self.dense_layer_2(deep_in) - deep_in = self.dense_layer_3(deep_in) - deep_in = self.dense_layer_4(deep_in) - deep_out = self.dense_layer_5(deep_in) - out = wide_out + deep_out - return out, self.embedding_table - - -class NetWithLossClass(nn.Cell): - - """" - Provide WideDeep training loss through network. - Args: - network (Cell): The training network - config (Class): WideDeep config - """ - - def __init__(self, network, config): - super(NetWithLossClass, self).__init__(auto_prefix=False) - self.network = network - self.l2_coef = config.l2_coef - self.loss = P.SigmoidCrossEntropyWithLogits() - self.square = P.Square() - self.reduceMean_false = P.ReduceMean(keep_dims=False) - self.reduceSum_false = P.ReduceSum(keep_dims=False) - - def construct(self, batch_ids, batch_wts, label): - predict, embedding_table = self.network(batch_ids, batch_wts) - log_loss = self.loss(predict, label) - wide_loss = self.reduceMean_false(log_loss) - l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 - deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v - - return wide_loss, deep_loss - - -class IthOutputCell(nn.Cell): - def __init__(self, network, output_index): - super(IthOutputCell, self).__init__() - self.network = network - self.output_index = output_index - - def construct(self, x1, x2, x3): - predict = self.network(x1, x2, x3)[self.output_index] - return predict - - -class TrainStepWrap(nn.Cell): - """ - Encapsulation class of WideDeep network training. - Append Adam and FTRL optimizers to the training network after that construct - function can be called to create the backward graph. - Args: - network (Cell): the training network. Note that loss function should have been added. - sens (Number): The adjust parameter. Default: 1000.0 - """ - - def __init__(self, network, sens=1024.0): - super(TrainStepWrap, self).__init__() - self.network = network - self.network.set_train() - self.trainable_params = network.trainable_params() - weights_w = [] - weights_d = [] - for params in self.trainable_params: - if 'wide' in params.name: - weights_w.append(params) - else: - weights_d.append(params) - self.weights_w = ParameterTuple(weights_w) - self.weights_d = ParameterTuple(weights_d) - self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, - l1=1e-8, l2=1e-8, initial_accum=1.0) - self.optimizer_d = Adam( - self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) - self.hyper_map = C.HyperMap() - self.grad_w = C.GradOperation('grad_w', get_by_list=True, - sens_param=True) - self.grad_d = C.GradOperation('grad_d', get_by_list=True, - sens_param=True) - self.sens = sens - self.loss_net_w = IthOutputCell(network, output_index=0) - self.loss_net_d = IthOutputCell(network, output_index=1) - - self.reducer_flag = False - self.grad_reducer_w = None - self.grad_reducer_d = None - parallel_mode = _get_parallel_mode() - self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL, - ParallelMode.HYBRID_PARALLEL) - if self.reducer_flag: - mean = _get_mirror_mean() - degree = _get_device_num() - self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree) - self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree) - - def construct(self, batch_ids, batch_wts, label): - weights_w = self.weights_w - weights_d = self.weights_d - loss_w, loss_d = self.network(batch_ids, batch_wts, label) - sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) - sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) - grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts, - label, sens_w) - grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts, - label, sens_d) - if self.reducer_flag: - grads_w = self.grad_reducer_w(grads_w) - grads_d = self.grad_reducer_d(grads_d) - return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, - self.optimizer_d(grads_d)) - - -class PredictWithSigmoid(nn.Cell): - def __init__(self, network): - super(PredictWithSigmoid, self).__init__() - self.network = network - self.sigmoid = P.Sigmoid() - - def construct(self, batch_ids, batch_wts, labels): - logits, _, _, = self.network(batch_ids, batch_wts) - pred_probs = self.sigmoid(logits) - return logits, pred_probs, labels diff --git a/model_zoo/wide_and_deep/train.py b/model_zoo/wide_and_deep/train.py deleted file mode 100644 index a043be3dc6..0000000000 --- a/model_zoo/wide_and_deep/train.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" test_training """ -import os -from mindspore import Model, context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor -from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel -from src.callbacks import LossCallBack -from src.datasets import create_dataset -from src.config import WideDeepConfig - - -def get_WideDeep_net(configure): - """ - Get network of wide&deep model. - """ - WideDeep_net = WideDeepModel(configure) - - loss_net = NetWithLossClass(WideDeep_net, configure) - train_net = TrainStepWrap(loss_net) - eval_net = PredictWithSigmoid(WideDeep_net) - - return train_net, eval_net - - -class ModelBuilder(): - """ - Build the model. - """ - def __init__(self): - pass - - def get_hook(self): - pass - - def get_train_hook(self): - hooks = [] - callback = LossCallBack() - hooks.append(callback) - if int(os.getenv('DEVICE_ID')) == 0: - pass - return hooks - - def get_net(self, configure): - return get_WideDeep_net(configure) - - -def test_train(configure): - """ - test_train - """ - data_path = configure.data_path - batch_size = configure.batch_size - epochs = configure.epochs - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) - print("ds_train.size: {}".format(ds_train.get_dataset_size())) - - net_builder = ModelBuilder() - train_net, _ = net_builder.get_net(configure) - train_net.set_train() - - model = Model(train_net) - callback = LossCallBack(config=configure) - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), - keep_checkpoint_max=5) - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) - model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb]) - - -if __name__ == "__main__": - config = WideDeepConfig() - config.argparse_init() - - context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) - test_train(config) diff --git a/model_zoo/wide_and_deep/train_and_eval.py b/model_zoo/wide_and_deep/train_and_eval.py deleted file mode 100644 index e0ab6b2e9e..0000000000 --- a/model_zoo/wide_and_deep/train_and_eval.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" test_training """ -import os - -from mindspore import Model, context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor - -from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel -from src.callbacks import LossCallBack, EvalCallBack -from src.datasets import create_dataset -from src.metrics import AUCMetric -from src.config import WideDeepConfig - - -def get_WideDeep_net(config): - """ - Get network of wide&deep model. - """ - WideDeep_net = WideDeepModel(config) - - loss_net = NetWithLossClass(WideDeep_net, config) - train_net = TrainStepWrap(loss_net) - eval_net = PredictWithSigmoid(WideDeep_net) - - return train_net, eval_net - - -class ModelBuilder(): - """ - ModelBuilder - """ - def __init__(self): - pass - - def get_hook(self): - pass - - def get_train_hook(self): - hooks = [] - callback = LossCallBack() - hooks.append(callback) - - if int(os.getenv('DEVICE_ID')) == 0: - pass - return hooks - - def get_net(self, config): - return get_WideDeep_net(config) - - -def test_train_eval(config): - """ - test_train_eval - """ - data_path = config.data_path - batch_size = config.batch_size - epochs = config.epochs - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size) - print("ds_train.size: {}".format(ds_train.get_dataset_size())) - print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) - - net_builder = ModelBuilder() - - train_net, eval_net = net_builder.get_net(config) - train_net.set_train() - auc_metric = AUCMetric() - - model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) - - eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) - - callback = LossCallBack(config=config) - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) - - out = model.eval(ds_eval) - print("=====" * 5 + "model.eval() initialized: {}".format(out)) - model.train(epochs, ds_train, - callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) - - -if __name__ == "__main__": - wide_deep_config = WideDeepConfig() - wide_deep_config.argparse_init() - - context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) - test_train_eval(wide_deep_config) diff --git a/model_zoo/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/wide_and_deep/train_and_eval_auto_parallel.py deleted file mode 100644 index 4c86931b2e..0000000000 --- a/model_zoo/wide_and_deep/train_and_eval_auto_parallel.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train_multinpu.""" - - -import os -import sys -import mindspore.dataset.engine as de -from mindspore import Model, context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor -from mindspore.train import ParallelMode -from mindspore.communication.management import get_rank, get_group_size, init -from mindspore.parallel import _cost_model_context as cost_model_context -from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple - -from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel -from src.callbacks import LossCallBack, EvalCallBack -from src.datasets import create_dataset -from src.metrics import AUCMetric -from src.config import WideDeepConfig - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) -context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) -cost_model_context.set_cost_model_context(multi_subgraphs=True) -init() - - - -def get_WideDeep_net(config): - """ - Get network of wide&deep model. - """ - WideDeep_net = WideDeepModel(config) - loss_net = NetWithLossClass(WideDeep_net, config) - loss_net = VirtualDatasetCellTriple(loss_net) - train_net = TrainStepWrap(loss_net) - eval_net = PredictWithSigmoid(WideDeep_net) - eval_net = VirtualDatasetCellTriple(eval_net) - return train_net, eval_net - - -class ModelBuilder(): - """ - ModelBuilder - """ - def __init__(self): - pass - - def get_hook(self): - pass - - def get_train_hook(self): - hooks = [] - callback = LossCallBack() - hooks.append(callback) - if int(os.getenv('DEVICE_ID')) == 0: - pass - return hooks - - def get_net(self, config): - return get_WideDeep_net(config) - - -def train_and_eval(config): - """ - test_train_eval - """ - data_path = config.data_path - batch_size = config.batch_size - epochs = config.epochs - print("epochs is {}".format(epochs)) - if config.full_batch: - context.set_auto_parallel_context(full_batch=True) - de.config.set_seed(1) - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, - batch_size=batch_size*get_group_size()) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, - batch_size=batch_size*get_group_size()) - else: - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, - batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, - batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) - print("ds_train.size: {}".format(ds_train.get_dataset_size())) - print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) - - net_builder = ModelBuilder() - - train_net, eval_net = net_builder.get_net(config) - train_net.set_train() - auc_metric = AUCMetric() - - model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) - - eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) - - callback = LossCallBack(config=config) - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', - directory=config.ckpt_path, config=ckptconfig) - context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") - model.train(epochs, ds_train, - callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) - - -if __name__ == "__main__": - wide_deep_config = WideDeepConfig() - wide_deep_config.argparse_init() - train_and_eval(wide_deep_config) diff --git a/model_zoo/wide_and_deep/train_and_eval_distribute.py b/model_zoo/wide_and_deep/train_and_eval_distribute.py deleted file mode 100644 index 71f2b11cba..0000000000 --- a/model_zoo/wide_and_deep/train_and_eval_distribute.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train_multinpu.""" - - -import os -import sys -import numpy as np -from mindspore import Model, context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor -from mindspore.train import ParallelMode -from mindspore.communication.management import get_rank, get_group_size, init - -from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel -from src.callbacks import LossCallBack, EvalCallBack -from src.datasets import create_dataset -from src.metrics import AUCMetric -from src.config import WideDeepConfig - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - - -def get_WideDeep_net(config): - """ - Get network of wide&deep model. - """ - WideDeep_net = WideDeepModel(config) - loss_net = NetWithLossClass(WideDeep_net, config) - train_net = TrainStepWrap(loss_net) - eval_net = PredictWithSigmoid(WideDeep_net) - return train_net, eval_net - - -class ModelBuilder(): - """ - ModelBuilder - """ - def __init__(self): - pass - - def get_hook(self): - pass - - def get_train_hook(self): - hooks = [] - callback = LossCallBack() - hooks.append(callback) - if int(os.getenv('DEVICE_ID')) == 0: - pass - return hooks - - def get_net(self, config): - return get_WideDeep_net(config) - - -def train_and_eval(config): - """ - test_train_eval - """ - np.random.seed(1000) - data_path = config.data_path - batch_size = config.batch_size - epochs = config.epochs - print("epochs is {}".format(epochs)) - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, - batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, - batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) - print("ds_train.size: {}".format(ds_train.get_dataset_size())) - print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) - - net_builder = ModelBuilder() - - train_net, eval_net = net_builder.get_net(config) - train_net.set_train() - auc_metric = AUCMetric() - - model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) - - eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) - - callback = LossCallBack(config=config) - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) - if config.device_target == "Ascend": - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', - directory=config.ckpt_path, config=ckptconfig) - elif config.device_target == "GPU": - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), - directory=config.ckpt_path, config=ckptconfig) - out = model.eval(ds_eval) - print("=====" * 5 + "model.eval() initialized: {}".format(out)) - model.train(epochs, ds_train, - callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) - - -if __name__ == "__main__": - wide_deep_config = WideDeepConfig() - wide_deep_config.argparse_init() - - context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) - if wide_deep_config.device_target == "Ascend": - init("hccl") - elif wide_deep_config.device_target == "GPU": - init("nccl") - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, - device_num=get_group_size()) - - train_and_eval(wide_deep_config) diff --git a/model_zoo/yolov3_darknet53/README.md b/model_zoo/yolov3_darknet53/README.md deleted file mode 100644 index 9a5bfed7d3..0000000000 --- a/model_zoo/yolov3_darknet53/README.md +++ /dev/null @@ -1,132 +0,0 @@ -# YOLOV3-DarkNet53 Example - -## Description - -This is an example of training YOLOV3-DarkNet53 with COCO2014 dataset in MindSpore. - -## Requirements - -- Install [MindSpore](https://www.mindspore.cn/install/en). - -- Download the dataset COCO2014. - -> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows: - -``` -. -└─dataset - ├─train2014 - ├─val2014 - └─annotations -``` - -## Structure - -```shell -. -└─yolov3_darknet53 - ├─README.md - ├─scripts - ├─run_standalone_train.sh # launch standalone training(1p) - ├─run_distribute_train.sh # launch distributed training(8p) - └─run_eval.sh # launch evaluating - ├─src - ├─config.py # parameter configuration - ├─darknet.py # backbone of network - ├─distributed_sampler.py # iterator of dataset - ├─initializer.py # initializer of parameters - ├─logger.py # log function - ├─loss.py # loss function - ├─lr_scheduler.py # generate learning rate - ├─transforms.py # Preprocess data - ├─util.py # util function - ├─yolo.py # yolov3 network - ├─yolo_dataset.py # create dataset for YOLOV3 - ├─eval.py # eval net - └─train.py # train net -``` - -## Running the example - -### Train - -#### Usage - -``` -# distributed training -sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [MINDSPORE_HCCL_CONFIG_PATH] - -# standalone training -sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] -``` - -#### Launch - -```bash -# distributed training example(8p) -sh run_distribute_train.sh dataset/coco2014 backbone/backbone.ckpt rank_table_8p.json - -# standalone training example(1p) -sh run_standalone_train.sh dataset/coco2014 backbone/backbone.ckpt -``` - -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). - -#### Result - -Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt. - -``` -# distribute training result(8p) -epoch[0], iter[0], loss:14623.384766, 1.23 imgs/sec, lr:7.812499825377017e-05 -epoch[0], iter[100], loss:1486.253051, 15.01 imgs/sec, lr:0.007890624925494194 -epoch[0], iter[200], loss:288.579535, 490.41 imgs/sec, lr:0.015703124925494194 -epoch[0], iter[300], loss:153.136754, 531.99 imgs/sec, lr:0.023515624925494194 -epoch[1], iter[400], loss:106.429322, 405.14 imgs/sec, lr:0.03132812678813934 -... -epoch[318], iter[102000], loss:34.135306, 431.06 imgs/sec, lr:9.63797629083274e-06 -epoch[319], iter[102100], loss:35.652469, 449.52 imgs/sec, lr:2.409552052995423e-06 -epoch[319], iter[102200], loss:34.652273, 384.02 imgs/sec, lr:2.409552052995423e-06 -epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e-06 -... -``` - -### Infer - -#### Usage - -``` -# infer -sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] -``` - -#### Launch - -```bash -# infer with checkpoint -sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt - -``` - -> checkpoint can be produced in training process. - - -#### Result - -Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt. - -``` -=============coco eval reulst========= - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.528 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.127 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.323 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.428 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.259 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.398 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.423 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.224 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.442 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551 -``` diff --git a/model_zoo/yolov3_darknet53/eval.py b/model_zoo/yolov3_darknet53/eval.py deleted file mode 100644 index 6680b10476..0000000000 --- a/model_zoo/yolov3_darknet53/eval.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""YoloV3 eval.""" -import os -import argparse -import datetime -import time -import sys -from collections import defaultdict - -import numpy as np -from pycocotools.coco import COCO -from pycocotools.cocoeval import COCOeval - -from mindspore import Tensor -from mindspore.train import ParallelMode -from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net -import mindspore as ms - -from src.yolo import YOLOV3DarkNet53 -from src.logger import get_logger -from src.yolo_dataset import create_yolo_dataset -from src.config import ConfigYOLOV3DarkNet53 - -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid) - - -class Redirct: - def __init__(self): - self.content = "" - - def write(self, content): - self.content += content - - def flush(self): - self.content = "" - - -class DetectionEngine: - """Detection engine.""" - def __init__(self, args): - self.ignore_threshold = args.ignore_threshold - self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', - 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', - 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', - 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', - 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', - 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', - 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] - self.num_classes = len(self.labels) - self.results = {} - self.file_path = '' - self.save_prefix = args.outputs_dir - self.annFile = args.annFile - self._coco = COCO(self.annFile) - self._img_ids = list(sorted(self._coco.imgs.keys())) - self.det_boxes = [] - self.nms_thresh = args.nms_thresh - self.coco_catIds = self._coco.getCatIds() - - def do_nms_for_results(self): - """Get result boxes.""" - for img_id in self.results: - for clsi in self.results[img_id]: - dets = self.results[img_id][clsi] - dets = np.array(dets) - keep_index = self._nms(dets, self.nms_thresh) - - keep_box = [{'image_id': int(img_id), - 'category_id': int(clsi), - 'bbox': list(dets[i][:4].astype(float)), - 'score': dets[i][4].astype(float)} - for i in keep_index] - self.det_boxes.extend(keep_box) - - def _nms(self, dets, thresh): - """Calculate NMS.""" - # conver xywh -> xmin ymin xmax ymax - x1 = dets[:, 0] - y1 = dets[:, 1] - x2 = x1 + dets[:, 2] - y2 = y1 + dets[:, 3] - scores = dets[:, 4] - - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - order = scores.argsort()[::-1] - - keep = [] - while order.size > 0: - i = order[0] - keep.append(i) - xx1 = np.maximum(x1[i], x1[order[1:]]) - yy1 = np.maximum(y1[i], y1[order[1:]]) - xx2 = np.minimum(x2[i], x2[order[1:]]) - yy2 = np.minimum(y2[i], y2[order[1:]]) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (areas[i] + areas[order[1:]] - inter) - - inds = np.where(ovr <= thresh)[0] - order = order[inds + 1] - return keep - - def write_result(self): - """Save result to file.""" - import json - t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') - try: - self.file_path = self.save_prefix + '/predict' + t + '.json' - f = open(self.file_path, 'w') - json.dump(self.det_boxes, f) - except IOError as e: - raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) - else: - f.close() - return self.file_path - - def get_eval_result(self): - """Get eval result.""" - cocoGt = COCO(self.annFile) - cocoDt = cocoGt.loadRes(self.file_path) - cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') - cocoEval.evaluate() - cocoEval.accumulate() - rdct = Redirct() - stdout = sys.stdout - sys.stdout = rdct - cocoEval.summarize() - sys.stdout = stdout - return rdct.content - - def detect(self, outputs, batch, image_shape, image_id): - """Detect boxes.""" - outputs_num = len(outputs) - # output [|32, 52, 52, 3, 85| ] - for batch_id in range(batch): - for out_id in range(outputs_num): - # 32, 52, 52, 3, 85 - out_item = outputs[out_id] - # 52, 52, 3, 85 - out_item_single = out_item[batch_id, :] - # get number of items in one head, [B, gx, gy, anchors, 5+80] - dimensions = out_item_single.shape[:-1] - out_num = 1 - for d in dimensions: - out_num *= d - ori_w, ori_h = image_shape[batch_id] - img_id = int(image_id[batch_id]) - x = out_item_single[..., 0] * ori_w - y = out_item_single[..., 1] * ori_h - w = out_item_single[..., 2] * ori_w - h = out_item_single[..., 3] * ori_h - - conf = out_item_single[..., 4:5] - cls_emb = out_item_single[..., 5:] - - cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) - x = x.reshape(-1) - y = y.reshape(-1) - w = w.reshape(-1) - h = h.reshape(-1) - cls_emb = cls_emb.reshape(-1, 80) - conf = conf.reshape(-1) - cls_argmax = cls_argmax.reshape(-1) - - x_top_left = x - w / 2. - y_top_left = y - h / 2. - # creat all False - flag = np.random.random(cls_emb.shape) > sys.maxsize - for i in range(flag.shape[0]): - c = cls_argmax[i] - flag[i, c] = True - confidence = cls_emb[flag] * conf - for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax): - if confi < self.ignore_threshold: - continue - if img_id not in self.results: - self.results[img_id] = defaultdict(list) - x_lefti = max(0, x_lefti) - y_lefti = max(0, y_lefti) - wi = min(wi, ori_w) - hi = min(hi, ori_h) - # transform catId to match coco - coco_clsi = self.coco_catIds[clsi] - self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) - - -def parse_args(): - """Parse arguments.""" - parser = argparse.ArgumentParser('mindspore coco testing') - - # dataset related - parser.add_argument('--data_dir', type=str, default='', help='train data dir') - parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') - - # network related - parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') - - # logging related - parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location') - - # detect_related - parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS') - parser.add_argument('--annFile', type=str, default='', help='path to annotation') - parser.add_argument('--testing_shape', type=str, default='', help='shape for test ') - parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes') - - args, _ = parser.parse_known_args() - - args.data_root = os.path.join(args.data_dir, 'val2014') - args.annFile = os.path.join(args.data_dir, 'annotations/instances_val2014.json') - - return args - - -def conver_testing_shape(args): - """Convert testing shape to list.""" - testing_shape = [int(args.testing_shape), int(args.testing_shape)] - return testing_shape - - -def test(): - """The function of eval.""" - start_time = time.time() - args = parse_args() - - # logger - args.outputs_dir = os.path.join(args.log_path, - datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - rank_id = int(os.environ.get('RANK_ID')) - args.logger = get_logger(args.outputs_dir, rank_id) - - context.reset_auto_parallel_context() - parallel_mode = ParallelMode.STAND_ALONE - context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1) - - args.logger.info('Creating Network....') - network = YOLOV3DarkNet53(is_training=False) - - args.logger.info(args.pretrained) - if os.path.isfile(args.pretrained): - param_dict = load_checkpoint(args.pretrained) - param_dict_new = {} - for key, values in param_dict.items(): - if key.startswith('moments.'): - continue - elif key.startswith('yolo_network.'): - param_dict_new[key[13:]] = values - else: - param_dict_new[key] = values - load_param_into_net(network, param_dict_new) - args.logger.info('load_model {} success'.format(args.pretrained)) - else: - args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained)) - assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained)) - exit(1) - - data_root = args.data_root - ann_file = args.annFile - - config = ConfigYOLOV3DarkNet53() - if args.testing_shape: - config.test_img_shape = conver_testing_shape(args) - - ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size, - max_epoch=1, device_num=1, rank=rank_id, shuffle=False, - config=config) - - args.logger.info('testing shape : {}'.format(config.test_img_shape)) - args.logger.info('totol {} images to eval'.format(data_size)) - - network.set_train(False) - - # init detection engine - detection = DetectionEngine(args) - - input_shape = Tensor(tuple(config.test_img_shape), ms.float32) - args.logger.info('Start inference....') - for i, data in enumerate(ds.create_dict_iterator()): - image = Tensor(data["image"]) - - image_shape = Tensor(data["image_shape"]) - image_id = Tensor(data["img_id"]) - - prediction = network(image, input_shape) - output_big, output_me, output_small = prediction - output_big = output_big.asnumpy() - output_me = output_me.asnumpy() - output_small = output_small.asnumpy() - image_id = image_id.asnumpy() - image_shape = image_shape.asnumpy() - - detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id) - if i % 1000 == 0: - args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100)) - - args.logger.info('Calculating mAP...') - detection.do_nms_for_results() - result_file_path = detection.write_result() - args.logger.info('result file path: {}'.format(result_file_path)) - eval_result = detection.get_eval_result() - - cost_time = time.time() - start_time - args.logger.info('\n=============coco eval reulst=========\n' + eval_result) - args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) - - -if __name__ == "__main__": - test() diff --git a/model_zoo/yolov3_darknet53/scripts/run_distribute_train.sh b/model_zoo/yolov3_darknet53/scripts/run_distribute_train.sh deleted file mode 100644 index c6e83ae8f8..0000000000 --- a/model_zoo/yolov3_darknet53/scripts/run_distribute_train.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [ $# != 3 ] -then - echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [MINDSPORE_HCCL_CONFIG_PATH]" -exit 1 -fi - -get_real_path(){ - if [ "${1:0:1}" == "/" ]; then - echo "$1" - else - echo "$(realpath -m $PWD/$1)" - fi -} - -DATASET_PATH=$(get_real_path $1) -PRETRAINED_BACKBONE=$(get_real_path $2) -MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3) -echo $DATASET_PATH -echo $PRETRAINED_BACKBONE -echo $MINDSPORE_HCCL_CONFIG_PATH - -if [ ! -d $DATASET_PATH ] -then - echo "error: DATASET_PATH=$DATASET_PATH is not a directory" -exit 1 -fi - -if [ ! -f $PRETRAINED_BACKBONE ] -then - echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" -exit 1 -fi - -if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ] -then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file" -exit 1 -fi - -export DEVICE_NUM=8 -export RANK_SIZE=8 -export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH - -for((i=0; i<${DEVICE_NUM}; i++)) -do - export DEVICE_ID=$i - export RANK_ID=$i - rm -rf ./train_parallel$i - mkdir ./train_parallel$i - cp ../*.py ./train_parallel$i - cp -r ../src ./train_parallel$i - cd ./train_parallel$i || exit - echo "start training for rank $RANK_ID, device $DEVICE_ID" - env > env.log - python train.py \ - --data_dir=$DATASET_PATH \ - --pretrained_backbone=$PRETRAINED_BACKBONE \ - --is_distributed=1 \ - --lr=0.1 \ - --T_max=320 \ - --max_epoch=320 \ - --warmup_epochs=4 \ - --lr_scheduler=cosine_annealing > log.txt 2>&1 & - cd .. -done diff --git a/model_zoo/yolov3_darknet53/src/initializer.py b/model_zoo/yolov3_darknet53/src/initializer.py deleted file mode 100644 index f3c03a8ad1..0000000000 --- a/model_zoo/yolov3_darknet53/src/initializer.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Parameter init.""" -import math -import numpy as np -from mindspore.common import initializer as init -from mindspore.common.initializer import Initializer as MeInitializer -import mindspore.nn as nn -from mindspore import Tensor - - -np.random.seed(5) - - -def calculate_gain(nonlinearity, param=None): - r"""Return the recommended gain value for the given nonlinearity function. - The values are as follows: - - ================= ==================================================== - nonlinearity gain - ================= ==================================================== - Linear / Identity :math:`1` - Conv{1,2,3}D :math:`1` - Sigmoid :math:`1` - Tanh :math:`\frac{5}{3}` - ReLU :math:`\sqrt{2}` - Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` - ================= ==================================================== - - Args: - nonlinearity: the non-linear function (`nn.functional` name) - param: optional parameter for the non-linear function - - Examples: - >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 - """ - linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] - if nonlinearity in linear_fns or nonlinearity == 'sigmoid': - return 1 - if nonlinearity == 'tanh': - return 5.0 / 3 - if nonlinearity == 'relu': - return math.sqrt(2.0) - if nonlinearity == 'leaky_relu': - if param is None: - negative_slope = 0.01 - elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): - # True/False are instances of int, hence check above - negative_slope = param - else: - raise ValueError("negative_slope {} not a valid number".format(param)) - return math.sqrt(2.0 / (1 + negative_slope ** 2)) - - raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) - - -def _assignment(arr, num): - """Assign the value of 'num' and 'arr'.""" - if arr.shape == (): - arr = arr.reshape((1)) - arr[:] = num - arr = arr.reshape(()) - else: - if isinstance(num, np.ndarray): - arr[:] = num[:] - else: - arr[:] = num - return arr - - -def _calculate_correct_fan(array, mode): - mode = mode.lower() - valid_modes = ['fan_in', 'fan_out'] - if mode not in valid_modes: - raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) - - fan_in, fan_out = _calculate_fan_in_and_fan_out(array) - return fan_in if mode == 'fan_in' else fan_out - - -def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): - r"""Fills the input `Tensor` with values according to the method - described in `Delving deep into rectifiers: Surpassing human-level - performance on ImageNet classification` - He, K. et al. (2015), using a - uniform distribution. The resulting tensor will have values sampled from - :math:`\mathcal{U}(-\text{bound}, \text{bound})` where - - .. math:: - \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} - - Also known as He initialization. - - Args: - tensor: an n-dimensional `Tensor` - a: the negative slope of the rectifier used after this layer (only - used with ``'leaky_relu'``) - mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` - preserves the magnitude of the variance of the weights in the - forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the - backwards pass. - nonlinearity: the non-linear function (`nn.functional` name), - recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). - - Examples: - >>> w = np.empty(3, 5) - >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') - """ - fan = _calculate_correct_fan(arr, mode) - gain = calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - return np.random.uniform(-bound, bound, arr.shape) - - -def _calculate_fan_in_and_fan_out(arr): - """Calculate fan in and fan out.""" - dimensions = len(arr.shape) - if dimensions < 2: - raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") - - num_input_fmaps = arr.shape[1] - num_output_fmaps = arr.shape[0] - receptive_field_size = 1 - if dimensions > 2: - receptive_field_size = arr[0][0].size - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - - return fan_in, fan_out - - -class KaimingUniform(MeInitializer): - """Kaiming uniform initializer.""" - def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): - super(KaimingUniform, self).__init__() - self.a = a - self.mode = mode - self.nonlinearity = nonlinearity - - def _initialize(self, arr): - tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) - _assignment(arr, tmp) - - -def default_recurisive_init(custom_cell): - """Initialize parameter.""" - for _, cell in custom_cell.cells_and_names(): - if isinstance(cell, nn.Conv2d): - cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if cell.bias is not None: - fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) - bound = 1 / math.sqrt(fan_in) - cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), - cell.bias.default_input.dtype) - elif isinstance(cell, nn.Dense): - cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), - cell.weight.default_input.shape, - cell.weight.default_input.dtype).to_tensor() - if cell.bias is not None: - fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) - bound = 1 / math.sqrt(fan_in) - cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), - cell.bias.default_input.dtype) - elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): - pass diff --git a/model_zoo/yolov3_darknet53/src/transforms.py b/model_zoo/yolov3_darknet53/src/transforms.py deleted file mode 100644 index 837d1a25ea..0000000000 --- a/model_zoo/yolov3_darknet53/src/transforms.py +++ /dev/null @@ -1,577 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Preprocess dataset.""" -import random -import threading -import copy - -import numpy as np -from PIL import Image -import cv2 - - -def _rand(a=0., b=1.): - return np.random.rand() * (b - a) + a - - -def bbox_iou(bbox_a, bbox_b, offset=0): - """Calculate Intersection-Over-Union(IOU) of two bounding boxes. - - Parameters - ---------- - bbox_a : numpy.ndarray - An ndarray with shape :math:`(N, 4)`. - bbox_b : numpy.ndarray - An ndarray with shape :math:`(M, 4)`. - offset : float or int, default is 0 - The ``offset`` is used to control the whether the width(or height) is computed as - (right - left + ``offset``). - Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. - - Returns - ------- - numpy.ndarray - An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of - bounding boxes in `bbox_a` and `bbox_b`. - - """ - if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: - raise IndexError("Bounding boxes axis 1 must have at least length 4") - - tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) - br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) - - area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) - area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) - area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) - return area_i / (area_a[:, None] + area_b - area_i) - - -def statistic_normalize_img(img, statistic_norm): - """Statistic normalize images.""" - # img: RGB - if isinstance(img, Image.Image): - img = np.array(img) - img = img/255. - mean = np.array([0.485, 0.456, 0.406]) - std = np.array([0.229, 0.224, 0.225]) - if statistic_norm: - img = (img - mean) / std - return img - - -def get_interp_method(interp, sizes=()): - """Get the interpolation method for resize functions. - The major purpose of this function is to wrap a random interp method selection - and a auto-estimation method. - - Parameters - ---------- - interp : int - interpolation method for all resizing operations - - Possible values: - 0: Nearest Neighbors Interpolation. - 1: Bilinear interpolation. - 2: Bicubic interpolation over 4x4 pixel neighborhood. - 3: Nearest Neighbors. [Originally it should be Area-based, - as we cannot find Area-based, so we use NN instead. - Area-based (resampling using pixel area relation). It may be a - preferred method for image decimation, as it gives moire-free - results. But when the image is zoomed, it is similar to the Nearest - Neighbors method. (used by default). - 4: Lanczos interpolation over 8x8 pixel neighborhood. - 9: Cubic for enlarge, area for shrink, bilinear for others - 10: Random select from interpolation method metioned above. - Note: - When shrinking an image, it will generally look best with AREA-based - interpolation, whereas, when enlarging an image, it will generally look best - with Bicubic (slow) or Bilinear (faster but still looks OK). - More details can be found in the documentation of OpenCV, please refer to - http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. - sizes : tuple of int - (old_height, old_width, new_height, new_width), if None provided, auto(9) - will return Area(2) anyway. - - Returns - ------- - int - interp method from 0 to 4 - """ - if interp == 9: - if sizes: - assert len(sizes) == 4 - oh, ow, nh, nw = sizes - if nh > oh and nw > ow: - return 2 - if nh < oh and nw < ow: - return 0 - return 1 - return 2 - if interp == 10: - return random.randint(0, 4) - if interp not in (0, 1, 2, 3, 4): - raise ValueError('Unknown interp method %d' % interp) - return interp - - -def pil_image_reshape(interp): - """Reshape pil image.""" - reshape_type = { - 0: Image.NEAREST, - 1: Image.BILINEAR, - 2: Image.BICUBIC, - 3: Image.NEAREST, - 4: Image.LANCZOS, - } - return reshape_type[interp] - - -def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, - max_boxes, label_smooth, label_smooth_factor=0.1): - """Preprocess annotation boxes.""" - anchors = np.array(anchors) - num_layers = anchors.shape[0] // 3 - anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - true_boxes = np.array(true_boxes, dtype='float32') - input_shape = np.array(in_shape, dtype='int32') - boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. - # trans to box center point - boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] - # input_shape is [h, w] - true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] - true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] - # true_boxes = [xywh] - - grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] - # grid_shape [h, w] - y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), - 5 + num_classes), dtype='float32') for l in range(num_layers)] - # y_true [gridy, gridx] - anchors = np.expand_dims(anchors, 0) - anchors_max = anchors / 2. - anchors_min = -anchors_max - valid_mask = boxes_wh[..., 0] > 0 - - wh = boxes_wh[valid_mask] - if wh.size > 0: - wh = np.expand_dims(wh, -2) - boxes_max = wh / 2. - boxes_min = -boxes_max - - intersect_min = np.maximum(boxes_min, anchors_min) - intersect_max = np.minimum(boxes_max, anchors_max) - intersect_wh = np.maximum(intersect_max - intersect_min, 0.) - intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] - box_area = wh[..., 0] * wh[..., 1] - anchor_area = anchors[..., 0] * anchors[..., 1] - iou = intersect_area / (box_area + anchor_area - intersect_area) - - best_anchor = np.argmax(iou, axis=-1) - for t, n in enumerate(best_anchor): - for l in range(num_layers): - if n in anchor_mask[l]: - i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y - j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x - - k = anchor_mask[l].index(n) - c = true_boxes[t, 4].astype('int32') - y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] - y_true[l][j, i, k, 4] = 1. - - # lable-smooth - if label_smooth: - sigma = label_smooth_factor/(num_classes-1) - y_true[l][j, i, k, 5:] = sigma - y_true[l][j, i, k, 5+c] = 1-label_smooth_factor - else: - y_true[l][j, i, k, 5 + c] = 1. - - # pad_gt_boxes for avoiding dynamic shape - pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) - pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) - pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) - - mask0 = np.reshape(y_true[0][..., 4:5], [-1]) - gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) - # gt_box [boxes, [x,y,w,h]] - gt_box0 = gt_box0[mask0 == 1] - # gt_box0: get all boxes which have object - pad_gt_box0[:gt_box0.shape[0]] = gt_box0 - # gt_box0.shape[0]: total number of boxes in gt_box0 - # top N of pad_gt_box0 is real box, and after are pad by zero - - mask1 = np.reshape(y_true[1][..., 4:5], [-1]) - gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) - gt_box1 = gt_box1[mask1 == 1] - pad_gt_box1[:gt_box1.shape[0]] = gt_box1 - - mask2 = np.reshape(y_true[2][..., 4:5], [-1]) - gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) - - gt_box2 = gt_box2[mask2 == 1] - pad_gt_box2[:gt_box2.shape[0]] = gt_box2 - return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 - - -def _reshape_data(image, image_size): - """Reshape image.""" - if not isinstance(image, Image.Image): - image = Image.fromarray(image) - ori_w, ori_h = image.size - ori_image_shape = np.array([ori_w, ori_h], np.int32) - # original image shape fir:H sec:W - h, w = image_size - interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) - image = image.resize((w, h), pil_image_reshape(interp)) - image_data = statistic_normalize_img(image, statistic_norm=True) - if len(image_data.shape) == 2: - image_data = np.expand_dims(image_data, axis=-1) - image_data = np.concatenate([image_data, image_data, image_data], axis=-1) - image_data = image_data.astype(np.float32) - return image_data, ori_image_shape - - -def color_distortion(img, hue, sat, val, device_num): - """Color distortion.""" - hue = _rand(-hue, hue) - sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) - val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) - if device_num != 1: - cv2.setNumThreads(1) - x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) - x = x / 255. - x[..., 0] += hue - x[..., 0][x[..., 0] > 1] -= 1 - x[..., 0][x[..., 0] < 0] += 1 - x[..., 1] *= sat - x[..., 2] *= val - x[x > 1] = 1 - x[x < 0] = 0 - x = x * 255. - x = x.astype(np.uint8) - image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) - return image_data - - -def filp_pil_image(img): - return img.transpose(Image.FLIP_LEFT_RIGHT) - - -def convert_gray_to_color(img): - if len(img.shape) == 2: - img = np.expand_dims(img, axis=-1) - img = np.concatenate([img, img, img], axis=-1) - return img - - -def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): - iou = bbox_iou(box, crop_box) - return min_iou <= iou.min() and max_iou >= iou.max() - - -def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): - """Choose candidate by constraints.""" - if use_constraints: - constraints = ( - (0.1, None), - (0.3, None), - (0.5, None), - (0.7, None), - (0.9, None), - (None, 1), - ) - else: - constraints = ( - (None, None), - ) - # add default candidate - candidates = [(0, 0, input_w, input_h)] - for constraint in constraints: - min_iou, max_iou = constraint - min_iou = -np.inf if min_iou is None else min_iou - max_iou = np.inf if max_iou is None else max_iou - - for _ in range(max_trial): - # box_data should have at least one box - new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) - scale = _rand(0.25, 2) - - if new_ar < 1: - nh = int(scale * input_h) - nw = int(nh * new_ar) - else: - nw = int(scale * input_w) - nh = int(nw / new_ar) - - dx = int(_rand(0, input_w - nw)) - dy = int(_rand(0, input_h - nh)) - - if box.size > 0: - t_box = copy.deepcopy(box) - t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx - t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy - - crop_box = np.array((0, 0, input_w, input_h)) - if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): - continue - else: - candidates.append((dx, dy, nw, nh)) - else: - raise Exception("!!! annotation box is less than 1") - return candidates - - -def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, - image_h, flip, box, box_data, allow_outside_center): - """Calculate correct boxes.""" - while candidates: - if len(candidates) > 1: - # ignore default candidate which do not crop - candidate = candidates.pop(np.random.randint(1, len(candidates))) - else: - candidate = candidates.pop(np.random.randint(0, len(candidates))) - dx, dy, nw, nh = candidate - t_box = copy.deepcopy(box) - t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx - t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy - if flip: - t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] - - if allow_outside_center: - pass - else: - t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)] - t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, - (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] - - # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero - t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 - # recorrect w,h not higher than input size - t_box[:, 2][t_box[:, 2] > input_w] = input_w - t_box[:, 3][t_box[:, 3] > input_h] = input_h - box_w = t_box[:, 2] - t_box[:, 0] - box_h = t_box[:, 3] - t_box[:, 1] - # discard invalid box: w or h smaller than 1 pixel - t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] - - if t_box.shape[0] > 0: - # break if number of find t_box - box_data[: len(t_box)] = t_box - return box_data, candidate - raise Exception('all candidates can not satisfied re-correct bbox') - - -def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, - anchors, num_classes, max_trial=10, device_num=1): - """Crop an image randomly with bounding box constraints. - - This data augmentation is used in training of - Single Shot Multibox Detector [#]_. More details can be found in - data augmentation section of the original paper. - .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, - Scott Reed, Cheng-Yang Fu, Alexander C. Berg. - SSD: Single Shot MultiBox Detector. ECCV 2016.""" - - if not isinstance(image, Image.Image): - image = Image.fromarray(image) - - image_w, image_h = image.size - input_h, input_w = image_input_size - - np.random.shuffle(box) - if len(box) > max_boxes: - box = box[:max_boxes] - flip = _rand() < .5 - box_data = np.zeros((max_boxes, 5)) - - candidates = _choose_candidate_by_constraints(use_constraints=False, - max_trial=max_trial, - input_w=input_w, - input_h=input_h, - image_w=image_w, - image_h=image_h, - jitter=jitter, - box=box) - box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, - input_w=input_w, - input_h=input_h, - image_w=image_w, - image_h=image_h, - flip=flip, - box=box, - box_data=box_data, - allow_outside_center=True) - dx, dy, nw, nh = candidate - interp = get_interp_method(interp=10) - image = image.resize((nw, nh), pil_image_reshape(interp)) - # place image, gray color as back graoud - new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) - new_image.paste(image, (dx, dy)) - image = new_image - - if flip: - image = filp_pil_image(image) - - image = np.array(image) - - image = convert_gray_to_color(image) - - image_data = color_distortion(image, hue, sat, val, device_num) - image_data = statistic_normalize_img(image_data, statistic_norm=True) - - image_data = image_data.astype(np.float32) - - return image_data, box_data - - -def preprocess_fn(image, box, config, input_size, device_num): - """Preprocess data function.""" - config_anchors = config.anchor_scales - anchors = np.array([list(x) for x in config_anchors]) - max_boxes = config.max_box - num_classes = config.num_classes - jitter = config.jitter - hue = config.hue - sat = config.saturation - val = config.value - image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, - image_input_size=input_size, max_boxes=max_boxes, - num_classes=num_classes, anchors=anchors, device_num=device_num) - return image, anno - - -def reshape_fn(image, img_id, config): - input_size = config.test_img_shape - image, ori_image_shape = _reshape_data(image, image_size=input_size) - return image, ori_image_shape, img_id - - -class MultiScaleTrans: - """Multi scale transform.""" - def __init__(self, config, device_num): - self.config = config - self.seed = 0 - self.size_list = [] - self.resize_rate = config.resize_rate - self.dataset_size = config.dataset_size - self.size_dict = {} - self.seed_num = int(1e6) - self.seed_list = self.generate_seed_list(seed_num=self.seed_num) - self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) - self.device_num = device_num - - def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): - seed_list = [] - random.seed(init_seed) - for _ in range(seed_num): - seed = random.randint(seed_range[0], seed_range[1]) - seed_list.append(seed) - return seed_list - - def __call__(self, imgs, annos, batchInfo): - epoch_num = batchInfo.get_epoch_num() - size_idx = int(batchInfo.get_batch_num() / self.resize_rate) - seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] - ret_imgs = [] - ret_annos = [] - - if self.size_dict.get(seed_key, None) is None: - random.seed(seed_key) - new_size = random.choice(self.config.multi_scale) - self.size_dict[seed_key] = new_size - seed = seed_key - - input_size = self.size_dict[seed] - for img, anno in zip(imgs, annos): - img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) - ret_imgs.append(img.transpose(2, 0, 1).copy()) - ret_annos.append(anno) - return np.array(ret_imgs), np.array(ret_annos) - - -def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, - batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): - """Preprocess true box for multi-thread.""" - i = 0 - for anno in annos: - bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ - _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, - num_classes=config.num_classes, max_boxes=config.max_box, - label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) - batch_bbox_true_1[result_index + i] = bbox_true_1 - batch_bbox_true_2[result_index + i] = bbox_true_2 - batch_bbox_true_3[result_index + i] = bbox_true_3 - batch_gt_box1[result_index + i] = gt_box1 - batch_gt_box2[result_index + i] = gt_box2 - batch_gt_box3[result_index + i] = gt_box3 - i = i + 1 - - -def batch_preprocess_true_box(annos, config, input_shape): - """Preprocess true box with multi-thread.""" - batch_bbox_true_1 = [] - batch_bbox_true_2 = [] - batch_bbox_true_3 = [] - batch_gt_box1 = [] - batch_gt_box2 = [] - batch_gt_box3 = [] - threads = [] - - step = 4 - for index in range(0, len(annos), step): - for _ in range(step): - batch_bbox_true_1.append(None) - batch_bbox_true_2.append(None) - batch_bbox_true_3.append(None) - batch_gt_box1.append(None) - batch_gt_box2.append(None) - batch_gt_box3.append(None) - step_anno = annos[index: index + step] - t = threading.Thread(target=thread_batch_preprocess_true_box, - args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, - batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) - t.start() - threads.append(t) - - for t in threads: - t.join() - - return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ - np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) - - -def batch_preprocess_true_box_single(annos, config, input_shape): - """Preprocess true boxes.""" - batch_bbox_true_1 = [] - batch_bbox_true_2 = [] - batch_bbox_true_3 = [] - batch_gt_box1 = [] - batch_gt_box2 = [] - batch_gt_box3 = [] - for anno in annos: - bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ - _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, - num_classes=config.num_classes, max_boxes=config.max_box, - label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) - batch_bbox_true_1.append(bbox_true_1) - batch_bbox_true_2.append(bbox_true_2) - batch_bbox_true_3.append(bbox_true_3) - batch_gt_box1.append(gt_box1) - batch_gt_box2.append(gt_box2) - batch_gt_box3.append(gt_box3) - - return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ - np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) diff --git a/model_zoo/yolov3_darknet53/src/util.py b/model_zoo/yolov3_darknet53/src/util.py deleted file mode 100644 index 1a3da99181..0000000000 --- a/model_zoo/yolov3_darknet53/src/util.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Util class or function.""" -from mindspore.train.serialization import load_checkpoint -import mindspore.nn as nn - - -class AverageMeter: - """Computes and stores the average and current value""" - - def __init__(self, name, fmt=':f', tb_writer=None): - self.name = name - self.fmt = fmt - self.reset() - self.tb_writer = tb_writer - self.cur_step = 1 - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - if self.tb_writer is not None: - self.tb_writer.add_scalar(self.name, self.val, self.cur_step) - self.cur_step += 1 - - def __str__(self): - fmtstr = '{name}:{avg' + self.fmt + '}' - return fmtstr.format(**self.__dict__) - - -def load_backbone(net, ckpt_path, args): - """Load darknet53 backbone checkpoint.""" - param_dict = load_checkpoint(ckpt_path) - yolo_backbone_prefix = 'feature_map.backbone' - darknet_backbone_prefix = 'network.backbone' - find_param = [] - not_found_param = [] - - for name, cell in net.cells_and_names(): - if name.startswith(yolo_backbone_prefix): - name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) - if isinstance(cell, (nn.Conv2d, nn.Dense)): - darknet_weight = '{}.weight'.format(name) - darknet_bias = '{}.bias'.format(name) - if darknet_weight in param_dict: - cell.weight.default_input = param_dict[darknet_weight].data - find_param.append(darknet_weight) - else: - not_found_param.append(darknet_weight) - if darknet_bias in param_dict: - cell.bias.default_input = param_dict[darknet_bias].data - find_param.append(darknet_bias) - else: - not_found_param.append(darknet_bias) - elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): - darknet_moving_mean = '{}.moving_mean'.format(name) - darknet_moving_variance = '{}.moving_variance'.format(name) - darknet_gamma = '{}.gamma'.format(name) - darknet_beta = '{}.beta'.format(name) - if darknet_moving_mean in param_dict: - cell.moving_mean.default_input = param_dict[darknet_moving_mean].data - find_param.append(darknet_moving_mean) - else: - not_found_param.append(darknet_moving_mean) - if darknet_moving_variance in param_dict: - cell.moving_variance.default_input = param_dict[darknet_moving_variance].data - find_param.append(darknet_moving_variance) - else: - not_found_param.append(darknet_moving_variance) - if darknet_gamma in param_dict: - cell.gamma.default_input = param_dict[darknet_gamma].data - find_param.append(darknet_gamma) - else: - not_found_param.append(darknet_gamma) - if darknet_beta in param_dict: - cell.beta.default_input = param_dict[darknet_beta].data - find_param.append(darknet_beta) - else: - not_found_param.append(darknet_beta) - - args.logger.info('================found_param {}========='.format(len(find_param))) - args.logger.info(find_param) - args.logger.info('================not_found_param {}========='.format(len(not_found_param))) - args.logger.info(not_found_param) - args.logger.info('=====load {} successfully ====='.format(ckpt_path)) - - return net - - -def default_wd_filter(x): - """default weight decay filter.""" - parameter_name = x.name - if parameter_name.endswith('.bias'): - # all bias not using weight decay - return False - if parameter_name.endswith('.gamma'): - # bn weight bias not using weight decay, be carefully for now x not include BN - return False - if parameter_name.endswith('.beta'): - # bn weight bias not using weight decay, be carefully for now x not include BN - return False - - return True - - -def get_param_groups(network): - """Param groups for optimizer.""" - decay_params = [] - no_decay_params = [] - for x in network.trainable_params(): - parameter_name = x.name - if parameter_name.endswith('.bias'): - # all bias not using weight decay - no_decay_params.append(x) - elif parameter_name.endswith('.gamma'): - # bn weight bias not using weight decay, be carefully for now x not include BN - no_decay_params.append(x) - elif parameter_name.endswith('.beta'): - # bn weight bias not using weight decay, be carefully for now x not include BN - no_decay_params.append(x) - else: - decay_params.append(x) - - return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] - - -class ShapeRecord: - """Log image shape.""" - def __init__(self): - self.shape_record = { - 320: 0, - 352: 0, - 384: 0, - 416: 0, - 448: 0, - 480: 0, - 512: 0, - 544: 0, - 576: 0, - 608: 0, - 'total': 0 - } - - def set(self, shape): - if len(shape) > 1: - shape = shape[0] - shape = int(shape) - self.shape_record[shape] += 1 - self.shape_record['total'] += 1 - - def show(self, logger): - for key in self.shape_record: - rate = self.shape_record[key] / float(self.shape_record['total']) - logger.info('shape {}: {:.2f}%'.format(key, rate*100)) diff --git a/model_zoo/yolov3_resnet18/scripts/run_distribute_train.sh b/model_zoo/yolov3_resnet18/scripts/run_distribute_train.sh deleted file mode 100644 index eeda5077e9..0000000000 --- a/model_zoo/yolov3_resnet18/scripts/run_distribute_train.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -echo "=======================================================================================================================================================" -echo "Please run the scipt as: " -echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" -echo "For example: sh run_distribute_train.sh 8 150 /data/Mindrecord_train /data /data/train.txt /data/hccl.json /opt/yolov3-150.ckpt(optional) 100(optional)" -echo "It is better to use absolute path." -echo "The learning rate is 0.005 as default, if you want other lr, please change the value in this script." -echo "=======================================================================================================================================================" - -if [ $# != 6 ] && [ $# != 8 ] -then - echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [MINDRECORD_DIR] [IMAGE_DIR] [ANNO_PATH] [MINDSPORE_HCCL_CONFIG_PATH] \ -[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" - exit 1 -fi - -EPOCH_SIZE=$2 -MINDRECORD_DIR=$3 -IMAGE_DIR=$4 -ANNO_PATH=$5 -PRE_TRAINED=$7 -PRE_TRAINED_EPOCH_SIZE=$8 - -# Before start distribute train, first create mindrecord files. -python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \ ---anno_path=$ANNO_PATH - -echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" - -export MINDSPORE_HCCL_CONFIG_PATH=$6 -export RANK_SIZE=$1 - -BASE_PATH=$(cd "`dirname $0`" || exit; pwd) -cd $BASE_PATH/../ || exit - -for((i=0;i env.log - - if [ $# == 6 ] - then - taskset -c $cmdopt python train.py \ - --distribute=1 \ - --lr=0.005 \ - --device_num=$RANK_SIZE \ - --device_id=$DEVICE_ID \ - --mindrecord_dir=$MINDRECORD_DIR \ - --image_dir=$IMAGE_DIR \ - --epoch_size=$EPOCH_SIZE \ - --anno_path=$ANNO_PATH > log.txt 2>&1 & - fi - - if [ $# == 8 ] - then - taskset -c $cmdopt python train.py \ - --distribute=1 \ - --lr=0.005 \ - --device_num=$RANK_SIZE \ - --device_id=$DEVICE_ID \ - --mindrecord_dir=$MINDRECORD_DIR \ - --image_dir=$IMAGE_DIR \ - --epoch_size=$EPOCH_SIZE \ - --pre_trained=$PRE_TRAINED \ - --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ - --anno_path=$ANNO_PATH > log.txt 2>&1 & - fi - - cd ../ -done diff --git a/model_zoo/yolov3_resnet18/src/dataset.py b/model_zoo/yolov3_resnet18/src/dataset.py deleted file mode 100644 index f85b442209..0000000000 --- a/model_zoo/yolov3_resnet18/src/dataset.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""YOLOv3 dataset""" -from __future__ import division - -import os -import numpy as np -from matplotlib.colors import rgb_to_hsv, hsv_to_rgb -from PIL import Image -import mindspore.dataset as de -from mindspore.mindrecord import FileWriter -import mindspore.dataset.transforms.vision.c_transforms as C -from src.config import ConfigYOLOV3ResNet18 - -iter_cnt = 0 -_NUM_BOXES = 50 - -def preprocess_fn(image, box, is_training): - """Preprocess function for dataset.""" - config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326] - anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2) - do_hsv = False - max_boxes = 20 - num_classes = ConfigYOLOV3ResNet18.num_classes - - def _rand(a=0., b=1.): - return np.random.rand() * (b - a) + a - - def _preprocess_true_boxes(true_boxes, anchors, in_shape=None): - """Get true boxes.""" - num_layers = anchors.shape[0] // 3 - anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - true_boxes = np.array(true_boxes, dtype='float32') - # input_shape = np.array([in_shape, in_shape], dtype='int32') - input_shape = np.array(in_shape, dtype='int32') - boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. - boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] - true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] - true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] - - grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] - y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), - 5 + num_classes), dtype='float32') for l in range(num_layers)] - - anchors = np.expand_dims(anchors, 0) - anchors_max = anchors / 2. - anchors_min = -anchors_max - - valid_mask = boxes_wh[..., 0] >= 1 - - wh = boxes_wh[valid_mask] - - - if len(wh) >= 1: - wh = np.expand_dims(wh, -2) - boxes_max = wh / 2. - boxes_min = -boxes_max - - intersect_min = np.maximum(boxes_min, anchors_min) - intersect_max = np.minimum(boxes_max, anchors_max) - intersect_wh = np.maximum(intersect_max - intersect_min, 0.) - intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] - box_area = wh[..., 0] * wh[..., 1] - anchor_area = anchors[..., 0] * anchors[..., 1] - iou = intersect_area / (box_area + anchor_area - intersect_area) - - best_anchor = np.argmax(iou, axis=-1) - for t, n in enumerate(best_anchor): - for l in range(num_layers): - if n in anchor_mask[l]: - i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') - j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') - k = anchor_mask[l].index(n) - - c = true_boxes[t, 4].astype('int32') - y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] - y_true[l][j, i, k, 4] = 1. - y_true[l][j, i, k, 5 + c] = 1. - - pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32) - pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32) - pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32) - - mask0 = np.reshape(y_true[0][..., 4:5], [-1]) - gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) - gt_box0 = gt_box0[mask0 == 1] - pad_gt_box0[:gt_box0.shape[0]] = gt_box0 - - mask1 = np.reshape(y_true[1][..., 4:5], [-1]) - gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) - gt_box1 = gt_box1[mask1 == 1] - pad_gt_box1[:gt_box1.shape[0]] = gt_box1 - - mask2 = np.reshape(y_true[2][..., 4:5], [-1]) - gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) - gt_box2 = gt_box2[mask2 == 1] - pad_gt_box2[:gt_box2.shape[0]] = gt_box2 - - return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 - - def _infer_data(img_data, input_shape, box): - w, h = img_data.size - input_h, input_w = input_shape - scale = min(float(input_w) / float(w), float(input_h) / float(h)) - nw = int(w * scale) - nh = int(h * scale) - img_data = img_data.resize((nw, nh), Image.BICUBIC) - - new_image = np.zeros((input_h, input_w, 3), np.float32) - new_image.fill(128) - img_data = np.array(img_data) - if len(img_data.shape) == 2: - img_data = np.expand_dims(img_data, axis=-1) - img_data = np.concatenate([img_data, img_data, img_data], axis=-1) - - dh = int((input_h - nh) / 2) - dw = int((input_w - nw) / 2) - new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data - new_image /= 255. - new_image = np.transpose(new_image, (2, 0, 1)) - new_image = np.expand_dims(new_image, 0) - return new_image, np.array([h, w], np.float32), box - - def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)): - """Data augmentation function.""" - if not isinstance(image, Image.Image): - image = Image.fromarray(image) - - iw, ih = image.size - ori_image_shape = np.array([ih, iw], np.int32) - h, w = image_size - - if not is_training: - return _infer_data(image, image_size, box) - - flip = _rand() < .5 - # correct boxes - box_data = np.zeros((max_boxes, 5)) - while True: - # Prevent the situation that all boxes are eliminated - new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \ - _rand(1 - jitter, 1 + jitter) - scale = _rand(0.25, 2) - - if new_ar < 1: - nh = int(scale * h) - nw = int(nh * new_ar) - else: - nw = int(scale * w) - nh = int(nw / new_ar) - - dx = int(_rand(0, w - nw)) - dy = int(_rand(0, h - nh)) - - if len(box) >= 1: - t_box = box.copy() - np.random.shuffle(t_box) - t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx - t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy - if flip: - t_box[:, [0, 2]] = w - t_box[:, [2, 0]] - t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 - t_box[:, 2][t_box[:, 2] > w] = w - t_box[:, 3][t_box[:, 3] > h] = h - box_w = t_box[:, 2] - t_box[:, 0] - box_h = t_box[:, 3] - t_box[:, 1] - t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box - - if len(t_box) >= 1: - box = t_box - break - - box_data[:len(box)] = box - # resize image - image = image.resize((nw, nh), Image.BICUBIC) - # place image - new_image = Image.new('RGB', (w, h), (128, 128, 128)) - new_image.paste(image, (dx, dy)) - image = new_image - - # flip image or not - if flip: - image = image.transpose(Image.FLIP_LEFT_RIGHT) - - # convert image to gray or not - gray = _rand() < .25 - if gray: - image = image.convert('L').convert('RGB') - - # when the channels of image is 1 - image = np.array(image) - if len(image.shape) == 2: - image = np.expand_dims(image, axis=-1) - image = np.concatenate([image, image, image], axis=-1) - - # distort image - hue = _rand(-hue, hue) - sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) - val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) - image_data = image / 255. - if do_hsv: - x = rgb_to_hsv(image_data) - x[..., 0] += hue - x[..., 0][x[..., 0] > 1] -= 1 - x[..., 0][x[..., 0] < 0] += 1 - x[..., 1] *= sat - x[..., 2] *= val - x[x > 1] = 1 - x[x < 0] = 0 - image_data = hsv_to_rgb(x) # numpy array, 0 to 1 - image_data = image_data.astype(np.float32) - - # preprocess bounding boxes - bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ - _preprocess_true_boxes(box_data, anchors, image_size) - - return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ - ori_image_shape, gt_box1, gt_box2, gt_box3 - - if is_training: - images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training) - return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3 - - images, shape, anno = _data_aug(image, box, is_training) - return images, shape, anno - - -def anno_parser(annos_str): - """Parse annotation from string to list.""" - annos = [] - for anno_str in annos_str: - anno = list(map(int, anno_str.strip().split(','))) - annos.append(anno) - return annos - - -def filter_valid_data(image_dir, anno_path): - """Filter valid image file, which both in image_dir and anno_path.""" - image_files = [] - image_anno_dict = {} - if not os.path.isdir(image_dir): - raise RuntimeError("Path given is not valid.") - if not os.path.isfile(anno_path): - raise RuntimeError("Annotation file is not valid.") - - with open(anno_path, "rb") as f: - lines = f.readlines() - for line in lines: - line_str = line.decode("utf-8").strip() - line_split = str(line_str).split(' ') - file_name = line_split[0] - if os.path.isfile(os.path.join(image_dir, file_name)): - image_anno_dict[file_name] = anno_parser(line_split[1:]) - image_files.append(file_name) - return image_files, image_anno_dict - - -def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix, file_num): - """Create MindRecord file by image_dir and anno_path.""" - mindrecord_path = os.path.join(mindrecord_dir, prefix) - writer = FileWriter(mindrecord_path, file_num) - image_files, image_anno_dict = filter_valid_data(image_dir, anno_path) - - yolo_json = { - "image": {"type": "bytes"}, - "annotation": {"type": "int64", "shape": [-1, 5]}, - } - writer.add_schema(yolo_json, "yolo_json") - - for image_name in image_files: - image_path = os.path.join(image_dir, image_name) - with open(image_path, 'rb') as f: - img = f.read() - annos = np.array(image_anno_dict[image_name]) - row = {"image": img, "annotation": annos} - writer.write_raw_data([row]) - writer.commit() - - -def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0, - is_training=True, num_parallel_workers=8): - """Creatr YOLOv3 dataset with MindDataset.""" - ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, - num_parallel_workers=num_parallel_workers, shuffle=is_training) - decode = C.Decode() - ds = ds.map(input_columns=["image"], operations=decode) - compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) - - if is_training: - hwc_to_chw = C.HWC2CHW() - ds = ds.map(input_columns=["image", "annotation"], - output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], - columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], - operations=compose_map_func, num_parallel_workers=num_parallel_workers) - ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) - else: - ds = ds.map(input_columns=["image", "annotation"], - output_columns=["image", "image_shape", "annotation"], - columns_order=["image", "image_shape", "annotation"], - operations=compose_map_func, num_parallel_workers=num_parallel_workers) - return ds diff --git a/model_zoo/yolov3_resnet18/train.py b/model_zoo/yolov3_resnet18/train.py deleted file mode 100644 index 0a15066ed3..0000000000 --- a/model_zoo/yolov3_resnet18/train.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# less required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -######################## train YOLOv3 example ######################## -train YOLOv3 and get network model files(.ckpt) : -python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train - -If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path. -Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path. -""" - -import os -import argparse -import numpy as np -import mindspore.nn as nn -from mindspore import context, Tensor -from mindspore.communication.management import init -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor -from mindspore.train import Model, ParallelMode -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.common.initializer import initializer - -from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper -from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image -from src.config import ConfigYOLOV3ResNet18 - - -def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): - """Set learning rate.""" - lr_each_step = [] - for i in range(global_step): - if steps: - lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step))) - else: - lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step))) - lr_each_step = np.array(lr_each_step).astype(np.float32) - lr_each_step = lr_each_step[start_step:] - return lr_each_step - - -def init_net_param(network, init_value='ones'): - """Init:wq the parameters in network.""" - params = network.trainable_params() - for p in params: - if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: - p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype)) - - -def main(): - parser = argparse.ArgumentParser(description="YOLOv3 train") - parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " - "Mindrecord, default is false.") - parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") - parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") - parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--lr", type=float, default=0.001, help="Learning rate, default is 0.001.") - parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") - parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") - parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") - parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path") - parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size") - parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") - parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") - parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", - help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by" - "image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir " - "rather than image_dir and anno_path. Default is ./Mindrecord_train") - parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, " - "the absolute image path is joined by the image_dir " - "and the relative path in anno_path") - parser.add_argument("--anno_path", type=str, default="", help="Annotation path.") - args_opt = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) - if args_opt.distribute: - device_num = args_opt.device_num - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, - device_num=device_num) - init() - rank = args_opt.device_id % device_num - else: - rank = 0 - device_num = 1 - - print("Start create dataset!") - - # It will generate mindrecord file in args_opt.mindrecord_dir, - # and the file name is yolo.mindrecord0, 1, ... file_num. - if not os.path.isdir(args_opt.mindrecord_dir): - os.makedirs(args_opt.mindrecord_dir) - - prefix = "yolo.mindrecord" - mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0") - if not os.path.exists(mindrecord_file): - if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path): - print("Create Mindrecord.") - data_to_mindrecord_byte_image(args_opt.image_dir, - args_opt.anno_path, - args_opt.mindrecord_dir, - prefix=prefix, - file_num=8) - print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir)) - else: - print("image_dir or anno_path not exits.") - - if not args_opt.only_create_dataset: - loss_scale = float(args_opt.loss_scale) - - # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. - dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, - batch_size=args_opt.batch_size, device_num=device_num, rank=rank) - dataset_size = dataset.get_dataset_size() - print("Create dataset done!") - - net = yolov3_resnet18(ConfigYOLOV3ResNet18()) - net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) - init_net_param(net, "XavierUniform") - - # checkpoint - ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) - ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) - - if args_opt.pre_trained: - if args_opt.pre_trained_epoch_size <= 0: - raise KeyError("pre_trained_epoch_size must be greater than 0.") - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - total_epoch_size = 60 - if args_opt.distribute: - total_epoch_size = 160 - lr = Tensor(get_lr(learning_rate=args_opt.lr, start_step=args_opt.pre_trained_epoch_size * dataset_size, - global_step=total_epoch_size * dataset_size, - decay_step=1000, decay_rate=0.95, steps=True)) - opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) - net = TrainingWrapper(net, opt, loss_scale) - - callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] - - model = Model(net) - dataset_sink_mode = False - if args_opt.mode == "sink": - print("In sink mode, one epoch return a loss.") - dataset_sink_mode = True - print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") - model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) - -if __name__ == '__main__': - main() diff --git a/predict/.gitignore b/predict/.gitignore deleted file mode 100644 index caf7aec495..0000000000 --- a/predict/.gitignore +++ /dev/null @@ -1,14 +0,0 @@ -# git ignore file for predict - -#flatbuf generated file -schema/*_generated.h -schema/inner/*_generated.h -module/tvm_module/lite/include/*_generated.h - -#tvm fbs files -module/tvm_module/lite/tune/convert/*.fbs - -#doTest dir -test/doTest/ - - diff --git a/predict/CMakeLists.txt b/predict/CMakeLists.txt deleted file mode 100755 index 39ca6b27e8..0000000000 --- a/predict/CMakeLists.txt +++ /dev/null @@ -1,79 +0,0 @@ -cmake_minimum_required(VERSION 3.12.1) -project (mindspore-predict) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") -set(CMAKE_BUILD_TYPE "Release") - -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=hidden") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") -set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s") - -option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs" OFF) -option(ENABLE_PREDICT_ARM64 "predict arm64" OFF) -option(ENABLE_PREDICT_ARM32 "predict arm32" OFF) - -set(PREDICT_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(PREDICT_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/build) -set(3RD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../third_party) -set(DOTEST_DIR ${PREDICT_BUILD_DIR}/test/doTest) - -include_directories(${3RD_DIR}) -include_directories(${3RD_DIR}/flatbuffers/include/) -include_directories(${3RD_DIR}/protobuf/build/include/) -include_directories(${3RD_DIR}/googletest/googletest/include/) -include_directories(${3RD_DIR}/googletest/googlemock/include/) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/module/tvm_kernel/lite/include/) -include_directories(${PREDICT_DIR}/module/tvm_kernel/incubator-tvm/3rdparty/dlpack/include) -include_directories(common) - -if(ENABLE_PREDICT_ARM64 OR ENABLE_PREDICT_ARM32) - message("*********************predict compile arm*********************") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_USE_ARM=1") - set(ANDROID_NDK $ENV{ANDROID_NDK}) - if(ANDROID_NDK) - add_subdirectory(${3RD_DIR}/googletest ${CMAKE_BINARY_DIR}/googletest) - link_directories(${PREDICT_BUILD_DIR}/googletest/googlemock/gtest) - - add_subdirectory(${3RD_DIR}/securec ${CMAKE_BINARY_DIR}/securec) - link_directories(${PREDICT_BUILD_DIR}/securec/src) - else() - message(FATAL_ERROR "please set ANDROID_NDK in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r16b/") - endif() - - include_directories(${ANDROID_SYSROOT}/usr/include/) - if(${ANDROID_ABI} STREQUAL "armeabi-v7a") - include_directories(${ANDROID_SYSROOT}/usr/include/arm-linux-androideabi) - elseif(${ANDROID_ABI} STREQUAL "arm64-v8a") - include_directories(${ANDROID_SYSROOT}/usr/include/aarch64-linux-android) - else() - include_directories(${ANDROID_SYSROOT}/usr/include/arm-linux-androideabi) - endif() - -else() - # include libsecurec.a x86 - message("*********************predict compile x86*********************") - if(EXISTS "${PREDICT_DIR}/../build/mindspore/securec/src/libsecurec.a") - link_directories(${PREDICT_DIR}/../build/mindspore/securec/src) - else() - include(${PREDICT_DIR}/../cmake/dependency_securec.cmake) - link_directories(${PREDICT_BUILD_DIR}/securec/src) - endif() - - # include libgtest.so x86 - if(EXISTS "${PREDICT_DIR}/../build/googletest/googlemock/gtest/libgtest.so") - link_directories(${PREDICT_DIR}/../build/googletest/googlemock/gtest) - else() - include(${PREDICT_DIR}/../cmake/dependency_gtest.cmake) - link_directories(${PREDICT_BUILD_DIR}/googletest/googlemock/gtest) - endif() -endif() - -if (CODE_COVERAGE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage -O0") -endif() - -add_subdirectory(common) -add_subdirectory(src) -add_subdirectory(benchmark) -add_subdirectory(test) -add_subdirectory(module) diff --git a/predict/benchmark/CMakeLists.txt b/predict/benchmark/CMakeLists.txt deleted file mode 100755 index 22f87d8a97..0000000000 --- a/predict/benchmark/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ - -cmake_minimum_required(VERSION 3.12) -project(benchmark) - -set(CMAKE_CXX_STANDARD 14) -set(CMAKE_BUILD_TYPE "Debug") - -#include 3rd -include_directories(${3RD_DIR}/protobuf/build/include) -include_directories(${3RD_DIR}/securec/include) -include_directories(${3RD_DIR}/flatbuffers/include) -include_directories(${3RD_DIR}/googletest/googletest/include) -include_directories(${3RD_DIR}/googletest/googlemock/include) -include_directories(${PREDICT_DIR}/module/tvm_kernel/incubator-tvm/3rdparty/dlpack/include) -include_directories(${3RD_DIR}/flatbuffers/include) -include_directories(${3RD_DIR}/securec/include) - -#include ms -include_directories(.) -include_directories(${PREDICT_DIR}) - -set(COMMON_SRC ${PREDICT_DIR}/common/flag_parser.cc - ${PREDICT_DIR}/common/file_utils.cc - ${PREDICT_DIR}/common/func_utils.cc - ${PREDICT_DIR}/common/mslog.cc - ${PREDICT_DIR}/common/utils.cc) - -link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../output/lib/) - -add_executable(benchmark main.cc benchmark.cc ${COMMON_SRC}) - -target_link_libraries(benchmark mspredict libsecurec.a) -add_dependencies(benchmark tvm_kernel) -add_dependencies(benchmark securec) - -add_custom_command(TARGET benchmark POST_BUILD - COMMAND mkdir -pv ${DOTEST_DIR} - COMMAND cp ${PREDICT_BUILD_DIR}/benchmark/benchmark ${DOTEST_DIR}) diff --git a/predict/benchmark/benchmark.cc b/predict/benchmark/benchmark.cc deleted file mode 100644 index c55d03e450..0000000000 --- a/predict/benchmark/benchmark.cc +++ /dev/null @@ -1,451 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "benchmark/benchmark.h" -#include -#include -#include -#include -#include -#include "include/session.h" - -namespace mindspore { -namespace predict { -STATUS Benchmark::GenerateRandomData(size_t size, void *data) { - MS_ASSERT(data != nullptr); - char *castedData = static_cast(data); - for (size_t i = 0; i < size; i++) { - castedData[i] = static_cast(i); - } - return RET_OK; -} - -STATUS Benchmark::GenerateInputData() { - for (Tensor *tensor : msInputs) { - MS_ASSERT(tensor != nullptr); - auto ret = tensor->MallocData(); - if (ret != RET_OK) { - MS_LOGE("MallocData for inTensor failed %d", ret); - return ret; - } - MS_ASSERT(tensor->GetData() != nullptr); - auto tensorByteSize = tensor->GetDataSize(); - auto status = GenerateRandomData(tensorByteSize, tensor->GetData()); - if (status != RET_OK) { - MS_LOGE("GenerateRandomData for inTensor failed %d", status); - return status; - } - } - return RET_OK; -} - -STATUS Benchmark::LoadInput() { - size_t size = 0; - char *graphBuf = ReadFile(_flags->modelPath.c_str(), &size); - if (graphBuf == nullptr) { - MS_LOGE("Load graph failed, path %s", _flags->modelPath.c_str()); - return RET_ERROR; - } - - this->msInputs = session->GetInput(); - - if (_flags->inDataPath.empty()) { - auto status = GenerateInputData(); - if (status != RET_OK) { - delete graphBuf; - MS_LOGE("Generate input data error %d", status); - return status; - } - } else { - auto status = ReadInputFile(); - if (status != RET_OK) { - delete graphBuf; - MS_LOGE("ReadInputFile error, %d", status); - return status; - } - } - delete graphBuf; - return RET_OK; -} - -STATUS Benchmark::ReadInputFile() { - MS_ASSERT(msInputs.size() <= 1); - if (msInputs.empty()) { - return RET_OK; - } - Tensor *inTensor = msInputs.at(0); - MS_ASSERT(inTensor != nullptr); - - size_t size; - char *binBuf = ReadFile(_flags->inDataPath.c_str(), &size); - if (binBuf == nullptr) { - return RET_ERROR; - } - auto tensorDataSize = inTensor->GetDataSize(); - if (size != tensorDataSize) { - MS_LOGE("Input binary file size error, required: %zu, in fact: %zu", tensorDataSize, size); - delete binBuf; - return RET_ERROR; - } - inTensor->SetData(binBuf); - binBuf = nullptr; - - return RET_OK; -} - -// calibData is FP32 -STATUS Benchmark::ReadCalibData() { - const char *calibDataPath = _flags->calibDataPath.c_str(); - // read calib data - std::ifstream inFile(calibDataPath); - if (!inFile.good()) { - MS_LOGE("file: %s is not exist", calibDataPath); - return RET_PARAM_INVALID; - } - - if (!inFile.is_open()) { - MS_LOGE("file: %s open failed", calibDataPath); - inFile.close(); - return RET_PARAM_INVALID; - } - - std::string line; - MS_LOGI("Start reading calibData file"); - std::string tensorName; - while (!inFile.eof()) { - getline(inFile, line); - std::stringstream stringLine1(line); - size_t dim = 0; - stringLine1 >> tensorName >> dim; - std::vector dims; - size_t shapeSize = 1; - for (size_t i = 0; i < dim; i++) { - size_t tmpDim; - stringLine1 >> tmpDim; - dims.push_back(tmpDim); - shapeSize *= tmpDim; - } - - getline(inFile, line); - std::stringstream stringLine2(line); - std::vector tensorData; - for (size_t i = 0; i < shapeSize; i++) { - float tmpData; - stringLine2 >> tmpData; - tensorData.push_back(tmpData); - } - - std::unique_ptr checkTensor(new CheckTensor(dims, tensorData)); - this->calibData.insert(std::make_pair(tensorName, checkTensor.release())); - } - inFile.close(); - MS_LOGI("Finish reading calibData file"); - return RET_OK; -} - -// tensorData need to be converter first -float Benchmark::CompareData(const std::string &nodeName, std::vector msShape, float *msTensorData) { - auto iter = this->calibData.find(nodeName); - if (iter != this->calibData.end()) { - std::vector castedMSShape; - size_t shapeSize = 1; - for (int64_t dim : msShape) { - castedMSShape.push_back(size_t(dim)); - shapeSize *= dim; - } - - CheckTensor *calibTensor = iter->second; - if (calibTensor->shape != castedMSShape) { - std::ostringstream oss; - oss << "Shape of mslite output("; - for (auto dim : castedMSShape) { - oss << dim << ","; - } - oss << ") and shape source model output("; - for (auto dim : calibTensor->shape) { - oss << dim << ","; - } - oss << ") are different"; - MS_LOGE("%s", oss.str().c_str()); - return -1; - } - - float meanBias = 0; - std::ostringstream outputData; - outputData << "Data of node " << nodeName << " : "; - for (size_t j = 0; j < shapeSize; j++) { - if (j < printNum) { - outputData << msTensorData[j] << " "; - } - if (fabs(calibTensor->data.at(j)) > minFloatThr) { - double bias = fabs(msTensorData[j] - calibTensor->data.at(j)) / fabs(calibTensor->data.at(j)); - meanBias += bias; - } - } - meanBias /= shapeSize; - MS_LOGI("%s", outputData.str().c_str()); - - if (meanBias <= minFloatThr) { - MS_LOGI("Mean bias of node %s : 0%%", nodeName.c_str()); - } else { - MS_LOGI("Mean bias of node %s : %f%%", nodeName.c_str(), meanBias * percentage); - } - return meanBias; - } else { - MS_LOGI("%s is not in Source Model output", nodeName.c_str()); - return -1; - } -} - -STATUS Benchmark::CompareOutput(const std::map> &msOutputs) { - float totalBias = 0; - int totalSize = 0; - bool hasError = false; - for (const auto &msOutput : msOutputs) { - std::string nodeName = msOutput.first; - auto tensors = msOutput.second; - for (auto tensor : tensors) { - MS_ASSERT(tensor->GetData() != nullptr); - float bias = CompareData(nodeName, tensor->GetDims(), static_cast(tensor->GetData())); - if (bias >= 0) { - totalBias += bias; - totalSize++; - } else { - hasError = true; - break; - } - } - } - - if (!hasError) { - float meanBias; - if (totalSize != 0) { - meanBias = totalBias / totalSize * percentage; - } else { - meanBias = 0; - } - - MS_LOGI("Mean bias all node : %f%%", meanBias); - - if (meanBias > 1) { - MS_LOGE("Mean bias of all nodes is too big: %f%%", meanBias); - return RET_ERROR; - } else { - return RET_OK; - } - } else { - MS_LOGE("Error in CompareData"); - return RET_ERROR; - } -} - -STATUS Benchmark::MarkPerformance() { - MS_LOGI("Running warm up loops..."); - for (int i = 0; i < _flags->warmUpLoopCount; i++) { - auto status = session->Run(msInputs); - if (status != RET_OK) { - MS_LOGE("Inference error %d", status); - return status; - } - } - - MS_LOGI("Running benchmark loops..."); - uint64_t timeMin = maxTimeThr; - uint64_t timeMax = 0; - uint64_t timeAvg = 0; - for (int i = 0; i < _flags->loopCount; i++) { - uint64_t start = GetTimeUs(); - auto status = session->Run(msInputs); - if (status != RET_OK) { - MS_LOGE("Inference error %d", status); - return status; - } - - uint64_t end = GetTimeUs(); - uint64_t time = end - start; - timeMin = std::min(timeMin, time); - timeMax = std::max(timeMax, time); - timeAvg += time; - - msOutputs = session->GetAllOutput(); - if (cleanData) { - for (auto &msOutput : msOutputs) { - for (auto &outputTensor : msOutput.second) { - delete outputTensor; - } - } - msOutputs.clear(); - } - } - if (_flags->loopCount > 0) { - timeAvg /= _flags->loopCount; - MS_LOGI("MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms", timeMin / US2MS, timeMax / US2MS, - timeAvg / US2MS); - } - return RET_OK; -} - -STATUS Benchmark::MarkAccuracy() { - MS_LOGI("MarkAccuracy"); - - auto status = session->Run(msInputs); - if (status != RET_OK) { - MS_LOGE("Inference error %d", status); - return status; - } - msOutputs = session->GetAllOutput(); - - ReadCalibData(); - status = CompareOutput(msOutputs); - if (cleanData) { - for (auto &msOutput : msOutputs) { - for (auto &outputTensor : msOutput.second) { - delete outputTensor; - } - } - msOutputs.clear(); - } - return status; -} - -STATUS Benchmark::CleanData() { - if (cleanData) { - for (auto &msInput : msInputs) { - delete msInput; - } - msInputs.clear(); - for (auto &data : calibData) { - data.second->shape.clear(); - data.second->data.clear(); - delete data.second; - } - calibData.clear(); - } - return RET_OK; -} - -STATUS Benchmark::RunBenchmark() { - // Load graph - std::string comment = modelName; - - MS_LOGI("start reading model file"); - size_t size = 0; - char *graphBuf = ReadFile(_flags->modelPath.c_str(), &size); - if (graphBuf == nullptr) { - MS_LOGE("Load graph failed while running %s", comment.c_str()); - return RET_ERROR; - } - - uint64_t startPrepareTime = GetTimeUs(); - session = CreateSession(graphBuf, size, ctx); - if (session == nullptr) { - delete graphBuf; - MS_LOGE("new session failed while running %s", comment.c_str()); - return RET_ERROR; - } - uint64_t endPrepareTime = GetTimeUs(); - MS_LOGI("PrepareTime = %f ms, ", (endPrepareTime - startPrepareTime) / US2MS); - - // Load input - MS_LOGI("start generate input data"); - auto status = LoadInput(); - if (status != RET_OK) { - delete graphBuf; - MS_LOGE("Generate input data error"); - return status; - } - - if (!_flags->calibDataPath.empty()) { - status = MarkAccuracy(); - if (status != RET_OK) { - delete graphBuf; - MS_LOGE("Run MarkAccuracy error: %d", status); - return status; - } - } else { - status = MarkPerformance(); - if (status != RET_OK) { - delete graphBuf; - MS_LOGE("Run MarkPerformance error: %d", status); - return status; - } - } - - CleanData(); - delete graphBuf; - return RET_OK; -} - -STATUS Benchmark::Init() { - if (this->_flags == nullptr) { - return RET_ERROR; - } - MS_LOGI("ModelPath = %s", this->_flags->modelPath.c_str()); - MS_LOGI("InDataPath = %s", this->_flags->inDataPath.c_str()); - MS_LOGI("TensorDataType = %s", this->_flags->tensorDataTypeIn.c_str()); - MS_LOGI("LoopCount = %d", this->_flags->loopCount); - MS_LOGI("WarmUpLoopCount = %d", this->_flags->warmUpLoopCount); - MS_LOGI("NumThreads = %d", this->_flags->numThreads); - MS_LOGI("calibDataPath = %s", this->_flags->calibDataPath.c_str()); - - this->_flags->inDataType = this->_flags->inDataTypeIn == "img" ? kImage : kBinary; - if (this->_flags->tensorDataTypeIn == "float") { - this->_flags->tensorDataType = DataType_DT_FLOAT; - } - - if (_flags->modelPath.empty()) { - MS_LOGE("modelPath is required"); - return RET_ERROR; - } - - modelName = _flags->modelPath.substr(_flags->modelPath.find_last_of("/") + 1); - - return RET_OK; -} - -int RunBenchmark(int argc, const char **argv) { - BenchmarkFlags flags; - Option err = flags.ParseFlags(argc, argv); - - if (err.IsSome()) { - std::cerr << err.Get() << std::endl; - std::cerr << flags.Usage() << std::endl; - return -1; - } - - if (flags.help) { - std::cerr << flags.Usage() << std::endl; - return 0; - } - - Benchmark mBenchmark(&flags); - auto status = mBenchmark.Init(); - if (status != RET_OK) { - MS_LOGE("Benchmark init Error : %d", status); - return 1; - } - - status = mBenchmark.RunBenchmark(); - if (status != RET_OK) { - MS_LOGE("Run Benchmark Error : %d", status); - return 1; - } - - MS_LOGI("end of benchmark"); - return 0; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/benchmark/benchmark.h b/predict/benchmark/benchmark.h deleted file mode 100644 index 03cd117df0..0000000000 --- a/predict/benchmark/benchmark.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_BENCHMARK_BENCHMARK_H_ -#define PREDICT_BENCHMARK_BENCHMARK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/flag_parser.h" -#include "common/file_utils.h" -#include "common/func_utils.h" -#include "common/mslog.h" -#include "common/utils.h" -#include "include/errorcode.h" -#include "include/session.h" -#include "include/tensor.h" -#include "schema/inner/ms_generated.h" -#include "src/graph.h" -#include "src/graph_execution.h" -#include "src/op.h" - -namespace mindspore { -namespace predict { -enum InDataType { kImage = 0, kBinary = 1 }; - -struct CheckTensor { - CheckTensor(const std::vector &shape, const std::vector &data) { - this->shape = shape; - this->data = data; - } - std::vector shape; - std::vector data; -}; - -class BenchmarkFlags : public virtual FlagParser { - public: - BenchmarkFlags() { - // common - AddFlag(&BenchmarkFlags::modelPath, "modelPath", "Input model path", ""); - AddFlag(&BenchmarkFlags::tensorDataTypeIn, "tensorDataType", "Data type of input Tensor. float", "float"); - AddFlag(&BenchmarkFlags::inDataPath, "inDataPath", "Input data path, if not set, use random input", ""); - // MarkPerformance - AddFlag(&BenchmarkFlags::loopCount, "loopCount", "Run loop count", 10); - AddFlag(&BenchmarkFlags::numThreads, "numThreads", "Run threads number", 2); - AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); - // MarkAccuracy - AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); - } - - ~BenchmarkFlags() override = default; - - public: - // common - std::string modelPath; - std::string inDataPath; - InDataType inDataType; - std::string inDataTypeIn; - DataType tensorDataType; - std::string tensorDataTypeIn; - // MarkPerformance - int loopCount; - int numThreads; - int warmUpLoopCount; - // MarkAccuracy - std::string calibDataPath; -}; - -class Benchmark { - public: - explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {} - - virtual ~Benchmark() = default; - - STATUS Init(); - STATUS RunBenchmark(); - - private: - // call GenerateInputData or ReadInputFile to init inputTensors - STATUS LoadInput(); - - // call GenerateRandomData to fill inputTensors - STATUS GenerateInputData(); - - STATUS GenerateRandomData(size_t size, void *data); - - STATUS ReadInputFile(); - - STATUS ReadCalibData(); - - STATUS CleanData(); - - STATUS CompareOutput(const std::map> &msOutputs); - - float CompareData(const std::string &nodeName, std::vector msShape, float *msTensorData); - - STATUS MarkPerformance(); - - STATUS MarkAccuracy(); - - private: - BenchmarkFlags *_flags; - std::shared_ptr session; - Context ctx; - std::vector msInputs; - std::map> msOutputs; - std::unordered_map calibData; - std::string modelName = ""; - bool cleanData = true; - - const float US2MS = 1000.0f; - const float percentage = 100.0f; - const int printNum = 50; - const float minFloatThr = 0.0000001f; - - const uint64_t maxTimeThr = 1000000; -}; - -int RunBenchmark(int argc, const char **argv); -} // namespace predict -} // namespace mindspore -#endif // PREDICT_BENCHMARK_BENCHMARK_H_ diff --git a/predict/benchmark/main.cc b/predict/benchmark/main.cc deleted file mode 100644 index 66e473a42a..0000000000 --- a/predict/benchmark/main.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "benchmark/benchmark.h" - -int main(int argc, const char **argv) { - signal(SIGSEGV, mindspore::predict::CoreDumpTraceFunc); - return mindspore::predict::RunBenchmark(argc, argv); -} diff --git a/predict/common/CMakeLists.txt b/predict/common/CMakeLists.txt deleted file mode 100755 index 3734c26bc0..0000000000 --- a/predict/common/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../include) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) - -add_compile_options(-fPIC) - -add_library(common_mid OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/common.h - ${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc - ${CMAKE_CURRENT_SOURCE_DIR}/file_utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/flag_parser.cc - ${CMAKE_CURRENT_SOURCE_DIR}/func_utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/module_registry.cc - ${CMAKE_CURRENT_SOURCE_DIR}/mslog.cc - ${CMAKE_CURRENT_SOURCE_DIR}/storage.cc - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc) diff --git a/predict/common/common.h b/predict/common/common.h deleted file mode 100644 index d93139abae..0000000000 --- a/predict/common/common.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_COMMON_H_ -#define PREDICT_COMMON_COMMON_H_ - -#include -#include "schema/inner/ms_generated.h" - -namespace mindspore { -namespace predict { -enum NCHW_SHAPE { NCHW_N = 0, NCHW_C = 1, NCHW_H = 2, NCHW_W = 3 }; -enum NHWC_SHAPE { NHWC_N = 0, NHWC_H = 1, NHWC_W = 2, NHWC_C = 3 }; -enum HWCK_SHAPE { HWCK_H = 0, HWCK_W = 1, HWCK_C = 2, HWCK_K = 3 }; -enum KCHW_SHAPE { KCHW_K = 0, KCHW_C = 1, KCHW_H = 2, KCHW_W = 3 }; -enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 }; -enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 }; - -static constexpr int TENSOR_MAX_REFCOUNT = 999; - -static const char *DELIM_COLON = ":"; -static const char *DELIM_COMMA = ","; -static const char *DELIM_SLASH = "/"; -static const char *DELIM_DOUBLE_BACKSLASH = "\\"; - -// quantization relative -static const char QUANTIZED_UINT8[] = "QUANTIZED_UINT8"; -static const char QUANTIZED_INT8[] = "QUANTIZED_INT8"; -static const char QUANTIZED_INT16[] = "QUANTIZED_INT16"; -static const char QUANTIZED_UINT16[] = "QUANTIZED_UINT16"; -static const char QUANTIZED_FLOAT16[] = "FLOAT16"; -static const char QUANTIZED_FLOAT32[] = "FLOAT32"; -static const char QUANTIZATION_TYPE_DYNAMIC[] = "DYNAMIC"; -static const char QUANTIZATION_TYPE_STATIC[] = "STATIC"; -static const char CALIB_NORM[] = "NORM"; - -// dims -static const int32_t DIM_DEFAULT_SIZE = 4; - -static const Format DEFAULT_FORMAT = Format_NCHW; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_COMMON_H_ diff --git a/predict/common/file_utils.cc b/predict/common/file_utils.cc deleted file mode 100644 index 94adf0f7ac..0000000000 --- a/predict/common/file_utils.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/file_utils.h" -#include - -namespace mindspore { -namespace predict { -char *ReadFile(const char *file, size_t *size) { - if (file == nullptr) { - MS_LOGE("file is nullptr"); - return nullptr; - } - MS_ASSERT(size != nullptr); - std::ifstream ifs(RealPath(file)); - if (!ifs.good()) { - MS_LOGE("file: %s is not exist", file); - return nullptr; - } - - if (!ifs.is_open()) { - MS_LOGE("file: %s open failed", file); - return nullptr; - } - - ifs.seekg(0, std::ios::end); - *size = ifs.tellg(); - std::unique_ptr buf(new (std::nothrow) char[*size]); - if (buf == nullptr) { - MS_LOGE("malloc buf failed, file:%s", file); - ifs.close(); - return nullptr; - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf.get(), *size); - ifs.close(); - - return buf.release(); -} - -std::string RealPath(const char *path) { - if (path == nullptr) { - MS_LOGE("path is nullptr"); - return ""; - } - if ((strlen(path)) >= PATH_MAX) { - MS_LOGE("path is too long"); - return ""; - } - - std::shared_ptr resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); - if (resolvedPath == nullptr) { - MS_LOGE("new resolvedPath failed"); - return ""; - } - - auto ret = realpath(path, resolvedPath.get()); - if (ret == nullptr) { - MS_LOGE("realpath failed"); - return ""; - } - return resolvedPath.get(); -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/file_utils.h b/predict/common/file_utils.h deleted file mode 100644 index e67c1cf9f1..0000000000 --- a/predict/common/file_utils.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_FILE_UTILS_H_ -#define PREDICT_COMMON_FILE_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "common/mslog.h" -#include "include/tensor.h" - -namespace mindspore { -namespace predict { -char *ReadFile(const char *file, size_t *size); - -std::string RealPath(const char *path); -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_FILE_UTILS_H_ diff --git a/predict/common/flag_parser.cc b/predict/common/flag_parser.cc deleted file mode 100644 index 37482dc409..0000000000 --- a/predict/common/flag_parser.cc +++ /dev/null @@ -1,179 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/flag_parser.h" - -namespace mindspore { -namespace predict { -// parse flags read from command line -Option FlagParser::ParseFlags(int argc, const char *const *argv, bool supportUnknown, - bool supportDuplicate) { - MS_ASSERT(argv != nullptr); - const int FLAG_PREFIX_LEN = 2; - // Get binary name - binName = GetFileName(argv[0]); - - std::multimap> keyValues; - for (int i = 1; i < argc; i++) { - std::string tmp = argv[i]; - Trim(&tmp); - const std::string flagItem(tmp); - - if (flagItem == "--") { - break; - } - - if (flagItem.find("--") == std::string::npos) { - continue; - } - - std::string key; - Option value = Option(None()); - - size_t pos = flagItem.find_first_of("="); - if (pos == std::string::npos && flagItem.find("--no-") != std::string::npos) { - key = flagItem.substr(FLAG_PREFIX_LEN); - } else if (pos == std::string::npos) { - key = flagItem.substr(FLAG_PREFIX_LEN); - } else { - key = flagItem.substr(FLAG_PREFIX_LEN, pos - FLAG_PREFIX_LEN); - value = Option(flagItem.substr(pos + 1)); - } - - keyValues.insert(std::pair>(key, value)); - } - - Option ret = Option(InnerParseFlags(&keyValues)); - if (ret.IsSome()) { - return Option(ret.Get()); - } - - return Option(None()); -} - -bool FlagParser::GetRealFlagName(const std::string &oriFlagName, std::string *flagName) { - MS_ASSERT(flagName != nullptr); - const int BOOL_TYPE_FLAG_PREFIX_LEN = 3; - bool opaque = false; - if (StartsWithPrefix(oriFlagName, "no-")) { - *flagName = oriFlagName.substr(BOOL_TYPE_FLAG_PREFIX_LEN); - opaque = true; - } else { - *flagName = oriFlagName; - } - return opaque; -} - -// Inner parse function -Option FlagParser::InnerParseFlags(std::multimap> *keyValues) { - MS_ASSERT(keyValues != nullptr); - for (auto it = keyValues->begin(); it != keyValues->end(); ++it) { - std::string flagName; - bool opaque = GetRealFlagName((*it).first, &flagName); - Option flagValue = (*it).second; - - auto item = flags.find(flagName); - if (item == flags.end()) { - return Option(std::string(flagName + " is not a valid flag")); - } - FlagInfo *flag = &(item->second); - if (flag == nullptr) { - return Option("Failed: flag is nullptr"); - } - if (flag->isParsed) { - return Option("Failed: already parsed flag: " + flagName); - } - std::string tmpValue; - if (!flag->isBoolean) { - if (opaque) { - return Option(flagName + " is not a boolean type"); - } - if (flagValue.IsNone()) { - return Option("No value provided for non-boolean type: " + flagName); - } - tmpValue = flagValue.Get(); - } else { - if (flagValue.IsNone() || flagValue.Get().empty()) { - tmpValue = !opaque ? "true" : "false"; - } else if (!opaque) { - tmpValue = flagValue.Get(); - } else { - return Option(std::string("Boolean flag can not have non-empty value")); - } - } - // begin to parse value - Option ret = flag->parse(this, tmpValue); - if (ret.IsNone()) { - return Option("Failed to parse value for: " + flag->flagName); - } - flag->isParsed = true; - } - - // to check flags not given in command line but added as in constructor - for (auto &flag : flags) { - if (flag.second.isRequired && !flag.second.isParsed) { - return Option("Error, value of '" + flag.first + "' not provided"); - } - } - - return Option(None()); -} - -void Replaceall(std::string *str, const std::string &oldValue, const std::string &newValue) { - if (str == nullptr) { - MS_LOGE("Input str is nullptr"); - return; - } - while (true) { - std::string::size_type pos(0); - if ((pos = str->find(oldValue)) != std::string::npos) { - str->replace(pos, oldValue.length(), newValue); - } else { - break; - } - } -} - -std::string FlagParser::Usage(const Option &usgMsg) const { - // first line, brief of the usage - std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; - // usage of bin name - usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; - // help line of help message, usageLine:message of parametors - std::string helpLine = ""; - std::string usageLine = ""; - uint32_t i = 0; - for (auto flag = flags.begin(); flag != flags.end(); flag++) { - std::string flagName = flag->second.flagName; - std::string helpInfo = flag->second.helpInfo; - // parameter line - std::string thisLine = flag->second.isBoolean ? " --[no-]" + flagName : " --" + flagName + "=VALUE"; - if (++i < flags.size()) { - // add paramter help message of each line - thisLine += " " + helpInfo; - Replaceall(&helpInfo, "\n\r", "\n"); - usageLine += thisLine + "\n"; - } else { - // brief help message - helpLine = thisLine + " " + helpInfo + "\n"; - } - } - // total usage is brief of usage+ brief of bin + help message + brief of - // paramters - return usageString + helpLine + usageLine; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/flag_parser.h b/predict/common/flag_parser.h deleted file mode 100644 index f01b8df71e..0000000000 --- a/predict/common/flag_parser.h +++ /dev/null @@ -1,291 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_FLAG_PARSER_H_ -#define PREDICT_COMMON_FLAG_PARSER_H_ - -#include -#include -#include -#include - -#include "common/utils.h" -#include "common/option.h" - -namespace mindspore { -namespace predict { -struct FlagInfo; - -struct Nothing {}; - -class FlagParser { - public: - FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", false); } - - virtual ~FlagParser() = default; - - // only support read flags from command line - virtual Option ParseFlags(int argc, const char *const *argv, bool supportUnknown = false, - bool supportDuplicate = false); - std::string Usage(const Option &usgMsg = Option(None())) const; - - template - void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); - - template - void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); - - template - void AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo); - - // Option-type fields - template - void AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo); - bool help; - - protected: - std::string binName; - Option usageMsg; - - private: - struct FlagInfo { - std::string flagName; - bool isRequired; - bool isBoolean; - std::string helpInfo; - bool isParsed; - std::function(FlagParser *, const std::string &)> parse; - }; - - inline void AddFlag(const FlagInfo &flag); - - // construct a temporary flag - template - void ConstructFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); - - // construct a temporary flag - template - void ConstructFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); - - Option InnerParseFlags(std::multimap> *values); - - bool GetRealFlagName(const std::string &oriFlagName, std::string *flagName); - - std::map flags; -}; - -// convert to std::string -template -Option ConvertToString(T Flags::*t, const FlagParser &baseFlag) { - const Flags *flag = dynamic_cast(&baseFlag); - if (flag != nullptr) { - return std::to_string(flag->*t); - } - - return Option(None()); -} - -// construct for a Option-type flag -template -void FlagParser::ConstructFlag(Option Flags::*t1, const std::string &flagName, const std::string &helpInfo, - FlagInfo *flag) { - if (flag == nullptr) { - MS_LOGE("FlagInfo is nullptr"); - return; - } - flag->flagName = flagName; - flag->helpInfo = helpInfo; - flag->isBoolean = typeid(T) == typeid(bool); - flag->isParsed = false; -} - -// construct a temporary flag -template -void FlagParser::ConstructFlag(T Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) { - if (flag == nullptr) { - MS_LOGE("FlagInfo is nullptr"); - return; - } - if (t1 == nullptr) { - MS_LOGE("t1 is nullptr"); - return; - } - flag->flagName = flagName; - flag->helpInfo = helpInfo; - flag->isBoolean = typeid(T) == typeid(bool); - flag->isParsed = false; -} - -inline void FlagParser::AddFlag(const FlagInfo &flagItem) { flags[flagItem.flagName] = flagItem; } - -template -void FlagParser::AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo) { - if (t == nullptr) { - MS_LOGE("t1 is nullptr"); - return; - } - - Flags *flag = dynamic_cast(this); - if (flag == nullptr) { - MS_LOGI("dynamic_cast failed"); - return; - } - - FlagInfo flagItem; - - // flagItem is as a output parameter - ConstructFlag(t, flagName, helpInfo, &flagItem); - flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option { - Flags *flag = dynamic_cast(base); - if (base != nullptr) { - Option ret = Option(GenericParseValue(value)); - if (ret.IsNone()) { - return Option(None()); - } else { - flag->*t = ret.Get(); - } - } - - return Option(Nothing()); - }; - - flagItem.isRequired = true; - flagItem.helpInfo += - !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; - flagItem.helpInfo += ")"; - - // add this flag to a std::map - AddFlag(flagItem); -} - -template -void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { - if (t1 == nullptr) { - MS_LOGE("t1 is nullptr"); - return; - } - - FlagInfo flagItem; - - // flagItem is as a output parameter - ConstructFlag(t1, flagName, helpInfo, flagItem); - flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { - if (base != nullptr) { - Option ret = Option(GenericParseValue(value)); - if (ret.IsNone()) { - return Option(None()); - } else { - *t1 = ret.Get(); - } - } - - return Option(Nothing()); - }; - - flagItem.isRequired = false; - *t1 = t2; - - flagItem.helpInfo += - !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; - flagItem.helpInfo += ToString(t2).Get(); - flagItem.helpInfo += ")"; - - // add this flag to a std::map - AddFlag(flagItem); -} - -template -void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { - if (t1 == nullptr) { - MS_LOGE("t1 is nullptr"); - return; - } - - Flags *flag = dynamic_cast(this); - if (flag == nullptr) { - MS_LOGI("dynamic_cast failed"); - return; - } - - FlagInfo flagItem; - - // flagItem is as a output parameter - ConstructFlag(t1, flagName, helpInfo, &flagItem); - flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { - Flags *flag = dynamic_cast(base); - if (base != nullptr) { - Option ret = Option(GenericParseValue(value)); - if (ret.IsNone()) { - return Option(None()); - } else { - flag->*t1 = ret.Get(); - } - } - - return Option(Nothing()); - }; - - flagItem.isRequired = false; - flag->*t1 = t2; - - flagItem.helpInfo += - !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; - flagItem.helpInfo += ToString(t2).Get(); - flagItem.helpInfo += ")"; - - // add this flag to a std::map - AddFlag(flagItem); -} - -// option-type add flag -template -void FlagParser::AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo) { - if (t == nullptr) { - MS_LOGE("t is nullptr"); - return; - } - - Flags *flag = dynamic_cast(this); - if (flag == nullptr) { - MS_LOGE("dynamic_cast failed"); - return; - } - - FlagInfo flagItem; - // flagItem is as a output parameter - ConstructFlag(t, flagName, helpInfo, &flagItem); - flagItem.isRequired = false; - flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option { - Flags *flag = dynamic_cast(base); - if (base != nullptr) { - Option ret = Option(GenericParseValue(value)); - if (ret.IsNone()) { - return Option(None()); - } else { - flag->*t = Option(Some(ret.Get())); - } - } - - return Option(Nothing()); - }; - - // add this flag to a std::map - AddFlag(flagItem); -} -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_FLAG_PARSER_H_ diff --git a/predict/common/func_utils.cc b/predict/common/func_utils.cc deleted file mode 100644 index d2aeb8d941..0000000000 --- a/predict/common/func_utils.cc +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/func_utils.h" - -namespace mindspore { -namespace predict { -#if MS_USE_ARM -_Unwind_Reason_Code PrintTraceArm(_Unwind_Context *ctx, void *d) { - MS_ASSERT(ctx != nullptr); - MS_ASSERT(d != nullptr); - Dl_info info; - int *depth = static_cast(d); - auto ipAddr = static_cast(_Unwind_GetIP(ctx)); - if (dladdr(reinterpret_cast(ipAddr), &info)) { - const char *symbol = ""; - const char *dlfile = ""; - if (info.dli_sname) { - symbol = info.dli_sname; - } - if (info.dli_fname) { - dlfile = info.dli_fname; - } - MS_PRINT_ERROR("#%d: (%08lx) %s %s ", *depth, ipAddr, dlfile, symbol); - } - - (*depth)++; - return _URC_NO_REASON; -} -#endif - -void CoreDumpTraceFunc(int iSignum) { - MS_PRINT_ERROR("----- start get backtrace info -----"); -#if MS_USE_ARM - int depth = 0; - _Unwind_Backtrace(&PrintTraceArm, &depth); -#else - const auto maxDeep = 32; - const auto maxStringLen = 100; - void *apBuffer[maxStringLen]; - char **ppStrings; - - auto iStackDepth = backtrace(apBuffer, maxDeep); - if (0 > iStackDepth) { - KillProcess("Get backtrace depth failed"); - return; - } - MS_PRINT_ERROR("Current stack depth is %d", iStackDepth); - ppStrings = backtrace_symbols(apBuffer, iStackDepth); - if (nullptr == ppStrings) { - KillProcess("Get backtrace_symbols failed"); - return; - } - - for (int iLoop = 0; iLoop < iStackDepth; iLoop++) { - MS_PRINT_ERROR("%s \n", ppStrings[iLoop]); - } -#endif - MS_PRINT_ERROR("----- finish get backtrace info -----"); - KillProcess("Exit after core dump"); - return; // try exit 1 -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/func_utils.h b/predict/common/func_utils.h deleted file mode 100644 index da0389a584..0000000000 --- a/predict/common/func_utils.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_FUNC_UTILS_H_ -#define PREDICT_COMMON_FUNC_UTILS_H_ - -#if MS_USE_ARM -#include -#include -#else -#include -#endif -#include "include/errorcode.h" -#include "common/mslog.h" - -namespace mindspore { -namespace predict { -void CoreDumpTraceFunc(int iSignum); -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_FUNC_UTILS_H_ diff --git a/predict/common/graph_util.cc b/predict/common/graph_util.cc deleted file mode 100644 index 6394731bc6..0000000000 --- a/predict/common/graph_util.cc +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/graph_util.h" -#include -#include -#include "common/mslog.h" -#include "include/errorcode.h" - -namespace mindspore { -namespace predict { -OpGraph *OpGraph::Build(const SubGraphDef &subGraphDef) { - auto graph = std::unique_ptr(new OpGraph()); - if (graph == nullptr) { - MS_LOGE("malloc opgraph failed"); - return nullptr; - } - - auto nodeDefs = subGraphDef.nodes(); - if (nodeDefs == nullptr) { - MS_LOGE("nodeDefs from subGraphDef is nullptr"); - return nullptr; - } - - uint32_t opCount = nodeDefs->size(); - for (uint32_t i = 0; i < opCount; i++) { - auto nodeDef = nodeDefs->GetAs(i); - MS_ASSERT(nodeDef != nullptr); - auto ret = graph->AddEdge(*nodeDef, *nodeDefs); - if (ret != RET_OK) { - MS_LOGE("%s add edge failed. ret:%d", nodeDef->opDef()->name()->c_str(), ret); - return nullptr; - } - } - - return graph.release(); -} - -int OpGraph::AddEdge(const NodeDef &srcNodeDef, const flatbuffers::Vector> &nodeDefs) { - MS_ASSERT(srcNodeDef.opDef() != nullptr); - MS_ASSERT(srcNodeDef.opDef()->name() != nullptr); - NODE_ID srcId = std::string(srcNodeDef.opDef()->name()->c_str()); - uint32_t opCount = nodeDefs.size(); - - MS_ASSERT(srcNodeDef.opDef()->outputIndex() != nullptr); - for (auto index : *(srcNodeDef.opDef()->outputIndex())) { - for (uint32_t i = 0; i < opCount; i++) { - auto dstNodeDef = nodeDefs.GetAs(i); - bool find = false; - MS_ASSERT(dstNodeDef != nullptr); - MS_ASSERT(dstNodeDef->opDef() != nullptr); - auto inputIndex = dstNodeDef->opDef()->inputIndex(); - MS_ASSERT(inputIndex != nullptr); - if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) { - find = true; - } - - if (!find) { - continue; - } - MS_ASSERT(dstNodeDef->opDef()->name() != nullptr); - NODE_ID dstId = std::string(dstNodeDef->opDef()->name()->c_str()); - auto ret = AddEdge(srcId, dstId); - if (ret != RET_OK) { - return ret; - } - } - } - - return RET_OK; -} - -int OpGraph::AddEdge(const NODE_ID &srcId, const NODE_ID &dstId) { - auto srcNode = AddNode(srcId); - if (srcNode == nullptr) { - MS_LOGE("add srcNode failed"); - return RET_ERROR; - } - srcNode->AddOutEdge(dstId); - auto dstNode = AddNode(dstId); - if (dstNode == nullptr) { - MS_LOGE("add dstNode failed"); - return RET_ERROR; - } - dstNode->AddInEdge(srcId); - return RET_OK; -} - -OpNode *OpGraph::GetNode(const NODE_ID &nodeId) { - auto node = nodes.find(nodeId); - if (node == nodes.end()) { - return nullptr; - } - return node->second; -} - -OpNode *OpGraph::AddNode(const NODE_ID &nodeId) { - auto node = GetNode(nodeId); - if (node != nullptr) { - return node; - } - node = new (std::nothrow) OpNode(nodeId); - if (node == nullptr) { - MS_LOGE("new node failed"); - return nullptr; - } - nodes[nodeId] = node; - return node; -} - -std::unordered_set OpGraph::GetInputNode() { - std::unordered_set inputNodes; - for (const auto &iter : nodes) { - auto node = iter.second; - MS_ASSERT(node != nullptr); - if (node->GetAllInEdge().empty()) { - inputNodes.insert(node->ID()); - } - } - return inputNodes; -} - -std::unordered_set OpGraph::GetOutputNode() { - std::unordered_set outputNodes; - for (const auto &iter : nodes) { - auto node = iter.second; - MS_ASSERT(node != nullptr); - if (node->GetAllOutEdge().empty()) { - outputNodes.insert(node->ID()); - } - } - return outputNodes; -} - -OpGraph::~OpGraph() { - for (auto iter : nodes) { - if (iter.second != nullptr) { - delete iter.second; - } - } - nodes.clear(); -} - -NODE_ID OpNode::ID() { return id; } - -void OpNode::AddInEdge(const NODE_ID &nodeId) { inEdges.insert(nodeId); } - -void OpNode::AddOutEdge(const NODE_ID &nodeId) { outEdges.insert(nodeId); } - -std::unordered_set OpNode::GetAllInEdge() { return inEdges; } - -std::unordered_set OpNode::GetAllOutEdge() { return outEdges; } -} // namespace predict -} // namespace mindspore diff --git a/predict/common/graph_util.h b/predict/common/graph_util.h deleted file mode 100644 index 9797edadf6..0000000000 --- a/predict/common/graph_util.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_GRAPH_UTIL_H_ -#define PREDICT_COMMON_GRAPH_UTIL_H_ - -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "schema/inner/ms_generated.h" - -namespace mindspore { -namespace predict { -using NODE_ID = std::string; - -class OpNode { - public: - explicit OpNode(NODE_ID nodeId) : id(std::move(nodeId)) {} - NODE_ID ID(); - void AddInEdge(const NODE_ID &nodeId); - void AddOutEdge(const NODE_ID &nodeId); - std::unordered_set GetAllInEdge(); - std::unordered_set GetAllOutEdge(); - - protected: - NODE_ID id; - std::unordered_set inEdges; - std::unordered_set outEdges; -}; - -class OpGraph { - public: - OpGraph() = default; - - ~OpGraph(); - - static OpGraph *Build(const SubGraphDef &subGraphDef); - - OpNode *GetNode(const NODE_ID &nodeId); - OpNode *AddNode(const NODE_ID &nodeId); - std::unordered_set GetInputNode(); - std::unordered_set GetOutputNode(); - - private: - int AddEdge(const NODE_ID &srcId, const NODE_ID &dstId); - int AddEdge(const NodeDef &srcNodeDef, const flatbuffers::Vector> &nodeDefs); - - protected: - std::unordered_map nodes; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_GRAPH_UTIL_H_ diff --git a/predict/common/module_registry.cc b/predict/common/module_registry.cc deleted file mode 100644 index da8992bb66..0000000000 --- a/predict/common/module_registry.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/module_registry.h" - -namespace mindspore { -namespace predict { -ModuleRegistry *GetRegistryInstance() { - static ModuleRegistry registry; - return ®istry; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/module_registry.h b/predict/common/module_registry.h deleted file mode 100644 index 9d7587e74a..0000000000 --- a/predict/common/module_registry.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_MODULE_REGISTRY_H_ -#define PREDICT_COMMON_MODULE_REGISTRY_H_ -#include -#include -#include -#include "common/mslog.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -class ModuleBase { - public: - virtual ~ModuleBase() = default; -}; - -template -class Module; - -class ModuleRegistry { - public: - ModuleRegistry() = default; - - virtual ~ModuleRegistry() = default; - - template - bool Register(const std::string &name, const T &t) { - modules[name] = &t; - return true; - } - - template - std::shared_ptr Create(const std::string &name) { - auto it = modules.find(name); - if (it == modules.end()) { - return nullptr; - } - auto *module = (Module *)it->second; - if (module == nullptr) { - return nullptr; - } else { - return module->Create(); - } - } - - template - T *GetInstance(const std::string &name) { - auto it = modules.find(name); - if (it == modules.end()) { - return nullptr; - } - auto *module = (Module *)it->second; - if (module == nullptr) { - return nullptr; - } else { - return module->GetInstance(); - } - } - - protected: - std::unordered_map modules; -}; - -ModuleRegistry *GetRegistryInstance() MSPREDICT_API; - -template -class ModuleRegistrar { - public: - ModuleRegistrar(const std::string &name, const T &module) { - auto registryInstance = GetRegistryInstance(); - if (registryInstance == nullptr) { - MS_LOGW("registryInstance is nullptr."); - } else { - registryInstance->Register(name, module); - } - } -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_MODULE_REGISTRY_H_ diff --git a/predict/common/mslog.cc b/predict/common/mslog.cc deleted file mode 100644 index a1b61bbc3d..0000000000 --- a/predict/common/mslog.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/mslog.h" -#include -#include -#include -#include -#include "include/errorcode.h" - -namespace mindspore { -namespace predict { -std::string GetEnv(const std::string &envvar) { - const char *value = std::getenv(envvar.c_str()); - if (value == nullptr) { - return std::string(); - } - return std::string(value); -} - -bool IsPrint(int level) { - auto envString = GetEnv("MSLOG"); - static int env = static_cast(std::strtol(!envString.empty() ? envString.c_str() : "3", nullptr, 0)); - if (env == INT_MIN || env == INT_MAX) { - env = WARN; - // enable the SP for binscope checking - std::string errorStr = "env exceeded the value that type int is able to represent"; - MS_LOGE("%s", errorStr.c_str()); - } - - return level >= env; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/mslog.h b/predict/common/mslog.h deleted file mode 100644 index a48d87f1fa..0000000000 --- a/predict/common/mslog.h +++ /dev/null @@ -1,230 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_MSLOG_H_ -#define PREDICT_COMMON_MSLOG_H_ - -#include -#include -#include -#include -#include -#include - -#if defined(__ANDROID__) || defined(ANDROID) -#include -#endif -namespace mindspore { -namespace predict { -constexpr const char *TAG = "MS_PREDICT"; - -constexpr int DEBUG = 1; -constexpr int INFO = 2; -constexpr int WARN = 3; -constexpr int ERROR = 4; - -#define MSPREDICT_API __attribute__((visibility("default"))) - -bool MSPREDICT_API IsPrint(int level); - -#if !defined(__ANDROID__) && !defined(ANDROID) - -#if LOG_TO_FILE -#define MS_LOGD(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::DEBUG)) { \ - syslog(LOG_DEBUG, "%s|%d|%s[%d]|: " #fmt, mindspore::predict::TAG, \getpid(), __func__, __LINE__, ##args); \ - } \ - } -#define MS_LOGI(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::INFO)) { \ - syslog(LOG_INFO, "%s|%d|%s[%d]|: " #fmt, mindspore::predict::TAG, \getpid(), __func__, __LINE__, ##args); \ - } \ - } -#define MS_LOGW(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::WARN)) { \ - syslog(LOG_WARNING, "%s|%d|%s[%d]|: " #fmt, mindspore::predict::TAG, \getpid(), __func__, __LINE__, ##args); \ - } \ - } -#define MS_LOGE(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::ERROR)) { \ - syslog(LOG_ERR, "%s|%d|%s[%d]|: " #fmt, mindspore::predict::TAG, getpid(), __func__, __LINE__, ##args); \ - } \ - } -#else - -#define MS_LOGD(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::DEBUG)) { \ - printf("[DEBUG] %s|%d|%s|%s[%d]|: " #fmt "\r\n", mindspore::predict::TAG, getpid(), __FILE__, __func__, \ - __LINE__, ##args); \ - } \ - } -#define MS_LOGI(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::INFO)) { \ - printf("[INFO] %s|%d|%s|%s[%d]|: " #fmt "\r\n", mindspore::predict::TAG, getpid(), __FILE__, __func__, \ - __LINE__, ##args); \ - } \ - } -#define MS_LOGW(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::WARN)) { \ - printf("[WARN] %s|%d|%s|%s[%d]|: " #fmt "\r\n", mindspore::predict::TAG, getpid(), __FILE__, __func__, \ - __LINE__, ##args); \ - } \ - } -#define MS_LOGE(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::ERROR)) { \ - printf("[ERROR] %s|%d|%s|%s[%d]|: " #fmt "\r\n", mindspore::predict::TAG, getpid(), __FILE__, __func__, \ - __LINE__, ##args); \ - } \ - } -#endif - -#else - -#define MS_LOGD(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::DEBUG)) \ - __android_log_print(ANDROID_LOG_DEBUG, mindspore::predict::TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, \ - __LINE__, ##args); \ - } - -#define MS_LOGI(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::INFO)) \ - __android_log_print(ANDROID_LOG_INFO, mindspore::predict::TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, \ - __LINE__, ##args); \ - } - -#define MS_LOGW(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::WARN)) \ - __android_log_print(ANDROID_LOG_WARN, mindspore::predict::TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, \ - __LINE__, ##args); \ - } - -#define MS_LOGE(fmt, args...) \ - { \ - if (mindspore::predict::IsPrint(mindspore::predict::ERROR)) \ - __android_log_print(ANDROID_LOG_ERROR, mindspore::predict::TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, \ - __LINE__, ##args); \ - } - -#endif - -#define MS_LOG(severity) std::cout << std::endl -#define MS_DLOG(verboselevel) std::cout << std::endl -// Kill the process for safe exiting. -inline void KillProcess(const std::string &ret) { - MS_LOG(ERROR) << "mindspore Exit Tip:" << ret; - if (raise(SIGKILL) != 0) { - MS_LOGE("Send SIGKILL to kill process failed"); - } -} -} // namespace predict -} // namespace mindspore - -#define MS_ASSERT(expression) \ - do { \ - if (!(expression)) { \ - std::stringstream ss; \ - ss << "Assertion failed: " << #expression << ", file: " << __FILE__ << ", line: " << __LINE__; \ - mindspore::predict::KillProcess(ss.str()); \ - } \ - } while (0) - -#define MS_EXIT(ret) \ - do { \ - std::stringstream ss; \ - ss << (ret) << " ( file: " << __FILE__ << ", line: " << __LINE__ << " )."; \ - mindspore::predict::KillProcess(ss.str()); \ - } while (0) - -#define MS_PRINT_ERROR(fmt, args...) \ - printf(#fmt "\n", ##args); \ - MS_LOGE(fmt, ##args); - -#define MS_PRINT_INFO(fmt, args...) \ - printf(fmt "\n", ##args); \ - MS_LOGI(fmt, ##args); - -constexpr int LOG_CHECK_EVERY_FIRSTNUM = 10; -constexpr int LOG_CHECK_EVERY_NUM1 = 10; -constexpr int LOG_CHECK_EVERY_NUM2 = 100; -constexpr int LOG_CHECK_EVERY_NUM3 = 1000; -constexpr int LOG_CHECK_EVERY_NUM4 = 10000; - -#define LOG_CHECK_ID_CONCAT(word1, word2) word1##word2 - -#define LOG_CHECK_ID LOG_CHECK_ID_CONCAT(__FUNCTION__, __LINE__) - -#define LOG_CHECK_FIRST_N \ - [](uint32_t firstNum) { \ - static uint32_t LOG_CHECK_ID = 0; \ - ++LOG_CHECK_ID; \ - return (LOG_CHECK_ID <= firstNum); \ - } - -#define LOG_CHECK_EVERY_N1 \ - [](uint32_t firstNum, uint32_t num) { \ - static uint32_t LOG_CHECK_ID = 0; \ - ++LOG_CHECK_ID; \ - return ((LOG_CHECK_ID <= firstNum) || (LOG_CHECK_ID % num == 0)); \ - } - -#define LOG_CHECK_EVERY_N2 \ - [](uint32_t firstNum, uint32_t num1, uint32_t num2) { \ - static uint32_t LOG_CHECK_ID = 0; \ - ++LOG_CHECK_ID; \ - return ((LOG_CHECK_ID <= firstNum) || (LOG_CHECK_ID < num2 && LOG_CHECK_ID % num1 == 0) || \ - (LOG_CHECK_ID % num2 == 0)); \ - } - -#define LOG_CHECK_EVERY_N3 \ - [](uint32_t firstNum, uint32_t num1, uint32_t num2, uint32_t num3) { \ - static uint32_t LOG_CHECK_ID = 0; \ - ++LOG_CHECK_ID; \ - return ((LOG_CHECK_ID <= firstNum) || (LOG_CHECK_ID < num2 && LOG_CHECK_ID % num1 == 0) || \ - (LOG_CHECK_ID < num3 && LOG_CHECK_ID % num2 == 0) || (LOG_CHECK_ID % num3 == 0)); \ - } - -#define LOG_CHECK_EVERY_N4 \ - [](uint32_t firstNum, uint32_t num1, uint32_t num2, uint32_t num3, uint32_t num4) { \ - static uint32_t LOG_CHECK_ID = 0; \ - ++LOG_CHECK_ID; \ - return ((LOG_CHECK_ID <= firstNum) || (LOG_CHECK_ID < num2 && LOG_CHECK_ID % num1 == 0) || \ - (LOG_CHECK_ID < num3 && LOG_CHECK_ID % num2 == 0) || (LOG_CHECK_ID < num4 && LOG_CHECK_ID % num3 == 0) || \ - (LOG_CHECK_ID % num4 == 0)); \ - } - -#define LOG_CHECK_EVERY_N \ - []() { \ - static uint32_t LOG_CHECK_ID = 0; \ - ++LOG_CHECK_ID; \ - return ((LOG_CHECK_ID <= LOG_CHECK_EVERY_FIRSTNUM) || \ - (LOG_CHECK_ID < LOG_CHECK_EVERY_NUM2 && LOG_CHECK_ID % LOG_CHECK_EVERY_NUM1 == 0) || \ - (LOG_CHECK_ID < LOG_CHECK_EVERY_NUM3 && LOG_CHECK_ID % LOG_CHECK_EVERY_NUM2 == 0) || \ - (LOG_CHECK_ID < LOG_CHECK_EVERY_NUM4 && LOG_CHECK_ID % LOG_CHECK_EVERY_NUM3 == 0) || \ - (LOG_CHECK_ID % LOG_CHECK_EVERY_NUM4 == 0)); \ - } - -#endif // PREDICT_COMMON_MSLOG_H_ diff --git a/predict/common/op_utils.h b/predict/common/op_utils.h deleted file mode 100644 index 35f01edce3..0000000000 --- a/predict/common/op_utils.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_OP_UTILS_H_ -#define PREDICT_COMMON_OP_UTILS_H_ - -#include -#include -#include "schema/inner/ms_generated.h" - -namespace mindspore { -namespace predict { -inline OpT GetOpType(const OpDef &opDef) { return opDef.attr_type(); } - -inline OpT GetOpType(const NodeDef &nodeDef) { return GetOpType(*(nodeDef.opDef())); } - -inline std::string GetOpTypeName(const NodeDef &nodeDef) { return EnumNameOpT(GetOpType(nodeDef)); } - -inline std::string GetOpTypeName(const OpDef &opDef) { return EnumNameOpT(GetOpType(opDef)); } - -inline OpT GetOpType(const OpDefT &opDefT) { return opDefT.attr.type; } - -inline OpT GetOpType(const NodeDefT &nodeDefT) { return GetOpType(*(nodeDefT.opDef.get())); } - -inline std::string GetOpTypeName(const NodeDefT &nodeDefT) { return EnumNameOpT(GetOpType(nodeDefT)); } - -inline std::string GetOpTypeName(const OpDefT &opDefT) { return EnumNameOpT(GetOpType(opDefT)); } -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_OP_UTILS_H_ diff --git a/predict/common/option.h b/predict/common/option.h deleted file mode 100644 index ca72dde29b..0000000000 --- a/predict/common/option.h +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_OPTION_H_ -#define PREDICT_COMMON_OPTION_H_ - -#include -#include -#include "common/mslog.h" - -namespace mindspore { -namespace predict { -template -struct InnerSome { - explicit InnerSome(const T &t) : _t(std::move(t)) {} - - T _t; -}; - -template -InnerSome::type> Some(T &&t) { - return InnerSome::type>(std::forward(t)); -} - -struct None {}; - -template -class Option { - public: - Option() : state(NONE) {} - - explicit Option(const T &t) : data(t), state(SOME) {} - - explicit Option(T &&t) : data(std::move(t)), state(SOME) {} - - explicit Option(const InnerSome &some) : data(some._t), state(SOME) {} - - explicit Option(const None &none) : state(NONE) {} - - Option(const Option &that) : state(that.state) { - if (that.IsSome()) { - new (&data) T(that.data); - } - } - - virtual ~Option() = default; - - bool IsNone() const { return state == NONE; } - - bool IsSome() const { return state == SOME; } - - const T &Get() const & { - MS_ASSERT(IsSome()); - return data; - } - - T &Get() & { - MS_ASSERT(IsSome()); - return data; - } - - T &&Get() && { - MS_ASSERT(IsSome()); - return std::move(data); - } - - const T &&Get() const && { - MS_ASSERT(IsSome()); - return std::move(data); - } - - // oprerator override - Option &operator=(const Option &that) { - if (&that != this) { - if (IsSome()) { - data.~T(); - } - state = that.state; - if (that.IsSome()) { - new (&data) T(that.data); - } - } - - return *this; - } - - bool operator==(const Option &that) const { - return (IsNone() && that.IsNone()) || (IsSome() && that.IsSome() && data == that.data); - } - - bool operator!=(const Option &that) const { return !(*this == that); } - - bool operator==(const T &that) const { return IsSome() && data == that; } - - bool operator!=(const T &that) const { return !(*this == that); } - - private: - enum State { NONE = 0, SOME = 1 }; - - T data; - State state; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_OPTION_H_ diff --git a/predict/common/storage.cc b/predict/common/storage.cc deleted file mode 100644 index ade5861c74..0000000000 --- a/predict/common/storage.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/storage.h" -#include "flatbuffers/flatbuffers.h" -#include "common/mslog.h" -#include "common/file_utils.h" - -namespace mindspore { -namespace predict { -int Storage::Save(const GraphDefT &graph, const std::string &outputPath) { - flatbuffers::FlatBufferBuilder builder(flatSize); - auto offset = GraphDef::Pack(builder, &graph); - builder.Finish(offset); - int size = builder.GetSize(); - auto content = builder.GetBufferPointer(); - if (content == nullptr) { - MS_LOGE("GetBufferPointer nullptr"); - return RET_ERROR; - } - std::string realPath = RealPath(outputPath.c_str()); - if (realPath.empty()) { - MS_LOGE("Output file path '%s' is not valid", outputPath.c_str()); - return RET_ERROR; - } - - std::ofstream output(realPath, std::ofstream::binary); - if (!output.is_open()) { - MS_LOGE("ofstream open failed"); - return RET_ERROR; - } - output.write((const char *)content, size); - output.close(); - return RET_OK; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/storage.h b/predict/common/storage.h deleted file mode 100644 index fc612ffb6b..0000000000 --- a/predict/common/storage.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_STORAGE_H_ -#define PREDICT_COMMON_STORAGE_H_ - -#include -#include -#include "include/errorcode.h" -#include "flatbuffers/flatbuffers.h" -#include "schema/inner/ms_generated.h" - -namespace mindspore { -namespace predict { -class Storage { - public: - int Save(const GraphDefT &graph, const std::string &outputPath); - const int flatSize = 1024; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_STORAGE_H_ diff --git a/predict/common/utils.cc b/predict/common/utils.cc deleted file mode 100644 index b186a4b8d8..0000000000 --- a/predict/common/utils.cc +++ /dev/null @@ -1,228 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/utils.h" - -namespace mindspore { -namespace predict { -uint64_t GetTimeUs() { - struct timespec ts = {0, 0}; - if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { - return 0; - } - // USECS_IN_SEC *NSECS_IN_USEC; - auto retval = static_cast((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); - return retval; -} - -static const unsigned int FP32_BIT_SIZE = 32; -static const unsigned int FP32_EXPONENT_BIAS = 127; -static const unsigned int FP32_SIGNIFICAND = 23; - -static const unsigned int FP32_EXPONENT_MAX = 255; - -static const unsigned int FP16_BIT_SIZE = 16; -static const unsigned int FP16_EXPONENT_BIAS = 15; -static const unsigned int FP16_SIGNIFICAND = 10; - -static const int FP16_EXPONENT_MAX = 30; -static const int FP16_EXPONENT_MIN = -10; - -float ShortToFloat32(int16_t srcValue) { - uint16_t expHalf16 = srcValue & 0x7C00; - int exp1 = static_cast(expHalf16); - uint16_t mantissa16 = srcValue & 0x03FF; - int mantissa1 = static_cast(mantissa16); - int sign = static_cast(srcValue & 0x8000); - sign = sign << FP16_BIT_SIZE; - - // nan or inf - if (expHalf16 == 0x7C00) { - // nan - if (mantissa16 > 0) { - int res = (0x7FC00000 | sign); - int *iRes = &res; - MS_ASSERT(iRes != nullptr); - auto fres = static_cast(*iRes); - return fres; - } - // inf - int res = (0x7F800000 | sign); - int *iRes = &res; - MS_ASSERT(iRes != nullptr); - auto fres = static_cast(*iRes); - return fres; - } - if (expHalf16 != 0) { - exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias - int res = (exp1 | mantissa1); - res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND); - res = (res | sign); - int *iRes = &res; - - auto fres = static_cast(*iRes); - return fres; - } - - int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND); - xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND) - << FP32_SIGNIFICAND); // add the bias difference to xmm1 - xmm1 = xmm1 | sign; // Combine with the sign mask - - auto res = static_cast(mantissa1); // Convert mantissa to float - res *= static_cast(xmm1); - - return res; -} - -int16_t Float32ToShort(float srcValue) { - auto srcValueBit = static_cast(srcValue); - int sign = srcValueBit >> (FP32_BIT_SIZE - 1); - int mantissa = srcValueBit & 0x007FFFFF; - // exponent - int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; - int16_t res; - if (exp > 0 && exp < FP16_EXPONENT_MAX) { - // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. - res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | - ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } else if (srcValueBit == 0) { - res = 0; - } else { - if (exp <= 0) { - if (exp < FP16_EXPONENT_MIN) { - // value is less than min half float point - res = 0; - } else { - // normalized single, magnitude is less than min normal half float point. - mantissa = (mantissa | 0x00800000) >> (1 - exp); - // round to nearest - if ((mantissa & 0x00001000) > 0) { - mantissa = mantissa + 0x00002000; - } - // combine sign & mantissa (exp is zero to get denormalized number) - res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } - } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { - if (mantissa == 0) { - // input float is infinity, return infinity half - res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; - } else { - // input float is NaN, return half NaN - res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } - } else { - // exp > 0, normalized single, round to nearest - if ((mantissa & 0x00001000) > 0) { - mantissa = mantissa + 0x00002000; - if ((mantissa & 0x00800000) > 0) { - mantissa = 0; - exp = exp + 1; - } - } - if (exp > FP16_EXPONENT_MAX) { - // exponent overflow - return infinity half - res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; - } else { - // combine sign, exp and mantissa into normalized half - res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | - (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } - } - } - return res; -} -std::string Remove(const std::string &from, const std::string &subStr, Mode mode) { - std::string result = from; - if (mode == PREFIX) { - if (from.substr(0, subStr.length()) == subStr) { - result = from.substr(subStr.size()); - } - } else if (mode == SUFFIX) { - if (from.rfind(subStr) == from.size() - subStr.size()) { - result = from.substr(0, from.size() - subStr.size()); - } - } else { - size_t index; - while ((index = result.find(subStr)) != std::string::npos) { - result = result.erase(index, subStr.size()); - } - } - - return result; -} - -std::vector StrSplit(const std::string &str, const std::string &pattern) { - std::string::size_type pos; - std::vector result; - std::string tmpStr(str + pattern); - std::string::size_type size = tmpStr.size(); - - for (std::string::size_type i = 0; i < size; i++) { - pos = tmpStr.find(pattern, i); - if (pos < size) { - std::string s = tmpStr.substr(i, pos - i); - result.push_back(s); - i = pos + pattern.size() - 1; - } - } - return result; -} - -std::vector Tokenize(const std::string &src, const std::string &delimiters, - const Option &maxTokenNum) { - if (maxTokenNum.IsSome() && maxTokenNum.Get() == 0) { - return {}; - } - - std::vector tokens; - size_t offset = 0; - - while (true) { - size_t nonDelimiter = src.find_first_not_of(delimiters, offset); - if (nonDelimiter == std::string::npos) { - break; - } - size_t delimiter = src.find_first_of(delimiters, nonDelimiter); - if (delimiter == std::string::npos || (maxTokenNum.IsSome() && tokens.size() == maxTokenNum.Get() - 1)) { - tokens.push_back(src.substr(nonDelimiter)); - break; - } - - tokens.push_back(src.substr(nonDelimiter, delimiter - nonDelimiter)); - offset = delimiter; - } - return tokens; -} - -void ShortToFloat32(const int16_t *srcdata, float *dstdata, size_t elementSize) { - MS_ASSERT(srcdata != nullptr); - MS_ASSERT(dstdata != nullptr); - for (size_t i = 0; i < elementSize; i++) { - dstdata[i] = ShortToFloat32(srcdata[i]); - } -} - -void Float32ToShort(const float *srcdata, int16_t *dstdata, size_t elementSize) { - MS_ASSERT(srcdata != nullptr); - MS_ASSERT(dstdata != nullptr); - for (size_t i = 0; i < elementSize; i++) { - dstdata[i] = Float32ToShort(srcdata[i]); - } -} -} // namespace predict -} // namespace mindspore diff --git a/predict/common/utils.h b/predict/common/utils.h deleted file mode 100644 index e7d44fe982..0000000000 --- a/predict/common/utils.h +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_COMMON_UTILS_H_ -#define PREDICT_COMMON_UTILS_H_ - -#include -#include -#include -#include -#include -#include "common/mslog.h" -#include "common/option.h" -#include "include/errorcode.h" - -namespace mindspore { -namespace predict { -const int USEC = 1000000; -const int MSEC = 1000; - -uint64_t GetTimeUs(); - -int16_t Float32ToShort(float srcValue); - -float ShortToFloat32(int16_t srcValue); - -void ShortToFloat32(const int16_t *srcData, float *dstData, size_t elementSize); - -void Float32ToShort(const float *srcData, int16_t *dstData, size_t elementSize); - -template -bool IsContain(const std::vector &vec, T element) { - for (auto iter = vec.begin(); iter != vec.end(); iter++) { - if (*iter == element) { - return true; - } - } - return false; -} - -const char WHITESPACE[] = "\t\n\v\f\r "; -const char STR_TRUE[] = "true"; -const char STR_FALSE[] = "false"; - -template -Option ToString(T t) { - std::ostringstream out; - out << t; - if (!out.good()) { - return Option(None()); - } - - return Option(out.str()); -} - -template <> -inline Option ToString(bool value) { - return value ? Option(STR_TRUE) : Option(STR_FALSE); -} - -// get the file name from a given path -// for example: "/usr/bin", we will get "bin" -inline std::string GetFileName(const std::string &path) { - char delim = '/'; - - size_t i = path.rfind(delim, path.length()); - if (i != std::string::npos) { - return (path.substr(i + 1, path.length() - i)); - } - - return ""; -} - -// trim the white space character in a string -// see also: macro WHITESPACE defined above -inline void Trim(std::string *input) { - if (input == nullptr) { - return; - } - if (input->empty()) { - return; - } - - input->erase(0, input->find_first_not_of(WHITESPACE)); - input->erase(input->find_last_not_of(WHITESPACE) + 1); -} - -// to judge whether a string is starting with prefix -// for example: "hello world" is starting with "hello" -inline bool StartsWithPrefix(const std::string &source, const std::string &prefix) { - if (source.length() < prefix.length()) { - return false; - } - - return (source.compare(0, prefix.length(), prefix) == 0); -} - -// split string -std::vector StrSplit(const std::string &str, const std::string &pattern); - -// tokenize string -std::vector Tokenize(const std::string &src, const std::string &delimiters, - const Option &maxTokenNum = Option(None())); - -enum Mode { PREFIX, SUFFIX, ANY }; - -// remove redundant character -std::string Remove(const std::string &from, const std::string &subStr, Mode mode = ANY); - -template -inline Option GenericParseValue(const std::string &value) { - T ret; - std::istringstream input(value); - input >> ret; - - if (input && input.eof()) { - return Option(ret); - } - - return Option(None()); -} - -template <> -inline Option GenericParseValue(const std::string &value) { - return Option(value); -} - -template <> -inline Option GenericParseValue(const std::string &value) { - if (value == "true") { - return Option(true); - } else if (value == "false") { - return Option(false); - } - - return Option(None()); -} -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_COMMON_UTILS_H_ diff --git a/predict/include/context.h b/predict/include/context.h deleted file mode 100644 index cb30f1f8c2..0000000000 --- a/predict/include/context.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_INCLUDE_CONTEXT_H_ -#define PREDICT_INCLUDE_CONTEXT_H_ - -#include -#include "dlpack/dlpack.h" -#include "include/tensor.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -///\brief Resource management definition of MindSpore predict. -class MSPREDICT_API Context { - public: - ///\brief Constructor of MindSpore predict context using default value for parameters. - /// - ///\return Instance of MindSpore predict context. - Context(); - - ///\brief Custum constructor of MindSpore predict context using input value for parameters. - /// - ///\param[in] threadNum The number of thread during the runtime. - ///\param[in] allocator The memory management during the runtime - ///\param[in] deviceCtx The device information during the runtime. - /// - ///\return Instance of MindSpore predict context. - Context(int threadNum, std::shared_ptr allocator, DLContext deviceCtx); - - ///\brief Destructor of MindSpore predict context. - virtual ~Context(); - - public: - DLContext deviceCtx; - int threadNum = 1; - std::shared_ptr allocator; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_INCLUDE_CONTEXT_H_ diff --git a/predict/include/errorcode.h b/predict/include/errorcode.h deleted file mode 100755 index 5487673f16..0000000000 --- a/predict/include/errorcode.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_INCLUDE_ERRORCODE_H_ -#define PREDICT_INCLUDE_ERRORCODE_H_ - -namespace mindspore { -namespace predict { -using STATUS = int; - -/* Success */ -constexpr int RET_OK = 0; /**< No error occurs. */ - -/* Common error code, range: [-1, -100]*/ -constexpr int RET_ERROR = -1; /**< Common error code. */ -constexpr int RET_NULL_PTR = -2; /**< NULL pointer returned.*/ -constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/ -constexpr int RET_NO_CHANGE = -4; /**< No change. */ - -/* Executor error code, range: [-101,-200] */ -constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to checking range. */ -constexpr int RET_INPUT_TENSOR_ERROR = -102; /**< Failed to checking input tensor. */ -constexpr int RET_REENTRANT_ERROR = -103; /**< Exist executor running. */ - -/* Graph error code, range: [-201,-300] */ -constexpr int RET_GRAPH_FILE_ERR = -201; /**< Failed to verify graph file. */ - -/* Node error code, range: [-301,-400] */ -constexpr int RET_NOT_FIND_OP = -301; /**< Failed to find OP. */ -constexpr int RET_INVALID_OP_NAME = -302; /**< Invalid OP name. */ -constexpr int RET_INVALID_OP_ATTR = -303; /**< Invalid OP attr. */ -constexpr int RET_OP_EXECUTE_FAILURE = -304; /**< Failed to execution OP. */ - -/* Tensor error code, range: [-401,-500] */ -constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */ -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_INCLUDE_ERRORCODE_H_ diff --git a/predict/include/session.h b/predict/include/session.h deleted file mode 100644 index 76fa2c4d6e..0000000000 --- a/predict/include/session.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_INCLUDE_SESSION_H_ -#define PREDICT_INCLUDE_SESSION_H_ - -#include -#include -#include -#include -#include -#include "include/context.h" -#include "include/tensor.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -using NODE_ID = std::string; - -///\brief Graph defined by MindSpore predict. -/// -///\note -/// The caller does not need to care about detailed implementation of this class, so just list the class name here. -class Graph; - -///\brief GraphExecution defined by MindSpore predict. -/// -///\note -/// The caller does not need to care about detailed implementation of this class, so just list the class name here. -class GraphExecution; - -///\brief MindSpore predict session. -/// -/// This class represents session of MindSpore predict. -/// -///\note -/// The caller needs to allocate and free memory of inputs and outputs. -/// New Session is not suggested, please use CreateSession function to create new session class. -class MSPREDICT_API Session { - public: - ///\brief Constructor of MindSpore predict session. - /// - ///\param[in] ctx The context of the session. - /// - ///\return Instance of MindSpore predict session. - explicit Session(const Context &ctx); - - ///\brief Destructor of MindSpore predict session. - ~Session(); - - ///\brief Init the session. - /// - ///\param[in] ctx The context of the session. - ///\param[in] size The size of the session. - ///\param[in] graphBuf The buffer of the graph, used for build session. - /// - ///\return Return RET_OK if the initialization is success, otherwhise return RET_ERROR. - int Init(const char *graphBuf, size_t size); - - ///\brief Get the input of session. - /// - ///\return Input node's input tensors if found, empty vector otherwise. - /// - ///\note - /// The caller needs to allocate and free memory of inputs. - std::vector GetInput(); - - ///\brief Run the session. - /// - ///\param[in] inputs The input of the session. - /// - ///\return Return RET_OK if run success, otherwhise return RET_ERROR. - ///\note - /// Currently input tensors' data format only support FORMAT_NCHW. - /// Currently input tensors' data type only support FLOAT. - int Run(const std::vector &inputs); - - ///\brief Get the output of session. - /// - ///\param[in] nodeName Given output node name. - /// - ///\return Output node's output tensors if found, empty vector otherwise. - /// - ///\note - /// The caller needs to free memory of outputs. - std::vector GetOutput(const std::string &nodeName); - - ///\brief Get the all output of session. - /// - ///\return Every output node's output tensors. - /// - ///\note - /// The caller needs to free memory of outputs. - std::map> GetAllOutput(); - - protected: - ///\brief Init the executor. - /// - ///\return Return RET_OK if the initialization is success, otherwhise return RET_ERROR. - int InitExecutor(); - - const Context &_ctx; - Graph *_graph = nullptr; - GraphExecution *_executor = nullptr; - bool reinitExecutor = true; -}; - -///\brief MindSpore predict neural network session create function -/// -/// This function used to create MindSpore predict neural network session, which will be used to run the neural network. -/// -///\param[in] sessionName The name of the session. -///\param[in] graphBuf The buffer of the graph, used for build session. -///\param[in] size The size of the session. -///\param[in] ctx The context of the session. -/// -///\return Instance of MindSpore predict session. -/// -///\note -/// The caller needs to allocate and free memory of graph buffer. -std::shared_ptr MSPREDICT_API CreateSession(const char *graphBuf, size_t size, const Context &ctx); -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_INCLUDE_SESSION_H_ diff --git a/predict/include/tensor.h b/predict/include/tensor.h deleted file mode 100644 index 8a608b486d..0000000000 --- a/predict/include/tensor.h +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_INCLUDE_TENSOR_H_ -#define PREDICT_INCLUDE_TENSOR_H_ - -#include -#include -#include "dlpack/dlpack.h" -#include "schema/inner/ms_generated.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -///\brief Allocator definition of MindSpore predict. -class Allocator; - -///\brief Tensor definition of MindSpore predict. -class MSPREDICT_API Tensor { - public: - ///\brief Constructor of MindSpore predict tensor. - /// - ///\param[in] tensor Define the parameters of the tensor. - ///\param[in] copyData Malloc data for the tensor, and copy origin data from - /// input tensor. - /// - ///\return Instance of MindSpore predict tensor. - Tensor(const Tensor &tensor, bool copyData = false); - - ///\brief Constructor of MindSpore predict tensor. - /// - ///\param[in] dt Data Type of the tensor, see introduction to 'enum DataType' - /// for supported type. - ///\param[in] dims Dimension Values such as height and width, which defined - /// the shape of the tensor. - ///\param[in] format Tensor format, see introduction to 'enum Format' for - /// supported format. - ///\param[in] data Data of the tensor. - /// - ///\return Instance of MindSpore predict tensor. - /// - ///\note - /// Length of data should align with dt, format and dims, otherwise the - /// application might run into unexpected error, - /// such as segment fault. - /// For example, dt is DT_FLOAT, format is FORMAT_NCHW, dims is [1,3,300,300], - /// then minimum length of data should - /// be 1 * 3 * 300 * 300 * sizeof(float). - Tensor(DataType dt, const std::vector &dims, Format format, void *data); - - ///\brief Destructor of MindSpore predict tensor. - ~Tensor(); - - ///\brief Get MindSpore predict tensor. - /// - ///\param[in] Definition of the tensor. - /// - ///\return Address of MindSpore predict tensor. - static Tensor *CopyFromTensorDef(const TensorDef &tensordef); - - ///\brief Get dtype of MindSpore predict tensor. - /// - ///\return Dtype of MindSpore predict tensor. - DLDataType GetTensorDtype() const; - - ///\brief Get data of MindSpore predict tensor. - /// - ///\return Address of MindSpore predict tensor data. - void *GetData() const; - - ///\brief Set data of MindSpore predict tensor. - /// - ///\param[in] data Address for data of the MindSpore predict tensor instance. - /// - ///\note - /// Length of data should align with dt, format and dims, otherwise the - /// application might run into unexpected error, - /// such as segment fault. - /// For example, dt is DT_FLOAT, format is FORMAT_NCHW, dims is [1,3,300,300], - /// then minimum length of data should - /// be 1 * 3 * 300 * 300 * sizeof(float). - void SetData(void *data); - - ///\brief Get data type of MindSpore predict tensor. - /// - ///\return Data Type of the tensor. - DataType GetDataType() const; - - ///\brief Set data type of MindSpore predict tensor. - /// - ///\param[in] dt Data Type of the tensor, see introduction to 'enum DataType' - /// for supported type. - void SetDataType(DataType dt); - - ///\brief Get number of dimension of MindSpore predict tensor. - /// - ///\return Number of dimension of the MindSpore predict tensor. - int GetNDim() const; - - ///\brief Get dimension of MindSpore predict tensor. - /// - ///\return Dimension of the MindSpore predict tensor. - std::vector GetDims() const; - - ///\brief Set dimension of MindSpore predict tensor. - /// - ///\param[in] dims Vector that has values of dimension. - void SetDims(const std::vector &dims); - - ///\brief Get format of MindSpore predict tensor. - /// - ///\return Format of the MindSpore predict tensor. - Format GetFormat() const { return format; } - - ///\brief Set format of MindSpore predict tensor. - /// - ///\param[in] format Format of the tensor. - void SetFormat(Format format) { this->format = format; } - - ///\brief Get reference count of MindSpore predict tensor. - /// - ///\return Reference count of the MindSpore predict tensor. - int RefCount() { return refCount; } - - ///\brief Increase reference count of MindSpore predict tensor. - /// - ///\param[in] ref The increase of the reference count. - void AddRef(int ref) { refCount += ref; } - - ///\brief Decrease reference count of MindSpore predict tensor. - /// - ///\param[in] ref The decrease of the reference count. - void DefRef(int ref) { refCount -= ref; } - - ///\brief Get element size of MindSpore predict tensor. - /// - ///\return Element size of MindSpore predict tensor. - size_t GetElementSize() const; - - ///\brief Get data size of MindSpore predict tensor. - /// - ///\return Data size of MindSpore predict tensor. - size_t GetDataSize() const; - - ///\brief Get element size of MindSpore predict tensor in NC4HW4 format. - /// - ///\param[in] isNhwc Whether the current format is NHWC. - /// - ///\return Element size of MindSpore predict tensor in NC4HW4 format. - size_t GetNC4HW4ElementSize(bool isNhwc); - - ///\brief Get data size of MindSpore predict tensor in NC4HW4 format. - /// - ///\param[in] isNhwc Whether the current format is NHWC. - /// - ///\return Data size of MindSpore predict tensor in NC4HW4 format. - size_t GetNC4HW4DataSize(bool isNhwc); - - ///\brief Malloc data for the MindSpore predict tensor. - /// - ///\param[in] allocator The malloc source for data. - ///\param[in] refCount The reference count of the data. - /// - ///\return Return RET_OK if the data is successfully allocated, otherwhise return RET_ERROR. - int MallocData(std::shared_ptr allocator = nullptr, int refCount = 0); - - ///\brief Free the MindSpore predict tensor. - void FreeTensor(); - - ///\brief Free the data of MindSpore predict tensor. - void ForceFreeData(); - - ///\brief Free the data of MindSpore predict tensor. - void FreeData(); - - ///\brief Compare data size of MindSpore predict tensor in NC4HW4 format. - /// - ///\param[in] dst The compare tensor. - /// - ///\return The result of fuction. - bool CompareShape(const Tensor &dst); - - ///\brief Compare shape of MindSpore predict tensor with another shape. - /// - ///\param[in] other The compare shape information. - /// - ///\return The result of function. - bool CompareShape(const std::vector &other); - - ///\brief Get instance of MindSpore predict tensor. - /// - ///\return Instance of MindSpore predict dlTensor. - DLTensor *GetDLTensor() { return &dlTensor; } - - ///\brief Get height of MindSpore predict tensor. - /// - ///\return Height of MindSpore predict tensor. - int64_t Height() const; - - ///\brief Get width of MindSpore predict tensor. - /// - ///\return Width of MindSpore predict tensor. - int64_t Width() const; - - ///\brief Get channel of MindSpore predict tensor. - /// - ///\return Channel of MindSpore predict tensor. - int64_t Channel() const; - - ///\brief Get batch of MindSpore predict tensor. - /// - ///\return Batch of MindSpore predict tensor. - int64_t Batch() const; - - ///\brief Get stride of MindSpore predict tensor. - /// - ///\param[in] index the index of stride. - /// - ///\return Stride of MindSpore predict tensor. - int64_t Stride(int index) const; - - ///\brief Set stride of MindSpore predict tensor by input. - /// - ///\param[in] index Index of stride - ///\param[in] stride The stride to set - void SetStride(int index, int64_t stride); - - ///\brief Set stride of MindSpore predict tensor by dims. - void SetStride(); - void SetScale(bool isScale = true); - - private: - bool isScale = false; - int refCount = 0; - int isConst; - Format format; - DLTensor dlTensor; - std::shared_ptr allocator = nullptr; - std::vector scale; - std::vector zeroPoint; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_INCLUDE_TENSOR_H_ diff --git a/predict/module/CMakeLists.txt b/predict/module/CMakeLists.txt deleted file mode 100644 index 9b978f1172..0000000000 --- a/predict/module/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(tvm_kernel) diff --git a/predict/module/tvm_kernel/.gitignore b/predict/module/tvm_kernel/.gitignore deleted file mode 100644 index 3b552d75da..0000000000 --- a/predict/module/tvm_kernel/.gitignore +++ /dev/null @@ -1,27 +0,0 @@ -# Created by .ignore support plugin -# - -# filter python -*.pyc - -# filter build -*.so -*.o - -# filter coverage -coverage/ - -# filter report -*.xml - -# filter tvm -3rdparty/ - -# filter build -build/ -cmake-build-debug/ -.idea/ -TFLite_Detection_PostProcess_CI -app_run -output -tvm diff --git a/predict/module/tvm_kernel/.gitmodules b/predict/module/tvm_kernel/.gitmodules deleted file mode 100644 index 8a987193a4..0000000000 --- a/predict/module/tvm_kernel/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "3rdparty/incubator-tvm"] - path = 3rdparty/incubator-tvm - url = https://github.com/dmlc/tvm.git - branch = v0.5 diff --git a/predict/module/tvm_kernel/CMakeLists.txt b/predict/module/tvm_kernel/CMakeLists.txt deleted file mode 100755 index b99e257c3e..0000000000 --- a/predict/module/tvm_kernel/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -cmake_minimum_required(VERSION 3.12.1) -project(autotensor LANGUAGES CXX) -set (MINDSPORE "${PROJECT_SOURCE_DIR}/../../..") -set (TVM_KERNEL_LITE "${PROJECT_SOURCE_DIR}/lite") -set (THIRDPARTY "${MINDSPORE}/third_party") -set (TVM_CLEAN_SOURCE "${THIRDPARTY}/incubator-tvm") -set (TVM_BUILD_SOURCE "${PROJECT_SOURCE_DIR}/incubator-tvm") -set (BUILD_DIR "${PROJECT_SOURCE_DIR}") -set (TVM_KERNEL_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) -set (TVM_OUTPUT_DIR ${TVM_KERNEL_OUTPUT_DIR}/incubator-tvm) - -set (LLVM_CONFIG $ENV{LLVM_PATH}) -if (NOT LLVM_CONFIG) - message(FATAL_ERROR "please set LLVM_PATH in env") -endif() -set (CMAKE_BUILD_TYPE "Release") - -include(${TVM_BUILD_SOURCE}/cmake/util/Util.cmake) -include(${TVM_BUILD_SOURCE}/cmake/util/FindLLVM.cmake) -if(EXISTS ${TVM_BUILD_SOURCE}/cmake/config.cmake) - include(${TVM_BUILD_SOURCE}/cmake/config.cmake) -endif() -add_subdirectory(${TVM_KERNEL_LITE}) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - diff --git a/predict/module/tvm_kernel/lite/CMakeLists.txt b/predict/module/tvm_kernel/lite/CMakeLists.txt deleted file mode 100755 index 50e2bf5c9b..0000000000 --- a/predict/module/tvm_kernel/lite/CMakeLists.txt +++ /dev/null @@ -1,140 +0,0 @@ -cmake_minimum_required(VERSION 3.12) -set(CMAKE_CXX_STANDARD 14) - -if(ENABLE_PREDICT_ARM64) - set(TARGS "arm64") -elseif(ENABLE_PREDICT_ARM32) - set(TARGS "arm32") -else() - set(TARGS "x86") -endif() -message("TARGET is set to ${TARGS}") - -set(CMAKE_VERBOSE_MAKEFILE ON) -set(CMAKE_SKIP_RPATH TRUE) - -if(MSVC) - message("not support MSVC") -else(MSVC) - include(CheckCXXCompilerFlag) - check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) - if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("Build in Debug mode") - set(CMAKE_C_FLAGS "-O0 -g -Wall -Werror -fPIC [${CMAKE_C_FLAGS} -rdynamic") - set(CMAKE_CXX_FLAGS "-O0 -g -Wall -Werror -fPIC -std=c++11 ${CMAKE_CXX_FLAGS} -rdynamic") - else() - set(CMAKE_C_FLAGS "-D_FORTIFY_SOURCE=2 -O2 -fno-rtti -fvisibility=hidden -Wall -Werror -fPIC -fstack-protector-strong ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "-D_FORTIFY_SOURCE=2 -O2 -fno-rtti -fvisibility=hidden -Wall -Werror -fPIC -fstack-protector-strong -std=c++11 ${CMAKE_CXX_FLAGS}") - set(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack") - endif () - if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) - set(CMAKE_CXX_FLAGS "-Wall -Werror -faligned-new ${CMAKE_CXX_FLAGS}") - endif() - if (CODE_COVERAGE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -fprofile-arcs -ftest-coverage -O0") - endif() -endif(MSVC) - - -if("${TARGS}" STREQUAL "x86") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__x86_64__ -fno-strict-aliasing") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__x86_64__ -fno-strict-aliasing") -endif() - - -set(PRJ_SRC_DIR "${PROJECT_SOURCE_DIR}") -set(PRJ_KLIB_DIR "${PROJECT_SOURCE_DIR}") -set(PRJ_LITE_DIR "${PROJECT_SOURCE_DIR}/lite") - -# include directories -message("current PRJ DIR: ${PROJECT_SOURCE_DIR}") -message("current SUB_PRJ DIR: ${PRJ_SRC_DIR}") -message("current KLIB DIR: ${PRJ_KLIB_DIR}") -message("current PRJ_LITE_DIR: ${PRJ_LITE_DIR}") -message("CMAKE_CURRENT_BINARY_DIR: ${CMAKE_CURRENT_BINARY_DIR}") -set(DMLC_CORE "${TVM_BUILD_SOURCE}/3rdparty/dmlc-core") -set(DLPACK "${TVM_BUILD_SOURCE}/3rdparty/dlpack") -set(PREDICT "${PRJ_SRC_DIR}/../../") -set(SECUREC "${PRJ_SRC_DIR}/../../../third_party/securec") -message("include dir: ${DLPACK}/include") -include_directories(${DLPACK}/include) -include_directories(${DMLC_CORE}/include) -include_directories(${TVM_BUILD_SOURCE}/include) -include_directories(${TVM_BUILD_SOURCE}/src/pass) -include_directories(${PRJ_LITE_DIR}) -include_directories(${PRJ_LITE_DIR}/include) -include_directories(${PRJ_LITE_DIR}/../../..) -include_directories(${PRJ_LITE_DIR}/../../../include) -include_directories(${PRJ_LITE_DIR}/../../../src/runtime) -include_directories(${PRJ_LITE_DIR}/../../../common) -include_directories(${SECUREC}) -message("SECUREC: " "${SECUREC}/build/src") -include_directories(${PREDICT}) -include_directories(${PREDICT}/src) -include_directories(${PRJ_SRC_DIR}/../../../third_party/flatbuffers/include) -include_directories(${PRJ_SRC_DIR}/../../../third_party) -# Source file lists -file(GLOB_RECURSE TVM_KERNEL_SRC - src/api/*.cc - src/tflite/TFLite_Detection_PostProcess.cc) - -set (TVM_RUNTIME_FLG $ENV{TVM_RUNTIME_ON}) -if ("${TVM_RUNTIME_FLG}" STREQUAL "true") - message("Using TVM runtime function") - file(GLOB TVM_RUNTIME_SRCS - ${TVM_ROOT}/apps/howto_deploy/tvm_runtime_pack.cc) -else() - message("Using LITE runtime function") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DLITE_RUNTIME_ON -DTVM_RUNTIME_HEADER_ONLY -DLITE_THREAD_POOL_SHARED") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLITE_RUNTIME_ON -DTVM_RUNTIME_HEADER_ONLY -DLITE_THREAD_POOL_SHARED") - file(GLOB_RECURSE TVM_RUNTIME_SRCS - ${PREDICT}/src/runtime/*.cc) -endif() - -if("${TARGS}" STREQUAL "arm32" OR "${TARGS}" STREQUAL "arm64") - set(CMAKE_SKIP_BUILD_RPATH TRUE) - set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) - set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) -endif() - -set(LIB_X86_PATH "${PRJ_KLIB_DIR}/build/lib_x86") -set(LIB_ARM64_PATH "${PRJ_KLIB_DIR}/build/lib_arm64") -set(LIB_ARM32_PATH "${PRJ_KLIB_DIR}/build/lib_arm32") -if("${TARGS}" STREQUAL "x86") - set(KLIBS_PATH "${LIB_X86_PATH}") -elseif("${TARGS}" STREQUAL "arm64") - set(KLIBS_PATH "${LIB_ARM64_PATH}") -elseif("${TARGS}" STREQUAL "arm32") - set(KLIBS_PATH "${LIB_ARM32_PATH}") -else() - message(ERROR " not suport ${TARGS}") -endif() - -file(GLOB_RECURSE KERNEL_LIBS "${KLIBS_PATH}/*.o") -message("KERNEL_PATH= ${KLIBS_PATH}") - -add_compile_options(-DTVM_CUDA_RUNTIM=0) -add_compile_options(-DTVM_METAL_RUNTIM=0) -add_compile_options(-DTVM_OPENCL_RUNTIM=0) - -link_directories(${KLIBS_PATH}) - -add_library(tvm_runtime_pack STATIC ${TVM_RUNTIME_SRCS}) -add_library(kernel_manager STATIC ${TVM_KERNEL_SRC}) -add_library(tvm_kernel_static STATIC ${TVM_KERNEL_SRC} ${KERNEL_LIBS}) -add_library(tvm_kernel SHARED ${TVM_KERNEL_SRC} ${KERNEL_LIBS}) -set_target_properties(tvm_kernel PROPERTIES LINK_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack") - -set(KERNEL_LD_LIB tvm_runtime_pack dl) - -if("${TARGS}" STREQUAL "x86") - set(KERNEL_LD_LIB ${KERNEL_LD_LIB} pthread) -else() - set(ANDROID_ALLOW_UNDEFINED_SYMBOLS TRUE) -endif() - -target_link_libraries(tvm_kernel ${KERNEL_LD_LIB} libsecurec.a) -target_link_libraries(tvm_kernel_static OBJECT tvm_runtime_pack libsecurec.a) - -add_dependencies(tvm_kernel securec) diff --git a/predict/module/tvm_kernel/lite/include/lite/api/km_api.h b/predict/module/tvm_kernel/lite/include/lite/api/km_api.h deleted file mode 100644 index 7ccd4964cb..0000000000 --- a/predict/module/tvm_kernel/lite/include/lite/api/km_api.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this ${file} except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef PREDICT_MODULE_TVM_KERNEL_LITE_INCLUDE_LITE_API_KM_API_H_ -#define PREDICT_MODULE_TVM_KERNEL_LITE_INCLUDE_LITE_API_KM_API_H_ - -#include -#include -#include -#include -#include "schema/inner/ms_generated.h" -#include "schema/inner/op_generated.h" - -#define PUBLIC __attribute__((visibility("default"))) - -/*! - * \brief Call tvm kernel. - * \param fid tvm kernel id. - * \param tensors tvm kernel arguments. - * \return 0 if SUCCESS. - */ -PUBLIC int CallKernel(const std::string &fid, const std::vector &tensors); - -/*! - * \brief Get tvm kernel by id. - * \param fid tvm kernel id. - * \return std::function if SUCCESS else nullptr. - */ -PUBLIC std::function &)> GetKernel(const std::string &fid); - -/*! - * \brief Get tvm kernel by OpDef. - * \param opdef defined by predict schema. - * \param tensors. - * \param option. - * \return std::function if SUCCESS else nullptr. - */ -struct PUBLIC KernelOption { - int numThreads = 0; - std::string device; -}; - -PUBLIC std::function &)> GetKernel(const mindspore::predict::OpDef &opdef, - const std::vector &tensors, - const KernelOption &option); - -/*! - * \brief load TVM Kernel lib - * \param mode 0 indicate shared lib - * \param fname shared lib path when mode equals 0 - * \return 0 if SUCCESS - */ -PUBLIC void InitKernelManager(int mode, const std::string &fname); - -/* - * \brief config ThreadPool using mode - * \param mode: -1 using mid speed cpu first, 1 using higher speed cpu first - * \param nthreads: threads num to be used, can't exceed cpu num - * if mode==-1 bind mid cpu first - * if mode==1 bind higher cpu first - * if mode==0 no bind - * \param execute_self: cur thread do arithmetic or not - * execute_self: true cur thread do arithmetic work - * execute_self: false cur thread not do arithmetic work - */ -PUBLIC void ConfigThreadPool(int mode = -1, int nthreads = 2, bool execute_self = true); - -/* - * \brief provid simple api for mslite, mslite not care mode - */ -inline void CfgThreadPool(int nthread) { ConfigThreadPool(-1, nthread, true); } - -/* - * the Callback function to do cpu bind for master thread. - */ -PUBLIC void DoMasterThreadBind(bool bindflg); - -PUBLIC void DoAllThreadBind(bool ifBind); - -#undef PUBLIC - -#endif // PREDICT_MODULE_TVM_KERNEL_LITE_INCLUDE_LITE_API_KM_API_H_ diff --git a/predict/module/tvm_kernel/lite/python/__init__.py b/predict/module/tvm_kernel/lite/python/__init__.py deleted file mode 100644 index 5a51943fbe..0000000000 --- a/predict/module/tvm_kernel/lite/python/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Neural network operators""" -# from . import arm_cpu -# from . import at_ops diff --git a/predict/module/tvm_kernel/lite/python/arm_cpu/__init__.py b/predict/module/tvm_kernel/lite/python/arm_cpu/__init__.py deleted file mode 100644 index dce9d5e96c..0000000000 --- a/predict/module/tvm_kernel/lite/python/arm_cpu/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Schedule for ARM CPU""" - -from . import conv2d diff --git a/predict/module/tvm_kernel/lite/python/arm_cpu/conv2d.py b/predict/module/tvm_kernel/lite/python/arm_cpu/conv2d.py deleted file mode 100644 index ded792f689..0000000000 --- a/predict/module/tvm_kernel/lite/python/arm_cpu/conv2d.py +++ /dev/null @@ -1,470 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Conv2D schedule for ARM CPU""" -from __future__ import absolute_import as _abs - -import functools - -import tvm -from tvm import autotvm -import tvm.contrib.nnpack - -from topi.generic import schedule_conv2d_nchw -from topi.util import traverse_inline, get_const_tuple -from topi.nn import pad, conv2d -from topi.nn.util import get_const_int, get_pad_tuple - - -@autotvm.register_topi_compute(conv2d, "arm_cpu", ["asm"]) -def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype): - """TOPI compute callback for conv2d - - Parameters - ---------- - cfg: ConfigEntity - The config for this template - - data : tvm.Tensor - 4-D with shape [batch, in_channel, in_height, in_width] - - kernel : tvm.Tensor - 4-D with shape [num_filter, in_channel, filter_height, filter_width] or - pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height, - filter_width, num_filter_block] - - strides : list of two ints - [stride_height, stride_width] - - padding : list of two ints - [pad_height, pad_width] - - dilation : list of two ints - [dilation_height, dilation_width] - - out_dtype: str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.Tensor - 4-D with shape [batch, out_channel, out_height, out_width] - """ - args = _gen_cfg(cfg, data, kernel, strides, padding, dilation, num_tile=2) - return _conv_spatial_pack_asm( - args, data, kernel, strides, padding, dilation, out_dtype - ) - - -@autotvm.register_topi_schedule(schedule_conv2d_nchw, "arm_cpu", ["asm"]) -def schedule_conv2d_nchw_arm_cpu(outs): - """TOPI schedule callback for conv2d - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of conv2d - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for conv2d. - """ - s = _conv_schedule_asm(outs) - return s - - -def _gen_cfg(cfg, data, kernel, strides, padding, dilation, num_tile): - """_gen_cfg""" - if len(kernel.shape) == 4: - co_, _, kh_, kw_ = get_const_tuple(kernel.shape) - else: # kernel tensor is pre packed - co_, _, kh_, kw_, vc_ = get_const_tuple(kernel.shape) - co_ = co_ * vc_ - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - n_, ci_, ih_, iw_ = get_const_tuple(data.shape) - - dilated_kernel_h = (kh_ - 1) * dilation_h + 1 - dilated_kernel_w = (kw_ - 1) * dilation_w + 1 - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w) - ) - hstr, wstr = strides if isinstance(strides, (tuple, list)) else (strides, strides) - oh_ = (ih_ + pad_top + pad_bottom - dilated_kernel_h) // hstr + 1 - ow_ = (iw_ + pad_left + pad_right - dilated_kernel_w) // wstr + 1 - - n, co, oh, ow = cfg.axis(n_), cfg.axis(co_), cfg.axis(oh_), cfg.axis(ow_) - ci, kh, kw = cfg.reduce_axis(ci_), cfg.reduce_axis(kh_), cfg.reduce_axis(kw_) - - if num_tile == 2: # for arm cpu - candidate_vc = [] - for iv in range(3, co_): - if co_ % iv == 0: - candidate_vc.append([co_ // iv, iv]) - candidate_vc.append([1, co_]) - co, vc = cfg.define_split( - "tile_co", co, num_outputs=2, policy="candidate", candidate=candidate_vc - ) - oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2) - ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2) - elif num_tile == 3: # for mali gpu - co, _, vc = cfg.define_split("tile_co", co, num_outputs=3) - oh, _, vh = cfg.define_split("tile_oh", oh, num_outputs=3) - ow, _, vw = cfg.define_split("tile_ow", ow, num_outputs=3) - else: - raise RuntimeError("Invalid num_tile") - - cfg.define_reorder( - "reorder_0", - [n, co, oh, ow, ci, kh, kw, vh, vw, vc], - policy="candidate", - candidate=[[n, co, oh, ow, ci, kh, kw, vh, vw, vc],], - ) - - vc_ = cfg["tile_co"].size[-1] - vh_ = cfg["tile_oh"].size[-1] - vw_ = cfg["tile_ow"].size[-1] - is_var = False - return (is_var, vh_, vw_, vc_) - -def _conv_spatial_pack_asm(args, data, kernel, strides, padding, - dilation, out_dtype): - """_conv_spatial_pack_asm""" - is_var, vh_, vw_, vc_ = args - - # create workload according to raw arguments - out_dtype = out_dtype or data.dtype - n_, ci_, ih_, iw_ = data.shape if is_var else get_const_tuple(data.shape) - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - if len(kernel.shape) == 4: - pre_packed = False - co_, _, kh_, kw_ = kernel.shape if is_var else get_const_tuple(kernel.shape) - else: # kernel tensor is pre packed - pre_packed = True - co_, _, kh_, kw_, vc_ = kernel.shape if is_var else get_const_tuple(kernel.shape) - co_ = co_ * vc_ - - dilated_kernel_h = (kh_ - 1) * dilation_h + 1 - dilated_kernel_w = (kw_ - 1) * dilation_w + 1 - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w) - ) - hstr, wstr = strides if isinstance(strides, (tuple, list)) else (strides, strides) - oh_ = (ih_ + pad_top + pad_bottom - dilated_kernel_h) // hstr + 1 - ow_ = (iw_ + pad_left + pad_right - dilated_kernel_w) // wstr + 1 - data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) - - oh_div = oh_ // vh_ - ow_div = ow_ // vw_ - kvshape = (co_ // vc_, ci_, kh_, kw_, vc_) - ovshape = (n_, co_ // vc_, oh_div, ow_div, vh_, vw_, vc_) - oshape = (n_, co_, oh_div * vh_, ow_div * vw_) - - if dilation_h != 1 or dilation_w != 1: - # undilate input data - dvshape = (n_, oh_ // vh_, ow_ // vw_, kh_, kw_, vh_, vw_, ci_) - data_vec = tvm.compute( - dvshape, - lambda n, h, w, kh, kw, vh, vw, ci: data_pad[n][ci][ - (h * vh_ + vh) * hstr + kh * dilation_h - ][(w * vw_ + vw) * wstr + kw * dilation_w], - name="data_vec_undilated", - ) - else: - dvshape = ( - n_, - oh_ // vh_, - ow_ // vw_, - (vh_ - 1) * hstr + kh_, - (vw_ - 1) * wstr + kw_, - ci_, - ) - data_vec = tvm.compute( - dvshape, - lambda n, h, w, vh, vw, ci: data_pad[n][ci][h * vh_ * hstr + vh][ - w * vw_ * wstr + vw - ], - name="data_vec", - ) - - if pre_packed: - kernel_vec = kernel - else: - kernel_vec = tvm.compute( - kvshape, - lambda co, ci, kh, kw, vc: kernel[co * vc_ + vc][ci][kh][kw], - name="kernel_vec", - ) - - ci = tvm.reduce_axis((0, ci_), name="ci") - kh = tvm.reduce_axis((0, kh_), name="kh") - kw = tvm.reduce_axis((0, kw_), name="kw") - - # asm begin---- - type_map = { - "int8": "int32", - "uint8": "uint32", - "float32": "float32", - "float16": "float16", - } - acum_dtype = type_map[data.dtype] - attrs = { - "SH": hstr, - "SW": wstr, - "PH": pad_top, - "PW": pad_left, - "DILA_H": dilation_h, - "DILA_W": dilation_w, - "VH": vh_, - "VW": vw_, - "VC": vc_, - "ACUM_DTYPE": acum_dtype, - } - # asm end---- - - if dilation_h != 1 or dilation_w != 1: - conv = tvm.compute( - ovshape, - lambda n, co, h, w, vh, vw, vc: tvm.sum( - data_vec[n, h, w, kh, kw, vh, vw, ci].astype(out_dtype) - * kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), - axis=[ci, kh, kw], - ), - name="conv", - attrs=attrs, - ) - else: - conv = tvm.compute( - ovshape, - lambda n, co, h, w, vh, vw, vc: tvm.sum( - data_vec[n, h, w, vh * hstr + kh, vw * wstr + kw, ci].astype(out_dtype) - * kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), - axis=[ci, kh, kw], - ), - name="conv", - attrs=attrs, - ) - - output = tvm.compute( - oshape, - lambda n, co, h, w: conv[n][co // vc_][h // vh_][w // vw_][h % vh_][w % vw_][ - co % vc_ - ], - name="output_unpack", - tag="asm_conv2d_output", - ) - - return output - - -def intrin_conv(args): - """intrin_conv""" - ( - ci_, - vh_, - vw_, - vc_, - kh_, - kw_, - sh_, - sw_, - dila_h, - dila_w, - dtype, - acum_dtype, - opname, - core_id, - ) = args - hstr, wstr = sh_, sw_ - ci_ = tvm.var("ci_") if ci_ is None else ci_ - kvshape = (ci_, kh_, kw_, vc_) - ovshape = (vh_, vw_, vc_) - if dila_h != 1 or dila_w != 1: - dvshape = (kh_, kw_, vh_, vw_, ci_) - else: - dvshape = ((vh_ - 1) * hstr + kh_, (vw_ - 1) * wstr + kw_, ci_) - - data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) - kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) - ci = tvm.reduce_axis((0, ci_), name="ci") - kh = tvm.reduce_axis((0, kh_), name="kh") - kw = tvm.reduce_axis((0, kw_), name="kw") - if dila_h != 1 or dila_w != 1: - conv = tvm.compute( - ovshape, - lambda vh, vw, vc: tvm.sum( - data_vec[kh, kw, vh, vw, ci].astype(acum_dtype) - * kernel_vec[ci, kh, kw, vc].astype(acum_dtype), - axis=[ci, kh, kw], - ), - name="conv", - ) - else: - conv = tvm.compute( - ovshape, - lambda vh, vw, vc: tvm.sum( - data_vec[vh * hstr + kh, vw * wstr + kw, ci].astype(acum_dtype) - * kernel_vec[ci, kh, kw, vc].astype(acum_dtype), - axis=[ci, kh, kw], - ), - name="conv", - ) - - stride_a = [ - functools.reduce(lambda x, y: x * y, dvshape[i + 1: len(dvshape)]) - for i in range(0, len(dvshape) - 1) - ] - stride_a.append(1) - stride_b = [ - functools.reduce(lambda x, y: x * y, kvshape[i + 1: len(kvshape)]) - for i in range(0, len(kvshape) - 1) - ] - stride_b.append(1) - stride_c = [ - functools.reduce(lambda x, y: x * y, ovshape[i + 1: len(ovshape)]) - for i in range(0, len(ovshape) - 1) - ] - stride_c.append(1) - - a_buffer = tvm.decl_buffer( - data_vec.shape, data_vec.dtype, name="A", offset_factor=1, strides=stride_a - ) - b_buffer = tvm.decl_buffer( - kernel_vec.shape, kernel_vec.dtype, name="B", offset_factor=1, strides=stride_b - ) - c_buffer = tvm.decl_buffer( - conv.shape, conv.dtype, name="C", offset_factor=1, strides=stride_c - ) - - def intrin_func(ins, outs): - aa, bb = ins - cc = outs[0] - - def _body(): - ib = tvm.ir_builder.create() - ib.emit( - tvm.call_extern( - "int32", - opname, - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"), - ci_, - vh_, - vw_, - vc_, - kh_, - sh_, - core_id, - ) - ) - return ib.get() - - return _body() - - return tvm.decl_tensor_intrin( - conv.op, intrin_func, binds={data_vec: a_buffer, kernel_vec: b_buffer, conv: c_buffer} - ) - - -def _schedule_asm(s, data_vec, kernel_vec, conv, output, last): - """schedule implementation""" - n, co, oh, ow, vh, vw, vc = s[conv].op.axis - - axis_extent = [] - for i in (vh, vw, vc): - axis_extent.append(get_const_int(i.dom.extent)) - reduce_extent = [] - for i in s[conv].op.reduce_axis[1:]: - reduce_extent.append(get_const_int(i.dom.extent)) - vh_, vw_, vc_ = axis_extent - - # schedule fusion - n, co, h, w = s[last].op.axis - co, vc = s[last].split(co, vc_) - oh, vh = s[last].split(h, vh_) - ow, vw = s[last].split(w, vw_) - s[last].reorder(n, co, oh, ow, vh, vw, vc) - if last != output: - s[output].compute_inline() - - s[conv].compute_at(s[last], ow) - - # mark parallel - s[last].parallel(co) - - if data_vec.op.name == "data_vec_undilated": - _, h, _, _, _, _, _, _ = s[data_vec].op.axis - else: - _, h, _, _, _, _ = s[data_vec].op.axis - s[data_vec].parallel(h) - - if kernel_vec.op.name == "kernel_vec": - co, _, _, _, _ = s[kernel_vec].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(co, "debug_skip_region") - else: - s[kernel_vec].parallel(co) - elif kernel_vec.op.name == "kernel_vec_conv2d_transpose": # for conv2d transpose - co, _, _, _, _ = s[kernel_vec].op.axis - s[kernel_vec].parallel(co) - - return s - - -def _conv_schedule_asm(outs): - """_conv_schedule_asm""" - s = tvm.create_schedule([x.op for x in outs]) - - def _callback(op): - if "asm_conv2d_output" in op.tag: - # schedule conv2d - output = op.output(0) - conv = op.input_tensors[0] - - sidx = 0 - if conv.op.input_tensors[0].name == "attr": - sidx = 1 - data_vec = conv.op.input_tensors[sidx] - data_pad = data_vec.op.input_tensors[0] - s[data_pad].compute_inline() - - kernel_vec = conv.op.input_tensors[sidx + 1] - if kernel_vec.op.name == "kernel_vec": - kernel = kernel_vec.op.input_tensors[0] - else: - kernel = kernel_vec - if (isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag): - s[kernel].compute_inline() - - if conv.op.input_tensors[0].name == "attr": - _schedule_asm(s, data_vec, kernel_vec, conv, output, outs[0]) - else: - _schedule_asm(s, data_vec, kernel_vec, conv, output, outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s diff --git a/predict/module/tvm_kernel/lite/python/arm_cpu/deconv.py b/predict/module/tvm_kernel/lite/python/arm_cpu/deconv.py deleted file mode 100644 index 4ed29900a2..0000000000 --- a/predict/module/tvm_kernel/lite/python/arm_cpu/deconv.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Conv2D_transpose of stride=2, kernel=2*2 schedule for ARM CPU""" -from __future__ import absolute_import as _abs - -import functools - -import tvm -from tvm import autotvm -import tvm.contrib.nnpack - -from topi.generic import schedule_conv2d_nchw -from topi.util import traverse_inline, get_const_tuple -from topi.nn import conv2d - - -@autotvm.register_topi_compute(conv2d, "arm_cpu", ["deconv"]) -def conv2d_arm_cpu_deconv(cfg, data, kernel, out_dtype): - """TOPI compute callback for conv2d - - Parameters - ---------- - cfg: ConfigEntity - The config for this template - - data : tvm.Tensor - 4-D with shape [batch, in_channel, in_height, in_width] - - kernel : tvm.Tensor - 4-D with shape [num_filter, in_channel, filter_height, filter_width] or - pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height, - filter_width, num_filter_block] - - out_dtype: str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.Tensor - 4-D with shape [batch, out_channel, out_height, out_width] - """ - args = _gen_cfg_deconv(cfg, data, kernel, num_tile=2) - return _conv_spatial_pack_deconv( - args, data, kernel, out_dtype - ) - - -@autotvm.register_topi_schedule(schedule_conv2d_nchw, "arm_cpu", ["deconv"]) -def schedule_conv2d_nchw_arm_cpu_deconv(cfg, outs): - """TOPI schedule callback for conv2d - - Parameters - ---------- - cfg: ConfigEntity - The config for this template - - outs: Array of Tensor - The computation graph description of conv2d - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for conv2d. - """ - s = _conv_schedule_deconv(cfg, outs) - return s - - -def _gen_cfg_deconv(cfg, data, kernel, num_tile): - """generation config from input args""" - if len(kernel.shape) == 4: - co_, _, _, _ = get_const_tuple(kernel.shape) - else: # kernel tensor is pre packed - co_, _, _, _, vc_ = get_const_tuple(kernel.shape) - co_ = co_ * vc_ - - if len(data.shape) == 4: - _, ci_, ih_, iw_ = get_const_tuple(data.shape) - c4 = 4 - ci_ = ci_ // 4 - else: - _, ci_, ih_, iw_, c4 = get_const_tuple(data.shape) - - oh_ = ih_ * 2 - ow_ = iw_ * 2 - - co, oh, ow = cfg.axis(co_), cfg.axis(oh_), cfg.axis(ow_) - ci, ki = cfg.reduce_axis(ci_), cfg.reduce_axis(c4) - - if num_tile == 2: # for arm cpu - candidate_vc = [[co_ // c4, c4]] - co, vc = cfg.define_split( - "tile_co", co, num_outputs=2, policy="candidate", candidate=candidate_vc - ) - candidate_vw = [] - for iv in range(4, ow_ + 1): # [4, 6, 8, 12, 16, 24, 32, 40]: - if iv % 4 == 0 and (ow_ % iv == 0): - candidate_vw.append([ow_ // iv, iv]) - ow, vw = cfg.define_split( - "tile_ow", ow, num_outputs=2, policy="candidate", candidate=candidate_vw - ) - candidate_vh = [[1, 2]] - oh, vh = cfg.define_split( - "tile_oh", oh, num_outputs=2, policy="candidate", candidate=candidate_vh - ) - elif num_tile == 3: # for mali gpu - co, _, vc = cfg.define_split("tile_co", co, num_outputs=3) - oh, _, vh = cfg.define_split("tile_oh", oh, num_outputs=3) - ow, _, vw = cfg.define_split("tile_ow", ow, num_outputs=3) - else: - raise RuntimeError("Invalid num_tile") - - cfg.define_annotate("ann_reduce", [ci, ki], policy="try_unroll") - cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec") - - vc_ = cfg["tile_co"].size[-1] - vh_ = cfg["tile_oh"].size[-1] - vw_ = cfg["tile_ow"].size[-1] - is_var = False - return (is_var, vh_, vw_, vc_) - - -def _conv_spatial_pack_deconv(args, data, kernel, out_dtype): - """conv2d_arm_cpu_deconv inner implement""" - is_var, vh_, vw_, vc_ = args - # create workload according to raw arguments - out_dtype = out_dtype or data.dtype - if len(data.shape) == 4: - n_, ci_, ih_, iw_ = data.shape if is_var else get_const_tuple(data.shape) - c4 = 4 - ci_ = ci_ // c4 - else: - n_, ci_, ih_, iw_, c4 = data.shape if is_var else get_const_tuple(data.shape) - - if len(kernel.shape) == 4: - pre_packed = False - _, co_, kh_, kw_ = kernel.shape if is_var else get_const_tuple(kernel.shape) - else: # kernel tensor is pre packed - pre_packed = True - _, co_, kh_, kw_, vc_ = kernel.shape if is_var else get_const_tuple(kernel.shape) - co_ = co_ * c4 - - oh_ = ih_ * 2 - ow_ = iw_ * 2 - ow_div = ow_ // vw_ - oh_div = oh_ // vh_ - kvshape = (co_ // vc_, kh_, kw_, ci_, c4, c4) - ovshape = (n_, co_ // vc_, oh_div, ow_div, vh_, vw_, c4) - - dvshape = (n_, ih_ // (vh_ // 2), iw_ // (vw_ // 2), vh_ // 2, ci_, vw_ // 2, c4) - if len(data.shape) == 4: - data_vec = tvm.compute( - dvshape, - lambda n, h, w, vh, ci, vw, ki: data[n][ci * c4 + ki][h * vh_ // 2 + vh][ - w * vw_ // 2 + vw - ], - name="data_vec", - ) - else: - data_vec = tvm.compute( - dvshape, - lambda n, h, w, vh, ci, vw, ki: data[n][ci][h * vh_ // 2 + vh][ - w * vw_ // 2 + vw - ][ki], - name="data_vec", - ) - - if pre_packed: - kernel_vec = kernel - else: - kernel_vec = tvm.compute( - kvshape, - lambda co, kh, kw, ci, ki, vc: kernel[ci * c4 + ki][co * vc_ + vc][kh][kw], - name="kernel_vec", - ) - - ci = tvm.reduce_axis((0, ci_), name="ci") - ki = tvm.reduce_axis((0, c4), name="ki") - - type_map = { - "int8": "int32", - "uint8": "uint32", - "float32": "float32", - "float16": "float16", - } - acum_dtype = type_map[data.dtype] - attrs = { - "SH": 2, - "SW": 2, - "PH": 0, - "PW": 0, - "DILA_H": 1, - "DILA_W": 1, - "VH": vh_, - "VW": vw_, - "VC": vc_, - "ACUM_DTYPE": acum_dtype, - } - - conv = tvm.compute( - ovshape, - lambda n, co, h, w, vh, vw, vc: tvm.sum( - data_vec[n, h, w, vh // 2, ci, vw // 2, ki].astype(out_dtype) - * kernel_vec[co, (h * vh_ + vh) % 2, (w * vw_ + vw) % 2, ci, ki, vc].astype( - out_dtype - ), - axis=[ci, ki], - ), - name="conv", - attrs=attrs, - ) - if len(data.shape) == 4: - osshape = (n_, co_, oh_, ow_div * vw_) - output = tvm.compute( - osshape, - lambda n, co, h, w: conv[n][co // c4][h][w // vw_][w % vw_][co % c4], - name="output_unpack", - tag="deconv_conv2d_output", - ) - else: - osshape = (n_, co_ // c4, oh_, ow_div * vw_, c4) - output = tvm.compute( - osshape, - lambda n, co, h, w, vc: conv[n][co][h // vh_][w // vw_][h % vh_][w % vw_][vc], - name="output_unpack", - tag="deconv_conv2d_output", - ) - - return output - - -def intrin_deconv(args): - """deconv inner implement""" - ( - ci_, - vh_, - vw_, - vc_, - kh_, - kw_, - sh_, - sw_, - dila_h, - dila_w, - dtype, - acum_dtype, - opname, - core_id, - ) = args - hstr, wstr = sh_, sw_ - ci_ = tvm.var("ci_") if ci_ is None else ci_ - kvshape = (ci_, kh_, kw_, vc_) - ovshape = (vh_, vw_, vc_) - if dila_h != 1 or dila_w != 1: - dvshape = (kh_, kw_, vh_, vw_, ci_) - else: - dvshape = ((vh_ - 1) * hstr + kh_, (vw_ - 1) * wstr + kw_, ci_) - - data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) - kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) - ci = tvm.reduce_axis((0, ci_), name="ci") - kh = tvm.reduce_axis((0, kh_), name="kh") - kw = tvm.reduce_axis((0, kw_), name="kw") - if DILA_H != 1 or dila_w != 1: - conv = tvm.compute( - ovshape, - lambda vh, vw, vc: tvm.sum( - data_vec[kh, kw, vh, vw, ci].astype(acum_dtype) - * kernel_vec[ci, kh, kw, vc].astype(acum_dtype), - axis=[ci, kh, kw], - ), - name="conv", - ) - else: - conv = tvm.compute( - ovshape, - lambda vh, vw, vc: tvm.sum( - data_vec[vh * hstr + kh, vw * wstr + kw, ci].astype(acum_dtype) - * kernel_vec[ci, kh, kw, vc].astype(acum_dtype), - axis=[ci, kh, kw], - ), - name="conv", - ) - - stride_a = [ - functools.reduce(lambda x, y: x * y, dvshape[i + 1: len(dvshape)]) - for i in range(0, len(dvshape) - 1) - ] - stride_a.append(1) - stride_b = [ - functools.reduce(lambda x, y: x * y, kvshape[i + 1: len(kvshape)]) - for i in range(0, len(kvshape) - 1) - ] - stride_b.append(1) - stride_c = [ - functools.reduce(lambda x, y: x * y, ovshape[i + 1: len(ovshape)]) - for i in range(0, len(ovshape) - 1) - ] - stride_c.append(1) - - a_buffer = tvm.decl_buffer( - data_vec.shape, data_vec.dtype, name="A", offset_factor=1, strides=stride_a - ) - b_buffer = tvm.decl_buffer( - kernel_vec.shape, kernel_vec.dtype, name="B", offset_factor=1, strides=stride_b - ) - c_buffer = tvm.decl_buffer( - conv.shape, conv.dtype, name="C", offset_factor=1, strides=stride_c - ) - - def intrin_func(ins, outs): - aa, bb = ins - cc = outs[0] - - def _body(): - ib = tvm.ir_builder.create() - ib.emit( - tvm.call_extern( - "int32", - opname, - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"), - ci_, - vh_, - vw_, - vc_, - kh_, - sh_, - core_id, - ) - ) - return ib.get() - - return _body() - - return tvm.decl_tensor_intrin( - conv.op, intrin_func, binds={data_vec: a_buffer, kernel_vec: b_buffer, conv: c_buffer} - ) - - -def _schedule_deconv(cfg, s, data_vec, kernel_vec, conv, output, last): - """schedule implementation""" - is_tune = bool(isinstance(cfg, (tvm.autotvm.ConfigEntity, tvm.autotvm.ConfigSpace))) - if is_tune: - vh_ = cfg["tile_oh"].size[-1] - vw_ = cfg["tile_ow"].size[-1] - vc_ = cfg["tile_co"].size[-1] - cfg = { - "ci_": tvm.var("ci_"), - "VH": vh_, - "VW": vw_, - "VC": vc_, - "tile_oh": vh_, - "tile_ow": vw_, - "tile_co": vc_, - "tile_ci": 4, - "ann_reduce": cfg["ann_reduce"].anns, - "ann_spatial": cfg["ann_spatial"].anns, - } # ,'reorder_0':cfg['reorder_0'].perm} - else: - pass - n, co, oh, ow, vh, vw, vc = s[conv].op.axis - ci, ki = s[conv].op.reduce_axis - s[conv].reorder(n, co, oh, ow, ci, vw, ki, vc) - if cfg["ann_reduce"][0] == "unroll": - s[conv].unroll(ci) - elif cfg["ann_reduce"][0] == "vec": - s[conv].vectorize(ci) - if cfg["ann_reduce"][1] == "unroll": - s[conv].unroll(ki) - elif cfg["ann_reduce"][1] == "vec": - s[conv].vectorize(ki) - if cfg["ann_spatial"][0] == "vec": - s[conv].vectorize(vh) - elif cfg["ann_spatial"][0] == "unroll": - s[conv].unroll(vh) - if cfg["ann_spatial"][1] == "vec": - s[conv].vectorize(vw) - elif cfg["ann_spatial"][1] == "unroll": - s[conv].unroll(vw) - if cfg["ann_spatial"][2] == "vec": - s[conv].vectorize(vc) - elif cfg["ann_spatial"][2] == "unroll": - s[conv].unroll(vc) - - # schedule conv - attrs = conv.op.attrs - vh_, vw_, vc_ = (attrs["VH"].value, attrs["VW"].value, attrs["VC"].value) - - # schedule fusion - if len(s[last].op.axis) == 4: - n, co, h, w = s[last].op.axis - co, vc = s[last].split(co, vc_) - ow, vw = s[last].split(w, vw_) - oh, vh = s[last].split(h, vh_) - s[last].reorder(n, co, oh, ow, vh, vw, vc) - else: - n, co, h, w, vc = s[last].op.axis - oh, vh = s[last].split(h, vh_) - ow, vw = s[last].split(w, vw_) - s[last].reorder(n, co, oh, ow, vh, vw, vc) - if last != output and isinstance(output.op, tvm.tensor.ComputeOp): - s[output].compute_inline() - if cfg["ann_spatial"][0] == "vec": - s[last].vectorize(vh) - elif cfg["ann_spatial"][0] == "unroll": - s[last].unroll(vh) - if cfg["ann_spatial"][1] == "vec": - s[last].vectorize(vw) - elif cfg["ann_spatial"][1] == "unroll": - s[last].unroll(vw) - if cfg["ann_spatial"][2] == "vec": - s[last].vectorize(vc) - elif cfg["ann_spatial"][2] == "unroll": - s[last].unroll(vc) - - s[conv].compute_at(s[last], ow) - - # mark parallel - s[last].parallel(co) - - if data_vec.op.name == "data_vec_undilated": - _, h, _, _, _, _, _, _, _ = s[data_vec].op.axis - else: - _, h, _, _, _, _, _ = s[data_vec].op.axis - s[data_vec].parallel(h) - - co, _, _, _, _, vc = s[kernel_vec].op.axis - s[kernel_vec].parallel(co) - if cfg["ann_spatial"][2] == "vec": - s[kernel_vec].vectorize(vc) - elif cfg["ann_spatial"][2] == "unroll": - s[kernel_vec].unroll(vc) - return s - - -def _conv_schedule_deconv(cfg, outs): - """schedule_conv2d_nchw_arm_cpu_deconv inner implementation""" - s = tvm.create_schedule([x.op for x in outs]) - - def _callback(op): - if "deconv_conv2d_output" in op.tag: - # schedule conv2d - output = op.output(0) - conv = op.input_tensors[0] - - sidx = 0 - if conv.op.input_tensors[0].name == "attr": - sidx = 1 - data_vec = conv.op.input_tensors[sidx] - - kernel_vec = conv.op.input_tensors[sidx + 1] - if kernel_vec.op.name == "kernel_vec": - kernel = kernel_vec.op.input_tensors[0] - else: - kernel = kernel_vec - if (isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag): - s[kernel].compute_inline() - - _schedule_deconv(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s diff --git a/predict/module/tvm_kernel/lite/python/arm_cpu/depthwise_conv2d.py b/predict/module/tvm_kernel/lite/python/arm_cpu/depthwise_conv2d.py deleted file mode 100644 index f54076eb84..0000000000 --- a/predict/module/tvm_kernel/lite/python/arm_cpu/depthwise_conv2d.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Depthwise convolution schedule for ARM CPU""" - -import tvm -from tvm import autotvm - -from topi.generic import schedule_depthwise_conv2d_nchw -from topi.nn import depthwise_conv2d_nchw, pad -from topi.util import traverse_inline, get_const_tuple -from topi.nn.util import get_pad_tuple - -# register customized schedule for arm cpu. -@autotvm.register_topi_schedule( - schedule_depthwise_conv2d_nchw, ["arm_cpu", "cpu"], ["custom"] -) -def schedule_depthwise_conv2d_nchw_arm(cfg, outs): - """Schedule depthwise conv2d - - Parameters - ---------- - cfg: ConfigEntity - The configuration of this template - outs: Array of Tensor - The computation graph description of depthwise convolution2d - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for depthwise_conv2d nchw. - """ - s = _depthwise_schedule_spatial_pack(cfg, outs) - return s - - -@autotvm.register_topi_compute(depthwise_conv2d_nchw, ["arm_cpu", "cpu"], ["custom"]) -def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype): - """TOPI compute callback for depthwise_conv2d nchw - - Parameters - ---------- - cfg: ConfigEntity - The config for this template - - data : tvm.Tensor - 4-D with shape [batch, in_channel, in_height, in_width] - - kernel : tvm.Tensor - 4-D with shape [num_filter, multiplier, filter_height, filter_width] or - pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height, - filter_width, num_filter_block] - - strides : list of two ints - [stride_height, stride_width] - - padding : list of two ints - [pad_height, pad_width] - - dilation : list of two ints - [dilation_height, dilation_width] - - out_dtype: str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.Tensor - 4-D with shape [batch, out_channel, out_height, out_width] - """ - - return _depthwise_spatial_pack( - cfg, data, kernel, strides, padding, dilation, out_dtype - ) - - -def _depthwise_spatial_pack(args, data, kernel, strides, padding, dilation, out_dtype): - """depthwise_conv2d_arm_cpu's inner implement""" - is_var, u_vh, u_vw, u_vc = args - out_dtype = out_dtype or data.dtype - - u_n, u_c, ih, iw = data.shape if is_var else get_const_tuple(data.shape) - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - if len(kernel.shape) == 4: - pre_packed = False - u_c, um, ukh, ukw = kernel.shape if is_var else get_const_tuple(kernel.shape) - else: # kernel tensor is pre packed - pre_packed = True - u_c, um, ukh, ukw, u_vc = kernel.shape if is_var else get_const_tuple(kernel.shape) - u_c = u_c * u_vc - - dilated_kernel_h = (ukh - 1) * dilation_h + 1 - dilated_kernel_w = (ukw - 1) * dilation_w + 1 - - pad_top, pad_left, pad_down, pad_right = get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w) - ) - hstr, wstr = strides if isinstance(strides, (tuple, list)) else (strides, strides) - u_oh = (ih + pad_top + pad_down - dilated_kernel_h) // hstr + 1 - u_ow = (iw + pad_left + pad_right - dilated_kernel_w) // wstr + 1 - # pack data - hpad = pad_top + pad_down - wpad = pad_left + pad_right - dopad = hpad != 0 or wpad != 0 - if dopad: - data_pad = pad( - data, - (0, 0, pad_top, pad_left), - (0, 0, pad_down, pad_right), - name="data_pad", - ) - else: - data_pad = data - - oh_div = u_oh // u_vh - ow_div = u_ow // u_vw - kvshape = (u_c // u_vc, um, ukh, ukw, u_vc) - ovshape = (u_n, u_c * um // u_vc, oh_div, u_ow // u_vw, u_vh, u_vw, u_vc) - oshape = (u_n, u_c * um, oh_div * u_vh, ow_div * u_vw) - - if dilation_h != 1 or dilation_w != 1: - # undilate input data - dvshape = (u_n, oh_div, ow_div, u_c, ukh, ukw, u_vh, u_vw) - data_vec = tvm.compute( - dvshape, - lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][ - (h * u_vh + vh) * hstr + kh * dilation_h - ][(w * u_vw + vw) * wstr + kw * dilation_w], - name="data_vec_undilated", - ) - else: - dvshape = (u_n, oh_div, ow_div, u_c, u_vh * hstr + ukh - 1, u_vw * wstr + ukw - 1) - data_vec = tvm.compute( - dvshape, - lambda n, h, w, c, vh, vw: data_pad[n][c][h * u_vh * hstr + vh][ - w * u_vw * wstr + vw - ], - name="data_vec", - ) - - if pre_packed: - kernel_vec = kernel - else: - kernel_vec = tvm.compute( - kvshape, - lambda co, m, kh, kw, vc: kernel[co * u_vc + vc][m][kh][kw], - name="kernel_vec", - ) - - kh = tvm.reduce_axis((0, ukh), name="kh") - kw = tvm.reduce_axis((0, ukw), name="kw") - - if dilation_h != 1 or dilation_w != 1: - conv = tvm.compute( - ovshape, - lambda n, co, h, w, vh, vw, vc: tvm.sum( - data_vec[n, h, w, (co * u_vc + vc) // um, kh, kw, vh, vw].astype(out_dtype) - * kernel_vec[co // um, co % um, kh, kw, vc].astype(out_dtype), - axis=[kh, kw], - ), - name="depthwise_conv", - ) - else: - conv = tvm.compute( - ovshape, - lambda n, co, h, w, vh, vw, vc: tvm.sum( - data_vec[ - n, h, w, (co * u_vc + vc) // um, vh * hstr + kh, vw * wstr + kw - ].astype(out_dtype) - * kernel_vec[co // um, co % um, kh, kw, vc].astype(out_dtype), - axis=[kh, kw], - ), - name="depthwise_conv", - ) - - output = tvm.compute( - oshape, - lambda n, co, h, w: conv[n][co // u_vc][h // u_vh][w // u_vw][h % u_vh][w % u_vw][ - co % u_vc - ], - name="output_unpack", - tag="spatial_depthwise_conv_nchw_output", - ) - return output - - -def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last): - """schedule implementation""" - u_vc = cfg["tile_co"].size[-1] if not isinstance(cfg, dict) else cfg["VC"] - u_vh = cfg["tile_oh"].size[-1] if not isinstance(cfg, dict) else cfg["VH"] - u_vw = cfg["tile_ow"].size[-1] if not isinstance(cfg, dict) else cfg["VW"] - - n, co, oh, ow, vh, vw, vc = s[conv].op.axis - kh, kw = s[conv].op.reduce_axis - - if data_vec.op.name == "data_vec_undilated": - _, _, dv_ow, _, _, _, _, _ = s[data_vec].op.axis - else: - _, _, dv_ow, _, _, _ = s[data_vec].op.axis - - data_pad = data_vec.op.input_tensors[0] - - if isinstance(data_pad.op, tvm.tensor.ComputeOp): - s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) - s[data_pad].compute_at(s[data_vec], dv_ow) - - s[data_vec].vectorize(list(s[data_vec].op.axis)[-1]) - s[data_vec].compute_at(s[conv], ow) - - # schedule conv - s[conv].reorder(n, co, oh, ow, kh, kw, vh, vw, vc) - s[conv].unroll(kh) - s[conv].unroll(vh) - s[conv].vectorize(vw) - s[conv].unroll(vc) - s[conv].parallel(co) - - n, co, h, w = s[last].op.axis - co, vc = s[last].split(co, u_vc) - oh, vh = s[last].split(h, u_vh) - ow, vw = s[last].split(w, u_vw) - if last != output: - s[output].compute_inline() - s[last].vectorize(vw) - s[last].unroll(vc) - else: - s[last].vectorize(vw) - s[conv].compute_at(s[last], oh) - - # mark parallel - s[last].parallel(co) - - if data_vec.op.name == "data_vec_undilated": - _, h, _, _, _, _, _, _ = s[data_vec].op.axis - else: - _, h, _, _, _, _ = s[data_vec].op.axis - s[data_vec].parallel(h) - - if kernel_vec.op.name == "kernel_vec": - co, _, _, _, _ = s[kernel_vec].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compliation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(co, "debug_skip_region") - else: - s[kernel_vec].parallel(co) - - return s - - -def _depthwise_schedule_spatial_pack(cfg, outs): - """schedule_depthwise_conv2d_nchw_arm's inner implement""" - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "spatial_depthwise_conv_nchw_output": - output = op.output(0) - conv = op.input_tensors[0] - data_vec = conv.op.input_tensors[0] - kernel_vec = conv.op.input_tensors[1] - if kernel_vec.op.name == "kernel_vec": - kernel = kernel_vec.op.input_tensors[0] - else: - kernel = kernel_vec - if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - - _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s diff --git a/predict/module/tvm_kernel/lite/python/arm_cpu/matmul.py b/predict/module/tvm_kernel/lite/python/arm_cpu/matmul.py deleted file mode 100644 index 6430f24f6f..0000000000 --- a/predict/module/tvm_kernel/lite/python/arm_cpu/matmul.py +++ /dev/null @@ -1,472 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Conv2D schedule for ARM CPU""" -from __future__ import absolute_import as _abs - -import functools - -import tvm -from tvm import autotvm -import tvm.contrib.nnpack - -from topi.generic import schedule_conv2d_nchw -from topi.util import traverse_inline -from topi.nn import conv2d - - -@autotvm.register_topi_compute(conv2d, "arm_cpu", ["matmul"]) -def matmul_arm_cpu(cfg, a_, b_, layout, out_dtype): - """TOPI compute callback for - - Parameters - ---------- - cfg: ConfigEntity - The config for this template - - a_ : tvm.Tensor - 2-D with shape [M, k_] - - b_ : tvm.Tensor - 2-D with shape [k_, N] - - out_dtype: str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.Tensor - 4-D with shape [batch, out_channel, out_height, out_width] - """ - args = _gen_cfg(cfg, a_, b_) - return _matmul_spatial_pack_asm(args, a_, b_, layout, out_dtype) - - -@autotvm.register_topi_schedule(schedule_conv2d_nchw, "arm_cpu", ["matmul"]) -def schedule_matmul_arm_cpu(cfg, outs): - """TOPI schedule callback for conv2d - - Parameters - ---------- - cfg: ConfigEntity - The config for this template - - outs: Array of Tensor - The computation graph description of conv2d - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for conv2d. - """ - s = _matmul_schedule_asm(cfg, outs) - return s - - -def _gen_cfg(cfg, a_, b_): - """get best loginfo from cfg""" - if len(a_.shape) == 2: - w_, ci_ = get_const_tuple(a_.shape) - h_ = 1 - elif len(a_.shape) == 3: - _, ci_, w_ = get_const_tuple(a_.shape) - h_ = 1 - elif len(a_.shape) == 4: - _, ci_, h_, w_ = get_const_tuple(a_.shape) - else: - raise ValueError("not support shape: " + a_.shape) - - co_, k_ = get_const_tuple(b_.shape) - - oh, ow = cfg.axis(h_), cfg.axis(w_) - co = cfg.axis(co_) - k = cfg.reduce_axis(k_) - - oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2) - ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2) - oc, vc = cfg.define_split("tile_co", co, num_outputs=2) - - cfg.define_reorder( - "reorder_0", - [oc, oh, ow, k, vh, vw, vc], - policy="candidate", - candidate=[[oc, oh, ow, k, vh, vw, vc],], - ) - - vh_ = cfg["tile_oh"].size[-1] - vw_ = cfg["tile_ow"].size[-1] - vc_ = cfg["tile_co"].size[-1] - is_var = False - is_transpose = False - return (is_var, is_transpose, ci_, vh_, vw_, vc_) - - -def _matmul_spatial_pack_asm(args, a_, b_, layout, out_dtype): - """matmul_spatial_pack_asm's inner interace""" - is_var, is_transpose, ci_, vh_, vw_, vc_ = args - - # create workload according to raw arguments - out_dtype = out_dtype or a_.dtype - if layout == "NCHW": - batch, k_, h_, w_ = a_.shape if is_var else get_const_tuple(a_.shape) - n_, _ = b_.shape if is_var else get_const_tuple(b_.shape) - elif layout == "NCH": - batch, k_, h_ = a_.shape if is_var else get_const_tuple(a_.shape) - n_, _ = b_.shape if is_var else get_const_tuple(b_.shape) - w_ = 1 - elif layout == "NC": - w_, k_ = a_.shape if is_var else get_const_tuple(a_.shape) - n_, _ = b_.shape if is_var else get_const_tuple(b_.shape) - h_ = 1 - else: - raise ValueError("not support layout: " + layout) - - ki = tvm.reduce_axis((0, k_), name="ki") - type_map = { - "int8": "int32", - "uint8": "uint32", - "float32": "float32", - "float16": "float16", - } - acum_dtype = type_map[a_.dtype] - attrs = {"ci_": ci_, "vh_": vh_, "vw_": vw_, "vc_": vc_, "ACUM_DTYPE": acum_dtype} - - if layout == "NCHW": - h_div = h_ // vh_ - w_div = w_ // vw_ - n_div = n_ // vc_ - avshape = (batch, h_div, w_div, vh_, vw_, k_) - bvshape = (n_div, k_, vc_) - ovshape = (batch, n_div, h_div, w_div, vh_, vw_, vc_) - - a_vec = tvm.compute( - avshape, - lambda n, oh, ow, vh, vw, ci: a_[n][ci][oh * vh_ + vh][ow * vw_ + vw], - name="a_vec", - ) - b_vec = tvm.compute( - bvshape, lambda oc, ci, vc: b_[oc * vc_ + vc][ci], name="b_vec" - ) - - ma = tvm.compute( - ovshape, - lambda n, oc, oh, ow, vh, vw, vc: tvm.sum( - a_vec[n, oh, ow, vh, vw, ki].astype(out_dtype) - * b_vec[oc, ki, vc].astype(out_dtype), - axis=[ki], - ), - name="matmul", - attrs=attrs, - ) - - if is_transpose: - oshape = (batch, h_div * vh_, w_div * vw_, n_div * vc_) - - output = tvm.compute( - oshape, - lambda n, h, w, c: ma[n][c // vc_][h // vh_][w // vw_][h % vh_][w % vw_][ - c % vc_ - ], - name="output_unpack", - tag="asm_matmul_output", - ) - else: - oshape = (batch, n_div * vc_, h_div * vh_, w_div * vw_) - output = tvm.compute( - oshape, - lambda n, c, h, w: ma[n][c // vc_][h // vh_][w // vw_][h % vh_][w % vw_][ - c % vc_ - ], - name="output_unpack", - tag="asm_matmul_output", - ) - elif layout == "NCH": - w_div = w_ // vw_ - n_div = n_ // vc_ - avshape = (batch, w_div, vw_, k_) - bvshape = (n_div, k_, vc_) - ovshape = (batch, n_div, w_div, vw_, vc_) - oshape = (batch, n_div * vc_, w_div * vw_) - - a_vec = tvm.compute( - avshape, lambda b, om, vw, ci: a_[b][ci][om * vw_ + vw], name="a_vec" - ) - b_vec = tvm.compute( - bvshape, lambda on, ci, vc: b_[on * vc_ + vc][ci], name="b_vec" - ) - - ma = tvm.compute( - ovshape, - lambda b, on, om, vm, vn: tvm.sum( - a_vec[b, om, vm, ki].astype(out_dtype) - * b_vec[on, ki, vn].astype(out_dtype), - axis=[ki], - ), - name="matmul", - attrs=attrs, - ) - - output = tvm.compute( - oshape, - lambda b, n, m: ma[b][n // vc_][m // vw_][m % vw_][n % vc_], - name="output_unpack", - tag="asm_matmul_output", - ) - elif layout == "NC": - w_div = w_ // vw_ - n_div = n_ // vc_ - avshape = (w_div, vw_, k_) - bvshape = (n_div, k_, vc_) - ovshape = (w_div, n_div, vw_, vc_) - oshape = (w_div * vw_, n_div * vc_) - - a_vec = tvm.compute( - avshape, lambda om, vw, ci: a_[om * vw_ + vw][ci], name="a_vec" - ) - b_vec = tvm.compute( - bvshape, lambda on, ci, vc: b_[on * vc_ + vc][ci], name="b_vec" - ) - - ma = tvm.compute( - ovshape, - lambda om, on, vm, vn: tvm.sum( - a_vec[om, vm, ki].astype(out_dtype) - * b_vec[on, ki, vn].astype(out_dtype), - axis=[ki], - ), - name="matmul", - attrs=attrs, - ) - - output = tvm.compute( - oshape, - lambda m, n: ma[m // vw_][n // vc_][m % vw_][n % vc_], - name="output_unpack", - tag="asm_matmul_output", - ) - else: - raise ValueError("not support layout: " + layout) - - return output - - -def intrin_conv(args): - """intrin_conv is a conv inner interface""" - ( - ndim, - ci_, - vh_, - vw_, - vc_, - _, - _, - _, - _, - _, - _, - _, - _, - dtype, - acum_dtype, - opname, - core_id, - ) = args - ci_ = tvm.var("ci_") if ci_ is None else ci_ - kvshape = (ci_, vc_) - if ndim == 2: - dvshape = (vw_, ci_) - ovshape = (vw_, vc_) - - data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) - kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) - ci = tvm.reduce_axis((0, ci_), name="ci") - conv = tvm.compute( - ovshape, - lambda vw, vc: tvm.sum( - data_vec[vw, ci].astype(acum_dtype) - * kernel_vec[ci, vc].astype(acum_dtype), - axis=[ci], - ), - name="conv", - ) - else: - dvshape = (vh_, vw_, ci_) - ovshape = (vh_, vw_, vc_) - - data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) - kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) - ci = tvm.reduce_axis((0, ci_), name="ci") - conv = tvm.compute( - ovshape, - lambda vh, vw, vc: tvm.sum( - data_vec[vh, vw, ci].astype(acum_dtype) - * kernel_vec[ci, vc].astype(acum_dtype), - axis=[ci], - ), - name="conv", - ) - - stride_a = [ - functools.reduce(lambda x, y: x * y, dvshape[i + 1: len(dvshape)]) - for i in range(0, len(dvshape) - 1) - ] - stride_a.append(1) - stride_b = [ - functools.reduce(lambda x, y: x * y, kvshape[i + 1: len(kvshape)]) - for i in range(0, len(kvshape) - 1) - ] - stride_b.append(1) - stride_c = [ - functools.reduce(lambda x, y: x * y, ovshape[i + 1: len(ovshape)]) - for i in range(0, len(ovshape) - 1) - ] - stride_c.append(1) - - ab_ = tvm.decl_buffer( - data_vec.shape, data_vec.dtype, name="a_", offset_factor=1, strides=stride_a - ) - bb_ = tvm.decl_buffer( - kernel_vec.shape, kernel_vec.dtype, name="b_", offset_factor=1, strides=stride_b - ) - cb_ = tvm.decl_buffer( - conv.shape, conv.dtype, name="C", offset_factor=1, strides=stride_c - ) - - def intrin_func(ins, outs): - aa, bb = ins - cc = outs[0] - - def _body(): - b_ = tvm.ir_builder.create() - b_.emit( - tvm.call_extern( - "int32", - opname, - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"), - ci_, - vh_, - vw_, - vc_, - core_id, - ) - ) - return b_.get() - - return _body() - - return tvm.decl_tensor_intrin( - conv.op, intrin_func, binds={data_vec: ab_, kernel_vec: bb_, conv: cb_} - ) - - -def _schedule_asm(cfg, s, a_vec, b_vec, mat, output, last): - """schedule implementation""" - is_transpose = 0 if not isinstance(cfg, dict) else cfg["is_transpose"] - attrs = mat.op.attrs - vh_, vw_, vc_ = (attrs["vh_"].value, attrs["vw_"].value, attrs["vc_"].value) - - # axis split and reorder - if len(a_vec.shape) == 3: - ow, oc = s[last].op.axis - oc, vc = s[last].split(oc, vc_) - ow, vw = s[last].split(ow, vw_) - s[last].reorder(ow, oc, vw, vc) - s[last].vectorize(vc) - oh = ow = oc - elif len(a_vec.shape) == 4: - n, oc, ow, vw, vc = s[last].op.axis - oc, vc = s[last].split(oc, vc_) - ow, vw = s[last].split(ow, vw_) - s[last].reorder(n, oc, ow, vw, vc) - elif len(a_vec.shape) == 6: - if is_transpose: - n, oh, ow, oc = s[last].op.axis - else: - n, oc, oh, ow = s[last].op.axis - oc, vc = s[last].split(oc, vc_) - oh, vh = s[last].split(oh, vh_) - ow, vw = s[last].split(ow, vw_) - s[last].reorder(n, oc, oh, ow, vh, vw, vc) - else: - raise ValueError("not support a_vec: " + str(len(a_vec.shape))) - if last != output and isinstance(output.op, tvm.tensor.ComputeOp): - s[output].compute_inline() - - s[mat].compute_at(s[last], ow) - s[mat].vectorize(s[mat].op.axis[-1]) - - # mark parallel - s[last].parallel(oh) - - if len(a_vec.shape) == 3: - om, _, _ = s[a_vec].op.axis - s[a_vec].compute_at(s[last], ow) - s[a_vec].parallel(om) - elif len(a_vec.shape) == 4: - _, om, _, _ = s[a_vec].op.axis - s[a_vec].compute_at(s[last], ow) - s[a_vec].parallel(om) - else: - _, oh, _, _, _, _ = s[a_vec].op.axis - s[a_vec].parallel(oh) - s[a_vec].vectorize(s[a_vec].op.axis[-1]) - s[a_vec].compute_inline() - - oc, _, _ = s[b_vec].op.axis - s[b_vec].parallel(oc) - s[b_vec].vectorize(s[b_vec].op.axis[-1]) - s[b_vec].compute_inline() - return s - - -def _matmul_schedule_asm(cfg, outs): - """schedule_conv2d_nchw schedule implementation""" - s = tvm.create_schedule([x.op for x in outs]) - - def _callback(op): - if "asm_matmul_output" in op.tag: - # schedule conv2d - output = op.output(0) - mat = op.input_tensors[0] - - sidx = 0 - if mat.op.input_tensors[0].name == "attr": - sidx = 1 - a_vec = mat.op.input_tensors[sidx] - b_vec = mat.op.input_tensors[sidx + 1] - - def recurs_inline(a_): - if a_.op.input_tensors: - a1 = a_.op.input_tensors[0] - if a1.shape == a_.shape: - s[a1].compute_inline() - recurs_inline(a1) - - def recurs_inline_(a_): - if isinstance(a_, tvm.tensor.ComputeOp): - if a_.op.input_tensors: - a1 = a_.op.input_tensors[0] - s[a1].compute_inline() - recurs_inline_(a1) - - recurs_inline_(a_vec) - recurs_inline_(b_vec) - - _schedule_asm(cfg, s, a_vec, b_vec, mat, output, outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s diff --git a/predict/module/tvm_kernel/lite/python/at_ops/__init__.py b/predict/module/tvm_kernel/lite/python/at_ops/__init__.py deleted file mode 100644 index 274ad9a7e5..0000000000 --- a/predict/module/tvm_kernel/lite/python/at_ops/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Neural network operators""" -# from .at_lib import * -# from .at_gen import * diff --git a/predict/module/tvm_kernel/lite/python/at_ops/at_gen_strip.py b/predict/module/tvm_kernel/lite/python/at_ops/at_gen_strip.py deleted file mode 100644 index 519740c6fe..0000000000 --- a/predict/module/tvm_kernel/lite/python/at_ops/at_gen_strip.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -This module is rule to generation tvm operate. you can use it like: -python3 at_gen_strip.py [x86:arm64:arm32] -""" -import os -import sys -import itertools -from functools import partial -from at_ops.at_lib import Deconv, tvm, ConvVar, BatchNorm, Eltwise, Resize, CaffeCrop, CaffePReLU -from at_ops.at_lib import FullConnection, Power, ArgMax, Concat, Pad, Pooling, Mean, MatMul, Softmax -from at_ops.at_lib import Activation, Exp, Split, Cast, ExpandDims, Tile, Range -from at_rt import at_runtime_reset - - -check_correctness = False -ARCH_TYPE = sys.argv[1] - -dtypes = ("float32",) # "float16", "uint8", "int8", "uint32", "int32" - -device_map = { - "x86": "llvm", - "arm64": "llvm -device=arm_cpu -model=kirin970 -target=arm64-linux-android", - "arm32": "llvm -device=arm_cpu -model=kirin970 -target=armv7a-linux-eabi -mfloat-abi=soft", -} - -lib_path_map = { - "x86": "../../../build/lib_x86/", - "arm64": "../../../build/lib_arm64/", - "arm32": "../../../build/lib_arm32/", -} - -best_log_map = { - "x86": None, - "arm64": None, - "arm32": None, -} - -lib_path = lib_path_map[ARCH_TYPE] -device = device_map[ARCH_TYPE] -if ARCH_TYPE == "arm64": - if dtypes[0] == "float16": - device += " -mattr=+fp16fml" - else: - device += " -mattr=+neon" -best_log = best_log_map[ARCH_TYPE] - -kwargs = { - "device": device, - "lib_path": lib_path, - "check_correctness": check_correctness, -} - -use_arm32 = ARCH_TYPE == "arm32" - -MAX_DIMS = 5 -const_op_list = [ - ( - "Deconvolution", - partial(Deconv, optype="Deconvolution"), - { - "ndim": (5,), - "dtype": dtypes, - "kernels": ((2, 2),), - "strides": ((2, 2),), - "pad": ((0, 0, 0, 0),), - "dilations": ((1, 1),), - "hasbias": (False, True), - "activation_type": ("NO_ACTIVATION",), - "cfg": [ - { - "CI": tvm.var("CI"), - "VH": 2, - "VW": 12, - "VC": 4, - "VI": 4, - "tile_oh": 2, - "tile_ow": 12, - "tile_co": 4, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - }, - { - "CI": tvm.var("CI"), - "VH": 2, - "VW": 10, - "VC": 4, - "VI": 4, - "tile_oh": 2, - "tile_ow": 10, - "tile_co": 4, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - }, - { - "CI": tvm.var("CI"), - "VH": 2, - "VW": 16, - "VC": 4, - "VI": 4, - "tile_oh": 2, - "tile_ow": 16, - "tile_co": 4, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - }, - { - "CI": tvm.var("CI"), - "VH": 2, - "VW": 8, - "VC": 4, - "VI": 4, - "tile_oh": 2, - "tile_ow": 8, - "tile_co": 4, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - }, - { - "CI": tvm.var("CI"), - "VH": 2, - "VW": 4, - "VC": 4, - "VI": 4, - "tile_oh": 2, - "tile_ow": 4, - "tile_co": 4, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - }, - { - "CI": tvm.var("CI"), - "VH": 2, - "VW": 2, - "VC": 4, - "VI": 4, - "tile_oh": 2, - "tile_ow": 2, - "tile_co": 4, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - }, - ], - }, - ), - ( - "Convolution", - partial(ConvVar, optype="Convolution"), - { - "ndim": (4,), - "layout": ("NCHW",), - "dtype": dtypes, - "kernels": ((1, 1), (3, 3), (5, 5),), - "strides": ((1, 1), (2, 2)), - "pad": ((1, 1, 1, 1), (0, 0, 0, 0), (2, 2, 2, 2)), - "dilations": ((1, 1),), - "hasbias": (False, True), - "activation_type": ("NO_ACTIVATION", "RELU"), - "cfg": [ - { - "CI": tvm.var("CI"), - "VH": 1, - "VW": 1, - "VC": 1, - "VI": 1, - "tile_oh": 1, - "tile_ow": 1, - "tile_co": 1, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - "core_id": 0, - }, - ], - }, - ), - ( - "ConvolutionDepthwise", - partial(ConvVar, optype="ConvolutionDepthwise"), - { - "ndim": (4,), - "layout": ("NCHW",), - "dtype": dtypes, - "kernels": ((2, 2), (3, 3),), - "strides": ((1, 1),), - "pad": ((0, 0, 0, 0), (0, 1, 0, 1), (1, 0, 1, 0), (1, 1, 1, 1),), - "dilations": ((1, 1),), - "hasbias": (False, True), - "activation_type": ("NO_ACTIVATION", "RELU"), - "channel_multiplier": (1,), - "cfg": [ - { - "CI": tvm.var("CI"), - "VH": 1, - "VW": 1, - "VC": 1, - "VI": 1, - "tile_oh": 1, - "tile_ow": 1, - "tile_co": 1, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - "core_id": 0, - }, - ], - }, - ), - ( - "DeConvolutionDepthwise", - partial(ConvVar, optype="DeConvolutionDepthwise"), - { - "ndim": (4,), - "layout": ("NCHW",), - "dtype": dtypes, - "kernels": ((1, 1), (2, 2), (3, 3),), - "strides": ((1, 1), (2, 2),), - "pad": ((0, 0, 0, 0), (1, 0, 1, 0), (1, 1, 1, 1),), - "dilations": ((1, 1),), - "hasbias": (False, True), - "activation_type": ("NO_ACTIVATION", "RELU"), - "channel_multiplier": (1,), - "cfg": [ - { - "CI": tvm.var("CI"), - "VH": 1, - "VW": 1, - "VC": 1, - "VI": 1, - "tile_oh": 1, - "tile_ow": 1, - "tile_co": 1, - "ann_reduce": ["none", "unroll"], - "ann_spatial": ["unroll", "unroll", "vec"], - "core_id": 0, - }, - ], - }, - ), - ( - "BatchNorm", - BatchNorm, - {"ndim": (4,), "dtype": dtypes, "optype": ("TFBatchNorm",), "axis": (1, 3,)}, - ), - ( - "BiasAdd", - BatchNorm, - {"ndim": (2, 4), "dtype": dtypes, "optype": ("TFBiasAdd",), "axis": (1, 3)}, - ), - ( - "CaffeBatchNorm", - BatchNorm, - {"ndim": (2, 4), "dtype": dtypes, "optype": ("CaffeBatchNorm",), "axis": (1, 3)}, - ), - ( - "Scale", - BatchNorm, - {"ndim": (2, 4), "dtype": dtypes, "optype": ("CaffeScale",), "axis": (1,)}, - ), - ( - "Eltwise", - Eltwise, - { - "ndim_a": tuple(range(0, MAX_DIMS + 1)), - "ndim_b": tuple(range(0, MAX_DIMS + 1)), - "dtype": dtypes, - "mode": ("add", "subtract", "multiply", "divide", "maximum"), - }, - ), - ( - "Add", - Eltwise, - { - "ndim_a": tuple(range(0, MAX_DIMS + 1)), - "ndim_b": tuple(range(0, MAX_DIMS + 1)), - "dtype": dtypes, - "mode": ("add",), - }, - ), - ( - "Sub", - Eltwise, - { - "ndim_a": tuple(range(0, MAX_DIMS + 1)), - "ndim_b": tuple(range(0, MAX_DIMS + 1)), - "dtype": dtypes, - "mode": ("subtract",), - }, - ), - ( - "Mul", - Eltwise, - { - "ndim_a": tuple(range(0, MAX_DIMS + 1)), - "ndim_b": tuple(range(0, MAX_DIMS + 1)), - "dtype": dtypes, - "mode": ("multiply",), - }, - ), - ( - "RealDiv", - Eltwise, - { - "ndim_a": tuple(range(0, MAX_DIMS + 1)), - "ndim_b": tuple(range(0, MAX_DIMS + 1)), - "dtype": dtypes, - "mode": ("divide",), - }, - ), - ( - "Maximum", - Eltwise, - { - "ndim_a": tuple(range(0, MAX_DIMS + 1)), - "ndim_b": tuple(range(0, MAX_DIMS + 1)), - "dtype": dtypes, - "mode": ("maximum",), - }, - ), - ( - "ResizeBilinear", - Resize, - { - "ndim": (4,), - "dtype": dtypes, - "method": ("bilinear",), # "bicubic" - "align_corners": (True, False), - }, - ), - ( - "ResizeNearestNeighbor", - Resize, - { - "ndim": (4,), - "dtype": dtypes, - "method": ("nearest_neighbor",), # "bicubic" - "align_corners": (True, False), - }, - ), - ( - "CaffeCrop", - CaffeCrop, - {"ndim": (4,), "dtype": dtypes, "axis": tuple(range(0, 4))}, - ), - ( - "CaffePReLU", - CaffePReLU, - {"ndim": (2, 4), "dtype": dtypes, "channel_shared": (True, False)}, - ), - ( - "FullConnection", - FullConnection, - {"ndim_a": (2, 4), "dtype": dtypes, "has_bias": (True, False)}, - ), - ("Power", Power, {"ndim": tuple(range(1, MAX_DIMS + 1)), "dtype": dtypes}), - ( - "ArgMax", - ArgMax, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "dtype": dtypes, - "axis": tuple(range(0, MAX_DIMS)), # not support None - "keep_dims": (True, False), - "top_k": (1,), - "out_dtype": dtypes, - }, - ), - ( - "Concat", - Concat, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "dtype": dtypes, - "input_num": tuple(range(2, 6 + 1)), - "axis": tuple(range(0, MAX_DIMS)), - }, - ), - ( - "Pad", - Pad, - { - "ndim": tuple(range(2, MAX_DIMS + 1)), - "dtype": dtypes, - "paddingmode": ("CONSTANT", "REFLECT", "SYMMETRIC"), - }, - ), - ( - "Pooling", - Pooling, - { - "ndim": (4,), - "dtype": dtypes, - "pooling_mode": ("max", "avg"), - "caffe_mode": (True, False), - "kernel": ((1, 1), (2, 2), (3, 3), (5, 5)), - "stride": ((1, 1), (2, 2), (3, 3)), - "pad": ((0, 0, 0, 0), (0, 1, 0, 1), (1, 1, 1, 1)), - "use_global": (True, False), - }, - ), - ( - "Mean", - Mean, - { - "ndim": (4,), - "dtype": dtypes, - "axis": ( - (0,), - (1,), - (2,), - (3,), - (0, 1), - (0, 2), - (0, 3), - (1, 2), - (1, 3), - (2, 3), - (0, 1, 2), - (0, 1, 3), - (0, 2, 3), - (1, 2, 3), - (0, 1, 2, 3), - ), - "keep_dims": (True, False), - }, - ), - ( - "MatMul", - MatMul, - { - "ndim_a": (2,), - "ndim_b": (2,), - "dtype": dtypes, - "transpose_a": (True, False), - "transpose_b": (True, False), - }, - ), - ( - "Softmax", - Softmax, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "dtype": dtypes, - "axis": tuple(range(0, MAX_DIMS)), - }, - ), - ( - "Activation", - Activation, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "dtype": dtypes, - "optype": ("NO_ACTIVATION", "RELU", "RELU6", "SIGMOID"), - }, - ), - ("Exp", Exp, {"ndim": tuple(range(1, MAX_DIMS + 1)), "dtype": dtypes}), - ( - "Split", - Split, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "dtype": dtypes, - "output_num": tuple(range(1, 5)), - "axis": tuple(range(0, MAX_DIMS)), - }, - ), - ( - "Cast", - Cast, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "src_dtype": dtypes, - "dst_dtype": dtypes, - }, - ), - ( - "ExpandDims", - ExpandDims, - { - "ndim": tuple(range(1, MAX_DIMS + 1)), - "dtype": dtypes, - "axis": tuple(range(0, MAX_DIMS)), - }, - ), - ("Tile", Tile, {"ndim": tuple(range(1, MAX_DIMS + 1)), "dtype": dtypes}), - ("Range", Range, {"out_dtype": ("float32", "uint32", "int32")}), -] - - -def gen_const_libs(some_op=None): - for optype, func, attr in const_op_list: - if some_op and some_op != optype: - continue - for values in itertools.product(*attr.values()): - args = dict((k, v) for k, v in zip(attr.keys(), values)) - func(device=device, lib_path=lib_path, **args) - - -if __name__ == "__main__": - if not os.path.exists(lib_path): - os.makedirs(lib_path) - # skip best_history log: - with tvm.target.create(device): - with at_runtime_reset.AtRuntimeReset(): - gen_const_libs() diff --git a/predict/module/tvm_kernel/lite/python/at_ops/at_lib.py b/predict/module/tvm_kernel/lite/python/at_ops/at_lib.py deleted file mode 100644 index 655064b29e..0000000000 --- a/predict/module/tvm_kernel/lite/python/at_ops/at_lib.py +++ /dev/null @@ -1,1193 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -This module is rule to generation tvm operate, call by at_gen_strip.py -""" -import numpy as np -import tvm -import topi -from topi.image import resize -from topi.nn import mirror_pad -from topi import tag -import topi.testing - -from arm_cpu.deconv import _conv_spatial_pack_deconv, schedule_conv2d_nchw_arm_cpu_deconv -from arm_cpu.conv2d import _conv_spatial_pack_asm, schedule_conv2d_nchw_arm_cpu -from arm_cpu.matmul import _matmul_spatial_pack_asm, _matmul_schedule_asm -from arm_cpu.depthwise_conv2d import _depthwise_spatial_pack, schedule_depthwise_conv2d_nchw_arm -from config_tool import activation_enum_map - -map_conv = { - 'Convolution': "Conv2D", - 'ConvolutionDepthwise': "DepthwiseConv2D", - 'Deconvolution': "DeConv2D", - 'DeConvolutionDepthwise': "DeDepthwiseConv2D", -} - - -def Genlib(sch, tensor_list, device, opname, lib_path, print_lower=False): - if print_lower: - print(tvm.lower(sch, tensor_list, simple_mode=True)) - ctx = tvm.context(device, 0) - func_o = tvm.build(sch, tensor_list, device + " --system-lib", name=opname) - func_so = tvm.build(sch, tensor_list, device, name=opname) - func_o.save(lib_path + opname + ".o", "o") - return func_o, func_so, ctx - - -def AsType(as_input, dtype): - if as_input.dtype == dtype: - return as_input - return tvm.compute(as_input.shape, - lambda *i: as_input(*i).astype(dtype), - tag="injective") - - -@tvm.tag_scope(tag=tag.ELEMWISE) -def TopiNNrelu6(x): - return tvm.compute(x.shape, lambda *i: tvm.min(tvm.max(x(*i), tvm.const(0, x.dtype)), tvm.const(6, x.dtype))) - - -def TopiActivation(in_tensor, a_type, memcpy=False): - ''' - activativation - Args: - in_tensor: - a_type: - memcpy: - - Returns: - ''' - if a_type == 'NO_ACTIVATION': - if memcpy: - return tvm.compute(in_tensor.shape, lambda *i: in_tensor[i], tag=tag.ELEMWISE) - return in_tensor - if a_type == 'RELU': - return topi.nn.relu(in_tensor) - if a_type == 'RELU6': - return TopiNNrelu6(in_tensor) - if a_type == 'SIGMOID': - if in_tensor.dtype in ["uint8", "int8", "uint32", "int32"]: - a_fp32 = AsType(in_tensor, 'float32') - out_tensor = topi.sigmoid(a_fp32) - return AsType(out_tensor, in_tensor.dtype) - return topi.sigmoid(in_tensor) - raise ValueError("not support activation type" + a_type) - - -def Deconv(device="llvm", lib_path="./", optype=None, - ndim=None, dtype=None, kernels=None, - strides=None, pad=None, dilations=None, - hasbias=None, activation_type=None, - config_entity=None, impl_dtype=None, - use_arm32=False, cfg=None): - ''' - Deconvolution - Args: - device: - lib_path: - optype: - ndim: - dtype: - kernels: - strides: - pad: - dilations: - hasbias: - activationType: - configEntity: - impl_dtype: - use_arm32: - cfg: - - Returns: - ''' - if cfg is None: - cfg = {'CI': tvm.var('ci'), 'VH': 2, 'VW': 2, 'VC': 4, 'VI': 4, - 'tile_oh': 2, 'tile_ow': 2, 'tile_co': 4, - 'ann_reduce': ['none', 'none'], - "ann_spatial": ['none', 'none', 'none'] - } - has_bias = hasbias - batch = tvm.var("batch") - in_channel = tvm.var("in_channel") - in_height, in_width = tvm.var("in_height"), tvm.var("in_width") - kh, kw = kernels - ow = cfg['VW'] - oh = cfg['VH'] - oc = cfg['VC'] - op_name = "%s_ndim%d_%s_k%d_s%d_p%d%d%d%d_d%d_act%d_vc%d_vh%d_vw%d_hasbias%d" % (\ - map_conv[optype], ndim, dtype,\ - kh, strides[0], pad[0], pad[1], pad[2], pad[3], dilations[0],\ - activation_enum_map[activation_type], oc, oh, ow, hasbias) - opname = op_name - print("DEconv", opname, config_entity) - - if impl_dtype is None: - impl_dtype = dtype - - out_channel = tvm.var("out_channel") - - # define placeholder - input_tensor = in_tensor = tvm.placeholder((batch, in_channel, in_height, in_width, 4), \ - dtype=dtype, name='in_tensor') - temp_tensor = kernel_tensor = tvm.placeholder((in_channel*4, out_channel, kh, kw), dtype=dtype, \ - name='kernel_tensor') - if has_bias: - bias = tvm.placeholder((out_channel,), dtype=dtype, name='bias') - bias1 = topi.reshape(bias, (out_channel, 1, 1)) - - if impl_dtype != dtype: - input_tensor = AsType(input_tensor, impl_dtype) - temp_tensor = AsType(temp_tensor, impl_dtype) - if has_bias: - bias1 = AsType(bias1, impl_dtype) - - # define compute & schedule - cfg1 = (True, 1, 1, 1) if cfg is None else (True, cfg["tile_oh"], cfg["tile_ow"], cfg["tile_co"]) - out_tensor = _conv_spatial_pack_deconv(cfg1, input_tensor, temp_tensor, out_dtype=impl_dtype) - - if has_bias: - out_tensor = tvm.compute(out_tensor.shape, lambda n, co, h, w, c4: \ - out_tensor[n, co, h, w, c4] + bias1[co*4 + c4][0][0], tag="injective") - out_tensor = TopiActivation(out_tensor, activation_type) - if impl_dtype != dtype: - out_tensor = AsType(out_tensor, dtype) - - # create schedule - if use_arm32: - s = tvm.create_schedule(out_tensor.op) - else: - s = schedule_conv2d_nchw_arm_cpu_deconv(cfg, [out_tensor]) - - attr = [batch, in_channel, in_height, in_width, out_channel, in_tensor, kernel_tensor] - if has_bias: attr.append(bias) - attr.append(out_tensor) - tensor_list = attr - - Genlib(s, tensor_list, device, opname, lib_path) - - -def ConvVar(device="llvm", lib_path="./", optype=None,\ - ndim=None, layout=None, dtype=None, kernels=None,\ - strides=None, pad=None, dilations=None,\ - hasbias=None, activation_type=None,\ - config_entity=None, impl_dtype=None, channel_multiplier=None,\ - use_arm32=False, cfg=None): - ''' - convolution - Args: - device: - lib_path: - optype: - ndim: - layout: - dtype: - kernels: - strides: - pad: - dilations: - hasbias: - activationType: - configEntity: - impl_dtype: - channel_multiplier: - use_arm32: - cfg: - - Returns: - ''' - use_depthwise = optype == 'ConvolutionDepthwise' - use_deconv = optype == 'Deconvolution' - use_deconv_depthwise = optype == 'DeConvolutionDepthwise' - has_bias = hasbias - - ow = 1 if cfg is None else cfg['VW'] - oh = 1 if cfg is None else cfg['VH'] - oc = 1 if cfg is None else cfg['VC'] - kh, kw = kernels - op_name = "%s_ndim%d_%s_k%d_s%d_p%d%d%d%d_d%d_act%d_vc%d_vh%d_vw%d_hasbias%d" % ( \ - map_conv[optype], ndim, dtype, \ - kh, strides[0], pad[0], pad[1], pad[2], pad[3], dilations[0], \ - activation_enum_map[activation_type], oc, oh, ow, hasbias) - batch = tvm.var("batch") - in_channel = tvm.var("in_channel") - in_height, in_width = tvm.var("in_height"), tvm.var("in_width") - pad_up, pad_down, pad_left, pad_right = pad - opname = op_name - - print("Conv", opname, config_entity) - - if impl_dtype is None: - impl_dtype = dtype - - if use_depthwise: - multiplier = channel_multiplier - out_channel = in_channel * multiplier - elif use_deconv_depthwise: - multiplier = channel_multiplier - out_channel = in_channel * multiplier - else: - out_channel = tvm.var("out_channel") - - # define placeholder - input_tensor = in_tensor = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=dtype, name='in_tensor') - - if use_depthwise: - temp_tensor = kernel_tensor = tvm.placeholder((in_channel, multiplier, kh, kw), dtype=dtype,\ - name='kernel_tensor') - elif use_deconv: - temp_tensor = kernel_tensor = tvm.placeholder((in_channel, out_channel, kh, kw), dtype=dtype,\ - name='kernel_tensor') - elif use_deconv_depthwise: - temp_tensor = kernel_tensor = tvm.placeholder((in_channel, multiplier, kh, kw), dtype=dtype,\ - name='kernel_tensor') - else: - temp_tensor = kernel_tensor = tvm.placeholder((out_channel, in_channel, kh, kw), dtype=dtype,\ - name='kernel_tensor') - if has_bias: - bias = tvm.placeholder((out_channel,), dtype=dtype, name='bias') - bias1 = topi.reshape(bias, (out_channel, 1, 1)) - - if impl_dtype != dtype: - input_tensor = AsType(input_tensor, impl_dtype) - temp_tensor = AsType(temp_tensor, impl_dtype) - if has_bias: - bias1 = AsType(bias1, impl_dtype) - - # define compute & schedule - if pad_up != pad_down or pad_left != pad_right: - input_tensor = topi.nn.pad(input_tensor, [0, 0, pad_up, pad_left], [0, 0, pad_down, pad_right], name='data_pad') - padding = 0, 0 - else: - padding = pad_up, pad_left - if use_depthwise: - cfg1 = (True, 1, 1, 1) if cfg is None else (True, cfg["tile_oh"], cfg["tile_ow"], cfg["tile_co"]) - out_tensor = _depthwise_spatial_pack(cfg1, input_tensor, temp_tensor, strides, padding, dilations,\ - out_dtype=impl_dtype) - elif use_deconv: - - def GetInput(input_tensor, temp_tensor, padding): - _, out_c, filter_h, filter_w = temp_tensor.shape - if out_c is None: - print("temp_tensor.shape err") - stride_h, stride_w = strides - # dilate stage - dilated_input = topi.nn.dilate(input_tensor, [1, 1, stride_h, stride_w], - name='DilatedInput') - # padding stage - fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(padding, ( - filter_h, filter_w)) - bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom - bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right - padded_input = topi.nn.pad(dilated_input, \ - [0, 0, bpad_top, bpad_left], \ - [0, 0, bpad_bottom, bpad_right], \ - name='PaddedInput') - return padded_input - - special_deconv = kh == 2 and kw == 2 and strides[0] == 2 and strides[1] == 2 - # special_deconv = False - if special_deconv: - out_tensor = OptimalOut(input_tensor, temp_tensor, in_channel) - else: - out_tensor = BaseImplementation(input_tensor, temp_tensor, GetInput, layout, padding) - elif use_deconv_depthwise: - def GetInput(input_tensor, temp_tensor, padding): - _, out_c, filter_h, filter_w = temp_tensor.shape - if out_c is None: - print("temp_tensor.shape err") - stride_h, stride_w = strides - # dilate stage - dilated_input = topi.nn.dilate(input_tensor, [1, 1, stride_h, stride_w], - name='DilatedInput') - # padding stage - fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(padding, ( - filter_h, filter_w)) - bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom - bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right - padded_input = topi.nn.pad(dilated_input, \ - [0, 0, bpad_top, bpad_left], \ - [0, 0, bpad_bottom, bpad_right], \ - name='PaddedInput') - return padded_input - - temp_tensor = topi.flip(temp_tensor, axis=-1) - temp_tensor = topi.flip(temp_tensor, axis=-2) - out_tensor = topi.nn.depthwise_conv2d_nchw(GetInput(input_tensor, temp_tensor, padding), temp_tensor, (1, 1), \ - padding, (1, 1), out_dtype=input_tensor.dtype) - else: - cfg1 = (True, 1, 1, 1) if cfg is None else (True, cfg["tile_oh"], cfg["tile_ow"], cfg["tile_co"]) - out_tensor = _conv_spatial_pack_asm(cfg1, input_tensor, temp_tensor, strides, padding, dilations,\ - out_dtype=impl_dtype) - - if has_bias: - out_tensor = tvm.compute(out_tensor.shape, lambda n, co, h, w: out_tensor[n, co, h, w] + bias1[co][0][0],\ - tag="injective") - out_tensor = TopiActivation(out_tensor, activation_type) - if impl_dtype != dtype: - out_tensor = AsType(out_tensor, dtype) - - # create schedule - if use_arm32: - s = tvm.create_schedule(out_tensor.op) - elif use_depthwise: - s = schedule_depthwise_conv2d_nchw_arm(cfg, [out_tensor]) - elif use_deconv: - if special_deconv: - s = tvm.create_schedule([out_tensor.op]) - else: - s = topi.generic.schedule_conv2d_nchw([out_tensor]) - elif use_deconv_depthwise: - s = tvm.create_schedule([out_tensor.op]) - else: - s = schedule_conv2d_nchw_arm_cpu([out_tensor]) - - # generate lib - attr = [batch, in_channel, in_height, in_width, out_channel, in_tensor, kernel_tensor] - tensor_list = [*attr, bias, out_tensor] if has_bias else [*attr, out_tensor] - Genlib(s, tensor_list, device, opname, lib_path) - - -def BaseImplementation(input_tensor, temp_tensor, get_input, layout, padding): - temp_tensor = topi.flip(temp_tensor, axis=-1) - temp_tensor = topi.flip(temp_tensor, axis=-2) - temp_tensor = topi.transpose(temp_tensor, axes=(1, 0, 2, 3)) - out_tensor = topi.nn.conv2d(get_input(input_tensor, temp_tensor, padding), temp_tensor, (1, 1), padding, (1, 1), - layout=layout, out_dtype=input_tensor.dtype) - return out_tensor - - -def OptimalOut(input_tensor, temp_tensor, in_channel): - ''' - deconv compute - Args: - input_tensor: - temp_tensor: - in_channel: - - Returns: - ''' - temp_tensor = topi.transpose(temp_tensor, axes=(1, 0, 2, 3)) - out_shape = [] - for i in range(len(input_tensor.shape)): - if i == 0: - out_shape.append(input_tensor.shape[i]) - continue - if i == 1: - out_shape.append(temp_tensor.shape[0]) - continue - out_shape.append(2 * input_tensor.shape[i]) - rc = tvm.reduce_axis((0, in_channel), name='rc') - return tvm.compute(out_shape, lambda i, j, k, l:\ - tvm.sum(input_tensor[i, rc, k // 2, l // 2].astype(input_tensor.dtype) *\ - temp_tensor[j, rc, k % 2, l % 2].astype(input_tensor.dtype), axis=[rc])) - - -def Concat(device="llvm", lib_path="./", - ndim=None, dtype=None, input_num=None, axis=None): - ''' - concat - Args: - device: - lib_path: - all_tensors: - ndim: - dtype: - input_num: - axis: - - Returns: - ''' - if axis >= ndim: - return - shapes = [] - for i in range(input_num): - shape = [] - for j in range(ndim): - if j == axis: - shape.append(tvm.var("axis" + str(i))) - else: - shape.append(tvm.var("n" + str(j))) - shapes.append(shape) - in_tensor = [tvm.placeholder(shape, dtype=dtype, name='in_tensor%d' % i) for i, shape in enumerate(shapes)] - opname = "Concat_ndim%d_%s_input_num%d_axis%d" % (ndim, dtype, input_num, axis) - print(opname) - - # define compute - out_tensor = topi.concatenate(tuple(in_tensor), axis) - tensor_list = in_tensor + [out_tensor] - if ndim < 5: - s = topi.generic.schedule_concatenate(out_tensor) - else: - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Activation(device="llvm", lib_path="./", - ndim=None, dtype=None, optype=None): - ''' - activation - Args: - device: - lib_path: - ndim: - dtype: - optype: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - opname = "Activation_ndim%d_%s_%s" % (ndim, dtype, optype) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - out_tensor = TopiActivation(in_tensor, optype, memcpy=True) - tensor_list = [in_tensor, out_tensor] - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def BatchNorm(device="llvm", lib_path="./", - ndim=None, dtype=None, optype=False, axis=None): - ''' - batchnorm - Args: - device: - lib_path: - ndim: - dtype: - optype: - axis: - - Returns: - ''' - if axis >= ndim: - return - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - channel = shape[axis] - eps = tvm.var("epsilon", dtype="float32") - opname = optype + ("_ndim%d_%s_axis%d" % (ndim, dtype, axis)) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - mean = tvm.placeholder((channel,), dtype=dtype, name='mean') - variance = tvm.placeholder((channel,), dtype=dtype, name='var') - scale = tvm.placeholder((channel,), dtype=dtype, name='scale') - offset = tvm.placeholder((channel,), dtype=dtype, name='offset') - - variance_sqrt = tvm.compute((channel,), lambda i: tvm.sqrt(variance[i] + eps.astype(dtype))) - if optype == "TFBatchNorm": - out_tensor = tvm.compute(shape, lambda *idx: ((in_tensor[idx] - mean[idx[axis]]) / variance_sqrt[idx[axis]]) *\ - scale[idx[axis]] + offset[idx[axis]]) - tensor_list = [eps, in_tensor, scale, offset, mean, variance, out_tensor] - elif optype == "CaffeBatchNorm": - out_tensor = tvm.compute(shape, lambda *idx: (in_tensor[idx] - mean[idx[axis]]) / variance_sqrt[idx[axis]]) - tensor_list = [eps, in_tensor, mean, variance, out_tensor] - elif optype == "CaffeScale": - out_tensor = tvm.compute(shape, lambda *idx: in_tensor[idx] * scale[idx[axis]] + offset[idx[axis]]) - tensor_list = [in_tensor, scale, offset, out_tensor] - elif optype == "TFBiasAdd": - out_tensor = tvm.compute(shape, lambda *idx: in_tensor[idx] + offset[idx[axis]]) - tensor_list = [in_tensor, offset, out_tensor] - else: - raise RuntimeError("no support for {}".format(optype)) - - # define schedule & generate lib - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Pooling(device="llvm", lib_path="./", - ndim=None, dtype=None, pooling_mode=None, kernel=None, stride=None, pad=None, caffe_mode=None, - use_global=False): - ''' - pooling - Args: - device: - lib_path: - ndim: - dtype: - pooling_mode: - kernel: - stride: - pad: - caffe_mode: - use_global: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(0, ndim)] - layout = 'NCHW' - if use_global: - opname = "GlobalPooling_ndim%d_%s_%s" % (ndim, dtype, pooling_mode) - else: - kernel_h, kernel_w = kernel - stride_h, stride_w = stride - pad_up, pad_down, pad_left, pad_right = pad - if pad_up == 0 and pad_down == 0 and pad_left == 0 and pad_right == 0 and caffe_mode: - caffe_mode = False - opname = "Pooling_ndim%d_%s_%s_kernel%d%d_stride%d%d_pad%d%d%d%d%s" \ - % (ndim, dtype, pooling_mode, kernel_h, kernel_w, stride_h, stride_w, - pad_up, pad_down, pad_left, pad_right, "_caffe" if caffe_mode else "") - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - if use_global: - out_tensor = topi.nn.global_pool(in_tensor, pool_type=pooling_mode, layout=layout) - sch = topi.generic.schedule_adaptive_pool(out_tensor) - else: - out_tensor = topi.nn.pool(in_tensor, - kernel=(kernel_h, kernel_w), - stride=(stride_h, stride_w), - padding=(pad_up, pad_left, pad_down, pad_right), - pool_type=pooling_mode, - ceil_mode=False, - layout=layout, - count_include_pad=False) - sch = topi.generic.schedule_pool(out_tensor, layout) - tensor_list = [in_tensor, out_tensor] - Genlib(sch, tensor_list, device, opname, lib_path, print_lower=False) - - -def Eltwise(device="llvm", lib_path="./", - ndim_a=None, ndim_b=None, dtype=None, mode=None): - ''' - eltwise - Args: - device: - lib_path: - ndim_a: - ndim_b: - dtype: - mode: - - Returns: - ''' - ndim_max = max(ndim_a, ndim_b) - shape = [tvm.var("n" + str(i)) for i in range(ndim_max)] - shape_b1 = [dim if i == 1 else 1 for i, dim in enumerate(shape)] - shape_a = shape[ndim_max - ndim_a:] if ndim_a else (1,) - shape_b = shape[ndim_max - ndim_b:] if ndim_b == ndim_a else shape_b1 if ndim_b == 1 else (1,) - opname = "Eltwise_%s_ndimA%d_ndimB%d_%s" % (mode, ndim_a, ndim_b, dtype) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape_a, dtype=dtype, name='in_tensor') - b_tensor = tvm.placeholder(shape_b, dtype=dtype, name='b_tensor') - - topi_funs = { - 'add': topi.add, - 'subtract': topi.subtract, - 'multiply': topi.multiply, - 'divide': topi.divide, - 'maximum': topi.maximum, - 'minimum': topi.minimum, - } - - out_tensor = topi_funs[mode](in_tensor, b_tensor) - tensor_list = [in_tensor, b_tensor, out_tensor] - s = topi.generic.schedule_elemwise(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Softmax(device="llvm", lib_path="./", - ndim=None, dtype=None, axis=None): - ''' - softmax - Args: - device: - lib_path: - ndim: - dtype: - axis: - - Returns: - ''' - if axis >= ndim: - return - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - opname = "Softmax_ndim%d_%s_axis%s" % (ndim, dtype, axis) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - out_tensor = topi.nn.softmax(in_tensor, axis) - tensor_list = [in_tensor, out_tensor] - s = topi.generic.schedule_elemwise(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Resize(device="llvm", lib_path="./", - ndim=None, dtype=None, method=None, align_corners=None): - ''' - resize - Args: - device: - lib_path: - ndim: - dtype: - method: - align_corners: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - new_height = tvm.var("newHeight") - new_width = tvm.var("new_width") - opname = "Resize_ndim%d_%s_%s_%s" % (ndim, dtype, method, "Align" if align_corners else "NotAlign") - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - out_tensor = resize(in_tensor, [new_height, new_width], align_corners=align_corners, method=method) - tensor_list = [new_height, new_width, in_tensor, out_tensor] - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Mean(device="llvm", lib_path="./", - ndim=None, dtype=None, axis=None, keep_dims=None): - ''' - mean - Args: - device: - lib_path: - ndim: - dtype: - axis: - keepDims: - - Returns: - ''' - if axis[-1] >= ndim: - return - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - axis_str = "" - for dim in axis: - axis_str += str(dim) - opname = "Mean_ndim%d_%s_axis%s_%s" % (ndim, dtype, axis_str, "keepDims" if keep_dims else "notkeepDims") - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - c_shape = shape[:] - reduced_num = 1 - for dim in axis: - c_shape[dim] = 1 - reduced_num *= shape[dim] - - def _ComputeSum(*b_idx): - reduce_axis = [tvm.reduce_axis((0, shape[dim])) for dim in axis] - a_idx = list(b_idx) - for i, dim in enumerate(axis): - a_idx[dim] = reduce_axis[i] - a_idx = tuple(a_idx) - return tvm.sum(in_tensor[a_idx], axis=reduce_axis) - - out_tensor = tvm.compute(c_shape, _ComputeSum) - out_tensor = tvm.compute(c_shape, lambda *i: out_tensor(*i) / reduced_num) - if not keep_dims: - out_tensor = topi.squeeze(out_tensor, axis) - - # define schedule & generate lib - tensor_list = [in_tensor, out_tensor] - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def CaffeCrop(device="llvm", lib_path="./", - ndim=None, dtype=None, axis=None): - ''' - caffe crop op - Args: - device: - lib_path: - ndim: - dtype: - axis: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(axis)] - shape_a = shape[:] - shape_b = shape[:] - offsets = [] - for i in range(axis, ndim): - shape_a.append(tvm.var("nA" + str(i))) - shape_b.append(tvm.var("nB" + str(i))) - offsets.append(tvm.var("offset" + str(i))) - opname = "CaffeCrop_ndim%d_%s_axis%d" % (ndim, dtype, axis) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape_a, dtype=dtype, name='in_tensor') - b_tensor = tvm.placeholder(shape_b, dtype=dtype, name='b_tensor') - begin = [0] * axis + offsets - end = shape_a[:] - for i in range(axis, len(shape_a)): - end[i] = offsets[i - axis] + shape_b[i] - shape_c = [end[i] - begin[i] for i in range(ndim)] - - def _Compute(*C_idx): - a_idx = [idx + begin[i] for i, idx in enumerate(list(C_idx))] - a_idx = tuple(a_idx) - return in_tensor[a_idx] - - out_tensor = tvm.compute(shape_c, _Compute) - tensor_list = offsets + [in_tensor, b_tensor, out_tensor] - - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def FullConnection(device="llvm", lib_path="./", - ndim_a=None, dtype=None, has_bias=None): - ''' - full connection - Args: - device: - lib_path: - ndim_a: - dtype: - hasBias: - - Returns: - ''' - n_dim, ci, h_dim, kernel_tensor = (tvm.var("n_dim"), tvm.var("out_tensor"), tvm.var("h_dim"), \ - tvm.var("kernel_tensor")) - co = tvm.var("co") - if ndim_a == 4: - shape_a = (n_dim, ci, h_dim, kernel_tensor) - chw = ci * h_dim * kernel_tensor - else: - shape_a = (n_dim, ci) - chw = ci - shape_w = (co, chw) - opname = "FullConnection_ndimA%d_%s_%s" % (ndim_a, dtype, "hasBias" if has_bias else "notHasBias") - is_var = True - vh, vw, vc = 1, 1, 1 - print(opname) - - in_tensor = tvm.placeholder(shape_a, dtype=dtype, name='in_tensor') - kernel_tensor = tvm.placeholder(shape_w, dtype=dtype, name='kernel_tensor') - input_tensor = topi.reshape(in_tensor, (n_dim, chw)) if len(shape_a) == 4 else in_tensor - - out_tensor = _matmul_spatial_pack_asm((is_var, 0, ci, vh, vw, vc), input_tensor, kernel_tensor, \ - layout='NC', out_dtype=dtype) - if has_bias: - bias = tvm.placeholder((co,), dtype=dtype, name='bias') - out_tensor = tvm.compute((n_dim, co), lambda n, co: out_tensor[n, co] + bias[co], tag='injective') - - tensor_list = [in_tensor, kernel_tensor, bias, out_tensor] if has_bias else [in_tensor, kernel_tensor, out_tensor] - cfg = {'is_var': is_var, 'is_transpose': 0, 'core_id': 0, 'CI': ci, 'VH': vh, 'VW': vw, 'VC': vc} - s = _matmul_schedule_asm(cfg, [out_tensor]) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Power(device="llvm", lib_path="./", - ndim=None, dtype=None): - ''' - power - Args: - device: - lib_path: - ndim: - dtype: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - power = tvm.var("power", dtype="float32") - scale = tvm.var("scale", dtype="float32") - shift = tvm.var("shift", dtype="float32") - opname = "Power_ndim%d_%s" % (ndim, dtype) - print(opname) - - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - out_tensor = tvm.compute(shape, lambda *i: tvm.power(in_tensor[i] * scale.astype(in_tensor.dtype) + \ - shift.astype(in_tensor.dtype), \ - power.astype(in_tensor.dtype))) - tensor_list = [power, scale, shift, in_tensor, out_tensor] - - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def CaffePReLU(device="llvm", lib_path="./", - ndim=None, dtype=None, channel_shared=None): - ''' - caffe prelu - Args: - device: - lib_path: - ndim: - dtype: - channel_shared: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - channel = 1 if channel_shared else shape[1] - opname = "CaffePReLU_ndim%d_%s_%s" % (ndim, dtype, - "channelShared" if channel_shared else "channelNotShared") - print(opname) - - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - slope = tvm.placeholder((channel,), dtype=dtype, name='slope') - if channel_shared: - out_tensor = tvm.compute(shape, lambda *idx: tvm.if_then_else(in_tensor[idx] >= 0, in_tensor[idx],\ - in_tensor[idx] * slope[0])) - else: - out_tensor = tvm.compute(shape, lambda *idx: tvm.if_then_else(in_tensor[idx] >= 0, in_tensor[idx],\ - in_tensor[idx] * slope[idx[1]])) - - tensor_list = [in_tensor, slope, out_tensor] - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Pad(device="llvm", lib_path="./", - ndim=None, dtype=None, paddingmode=None): - ''' - pad - Args: - device: - lib_path: - ndim: - dtype: - paddingmode: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - pad_before = [tvm.var("pad_before" + str(i)) for i in range(ndim)] - pad_after = [tvm.var("pad_after" + str(i)) for i in range(ndim)] - pad_before_const = [0, 0] + pad_before[2:] - pad_after_const = [0, 0] + pad_after[2:] - paddings = [None] * 2 * len(shape) - paddings[0:: 2] = pad_before - paddings[1:: 2] = pad_after - pad_value = 0 - opname = "Pad_ndim%d_%s_%s" % (ndim, dtype, paddingmode) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - if paddingmode == "CONSTANT": - out_tensor = topi.nn.pad(in_tensor, pad_before_const, pad_after_const, pad_value=pad_value, name='out_tensor') - else: - out_tensor = mirror_pad(in_tensor, pad_before_const, pad_after_const, mode=paddingmode, name='out_tensor') - tensor_list = paddings + [in_tensor, out_tensor] - def SchedulePad(inputs): - s = tvm.create_schedule(inputs.op) - if s[inputs].op.axis: - s[inputs].parallel(s[inputs].op.axis[1]) - return s - - s = SchedulePad(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def MatMul(device="llvm", lib_path="./", - ndim_a=None, ndim_b=None, dtype=None, transpose_a=None, transpose_b=None): - ''' - matmul - Args: - device: - lib_path: - ndim_a: - ndim_b: - dtype: - transpose_a: - transpose_b: - - Returns: - ''' - m, k, n_dim = tvm.var("m"), tvm.var("k"), tvm.var("n_dim") - a_shape = (m, k) if not transpose_a else (k, m) - b_shape = (k, n_dim) if not transpose_b else (n_dim, k) - opname = "MatMul_ndimA%d_ndimB%d_%s_%d_%d" % (ndim_a, ndim_b, dtype, transpose_a, transpose_b) - print(opname) - - # define compute - in_tensor = tvm.placeholder(a_shape, dtype=dtype, name='in_tensor') - b_tensor = tvm.placeholder(b_shape, dtype=dtype, name='b_tensor') - out_tensor = topi.matmul(in_tensor, b_tensor, transpose_a, transpose_b) - tensor_list = [in_tensor, b_tensor, out_tensor] - s = topi.generic.schedule_elemwise(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Stack(device="llvm", lib_path="./", - ndim=None, dtype=None, input_num=None, axis=None): - ''' - stack - Args: - device: - lib_path: - ndim: - dtype: - input_num: - axis: - - Returns: - ''' - if axis > ndim: - return - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - shapes = [shape] * input_num - in_tensor = [tvm.placeholder(shape, dtype=dtype, name='in_tensor%d' % i) for i, shape in enumerate(shapes)] - opname = "Stack_ndim%d_%s_input_num%d_axis%d" % (ndim, dtype, input_num, axis) - print(opname) - - input_tensor = [topi.expand_dims(ai, axis) for ai in in_tensor] - out_tensor = topi.concatenate(tuple(input_tensor), axis=axis) - tensor_list = in_tensor + [out_tensor] - if ndim < 4: - s = topi.generic.schedule_concatenate(out_tensor) - else: - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def ArgMax(device="llvm", lib_path="./", - ndim=None, dtype=None, axis=None, keep_dims=None, top_k=None, - out_dtype=None): - ''' - argmax - Args: - device: - lib_path: - ndim: - dtype: - axis: - keepDims: - top_k: - out_dtype: - - Returns: - ''' - if axis >= ndim: - return - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - opname = "ArgMax_ndim%d_%s_axis%d_%s_top%d_%s" \ - % (ndim, dtype, axis, "keepDims" if keep_dims else "notKeepDims", top_k, out_dtype) - print(opname) - - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - out_tensor = topi.argmax(in_tensor, axis=axis, keepdims=keep_dims) - out_tensor = AsType(out_tensor, out_dtype) - tensor_list = [in_tensor, out_tensor] - s = tvm.create_schedule(out_tensor.op) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Exp(device="llvm", lib_path="./", - ndim=None, dtype=None): - ''' - exp - Args: - device: - lib_path: - ndim: - dtype: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - opname = "Exp_ndim%d_%s" % (ndim, dtype) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - if 'int' in dtype: - input_tensor = AsType(in_tensor, 'float32') - out_tensor = topi.exp(input_tensor) - out_tensor = AsType(out_tensor, in_tensor.dtype) - else: - out_tensor = topi.exp(in_tensor) - tensor_list = [in_tensor, out_tensor] - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Cast(device="llvm", lib_path="./", - ndim=None, src_dtype=None, dst_dtype=None): - ''' - cast - Args: - device: - lib_path: - ndim: - src_dtype: - dst_dtype: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - opname = "Cast_ndim%d_%s_%s" % (ndim, src_dtype, dst_dtype) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=src_dtype, name='in_tensor') - out_tensor = topi.cast(in_tensor, dst_dtype) - tensor_list = [in_tensor, out_tensor] - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def ExpandDims(device="llvm", lib_path="./", - ndim=None, axis=None, dtype=None): - ''' - expand dims - Args: - device: - lib_path: - ndim: - axis: - dtype: - - Returns: - ''' - if axis > ndim: - return - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - opname = "ExpandDim_ndim%d_%s_axis%d" % (ndim, dtype, axis) - print(opname) - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') - out_tensor = topi.expand_dims(in_tensor, axis=axis) - tensor_list = [in_tensor, out_tensor] - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Tile(device="llvm", lib_path="./", - ndim=None, dtype=None): - ''' - tile - Args: - device: - lib_path: - ndim: - dtype: - - Returns: - ''' - shape = [tvm.var("n" + str(i)) for i in range(ndim)] - multiples = [tvm.var("k" + str(i)) for i in range(ndim)] - opname = "Tile_ndim%d_%s" % (ndim, dtype) - print(opname) - - def _Compute(*C_idx): - a_idx = [tvm.floordiv(idx, multiples[i]) for i, idx in enumerate(list(C_idx))] - a_idx = tuple(a_idx) - return in_tensor[a_idx] - - # define compute - in_tensor = tvm.placeholder(shape, dtype=dtype, name='in_tensor') # tvm 0.6-dev: topi.tile - shape_c = (np.array(shape) * np.array(multiples)).tolist() - out_tensor = tvm.compute(shape_c, _Compute) - - tensor_list = multiples + [in_tensor, out_tensor] - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Range(device="llvm", lib_path="./", - out_dtype=None): - ''' - range - Args: - device: - lib_path: - out_dtype: - - Returns: - ''' - start = tvm.var("start") - delta = tvm.var("delta") - opname = "Range_ndim_" + out_dtype - print(opname) - - out_tensor = tvm.compute((tvm.var("n0"),), lambda i: start.astype(out_dtype) + delta.astype(out_dtype) * i, \ - name='out_tensor') - out_tensor = AsType(out_tensor, out_dtype) - tensor_list = [start, delta, out_tensor] - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) - - -def Split(device="llvm", lib_path="./", - ndim=None, dtype=None, output_num=None, axis=None): - ''' - split - Args: - device: - lib_path: - ndim: - dtype: - output_num: - axis: - - Returns: - ''' - if axis >= ndim: - return - size_splits = [tvm.var("split" + str(i)) for i in range(output_num)] - a_shape = [tvm.var("n" + str(i)) for i in range(axis)] \ - + [np.sum(size_splits)] \ - + [tvm.var("n" + str(i)) for i in range(axis + 1, ndim)] - c_shapes = [] - for i in range(output_num): - c_shape = [] - for j in range(ndim): - if j == axis: - c_shape.append(tvm.var("split" + str(i))) - else: - c_shape.append(tvm.var("n" + str(j))) - c_shapes.append(c_shape) - indices_or_sections = np.cumsum(size_splits).tolist()[:-1] - opname = "Split_ndim%d_%s_output_num%d_axis%d" % (ndim, dtype, output_num, axis) - print(opname) - - # define compute - in_tensor = tvm.placeholder(a_shape, dtype=dtype, name='in_tensor') - - def _Compute(*C_idx): - a_idx = list(C_idx) - a_idx[axis] += idx_shift - a_idx = tuple(a_idx) - return in_tensor[a_idx] - - indices_or_sections_add0 = [0] + indices_or_sections - out_tensor = [] - for i in range(output_num): - idx_shift = indices_or_sections_add0[i] - ci = tvm.compute(c_shapes[i], _Compute) - out_tensor.append(ci) - tensor_list = size_splits + [in_tensor] + out_tensor - - s = topi.generic.schedule_injective(out_tensor) - Genlib(s, tensor_list, device, opname, lib_path) diff --git a/predict/module/tvm_kernel/lite/python/at_ops/config_tool.py b/predict/module/tvm_kernel/lite/python/at_ops/config_tool.py deleted file mode 100644 index b3006b1174..0000000000 --- a/predict/module/tvm_kernel/lite/python/at_ops/config_tool.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -This module is define some data struct for tvm kernel. -""" -import tvm -import topi - -format_map = {"NCHW": 0, "NHWC": 1} - -pool_map = {"max_pool": 0, "avg_pool": 1, "global_pool": 2} - -activation_map = { - "no_activation": 0, - "relu": 1, - "sigmoid": 2, - "relu6": 3, - "elu": 4, - "leaky_relu": 5, - "abs": 6, - "relu1": 7, - "softsign": 8, - "softplus": 9, - "tanh ": 10, -} -activation_enum_map = { - "NO_ACTIVATION": 0, - "RELU": 1, - "SIGMOID": 2, - "RELU6": 3, - "elu": 4, - "leaky_relu": 5, - "abs": 6, - "relu1": 7, - "softsign": 8, - "softplus": 9, - "tanh ": 10, -} - -padmode_map = {"NOTSET": 0, "SAME": 1, "VALID": 2} - -mslite_datatype_map = { - "float16": 1, - "float32": 0, - "double": 11, - "int8": 2, - "int16": 6, - "int32": 3, - "int64": 9, - "uint8": 4, - "uint16": 7, - "uint32": 8, - "uint64": 10, -} - - -def get_key_by_value(dicts, value): - for k, v in dicts.items(): - if v == value: - return k - return None - - -def relu6(x): - return tvm.compute( - x.shape, - lambda *i: tvm.min( - tvm.max(x(*i), tvm.const(0, x.dtype)), tvm.const(6, x.dtype) - ), - ) - - -activation_topi_funs = {"NO_ACTIVATION": None, "RELU": topi.nn.relu, "RELU6": relu6} - -name_funcs = { - "Concat": ( - lambda opname, x: ( - opname + "_%d_%d" + "_%d" + "_%d" * x["ndim"] + "_%d" * len(x["shapeAxis"]) - ) - % ( - format_map[x["format"]], - x["ndim"], - x["axis"], - *x["shapeOut"], - *x["shapeAxis"], - ) - ), - "Softmax": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" * x["ndim"] + "_%d") - % (format_map[x["format"]], x["ndim"], *x["shape"], x["axis"]) - ), - "Activation": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" + "_%d" * x["ndim"]) - % (format_map[x["format"]], x["ndim"], activation_map[x["type"]], *x["shape"]) - ), - "Add": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" * x["ndim"]) - % (format_map[x["format"]], x["ndim"], *x["shape"]) - ), - "Convolution": ( - lambda opname, x: ( - opname + "_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d" - ) - % ( - format_map[x["format"]], - x["ndim"], - x["batch"], - x["in_channel"], - *x["in_size"], - x["num_filter"], - *x["filter_size"], - *x["pad"], - *x["stride"], - x["dilation"], - x["hasbias"], - activation_map[x["activation_type"]], - ) - ), - "Identity": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" * x["ndim"]) - % (format_map[x["format"]], x["ndim"], *x["shape"]) - ), - "BatchNorm": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" * x["ndim"] + "_%d") - % (format_map[x["format"]], x["ndim"], *x["shape"], x["epsilon"]) - ), - "Squeeze": ( - lambda opname, x: ( - opname + "_%d_%d" + "_%d" * x["ndim"] + "_%d" * len(x["axis"]) - ) - % (format_map[x["format"]], x["ndim"], *x["shape"], *x["axis"]) - ), - "BiasAdd": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" * x["ndim"] + "_%d") - % (format_map[x["format"]], x["ndim"], *x["shape"], x["axis"]) - ), - "Pooling": ( - lambda opname, x: (opname + "_%d_%d_%d" + "_%d" * x["ndim"] + "_%d_%d_%d") - % ( - format_map[x["format"]], - x["ndim"], - pool_map[x["type"]], - *x["shape"], - x["kernel"], - x["stride"], - x["pad"], - ) - ), - "ConvolutionDepthwise": ( - lambda opname, x: ( - opname + "_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d" - ) - % ( - format_map[x["format"]], - x["ndim"], - x["batch"], - x["in_channel"], - *x["in_size"], - x["in_channel"] * x["channel_multiplier"], - *x["filter_size"], - *x["pad"], - *x["stride"], - x["dilation"], - x["hasbias"], - activation_map[x["activation_type"]], - ) - ), - "Reshape": ( - lambda opname, x: ( - opname + "_%d_%d" + "_%d" * x["ndimA"] + "_%d" * len(x["shapeB"]) - ) - % (format_map[x["format"]], x["ndimA"], *x["shapeA"], *x["shapeB"]) - ), - "Shape": ( - lambda opname, x: (opname + "_%d_%d" + "_%d" * x["ndim"]) - % (format_map[x["format"]], x["ndim"], *x["shape"]) - ), - "RealDiv": ( - lambda opname, x: ( - opname + "_%d_%d" + "_%d" * x["ndim"] + "_%d" * len(x["shapeB"]) - ) - % (format_map[x["format"]], x["ndim"], *x["shapeA"], *x["shapeB"]) - ), - "ResizeBilinear": (lambda opname, x: "ResizeBilinear"), - "TFLite_Detection_PostProcess": (lambda opname, x: "TFLite_Detection_PostProcess"), -} - -config_dict = {op_type: [] for op_type in name_funcs} - - -def config_dict_append(op_type, config, opname=None): - if opname is None: - config["opname"] = name_funcs[op_type](op_type, config) - else: - config["opname"] = opname - duplicate = [True for x in config_dict[op_type] if config == x] - - if duplicate: - config_dict[op_type].append(config) diff --git a/predict/module/tvm_kernel/lite/python/at_rt/at_runtime_reset.py b/predict/module/tvm_kernel/lite/python/at_rt/at_runtime_reset.py deleted file mode 100644 index cc0cc72885..0000000000 --- a/predict/module/tvm_kernel/lite/python/at_rt/at_runtime_reset.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -This module is Using to make Lite Runtime funcitons instead TVM Runtime funcitons while codegen. -""" - -import os -from tvm import codegen - -class AtRuntimeReset(): - """Using this class to make Lite Runtime funcitons instead TVM Runtime funcitons while codegen - Usage like: - with at_runtime_reset.AtRuntimeReset(): - fadd = tvm.build(s, [A, B], tgt, target_host = tgt_host, name = "myadd") - then the module fadd will using Lite runtime functions. - """ - - def __enter__(self): - if os.getenv("TVM_RUNTIME_ON") is not None: - return - codegen.SetRTFuncTransPair( - "TVMBackendAllocWorkspace", "LiteBackendAllocWorkspace" - ) - codegen.SetRTFuncTransPair( - "TVMBackendFreeWorkspace", "LiteBackendFreeWorkspace" - ) - codegen.SetRTFuncTransPair("TVMAPISetLastError", "LiteAPISetLastError") - codegen.SetRTFuncTransPair( - "TVMBackendParallelLaunch", "LiteBackendParallelLaunch" - ) - codegen.SetRTFuncTransPair( - "TVMBackendParallelBarrier", "LiteBackendParallelBarrier" - ) - codegen.SetRTFuncTransPair( - "TVMBackendRegisterSystemLibSymbol", "LiteBackendRegisterSystemLibSymbol" - ) - codegen.SetRTFuncTransPair("TVMFuncCall", "LiteFuncCall") - codegen.SetRTFuncTransPair( - "TVMBackendGetFuncFromEnv", "LiteBackendGetFuncFromEnv" - ) - - def __exit__(self, ptype, value, trace): - codegen.DelRTFuncTransPair("TVMBackendAllocWorkspace") - codegen.DelRTFuncTransPair("TVMBackendFreeWorkspace") - codegen.DelRTFuncTransPair("TVMAPISetLastError") - codegen.DelRTFuncTransPair("TVMBackendParallelLaunch") - codegen.DelRTFuncTransPair("TVMBackendParallelBarrier") - codegen.DelRTFuncTransPair("TVMBackendRegisterSystemLibSymbol") - codegen.DelRTFuncTransPair("TVMFuncCall") - codegen.DelRTFuncTransPair("TVMBackendGetFuncFromEnv") diff --git a/predict/module/tvm_kernel/lite/src/api/kernel_manager.cc b/predict/module/tvm_kernel/lite/src/api/kernel_manager.cc deleted file mode 100644 index b349ae6019..0000000000 --- a/predict/module/tvm_kernel/lite/src/api/kernel_manager.cc +++ /dev/null @@ -1,1772 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this ${file} except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LITE_RUNTIME_ON - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/mslog.h" - -const char *LIB_INFO = "libtvm_kernel version: master (c66c6b28dc991c9d705e1b983aab7385c337128d)"; -namespace km { -class KernelManager { - public: - int CallKernel(const std::string &fid, TVMArgs args) { - tvm::runtime::Module *mod = this->GetModule(); - CHECK(mod != nullptr) << "Failed to get Module!"; - const std::string name = fid; - tvm::runtime::PackedFunc f = mod->GetFunction(name, false); - CHECK(f != nullptr) << "Can't find kernel func " << fid; - TVMRetValue rv; - f.CallPacked(args, &rv); - return 0; - } - - void InitKernelManager(int mode, const std::string &fname) { return this->Init(mode, fname); } - - static KernelManager *Global() { - static KernelManager inst; - return &inst; - } - - tvm::runtime::Module *GetModule() const { return &g_modLib; } - - private: - KernelManager() = default; - - ~KernelManager() = default; - - void Init(int mode, std::string fpath) { - std::call_once(init_flag, &KernelManager::InitLib, mode, fpath); - return; - } - - static void InitLib(int mode, std::string fpath) { - if (mode) { - const PackedFunc *ptr = tvm::runtime::Registry::Get("module._GetSystemLib"); - CHECK(ptr != nullptr) << "Failed to get systemlib"; - g_modLib = (*ptr)(); - } else { - g_modLib = tvm::runtime::Module::LoadFromFile(fpath); - } - } - static tvm::runtime::Module g_modLib; - std::once_flag init_flag; -}; - -tvm::runtime::Module KernelManager::g_modLib; -} // namespace km - -std::function &)> GetKernel(const std::string &fid) { - km::KernelManager *inst = km::KernelManager::Global(); - CHECK(inst != nullptr) << "Failed to get KernelManager instance!"; - tvm::runtime::Module *mod = inst->GetModule(); - CHECK(mod != nullptr) << "Failed to get Module!"; - tvm::runtime::PackedFunc f = mod->GetFunction(fid, false); - if (f == nullptr) { - MS_LOGE("GetFunction return nullptr"); - return nullptr; - } - auto runner = [f](const std::vector &tensors) -> int { - int argLen = tensors.size(); - CHECK(argLen) << "Input tensors num=0 !"; - std::vector values(argLen); - std::vector codes(argLen); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (int i = 0; i < argLen; ++i) { - setter(i, tensors.at(i)); - } - tvm::runtime::TVMArgs targs(values.data(), codes.data(), argLen); - TVMRetValue rv; - f.CallPacked(targs, &rv); - return 0; - }; - return runner; -} - -int CallKernel(const std::string &fid, const std::vector &tensors) { - km::KernelManager *inst = km::KernelManager::Global(); - CHECK(inst != nullptr) << "Failed to get KernelManager instance!"; - int argLen = tensors.size(); - CHECK(argLen) << "Input tensors num=0 !"; - std::vector values(argLen); - std::vector codes(argLen); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (int i = 0; i < argLen; ++i) { - setter(i, tensors.at(i)); - } - tvm::runtime::TVMArgs targs(values.data(), codes.data(), argLen); - inst->CallKernel(fid, targs); - return 0; -} - -int InitKernelManager(int mode, const std::string &fname) { - km::KernelManager *inst = km::KernelManager::Global(); - CHECK(inst != nullptr) << "Failed to get KernelManager instance!"; - inst->InitKernelManager(mode, fname); - return 0; -} - -// just for api compatible, tvm/lite has same api -void ConfigThreadPool(int mode = 1, int nthreads = 0, bool execute_self = true) {} - -#else - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "flatbuffers/flatbuffers.h" -#include "schema/inner/ms_generated.h" -#include "include/securec.h" -#include "src/runtime/runtime_api.h" -#include "common/mslog.h" - -using runnerType = std::function &)>; - -const char *LIB_INFO = "libtvm_kernel version: master (c66c6b28dc991c9d705e1b983aab7385c337128d)"; - -namespace lite { -namespace runtime { -extern "C" { -// Function signature for generated packed function in shared library -typedef int (*BackendPackedCFunc)(const void *args, int *type_codes, int num_args); -} // extern "C" - -class LiteFuncPool { - public: - LiteFuncPool() = default; - - ~LiteFuncPool() = default; - - void GetFunction(const std::string &name, void **func_addr) { - auto it = tbl_.find(name); - if (func_addr == nullptr) { - MS_LOGW("input func_addr is nullptr"); - return; - } - *func_addr = (it != tbl_.end() ? it->second : nullptr); - } - - void RegisterSymbol(const std::string &name, void *ptr) { - std::lock_guard lock(mutex_); - auto it = tbl_.find(name); - if (it != tbl_.end() && ptr != it->second) { - MS_LOGW("Lite symbol %s get overriden to a different address %p->%p", name.c_str(), ptr, it->second); - } - tbl_[name] = ptr; - } - - static LiteFuncPool *Global() { - static LiteFuncPool inst; - return &inst; - } - - private: - // Internal mutex - std::mutex mutex_; - // Internal symbol table - std::unordered_map tbl_; -}; -} // namespace runtime -} // namespace lite - -using LiteFuncPool = lite::runtime::LiteFuncPool; -using BackendPackedCFunc = lite::runtime::BackendPackedCFunc; - -int LiteBackendRegisterSystemLibSymbol(const char *name, void *ptr) { - MS_ASSERT(LiteFuncPool::Global() != nullptr); - LiteFuncPool::Global()->RegisterSymbol(name, ptr); - return 0; -} - -// do nothing, api compatible with TVM_RUNTIME_ON API -void InitKernelManager(int mode, const std::string &fname) { return; } - -static inline void *GetFunction(const std::string &fid) { - void *f = nullptr; - MS_ASSERT(LiteFuncPool::Global() != nullptr); - LiteFuncPool::Global()->GetFunction(fid, &f); - if (f == nullptr) { - return nullptr; - } - return f; -} - -runnerType __attribute__((noinline)) GetKernel(const std::string &fid) { - auto f = GetFunction(fid); - if (f == nullptr) { - return nullptr; - } - auto runner = [f](const std::vector &tensors) -> int { - if (tensors.empty()) { - MS_LOGE("Input tensors num = 0 !"); - return -1; - } - std::vector values(tensors.size()); - std::vector codes(tensors.size()); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (size_t i = 0; i < tensors.size(); ++i) { - setter(i, tensors.at(i)); - } - auto passfunc = reinterpret_cast(f); - return passfunc(values.data(), codes.data(), tensors.size()); - }; - return runner; -} - -namespace auto_tensor { -constexpr int TENSOR_NUM_MAX = 10; -constexpr bool STORE_MODE = true; -constexpr bool RESUME_MODE = false; -const char *NOT_SUPPORT = "NOT SUPPORT"; -const int NCHW_N = 0; -const int NCHW_C = 1; -const int NCHW_H = 2; -const int NCHW_W = 3; -const int tile = 4; - -void store_shape(const std::vector &tensors, int (&ndim)[TENSOR_NUM_MAX], int64_t *(&shape)[TENSOR_NUM_MAX], - int64_t *(&strides)[TENSOR_NUM_MAX], bool mode = STORE_MODE) { - if (mode == STORE_MODE) { - for (size_t i = 0; i < tensors.size(); ++i) { - ndim[i] = tensors[i]->ndim; - shape[i] = tensors[i]->shape; - strides[i] = tensors[i]->strides; - } - } else { - for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i]->ndim = ndim[i]; - tensors[i]->shape = shape[i]; - tensors[i]->strides = strides[i]; - } - } -} - -static std::string get_dtype(const DLTensor &tensor) { - auto dtype = tensor.dtype; - if (dtype.code == kDLFloat) { - if (dtype.bits == 16) - return "float16"; - else if (dtype.bits == 32) - return "float32"; - else if (dtype.bits == 64) - return "float64"; - } else if (dtype.code == kDLInt) { - if (dtype.bits == 8) - return "int8"; - else if (dtype.bits == 16) - return "int16"; - else if (dtype.bits == 32) - return "int32"; - else if (dtype.bits == 64) - return "int64"; - } else if (dtype.code == kDLUInt) { - if (dtype.bits == 8) - return "uint8"; - else if (dtype.bits == 16) - return "uint16"; - else if (dtype.bits == 32) - return "uint32"; - else if (dtype.bits == 64) - return "uint64"; - } - return std::string(NOT_SUPPORT); -} - -struct OpCommonAttr { - std::string optype = ""; - std::string fid = ""; - uint32_t ndim = 0; - std::string dtype = "float32"; - - OpCommonAttr(const mindspore::predict::OpDef &opdef, const std::vector &tensors) { - auto opT = mindspore::predict::EnumNameOpT(opdef.attr_type()); - this->optype = opT; - MS_ASSERT(opdef.name() != nullptr); - this->fid = opdef.name()->str(); - if (!tensors.empty()) { - MS_ASSERT(tensors.front() != nullptr); - ndim = tensors.front()->ndim; - dtype = get_dtype(*tensors.front()); - } - } -}; - -template -static void NCHW2NHWC(DLTensor *src) { - if (src == nullptr) { - MS_LOGW("input src is nullptr"); - return; - } - T *src_data = static_cast(src->data); - std::unique_ptr tmp(new (std::nothrow) - T[src->shape[NCHW_N] * src->shape[NCHW_C] * src->shape[NCHW_H] * src->shape[NCHW_W]]); - if (tmp == nullptr) { - MS_LOGW("new tmp buf failed"); - return; - } - int N = src->shape[NCHW_N]; - int C = src->shape[NCHW_C]; - int H = src->shape[NCHW_H]; - int W = src->shape[NCHW_W]; - - // NCHW -> NHWC - int k = 0; - for (int n = 0; n < N; n++) - for (int h = 0; h < H; h++) - for (int w = 0; w < W; w++) - for (int c = 0; c < C; c++) { - tmp[k++] = src_data[n * C * H * W + c * H * W + h * W + w]; - } - - int sizes = N * C * H * W * sizeof(T); - errno_t ret = memcpy_s(src_data, sizes, tmp.get(), sizes); - if (ret != 0) { - MS_LOGW("memcpy_s failed: %d", ret); - return; - } -} - -static void transpose_shape(DLTensor *tensor, std::vector axis) { - if (tensor == nullptr) { - MS_LOGW("input tensor is nullptr"); - return; - } - int ndim = tensor->ndim; - std::vector origin_shape(tensor->shape, tensor->shape + ndim); - - for (int i = ndim - 1; i >= 0; --i) { - tensor->shape[i] = origin_shape[axis[i]]; - } -} - -static runnerType Pack_NCHW2NHWC(runnerType fun) { - if (fun == nullptr) { - MS_LOGE("input fun is nullptr"); - return nullptr; - } - auto runner = [fun](const std::vector &tensors) -> int { - if (tensors.back() == nullptr) { - MS_LOGE("tensors.back() is nullptr"); - return 1; - } - transpose_shape(tensors.back(), {0, 3, 1, 2}); // NHWC -> NCHW - fun(tensors); - - auto output = tensors.back(); - if (output == nullptr) { - MS_LOGE("tensors.back() after func is nullptr"); - return 1; - } - if (output->dtype.bits == 8) { - NCHW2NHWC(output); - } else if (output->dtype.bits == 16) { - NCHW2NHWC(output); - } else if (output->dtype.bits == 32) { - NCHW2NHWC(output); - } else if (output->dtype.bits == 64) { - NCHW2NHWC(output); - } else { - MS_LOGE("conv NCHW2NHWC output.dtype.bits=%d invalid, only support (8, 16, 32, 64)", output->dtype.bits); - return 1; - } - - if (tensors.back() == nullptr) { - MS_LOGE("tensors.back() is nullptr"); - return 1; - } - transpose_shape(tensors.back(), {0, 2, 3, 1}); // NCHW -> NHWC - return 0; - }; - return runner; -} - -runnerType __attribute__((noinline)) GetKernel_Insert_vector_int32(const std::string &fid, - const std::vector &vec) { - auto f = GetFunction(fid); - if (f == nullptr) { - MS_LOGE("GetFunction return nullptr"); - return nullptr; - } - auto runner = [f, vec](const std::vector &tensors) -> int { - std::vector values(vec.size() + tensors.size()); - std::vector codes(values.size()); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (size_t i = 0; i < vec.size(); ++i) { - setter(i, vec.at(i)); - } - for (size_t i = 0; i < tensors.size(); ++i) { - setter(i + vec.size(), tensors.at(i)); - } - auto passfunc = reinterpret_cast(f); - return passfunc(values.data(), codes.data(), values.size()); - }; - return runner; -} - -runnerType __attribute__((noinline)) GetKernel_Insert_vector_float(const std::string &fid, - const std::vector &vec) { - auto f = GetFunction(fid); - if (f == nullptr) { - MS_LOGE("GetFunction return nullptr"); - return nullptr; - } - auto runner = [f, vec](const std::vector &tensors) -> int { - std::vector values(vec.size() + tensors.size()); - std::vector codes(values.size()); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (size_t i = 0; i < vec.size(); ++i) { - setter(i, vec.at(i)); - } - for (size_t i = 0; i < tensors.size(); ++i) { - setter(i + vec.size(), tensors.at(i)); - } - auto passfunc = reinterpret_cast(f); - return passfunc(values.data(), codes.data(), values.size()); - }; - return runner; -} - -static runnerType GetKernel_Conv(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - if (tensors.at(0) == nullptr) { - MS_LOGE("input tensors.at(0) is nullptr"); - return nullptr; - } - int n = tensors.at(0)->shape[NCHW_N]; - int ci = tensors.at(0)->shape[NCHW_C]; - int h = tensors.at(0)->shape[NCHW_H]; - int w = tensors.at(0)->shape[NCHW_W]; - std::vector arg_const{n, ci, h, w}; - const OpCommonAttr opAttr(opdef, tensors); - - std::string fid; - if (opdef.attr_as_Conv2D() != nullptr) { - auto op = opdef.attr_as_Conv2D(); - fid = std::string(mindspore::predict::EnumNameOpT(opdef.attr_type())) + "_ndim" + std::to_string(opAttr.ndim) + - "_" + opAttr.dtype + "_k" + std::to_string(op->kernelH()) + "_s" + std::to_string(op->strideH()) + "_p" + - std::to_string(op->padUp()) + std::to_string(op->padDown()) + std::to_string(op->padLeft()) + - std::to_string(op->padRight()) + "_d" + std::to_string(op->dilateH()) + "_act" + - std::to_string(static_cast(op->activationType())) + "_vc" + std::to_string(1) + "_vh" + - std::to_string(1) + "_vw" + std::to_string(1) + "_hasbias" + std::to_string(op->hasBias()); - if (tensors.at(1) == nullptr) { - MS_LOGE("input tensors.at(1) is nullptr"); - return nullptr; - } - int co = tensors.at(1)->shape[NCHW_N]; - arg_const.push_back(co); - } else if (opdef.attr_as_DepthwiseConv2D() != nullptr) { - auto op = opdef.attr_as_DepthwiseConv2D(); - fid = std::string(mindspore::predict::EnumNameOpT(opdef.attr_type())) + "_ndim" + std::to_string(opAttr.ndim) + - "_" + opAttr.dtype + "_k" + std::to_string(op->kernelH()) + "_s" + std::to_string(op->strideH()) + "_p" + - std::to_string(op->padUp()) + std::to_string(op->padDown()) + std::to_string(op->padLeft()) + - std::to_string(op->padRight()) + "_d" + std::to_string(op->dilateH()) + "_act" + - std::to_string(static_cast(op->activationType())) + "_vc" + std::to_string(1) + "_vh" + - std::to_string(1) + "_vw" + std::to_string(1) + "_hasbias" + std::to_string(op->hasBias()); - int co = tensors.at(0)->shape[NCHW_C] * op->channelMultiplier(); - arg_const.push_back(co); - } else if (opdef.attr_as_DeDepthwiseConv2D() != nullptr) { - auto op = opdef.attr_as_DeDepthwiseConv2D(); - fid = std::string(mindspore::predict::EnumNameOpT(opdef.attr_type())) + "_ndim" + std::to_string(opAttr.ndim) + - "_" + opAttr.dtype + "_k" + std::to_string(op->kernelH()) + "_s" + std::to_string(op->strideH()) + "_p" + - std::to_string(op->padUp()) + std::to_string(op->padDown()) + std::to_string(op->padLeft()) + - std::to_string(op->padRight()) + "_d" + std::to_string(op->dilateH()) + "_act" + - std::to_string(static_cast(op->activationType())) + "_vc" + std::to_string(1) + "_vh" + - std::to_string(1) + "_vw" + std::to_string(1) + "_hasbias" + std::to_string(op->hasBias()); - int co = tensors.at(0)->shape[NCHW_C] * op->channelMultiplier(); - arg_const.push_back(co); - } - auto fun = GetKernel(fid); - if (fun == nullptr) { - MS_LOGE("GetKernel return nullptr"); - return nullptr; - } - - auto f = GetFunction(fid); - if (f == nullptr) { - MS_LOGE("GetFunction return nullptr"); - return nullptr; - } - auto runner = [f, arg_const](const std::vector &tensors) -> int { - int ndim[TENSOR_NUM_MAX]; - int64_t *shapes[TENSOR_NUM_MAX]; - int64_t *strides[TENSOR_NUM_MAX]; - store_shape(tensors, ndim, shapes, strides, STORE_MODE); - - std::vector values(arg_const.size() + tensors.size()); - std::vector codes(values.size()); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (size_t i = 0; i < arg_const.size(); ++i) { - setter(i, arg_const.at(i)); - } - for (size_t i = 0; i < tensors.size(); ++i) { - setter(i + arg_const.size(), tensors.at(i)); - } - auto passfunc = reinterpret_cast(f); - passfunc(values.data(), codes.data(), values.size()); - store_shape(tensors, ndim, shapes, strides, RESUME_MODE); - return 0; - }; - fun = runner; - - if (opdef.isLastConv()) { - return Pack_NCHW2NHWC(fun); - } - return fun; -} - -void update_shape_NC4HW4(const std::vector &tensors, int64_t (&shapeA)[TENSOR_NUM_MAX], - int64_t (&shapeC)[TENSOR_NUM_MAX]) { - auto inputA = tensors.front(); - auto output = tensors.back(); - if (inputA == nullptr) { - MS_LOGW("input tensors.front() is nullptr"); - return; - } - if (output == nullptr) { - MS_LOGW("input tensors.back() is nullptr"); - return; - } - shapeA[inputA->ndim] = tile; - for (int32_t i = 0; i < inputA->ndim; ++i) { - if (i == 1) { - shapeA[i] = inputA->shape[i] >> 2; - } else { - shapeA[i] = inputA->shape[i]; - } - } - { - inputA->ndim = inputA->ndim + 1; - inputA->shape = shapeA; - inputA->strides = nullptr; - } - shapeC[output->ndim] = tile; - for (int32_t i = 0; i < output->ndim; ++i) { - if (i == 1) { - shapeC[i] = output->shape[i] >> 2; - } else { - shapeC[i] = output->shape[i]; - } - } - { - output->ndim = output->ndim + 1; - output->shape = shapeC; - output->strides = nullptr; - } -} - -static runnerType GetKernel_Conv_var(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto fun = GetKernel(opAttr.fid); - if (tensors.at(0) == nullptr) { - MS_LOGE("input tensors.at(0) is nullptr"); - return nullptr; - } - std::string fid = opAttr.fid.substr(0, opAttr.fid.find('_')); - int n = tensors.at(0)->shape[NCHW_N]; - int ci = tensors.at(0)->shape[NCHW_C]; - int h = tensors.at(0)->shape[NCHW_H]; - int w = tensors.at(0)->shape[NCHW_W]; - int co = tensors.at(1)->shape[NCHW_C]; - std::vector arg_const{n, ci >> 2, h, w, co}; - if (fun == nullptr) { - auto fd = [](int h, std::vector &res) { - for (int i = 2; i <= h; i += 2) { - if ((h % i) == 0) res.emplace_back(i); - } - }; - int outidx = tensors.size() - 1; - std::vector vw; - if (tensors.at(outidx) == nullptr) { - MS_LOGE("input tensors.at(%d) is nullptr", outidx); - return nullptr; - } - fd(tensors.at(outidx)->shape[NCHW_W], vw); - - auto op = opdef.attr_as_DeConv2D(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_DeConv2D() is nullptr"); - return nullptr; - } - std::string fids; - for (auto iter = vw.rbegin(); iter != vw.rend(); iter++) { - fids = fid + "_ndim" + std::to_string(opAttr.ndim + 1) + "_" + opAttr.dtype + "_k" + - std::to_string(op->kernelH()) + "_s" + std::to_string(op->strideH()) + "_p" + std::to_string(op->padUp()) + - std::to_string(op->padDown()) + std::to_string(op->padLeft()) + std::to_string(op->padRight()) + "_d" + - std::to_string(op->dilateH()) + "_act" + std::to_string(static_cast(op->activationType())) + "_vc" + - std::to_string(4) + "_vh" + std::to_string(2) + "_vw" + std::to_string(*iter) + "_hasbias" + - std::to_string(op->hasBias()); - fun = GetKernel(fids); - if (fun != nullptr) { - break; - } - } - fid = fids; - if (fun == nullptr) { - MS_LOGE("fun is nullptr"); - return nullptr; - } - auto f = GetFunction(fid); - if (f == nullptr) { - MS_LOGE("GetFunction return nullptr"); - return nullptr; - } - auto runner = [f, arg_const](const std::vector &tensors) -> int { - int ndim[TENSOR_NUM_MAX]; - int64_t *shapes[TENSOR_NUM_MAX]; - int64_t *strides[TENSOR_NUM_MAX]; - int64_t shapeA[TENSOR_NUM_MAX]; - int64_t shapeC[TENSOR_NUM_MAX]; - store_shape(tensors, ndim, shapes, strides, STORE_MODE); - update_shape_NC4HW4(tensors, shapeA, shapeC); - - std::vector values(arg_const.size() + tensors.size()); - std::vector codes(values.size()); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (size_t i = 0; i < arg_const.size(); ++i) { - setter(i, arg_const.at(i)); - } - for (size_t i = 0; i < tensors.size(); ++i) { - setter(i + arg_const.size(), tensors.at(i)); - } - auto passfunc = reinterpret_cast(f); - passfunc(values.data(), codes.data(), values.size()); - store_shape(tensors, ndim, shapes, strides, RESUME_MODE); - return 0; - }; - fun = runner; - } - - if (opdef.isLastConv()) { - return Pack_NCHW2NHWC(fun); - } - return fun; -} - -enum reahpeCHW_Mode { FusedCHW, ExpandCHW }; - -void update_shape_reahpeCHW(const std::vector &tensors, reahpeCHW_Mode mode, int64_t (&shape)[4], - int64_t (&strides)[4], bool reahpe_output = false) { - auto input = tensors.front(); - auto output = tensors.back(); - if (input == nullptr) { - MS_LOGW("input tensors.front() is nullptr"); - return; - } - if (output == nullptr) { - MS_LOGW("input tensors.back() is nullptr"); - return; - } - int ndim; - if (mode == FusedCHW) { - ndim = 2; - int64_t CHW = 1; - for (int32_t i = 1; i < input->ndim; ++i) { - CHW *= input->shape[i]; - } - shape[NCHW_N] = input->shape[NCHW_N]; - shape[NCHW_C] = CHW; - strides[1] = 1; - strides[0] = CHW; - } else { - ndim = 4; - shape[NCHW_N] = input->shape[NCHW_N]; - shape[NCHW_C] = input->shape[NCHW_C]; - shape[NCHW_H] = 1; - shape[NCHW_W] = 1; - strides[3] = 1; - strides[2] = 1; - strides[1] = 1; - strides[0] = input->shape[NCHW_C]; - } - - input->ndim = ndim; - input->shape = shape; - input->strides = strides; - if (reahpe_output) { - output->ndim = ndim; - output->shape = shape; - output->strides = strides; - } -} - -static runnerType Pack_reahpeCHW(const runnerType &fun, const std::vector &tensors, reahpeCHW_Mode mode, - bool reahpe_output = false) { - if (fun == nullptr) { - MS_LOGE("input fun is nullptr"); - return nullptr; - } - if (tensors.front() == nullptr) { - MS_LOGE("input tensors.front() is nullptr"); - return nullptr; - } - if ((tensors.front()->ndim == 2 && mode == FusedCHW) || (tensors.front()->ndim == 4 && mode == ExpandCHW)) { - return fun; - } - - auto runner = [fun, mode, reahpe_output](const std::vector &tensors) -> int { - int ndim[TENSOR_NUM_MAX]; - int64_t *shape[TENSOR_NUM_MAX]; - int64_t *strides[TENSOR_NUM_MAX]; - int64_t shape_R[4]; - int64_t strides_R[4]; - store_shape(tensors, ndim, shape, strides, STORE_MODE); - update_shape_reahpeCHW(tensors, mode, shape_R, strides_R, reahpe_output); - fun(tensors); - store_shape(tensors, ndim, shape, strides, RESUME_MODE); - return 0; - }; - return runner; -} - -static runnerType GetKernel_BatchNorm(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - std::string fid; - std::vector epsilon(1, 0.001); - if (opAttr.optype == "BatchNorm") { - fid = "TFBatchNorm_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis1"; - epsilon.front() = opdef.attr_as_FusedBatchNorm()->epsilon(); - return GetKernel_Insert_vector_float(fid, epsilon); - } else if (opAttr.optype == "CaffeBatchNorm") { - fid = "CaffeBatchNorm_ndim4_" + opAttr.dtype + "_axis1"; - epsilon.front() = opdef.attr_as_CaffeBatchNorm()->epsilon(); - auto fun = GetKernel_Insert_vector_float(fid, epsilon); - if (fun == nullptr) { - MS_LOGE("GetKernel_Insert_vector_float return nullptr"); - return nullptr; - } - bool reahpe_output = true; - return Pack_reahpeCHW(fun, tensors, ExpandCHW, reahpe_output); - } else if (opAttr.optype == "BiasAdd") { - auto op = opdef.attr_as_BiasAdd(); - fid = "TFBiasAdd_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis" + - std::to_string(op->axis()->Get(0)); - return GetKernel(fid); - } else if (opAttr.optype == "Scale") { - fid = "CaffeScale_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis1"; - return GetKernel(fid); - } - return nullptr; -} - -void update_shape_flatten(const std::vector &tensors, int64_t *shape, int64_t *strides) { - auto inputA = tensors.back(); - if (inputA == nullptr) { - MS_LOGW("input tensors.back() is nullptr"); - return; - } - for (int32_t i = 0; i < inputA->ndim; ++i) { - *shape *= inputA->shape[i]; - } - for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i]->ndim = 1; - tensors[i]->shape = shape; - tensors[i]->strides = strides; - } -} - -std::string GetEltwiseMode(const OpCommonAttr &opAttr, const mindspore::predict::OpDef &opdef) { - const auto optype = opAttr.optype; - std::string mode = "add"; - if (optype == "Eltwise") { - auto op_mode = opdef.attr_as_Eltwise()->mode(); - if (mindspore::predict::EltwiseMode_PROD == op_mode) { - mode = "multiply"; - } else if (mindspore::predict::EltwiseMode_SUM == op_mode) { - mode = "add"; - } else if (mindspore::predict::EltwiseMode_MAXIMUM == op_mode) { - mode = "maximum"; - } - } else { - if ("Add" == optype) { - mode = "add"; - } else if ("Sub" == optype) { - mode = "subtract"; - } else if ("Mul" == optype) { - mode = "multiply"; - } else if ("RealDiv" == optype) { - mode = "divide"; - } else if ("Maximum" == optype) { - mode = "maximum"; - } - } - return mode; -} - -bool IsSwap(const std::vector &tensors) { - auto CalShape = [](DLTensor *tensor) -> int { - int res = 1; - if (tensor == nullptr) { - MS_LOGE("input DLTensor is nullptr"); - return -1; - } - for (int i = 0; i < tensor->ndim; ++i) { - res *= tensor->shape[i]; - } - return res; - }; - - MS_ASSERT(tensors[0] != nullptr); - MS_ASSERT(tensors[1] != nullptr); - auto ndimA = tensors[0]->ndim; - auto ndimB = tensors[1]->ndim; - bool isSwap = false; - - if (ndimA <= ndimB) { - auto AShape = CalShape(tensors[0]); - auto BShape = CalShape(tensors[1]); - if (AShape < BShape) { - isSwap = true; - } - } - return isSwap; -} - -static runnerType GetKernel_Eltwise(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - std::string mode = GetEltwiseMode(opAttr, opdef); - - // make fid - int indexA = 0; - int indexB = 1; - MS_ASSERT(tensors[0] != nullptr); - MS_ASSERT(tensors[1] != nullptr); - auto ndimA = tensors[0]->ndim; - auto ndimB = tensors[1]->ndim; - - bool isSwap = IsSwap(tensors); - if (isSwap) { - std::swap(ndimA, ndimB); - std::swap(indexA, indexB); - } - - MS_ASSERT(tensors[indexA] != nullptr); - MS_ASSERT(tensors[indexB] != nullptr); - if (ndimA == 1 && tensors[indexA]->shape[NCHW_N] == 1) { - ndimA = 0; - } - if (ndimB == 1 && tensors[indexB]->shape[NCHW_N] == 1) { - ndimB = 0; - } - bool is_same = ndimA == ndimB && ndimA > 1; - for (int i = 0; i < tensors[indexB]->ndim && is_same; ++i) { - if (tensors[indexB]->shape[i] != tensors[indexA]->shape[i]) { - is_same = false; - } - } - for (int i = 0; i < tensors[indexB]->ndim && ndimB > 1 && is_same == false; ++i) { - if (tensors[indexB]->shape[i] == 1) { - ndimB--; - } - } - - if (ndimA == ndimB && ndimA >= 1) { - std::string fid = "Eltwise_" + mode + "_ndimA1_ndimB1" + "_" + opAttr.dtype; - auto fun = GetKernel(fid); - if (fun == nullptr) { - MS_LOGE("GetKernel return nullptr"); - return nullptr; - } - auto runner = [fun, isSwap](const std::vector &tensors) -> int { - std::vector tensorsCopy(tensors); - if (isSwap) { - iter_swap(tensorsCopy.begin(), tensorsCopy.begin() + 1); - } - int ndim[TENSOR_NUM_MAX]; - int64_t *shapes[TENSOR_NUM_MAX]; - int64_t *strides[TENSOR_NUM_MAX]; - int64_t shape = 1; - int64_t stride = 1; - - store_shape(tensorsCopy, ndim, shapes, strides, STORE_MODE); - update_shape_flatten(tensorsCopy, &shape, &stride); - fun(tensorsCopy); - store_shape(tensorsCopy, ndim, shapes, strides, RESUME_MODE); - return 0; - }; - return runner; - } else { - std::string fid = - "Eltwise_" + mode + "_ndimA" + std::to_string(ndimA) + "_ndimB" + std::to_string(ndimB) + "_" + opAttr.dtype; - auto fun = GetKernel(fid); - if (fun == nullptr) { - MS_LOGE("GetKernel return nullptr"); - return nullptr; - } - auto runner = [fun, isSwap](const std::vector &tensors) -> int { - std::vector tensorsCopy(tensors); - if (isSwap) { - iter_swap(tensorsCopy.begin(), tensorsCopy.begin() + 1); - } - - fun(tensorsCopy); - return 0; - }; - return runner; - } -} - -static runnerType GetKernel_Resize(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - if (tensors.size() != 2) { - MS_LOGE("Input tensors num should be 2 !"); - return nullptr; - } - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Resize(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Resize() is nullptr"); - return nullptr; - } - std::string fid = "Resize_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype; - if (op->method() == mindspore::predict::ResizeMethod::ResizeMethod_NEAREST_NEIGHBOR) { - fid += "_nearest_neighbor"; - } else if (op->method() == mindspore::predict::ResizeMethod::ResizeMethod_BILINEAR) { - fid += "_bilinear"; - } - fid += (op->alignCorners()) ? "_Align" : "_NotAlign"; - std::vector HeightWidth = { - static_cast(op->newHeight()), - static_cast(op->newWidth()), - }; - return GetKernel_Insert_vector_int32(fid, HeightWidth); -} - -static runnerType GetKernel_DataCarry(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - auto runner = [](const std::vector &tensors) -> int { - auto input = tensors.front(); - auto output = tensors.back(); - if (input == nullptr) { - MS_LOGE("input tensors.front() is nullptr"); - return 1; - } - if (output == nullptr) { - MS_LOGE("input tensors.back() is nullptr"); - return 1; - } - uint64_t input_num = 1; - for (int i = 0; i < input->ndim; ++i) { - input_num *= input->shape[i]; - } - uint64_t input_byte_num = input_num * input->dtype.lanes * input->dtype.bits / 8; - - uint64_t output_num = 1; - for (int i = 0; i < output->ndim; ++i) { - output_num *= output->shape[i]; - } - uint64_t output_byte_num = output_num * output->dtype.lanes * output->dtype.bits / 8; - - errno_t ret = memcpy_s(output->data, output_byte_num, input->data, input_byte_num); - if (ret != 0) { - MS_LOGE("memset_s failed."); - return ret; - } - return 0; - }; - return runner; -} - -static runnerType GetKernel_Shape(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - auto runner = [](const std::vector &tensors) -> int { - auto input = tensors.front(); - auto output = tensors.back(); - if (input == nullptr) { - MS_LOGE("input tensors.front() is nullptr"); - return 1; - } - if (output == nullptr) { - MS_LOGE("input tensors.back() is nullptr"); - return 1; - } - for (int i = 0; i < input->ndim; ++i) { - reinterpret_cast(output->data)[i] = static_cast(input->shape[i]); - } - return 0; - }; - return runner; -} - -void StridedSliceArgs(const std::vector &input_shape, std::vector *begin, std::vector *end, - std::vector *stride, uint32_t begin_mask, uint32_t end_mask, uint32_t ellipsis_mask, - uint32_t new_axis_mask, uint32_t shrink_axis_mask) { - MS_ASSERT(begin != nullptr); - MS_ASSERT(end != nullptr); - MS_ASSERT(stride != nullptr); - constexpr int support_dims = 8; - std::bitset begin_list(begin_mask); - std::bitset end_list(end_mask); - std::bitset ellipsis_list(ellipsis_mask); - std::bitset new_axis_list(new_axis_mask); - std::bitset shrink_list(shrink_axis_mask); - - std::string begin_list_s = begin_list.to_string().substr(support_dims - begin->size()); - reverse(begin_list_s.begin(), begin_list_s.end()); - - std::string end_list_s = end_list.to_string().substr(support_dims - end->size()); - reverse(end_list_s.begin(), end_list_s.end()); - - std::string ellipsis_list_s = ellipsis_list.to_string().substr(support_dims - end->size()); - reverse(ellipsis_list_s.begin(), ellipsis_list_s.end()); - - std::string new_axis_list_s = new_axis_list.to_string().substr(support_dims - end->size()); - reverse(new_axis_list_s.begin(), new_axis_list_s.end()); - - std::string shrink_list_s = shrink_list.to_string().substr(support_dims - end->size()); - reverse(shrink_list_s.begin(), shrink_list_s.end()); - - int new_axis_count = new_axis_list.count(); - if (ellipsis_list.any()) { - auto idx = 0; // ellipsis_list._Find_first(); - // the 1 is ellipsis - int ellipsis_length = input_shape.size() - (begin->size() - 1 - new_axis_count); - begin->erase(begin->begin() + idx); - end->erase(end->begin() + idx); - stride->erase(stride->begin() + idx); - - begin_list_s.erase(idx, 1); - end_list_s.erase(idx, 1); - ellipsis_list_s.erase(idx, 1); - new_axis_list_s.erase(idx, 1); - shrink_list_s.erase(idx, 1); - - if (ellipsis_length > 0) { - begin->insert(begin->begin() + idx, ellipsis_length, 0); - end->insert(end->begin() + idx, ellipsis_length, 0); - stride->insert(stride->begin() + idx, ellipsis_length, 1); - begin_list_s.insert(idx, ellipsis_length, '1'); - end_list_s.insert(idx, ellipsis_length, '1'); - ellipsis_list_s.insert(idx, ellipsis_length, '0'); - new_axis_list_s.insert(idx, ellipsis_length, '0'); - shrink_list_s.insert(idx, ellipsis_length, '0'); - } - } - - if (new_axis_count) { - for (int i = static_cast(new_axis_list_s.size()) - 1; i >= 0; i--) { - if (new_axis_list_s[i] == '1') { - begin->erase(begin->begin() + i); - end->erase(end->begin() + i); - stride->erase(stride->begin() + i); - begin_list_s.erase(i, 1); - end_list_s.erase(i, 1); - shrink_list_s.erase(i, 1); - } - } - } - - unsigned int size = begin->size(); - for (unsigned int i = 0; i < size; i++) { - if (shrink_list_s[i] == '1') { - auto beginItr = (begin->begin() + i); - auto endItr = (end->begin() + i); - auto strideItr = (stride->begin() + i); - *endItr = *beginItr + 1; - *strideItr = 1; - continue; - } - if (begin_list_s[i] == '1') { - auto beginItr = (begin->begin() + i); - *beginItr = 0; - } - if (end_list_s[i] == '1') { - auto endItr = (end->begin() + i); - *endItr = input_shape[i]; - } - } -} - -#define MAXDIMS 10 -template -int StridedSlice(const std::vector &input_shape, T *input, T *output, int *start, int *end, int *stride, - const int &output_size) { - MS_ASSERT(input != nullptr); - MS_ASSERT(output != nullptr); - MS_ASSERT(start != nullptr); - MS_ASSERT(end != nullptr); - MS_ASSERT(stride != nullptr); - int dimension = input_shape.size(); - if (dimension == 1) { - if (*stride == 1) { - int sizes = (*end - *start) * sizeof(T); - errno_t ret = memcpy_s(output, output_size * sizeof(T), input + *start, sizes); - if (ret != 0) { - MS_LOGE("memset_s failed: %d", ret); - return ret; - } - return 0; - } - for (int j = *start, i = 0; j < *end; j += (*stride), i++) { - output[i] = input[j]; - } - return 0; - } - - // adapt higher dimension - int dimensionArray[MAXDIMS]; - int factorArray[MAXDIMS]; - int totalElement = 0; - - for (int i = 0; i < dimension; i++) { - dimensionArray[i] = input_shape[i]; - factorArray[i] = i ? factorArray[i - 1] * dimensionArray[i] : dimensionArray[i]; - totalElement = i ? totalElement * dimensionArray[i] : dimensionArray[i]; - } - - int j = 0; - for (int k = 0; k < totalElement; k++) { - bool isValid = true; - for (int i = 0; i < dimension; i++) { - int tmp = (k / (totalElement / factorArray[i])) % dimensionArray[i]; - if (tmp < start[i] || tmp >= end[i]) { - isValid = false; - break; - } - isValid = isValid && ((tmp - start[i]) % stride[i] == 0); - } - if (isValid) { - output[j++] = input[k]; - } - } - - return 0; -} - -static runnerType GetKernel_StridedSlice(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto ndim = opAttr.ndim; - - auto op = opdef.attr_as_StridedSlice(); - if (op == nullptr) { - MS_LOGE("op is nullptr"); - return nullptr; - } - uint32_t begin_mask = op->beginMask(); - uint32_t end_mask = op->endMask(); - uint32_t ellipsis_mask = op->ellipsisMask(); - uint32_t new_axis_mask = op->newAxisMask(); - uint32_t shrink_axis_mask = op->shrinkAxisMask(); - std::vector begin; - std::vector end; - std::vector stride; - for (uint32_t i = 0; i < ndim; ++i) { - begin.push_back(op->begin()->Get(i)); - end.push_back(op->end()->Get(i)); - stride.push_back(op->stride()->Get(i)); - } - - auto runner = [begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, begin, end, - stride](const std::vector &tensors) mutable -> int { - auto input = tensors.front(); - auto output = tensors.back(); - std::vector input_shape; - for (int i = 0; i < input->ndim; ++i) { - input_shape.push_back(input->shape[i]); - } - - int output_size = 1; - for (int i = 0; i < output->ndim; ++i) { - output_size *= output->shape[i]; - } - - StridedSliceArgs(input_shape, &begin, &end, &stride, begin_mask, end_mask, ellipsis_mask, new_axis_mask, - shrink_axis_mask); - - if (input->dtype.lanes != 1) { - MS_LOGE("StridedSlice input.dtype.lanes=%d invalid, only support 1", input->dtype.lanes); - return 1; - } - - if (input->dtype.bits == 16) { - StridedSlice(input_shape, reinterpret_cast(input->data), reinterpret_cast(output->data), - begin.data(), end.data(), stride.data(), output_size); - } else if (input->dtype.bits == 32) { - StridedSlice(input_shape, reinterpret_cast(input->data), reinterpret_cast(output->data), - begin.data(), end.data(), stride.data(), output_size); - } else { - MS_LOGE("StridedSlice input.dtype.bits=%d invalid, only support (16, 32)", input->dtype.bits); - return 1; - } - return 0; - }; - return runner; -} - -template -static void Permute4d(DLTensor *src, DLTensor *dst, const std::vector &shape, - const std::vector &strides) { - MS_ASSERT(src != nullptr); - MS_ASSERT(dst != nullptr); - int64_t N = shape[NCHW_N]; - int64_t C = shape[NCHW_C]; - int64_t H = shape[NCHW_H]; - int64_t W = shape[NCHW_W]; - auto src_data = reinterpret_cast(src->data); - auto dst_data = reinterpret_cast(dst->data); - int k = 0; - for (int n = 0; n < N; n++) - for (int c = 0; c < C; c++) - for (int h = 0; h < H; h++) - for (int w = 0; w < W; w++) { - dst_data[k++] = src_data[n * strides[0] + c * strides[1] + h * strides[2] + w * strides[3]]; - } -} - -static runnerType GetKernel_CaffeCrop(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_CaffeCrop(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_CaffeCrop() is nullptr"); - return nullptr; - } - std::string fid = - "CaffeCrop_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis" + std::to_string(op->axis()); - - std::vector offsets(op->offsets()->size()); - for (size_t i = 0; i < offsets.size(); ++i) { - offsets[i] = op->offsets()->Get(i); - } - return GetKernel_Insert_vector_int32(fid, offsets); -} - -static runnerType GetKernel_CaffePReLU(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_CaffePReLU(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_CaffePReLU() is nullptr"); - return nullptr; - } - std::string fid = "CaffePReLU_ndim4_" + opAttr.dtype; - fid += (op->channelShared()) ? "_channelShared" : "_channelNotShared"; - auto fun = GetKernel(fid); - if (fun == nullptr) { - return nullptr; - } - bool reahpe_output = true; - return Pack_reahpeCHW(fun, tensors, ExpandCHW, reahpe_output); -} - -static runnerType GetKernel_FullConnection(const mindspore::predict::OpDef &opdef, - const std::vector &tensors, const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_FullConnection(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_FullConnection() is nullptr"); - return nullptr; - } - std::string fid = "FullConnection_ndimA2_" + opAttr.dtype; - fid += (op->hasBias()) ? "_hasBias" : "_notHasBias"; - auto fun = GetKernel(fid); - if (fun == nullptr) { - return nullptr; - } - bool reahpe_output = false; - return Pack_reahpeCHW(fun, tensors, FusedCHW, reahpe_output); -} - -static runnerType GetKernel_Power(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Power(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Power() is nullptr"); - return nullptr; - } - std::string fid = "Power_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype; - std::vector pss; - pss.push_back(op->power()); - pss.push_back(op->scale()); - pss.push_back(op->shift()); - return GetKernel_Insert_vector_float(fid, pss); -} - -static runnerType GetKernel_ArgMax(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_ArgMax(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_ArgMax() is nullptr"); - return nullptr; - } - std::string fid = - "ArgMax_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis" + std::to_string(op->axis()); - fid += (op->keepDims()) ? "_keepDims" : "_notKeepDims"; - fid += "_top1"; - if (tensors.back() == nullptr) { - MS_LOGE("tensors.back() is nullptr"); - return nullptr; - } - fid += "_" + get_dtype(*tensors.back()); - return GetKernel(fid); -} - -void update_shape_Concat(const std::vector &tensors, int32_t axis, int64_t (&shape)[TENSOR_NUM_MAX][3], - int64_t (&strides)[TENSOR_NUM_MAX][3]) { - int64_t shape_low_dim = 1; - int64_t shape_high_dim = 1; - auto output = tensors.back(); - if (output == nullptr) { - MS_LOGW("tensors.back() is nullptr"); - return; - } - for (int32_t i = 0; i < axis; ++i) { - shape_high_dim *= output->shape[i]; - } - for (int32_t i = axis + 1; i < output->ndim; ++i) { - shape_low_dim *= output->shape[i]; - } - for (size_t i = 0; i < tensors.size(); ++i) { - shape[i][0] = shape_high_dim; - shape[i][1] = tensors[i]->shape[axis]; - shape[i][2] = shape_low_dim; - - strides[i][2] = 1; - strides[i][1] = shape[i][2]; - strides[i][0] = shape[i][2] * shape[i][1]; - - tensors[i]->ndim = 3; - tensors[i]->shape = shape[i]; - tensors[i]->strides = strides[i]; - } -} - -static runnerType GetKernel_Concat(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - if (tensors.size() < 2) { - MS_LOGE("Concat should have at least two tensors"); - return nullptr; - } - if (tensors.at(0) == nullptr) { - MS_LOGE("0th tensors of Concat is nullptr"); - return nullptr; - } - auto ndim = tensors.at(0)->ndim; - if (opdef.attr_as_Concat() == nullptr) { - MS_LOGE("opdef.attr_as_Concat() is nullptr"); - return nullptr; - } - auto axis = opdef.attr_as_Concat()->axis(); - if (axis < 0) { - axis += ndim; - } - std::string fid = - "Concat_ndim3_" + opAttr.dtype + "_input_num" + std::to_string(static_cast(tensors.size()) - 1) + "_axis1"; - auto fun = GetKernel(fid); - if (fun == nullptr) { - MS_LOGE("GetKernel return nullptr"); - return nullptr; - } - auto runner = [fun, axis](const std::vector &tensors) -> int { - int ndim[TENSOR_NUM_MAX]; - int64_t *shape[TENSOR_NUM_MAX]; - int64_t *strides[TENSOR_NUM_MAX]; - int64_t shape_C[TENSOR_NUM_MAX][3]; - int64_t strides_C[TENSOR_NUM_MAX][3]; - store_shape(tensors, ndim, shape, strides, STORE_MODE); - update_shape_Concat(tensors, axis, shape_C, strides_C); - fun(tensors); - store_shape(tensors, ndim, shape, strides, RESUME_MODE); - return 0; - }; - return runner; -} - -template -void Stack_ScaleNumber(const std::vector &tensors) { - if (tensors.empty()) { - MS_LOGW("input tensors is nullptr"); - return; - } - auto output = tensors.back(); - if (output != nullptr) { - MS_LOGW("tensors.back() is nullptr"); - return; - } - for (int i = 0; i < static_cast(tensors.size()) - 1; i++) { - reinterpret_cast(output->data)[i] = reinterpret_cast(tensors.at(i)->data)[0]; - } -} - -static runnerType GetKernel_Stack(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Stack(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Stack() is nullptr"); - return nullptr; - } - if (op->isScale()->Get(0)) { - auto runner = [](const std::vector &tensors) -> int { - auto input = tensors.front(); - if (input->dtype.bits == 8) { - Stack_ScaleNumber(tensors); - } else if (input->dtype.bits == 16) { - Stack_ScaleNumber(tensors); - } else if (input->dtype.bits == 32) { - Stack_ScaleNumber(tensors); - } else if (input->dtype.bits == 64) { - Stack_ScaleNumber(tensors); - } else { - MS_LOGE("StridedSlice input.dtype.bits=%d invalid, only support (8, 16, 32, 64)", input->dtype.bits); - return 1; - } - return 0; - }; - return runner; - } - std::string fid = "Stack_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_input_num" + - std::to_string(static_cast(tensors.size()) - 1) + "_axis" + std::to_string(op->axis()); - return GetKernel(fid); -} - -static runnerType GetKernel_Pad(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Pad(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Pad() is nullptr"); - return nullptr; - } - std::string fid = "Pad_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_" + - mindspore::predict::EnumNamePaddingMode(op->paddingmode()); - std::vector paddings(op->paddings()->size()); - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i] = op->paddings()->Get(i); - } - return GetKernel_Insert_vector_int32(fid, paddings); -} - -static runnerType GetKernel_Pooling(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Pooling(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Pooling() is nullptr"); - return nullptr; - } - auto H = tensors.front()->shape[NCHW_H]; - auto W = tensors.front()->shape[NCHW_W]; - auto padUp = op->padUp(); - auto padDown = op->padDown(); - auto padLeft = op->padLeft(); - auto padRight = op->padRight(); - bool useGlobal = false; - if (H == op->windowH() && W == op->windowW()) { - useGlobal = true; - } - if (op->padMode() != mindspore::predict::PadMode_VALID) { - int64_t outputHeight = tensors.back()->shape[NCHW_H]; - int64_t outputWidth = tensors.back()->shape[NCHW_W]; - - int64_t dHeight = (outputHeight - 1) * op->strideH() + op->windowH() - H; - int64_t dWidth = (outputWidth - 1) * op->strideW() + op->windowW() - W; - padUp = dHeight / 2; - padDown = dHeight - dHeight / 2; - padLeft = dWidth / 2; - padRight = dWidth - dWidth / 2; - if (padDown < 0) { - padDown = 0; - } - if (padRight < 0) { - padRight = 0; - } - } - std::string poolingMode = mindspore::predict::EnumNamesPoolMode()[op->poolingMode()]; - if (poolingMode != "MAX_POOLING" && poolingMode != "MEAN_POOLING") { - MS_LOGE("Pooling op not support poolingMode=%s", poolingMode.c_str()); - return nullptr; - } - - std::string fid; - fid += useGlobal ? "GlobalPooling" : "Pooling"; - fid += "_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype; - fid += (poolingMode == "MAX_POOLING") ? "_max" : "_avg"; - if (!useGlobal) { - fid += "_kernel" + std::to_string(op->windowH()) + std::to_string(op->windowW()); - fid += "_stride" + std::to_string(op->strideH()) + std::to_string(op->strideW()); - fid += - "_pad" + std::to_string(padUp) + std::to_string(padDown) + std::to_string(padLeft) + std::to_string(padRight); - if (op->caffeMode() && (padUp || padDown || padLeft || padRight)) fid += "_caffe"; - } - auto fun = GetKernel(fid); - if (fun == nullptr) { - MS_LOGE("GetKernel return nullptr"); - return nullptr; - } - return (opdef.isLastConv()) ? Pack_NCHW2NHWC(fun) : fun; -} - -static runnerType GetKernel_Mean(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Mean(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Mean() is nullptr"); - return nullptr; - } - std::string axis_str = ""; - for (uint32_t i = 0; i < op->axis()->size(); ++i) { - axis_str += std::to_string(op->axis()->Get(i)); - } - std::string fid = "Mean_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis" + axis_str; - fid += (op->keepDims()) ? "_keepDims" : "_notkeepDims"; - return GetKernel(fid); -} - -static runnerType GetKernel_MatMul(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_MatMul(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_MatMul() is nullptr"); - return nullptr; - } - std::string fid = "MatMul_ndimA2_ndimB2_" + opAttr.dtype; - fid += (op->transposeA()) ? "_1" : "_0"; - fid += (op->transposeB()) ? "_1" : "_0"; - return GetKernel(fid); -} - -static runnerType GetKernel_Softmax(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_SoftMax(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_SoftMax() is nullptr"); - return nullptr; - } - std::string fid = - "Softmax_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis" + std::to_string(op->axis()->Get(0)); - return GetKernel(fid); -} - -void update_shape_Activation(const std::vector &tensors, int64_t *shape, int64_t *strides) { - auto input = tensors.front(); - MS_ASSERT(input != nullptr); - for (int32_t i = 0; i < input->ndim; ++i) { - *shape *= input->shape[i]; - } - for (size_t i = 0; i < tensors.size(); ++i) { - MS_ASSERT(tensors[i] != nullptr); - tensors[i]->ndim = 1; - tensors[i]->shape = shape; - tensors[i]->strides = strides; - } -} - -static runnerType GetKernel_Activation(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Activation(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Activation() is nullptr"); - return nullptr; - } - std::string fid = - "Activation_ndim1_" + opAttr.dtype + "_" + std::string(mindspore::predict::EnumNameActivationType(op->type())); - - auto fun = GetKernel(fid); - if (fun == nullptr) { - MS_LOGE("GetKernel return nullptr"); - return nullptr; - } - auto runner = [fun](const std::vector &tensors) -> int { - int ndim[TENSOR_NUM_MAX]; - int64_t *shapes[TENSOR_NUM_MAX]; - int64_t *strides[TENSOR_NUM_MAX]; - int64_t shape = 1; - int64_t stride = 1; - - store_shape(tensors, ndim, shapes, strides, STORE_MODE); - update_shape_Activation(tensors, &shape, &stride); - fun(tensors); - store_shape(tensors, ndim, shapes, strides, RESUME_MODE); - return 0; - }; - return runner; -} - -static runnerType GetKernel_Exp(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - std::string fid = "Exp_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype; - return GetKernel(fid); -} - -static runnerType GetKernel_Cast(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - MS_ASSERT(tensors.front() != nullptr); - MS_ASSERT(tensors.back() != nullptr); - auto src_dtype = get_dtype(*tensors.front()); - auto dst_dtype = get_dtype(*tensors.back()); - std::string fid = "Cast_ndim" + std::to_string(tensors.front()->ndim) + "_" + src_dtype + "_" + dst_dtype; - return GetKernel(fid); -} - -static runnerType GetKernel_ExpandDims(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_ExpandDims(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_ExpandDims() is nullptr"); - return nullptr; - } - std::string fid = - "ExpandDims_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype + "_axis" + std::to_string(op->dim()); - return GetKernel(fid); -} - -static runnerType GetKernel_Tile(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Tile(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Tile() is nullptr"); - return nullptr; - } - std::string fid = "Tile_ndim" + std::to_string(opAttr.ndim) + "_" + opAttr.dtype; - std::vector multiples; - for (size_t i = 0; i < op->multiples()->size(); ++i) { - multiples.push_back(op->multiples()->Get(i)); - } - return GetKernel_Insert_vector_int32(fid, multiples); -} - -static runnerType GetKernel_Range(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - const OpCommonAttr opAttr(opdef, tensors); - auto op = opdef.attr_as_Range(); - if (op == nullptr) { - MS_LOGE("opdef.attr_as_Range() is nullptr"); - return nullptr; - } - std::string fid = "Range_ndim_" + opAttr.dtype; - std::vector vec = {static_cast(op->start()), static_cast(op->delta())}; - - auto f = GetFunction(fid); - if (f == nullptr) { - MS_LOGE("GetFunction returu nullptr"); - return nullptr; - } - auto runner = [f, vec](const std::vector &tensors_origin) -> int { - // remove 3 input, only remain output - const std::vector tensors = {tensors_origin.back()}; - std::vector values(vec.size() + tensors.size()); - std::vector codes(values.size()); - tvm::runtime::TVMArgsSetter setter(values.data(), codes.data()); - for (size_t i = 0; i < vec.size(); ++i) { - setter(i, vec.at(i)); - } - for (size_t i = 0; i < tensors.size(); ++i) { - setter(i + vec.size(), tensors.at(i)); - } - auto passfunc = reinterpret_cast(f); - return passfunc(values.data(), codes.data(), values.size()); - }; - return runner; -} - -using GetKernelFunType = std::function &tensors, const KernelOption &option)>; - -static const std::unordered_map g_kernel_op_list = { - {"Conv2D", GetKernel_Conv}, - {"DepthwiseConv2D", GetKernel_Conv}, - {"DeDepthwiseConv2D", GetKernel_Conv}, - {"DeConv2D", GetKernel_Conv_var}, - {"BatchNorm", GetKernel_BatchNorm}, - {"CaffeBatchNorm", GetKernel_BatchNorm}, - {"BiasAdd", GetKernel_BatchNorm}, - {"Scale", GetKernel_BatchNorm}, - {"Eltwise", GetKernel_Eltwise}, - {"Add", GetKernel_Eltwise}, - {"Sub", GetKernel_Eltwise}, - {"Mul", GetKernel_Eltwise}, - {"RealDiv", GetKernel_Eltwise}, - {"Maximum", GetKernel_Eltwise}, - {"ResizeBilinear", GetKernel_Resize}, - {"ResizeNearestNeighbor", GetKernel_Resize}, - {"Squeeze", GetKernel_DataCarry}, - {"Reshape", GetKernel_DataCarry}, - {"Shape", GetKernel_Shape}, - {"StridedSlice", GetKernel_StridedSlice}, - {"CaffeCrop", GetKernel_CaffeCrop}, - {"CaffePReLU", GetKernel_CaffePReLU}, - {"FullConnection", GetKernel_FullConnection}, - {"Power", GetKernel_Power}, - {"ArgMax", GetKernel_ArgMax}, - {"Concat", GetKernel_Concat}, - {"Stack", GetKernel_Stack}, - {"Pad", GetKernel_Pad}, - {"Pooling", GetKernel_Pooling}, - {"Mean", GetKernel_Mean}, - {"MatMul", GetKernel_MatMul}, - {"SoftMax", GetKernel_Softmax}, - {"Activation", GetKernel_Activation}, - {"Exp", GetKernel_Exp}, - {"Cast", GetKernel_Cast}, - {"ExpandDims", GetKernel_ExpandDims}, - {"Tile", GetKernel_Tile}, - {"Range", GetKernel_Range}, -}; - -GetKernelFunType Get_GetKernelFun(const std::string &optype) { - auto it = g_kernel_op_list.find(optype); - return (it != g_kernel_op_list.end() ? it->second : nullptr); -} -} // namespace auto_tensor - -runnerType GetKernel(const mindspore::predict::OpDef &opdef, const std::vector &tensors, - const KernelOption &option) { - std::string optype = mindspore::predict::EnumNameOpT(opdef.attr_type()); - auto GetKernelFun = auto_tensor::Get_GetKernelFun(optype); - if (GetKernelFun != nullptr) { - return GetKernelFun(opdef, tensors, option); - } else { - return nullptr; - } -} - -int CallKernel(const std::string &fid, const std::vector &tensors) { - auto runner = GetKernel(fid); - return runner(tensors); -} - -#endif // LITE_RUNTIME_ON diff --git a/predict/module/tvm_kernel/lite/src/api/tvm_op_module.cc b/predict/module/tvm_kernel/lite/src/api/tvm_op_module.cc deleted file mode 100644 index c6a6fc93f5..0000000000 --- a/predict/module/tvm_kernel/lite/src/api/tvm_op_module.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this ${file} except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "src/api/tvm_op_module.h" -#include "common/op_utils.h" -#include "lite/api/km_api.h" -#include "src/op.h" -namespace mindspore { -namespace predict { -using OpFunc = std::function &)>; - -class TVMOperator : public OpBase { - public: - explicit TVMOperator(OpFunc func) : opfunc(std::move(func)) {} - ~TVMOperator() override = default; - int Init(const std::vector &inputs, const std::vector &outputs) override { return 0; } - int Execute(const std::vector &inputs, const std::vector &outputs) override { - std::vector dlT; - for (auto input : inputs) { - MS_ASSERT(input != nullptr); - dlT.push_back(input->GetDLTensor()); - } - - for (auto output : outputs) { - MS_ASSERT(output != nullptr); - dlT.push_back(output->GetDLTensor()); - } - return opfunc(dlT); - } - - static OpBase *CreateOp(const std::vector &inputs, const std::vector &outputs, const OpDef &opDef, - const Context &ctx, const OpDesc &desc) { - std::vector dlT; - for (auto input : inputs) { - MS_ASSERT(input != nullptr); - dlT.push_back(input->GetDLTensor()); - } - - for (auto output : outputs) { - MS_ASSERT(output != nullptr); - dlT.push_back(output->GetDLTensor()); - } - - KernelOption option; - option.numThreads = ctx.threadNum; - OpFunc opFunc = GetKernel(opDef, dlT, option); - if (opFunc != nullptr) { - auto op = std::unique_ptr(new (std::nothrow) TVMOperator(opFunc)); - if (op == nullptr) { - MS_LOGE("new TVMOperator failed"); - } - return op.release(); - } - return nullptr; - } - - private: - OpFunc opfunc; - std::vector dltensors; -}; - -TVMOpRegistry::TVMOpRegistry() = default; - -mindspore::predict::OpCreator TVMOpRegistry::GetOpCreator(const mindspore::predict::OpDesc &desc) { - return TVMOperator::CreateOp; -} - -OpRegistry *TVMOpModule::GetInstance() { - static TVMOpRegistry tvmOpRegistry; - return &tvmOpRegistry; -} - -static TVMOpModule tvmOpModule; - -static ModuleRegistrar> g_tvmOpReg(MODULE_REG_NAME_OP_REGISTRY, tvmOpModule); -} // namespace predict -} // namespace mindspore diff --git a/predict/module/tvm_kernel/lite/src/api/tvm_op_module.h b/predict/module/tvm_kernel/lite/src/api/tvm_op_module.h deleted file mode 100644 index 3f0f33142d..0000000000 --- a/predict/module/tvm_kernel/lite/src/api/tvm_op_module.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this ${file} except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef PREDICT_MODULE_TVM_KERNEL_LITE_SRC_API_TVM_OP_MODULE_H_ -#define PREDICT_MODULE_TVM_KERNEL_LITE_SRC_API_TVM_OP_MODULE_H_ - -#include "src/op_registry.h" -namespace mindspore { -namespace predict { -class TVMOpRegistry : public OpRegistry { - public: - TVMOpRegistry(); - OpCreator GetOpCreator(const OpDesc &desc) override; -}; - -class TVMOpModule : public Module { - public: - OpRegistry *GetInstance() override; -}; -} // namespace predict -} // namespace mindspore -#endif // PREDICT_MODULE_TVM_KERNEL_LITE_SRC_API_TVM_OP_MODULE_H_ diff --git a/predict/module/tvm_kernel/lite/src/codegen/llvm/lite_rtfunc_reset.cc b/predict/module/tvm_kernel/lite/src/codegen/llvm/lite_rtfunc_reset.cc deleted file mode 100644 index df19676b40..0000000000 --- a/predict/module/tvm_kernel/lite/src/codegen/llvm/lite_rtfunc_reset.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this ${file} except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "tvm/api_registry.h" - -namespace tvm { -namespace codegen { -class LiteRTFuncReseter { - public: - LiteRTFuncReseter() {} - ~LiteRTFuncReseter() {} - int InsertFuncPair(std::string sfunc, std::string dfunc) { - CHECK_NE(sfunc.size(), 0); - CHECK_NE(dfunc.size(), 0); - func_map_[sfunc] = dfunc; - return 0; - } - - /* - * the llvm::Function::Create need a longe life scope const char* as input - * so here not return block life scopte tmp std::string. - */ - const char* GetResetFunc(std::string sfunc) { - CHECK_NE(sfunc.size(), 0); - auto it_dfunc = func_map_.find(sfunc); - if (it_dfunc != func_map_.end()) { - return it_dfunc->second.c_str(); - } else { - func_map_[sfunc] = sfunc; - return func_map_[sfunc].c_str(); - } - } - - /* - * not real delete item paire, just set orig function pair - */ - int DeleteFuncPair(std::string sfunc) { - CHECK_NE(sfunc.size(), 0); - func_map_[sfunc] = sfunc; - return 0; - } - static LiteRTFuncReseter* GetRTFuncReseter() { - static LiteRTFuncReseter inst; - return &inst; - } - - private: - std::map func_map_; -}; - -TVM_REGISTER_API("codegen.SetRTFuncTransPair").set_body([](const TVMArgs& targs, TVMRetValue* rv) { - *rv = LiteRTFuncReseter::GetRTFuncReseter()->InsertFuncPair(targs[0], targs[1]); -}); - -TVM_REGISTER_API("codegen.DelRTFuncTransPair").set_body([](const TVMArgs& targs, TVMRetValue* rv) { - *rv = LiteRTFuncReseter::GetRTFuncReseter()->DeleteFuncPair(targs[0]); -}); - -/* - * now no operator=(const char* ) provide for TVMRetValue - * here using explicit operator call function to make sure not using operator=(int) - */ -TVM_REGISTER_API("codegen.GetTransRTFunc").set_body([](const TVMArgs& targs, TVMRetValue* rv) { - (*rv).operator=( - reinterpret_cast(const_cast(LiteRTFuncReseter::GetRTFuncReseter()->GetResetFunc(targs[0])))); -}); -} // namespace codegen -} // namespace tvm diff --git a/predict/schema/inner/README b/predict/schema/inner/README deleted file mode 100644 index f397efe4c5..0000000000 --- a/predict/schema/inner/README +++ /dev/null @@ -1 +0,0 @@ -for flatbuf auto generated file diff --git a/predict/schema/ms.fbs b/predict/schema/ms.fbs deleted file mode 100644 index f66abf8a86..0000000000 --- a/predict/schema/ms.fbs +++ /dev/null @@ -1,153 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -include "op.fbs"; - -namespace mindspore.predict; - -enum DataType : int { - DT_FLOAT = 0, - DT_FLOAT16 = 1, - DT_INT8 = 2, - DT_INT32 = 3, - DT_UINT8 = 4, - DT_UINT32 = 8, - DT_UNDEFINED = 16 -} - -enum Format : int { - NCHW = 0, - NHWC, - NC4HW4 = 100, - NUM_OF_FORMAT -} - -enum MSConst: int { - WEIGHT_REFCOUNT = 999 -} - -table QuantizationDef { - // Quantized value q, corresponding float value r: - // r = scale * (q - zero_point), where scale = (rmax - rmin) / (qmax - qmin) - min: [float]; - max: [float]; - scale: [float]; - zero_point: [long]; - - // Tensor shape of the specifies dimension. - dimension: int; -} - -table TensorDef { - // data type - dataType: DataType; - // shape - dims: [int]; - format: Format; - refCount: int; - offset: int; - data: [ubyte]; - quantization: QuantizationDef; -} - -union OpT { - Concat, - SoftMax, - Activation, - Conv2D, - FusedBatchNorm, - CaffeBatchNorm, - Squeeze, - BiasAdd, - Pooling, - DepthwiseConv2D, - DeDepthwiseConv2D, - Resize, - DetectionPostProcess, - FullConnection, - Mean, - DeConv2D, - Scale, - Reshape, - Eltwise, - NetOutput, - Add, - MatMul, - StridedSlice, - Power, - Slice, - Stack, - Mul, - Pad, - Maximum, - CaffePReLU, - ArgMax, - Exp, - CaffeCrop, - Range, - ExpandDims, - Tile, - Cast -// Split -} - -enum QuantType: int { - QUANT_NONE, - QUANT_INT8 -} - -table OpDef { - name: string; - attr: OpT; - inputIndex: [uint]; - outputIndex: [uint]; - isLastConv: bool; - quantType: QuantType = QUANT_NONE; -} - - -enum FmkType: int { - TF, - CAFFE -} - -table NodeDef { - fmkType: FmkType; - opDef: OpDef; -} - - -table SubGraphDef { - name: string; - inputIndex: [uint]; - outputIndex: [uint]; - mempoolSize: uint; - nodes: [NodeDef]; - allTensors: [TensorDef]; // weight + input + output -} - -table MempoolCfg { - size: uint; - shiftFactor: uint; -} - -table GraphDef { - name: string; - mempoolCfg: MempoolCfg; - subgraphs: [SubGraphDef]; -} - -root_type GraphDef; diff --git a/predict/schema/op.fbs b/predict/schema/op.fbs deleted file mode 100755 index 4d01bb9c3b..0000000000 --- a/predict/schema/op.fbs +++ /dev/null @@ -1,351 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -namespace mindspore.predict; - -enum ResizeMethod: byte { - UNKNOW = -1, - BILINEAR = 0, - NEAREST_NEIGHBOR = 1 -} - -enum DataFormatType : byte { - UNKNOW = -1, - NCHW = 0, - NHWC = 1, - HWC = 2, // for input image or resize - CHW = 3, // for input image or resize -} - -enum ActivationType : byte { - NO_ACTIVATION = 0, - RELU = 1, - SIGMOID = 2, - RELU6 = 3, - ELU = 4, - LEAKY_RELU = 5, - ABS = 6, - RELU1 = 7, - SOFTSIGN = 8, - SOFTPLUS = 9, - TANH = 10, - UNKNOW = 11 -} - -enum PoolMode : byte { - MAX_POOLING = 0, - MEAN_POOLING = 1, - GLOBAL_POOING = 2 -} - -enum EltwiseMode : byte { - PROD = 0, - SUM = 1, - MAXIMUM = 2 -} - -enum PadMode : byte { - NOTSET=0, - SAME=1, - VALID=2, - CAFFE_CEIL_NEW=4 -} - -enum PaddingMode : byte { - CONSTANT = 0, - REFLECT = 1, - SYMMETRIC = 2, - MODE_RESERVED = 3 -} - -table Pad { - paddingmode: PaddingMode; - paddings: [int]; -} - -table Maximum { - format: DataFormatType = 0; -} - -table Concat { - axis: int; - n: int; -} - -table SoftMax { - axis: [int]; -} - -table Activation { - type: ActivationType = 0; -} - -table Conv2D { - format: DataFormatType = 0; - group: int; - channelIn: int; - channelOut: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - -table FusedBatchNorm { - epsilon: float; // eg. epsilon=0.001 -} - -table CaffeBatchNorm { - epsilon: float; // eg. epsilon=0.001 -} - -table Squeeze { - axis: [int]; -} - -table BiasAdd { - axis: [int]; -} - -table Pooling { - format: DataFormatType = 0; - poolingMode: PoolMode; - windowW: int; - windowH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - caffeMode: bool = false; -} - -table DepthwiseConv2D { - format: DataFormatType = 0; - channelIn: int; - channelMultiplier: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - -table DeDepthwiseConv2D { - format: DataFormatType = 0; - channelIn: int; - channelMultiplier: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - - -table Resize { - format: DataFormatType = 0; - method: ResizeMethod; - newHeight: long; - newWidth: long; - alignCorners: bool = false; - preserveAspectRatio: bool = false; -} - -table DetectionPostProcess { - format: DataFormatType = 0; - inputSize: int; - hScale: float; - wScale: float; - xScale: float; - yScale: float; - NmsIouThreshold: float; - NmsScoreThreshold: float; - MaxDetections: long; - DetectionsPreClass: long; - MaxClassesPreDetection: long; - NumClasses: long; - UseRegularNms: bool; -} - -table FullConnection { - format: DataFormatType = 0; - hasBias: bool; - axis: int; -} - -// Mean(input_tensor, axis, keep_dims) -table Mean { - axis: [int]; - keepDims: bool = false; -} - -table DeConv2D { - format: DataFormatType = 0; - group: int; - channelIn: int; - channelOut: int; - kernelW: int; - kernelH: int; - strideW: int; - strideH: int; - padMode: PadMode; - padUp: int; - padDown: int; - padLeft: int; - padRight: int; - dilateW: int; - dilateH: int; - hasBias: bool = false; - activationType: ActivationType = 0; -} - -table Scale { - format: DataFormatType = 0; -} - -table Eltwise { - format: DataFormatType = 0; - mode: EltwiseMode; -} - -table Add { - format: DataFormatType = 0; -} - -table Slice { - format: DataFormatType = 0; - begin: [int]; - end: [int]; - stride: [int]; -} - -table Mul { -} - -table Exp { -} - -table Reshape { - format: DataFormatType = 0; - shape: [long]; -} - -table Power { - power: float; - scale: float; - shift: float; -} - -table ArgMax { - axis: int; - outMaxValue: bool; - topK: int; - keepDims: bool; - axisType: int; -} - -table NetOutput { - format: DataFormatType = 0; -} - -table MatMul { - transposeA : bool = false; - transposeB : bool = false; -} - -table CaffePReLU { - channelShared : bool = false; -} - -table StridedSlice { - beginMask: int; - endMask: int; - ellipsisMask: int; - newAxisMask: int; - shrinkAxisMask: int; - begin: [int]; - end: [int]; - stride: [int]; - isScale: [int]; -} - -table Stack { - axis: int; - n: int; - isScale: [int]; -} - -table Range { - start: int; - limit: int; - delta: int; -} - -table ExpandDims { - dim: int; -} - -table Tile { - multiples: [int]; -} - -table Cast { - srcT: int; - dstT: int; -} - -table Split { - numberSplit: int; - sizeSplits: [int]; - splitDim: int; -} - -table CaffeCrop { - axis : long; - offsets : [long]; -} - -table Permute { - order: [long]; -} diff --git a/predict/src/CMakeLists.txt b/predict/src/CMakeLists.txt deleted file mode 100644 index 92c45473d7..0000000000 --- a/predict/src/CMakeLists.txt +++ /dev/null @@ -1,69 +0,0 @@ -cmake_minimum_required(VERSION 3.12) -project(mspredict) - -set(CMAKE_CXX_STANDARD 11) - -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/..) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../module/tvm_kernel/incubator-tvm/3rdparty/dlpack/include/) - -link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../output/lib/) - -if (ENABLE_ASAN) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -o0 -fsanitize=address -fno-omit-frame-pointer -fsanitize=undefined") -endif() - -set(MSPREDICT_SRC - runtime/allocator.cc - runtime/allocator.h - runtime/thread_pool.cc - runtime/thread_pool.h - runtime/workspace_pool.cc - runtime/workspace_pool.h - runtime/runtime_api.cc - runtime/runtime_api.h - context.cc - graph.cc - graph.h - graph_execution.cc - graph_execution.h - node.cc - node.h - op.cc - op.h - op_factory.cc - op_factory.h - op_registry.cc - op_registry.h - session.cc - tensor.cc - ${CMAKE_CURRENT_SOURCE_DIR}/operator/cpu/common/op_func_comm.cc) - -set(MSPREDICT_SRC ${MSPREDICT_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../common/mslog.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../common/module_registry.cc) - -add_library(mspredict SHARED ${MSPREDICT_SRC}) - -if(ENABLE_PREDICT_ARM64 OR ENABLE_PREDICT_ARM32) - target_link_libraries(mspredict android log tvm_kernel libsecurec.a) -else() - target_link_libraries(mspredict pthread tvm_kernel libsecurec.a) -endif() - -add_dependencies(mspredict tvm_kernel) -add_dependencies(mspredict securec) -add_dependencies(mspredict gtest) - -add_custom_command(TARGET mspredict POST_BUILD - COMMAND mkdir -pv ${PREDICT_DIR}/output/lib - COMMAND cp ${PREDICT_BUILD_DIR}/src/libmspredict.so ${PREDICT_DIR}/output/lib/ - COMMAND cp ${PREDICT_BUILD_DIR}/module/tvm_kernel/lite/libtvm_kernel.so ${PREDICT_DIR}/output/lib/ - COMMAND mkdir -pv ${PREDICT_DIR}/output/include - COMMAND cp -r ${PREDICT_DIR}/include/* ${PREDICT_DIR}/output/include - COMMAND mkdir -pv ${PREDICT_DIR}/output/include/schema/inner - COMMAND cp ${PREDICT_DIR}/schema/ms_generated.h ${PREDICT_DIR}/output/include/schema/inner - COMMAND cp ${PREDICT_DIR}/schema/op_generated.h ${PREDICT_DIR}/output/include/schema/inner - COMMAND mkdir -pv ${PREDICT_DIR}/output/include/dlpack/ - COMMAND cp ${PREDICT_DIR}/module/tvm_kernel/incubator-tvm/3rdparty/dlpack/include/dlpack/dlpack.h ${PREDICT_DIR}/output/include/dlpack/) diff --git a/predict/src/context.cc b/predict/src/context.cc deleted file mode 100644 index 45ec719db6..0000000000 --- a/predict/src/context.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/context.h" -#include "include/tensor.h" -#include "src/runtime/allocator.h" - -namespace mindspore { -namespace predict { -Context::Context() { allocator = Allocator::Create(); } - -Context::~Context() {} - -Context::Context(int threadNum, std::shared_ptr allocator, DLContext deviceCtx) { - this->allocator = allocator; - this->threadNum = threadNum; - this->deviceCtx = deviceCtx; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/graph.cc b/predict/src/graph.cc deleted file mode 100644 index 5e1aab6a56..0000000000 --- a/predict/src/graph.cc +++ /dev/null @@ -1,378 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/graph.h" -#include -#include -#include -#include -#include "schema/ms_generated.h" -#include "common/graph_util.h" -#include "common/mslog.h" -#include "include/errorcode.h" -#include "src/graph_execution.h" - -namespace mindspore { -namespace predict { -static const uint32_t G_MAX_OP_COUNT = 10000; - -Graph *Graph::CreateFromBuf(const char *buf, size_t size, const Context &ctx) { - if (buf == nullptr) { - MS_LOGE("the input buffer is nullptr"); - return nullptr; - } - - flatbuffers::Verifier verify((const uint8_t *)buf, size); - if (!VerifyGraphDefBuffer(verify)) { - MS_LOGE("the buffer is invalid and fail to create graph"); - return nullptr; - } - - auto graphDef = GetGraphDef(buf); - std::unique_ptr graph(new (std::nothrow) Graph()); - if (graph == nullptr) { - MS_LOGE("graph malloc fail"); - return nullptr; - } - auto ret = graph->Build(*graphDef, ctx); - if (ret != RET_OK) { - MS_LOGE("build graph fail"); - return nullptr; - } - return graph.release(); -} - -Graph::Graph() = default; - -Graph::~Graph() { - for (auto &subgraph : subgraphs) { - delete subgraph; - } - subgraphs.clear(); -} - -int Graph::Build(const GraphDef &graphDef, const Context &ctx) { - MS_ASSERT(graphDef.subgraphs() != nullptr); - for (size_t i = 0; i < graphDef.subgraphs()->size(); i++) { - MS_ASSERT(graphDef.subgraphs()->GetAs(i) != nullptr); - SubGraph *subGraph = SubGraph::CreateSubGraph(*(graphDef.subgraphs()->GetAs(i)), ctx); - if (subGraph == nullptr) { - MS_LOGE("converter subgraph failed"); - return RET_ERROR; - } - subgraphs.push_back(subGraph); - auto subDepends = subGraph->GetDepends(); - depends.insert(subDepends.begin(), subDepends.end()); - } - - auto iter = depends.begin(); - while (iter != depends.end()) { - if (iter->second.empty()) { - readyQue.push_back(iter->first); - iter = depends.erase(iter); - } else { - iter++; - } - } - - return RET_OK; -} - -std::vector Graph::GetInputs() { - MS_ASSERT(subgraphs.front() != nullptr); - return subgraphs.front()->GetInputs(); -} - -std::vector Graph::GetOutputs() { - MS_ASSERT(subgraphs.back() != nullptr); - return subgraphs.back()->GetOutputs(); -} - -std::map> &Graph::GetOutputsMap() { - MS_ASSERT(subgraphs.back() != nullptr); - return subgraphs.back()->GetOutputsMap(); -} - -void Graph::FreeAllTensors() { - for (auto iter : subgraphs) { - iter->FreeAllTensors(); - } -} - -std::vector *Graph::Subgraphs() { return &subgraphs; } - -SubGraph::SubGraph() = default; - -SubGraph::~SubGraph() { - for (auto iter = nodes.begin(); iter != nodes.end();) { - if (iter->second != nullptr) { - delete iter->second; - } - iter = nodes.erase(iter); - } - nodes.clear(); - - for (auto &allTensor : allTensors) { - if (allTensor != nullptr) { - delete allTensor; - } - } - allTensors.clear(); -} - -SubGraph *SubGraph::CreateSubGraph(const SubGraphDef &subGraphDef, const Context &ctx) { - std::unique_ptr subGraph(new (std::nothrow) SubGraph()); - if (subGraph == nullptr) { - MS_LOGE("subGraph malloc fail"); - return nullptr; - } - - auto ret = subGraph->Build(subGraphDef, ctx); - if (ret != RET_OK) { - MS_LOGE("subGraph Build fail"); - return nullptr; - } - - return subGraph.release(); -} - -int SubGraph::Build(const SubGraphDef &subGraphDef, const Context &ctx) { - int ret; - MS_ASSERT(subGraphDef.inputIndex() != nullptr); - ret = ConverterIndex(*(subGraphDef.inputIndex()), &inputIndices); - if (ret != RET_OK) { - MS_LOGE("ConverterIndex failed: %d", ret); - return ret; - } - MS_LOGD("converter inputIndex succ"); - - MS_ASSERT(subGraphDef.outputIndex() != nullptr); - ret = ConverterIndex(*(subGraphDef.outputIndex()), &outputIndices); - if (ret != RET_OK) { - MS_LOGE("ConverterIndex failed: %d", ret); - return ret; - } - MS_LOGD("converter outputIndex succ"); - MS_ASSERT(subGraphDef.allTensors() != nullptr); - ret = ConverterAllTensor(*(subGraphDef.allTensors())); - if (ret != RET_OK) { - MS_LOGE("ConverterAllTensor failed: %d", ret); - return ret; - } - MS_LOGD("converter AllTensor succ"); - MS_ASSERT(subGraphDef.nodes() != nullptr); - ret = ConverterNodes(*(subGraphDef.nodes()), ctx); - if (ret != RET_OK) { - MS_LOGE("ConverterNodes failed: %d", ret); - return ret; - } - MS_LOGD("converter nodes succ"); - - ret = ConverterEdges(subGraphDef); - if (ret != RET_OK) { - MS_LOGE("ConverterEdges failed: %d", ret); - return ret; - } - MS_LOGD("converter edges succ"); - - ret = InitOutputsMap(); - if (ret != RET_OK) { - MS_LOGE("InitOutputsMap failed: %d", ret); - return ret; - } - MS_LOGD("init outputs map succ"); - - MS_LOGD("build graph succ"); - return RET_OK; -} - -int SubGraph::ConverterIndex(const flatbuffers::Vector &srcIndex, std::vector *dstIndex) { - if (dstIndex == nullptr) { - MS_LOGE("input dstIndex is nullptr"); - return RET_PARAM_INVALID; - } - dstIndex->resize(srcIndex.size()); - std::copy(srcIndex.begin(), srcIndex.end(), dstIndex->begin()); - return RET_OK; -} - -int SubGraph::ConverterAllTensor(const flatbuffers::Vector> &srcTensors) { - uint32_t tensorsSize = srcTensors.size(); - - allTensors.clear(); - allTensors.reserve(tensorsSize); - for (uint32_t i = 0; i < tensorsSize; i++) { - auto tensorDef = srcTensors.GetAs(i); - if (tensorDef == nullptr) { - MS_LOGE("%ud th tensordef is null", i); - return RET_ERROR; - } - auto tensor = Tensor::CopyFromTensorDef(*tensorDef); - if (tensor == nullptr) { - return RET_ERROR; - } - allTensors.push_back(tensor); - } - - return RET_OK; -} - -int SubGraph::ConverterNodes(const flatbuffers::Vector> &nodeDefs, const Context &ctx) { - uint32_t opCount = nodeDefs.size(); - // for dfx - if (opCount > G_MAX_OP_COUNT) { - MS_LOGE("opCount(%u) bigger than maxOpCount(%u)", opCount, G_MAX_OP_COUNT); - return RET_ERROR; - } - - nodes.clear(); - - for (uint32_t i = 0; i < opCount; i++) { - auto nodeDef = nodeDefs.GetAs(i); - MS_ASSERT(nodeDef != nullptr); - auto node = std::unique_ptr(new (std::nothrow) Node(nodeDef)); - if (node == nullptr) { - MS_LOGE("new node failed"); - return RET_NULL_PTR; - } - - node->SetTensors(*nodeDef, allTensors); - - auto ret = node->InitOp(*(nodeDef->opDef()), ctx); - if (ret != RET_OK) { - MS_LOGE("node (%s) InitOP failed. ret:%d", node->ID().c_str(), ret); - return ret; - } - - auto nodeId = node->ID(); - nodes[nodeId] = node.release(); - MS_LOGD("add node succ, id:%s", nodeId.c_str()); - } - - return RET_OK; -} - -int SubGraph::ConverterEdges(const SubGraphDef &subGraphDef) { - auto opGraph = OpGraph::Build(subGraphDef); - if (opGraph == nullptr) { - MS_LOGE("opGraph Build fail"); - return RET_ERROR; - } - - for (auto nodeIter : nodes) { - auto node = opGraph->GetNode(nodeIter.first); - if (node == nullptr) { - MS_LOGI("node %s not found", nodeIter.first.c_str()); - continue; - } - for (const auto &edge : node->GetAllInEdge()) { - MS_ASSERT(nodeIter.second != nullptr); - nodeIter.second->AddInEdge(GetNode(edge)); - } - for (const auto &edge : node->GetAllOutEdge()) { - MS_ASSERT(nodeIter.second != nullptr); - nodeIter.second->AddOutEdge(GetNode(edge)); - } - } - delete opGraph; - return RET_OK; -} - -int SubGraph::InitOutputsMap() { - if (nodes.empty()) { - MS_LOGE("nodes are empty"); - return RET_ERROR; - } - for (auto node : nodes) { - NODE_ID realNodeName = node.second->ID(); - MS_ASSERT(node.second != nullptr); - if (node.second->GetAllOutEdges().empty()) { - auto nodeType = node.second->Type(); - if (nodeType == "Nhwc2Nchw" || nodeType == "Nchw2Nhwc") { - auto dependNode = *(this->GetDepends().at(this->GetNode(realNodeName)).begin()); - realNodeName = dependNode->ID(); - } - this->outputsMap.emplace( - std::pair>(realNodeName, node.second->GetOutputTensors())); - } - } - return RET_OK; -} - -std::unordered_map> SubGraph::GetDepends() { - std::unordered_map> depends; - for (auto nodeIter : nodes) { - MS_ASSERT(nodeIter.second != nullptr); - depends[nodeIter.second] = nodeIter.second->GetAllInEdges(); - } - return depends; -} - -Node *SubGraph::GetNode(const NODE_ID &id) { - auto node = nodes.find(id); - if (node == nodes.end()) { - return nullptr; - } - return node->second; -} - -std::vector SubGraph::GetInputs() { - std::vector inputTensor; - inputTensor.resize(inputIndices.size()); - std::transform(inputIndices.begin(), inputIndices.end(), inputTensor.begin(), - [this](int i) { return this->allTensors[i]; }); - - return inputTensor; -} - -std::vector SubGraph::GetOutputs() { - std::vector outputTensor; - outputTensor.resize(outputIndices.size()); - std::transform(outputIndices.begin(), outputIndices.end(), outputTensor.begin(), - [this](int i) { return this->allTensors[i]; }); - - return outputTensor; -} - -std::map> &SubGraph::GetOutputsMap() { return outputsMap; } - -void SubGraph::FreeAllTensors() { - for (auto &allTensor : allTensors) { - if (allTensor != nullptr) { - auto refcount = allTensor->RefCount(); - if (refcount != MSConst_WEIGHT_REFCOUNT) { - allTensor->DefRef(refcount); - allTensor->FreeData(); - } - } - } -} - -const std::vector *SubGraph::GetInputIndices() const { return &inputIndices; } - -const std::vector *SubGraph::GetOutputIndices() const { return &outputIndices; } - -bool SubGraph::IsInputIndex(uint32_t i) { - auto iter = std::find(inputIndices.begin(), inputIndices.end(), i); - return !(iter == inputIndices.end()); -} - -bool SubGraph::IsOutputIndex(uint32_t i) { - auto iter = std::find(outputIndices.begin(), outputIndices.end(), i); - return !(iter == outputIndices.end()); -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/graph.h b/predict/src/graph.h deleted file mode 100755 index f02c46f94e..0000000000 --- a/predict/src/graph.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_GRAPH_H_ -#define PREDICT_SRC_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "common/graph_util.h" -#include "include/tensor.h" -#include "src/node.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -class SubGraph { - public: - SubGraph(); - ~SubGraph(); - static SubGraph *CreateSubGraph(const SubGraphDef &subGraphDef, const Context &ctx); - int Build(const SubGraphDef &subGraphDef, const Context &ctx); - bool IsInputIndex(uint32_t i); - bool IsOutputIndex(uint32_t i); - - const std::vector *GetInputIndices() const; - const std::vector *GetOutputIndices() const; - - std::vector GetInputs(); - std::vector GetOutputs(); - std::map> &GetOutputsMap(); - void FreeAllTensors(); - - Node *GetNode(const NODE_ID &id); - - std::unordered_map> GetDepends(); - - private: - int ConverterIndex(const flatbuffers::Vector &srcIndex, std::vector *dstIndex); - - int ConverterAllTensor(const flatbuffers::Vector> &srcTensors); - - int ConverterNodes(const flatbuffers::Vector> &opDefs, const Context &ctx); - - int ConverterEdges(const SubGraphDef &subGraphDef); - - int InitOutputsMap(); - - protected: - std::unordered_map nodes; - std::vector inputIndices; - std::vector outputIndices; - std::vector allTensors; // weight + input + output - std::map> outputsMap; -}; - -class MSPREDICT_API Graph { - public: - Graph(); - ~Graph(); - static Graph *CreateFromBuf(const char *buf, size_t size, const Context &ctx); - - std::vector GetInputs(); - std::vector GetOutputs(); - - std::map> &GetOutputsMap(); - - void FreeAllTensors(); - - int Build(const GraphDef &def, const Context &ctx); - std::vector *Subgraphs(); - - protected: - friend class GraphExecution; - - std::vector subgraphs; - std::unordered_map> depends; // records the dependencies - std::deque readyQue; // the nodes which can execute without any dependencies -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_GRAPH_H_ diff --git a/predict/src/graph_execution.cc b/predict/src/graph_execution.cc deleted file mode 100644 index 4a71d4359a..0000000000 --- a/predict/src/graph_execution.cc +++ /dev/null @@ -1,293 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/graph_execution.h" -#include -#include -#include - -namespace mindspore { -namespace predict { -GraphExecution::GraphExecution(const Context &ctx) : graph(nullptr), _ctx(ctx) {} -GraphExecution::GraphExecution(const Context &ctx, Graph *staticGraph) : _ctx(ctx) { - graph = staticGraph; - if (graph != nullptr) { - depends = graph->depends; - readyQue = graph->readyQue; - outputTensors = graph->GetOutputs(); - inputTensors = graph->GetInputs(); - } -} - -GraphExecution::~GraphExecution() = default; - -int GraphExecution::TransInputDataToNc4hw4(const Tensor &src, Tensor *dst) { - MS_ASSERT(dst != nullptr); - if (dst->GetData() == nullptr) { - auto ret = dst->MallocData(nullptr, MSConst_WEIGHT_REFCOUNT); - if (ret != RET_OK) { - MS_LOGE("Malloc inputTensors failed: %d", ret); - return ret; - } - } - auto ret = NchwToNc4hw4(&src, dst); - if (ret != RET_OK) { - MS_LOGE("NchwToNc4hw4 failed"); - return ret; - } - return RET_OK; -} - -int GraphExecution::SetInputTensors(const std::vector &inputs) { - size_t num = inputs.size(); - if (num != inputTensors.size()) { - MS_LOGE("input num %zu != model input num %zu", num, inputTensors.size()); - return RET_INPUT_TENSOR_ERROR; - } - - for (size_t i = 0; i < num; i++) { - MS_ASSERT(inputs[i] != nullptr); - // The input Tensor desc must be equivalent with the model tensor - if (inputs[i]->GetData() == nullptr) { - MS_LOGE("input tensor data is null!"); - return RET_INPUT_TENSOR_ERROR; - } - if (inputTensors[i] == nullptr) { - MS_LOGE("inputTensors[%zu] is nullptr", i); - return RET_ERROR; - } - - if (!inputs[i]->CompareShape(*inputTensors[i])) { - MS_LOGE("tensor shape in graph and executor are different!"); - return RET_INPUT_TENSOR_ERROR; - } - - if (inputs[i]->GetDataType() != inputTensors[i]->GetDataType()) { - MS_LOGE("tensor datatype in graph and executor are different!"); - return RET_INPUT_TENSOR_ERROR; - } - - if (inputs[i]->GetFormat() != Format_NCHW) { - MS_LOGE("input format not support. only nchw is supported now"); - return RET_INPUT_TENSOR_ERROR; - } - - if (inputs[i]->GetFormat() == inputTensors[i]->GetFormat()) { - auto data = inputs[i]->GetData(); - if (data == nullptr) { - MS_LOGE("data of input tensor is null!"); - return RET_INPUT_TENSOR_ERROR; - } - inputTensors[i]->SetData(data); - } else if (inputTensors[i]->GetFormat() == Format_NC4HW4) { - auto ret = TransInputDataToNc4hw4(*inputs[i], inputTensors[i]); - if (ret != RET_OK) { - MS_LOGE("TransInputDataToNc4hw4 failed"); - return ret; - } - } else { - MS_LOGE("graphDef inputTensors format is invalid: %d", inputTensors[i]->GetFormat()); - return RET_ERROR; - } - } - return RET_OK; -} - -int GraphExecution::MallocOutput() { - for (auto tensor : outputTensors) { - auto ret = tensor->MallocData(); - if (ret != RET_OK) { - MS_LOGE("malloc output data failed"); - return RET_ERROR; - } - } - return RET_OK; -} - -void GraphExecution::FreeTensors(std::vector *tensors) { - for (auto &tensor : (*tensors)) { - delete tensor; - } - tensors->clear(); -} - -void GraphExecution::FreeOutputMap(std::map> *map) { - MS_ASSERT(map != nullptr); - for (auto &m : *map) { - FreeTensors(&(m.second)); - } - map->clear(); -} - -int GraphExecution::CopyOutputTensors(const std::vector &refOutputs, std::vector *outputs) { - for (auto tensor : refOutputs) { - if (tensor == nullptr) { - MS_LOGE("tensor in refOutputs is nullptr"); - return RET_INPUT_TENSOR_ERROR; - } - std::unique_ptr t(new Tensor(*tensor)); - if (t == nullptr) { - MS_LOGE("new Tensor failed."); - if (outputs != nullptr) { - FreeTensors(outputs); - } - return RET_ERROR; - } - - if (tensor->GetFormat() == Format_NC4HW4) { - t->SetFormat(Format_NCHW); - auto ret = t->MallocData(); - if (ret != RET_OK) { - MS_LOGE("malloc data failed.") - FreeTensors(outputs); - return ret; - } - - ret = Nc4hw4ToNchw(tensor, t.get()); - if (ret != RET_OK) { - MS_LOGE("Nc4hw4ToNchw failed"); - return ret; - } - tensor->FreeData(); - } else { - t->SetData(tensor->GetData()); - tensor->SetData(nullptr); - } - outputs->push_back(t.release()); - } - return RET_OK; -} - -std::map> GraphExecution::GetAllOutput() { - std::map> outputs{}; - for (auto &outputNode : graph->GetOutputsMap()) { - std::vector outputNodeTensors{}; - auto ret = this->CopyOutputTensors(outputNode.second, &outputNodeTensors); - if (ret != RET_OK) { - MS_LOGE("copy output failed."); - FreeOutputMap(&outputs); - return outputs; - } - outputs.emplace(std::pair>(outputNode.first, outputNodeTensors)); - } - return outputs; -} - -std::vector GraphExecution::GetOutput(const NODE_ID &nodeName) { - std::vector outputNodeTensors{}; - auto iter = graph->GetOutputsMap().find(nodeName); - if (iter == graph->GetOutputsMap().end()) { - MS_LOGE("node name is not in output."); - return outputNodeTensors; - } - auto ret = this->CopyOutputTensors(iter->second, &outputNodeTensors); - if (ret != RET_OK) { - MS_LOGE("copy output failed."); - } - return outputNodeTensors; -} - -std::vector GraphExecution::GetInput() { - std::vector inputs{}; - for (auto refInput : graph->GetInputs()) { - if (refInput == nullptr) { - MS_LOGE("tensor from graph->GetInputs() is nullptr"); - return inputs; - } - std::unique_ptr t(new Tensor(refInput->GetDataType(), refInput->GetDims(), Format_NCHW, nullptr)); - if (t == nullptr) { - MS_LOGE("new Tensor failed.") - FreeTensors(&inputs); - return inputs; - } - inputs.push_back(t.release()); - } - return inputs; -} - -void GraphExecution::ResetInputData() { - for (auto tensor : inputTensors) { - if (tensor == nullptr) { - MS_LOGW("tensor in inputTensors is nullptr"); - continue; - } - if (tensor->GetFormat() == Format_NC4HW4) { - if (tensor->GetData() != nullptr) { - free(tensor->GetData()); - tensor->SetData(nullptr); - } - continue; - } - tensor->SetData(nullptr); - } -} - -void GraphExecution::FreeAllTensors() { graph->FreeAllTensors(); } - -int GraphExecution::Run(const std::vector &inputs) { - if (inputs.empty()) { - MS_LOGE("input is empty"); - return RET_ERROR; - } - - int ret; - - if (readyQue.empty()) { - MS_LOGE("readyQue is empty"); - return RET_ERROR; - } - - ret = SetInputTensors(inputs); - if (ret != RET_OK) { - MS_LOGE("SetInputTensors failed: %d", ret); - ResetInputData(); - return ret; - } - ret = MallocOutput(); - if (ret != RET_OK) { - MS_LOGE("MallocOutput failed: %d", ret); - ResetInputData(); - return ret; - } - - while (!readyQue.empty()) { - auto *node = readyQue.front(); - readyQue.pop_front(); - - ret = node->Run(_ctx); - if (ret != RET_OK) { - MS_LOGE("node (%s) failed to run op (%s). error code:%d", node->ID().c_str(), node->Type().c_str(), ret); - ResetInputData(); - FreeAllTensors(); - return ret; - } - - for (auto outNode : node->GetAllOutEdges()) { - auto nodeDepend = depends.find(outNode); - nodeDepend->second.erase(node); - if (nodeDepend->second.empty()) { - depends.erase(nodeDepend); - readyQue.push_back(outNode); - } - } - } - - ResetInputData(); - - return RET_OK; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/graph_execution.h b/predict/src/graph_execution.h deleted file mode 100644 index 022be21865..0000000000 --- a/predict/src/graph_execution.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_GRAPH_EXECUTION_H_ -#define PREDICT_SRC_GRAPH_EXECUTION_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "common/mslog.h" -#include "src/graph.h" -#include "include/errorcode.h" -#include "schema/inner/ms_generated.h" -#include "src/operator/cpu/include/op_func_comm.h" -#include "src/node.h" - -namespace mindspore { -namespace predict { -class GraphExecution { - public: - explicit GraphExecution(const Context &ctx); - GraphExecution(const Context &ctx, Graph *staticGraph); - virtual ~GraphExecution(); - - virtual std::vector GetInput(); - virtual int SetInputTensors(const std::vector &inputs); - - virtual int Run(const std::vector &inputs); - - virtual std::map> GetAllOutput(); - virtual std::vector GetOutput(const NODE_ID &nodeName); - - private: - void ResetInputData(); - int MallocOutput(); - void FreeTensors(std::vector *tensors); - int TransInputDataToNc4hw4(const Tensor &src, Tensor *dst); - int CopyOutputTensors(const std::vector &refOutputs, std::vector *outputs); - void FreeOutputMap(std::map> *map); - void FreeAllTensors(); - - protected: - Graph *graph; - const Context &_ctx; - std::vector inputTensors; - std::vector outputTensors; - std::unordered_map> depends; // records the dependencies - std::deque readyQue; // the nodes which can execute without any dependencies -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_GRAPH_EXECUTION_H_ diff --git a/predict/src/node.cc b/predict/src/node.cc deleted file mode 100644 index 7128dde209..0000000000 --- a/predict/src/node.cc +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/node.h" -#include -#include -#include -#include -#include -#include "schema/inner/ms_generated.h" -#include "common/mslog.h" -#include "common/op_utils.h" -#include "include/errorcode.h" -#include "src/op_factory.h" - -namespace mindspore { -namespace predict { -Node::Node(const NodeDef *nodeDef) - : id(std::string(nodeDef->opDef()->name()->c_str())), type(GetOpTypeName(*nodeDef)) {} - -Node::~Node() { - if (op != nullptr) { - delete op; - } -} - -NODE_ID Node::ID() { return id; } - -std::string Node::Type() { return type; } - -void Node::SetTensors(const NodeDef &nodeDef, const std::vector &allTensors) { - if (nodeDef.opDef() == nullptr) { - MS_LOGE("nodeDef is null"); - return; - } - - auto inputIndex = nodeDef.opDef()->inputIndex(); - MS_ASSERT(inputIndex != nullptr); - inputs.resize(inputIndex->size()); - std::transform(inputIndex->begin(), inputIndex->end(), inputs.begin(), [allTensors](int i) { return allTensors[i]; }); - - auto outputIndex = nodeDef.opDef()->outputIndex(); - MS_ASSERT(outputIndex != nullptr); - outputs.resize(outputIndex->size()); - std::transform(outputIndex->begin(), outputIndex->end(), outputs.begin(), - [allTensors](int i) { return allTensors[i]; }); -} - -void Node::SetDepends(const std::unordered_set &deps) { depends = deps; } - -std::unordered_set Node::GetDepends() { return depends; } - -void Node::AddInEdge(Node *node) { - if (node == nullptr) { - MS_LOGE("node is null"); - return; - } - inEdges.insert(node); -} - -void Node::AddOutEdge(Node *node) { - if (node == nullptr) { - MS_LOGE("node is null"); - return; - } - outEdges.insert(node); -} - -std::unordered_set &Node::GetAllInEdges() { return inEdges; } - -std::unordered_set &Node::GetAllOutEdges() { return outEdges; } - -std::vector &Node::GetOutputTensors() { return outputs; } -std::vector &Node::GetInputTensors() { return inputs; } - -int Node::InitOp(const OpDef &opDef, const Context &ctx) { - OpDesc dst; - dst.type = GetOpType(opDef); - dst.arch = X86_FP32; - MS_ASSERT(OpFactory::GetInstance() != nullptr); - op = OpFactory::GetInstance()->GetOp(inputs, outputs, opDef, ctx, dst); - if (op == nullptr) { - MS_LOGE("Can't find opName: %s, type: %s ", id.c_str(), type.c_str()); - return RET_ERROR; - } - return RET_OK; -} - -int Node::Run(const Context &ctx) { - MS_LOGD("%s run start", id.c_str()); - auto ret = MallocOutput(ctx); - if (ret != RET_OK) { - MS_LOGE("MallocOutput failed: %d", ret); - return ret; - } - if (op == nullptr) { - MS_LOGE("op is nullptr."); - return RET_ERROR; - } - ret = op->Execute(inputs, outputs); - if (ret != RET_OK) { - return ret; - } - FreeInput(); - return RET_OK; -} - -int Node::MallocOutput(const Context &ctx) { - size_t refCount = outEdges.size(); - for (auto tensor : outputs) { - if (tensor == nullptr) { - MS_LOGE("tensor in outputs is nullptr"); - return RET_ERROR; - } - auto ret = tensor->MallocData(ctx.allocator, refCount); - if (ret != RET_OK) { - return ret; - } - } - return RET_OK; -} - -void Node::FreeInput() { - for (auto tensor : inputs) { - if (tensor == nullptr) { - MS_LOGW("tensor in inputs is nullptr"); - return; - } - if (tensor->RefCount() != MSConst_WEIGHT_REFCOUNT) { - tensor->FreeData(); - } - } -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/node.h b/predict/src/node.h deleted file mode 100644 index eebb1b4321..0000000000 --- a/predict/src/node.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_NODE_H_ -#define PREDICT_SRC_NODE_H_ - -#include -#include -#include -#include "include/session.h" -#include "src/op.h" - -namespace mindspore { -namespace predict { -using NODE_ID = std::string; - -class Node { - public: - Node() = default; - explicit Node(const NodeDef *nodeDef); - virtual ~Node(); - NODE_ID ID(); - std::string Type(); - void SetTensors(const NodeDef &nodeDef, const std::vector &allTensors); - void SetDepends(const std::unordered_set &deps); - std::unordered_set GetDepends(); - - void AddInEdge(Node *node); - void AddOutEdge(Node *node); - std::unordered_set &GetAllOutEdges(); - std::unordered_set &GetAllInEdges(); - - std::vector &GetOutputTensors(); - std::vector &GetInputTensors(); - - int InitOp(const OpDef &opDef, const Context &ctx); - int Run(const Context &ctx); - int MallocOutput(const Context &ctx); - void FreeInput(); - - protected: - friend class GraphExecution; - NODE_ID id; - std::string type; - OpBase *op{}; - std::vector inputs; - std::vector outputs; - std::unordered_set depends; - std::unordered_set inEdges; - std::unordered_set outEdges; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_NODE_H_ diff --git a/predict/src/op.cc b/predict/src/op.cc deleted file mode 100644 index ca99b7fdff..0000000000 --- a/predict/src/op.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/op.h" - -namespace mindspore { -namespace predict { -OpBase::OpBase() : desc(nullptr) {} - -OpBase::~OpBase() = default; -} // namespace predict -} // namespace mindspore diff --git a/predict/src/op.h b/predict/src/op.h deleted file mode 100644 index a07ce21952..0000000000 --- a/predict/src/op.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_OP_H_ -#define PREDICT_SRC_OP_H_ - -#include -#include -#include "include/context.h" -#include "include/tensor.h" -#include "include/errorcode.h" -#include "schema/inner/ms_generated.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -enum OP_ARCH { X86_FP32, X86_INT8, ARM_FP32, ARM_FP16, ARM_INT8, GPU }; - -struct MSPREDICT_API OpDesc { - OP_ARCH arch; - OpT type; - - bool operator<(const OpDesc &dst) const { return (arch < dst.arch) || (type < dst.type); } -}; - -class MSPREDICT_API OpBase { - public: - OpBase(); - virtual ~OpBase(); - - virtual int Execute(const std::vector &inputs, const std::vector &outputs) = 0; - virtual int Init(const std::vector &inputs, const std::vector &outputs) = 0; - - protected: - const OpDesc *desc; - std::string name; -}; - -typedef OpBase *(*OpCreator)(const std::vector &inputs, const std::vector &outputs, - const OpDef &opDef, const Context &ctx, const OpDesc &desc); -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_OP_H_ diff --git a/predict/src/op_common.h b/predict/src/op_common.h deleted file mode 100644 index c5eb69bd57..0000000000 --- a/predict/src/op_common.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_OP_COMMON_H_ -#define PREDICT_SRC_OP_COMMON_H_ -#include - -namespace mindspore { -namespace predict { -static inline size_t AlignSize(size_t size, size_t align) { return (size + align - 1) & -align; } - -template -inline void Nchw2Nhwc(const Tsrc *in, Tdst *out, size_t h, size_t w, size_t c) { - MS_ASSERT(in != nullptr && out != nullptr); - const size_t sz = w * h; - - for (size_t cc = 0; cc < c; ++cc) { - auto pi = in + sz * cc; - - for (size_t el = 0; el < sz; ++el) { - out[cc + el * c] = (Tdst)pi[el]; - } - } -} - -template -inline void Nhwc2Nchw(const Tsrc *in, Tdst *out, size_t h, size_t w, size_t c) { - MS_ASSERT(in != nullptr && out != nullptr); - const size_t sz = w * h; - - for (auto cc = 0; cc < c; ++cc) { - auto po = out + sz * cc; - - for (size_t el = 0; el < sz; ++el) { - po[el] = (Tdst)in[cc + el * c]; - } - } -} - -template -inline void InverseQuantization(const Tsrc *srcdata, Tdst *dstdata, size_t datanum, float *parms) { - MS_ASSERT(srcdata != nullptr && dstdata != nullptr); - float scale = parms[2]; - float zeroPoint = parms[3]; - for (size_t i = 0; i < datanum; ++i) { - dstdata = (scale == 0) ? (0) : (Tdst)((srcdata[i] - zeroPoint) * scale); - } -} - -template -inline void Astype(const Tsrc *srcdata, Tdst *dstdata, size_t datanum) { - MS_ASSERT(srcdata != nullptr && dstdata != nullptr); - for (size_t i = 0; i < datanum; ++i) { - dstdata[i] = (Tdst)srcdata[i]; - } -} -#define MSMIN(x, y) ((x) < (y) ? (x) : (y)) -#define MSMAX(x, y) ((x) > (y) ? (x) : (y)) - -#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -#define DOWN_DIV(x, y) (((x) - (y) + (1)) / (y)) -#define ROUND_UP(x, y) (((x) + (y) - (1)) / (y) * (y)) -#define ALIGN_UP4(x) ROUND_UP((x), 4) -#define ALIGN_UP8(x) ROUND_UP((x), 8) - -#define MAX_MALLOC_SIZE 100 * 1024 * 1024 -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_OP_COMMON_H_ diff --git a/predict/src/op_factory.cc b/predict/src/op_factory.cc deleted file mode 100644 index 2506db68d8..0000000000 --- a/predict/src/op_factory.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/op_factory.h" - -namespace mindspore { -namespace predict { -OpFactory::OpFactory() { InitKernelManager(0, ""); } - -OpFactory::~OpFactory() = default; - -OpFactory *OpFactory::GetInstance() { - static OpFactory instance; - return &instance; -} - -OpBase *OpFactory::GetOp(const std::vector &inputs, const std::vector &outputs, const OpDef &opDef, - const Context &ctx, const OpDesc &desc) { - MS_ASSERT(GetRegistryInstance() != nullptr); - auto *reg = GetRegistryInstance()->GetInstance(MODULE_REG_NAME_OP_REGISTRY); - if (reg != nullptr) { - auto creator = reg->GetOpCreator(desc); - if (creator) { - return creator(inputs, outputs, opDef, ctx, desc); - } - } - MS_ASSERT(OpRegistry::GetInstance() != nullptr); - auto creator = OpRegistry::GetInstance()->GetOpCreator(desc); - if (creator) { - return creator(inputs, outputs, opDef, ctx, desc); - } - return nullptr; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/op_factory.h b/predict/src/op_factory.h deleted file mode 100644 index 583a605d8c..0000000000 --- a/predict/src/op_factory.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_OP_FACTORY_H_ -#define PREDICT_SRC_OP_FACTORY_H_ - -#include -#include "lite/api/km_api.h" -#include "src/op.h" -#include "src/op_registry.h" - -namespace mindspore { -namespace predict { -class OpFactory { - public: - OpFactory(); - virtual ~OpFactory(); - - static OpFactory *GetInstance(); - OpBase *GetOp(const std::vector &inputs, const std::vector &outputs, const OpDef &opDef, - const Context &ctx, const OpDesc &desc); -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_OP_FACTORY_H_ diff --git a/predict/src/op_registry.cc b/predict/src/op_registry.cc deleted file mode 100644 index 14cd810d54..0000000000 --- a/predict/src/op_registry.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/op_registry.h" - -namespace mindspore { -namespace predict { -OpRegistry::OpRegistry() = default; - -OpRegistry::~OpRegistry() = default; - -OpRegistry *OpRegistry::GetInstance() { - static OpRegistry instance; - return &instance; -} - -OpCreator OpRegistry::GetOpCreator(const OpDesc &desc) { - auto it = creators.find(desc); - if (it != creators.end()) { - return it->second; - } - return nullptr; -} - -void OpRegistry::RegOp(const OpDesc desc, OpCreator creator) { creators[desc] = creator; } - -void OpRegistry::RegOp(const OP_ARCH arch, const OpT type, OpCreator creator) { - OpDesc desc = {arch, type}; - creators[desc] = creator; -} - -bool OpRegistry::Merge(const std::unordered_map &newCreators) { return false; } - -const std::map &OpRegistry::GetOpCreators() { return creators; } -} // namespace predict -} // namespace mindspore diff --git a/predict/src/op_registry.h b/predict/src/op_registry.h deleted file mode 100644 index bb1d957fec..0000000000 --- a/predict/src/op_registry.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_OP_REGISTRY_H_ -#define PREDICT_SRC_OP_REGISTRY_H_ - -#include -#include -#include -#include "common/mslog.h" -#include "common/module_registry.h" -#include "src/op.h" - -#define MSPREDICT_API __attribute__((visibility("default"))) - -namespace mindspore { -namespace predict { -class MSPREDICT_API OpRegistry { - public: - OpRegistry(); - virtual ~OpRegistry(); - - static OpRegistry *GetInstance(); - virtual OpCreator GetOpCreator(const OpDesc &desc); - - const std::map &GetOpCreators(); - - void RegOp(OpDesc desc, OpCreator creator); - void RegOp(OP_ARCH arch, OpT type, OpCreator creator); - static bool Merge(const std::unordered_map &newCreators); - - protected: - std::map creators; -}; - -template <> -class Module : public ModuleBase { - public: - virtual OpRegistry *GetInstance() = 0; -}; - -const char MODULE_REG_NAME_OP_REGISTRY[] = "op_registry"; - -class OpRegistrar { - public: - OpRegistrar(const OpDesc &desc, OpCreator creator) { OpRegistry::GetInstance()->RegOp(desc, creator); } - - OpRegistrar(const OP_ARCH arch, const OpT type, OpCreator creator) { - MS_ASSERT(OpRegistry::GetInstance() != nullptr); - OpRegistry::GetInstance()->RegOp(arch, type, creator); - } -}; - -#define REG_OP(arch, type, opCreater) static OpRegistrar g_##arch##type##OpReg(arch, type, opCreater); -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_OP_REGISTRY_H_ diff --git a/predict/src/operator/cpu/common/op_func_comm.cc b/predict/src/operator/cpu/common/op_func_comm.cc deleted file mode 100755 index f05e224ec3..0000000000 --- a/predict/src/operator/cpu/common/op_func_comm.cc +++ /dev/null @@ -1,422 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/operator/cpu/include/op_func_comm.h" -#include "include/errorcode.h" -#include "include/tensor.h" -#include "common/mslog.h" -#include "securec/include/securec.h" - -namespace mindspore { -namespace predict { -#ifndef MS_USE_NEON -#ifndef MS_USE_SSE - -void MSAddBias(float *srcPtr, const float *bias, size_t unitSize, size_t count) { - if (srcPtr == nullptr || bias == nullptr) { - MS_LOGW("srcPtr or bias is nullptr"); - return; - } - for (size_t stepU = 0; stepU < count; stepU++) { - float *tmpPtr = srcPtr + unitSize * CAL_STEP * stepU; - const float *biasPtr = bias + CAL_STEP * stepU; - for (size_t step = 0; step < unitSize; step++) { - float *dstPtr = tmpPtr + CAL_STEP * step; - for (int i = 0; i < CAL_STEP; i++) { - dstPtr[i] += biasPtr[i]; - } - } - } -} - -void MSAddBiasRelu(float *srcPtr, const float *bias, size_t unitSize, size_t count) { - if (srcPtr == nullptr || bias == nullptr) { - MS_LOGW("srcPtr or bias is nullptr"); - return; - } - for (size_t stepU = 0; stepU < count; stepU++) { - float *tmpPtr = srcPtr + unitSize * CAL_STEP * stepU; - const float *biasPtr = bias + CAL_STEP * stepU; - for (size_t step = 0; step < unitSize; step++) { - float *dstPtr = tmpPtr + CAL_STEP * step; - for (int i = 0; i < CAL_STEP; i++) { - dstPtr[i] += biasPtr[i]; - dstPtr[i] = (dstPtr[i] < 0) ? 0 : dstPtr[i]; - } - } - } -} - -void MSAddBiasRelu6(float *srcPtr, const float *bias, size_t unitSize, size_t count) { - if (srcPtr == nullptr || bias == nullptr) { - MS_LOGW("srcPtr or bias is nullptr"); - return; - } - for (size_t stepU = 0; stepU < count; stepU++) { - float *tmpPtr = srcPtr + unitSize * CAL_STEP * stepU; - const float *biasPtr = bias + CAL_STEP * stepU; - for (size_t step = 0; step < unitSize; step++) { - float *dstPtr = tmpPtr + CAL_STEP * step; - for (int i = 0; i < CAL_STEP; i++) { - dstPtr[i] += biasPtr[i]; - dstPtr[i] = (dstPtr[i] < 0) ? 0 : dstPtr[i]; - dstPtr[i] = (dstPtr[i] > 6.0f) ? 6.0f : dstPtr[i]; - } - } - } -} - -void MSCopyC4WithStride(const float *srcPtr, float *dstPtr, size_t srcStride, size_t dstStride, size_t count) { - if (srcPtr == nullptr || dstPtr == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - for (size_t stepU = 0; stepU < count; stepU++) { - auto sPtr = srcPtr + stepU * srcStride; - auto dPtr = dstPtr + stepU * dstStride; - int tmpC = 0; - while (tmpC < CAL_STEP) { - dPtr[tmpC] = sPtr[tmpC]; - tmpC++; - } - } -} -#endif // MS_USE_SSE - -int MSPackC4(float *dstPtr, const float *srcPtr, size_t area, size_t depth) { - if (dstPtr == nullptr || srcPtr == nullptr) { - MS_LOGE("srcPtr or dstPtr is nullptr"); - return RET_ERROR; - } - int cur = 0; - size_t size = area * UP_DIV(depth, CAL_STEP) * CAL_STEP * sizeof(float); - auto ret = memset_s(dstPtr, size, 0, size); - if (ret != EOK) { - MS_LOGE("memset_s failed!"); - return RET_ERROR; - } - for (size_t step = 0; step < depth; step++) { - auto plane = step / CAL_STEP; - auto offset = step % CAL_STEP; - auto dstPlane = plane * area * CAL_STEP + dstPtr; - for (size_t i = 0; i < area; i++) { - dstPlane[CAL_STEP * i + offset] = srcPtr[cur++]; - } - } - return RET_OK; -} - -void MSUnpackC4(float *dstPtr, const float *srcPtr, size_t area, size_t depth) { - if (dstPtr == nullptr || srcPtr == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - int cur = 0; - for (size_t step = 0; step < depth; step++) { - auto plane = step / CAL_STEP; - auto offset = step % CAL_STEP; - auto srcPlane = plane * area * CAL_STEP + srcPtr; - for (size_t i = 0; i < area; i++) { - dstPtr[cur++] = srcPlane[CAL_STEP * i + offset]; - } - } -} - -void MSUInt8ToInt16WithOffsetC4Common(int16_t *dstPtr, const uint8_t *srcPtr, size_t zeroPoint, size_t sizeQuad, - size_t dstStride, size_t srcStride) { - if (dstPtr == nullptr || srcPtr == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - for (size_t step = 0; step < sizeQuad; step++) { - auto dstZ = dstPtr + (dstStride / sizeof(int16_t)) * step; - auto srcZ = srcPtr + (srcStride / sizeof(uint8_t)) * step; - for (int i = 0; i < CAL_STEP; i++) { - dstZ[i] = (int16_t)((int32_t)srcZ[i] - (int32_t)zeroPoint); - } - } -} - -void MSUInt8ToInt16WithOffsetC4Fast(int16_t *colAddr, const uint8_t *srcStart, size_t zeroPoint, size_t sizeQuad, - size_t depthQuad, size_t dstZStep, size_t srcZStep) { - if (colAddr == nullptr || srcStart == nullptr) { - MS_LOGW("colAddr or srcStart is nullptr"); - return; - } - for (size_t step = 0; step < depthQuad; step++) { - auto dstZ = colAddr + (dstZStep / sizeof(int16_t)) * step; - auto srcZ = srcStart + (srcZStep / sizeof(uint8_t)) * step; - MSUInt8ToInt16WithOffsetC4Common(dstZ, srcZ, zeroPoint, sizeQuad, CAL_STEP * sizeof(int16_t), - CAL_STEP * sizeof(uint8_t)); - } -} -#endif - -void MSPackC4Uint8(uint8_t *dstPtr, const uint8_t *srcPtr, size_t area, size_t depth) { - if (dstPtr == nullptr || srcPtr == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - int cur = 0; - size_t size = area * UP_DIV(depth, CAL_STEP) * CAL_STEP * sizeof(uint8_t); - auto ret = memset_s(dstPtr, size, 0, size); - if (ret != EOK) { - MS_LOGE("memset_s failed!"); - return; - } - for (size_t step = 0; step < depth; step++) { - auto plane = step / CAL_STEP; - auto offset = step % CAL_STEP; - auto dstPlane = plane * area * CAL_STEP + dstPtr; - for (size_t x = 0; x < area; ++x) { - dstPlane[CAL_STEP * x + offset] = srcPtr[cur++]; - } - } -} - -void MSUnpackC4Uint8(uint8_t *dstPtr, const uint8_t *srcPtr, size_t area, size_t depth) { - if (dstPtr == nullptr || srcPtr == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - int cur = 0; - for (size_t step = 0; step < depth; step++) { - auto srcPlane = (step / CAL_STEP) * area * CAL_STEP + srcPtr; - for (size_t i = 0; i < area; i++) { - dstPtr[cur++] = srcPlane[CAL_STEP * i + (step % CAL_STEP)]; - } - } -} - -#ifdef MS_USE_NEON -static void MSTensorConvertNCHWToNC4HW4Depth(float *dst, const float *src, size_t area, size_t depth) { - if (dstPtr == nullptr || srcPtr == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - if (1 == depth) { - auto zeroValue = vmovq_n_f32(0.0f); - int areaC4 = static_cast(area / CAL_STEP); - int remain = areaC4 * CAL_STEP; - for (int i = 0; i < areaC4; ++i) { - auto srcCur = src + CAL_STEP * i; - auto dstCur = dst + CAL_STEP * CAL_STEP * i; - auto srcValue = vld1q_f32(srcCur); - float32x4x4_t dstValue; - dstValue.val[0] = srcValue; - dstValue.val[1] = zeroValue; - dstValue.val[2] = zeroValue; - dstValue.val[3] = zeroValue; - vst4q_f32(dstCur, dstValue); - } - for (int i = remain; i < area; ++i) { - dst[CAL_STEP * i + 0] = src[i]; - dst[CAL_STEP * i + 1] = 0.0f; - dst[CAL_STEP * i + 2] = 0.0f; - dst[CAL_STEP * i + 3] = 0.0f; - } - } else if (3 == depth) { - auto zeroValue = vmovq_n_f32(0.0f); - int areaC4 = static_cast(area / CAL_STEP); - int remain = areaC4 * CAL_STEP; - for (int i = 0; i < areaC4; ++i) { - auto srcCur = src + 12 * i; - auto dstCur = dst + 16 * i; - auto srcValue = vld3q_f32(srcCur); - float32x4x4_t dstValue; - dstValue.val[0] = srcValue.val[0]; - dstValue.val[1] = srcValue.val[1]; - dstValue.val[2] = srcValue.val[2]; - dstValue.val[3] = zeroValue; - vst4q_f32(dstCur, dstValue); - } - for (int i = remain; i < area; ++i) { - dst[CAL_STEP * i + 0] = src[3 * i + 0]; - dst[CAL_STEP * i + 1] = src[3 * i + 1]; - dst[CAL_STEP * i + 2] = src[3 * i + 2]; - dst[CAL_STEP * i + 3] = 0.0f; - } - } -} -#endif - -void MSTensorConvertNHWCToNC4HW4(float *dst, const float *src, size_t area, size_t depth) { - if (dst == nullptr || src == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } -#ifdef MS_USE_NEON - MSTensorConvertNCHWToNC4HW4Depth(dst, src, area, depth); - return; -#endif - int c = static_cast(depth); - int cDiv4 = c / CAL_STEP; - int cMod4 = c % CAL_STEP; - int cAlign = cDiv4 * CAL_STEP; - for (int hi = 0; hi < area; ++hi) { - auto srcHeight = src + hi * c; - auto dstHeight = dst + hi * CAL_STEP; - for (int ci = 0; ci < cDiv4; ++ci) { -#ifdef MS_USE_NEON - vst1q_f32(dstHeight + CAL_STEP * ci * area, vld1q_f32(srcHeight + CAL_STEP * ci)); -#else - for (int i = 0; i < CAL_STEP; ++i) { - dstHeight[ci * area * CAL_STEP + i] = srcHeight[CAL_STEP * ci + i]; - } -#endif - } - } - - if (cMod4 == 0) { - MS_LOGW("depth should be multiple of four"); - return; - } - - auto srcAlign = src + cAlign; - auto dstAlign = dst + area * cAlign; - -#ifdef MS_USE_NEON - auto zeroVector = vdupq_n_f32(0.0f); -#endif - - for (int hi = 0; hi < area; ++hi) { - auto srcHeight = srcAlign + hi * c; - auto dstHeight = dstAlign + hi * CAL_STEP; -#ifdef MS_USE_NEON - vst1q_f32(dstHeight, zeroVector); -#else - for (int i = 0; i < CAL_STEP; ++i) { - dstHeight[i] = 0; - } -#endif - for (int ci = 0; ci < cMod4; ++ci) { - dstHeight[ci] = srcHeight[ci]; - } - } -} - -void MSTensorConvertNC4HW4ToNHWC(float *dst, const float *src, size_t area, size_t depth) { - if (dst == nullptr || src == nullptr) { - MS_LOGW("srcPtr or dstPtr is nullptr"); - return; - } - int c = static_cast(depth); - int cDiv4 = c / CAL_STEP; - int cMod4 = c % CAL_STEP; - int cAlign = cDiv4 * CAL_STEP; - for (int hi = 0; hi < area; ++hi) { - auto srcHeight = src + hi * CAL_STEP; - auto dstHeight = dst + hi * c; - for (int ci = 0; ci < cDiv4; ++ci) { -#ifdef MS_USE_NEON - vst1q_f32(dstHeight + CAL_STEP * ci, vld1q_f32(srcHeight + CAL_STEP * ci * area)); -#else - for (int i = 0; i < CAL_STEP; ++i) { - dstHeight[ci * CAL_STEP + i] = srcHeight[CAL_STEP * ci * area + i]; - } -#endif - } - } - - if (cMod4 == 0) { - MS_LOGW("depth should be multiple of four"); - return; - } - - auto srcAlign = src + area * cAlign; - auto dstAlign = dst + cAlign; - - for (int hi = 0; hi < area; ++hi) { - auto srcHeight = srcAlign + hi * CAL_STEP; - auto dstHeight = dstAlign + hi * c; - - for (int ci = 0; ci < cMod4; ++ci) { - dstHeight[ci] = srcHeight[ci]; - } - } -} - -int NchwToNc4hw4(const Tensor *input, Tensor *output) { - if (input == nullptr || output == nullptr) { - MS_LOGE("input or output is nullptr"); - return RET_ERROR; - } - int batch = static_cast(input->Batch()); - int channel = static_cast(input->Channel()); - MS_ASSERT(batch > 0); - MS_ASSERT(channel > 0); - int area = static_cast(input->Width()) * static_cast(input->Height()); - int inputStride = input->GetElementSize() / batch; - int outputStride = output->GetElementSize() / batch; - DataType dt = input->GetDataType(); - - MS_ASSERT(input->GetData()); - MS_ASSERT(output->GetData()); - - if (dt == DataType_DT_FLOAT) { - for (int i = 0; i < batch; ++i) { - auto ret = MSPackC4(reinterpret_cast(output->GetData()) + outputStride * i, - (const float *)input->GetData() + inputStride * i, area, channel); - if (ret != RET_OK) { - MS_LOGE("MSPackC4 failed: %d", ret); - return RET_ERROR; - } - } - } else if (dt == DataType_DT_UINT8) { - for (int i = 0; i < batch; ++i) { - MSPackC4Uint8(reinterpret_cast(output->GetData()) + outputStride * i, - (const uint8_t *)input->GetData() + inputStride * i, area, channel); - } - } else { - MS_LOGE("Unsupported dataType: %d", dt); - return RET_ERROR; - } - return RET_OK; -} - -int Nc4hw4ToNchw(const Tensor *input, Tensor *output) { - if (input == nullptr || output == nullptr) { - MS_LOGE("input tensor or output tensor is nullptr"); - return RET_ERROR; - } - - int batch = static_cast(input->Batch()); - int channel = static_cast(input->Channel()); - MS_ASSERT(batch > 0); - MS_ASSERT(channel > 0); - int area = static_cast(input->Width()) * static_cast(input->Height()); - int inputStride = input->GetElementSize() / batch; - int outputStride = output->GetElementSize() / batch; - DataType dt = input->GetDataType(); - if (dt == DataType_DT_FLOAT) { - for (int i = 0; i < batch; ++i) { - MSUnpackC4(reinterpret_cast(output->GetData()) + outputStride * i, - (const float *)input->GetData() + inputStride * i, area, channel); - } - } else if (dt == DataType_DT_UINT8) { - for (int i = 0; i < batch; ++i) { - MSUnpackC4Uint8(reinterpret_cast(output->GetData()) + outputStride * i, - (const uint8_t *)input->GetData() + inputStride * i, area, channel); - } - } else { - MS_LOGE("Unsupported dataType: %d", dt); - return RET_ERROR; - } - - return RET_OK; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/operator/cpu/include/op_func_comm.h b/predict/src/operator/cpu/include/op_func_comm.h deleted file mode 100644 index 884803d669..0000000000 --- a/predict/src/operator/cpu/include/op_func_comm.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_OPERATOR_CPU_INCLUDE_OP_FUNC_COMM_H_ -#define PREDICT_SRC_OPERATOR_CPU_INCLUDE_OP_FUNC_COMM_H_ - -#include -#include -#include -#include -#include -#include "src/op_common.h" -#include "include/tensor.h" - -#ifdef MS_USE_NEON -#include -#endif // MS_USE_NEON - -namespace mindspore { -namespace predict { -#ifdef __cplusplus -extern "C" { -#endif -#define CAL_STEP 4 -void MSAddBias(float *dst, const float *bias, size_t planeNumber, size_t biasNumber); -void MSAddBiasRelu(float *dst, const float *bias, size_t planeNumber, size_t biasNumber); -void MSAddBiasRelu6(float *dst, const float *bias, size_t planeNumber, size_t biasNumber); -void MSPackC4Uint8(uint8_t *dst, const uint8_t *src, size_t area, size_t depth); -void MSUnpackC4(float *dst, const float *src, size_t area, size_t depth); -void MSUnpackC4Uint8(uint8_t *dst, const uint8_t *src, size_t area, size_t depth); -void MSTensorConvertNHWCToNC4HW4(float *dst, const float *src, size_t area, size_t depth); -void MSTensorConvertNC4HW4ToNHWC(float *dst, const float *src, size_t area, size_t depth); -void MSUnpackC4(float *dst, const float *src, size_t area, size_t depth); -void MSCopyC4WithStride(const float *source, float *dest, size_t srcStride, size_t dstStride, size_t count); -void MSUInt8ToInt16WithOffsetC4Common(int16_t *dst, const uint8_t *src, size_t zeroPoint, size_t sizeQuad, - size_t dstStride, size_t srcStride); -void MSUInt8ToInt16WithOffsetC4Fast(int16_t *dst, const uint8_t *src, size_t zeroPoint, size_t sizeQuad, - size_t depthQuad, size_t dstZStep, size_t srcZStep); - -int MSPackC4(float *dst, const float *src, size_t area, size_t depth); -int NchwToNc4hw4(const Tensor *input, Tensor *output); -int Nc4hw4ToNchw(const Tensor *input, Tensor *output); -#ifdef __cplusplus -} -#endif -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_OPERATOR_CPU_INCLUDE_OP_FUNC_COMM_H_ diff --git a/predict/src/runtime/allocator.cc b/predict/src/runtime/allocator.cc deleted file mode 100644 index cb94af5df9..0000000000 --- a/predict/src/runtime/allocator.cc +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/allocator.h" -#include "common/module_registry.h" -#include "src/op_common.h" - -namespace mindspore { -namespace predict { -std::shared_ptr Allocator::Create() { - auto alloc = GetRegistryInstance()->Create(MODULE_REG_NAME_ALLOCATOR); - if (alloc != nullptr) { - return alloc; - } - - // default allocator - return std::shared_ptr(new DefaultAllocator()); -} - -DefaultAllocator::DefaultAllocator() = default; - -DefaultAllocator::~DefaultAllocator() { Clear(); } - -void DefaultAllocator::SetContext(const AllocatorContext &ctx) { - lockFlag = ctx.lockFlag; - shiftFactor = ctx.shiftFactor; -} - -void DefaultAllocator::Lock() { - if (lockFlag) { - lock.lock(); - } -} - -void DefaultAllocator::UnLock() { - if (lockFlag) { - lock.unlock(); - } -} - -void *DefaultAllocator::Malloc(size_t size) { - if (size > MAX_MALLOC_SIZE) { - return nullptr; - } - Lock(); - auto it = freeList.begin(); - for (; it != freeList.end(); it++) { - auto membuf = *it; - - if ((membuf->size >= size) && (membuf->size < (size << shiftFactor))) { - freeList.erase(it); - allocatedList.push_back(membuf); - UnLock(); - return membuf->buf; - } - } - std::unique_ptr membuf(reinterpret_cast(malloc(sizeof(MemBuf) + size))); - if (membuf == nullptr) { - UnLock(); - return nullptr; - } - membuf->size = size; - membuf->buf = reinterpret_cast(membuf.get()) + sizeof(MemBuf); - auto bufPtr = membuf->buf; - allocatedList.push_back(membuf.release()); - UnLock(); - return bufPtr; -} - -void DefaultAllocator::Free(void *buf) { - if (buf == nullptr) { - return; - } - Lock(); - auto it = allocatedList.begin(); - for (; it != allocatedList.end(); it++) { - auto membuf = *it; - - if (membuf->buf == buf) { - allocatedList.erase(it); - freeList.push_back(membuf); - UnLock(); - return; - } - } - UnLock(); - free(buf); -} - -size_t DefaultAllocator::GetTotalSize() { - Lock(); - size_t totalSize = 0; - auto it = allocatedList.begin(); - for (; it != allocatedList.end(); it++) { - auto membuf = *it; - totalSize += membuf->size; - } - it = freeList.begin(); - for (; it != freeList.end(); it++) { - auto membuf = *it; - totalSize += membuf->size; - } - UnLock(); - return totalSize; -} - -void DefaultAllocator::Clear() { - Lock(); - auto it = allocatedList.begin(); - for (; it != allocatedList.end(); it++) { - free(*it); - } - allocatedList.clear(); - it = freeList.begin(); - for (; it != freeList.end(); it++) { - free(*it); - } - freeList.clear(); - UnLock(); -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/runtime/allocator.h b/predict/src/runtime/allocator.h deleted file mode 100644 index a9d72fbc9d..0000000000 --- a/predict/src/runtime/allocator.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_RUNTIME_ALLOCATOR_H_ -#define PREDICT_SRC_RUNTIME_ALLOCATOR_H_ - -#include -#include -#include -#include -#include "common/module_registry.h" - -namespace mindspore { -namespace predict { -struct AllocatorContext { - int shiftFactor; - bool lockFlag; -}; - -class Allocator { - public: - Allocator() : name("default") {} - virtual ~Allocator() {} - virtual void *Malloc(size_t size) = 0; - virtual void Free(void *ptr) = 0; - virtual void SetContext(const AllocatorContext &ctx) {} - virtual size_t GetTotalSize() { return 0; } - virtual void Clear() {} - static std::shared_ptr Create(); - std::string name; -}; - -class DefaultAllocator : public Allocator { - public: - DefaultAllocator(); - ~DefaultAllocator() override; - void SetContext(const AllocatorContext &ctx) override; - void *Malloc(size_t size) override; - void Free(void *ptr) override; - size_t GetTotalSize() override; - void Clear() override; - - private: - void Lock(); - void UnLock(); - struct MemBuf { - size_t size; - void *buf; - }; - - std::mutex lock; - std::vector allocatedList; - std::vector freeList; - int shiftFactor = 0; - bool lockFlag = false; -}; - -// these declaration are for module integration, refer to sample_allocator -const char MODULE_REG_NAME_ALLOCATOR[] = "allocator"; - -template <> class Module : public ModuleBase { - public: - virtual std::shared_ptr Create() = 0; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_RUNTIME_ALLOCATOR_H_ diff --git a/predict/src/runtime/runtime_api.cc b/predict/src/runtime/runtime_api.cc deleted file mode 100644 index 2091c808ff..0000000000 --- a/predict/src/runtime/runtime_api.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/runtime_api.h" -#include -#include -#include "src/runtime/workspace_pool.h" -#include "src/runtime/thread_pool.h" -#include "common/mslog.h" - -static std::mutex gWorkspaceMutex; -#ifdef __cplusplus -extern "C" { -#endif -void LiteAPISetLastError(const char *msg) { MS_LOGE("The lite api set last error is %s.", msg); } - -void *LiteBackendAllocWorkspace(int deviceType, int deviceId, uint64_t size, int dtypeCode, int dtypeBits) { - std::lock_guard lock(gWorkspaceMutex); - auto p = mindspore::predict::WorkspacePool::GetInstance(); - if (p == nullptr) { - MS_LOGE("get ThreadPool install failed"); - return nullptr; - } - return p->AllocWorkSpaceMem(size); -} - -int LiteBackendFreeWorkspace(int deviceType, int deviceId, void *ptr) { - std::lock_guard lock(gWorkspaceMutex); - auto p = mindspore::predict::WorkspacePool::GetInstance(); - if (p == nullptr) { - MS_LOGE("get ThreadPool install failed"); - return -1; - } - p->FreeWorkSpaceMem(ptr); - return 0; -} - -void ConfigThreadPool(int mode, int nthreads) { - auto p = mindspore::predict::ThreadPool::GetInstance(); - if (p == nullptr) { - MS_LOGE("get ThreadPool install failed"); - return; - } - p->ConfigThreadPool(mode, nthreads); -} - -int LiteBackendParallelLaunch(FTVMParallelLambda flambda, void *cdata, int num_task) { - auto p = mindspore::predict::ThreadPool::GetInstance(); - if (p == nullptr) { - MS_LOGE("get ThreadPool install failed"); - return -1; - } - if (!p->LaunchThreadPoolTask()) { - MS_LOGE("get ThreadPool or thread bind failed"); - return -1; - } - if (!p->AddTask(flambda, cdata, num_task)) { - MS_LOGE("AddTask failed"); - return -1; - } - return 0; -} - -#ifdef __cplusplus -} -#endif diff --git a/predict/src/runtime/runtime_api.h b/predict/src/runtime/runtime_api.h deleted file mode 100644 index 01aa782cf8..0000000000 --- a/predict/src/runtime/runtime_api.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef PREDICT_SRC_RUNTIME_RUNTIME_API_H_ -#define PREDICT_SRC_RUNTIME_RUNTIME_API_H_ -#include - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - void *sync_handle; - int32_t num_task; -} TVMParallelGroupEnv; -typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv *penv, void *cdata); -void LiteAPISetLastError(const char *msg); -void *LiteBackendAllocWorkspace(int deviceType, int deviceId, uint64_t size, int dtypeCode, int dtypeBits); -int LiteBackendFreeWorkspace(int deviceType, int deviceId, void *ptr); -void ConfigThreadPool(int mode, int nthreads); -int LiteBackendParallelLaunch(FTVMParallelLambda flambda, void *cdata, int num_task); -int LiteBackendRegisterSystemLibSymbol(const char *name, void *ptr); - -#ifdef __cplusplus -} -#endif -#endif // PREDICT_SRC_RUNTIME_RUNTIME_API_H_ diff --git a/predict/src/runtime/thread_pool.cc b/predict/src/runtime/thread_pool.cc deleted file mode 100644 index 6018927a18..0000000000 --- a/predict/src/runtime/thread_pool.cc +++ /dev/null @@ -1,447 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/thread_pool.h" -#include -#include "common/mslog.h" - -namespace mindspore { -namespace predict { -static constexpr int kThreadPoolMaxThreads = 8; -static const int kCoreNumThr = 4; -static const int kMidCoreNum = 2; -static const int kBigCoreNum = 2; -bool LiteQueue::Enqueue(const ThreadPoolTask &task) { - const int tailIndex = tail.load(std::memory_order_relaxed); - // queue full - auto next = (tailIndex + 1) % kSingleThreadMaxTask; - if (next == head.load(std::memory_order_acquire)) { - return false; - } - buffer[tailIndex] = task; - tail.store(next, std::memory_order_release); - taskSize.fetch_add(1); - return true; -} - -bool LiteQueue::Dequeue(ThreadPoolTask *out) { - if (out == nullptr) { - MS_LOGE("ThreadPoolTask is nullptr"); - return false; - } - if (taskSize.load() == 0) { - return false; - } - // queue empty - const int headIndex = head.load(std::memory_order_relaxed); - if (headIndex == tail.load(std::memory_order_acquire)) { - return false; - } - *out = buffer[headIndex]; - head.store((headIndex + 1) % kSingleThreadMaxTask, std::memory_order_release); - return true; -} - -bool LiteThreadBind::Bind(int numThreads, int mode) { - InitSortedCpuId(); - if (numThreads > static_cast(sortedCpuIds.size())) { - MS_LOGE("thread num %d is larger than cores %lu in the system", numThreads, sortedCpuIds.size()); - return false; - } - threadNums = numThreads + 1; - bindModel = static_cast(mode); - if (bindModel == NO_BIND) { - if (!BindAllThread(false)) { - MS_LOGE("unbind %d threads failed", threadNums); - return false; - } - MS_LOGD("unbind %d threads successful", threadNums); - } else { - if (!BindAllThread(true)) { - MS_LOGE("bind %d threads failed", threadNums); - return false; - } - MS_LOGD("bind %d threads successful", threadNums); - } - return true; -} - -void LiteThreadBind::InitSortedCpuId() { - int numCores = static_cast(std::thread::hardware_concurrency()); - if (numCores < kCoreNumThr) { - bigCore = 0; - midCore = numCores; - } else { - bigCore = kBigCoreNum; - midCore = kMidCoreNum; - } - if (numCores > kCoreNumThr) { - numCores = bigCore + midCore; - } - sortedCpuIds.resize(numCores); - sortedCpuIds.clear(); - for (int i = numCores - 1; i >= 0; --i) { - sortedCpuIds.emplace_back(i); - } -} - -bool LiteThreadBind::BindAllThread(bool bindFlag) { - if (threadNums <= 0) { - MS_LOGE("no thread pool find, current threadNums %d", threadNums); - return false; - } - if (!BindThreads(bindFlag)) { - MS_LOGE("bind threads failed"); - return false; - } - return true; -} - -bool LiteThreadBind::BindMasterThread(bool bindFlag, int mode) { - std::vector cpu; - cpu.resize(sortedCpuIds.size()); - cpu.clear(); - if (bindFlag) { - int cpuIndex = (mode == MID_CORE) ? (threadNums - 1) : 0; - auto materCpuId = sortedCpuIds.at(cpuIndex); - cpu.emplace_back(materCpuId); - } else { - // unbind master - cpu.assign(sortedCpuIds.begin(), sortedCpuIds.end()); - } - cpu_set_t cpuSet; - CPU_ZERO(&cpuSet); - for (auto coreId : cpu) { - CPU_SET(coreId, &cpuSet); - } - if (!SetCPUBind(pthread_self(), cpuSet)) { - MS_LOGE("do master bind failed. mode: %d", mode); - return false; - } - return true; -} - -bool LiteThreadBind::BindThreads(bool bindFlag) { - if (bindFlag) { - if (bindModel != NO_BIND) { - size_t bindNums = std::min(sortedCpuIds.size(), threadIdList.size()); - size_t coreIndex; - cpu_set_t cpuSet; - for (size_t i = 0; i < bindNums; ++i) { - if (bindModel == MID_CORE) { - coreIndex = sortedCpuIds.size() - i - 1; - } else { - coreIndex = i; - } - CPU_ZERO(&cpuSet); - CPU_SET(sortedCpuIds[coreIndex], &cpuSet); - if (!threadIdList[i].second) { - MS_LOGD("threadIdList[%lu]=%lu, sortedCpuIds[%lu]=%d", i, threadIdList[i].first, coreIndex, - sortedCpuIds[coreIndex]); - if (!SetCPUBind(threadIdList[i].first, cpuSet)) { - MS_LOGE("do SetCPUBind failed"); - return false; - } - } - threadIdList[i].second = true; - } - } - } else { - // unbind - size_t bindNums = std::min(sortedCpuIds.size(), threadIdList.size()); - cpu_set_t cpuSet; - CPU_ZERO(&cpuSet); - for (auto coreId : sortedCpuIds) { - CPU_SET(coreId, &cpuSet); - } - for (size_t i = 0; i < bindNums; ++i) { - if (!SetCPUBind(threadIdList[i].first, cpuSet)) { - MS_LOGE("do SetCPUBind failed"); - return false; - } - threadIdList[i].second = false; - } - } - return true; -} - -bool LiteThreadBind::SetCPUBind(pthread_t threadId, const cpu_set_t &cpuSet) { -#if defined(__ANDROID__) -#if __ANDROID_API__ >= 21 - int ret = sched_setaffinity(pthread_gettid_np(threadId), sizeof(cpu_set_t), &cpuSet); - if (ret != 0) { - MS_LOGE("bind thread %ld to cpu failed.ERROR %d", threadId, ret); - } -#endif -#else - int ret = pthread_setaffinity_np(threadId, sizeof(cpu_set_t), &cpuSet); - if (ret != 0) { - MS_LOGE("bind thread %ld to cpu failed.ERROR %d", threadId, ret); - return false; - } -#endif - return true; -} - -LiteThreadPool::LiteThreadPool(int numThreads) { - queueList.resize(kThreadPoolMaxThreads); - queueList.clear(); - AddNewThread(numThreads); -} - -void LiteThreadPool::AddNewThread(int newNums) { - for (int i = curThreadNums, j = 0; j < newNums; ++j, ++i) { - queueList.push_back(std::unique_ptr(new LiteQueue())); - threadList.emplace_back([this, i]() { - ThreadPoolTask task; - while (!destroy) { - while (running != 0) { - MS_LOGD("i = %d, thread id = %lu, taskSize = %d", i, pthread_self(), queueList[i]->taskSize.load()); - while (queueList[i]->taskSize.load() > 0 && queueList[i]->Dequeue(&task)) { - auto ret = task.first(task.second.taskId, task.second.tvmParam, task.second.cdata); - if (ret != 0) { - errorInfo.emplace_back(std::make_pair(task.second.taskId, std::make_pair(false, ret))); - } - queueList[i]->taskSize.fetch_sub(1); - } - std::this_thread::yield(); - } - std::unique_lock queueLock(tMutex); - queueReady.wait(queueLock, [this] { return destroy || running != 0; }); - } - }); - } - MS_LOGI("%d new thread create", newNums); - curThreadNums += newNums; -} - -bool LiteThreadPool::DistributeTask(ThreadPoolTask task, int numTask) { - // wake up - errorInfo.clear(); - if (!AddRunReference()) { - MS_LOGE("add reference failed"); - return false; - } - bool kSuccFlag; - for (int i = 1; i < numTask; ++i) { - task.second.taskId = i; - do { - kSuccFlag = false; - for (auto &queue : queueList) { - MS_ASSERT(queue != nullptr); - if (queue->Enqueue(task)) { - kSuccFlag = true; - break; - } - } - std::this_thread::yield(); - } while (!kSuccFlag); - } - MS_LOGI("add %d task successful", numTask); - // master thread - int ret = task.first(0, task.second.tvmParam, task.second.cdata); - if (ret != 0) { - errorInfo.emplace_back(std::make_pair(0, std::make_pair(false, ret))); - } - kSuccFlag = false; - while (!kSuccFlag) { - kSuccFlag = true; - for (auto iter = queueList.begin(); iter != queueList.end(); ++iter) { - if ((*iter)->taskSize.load() != 0) { - kSuccFlag = false; - break; - } - } - std::this_thread::yield(); - } - // hibernate - if (!SubRunReference()) { - MS_LOGE("sub reference failed"); - return false; - } - MS_LOGI("finish %d task successful", numTask); - return CheckResult(); -} - -bool LiteThreadPool::AddRunReference() { - running.fetch_add(1); - std::lock_guard queueLock(tMutex); - queueReady.notify_all(); - return true; -} - -bool LiteThreadPool::SubRunReference() { - running.fetch_sub(1); - return true; -} - -bool LiteThreadPool::CheckResult() { - bool kSuccFlag = true; - for (auto result : errorInfo) { - if (result.second.first) { - MS_LOGE("task %d failed, error code is %d", result.first, result.second.second); - kSuccFlag = false; - } - } - return kSuccFlag; -} - -int ThreadPool::GetThreadNum(int numThreads) { - if (numThreads <= 0 || numThreads > kThreadPoolMaxThreads) { - MS_LOGE("numThreads %d, must be greater than 0 or less than or equal to %d", numThreads, kThreadPoolMaxThreads); - return -1; - } else { - if (numThreads > totalThreadNum) { - return (numThreads - totalThreadNum); - } else { - MS_LOGD("%d threads have been already created", numThreads); - return 0; - } - } -} - -void ThreadPool::GetThreadIdList() { - if (gThreadPool != nullptr) { - for (int i = 0; i < totalThreadNum; ++i) { - bool kSuccFlag = false; - pthread_t threadHandle; - do { - kSuccFlag = false; - threadHandle = gThreadPool->threadList[i].native_handle(); - if (threadHandle != 0) { - kSuccFlag = true; - } - std::this_thread::yield(); - } while (!kSuccFlag); - - auto iter = std::find_if(std::begin(gThreadBind->threadIdList), std::end(gThreadBind->threadIdList), - [threadHandle](std::pair id) { return id.first == threadHandle; }); - if (iter == std::end(gThreadBind->threadIdList)) { - gThreadBind->threadIdList.emplace_back(std::make_pair(threadHandle, false)); - } - } - } - MS_ASSERT(gThreadBind != nullptr); - gThreadBind->threadIdList.emplace_back(std::make_pair(pthread_self(), false)); -} - -bool ThreadPool::SetThreadCpulBind(int mode) { - if (totalThreadNum <= 0) { - MS_LOGE("no threads need to be bind, totalThreadNum : %d", totalThreadNum); - return false; - } - std::lock_guard bMutex(gPoolMutex); - if (gThreadBind == nullptr) { - gThreadBind = std::unique_ptr(new (std::nothrow) LiteThreadBind()); - if (gThreadBind == nullptr) { - MS_LOGE("new LiteThreadBind failed"); - return false; - } - gThreadBind->threadIdList.resize(kThreadPoolMaxThreads + 1); - gThreadBind->threadIdList.clear(); - } - GetThreadIdList(); - - if (!gThreadBind->Bind(totalThreadNum, mode)) { - MS_LOGE("BindCore failed"); - return false; - } - return true; -} - -bool ThreadPool::SetThreadPool(int numThreads) { - std::lock_guard Lock(gPoolMutex); - int realNums = GetThreadNum(numThreads); - if (realNums < -1) { - return false; - } - if (realNums == 0) { - return true; - } - if (gThreadPool == nullptr) { - gThreadPool = std::unique_ptr(new (std::nothrow) LiteThreadPool(realNums)); - if (gThreadPool == nullptr) { - MS_LOGE("%d threads create failed", realNums); - return false; - } - } else { - gThreadPool->AddNewThread(realNums); - } - MS_LOGD("%d threads create successful", realNums); - return true; -} - -ThreadPool *ThreadPool::GetInstance() { - static ThreadPool instance; - return &instance; -} - -void ThreadPool::ConfigThreadPool(int mode, int numThreads) { - bindMode = mode; - totalThreadNum = numThreads; -} - -bool ThreadPool::LaunchThreadPoolTask() { - if (gThreadPool == nullptr) { - if (!SetThreadPool(totalThreadNum)) { - MS_LOGE("create %d threads failed", totalThreadNum); - return false; - } - } - - if (gThreadBind == nullptr) { - if (!SetThreadCpulBind(bindMode)) { - MS_LOGE("create bind mode %d failed", bindMode); - return false; - } - } - return true; -} - -bool ThreadPool::AddTask(const WorkFun &worker, void *cdata, int numTask) { - if (numTask <= 0) { - numTask = totalThreadNum; - } - // single task, run master thread - if (numTask <= 1) { - TvmEnv env{}; - env.num_task = numTask; - int ret = worker(0, &env, cdata); - if (ret != 0) { - MS_LOGE("task 0 failed, error code is %d", ret); - return false; - } - MS_LOGD("task 0 successful"); - return true; - } - ThreadPoolTask task; - task.first = worker; - task.second.cdata = cdata; - return gThreadPool->DistributeTask(task, numTask); -} - -LiteThreadPool::~LiteThreadPool() { - destroy.store(true); - running.store(0); - queueReady.notify_all(); - for (auto &thread : threadList) { - if (thread.joinable()) { - thread.join(); - } - } -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/runtime/thread_pool.h b/predict/src/runtime/thread_pool.h deleted file mode 100644 index 53e4c1ec88..0000000000 --- a/predict/src/runtime/thread_pool.h +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_RUNTIME_THREAD_POOL_H_ -#define PREDICT_SRC_RUNTIME_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "src/runtime/runtime_api.h" - -namespace mindspore { -namespace predict { -constexpr int kSingleThreadMaxTask = 4; -using TvmEnv = TVMParallelGroupEnv; -using WorkFun = FTVMParallelLambda; -using TaskParam = struct Param { - void *cdata; - int32_t taskId; - TvmEnv *tvmParam; -}; -using ThreadPoolTask = std::pair; - -class LiteQueue { - public: - LiteQueue() = default; - ~LiteQueue() = default; - - bool Enqueue(const ThreadPoolTask &task); - bool Dequeue(ThreadPoolTask *out); - std::atomic taskSize{0}; - - private: - std::atomic head{0}; - std::atomic tail{0}; - ThreadPoolTask buffer[kSingleThreadMaxTask]{}; -}; - -class LiteThreadBind { - public: - LiteThreadBind() = default; - ~LiteThreadBind() = default; - bool Bind(int numThreads, int mode); - std::vector> threadIdList; - - private: - enum AffinityMode : int { BIG_CORE = 1, MID_CORE = -1, NO_BIND = 0 }; - void InitSortedCpuId(); - bool BindAllThread(bool bindFlag); - bool BindMasterThread(bool bindFlag, int mode = MID_CORE); - bool BindThreads(bool bindFlag); - bool SetCPUBind(pthread_t threadId, const cpu_set_t &cpuSet); - int bigCore{0}; - int midCore{0}; - int threadNums{0}; - std::vector sortedCpuIds{}; - AffinityMode bindModel{MID_CORE}; -}; - -class LiteThreadPool { - public: - LiteThreadPool() = default; - explicit LiteThreadPool(int numThreads); - ~LiteThreadPool(); - - void AddNewThread(int newNums); - bool DistributeTask(ThreadPoolTask task, int numTask); - std::vector threadList{}; - - private: - using errCode = std::pair; - bool AddRunReference(); - bool SubRunReference(); - bool CheckResult(); - int curThreadNums{0}; - std::vector> queueList; - std::atomic_int running{0}; - std::mutex tMutex; - std::condition_variable queueReady; - std::atomic destroy = {false}; - std::vector> errorInfo{}; -}; - -class ThreadPool { - public: - static ThreadPool *GetInstance(); - void ConfigThreadPool(int mode, int numThreads); - bool LaunchThreadPoolTask(); - bool AddTask(const WorkFun &worker, void *cdata, int numTask); - - ThreadPool(const ThreadPool &) = delete; - ThreadPool &operator=(const ThreadPool &) = delete; - - private: - ThreadPool() = default; - ~ThreadPool() = default; - int GetThreadNum(int numThreads); - void GetThreadIdList(); - bool SetThreadPool(int numThreads = 1); - bool SetThreadCpulBind(int mode); - std::unique_ptr gThreadPool{nullptr}; - std::unique_ptr gThreadBind{nullptr}; - std::mutex gPoolMutex; - int totalThreadNum{1}; - int bindMode{-1}; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_SRC_RUNTIME_THREAD_POOL_H_ diff --git a/predict/src/runtime/workspace_pool.cc b/predict/src/runtime/workspace_pool.cc deleted file mode 100644 index 6cafe7482e..0000000000 --- a/predict/src/runtime/workspace_pool.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/runtime/workspace_pool.h" -#include -#include -#include "common/mslog.h" - -namespace mindspore { -namespace predict { -static constexpr size_t kWorkspacePageSize = 4096; -static constexpr int kTempAllocaAlignment = 64; -WorkspacePool *WorkspacePool::GetInstance() { - static WorkspacePool instance; - return &instance; -} - -void *WorkspacePool::AllocWorkSpaceMem(size_t size) { - size_t nbytes = (size + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize; - if (nbytes == 0) { - nbytes = kWorkspacePageSize; - } - std::pair alloc; - // fist alloc - if (freeList.empty()) { - alloc.first = nbytes; - alloc.second = memalign(kTempAllocaAlignment, nbytes); - } else if (freeList.size() == 1) { // one element - alloc = *(freeList.begin()); - freeList.erase(freeList.begin()); - if (alloc.first < nbytes) { - free(alloc.second); - alloc.first = nbytes; - alloc.second = memalign(kTempAllocaAlignment, nbytes); - } - } else { - if ((*(freeList.begin())).first >= nbytes) { - auto iter = freeList.begin(); - for (; iter != freeList.end(); ++iter) { - if ((*iter).first < size) { - alloc = *(--iter); - freeList.erase(iter); - break; - } - } - if (iter == freeList.end()) { - alloc = *(freeList.rbegin()); - freeList.erase(--freeList.end()); - } - } else { - alloc = *(freeList.begin()); - freeList.erase(freeList.begin()); - free(alloc.second); - alloc.first = nbytes; - alloc.second = memalign(kTempAllocaAlignment, nbytes); - } - } - allocList.emplace_back(alloc); - return alloc.second; -} - -void WorkspacePool::FreeWorkSpaceMem(void *ptr) { - if (ptr == nullptr) { - return; - } - std::pair alloc; - if (allocList.empty()) { - MS_LOGE("no mem have been alloc"); - return; - } else if (allocList.back().second == ptr) { - alloc = allocList.back(); - allocList.pop_back(); - } else { - auto iter = allocList.begin(); - for (; iter != allocList.end(); ++iter) { - if ((*iter).second == ptr) { - alloc = *iter; - allocList.erase(iter); - break; - } - } - if (iter == allocList.end()) { - MS_LOGE("no value ptr have been alloc"); - return; - } - } - freeList.insert(alloc); -} - -WorkspacePool::~WorkspacePool() { - for (auto &a : allocList) { - free(a.second); - } - allocList.clear(); - for (auto &f : freeList) { - free(f.second); - } - freeList.clear(); -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/runtime/workspace_pool.h b/predict/src/runtime/workspace_pool.h deleted file mode 100644 index ce8a5ca3ab..0000000000 --- a/predict/src/runtime/workspace_pool.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_SRC_RUNTIME_WORKSPACE_POOL_H_ -#define PREDICT_SRC_RUNTIME_WORKSPACE_POOL_H_ -#include -#include -#include -#include -#include -#include - -namespace mindspore { -namespace predict { -class WorkspacePool { - public: - WorkspacePool() = default; - ~WorkspacePool(); - WorkspacePool(const WorkspacePool &) = delete; - WorkspacePool &operator=(const WorkspacePool &) = delete; - static WorkspacePool *GetInstance(); - void *AllocWorkSpaceMem(size_t size); - void FreeWorkSpaceMem(void *ptr); - - private: - std::vector> allocList{}; - std::set, std::greater>> freeList{}; -}; -} // namespace predict -} // namespace mindspore -#endif // PREDICT_SRC_RUNTIME_WORKSPACE_POOL_H_ diff --git a/predict/src/session.cc b/predict/src/session.cc deleted file mode 100644 index b808ec7c6b..0000000000 --- a/predict/src/session.cc +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/session.h" -#include -#include -#include "include/errorcode.h" -#include "common/mslog.h" -#include "src/graph.h" -#include "src/graph_execution.h" - -namespace mindspore { -namespace predict { -Context m_ctx; -bool m_isConfig = false; - -// In 32bits, this evaluates to 2GB - 1 -static constexpr auto MAX_BUFFER_SIZE = ((1ULL << (sizeof(int32_t) * 8 - 1)) - 1); - -std::shared_ptr CreateSession(const char *graphBuf, size_t size, const Context &ctx) { - if (graphBuf == nullptr) { - MS_LOGE("the graphBuf is nullptr"); - return nullptr; - } - if (size > MAX_BUFFER_SIZE) { - MS_LOGE("the size is invalid"); - return nullptr; - } - auto session = std::make_shared(ctx); - MS_ASSERT(session != nullptr); - auto ret = session->Init(graphBuf, size); - if (ret != RET_OK) { - MS_LOGE("Init session failed."); - return nullptr; - } - return session; -} -Session::Session(const Context &ctx) : _ctx(ctx) { - Context cfgCtx; - cfgCtx = ctx; - if (cfgCtx.threadNum > m_ctx.threadNum) { - cfgCtx.threadNum = m_ctx.threadNum; - } -} - -int Session::Init(const char *graphBuf, size_t size) { - _graph = Graph::CreateFromBuf(graphBuf, size, _ctx); - if (_graph == nullptr) { - MS_LOGE("Graph create from buf failed."); - return RET_NULL_PTR; - } - - auto ret = this->InitExecutor(); - if (ret != RET_OK) { - MS_LOGE("Init Executor failed"); - return ret; - } - return ret; -} - -int Session::InitExecutor() { - if (_executor != nullptr) { - delete _executor; - _executor = nullptr; - } - if (_graph != nullptr) { - _executor = new (std::nothrow) GraphExecution(_ctx, _graph); - if (_executor == nullptr) { - MS_LOGE("new GraphExecution fail"); - return RET_ERROR; - } - return RET_OK; - } else { - MS_LOGE("the graph is nullptr"); - return RET_ERROR; - } -} - -Session::~Session() { - if (_executor != nullptr) { - delete _executor; - } - if (_graph != nullptr) { - delete _graph; - } -} - -int Session::Run(const std::vector &inputs) { - auto ret = RET_OK; - if (reinitExecutor) { - ret = this->InitExecutor(); - if (ret != RET_OK) { - MS_LOGE("Init Executor failed"); - return ret; - } - } - if (_executor == nullptr) { - MS_LOGE("_executor is nullptr"); - return ret; - } - ret = _executor->Run(inputs); - return ret; -} - -std::vector Session::GetInput() { - if (_executor == nullptr) { - MS_LOGE("_executor is nullptr"); - return std::vector{}; - } - auto inputs = _executor->GetInput(); - if (inputs.empty()) { - MS_LOGI("output is empty."); - } - return inputs; -} - -std::vector Session::GetOutput(const std::string &nodeName) { - if (_executor == nullptr) { - MS_LOGE("graph's executor is nullptr."); - return std::vector{}; - } - auto outputs = _executor->GetOutput(nodeName); - if (outputs.empty()) { - MS_LOGI("output is empty."); - } - return outputs; -} - -std::map> Session::GetAllOutput() { - if (_executor == nullptr) { - MS_LOGE("graph's executor is nullptr."); - return std::map>{}; - } - auto outputs = _executor->GetAllOutput(); - if (outputs.empty()) { - MS_LOGI("outputs is empty."); - } - return outputs; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/src/tensor.cc b/predict/src/tensor.cc deleted file mode 100644 index de758f3407..0000000000 --- a/predict/src/tensor.cc +++ /dev/null @@ -1,517 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/tensor.h" -#include "common/mslog.h" -#include "src/op_common.h" -#include "include/errorcode.h" -#include "securec/include/securec.h" -#include "common/common.h" -#include "src/runtime/allocator.h" - -namespace mindspore { -namespace predict { -Tensor *Tensor::CopyFromTensorDef(const TensorDef &tensorDef) { - std::vector dims; - - if (tensorDef.dims() == nullptr) { - MS_LOGD("tensorDef->dims is nullptr"); - } else { - MS_ASSERT(tensorDef.dims()->data() != nullptr); - for (uint32_t j = 0; j < tensorDef.dims()->size(); j++) { - dims.push_back(tensorDef.dims()->data()[j]); - } - } - auto tensor = - std::unique_ptr(new (std::nothrow) Tensor(tensorDef.dataType(), dims, tensorDef.format(), nullptr)); - if (tensor == nullptr) { - MS_LOGE("new Tensor failed"); - return nullptr; - } - - if (tensorDef.refCount() == MSConst_WEIGHT_REFCOUNT && tensorDef.data() != nullptr && tensorDef.data()->size() > 0) { - if (dims.size() < 1) { - tensor->SetDims({1}); - } - auto ret = tensor->MallocData(); - if (ret != RET_OK) { - MS_LOGE("malloc data fail,datasize %zu", tensor->GetDataSize()); - return nullptr; - } - auto tensorData = tensorDef.data()->data(); - ret = memcpy_sp(tensor->GetData(), tensor->GetDataSize(), tensorData, tensorDef.data()->size()); - if (ret != RET_OK) { - MS_LOGE("copy data fail,dst size %zu, src size %u", tensor->GetDataSize(), tensorDef.data()->size()); - return nullptr; - } - } - tensor->refCount = tensorDef.refCount(); - return tensor.release(); -} - -Tensor::Tensor(const Tensor &tensor, bool copyData) { - format = tensor.format; - dlTensor.data = nullptr; - dlTensor.ctx.device_type = tensor.dlTensor.ctx.device_type; - dlTensor.ctx.device_id = tensor.dlTensor.ctx.device_id; - dlTensor.strides = nullptr; - dlTensor.byte_offset = tensor.dlTensor.byte_offset; - dlTensor.dtype.code = tensor.dlTensor.dtype.code; - dlTensor.dtype.bits = tensor.dlTensor.dtype.bits; - dlTensor.dtype.lanes = tensor.dlTensor.dtype.lanes; - - dlTensor.ndim = tensor.dlTensor.ndim; - if (dlTensor.ndim > 0) { - dlTensor.shape = new (std::nothrow) int64_t[dlTensor.ndim]; - if (dlTensor.shape != nullptr) { - for (int i = 0; i < dlTensor.ndim; i++) { - dlTensor.shape[i] = tensor.dlTensor.shape[i]; - } - } else { - MS_LOGW("new shape fail,ndim %d", dlTensor.ndim); - } - } else { - dlTensor.shape = nullptr; - } - if (copyData) { - allocator = tensor.allocator; - refCount = tensor.refCount; - auto ret = MallocData(); - if (ret != RET_OK) { - return; - } - size_t datasize = GetDataSize(); - ret = memcpy_sp(dlTensor.data, datasize, tensor.dlTensor.data, datasize); - if (ret != RET_OK) { - return; - } - } -} - -Tensor::Tensor(DataType dt, const std::vector &dims, Format format, void *data) { - this->format = format; - dlTensor.data = data; - dlTensor.ctx.device_type = DLDeviceType::kDLCPU; - dlTensor.ctx.device_id = 0; - dlTensor.strides = nullptr; - dlTensor.byte_offset = 0; - - dlTensor.ndim = static_cast(dims.size()); - if (dlTensor.ndim > 0) { - dlTensor.shape = new (std::nothrow) int64_t[dlTensor.ndim]; - if (dlTensor.shape != nullptr) { - for (int i = 0; i < dlTensor.ndim; i++) { - dlTensor.shape[i] = dims[i]; - } - } else { - MS_LOGW("new shape fail,ndim %d", dlTensor.ndim); - } - } else { - dlTensor.shape = nullptr; - } - - SetDataType(dt); -} - -Tensor::~Tensor() { FreeTensor(); } - -DLDataType Tensor::GetTensorDtype() const { return dlTensor.dtype; } - -void *Tensor::GetData() const { return dlTensor.data; } - -void Tensor::SetData(void *data) { dlTensor.data = data; } - -DataType Tensor::GetDataType() const { - DataType dataType = DataType_DT_UNDEFINED; - switch (dlTensor.dtype.code) { - case kDLFloat: - if (dlTensor.dtype.bits == 32) { - dataType = DataType_DT_FLOAT; - } else if (dlTensor.dtype.bits == 16) { - dataType = DataType_DT_FLOAT16; - } - break; - case kDLInt: - if (dlTensor.dtype.bits == 32) { - dataType = DataType_DT_INT32; - } else if (dlTensor.dtype.bits == 8) { - dataType = DataType_DT_INT8; - } - break; - case kDLUInt: - if (dlTensor.dtype.bits == 32) { - dataType = DataType_DT_UINT32; - } else if (dlTensor.dtype.bits == 8) { - dataType = DataType_DT_UINT8; - } - break; - default: - break; - } - return dataType; -} - -void Tensor::SetDataType(DataType dt) { - switch (dt) { - case DataType_DT_FLOAT: - dlTensor.dtype.code = kDLFloat; - dlTensor.dtype.bits = 32; - dlTensor.dtype.lanes = 1; - break; - case DataType_DT_FLOAT16: - dlTensor.dtype.code = kDLFloat; - dlTensor.dtype.bits = 16; - dlTensor.dtype.lanes = 1; - break; - case DataType_DT_INT8: - dlTensor.dtype.code = kDLInt; - dlTensor.dtype.bits = 8; - dlTensor.dtype.lanes = 1; - break; - case DataType_DT_UINT8: - dlTensor.dtype.code = kDLUInt; - dlTensor.dtype.bits = 8; - dlTensor.dtype.lanes = 1; - break; - case DataType_DT_INT32: - dlTensor.dtype.code = kDLInt; - dlTensor.dtype.bits = 32; - dlTensor.dtype.lanes = 1; - break; - case DataType_DT_UINT32: - dlTensor.dtype.code = kDLUInt; - dlTensor.dtype.bits = 32; - dlTensor.dtype.lanes = 1; - break; - default: - MS_LOGW(" DataType %d is not implemented.", dt); - MS_LOGW(" DataType DT_FLOAT is used."); - dlTensor.dtype.code = kDLFloat; - dlTensor.dtype.bits = 32; - dlTensor.dtype.lanes = 1; - return; - } -} - -int Tensor::GetNDim() const { return dlTensor.ndim; } - -std::vector Tensor::GetDims() const { - std::vector dims; - for (int i = 0; i < dlTensor.ndim; i++) { - dims.push_back(dlTensor.shape[i]); - } - return dims; -} - -size_t Tensor::GetElementSize() const { - const int tile = 4; - if (format == Format_NC4HW4) { - size_t size = 1; - for (int i = 0; i < dlTensor.ndim; i++) { - auto var = static_cast(dlTensor.shape[i]); - if (i == 1) { - var = UP_DIV(var, tile) * tile; - } - size *= var; - } - return size; - } else { - size_t size = 1; - for (int i = 0; i < dlTensor.ndim; i++) { - size *= static_cast(dlTensor.shape[i]); - } - - return size; - } -} - -size_t Tensor::GetDataSize() const { - size_t size = GetElementSize(); - - const int BYTES = 8; - const int GAP = 7; - size *= (dlTensor.dtype.bits * dlTensor.dtype.lanes + GAP) / BYTES; - return size; -} - -int Tensor::MallocData(std::shared_ptr allocator, int refCount) { - if (dlTensor.data != nullptr) { - this->refCount += refCount; - return RET_OK; - } - this->refCount = refCount; - - size_t size = GetDataSize(); - if (allocator) { - this->allocator = allocator; - dlTensor.data = allocator->Malloc(size); - } else { - if (size > MAX_MALLOC_SIZE) { - return RET_ERROR; - } - dlTensor.data = malloc(size); - } - if (dlTensor.data == nullptr) { - return RET_ERROR; - } - return RET_OK; -} - -void Tensor::ForceFreeData() { - if (allocator) { - allocator->Free(dlTensor.data); - } else { - free(dlTensor.data); - } - dlTensor.data = nullptr; -} - -void Tensor::FreeData() { - --refCount; - if (refCount <= 0) { - ForceFreeData(); - } -} - -bool Tensor::CompareShape(const Tensor &dst) { - if (dlTensor.ndim != dst.dlTensor.ndim || dlTensor.shape == nullptr || dst.dlTensor.shape == nullptr) { - MS_LOGE("param error, one.ndim: %d, other.ndim: %d, one shape %p,other shape %p", dlTensor.ndim, dst.dlTensor.ndim, - dlTensor.shape, dst.dlTensor.shape); - return false; - } - - for (int i = 0; i < dlTensor.ndim; i++) { - if (dlTensor.shape[i] != dst.dlTensor.shape[i]) { - MS_LOGE("one.shape[%d]: %ld, other.shape[%d]: %ld", i, dlTensor.shape[i], i, dst.dlTensor.shape[i]); - return false; - } - } - return true; -} - -bool Tensor::CompareShape(const std::vector &other) { - if (dlTensor.ndim != other.size() || dlTensor.shape == nullptr) { - return false; - } - - for (int i = 0; i < dlTensor.ndim; i++) { - if (dlTensor.shape[i] != other[i]) { - return false; - } - } - return true; -} - -int64_t Tensor::Height() const { - if (dlTensor.shape == nullptr) { - MS_LOGE("shape is null"); - } - if (dlTensor.ndim != DIM_DEFAULT_SIZE) { - MS_LOGE("Tensor should be 4 dimensional."); - return -1; - } - switch (this->format) { - case Format_NCHW: - case Format_NC4HW4: - return dlTensor.shape[NCHW_H]; - case Format_NHWC: - return dlTensor.shape[NHWC_H]; - default: - MS_LOGE("Unsupported format: %d", this->format); - return -1; - } -} - -int64_t Tensor::Width() const { - if (dlTensor.shape == nullptr) { - MS_LOGE("shape is null"); - } - if (dlTensor.ndim != DIM_DEFAULT_SIZE) { - MS_LOGE("Tensor should be 4 dimensional."); - return -1; - } - switch (this->format) { - case Format_NCHW: - case Format_NC4HW4: - return dlTensor.shape[NCHW_W]; - case Format_NHWC: - return dlTensor.shape[NHWC_W]; - default: - MS_LOGE("Unsupported format: %d", this->format); - return -1; - } -} - -int64_t Tensor::Channel() const { - if (dlTensor.shape == nullptr) { - MS_LOGE("shape is null"); - } - if (dlTensor.ndim != DIM_DEFAULT_SIZE) { - MS_LOGE("Tensor should be 4 dimensional."); - return -1; - } - switch (this->format) { - case Format_NCHW: - case Format_NC4HW4: - return dlTensor.shape[NCHW_C]; - case Format_NHWC: - return dlTensor.shape[NHWC_C]; - default: - MS_LOGE("Unsupported format: %d", this->format); - return -1; - } -} - -int64_t Tensor::Batch() const { - if (dlTensor.shape == nullptr) { - MS_LOGE("shape is null"); - } - if (dlTensor.ndim != DIM_DEFAULT_SIZE) { - MS_LOGE("Tensor should be 4 dimensional."); - return -1; - } - switch (this->format) { - case Format_NCHW: - case Format_NC4HW4: - case Format_NHWC: - return dlTensor.shape[NCHW_N]; - default: - MS_LOGE("Unsupported format: %d", this->format); - return -1; - } -} - -int64_t Tensor::Stride(int index) const { - if (dlTensor.strides) { - return dlTensor.strides[index]; - } - if (dlTensor.shape == nullptr) { - MS_LOGE("shape is null"); - return -1; - } - int64_t stride = 1; - for (int i = index + 1; i < dlTensor.ndim; i++) { - stride *= dlTensor.shape[i]; - } - return stride; -} - -void Tensor::SetStride() { - if (dlTensor.strides == nullptr) { - if (dlTensor.ndim < 1) { - MS_LOGE("dims of dlTensor is empty."); - return; - } - dlTensor.strides = new (std::nothrow) int64_t[dlTensor.ndim - 1]; - if (dlTensor.strides == nullptr) { - MS_LOGW("new stride fail, ndim %d.", dlTensor.ndim); - return; - } - } - - for (int idx = 0; idx < dlTensor.ndim - 1; idx++) { - int64_t stride = 1; - if (dlTensor.ndim <= idx + 1) { - MS_LOGE("out of for loop upper limit."); - return; - } - for (int i = idx + 1; i < dlTensor.ndim; i++) { - stride *= dlTensor.shape[i]; - } - dlTensor.strides[idx] = stride; - } -} -void Tensor::SetScale(bool isScale) { this->isScale = isScale; } - -void Tensor::SetStride(int index, int64_t stride) { - if (index >= dlTensor.ndim) { - return; - } - - if (dlTensor.strides == nullptr) { - SetStride(); - } - - dlTensor.strides[index] = stride; - return; -} - -void Tensor::SetDims(const std::vector &dims) { - if (dlTensor.shape != nullptr) { - delete[] dlTensor.shape; - } - dlTensor.ndim = static_cast(dims.size()); - if (dlTensor.ndim > 0) { - dlTensor.shape = new (std::nothrow) int64_t[dlTensor.ndim]; - if (dlTensor.shape != nullptr) { - for (int i = 0; i < dlTensor.ndim; i++) { - dlTensor.shape[i] = dims[i]; - } - } else { - MS_LOGW("new shape fail,ndim %d", dlTensor.ndim); - } - } else { - dlTensor.shape = nullptr; - } -} - -void Tensor::FreeTensor() { - if (dlTensor.shape != nullptr) { - delete[] dlTensor.shape; - dlTensor.shape = nullptr; - } - - if (dlTensor.strides != nullptr) { - delete[] dlTensor.strides; - dlTensor.strides = nullptr; - } - - dlTensor.ndim = 0; - - if (allocator != nullptr) { - allocator->Free(dlTensor.data); - } else { - free(dlTensor.data); - } - dlTensor.data = nullptr; -} - -size_t Tensor::GetNC4HW4ElementSize(bool isNhwc) { - int alignIndex = 1; - if (isNhwc) { - alignIndex = 3; - } - - size_t size = 1; - for (int i = 0; i < dlTensor.ndim; i++) { - auto var = static_cast(dlTensor.shape[i]); - if (i == alignIndex) { - var = ALIGN_UP4(var); - } - size *= var; - } - return size; -} - -size_t Tensor::GetNC4HW4DataSize(bool isNhwc) { - size_t size = GetNC4HW4ElementSize(isNhwc); - const int BYTES = 8; - const int GAP = 7; - size *= (dlTensor.dtype.bits * dlTensor.dtype.lanes + GAP) / BYTES; - return size; -} -} // namespace predict -} // namespace mindspore diff --git a/predict/test/CMakeLists.txt b/predict/test/CMakeLists.txt deleted file mode 100755 index 9370ff7ce0..0000000000 --- a/predict/test/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -cmake_minimum_required(VERSION 3.12) -project(ms-test) - -set(CMAKE_CXX_STANDARD 11) - -#include 3rd -include_directories(${3RD_DIR}/securec/include) -include_directories(${3RD_DIR}/flatbuffers/include) -include_directories(${3RD_DIR}/googletest/googletest/include) -include_directories(${3RD_DIR}/googletest/googlemock/include) -include_directories(${3RD_DIR}/securec/include) - -#include ms -include_directories(.) -include_directories(..) - -link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../output/lib/) - -set(COMMON_SRC ${PREDICT_DIR}/common/flag_parser.cc - ${PREDICT_DIR}/common/file_utils.cc - ${PREDICT_DIR}/common/mslog.cc - ${PREDICT_DIR}/common/storage.cc - ${PREDICT_DIR}/common/utils.cc) - -#tools src -file(GLOB_RECURSE TOOLS_SRC ../tools/*.cpp) - -add_executable(ms-test - ${COMMON_SRC} - ${TOOLS_SRC} - src/graph_tests.cc - benchmark/benchmark_tests.cc - ${CMAKE_SOURCE_DIR}/benchmark/benchmark.cc - ${TF_PROTO_SRC} - ${MS_CONVERTER_SRC} - test_context.h - test_context.cc - main.cc) - -target_link_libraries(ms-test mspredict gtest libsecurec.a) -add_dependencies(ms-test securec) -add_dependencies(ms-test gtest) - -# copy test file -add_custom_command(TARGET ms-test POST_BUILD - COMMAND mkdir -pv ${DOTEST_DIR} - COMMAND cp ${PREDICT_BUILD_DIR}/test/ms-test ${DOTEST_DIR} - COMMAND cp ${PREDICT_DIR}/test/run_tests.sh ${PREDICT_BUILD_DIR}/test/ - COMMAND cp -r ${PREDICT_DIR}/test/data/ ${PREDICT_BUILD_DIR}/test/doTest/) diff --git a/predict/test/benchmark/benchmark_tests.cc b/predict/test/benchmark/benchmark_tests.cc deleted file mode 100644 index e1e218e851..0000000000 --- a/predict/test/benchmark/benchmark_tests.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "test/test_context.h" -#include "benchmark/benchmark.h" - -#define LENET_ARGS 2 -#define MS_ARGS 4 - -namespace mindspore { -namespace predict { -class BenchmarkTest : public ::testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - std::string root; -}; - -TEST_F(BenchmarkTest, BenchmarkRun) { - const char* args[LENET_ARGS]; - args[0] = "./benchmark"; - args[1] = "--modelPath=./data/lenet/lenet.ms"; - - int errorcode = mindspore::predict::RunBenchmark(LENET_ARGS, args); - EXPECT_EQ(0, errorcode); -} - -TEST_F(BenchmarkTest, LenetRun) { - const char* args[MS_ARGS]; - args[0] = "./benchmark"; - args[1] = "--modelPath=./data/ms/mindspore.ms"; - args[2] = "--inDataPath=./data/ms/mindspore.bin"; - args[3] = "--calibDataPath=./data/ms/mindspore.out"; - - int errorcode = mindspore::predict::RunBenchmark(MS_ARGS, args); - EXPECT_EQ(0, errorcode); -} - -TEST_F(BenchmarkTest, MindSporeRun) { - const char* args[4]; - args[0] = "./benchmark"; - args[1] = "--modelPath=./data/lenet/lenet.ms"; - args[2] = "--inDataPath=./data/lenet/lenet.bin"; - args[3] = "--calibDataPath=./data/lenet/lenet.out"; - - int errorcode = mindspore::predict::RunBenchmark(4, args); - EXPECT_EQ(0, errorcode); -} -} // namespace predict -} // namespace mindspore diff --git a/predict/test/data/lenet/lenet.bin b/predict/test/data/lenet/lenet.bin deleted file mode 100755 index 6aef53bd64..0000000000 Binary files a/predict/test/data/lenet/lenet.bin and /dev/null differ diff --git a/predict/test/data/lenet/lenet.ms b/predict/test/data/lenet/lenet.ms deleted file mode 100755 index c66948dd67..0000000000 Binary files a/predict/test/data/lenet/lenet.ms and /dev/null differ diff --git a/predict/test/data/lenet/lenet.out b/predict/test/data/lenet/lenet.out deleted file mode 100644 index 31a1d61b8b..0000000000 --- a/predict/test/data/lenet/lenet.out +++ /dev/null @@ -1,2 +0,0 @@ -prob 2 1 10 -0.0 0.9999994 5.4061115e-07 0.0 0.0 0.0 0.0 5.690875e-08 1.1269122e-34 0.0 diff --git a/predict/test/data/ms/mindspore.bin b/predict/test/data/ms/mindspore.bin deleted file mode 100755 index 77981f8b7e..0000000000 --- a/predict/test/data/ms/mindspore.bin +++ /dev/null @@ -1,5 +0,0 @@ -p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پkI<6,?6,?6,?6,?6,?BL>p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پy>b??<4@<4@<4@<4@@L?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ`9?>`9?>`9?>`9?>`9?>`9?>`9?>N?<4@<4@<4@!2@@Yy @ -f>p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ?<4@<4@<4@<4@<4@<4@<4@<4@<4@<4@!2@R@lnp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ?<4@<4@<4@<4@<4@<4@<4@<4@<4@<4@@lnҾp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پh?Yy @w#@<4@<4@<4@<4@<4@<4@6,?`9?>`9?>`9?>dד=p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پYy @<4@<4@<4@<4@p@b?b?b?=ٽp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پYy @<4@<4@<4@@;?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پYy @<4@<4@<4@,?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ ?5 @<4@<4@@b??ݜp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ ?w#@<4@<4@<4@:)@T>`9?>`9?> p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ?b>-@<4@<4@<4@<4@<4@<4@4?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ4?<4@<4@<4@<4@<4@<4@b>-@?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ `9?>`9?>T>:)@<4@<4@<4@w#@ ?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پݜ?b?@<4@<4@5 @ ?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پL?BA@<4@<4@w#@?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ?<4@<4@<4@'@ap2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ -f>b?kI-@@!2@<4@<4@<4@<4@<4@4@y>~žp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ -f>@<4@<4@<4@<4@<4@<4@<4@<4@<4@<4@@ -f>p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پL?6,?6,?6,?6,?6,?6,?6,?6,?6,?6,?L?p2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پp2پ \ No newline at end of file diff --git a/predict/test/data/ms/mindspore.ms b/predict/test/data/ms/mindspore.ms deleted file mode 100755 index 654e4060aa..0000000000 Binary files a/predict/test/data/ms/mindspore.ms and /dev/null differ diff --git a/predict/test/data/ms/mindspore.out b/predict/test/data/ms/mindspore.out deleted file mode 100644 index a7325da0e3..0000000000 --- a/predict/test/data/ms/mindspore.out +++ /dev/null @@ -1,2 +0,0 @@ -Default/fc3-Dense/BiasAdd-op14 2 1 10 --2.1191406 -8.4140625 -13.625 2.8222656 -11.4453125 30.734375 7.515625 -9.921875 2.5371094 2.9238281 diff --git a/predict/test/main.cc b/predict/test/main.cc deleted file mode 100644 index 7e18f9a10d..0000000000 --- a/predict/test/main.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "test/test_context.h" -#include "common/mslog.h" - -int main(int argc, char **argv) { - // Initialize Google Test. - testing::InitGoogleTest(&argc, argv); - - for (size_t i = 0; i < argc; i++) { - std::string arg = std::string(argv[i]); - if (arg.find("--testRoot") != std::string::npos) { - auto testContext = - std::shared_ptr(new (std::nothrow) mindspore::predict::TestContext()); - if (testContext == nullptr) { - MS_LOGE("new testContext failed"); - return 1; - } - testContext->SetTestRoot(arg.substr(arg.find("--testRoot=") + 11)); - break; - } - } - - int result = RUN_ALL_TESTS(); - - return result; -} diff --git a/predict/test/run_tests.sh b/predict/test/run_tests.sh deleted file mode 100755 index e5a94e70f7..0000000000 --- a/predict/test/run_tests.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -CUR_DIR=$(dirname "$(readlink -f "$0")") -echo "$CUR_DIR" -DOTEST_DIR="$CUR_DIR"/doTest - -cd "$DOTEST_DIR" -./ms-test -if [ $? -ne 0 ]; then - echo "run ./ms-test failed !" - exit 1 -fi diff --git a/predict/test/src/graph_tests.cc b/predict/test/src/graph_tests.cc deleted file mode 100644 index 8fbaf689f3..0000000000 --- a/predict/test/src/graph_tests.cc +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "schema/inner/ms_generated.h" -#include "src/graph.h" -#include "common/file_utils.h" -#include "test/test_context.h" -#include "include/session.h" - -namespace mindspore { -namespace predict { -class GraphTest : public ::testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - - std::string root; -}; - -void InitMsGraphAllTensor(SubGraphDefT *msSubgraph) { - ASSERT_NE(msSubgraph, nullptr); - std::unique_ptr tensor (new (std::nothrow) TensorDefT); - ASSERT_NE(tensor, nullptr); - tensor->refCount = MSConst_WEIGHT_REFCOUNT; - tensor->format = Format_NCHW; - tensor->dataType = DataType_DT_FLOAT; - tensor->dims = {1, 1, 1, 2}; - tensor->offset = -1; - tensor->data.resize(0); - msSubgraph->allTensors.emplace_back(std::move(tensor)); - - std::unique_ptr tensor2(new (std::nothrow) TensorDefT); - ASSERT_NE(tensor2, nullptr); - tensor2->refCount = MSConst_WEIGHT_REFCOUNT; - tensor2->format = Format_NCHW; - tensor2->dataType = DataType_DT_FLOAT; - tensor2->dims = {1, 1, 1, 2}; - tensor2->offset = -1; - tensor2->data.resize(0); - msSubgraph->allTensors.emplace_back(std::move(tensor2)); - - std::unique_ptr tensor3(new (std::nothrow) TensorDefT); - ASSERT_NE(tensor3, nullptr); - tensor3->refCount = 0; - tensor3->format = Format_NCHW; - tensor3->dataType = DataType_DT_FLOAT; - tensor3->dims = {1, 1, 1, 2}; - tensor3->offset = -1; - tensor3->data.resize(0); - msSubgraph->allTensors.emplace_back(std::move(tensor3)); -} - -void FreeOutputs(std::map> *outputs) { - for (auto &output : (*outputs)) { - for (auto &outputTensor : output.second) { - delete outputTensor; - } - } - outputs->clear(); -} - -void FreeInputs(std::vector *inputs) { - for (auto &input : *inputs) { - input->SetData(nullptr); - delete input; - } - inputs->clear(); - return; -} - -TEST_F(GraphTest, CreateFromFileAdd) { - auto msGraph = std::unique_ptr(new (std::nothrow) GraphDefT()); - ASSERT_NE(msGraph, nullptr); - msGraph->name = "test1"; - auto msSubgraph = std::unique_ptr(new (std::nothrow) SubGraphDefT()); - ASSERT_NE(msSubgraph, nullptr); - msSubgraph->name = msGraph->name + "_1"; - msSubgraph->inputIndex = {0, 1}; - msSubgraph->outputIndex = {2}; - - std::unique_ptr node(new (std::nothrow) NodeDefT); - ASSERT_NE(node, nullptr); - std::unique_ptr opDef(new (std::nothrow) OpDefT); - ASSERT_NE(opDef, nullptr); - node->opDef = std::move(opDef); - node->opDef->isLastConv = false; - node->opDef->inputIndex = {static_cast(0), 1}; - node->opDef->outputIndex = {static_cast(2)}; - node->opDef->name = msSubgraph->name + std::to_string(0); - node->fmkType = FmkType_CAFFE; - - auto attr = std::unique_ptr(new (std::nothrow) AddT()); - ASSERT_NE(attr, nullptr); - attr->format = DataFormatType_NCHW; - node->opDef->attr.type = OpT_Add; - node->opDef->attr.value = attr.release(); - - msSubgraph->nodes.emplace_back(std::move(node)); - - InitMsGraphAllTensor(msSubgraph.get()); - msGraph->subgraphs.emplace_back(std::move(msSubgraph)); - - flatbuffers::FlatBufferBuilder builder(1024); - auto offset = mindspore::predict::GraphDef::Pack(builder, msGraph.get()); - builder.Finish(offset); - int size = builder.GetSize(); - void *content = builder.GetBufferPointer(); - - Context ctx; - auto session = CreateSession(static_cast(content), size, ctx); - - std::vector tmpT = {1, 2}; - void *in1Data = tmpT.data(); - std::vector tmpT2 = {3, 5}; - void *in2Data = tmpT2.data(); - - auto inputs = session->GetInput(); - inputs[0]->SetData(in1Data); - inputs[1]->SetData(in2Data); - - auto ret = session->Run(inputs); - EXPECT_EQ(0, ret); - auto outputs = session->GetAllOutput(); - EXPECT_EQ(4, reinterpret_cast(outputs.begin()->second.front()->GetData())[0]); - EXPECT_EQ(7, reinterpret_cast(outputs.begin()->second.front()->GetData())[1]); - - FreeOutputs(&outputs); - FreeInputs(&inputs); -} -} // namespace predict -} // namespace mindspore diff --git a/predict/test/test_context.cc b/predict/test/test_context.cc deleted file mode 100644 index ca8f36d3a8..0000000000 --- a/predict/test/test_context.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "test/test_context.h" - -namespace mindspore { -namespace predict { -std::string TestContext::GetTestRoot() { return this->testRoot; } - -void TestContext::SetTestRoot(const std::string &testRoot) { this->testRoot = testRoot; } -} // namespace predict -} // namespace mindspore diff --git a/predict/test/test_context.h b/predict/test/test_context.h deleted file mode 100644 index 16f439d6e6..0000000000 --- a/predict/test/test_context.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_TEST_TEST_CONTEXT_H_ -#define PREDICT_TEST_TEST_CONTEXT_H_ - -#include - -namespace mindspore { -namespace predict { -class TestContext { - public: - TestContext() = default; - std::string GetTestRoot(); - void SetTestRoot(const std::string &testRoot); - - private: - std::string testRoot = "./"; -}; -} // namespace predict -} // namespace mindspore - -#endif // PREDICT_TEST_TEST_CONTEXT_H_ diff --git a/requirements.txt b/requirements.txt index 4038e63ea7..5fe70c0492 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ setuptools >= 40.8.0 matplotlib >= 3.1.3 # for ut test opencv-python >= 4.1.2.30 # for ut test sklearn >= 0.0 # for st test -pandas >= 1.0.2 # for ut test \ No newline at end of file +pandas >= 1.0.2 # for ut test +bs4 diff --git a/scripts/check_clang_format.sh b/scripts/check_clang_format.sh index 6ed4f7e5de..8cd900dec9 100755 --- a/scripts/check_clang_format.sh +++ b/scripts/check_clang_format.sh @@ -86,12 +86,12 @@ cd "${SCRIPTS_PATH}/.." || exit 1 CHECK_LIST_FILE='__checked_files_list__' if [ "X${mode}" == "Xall" ]; then - find mindspore/ccsrc -type f -name "*" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true + find mindspore/{ccsrc,core} -type f -name "*" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true elif [ "X${mode}" == "Xchanged" ]; then # --diff-filter=ACMRTUXB will ignore deleted files in commit - git diff --diff-filter=ACMRTUXB --name-only | grep "mindspore/ccsrc" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true + git diff --diff-filter=ACMRTUXB --name-only | grep "mindspore/ccsrc\|mindspore/core" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true else # "X${mode}" == "Xlastcommit" - git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "mindspore/ccsrc" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true + git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "mindspore/ccsrc\|mindspore/core" | grep "\.h$\|\.cc$" > "${CHECK_LIST_FILE}" || true fi CHECK_RESULT_FILE=__code_format_check_result__ diff --git a/scripts/format_source_code.sh b/scripts/format_source_code.sh index 39c5df7ce5..d9ef0626c9 100755 --- a/scripts/format_source_code.sh +++ b/scripts/format_source_code.sh @@ -86,11 +86,11 @@ cd "${SCRIPTS_PATH}/.." || exit 1 FMT_FILE_LIST='__format_files_list__' if [[ "X${mode}" == "Xall" ]]; then - find mindspore/ccsrc -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true + find mindspore/{ccsrc,core} -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true elif [[ "X${mode}" == "Xchanged" ]]; then - git diff --name-only | grep "mindspore/ccsrc" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true + git diff --name-only | grep "mindspore/ccsrc\|mindspore/core" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true else # "X${mode}" == "Xlastcommit" - git diff --name-only HEAD~ HEAD | grep "mindspore/ccsrc" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true + git diff --name-only HEAD~ HEAD | grep "mindspore/ccsrc\|mindspore/core" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true fi while read line; do diff --git a/serving/CMakeLists.txt b/serving/CMakeLists.txt index 4529323fe1..8b60168228 100644 --- a/serving/CMakeLists.txt +++ b/serving/CMakeLists.txt @@ -13,19 +13,19 @@ add_library(protobuf::libprotobuf ALIAS protobuf::protobuf) add_executable(protobuf::libprotoc ALIAS protobuf::protoc) set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) -if(CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING) find_program(_PROTOBUF_PROTOC protoc) -else() +else () set(_PROTOBUF_PROTOC $) -endif() +endif () # Find gRPC installation # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. if (EXISTS ${grpc_ROOT}/lib64) set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc") -else() +else () set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc") -endif() +endif () message("serving using grpc_DIR : " ${gPRC_DIR}) find_package(gRPC CONFIG REQUIRED) @@ -34,44 +34,81 @@ message(STATUS "Using gRPC ${gRPC_VERSION}") set(_GRPC_GRPCPP gRPC::grpc++) set(_REFLECTION gRPC::grpc++_reflection) -if(CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING) find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) -else() + find_program(_GRPC_PYTHON_PLUGIN_EXECUTABLE grpc_python_plugin) +else () set(_GRPC_CPP_PLUGIN_EXECUTABLE $) -endif() + set(_GRPC_PYTHON_PLUGIN_EXECUTABLE $) +endif () # Proto file get_filename_component(hw_proto "ms_service.proto" ABSOLUTE) get_filename_component(hw_proto_path "${hw_proto}" PATH) - # Generated sources set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc") set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h") set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc") set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h") +set(hw_py_pb2 "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2.py") +set(hw_py_pb2_grpc "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2_grpc.py") add_custom_command( - OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" "${hw_py_pb2}" "${hw_py_pb2_grpc}" COMMAND ${_PROTOBUF_PROTOC} ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${hw_proto_path}" --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" "${hw_proto}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --python_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_PYTHON_PLUGIN_EXECUTABLE}" + "${hw_proto}" DEPENDS "${hw_proto}") # Include generated *.pb.h files include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/core" - "${PROJECT_SOURCE_DIR}/mindspore/ccsrc") + "${PROJECT_SOURCE_DIR}/mindspore/ccsrc" "${PROJECT_SOURCE_DIR}/mindspore/core") file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "core/*.cc" "core/util/*.cc" "core/version_control/*.cc") list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST}) +option(ENABLE_ACL "enable acl" OFF) + +if (ENABLE_ACL) + if (DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) + else () + set(ASCEND_PATH /usr/local/Ascend) + endif () + set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/) + MESSAGE("acl lib dir " ${ACL_LIB_DIR}) + + include_directories(${ACL_LIB_DIR}/include/) + file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "acl/*.cc") + list(APPEND SERVING_SRC ${ACL_SESSION_SRC_LIST}) +endif () + include_directories(${CMAKE_BINARY_DIR}) add_executable(ms_serving ${SERVING_SRC}) -target_link_libraries(ms_serving inference mindspore_gvar) + target_link_libraries(ms_serving ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF} pthread) if (ENABLE_D) add_compile_definitions(ENABLE_D) target_link_libraries(ms_serving ${RUNTIME_LIB}) -endif() +endif () + +if (ENABLE_ACL) + add_compile_definitions(ENABLE_ACL) + add_compile_definitions(ENABLE_DVPP_INTERFACE) + set(ALC_LIB_SO ${ACL_LIB_DIR}/lib64/libruntime.so ${ACL_LIB_DIR}/lib64/libascendcl.so + ${ACL_LIB_DIR}/lib64/libacl_retr.so ${ACL_LIB_DIR}/lib64/libacl_cblas.so + ${ACL_LIB_DIR}/lib64/libacl_dvpp.so) + target_link_libraries(ms_serving ${ALC_LIB_SO}) + target_link_libraries(ms_serving jpeg_turbo::jpeg) +else () + target_link_libraries(ms_serving inference mindspore_gvar) +endif () diff --git a/serving/README.en.md b/serving/README.en.md deleted file mode 100644 index 830b94537a..0000000000 --- a/serving/README.en.md +++ /dev/null @@ -1,36 +0,0 @@ -# serving - -#### Description -A flexible, high-performance serving system for deep learning models - -#### Software Architecture -Software architecture description - -#### Installation - -1. xxxx -2. xxxx -3. xxxx - -#### Instructions - -1. xxxx -2. xxxx -3. xxxx - -#### Contribution - -1. Fork the repository -2. Create Feat_xxx branch -3. Commit your code -4. Create Pull Request - - -#### Gitee Feature - -1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md -2. Gitee blog [blog.gitee.com](https://blog.gitee.com) -3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) -4. The most valuable open source project [GVP](https://gitee.com/gvp) -5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) -6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) diff --git a/serving/README.md b/serving/README.md deleted file mode 100644 index b26b9a6887..0000000000 --- a/serving/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# serving - -#### 介绍 -A flexible, high-performance serving system for deep learning models - -#### 软件架构 -软件架构说明 - - -#### 安装教程 - -1. xxxx -2. xxxx -3. xxxx - -#### 使用说明 - -1. xxxx -2. xxxx -3. xxxx - -#### 参与贡献 - -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request - - -#### 码云特技 - -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. 码云官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解码云上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是码云最有价值开源项目,是码云综合评定出的优秀开源项目 -5. 码云官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. 码云封面人物是一档用来展示码云会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) diff --git a/serving/README_CN.md b/serving/README_CN.md new file mode 100644 index 0000000000..cd38554cb3 --- /dev/null +++ b/serving/README_CN.md @@ -0,0 +1,117 @@ +# 基于MindSpore部署预测服务 + + + +- [基于MindSpore部署预测服务](#基于mindspore部署预测服务) + - [概述](#概述) + - [启动Serving服务](#启动serving服务) + - [应用示例](#应用示例) + - [导出模型](#导出模型) + - [启动Serving推理服务](#启动serving推理服务) + - [客户端示例](#客户端示例) + + +## 概述 + +MindSpore Serving是一个轻量级、高性能的服务模块,旨在帮助MindSpore开发者在生产环境中高效部署在线预测服务。当用户使用MindSpore完成模型训练后,导出MindSpore模型,即可使用MindSpore Serving创建该模型的预测服务。当前Serving仅支持Ascend 910。 + + +## 启动Serving服务 +通过pip安装MindSpore后,Serving可执行程序位于`/{your python path}/lib/python3.7/site-packages/mindspore/ms_serving` 。 +启动Serving服务命令如下 +```bash +ms_serving [--help] [--model_path ] [--model_name ] + [--port ] [--device_id ] +``` +参数含义如下 + +|参数名|属性|功能描述|参数类型|默认值|取值范围| +|---|---|---|---|---|---| +|`--help`|可选|显示启动命令的帮助信息。|-|-|-| +|`--model_path `|必选|指定待加载模型的存放路径。|str|空|-| +|`--model_name `|必选|指定待加载模型的文件名。|str|空|-| +|`--port `|可选|指定Serving对外的端口号。|int|5500|1~65535| +|`--device_id `|可选|指定使用的设备号|int|0|0~7| + + > 执行启动命令前,需将`/{your python path}/lib:/{your python path}/lib/python3.7/site-packages/mindspore/lib`对应的路径加入到环境变量LD_LIBRARY_PATH中 。 + +## 应用示例 +下面以一个简单的网络为例,演示MindSpore Serving如何使用。 + +### 导出模型 +使用[add_model.py](https://gitee.com/mindspore/mindspore/blob/master/serving/example/export_model/add_model.py),构造一个只有Add算子的网络,并导出MindSpore推理部署模型。 + +```python +python add_model.py +``` +执行脚本,生成add.pb文件,该模型的输入为两个shape为[4]的一维Tensor,输出结果是两个输入Tensor之和。 + +### 启动Serving推理服务 +```bash +ms_serving --model_path={current path} --model_name=add.pb +``` +当服务端打印日志`MS Serving Listening on 0.0.0.0:5500`时,表示Serving服务已加载推理模型完毕。 + +### 客户端示例 +执行如下命令,编译一个客户端示例程序,并向Serving服务发送推理请求。 +```bash +cd mindspore/serving/example/cpp_client +mkdir build +cmake .. +make +./ms_client --target=localhost:5500 +``` +显示如下返回值说明Serving服务已正确执行Add网络的推理。 +``` +Compute [1, 2, 3, 4] + [1, 2, 3, 4] +Add result is [2, 4, 6, 8] +client received: RPC OK +``` + > 编译客户端要求用户本地已安装c++版本的[gRPC](https://gRPC.io),并将对应路径加入到环境变量`PATH`中。 + +客户端代码主要包含以下几个部分: + +1. 基于MSService::Stub实现Client,并创建Client实例。 + ``` + class MSClient { + public: + explicit MSClient(std::shared_ptr channel) : stub_(MSService::NewStub(channel)) {} + private: + std::unique_ptr stub_; + };MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); + + MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); + + ``` +2. 根据网络的实际输入构造请求的入参Request、出参Reply和gRPC的客户端Context。 + ``` + PredictRequest request; + PredictReply reply; + ClientContext context; + + //construct tensor + Tensor data; + + //set shape + TensorShape shape; + shape.add_dims(4); + *data.mutable_tensor_shape() = shape; + + //set type + data.set_tensor_type(ms_serving::MS_FLOAT32); + std::vector input_data{1, 2, 3, 4}; + + //set datas + data.set_data(input_data.data(), input_data.size()); + + //add tensor to request + *request.add_data() = data; + *request.add_data() = data; + ``` +3. 调用gRPC接口和已经启动的Serving服务通信,并取回返回值。 + ``` + Status status = stub_->Predict(&context, request, &reply); + ``` + +完整代码参考[ms_client](https://gitee.com/mindspore/mindspore/blob/master/serving/example/cpp_client/ms_client.cc)。 + diff --git a/serving/acl/acl_session.cc b/serving/acl/acl_session.cc new file mode 100644 index 0000000000..92444924eb --- /dev/null +++ b/serving/acl/acl_session.cc @@ -0,0 +1,243 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "serving/acl/acl_session.h" +#include "include/infer_log.h" + +namespace mindspore::inference { + +std::shared_ptr InferSession::CreateSession(const std::string &device, uint32_t device_id) { + try { + auto session = std::make_shared(); + auto ret = session->InitEnv(device, device_id); + if (ret != SUCCESS) { + return nullptr; + } + return session; + } catch (std::exception &e) { + MSI_LOG_ERROR << "Inference CreatSession failed"; + return nullptr; + } +} + +Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { + Status ret = model_process_.LoadModelFromFile(file_name, model_id); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Load model from file failed, model file " << file_name; + return FAILED; + } + std::string dvpp_config_file; + auto index = file_name.rfind("."); + if (index == std::string::npos) { + dvpp_config_file = file_name; + } else { + dvpp_config_file = file_name.substr(0, index); + } + dvpp_config_file += "_dvpp_config.json"; + std::ifstream fp(dvpp_config_file); + if (!fp.is_open()) { + MSI_LOG_INFO << "Dvpp config file not exist, model will execute with tensors as inputs, dvpp config file " + << dvpp_config_file; + return SUCCESS; + } + fp.close(); + if (dvpp_process_.InitWithJsonConfig(dvpp_config_file) != SUCCESS) { + MSI_LOG_ERROR << "Dvpp config file parse error, dvpp config file " << dvpp_config_file; + return FAILED; + } + execute_with_dvpp_ = true; + MSI_LOG_INFO << "Dvpp config success"; + return SUCCESS; +} + +Status AclSession::UnloadModel(uint32_t /*model_id*/) { + model_process_.UnLoad(); + return SUCCESS; +} + +Status AclSession::ExecuteModel(uint32_t /*model_id*/, const RequestBase &request, + ReplyBase &reply) { // set d context + aclError rt_ret = aclrtSetCurrentContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "set the ascend device context failed"; + return FAILED; + } + return model_process_.Execute(request, reply); +} + +Status AclSession::PreProcess(uint32_t /*model_id*/, const InferImagesBase *images_input, + ImagesDvppOutput &dvpp_output) { + if (images_input == nullptr) { + MSI_LOG_ERROR << "images input is nullptr"; + return FAILED; + } + auto batch_size = images_input->batch_size(); + if (batch_size <= 0) { + MSI_LOG_ERROR << "invalid batch size " << images_input->batch_size(); + return FAILED; + } + std::vector pic_buffer_list; + std::vector pic_size_list; + for (size_t i = 0; i < batch_size; i++) { + const void *pic_buffer = nullptr; + uint32_t pic_size = 0; + if (!images_input->get(i, pic_buffer, pic_size) || pic_buffer == nullptr || pic_size == 0) { + MSI_LOG_ERROR << "Get request " << 0 << "th buffer failed"; + return FAILED; + } + pic_buffer_list.push_back(pic_buffer); + pic_size_list.push_back(pic_size); + } + auto ret = dvpp_process_.Process(pic_buffer_list, pic_size_list, dvpp_output.buffer_device, dvpp_output.buffer_size); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "dvpp process failed"; + return ret; + } + return SUCCESS; +} + +Status AclSession::ExecuteModel(uint32_t model_id, const ImagesRequestBase &images_inputs, // images for preprocess + const RequestBase &request, ReplyBase &reply) { + if (!execute_with_dvpp_) { + MSI_LOG_ERROR << "Unexpected images as inputs, DVPP not config"; + return INFER_STATUS(INVALID_INPUTS) << "Unexpected images as inputs, DVPP not config"; + } + aclError rt_ret = aclrtSetCurrentContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "set the ascend device context failed"; + return FAILED; + } + if (images_inputs.size() != 1) { + MSI_LOG_ERROR << "Only support one input to do DVPP preprocess"; + return INFER_STATUS(INVALID_INPUTS) << "Only support one input to do DVPP preprocess"; + } + if (images_inputs[0] == nullptr) { + MSI_LOG_ERROR << "Get first images input failed"; + return FAILED; + } + if (images_inputs[0]->batch_size() != model_process_.GetBatchSize()) { + MSI_LOG_ERROR << "Input batch size " << images_inputs[0]->batch_size() << " not match Model batch size " + << model_process_.GetBatchSize(); + return INFER_STATUS(INVALID_INPUTS) << "Input batch size " << images_inputs[0]->batch_size() + << " not match Model batch size " << model_process_.GetBatchSize(); + } + if (request.size() != 0) { + MSI_LOG_ERROR << "only support one input, images input size is 1, tensor inputs is not 0 " << request.size(); + return INFER_STATUS(INVALID_INPUTS) << "only support one input, images input size is 1, tensor inputs is not 0 " + << request.size(); + } + ImagesDvppOutput dvpp_output; + Status ret = PreProcess(model_id, images_inputs[0], dvpp_output); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "DVPP preprocess failed"; + return ret; + } + ret = model_process_.Execute(dvpp_output.buffer_device, dvpp_output.buffer_size, reply); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Execute model failed"; + return ret; + } + return SUCCESS; +} + +Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) { + device_type_ = device_type; + device_id_ = device_id; + auto ret = aclInit(nullptr); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Execute aclInit Failed"; + return FAILED; + } + MSI_LOG_INFO << "acl init success"; + + ret = aclrtSetDevice(device_id_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acl open device " << device_id_ << " failed"; + return FAILED; + } + MSI_LOG_INFO << "open device " << device_id_ << " success"; + + ret = aclrtCreateContext(&context_, device_id_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acl create context failed"; + return FAILED; + } + MSI_LOG_INFO << "create context success"; + + ret = aclrtCreateStream(&stream_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acl create stream failed"; + return FAILED; + } + MSI_LOG_INFO << "create stream success"; + + aclrtRunMode run_mode; + ret = aclrtGetRunMode(&run_mode); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acl get run mode failed"; + return FAILED; + } + bool is_device = (run_mode == ACL_DEVICE); + model_process_.SetIsDevice(is_device); + MSI_LOG_INFO << "get run mode success is device input/output " << is_device; + + if (dvpp_process_.InitResource(stream_) != SUCCESS) { + MSI_LOG_ERROR << "dvpp init resource failed"; + return FAILED; + } + MSI_LOG_INFO << "Init acl success, device id " << device_id_; + return SUCCESS; +} + +Status AclSession::FinalizeEnv() { + dvpp_process_.Finalize(); + aclError ret; + if (stream_ != nullptr) { + ret = aclrtDestroyStream(stream_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "destroy stream failed"; + } + stream_ = nullptr; + } + MSI_LOG_INFO << "end to destroy stream"; + if (context_ != nullptr) { + ret = aclrtDestroyContext(context_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "destroy context failed"; + } + context_ = nullptr; + } + MSI_LOG_INFO << "end to destroy context"; + + ret = aclrtResetDevice(device_id_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "reset devie " << device_id_ << " failed"; + } + MSI_LOG_INFO << "end to reset device " << device_id_; + + ret = aclFinalize(); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "finalize acl failed"; + } + MSI_LOG_INFO << "end to finalize acl"; + return SUCCESS; +} + +AclSession::AclSession() = default; +} // namespace mindspore::inference diff --git a/serving/acl/acl_session.h b/serving/acl/acl_session.h new file mode 100644 index 0000000000..c1ae025df2 --- /dev/null +++ b/serving/acl/acl_session.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_SERVING_ACL_SESSION_H +#define MINDSPORE_SERVING_ACL_SESSION_H + +#include +#include +#include +#include +#include +#include + +#include "include/inference.h" +#include "serving/acl/model_process.h" +#include "serving/acl/dvpp_process.h" + +namespace mindspore { +namespace inference { + +class AclSession : public InferSession { + public: + AclSession(); + + Status InitEnv(const std::string &device_type, uint32_t device_id) override; + Status FinalizeEnv() override; + Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; + Status UnloadModel(uint32_t model_id) override; + Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override; + Status ExecuteModel(uint32_t model_id, const ImagesRequestBase &images_inputs, // images for preprocess + const RequestBase &request, ReplyBase &reply) override; + + private: + std::string device_type_; + int32_t device_id_; + aclrtStream stream_ = nullptr; + aclrtContext context_ = nullptr; + ModelProcess model_process_; + bool execute_with_dvpp_ = false; + DvppProcess dvpp_process_; + + Status PreProcess(uint32_t model_id, const InferImagesBase *images_input, ImagesDvppOutput &dvpp_output); +}; +} // namespace inference +} // namespace mindspore +#endif // MINDSPORE_SERVING_ACL_SESSION_H diff --git a/serving/acl/dvpp_process.cc b/serving/acl/dvpp_process.cc new file mode 100644 index 0000000000..1fedaf6406 --- /dev/null +++ b/serving/acl/dvpp_process.cc @@ -0,0 +1,1139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "serving/acl/dvpp_process.h" +#include +#include +#include +#include +#include "include/infer_log.h" + +namespace mindspore { +namespace inference { + +DvppProcess::DvppProcess() {} + +DvppProcess::~DvppProcess() {} + +static uint32_t ToEven(uint32_t num) { return (num + 1) / 2 * 2; } + +static uint32_t ToOdd(uint32_t num) { + if (num == 0) { + return 1; + } + return (num + 1) / 2 * 2 - 1; +} + +class DvppJsonConfigParser { + public: + DvppJsonConfigParser() = default; + ~DvppJsonConfigParser() = default; + + Status InitWithJsonConfig(const std::string &json_config); + DvppDecodePara GetDecodePara() const { return decode_para_; } + DvppResizePara GetResizePara() const { return resize_para_; } + DvppCropPara GetCropPara() const { return crop_para_; } + DvppCropAndPastePara GetCropAndPastePara() const { return crop_and_paste_para_; } + bool HasResizeConfig() const { return resize_flag_; } + bool HasCropConfig() const { return crop_flag_; } + bool HasCropAndPasteConfig() const { return crop_and_paste_flag_; } + + private: + DvppDecodePara decode_para_; + DvppResizePara resize_para_; + DvppCropPara crop_para_; + DvppCropAndPastePara crop_and_paste_para_; + bool resize_flag_ = false; + bool crop_flag_ = false; + bool crop_and_paste_flag_ = false; + + Status GetStringValue(const nlohmann::json &json_item, const std::string &key, std::string &val); + Status GetIntValue(const nlohmann::json &json_item, const std::string &key, uint32_t &val); + Status ParseInputPara(const nlohmann::json &preprocess_item); + Status ParseDecodePara(const nlohmann::json &preprocess_item); + Status ParseResizePara(const nlohmann::json &json_item); + Status ParseCropPara(const nlohmann::json &json_item); + Status ParseCropAndPastePara(const nlohmann::json &json_item); + Status InitWithJsonConfigImp(const std::string &json_config); +}; + +Status DvppProcess::InitResource(aclrtStream stream) { + stream_ = stream; + aclError acl_ret; + dvpp_channel_desc_ = acldvppCreateChannelDesc(); + if (dvpp_channel_desc_ == nullptr) { + MSI_LOG_ERROR << "acldvppCreateChannelDesc failed"; + return FAILED; + } + acl_ret = acldvppCreateChannel(dvpp_channel_desc_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppCreateChannel failed, acl return " << acl_ret; + return FAILED; + } + MSI_LOG_INFO << "End init dvpp process resource"; + return SUCCESS; +} + +void DvppProcess::DestroyResource() { + if (dvpp_channel_desc_ != nullptr) { + auto acl_ret = acldvppDestroyChannel(dvpp_channel_desc_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppDestroyChannel failed, acl return " << acl_ret; + } + acl_ret = acldvppDestroyChannelDesc(dvpp_channel_desc_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppDestroyChannelDesc failed, acl return " << acl_ret; + } + dvpp_channel_desc_ = nullptr; + } +} + +void DvppProcess::Finalize() { + DestroyDecodeDesc(); + DestroyVpcOutputDesc(); + DestroyResource(); + if (resize_config_ != nullptr) { + acldvppDestroyResizeConfig(resize_config_); + resize_config_ = nullptr; + } + if (crop_area_ != nullptr) { + acldvppDestroyRoiConfig(crop_area_); + crop_area_ = nullptr; + } + if (paste_area_ != nullptr) { + acldvppDestroyRoiConfig(paste_area_); + paste_area_ = nullptr; + } + if (input_pic_dev_buffer_ != nullptr) { + acldvppFree(input_pic_dev_buffer_); + } + input_pic_buffer_size_ = 0; + MSI_LOG_INFO << "End dvpp process finalize"; +} + +Status DvppProcess::InitJpegDecodePara(const DvppDecodePara &decode_para) { + decode_para_ = decode_para; + MSI_LOG_INFO << "Init decode para, pixel_format " << decode_para_.pixel_format; + return SUCCESS; +} + +Status DvppProcess::InitResizePara(const DvppResizePara &resize_para) { + resize_para_ = resize_para; + MSI_LOG_INFO << "Init resize para, " + << "output_width " << resize_para_.output_width << ", output_height " << resize_para_.output_height; + to_resize_flag_ = true; + to_crop_flag_ = false; + to_crop_and_paste_flag_ = false; + Status ret = InitResizeOutputDesc(); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "InitResizeOutputDesc failed"; + } + return ret; +} + +Status DvppProcess::InitCommonCropPara(DvppCropInfo &crop_info, uint32_t output_width, uint32_t output_height) { + if (crop_info.crop_type == kDvppCropTypeOffset) { + if (CheckAndAdjustRoiArea(crop_info.crop_area) != SUCCESS) { + MSI_LOG_ERROR << "Check and adjust crop area failed"; + return FAILED; + } + MSI_LOG_INFO << "Init common crop para, crop type offset " + << ", left " << crop_info.crop_area.left << ", right " << crop_info.crop_area.right << ", top " + << crop_info.crop_area.top << ", bottom " << crop_info.crop_area.bottom << ", output_width " + << output_width << ", output_height " << output_height; + } else { + crop_info.crop_width = ToEven(crop_info.crop_width); + crop_info.crop_height = ToEven(crop_info.crop_height); + if (CheckRoiAreaWidthHeight(crop_info.crop_width, crop_info.crop_height) != SUCCESS) { + MSI_LOG_ERROR << "Check crop area width and height failed, actually width " << crop_info.crop_width << " height " + << crop_info.crop_height; + return FAILED; + } + MSI_LOG_INFO << "Init common crop para, crop type centre " + << ", crop_width " << crop_info.crop_width << ", crop_height " << crop_info.crop_height + << ", output_width " << output_width << ", output_height " << output_height; + } + return SUCCESS; +} + +Status DvppProcess::InitCropPara(const DvppCropPara &crop_para) { + crop_para_ = crop_para; + if (InitCommonCropPara(crop_para_.crop_info, crop_para_.output_width, crop_para_.output_height) != SUCCESS) { + MSI_LOG_ERROR << "Init common crop para failed in InitCropPara"; + return FAILED; + } + to_crop_flag_ = true; + to_resize_flag_ = false; + to_crop_and_paste_flag_ = false; + Status ret = InitCropOutputDesc(); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "InitCropOutputDesc failed"; + } + return ret; +} + +Status DvppProcess::InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para) { + crop_and_paste_para_ = crop_and_paste_para; + if (InitCommonCropPara(crop_and_paste_para_.crop_info, crop_and_paste_para_.output_width, + crop_and_paste_para_.output_height) != SUCCESS) { + MSI_LOG_ERROR << "Init common crop para failed in InitCropAndPastePara"; + return FAILED; + } + auto &paste_area = crop_and_paste_para_.paste_area; + if (CheckAndAdjustRoiArea(paste_area) != SUCCESS) { + MSI_LOG_ERROR << "Check and adjust paste area failed"; + return FAILED; + } + MSI_LOG_INFO << "Init crop and paste para, paste info: " + << ", left " << paste_area.left << ", right " << paste_area.right << ", top " << paste_area.top + << ", bottom " << paste_area.bottom; + + to_crop_and_paste_flag_ = true; + to_crop_flag_ = false; + to_resize_flag_ = false; + Status ret = InitCropAndPasteOutputDesc(); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "InitCropAndPasteOutputDesc failed"; + } + return ret; +} + +Status DvppProcess::InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size) { + aclError acl_ret; + if (pic_buffer_size != input_pic_buffer_size_) { + acldvppFree(input_pic_dev_buffer_); + input_pic_buffer_size_ = 0; + acl_ret = acldvppMalloc(&input_pic_dev_buffer_, pic_buffer_size); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppMalloc input picture buffer on device failed, buffer size " << pic_buffer_size; + return FAILED; + } + input_pic_buffer_size_ = pic_buffer_size; + } + acl_ret = + aclrtMemcpy(input_pic_dev_buffer_, input_pic_buffer_size_, pic_buffer, pic_buffer_size, ACL_MEMCPY_HOST_TO_DEVICE); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "aclrtMemcpy input picture buffer to device, buffer size " << pic_buffer_size; + return FAILED; + } + return SUCCESS; +} + +static void JpegErrorExitCustom(j_common_ptr cinfo) { + char jpeg_last_error_msg[JMSG_LENGTH_MAX]; + if (cinfo != nullptr && cinfo->err != nullptr && cinfo->err->format_message != nullptr) { + (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); + } + throw std::runtime_error(jpeg_last_error_msg); +} + +Status DvppProcess::GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t &image_width, + uint32_t &image_height) { + struct jpeg_decompress_struct jpeg_header; + struct jpeg_error_mgr jpeg_error; + jpeg_header.err = jpeg_std_error(&jpeg_error); + jpeg_error.error_exit = JpegErrorExitCustom; + try { + jpeg_create_decompress(&jpeg_header); + jpeg_mem_src(&jpeg_header, reinterpret_cast(pic_buffer), pic_buffer_size); + (void)jpeg_read_header(&jpeg_header, TRUE); + } catch (std::runtime_error &e) { + jpeg_destroy_decompress(&jpeg_header); + MSI_LOG_ERROR << "jpeg images read failed, " << e.what(); + return INFER_STATUS(INVALID_INPUTS) << "jpeg images decode failed"; + } + image_width = jpeg_header.image_width; + image_height = jpeg_header.image_height; + + if (jpeg_header.jpeg_color_space != JCS_YCbCr) { + MSI_LOG_ERROR << "Expect color space YUV(YCbCr), current " << jpeg_header.jpeg_color_space; + jpeg_destroy_decompress(&jpeg_header); + return INFER_STATUS(INVALID_INPUTS) << "Expect color space YUV(YCbCr), current " << jpeg_header.jpeg_color_space; + } + if (jpeg_header.dc_huff_tbl_ptrs[0] == nullptr) { + MSI_LOG_ERROR << "Only support Huffman code"; + jpeg_destroy_decompress(&jpeg_header); + return INFER_STATUS(INVALID_INPUTS) << "Only support Huffman code"; + } + jpeg_destroy_decompress(&jpeg_header); + + const uint32_t min_width = 32; + const uint32_t max_width = 8192; + const uint32_t min_height = 32; + const uint32_t max_height = 8192; + if (image_width < min_width || image_width > max_width) { + MSI_LOG_ERROR << "expect image width [" << min_width << ", " << max_width << "], the real image width is " + << image_width; + return INFER_STATUS(INVALID_INPUTS) << "expect image width [" << min_width << ", " << max_width + << "], the real image width is " << image_width; + } + if (image_height < min_height || image_height > max_height) { + MSI_LOG_ERROR << "expect image height [" << min_height << ", " << max_height << "], the real image height is " + << image_height; + return INFER_STATUS(INVALID_INPUTS) << "expect image height [" << min_height << ", " << max_height + << "], the real image height is " << image_height; + } + return SUCCESS; +} + +Status DvppProcess::Process(const void *pic_buffer, size_t pic_buffer_size, void *&output_device_buffer, + size_t &output_size) { + if (dvpp_channel_desc_ == nullptr) { + MSI_LOG_ERROR << "Process failed, dvpp not inited"; + return FAILED; + } + uint32_t image_width = 0; + uint32_t image_height = 0; + Status ret = GetJpegWidthHeight(pic_buffer, pic_buffer_size, image_width, image_height); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Get jpeg image height and width failed"; + return ret; + } + MSI_LOG_INFO << "Get jpeg width " << image_width << ", height " << image_height; + ret = InitDecodeOutputDesc(image_width, image_height); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "InitDecodeOutputDesc failed"; + return FAILED; + } + ret = UpdateCropArea(image_width, image_height); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Update crop area failed"; + return ret; + } + ret = CheckResizeImageInfo(image_width, image_height); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Check resize para failed"; + return ret; + } + if (InputInputBuffer(pic_buffer, pic_buffer_size) != SUCCESS) { + MSI_LOG_ERROR << "InputInputBuffer failed"; + return FAILED; + } + if (ProcessDecode() != SUCCESS) { + MSI_LOG_ERROR << "Process Decode failed"; + return INFER_STATUS(INVALID_INPUTS) << "Decode image failed"; + } + MSI_LOG_INFO << "Process Decode success"; + if (to_resize_flag_) { + if (ProcessResize() != SUCCESS) { + MSI_LOG_ERROR << "Process Resize failed"; + return INFER_STATUS(FAILED) << "Resize image failed"; + } + MSI_LOG_INFO << "Process Resize success"; + } else if (to_crop_flag_) { + if (ProcessCrop() != SUCCESS) { + MSI_LOG_ERROR << "Process Crop failed"; + return INFER_STATUS(FAILED) << "Crop image failed"; + } + MSI_LOG_INFO << "Process Crop success"; + } else if (to_crop_and_paste_flag_) { + if (ProcessCropAndPaste() != SUCCESS) { + MSI_LOG_ERROR << "Process Crop And Paste failed"; + return INFER_STATUS(FAILED) << "Crop And Paste image failed"; + } + MSI_LOG_INFO << "Process Crop And Paste success"; + } + if (vpc_output_buffer_dev_ == nullptr) { + output_device_buffer = decode_output_buffer_dev_; + output_size = decode_output_buffer_size_; + } else { + output_device_buffer = vpc_output_buffer_dev_; + output_size = vpc_output_buffer_size_; + } + MSI_LOG_INFO << "Process dvpp success"; + return SUCCESS; +} + +Status DvppProcess::Process(const std::vector &pic_buffer_list, + const std::vector &pic_buffer_size_list, void *&output_device_buffer, + size_t &output_size) { + auto batch_size = pic_buffer_list.size(); + if (batch_size == 0 || batch_size != pic_buffer_size_list.size()) { + MSI_LOG_ERROR << "invalid batch size " << batch_size << ", pic size count" << pic_buffer_size_list.size(); + return FAILED; + } + MSI_LOG_INFO << "Begin dvpp process, batch size " << batch_size; + if (batch_size == 1) { + return Process(pic_buffer_list[0], pic_buffer_size_list[0], output_device_buffer, output_size); + } + size_t total_buffer_size = vpc_output_buffer_size_ * batch_size; + if (batch_size_ != batch_size) { + if (batch_vpc_output_buffer_dev_ != nullptr) { + acldvppFree(batch_vpc_output_buffer_dev_); + batch_vpc_output_buffer_dev_ = nullptr; + } + batch_size_ = batch_size; + auto acl_rt = acldvppMalloc(&batch_vpc_output_buffer_dev_, total_buffer_size); + if (acl_rt != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppMalloc failed, buffer size " << total_buffer_size; + return FAILED; + } + } + for (size_t i = 0; i < batch_size; i++) { + const void *pic_buffer = pic_buffer_list[i]; + uint32_t pic_size = pic_buffer_size_list[i]; + if (pic_buffer == nullptr || pic_size == 0) { + MSI_LOG_ERROR << "Get " << 0 << "th images failed"; + return FAILED; + } + void *output_dev_buffer_tmp = nullptr; + size_t output_buffer_size_tmp = 0; + Status ret = Process(pic_buffer, pic_size, output_dev_buffer_tmp, output_buffer_size_tmp); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "dvpp process failed"; + return ret; + } + aclrtMemcpy(static_cast(batch_vpc_output_buffer_dev_) + vpc_output_buffer_size_ * i, + total_buffer_size - vpc_output_buffer_size_ * i, output_dev_buffer_tmp, vpc_output_buffer_size_, + ACL_MEMCPY_DEVICE_TO_DEVICE); + + MSI_LOG_INFO << "Dvpp process " << i << " th images success, input pic size " << pic_size << " output buffer size " + << output_buffer_size_tmp; + } + output_device_buffer = batch_vpc_output_buffer_dev_; + output_size = total_buffer_size; + MSI_LOG_INFO << "End dvpp process, batch size " << batch_size << ", output size " << output_size; + return SUCCESS; +} + +uint32_t DvppProcess::AlignmentHelper(uint32_t org_size, uint32_t alignment) const { + if (alignment == 0) { + return 0; + } + return (org_size + alignment - 1) / alignment * alignment; +} + +uint32_t DvppProcess::GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, + acldvppPixelFormat pixel_format) const { + if (stride_height == 0 || stride_width == 0) { + MSI_LOG_ERROR << "invalid stride height or width, stride_width " << stride_width << " stride_height " + << stride_height; + return 0; + } + if (UINT32_MAX / 3 < stride_height || UINT32_MAX / (3 * stride_height) < stride_width) { + MSI_LOG_ERROR << "invalid stride height or width, stride_width " << stride_width << " stride_height " + << stride_height; + return 0; + } + if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_420 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + return stride_width * stride_height * 3 / 2; // 420 + } else if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_422 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_422) { + return stride_width * stride_height * 2; // 422 + } else if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_444 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_444) { + return stride_width * stride_height * 3; // 444 + } + MSI_LOG_ERROR << "Not support pixel format " << pixel_format; + return 0; +} + +Status DvppProcess::GetPicDescStride(uint32_t width, uint32_t height, uint32_t &stride_width, uint32_t &stride_height) { + const uint32_t width_alignment = 16; + const uint32_t height_alignment = 2; + const uint32_t stride_width_minimum = 32; + const uint32_t stride_width_maximum = 4096; + const uint32_t stride_height_minimum = 6; + const uint32_t stride_height_maximum = 4096; + + stride_width = AlignmentHelper(width, width_alignment); + stride_height = AlignmentHelper(height, height_alignment); + if (stride_width == 0 || stride_height == 0) { + MSI_LOG_ERROR << "Init VPC output desc failed, get stride width or height failed"; + return FAILED; + } + if (stride_width < stride_width_minimum || stride_width > stride_width_maximum) { + MSI_LOG_ERROR << "Expect stride width [" << stride_width_minimum << ", " << stride_width_maximum + << "], current stride width " << stride_width << " given width " << width; + return FAILED; + } + if (stride_height < stride_height_minimum || stride_height > stride_height_maximum) { + MSI_LOG_ERROR << "Expect stride height [" << stride_height_minimum << ", " << stride_height_maximum + << "], current stride height " << stride_height << " given height " << height; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t &stride_width, + uint32_t &stride_height) { + const uint32_t width_alignment = 128; + const uint32_t height_alignment = 16; + const uint32_t width_minimum = 32; + const uint32_t width_maximum = 4096; // decode support 8192, dvpp(resize/crop/crop&paste) support 4096 + const uint32_t height_minimum = 32; + const uint32_t height_maximum = 4096; // decode support 8192, dvpp(resize/crop/crop&paste) support 4096 + if (width < width_minimum || width > width_maximum) { + MSI_LOG_ERROR << "Expect width [" << width_minimum << ", " << width_maximum << "], current width " << width; + return INFER_STATUS(INVALID_INPUTS) << "Expect width [" << width_minimum << ", " << width_maximum + << "], current width " << width; + } + if (height < height_minimum || height > height_maximum) { + MSI_LOG_ERROR << "Expect height [" << height_minimum << ", " << height_maximum << "], current height " << height; + return INFER_STATUS(INVALID_INPUTS) << "Expect height [" << height_minimum << ", " << height_maximum + << "], current height " << height; + } + stride_width = AlignmentHelper(width, width_alignment); + stride_height = AlignmentHelper(height, height_alignment); + if (stride_width == 0 || stride_height == 0) { + MSI_LOG_ERROR << "Init decode output desc failed, get stride width or height failed"; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, acldvppPixelFormat pixel_format) { + DestroyVpcOutputDesc(); + uint32_t vpc_stride_width = 0; + uint32_t vpc_stride_height = 0; + if (GetPicDescStride(output_width, output_height, vpc_stride_width, vpc_stride_height) != SUCCESS) { + MSI_LOG_ERROR << "Init VPC output desc failed, get VPC output stride width/height failed"; + return FAILED; + } + vpc_output_buffer_size_ = GetImageBufferSize(vpc_stride_width, vpc_stride_height, pixel_format); + if (vpc_output_buffer_size_ == 0) { + MSI_LOG_ERROR << "Init VPC output desc failed, get image buffer size failed"; + return FAILED; + } + auto acl_ret = acldvppMalloc(&vpc_output_buffer_dev_, vpc_output_buffer_size_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Init VPC output desc failed, malloc dvpp memory failed"; + return FAILED; + } + vpc_output_desc_ = acldvppCreatePicDesc(); + if (vpc_output_desc_ == nullptr) { + MSI_LOG_ERROR << "Init VPC output desc failed, create pic desc failed"; + return FAILED; + } + acldvppSetPicDescData(vpc_output_desc_, vpc_output_buffer_dev_); + acldvppSetPicDescSize(vpc_output_desc_, vpc_output_buffer_size_); + acldvppSetPicDescFormat(vpc_output_desc_, pixel_format); + acldvppSetPicDescWidth(vpc_output_desc_, output_width); + acldvppSetPicDescHeight(vpc_output_desc_, output_height); + acldvppSetPicDescWidthStride(vpc_output_desc_, vpc_stride_width); + acldvppSetPicDescHeightStride(vpc_output_desc_, vpc_stride_height); + MSI_LOG_INFO << "Init VPC output desc success"; + return SUCCESS; +} + +void DvppProcess::DestroyVpcOutputDesc() { + if (vpc_output_desc_ != nullptr) { + acldvppDestroyPicDesc(vpc_output_desc_); + vpc_output_desc_ = nullptr; + } + if (vpc_output_buffer_dev_ != nullptr) { + acldvppFree(vpc_output_buffer_dev_); + vpc_output_buffer_dev_ = nullptr; + } + if (batch_vpc_output_buffer_dev_ != nullptr) { + acldvppFree(batch_vpc_output_buffer_dev_); + batch_vpc_output_buffer_dev_ = nullptr; + } + vpc_output_buffer_size_ = 0; + MSI_LOG_INFO << "End destroy vpc desc"; +} + +Status DvppProcess::InitDecodeOutputDesc(uint32_t image_width, uint32_t image_height) { + if (decode_output_buffer_dev_ != nullptr && image_width == pic_width_ && image_height == pic_height_) { + return SUCCESS; + } + DestroyDecodeDesc(); + + pic_width_ = image_width; + pic_height_ = image_height; + + uint32_t stride_width = 0; + uint32_t stride_height = 0; + Status ret = GetPicDescStrideDecode(pic_width_, pic_height_, stride_width, stride_height); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Init VPC output desc failed, get VPC output stride width/height failed"; + return ret; + } + + decode_output_buffer_size_ = GetImageBufferSize(stride_width, stride_height, decode_para_.pixel_format); + if (decode_output_buffer_size_ == 0) { + MSI_LOG_ERROR << "Init decode output desc failed, get image buffer size failed"; + return FAILED; + } + auto acl_ret = acldvppMalloc(&decode_output_buffer_dev_, decode_output_buffer_size_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Init decode output desc failed, malloc dvpp memory failed"; + return FAILED; + } + decode_output_desc_ = acldvppCreatePicDesc(); + if (decode_output_desc_ == nullptr) { + MSI_LOG_ERROR << "Init decode output desc failed, create pic desc failed"; + return FAILED; + } + acldvppSetPicDescData(decode_output_desc_, decode_output_buffer_dev_); + acldvppSetPicDescSize(decode_output_desc_, decode_output_buffer_size_); + acldvppSetPicDescFormat(decode_output_desc_, decode_para_.pixel_format); + acldvppSetPicDescWidth(decode_output_desc_, pic_width_); + acldvppSetPicDescHeight(decode_output_desc_, pic_height_); + acldvppSetPicDescWidthStride(decode_output_desc_, stride_width); + acldvppSetPicDescHeightStride(decode_output_desc_, stride_height); + MSI_LOG_INFO << "Init decode output desc success"; + return SUCCESS; +} + +Status DvppProcess::CheckRoiAreaWidthHeight(uint32_t width, uint32_t height) { + const uint32_t min_crop_width = 10; + const uint32_t max_crop_width = 4096; + const uint32_t min_crop_height = 6; + const uint32_t max_crop_height = 4096; + + if (width < min_crop_width || width > max_crop_width) { + MSI_LOG_ERROR << "Expect roi area width in [" << min_crop_width << ", " << max_crop_width << "], actually " + << width; + return FAILED; + } + if (height < min_crop_height || height > max_crop_height) { + MSI_LOG_ERROR << "Expect roi area height in [" << min_crop_height << ", " << max_crop_height << "], actually " + << height; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::CheckAndAdjustRoiArea(DvppRoiArea &area) { + if (area.right < area.left) { + MSI_LOG_ERROR << "check roi area failed, left " << area.left << ", right " << area.right; + return FAILED; + } + if (area.bottom < area.top) { + MSI_LOG_ERROR << "check roi area failed, top " << area.top << ", bottom " << area.bottom; + return FAILED; + } + + area.left = ToEven(area.left); + area.top = ToEven(area.top); + area.right = ToOdd(area.right); + area.bottom = ToOdd(area.bottom); + + auto width = area.right - area.left + 1; + auto height = area.bottom - area.top + 1; + if (CheckRoiAreaWidthHeight(width, height) != SUCCESS) { + MSI_LOG_ERROR << "Check roi area width and height failed," + << " actually width " << width << " left " << area.left << ", right " << area.right + << " actually height " << height << " top " << area.top << ", bottom " << area.bottom; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::UpdateCropArea(uint32_t image_width, uint32_t image_height) { + DvppCropInfo *crop_info = nullptr; + if (to_crop_flag_) { + crop_info = &crop_para_.crop_info; + } else if (to_crop_and_paste_flag_) { + crop_info = &crop_and_paste_para_.crop_info; + } else { + return SUCCESS; + } + if (crop_info->crop_type != kDvppCropTypeCentre) { + return SUCCESS; + } + if (image_width < crop_info->crop_width) { + MSI_LOG_ERROR << "Image width " << image_width << "smaller than crop width " << crop_info->crop_width; + return INFER_STATUS(INVALID_INPUTS) << "Image width " << image_width << "smaller than crop width " + << crop_info->crop_width; + } + if (image_height < crop_info->crop_height) { + MSI_LOG_ERROR << "Image height " << image_height << "smaller than crop height " << crop_info->crop_height; + return INFER_STATUS(INVALID_INPUTS) << "Image width " << image_width << "smaller than crop width " + << crop_info->crop_width; + } + uint32_t left = ToEven((image_width - crop_info->crop_width) / 2); + uint32_t top = ToEven((image_height - crop_info->crop_height) / 2); + uint32_t right = ToOdd(left + crop_info->crop_width); + uint32_t bottom = ToOdd(top + crop_info->crop_height); + + auto acl_ret = acldvppSetRoiConfig(crop_area_, left, right, top, bottom); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Update Crop Area failed"; + return FAILED; + } + MSI_LOG_INFO << "Update crop area, crop type centre, crop info: " + << ", left " << left << ", right " << right << ", top " << top << ", bottom " << bottom; + return SUCCESS; +} + +Status DvppProcess::CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const { + if (!to_resize_flag_) { + return SUCCESS; + } + // resize ratio required [1/32, 16] + auto check_resize_ratio = [](uint32_t before_resize, uint32_t after_resize) { + if (before_resize == 0 || after_resize == 0) { + return false; + } + if (before_resize / after_resize > 32) { + return false; + } + if (after_resize / before_resize > 16) { + return false; + } + return true; + }; + if (!check_resize_ratio(image_width, resize_para_.output_width)) { + MSI_LOG_ERROR << "Resize ratio required [1/32, 16], current width resize from " << image_width << " to " + << resize_para_.output_width; + return INFER_STATUS(INVALID_INPUTS) << "Resize ratio required [1/32, 16], current width resize from " << image_width + << " to " << resize_para_.output_width; + } + if (!check_resize_ratio(image_height, resize_para_.output_height)) { + MSI_LOG_ERROR << "Resize ratio required [1/32, 16], current height resize from " << image_height << " to " + << resize_para_.output_height; + return INFER_STATUS(INVALID_INPUTS) << "Resize ratio required [1/32, 16], current height resize from " + << image_height << " to " << resize_para_.output_height; + } + return SUCCESS; +} + +void DvppProcess::DestroyDecodeDesc() { + if (decode_output_desc_ != nullptr) { + acldvppDestroyPicDesc(decode_output_desc_); + decode_output_desc_ = nullptr; + } + if (decode_output_buffer_dev_ != nullptr) { + acldvppFree(decode_output_buffer_dev_); + decode_output_buffer_dev_ = nullptr; + } + decode_output_buffer_size_ = 0; + MSI_LOG_INFO << "End destroy decode desc"; +} + +Status DvppProcess::InitResizeOutputDesc() { + if (InitVpcOutputDesc(resize_para_.output_width, resize_para_.output_height, decode_para_.pixel_format) != SUCCESS) { + MSI_LOG_ERROR << "Init VPC output desc failed"; + return FAILED; + } + if (resize_config_ == nullptr) { + resize_config_ = acldvppCreateResizeConfig(); + if (resize_config_ == nullptr) { + MSI_LOG_ERROR << "Create Resize config failed"; + return FAILED; + } + } + return SUCCESS; +} + +Status DvppProcess::InitRoiAreaConfig(acldvppRoiConfig *&roi_area, const DvppRoiArea &init_para) { + if (roi_area == nullptr) { + roi_area = acldvppCreateRoiConfig(init_para.left, init_para.right, init_para.top, init_para.bottom); + if (roi_area == nullptr) { + MSI_LOG_ERROR << "Create Roi config failed"; + return FAILED; + } + } else { + auto acl_ret = acldvppSetRoiConfig(roi_area, init_para.left, init_para.right, init_para.top, init_para.bottom); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Set Roi config failed"; + return FAILED; + } + } + return SUCCESS; +} + +Status DvppProcess::InitCropOutputDesc() { + if (InitVpcOutputDesc(crop_para_.output_width, crop_para_.output_height, decode_para_.pixel_format) != SUCCESS) { + MSI_LOG_ERROR << "Init VPC output desc failed"; + return FAILED; + } + if (InitRoiAreaConfig(crop_area_, crop_para_.crop_info.crop_area) != SUCCESS) { + MSI_LOG_ERROR << "Init crop area failed"; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::InitCropAndPasteOutputDesc() { + if (InitVpcOutputDesc(crop_and_paste_para_.output_width, crop_and_paste_para_.output_height, + decode_para_.pixel_format) != SUCCESS) { + MSI_LOG_ERROR << "Init VPC output desc failed"; + return FAILED; + } + if (InitRoiAreaConfig(crop_area_, crop_and_paste_para_.crop_info.crop_area) != SUCCESS) { + MSI_LOG_ERROR << "Init crop area failed"; + return FAILED; + } + if (InitRoiAreaConfig(paste_area_, crop_and_paste_para_.paste_area) != SUCCESS) { + MSI_LOG_ERROR << "Init paste area failed"; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessDecode() { + aclError acl_ret; + acl_ret = acldvppJpegDecodeAsync(dvpp_channel_desc_, input_pic_dev_buffer_, input_pic_buffer_size_, + decode_output_desc_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppJpegDecodeAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessResize() { + aclError acl_ret; + acl_ret = acldvppVpcResizeAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, resize_config_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppVpcResizeAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessCrop() { + aclError acl_ret; + acl_ret = acldvppVpcCropAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, crop_area_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppVpcCropAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessCropAndPaste() { + aclError acl_ret; + acl_ret = acldvppVpcCropAndPasteAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, crop_area_, + paste_area_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "acldvppVpcCropAndPasteAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppJsonConfigParser::GetStringValue(const nlohmann::json &json_item, const std::string &key, std::string &val) { + auto it = json_item.find(key); + if (it == json_item.end()) { + MSI_LOG_ERROR << "get string item " << key << " failed"; + return FAILED; + } + if (!it->is_string()) { + MSI_LOG_ERROR << "item " << key << " value is not string type"; + return FAILED; + } + val = it->get(); + return SUCCESS; +} + +Status DvppJsonConfigParser::GetIntValue(const nlohmann::json &json_item, const std::string &key, uint32_t &val) { + auto it = json_item.find(key); + if (it == json_item.end()) { + MSI_LOG_ERROR << "get string item " << key << " failed"; + return FAILED; + } + if (!it->is_number_integer()) { + MSI_LOG_ERROR << "item " << key << " value is not integer type"; + return FAILED; + } + val = it->get(); + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseInputPara(const nlohmann::json &preprocess_item) { + auto input = preprocess_item.find("input"); + if (input == preprocess_item.end()) { + MSI_LOG_ERROR << "get input failed"; + return FAILED; + } + if (!input->is_object()) { + MSI_LOG_ERROR << "input is not object"; + return FAILED; + } + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseDecodePara(const nlohmann::json &preprocess_item) { + auto decode_para = preprocess_item.find("decode_para"); + if (decode_para == preprocess_item.end()) { + MSI_LOG_ERROR << "get input failed"; + return FAILED; + } + if (!decode_para->is_object()) { + MSI_LOG_ERROR << "input is not object"; + return FAILED; + } + const std::unordered_map pixel_format_map = { + {"YUV420SP", PIXEL_FORMAT_YUV_SEMIPLANAR_420}, {"YVU420SP", PIXEL_FORMAT_YVU_SEMIPLANAR_420}, + {"YUV422SP", PIXEL_FORMAT_YUV_SEMIPLANAR_422}, {"YVU422SP", PIXEL_FORMAT_YVU_SEMIPLANAR_422}, + {"YUV444SP", PIXEL_FORMAT_YUV_SEMIPLANAR_444}, {"YVU444SP", PIXEL_FORMAT_YVU_SEMIPLANAR_444}, + }; + std::string pixel_format; + if (GetStringValue(*decode_para, "out_pixel_format", pixel_format) != SUCCESS) { + MSI_LOG_ERROR << "get op out_pixel_format failed"; + return FAILED; + } + auto format = pixel_format_map.find(pixel_format); + if (format == pixel_format_map.end()) { + MSI_LOG_ERROR << "unsupported out_pixel_format " << pixel_format; + return FAILED; + } + decode_para_.pixel_format = format->second; + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseResizePara(const nlohmann::json &json_item) { + if (GetIntValue(json_item, "out_width", resize_para_.output_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "out_height", resize_para_.output_height) != SUCCESS) { + return FAILED; + } + resize_flag_ = true; + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseCropPara(const nlohmann::json &json_item) { + if (GetIntValue(json_item, "out_width", crop_para_.output_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "out_height", crop_para_.output_height) != SUCCESS) { + return FAILED; + } + auto &crop_info = crop_para_.crop_info; + std::string crop_type = "crop_type"; + if (GetStringValue(json_item, "crop_type", crop_type) != SUCCESS) { + return FAILED; + } + if (crop_type == "offset") { + MSI_LOG_INFO << "Crop type is 'offset'"; + crop_info.crop_type = kDvppCropTypeOffset; + auto &crop_area = crop_info.crop_area; + if (GetIntValue(json_item, "crop_left", crop_area.left) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_top", crop_area.top) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_right", crop_area.right) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_bottom", crop_area.bottom) != SUCCESS) { + return FAILED; + } + } else if (crop_type == "centre") { + MSI_LOG_INFO << "Crop type is 'centre'"; + if (GetIntValue(json_item, "crop_width", crop_info.crop_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_height", crop_info.crop_height) != SUCCESS) { + return FAILED; + } + crop_info.crop_type = kDvppCropTypeCentre; + } else { + MSI_LOG_ERROR << "Invalid crop type " << crop_type << ", expect offset or centre"; + return FAILED; + } + crop_flag_ = true; + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseCropAndPastePara(const nlohmann::json &json_item) { + // crop info + if (GetIntValue(json_item, "out_width", crop_and_paste_para_.output_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "out_height", crop_and_paste_para_.output_height) != SUCCESS) { + return FAILED; + } + auto &crop_info = crop_and_paste_para_.crop_info; + std::string crop_type = "crop_type"; + if (GetStringValue(json_item, "crop_type", crop_type) != SUCCESS) { + return FAILED; + } + if (crop_type == "offset") { + MSI_LOG_INFO << "Crop type is 'offset'"; + crop_info.crop_type = kDvppCropTypeOffset; + auto &crop_area = crop_info.crop_area; + if (GetIntValue(json_item, "crop_left", crop_area.left) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_top", crop_area.top) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_right", crop_area.right) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_bottom", crop_area.bottom) != SUCCESS) { + return FAILED; + } + } else if (crop_type == "centre") { + MSI_LOG_INFO << "Crop type is 'centre'"; + if (GetIntValue(json_item, "crop_width", crop_info.crop_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_height", crop_info.crop_height) != SUCCESS) { + return FAILED; + } + crop_info.crop_type = kDvppCropTypeCentre; + } else { + MSI_LOG_ERROR << "Invalid crop type " << crop_type << ", expect offset or centre"; + return FAILED; + } + // paste info + auto &paste_area = crop_and_paste_para_.paste_area; + if (GetIntValue(json_item, "paste_left", paste_area.left) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "paste_top", paste_area.top) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "paste_right", paste_area.right) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "paste_bottom", paste_area.bottom) != SUCCESS) { + return FAILED; + } + crop_and_paste_flag_ = true; + return SUCCESS; +} + +Status DvppJsonConfigParser::InitWithJsonConfigImp(const std::string &json_config) { + std::ifstream fp(json_config); + if (!fp.is_open()) { + MSI_LOG_ERROR << "read json config file failed"; + return FAILED; + } + const auto &model_info = nlohmann::json::parse(fp); + auto preprocess_list = model_info.find("preprocess"); + if (preprocess_list == model_info.end()) { + MSI_LOG_ERROR << "get preprocess failed"; + return FAILED; + } + if (!preprocess_list->is_array()) { + MSI_LOG_ERROR << "preprocess is not array"; + return FAILED; + } + if (preprocess_list->empty()) { + MSI_LOG_ERROR << "preprocess size is 0"; + return FAILED; + } + auto &preprocess = preprocess_list->at(0); + // input + if (ParseInputPara(preprocess) != SUCCESS) { + MSI_LOG_ERROR << "parse input failed"; + return FAILED; + } + // decode para + if (ParseDecodePara(preprocess) != SUCCESS) { + MSI_LOG_ERROR << "parse decode failed"; + return FAILED; + } + // ops + auto dvpp_process = preprocess.find("dvpp_process"); + if (dvpp_process == preprocess.end()) { + MSI_LOG_ERROR << "get dvpp_process failed"; + return FAILED; + } + if (!dvpp_process->is_object()) { + MSI_LOG_ERROR << "dvpp_process is not array"; + return FAILED; + } + const auto &item = *dvpp_process; + std::string op_name; + if (GetStringValue(item, "op_name", op_name) != SUCCESS) { + return FAILED; + } + if (op_name == "resize") { + if (ParseResizePara(item) != SUCCESS) { + MSI_LOG_ERROR << "Parse resize para failed"; + return FAILED; + } + } else if (op_name == "crop") { + if (ParseCropPara(item) != SUCCESS) { + MSI_LOG_ERROR << "Parse crop para failed"; + return FAILED; + } + } else if (op_name == "crop_and_paste") { + if (ParseCropAndPastePara(item) != SUCCESS) { + MSI_LOG_ERROR << "Parse decode para failed"; + return FAILED; + } + } else { + MSI_LOG_ERROR << "Unsupport op name " << op_name << ", expect resize, crop or crop_and_paste"; + return FAILED; + } + return SUCCESS; +} + +Status DvppJsonConfigParser::InitWithJsonConfig(const std::string &json_config) { + try { + auto ret = InitWithJsonConfigImp(json_config); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "init dvpp with json config failed, json config " << json_config; + return FAILED; + } + } catch (nlohmann::json::exception &e) { + MSI_LOG_ERROR << "init dvpp with json config failed, json config " << json_config << ", error: " << e.what(); + return FAILED; + } + MSI_LOG_INFO << "Init with json config " << json_config << " success"; + return SUCCESS; +} + +Status DvppProcess::InitWithJsonConfig(const std::string &json_config) { + DvppJsonConfigParser parser; + if (parser.InitWithJsonConfig(json_config) != SUCCESS) { + MSI_LOG_ERROR << "init json config failed"; + return FAILED; + } + if (InitJpegDecodePara(parser.GetDecodePara()) != SUCCESS) { + MSI_LOG_ERROR << "init decode para failed"; + return FAILED; + } + if (parser.HasResizeConfig()) { + if (InitResizePara(parser.GetResizePara()) != SUCCESS) { + MSI_LOG_ERROR << "init resize para failed"; + return FAILED; + } + } else if (parser.HasCropConfig()) { + if (InitCropPara(parser.GetCropPara()) != SUCCESS) { + MSI_LOG_ERROR << "init crop para failed"; + return FAILED; + } + } else if (parser.HasCropAndPasteConfig()) { + if (InitCropAndPastePara(parser.GetCropAndPastePara()) != SUCCESS) { + MSI_LOG_ERROR << "init crop and paste para failed"; + return FAILED; + } + } + return SUCCESS; +} + +} // namespace inference +} // namespace mindspore diff --git a/serving/acl/dvpp_process.h b/serving/acl/dvpp_process.h new file mode 100644 index 0000000000..da0275118d --- /dev/null +++ b/serving/acl/dvpp_process.h @@ -0,0 +1,159 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_DVPP_PROCESS_ACL +#define INC_DVPP_PROCESS_ACL +#include +#include +#include "acl/acl.h" +#include "acl/acl_mdl.h" +#include "acl/acl_rt.h" +#include "acl/ops/acl_dvpp.h" +#include "include/inference.h" + +namespace mindspore::inference { + +struct DvppDecodePara { + acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; +}; + +struct DvppResizePara { + uint32_t output_width = 0; + uint32_t output_height = 0; +}; + +enum DvppCropType { + // crop left,top,right,bottom is given in config + kDvppCropTypeOffset = 0, + // crop left,top,right,bottom is calculated by image width/height and output crop width/height + kDvppCropTypeCentre = 1, +}; + +struct DvppRoiArea { + uint32_t left = 0; + uint32_t top = 0; + uint32_t right = 0; + uint32_t bottom = 0; +}; + +struct DvppCropInfo { + DvppCropType crop_type = kDvppCropTypeOffset; + DvppRoiArea crop_area; // when kDvppCropTypeOffset + uint32_t crop_width = 0; // when kDvppCropTypeCentre + uint32_t crop_height = 0; // when kDvppCropTypeCentre +}; + +struct DvppCropPara { + DvppCropInfo crop_info; + uint32_t output_width = 0; + uint32_t output_height = 0; +}; + +struct DvppCropAndPastePara { + DvppCropInfo crop_info; + DvppRoiArea paste_area; + uint32_t output_width = 0; + uint32_t output_height = 0; +}; + +class DvppProcess { + public: + DvppProcess(); + ~DvppProcess(); + + Status InitResource(aclrtStream stream); + void Finalize(); + Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop) + Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize + Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop + Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste + + Status InitWithJsonConfig(const std::string &json_config); + + // output device buffer will be destroy by DvppProcess itself. + Status Process(const void *pic_buffer, size_t pic_buffer_size, void *&output_device_buffer, size_t &output_size); + Status Process(const std::vector &pic_buffer_list, const std::vector &pic_buffer_size_list, + void *&output_device_buffer, size_t &output_size); + + private: + uint32_t pic_width_ = 0; + uint32_t pic_height_ = 0; + + DvppDecodePara decode_para_; + DvppResizePara resize_para_; + DvppCropPara crop_para_; + DvppCropAndPastePara crop_and_paste_para_; + // only one of the resize or crop flag can be true + bool to_resize_flag_ = false; + bool to_crop_flag_ = false; + bool to_crop_and_paste_flag_ = false; + + void *input_pic_dev_buffer_ = nullptr; + uint32_t input_pic_buffer_size_ = 0; + + uint32_t decode_output_buffer_size_ = 0; + void *decode_output_buffer_dev_ = nullptr; + acldvppPicDesc *decode_output_desc_ = nullptr; + + acldvppResizeConfig *resize_config_ = nullptr; + acldvppRoiConfig *crop_area_ = nullptr; + acldvppRoiConfig *paste_area_ = nullptr; + + acldvppPicDesc *vpc_output_desc_ = nullptr; + void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length + uint32_t vpc_output_buffer_size_ = 0; + + void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length + uint32_t batch_size_ = 0; + + aclrtStream stream_ = nullptr; + acldvppChannelDesc *dvpp_channel_desc_ = nullptr; + + uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const; + uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const; + Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t &stride_width, uint32_t &stride_height); + Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t &stride_width, uint32_t &stride_height); + Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size); + Status InitDecodeOutputDesc(uint32_t image_width, + uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_ + Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height); + Status CheckAndAdjustRoiArea(DvppRoiArea &area); + Status UpdateCropArea(uint32_t image_width, uint32_t image_height); + Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const; + void DestroyDecodeDesc(); + + Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, + acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_ + Status InitRoiAreaConfig(acldvppRoiConfig *&roi_area, const DvppRoiArea &init_para); + Status InitCommonCropPara(DvppCropInfo &crop_info, uint32_t out_width, uint32_t out_height); + Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config + Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_ + Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_ + void DestroyVpcOutputDesc(); + + Status ProcessDecode(); + Status ProcessResize(); + Status ProcessCrop(); + Status ProcessCropAndPaste(); + void DestroyResource(); + + Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t &image_width, + uint32_t &image_height); +}; + +} // namespace mindspore::inference + +#endif // INC_DVPP_PROCESS_ACL diff --git a/serving/acl/model_info_example.json b/serving/acl/model_info_example.json new file mode 100644 index 0000000000..e6d37048d0 --- /dev/null +++ b/serving/acl/model_info_example.json @@ -0,0 +1,68 @@ +{ + "preprocess": [ + { + "input": { + "index": 0 + }, + "decode_para": { + "out_pixel_format": "YUV420SP" + }, + "dvpp_process": { + "op_name": "resize", + "out_width": 224, + "out_height": 224 + }, + "sample of dvpp_process content": [ + { + "op_name": "resize", + "out_width": 224, + "out_height": 224 + }, + { + "op_name": "crop", + "crop_type": "offset", + "crop_left": 10, + "crop_top": 10, + "crop_right": 100, + "crop_bottom": 200, + "out_width": 224, + "out_height": 224 + }, + { + "op_name": "crop", + "crop_type": "centre", + "crop_width": 100, + "crop_height": 100, + "out_width": 224, + "out_height": 224 + }, + { + "op_name": "crop_and_paste", + "crop_type": "offset", + "crop_left": 10, + "crop_top": 10, + "crop_right": 100, + "crop_bottom": 200, + "paste_left": 10, + "paste_top": 10, + "paste_right": 100, + "paste_bottom": 200, + "out_width": 224, + "out_height": 224 + }, + { + "op_name": "crop_and_paste", + "crop_type": "centre", + "crop_width": 100, + "crop_height": 100, + "paste_left": 10, + "paste_top": 10, + "paste_right": 100, + "paste_bottom": 200, + "out_width": 224, + "out_height": 224 + } + ] + } + ] +} \ No newline at end of file diff --git a/serving/acl/model_process.cc b/serving/acl/model_process.cc new file mode 100644 index 0000000000..4016d1c00b --- /dev/null +++ b/serving/acl/model_process.cc @@ -0,0 +1,433 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "serving/acl/model_process.h" +#include +#include + +#include "include/infer_log.h" + +namespace mindspore { +namespace inference { + +Status ModelProcess::PreInitModelResource() { + model_desc_ = aclmdlCreateDesc(); + aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Read model desc failed"; + return FAILED; + } + Status ret = InitInputsBuffer(); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Create input buffer failed"; + return FAILED; + } + ret = InitOutputsBuffer(); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Create output buffer failed"; + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { + aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), &model_id); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Read model file failed, file name is " << file_name; + return FAILED; + } + MSI_LOG_INFO << "Load model success " << file_name; + model_id_ = model_id; + if (PreInitModelResource() != SUCCESS) { + aclmdlUnload(model_id_); + MSI_LOG_ERROR << "Pre init model resource failed, file name is " << file_name; + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::InitInputsBuffer() { + aclError ret; + size_t input_size = aclmdlGetNumInputs(model_desc_); + + for (size_t i = 0; i < input_size; ++i) { + auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i); + void *data_mem_buffer = nullptr; + if (!is_run_on_device_) { // need to copy input/output to/from device + ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Malloc device input buffer faild , input size " << buffer_size; + return FAILED; + } + } + + aclmdlIODims dims; + ret = aclmdlGetInputDims(model_desc_, i, &dims); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Get input shape failed"; + if (!is_run_on_device_) { + aclrtFree(data_mem_buffer); + } + return FAILED; + } + aclDataType data_type = aclmdlGetInputDataType(model_desc_, i); + std::vector shape(dims.dims, dims.dims + dims.dimCount); + input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape}); + } + MSI_LOG_INFO << "Create model inputs success"; + return SUCCESS; +} + +Status ModelProcess::CreateDataBuffer(void *&data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) { + aclError ret; + auto free_data_buffer = [this](void *dataMemBuffer) { + if (!is_run_on_device_) { + aclrtFree(dataMemBuffer); + } else { + aclrtFreeHost(dataMemBuffer); + } + }; + if (!is_run_on_device_) { + ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Malloc device buffer faild , buffer size " << buffer_size; + return FAILED; + } + } else { + ret = aclrtMallocHost(&data_mem_buffer, buffer_size); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Malloc device buffer faild , buffer size " << buffer_size; + return FAILED; + } + } + + auto data_buffer = aclCreateDataBuffer(data_mem_buffer, buffer_size); + if (data_buffer == nullptr) { + MSI_LOG_ERROR << "Create Data Buffer failed"; + free_data_buffer(data_mem_buffer); + return FAILED; + } + ret = aclmdlAddDatasetBuffer(dataset, data_buffer); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "add data buffer failed"; + free_data_buffer(data_mem_buffer); + aclDestroyDataBuffer(data_buffer); + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::InitOutputsBuffer() { + aclError ret; + outputs_ = aclmdlCreateDataset(); + if (outputs_ == nullptr) { + MSI_LOG_ERROR << "Create input dataset failed"; + return FAILED; + } + size_t output_size = aclmdlGetNumOutputs(model_desc_); + for (size_t i = 0; i < output_size; ++i) { + auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i); + + void *data_mem_buffer = nullptr; + if (CreateDataBuffer(data_mem_buffer, buffer_size, outputs_) != SUCCESS) { + MSI_LOG_ERROR << "add output data buffer failed, buffer size " << buffer_size; + return FAILED; + } + aclmdlIODims dims; + ret = aclmdlGetOutputDims(model_desc_, i, &dims); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Get input shape failed"; + if (!is_run_on_device_) { + aclrtFree(data_mem_buffer); + } else { + aclrtFreeHost(data_mem_buffer); + } + return FAILED; + } + aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i); + std::vector shape(dims.dims, dims.dims + dims.dimCount); + output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape}); + } + MSI_LOG_INFO << "Create model output success"; + return SUCCESS; +} + +void ModelProcess::DestroyInputsDataset() { + if (inputs_ == nullptr) { + return; + } + for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) { + auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i); + aclDestroyDataBuffer(dataBuffer); + } + aclmdlDestroyDataset(inputs_); + inputs_ = nullptr; +} + +void ModelProcess::DestroyInputsDataMem() { + if (!is_run_on_device_) { + for (const auto &item : input_infos_) { + aclrtFree(item.device_data); + } + } + input_infos_.clear(); +} + +void ModelProcess::DestroyInputsBuffer() { + DestroyInputsDataMem(); + DestroyInputsDataset(); +} + +void ModelProcess::DestroyOutputsBuffer() { + for (const auto &item : output_infos_) { + if (!is_run_on_device_) { + aclrtFree(item.device_data); + } else { + aclrtFreeHost(item.device_data); + } + } + output_infos_.clear(); + + if (outputs_ == nullptr) { + return; + } + for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) { + auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i); + aclDestroyDataBuffer(dataBuffer); + } + aclmdlDestroyDataset(outputs_); + outputs_ = nullptr; +} + +void ModelProcess::UnLoad() { + auto ret = aclmdlUnload(model_id_); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Unload model failed"; + } + if (model_desc_ != nullptr) { + aclmdlDestroyDesc(model_desc_); + model_desc_ = nullptr; + } + DestroyInputsBuffer(); + DestroyOutputsBuffer(); + MSI_LOG_INFO << "End unload model " << model_id_; +} + +Status ModelProcess::CheckAndInitInput(const RequestBase &request) { + aclError ret; + inputs_ = aclmdlCreateDataset(); + // check inputs + if (request.size() != input_infos_.size()) { + MSI_LOG_ERROR << "inputs count not match, required count " << input_infos_.size() << ", given count " + << request.size(); + return INFER_STATUS(INVALID_INPUTS) << "inputs count not match, required count " << input_infos_.size() + << ", given count " << request.size(); + } + for (size_t i = 0; i < input_infos_.size(); i++) { + if (request[i] == nullptr) { + MSI_LOG_ERROR << "input " << i << " cannot be null"; + return FAILED; + } + if (request[i]->data_size() != input_infos_[i].buffer_size) { + MSI_LOG_ERROR << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size + << ", given count " << request[i]->data_size(); + return INFER_STATUS(INVALID_INPUTS) << "input " << i << " data size not match, required size " + << input_infos_[i].buffer_size << ", given count " << request[i]->data_size(); + } + } + // copy inputs + for (size_t i = 0; i < input_infos_.size(); i++) { + void *input_buffer = nullptr; + auto &info = input_infos_[i]; + const void *data = request[i]->data(); + if (!is_run_on_device_) { + ret = aclrtMemcpy(info.device_data, info.buffer_size, data, request[i]->data_size(), ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "memcpy input " << i << " data to device failed, buffer size " << request[i]->data_size(); + return FAILED; + } + input_buffer = info.device_data; + } else { + input_buffer = const_cast(data); + } + auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size); + if (data_buffer == nullptr) { + MSI_LOG_ERROR << "Create Data Buffer failed"; + return FAILED; + } + ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "add data buffer failed"; + aclDestroyDataBuffer(data_buffer); + return FAILED; + } + } + return SUCCESS; +} + +Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, + size_t input_index) { + aclError ret; + inputs_ = aclmdlCreateDataset(); + // check inputs + if (input_index >= input_infos_.size()) { + MSI_LOG_ERROR << "inputs count not match, required count " << input_infos_.size() << ", given index " + << input_index; + return INFER_STATUS(INVALID_INPUTS) << "inputs count not match, required count " << input_infos_.size() + << ", given index " << input_index; + } + if (dvpp_outputs_buffer_dev == nullptr) { + MSI_LOG_ERROR << "input " << 0 << " cannot be null"; + return FAILED; + } + if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) { + MSI_LOG_ERROR << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size + << ", given count " << dvpp_outputs_buffer_size; + return INFER_STATUS(INVALID_INPUTS) << "input " << 0 << " data size not match, required size " + << input_infos_[input_index].buffer_size << ", given count " + << dvpp_outputs_buffer_size; + } + // copy inputs + auto &info = input_infos_[input_index]; + auto data_buffer = aclCreateDataBuffer(const_cast(dvpp_outputs_buffer_dev), info.buffer_size); + if (data_buffer == nullptr) { + MSI_LOG_ERROR << "Create Data Buffer failed"; + return FAILED; + } + ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "add data buffer failed"; + aclDestroyDataBuffer(data_buffer); + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::BuildOutputs(ReplyBase &reply) { + aclError ret; + // copy outputs + reply.clear(); + + std::unordered_map data_type_map = { + {ACL_FLOAT16, inference::kMSI_Float16}, {ACL_FLOAT, inference::kMSI_Float32}, {ACL_DOUBLE, inference::kMSI_Float64}, + {ACL_INT8, inference::kMSI_Int8}, {ACL_INT16, inference::kMSI_Int16}, {ACL_INT32, inference::kMSI_Int32}, + {ACL_INT64, inference::kMSI_Int64}, {ACL_UINT8, inference::kMSI_Uint8}, {ACL_UINT16, inference::kMSI_Uint16}, + {ACL_UINT32, inference::kMSI_Uint32}, {ACL_UINT64, inference::kMSI_Uint64}, {ACL_BOOL, inference::kMSI_Bool}, + }; + auto trans_to_serving_type = [&data_type_map](aclDataType data_type) { + auto it = data_type_map.find(data_type); + if (it == data_type_map.end()) { + return inference::kMSI_Unknown; + } else { + return it->second; + } + }; + for (size_t i = 0; i < output_infos_.size(); i++) { + auto &info = output_infos_[i]; + auto output = reply.add(); + if (output == nullptr) { + MSI_LOG_ERROR << "add new output failed"; + return FAILED; + } + output->set_data_type(trans_to_serving_type(info.data_type)); + output->set_shape(info.dims); + if (!output->resize_data(info.buffer_size)) { + MSI_LOG_ERROR << "new output data buffer failed, data size " << info.buffer_size; + return FAILED; + } + if (!is_run_on_device_) { + ret = aclrtMemcpy(output->mutable_data(), output->data_size(), info.device_data, info.buffer_size, + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Memcpy output " << i << " to host failed, memory size " << info.buffer_size; + return FAILED; + } + } else { + ret = aclrtMemcpy(output->mutable_data(), output->data_size(), info.device_data, info.buffer_size, + ACL_MEMCPY_HOST_TO_HOST); + if (ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Memcpy output " << i << " to host failed, memory size " << info.buffer_size; + return FAILED; + } + } + } + return SUCCESS; +} + +Status ModelProcess::Execute(const RequestBase &request, ReplyBase &reply) { + aclError acl_ret; + Status ret = CheckAndInitInput(request); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "check or init input failed"; + DestroyInputsDataset(); + return ret; // forward status error + } + acl_ret = aclmdlExecute(model_id_, inputs_, outputs_); + DestroyInputsDataset(); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Execute Model Failed"; + return FAILED; + } + ret = BuildOutputs(reply); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Build outputs faield"; + return FAILED; + } + MSI_LOG_INFO << "excute model success"; + return SUCCESS; +} + +Status ModelProcess::Execute(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, ReplyBase &reply) { + aclError acl_ret; + if (input_infos_.size() != 1) { + MSI_LOG_ERROR << "can only support input size 1, now model inputs size is " << input_infos_.size(); + return INFER_STATUS(INVALID_INPUTS) << "can only support input size 1, now model inputs size is " + << input_infos_.size(); + } + Status ret = CheckAndInitDvppInput(dvpp_outputs_buffer_dev, dvpp_outputs_buffer_size, 0); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "check or init input failed"; + DestroyInputsDataset(); + return ret; // forward status msg + } + acl_ret = aclmdlExecute(model_id_, inputs_, outputs_); + DestroyInputsDataset(); + if (acl_ret != ACL_ERROR_NONE) { + MSI_LOG_ERROR << "Execute Model Failed"; + return INFER_STATUS(FAILED) << "Execute Model Failed"; + } + ret = BuildOutputs(reply); + if (ret != SUCCESS) { + MSI_LOG_ERROR << "Build outputs faield"; + return FAILED; + } + MSI_LOG_INFO << "excute model success"; + return SUCCESS; +} + +size_t ModelProcess::GetBatchSize() const { + if (input_infos_.empty()) { + MSI_LOG_ERROR << "Model is not loaded"; + return 0; + } + if (input_infos_[0].dims.empty()) { + return 1; + } + return static_cast(input_infos_[0].dims[0]); +} + +} // namespace inference +} // namespace mindspore diff --git a/serving/acl/model_process.h b/serving/acl/model_process.h new file mode 100644 index 0000000000..ae716404ff --- /dev/null +++ b/serving/acl/model_process.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_MODEL_PROCESS_ACL +#define INC_MODEL_PROCESS_ACL +#include +#include +#include "acl/acl.h" +#include "acl/acl_mdl.h" +#include "acl/acl_rt.h" +#include "include/inference.h" + +namespace mindspore { +namespace inference { + +struct AclTensorInfo { + void *device_data; + size_t buffer_size; + aclDataType data_type; + std::vector dims; +}; + +struct ImagesDvppOutput { + void *buffer_device = nullptr; + size_t buffer_size = 0; + size_t input_index = 0; +}; + +class ModelProcess { + public: + ModelProcess() {} + ~ModelProcess() {} + + Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id); + void UnLoad(); + + // override this method to avoid request/reply data copy + Status Execute(const RequestBase &request, ReplyBase &reply); + Status Execute(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, ReplyBase &reply); + void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } + + size_t GetBatchSize() const; + + private: + uint32_t model_id_ = 0xffffffff; + // if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device + bool is_run_on_device_ = false; + aclmdlDesc *model_desc_ = nullptr; + aclmdlDataset *inputs_ = nullptr; + aclmdlDataset *outputs_ = nullptr; + std::vector input_infos_; + std::vector output_infos_; + + Status PreInitModelResource(); + Status CreateDataBuffer(void *&data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); + Status CheckAndInitInput(const RequestBase &request); + Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, + size_t input_index); + Status BuildOutputs(ReplyBase &reply); + + Status InitInputsBuffer(); + Status InitOutputsBuffer(); + void DestroyInputsDataset(); + void DestroyInputsDataMem(); + void DestroyInputsBuffer(); + void DestroyOutputsBuffer(); +}; + +} // namespace inference +} // namespace mindspore + +#endif diff --git a/serving/core/ms_service_pb2.py b/serving/core/ms_service_pb2.py deleted file mode 100644 index 9feec026f9..0000000000 --- a/serving/core/ms_service_pb2.py +++ /dev/null @@ -1,318 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: ms_service.proto - -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='ms_service.proto', - package='ms_serving', - syntax='proto3', - serialized_options=None, - serialized_pb=b'\n\x10ms_service.proto\x12\nms_serving\"2\n\x0ePredictRequest\x12 \n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"2\n\x0cPredictReply\x12\"\n\x06result\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"\x1b\n\x0bTensorShape\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\"p\n\x06Tensor\x12-\n\x0ctensor_shape\x18\x01 \x01(\x0b\x32\x17.ms_serving.TensorShape\x12)\n\x0btensor_type\x18\x02 \x01(\x0e\x32\x14.ms_serving.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c*\xc9\x01\n\x08\x44\x61taType\x12\x0e\n\nMS_UNKNOWN\x10\x00\x12\x0b\n\x07MS_BOOL\x10\x01\x12\x0b\n\x07MS_INT8\x10\x02\x12\x0c\n\x08MS_UINT8\x10\x03\x12\x0c\n\x08MS_INT16\x10\x04\x12\r\n\tMS_UINT16\x10\x05\x12\x0c\n\x08MS_INT32\x10\x06\x12\r\n\tMS_UINT32\x10\x07\x12\x0c\n\x08MS_INT64\x10\x08\x12\r\n\tMS_UINT64\x10\t\x12\x0e\n\nMS_FLOAT16\x10\n\x12\x0e\n\nMS_FLOAT32\x10\x0b\x12\x0e\n\nMS_FLOAT64\x10\x0c\x32\x8e\x01\n\tMSService\x12\x41\n\x07Predict\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x12>\n\x04Test\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x62\x06proto3' -) - -_DATATYPE = _descriptor.EnumDescriptor( - name='DataType', - full_name='ms_serving.DataType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='MS_UNKNOWN', index=0, number=0, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_BOOL', index=1, number=1, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_INT8', index=2, number=2, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_UINT8', index=3, number=3, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_INT16', index=4, number=4, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_UINT16', index=5, number=5, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_INT32', index=6, number=6, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_UINT32', index=7, number=7, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_INT64', index=8, number=8, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_UINT64', index=9, number=9, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_FLOAT16', index=10, number=10, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_FLOAT32', index=11, number=11, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MS_FLOAT64', index=12, number=12, - serialized_options=None, - type=None), - ], - containing_type=None, - serialized_options=None, - serialized_start=280, - serialized_end=481, -) -_sym_db.RegisterEnumDescriptor(_DATATYPE) - -DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) -MS_UNKNOWN = 0 -MS_BOOL = 1 -MS_INT8 = 2 -MS_UINT8 = 3 -MS_INT16 = 4 -MS_UINT16 = 5 -MS_INT32 = 6 -MS_UINT32 = 7 -MS_INT64 = 8 -MS_UINT64 = 9 -MS_FLOAT16 = 10 -MS_FLOAT32 = 11 -MS_FLOAT64 = 12 - - - -_PREDICTREQUEST = _descriptor.Descriptor( - name='PredictRequest', - full_name='ms_serving.PredictRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='data', full_name='ms_serving.PredictRequest.data', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=32, - serialized_end=82, -) - - -_PREDICTREPLY = _descriptor.Descriptor( - name='PredictReply', - full_name='ms_serving.PredictReply', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='result', full_name='ms_serving.PredictReply.result', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=84, - serialized_end=134, -) - - -_TENSORSHAPE = _descriptor.Descriptor( - name='TensorShape', - full_name='ms_serving.TensorShape', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dims', full_name='ms_serving.TensorShape.dims', index=0, - number=1, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=136, - serialized_end=163, -) - - -_TENSOR = _descriptor.Descriptor( - name='Tensor', - full_name='ms_serving.Tensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tensor_shape', full_name='ms_serving.Tensor.tensor_shape', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor_type', full_name='ms_serving.Tensor.tensor_type', index=1, - number=2, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='data', full_name='ms_serving.Tensor.data', index=2, - number=3, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=b"", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=165, - serialized_end=277, -) - -_PREDICTREQUEST.fields_by_name['data'].message_type = _TENSOR -_PREDICTREPLY.fields_by_name['result'].message_type = _TENSOR -_TENSOR.fields_by_name['tensor_shape'].message_type = _TENSORSHAPE -_TENSOR.fields_by_name['tensor_type'].enum_type = _DATATYPE -DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST -DESCRIPTOR.message_types_by_name['PredictReply'] = _PREDICTREPLY -DESCRIPTOR.message_types_by_name['TensorShape'] = _TENSORSHAPE -DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR -DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTREQUEST, - '__module__' : 'ms_service_pb2' - # @@protoc_insertion_point(class_scope:ms_serving.PredictRequest) - }) -_sym_db.RegisterMessage(PredictRequest) - -PredictReply = _reflection.GeneratedProtocolMessageType('PredictReply', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTREPLY, - '__module__' : 'ms_service_pb2' - # @@protoc_insertion_point(class_scope:ms_serving.PredictReply) - }) -_sym_db.RegisterMessage(PredictReply) - -TensorShape = _reflection.GeneratedProtocolMessageType('TensorShape', (_message.Message,), { - 'DESCRIPTOR' : _TENSORSHAPE, - '__module__' : 'ms_service_pb2' - # @@protoc_insertion_point(class_scope:ms_serving.TensorShape) - }) -_sym_db.RegisterMessage(TensorShape) - -Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSOR, - '__module__' : 'ms_service_pb2' - # @@protoc_insertion_point(class_scope:ms_serving.Tensor) - }) -_sym_db.RegisterMessage(Tensor) - - - -_MSSERVICE = _descriptor.ServiceDescriptor( - name='MSService', - full_name='ms_serving.MSService', - file=DESCRIPTOR, - index=0, - serialized_options=None, - serialized_start=484, - serialized_end=626, - methods=[ - _descriptor.MethodDescriptor( - name='Predict', - full_name='ms_serving.MSService.Predict', - index=0, - containing_service=None, - input_type=_PREDICTREQUEST, - output_type=_PREDICTREPLY, - serialized_options=None, - ), - _descriptor.MethodDescriptor( - name='Test', - full_name='ms_serving.MSService.Test', - index=1, - containing_service=None, - input_type=_PREDICTREQUEST, - output_type=_PREDICTREPLY, - serialized_options=None, - ), -]) -_sym_db.RegisterServiceDescriptor(_MSSERVICE) - -DESCRIPTOR.services_by_name['MSService'] = _MSSERVICE - -# @@protoc_insertion_point(module_scope) diff --git a/serving/core/ms_service_pb2_grpc.py b/serving/core/ms_service_pb2_grpc.py deleted file mode 100644 index e6f21a0de3..0000000000 --- a/serving/core/ms_service_pb2_grpc.py +++ /dev/null @@ -1,96 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - -import ms_service_pb2 as ms__service__pb2 - - -class MSServiceStub(object): - """Missing associated documentation comment in .proto file""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Predict = channel.unary_unary( - '/ms_serving.MSService/Predict', - request_serializer=ms__service__pb2.PredictRequest.SerializeToString, - response_deserializer=ms__service__pb2.PredictReply.FromString, - ) - self.Test = channel.unary_unary( - '/ms_serving.MSService/Test', - request_serializer=ms__service__pb2.PredictRequest.SerializeToString, - response_deserializer=ms__service__pb2.PredictReply.FromString, - ) - - -class MSServiceServicer(object): - """Missing associated documentation comment in .proto file""" - - def Predict(self, request, context): - """Missing associated documentation comment in .proto file""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Test(self, request, context): - """Missing associated documentation comment in .proto file""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_MSServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Predict': grpc.unary_unary_rpc_method_handler( - servicer.Predict, - request_deserializer=ms__service__pb2.PredictRequest.FromString, - response_serializer=ms__service__pb2.PredictReply.SerializeToString, - ), - 'Test': grpc.unary_unary_rpc_method_handler( - servicer.Test, - request_deserializer=ms__service__pb2.PredictRequest.FromString, - response_serializer=ms__service__pb2.PredictReply.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'ms_serving.MSService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class MSService(object): - """Missing associated documentation comment in .proto file""" - - @staticmethod - def Predict(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Predict', - ms__service__pb2.PredictRequest.SerializeToString, - ms__service__pb2.PredictReply.FromString, - options, channel_credentials, - call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def Test(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Test', - ms__service__pb2.PredictRequest.SerializeToString, - ms__service__pb2.PredictReply.FromString, - options, channel_credentials, - call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/serving/core/server.cc b/serving/core/server.cc index 5ba7ad36a7..61a3f1558a 100644 --- a/serving/core/server.cc +++ b/serving/core/server.cc @@ -23,14 +23,14 @@ #include #include #include +#include -#include "mindspore/ccsrc/utils/log_adapter.h" +#include "include/infer_log.h" #include "serving/ms_service.grpc.pb.h" #include "core/util/option_parser.h" #include "core/version_control/version_controller.h" -#include "mindspore/ccsrc/utils/context/ms_context.h" #include "core/util/file_system_operation.h" -#include "graphengine/third_party/fwkacllib/inc/runtime/context.h" +#include "core/serving_tensor.h" using ms_serving::MSService; using ms_serving::PredictReply; @@ -38,12 +38,19 @@ using ms_serving::PredictRequest; namespace mindspore { namespace serving { -using MSTensorPtr = std::shared_ptr; + +#define MSI_TIME_STAMP_START(name) auto time_start_##name = std::chrono::steady_clock::now(); +#define MSI_TIME_STAMP_END(name) \ + { \ + auto time_end_##name = std::chrono::steady_clock::now(); \ + auto time_cost = std::chrono::duration(time_end_##name - time_start_##name).count(); \ + MSI_LOG_INFO << #name " Time Cost # " << time_cost << " ms ---------------------"; \ + } Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { - session_ = inference::MSSession::CreateSession(device, device_id); + session_ = inference::InferSession::CreateSession(device, device_id); if (session_ == nullptr) { - MS_LOG(ERROR) << "Creat Session Failed"; + MSI_LOG(ERROR) << "Creat Session Failed"; return FAILED; } device_type_ = device; @@ -55,48 +62,67 @@ Session &Session::Instance() { return instance; } -Status Session::Predict(const std::vector &inputs, inference::MultiTensor *outputs) { - if (last_graph_ == nullptr) { - MS_LOG(ERROR) << "the model has not loaded"; +Status Session::Predict(const PredictRequest &request, PredictReply &reply) { + if (!model_loaded_) { + MSI_LOG(ERROR) << "the model has not loaded"; return FAILED; } if (session_ == nullptr) { - MS_LOG(ERROR) << "the inference session has not be initialized"; + MSI_LOG(ERROR) << "the inference session has not be initialized"; return FAILED; } std::lock_guard lock(mutex_); - MS_LOG(INFO) << "run Predict"; + MSI_LOG(INFO) << "run Predict"; + + if (request.images_size() > 0) { + ServingImagesRequest serving_images(request); + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + Status ret = session_->ExecuteModel(graph_id_, serving_images, serving_request, serving_reply); + if (ret != SUCCESS) { + MSI_LOG(ERROR) << "execute model with images return failed"; + return ret; + } + } else if (request.data_size() > 0) { + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + Status ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply); + if (ret != SUCCESS) { + MSI_LOG(ERROR) << "execute model with datas return failed"; + return ret; + } + } - *outputs = session_->RunGraph(graph_id_, inputs); - MS_LOG(INFO) << "run Predict finished"; + MSI_LOG(INFO) << "run Predict finished"; return SUCCESS; } Status Session::Warmup(const MindSporeModelPtr model) { if (session_ == nullptr) { - MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup"; + MSI_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup"; return FAILED; } std::lock_guard lock(mutex_); - size_t size = 0; std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); - char *graphBuf = ReadFile(file_name.c_str(), &size); - if (graphBuf == nullptr) { - MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); - return FAILED; - } - last_graph_ = inference::LoadModel(graphBuf, size, device_type_); - if (last_graph_ == nullptr) { - MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); - return FAILED; - } - graph_id_ = session_->CompileGraph(last_graph_); - MS_LOG(INFO) << "Session Warmup finished"; + model_loaded_ = false; + MSI_TIME_STAMP_START(LoadModelFromFile) + auto ret = session_->LoadModelFromFile(file_name, graph_id_); + MSI_TIME_STAMP_END(LoadModelFromFile) + if (ret != SUCCESS) { + MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + return ret; + } + model_loaded_ = true; + MSI_LOG(INFO) << "Session Warmup finished"; return SUCCESS; } Status Session::Clear() { - session_ = nullptr; + if (session_ != nullptr) { + session_->UnloadModel(graph_id_); + session_->FinalizeEnv(); + session_ = nullptr; + } return SUCCESS; } @@ -104,121 +130,45 @@ namespace { static const uint32_t uint32max = 0x7FFFFFFF; std::promise exit_requested; -const std::map type2id_map{ - {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, - {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, - {ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16}, - {ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32}, - {ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64}, - {ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32}, - {ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64}, -}; - -const std::map id2type_map{ - {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL}, - {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8}, - {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16}, - {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32}, - {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64}, - {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32}, - {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64}, -}; -const std::map length_map{ - {ms_serving::MS_UNKNOWN, 0}, - {ms_serving::MS_BOOL, sizeof(bool)}, - {ms_serving::MS_INT8, sizeof(int8_t)}, - {ms_serving::MS_UINT8, sizeof(uint8_t)}, - {ms_serving::MS_INT16, sizeof(int16_t)}, - {ms_serving::MS_UINT16, sizeof(uint16_t)}, - {ms_serving::MS_INT32, sizeof(int32_t)}, - {ms_serving::MS_UINT32, sizeof(uint32_t)}, - {ms_serving::MS_INT64, sizeof(int64_t)}, - {ms_serving::MS_UINT64, sizeof(uint64_t)}, - {ms_serving::MS_FLOAT16, 2}, - {ms_serving::MS_FLOAT32, 4}, - {ms_serving::MS_FLOAT64, 8}, -}; -MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { - std::vector shape; - for (auto dim : tensor.tensor_shape().dims()) { - shape.push_back(static_cast(dim)); - } - auto iter = type2id_map.find(tensor.tensor_type()); - if (iter == type2id_map.end()) { - MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type(); - return nullptr; - } - TypeId type = iter->second; - auto ms_tensor = std::shared_ptr(inference::MSTensor::CreateTensor(type, shape)); - memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size()); - return ms_tensor; -} +void ClearEnv() { Session::Instance().Clear(); } +void HandleSignal(int sig) { exit_requested.set_value(); } -ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) { - ms_serving::Tensor tensor; - ms_serving::TensorShape shape; - for (auto dim : ms_tensor->shape()) { - shape.add_dims(dim); - } - *tensor.mutable_tensor_shape() = shape; - auto iter = id2type_map.find(ms_tensor->data_type()); - if (iter == id2type_map.end()) { - MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type(); - return tensor; +grpc::Status CreatGRPCStatus(const Status &status) { + switch (status.StatusCode()) { + case SUCCESS: + return grpc::Status::OK; + case FAILED: + return grpc::Status::CANCELLED; + case INVALID_INPUTS: { + auto status_msg = status.StatusMessage(); + if (status_msg.empty()) { + status_msg = "The Predict Inputs do not match the Model Request!"; + } + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, status_msg); + } + default: + return grpc::Status::CANCELLED; } - tensor.set_tensor_type(iter->second); - tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size()); - return tensor; } -void ClearEnv() { - Session::Instance().Clear(); - inference::ExitInference(); -} -void HandleSignal(int sig) { exit_requested.set_value(); } - -#ifdef ENABLE_D -static rtContext_t g_ctx = nullptr; -#endif } // namespace // Service Implement class MSServiceImpl final : public MSService::Service { grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override { std::lock_guard lock(mutex_); -#ifdef ENABLE_D - if (g_ctx == nullptr) { - MS_LOG(ERROR) << "rtCtx is nullptr"; - return grpc::Status::CANCELLED; + MSI_TIME_STAMP_START(Predict) + auto res = Session::Instance().Predict(*request, *reply); + MSI_TIME_STAMP_END(Predict) + if (res != inference::SUCCESS) { + return CreatGRPCStatus(res); } - rtError_t rt_ret = rtCtxSetCurrent(g_ctx); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "set Ascend rtCtx failed"; - } -#endif - std::vector inputs; - inference::MultiTensor outputs; - for (int i = 0; i < request->data_size(); i++) { - auto input = ServingTensor2MSTensor(request->data(i)); - if (input == nullptr) { - MS_LOG(ERROR) << "Tensor convert failed"; - return grpc::Status::CANCELLED; - } - inputs.push_back(input); - } - auto res = Session::Instance().Predict(inputs, &outputs); - if (res != SUCCESS) { - return grpc::Status::CANCELLED; - } - for (const auto &tensor : outputs) { - *reply->add_result() = MSTensor2ServingTensor(tensor); - } - MS_LOG(INFO) << "Finish call service Eval"; + MSI_LOG(INFO) << "Finish call service Eval"; return grpc::Status::OK; } grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override { - MS_LOG(INFO) << "TestService call"; + MSI_LOG(INFO) << "TestService call"; return grpc::Status::OK; } std::mutex mutex_; @@ -237,28 +187,17 @@ Status Server::BuildAndStart() { auto device_id = option_args->device_id; res = Session::Instance().CreatDeviceSession(device_type, device_id); if (res != SUCCESS) { - MS_LOG(ERROR) << "creat session failed"; + MSI_LOG(ERROR) << "creat session failed"; ClearEnv(); return res; } VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name); res = version_controller.Run(); if (res != SUCCESS) { - MS_LOG(ERROR) << "load model failed"; + MSI_LOG(ERROR) << "load model failed"; ClearEnv(); return res; } -#ifdef ENABLE_D - // set d context - rtContext_t ctx = nullptr; - rtError_t rt_ret = rtCtxGetCurrent(&ctx); - if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { - MS_LOG(ERROR) << "the ascend device context is null"; - ClearEnv(); - return FAILED; - } - g_ctx = ctx; -#endif MSServiceImpl ms_service; grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); @@ -271,13 +210,13 @@ Status Server::BuildAndStart() { serverBuilder.RegisterService(&ms_service); std::unique_ptr server(serverBuilder.BuildAndStart()); if (server == nullptr) { - MS_LOG(ERROR) << "The serving server create failed"; + MSI_LOG(ERROR) << "The serving server create failed"; ClearEnv(); return FAILED; } auto grpc_server_run = [&server]() { server->Wait(); }; std::thread serving_thread(grpc_server_run); - MS_LOG(INFO) << "MS Serving listening on " << server_address; + MSI_LOG(INFO) << "MS Serving listening on " << server_address; auto exit_future = exit_requested.get_future(); exit_future.wait(); ClearEnv(); diff --git a/serving/core/server.h b/serving/core/server.h index f1927e9946..f97db84fce 100644 --- a/serving/core/server.h +++ b/serving/core/server.h @@ -23,14 +23,25 @@ #include "util/status.h" #include "version_control/model.h" #include "include/inference.h" -#include "mindspore/ccsrc/debug/info.h" +#include "serving/ms_service.pb.h" +#include "serving/ms_service.grpc.pb.h" + namespace mindspore { namespace serving { + +using ms_serving::PredictReply; +using ms_serving::PredictRequest; +using inference::Status; +using inference::SUCCESS; +using inference::FAILED; +using inference::INVALID_INPUTS; + class Session { public: static Session &Instance(); Status CreatDeviceSession(const std::string &device, uint32_t device_id); - Status Predict(const std::vector> &inputs, inference::MultiTensor *output); + // Status Predict(const inference::MultiTensor &inputs, inference::MultiTensor &output); + Status Predict(const PredictRequest &request, PredictReply &reply); Status Warmup(const MindSporeModelPtr model); Status Clear(); @@ -38,8 +49,8 @@ class Session { Session() = default; ~Session() = default; int sesseion_id_{0}; - std::shared_ptr session_{nullptr}; - FuncGraphPtr last_graph_{nullptr}; + std::shared_ptr session_{nullptr}; + bool model_loaded_ = false; uint32_t graph_id_{0}; std::mutex mutex_; std::string device_type_; diff --git a/serving/core/serving_tensor.cc b/serving/core/serving_tensor.cc new file mode 100644 index 0000000000..72225f3754 --- /dev/null +++ b/serving/core/serving_tensor.cc @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/serving_tensor.h" +#include +#include +#include +#include +#include "include/infer_log.h" + +using std::string; +using std::unordered_map; +using std::vector; + +namespace mindspore { +namespace serving { + +using inference::DataType; +using inference::InferTensorBase; + +const size_t kMaxShapeElementCount = INT32_MAX; +const size_t kMaxDataBufferSize = UINT32_MAX; + +ServingTensor::ServingTensor(ms_serving::Tensor &other) : tensor_(other) {} + +ServingTensor::~ServingTensor() {} + +DataType ServingTensor::data_type() const { + const std::unordered_map type2id_map{ + {ms_serving::MS_UNKNOWN, inference::kMSI_Unknown}, {ms_serving::MS_BOOL, inference::kMSI_Bool}, + {ms_serving::MS_INT8, inference::kMSI_Int8}, {ms_serving::MS_UINT8, inference::kMSI_Uint8}, + {ms_serving::MS_INT16, inference::kMSI_Int16}, {ms_serving::MS_UINT16, inference::kMSI_Uint16}, + {ms_serving::MS_INT32, inference::kMSI_Int32}, {ms_serving::MS_UINT32, inference::kMSI_Uint32}, + {ms_serving::MS_INT64, inference::kMSI_Int64}, {ms_serving::MS_UINT64, inference::kMSI_Uint64}, + {ms_serving::MS_FLOAT16, inference::kMSI_Float16}, {ms_serving::MS_FLOAT32, inference::kMSI_Float32}, + {ms_serving::MS_FLOAT64, inference::kMSI_Float64}, + }; + auto it = type2id_map.find(tensor_.tensor_type()); + if (it == type2id_map.end()) { + MSI_LOG_WARNING << "failed to get data type, undefined data type " << tensor_.tensor_type(); + return inference::kMSI_Unknown; + } else { + return it->second; + } +} + +void ServingTensor::set_data_type(DataType data_type) { + const std::unordered_map id2type_map{ + {inference::kMSI_Unknown, ms_serving::MS_UNKNOWN}, {inference::kMSI_Bool, ms_serving::MS_BOOL}, + {inference::kMSI_Float64, ms_serving::MS_FLOAT64}, {inference::kMSI_Int8, ms_serving::MS_INT8}, + {inference::kMSI_Uint8, ms_serving::MS_UINT8}, {inference::kMSI_Int16, ms_serving::MS_INT16}, + {inference::kMSI_Uint16, ms_serving::MS_UINT16}, {inference::kMSI_Int32, ms_serving::MS_INT32}, + {inference::kMSI_Uint32, ms_serving::MS_UINT32}, {inference::kMSI_Int64, ms_serving::MS_INT64}, + {inference::kMSI_Uint64, ms_serving::MS_UINT64}, {inference::kMSI_Float16, ms_serving::MS_FLOAT16}, + {inference::kMSI_Float32, ms_serving::MS_FLOAT32}, + }; + auto it = id2type_map.find(data_type); + if (it == id2type_map.end()) { + MSI_LOG_WARNING << "failed to set data type, undefined data type " << data_type; + tensor_.set_tensor_type(ms_serving::MS_UNKNOWN); + } else { + tensor_.set_tensor_type(it->second); + } +} + +std::vector ServingTensor::shape() const { + std::vector result; + auto dims = tensor_.tensor_shape().dims(); + std::transform(dims.begin(), dims.end(), std::back_inserter(result), [](const int64_t dim) { return dim; }); + return result; +} + +void ServingTensor::set_shape(const std::vector &shape) { + auto tensor_shape = tensor_.mutable_tensor_shape(); + tensor_shape->Clear(); + size_t element_count = 1; + for (auto dim : shape) { + if (dim <= 0 || element_count > kMaxShapeElementCount / dim) { + MSI_LOG_ERROR << "failed to set shape, invalid dim num " << dim; + tensor_shape->Clear(); + return; + } + element_count *= dim; + tensor_shape->add_dims(dim); + } +} + +bool ServingTensor::resize_data(size_t data_len) { + string *buffer = tensor_.mutable_data(); + if (buffer == nullptr) { + MSI_LOG_ERROR << "invalid buffer data"; + return false; + } + buffer->resize(data_len); + return true; +} + +size_t ServingTensor::data_size() const { return tensor_.data().size(); } + +void *ServingTensor::mutable_data() { return const_cast(tensor_.mutable_data()->data()); } + +const void *ServingTensor::data() const { return tensor_.data().data(); } + +ServingRequest::ServingRequest(const ms_serving::PredictRequest &request) : request_(request) { + auto &data = request_.data(); + std::transform(data.begin(), data.end(), std::back_inserter(cache_), + [](const ms_serving::Tensor &item) { return ServingTensor(const_cast(item)); }); +} + +size_t ServingRequest::size() const { return cache_.size(); } + +const InferTensorBase *ServingRequest::operator[](size_t index) const { + if (index >= cache_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size(); + return nullptr; + } + return &(cache_[index]); +} + +ServingImages::ServingImages(const ms_serving::Images &images) : images_(images) {} + +size_t ServingImages::batch_size() const { return images_.images_size(); } + +bool ServingImages::get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const { + if (index >= static_cast(images_.images_size())) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << images_.images_size(); + return false; + } + pic_buffer = images_.images(index).data(); + pic_size = images_.images(index).size(); + return true; +} + +size_t ServingImages::input_index() const { return static_cast(images_.input_index()); } + +size_t ServingReply::size() const { return cache_.size(); } + +InferTensorBase *ServingReply::operator[](size_t index) { + if (index >= cache_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size(); + return nullptr; + } + return &(cache_[index]); +} + +const InferTensorBase *ServingReply::operator[](size_t index) const { + if (index >= cache_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size(); + return nullptr; + } + return &(cache_[index]); +} + +InferTensorBase *ServingReply::add() { + auto new_item = reply_.add_result(); + if (new_item == nullptr) { + MSI_LOG_ERROR << "add new item failed, current total size " << cache_.size(); + return nullptr; + } + cache_.push_back(ServingTensor(*new_item)); + return &(cache_.back()); +} + +void ServingReply::clear() { reply_.mutable_result()->Clear(); } + +ServingImagesRequest::ServingImagesRequest(const ms_serving::PredictRequest &request) : request_(request) { + auto &images_inputs = request_.images(); + std::transform(images_inputs.begin(), images_inputs.end(), std::back_inserter(cache_), + [](const ms_serving::Images &item) { return ServingImages(const_cast(item)); }); +} + +size_t ServingImagesRequest::size() const { return cache_.size(); } + +const inference::InferImagesBase *ServingImagesRequest::operator[](size_t index) const { + if (index >= cache_.size()) { + MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size(); + return nullptr; + } + return &(cache_[index]); +} + +} // namespace serving +} // namespace mindspore diff --git a/serving/core/serving_tensor.h b/serving/core/serving_tensor.h new file mode 100644 index 0000000000..b55112fd64 --- /dev/null +++ b/serving/core/serving_tensor.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_TENSOR_H_ +#define MINDSPORE_SERVING_TENSOR_H_ + +#include +#include +#include +#include "include/infer_tensor.h" +#include "serving/ms_service.pb.h" + +namespace mindspore { +namespace serving { + +class MS_API ServingTensor : public inference::InferTensorBase { + public: + // the other's lifetime must longer than this object + explicit ServingTensor(ms_serving::Tensor &other); + ~ServingTensor(); + + inference::DataType data_type() const override; + void set_data_type(inference::DataType type) override; + std::vector shape() const override; + void set_shape(const std::vector &shape) override; + const void *data() const override; + size_t data_size() const override; + bool resize_data(size_t data_len) override; + void *mutable_data() override; + + private: + // if tensor_ is reference from other ms_serving::Tensor, the other's lifetime must + // longer than this object + ms_serving::Tensor &tensor_; +}; + +class ServingImages : public inference::InferImagesBase { + public: + explicit ServingImages(const ms_serving::Images &images); + + size_t batch_size() const override; + bool get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const override; + size_t input_index() const override; + + private: + const ms_serving::Images &images_; +}; + +class ServingRequest : public inference::RequestBase { + public: + explicit ServingRequest(const ms_serving::PredictRequest &request); + + size_t size() const override; + const inference::InferTensorBase *operator[](size_t index) const override; + + private: + const ms_serving::PredictRequest &request_; + std::vector cache_; +}; + +class ServingReply : public inference::ReplyBase { + public: + explicit ServingReply(ms_serving::PredictReply &reply) : reply_(reply) {} + + size_t size() const override; + inference::InferTensorBase *operator[](size_t index) override; + const inference::InferTensorBase *operator[](size_t index) const override; + inference::InferTensorBase *add() override; + void clear() override; + + private: + ms_serving::PredictReply &reply_; + std::vector cache_; +}; + +class ServingImagesRequest : public inference::ImagesRequestBase { + public: + explicit ServingImagesRequest(const ms_serving::PredictRequest &request); + + size_t size() const override; + const inference::InferImagesBase *operator[](size_t index) const override; + + private: + const ms_serving::PredictRequest &request_; + std::vector cache_; +}; + +} // namespace serving +} // namespace mindspore +#endif // MINDSPORE_SERVING_TENSOR_H_ diff --git a/serving/core/util/file_system_operation.cc b/serving/core/util/file_system_operation.cc index 1af512a54c..66bbde3414 100644 --- a/serving/core/util/file_system_operation.cc +++ b/serving/core/util/file_system_operation.cc @@ -25,43 +25,10 @@ #include #include #include -#include "mindspore/ccsrc/utils/log_adapter.h" +#include "include/infer_log.h" namespace mindspore { namespace serving { -char *ReadFile(const char *file, size_t *size) { - if (file == nullptr) { - MS_LOG(ERROR) << "file is nullptr"; - return nullptr; - } - MS_ASSERT(size != nullptr); - std::string realPath = file; - std::ifstream ifs(realPath); - if (!ifs.good()) { - MS_LOG(ERROR) << "file: " << realPath << " is not exist"; - return nullptr; - } - - if (!ifs.is_open()) { - MS_LOG(ERROR) << "file: " << realPath << "open failed"; - return nullptr; - } - - ifs.seekg(0, std::ios::end); - *size = ifs.tellg(); - std::unique_ptr buf(new (std::nothrow) char[*size]); - if (buf == nullptr) { - MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; - ifs.close(); - return nullptr; - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf.get(), *size); - ifs.close(); - - return buf.release(); -} bool DirOrFileExist(const std::string &file_path) { int ret = access(file_path.c_str(), 0); @@ -74,7 +41,7 @@ std::vector GetAllSubDirs(const std::string &dir_path) { std::vector SubDirs; if ((dir = opendir(dir_path.c_str())) == NULL) { - MS_LOG(ERROR) << "Open " << dir_path << " error!"; + MSI_LOG(ERROR) << "Open " << dir_path << " error!"; return std::vector(); } diff --git a/serving/core/util/option_parser.cc b/serving/core/util/option_parser.cc index c7f00e3733..df2047c30f 100644 --- a/serving/core/util/option_parser.cc +++ b/serving/core/util/option_parser.cc @@ -19,10 +19,11 @@ #include #include #include -#include "mindspore/ccsrc/utils/log_adapter.h" +#include "include/infer_log.h" namespace mindspore { namespace serving { + bool StartWith(const std::string &str, const std::string &expected) { return expected.empty() || (str.size() >= expected.size() && memcmp(str.data(), expected.data(), expected.size()) == 0); diff --git a/serving/core/util/status.h b/serving/core/util/status.h index 5f97f9b0b7..e1f1df6874 100644 --- a/serving/core/util/status.h +++ b/serving/core/util/status.h @@ -15,10 +15,14 @@ */ #ifndef MINDSPORE_STATUS_H #define MINDSPORE_STATUS_H +#include "include/inference.h" + namespace mindspore { namespace serving { -using Status = uint32_t; -enum ServingStatus { SUCCESS = 0, FAILED }; +using inference::Status; +using inference::SUCCESS; +using inference::FAILED; +using inference::INVALID_INPUTS; } // namespace serving } // namespace mindspore diff --git a/serving/core/version_control/model.cc b/serving/core/version_control/model.cc index 8e3942b926..a656e9af91 100644 --- a/serving/core/version_control/model.cc +++ b/serving/core/version_control/model.cc @@ -15,18 +15,19 @@ */ #include "core/version_control/model.h" #include -#include "mindspore/ccsrc/utils/log_adapter.h" +#include "include/infer_log.h" namespace mindspore { namespace serving { + MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path, const std::string &model_version, const time_t &last_update_time) : model_name_(model_name), model_path_(model_path), model_version_(model_version), last_update_time_(last_update_time) { - MS_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_ - << ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_; + MSI_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_ + << ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_; } } // namespace serving } // namespace mindspore diff --git a/serving/core/version_control/version_controller.cc b/serving/core/version_control/version_controller.cc index 71aba923d5..1dc6b1b2bb 100644 --- a/serving/core/version_control/version_controller.cc +++ b/serving/core/version_control/version_controller.cc @@ -20,11 +20,12 @@ #include #include #include "util/file_system_operation.h" -#include "mindspore/ccsrc/utils/log_adapter.h" +#include "include/infer_log.h" #include "core/server.h" namespace mindspore { namespace serving { + volatile bool stop_poll = false; std::string GetVersionFromPath(const std::string &path) { @@ -96,7 +97,7 @@ Status VersionController::Run() { Status VersionController::CreateInitModels() { if (!DirOrFileExist(models_path_)) { - MS_LOG(ERROR) << "Model Path Not Exist!" << std::endl; + MSI_LOG(ERROR) << "Model Path Not Exist!" << std::endl; return FAILED; } std::vector SubDirs = GetAllSubDirs(models_path_); @@ -115,7 +116,7 @@ Status VersionController::CreateInitModels() { } } if (valid_models_.empty()) { - MS_LOG(ERROR) << "There is no valid model for serving"; + MSI_LOG(ERROR) << "There is no valid model for serving"; return FAILED; } auto ret = Session::Instance().Warmup(valid_models_.back()); diff --git a/serving/cpp_example/CMakeLists.txt b/serving/cpp_example/CMakeLists.txt deleted file mode 100644 index aaf0277880..0000000000 --- a/serving/cpp_example/CMakeLists.txt +++ /dev/null @@ -1,71 +0,0 @@ -cmake_minimum_required(VERSION 3.5.1) - -project(HelloWorld C CXX) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) - -find_package(Threads REQUIRED) - -# This branch assumes that gRPC and all its dependencies are already installed - # on this system, so they can be located by find_package(). - - # Find Protobuf installation - # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. - set(protobuf_MODULE_COMPATIBLE TRUE) - find_package(Protobuf CONFIG REQUIRED) - message(STATUS "Using protobuf ${protobuf_VERSION}") - - set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) - set(_REFLECTION gRPC::grpc++_reflection) - if(CMAKE_CROSSCOMPILING) - find_program(_PROTOBUF_PROTOC protoc) - else() - set(_PROTOBUF_PROTOC $) - endif() - - # Find gRPC installation - # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. - find_package(gRPC CONFIG REQUIRED) - message(STATUS "Using gRPC ${gRPC_VERSION}") - - set(_GRPC_GRPCPP gRPC::grpc++) - if(CMAKE_CROSSCOMPILING) - find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) - else() - set(_GRPC_CPP_PLUGIN_EXECUTABLE $) - endif() - -# Proto file -get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE) -get_filename_component(hw_proto_path "${hw_proto}" PATH) - -# Generated sources -set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc") -set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h") -set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc") -set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h") -add_custom_command( - OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" - COMMAND ${_PROTOBUF_PROTOC} - ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" - --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" - -I "${hw_proto_path}" - --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" - "${hw_proto}" - DEPENDS "${hw_proto}") - -# Include generated *.pb.h files -include_directories("${CMAKE_CURRENT_BINARY_DIR}") - -# Targets greeter_[async_](client|server) -foreach(_target - ms_client ms_server) - add_executable(${_target} "${_target}.cc" - ${hw_proto_srcs} - ${hw_grpc_srcs}) - target_link_libraries(${_target} - ${_REFLECTION} - ${_GRPC_GRPCPP} - ${_PROTOBUF_LIBPROTOBUF}) -endforeach() diff --git a/serving/cpp_example/ms_client.cc b/serving/cpp_example/ms_client.cc deleted file mode 100644 index 3a9cac77e4..0000000000 --- a/serving/cpp_example/ms_client.cc +++ /dev/null @@ -1,323 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include "./ms_service.grpc.pb.h" - -using grpc::Channel; -using grpc::ClientContext; -using grpc::Status; -using ms_serving::MSService; -using ms_serving::PredictReply; -using ms_serving::PredictRequest; -using ms_serving::Tensor; -using ms_serving::TensorShape; - -enum TypeId : int { - kTypeUnknown = 0, - kMetaTypeBegin = kTypeUnknown, - kMetaTypeType, // Type - kMetaTypeAnything, - kMetaTypeObject, - kMetaTypeTypeType, // TypeType - kMetaTypeProblem, - kMetaTypeExternal, - kMetaTypeNone, - kMetaTypeNull, - kMetaTypeEllipsis, - kMetaTypeEnd, - // - // Object types - // - kObjectTypeBegin = kMetaTypeEnd, - kObjectTypeNumber, - kObjectTypeString, - kObjectTypeList, - kObjectTypeTuple, - kObjectTypeSlice, - kObjectTypeKeyword, - kObjectTypeTensorType, - kObjectTypeClass, - kObjectTypeDictionary, - kObjectTypeFunction, - kObjectTypeJTagged, - kObjectTypeSymbolicKeyType, - kObjectTypeEnvType, - kObjectTypeRefKey, - kObjectTypeRef, - kObjectTypeEnd, - // - // Number Types - // - kNumberTypeBegin = kObjectTypeEnd, - kNumberTypeBool, - kNumberTypeInt, - kNumberTypeInt8, - kNumberTypeInt16, - kNumberTypeInt32, - kNumberTypeInt64, - kNumberTypeUInt, - kNumberTypeUInt8, - kNumberTypeUInt16, - kNumberTypeUInt32, - kNumberTypeUInt64, - kNumberTypeFloat, - kNumberTypeFloat16, - kNumberTypeFloat32, - kNumberTypeFloat64, - kNumberTypeEnd -}; - -std::string RealPath(const char *path) { - if (path == nullptr) { - std::cout << "path is nullptr"; - return ""; - } - if ((strlen(path)) >= PATH_MAX) { - std::cout << "path is too long"; - return ""; - } - - std::shared_ptr resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); - if (resolvedPath == nullptr) { - std::cout << "new resolvedPath failed"; - return ""; - } - - auto ret = realpath(path, resolvedPath.get()); - if (ret == nullptr) { - std::cout << "realpath failed"; - return ""; - } - return resolvedPath.get(); -} - -char *ReadFile(const char *file, size_t *size) { - if (file == nullptr) { - std::cout << "file is nullptr" << std::endl; - return nullptr; - } - if (size == nullptr) { - std::cout << "size should not be nullptr" << std::endl; - return nullptr; - } - std::ifstream ifs(RealPath(file)); - if (!ifs.good()) { - std::cout << "file: " << file << "is not exist"; - return nullptr; - } - - if (!ifs.is_open()) { - std::cout << "file: " << file << "open failed"; - return nullptr; - } - - ifs.seekg(0, std::ios::end); - *size = ifs.tellg(); - std::unique_ptr buf(new (std::nothrow) char[*size]); - if (buf == nullptr) { - std::cout << "malloc buf failed, file: " << file; - ifs.close(); - return nullptr; - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf.get(), *size); - ifs.close(); - - return buf.release(); -} -const std::map id2type_map{ - {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL}, - {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8}, - {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16}, - {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32}, - {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64}, - {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32}, - {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64}, -}; - -int WriteFile(const void *buf, size_t size) { - auto fd = fopen("output.json", "a+"); - if (fd == NULL) { - std::cout << "fd is null and open file fail" << std::endl; - return 0; - } - fwrite(buf, size, 1, fd); - fclose(fd); - return 0; -} - -PredictRequest ReadBertInput() { - size_t size; - auto buf = ReadFile("input206.json", &size); - if (buf == nullptr) { - std::cout << "read file failed" << std::endl; - return PredictRequest(); - } - PredictRequest request; - auto cur = buf; - while (size > 0) { - if (request.data_size() == 4) { - break; - } - Tensor data; - TensorShape shape; - // set type - int type = *(reinterpret_cast(cur)); - cur = cur + sizeof(int); - size = size - sizeof(int); - ms_serving::DataType dataType = id2type_map.at(TypeId(type)); - data.set_tensor_type(dataType); - - // set shape - size_t dims = *(reinterpret_cast(cur)); - cur = cur + sizeof(size_t); - size = size - sizeof(size_t); - - for (size_t i = 0; i < dims; i++) { - int dim = *(reinterpret_cast(cur)); - shape.add_dims(dim); - cur = cur + sizeof(int); - size = size - sizeof(int); - } - *data.mutable_tensor_shape() = shape; - - // set data - size_t data_len = *(reinterpret_cast(cur)); - cur = cur + sizeof(size_t); - size = size - sizeof(size_t); - data.set_data(cur, data_len); - cur = cur + data_len; - size = size - data_len; - *request.add_data() = data; - } - return request; -} - -class MSClient { - public: - explicit MSClient(std::shared_ptr channel) : stub_(MSService::NewStub(channel)) {} - ~MSClient() = default; - - std::string Predict(const std::string &type) { - // Data we are sending to the server. - PredictRequest request; - if (type == "add") { - Tensor data; - TensorShape shape; - shape.add_dims(1); - shape.add_dims(1); - shape.add_dims(2); - shape.add_dims(2); - *data.mutable_tensor_shape() = shape; - data.set_tensor_type(ms_serving::MS_FLOAT32); - std::vector input_data{1.1, 2.1, 3.1, 4.1}; - data.set_data(input_data.data(), input_data.size()); - *request.add_data() = data; - *request.add_data() = data; - } else if (type == "bert") { - request = ReadBertInput(); - } else { - std::cout << "type only support bert or add, but input is " << type << std::endl; - } - std::cout << "intput tensor size is " << request.data_size() << std::endl; - // Container for the data we expect from the server. - PredictReply reply; - - // Context for the client. It could be used to convey extra information to - // the server and/or tweak certain RPC behaviors. - ClientContext context; - - // The actual RPC. - Status status = stub_->Predict(&context, request, &reply); - - for (int i = 0; i < reply.result_size(); i++) { - WriteFile(reply.result(i).data().data(), reply.result(i).data().size()); - } - - std::cout << "the return result size is " << reply.result_size() << std::endl; - - // Act upon its status. - if (status.ok()) { - return "RPC OK"; - } else { - std::cout << status.error_code() << ": " << status.error_message() << std::endl; - return "RPC failed"; - } - } - - private: - std::unique_ptr stub_; -}; - -int main(int argc, char **argv) { - // Instantiate the client. It requires a channel, out of which the actual RPCs - // are created. This channel models a connection to an endpoint specified by - // the argument "--target=" which is the only expected argument. - // We indicate that the channel isn't authenticated (use of - // InsecureChannelCredentials()). - std::string target_str; - std::string arg_target_str("--target"); - std::string type; - std::string arg_type_str("--type"); - if (argc > 2) { - { - // parse target - std::string arg_val = argv[1]; - size_t start_pos = arg_val.find(arg_target_str); - if (start_pos != std::string::npos) { - start_pos += arg_target_str.size(); - if (arg_val[start_pos] == '=') { - target_str = arg_val.substr(start_pos + 1); - } else { - std::cout << "The only correct argument syntax is --target=" << std::endl; - return 0; - } - } else { - target_str = "localhost:5500"; - } - } - - { - // parse type - std::string arg_val2 = argv[2]; - size_t start_pos = arg_val2.find(arg_type_str); - if (start_pos != std::string::npos) { - start_pos += arg_type_str.size(); - if (arg_val2[start_pos] == '=') { - type = arg_val2.substr(start_pos + 1); - } else { - std::cout << "The only correct argument syntax is --target=" << std::endl; - return 0; - } - } else { - type = "add"; - } - } - } else { - target_str = "localhost:5500"; - type = "add"; - } - MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); - std::string reply = client.Predict(type); - std::cout << "client received: " << reply << std::endl; - - return 0; -} diff --git a/serving/cpp_example/ms_server.cc b/serving/cpp_example/ms_server.cc deleted file mode 100644 index f6021ef000..0000000000 --- a/serving/cpp_example/ms_server.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include - -#include "./ms_service.grpc.pb.h" - -using grpc::Server; -using grpc::ServerBuilder; -using grpc::ServerContext; -using grpc::Status; -using ms_serving::MSService; -using ms_serving::PredictReply; -using ms_serving::PredictRequest; - -// Logic and data behind the server's behavior. -class MSServiceImpl final : public MSService::Service { - Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override { - std::cout << "server eval" << std::endl; - return Status::OK; - } -}; - -void RunServer() { - std::string server_address("0.0.0.0:50051"); - MSServiceImpl service; - - grpc::EnableDefaultHealthCheckService(true); - grpc::reflection::InitProtoReflectionServerBuilderPlugin(); - auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); - - ServerBuilder builder; - builder.SetOption(std::move(option)); - // Listen on the given address without any authentication mechanism. - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - // Register "service" as the instance through which we'll communicate with - // clients. In this case it corresponds to an *synchronous* service. - builder.RegisterService(&service); - // Finally assemble the server. - std::unique_ptr server(builder.BuildAndStart()); - std::cout << "Server listening on " << server_address << std::endl; - - // Wait for the server to shutdown. Note that some other thread must be - // responsible for shutting down the server for this call to ever return. - server->Wait(); -} - -int main(int argc, char **argv) { - RunServer(); - - return 0; -} diff --git a/serving/example/cpp_client/CMakeLists.txt b/serving/example/cpp_client/CMakeLists.txt new file mode 100644 index 0000000000..44df268e58 --- /dev/null +++ b/serving/example/cpp_client/CMakeLists.txt @@ -0,0 +1,87 @@ +cmake_minimum_required(VERSION 3.5.1) + +project(MSClient C CXX) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +find_package(Threads REQUIRED) + +# This branch assumes that gRPC and all its dependencies are already installed +# on this system, so they can be located by find_package(). + +# Find Protobuf installation +# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. +option(GRPC_PATH "set grpc path") +if(GRPC_PATH) + set(CMAKE_PREFIX_PATH ${GRPC_PATH}) + set(protobuf_MODULE_COMPATIBLE TRUE) + find_package(Protobuf CONFIG REQUIRED) + message(STATUS "Using protobuf ${protobuf_VERSION}, CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}") +elseif(NOT GRPC_PATH) + if (EXISTS ${grpc_ROOT}/lib64) + set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc") + elseif(EXISTS ${grpc_ROOT}/lib) + set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc") + endif() + add_library(protobuf::libprotobuf ALIAS protobuf::protobuf) + add_executable(protobuf::libprotoc ALIAS protobuf::protoc) + message(STATUS "serving using grpc_DIR : " ${gRPC_DIR}) +elseif(NOT gRPC_DIR AND NOT GRPC_PATH) + message("please check gRPC. If the client is compiled separately,you can use the command: cmake -D GRPC_PATH=xxx") + message("XXX is the gRPC installation path") +endif() + +set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) +set(_REFLECTION gRPC::grpc++_reflection) +if(CMAKE_CROSSCOMPILING) + find_program(_PROTOBUF_PROTOC protoc) +else() + set(_PROTOBUF_PROTOC $) +endif() + +# Find gRPC installation +# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. +find_package(gRPC CONFIG REQUIRED) +message(STATUS "Using gRPC ${gRPC_VERSION}") + +set(_GRPC_GRPCPP gRPC::grpc++) +if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) +else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) +endif() + +# Proto file +get_filename_component(hw_proto "../../ms_service.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) + +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h") +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") + +# Include generated *.pb.h files +include_directories("${CMAKE_CURRENT_BINARY_DIR}") + +# Targets greeter_[async_](client|server) +foreach(_target + ms_client) + add_executable(${_target} "${_target}.cc" + ${hw_proto_srcs} + ${hw_grpc_srcs}) + target_link_libraries(${_target} + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) +endforeach() diff --git a/serving/example/cpp_client/ms_client.cc b/serving/example/cpp_client/ms_client.cc new file mode 100644 index 0000000000..720f1cf704 --- /dev/null +++ b/serving/example/cpp_client/ms_client.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "./ms_service.grpc.pb.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using ms_serving::MSService; +using ms_serving::PredictReply; +using ms_serving::PredictRequest; +using ms_serving::Tensor; +using ms_serving::TensorShape; + +class MSClient { + public: + explicit MSClient(std::shared_ptr channel) : stub_(MSService::NewStub(channel)) {} + + ~MSClient() = default; + + std::string Predict() { + // Data we are sending to the server. + PredictRequest request; + + Tensor data; + TensorShape shape; + shape.add_dims(4); + *data.mutable_tensor_shape() = shape; + data.set_tensor_type(ms_serving::MS_FLOAT32); + std::vector input_data{1, 2, 3, 4}; + data.set_data(input_data.data(), input_data.size() * sizeof(float)); + *request.add_data() = data; + *request.add_data() = data; + std::cout << "intput tensor size is " << request.data_size() << std::endl; + // Container for the data we expect from the server. + PredictReply reply; + + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + ClientContext context; + + // The actual RPC. + Status status = stub_->Predict(&context, request, &reply); + std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl; + + // Act upon its status. + if (status.ok()) { + std::cout << "Add result is"; + for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) { + std::cout << " " << (reinterpret_cast(reply.mutable_result(0)->mutable_data()->data()))[i]; + } + std::cout << std::endl; + return "RPC OK"; + } else { + std::cout << status.error_code() << ": " << status.error_message() << std::endl; + return "RPC failed"; + } + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char **argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint specified by + // the argument "--target=" which is the only expected argument. + // We indicate that the channel isn't authenticated (use of + // InsecureChannelCredentials()). + std::string target_str; + std::string arg_target_str("--target"); + if (argc > 1) { + // parse target + std::string arg_val = argv[1]; + size_t start_pos = arg_val.find(arg_target_str); + if (start_pos != std::string::npos) { + start_pos += arg_target_str.size(); + if (arg_val[start_pos] == '=') { + target_str = arg_val.substr(start_pos + 1); + } else { + std::cout << "The only correct argument syntax is --target=" << std::endl; + return 0; + } + } else { + target_str = "localhost:5500"; + } + } else { + target_str = "localhost:5500"; + } + MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); + std::string reply = client.Predict(); + std::cout << "client received: " << reply << std::endl; + + return 0; +} diff --git a/serving/example/export_model/add_model.py b/serving/example/export_model/add_model.py new file mode 100644 index 0000000000..caf354cc5a --- /dev/null +++ b/serving/example/export_model/add_model.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.train.serialization import export + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.add = P.TensorAdd() + + def construct(self, x_, y_): + return self.add(x_, y_) + +x = np.ones(4).astype(np.float32) +y = np.ones(4).astype(np.float32) + +def export_net(): + add = Net() + output = add(Tensor(x), Tensor(y)) + export(add, Tensor(x), Tensor(y), file_name='tensor_add.pb', file_format='BINARY') + print(x) + print(y) + print(output.asnumpy()) + +if __name__ == "__main__": + export_net() + \ No newline at end of file diff --git a/serving/example/python_client/ms_client.py b/serving/example/python_client/ms_client.py new file mode 100644 index 0000000000..8ea64916d9 --- /dev/null +++ b/serving/example/python_client/ms_client.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import grpc +import numpy as np +import ms_service_pb2 +import ms_service_pb2_grpc + + +def run(): + if len(sys.argv) > 2: + sys.exit("input error") + channel_str = "" + if len(sys.argv) == 2: + split_args = sys.argv[1].split('=') + if len(split_args) > 1: + channel_str = split_args[1] + else: + channel_str = 'localhost:5500' + else: + channel_str = 'localhost:5500' + + channel = grpc.insecure_channel(channel_str) + stub = ms_service_pb2_grpc.MSServiceStub(channel) + request = ms_service_pb2.PredictRequest() + + x = request.data.add() + x.tensor_shape.dims.extend([4]) + x.tensor_type = ms_service_pb2.MS_FLOAT32 + x.data = (np.ones([4]).astype(np.float32)).tobytes() + + y = request.data.add() + y.tensor_shape.dims.extend([4]) + y.tensor_type = ms_service_pb2.MS_FLOAT32 + y.data = (np.ones([4]).astype(np.float32)).tobytes() + + try: + result = stub.Predict(request) + print(result) + result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) + print("ms client received: ") + print(result_np) + except grpc.RpcError as e: + print(e.details()) + status_code = e.code() + print(status_code.name) + print(status_code.value) + +if __name__ == '__main__': + run() diff --git a/serving/ms_service.proto b/serving/ms_service.proto index f11bc235e7..26cedad91b 100644 --- a/serving/ms_service.proto +++ b/serving/ms_service.proto @@ -20,17 +20,19 @@ syntax = "proto3"; package ms_serving; service MSService { - rpc Predict(PredictRequest) returns (PredictReply) {} - rpc Test(PredictRequest) returns (PredictReply) {} + rpc Predict(PredictRequest) returns (PredictReply) {} + rpc Test(PredictRequest) returns (PredictReply) {} } message PredictRequest { - repeated Tensor data = 1; + repeated Tensor data = 1; + repeated Images images = 2; } message PredictReply { - repeated Tensor result = 1; + repeated Tensor result = 1; } + enum DataType { MS_UNKNOWN = 0; MS_BOOL = 1; @@ -62,3 +64,7 @@ message Tensor { bytes data = 3; } +message Images{ + repeated bytes images = 1; + uint32 input_index = 2; +} diff --git a/serving/python_example/ms_client.py b/serving/python_example/ms_client.py deleted file mode 100644 index d567d089b8..0000000000 --- a/serving/python_example/ms_client.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import grpc -import numpy as np -import ms_service_pb2 -import ms_service_pb2_grpc - - -def run(): - channel = grpc.insecure_channel('localhost:50051') - stub = ms_service_pb2_grpc.MSServiceStub(channel) - # request = ms_service_pb2.PredictRequest() - # request.name = 'haha' - # response = stub.Eval(request) - # print("ms client received: " + response.message) - - request = ms_service_pb2.PredictRequest() - request.data.tensor_shape.dims.extend([32, 1, 32, 32]) - request.data.tensor_type = ms_service_pb2.MS_FLOAT32 - request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes() - - request.label.tensor_shape.dims.extend([32]) - request.label.tensor_type = ms_service_pb2.MS_INT32 - request.label.data = np.ones([32]).astype(np.int32).tobytes() - - result = stub.Predict(request) - #result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims) - print("ms client received: ") - #print(result_np) - - # future_list = [] - # times = 1000 - # for i in range(times): - # async_future = stub.Eval.future(request) - # future_list.append(async_future) - # print("async call, future list add item " + str(i)); - # - # for i in range(len(future_list)): - # async_result = future_list[i].result() - # print("ms client async get result of item " + str(i)) - - - -if __name__ == '__main__': - run() diff --git a/serving/python_example/ms_client_test_call.py b/serving/python_example/ms_client_test_call.py deleted file mode 100644 index 56643d8351..0000000000 --- a/serving/python_example/ms_client_test_call.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import grpc -import numpy as np -import ms_service_pb2 -import ms_service_pb2_grpc - - -def run(): - channel = grpc.insecure_channel('localhost:50051') - stub = ms_service_pb2_grpc.MSServiceStub(channel) - # request = ms_service_pb2.EvalRequest() - # request.name = 'haha' - # response = stub.Eval(request) - # print("ms client received: " + response.message) - - request = ms_service_pb2.PredictRequest() - request.data.tensor_shape.dims.extend([32, 1, 32, 32]) - request.data.tensor_type = ms_service_pb2.MS_FLOAT32 - request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes() - - request.label.tensor_shape.dims.extend([32]) - request.label.tensor_type = ms_service_pb2.MS_INT32 - request.label.data = np.ones([32]).astype(np.int32).tobytes() - - result = stub.Test(request) - #result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims) - print("ms client test call received: ") - #print(result_np) - - - -if __name__ == '__main__': - run() diff --git a/serving/python_example/ms_server.py b/serving/python_example/ms_server.py deleted file mode 100644 index f538856804..0000000000 --- a/serving/python_example/ms_server.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from concurrent import futures -import time -import grpc -import numpy as np -import ms_service_pb2 -import ms_service_pb2_grpc -import test_cpu_lenet -from mindspore import Tensor - -class MSService(ms_service_pb2_grpc.MSServiceServicer): - def Predict(self, request, context): - request_data = request.data - request_label = request.label - - data_from_buffer = np.frombuffer(request_data.data, dtype=np.float32) - data_from_buffer = data_from_buffer.reshape(request_data.tensor_shape.dims) - data = Tensor(data_from_buffer) - - label_from_buffer = np.frombuffer(request_label.data, dtype=np.int32) - label_from_buffer = label_from_buffer.reshape(request_label.tensor_shape.dims) - label = Tensor(label_from_buffer) - - result = test_cpu_lenet.test_lenet(data, label) - result_reply = ms_service_pb2.PredictReply() - result_reply.result.tensor_shape.dims.extend(result.shape()) - result_reply.result.data = result.asnumpy().tobytes() - return result_reply - -def serve(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) - ms_service_pb2_grpc.add_MSServiceServicer_to_server(MSService(), server) - server.add_insecure_port('[::]:50051') - server.start() - try: - while True: - time.sleep(60*60*24) # one day in seconds - except KeyboardInterrupt: - server.stop(0) - -if __name__ == '__main__': - serve() diff --git a/serving/python_example/test_cpu_lenet.py b/serving/python_example/test_cpu_lenet.py deleted file mode 100644 index a609c9b924..0000000000 --- a/serving/python_example/test_cpu_lenet.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Momentum -from mindspore.ops import operations as P -import ms_service_pb2 - - -class LeNet(nn.Cell): - def __init__(self): - super(LeNet, self).__init__() - self.relu = P.ReLU() - self.batch_size = 32 - - self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() - self.fc1 = nn.Dense(400, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, 10) - - def construct(self, input_x): - output = self.conv1(input_x) - output = self.relu(output) - output = self.pool(output) - output = self.conv2(output) - output = self.relu(output) - output = self.pool(output) - output = self.reshape(output, (self.batch_size, -1)) - output = self.fc1(output) - output = self.relu(output) - output = self.fc2(output) - output = self.relu(output) - output = self.fc3(output) - return output - - -def train(net, data, label): - learning_rate = 0.01 - momentum = 0.9 - - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) - criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer - train_network.set_train() - res = train_network(data, label) - print("+++++++++Loss+++++++++++++") - print(res) - print("+++++++++++++++++++++++++++") - assert res - return res - -def test_lenet(data, label): - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - net = LeNet() - return train(net, data, label) - -if __name__ == '__main__': - tensor = ms_service_pb2.Tensor() - tensor.tensor_shape.dim.extend([32, 1, 32, 32]) - # tensor.tensor_shape.dim.add() = 1 - # tensor.tensor_shape.dim.add() = 32 - # tensor.tensor_shape.dim.add() = 32 - tensor.tensor_type = ms_service_pb2.MS_FLOAT32 - tensor.data = np.ones([32, 1, 32, 32]).astype(np.float32).tobytes() - - data_from_buffer = np.frombuffer(tensor.data, dtype=np.float32) - print(tensor.tensor_shape.dim) - data_from_buffer = data_from_buffer.reshape(tensor.tensor_shape.dim) - print(data_from_buffer.shape) - input_data = Tensor(data_from_buffer * 0.01) - input_label = Tensor(np.ones([32]).astype(np.int32)) - test_lenet(input_data, input_label) diff --git a/setup.py b/setup.py index bf16c9106b..2836a24c31 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ from setuptools import setup, find_packages from setuptools.command.egg_info import egg_info from setuptools.command.build_py import build_py -version = '0.5.0' +version = '0.6.0' backend_policy = os.getenv('BACKEND_POLICY') commit_id = os.getenv('COMMIT_ID').replace("\n", "") @@ -146,7 +146,7 @@ class BuildPy(build_py): super().run() mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore') update_permissions(mindspore_dir) - mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', '_akg') + mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'akg') update_permissions(mindspore_dir) diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 237e38a9d3..84f0ae00da 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -50,11 +50,18 @@ class MindData: def input_indexs(self): return self._input_indexs - def device_que(self): + def device_que(self, send_epoch_end=True): self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' + self.send_epoch_end = send_epoch_end return self - def send(self): + def create_tuple_iterator(self): + return self.__iter__() + + def send(self, num_epochs=-1): + pass + + def stop_send(self): pass def __len__(self): diff --git a/tests/mindspore_test_framework/apps/bert_attention_submodules.py b/tests/mindspore_test_framework/apps/bert_attention_submodules.py index 4ce72ffc84..83729d9e70 100644 --- a/tests/mindspore_test_framework/apps/bert_attention_submodules.py +++ b/tests/mindspore_test_framework/apps/bert_attention_submodules.py @@ -108,7 +108,7 @@ class BertAttentionRelativePositionKeys(nn.Cell): self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_relative = (2, 0, 1, 3) - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype) + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.reshape = P.Reshape() self.multiply = P.Mul() @@ -301,7 +301,7 @@ class BertAttentionRelativePositionValues(nn.Cell): self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_relative = (2, 0, 1, 3) - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype) + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.trans_shape = (0, 2, 1, 3) self.reshape = P.Reshape() diff --git a/tests/mindspore_test_framework/apps/test_lamb_check_loss.py b/tests/mindspore_test_framework/apps/test_lamb_check_loss.py index 4498959620..11e13261e6 100644 --- a/tests/mindspore_test_framework/apps/test_lamb_check_loss.py +++ b/tests/mindspore_test_framework/apps/test_lamb_check_loss.py @@ -30,7 +30,7 @@ verification_set = [ 'block': { 'model': network, 'loss': SquaredLoss(), - 'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01), + 'opt': Lamb(network.trainable_params(), 0.02, weight_decay=0.01), 'num_epochs': num_epochs, 'loss_upper_bound': 0.3, }, diff --git a/tests/mindspore_test_framework/pipeline/gradient/check_training.py b/tests/mindspore_test_framework/pipeline/gradient/check_training.py index 135b162ec7..61ed61af6d 100644 --- a/tests/mindspore_test_framework/pipeline/gradient/check_training.py +++ b/tests/mindspore_test_framework/pipeline/gradient/check_training.py @@ -31,7 +31,7 @@ Example: 'block': { 'model': network, 'loss': SquaredLoss(), - 'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01), + 'opt': Lamb(network.trainable_params(), lr=0.02, weight_decay=0.01), 'num_epochs': num_epochs, 'loss_upper_bound': 0.3, }, diff --git a/tests/mindspore_test_framework/utils/bprop_util.py b/tests/mindspore_test_framework/utils/bprop_util.py index 1fbbd4dded..1990c1d0df 100644 --- a/tests/mindspore_test_framework/utils/bprop_util.py +++ b/tests/mindspore_test_framework/utils/bprop_util.py @@ -37,7 +37,7 @@ class Bprop(Cell): self.grad = grad_op self.sens = sens self.with_sens = False - if sens: + if sens is not None: self.with_sens = True def construct(self, *inputs): @@ -71,10 +71,10 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list func.set_train() with_sens_param = False - if grads_wrt_outputs: + if grads_wrt_outputs is not None: with_sens_param = True - if not wrt: + if wrt is None: wrt = [] wrt_inputs = False if 'inputs' in wrt: diff --git a/tests/mindspore_test_framework/utils/facade_util.py b/tests/mindspore_test_framework/utils/facade_util.py index c039fa4e5e..7e50b6125a 100644 --- a/tests/mindspore_test_framework/utils/facade_util.py +++ b/tests/mindspore_test_framework/utils/facade_util.py @@ -63,7 +63,7 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex sampling_times, reduce_output, init_param_with, \ split_outputs, exception, error_keywords = get_function_config(block_config[-1]) - if block: + if block is not None: func_list.append({ keyword.id: tid, keyword.group: group, diff --git a/tests/perf_test/bert/test_bert_train.py b/tests/perf_test/bert/test_bert_train.py index 096571adea..705318c283 100644 --- a/tests/perf_test/bert/test_bert_train.py +++ b/tests/perf_test/bert/test_bert_train.py @@ -22,12 +22,15 @@ import os import mindspore.common.dtype as mstype import mindspore.context as context from mindspore import Tensor -from mindspore.nn.optim import AdamWeightDecayDynamicLR +from mindspore.ops import operations as P +from mindspore.nn.optim import AdamWeightDecay from mindspore.train.loss_scale_manager import DynamicLossScaleManager -from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from mindspore.nn import learning_rate_schedule as lr_schedules +from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from ...dataset_mock import MindData from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph + _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data" context.set_context(mode=context.GRAPH_MODE) @@ -98,6 +101,25 @@ def get_config(version='base', batch_size=1): return BertConfig(batch_size=batch_size) +class BertLearningRate(lr_schedules.LearningRateSchedule): + def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr + + def test_bert_train(): """ the main function @@ -123,7 +145,8 @@ def test_bert_train(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=False) @@ -147,7 +170,8 @@ def test_bert_withlossscale_train(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True) @@ -173,7 +197,8 @@ def bert_withlossscale_manager_train(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True) @@ -200,7 +225,8 @@ def bert_withlossscale_manager_train_feed(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True) diff --git a/tests/st/control/test_multigraph_sink.py b/tests/st/control/test_multigraph_sink.py index 9f7d24a80a..ba0f45b06c 100644 --- a/tests/st/control/test_multigraph_sink.py +++ b/tests/st/control/test_multigraph_sink.py @@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor def setup_module(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") c1 = Tensor([2], mstype.int32) diff --git a/tests/st/gnn/gcn/test_gcn.py b/tests/st/gnn/gcn/test_gcn.py index afe123cdb0..09789a00e5 100644 --- a/tests/st/gnn/gcn/test_gcn.py +++ b/tests/st/gnn/gcn/test_gcn.py @@ -17,10 +17,10 @@ import time import pytest import numpy as np from mindspore import context -from model_zoo.gcn.src.gcn import GCN -from model_zoo.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper -from model_zoo.gcn.src.config import ConfigGCN -from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask +from model_zoo.official.gnn.gcn.src.gcn import GCN +from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper +from model_zoo.official.gnn.gcn.src.config import ConfigGCN +from model_zoo.official.gnn.gcn.src.dataset import get_adj_features_labels, get_mask DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr' diff --git a/tests/st/heterogeneous_excutor/test_control.py b/tests/st/heterogeneous_excutor/test_control.py new file mode 100644 index 0000000000..189441f1f9 --- /dev/null +++ b/tests/st/heterogeneous_excutor/test_control.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net1(nn.Cell): + def __init__(self): + super(Net1, self).__init__() + self.relu1 = P.ReLU() + self.relu2 = P.ReLU() + self.mul = P.Mul() + self.control = P.ControlDepend() + + def construct(self, x, y): + a = self.relu1(x) + b = self.relu2(y) + c = self.mul(a, b) + e = self.control(a, b) + return c, e + + +class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.relu1 = P.ReLU() + self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU") + self.mul = P.Mul() + self.control = P.ControlDepend() + + def construct(self, x, y): + a = self.relu1(x) + b = self.relu2(y) + c = self.mul(a, b) + e = self.control(a, b) + return c, e + + +def test_net(): + x = np.random.randn(2, 3, 3, 4).astype(np.float32) + y = np.random.randn(2, 3, 3, 4).astype(np.float32) + net1 = Net1() + output1 = net1(Tensor(x), Tensor(y)) + + context.set_context(save_graphs=True) + net2 = Net2() + output2 = net2(Tensor(x), Tensor(y)) + assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy()) + print("##success##") + + +if __name__ == "__main__": + test_net() diff --git a/tests/st/host_device/test_host_device_lenet.py b/tests/st/host_device/test_host_device_lenet.py new file mode 100644 index 0000000000..0a312a3422 --- /dev/null +++ b/tests/st/host_device/test_host_device_lenet.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class LeNet(nn.Cell): + def __init__(self): + super(LeNet, self).__init__() + self.relu = P.ReLU() + self.batch_size = 32 + + self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.fc1 = nn.Dense(400, 120) + self.fc1.matmul.add_prim_attr("primitive_target", "CPU") + self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc2 = nn.Dense(120, 84) + self.fc2.matmul.add_prim_attr("primitive_target", "CPU") + self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc3 = nn.Dense(84, 10) + self.fc3.matmul.add_prim_attr("primitive_target", "CPU") + self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") + + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = self.conv2(output) + output = self.relu(output) + output = self.pool(output) + output = self.reshape(output, (self.batch_size, -1)) + output = self.fc1(output) + output = self.relu(output) + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + return output + + +def train(net, data, label): + learning_rate = 0.01 + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + res = train_network(data, label) + print("+++++++++Loss+++++++++++++") + print(res) + print("+++++++++++++++++++++++++++") + diff = res.asnumpy()[0] - 2.3025851 + assert np.all(diff < 1.e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_lenet(): + data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + train(net, data, label) diff --git a/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py b/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py index 73931a8046..878ab37812 100644 --- a/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py +++ b/tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py @@ -73,7 +73,7 @@ if __name__ == "__main__": epoch_size = 3 args_opt.base_size = config.crop_size args_opt.crop_size = config.crop_size - train_dataset = create_dataset(args_opt, args_opt.data_url, epoch_size, config.batch_size, + train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train", shuffle=False) dataset_size = train_dataset.get_dataset_size() callback = LossCallBack(dataset_size) diff --git a/tests/st/model_zoo_tests/transformer/test_transformer.py b/tests/st/model_zoo_tests/transformer/test_transformer.py index ebfdbbbb7e..5e413edf74 100644 --- a/tests/st/model_zoo_tests/transformer/test_transformer.py +++ b/tests/st/model_zoo_tests/transformer/test_transformer.py @@ -25,12 +25,12 @@ from mindspore.train.model import Model from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.callback import Callback from mindspore import context -from model_zoo.Transformer.src.transformer_model import TransformerConfig -from model_zoo.Transformer.src.transformer_for_train import TransformerNetworkWithLoss, \ - TransformerTrainOneStepWithLossScaleCell -from model_zoo.Transformer.src.config import cfg -from model_zoo.Transformer.src.dataset import create_transformer_dataset -from model_zoo.Transformer.src.lr_schedule import create_dynamic_lr +from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig +from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \ + TransformerTrainOneStepWithLossScaleCell +from model_zoo.official.nlp.transformer.src.config import cfg +from model_zoo.official.nlp.transformer.src.dataset import create_transformer_dataset +from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"] @@ -120,10 +120,10 @@ def test_transformer(): batch_size = 96 epoch_size = 3 config = get_config(version=version, batch_size=batch_size) - dataset, repeat_count = create_transformer_dataset(epoch_count=epoch_size, - do_shuffle="false", - enable_data_sink="false", - dataset_path=DATA_DIR) + dataset = create_transformer_dataset(epoch_count=1, + do_shuffle="false", + enable_data_sink="false", + dataset_path=DATA_DIR) netwithloss = TransformerNetworkWithLoss(config, True) @@ -146,7 +146,7 @@ def test_transformer(): netwithgrads.set_train(True) time_monitor_callback = TimeMonitor(dataset.get_dataset_size()) model = Model(netwithgrads) - model.train(repeat_count, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False) + model.train(epoch_size, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False) # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) diff --git a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py index 0aca7d1e75..930d7d6aaa 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py +++ b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py @@ -79,9 +79,9 @@ def test_train_eval(): batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size, + ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size, + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) diff --git a/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh b/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh index 189014ce91..0cdd327212 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh +++ b/tests/st/model_zoo_tests/wide_and_deep/run_wide_and_deep_auto_parallel.sh @@ -21,10 +21,10 @@ export RANK_SIZE=$DEVICE_NUM unset SLOG_PRINT_TO_STDOUT export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json CODE_DIR="./" -if [ -d ${BASE_PATH}/../../../../model_zoo/wide_and_deep ]; then - CODE_DIR=${BASE_PATH}/../../../../model_zoo/wide_and_deep -elif [ -d ${BASE_PATH}/../../model_zoo/wide_and_deep ]; then - CODE_DIR=${BASE_PATH}/../../model_zoo/wide_and_deep +if [ -d ${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep ]; then + CODE_DIR=${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep +elif [ -d ${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep ]; then + CODE_DIR=${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep else echo "[ERROR] code dir is not found" fi diff --git a/tests/st/model_zoo_tests/wide_and_deep/train_and_test_multinpu_ci_data_parallel.py b/tests/st/model_zoo_tests/wide_and_deep/train_and_test_multinpu_ci_data_parallel.py index e39562c92f..0f909b0236 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/train_and_test_multinpu_ci_data_parallel.py +++ b/tests/st/model_zoo_tests/wide_and_deep/train_and_test_multinpu_ci_data_parallel.py @@ -76,9 +76,9 @@ def test_train_eval(): batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, + ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) diff --git a/tests/st/model_zoo_tests/yolov3/test_yolov3.py b/tests/st/model_zoo_tests/yolov3/test_yolov3.py index 126c66a6f3..f95e913096 100644 --- a/tests/st/model_zoo_tests/yolov3/test_yolov3.py +++ b/tests/st/model_zoo_tests/yolov3/test_yolov3.py @@ -113,14 +113,13 @@ def test_yolov3(): loss_scale = float(loss_scale) # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. - dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size, + dataset = create_yolo_dataset(mindrecord_file, repeat_num=1, batch_size=batch_size, device_num=device_num, rank=rank) dataset_size = dataset.get_dataset_size() print("Create dataset done!") net = yolov3_resnet18(ConfigYOLOV3ResNet18()) net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) - init_net_param(net) total_epoch_size = 60 lr = Tensor(get_lr(learning_rate=lr_init, start_step=0, @@ -146,12 +145,12 @@ def test_yolov3(): assert loss_value[2] < expect_loss_value[2] epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] - expect_epoch_mseconds = 950 + expect_epoch_mseconds = 2000 print("epoch mseconds: {}".format(epoch_mseconds)) assert epoch_mseconds <= expect_epoch_mseconds per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] - expect_per_step_mseconds = 110 + expect_per_step_mseconds = 220 print("per step mseconds: {}".format(per_step_mseconds)) assert per_step_mseconds <= expect_per_step_mseconds print("yolov3 test case passed.") diff --git a/tests/st/networks/models/bert/__init__.py b/tests/st/networks/models/bert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/st/networks/models/bert/src/bert_model.py b/tests/st/networks/models/bert/src/bert_model.py index 310d330daa..c9ecf3c064 100644 --- a/tests/st/networks/models/bert/src/bert_model.py +++ b/tests/st/networks/models/bert/src/bert_model.py @@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell): def __init__(self, length, max_relative_position): super(RelaPosMatrixGenerator, self).__init__() self._length = length - self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) - self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) + self._max_relative_position = max_relative_position + self._min_relative_position = -max_relative_position self.range_length = -length + 1 self.tile = P.Tile() @@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, max_relative_position=max_relative_position) self.reshape = P.Reshape() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) + self.one_hot = nn.OneHot(depth=self.vocab_size) self.shape = P.Shape() self.gather = P.GatherV2() # index_select self.matmul = P.BatchMatMul() @@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): if self.use_one_hot_embeddings: flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) one_hot_relative_positions_matrix = self.one_hot( - flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) + flat_relative_positions_matrix) embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) embeddings = self.reshape(embeddings, my_shape) @@ -372,11 +370,11 @@ class SaturateCast(nn.Cell): def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): super(SaturateCast, self).__init__() np_type = mstype.dtype_to_nptype(dst_type) - min_type = np.finfo(np_type).min - max_type = np.finfo(np_type).max + min_type = float(np.finfo(np_type).min) + max_type = float(np.finfo(np_type).max) - self.tensor_min_type = Tensor([min_type], dtype=src_type) - self.tensor_max_type = Tensor([max_type], dtype=src_type) + self.tensor_min_type = min_type + self.tensor_max_type = max_type self.min_op = P.Minimum() self.max_op = P.Maximum() @@ -442,7 +440,7 @@ class BertAttention(nn.Cell): self.has_attention_mask = has_attention_mask self.use_relative_positions = use_relative_positions - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.reshape = P.Reshape() self.shape_from_2d = (-1, from_tensor_width) self.shape_to_2d = (-1, to_tensor_width) @@ -471,7 +469,7 @@ class BertAttention(nn.Cell): self.trans_shape = (0, 2, 1, 3) self.trans_shape_relative = (2, 0, 1, 3) self.trans_shape_position = (1, 2, 0, 3) - self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.multiply_data = -10000.0 self.batch_num = batch_size * num_attention_heads self.matmul = P.BatchMatMul() diff --git a/tests/st/networks/models/bert/src/config.py b/tests/st/networks/models/bert/src/config.py index 812f0c2f18..0aef2bc8c9 100644 --- a/tests/st/networks/models/bert/src/config.py +++ b/tests/st/networks/models/bert/src/config.py @@ -24,7 +24,7 @@ cfg = edict({ 'scale_factor': 2, 'scale_window': 1000, 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ + 'AdamWeightDecay': edict({ 'learning_rate': 3e-5, 'end_learning_rate': 1e-10, 'power': 5.0, @@ -33,7 +33,7 @@ cfg = edict({ 'warmup_steps': 10000, }), 'Lamb': edict({ - 'start_learning_rate': 3e-5, + 'learning_rate': 3e-5, 'end_learning_rate': 1e-10, 'power': 10.0, 'warmup_steps': 10000, diff --git a/tests/st/networks/models/bert/src/finetune_config.py b/tests/st/networks/models/bert/src/finetune_config.py index e92842489b..466676fd2e 100644 --- a/tests/st/networks/models/bert/src/finetune_config.py +++ b/tests/st/networks/models/bert/src/finetune_config.py @@ -32,7 +32,7 @@ cfg = edict({ 'pre_training_ckpt': '/your/path/pre_training.ckpt', 'use_crf': False, 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ + 'AdamWeightDecay': edict({ 'learning_rate': 2e-5, 'end_learning_rate': 1e-7, 'power': 1.0, @@ -40,7 +40,7 @@ cfg = edict({ 'eps': 1e-6, }), 'Lamb': edict({ - 'start_learning_rate': 2e-5, + 'learning_rate': 2e-5, 'end_learning_rate': 1e-7, 'power': 1.0, 'decay_filter': lambda x: False, diff --git a/tests/st/networks/models/bert/test_bert_graph_kernel.py b/tests/st/networks/models/bert/test_bert_graph_kernel.py index 4c9673e076..576c7a32c5 100644 --- a/tests/st/networks/models/bert/test_bert_graph_kernel.py +++ b/tests/st/networks/models/bert/test_bert_graph_kernel.py @@ -29,9 +29,11 @@ from mindspore.nn.optim import Lamb from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.model import Model +from mindspore.nn import learning_rate_schedule as lr_schedules from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell from src.bert_model import BertConfig + DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" @@ -111,6 +113,25 @@ def weight_variable(shape): return Tensor(ones) +class BertLearningRate(lr_schedules.LearningRateSchedule): + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr + + class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() @@ -134,9 +155,15 @@ def test_bert_tdt(): ds = me_de_train_dataset() config = get_config(version='large', batch_size=16) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), - start_learning_rate=5e-5, end_learning_rate=1e-9, - power=10.0, warmup_steps=0, weight_decay=0.01) + lr = BertLearningRate(decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), learning_rate=5e-5, + end_learning_rate=1e-9, power=10.0, warmup_steps=0) + decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() + no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() + decay_params = list(filter(decay_filter, net_with_loss.trainable_params())) + other_params = list(filter(no_decay_filter, net_with_loss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}] + optimizer = Lamb(group_params, lr) scale_window = 3 scale_manager = DynamicLossScaleManager(262144, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, @@ -144,9 +171,9 @@ def test_bert_tdt(): netwithgrads.set_train(True) model = Model(netwithgrads) callback = ModelCallback() + netwithloss.init_parameters_data() params = netwithloss.trainable_params() for param in params: - param.init_data() value = param.default_input name = param.name if isinstance(value, Tensor): diff --git a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py index d4c56edbc1..150bc0fd41 100644 --- a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py @@ -28,11 +28,13 @@ import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C from mindspore import context from mindspore import log as logger +from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.nn.optim import Lamb from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.model import Model +import mindspore.nn.learning_rate_schedule as lr_schedules _current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] @@ -91,6 +93,7 @@ def me_de_train_dataset(sink_mode=False): """test me de train dataset""" # apply repeat operations repeat_count = 1 + sink_size = -1 batch_size = 16 ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", @@ -98,12 +101,8 @@ def me_de_train_dataset(sink_mode=False): type_cast_op = C.TypeCast(mstype.int32) new_repeat_count = repeat_count if sink_mode: - repeat_count = 30 - sink_steps = 100 - ori_dataaet_size = ds.get_dataset_size() - new_size = sink_steps * batch_size - ds.set_dataset_size(new_size) - new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size()) + sink_size = 100 + new_repeat_count = 3 ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) @@ -112,10 +111,9 @@ def me_de_train_dataset(sink_mode=False): ds = ds.map(input_columns="input_ids", operations=type_cast_op) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_count) logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeat_count: {}".format(ds.get_repeat_count())) - return ds, new_repeat_count + return ds, new_repeat_count, sink_size def weight_variable(shape): @@ -125,6 +123,31 @@ def weight_variable(shape): return Tensor(ones) +class BertLearningRate(lr_schedules.LearningRateSchedule): + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + decay_lr = self.decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr + + class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() @@ -154,17 +177,29 @@ class TimeMonitor(Callback): self.epoch_mseconds_list.append(epoch_mseconds) self.per_step_mseconds_list.append(epoch_mseconds / self.data_size) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_bert_percision(): """test bert percision""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) - ds, new_repeat_count = me_de_train_dataset() + ds, new_repeat_count, _ = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = 16 config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*new_repeat_count, - start_learning_rate=5e-5, end_learning_rate=1e-9, - power=10.0, warmup_steps=0, weight_decay=0.01) + lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count, + learning_rate=5e-5, end_learning_rate=1e-9, + power=10.0, warmup_steps=0) + decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() + no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() + decay_params = list(filter(decay_filter, netwithloss.trainable_params())) + other_params = list(filter(no_decay_filter, netwithloss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}, + {'order_params': netwithloss.trainable_params()}] + optimizer = Lamb(group_params, lr) scale_window = 3 scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, @@ -174,7 +209,6 @@ def test_bert_percision(): callback = ModelCallback() params = netwithloss.trainable_params() for param in params: - param.init_data() value = param.default_input name = param.name if isinstance(value, Tensor): @@ -212,17 +246,31 @@ def test_bert_percision(): print("loss scale: {}".format(loss_scale)) assert np.allclose(loss_scale, expect_loss_scale, 0, 0) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_bert_performance(): """test bert performance""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) - ds, new_repeat_count = me_de_train_dataset(sink_mode=True) + ds, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True) version = os.getenv('VERSION', 'large') batch_size = 16 config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*new_repeat_count, - start_learning_rate=5e-5, end_learning_rate=1e-9, - power=10.0, warmup_steps=0, weight_decay=0.01) + + lr = BertLearningRate(decay_steps=sink_size * new_repeat_count, + learning_rate=5e-5, end_learning_rate=1e-9, + power=10.0, warmup_steps=0) + decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() + no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() + decay_params = list(filter(decay_filter, netwithloss.trainable_params())) + other_params = list(filter(no_decay_filter, netwithloss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}, + {'order_params': netwithloss.trainable_params()}] + optimizer = Lamb(group_params, lr) + scale_window = 3 scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, @@ -232,7 +280,6 @@ def test_bert_performance(): callback = ModelCallback() params = netwithloss.trainable_params() for param in params: - param.init_data() value = param.default_input name = param.name if isinstance(value, Tensor): @@ -249,9 +296,9 @@ def test_bert_performance(): else: logger.info("***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) - time_monitor_callback = TimeMonitor(ds.get_dataset_size()) + time_monitor_callback = TimeMonitor(sink_size) model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback], - dataset_sink_mode=True) + dataset_sink_mode=True, sink_size=sink_size) # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) diff --git a/tests/st/networks/models/deeplabv3/src/deeplabv3.py b/tests/st/networks/models/deeplabv3/src/deeplabv3.py index 906a207302..bbfc4dceb3 100644 --- a/tests/st/networks/models/deeplabv3/src/deeplabv3.py +++ b/tests/st/networks/models/deeplabv3/src/deeplabv3.py @@ -276,7 +276,7 @@ class SingleDeepLabV3(nn.Cell): atrous_rates=atrous_rates, output_stride=output_stride, fine_tune_batch_norm=fine_tune_batch_norm) - self.aspp.add_flags(loop_can_unroll=True) + atrous_rates_len = 0 if atrous_rates is not None: atrous_rates_len = len(atrous_rates) diff --git a/tests/st/networks/models/deeplabv3/test_deeplabv3.py b/tests/st/networks/models/deeplabv3/test_deeplabv3.py index d033a991e9..8ee6000eee 100644 --- a/tests/st/networks/models/deeplabv3/test_deeplabv3.py +++ b/tests/st/networks/models/deeplabv3/test_deeplabv3.py @@ -63,6 +63,7 @@ class LossCallBack(Callback): str(cb_params.net_outputs))) def model_fine_tune(train_net, fix_weight_layer): + train_net.init_parameters_data() for para in train_net.trainable_params(): para.set_parameter_data(Tensor(np.ones(para.data.shape).astype(np.float32) * 0.02)) if fix_weight_layer in para.name: @@ -79,7 +80,7 @@ def test_deeplabv3_1p(): args_opt.base_size = config.crop_size args_opt.crop_size = config.crop_size args_opt.batch_size = config.batch_size - train_dataset = create_dataset(args_opt, data_url, epoch_size, config.batch_size, + train_dataset = create_dataset(args_opt, data_url, 1, config.batch_size, usage="eval") dataset_size = train_dataset.get_dataset_size() callback = LossCallBack(dataset_size) diff --git a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py index 1ca4d388f7..0206335a6c 100644 --- a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py +++ b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py @@ -15,11 +15,16 @@ """Dataset help for minddata dataset""" from mindspore._checkparam import check_bool from mindspore.parallel._utils import _get_device_num, _get_parallel_mode -from mindspore.train.dataset_helper import _send_data from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ _to_full_shapes from mindspore.train.parallel_utils import ParallelMode +def _send_data(dataset): + """Engine dataset to write data to tdt queue.""" + if not hasattr(dataset, '__has_sent__'): + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send() + dataset.__has_sent__ = True class DatasetHelper: """ diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index e721b62c58..220b986208 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -155,7 +155,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): # train dataset dataset = create_dataset(dataset_path=dataset_path, do_train=True, - repeat_num=epoch_size, batch_size=config.batch_size) + repeat_num=1, batch_size=config.batch_size) step_size = dataset.get_dataset_size() eval_interval = config.eval_interval @@ -163,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): # evalutation dataset eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, - repeat_num=epoch_size, batch_size=config.eval_batch_size) + repeat_num=1, batch_size=config.eval_batch_size) # loss scale loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) @@ -174,9 +174,14 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): steps_per_epoch=step_size, lr_decay_mode=config.lr_decay_mode)) # optimizer - decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, - net.trainable_params())) - no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, {'params': no_decayed_params, 'weight_decay': 0.0}, {'order_params': net.trainable_params()}] @@ -260,14 +265,14 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): # train dataset dataset = create_dataset(dataset_path=dataset_path, do_train=True, - repeat_num=epoch_size, batch_size=thor_config.batch_size) + repeat_num=1, batch_size=thor_config.batch_size) step_size = dataset.get_dataset_size() eval_interval = thor_config.eval_interval # evalutation dataset eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, - repeat_num=epoch_size, batch_size=thor_config.eval_batch_size) + repeat_num=1, batch_size=thor_config.eval_batch_size) # loss scale loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False) diff --git a/tests/st/networks/test_cell_bprop.py b/tests/st/networks/test_cell_bprop.py new file mode 100644 index 0000000000..84d8d54d8f --- /dev/null +++ b/tests/st/networks/test_cell_bprop.py @@ -0,0 +1,419 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_cell_bprop """ +import numpy as np +import pytest + +import mindspore as ms +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Parameter, ParameterTuple +from mindspore import context +from mindspore.common.initializer import initializer +from mindspore.common.tensor import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class MulAdd(nn.Cell): + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result + return 2 * dout, 2 * y + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_mul_add(): + mul_add = MulAdd() + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(mul_add)(x, y) == (2, 4) + + +class InlineMulADD(nn.Cell): + def __init__(self): + super(InlineMulADD, self).__init__() + self.mul_add = MulAdd() + self.param = 2 + + def construct(self, x, y): + return self.mul_add(x, y) + x + self.param * y + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_inline_mul_add(): + inline_mul_add = InlineMulADD() + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(inline_mul_add)(x, y) == (3, 6) + + +class WithParameter(nn.Cell): + def __init__(self): + super(WithParameter, self).__init__() + self.param1 = Parameter(1, 'param1') + self.param2 = Parameter(2, 'param2') + + def construct(self, x, y): + return self.param1 * self.param2 * x + y + + def bprop(self, x, y, out, dout): + # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result + return self.param1 * self.param2 * dout, 2 * y + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_with_param(): + with_param = WithParameter() + with pytest.raises(RuntimeError): + C.grad_all(with_param)(1, 2) + + +class WithNoBprop(nn.Cell): + def construct(self, x, y): + return 2 * x + y + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_with_no_bprop(): + with_no_bprop = WithNoBprop() + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(with_no_bprop)(x, y) == (2, 1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_in_bprop_1(): + class GradInBprop_1(nn.Cell): + def __init__(self): + super(GradInBprop_1, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + return self.relu(x) + + class GradInBprop_2(nn.Cell): + def __init__(self): + super(GradInBprop_2, self).__init__() + self.f = GradInBprop_1() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return out[1][0], grads[1] + + class GradInBprop_3(nn.Cell): + def __init__(self): + super(GradInBprop_3, self).__init__() + self.f = GradInBprop_2() + + def construct(self, x, y): + return self.f(x, y) + + grad_in_bprop = GradInBprop_3() + grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), + Tensor(np.ones([2, 2]).astype(np.float32))) + assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_in_bprop_2(): + class GradInBprop_1(nn.Cell): + def __init__(self): + super(GradInBprop_1, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + return self.relu(x) + + def bprop(self, x, y, out, dout): + return x * y, y + x + + class GradInBprop_2(nn.Cell): + def __init__(self): + super(GradInBprop_2, self).__init__() + self.f = GradInBprop_1() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return out[1][0], grads[1] + + class GradInBprop_3(nn.Cell): + def __init__(self): + super(GradInBprop_3, self).__init__() + self.f = GradInBprop_2() + + def construct(self, x, y): + return self.f(x, y) + + grad_in_bprop = GradInBprop_3() + grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), + Tensor(np.ones([2, 2]).astype(np.float32))) + assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_in_bprop_3(): + class GradInBprop_1(nn.Cell): + def __init__(self): + super(GradInBprop_1, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + return self.relu(x) + + class GradInBprop_2(nn.Cell): + def __init__(self): + super(GradInBprop_2, self).__init__() + self.f = GradInBprop_1() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return out[1][0], grads[1] + + class GradInBprop_3(nn.Cell): + def __init__(self): + super(GradInBprop_3, self).__init__() + self.f = GradInBprop_2() + + def construct(self, x, y): + return self.f(x, y) + + def bprop(self, x, y, out, dout): + return x + y + y + out[0], x + x + y + y + dout[0] + + grad_in_bprop = GradInBprop_3() + grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), + Tensor(np.ones([2, 2]).astype(np.float32))) + assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() + + +class OneInputBprop(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.ReLU() + + def construct(self, x): + return self.op(x) + + def bprop(self, x, out, dout): + return (5 * x,) + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_one_input_bprop(): + net = OneInputBprop() + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + grad = C.grad_all(net)(input1) + assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() + + +class TwoInput(nn.Cell): + def construct(self, x, y): + return x * y + + +class InlineBpropTwoInput(nn.Cell): + def __init__(self): + super().__init__() + self.f = TwoInput() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return grads[0] * 2, grads[1] * 2 + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_inline_bprop_two_input(): + net = InlineBpropTwoInput() + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + input2 = Tensor(np.ones([2, 2]).astype(np.float32)) + grads = C.grad_all(net)(input1, input2) + assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() + assert len(grads) == 2 + + +class TwoInputBprop(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x, y): + return self.op(x, y) + + def bprop(self, x, y, out, dout): + return 5 * x, 8 * y + + +class TwoInputWithParameter(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") + + def construct(self, x, y): + x = self.inputdata + x + return self.op(x, y) + + +class TwoInputWithOnlyInitParameterBprop(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") + + def construct(self, x, y): + return self.op(x, y) + + def bprop(self, x, y, out, dout): + return 5 * x, 8 * y + + +class InlineMutilTwoInputParameterCell(nn.Cell): + def __init__(self): + super().__init__() + self.f1 = TwoInputBprop() + self.f2 = TwoInput() + self.f3 = TwoInputWithParameter() + self.f4 = TwoInputWithOnlyInitParameterBprop() + + def construct(self, x, y): + output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) + return output + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_inline_bprop_multi_input(): + net = InlineMutilTwoInputParameterCell() + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + input2 = Tensor(np.ones([2, 2]).astype(np.float32)) + net.init_parameters_data() + grads = C.grad_all(net)(input1, input2) + assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() + assert len(grads) == 2 + + +class MulAddWithParam(nn.Cell): + def __init__(self): + super(MulAddWithParam, self).__init__() + self.mul_add = MulAdd() + self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param') + + def construct(self, x): + return self.mul_add(self.param, x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_refkey_bprop(): + grad_by_list = C.GradOperation('get_by_list', get_all=True, get_by_list=True) + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + def construct(self, x): + weights = self.weights + grads = grad_by_list(self.network, weights)(x) + return grads + network = GradWrap(MulAddWithParam()) + input_data = Tensor(np.array([2, 2], np.float32)) + grads = network(input_data) + assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() + assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() + + +class MulAddWithWrongOutputNum(nn.Cell): + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + return (2 * dout,) + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_mul_add_with_wrong_output_num(): + context.set_context(check_bprop=True) + mul_add = MulAddWithWrongOutputNum() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, 2) + + +class MulAddWithWrongOutputType(nn.Cell): + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + return 2 * dout, 2 + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_mul_add_with_wrong_output_type(): + context.set_context(check_bprop=True) + mul_add = MulAddWithWrongOutputType() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) + + +class MulAddWithWrongOutputShape(nn.Cell): + def __init__(self): + super(MulAddWithWrongOutputShape, self).__init__() + self.ones = Tensor(np.ones([2,])) + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + return 2, self.ones + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_mul_add_with_wrong_output_shape(): + context.set_context(check_bprop=True) + mul_add = MulAddWithWrongOutputShape() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) diff --git a/tests/st/networks/test_gpu_resnet.py b/tests/st/networks/test_gpu_resnet.py index 6bd947c712..d440c5cacb 100644 --- a/tests/st/networks/test_gpu_resnet.py +++ b/tests/st/networks/test_gpu_resnet.py @@ -355,7 +355,7 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=170): +def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=338): net = resnet50(num_classes) lr = 0.1 momentum = 0.9 diff --git a/tests/st/ops/ascend/test_addn.py b/tests/st/ops/ascend/test_addn.py index 6d0d5b5be0..fa97fcc973 100644 --- a/tests/st/ops/ascend/test_addn.py +++ b/tests/st/ops/ascend/test_addn.py @@ -16,6 +16,7 @@ import numpy as np import mindspore.context as context import mindspore.nn as nn +import mindspore.ops.composite as C from mindspore import Tensor from mindspore.ops import operations as P @@ -45,3 +46,17 @@ def test_net(): add = Net() output = add(x, y) assert output == expect + + +def test_grad_addn_with_list(): + grad_op = C.GradOperation('get_all', get_all=True) + class AddN(nn.Cell): + def __init__(self): + super().__init__() + self.add_n = P.AddN() + + def construct(self, a, b): + return self.add_n([a, b]) + + inp = Tensor(np.ones([128, 96]).astype(np.float32)) + grad_op(AddN())(inp, inp) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_cast.py b/tests/st/ops/ascend/test_aicpu_ops/test_cast.py index c236c866c0..8c2687796b 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_cast.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_cast.py @@ -32,8 +32,8 @@ class Net(nn.Cell): return self.cast(self.x, self.dtype) def test_net_f32_bool(): - x = np.random.randn(3,4).astype(np.float32) - x[:,1] = 0 + x = np.random.randn(3, 4).astype(np.float32) + x[:, 1] = 0 net = Net(Tensor(x), mstype.bool_) output = net() print(output.asnumpy()) @@ -41,8 +41,8 @@ def test_net_f32_bool(): print(output.dtype) def test_net_f16_bool(): - x = np.random.randn(3,4).astype(np.float16) - x[:,1] = 0 + x = np.random.randn(3, 4).astype(np.float16) + x[:, 1] = 0 net = Net(Tensor(x), mstype.bool_) output = net() print(output.asnumpy()) @@ -50,8 +50,8 @@ def test_net_f16_bool(): print(output.dtype) def test_net_f64_bool(): - x = np.random.randn(3,4).astype(np.float64) - x[:,1] = 0 + x = np.random.randn(3, 4).astype(np.float64) + x[:, 1] = 0 net = Net(Tensor(x), mstype.bool_) output = net() print(output.asnumpy()) @@ -59,7 +59,7 @@ def test_net_f64_bool(): print(output.dtype) def test_net_int16_float16(): - x = np.random.randint(-512, 512, size=(3,4)).astype(np.int16) + x = np.random.randint(-512, 512, size=(3, 4)).astype(np.int16) net = Net(Tensor(x), mstype.float16) output = net() print(output.asnumpy()) @@ -67,7 +67,7 @@ def test_net_int16_float16(): print(output.dtype) def test_net_int64_float16(): - x = np.random.randint(-512, 512, size=(3,4)).astype(np.int64) + x = np.random.randint(-512, 512, size=(3, 4)).astype(np.int64) net = Net(Tensor(x), mstype.float16) output = net() print(output.asnumpy()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py b/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py index 61bb3f8476..c891c7f863 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py @@ -53,4 +53,3 @@ def test_net_ND(): talpha, tbeta = Tensor(alpha), Tensor(beta) output = net(talpha, tbeta) assert output.shape == (3, 2, 2) - diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py b/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py index 1cfd6c0ede..75e207c451 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_laplace.py @@ -54,4 +54,4 @@ def test_net_ND(): tmean, tlambda_param = Tensor(mean), Tensor(lambda_param) output = net(tmean, tlambda_param) print(output.asnumpy()) - assert output.shape == (3, 2, 2) \ No newline at end of file + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py b/tests/st/ops/ascend/test_aicpu_ops/test_normal.py index 70c9c2c68f..01ecbf5ec4 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_normal.py @@ -55,4 +55,5 @@ def test_net_ND(): net = Net(shape, seed) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) output = net(tmean, tstddev) - assert output.shape == (3, 2, 2) \ No newline at end of file + assert output.shape == (3, 2, 2) + diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py b/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py index dd5ada2712..7720d303d6 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py @@ -48,4 +48,5 @@ def test_net_2(): net = Net(shape=shape) tmean = Tensor(mean) output = net(tmean) - assert output.shape == (4, 2) \ No newline at end of file + print(output.asnumpy()) + assert output.shape == (4, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py b/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py index 6304e8b111..a581636ad4 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py @@ -12,27 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import numpy as np import mindspore -from mindspore import Tensor -from mindspore.ops import operations as P import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self, num_sample): - super(Net, self).__init__() - self.random_categorical = P.RandomCategorical(mindspore.int64) - self.num_sample = num_sample + def __init__(self, num_sample): + super(Net, self).__init__() + self.random_categorical = P.RandomCategorical(mindspore.int64) + self.num_sample = num_sample - def construct(self, logits, seed=0): - return self.random_categorical(logits, self.num_sample, seed) + def construct(self, logits, seed=0): + return self.random_categorical(logits, self.num_sample, seed) def test_net(): - x = np.random.random((10, 5)).astype(np.float32) - net = Net(8) - output = net(Tensor(x)) - print(x) - print(output.asnumpy()) - print(output.dtype()) + x = np.random.random((10, 5)).astype(np.float32) + net = Net(8) + output = net(Tensor(x)) + print(x) + print(output.asnumpy()) + #print(output.dtype()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py b/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py index c7e2df07f8..705aa0e594 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py @@ -12,32 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import mindspore as ms -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function import numpy as np +import mindspore.nn as nn import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.rnnt_loss = P.RNNTLoss(blank_label=0) + def __init__(self): + super(Net, self).__init__() + self.rnnt_loss = P.RNNTLoss(blank_label=0) - def construct(self, acts, labels, act_lens, label_lens): - return self.rnnt_loss(acts, labels, act_lens, label_lens) + def construct(self, acts, labels, act_lens, label_lens): + return self.rnnt_loss(acts, labels, act_lens, label_lens) def test_net(): - B, T, U, V = 1, 2, 3, 5 - acts = np.random.random((B, T, U, V)).astype(np.float32) - labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32) - input_length = np.array([T] * B).astype(np.int32) - label_length = np.array([len(l) for l in labels]).astype(np.int32) - - rnnt_loss = Net() - costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) - print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) - print(costs.asnumpy()) - print(grads.asnumpy()) + B, T, U, V = 1, 2, 3, 5 + acts = np.random.random((B, T, U, V)).astype(np.float32) + labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32) + input_length = np.array([T] * B).astype(np.int32) + label_length = np.array([len(l) for l in labels]).astype(np.int32) + rnnt_loss = Net() + costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + print(costs.asnumpy()) + print(grads.asnumpy()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py b/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py index f45cb462cc..818ae092b1 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py @@ -13,13 +13,8 @@ # limitations under the License. # ============================================================================ -import numpy as np -import pytest - import mindspore.context as context import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common import dtype as mstype from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -43,5 +38,4 @@ def test_net(): shape = (3, 2, 4) net = Net(shape, seed, seed2) output = net() - print(output.asnumpy()) assert output.shape == (3, 2, 4) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py b/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py index 422572857e..f3d30e059d 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py @@ -34,8 +34,8 @@ class Net(nn.Cell): return self.strided_slice(input, self.begin, self.end, self.strides) -input_x = np.array([[[0, 1, 2], [3, 4, 5]], - [[6, 7, 8], [9, 10, 11]], +input_x = np.array([[[0, 1, 2], [3, 4, 5]], + [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]] ]).astype(np.float32) begin = (1, 0, 0) @@ -48,4 +48,4 @@ def test_net(): tinput = Tensor(input_x) output = net(tinput) print(output.asnumpy()) - assert np.all([[[6, 8], [9, 11]]] == output.asnumpy()) \ No newline at end of file + assert np.all([[[6, 8], [9, 11]]] == output.asnumpy()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py b/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py index 116ceb7451..54ba40d91d 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py @@ -47,7 +47,7 @@ def test_net(): tdy = Tensor(dy) output = net(tdy) print(output.asnumpy()) - assert np.all([[[0, 0, 0], [0, 0, 0]], - [[6, 0, 8], [9, 0, 11]], - [[0, 0, 0], [0, 0, 0]] - ] == output.asnumpy()) \ No newline at end of file + assert np.all([[[0, 0, 0], [0, 0, 0]], + [[6, 0, 8], [9, 0, 11]], + [[0, 0, 0], [0, 0, 0]] + ] == output.asnumpy()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py index fef34fad07..5777aec5fb 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py @@ -41,3 +41,15 @@ def test_net_1D(): ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32) output = net(ta, tb) assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 2, 1) + a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32) + b = np.array([10]).astype(np.int32) + net = Net(shape, seed) + ta, tb = Tensor(a), Tensor(b) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py index 7ab0b42e11..d5e643b3f9 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py @@ -36,3 +36,15 @@ def test_net(): net = Net(shape, seed=seed) output = net() assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 2, 1) + a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.float32) + b = np.array([10]).astype(np.float32) + net = Net(shape, seed) + ta, tb = Tensor(a), Tensor(b) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_bijector/test_exp.py b/tests/st/ops/ascend/test_bijector/test_exp.py new file mode 100644 index 0000000000..7e3f16a9a8 --- /dev/null +++ b/tests/st/ops/ascend/test_bijector/test_exp.py @@ -0,0 +1,105 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for exp""" +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: forward pass of bijector. + """ + def __init__(self): + super(Net, self).__init__() + self.bijector = msb.Exp() + + def construct(self, x_): + forward = self.bijector.forward(x_) + return forward + +def test_forward(): + x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + tx = Tensor(x, dtype=dtype.float32) + forward = Net() + ans = forward(tx) + expected = np.exp(x) + tol = 1e-5 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net1(nn.Cell): + """ + Test class: inverse pass of bijector. + """ + def __init__(self): + super(Net1, self).__init__() + self.bijector = msb.Exp() + + def construct(self, y_): + inverse = self.bijector.inverse(y_) + return inverse + +def test_inverse(): + y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + ty = Tensor(y, dtype=dtype.float32) + inverse = Net1() + ans = inverse(ty) + expected = np.log(y) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net2(nn.Cell): + """ + Test class: Forward Jacobian. + """ + def __init__(self): + super(Net2, self).__init__() + self.bijector = msb.Exp() + + def construct(self, x_): + return self.bijector.forward_log_jacobian(x_) + +def test_forward_jacobian(): + x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + tx = Tensor(x, dtype=dtype.float32) + forward_jacobian = Net2() + ans = forward_jacobian(tx) + expected = x + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net3(nn.Cell): + """ + Test class: Backward Jacobian. + """ + def __init__(self): + super(Net3, self).__init__() + self.bijector = msb.Exp() + + def construct(self, y_): + return self.bijector.inverse_log_jacobian(y_) + +def test_inverse_jacobian(): + y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + ty = Tensor(y, dtype=dtype.float32) + inverse_jacobian = Net3() + ans = inverse_jacobian(ty) + expected = -np.log(y) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() diff --git a/tests/st/ops/ascend/test_bijector/test_power_transform.py b/tests/st/ops/ascend/test_bijector/test_power_transform.py new file mode 100644 index 0000000000..d76b67a8ee --- /dev/null +++ b/tests/st/ops/ascend/test_bijector/test_power_transform.py @@ -0,0 +1,109 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for powertransform""" +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: forward pass of bijector. + """ + def __init__(self, power): + super(Net, self).__init__() + self.bijector = msb.PowerTransform(power=power) + + def construct(self, x_): + forward = self.bijector.forward(x_) + return forward + +def test_forward(): + power = 2 + x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + tx = Tensor(x, dtype=dtype.float32) + forward = Net(power=power) + ans = forward(tx) + expected = np.exp(np.log1p(x * power) / power) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net1(nn.Cell): + """ + Test class: inverse pass of bijector. + """ + def __init__(self, power): + super(Net1, self).__init__() + self.bijector = msb.PowerTransform(power=power) + + def construct(self, y_): + inverse = self.bijector.inverse(y_) + return inverse + +def test_inverse(): + power = 2 + y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + ty = Tensor(y, dtype=dtype.float32) + inverse = Net1(power=power) + ans = inverse(ty) + expected = np.expm1(np.log(y) * power) / power + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net2(nn.Cell): + """ + Test class: Forward Jacobian. + """ + def __init__(self, power): + super(Net2, self).__init__() + self.bijector = msb.PowerTransform(power=power) + + def construct(self, x_): + return self.bijector.forward_log_jacobian(x_) + +def test_forward_jacobian(): + power = 2 + x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + tx = Tensor(x, dtype=dtype.float32) + forward_jacobian = Net2(power=power) + ans = forward_jacobian(tx) + expected = (1 / power - 1) * np.log1p(x * power) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net3(nn.Cell): + """ + Test class: Backward Jacobian. + """ + def __init__(self, power): + super(Net3, self).__init__() + self.bijector = msb.PowerTransform(power=power) + + def construct(self, y_): + return self.bijector.inverse_log_jacobian(y_) + +def test_inverse_jacobian(): + power = 2 + y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32) + ty = Tensor(y, dtype=dtype.float32) + inverse_jacobian = Net3(power=power) + ans = inverse_jacobian(ty) + expected = (power - 1) * np.log(y) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py index 5652d536c7..2dc2300f58 100644 --- a/tests/st/ops/ascend/test_distribution/test_bernoulli.py +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -12,77 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""test cases for bernoulli distribution""" +"""test cases for Bernoulli distribution""" import numpy as np from scipy import stats import mindspore.context as context import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class Net(nn.Cell): +class Prob(nn.Cell): """ - Test class: probability of bernoulli distribution. + Test class: probability of Bernoulli distribution. """ def __init__(self): - super(Net, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + super(Prob, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) - @ms_function def construct(self, x_): - return self.b('prob', x_) - -class Net1(nn.Cell): - """ - Test class: log probability of bernoulli distribution. - """ - def __init__(self): - super(Net1, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) - - @ms_function - def construct(self, x_): - return self.b('log_prob', x_) - -class Net2(nn.Cell): - """ - Test class: kl_loss between bernoulli distributions. - """ - def __init__(self): - super(Net2, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) - - @ms_function - def construct(self, x_): - return self.b('kl_loss', 'Bernoulli', x_) - -class Net3(nn.Cell): - """ - Test class: mean/sd of bernoulli distribution. - """ - def __init__(self): - super(Net3, self).__init__() - self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) - - @ms_function - def construct(self): - return self.b('mean'), self.b('sd') - -class Net4(nn.Cell): - """ - Test class: log probability of bernoulli distribution. - """ - def __init__(self, shape, seed=0): - super(Net4, self).__init__() - self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) - self.shape = shape - - @ms_function - def construct(self, probs=None): - return self.b('sample', self.shape, probs) + return self.b.prob(x_) def test_pmf(): """ @@ -90,24 +40,47 @@ def test_pmf(): """ bernoulli_benchmark = stats.bernoulli(0.7) expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) - pdf = Net() + pmf = Prob() x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) - output = pdf(x_) + output = pmf(x_) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Bernoulli distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b.log_prob(x_) + def test_log_likelihood(): """ Test log_pmf. """ bernoulli_benchmark = stats.bernoulli(0.7) expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) - logprob = Net1() + logprob = LogProb() x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) output = logprob(x_) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() +class KL(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b.kl_loss('Bernoulli', x_) + def test_kl_loss(): """ Test kl_loss. @@ -117,31 +90,194 @@ def test_kl_loss(): probs0_a = 1 - probs1_a probs0_b = 1 - probs1_b expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) - kl_loss = Net2() + kl_loss = KL() output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Bernoulli distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) + + def construct(self): + return self.b.mean(), self.b.sd(), self.b.mode() + def test_basics(): """ - Test mean/standard deviation and probs. + Test mean/standard deviation/mode. """ - basics = Net3() - mean, sd = basics() - expect_mean = [0.5, 0.5] - assert (mean.asnumpy() == expect_mean).all() - assert (sd.asnumpy() == expect_mean).all() - b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) - probs = b.probs() - expect_probs = [0.7, 0.5] + basics = Basics() + mean, sd, mode = basics() + expect_mean = [0.3, 0.5, 0.7] + expect_sd = np.sqrt(np.multiply([0.7, 0.5, 0.3], [0.3, 0.5, 0.7])) + expect_mode = [0.0, 0.0, 1.0] tol = 1e-6 - assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: log probability of Bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + def construct(self, probs=None): + return self.b.sample(self.shape, probs) def test_sample(): """ Test sample. """ shape = (2, 3) - sample = Net4(shape) + sample = Sampling(shape) output = sample() assert output.shape == (2, 3, 2) + +class CDF(nn.Cell): + """ + Test class: cdf of bernoulli distributions. + """ + def __init__(self): + super(CDF, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32) + x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) + cdf = CDF() + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log cdf of bernoulli distributions. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b.log_cdf(x_) + +def test_logcdf(): + """ + Test log_cdf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_logcdf = bernoulli_benchmark.logcdf([0, 0, 1, 0, 1]).astype(np.float32) + x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) + logcdf = LogCDF() + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + + +class SF(nn.Cell): + """ + Test class: survival function of Bernoulli distributions. + """ + def __init__(self): + super(SF, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b.survival_function(x_) + +def test_survival(): + """ + Test survival funciton. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_survival = bernoulli_benchmark.sf([0, 1, 1, 0, 0]).astype(np.float32) + x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32) + sf = SF() + output = sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + + +class LogSF(nn.Cell): + """ + Test class: log survival function of Bernoulli distributions. + """ + def __init__(self): + super(LogSF, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b.log_survival(x_) + +def test_log_survival(): + """ + Test log survival funciton. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_logsurvival = bernoulli_benchmark.logsf([-1, 0.9, 0, 0, 0]).astype(np.float32) + x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]).astype(np.float32), dtype=dtype.float32) + log_sf = LogSF() + output = log_sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Bernoulli distributions. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self): + return self.b.entropy() + +def test_entropy(): + """ + Test entropy. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_entropy = bernoulli_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between bernoulli distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.b = msd.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + entropy = self.b.entropy() + kl_loss = self.b.kl_loss('Bernoulli', x_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.b.cross_entropy('Bernoulli', x_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + prob = Tensor([0.3], dtype=dtype.float32) + diff = cross_entropy(prob) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_exponential.py b/tests/st/ops/ascend/test_distribution/test_exponential.py new file mode 100644 index 0000000000..ba1689c6f9 --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_exponential.py @@ -0,0 +1,280 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Exponential distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Exponential distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self, x_): + return self.e.prob(x_) + +def test_pdf(): + """ + Test pdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_pdf = expon_benchmark.pdf([-1.0, 0.0, 1.0]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Exponential distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self, x_): + return self.e.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_logpdf = expon_benchmark.logpdf([0.5, 1.0, 2.0]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between Exponential distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.e = msd.Exponential([1.5], dtype=dtype.float32) + + def construct(self, x_): + return self.e.kl_loss('Exponential', x_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + rate_a = 1.5 + rate_b = np.array([0.5, 2.0]).astype(np.float32) + expect_kl_loss = np.log(rate_a) - np.log(rate_b) + rate_b / rate_a - 1.0 + kl = KL() + output = kl(Tensor(rate_b, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Exponential distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.e = msd.Exponential([0.5], dtype=dtype.float32) + + def construct(self): + return self.e.mean(), self.e.sd(), self.e.mode() + +def test_basics(): + """ + Test mean/standard/mode deviation. + """ + basics = Basics() + mean, sd, mode = basics() + expect_mean = 2. + expect_sd = 2. + expect_mode = 0. + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Exponential distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, rate=None): + return self.e.sample(self.shape, rate) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(rate) + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Exponential distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self, x_): + return self.e.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_cdf = expon_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) + cdf = CDF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Exponential distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self, x_): + return self.e.log_cdf(x_) + +def test_log_cdf(): + """ + Test log_cdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_logcdf = expon_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) + logcdf = LogCDF() + x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Exponential distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self, x_): + return self.e.survival_function(x_) + +def test_survival(): + """ + Test survival function. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_survival = expon_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) + survival = SF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = survival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Exponential distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self, x_): + return self.e.log_survival(x_) + +def test_log_survival(): + """ + Test log survival function. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_logsurvival = expon_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) + logsurvival = LogSF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = logsurvival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Exponential distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + def construct(self): + return self.e.entropy() + +def test_entropy(): + """ + Test entropy. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_entropy = expon_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Exponential distribution. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.e = msd.Exponential([1.0], dtype=dtype.float32) + + def construct(self, x_): + entropy = self.e.entropy() + kl_loss = self.e.kl_loss('Exponential', x_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.e.cross_entropy('Exponential', x_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + rate = Tensor([0.5], dtype=dtype.float32) + diff = cross_entropy(rate) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_geometric.py b/tests/st/ops/ascend/test_distribution/test_geometric.py new file mode 100644 index 0000000000..6b2a5ba84d --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_geometric.py @@ -0,0 +1,280 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Geometric distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Geometric distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.prob(x_) + +def test_pmf(): + """ + Test pmf. + """ + geom_benchmark = stats.geom(0.7) + expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Geometric distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pmf. + """ + geom_benchmark = stats.geom(0.7) + expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between Geometric distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.kl_loss('Geometric', x_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + probs1_a = 0.7 + probs1_b = 0.5 + probs0_a = 1 - probs1_a + probs0_b = 1 - probs1_b + expect_kl_loss = np.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * np.log(probs0_a / probs0_b) + kl_loss = KL() + output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Geometric distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32) + + def construct(self): + return self.g.mean(), self.g.sd(), self.g.mode() + +def test_basics(): + """ + Test mean/standard deviation/mode. + """ + basics = Basics() + mean, sd, mode = basics() + expect_mean = [1.0, 1.0] + expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5]))) + expect_mode = [0.0, 0.0] + tol = 1e-6 + assert (np.abs(mean.asnumpy()- expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + def construct(self, probs=None): + return self.g.sample(self.shape, probs) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + sample = Sampling(shape) + output = sample() + assert output.shape == (2, 3, 2) + +class CDF(nn.Cell): + """ + Test class: cdf of Geometric distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + geom_benchmark = stats.geom(0.7) + expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32) + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) + cdf = CDF() + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log cdf of Geometric distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.log_cdf(x_) + +def test_logcdf(): + """ + Test log_cdf. + """ + geom_benchmark = stats.geom(0.7) + expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32) + x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) + logcdf = LogCDF() + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survial funciton of Geometric distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.survival_function(x_) + +def test_survival(): + """ + Test survival function. + """ + geom_benchmark = stats.geom(0.7) + expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32) + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) + sf = SF() + output = sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survial funciton of Geometric distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.g.log_survival(x_) + +def test_log_survival(): + """ + Test log_survival function. + """ + geom_benchmark = stats.geom(0.7) + expect_logsurvival = geom_benchmark.logsf([0, 1, 2, 3, 4]).astype(np.float32) + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) + log_sf = LogSF() + output = log_sf(x_) + tol = 5e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Geometric distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self): + return self.g.entropy() + +def test_entropy(): + """ + Test entropy. + """ + geom_benchmark = stats.geom(0.7) + expect_entropy = geom_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Geometric distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.g = msd.Geometric(0.7, dtype=dtype.int32) + + def construct(self, x_): + entropy = self.g.entropy() + kl_loss = self.g.kl_loss('Geometric', x_) + h_sum_kl = entropy + kl_loss + ans = self.g.cross_entropy('Geometric', x_) + return h_sum_kl - ans + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + prob = Tensor([0.5], dtype=dtype.float32) + diff = cross_entropy(prob) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py index 52bb1173ee..ee851281ef 100644 --- a/tests/st/ops/ascend/test_distribution/test_normal.py +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -12,77 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""test cases for normal distribution""" +"""test cases for Normal distribution""" import numpy as np from scipy import stats import mindspore.context as context import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class Net(nn.Cell): +class Prob(nn.Cell): """ - Test class: probability of normal distribution. + Test class: probability of Normal distribution. """ def __init__(self): - super(Net, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + super(Prob, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - @ms_function def construct(self, x_): - return self.n('prob', x_) - -class Net1(nn.Cell): - """ - Test class: log probability of normal distribution. - """ - def __init__(self): - super(Net1, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - - @ms_function - def construct(self, x_): - return self.n('log_prob', x_) - -class Net2(nn.Cell): - """ - Test class: kl_loss of normal distribution. - """ - def __init__(self): - super(Net2, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) - - @ms_function - def construct(self, x_, y_): - return self.n('kl_loss', 'Normal', x_, y_) - -class Net3(nn.Cell): - """ - Test class: mean/sd of normal distribution. - """ - def __init__(self): - super(Net3, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) - - @ms_function - def construct(self): - return self.n('mean'), self.n('sd') - -class Net4(nn.Cell): - """ - Test class: mean/sd of normal distribution. - """ - def __init__(self, shape, seed=0): - super(Net4, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) - self.shape = shape - - @ms_function - def construct(self, mean=None, sd=None): - return self.n('sample', self.shape, mean, sd) + return self.n.prob(x_) def test_pdf(): """ @@ -90,22 +40,46 @@ def test_pdf(): """ norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) - pdf = Net() + pdf = Prob() output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() +class LogProb(nn.Cell): + """ + Test class: log probability of Normal distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.n.log_prob(x_) + def test_log_likelihood(): """ Test log_pdf. """ norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) - logprob = Net1() + logprob = LogProb() output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + """ + def __init__(self): + super(KL, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + return self.n.kl_loss('Normal', x_, y_) + + def test_kl_loss(): """ Test kl_loss. @@ -120,25 +94,49 @@ def test_kl_loss(): squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale - kl_loss = Net2() + kl_loss = KL() mean = Tensor(mean_b, dtype=dtype.float32) sd = Tensor(sd_b, dtype=dtype.float32) output = kl_loss(mean, sd) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Normal distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) + + def construct(self): + return self.n.mean(), self.n.sd(), self.n.mode() + def test_basics(): """ - Test mean/standard deviation. + Test mean/standard deviation/mode. """ - basics = Net3() - mean, sd = basics() + basics = Basics() + mean, sd, mode = basics() expect_mean = [3.0, 3.0] expect_sd = [2.0, 4.0] tol = 1e-6 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mean) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() +class Sampling(nn.Cell): + """ + Test class: sample of Normal distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, mean=None, sd=None): + return self.n.sample(self.shape, mean, sd) + def test_sample(): """ Test sample. @@ -147,6 +145,180 @@ def test_sample(): seed = 10 mean = Tensor([2.0], dtype=dtype.float32) sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) - sample = Net4(shape, seed=seed) + sample = Sampling(shape, seed=seed) output = sample(mean, sd) assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Normal distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.n.cdf(x_) + + +def test_cdf(): + """ + Test cdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_cdf = norm_benchmark.cdf([1.0, 2.0]).astype(np.float32) + cdf = CDF() + output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Mormal distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.n.log_cdf(x_) + +def test_log_cdf(): + """ + Test log cdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_logcdf = norm_benchmark.logcdf([1.0, 2.0]).astype(np.float32) + logcdf = LogCDF() + output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 5e-5 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Normal distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.n.survival_function(x_) + +def test_survival(): + """ + Test log_survival. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_survival = norm_benchmark.sf([1.0, 2.0]).astype(np.float32) + survival_function = SF() + output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Normal distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.n.log_survival(x_) + +def test_log_survival(): + """ + Test log_survival. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_log_survival = norm_benchmark.logsf([1.0, 2.0]).astype(np.float32) + log_survival = LogSF() + output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Normal distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self): + return self.n.entropy() + +def test_entropy(): + """ + Test entropy. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_entropy = norm_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Normal distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + entropy = self.n.entropy() + kl_loss = self.n.kl_loss('Normal', x_, y_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.n.cross_entropy('Normal', x_, y_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + mean = Tensor([1.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + diff = cross_entropy(mean, sd) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() + +class Net(nn.Cell): + """ + Test class: expand single distribution instance to multiple graphs + by specifying the attributes. + """ + + def __init__(self): + super(Net, self).__init__() + self.normal = msd.Normal(0., 1., dtype=dtype.float32) + + def construct(self, x_, y_): + kl = self.normal.kl_loss('Normal', x_, y_) + prob = self.normal.prob(kl) + return prob + +def test_multiple_graphs(): + """ + Test multiple graphs case. + """ + prob = Net() + mean_a = np.array([0.0]).astype(np.float32) + sd_a = np.array([1.0]).astype(np.float32) + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + ans = prob(Tensor(mean_b), Tensor(sd_b)) + + diff_log_scale = np.log(sd_a) - np.log(sd_b) + squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) + expect_kl_loss = 0.5 * squared_diff + 0.5 * \ + np.expm1(2 * diff_log_scale) - diff_log_scale + + norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0])) + expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32) + + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_uniform.py b/tests/st/ops/ascend/test_distribution/test_uniform.py new file mode 100644 index 0000000000..5e54f2cdcc --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_uniform.py @@ -0,0 +1,282 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Uniform distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Uniform distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) + + def construct(self, x_): + return self.u.prob(x_) + +def test_pdf(): + """ + Test pdf. + """ + uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) + expect_pdf = uniform_benchmark.pdf([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Uniform distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) + + def construct(self, x_): + return self.u.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) + expect_logpdf = uniform_benchmark.logpdf([0.5]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([0.5]).astype(np.float32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between Uniform distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) + + def construct(self, x_, y_): + return self.u.kl_loss('Uniform', x_, y_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + low_a = 0.0 + high_a = 1.5 + low_b = -1.0 + high_b = 2.0 + expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a) + kl = KL() + output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd of Uniform distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32) + + def construct(self): + return self.u.mean(), self.u.sd() + +def test_basics(): + """ + Test mean/standard deviation. + """ + basics = Basics() + mean, sd = basics() + expect_mean = [1.5] + expect_sd = np.sqrt([0.75]) + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Uniform distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, low=None, high=None): + return self.u.sample(self.shape, low, high) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + low = Tensor([1.0], dtype=dtype.float32) + high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(low, high) + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Uniform distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) + + def construct(self, x_): + return self.u.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_cdf = uniform_benchmark.cdf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) + cdf = CDF() + x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Uniform distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) + + def construct(self, x_): + return self.u.log_cdf(x_) + +class SF(nn.Cell): + """ + Test class: survival function of Uniform distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) + + def construct(self, x_): + return self.u.survival_function(x_) + +class LogSF(nn.Cell): + """ + Test class: log survival function of Uniform distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) + + def construct(self, x_): + return self.u.log_survival(x_) + +class EntropyH(nn.Cell): + """ + Test class: entropy of Uniform distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) + + def construct(self): + return self.u.entropy() + +def test_entropy(): + """ + Test entropy. + """ + uniform_benchmark = stats.uniform([0.0], [1.0, 2.0]) + expect_entropy = uniform_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross_entropy between Uniform distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) + + def construct(self, x_, y_): + entropy = self.u.entropy() + kl_loss = self.u.kl_loss('Uniform', x_, y_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.u.cross_entropy('Uniform', x_, y_) + return h_sum_kl - cross_entropy + +def test_log_cdf(): + """ + Test log_cdf. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_logcdf = uniform_benchmark.logcdf([0.5, 0.8, 2.0]).astype(np.float32) + logcdf = LogCDF() + x_ = Tensor(np.array([0.5, 0.8, 2.0]).astype(np.float32), dtype=dtype.float32) + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +def test_survival(): + """ + Test survival function. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_survival = uniform_benchmark.sf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) + survival = SF() + x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = survival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +def test_log_survival(): + """ + Test log survival function. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_logsurvival = uniform_benchmark.logsf([0.5, 0.8, -2.0]).astype(np.float32) + logsurvival = LogSF() + x_ = Tensor(np.array([0.5, 0.8, -2.0]).astype(np.float32), dtype=dtype.float32) + output = logsurvival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + low_b = -1.0 + high_b = 2.0 + diff = cross_entropy(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/cpu/test_embedding_look_up_op.py b/tests/st/ops/cpu/test_embedding_look_up_op.py new file mode 100644 index 0000000000..e7fb713fe5 --- /dev/null +++ b/tests/st/ops/cpu/test_embedding_look_up_op.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self, offset): + super(Net, self).__init__() + self.embedding = P.EmbeddingLookup() + self.offset = offset + + def construct(self, param, index): + return self.embedding(param, index, self.offset) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_embedding_look_up0(): + params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32) + indices = Tensor(np.array([5, 2, 8, 5]), mstype.int32) + offset = 4 + embedding = Net(offset) + out = embedding(params, indices) + expect = np.array([[10, 11], [0, 0], [0, 0], [10, 11]]).astype(np.float32) + assert (out.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_embedding_look_up1(): + params = Tensor(np.array([[8, 9], [10, 11]]), mstype.float32) + indices = Tensor(np.array([2, 2, 1, 0]), mstype.int32) + offset = 0 + embedding = Net(offset) + out = embedding(params, indices) + expect = np.array([[0, 0], [0, 0], [10, 11], [8, 9]]).astype(np.float32) + assert (out.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_embedding_look_up2(): + params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32) + indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32) + offset = 4 + embedding = Net(offset) + out = embedding(params, indices) + expect = np.array([[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).astype(np.float32) + assert (out.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_embedding_look_up3(): + params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32) + indices = Tensor(np.array([[[5], [2]], [[8], [5]]]), mstype.int32) + offset = 4 + embedding = Net(offset) + out = embedding(params, indices) + expect = np.array([[[[10, 11]], [[0, 0]]], [[[0, 0]], [[10, 11]]]]).astype(np.float32) + assert (out.asnumpy() == expect).all() diff --git a/tests/st/ops/cpu/test_lstm_op.py b/tests/st/ops/cpu/test_lstm_op.py index 7992bfbf0a..c8174a5f90 100644 --- a/tests/st/ops/cpu/test_lstm_op.py +++ b/tests/st/ops/cpu/test_lstm_op.py @@ -23,7 +23,7 @@ from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import ParameterTuple, Parameter -context.set_context(device_target='CPU') +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') class LstmNet(nn.Cell): diff --git a/tests/st/ops/cpu/test_sparse_apply_adam_op.py b/tests/st/ops/cpu/test_sparse_apply_adam_op.py index 887f76313c..e57b8b515d 100644 --- a/tests/st/ops/cpu/test_sparse_apply_adam_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_adam_op.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -43,6 +44,9 @@ class Net(nn.Cell): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_net(): gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) indices = Tensor([0, 1, 2], mstype.int32) diff --git a/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py b/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py index 3071c54f04..31826b3fab 100644 --- a/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -35,6 +36,9 @@ class Net(nn.Cell): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_net(): gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) indices = Tensor([0, 1, 2], mstype.int32) diff --git a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py index 67510a73be..696f3bf016 100644 --- a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor @@ -37,6 +38,9 @@ class Net(nn.Cell): return out +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_net(): gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) indices = Tensor([0, 1, 2], mstype.int32) diff --git a/tests/st/ops/gpu/test_binary_cross_entropy_op.py b/tests/st/ops/gpu/test_binary_cross_entropy_op.py new file mode 100644 index 0000000000..724188314d --- /dev/null +++ b/tests/st/ops/gpu/test_binary_cross_entropy_op.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.BinaryCrossEntropy = P.BinaryCrossEntropy("none") + + def construct(self, x, y, weight): + return self.BinaryCrossEntropy(x, y, weight) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + net = Net() + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.09555826, 1.2861121, 0.03518666, 0.6969416, 0.24313456, 0.99062896, + 0.19205657, 0.5465214, 0.36964455, 0.21999404, 2.2953863, 2.2566645, + 1.5803775, 1.3266402, 0.9883408, 1.2997618, 0.05439841, 0.14389999, + 0.03405444, 0.23934692] + assert np.allclose(loss.asnumpy(), expect) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens, weight): + gout = self.grad(self.network)(x1, x2, sens, weight) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss_grad(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + grad = Grad(Net()) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens), Tensor(weight)) + + dx1_expect = [-4.80516590e-02, 2.32625079e+00, 6.38972521e-02, 3.13642323e-01, + -1.65661633e-01, -1.71821892e+00, -1.13685496e-01, 1.26669514e+00, + 1.47891801e-03, 5.83921909e-01, -2.17992840e+01, 4.21899414e+00, + 2.85430793e-02, -3.21346498e+00, -2.22674108e+00, -2.80453944e+00, + -1.19787852e-04, 2.48514321e-02, -1.66696273e-02, -2.71965731e-02] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) diff --git a/tests/st/ops/gpu/test_boundingbox_decode_op.py b/tests/st/ops/gpu/test_boundingbox_decode_op.py new file mode 100644 index 0000000000..8400ee02b9 --- /dev/null +++ b/tests/st/ops/gpu/test_boundingbox_decode_op.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetBoundingBoxDecode(nn.Cell): + def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): + super(NetBoundingBoxDecode, self).__init__() + self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=means, stds=stds, + wh_ratio_clip=0.016) + + def construct(self, anchor, groundtruth): + return self.decode(anchor, groundtruth) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_boundingbox_decode(): + anchor = np.array([[4, 1, 2, 1], [2, 2, 2, 3]], np.float32) + deltas = np.array([[3, 1, 2, 2], [1, 2, 1, 4]], np.float32) + means = (0.1, 0.1, 0.2, 0.2) + stds = (2.0, 2.0, 3.0, 3.0) + anchor_box = Tensor(anchor, mindspore.float32) + deltas_box = Tensor(deltas, mindspore.float32) + expect_deltas = np.array([[28.6500, 0.0000, 0.0000, 33.8500], + [0.0000, 0.0000, 15.8663, 72.7000]], np.float32) + + error = np.ones(shape=[2, 4]) * 1.0e-4 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_decode = NetBoundingBoxDecode(means, stds) + output = boundingbox_decode(anchor_box, deltas_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_decode = NetBoundingBoxDecode(means, stds) + output = boundingbox_decode(anchor_box, deltas_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) diff --git a/tests/st/ops/gpu/test_boundingbox_encode_op.py b/tests/st/ops/gpu/test_boundingbox_encode_op.py new file mode 100644 index 0000000000..c34e0e0e8e --- /dev/null +++ b/tests/st/ops/gpu/test_boundingbox_encode_op.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetBoundingBoxEncode(nn.Cell): + def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): + super(NetBoundingBoxEncode, self).__init__() + self.encode = P.BoundingBoxEncode(means=means, stds=stds) + + def construct(self, anchor, groundtruth): + return self.encode(anchor, groundtruth) + +def bbox2delta(proposals, gt, means, stds): + px = (proposals[..., 0] + proposals[..., 2]) * 0.5 + py = (proposals[..., 1] + proposals[..., 3]) * 0.5 + pw = proposals[..., 2] - proposals[..., 0] + 1.0 + ph = proposals[..., 3] - proposals[..., 1] + 1.0 + + gx = (gt[..., 0] + gt[..., 2]) * 0.5 + gy = (gt[..., 1] + gt[..., 3]) * 0.5 + gw = gt[..., 2] - gt[..., 0] + 1.0 + gh = gt[..., 3] - gt[..., 1] + 1.0 + + dx = (gx - px) / pw + dy = (gy - py) / ph + dw = np.log(gw / pw) + dh = np.log(gh / ph) + means = np.array(means, np.float32) + stds = np.array(stds, np.float32) + deltas = np.stack([(dx - means[0]) / stds[0], (dy - means[1]) / stds[1], + (dw - means[2]) / stds[2], (dh - means[3]) / stds[3]], axis=-1) + + return deltas + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_boundingbox_encode(): + anchor = np.array([[4, 1, 6, 9], [2, 5, 5, 9]]).astype(np.float32) + gt = np.array([[3, 2, 7, 7], [1, 5, 5, 8]]).astype(np.float32) + means = (0.1, 0.1, 0.2, 0.2) + stds = (2.0, 2.0, 3.0, 3.0) + anchor_box = Tensor(anchor, mindspore.float32) + groundtruth_box = Tensor(gt, mindspore.float32) + expect_deltas = bbox2delta(anchor, gt, means, stds) + + error = np.ones(shape=[2, 4]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_encode = NetBoundingBoxEncode(means, stds) + output = boundingbox_encode(anchor_box, groundtruth_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_encode = NetBoundingBoxEncode(means, stds) + output = boundingbox_encode(anchor_box, groundtruth_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) diff --git a/tests/st/ops/gpu/test_broadcast_op.py b/tests/st/ops/gpu/test_broadcast_op.py index 3f97a229e8..202517729a 100644 --- a/tests/st/ops/gpu/test_broadcast_op.py +++ b/tests/st/ops/gpu/test_broadcast_op.py @@ -29,6 +29,8 @@ def test_nobroadcast(): x1_np = np.random.rand(10, 20).astype(np.float32) x2_np = np.random.rand(10, 20).astype(np.float32) + x1_np_int32 = np.random.randint(0, 100, (10, 20)).astype(np.int32) + x2_np_int32 = np.random.randint(0, 100, (10, 20)).astype(np.int32) output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) output_np = np.minimum(x1_np, x2_np) @@ -45,6 +47,9 @@ def test_nobroadcast(): output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) output_np = x1_np < x2_np assert np.allclose(output_ms.asnumpy(), output_np) + output_ms = P.Less()(Tensor(x1_np_int32), Tensor(x2_np_int32)) + output_np = x1_np_int32 < x2_np_int32 + assert np.allclose(output_ms.asnumpy(), output_np) output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) output_np = np.power(x1_np, x2_np) @@ -71,6 +76,8 @@ def test_broadcast(): x1_np = np.random.rand(3, 1, 5, 1).astype(np.float32) x2_np = np.random.rand(1, 4, 1, 6).astype(np.float32) + x1_np_int32 = np.random.randint(0, 100, (3, 1, 5, 1)).astype(np.int32) + x2_np_int32 = np.random.randint(0, 100, (3, 1, 5, 1)).astype(np.int32) output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) output_np = np.minimum(x1_np, x2_np) @@ -87,6 +94,9 @@ def test_broadcast(): output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) output_np = x1_np < x2_np assert np.allclose(output_ms.asnumpy(), output_np) + output_ms = P.Less()(Tensor(x1_np_int32), Tensor(x2_np_int32)) + output_np = x1_np_int32 < x2_np_int32 + assert np.allclose(output_ms.asnumpy(), output_np) output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) output_np = np.power(x1_np, x2_np) @@ -113,6 +123,8 @@ def test_broadcast_diff_dims(): x1_np = np.random.rand(2).astype(np.float32) x2_np = np.random.rand(2, 1).astype(np.float32) + x1_np_int32 = np.random.randint(0, 100, (2)).astype(np.int32) + x2_np_int32 = np.random.randint(0, 100, (2, 1)).astype(np.int32) output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) output_np = np.minimum(x1_np, x2_np) @@ -129,6 +141,9 @@ def test_broadcast_diff_dims(): output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) output_np = x1_np < x2_np assert np.allclose(output_ms.asnumpy(), output_np) + output_ms = P.Less()(Tensor(x1_np_int32), Tensor(x2_np_int32)) + output_np = x1_np_int32 < x2_np_int32 + assert np.allclose(output_ms.asnumpy(), output_np) output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) output_np = np.power(x1_np, x2_np) diff --git a/tests/st/ops/gpu/test_broadcast_to_ops.py b/tests/st/ops/gpu/test_broadcast_to_ops.py new file mode 100644 index 0000000000..137f271519 --- /dev/null +++ b/tests/st/ops/gpu/test_broadcast_to_ops.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) + shape = (3, 4, 5, 6) + + output = P.BroadcastTo(shape)(Tensor(x_np)) + expect = np.broadcast_to(x_np, shape) + assert np.allclose(output.asnumpy(), expect) + + x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16) + output = P.BroadcastTo(shape)(Tensor(x1_np)) + expect = np.broadcast_to(x1_np, shape) + assert np.allclose(output.asnumpy(), expect) + + x1_np = np.random.rand(4, 5).astype(np.float32) + shape = (2, 3, 4, 5) + output = P.BroadcastTo(shape)(Tensor(x1_np)) + expect = np.broadcast_to(x1_np, shape) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_cast_op.py b/tests/st/ops/gpu/test_cast_op.py index 793d92d7bc..d3c8543101 100644 --- a/tests/st/ops/gpu/test_cast_op.py +++ b/tests/st/ops/gpu/test_cast_op.py @@ -70,3 +70,292 @@ def test_cast1(): assert type0 == 'float32' type1 = output[1].asnumpy().dtype assert type1 == 'float32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast2(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float16)) + t0 = mstype.int32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float16)) + t1 = mstype.float64 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int32' + type1 = output[1].asnumpy().dtype + assert type1 == 'float64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast3(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int64)) + t0 = mstype.int32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32)) + t1 = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int32' + type1 = output[1].asnumpy().dtype + assert type1 == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast4(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32)) + t0 = mstype.float16 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32)) + t1 = mstype.int8 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float16' + type1 = output[1].asnumpy().dtype + assert type1 == 'int8' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast5(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32)) + t0 = mstype.uint8 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32)) + t1 = mstype.bool_ + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'uint8' + type1 = output[1].asnumpy().dtype + assert type1 == 'bool' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast6(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t0 = mstype.float64 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t1 = mstype.float32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float64' + type1 = output[1].asnumpy().dtype + assert type1 == 'float32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast7(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t0 = mstype.float32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t1 = mstype.float16 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float32' + type1 = output[1].asnumpy().dtype + assert type1 == 'float16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast8(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t0 = mstype.int32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t1 = mstype.int16 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int32' + type1 = output[1].asnumpy().dtype + assert type1 == 'int16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast9(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8)) + t0 = mstype.int64 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool)) + t1 = mstype.float16 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int64' + type1 = output[1].asnumpy().dtype + assert type1 == 'float16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast10(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool)) + t0 = mstype.int8 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool)) + t1 = mstype.float64 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int8' + type1 = output[1].asnumpy().dtype + assert type1 == 'float64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast11(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool)) + t0 = mstype.int16 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool)) + t1 = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int16' + type1 = output[1].asnumpy().dtype + assert type1 == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast12(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool)) + t0 = mstype.int64 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.uint8)) + t1 = mstype.float32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int64' + type1 = output[1].asnumpy().dtype + assert type1 == 'float32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast13(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.uint8)) + t0 = mstype.int32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.uint8)) + t1 = mstype.float16 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'int32' + type1 = output[1].asnumpy().dtype + assert type1 == 'float16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast14(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t0 = mstype.float64 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t1 = mstype.float32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float64' + type1 = output[1].asnumpy().dtype + assert type1 == 'float32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast15(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t0 = mstype.float16 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t1 = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float16' + type1 = output[1].asnumpy().dtype + assert type1 == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast16(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t0 = mstype.float16 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int64)) + t1 = mstype.float64 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float16' + type1 = output[1].asnumpy().dtype + assert type1 == 'float64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast17(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t0 = mstype.float32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16)) + t1 = mstype.float16 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float32' + type1 = output[1].asnumpy().dtype + assert type1 == 'float16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast18(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int64)) + t0 = mstype.float32 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int64)) + t1 = mstype.float16 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = Net(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'float32' + type1 = output[1].asnumpy().dtype + assert type1 == 'float16' diff --git a/tests/st/ops/gpu/test_check_valid_op.py b/tests/st/ops/gpu/test_check_valid_op.py new file mode 100644 index 0000000000..2f30ecfc6e --- /dev/null +++ b/tests/st/ops/gpu/test_check_valid_op.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetCheckValid(nn.Cell): + def __init__(self): + super(NetCheckValid, self).__init__() + self.valid = P.CheckValid() + + def construct(self, anchor, image_metas): + return self.valid(anchor, image_metas) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_boundingbox_decode(): + anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], np.float32) + image_metas = np.array([768, 1280, 1], np.float32) + anchor_box = Tensor(anchor, mindspore.float32) + image_metas_box = Tensor(image_metas, mindspore.float32) + expect = np.array([True, False, False], np.bool_) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_decode = NetCheckValid() + output = boundingbox_decode(anchor_box, image_metas_box) + diff = (output.asnumpy() == expect) + assert (diff == 1).all() + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_decode = NetCheckValid() + output = boundingbox_decode(anchor_box, image_metas_box) + diff = (output.asnumpy() == expect) + assert (diff == 1).all() diff --git a/tests/st/ops/gpu/test_cumsum_op.py b/tests/st/ops/gpu/test_cumsum_op.py new file mode 100644 index 0000000000..c639c2952d --- /dev/null +++ b/tests/st/ops/gpu/test_cumsum_op.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis0 = 3 + +x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis1 = 3 + +x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis2 = 2 + +x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis3 = 2 + +x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis4 = 1 + +x5 = np.random.rand(2, 3).astype(np.float32) +axis5 = 1 + +x6 = np.random.rand(1, 1, 1, 1).astype(np.float32) +axis6 = 0 + +context.set_context(device_target='GPU') + + +class CumSum(nn.Cell): + def __init__(self): + super(CumSum, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + + @ms_function + def construct(self): + return (P.CumSum()(self.x0, self.axis0), + P.CumSum()(self.x1, self.axis1), + P.CumSum()(self.x2, self.axis2), + P.CumSum()(self.x3, self.axis3), + P.CumSum()(self.x4, self.axis4), + P.CumSum()(self.x5, self.axis5), + P.CumSum()(self.x6, self.axis6)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_CumSum(): + cumsum = CumSum() + output = cumsum() + + expect0 = np.cumsum(x0, axis=axis0) + diff0 = abs(output[0].asnumpy() - expect0) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = np.cumsum(x1, axis=axis1) + diff1 = abs(output[1].asnumpy() - expect1) + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = np.cumsum(x2, axis=axis2) + diff2 = abs(output[2].asnumpy() - expect2) + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape + + expect3 = np.cumsum(x3, axis=axis3) + diff3 = abs(output[3].asnumpy() - expect3) + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output[3].shape == expect3.shape + + expect4 = np.cumsum(x4, axis=axis4) + diff4 = abs(output[4].asnumpy() - expect4) + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output[4].shape == expect4.shape + + expect5 = np.cumsum(x5, axis=axis5) + diff5 = abs(output[5].asnumpy() - expect5) + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output[5].shape == expect5.shape + + expect6 = np.cumsum(x6, axis=axis6) + diff6 = abs(output[6].asnumpy() - expect6) + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output[6].shape == expect6.shape diff --git a/tests/st/ops/gpu/test_equal_op.py b/tests/st/ops/gpu/test_equal_op.py index 6dc08b08bb..5d38a3abd0 100644 --- a/tests/st/ops/gpu/test_equal_op.py +++ b/tests/st/ops/gpu/test_equal_op.py @@ -60,6 +60,11 @@ def test_equal(): y1_np = np.array([0, 1, -3]).astype(np.float32) y1 = Tensor(y1_np) expect1 = np.equal(x1_np, y1_np) + x2_np = np.array([0, 1, 3]).astype(np.int32) + x2 = Tensor(x2_np) + y2_np = np.array([0, 1, -3]).astype(np.int32) + y2 = Tensor(y2_np) + expect2 = np.equal(x2_np, y2_np) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") equal = NetEqual() @@ -69,6 +74,9 @@ def test_equal(): output1 = equal(x1, y1) assert np.all(output1.asnumpy() == expect1) assert output1.shape == expect1.shape + output2 = equal(x2, y2) + assert np.all(output2.asnumpy() == expect2) + assert output2.shape == expect2.shape context.set_context(mode=context.GRAPH_MODE, device_target="GPU") equal = NetEqual() @@ -78,6 +86,9 @@ def test_equal(): output1 = equal(x1, y1) assert np.all(output1.asnumpy() == expect1) assert output1.shape == expect1.shape + output2 = equal(x2, y2) + assert np.all(output2.asnumpy() == expect2) + assert output2.shape == expect2.shape @pytest.mark.level0 diff --git a/tests/st/ops/gpu/test_floordiv_op.py b/tests/st/ops/gpu/test_floordiv_op.py new file mode 100644 index 0000000000..dc7d76807f --- /dev/null +++ b/tests/st/ops/gpu/test_floordiv_op.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class NetFloorDiv(nn.Cell): + def __init__(self): + super(NetFloorDiv, self).__init__() + self.floordiv = P.FloorDiv() + + def construct(self, x, y): + return self.floordiv(x, y) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_floor_div(): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x3_np = np.random.randint(1, 5, 1).astype(np.float32) + y3_np = np.random.randint(1, 5, 1).astype(np.float32) + x4_np = np.array(768).astype(np.float32) + y4_np = np.array(3072.5).astype(np.float32) + x5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) + y5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) + x6_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32) + y6_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + x5 = Tensor(x5_np) + y5 = Tensor(y5_np) + x6 = Tensor(x6_np) + y6 = Tensor(y6_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + floor_div = NetFloorDiv() + output0 = floor_div(x0, y0) + expect0 = np.floor_divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = floor_div(x1, y1) + expect1 = np.floor_divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = floor_div(x2, y2) + expect2 = np.floor_divide(x2_np, y2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + output3 = floor_div(x3, y3) + expect3 = np.floor_divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = floor_div(x4, y4) + expect4 = np.floor_divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + output5 = floor_div(x5, y5) + expect5 = np.floor_divide(x5_np, y5_np) + diff5 = output5.asnumpy() - expect5 + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output5.shape == expect5.shape + + output6 = floor_div(x6, y6) + expect6 = np.floor_divide(x6_np, y6_np) + diff6 = output6.asnumpy() - expect6 + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output6.shape == expect6.shape diff --git a/tests/st/ops/gpu/test_gathernd_op.py b/tests/st/ops/gpu/test_gathernd_op.py new file mode 100644 index 0000000000..c901eb08f2 --- /dev/null +++ b/tests/st/ops/gpu/test_gathernd_op.py @@ -0,0 +1,151 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context + +class GatherNdNet(nn.Cell): + def __init__(self): + super(GatherNdNet, self).__init__() + self.gathernd = P.GatherNd() + + def construct(self, x, indices): + return self.gathernd(x, indices) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd0(): + x = Tensor(np.arange(3 * 2, dtype=np.float32).reshape(3, 2)) + indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32)) + expect = np.array([3., 1.]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_gathernd1(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)] + for k in range(3)] for l in range(2)], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNdNet() + output = gather(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_gathernd2(): + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]]).astype(np.float16)) + + indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) + expect = np.array([[0., 0., 0., 0., 0.], + [4., 9., 5., 6., 4.], + [0., 0., 0., 0., 0.]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_gathernd3(): + x = Tensor(np.array([[4, 5, 4, 1, 5], + [4, 9, 5, 6, 4], + [9, 8, 4, 3, 6], + [0, 4, 2, 2, 8], + [1, 8, 6, 2, 8], + [8, 1, 9, 7, 3], + [7, 9, 2, 5, 7], + [9, 8, 6, 8, 5], + [3, 7, 2, 7, 4], + [4, 2, 8, 2, 9]] + ).astype(np.int32)) + + indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) + expect = np.array([[0, 0, 0, 0, 0], + [4, 9, 5, 6, 4], + [0, 0, 0, 0, 0]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) diff --git a/tests/st/ops/gpu/test_iou_op.py b/tests/st/ops/gpu/test_iou_op.py new file mode 100644 index 0000000000..17812f2d17 --- /dev/null +++ b/tests/st/ops/gpu/test_iou_op.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetIOU(nn.Cell): + def __init__(self, mode): + super(NetIOU, self).__init__() + self.encode = P.IOU(mode=mode) + + def construct(self, anchor, groundtruth): + return self.encode(anchor, groundtruth) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_iou(): + pos1 = [101, 169, 246, 429] + pos2 = [121, 138, 304, 374] + mode = "iou" + pos1_box = Tensor(np.array(pos1).reshape(1, 4), mindspore.float32) + pos2_box = Tensor(np.array(pos2).reshape(1, 4), mindspore.float32) + expect_result = np.array(0.46551168, np.float32) + + error = np.ones(shape=[1]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + overlaps = NetIOU(mode) + output = overlaps(pos1_box, pos2_box) + diff = output.asnumpy() - expect_result + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + overlaps = NetIOU(mode) + output = overlaps(pos1_box, pos2_box) + diff = output.asnumpy() - expect_result + assert np.all(abs(diff) < error) diff --git a/tests/st/ops/gpu/test_kl_div_op.py b/tests/st/ops/gpu/test_kl_div_op.py new file mode 100644 index 0000000000..e5b8fcd079 --- /dev/null +++ b/tests/st/ops/gpu/test_kl_div_op.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.KLDivLoss = P.KLDivLoss("none") + + def construct(self, x, y): + return self.KLDivLoss(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + net = Net() + loss = net(Tensor(prediction), Tensor(target)) + expect = [-0.5297444, -0.40738472, -0.5733339, -0.58720195, -0.42922008, -0.31237593, + -0.3332863, -0.78742254, -0.6662671, -0.17546377, -0.31526336, -0.46702948, + -0.23191005, -0.2512708, -0.20934652, -0.32021108, -0.45477402, -0.278453, + -0.5551879, -0.48938933] + assert np.allclose(loss.asnumpy(), expect) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens): + gout = self.grad(self.network)(x1, x2, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss_grad(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + grad = Grad(Net()) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) + + dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656, + -0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884, + -0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206, + -0.03094601, -0.14319494] + + dx2_expect = [0.0163771, -0.950962, -0.03309895, -0.5481312, 0.01523498, 0.39894313, + -0.20858267, -0.27628726, -0.06815486, -0.5134226, 0.46645382, -1.3477919, + -2.409831, 0.65787154, 0.4682768, 0.55671424, -0.04362264, -0.36274382, + 0.00852979, -0.03639247] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) + assert np.allclose(dx[1].asnumpy(), dx2_expect) diff --git a/tests/st/ops/gpu/test_layer_norm_grad_op.py b/tests/st/ops/gpu/test_layer_norm_grad_op.py index 032dee50ac..f7a91e7cdf 100644 --- a/tests/st/ops/gpu/test_layer_norm_grad_op.py +++ b/tests/st/ops/gpu/test_layer_norm_grad_op.py @@ -141,3 +141,81 @@ def test_layernormgrad2(): assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad3(): + begin_norm_axis = -1 + begin_params_axis = -1 + x_np = np.random.randn(32, 64).astype(np.float32) + dy_np = np.random.randn(32, 64).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad4(): + begin_norm_axis = -1 + begin_params_axis = -1 + x_np = np.random.randn(32, 64).astype(np.float32) + dy_np = np.random.randn(32, 64).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad5(): + begin_norm_axis = 2 + begin_params_axis = 1 + x_np = np.random.randn(128, 2, 16, 32).astype(np.float32) + dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) diff --git a/tests/st/ops/gpu/test_layer_norm_op.py b/tests/st/ops/gpu/test_layer_norm_op.py index 776201735b..040bc2c1bc 100644 --- a/tests/st/ops/gpu/test_layer_norm_op.py +++ b/tests/st/ops/gpu/test_layer_norm_op.py @@ -133,3 +133,67 @@ def test_layernorm3d_2(): assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm2d_2(): + begin_norm_axis = -1 + begin_params_axis = 1 + x_np = np.random.randn(64, 32).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm2d_3(): + begin_norm_axis = -1 + begin_params_axis = 1 + x_np = np.random.randn(128, 128).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm2d_4(): + begin_norm_axis = 2 + begin_params_axis = 1 + np.random.seed(42) + x_np = np.random.randn(128, 2, 16, 32).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) diff --git a/tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py b/tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py new file mode 100644 index 0000000000..a2ff401738 --- /dev/null +++ b/tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py @@ -0,0 +1,147 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class Net_Pool(nn.Cell): + def __init__(self): + super(Net_Pool, self).__init__() + self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID") + + def construct(self, x): + return self.maxpool_fun(x) + + +class Net_Pool2(nn.Cell): + def __init__(self): + super(Net_Pool2, self).__init__() + self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME") + + def construct(self, x): + return self.maxpool_fun(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool_with_argmax_2d(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, -4, -5], + [6, 7, 8, 9, -10, -11], + [12, 13, 14, -15, -16, -17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float32)) + expect_result = (np.array([[[ + [7, 9, -4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect_result2 = (np.array([[[ + [14, 14, -4], + [26, 28, 29], + [32, 34, 35] + ]]])) + expect_index_result = (np.array([[[ + [7, 9, 4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect__index_result2 = (np.array([[[ + [14, 14, 4], + [26, 28, 29], + [32, 34, 35] + ]]])) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool_with_argmax_2d_fp16(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, -4, -5], + [6, 7, 8, 9, -10, -11], + [12, 13, 14, -15, -16, -17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float16)) + expect_result = (np.array([[[ + [7, 9, -4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect_result2 = (np.array([[[ + [14, 14, -4], + [26, 28, 29], + [32, 34, 35] + ]]])) + expect_index_result = (np.array([[[ + [7, 9, 4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect__index_result2 = (np.array([[[ + [14, 14, 4], + [26, 28, 29], + [32, 34, 35] + ]]])) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + \ No newline at end of file diff --git a/tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py b/tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py new file mode 100644 index 0000000000..a9ea790ffa --- /dev/null +++ b/tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class Net_Pool_Grad(nn.Cell): + def __init__(self): + super(Net_Pool_Grad, self).__init__() + self.maxpool_grad_fun = G.MaxPoolGradWithArgmax(padding="VALID", ksize=2, strides=2) + + def construct(self, x, dy, index): + return self.maxpool_grad_fun(x, dy, index) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool2d_grad(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float32)) + dy = Tensor(np.array([[[ + [0.7, 0.9, 0.11], + [0.19, 0.21, 0.23], + [0.31, 0.33, 0.35] + ]]]).astype(np.float32)) + index = Tensor(np.array([[[ + [7, 9, 11], + [19, 21, 23], + [31, 33, 35] + ]]]).astype(np.int32)) + expect_result = (np.array([[[ + [0., 0., 0., 0., 0., 0.], + [0., 0.7, 0., 0.9, 0., 0.11], + [0., 0., 0., 0., 0., 0.], + [0., 0.19, 0., 0.21, 0., 0.23], + [0., 0., 0., 0., 0., 0.], + [0., 0.31, 0., 0.33, 0., 0.35] + ]]])) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool2d_grad_fp16(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float16)) + dy = Tensor(np.array([[[ + [0.7, 0.9, 0.11], + [0.19, 0.21, 0.23], + [0.31, 0.33, 0.35] + ]]]).astype(np.float16)) + index = Tensor(np.array([[[ + [7, 9, 11], + [19, 21, 23], + [31, 33, 35] + ]]]).astype(np.int32)) + expect_result = np.array([[[ + [0., 0., 0., 0., 0., 0.], + [0., 0.7, 0., 0.9, 0., 0.11], + [0., 0., 0., 0., 0., 0.], + [0., 0.19, 0., 0.21, 0., 0.23], + [0., 0., 0., 0., 0., 0.], + [0., 0.31, 0., 0.33, 0., 0.35] + ]]]).astype(np.float16) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) diff --git a/tests/st/ops/gpu/test_mirror_pad.py b/tests/st/ops/gpu/test_mirror_pad.py new file mode 100644 index 0000000000..9e6613d744 --- /dev/null +++ b/tests/st/ops/gpu/test_mirror_pad.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore +import mindspore.nn as nn +import mindspore.context as context + +from mindspore import Tensor +from mindspore.ops.composite import GradOperation + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_mirror_pad(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + test1_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]] + test_1_paddings = ((0, 0), (0, 0), (1, 1), (2, 2)) + test1_arr_exp = [[[[6, 5, 4, 5, 6, 5, 4], [3, 2, 1, 2, 3, 2, 1], [6, 5, 4, 5, 6, 5, 4], + [9, 8, 7, 8, 9, 8, 7], [6, 5, 4, 5, 6, 5, 4]]]] + + test2_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]] + test_2_paddings = ((0, 0), (0, 0), (1, 1), (2, 2)) + test2_arr_exp = [[[[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], [5, 4, 4, 5, 6, 6, 5], + [8, 7, 7, 8, 9, 9, 8], [8, 7, 7, 8, 9, 9, 8]]]] + + reflectOp = nn.Pad(mode='REFLECT', paddings=test_1_paddings) + symmOp = nn.Pad(mode='SYMMETRIC', paddings=test_2_paddings) + + x_test_1 = Tensor(np.array(test1_arr_in), dtype=mindspore.float32) + x_test_2 = Tensor(np.array(test2_arr_in), dtype=mindspore.float32) + + y_test_1 = reflectOp(x_test_1).asnumpy() + y_test_2 = symmOp(x_test_2).asnumpy() + + print(np.array(test1_arr_in)) + print(y_test_1) + + np.testing.assert_equal(np.array(test1_arr_exp), y_test_1) + np.testing.assert_equal(np.array(test2_arr_exp), y_test_2) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + def construct(self, input_, output_grad): + return self.grad(self.network)(input_, output_grad) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.pad = nn.Pad(mode="REFLECT", paddings=((0, 0), (0, 0), (1, 0), (0, 2))) + def construct(self, x): + return self.pad(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_mirror_pad_backprop(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]] # size -> 3*3 + test_arr_in = Tensor(test_arr_in, dtype=mindspore.float32) + dy = (np.ones((1, 1, 4, 5)) * 0.1).astype(np.float32) + expected_dx = np.array([[[[0.2, 0.2, 0.1], + [0.4, 0.4, 0.2], + [0.2, 0.2, 0.1]]]]) + net = Grad(Net()) + dx = net(test_arr_in, Tensor(dy)) + dx = dx[0].asnumpy() + np.testing.assert_array_almost_equal(dx, expected_dx) diff --git a/tests/st/ops/gpu/test_nms_with_mask_op.py b/tests/st/ops/gpu/test_nms_with_mask_op.py new file mode 100644 index 0000000000..210a14b3ff --- /dev/null +++ b/tests/st/ops/gpu/test_nms_with_mask_op.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore +from mindspore import Tensor +from mindspore.ops import operations as P + +def manualNMS(bbox, overlap_val_iou): + mask = [True] * len(bbox) + for box_a_index, _ in enumerate(bbox): + if not mask[box_a_index]: + continue # ignore if not in list + box_a = bbox[box_a_index] # select box for value extraction + for box_b_index in range(box_a_index + 1, len(bbox)): + if not mask[box_b_index]: + continue # ignore if not in list + box_b = bbox[box_b_index] + areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) + areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) + overlap_x1 = max(box_a[0], box_b[0]) + overlap_y1 = max(box_a[1], box_b[1]) + overlap_x2 = min(box_a[2], box_b[2]) + overlap_y2 = min(box_a[3], box_b[3]) + width = max((overlap_x2 - overlap_x1), 0) + height = max((overlap_y2 - overlap_y1), 0) + # generate IOU decision + mask[box_b_index] = not ( + (width * height)/(areaA + areaB - (width * height))) > overlap_val_iou + return mask + + +def runMSRun(op, bbox): + inputs = Tensor(bbox, mindspore.float32) + box, _, mask = op(inputs) + box = box.asnumpy() + mask = mask.asnumpy() + sel_idx = np.where(mask) + sel_rows = box[sel_idx][:, 0:4] + sel_score = box[sel_idx][:, -1] + return sel_rows, sel_score + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_check_order(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + nms_op = P.NMSWithMask(0.5) + for _ in range(500): + count = 20 + box = np.random.randint(1, 100, size=(count, 4)) + box[:, 2] = box[:, 0] + box[:, 2] + box[:, 3] = box[:, 1] + box[:, 3] + unsorted_scores = np.random.rand(count, 1) + bbox = np.hstack((box, unsorted_scores)) + bbox = Tensor(bbox, dtype=mindspore.float32) + prop, _, _ = nms_op(bbox) + ms_sorted_scores = (prop.asnumpy()[:, -1]) # select just scores + np_sorted_scores = (np.sort(unsorted_scores, axis=0)[::-1][:, 0]) # sort manually + np.testing.assert_array_almost_equal( + ms_sorted_scores, np_sorted_scores) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_masl_check_result(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_count = 500 + for x in range(1, test_count+1): + count = 20 # size of bbox lists + nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1 + box = np.random.randint(1, 100, size=(count, 4)) + box[:, 2] = box[:, 0] + box[:, 2] + box[:, 3] = box[:, 1] + box[:, 3] + unsorted_scores = np.random.rand(count, 1) + sorted_scores = np.sort(unsorted_scores, axis=0)[::-1] + bbox = np.hstack((box, sorted_scores)) + bbox = Tensor(bbox, dtype=mindspore.float32) + _, _, mask = nms_op(bbox) + mask = mask.asnumpy() + manual_mask = manualNMS(box, x * 0.002) + np.testing.assert_array_equal(mask, np.array(manual_mask)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_edge_case_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # CASE 1 - FULL OVERLAP BOXES - Every box is duplicated and has a different score + nms_op1 = P.NMSWithMask(0.3) + bbox1 = [[12, 4, 33, 17, 0.6], [20, 11, 38, 23, 0.1], [20, 10, 45, 26, 0.9], [15, 17, 35, 38, 0.5], + [10, 20, 30, 40, 0.4], [35, 35, 89, 90, 0.8], [12, 4, 33, 17, 0.3], [20, 11, 38, 23, 0.2], + [20, 10, 45, 26, 0.1], [15, 17, 35, 38, 0.8], [10, 20, 30, 40, 0.41], [35, 35, 89, 90, 0.82]] + expected_bbox = np.array([[20., 10., 45., 26.], + [35., 35., 89., 90.], + [15., 17., 35., 38.], + [12., 4., 33., 17.]]) + expected_score = np.array([0.9, 0.82, 0.8, 0.6]) + + sel_rows, sel_score = runMSRun(nms_op1, bbox1) + np.testing.assert_almost_equal(sel_rows, expected_bbox) + np.testing.assert_almost_equal(sel_score, expected_score) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_edge_case_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # CASE 2 - 0 value boxes - with valid scores + nms_op2 = P.NMSWithMask(0.5) + bbox2 = [[0, 0, 0, 0, 0.6], [0, 0, 0, 0, 0.1]] + expected_bbox = np.array([[0., 0., 0., 0.], + [0., 0., 0., 0.]]) + expected_score = np.array([0.6, 0.1]) + + sel_rows, sel_score = runMSRun(nms_op2, bbox2) + np.testing.assert_almost_equal(sel_rows, expected_bbox) + np.testing.assert_almost_equal(sel_score, expected_score) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_edge_case_3(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # CASE 3 - x2/x1 and y2/y1 sequence out of place + nms_op3 = P.NMSWithMask(0.7) + bbox3 = [[70, 70, 45, 75, 0.6], [30, 33, 43, 29, 0.1]] + expected_bbox = np.array([[70., 70., 45., 75.], + [30., 33., 43., 29.]]) + expected_score = np.array([0.6, 0.1]) + + sel_rows, sel_score = runMSRun(nms_op3, bbox3) + np.testing.assert_almost_equal(sel_rows, expected_bbox) + np.testing.assert_almost_equal(sel_score, expected_score) diff --git a/tests/st/ops/gpu/test_oneslike_op.py b/tests/st/ops/gpu/test_oneslike_op.py new file mode 100644 index 0000000000..e721d15729 --- /dev/null +++ b/tests/st/ops/gpu/test_oneslike_op.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + +class NetOnesLike(nn.Cell): + def __init__(self): + super(NetOnesLike, self).__init__() + self.ones_like = P.OnesLike() + + def construct(self, x): + return self.ones_like(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_OnesLike(): + x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32) + x1_np = np.random.uniform(-2, 2, 1).astype(np.float16) + x2_np = np.zeros([3, 3, 3], dtype=np.int32) + + x0 = Tensor(x0_np) + x1 = Tensor(x1_np) + x2 = Tensor(x2_np) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + ones_like = NetOnesLike() + output0 = ones_like(x0) + expect0 = np.ones_like(x0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = ones_like(x1) + expect1 = np.ones_like(x1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ones_like = NetOnesLike() + output0 = ones_like(x0) + expect0 = np.ones_like(x0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = ones_like(x1) + expect1 = np.ones_like(x1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = ones_like(x2) + expect2 = np.ones_like(x2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape diff --git a/tests/st/ops/gpu/test_random_choice_with_mask.py b/tests/st/ops/gpu/test_random_choice_with_mask.py new file mode 100644 index 0000000000..3ca12e7dd6 --- /dev/null +++ b/tests/st/ops/gpu/test_random_choice_with_mask.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class RCWM_count_in(nn.Cell): + def __init__(self): + super(RCWM_count_in, self).__init__() + self.RCWM_count_in = P.RandomChoiceWithMask(count=4, seed=1) + + def construct(self, x): + return self.RCWM_count_in(x) + +class RCWM_count_out(nn.Cell): + def __init__(self): + super(RCWM_count_out, self).__init__() + self.RCWM_count_out = P.RandomChoiceWithMask(count=10, seed=1) + + def construct(self, x): + return self.RCWM_count_out(x) + +class RCWM_3D(nn.Cell): + def __init__(self): + super(RCWM_3D, self).__init__() + self.RCWM_3D = P.RandomChoiceWithMask(count=10, seed=1) + + def construct(self, x): + return self.RCWM_3D(x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_3D(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool)) + expect1 = [[0, 1, 1], [0, 2, 1], [0, 2, 2], [1, 0, 1], [0, 1, 3], [0, 3, 0], [1, 3, 2], \ + [0, 0, 0], [1, 1, 2], [1, 3, 4]] + expect2 = [True, True, True, True, True, True, True, True, True, True] + rcwm = RCWM_3D() + output1, output2 = rcwm(input_tensor) + assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) + assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_count_out(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + expect1 = [[0, 2], [2, 2], [2, 1], [2, 0], [0, 0], [3, 3], [2, 3], [1, 3], [0, 0], [0, 0]] + expect2 = [True, True, True, True, True, True, True, True, False, False] + rcwm = RCWM_count_out() + output1, output2 = rcwm(input_tensor) + assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) + assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_count_in(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + expect1 = [[0, 2], [2, 2], [2, 1], [2, 0]] + expect2 = [True, True, True, True] + rcwm = RCWM_count_in() + output1, output2 = rcwm(input_tensor) + assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) + assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) diff --git a/tests/st/ops/gpu/test_reduce_min_op.py b/tests/st/ops/gpu/test_reduce_min_op.py new file mode 100644 index 0000000000..c502c549ae --- /dev/null +++ b/tests/st/ops/gpu/test_reduce_min_op.py @@ -0,0 +1,177 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis0 = 3 +keep_dims0 = True + +x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis1 = 3 +keep_dims1 = False + +x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis2 = 2 +keep_dims2 = True + +x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis3 = 2 +keep_dims3 = False + +x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis4 = () +np_axis4 = None +keep_dims4 = True + +x5 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis5 = () +np_axis5 = None +keep_dims5 = False + +x6 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis6 = -2 +keep_dims6 = False + +x7 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis7 = (-2, -1) +keep_dims7 = True + +x8 = np.random.rand(1, 1, 1, 1).astype(np.float32) +axis8 = () +np_axis8 = None +keep_dims8 = True + +context.set_context(device_target='GPU') + + +class ReduceMin(nn.Cell): + def __init__(self): + super(ReduceMin, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + self.keep_dims0 = keep_dims0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + self.keep_dims1 = keep_dims1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + self.keep_dims2 = keep_dims2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + self.keep_dims3 = keep_dims3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + self.keep_dims4 = keep_dims4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + self.keep_dims5 = keep_dims5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + self.keep_dims6 = keep_dims6 + + self.x7 = Tensor(x7) + self.axis7 = axis7 + self.keep_dims7 = keep_dims7 + + self.x8 = Tensor(x8) + self.axis8 = axis8 + self.keep_dims8 = keep_dims8 + + @ms_function + def construct(self): + return (P.ReduceMin(self.keep_dims0)(self.x0, self.axis0), + P.ReduceMin(self.keep_dims1)(self.x1, self.axis1), + P.ReduceMin(self.keep_dims2)(self.x2, self.axis2), + P.ReduceMin(self.keep_dims3)(self.x3, self.axis3), + P.ReduceMin(self.keep_dims4)(self.x4, self.axis4), + P.ReduceMin(self.keep_dims5)(self.x5, self.axis5), + P.ReduceMin(self.keep_dims6)(self.x6, self.axis6), + P.ReduceMin(self.keep_dims7)(self.x7, self.axis7), + P.ReduceMin(self.keep_dims8)(self.x8, self.axis8)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ReduceMin(): + reduce_min = ReduceMin() + output = reduce_min() + + expect0 = np.min(x0, axis=axis0, keepdims=keep_dims0) + diff0 = abs(output[0].asnumpy() - expect0) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = np.min(x1, axis=axis1, keepdims=keep_dims1) + diff1 = abs(output[1].asnumpy() - expect1) + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = np.min(x2, axis=axis2, keepdims=keep_dims2) + diff2 = abs(output[2].asnumpy() - expect2) + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape + + expect3 = np.min(x3, axis=axis3, keepdims=keep_dims3) + diff3 = abs(output[3].asnumpy() - expect3) + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output[3].shape == expect3.shape + + expect4 = np.min(x4, axis=np_axis4, keepdims=keep_dims4) + diff4 = abs(output[4].asnumpy() - expect4) + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output[4].shape == expect4.shape + + expect5 = np.min(x5, axis=np_axis5, keepdims=keep_dims5) + diff5 = abs(output[5].asnumpy() - expect5) + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output[5].shape == expect5.shape + + expect6 = np.min(x6, axis=axis6, keepdims=keep_dims6) + diff6 = abs(output[6].asnumpy() - expect6) + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output[6].shape == expect6.shape + + expect7 = np.min(x7, axis=axis7, keepdims=keep_dims7) + diff7 = abs(output[7].asnumpy() - expect7) + error7 = np.ones(shape=expect7.shape) * 1.0e-5 + assert np.all(diff7 < error7) + + expect8 = np.min(x8, axis=np_axis8, keepdims=keep_dims8) + diff8 = abs(output[8].asnumpy() - expect8) + error8 = np.ones(shape=expect8.shape) * 1.0e-5 + assert np.all(diff8 < error8) diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py new file mode 100644 index 0000000000..0203c87204 --- /dev/null +++ b/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py @@ -0,0 +1,87 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +class ResizeNearestNeighborGradAlignCornerT(nn.Cell): + def __init__(self): + super(ResizeNearestNeighborGradAlignCornerT, self).__init__() + self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad(align_corners=True) + + def construct(self, dy, size): + return self.ResizeNearestNeighborGradAlignCornerT(dy, size) + +class ResizeNearestNeighborGradAlignCornerF(nn.Cell): + def __init__(self): + super(ResizeNearestNeighborGradAlignCornerF, self).__init__() + self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad(align_corners=False) + + def construct(self, dy, size): + return self.ResizeNearestNeighborGradAlignCornerF(dy, size) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborGradAlignCornerT(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) + size = (4, 4) + expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float32) + rnn = ResizeNearestNeighborGradAlignCornerT() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16) + size = (4, 4) + expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerT() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.int32) + size = (4, 4) + expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.int32) + rnn = ResizeNearestNeighborGradAlignCornerT() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborGradAlignCornerF(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float32) + rnn = ResizeNearestNeighborGradAlignCornerF() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerF() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.int32) + rnn = ResizeNearestNeighborGradAlignCornerF() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py new file mode 100644 index 0000000000..1438d69637 --- /dev/null +++ b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py @@ -0,0 +1,81 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class ResizeNearestNeighborAlignCornerT(nn.Cell): + def __init__(self, size): + super(ResizeNearestNeighborAlignCornerT, self).__init__() + self.ResizeNearestNeighborAlignCornerT = P.ResizeNearestNeighbor(size, align_corners=True) + + def construct(self, x): + return self.ResizeNearestNeighborAlignCornerT(x) + +class ResizeNearestNeighborAlignCornerF(nn.Cell): + def __init__(self, size): + super(ResizeNearestNeighborAlignCornerF, self).__init__() + self.ResizeNearestNeighborAlignCornerF = P.ResizeNearestNeighbor(size, align_corners=False) + + def construct(self, x): + return self.ResizeNearestNeighborAlignCornerF(x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborAlignCornerT(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + rnn = ResizeNearestNeighborAlignCornerT((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + rnn = ResizeNearestNeighborAlignCornerT((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + rnn = ResizeNearestNeighborAlignCornerT((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborAlignCornerF(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + rnn = ResizeNearestNeighborAlignCornerF((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + rnn = ResizeNearestNeighborAlignCornerF((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + rnn = ResizeNearestNeighborAlignCornerF((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/gpu/test_roi_align_grad_op.py b/tests/st/ops/gpu/test_roi_align_grad_op.py new file mode 100644 index 0000000000..2231085259 --- /dev/null +++ b/tests/st/ops/gpu/test_roi_align_grad_op.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class NetROIAlignGrad(nn.Cell): + def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num): + super(NetROIAlignGrad, self).__init__() + self.roiAlignGrad = G.ROIAlignGrad( + xdiff_shape, + pooled_height, + pooled_width, + spatial_scale, + sample_num) + + def construct(self, dy, rois): + return self.roiAlignGrad(dy, rois) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_roi_align_grad(): + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + + dy = Tensor(np.array([[[ + [.1, .2, .3], + [.1, .2, .3], + [.1, .2, .3] + ]]], np.float32)) + + xdiff_shape = (1, 1, 6, 6) + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + roi_align_grad = NetROIAlignGrad( + xdiff_shape, + pooled_height, + pooled_width, + spatial_scale, + sample_num) + output = roi_align_grad(dy, rois) + print(output) + expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]]) + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) diff --git a/tests/st/ops/gpu/test_roi_align_op.py b/tests/st/ops/gpu/test_roi_align_op.py new file mode 100644 index 0000000000..e31d54ef05 --- /dev/null +++ b/tests/st/ops/gpu/test_roi_align_op.py @@ -0,0 +1,95 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_roi_align(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = Tensor(np.array([[ + [[1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30], + [31, 32, 33, 34, 35, 36]] + ]], np.float32)) + + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + + # test case 1 + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[2.75, 4.5, 6.5], + [13.25, 15., 17.], + [25.25, 27., 29.]]]] + assert (output.asnumpy() == expect).all() + + # test case 1 + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[2.75, 4.5, 6.5], + [13.25, 15., 17.], + [25.25, 27., 29.]]]] + assert (output.asnumpy() == expect).all() + + # test case 2 + pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[1.2333, 2.1000, 3.3000, 4.5000], + [6.4333, 7.3000, 8.5000, 9.7000], + [13.6333, 14.5000, 15.7000, 16.9000], + [20.8333, 21.7000, 22.9000, 24.1000]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + + # test case 3 + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.3, 3 + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0], + [0, 1.0, 0.0, 19.0, 18.0]], + np.float32)) + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[3.3333, 5.5000, 7.6667], + [16.3333, 18.5000, 20.6667], + [29.3333, 31.5000, 33.6667]]], + [[[4.5000, 6.3000, 8.1000], + [14.9000, 16.7000, 18.5000], + [25.7000, 27.5000, 29.3000]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + + # test case 4 + pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1 + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[4.625, 0.], + [0., 0.]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) diff --git a/tests/st/ops/gpu/test_scatter_nd.py b/tests/st/ops/gpu/test_scatter_nd.py new file mode 100644 index 0000000000..b201c7be2c --- /dev/null +++ b/tests/st/ops/gpu/test_scatter_nd.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class Net(nn.Cell): + def __init__(self, _shape): + super(Net, self).__init__() + self.shape = _shape + self.scatternd = P.ScatterNd() + + def construct(self, indices, update): + return self.scatternd(indices, update, self.shape) + +def scatternd_net(indices, update, _shape, expect): + scatternd = Net(_shape) + output = scatternd(Tensor(indices), Tensor(update)) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_scatternd(): + arr_indices = np.array([[0, 1], [1, 1]]).astype(np.int32) + arr_update = np.array([3.2, 1.1]).astype(np.float32) + shape = (2, 2) + expect = np.array([[0., 3.2], + [0., 1.1]]) + scatternd_net(arr_indices, arr_update, shape, expect) diff --git a/tests/st/ops/gpu/test_sgd_op.py b/tests/st/ops/gpu/test_sgd_op.py new file mode 100644 index 0000000000..85d470f50d --- /dev/null +++ b/tests/st/ops/gpu/test_sgd_op.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import SGD +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetSGD(nn.Cell): + def __init__(self): + super(NetSGD, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_SGD(): + epoch = 3 + net = NetSGD() + learning_rate = 0.1 + momentum = 0.9 + dampening = 0.0 + weight_decay = 0.0 + nesterov = True + loss_scale = 1.0 + + optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening, + weight_decay, nesterov, loss_scale) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses.append(loss.asnumpy()) + + last_loss = 100.0 + for loss in losses: + assert last_loss > loss + last_loss = loss + return losses diff --git a/tests/st/ops/gpu/test_split.py b/tests/st/ops/gpu/test_split.py new file mode 100644 index 0000000000..f9e3cfce2f --- /dev/null +++ b/tests/st/ops/gpu/test_split.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, axis=0, out_nums=1): + super(Net, self).__init__() + self.split = P.Split(axis, out_nums) + + def construct(self, x): + return self.split(x) + + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_split(): + x = np.array([[[1, -1, 1], [2, -2, 2]], + [[3, -3, 3], [4, -4, 4]], + [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) + + split_op = Net(0, 3) + outputs = split_op(Tensor(x)) + for i, out in enumerate(outputs): + assert (out.asnumpy() == x[i]).all() + + +def test_split_4d(): + x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) + y = np.split(x_np, 3, axis=1) + + split_op = Net(1, 3) + outputs = split_op(Tensor(x_np)) + + for i, out in enumerate(outputs): + assert (out.asnumpy() == y[i]).all() diff --git a/tests/st/ops/gpu/test_standard_normal.py b/tests/st/ops/gpu/test_standard_normal.py new file mode 100644 index 0000000000..efa4a99d74 --- /dev/null +++ b/tests/st/ops/gpu/test_standard_normal.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.context as context +import mindspore.nn as nn +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0, seed2=0): + super(Net, self).__init__() + self.shape = shape + self.seed = seed + self.seed2 = seed2 + self.stdnormal = P.StandardNormal(seed, seed2) + + def construct(self): + return self.stdnormal(self.shape) + + +def test_net(): + seed = 10 + seed2 = 10 + shape = (3, 2, 4) + net = Net(shape, seed, seed2) + output = net() + assert output.shape == (3, 2, 4) diff --git a/tests/st/ops/gpu/test_stridedslice_grad_op.py b/tests/st/ops/gpu/test_stridedslice_grad_op.py index 2ab8136857..39d31c53cb 100644 --- a/tests/st/ops/gpu/test_stridedslice_grad_op.py +++ b/tests/st/ops/gpu/test_stridedslice_grad_op.py @@ -19,31 +19,292 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor -from mindspore.common.api import ms_function from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C context.set_context(mode=context.GRAPH_MODE, device_target='GPU') -class StridedSliceGrad(nn.Cell): - def __init__(self): - super(StridedSliceGrad, self).__init__() - self.ssg = G.StridedSliceGrad() - self.shape = P.Shape() +class StridedSliceNet(nn.Cell): + def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0): + super(StridedSliceNet, self).__init__() + self.begin = begin + self.end = end + self.strides = stride + self.slice = P.StridedSlice(begin_mask, end_mask, ellipsis_mask) - @ms_function - def construct(self, dy, x): - return self.ssg(dy, self.shape(x), (2, 0, 0), (3, 2, 3), (1, 1, 1)) + def construct(self, x): + return self.slice(x, self.begin, self.end, self.strides) +class GradData(nn.Cell): + def __init__(self, network): + super(GradData, self).__init__() + self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=False) + self.network = network + + def construct(self, x): + return self.grad(self.network)(x) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_slice(): - x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) - dy = Tensor(np.array([[[5., 1., 5.], [6., 1., 8.]]]).astype(np.float32)) - ssg = StridedSliceGrad() - output = ssg(dy, x) - expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]] - assert (output.asnumpy() == expect).all() +def test_strided_slice_grad(): + x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32)) + net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) + dx = GradData(net)(x) + expect = np.array([[[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]], + + + [[[0., 0., 1., 1., 0.], + [0., 0., 1., 1., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 1., 1., 0.], + [0., 0., 1., 1., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]]]) + assert np.allclose(dx[0].asnumpy(), expect) + + net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) + dx = GradData(net)(x) + expect = np.array([[[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]], + + + [[[0., 0., 1., 0., 1.], + [0., 0., 1., 0., 1.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 1., 0., 1.], + [0., 0., 1., 0., 1.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]]]) + assert np.allclose(dx[0].asnumpy(), expect) + + + net = StridedSliceNet((1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1)) + dx = GradData(net)(x) + expect = np.array([[[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]], + + + [[[0., 0., 1., 1., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 1., 1., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]]]) + assert np.allclose(dx[0].asnumpy(), expect) + + # ME infer fault + # y = GradData()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2)) + # expect = np.array([[[[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]]], + + + # [[[0., 0., 0., 0., 0.], + # [0., 1., 0., 1., 0.], + # [0., 1., 0., 1., 0.], + # [0., 1., 0., 1., 0.]], + + # [[0., 0., 0., 0., 0.], + # [0., 1., 0., 1., 0.], + # [0., 1., 0., 1., 0.], + # [0., 1., 0., 1., 0.]],begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100 + + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]]]]) + # assert np.allclose(y.asnumpy(), expect) + + # y = Grad(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) + # expect = np.array([[[[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]]], + + + # [[[0., 0., 1., 1., 0.], + # [0., 0., 1., 1., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + + # [[0., 0., 1., 1., 0.], + # [0., 0., 1., 1., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]]]]) + # assert np.allclose(y.asnumpy(), expect) + + + net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1), + begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) + dx = GradData(net)(x) + expect = np.array([[[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]], + + + [[[1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.]], + + [[1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.]], + + [[1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.]]]]) + assert np.allclose(dx[0].asnumpy(), expect) + + x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) + net = StridedSliceNet((1, 0, 0), (2, -3, 3), (1, 1, 3)) + dx = GradData(net)(x) + expect = np.array([[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[1., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]]) + assert np.allclose(dx[0].asnumpy(), expect) + + x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(np.float32)) + net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1)) + dx = GradData(net)(x) + expect = np.array([[[[[[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]], + + [[[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 1., 1., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 1., 1., 0.], + [0., 0., 0., 0., 0.]]]]]]]) + assert np.allclose(dx[0].asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_stridedslice_op.py b/tests/st/ops/gpu/test_stridedslice_op.py index d6e4776931..098d18c9cb 100644 --- a/tests/st/ops/gpu/test_stridedslice_op.py +++ b/tests/st/ops/gpu/test_stridedslice_op.py @@ -17,29 +17,91 @@ import numpy as np import pytest import mindspore.context as context -import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target='GPU') -class StridedSlice(nn.Cell): - def __init__(self): - super(StridedSlice, self).__init__() - self.stridedslice = P.StridedSlice() - - def construct(self, x): - return self.stridedslice(x, (2, 0, 0), (3, 2, 3), (1, 1, 1)) - - @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_slice(): - x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.int32)) - stridedslice = StridedSlice() - output = stridedslice(x) - expect = [[[5., 5., 5.], - [6., 7., 8.]]] - assert (output.asnumpy() == expect).all() +def test_stridedslice(): + x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32)) + y = P.StridedSlice()(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) + expect = np.array([[[[62, 63], + [67, 68]], + [[82, 83], + [87, 88]]]]) + assert np.allclose(y.asnumpy(), expect) + + y = P.StridedSlice()(x, (1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) + expect = np.array([[[[64, 62], + [69, 67]], + [[84, 82], + [89, 87]]]]) + assert np.allclose(y.asnumpy(), expect) + + y = P.StridedSlice()(x, (1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1)) + expect = np.array([[[[64, 63, 62], + [69, 68, 67]], + [[84, 83, 82], + [89, 88, 87]]]]) + assert np.allclose(y.asnumpy(), expect) + + # ME infer fault + # y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2)) + # expect = np.array([[[[78, 76], + # [73, 71], + # [68, 66]], + # [[98, 96], + # [93, 91], + # [88, 86]]]]) + # assert np.allclose(y.asnumpy(), expect) + + # y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) + # expect = np.array([[[[ 62, 63], + # [ 67, 68]], + # [[ 82, 83], + # [ 87, 88]], + # [[102, 103], + # [107, 108]]]]) + # assert np.allclose(y.asnumpy(), expect) + + op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) + y = op(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) + expect = np.array([[[[60, 61, 62, 63], + [65, 66, 67, 68], + [70, 71, 72, 73], + [75, 76, 77, 78]], + [[80, 81, 82, 83], + [85, 86, 87, 88], + [90, 91, 92, 93], + [95, 96, 97, 98]], + [[100, 101, 102, 103], + [105, 106, 107, 108], + [110, 111, 112, 113], + [115, 116, 117, 118]]]]) + assert np.allclose(y.asnumpy(), expect) + + x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) + y = P.StridedSlice()(x, (1, 0, 0), (2, -3, 3), (1, 1, 3)) + expect = np.array([[[20]]]) + assert np.allclose(y.asnumpy(), expect) + + x_np = np.arange(0, 4*5).reshape(4, 5).astype(np.float32) + y = Tensor(x_np)[:, ::-1] + expect = x_np[:, ::-1] + assert np.allclose(y.asnumpy(), expect) + + x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(np.float32)) + y = P.StridedSlice()(x, (1, 0, 0, 2, 1, 2, 0), (2, 2, 2, 4, 2, 3, 2), (1, 1, 1, 1, 1, 1, 2)) + expect = np.array([[[[[[[1498.]]], + [[[1522.]]]], + [[[[1618.]]], + [[[1642.]]]]], + [[[[[1978.]]], + [[[2002.]]]], + [[[[2098.]]], + [[[2122.]]]]]]]) + assert np.allclose(y.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_topk_op.py b/tests/st/ops/gpu/test_topk_op.py new file mode 100644 index 0000000000..83cd8e6403 --- /dev/null +++ b/tests/st/ops/gpu/test_topk_op.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_topk(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + x_np = np.random.rand(3, 4).astype(np.float32) + k = 4 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(3, 4).astype(np.float32) + k = 4 + ms_output = P.TopK(False)(Tensor(x_np), k) + assert np.allclose(ms_output[0].asnumpy(), x_np) + + x_np = np.random.rand(2, 3, 4).astype(np.float32) + k = 2 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 1024).astype(np.float32) + k = 512 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + # sorted elements num greater than max thread per block + x_np = np.random.rand(512, 2048).astype(np.float32) + k = 1 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 2048).astype(np.float32) + k = 2048 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + # sorted elements num greater than max share memory per block + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 1 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 40960 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 40960 + ms_output = P.TopK(False)(Tensor(x_np), k) + assert np.allclose(ms_output[0].asnumpy(), x_np) diff --git a/tests/st/ops/gpu/test_uniform_real.py b/tests/st/ops/gpu/test_uniform_real.py new file mode 100644 index 0000000000..8fa4b0eb0b --- /dev/null +++ b/tests/st/ops/gpu/test_uniform_real.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.uniformreal = P.UniformReal(seed=seed) + self.shape = shape + + def construct(self, a, b): + return self.uniformreal(self.shape, a, b) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + a = 0.0 + b = 1.0 + net = Net(shape, seed) + ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 4) diff --git a/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh b/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh new file mode 100644 index 0000000000..e1afc1dc14 --- /dev/null +++ b/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +self_path=$(dirname "${script_self}") +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +DEVICE_TARGET=$1 +export MS_WORKER_NUM=$2 +export MS_SERVER_NUM=$3 +export MS_SCHED_HOST=$4 +export MS_SCHED_PORT=$5 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + python ${self_path}/../test_cmp_sparse_embedding.py & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + python ${self_path}/../test_cmp_sparse_embedding.py & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + python ${self_path}/../test_cmp_sparse_embedding.py & +done + +wait $! +exit $? diff --git a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py new file mode 100644 index 0000000000..c08b5b9936 --- /dev/null +++ b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py @@ -0,0 +1,106 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Adam +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore import Parameter + +parser = argparse.ArgumentParser(description="test_sparse_embedding") +parser.add_argument("--device_target", type=str, default="Ascend") +args, _ = parser.parse_known_args() +device_target = args.device_target +context.set_context( + mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True +) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10): + super(LeNet5, self).__init__() + self.cast = P.Cast() + self.flatten = nn.Flatten() + self.embedding_table = Parameter( + initializer("normal", (16, 4), mstype.float32), name="embedding_table" + ) + self.embedding = nn.EmbeddingLookup() + self.relu = nn.ReLU() + self.fc = fc_with_initialize(12, num_class) + + def construct(self, x): + x = self.cast(x, mstype.int32) + x = self.embedding(self.embedding_table, x) + x = self.flatten(x) + x = self.fc(x) + return x + + +def do_sparse_embedding(ps=False): + epoch = 10 + net = LeNet5(10) + if ps: + net.embedding_table.set_param_ps() + + optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) + optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") + criterion = nn.SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction="mean" + ) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.random.randint(0, 15, (32, 3), np.int32)) + label = Tensor(np.random.randint(0, 9, (32), np.int32)) + loss = train_network(data, label).asnumpy() + losses.append(loss) + print(losses) + return losses + + +envs = os.environ +if __name__ == "__main__": + np.random.seed(0) + ps_loss = do_sparse_embedding(True) + + if envs.get("MS_ROLE") == "MS_WORKER": + envs["MS_ROLE"] = "" + np.random.seed(0) + no_ps_loss = do_sparse_embedding() + envs["MS_ROLE"] = "MS_WORKER" + + assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6) diff --git a/tests/st/ps/cmp_sparse_embedding/test_entry_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_entry_cmp_sparse_embedding.py new file mode 100644 index 0000000000..bc400c963c --- /dev/null +++ b/tests/st/ps/cmp_sparse_embedding/test_entry_cmp_sparse_embedding.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_cmp_sparse_embedding(): + return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8081") + assert return_code == 0 diff --git a/tests/st/ps/full_ps/shell_run_test.sh b/tests/st/ps/full_ps/shell_run_test.sh new file mode 100644 index 0000000000..8222e76888 --- /dev/null +++ b/tests/st/ps/full_ps/shell_run_test.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +self_path=$(dirname "${script_self}") +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +DEVICE_TARGET=$1 +DATASET_PATH=$2 +export MS_WORKER_NUM=$3 +export MS_SERVER_NUM=$4 +export MS_SCHED_HOST=$5 +export MS_SCHED_PORT=$6 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & +done + +wait $! +exit $? diff --git a/tests/st/ps/full_ps/test_entry_full_ps_lenet.py b/tests/st/ps/full_ps/test_entry_full_ps_lenet.py new file mode 100644 index 0000000000..9d11a52bc6 --- /dev/null +++ b/tests/st/ps/full_ps/test_entry_full_ps_lenet.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_full_ps_ascend_lenet(): + return_code = os.system( + "bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082" + ) + assert return_code == 0 diff --git a/tests/st/ps/full_ps/test_full_ps_lenet.py b/tests/st/ps/full_ps/test_full_ps_lenet.py new file mode 100644 index 0000000000..fbf48e5fb8 --- /dev/null +++ b/tests/st/ps/full_ps/test_full_ps_lenet.py @@ -0,0 +1,137 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse + +import mindspore.context as context +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.nn as nn +from mindspore.common import dtype as mstype +from mindspore.dataset.transforms.vision import Inter +from mindspore.nn.metrics import Accuracy +from mindspore.train import Model +from mindspore.train.callback import LossMonitor +from mindspore.common.initializer import TruncatedNormal + +parser = argparse.ArgumentParser(description='test_ps_lenet') +parser.add_argument("--device_target", type=str, default="Ascend") +parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist") +args, _ = parser.parse_known_args() +device_target = args.device_target +dataset_path = args.dataset_path +context.set_context(mode=context.GRAPH_MODE, device_target=device_target) + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=1): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + rescale_nml = 1 / 0.3081 + shift_nml = -1 * 0.1307 / 0.3081 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode + rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + +if __name__ == "__main__": + network = LeNet5(10) + network.set_param_ps() + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1) + model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) + + ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1) + acc = model.eval(ds_eval, dataset_sink_mode=False) + + print("Accuracy:", acc['Accuracy']) + assert acc['Accuracy'] > 0.93 diff --git a/tests/st/ps/multi_worker_full_ps/entry.py b/tests/st/ps/multi_worker_full_ps/entry.py new file mode 100644 index 0000000000..e54623144d --- /dev/null +++ b/tests/st/ps/multi_worker_full_ps/entry.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os + +# @pytest.mark.level0 +# @pytest.mark.platform_arm_ascend_training +# @pytest.mark.platform_x86_ascend_training +# @pytest.mark.env_single +def test_multi_worker_full_ps_ascend_lenet(): + return_code = os.system("bash shell_run_test.sh Ascend 8 1 127.0.0.1 8088") + assert return_code == 0 + + +# @pytest.mark.level0 +# @pytest.mark.platform_arm_ascend_training +# @pytest.mark.platform_x86_ascend_training +# @pytest.mark.env_onecard +def test_full_ps_ascend_lenet(): + return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8088") + assert return_code == 0 diff --git a/tests/st/ps/multi_worker_full_ps/shell_run_test.sh b/tests/st/ps/multi_worker_full_ps/shell_run_test.sh new file mode 100644 index 0000000000..47cb5a4dcf --- /dev/null +++ b/tests/st/ps/multi_worker_full_ps/shell_run_test.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +self_path=$(dirname "${script_self}") +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +DEVICE_TARGET=$1 +export MS_WORKER_NUM=$2 +export MS_SERVER_NUM=$3 +export MS_SCHED_HOST=$4 +export MS_SCHED_PORT=$5 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & +done + +wait $! +exit $? diff --git a/tests/st/ps/multi_worker_full_ps/test_multi_worker_full_ps_lenet.py b/tests/st/ps/multi_worker_full_ps/test_multi_worker_full_ps_lenet.py new file mode 100644 index 0000000000..c08f923e0d --- /dev/null +++ b/tests/st/ps/multi_worker_full_ps/test_multi_worker_full_ps_lenet.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell + +parser = argparse.ArgumentParser(description="test_ps_lenet") +parser.add_argument("--device_target", type=str, default="Ascend") +args, _ = parser.parse_known_args() +device_target = args.device_target +context.set_context(mode=context.GRAPH_MODE, device_target=device_target) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_init=weight, + has_bias=False, + pad_mode="valid", + ) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=3): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +if __name__ == "__main__": + epoch = 5 + np.random.seed(0) + network = LeNet5(10) + network.set_param_ps() + criterion = nn.SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction="mean" + ) + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + + net_with_criterion = WithLossCell(network, criterion) + train_network = TrainOneStepCell(net_with_criterion, net_opt) + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) + label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32)) + loss = train_network(data, label).asnumpy() + losses.append(loss) + print(losses) diff --git a/tests/st/pynative/test_ops.py b/tests/st/pynative/test_ops.py index 3cec24fb10..c43e626be5 100644 --- a/tests/st/pynative/test_ops.py +++ b/tests/st/pynative/test_ops.py @@ -28,4 +28,4 @@ def test_cast(): type_dst = ms.float32 cast = P.Cast() result = cast(input_x, type_dst) - assert result.dtype() == type_dst + assert result.dtype == type_dst diff --git a/tests/st/pynative/test_pynative_resnet50.py b/tests/st/pynative/test_pynative_resnet50.py index de9ecebb9c..720dad3ec3 100644 --- a/tests/st/pynative/test_pynative_resnet50.py +++ b/tests/st/pynative/test_pynative_resnet50.py @@ -413,6 +413,7 @@ def test_pynative_resnet50(): step = 0 max_step = 20 + exceed_num = 0 data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) for element in data_set.create_dict_iterator(): step = step + 1 @@ -427,6 +428,7 @@ def test_pynative_resnet50(): end_time = time.time() cost_time = end_time - start_time print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - if step > 1: - assert cost_time < 0.3 - \ No newline at end of file + if step > 1 and cost_time > 0.21: + exceed_num = exceed_num + 1 + assert exceed_num < 10 + \ No newline at end of file diff --git a/tests/st/serving/__init__.py b/tests/st/serving/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/st/serving/client_example.py b/tests/st/serving/client_example.py new file mode 100644 index 0000000000..ae203aea39 --- /dev/null +++ b/tests/st/serving/client_example.py @@ -0,0 +1,98 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import random +import grpc +import numpy as np +import ms_service_pb2 +import ms_service_pb2_grpc +import mindspore.dataset as de +from mindspore import Tensor, context +from mindspore import log as logger +from tests.st.networks.models.bert.src.bert_model import BertModel +from .generate_model import AddNet, bert_net_cfg + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +def test_add(): + channel = grpc.insecure_channel('localhost:5500') + stub = ms_service_pb2_grpc.MSServiceStub(channel) + request = ms_service_pb2.PredictRequest() + + x = request.data.add() + x.tensor_shape.dims.extend([4]) + x.tensor_type = ms_service_pb2.MS_FLOAT32 + x.data = (np.ones([4]).astype(np.float32)).tobytes() + + y = request.data.add() + y.tensor_shape.dims.extend([4]) + y.tensor_type = ms_service_pb2.MS_FLOAT32 + y.data = (np.ones([4]).astype(np.float32)).tobytes() + + result = stub.Predict(request) + result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) + print("ms client received: ") + print(result_np) + + net = AddNet() + net_out = net(Tensor(np.ones([4]).astype(np.float32)), Tensor(np.ones([4]).astype(np.float32))) + print("add net out: ") + print(net_out) + assert np.allclose(net_out.asnumpy(), result_np, 0.001, 0.001, equal_nan=True) + +def test_bert(): + MAX_MESSAGE_LENGTH = 0x7fffffff + input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32) + segment_ids = np.zeros((2, 32), dtype=np.int32) + input_mask = np.zeros((2, 32), dtype=np.int32) + channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH), + ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)]) + stub = ms_service_pb2_grpc.MSServiceStub(channel) + request = ms_service_pb2.PredictRequest() + + x = request.data.add() + x.tensor_shape.dims.extend([2, 32]) + x.tensor_type = ms_service_pb2.MS_INT32 + x.data = input_ids.tobytes() + + y = request.data.add() + y.tensor_shape.dims.extend([2, 32]) + y.tensor_type = ms_service_pb2.MS_INT32 + y.data = segment_ids.tobytes() + + z = request.data.add() + z.tensor_shape.dims.extend([2, 32]) + z.tensor_type = ms_service_pb2.MS_INT32 + z.data = input_mask.tobytes() + + result = stub.Predict(request) + result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) + print("ms client received: ") + print(result_np) + + net = BertModel(bert_net_cfg, False) + bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask)) + print("bert out: ") + print(bert_out) + bert_out_size = len(bert_out) + for i in range(bert_out_size): + result_np = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims) + logger.info("i:{}, result_np:{}, bert_out:{}". + format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape)) + assert np.allclose(bert_out[i].asnumpy(), result_np, 0.001, 0.001, equal_nan=True) diff --git a/tests/st/serving/generate_model.py b/tests/st/serving/generate_model.py new file mode 100644 index 0000000000..f7d47392f6 --- /dev/null +++ b/tests/st/serving/generate_model.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import random +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.dataset as de +from mindspore import Tensor, context +from mindspore.ops import operations as P +from mindspore.train.serialization import export +from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig + +bert_net_cfg = BertConfig( + batch_size=2, + seq_length=32, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 +) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +class AddNet(nn.Cell): + def __init__(self): + super(AddNet, self).__init__() + self.add = P.TensorAdd() + + def construct(self, x_, y_): + return self.add(x_, y_) + +def export_add_model(): + net = AddNet() + x = np.ones(4).astype(np.float32) + y = np.ones(4).astype(np.float32) + export(net, Tensor(x), Tensor(y), file_name='add.pb', file_format='BINARY') + +def export_bert_model(): + net = BertModel(bert_net_cfg, False) + input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32) + segment_ids = np.zeros((2, 32), dtype=np.int32) + input_mask = np.zeros((2, 32), dtype=np.int32) + export(net, Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask), file_name='bert.pb', file_format='BINARY') + +if __name__ == '__main__': + export_add_model() + export_bert_model() diff --git a/tests/st/serving/serving.sh b/tests/st/serving/serving.sh new file mode 100644 index 0000000000..dd40293062 --- /dev/null +++ b/tests/st/serving/serving.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +export GLOG_v=1 +export DEVICE_ID=1 + +MINDSPORE_INSTALL_PATH=$1 +CURRPATH=$(cd $(dirname $0); pwd) +CURRUSER=$(whoami) +PROJECT_PATH=${CURRPATH}/../../../ +ENV_DEVICE_ID=$DEVICE_ID +echo "MINDSPORE_INSTALL_PATH:" ${MINDSPORE_INSTALL_PATH} +echo "CURRPATH:" ${CURRPATH} +echo "CURRUSER:" ${CURRUSER} +echo "PROJECT_PATH:" ${PROJECT_PATH} +echo "ENV_DEVICE_ID:" ${ENV_DEVICE_ID} + +MODEL_PATH=${CURRPATH}/model +export LD_LIBRARY_PATH=${MINDSPORE_INSTALL_PATH}/lib:/usr/local/python/python375/lib/:${LD_LIBRARY_PATH} +export PYTHONPATH=${MINDSPORE_INSTALL_PATH}/../:${PYTHONPATH} + +echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} +echo "PYTHONPATH: " ${PYTHONPATH} +echo "-------------show MINDSPORE_INSTALL_PATH----------------" +ls -l ${MINDSPORE_INSTALL_PATH} +echo "------------------show /usr/lib64/----------------------" +ls -l /usr/local/python/python375/lib/ + +clean_pid() +{ + ps aux | grep 'ms_serving' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 + if [ $? -ne 0 ] + then + echo "clean pip failed" + fi + sleep 6 +} + +prepare_model() +{ + echo "### begin to generate mode for serving test ###" + python3 generate_model.py &> generate_model_serving.log + echo "### end to generate mode for serving test ###" + result=`ls -l | grep -E '*pb' | grep -v ".log" | wc -l` + if [ ${result} -ne 2 ] + then + cat generate_model_serving.log + echo "### generate model for serving test failed ###" && exit 1 + clean_pid + fi + rm -rf model + mkdir model + mv *.pb ${CURRPATH}/model + cp ${MINDSPORE_INSTALL_PATH}/ms_serving ./ +} + +start_service() +{ + ${CURRPATH}/ms_serving --port=$1 --model_path=${MODEL_PATH} --model_name=$2 --device_id=$3 > $2_service.log 2>&1 & + if [ $? -ne 0 ] + then + echo "$2 faile to start." + fi + + result=`grep -E 'MS Serving listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l` + count=0 + while [[ ${result} -ne 1 && ${count} -lt 150 ]] + do + sleep 1 + count=$(($count+1)) + result=`grep -E 'MS Serving listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l` + done + + if [ ${count} -eq 150 ] + then + clean_pid + cat $2_service.log + echo "start serving service failed!" && exit 1 + fi + echo "### start serving service end ###" +} + +pytest_serving() +{ + unset http_proxy https_proxy + CLIENT_DEVICE_ID=$((${ENV_DEVICE_ID}+1)) + export DEVICE_ID=${CLIENT_DEVICE_ID} + local test_client_name=$1 + echo "### $1 client start ###" + python3 -m pytest -v -s client_example.py::${test_client_name} > ${test_client_name}_client.log 2>&1 + if [ $? -ne 0 ] + then + clean_pid + cat ${test_client_name}_client.log + echo "client $1 faile to start." + fi + echo "### $1 client end ###" +} + +test_add_model() +{ + start_service 5500 add.pb ${ENV_DEVICE_ID} + pytest_serving test_add + clean_pid +} + +test_bert_model() +{ + start_service 5500 bert.pb ${ENV_DEVICE_ID} + pytest_serving test_bert + clean_pid +} + +echo "-----serving start-----" +rm -rf ms_serving *.log *.pb *.dat ${CURRPATH}/model ${CURRPATH}/kernel_meta +prepare_model +test_add_model +test_bert_model diff --git a/tests/st/serving/test_serving.py b/tests/st/serving/test_serving.py new file mode 100644 index 0000000000..be94610ae5 --- /dev/null +++ b/tests/st/serving/test_serving.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import sys +import pytest +import numpy as np + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_single +def test_serving(): + """test_serving""" + sh_path = os.path.split(os.path.realpath(__file__))[0] + python_path_folders = [] + for python_path in sys.path: + if os.path.isdir(python_path): + python_path_folders += [python_path] + folders = [] + for folder in python_path_folders: + folders += [os.path.join(folder, x) for x in os.listdir(folder) \ + if os.path.isdir(os.path.join(folder, x)) and '/site-packages/mindspore' in os.path.join(folder, x)] + ret = os.system(f"sh {sh_path}/serving.sh {folders[0].split('mindspore', 1)[0] + 'mindspore'}") + assert np.allclose(ret, 0, 0.0001, 0.0001) + +if __name__ == '__main__': + test_serving() diff --git a/tests/st/tbe_networks/resnet_cifar.py b/tests/st/tbe_networks/resnet_cifar.py index cf9eb59400..c6b1ee0a78 100644 --- a/tests/st/tbe_networks/resnet_cifar.py +++ b/tests/st/tbe_networks/resnet_cifar.py @@ -136,7 +136,7 @@ if __name__ == '__main__': model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) if args_opt.do_train: - dataset = create_dataset(epoch_size) + dataset = create_dataset(1) batch_num = dataset.get_dataset_size() config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=10) ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck) diff --git a/tests/st/tbe_networks/test_resnet_cifar_1p.py b/tests/st/tbe_networks/test_resnet_cifar_1p.py index 672d17c72b..8ef48b8774 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_1p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_1p.py @@ -140,7 +140,7 @@ def train_process(epoch_size, num_classes, batch_size): model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) + dataset = create_dataset(1, training=True, batch_size=batch_size) loss_cb = LossGet() model.train(epoch_size, dataset, callbacks=[loss_cb]) diff --git a/tests/st/tbe_networks/test_resnet_cifar_8p.py b/tests/st/tbe_networks/test_resnet_cifar_8p.py index a13f367b9f..56d6a91d64 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_8p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_8p.py @@ -164,7 +164,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - dataset = create_dataset(epoch_size, training=True, + dataset = create_dataset(1, training=True, batch_size=batch_size, rank_id=device_id, rank_size=device_num, enable_hccl=enable_hccl) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 880a281037..eddf8f66ba 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -39,6 +39,7 @@ if(ENABLE_MINDDATA) dataset/filter_op_test.cc dataset/voc_op_test.cc dataset/manifest_op_test.cc + dataset/sentence_piece_vocab_op_test.cc ) list(REMOVE_ITEM UT_SRCS ${PYTHON_RELATED_SRCS}) endif() @@ -50,11 +51,16 @@ else() endif() endforeach () endif() +# removing serving ut +file(GLOB_RECURSE SERVING_ACL_UT_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} serving/*.cc) +list(REMOVE_ITEM UT_SRCS ${SERVING_ACL_UT_SRCS}) +add_subdirectory(serving) file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/core/base/*.cc" "../../../mindspore/core/abstract/*.cc" "../../../mindspore/core/ir/*.cc" + "../../../mindspore/core/utils/*.cc" "../../../mindspore/ccsrc/common/*.cc" "../../../mindspore/ccsrc/utils/*.cc" "../../../mindspore/ccsrc/pipeline/jit/parse/*.cc" @@ -76,6 +82,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/backend/session/kernel_graph.cc" "../../../mindspore/ccsrc/backend/session/session_basic.cc" "../../../mindspore/ccsrc/backend/session/session_factory.cc" + "../../../mindspore/ccsrc/backend/session/kernel_build_client.cc" "../../../mindspore/ccsrc/vm/*.cc" "../../../mindspore/ccsrc/pipeline/pynative/*.cc" "../../../mindspore/ccsrc/pybind_api/*.cc" @@ -122,6 +129,11 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc" ) +if (CMAKE_SYSTEM_NAME MATCHES "Windows") + list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/duplex_pipe.cc") +else() + list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/duplex_pipe_win.cc") +endif() list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/dump_proto.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/core/ir/lite/tensor.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") @@ -155,7 +167,7 @@ file(GLOB_RECURSE UT_SUTB_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "stub/ge/*.cc" ) -add_executable(ut_tests ${UT_SRCS} ${MINDSPORE_SRC_LIST} ${UT_SUTB_SRC_LIST}) +add_executable(ut_tests ${UT_SRCS} ${MINDSPORE_SRC_LIST} ${UT_SUTB_SRC_LIST} $) if (ENABLE_GE) if(ENABLE_TRAIN) @@ -180,3 +192,14 @@ if (USE_GLOG) endif() target_link_libraries(ut_tests PRIVATE securec graph) + +# link grpc +if (EXISTS ${grpc_ROOT}/lib64) + set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc") +else () + set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc") +endif () +find_package(gRPC CONFIG REQUIRED) +target_link_libraries(ut_tests PRIVATE gRPC::grpc++) +target_link_libraries(ut_tests PRIVATE gRPC::grpc++_reflection) +target_link_libraries(ut_tests PRIVATE protobuf::libprotobuf) \ No newline at end of file diff --git a/tests/ut/cpp/common/backend_common_test.h b/tests/ut/cpp/common/backend_common_test.h index f5bfc9d6dd..aeeccef2bc 100644 --- a/tests/ut/cpp/common/backend_common_test.h +++ b/tests/ut/cpp/common/backend_common_test.h @@ -16,7 +16,7 @@ #ifndef TESTS_UT_CPP_COMMON_UT_BACKEND_COMMON_H_ #define TESTS_UT_CPP_COMMON_UT_BACKEND_COMMON_H_ #include "common/common_test.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/session/kernel_graph.h" namespace mindspore { diff --git a/tests/ut/cpp/common/common_test.h b/tests/ut/cpp/common/common_test.h index a293584d7b..8490046f13 100644 --- a/tests/ut/cpp/common/common_test.h +++ b/tests/ut/cpp/common/common_test.h @@ -16,6 +16,9 @@ #ifndef TESTS_UT_COMMON_UT_COMMON_H_ #define TESTS_UT_COMMON_UT_COMMON_H_ +#include +#include +#include #include "gtest/gtest.h" namespace UT { class Common : public testing::Test { @@ -27,6 +30,47 @@ class Common : public testing::Test { // every TEST_F macro will enter one virtual void SetUp(); virtual void TearDown(); + + template + void PrintData(std::string name, T *output_data, int size) { + std::cout << "The " << name << " is as follows:" << std::endl; + if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << (int)output_data[i] << " "; + } + } else { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << output_data[i] << " "; + } + } + std::cout << std::endl; + } + + template + static void CompareOutputData(T *output_data, T *correct_data, int size, float err_bound) { + for (size_t i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + ASSERT_LE(abs, err_bound); + } + } + + void ReadFile(const char *file, size_t *size, char **buf) { + ASSERT_NE(nullptr, file); + ASSERT_NE(nullptr, size); + ASSERT_NE(nullptr, buf); + std::string path = std::string(file); + std::ifstream ifs(path); + ASSERT_EQ(true, ifs.good()); + ASSERT_EQ(true, ifs.is_open()); + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + *buf = new char[*size]; + + ifs.seekg(0, std::ios::beg); + ifs.read(*buf, *size); + ifs.close(); + } }; } // namespace UT #endif // TESTS_UT_COMMON_UT_COMMON_H_ diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index d864842760..ae9467cef1 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -72,6 +72,7 @@ class PyFuncGraphFetcher { mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { std::shared_ptr manager = mindspore::Manage(func_graph, false); + mindspore::parse::python_adapter::set_use_signature_in_resolve(false); mindspore::parse::ResolveAll(manager); } return func_graph; diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 8bbf42a640..bc6d772ad8 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -55,7 +55,7 @@ SET(DE_UT_SRCS resize_bilinear_op_test.cc resize_op_test.cc resize_with_bbox_op_test.cc - schema_test.cc + schema_test.cc shuffle_op_test.cc stand_alone_samplers_test.cc status_test.cc @@ -71,14 +71,12 @@ SET(DE_UT_SRCS subset_random_sampler_test.cc weighted_random_sampler_test.cc mnist_op_test.cc - manifest_op_test.cc - voc_op_test.cc cifar_op_test.cc celeba_op_test.cc take_op_test.cc clue_op_test.cc + csv_op_test.cc text_file_op_test.cc - filter_op_test.cc concat_op_test.cc jieba_tokenizer_op_test.cc tokenizer_op_test.cc @@ -91,9 +89,24 @@ SET(DE_UT_SRCS cyclic_array_test.cc perf_data_test.cc c_api_test.cc - tensor_op_fusion_pass_test.cc + tensor_op_fusion_pass_test.cc + sliding_window_op_test.cc + epoch_ctrl_op_test.cc + sentence_piece_vocab_op_test.cc + swap_red_blue_test.cc + distributed_sampler_test.cc ) +if (ENABLE_PYTHON) + set(DE_UT_SRCS + ${DE_UT_SRCS} + filter_op_test.cc + manifest_op_test.cc + voc_op_test.cc + sentence_piece_vocab_op_test.cc + ) +endif () + add_executable(de_ut_tests ${DE_UT_SRCS}) set_target_properties(de_ut_tests PROPERTIES INSTALL_RPATH "$ORIGIN/../lib:$ORIGIN/../lib64") diff --git a/tests/ut/cpp/dataset/batch_op_test.cc b/tests/ut/cpp/dataset/batch_op_test.cc index 3e1f3c0b32..69090ec6b6 100644 --- a/tests/ut/cpp/dataset/batch_op_test.cc +++ b/tests/ut/cpp/dataset/batch_op_test.cc @@ -18,7 +18,7 @@ #include #include "minddata/dataset/core/client.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" @@ -78,7 +78,10 @@ std::shared_ptr Build(std::vector &op = Batch(12); + EXPECT_EQ(op->Name(), "BatchOp"); + + auto tree = Build({TFReader(schema_file), op}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -90,8 +93,8 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) { rc = di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); std::shared_ptr t; - rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t); EXPECT_TRUE(rc.IsOk()); // verify the actual data in Tensor is correct EXPECT_EQ(*t == *tensor_map["col_sint64"], true); @@ -119,14 +122,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2, t3; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 7)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7), &t2); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 2)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 2), &t3); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -164,17 +167,17 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2, t3, t4; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 7)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7), &t2); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 2)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 2), &t3); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t4, TensorImpl::kFlexible, de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 9)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 9), &t4); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -216,11 +219,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 7)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7), &t2); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -262,11 +265,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 5)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 5), &t2); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -300,7 +303,7 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) { std::shared_ptr op; PadInfo m; std::shared_ptr pad_value; - Tensor::CreateTensor(&pad_value, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)); + Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_value); pad_value->SetItemAt({}, -1); m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)}); de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op); @@ -359,8 +362,8 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) { -1, -1}; std::shared_ptr t; - rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t); de::DatasetIterator di(tree); TensorMap tensor_map; rc = di.GetNextAsMap(&tensor_map); diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 902bc9a43b..11e84f7de4 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -20,7 +20,7 @@ #include #include "utils/log_adapter.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" @@ -29,15 +29,23 @@ #include "minddata/dataset/include/transforms.h" #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" + using namespace mindspore::dataset::api; using mindspore::MsLogLevel::ERROR; using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; using mindspore::dataset::Tensor; +using mindspore::dataset::TensorShape; +using mindspore::dataset::TensorImpl; +using mindspore::dataset::DataType; using mindspore::dataset::Status; using mindspore::dataset::BorderType; +using mindspore::dataset::dsize_t; class MindDataTestPipeline : public UT::DatasetOpTesting { @@ -46,25 +54,27 @@ class MindDataTestPipeline : public UT::DatasetOpTesting { TEST_F(MindDataTestPipeline, TestBatchAndRepeat) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestBatchAndRepeat."; + // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; std::shared_ptr ds = Mnist(folder_path, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 2; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -78,43 +88,53 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 10); + EXPECT_EQ(i, 10); // Manually terminate the pipeline iter->Stop(); } +TEST_F(MindDataTestPipeline, TestMnistFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFail1."; + + // Create a Mnist Dataset + std::shared_ptr ds = Mnist("", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTensorOpsAndMap."; + // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; std::shared_ptr ds = Mnist(folder_path, RandomSampler(false, 20)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr resize_op = vision::Resize({30, 30}); - EXPECT_TRUE(resize_op != nullptr); + EXPECT_NE(resize_op, nullptr); std::shared_ptr center_crop_op = vision::CenterCrop({16, 16}); - EXPECT_TRUE(center_crop_op != nullptr); + EXPECT_NE(center_crop_op, nullptr); // Create a Map operation on ds ds = ds->Map({resize_op, center_crop_op}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -128,44 +148,46 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 40); + EXPECT_EQ(i, 40); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUniformAugWithOps."; + // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; std::shared_ptr ds = Mnist(folder_path, RandomSampler(false, 20)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 1; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr resize_op = vision::Resize({30, 30}); - EXPECT_TRUE(resize_op != nullptr); + EXPECT_NE(resize_op, nullptr); std::shared_ptr random_crop_op = vision::RandomCrop({28, 28}); - EXPECT_TRUE(random_crop_op != nullptr); + EXPECT_NE(random_crop_op, nullptr); std::shared_ptr center_crop_op = vision::CenterCrop({16, 16}); - EXPECT_TRUE(center_crop_op != nullptr); + EXPECT_NE(center_crop_op, nullptr); std::shared_ptr uniform_aug_op = vision::UniformAugment({random_crop_op, center_crop_op}, 2); - EXPECT_TRUE(uniform_aug_op != nullptr); + EXPECT_NE(uniform_aug_op, nullptr); // Create a Map operation on ds ds = ds->Map({resize_op, uniform_aug_op}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -179,43 +201,45 @@ TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestRandomFlip) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomFlip."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr random_vertical_flip_op = vision::RandomVerticalFlip(0.5); - EXPECT_TRUE(random_vertical_flip_op != nullptr); + EXPECT_NE(random_vertical_flip_op, nullptr); std::shared_ptr random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5); - EXPECT_TRUE(random_horizontal_flip_op != nullptr); + EXPECT_NE(random_horizontal_flip_op, nullptr); // Create a Map operation on ds ds = ds->Map({random_vertical_flip_op, random_horizontal_flip_op}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -229,32 +253,34 @@ TEST_F(MindDataTestPipeline, TestRandomFlip) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderBatchAndRepeat."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 2; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -268,13 +294,23 @@ TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 10); + EXPECT_EQ(i, 10); // Manually terminate the pipeline iter->Stop(); } +TEST_F(MindDataTestPipeline, TestImageFolderFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderFail1."; + + // Create an ImageFolder Dataset + std::shared_ptr ds = ImageFolder("", true, nullptr); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderWithSamplers."; + std::shared_ptr sampl = DistributedSampler(2, 1); EXPECT_NE(sampl, nullptr); @@ -327,46 +363,58 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 12); + EXPECT_EQ(i, 12); // Manually terminate the pipeline iter->Stop(); } +TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { + std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; + std::shared_ptr sampl1 = SubsetRandomSampler(indices); + EXPECT_FALSE(indices.empty()); + EXPECT_NE(sampl1->Build(), nullptr); + std::shared_ptr sampl2 = SubsetRandomSampler(std::move(indices)); + EXPECT_TRUE(indices.empty()); + EXPECT_NE(sampl2->Build(), nullptr); +} + TEST_F(MindDataTestPipeline, TestPad) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr pad_op1 = vision::Pad({1, 2, 3, 4}, {0}, BorderType::kSymmetric); - EXPECT_TRUE(pad_op1 != nullptr); + EXPECT_NE(pad_op1, nullptr); std::shared_ptr pad_op2 = vision::Pad({1}, {1, 1, 1}, BorderType::kEdge); - EXPECT_TRUE(pad_op2 != nullptr); + EXPECT_NE(pad_op2, nullptr); std::shared_ptr pad_op3 = vision::Pad({1, 4}); - EXPECT_TRUE(pad_op3 != nullptr); + EXPECT_NE(pad_op3, nullptr); // Create a Map operation on ds ds = ds->Map({pad_op1, pad_op2, pad_op3}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -380,43 +428,45 @@ TEST_F(MindDataTestPipeline, TestPad) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestCutOut) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCutOut."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr cut_out1 = vision::CutOut(30, 5); - EXPECT_TRUE(cut_out1!= nullptr); + EXPECT_NE(cut_out1, nullptr); std::shared_ptr cut_out2 = vision::CutOut(30); - EXPECT_TRUE(cut_out2 != nullptr); + EXPECT_NE(cut_out2, nullptr); // Create a Map operation on ds ds = ds->Map({cut_out1, cut_out2}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -430,40 +480,42 @@ TEST_F(MindDataTestPipeline, TestCutOut) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestNormalize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalize."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr normalize = vision::Normalize({121.0, 115.0, 100.0}, {70.0, 68.0, 71.0}); - EXPECT_TRUE(normalize != nullptr); + EXPECT_NE(normalize, nullptr); // Create a Map operation on ds ds = ds->Map({normalize}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -477,40 +529,42 @@ TEST_F(MindDataTestPipeline, TestNormalize) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestDecode) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDecode."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, false, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr decode = vision::Decode(true); - EXPECT_TRUE(decode != nullptr); + EXPECT_NE(decode, nullptr); // Create a Map operation on ds ds = ds->Map({decode}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -530,30 +584,157 @@ TEST_F(MindDataTestPipeline, TestDecode) { } TEST_F(MindDataTestPipeline, TestShuffleDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestShuffleDataset."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Shuffle operation on ds int32_t shuffle_size = 10; ds = ds->Shuffle(shuffle_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 2; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestSkipDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Skip operation on ds + int32_t count = 3; + ds = ds->Skip(count); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + MS_LOG(INFO) << "Number of rows: " << i; + + // Expect 10-3=7 rows + EXPECT_EQ(i, 7); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestSkipDatasetError1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Skip operation on ds with invalid count input + int32_t count = -1; + ds = ds->Skip(count); + // Expect nullptr for invalid input skip_count + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestTakeDatasetDefault) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeDatasetDefault."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 7)); + EXPECT_NE(ds, nullptr); + + // Create a Take operation on ds, dafault count = -1 + ds = ds->Take(); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + MS_LOG(INFO) << "Number of rows: " << i; + + // Expect 7 rows + EXPECT_EQ(i, 7); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTakeDatasetNormal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeDatasetNormal."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 8)); + EXPECT_NE(ds, nullptr); + + // Create a Take operation on ds + ds = ds->Take(5); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -566,39 +747,50 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) { MS_LOG(INFO) << "Tensor image shape: " << image->shape(); iter->GetNextRow(&row); } + MS_LOG(INFO) << "Number of rows: " << i; - EXPECT_TRUE(i == 10); + // Expect 5 rows + EXPECT_EQ(i, 5); // Manually terminate the pipeline iter->Stop(); } +TEST_F(MindDataTestPipeline, TestTakeDatasetError1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeDatasetError1."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Take operation on ds with invalid count input + int32_t count = -5; + ds = ds->Take(count); + // Expect nullptr for invalid input take_count + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestCifar10Dataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10Dataset."; // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, 0, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); - - // Create a Repeat operation on ds - int32_t repeat_num = 2; - ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); - - // Create a Batch operation on ds - int32_t batch_size = 2; - ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; iter->GetNextRow(&row); + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + uint64_t i = 0; while (row.size() != 0) { i++; @@ -607,51 +799,104 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 10); + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetFail1."; + + // Create a Cifar10 Dataset + std::shared_ptr ds = Cifar10("", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestCifar100Dataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset."; + + // Create a Cifar100 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; + std::shared_ptr ds = Cifar100(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("coarse_label"), row.end()); + EXPECT_NE(row.find("fine_label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 10); // Manually terminate the pipeline iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1."; + + // Create a Cifar100 Dataset + std::shared_ptr ds = Cifar100("", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomColorAdjust."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5}); - EXPECT_TRUE(random_color_adjust1 != nullptr); + EXPECT_NE(random_color_adjust1, nullptr); std::shared_ptr random_color_adjust2 = vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5}, {0.5, 0.5}); - EXPECT_TRUE(random_color_adjust2 != nullptr); + EXPECT_NE(random_color_adjust2, nullptr); std::shared_ptr random_color_adjust3 = vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5}, {0.25, 0.5}); - EXPECT_TRUE(random_color_adjust3 != nullptr); + EXPECT_NE(random_color_adjust3, nullptr); std::shared_ptr random_color_adjust4 = vision::RandomColorAdjust(); - EXPECT_TRUE(random_color_adjust4 != nullptr); + EXPECT_NE(random_color_adjust4, nullptr); // Create a Map operation on ds ds = ds->Map({random_color_adjust1, random_color_adjust2, random_color_adjust3, random_color_adjust4}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -665,40 +910,42 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestRandomRotation) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomRotation."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr random_rotation_op = vision::RandomRotation({-180, 180}); - EXPECT_TRUE(random_rotation_op != nullptr); + EXPECT_NE(random_rotation_op, nullptr); // Create a Map operation on ds ds = ds->Map({random_rotation_op}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -712,45 +959,47 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); } TEST_F(MindDataTestPipeline, TestProjectMap) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectMap."; + // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds int32_t repeat_num = 2; ds = ds->Repeat(repeat_num); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create objects for the tensor ops std::shared_ptr random_vertical_flip_op = vision::RandomVerticalFlip(0.5); - EXPECT_TRUE(random_vertical_flip_op != nullptr); + EXPECT_NE(random_vertical_flip_op, nullptr); // Create a Map operation on ds ds = ds->Map({random_vertical_flip_op}, {}, {}, {"image", "label"}); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Project operation on ds std::vector column_project = {"image"}; ds = ds->Project(column_project); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create a Batch operation on ds int32_t batch_size = 1; ds = ds->Batch(batch_size); - EXPECT_TRUE(ds != nullptr); + EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds->CreateIterator(); - EXPECT_TRUE(iter != nullptr); + EXPECT_NE(iter, nullptr); // Iterate the dataset and get each row std::unordered_map> row; @@ -764,8 +1013,925 @@ TEST_F(MindDataTestPipeline, TestProjectMap) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestZipSuccess) { + // Testing the member zip() function + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipSuccess."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Project operation on ds + std::vector column_project = {"image"}; + ds = ds->Project(column_project); + EXPECT_NE(ds, nullptr); + + // Create an ImageFolder Dataset + std::shared_ptr ds1 = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds1, nullptr); + + // Create a Rename operation on ds (so that the 3 datasets we are going to zip have distinct column names) + ds1 = ds1->Rename({"image", "label"}, {"col1", "col2"}); + EXPECT_NE(ds1, nullptr); + + folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds2 = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds2, nullptr); + + // Create a Project operation on ds + column_project = {"label"}; + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Zip operation on the datasets + ds = ds->Zip({ds1, ds2}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check zipped column names + EXPECT_EQ(row.size(), 4); + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + EXPECT_NE(row.find("col1"), row.end()); + EXPECT_NE(row.find("col2"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestZipSuccess2) { + // Testing the static zip() function + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipSuccess2."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 9)); + EXPECT_NE(ds, nullptr); + std::shared_ptr ds2 = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds2, nullptr); + + // Create a Rename operation on ds (so that the 2 datasets we are going to zip have distinct column names) + ds = ds->Rename({"image", "label"}, {"col1", "col2"}); + EXPECT_NE(ds, nullptr); + + // Create a Zip operation on the datasets + ds = Zip({ds, ds2}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check zipped column names + EXPECT_EQ(row.size(), 4); + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + EXPECT_NE(row.find("col1"), row.end()); + EXPECT_NE(row.find("col2"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 9); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestZipFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipFail."; + // We expect this test to fail because we are the both datasets we are zipping have "image" and "label" columns + // and zip doesn't accept datasets with same column names + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create an ImageFolder Dataset + std::shared_ptr ds1 = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds1, nullptr); + + // Create a Zip operation on the datasets + ds = Zip({ds, ds1}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} + +TEST_F(MindDataTestPipeline, TestZipFail2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipFail2."; + // This case is expected to fail because the input dataset is empty. + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Zip operation on the datasets + // Input dataset to zip is empty + ds = Zip({}); + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestRenameSuccess) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameSuccess."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create a Rename operation on ds + ds = ds->Rename({"image", "label"}, {"col1", "col2"}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + EXPECT_NE(row.find("col1"), row.end()); + EXPECT_NE(row.find("col2"), row.end()); + EXPECT_EQ(row.find("image"), row.end()); + EXPECT_EQ(row.find("label"), row.end()); + + while (row.size() != 0) { + i++; + auto image = row["col1"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRenameFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail."; + // We expect this test to fail because input and output in Rename are not the same size + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create a Rename operation on ds + ds = ds->Rename({"image", "label"}, {"col2"}); + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestVOCSegmentation) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCSegmentation."; + + // Create a VOC Dataset + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::shared_ptr ds = VOC(folder_path, "Segmentation", "train", {}, false, SequentialSampler(0, 3)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if VOCOp read correct images/targets + using Tensor = mindspore::dataset::Tensor; + std::string expect_file[] = {"32", "33", "39", "32", "33", "39"}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto target = row["target"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Tensor target shape: " << target->shape(); + + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + "/JPEGImages/" + expect_file[i] + ".jpg", &expect_image); + EXPECT_EQ(*image, *expect_image); + + std::shared_ptr expect_target; + Tensor::CreateFromFile(folder_path + "/SegmentationClass/" + expect_file[i] + ".png", &expect_target); + EXPECT_EQ(*target, *expect_target); + + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestVOCSegmentationError1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCSegmentationError1."; + + // Create a VOC Dataset + std::map class_index; + class_index["car"] = 0; + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::shared_ptr ds = VOC(folder_path, "Segmentation", "train", class_index, false, RandomSampler(false, 6)); + + // Expect nullptr for segmentation task with class_index + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestVOCInvalidTaskOrMode) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCInvalidTaskOrMode."; + + // Create a VOC Dataset + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::shared_ptr ds_1 = VOC(folder_path, "Classification", "train", {}, false, SequentialSampler(0, 3)); + // Expect nullptr for invalid task + EXPECT_EQ(ds_1, nullptr); + + std::shared_ptr ds_2 = VOC(folder_path, "Segmentation", "validation", {}, false, RandomSampler(false, 4)); + // Expect nullptr for invalid mode + EXPECT_EQ(ds_2, nullptr); +} + +TEST_F(MindDataTestPipeline, TestVOCDetection) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCDetection."; + + // Create a VOC Dataset + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::shared_ptr ds = VOC(folder_path, "Detection", "train", {}, false, SequentialSampler(0, 4)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if VOCOp read correct images/labels + std::string expect_file[] = {"15", "32", "33", "39"}; + uint32_t expect_num[] = {5, 5, 4, 3}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Tensor label shape: " << label->shape(); + + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + "/JPEGImages/" + expect_file[i] + ".jpg", &expect_image); + EXPECT_EQ(*image, *expect_image); + + std::shared_ptr expect_label; + Tensor::CreateFromMemory(TensorShape({1, 1}), DataType(DataType::DE_UINT32), nullptr, &expect_label); + expect_label->SetItemAt({0, 0}, expect_num[i]); + EXPECT_EQ(*label, *expect_label); + + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 4); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestVOCClassIndex) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCClassIndex."; + + // Create a VOC Dataset + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::map class_index; + class_index["car"] = 0; + class_index["cat"] = 1; + class_index["train"] = 9; + + std::shared_ptr ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if VOCOp read correct labels + // When we provide class_index, label of ["car","cat","train"] become [0,1,9] + std::shared_ptr expect_label; + Tensor::CreateFromMemory(TensorShape({1, 1}), DataType(DataType::DE_UINT32), nullptr, &expect_label); + + uint32_t expect[] = {9, 9, 9, 1, 1, 0}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Tensor label shape: " << label->shape(); + expect_label->SetItemAt({0, 0}, expect[i]); + EXPECT_EQ(*label, *expect_label); + + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCocoDetection) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoDetection."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file, "Detection", false, SequentialSampler(0, 6)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + std::string expect_file[] = {"000000391895", "000000318219", "000000554625", "000000574769", "000000060623", + "000000309022"}; + std::vector> expect_bbox_vector = {{10.0, 10.0, 10.0, 10.0, 70.0, 70.0, 70.0, 70.0}, + {20.0, 20.0, 20.0, 20.0, 80.0, 80.0, 80.0, 80.0}, + {30.0, 30.0, 30.0, 30.0}, {40.0, 40.0, 40.0, 40.0}, + {50.0, 50.0, 50.0, 50.0}, {60.0, 60.0, 60.0, 60.0}}; + std::vector> expect_catagoryid_list = {{1, 7}, {2, 8}, {3}, {4}, {5}, {6}}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto bbox = row["bbox"]; + auto category_id = row["category_id"]; + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + "/" + expect_file[i] + ".jpg", &expect_image); + EXPECT_EQ(*image, *expect_image); + std::shared_ptr expect_bbox; + dsize_t bbox_num = static_cast(expect_bbox_vector[i].size() / 4); + Tensor::CreateFromVector(expect_bbox_vector[i], TensorShape({bbox_num, 4}), &expect_bbox); + EXPECT_EQ(*bbox, *expect_bbox); + std::shared_ptr expect_categoryid; + Tensor::CreateFromVector(expect_catagoryid_list[i], TensorShape({bbox_num, 1}), &expect_categoryid); + EXPECT_EQ(*category_id, *expect_categoryid); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCocoStuff) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoStuff."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file, "Stuff", false, SequentialSampler(0, 6)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + std::string expect_file[] = {"000000391895", "000000318219", "000000554625", "000000574769", "000000060623", + "000000309022"}; + std::vector> expect_segmentation_vector = + {{10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, + 70.0, 72.0, 73.0, 74.0, 75.0, -1.0, -1.0, -1.0, -1.0, -1.0}, + {20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, + 10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0}, + {40.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 40.0, 41.0, 42.0}, + {50.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}, + {60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}, + {60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}}; + std::vector> expect_size = {{2, 10}, {2, 11}, {1, 12}, {1, 13}, {1, 14}, {2, 7}}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto segmentation = row["segmentation"]; + auto iscrowd = row["iscrowd"]; + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + "/" + expect_file[i] + ".jpg", &expect_image); + EXPECT_EQ(*image, *expect_image); + std::shared_ptr expect_segmentation; + Tensor::CreateFromVector(expect_segmentation_vector[i], TensorShape(expect_size[i]), &expect_segmentation); + EXPECT_EQ(*segmentation, *expect_segmentation); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCocoKeypoint) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoKeypoint."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/key_point.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file, "Keypoint", false, SequentialSampler(0, 2)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + std::string expect_file[] = {"000000391895", "000000318219"}; + std::vector> expect_keypoint_vector = + {{368.0, 61.0, 1.0, 369.0, 52.0, 2.0, 0.0, 0.0, 0.0, 382.0, 48.0, 2.0, 0.0, 0.0, 0.0, 368.0, 84.0, 2.0, 435.0, + 81.0, 2.0, 362.0, 125.0, 2.0, 446.0, 125.0, 2.0, 360.0, 153.0, 2.0, 0.0, 0.0, 0.0, 397.0, 167.0, 1.0, 439.0, + 166.0, 1.0, 369.0, 193.0, 2.0, 461.0, 234.0, 2.0, 361.0, 246.0, 2.0, 474.0, 287.0, 2.0}, + {244.0, 139.0, 2.0, 0.0, 0.0, 0.0, 226.0, 118.0, 2.0, 0.0, 0.0, 0.0, 154.0, 159.0, 2.0, 143.0, 261.0, 2.0, 135.0, + 312.0, 2.0, 271.0, 423.0, 2.0, 184.0, 530.0, 2.0, 261.0, 280.0, 2.0, 347.0, 592.0, 2.0, 0.0, 0.0, 0.0, 123.0, + 596.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; + std::vector> expect_size = {{1, 51}, {1, 51}}; + std::vector> expect_num_keypoints_list = {{14}, {10}}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto keypoints = row["keypoints"]; + auto num_keypoints = row["num_keypoints"]; + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + "/" + expect_file[i] + ".jpg", &expect_image); + EXPECT_EQ(*image, *expect_image); + std::shared_ptr expect_keypoints; + dsize_t keypoints_size = expect_size[i][0]; + Tensor::CreateFromVector(expect_keypoint_vector[i], TensorShape(expect_size[i]), &expect_keypoints); + EXPECT_EQ(*keypoints, *expect_keypoints); + std::shared_ptr expect_num_keypoints; + Tensor::CreateFromVector(expect_num_keypoints_list[i], TensorShape({keypoints_size, 1}), &expect_num_keypoints); + EXPECT_EQ(*num_keypoints, *expect_num_keypoints); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCocoPanoptic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoPanoptic."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/panoptic.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file, "Panoptic", false, SequentialSampler(0, 2)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + std::string expect_file[] = {"000000391895", "000000574769"}; + std::vector> expect_bbox_vector = {{472, 173, 36, 48, 340, 22, 154, 301, 486, 183, 30, 35}, + {103, 133, 229, 422, 243, 175, 93, 164}}; + std::vector> expect_categoryid_vector = {{1, 1, 2}, {1, 3}}; + std::vector> expect_iscrowd_vector = {{0, 0, 0}, {0, 0}}; + std::vector> expect_area_vector = {{705, 14062, 626}, {43102, 6079}}; + std::vector> expect_size = {{3, 4}, {2, 4}}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto bbox = row["bbox"]; + auto category_id = row["category_id"]; + auto iscrowd = row["iscrowd"]; + auto area = row["area"]; + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + "/" + expect_file[i] + ".jpg", &expect_image); + EXPECT_EQ(*image, *expect_image); + std::shared_ptr expect_bbox; + dsize_t bbox_size = expect_size[i][0]; + Tensor::CreateFromVector(expect_bbox_vector[i], TensorShape(expect_size[i]), &expect_bbox); + EXPECT_EQ(*bbox, *expect_bbox); + std::shared_ptr expect_categoryid; + Tensor::CreateFromVector(expect_categoryid_vector[i], TensorShape({bbox_size, 1}), &expect_categoryid); + EXPECT_EQ(*category_id, *expect_categoryid); + std::shared_ptr expect_iscrowd; + Tensor::CreateFromVector(expect_iscrowd_vector[i], TensorShape({bbox_size, 1}), &expect_iscrowd); + EXPECT_EQ(*iscrowd, *expect_iscrowd); + std::shared_ptr expect_area; + Tensor::CreateFromVector(expect_area_vector[i], TensorShape({bbox_size, 1}), &expect_area); + EXPECT_EQ(*area, *expect_area); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCocoDefault) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoDefault."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto bbox = row["bbox"]; + auto category_id = row["category_id"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Tensor bbox shape: " << bbox->shape(); + MS_LOG(INFO) << "Tensor category_id shape: " << category_id->shape(); + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCocoException) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoException."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; + std::string invalid_folder_path = "./NotExist"; + std::string invalid_annotation_file = "./NotExistFile"; + + std::shared_ptr ds = Coco(invalid_folder_path, annotation_file); + EXPECT_EQ(ds, nullptr); + + std::shared_ptr ds1 = Coco(folder_path, invalid_annotation_file); + EXPECT_EQ(ds1, nullptr); + + std::shared_ptr ds2 = Coco(folder_path, annotation_file, "valid_mode"); + EXPECT_EQ(ds2, nullptr); +} + +TEST_F(MindDataTestPipeline, TestConcatSuccess) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess."; + + // Create an ImageFolder Dataset + // Column names: {"image", "label"} + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Cifar10 Dataset + // Column names: {"image", "label"} + folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds2 = Cifar10(folder_path, RandomSampler(false, 9)); + EXPECT_NE(ds2, nullptr); + + // Create a Project operation on ds + ds = ds->Project({"image"}); + EXPECT_NE(ds, nullptr); + ds2 = ds2->Project({"image"}); + EXPECT_NE(ds, nullptr); + + // Create a Concat operation on the ds + ds = ds->Concat({ds2}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 19); + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestConcatSuccess2) { + // Test "+" operator to concat two datasets + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess2."; + + // Create an ImageFolder Dataset + // Column names: {"image", "label"} + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Cifar10 Dataset + // Column names: {"image", "label"} + folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds2 = Cifar10(folder_path, RandomSampler(false, 9)); + EXPECT_NE(ds2, nullptr); + + // Create a Project operation on ds + ds = ds->Project({"image"}); + EXPECT_NE(ds, nullptr); + ds2 = ds2->Project({"image"}); + EXPECT_NE(ds, nullptr); + + // Create a Concat operation on the ds + ds = ds + ds2; + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 19); + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestConcatFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail1."; + // This case is expected to fail because the input column names of concatenated datasets are not the same + + // Create an ImageFolder Dataset + // Column names: {"image", "label"} + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + std::shared_ptr ds2 = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Rename operation on ds + ds2 = ds2->Rename({"image", "label"}, {"col1", "col2"}); + EXPECT_NE(ds, nullptr); + + // Create a Project operation on the ds + // Name of datasets to concat doesn't not match + ds = ds->Concat({ds2}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} + +TEST_F(MindDataTestPipeline, TestConcatFail2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail2."; + // This case is expected to fail because the input dataset is empty. + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Project operation on the ds + // Input dataset to concat is empty + ds = ds->Concat({}); + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestCelebADataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebADataset."; + + // Create a CelebA Dataset + std::string folder_path = datasets_root_path_ + "/testCelebAData/"; + std::shared_ptr ds = CelebA(folder_path, "all", SequentialSampler(0, 2), false, {}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if CelebAOp read correct images/attr + std::string expect_file[] = {"1.JPEG", "2.jpg"}; + std::vector> expect_attr_vector = + {{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, + 1, 0, 0, 1}, {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 1}}; + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto attr = row["attr"]; + + std::shared_ptr expect_image; + Tensor::CreateFromFile(folder_path + expect_file[i], &expect_image); + EXPECT_EQ(*image, *expect_image); + + std::shared_ptr expect_attr; + Tensor::CreateFromVector(expect_attr_vector[i], TensorShape({40}), &expect_attr); + EXPECT_EQ(*attr, *expect_attr); + + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCelebADefault) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebADefault."; + + // Create a CelebA Dataset + std::string folder_path = datasets_root_path_ + "/testCelebAData/"; + std::shared_ptr ds = CelebA(folder_path); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if CelebAOp read correct images/attr + uint64_t i = 0; + while (row.size() != 0) { + auto image = row["image"]; + auto attr = row["attr"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + MS_LOG(INFO) << "Tensor attr shape: " << attr->shape(); + + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCelebAException) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAException."; + + // Create a CelebA Dataset + std::string folder_path = datasets_root_path_ + "/testCelebAData/"; + std::string invalid_folder_path = "./testNotExist"; + std::string invalid_dataset_type = "invalid_type"; + std::shared_ptr ds = CelebA(invalid_folder_path); + EXPECT_EQ(ds, nullptr); + std::shared_ptr ds1 = CelebA(folder_path, invalid_dataset_type); + EXPECT_EQ(ds1, nullptr); } \ No newline at end of file diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index bdb7c861b2..26db41ef66 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -75,7 +75,8 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { EXPECT_TRUE(rc.IsOk()); // Create a tensor, take a snapshot and restore it back, and compare. - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); t->SetItemAt({0, 0}, 1); t->SetItemAt({0, 1}, 2); t->SetItemAt({0, 2}, 3); @@ -129,7 +130,8 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { rc = myClient.CreateCache(1, true); EXPECT_TRUE(rc.IsOk()); std::cout << myClient << std::endl; - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); t->SetItemAt({0, 0}, 1); t->SetItemAt({0, 1}, 2); t->SetItemAt({0, 2}, 3); @@ -397,23 +399,17 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { std::shared_ptr myClient = std::make_shared(1, 0, true); - std::shared_ptr myMergeOp; - rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build( - &myMergeOp); - EXPECT_TRUE(rc.IsOk()); + // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. + // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by + // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and + // replace it with the required tree structures for cache lookup op and cache merge op. - std::shared_ptr myLookupOp; - rc = CacheLookupOp::Builder() - .SetNumWorkers(3) - .SetOpConnectorSize(3) - .SetClient(myClient) - .SetSampler(seq_sampler) - .Build(&myLookupOp); - EXPECT_TRUE(rc.IsOk()); + std::shared_ptr myCacheOp; + rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); std::shared_ptr so; ImageFolderOp::Builder builder; - builder.SetSampler(myLookupOp) + builder.SetSampler(std::move(seq_sampler)) .SetOpConnectorSize(3) .SetNumWorkers(3) .SetRowsPerBuffer(2) @@ -432,20 +428,18 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { auto myTree = std::make_shared(); rc = myTree->AssociateNode(so); EXPECT_TRUE(rc.IsOk()); - rc = myTree->AssociateNode(myLookupOp); - EXPECT_TRUE(rc.IsOk()); - rc = myTree->AssociateNode(myMergeOp); + + rc = myTree->AssociateNode(myCacheOp); EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); EXPECT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); EXPECT_TRUE(rc.IsOk()); - rc = myRepeatOp->AddChild(myMergeOp); - EXPECT_TRUE(rc.IsOk()); - rc = myMergeOp->AddChild(myLookupOp); + rc = myRepeatOp->AddChild(myCacheOp); EXPECT_TRUE(rc.IsOk()); - rc = myMergeOp->AddChild(so); + rc = myCacheOp->AddChild(so); EXPECT_TRUE(rc.IsOk()); rc = myTree->Prepare(); diff --git a/tests/ut/cpp/dataset/center_crop_op_test.cc b/tests/ut/cpp/dataset/center_crop_op_test.cc index cd0f362f64..20e2be3991 100644 --- a/tests/ut/cpp/dataset/center_crop_op_test.cc +++ b/tests/ut/cpp/dataset/center_crop_op_test.cc @@ -20,17 +20,17 @@ #include "utils/log_adapter.h" using namespace mindspore::dataset; -using mindspore::MsLogLevel::INFO; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; class MindDataTestCenterCropOp : public UT::CVOP::CVOpCommon { public: MindDataTestCenterCropOp() : CVOpCommon() {} }; -TEST_F(MindDataTestCenterCropOp, TestOp) { - MS_LOG(INFO) << "Doing MindDataTestCenterCropOp::TestOp."; +TEST_F(MindDataTestCenterCropOp, TestOp1) { + MS_LOG(INFO) << "Doing MindDataTestCenterCropOp::TestOp1."; std::shared_ptr output_tensor; int het = 256; int wid = 128; @@ -42,3 +42,16 @@ TEST_F(MindDataTestCenterCropOp, TestOp) { EXPECT_EQ(wid, output_tensor->shape()[1]); std::shared_ptr p = CVTensor::AsCVTensor(output_tensor); } + +TEST_F(MindDataTestCenterCropOp, TestOp2) { + MS_LOG(INFO) << "MindDataTestCenterCropOp::TestOp2. Cap valid crop size at 10 times the input size"; + std::shared_ptr output_tensor; + + int64_t wid = input_tensor_->shape()[0] * 10 + 1; + int64_t het = input_tensor_->shape()[1] * 10 + 1; + + std::unique_ptr op(new CenterCropOp(het, wid)); + Status s = op->Compute(input_tensor_, &output_tensor); + EXPECT_TRUE(s.IsError()); + ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); +} diff --git a/tests/ut/cpp/dataset/channel_swap_test.cc b/tests/ut/cpp/dataset/channel_swap_test.cc index 2000de15b2..0fae8417e1 100644 --- a/tests/ut/cpp/dataset/channel_swap_test.cc +++ b/tests/ut/cpp/dataset/channel_swap_test.cc @@ -36,7 +36,7 @@ TEST_F(MindDataTestChannelSwap, TestOp) { int size_buffer = s[0] * s[1] * s[2]; std::unique_ptr output_buffer(new uchar[size_buffer]); - std::shared_ptr output_tensor(new Tensor(s, DataType(DataType::DE_UINT8))); + std::shared_ptr output_tensor; // Decoding std::unique_ptr op(new HwcToChwOp()); diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index ed22f4f347..0b6b4099a4 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -19,7 +19,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h" diff --git a/tests/ut/cpp/dataset/circular_pool_test.cc b/tests/ut/cpp/dataset/circular_pool_test.cc index d06f846684..ecad6e4d7b 100644 --- a/tests/ut/cpp/dataset/circular_pool_test.cc +++ b/tests/ut/cpp/dataset/circular_pool_test.cc @@ -19,7 +19,7 @@ #include "minddata/dataset/util/circular_pool.h" #include "minddata/dataset/util/services.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "utils/log_adapter.h" #include "./securec.h" diff --git a/tests/ut/cpp/dataset/clue_op_test.cc b/tests/ut/cpp/dataset/clue_op_test.cc index 0935434a06..170e54b7f2 100644 --- a/tests/ut/cpp/dataset/clue_op_test.cc +++ b/tests/ut/cpp/dataset/clue_op_test.cc @@ -19,7 +19,7 @@ #include "minddata/dataset/core/client.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h" diff --git a/tests/ut/cpp/dataset/coco_op_test.cc b/tests/ut/cpp/dataset/coco_op_test.cc index 6e6d3c26e5..3c786e6f81 100644 --- a/tests/ut/cpp/dataset/coco_op_test.cc +++ b/tests/ut/cpp/dataset/coco_op_test.cc @@ -19,7 +19,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h" diff --git a/tests/ut/cpp/dataset/common/bboxop_common.cc b/tests/ut/cpp/dataset/common/bboxop_common.cc index 62c9f85348..b16c4c0615 100644 --- a/tests/ut/cpp/dataset/common/bboxop_common.cc +++ b/tests/ut/cpp/dataset/common/bboxop_common.cc @@ -25,7 +25,7 @@ #include "./tinyxml2.h" #include "opencv2/opencv.hpp" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/cv_tensor.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/core/constants.h" @@ -163,8 +163,11 @@ void BBoxOpCommon::CompareActualAndExpected(const std::string &op_name) { // after comparison is done remove temporary file EXPECT_TRUE(remove(actual_path.c_str()) == 0); // compare using ==operator by Tensor + std::shared_ptr expect_img_t, actual_img_t; + CVTensor::CreateFromMat(expect_img, &expect_img_t); + CVTensor::CreateFromMat(actual_img, &actual_img_t); if (actual_img.data) { - EXPECT_EQ(CVTensor(expect_img) == CVTensor(actual_img), true); + EXPECT_EQ(*expect_img_t == *actual_img_t, true); } else { MS_LOG(ERROR) << "Not pass verification! Image data is null."; EXPECT_EQ(0, 1); @@ -223,7 +226,7 @@ bool BBoxOpCommon::LoadAnnotationFile(const std::string &path, std::shared_ptrNextSiblingElement("object"); // Read next BBox if exists } std::shared_ptr ret_value; - Status s = Tensor::CreateTensor(&ret_value, return_value_list, TensorShape({bbox_count, bbox_val_count})); + Status s = Tensor::CreateFromVector(return_value_list, TensorShape({bbox_count, bbox_val_count}), &ret_value); EXPECT_TRUE(s.IsOk()); (*target_BBox) = ret_value; // load bbox from file into return return true; diff --git a/tests/ut/cpp/dataset/common/cvop_common.cc b/tests/ut/cpp/dataset/common/cvop_common.cc index 48d69564fd..28d0c07764 100644 --- a/tests/ut/cpp/dataset/common/cvop_common.cc +++ b/tests/ut/cpp/dataset/common/cvop_common.cc @@ -19,7 +19,7 @@ #include #include "cvop_common.h" #include "minddata/dataset/core/constants.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" #include @@ -52,9 +52,11 @@ std::string CVOpCommon::GetFilename() { void CVOpCommon::GetInputImage(std::string filename) { try { - Tensor::CreateTensor(&raw_input_tensor_, filename); + Tensor::CreateFromFile(filename, &raw_input_tensor_); raw_cv_image_ = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); - input_tensor_ = std::dynamic_pointer_cast(std::make_shared(raw_cv_image_)); + std::shared_ptr input_cv_tensor; + CVTensor::CreateFromMat(raw_cv_image_, &input_cv_tensor); + input_tensor_ = std::dynamic_pointer_cast(input_cv_tensor); SwapRedAndBlue(input_tensor_, &input_tensor_); if (raw_cv_image_.data) { MS_LOG(INFO) << "Reading was successful. Height:" << raw_cv_image_.rows << " Width: " << raw_cv_image_.cols diff --git a/tests/ut/cpp/dataset/concat_op_test.cc b/tests/ut/cpp/dataset/concat_op_test.cc index 9e991ce0d3..7c6d847540 100644 --- a/tests/ut/cpp/dataset/concat_op_test.cc +++ b/tests/ut/cpp/dataset/concat_op_test.cc @@ -18,7 +18,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/concatenate_op_test.cc b/tests/ut/cpp/dataset/concatenate_op_test.cc index dc2fc69266..4e1e29b2be 100644 --- a/tests/ut/cpp/dataset/concatenate_op_test.cc +++ b/tests/ut/cpp/dataset/concatenate_op_test.cc @@ -28,15 +28,14 @@ class MindDataTestConcatenateOp : public UT::Common { }; TEST_F(MindDataTestConcatenateOp, TestOp) { - MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp."; - uint64_t labels[3] = {1, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp-SingleRowinput."; + std::vector labels = {1, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - uint64_t append_labels[3] = {4, 4, 4}; - std::shared_ptr append = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(append_labels)); + std::vector append_labels = {4, 4, 4}; + std::shared_ptr append; + Tensor::CreateFromVector(append_labels, &append); std::shared_ptr output; std::unique_ptr op(new ConcatenateOp(0, nullptr, append)); @@ -44,23 +43,84 @@ TEST_F(MindDataTestConcatenateOp, TestOp) { in.push_back(input); TensorRow out_row; Status s = op->Compute(in, &out_row); - uint64_t out[6] = {1, 1, 2, 4, 4, 4}; + std::vector out = {1, 1, 2, 4, 4, 4}; + + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); - std::shared_ptr expected = - std::make_shared(TensorShape{6}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); output = out_row[0]; EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); ASSERT_TRUE(output->type() == expected->type()); MS_LOG(DEBUG) << *output << std::endl; MS_LOG(DEBUG) << *expected << std::endl; + ASSERT_TRUE(*output == *expected); +} + +TEST_F(MindDataTestConcatenateOp, TestOp2) { + MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp2-MultiInput."; + std::vector labels = {1, 12, 2}; + std::shared_ptr row_1; + Tensor::CreateFromVector(labels, &row_1); + + std::shared_ptr row_2; + Tensor::CreateFromVector(labels, &row_2); + + std::vector append_labels = {4, 4, 4}; + std::shared_ptr append; + Tensor::CreateFromVector(append_labels, &append); + + TensorRow tensor_list; + tensor_list.push_back(row_1); + tensor_list.push_back(row_2); + + std::shared_ptr output; + std::unique_ptr op(new ConcatenateOp(0, nullptr, append)); + + TensorRow out_row; + Status s = op->Compute(tensor_list, &out_row); + std::vector out = {1, 12, 2, 1, 12, 2, 4, 4, 4}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); + + output = out_row[0]; + EXPECT_TRUE(s.IsOk()); + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; ASSERT_TRUE(*output == *expected); +} + +TEST_F(MindDataTestConcatenateOp, TestOp3) { + MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp3-Strings."; + std::vector labels = {"hello", "bye"}; + std::shared_ptr row_1; + Tensor::CreateFromVector(labels, &row_1); + + std::vector append_labels = {"1", "2", "3"}; + std::shared_ptr append; + Tensor::CreateFromVector(append_labels, &append); + + TensorRow tensor_list; + tensor_list.push_back(row_1); + + std::shared_ptr output; + std::unique_ptr op(new ConcatenateOp(0, nullptr, append)); + + TensorRow out_row; + Status s = op->Compute(tensor_list, &out_row); + std::vector out = {"hello", "bye", "1", "2", "3"}; + + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); - // std::vector inputs = {TensorShape({3})}; - // std::vector outputs = {}; - // s = op->OutputShape(inputs, outputs); - // EXPECT_TRUE(s.IsOk()); - // ASSERT_TRUE(outputs[0] == TensorShape{6}); - // MS_LOG(INFO) << "MindDataTestConcatenateOp-TestOp end."; + output = out_row[0]; + EXPECT_TRUE(s.IsOk()); + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + ASSERT_TRUE(*output == *expected); } diff --git a/tests/ut/cpp/dataset/crop_op_test.cc b/tests/ut/cpp/dataset/crop_op_test.cc new file mode 100644 index 0000000000..0f365558a6 --- /dev/null +++ b/tests/ut/cpp/dataset/crop_op_test.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "common/cvop_common.h" +#include "minddata/dataset/kernels/image/crop_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestCropOp : public UT::CVOP::CVOpCommon { + protected: + MindDataTestCropOp() : CVOpCommon() {} + + std::shared_ptr output_tensor_; +}; + +TEST_F(MindDataTestCropOp, TestOp1) { + MS_LOG(INFO) << "Doing testCrop."; + // Crop params + int crop_height = 18; + int crop_width = 12; + std::unique_ptr op(new CropOp(0, 0, crop_height, crop_width)); + EXPECT_TRUE(op->OneToOne()); + Status s = op->Compute(input_tensor_, &output_tensor_); + size_t actual = 0; + if (s == Status::OK()) { + actual = output_tensor_->shape()[0] * output_tensor_->shape()[1] * output_tensor_->shape()[2]; + } + EXPECT_EQ(crop_height, output_tensor_->shape()[1]); + EXPECT_EQ(actual, crop_height * crop_width * 3); + EXPECT_EQ(s, Status::OK()); +} + +TEST_F(MindDataTestCropOp, TestOp2) { + MS_LOG(INFO) << "Doing testCrop negative coordinates."; + // Crop params + unsigned int crop_height = 10; + unsigned int crop_width = 10; + + std::unique_ptr op( + new CropOp(-10, -10, crop_height, crop_width)); + EXPECT_TRUE(op->OneToOne()); + Status s = op->Compute(input_tensor_, &output_tensor_); + EXPECT_EQ(false, s.IsOk()); + MS_LOG(INFO) << "testCrop coordinate exception end."; +} + +TEST_F(MindDataTestCropOp, TestOp3) { + MS_LOG(INFO) << "Doing testCrop size too large."; + // Crop params + unsigned int crop_height = 1200000; + unsigned int crop_width = 1200000; + + std::unique_ptr op( + new CropOp(0, 0, crop_height, crop_width)); + EXPECT_TRUE(op->OneToOne()); + Status s = op->Compute(input_tensor_, &output_tensor_); + EXPECT_EQ(false, s.IsOk()); + MS_LOG(INFO) << "testCrop size exception end."; +} + diff --git a/tests/ut/cpp/dataset/csv_op_test.cc b/tests/ut/cpp/dataset/csv_op_test.cc new file mode 100644 index 0000000000..4c01af9654 --- /dev/null +++ b/tests/ut/cpp/dataset/csv_op_test.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/core/client.h" +#include "common/common.h" +#include "utils/ms_utils.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/util/status.h" + + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestCSVOp : public UT::DatasetOpTesting { + +}; + +TEST_F(MindDataTestCSVOp, TestCSVBasic) { + // Start with an empty execution tree + auto tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testCSV/1.csv"; + + std::vector> column_default_list; + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + std::shared_ptr op; + CsvOp::Builder builder; + builder.SetCsvFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16) + .SetShuffleFiles(false) + .SetOpConnectorSize(2) + .SetFieldDelim(',') + .SetColumDefault(column_default_list) + .SetColumName({"col1", "col2", "col3", "col4"}); + + Status rc = builder.Build(&op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 3); +} + +TEST_F(MindDataTestCSVOp, TestTotalRows) { + std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv"; + std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv"; + std::vector files; + files.push_back(csv_file1); + int64_t total_rows = 0; + CsvOp::CountAllFileRows(files, false, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(csv_file2); + CsvOp::CountAllFileRows(files, false, &total_rows); + ASSERT_EQ(total_rows, 5); + files.clear(); + + files.push_back(csv_file1); + files.push_back(csv_file2); + CsvOp::CountAllFileRows(files, false, &total_rows); + ASSERT_EQ(total_rows, 8); + files.clear(); +} diff --git a/tests/ut/cpp/dataset/distributed_sampler_test.cc b/tests/ut/cpp/dataset/distributed_sampler_test.cc new file mode 100644 index 0000000000..9f68e67079 --- /dev/null +++ b/tests/ut/cpp/dataset/distributed_sampler_test.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "gtest/gtest.h" + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "utils/log_adapter.h" + +#include +#include + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestDistributedSampler : public UT::Common { + public: + class DummyRandomAccessOp : public RandomAccessOp { + public: + DummyRandomAccessOp(uint64_t num_rows) { + // row count is in base class as protected member + // GetNumRowsInDataset does not need an override, the default from base class is fine. + num_rows_ = num_rows; + } + }; +}; + +TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) { + // num samples to draw. + uint64_t num_samples = 7; + + // create sampler with replacement = true + DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false); + DummyRandomAccessOp dummyRandomAccessOp(num_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); + + std::unique_ptr db; + TensorRow row; + std::vector out; + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); + db->PopRow(&row); + for (const auto &t : row) { + for (auto it = t->begin(); it != t->end(); it++) { + out.push_back(*it); + } + } + + ASSERT_EQ(4, out.size()); + + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); + ASSERT_EQ(db->eoe(), true); +} + +TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { + // num samples to draw. + uint64_t num_samples = 7; + + // create sampler with replacement = true + DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false); + DummyRandomAccessOp dummyRandomAccessOp(num_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); + + std::unique_ptr db; + TensorRow row; + std::vector out; + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); + db->PopRow(&row); + for (const auto &t : row) { + for (auto it = t->begin(); it != t->end(); it++) { + out.push_back(*it); + } + } + + ASSERT_EQ(3, out.size()); + + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); + ASSERT_EQ(db->eoe(), true); +} + +TEST_F(MindDataTestDistributedSampler, TestThreeShards) { + // num samples to draw. + uint64_t num_samples = 2; + + // create sampler with replacement = true + DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false); + DummyRandomAccessOp dummyRandomAccessOp(num_samples); + m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); + + std::unique_ptr db; + TensorRow row; + std::vector out; + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); + db->PopRow(&row); + for (const auto &t : row) { + for (auto it = t->begin(); it != t->end(); it++) { + out.push_back(*it); + } + } + + ASSERT_EQ(0, out.size()); + + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); + ASSERT_EQ(db->eoe(), true); +} + diff --git a/tests/ut/cpp/dataset/duplicate_op_test.cc b/tests/ut/cpp/dataset/duplicate_op_test.cc index 93779b084d..afad66f620 100644 --- a/tests/ut/cpp/dataset/duplicate_op_test.cc +++ b/tests/ut/cpp/dataset/duplicate_op_test.cc @@ -32,9 +32,9 @@ class MindDataTestDuplicateOp : public UT::Common { TEST_F(MindDataTestDuplicateOp, Basics) { std::shared_ptr t; - Tensor::CreateTensor(&t, std::vector({1, 2, 3, 4, 5, 6})); + Tensor::CreateFromVector(std::vector({1, 2, 3, 4, 5, 6}), &t); std::shared_ptr v; - Tensor::CreateTensor(&v, std::vector({3}), TensorShape::CreateScalar()); + Tensor::CreateFromVector(std::vector({3}), TensorShape::CreateScalar(), &v); std::shared_ptr op = std::make_shared(); TensorRow in; in.push_back(t); diff --git a/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc b/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc new file mode 100644 index 0000000000..2fc5f3c047 --- /dev/null +++ b/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc @@ -0,0 +1,639 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, + bool shuf = false, std::shared_ptr sampler = nullptr, + std::map map = {}, bool decode = false); + +std::shared_ptr Build(std::vector> ops); + +class MindDataTestEpochCtrlOp : public UT::DatasetOpTesting { +public: + void SetUp() override { + DatasetOpTesting::SetUp(); + folder_path = datasets_root_path_ + "/testPK/data"; + + GlobalInit(); + + // Start with an empty execution tree + my_tree_ = std::make_shared(); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); + rc = my_tree_->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + int32_t i = 0; + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + golden_imgs.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + } + + std::shared_ptr my_tree_; + Status rc; + std::string golden_imgs; + std::string folder_path; + int32_t label = 0; + std::string result; + int32_t img_class[4] = {0, 1, 2, 3}; + +}; + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_AutoInjectEpoch) { + MS_LOG(WARNING) << "Doing ImageFolder_AutoInjectEpoch."; + + int32_t num_epoch = 2 + std::rand() % 5; + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); + rc = my_tree_->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch; + std::string golden = golden_imgs; + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + uint64_t i = 0; + for (int epoch = 0; epoch < num_epoch; epoch++) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_TRUE(result == golden); + result.clear(); + + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + + EXPECT_TRUE(i == 44 * num_epoch); + + // Try to fetch data beyond the specified number of epochs. + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch) { + MS_LOG(WARNING) << "Doing ImageFolder_Epoch."; + + int32_t num_epoch = 2 + std::rand() % 5; + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch; + std::string golden = golden_imgs; + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + uint64_t i = 0; + for (int epoch = 0; epoch < num_epoch; epoch++) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_TRUE(result == golden); + result.clear(); + + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + + EXPECT_TRUE(i == 44 * num_epoch); + + // Try to fetch data beyond the specified number of epochs. + rc = di.GetNextAsMap(&tensor_map); + EXPECT_FALSE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch) { + MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch."; + + int32_t num_epoch = 2 + std::rand() % 5; + + int32_t num_repeats = 2; + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); + EXPECT_TRUE(rc.IsOk()); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats; + std::string golden = golden_imgs; + for (int i = 1; i < num_repeats; i++) { + golden += golden_imgs; + } + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + uint64_t i = 0; + for (int epoch = 0; epoch < num_epoch; epoch++) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_TRUE(result == golden); + result.clear(); + + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + + EXPECT_TRUE(i == 44 * num_repeats * num_epoch); + + // Try to fetch data beyond the specified number of epochs. + rc = di.GetNextAsMap(&tensor_map); + EXPECT_FALSE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch) { + MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch."; + + int32_t num_epoch = 2 + std::rand() % 5; + + int32_t num_repeats = 2; + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); + EXPECT_TRUE(rc.IsOk()); + + int32_t num_repeats_2 = 3; + std::shared_ptr repeat_op_2; + rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2); + EXPECT_TRUE(rc.IsOk()); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2; + std::string golden; + for (int j = 0; j < num_repeats_2; j++) { + for (int i = 0; i < num_repeats; i++) { + golden += golden_imgs; + } + } + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + uint64_t i = 0; + for (int epoch = 0; epoch < num_epoch; epoch++) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_EQ(result.size(), golden.size()); + EXPECT_TRUE(result == golden); + result.clear(); + + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + + EXPECT_EQ(i, 44 * num_epoch * num_repeats * num_repeats_2); + + // Try to fetch data beyond the specified number of epochs. + rc = di.GetNextAsMap(&tensor_map); + EXPECT_FALSE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf) { + MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf."; + + // if num_epoch == -1, it means infinity. + int32_t num_epoch = -1; + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + uint64_t i = 0; + + // For this test, we stop at stop_at_epoch number. + int32_t stop_at_epoch = 2 + std::rand() % 6; + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; + for (int epoch = 0; epoch < stop_at_epoch; epoch++) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_EQ(result, golden_imgs); + result.clear(); + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + EXPECT_TRUE(i == 44 * stop_at_epoch); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_Inf) { + MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf."; + + // if num_epoch == -1, it means infinity. + int32_t num_epoch = -1; + + int32_t num_repeats = 2; + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); + EXPECT_TRUE(rc.IsOk()); + + int32_t num_repeats_2 = 3; + std::shared_ptr repeat_op_2; + rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2); + EXPECT_TRUE(rc.IsOk()); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2; + std::string golden; + for (int j = 0; j < num_repeats_2; j++) { + for (int i = 0; i < num_repeats; i++) { + golden += golden_imgs; + } + } + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree_); + TensorMap tensor_map; + uint64_t i = 0; + + // For this test, we stop at stop_at_epoch number. + int32_t stop_at_epoch = 2 + std::rand() % 6; + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; + for (int epoch = 0; epoch < stop_at_epoch; epoch++) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (tensor_map.size() != 0) { + tensor_map["label"]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_EQ(result, golden); + result.clear(); + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats * num_repeats_2); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_ChildItr) { + MS_LOG(WARNING) << "Doing ImageFolder_Epoch_ChildItr."; + + int32_t num_epoch = 2 + std::rand() % 5; + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "num_epoch: " << num_epoch; + + // Start the loop of reading tensors from our pipeline + ChildIterator ci(my_tree_->root().get(), 0, 0); + TensorRow tensor_row; + uint64_t total_sample = 0; + uint64_t i = 0; + uint32_t epoch = 0; + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + while(!ci.eof_handled()) { + i = 0; + while (tensor_row.size() != 0) { + tensor_row[1]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + i++; + } + + epoch++; + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + EXPECT_TRUE(result == golden_imgs); + result.clear(); + EXPECT_TRUE(i == 44); + total_sample += i; + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + } + EXPECT_TRUE(total_sample == 44 * num_epoch); + + // Try to fetch data after last epoch ends. + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(tensor_row.empty()); + EXPECT_FALSE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_ChildItr) { + MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_ChildItr."; + + int32_t num_epoch = 2 + std::rand() % 5; + + int32_t num_repeats = 2; + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); + EXPECT_TRUE(rc.IsOk()); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats; + std::string golden; + for (int i = 0; i < num_repeats; i++) { + golden += golden_imgs; + } + + // Start the loop of reading tensors from our pipeline + ChildIterator ci(my_tree_->root().get(), 0, 0); + TensorRow tensor_row; + uint64_t total_sample = 0; + uint64_t i = 0; + uint32_t epoch = 0; + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + while(!ci.eof_handled()) { + i = 0; + while (tensor_row.size() != 0) { + tensor_row[1]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + i++; + } + + epoch++; + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + EXPECT_TRUE(result == golden); + result.clear(); + EXPECT_TRUE(i == 44 * num_repeats); + total_sample += i; + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + } + EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats); + + // Try to fetch data after last epoch ends. + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(tensor_row.empty()); + EXPECT_FALSE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_ChildItr) { + MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch_ChildItr."; + + int32_t num_epoch = 2 + std::rand() % 5; + + int32_t num_repeats = 2; + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); + EXPECT_TRUE(rc.IsOk()); + + int32_t num_repeats_2 = 3; + std::shared_ptr repeat_op_2; + rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2); + EXPECT_TRUE(rc.IsOk()); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2; + std::string golden; + for (int j = 0; j < num_repeats_2; j++) { + for (int i = 0; i < num_repeats; i++) { + golden += golden_imgs; + } + } + + // Start the loop of reading tensors from our pipeline + ChildIterator ci(my_tree_->root().get(), 0, 0); + TensorRow tensor_row; + uint64_t total_sample = 0; + uint64_t i = 0; + uint32_t epoch = 0; + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + while(!ci.eof_handled()) { + i = 0; + while (tensor_row.size() != 0) { + tensor_row[1]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + i++; + } + + epoch++; + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + EXPECT_TRUE(result == golden); + result.clear(); + EXPECT_TRUE(i == 44 * num_repeats * num_repeats_2); + total_sample += i; + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + } + EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats * num_repeats_2); + + // Try to fetch data after last epoch ends. + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(tensor_row.empty()); + EXPECT_FALSE(rc.IsOk()); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf_ChildItr) { + MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf_ChildItr."; + + // if num_epoch == -1, it means infinity. + int32_t num_epoch = -1; + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + ChildIterator ci(my_tree_->root().get(), 0, 0); + TensorRow tensor_row; + uint64_t i = 0; + + // For this test, we stop at a random number between 0 - 100 epochs. + int32_t stop_at_epoch = 2 + std::rand() % 5; + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; + for (int epoch = 0; epoch < stop_at_epoch; epoch++) { + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + while (tensor_row.size() != 0) { + tensor_row[1]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_TRUE(result == golden_imgs); + result.clear(); + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + EXPECT_TRUE(i == 44 * stop_at_epoch); +} + +TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_Inf_ChildItr) { + MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf_ChildItr."; + + // if num_epoch == -1, it means infinity. + int32_t num_epoch = -1; + int32_t num_repeats = 2; + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); + EXPECT_TRUE(rc.IsOk()); + + my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op}); + rc = my_tree_->Prepare(num_epoch); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree_->Launch(); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats; + std::string golden; + for (int i = 0; i < num_repeats; i++) { + golden += golden_imgs; + } + + // Start the loop of reading tensors from our pipeline + ChildIterator ci(my_tree_->root().get(), 0, 0); + TensorRow tensor_row; + uint64_t i = 0; + + // For this test, we stop at a random number between 0 - 100 epochs. + int32_t stop_at_epoch = 2 + std::rand() % 5; + MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; + for (int epoch = 0; epoch < stop_at_epoch; epoch++) { + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + while (tensor_row.size() != 0) { + tensor_row[1]->GetItemAt(&label, {}); + MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; + EXPECT_TRUE(img_class[(i % 44) / 11] == label); + // Dump all the image into string, to be used as a comparison later. + result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); + rc = ci.FetchNextTensorRow(&tensor_row); + EXPECT_TRUE(rc.IsOk()); + i++; + } + EXPECT_TRUE(result == golden); + result.clear(); + MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; + } + EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats); +} diff --git a/tests/ut/cpp/dataset/fill_op_test.cc b/tests/ut/cpp/dataset/fill_op_test.cc index 20e323cc8d..795db705af 100644 --- a/tests/ut/cpp/dataset/fill_op_test.cc +++ b/tests/ut/cpp/dataset/fill_op_test.cc @@ -29,23 +29,20 @@ class MindDataTestFillOp : public UT::Common { TEST_F(MindDataTestFillOp, TestOp) { MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp."; - uint64_t labels[3] = {1, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {1, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - TensorShape fill_shape({}); - std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_UINT64)); - fill_tensor->SetItemAt({}, 4); + std::shared_ptr fill_tensor; + Tensor::CreateScalar(4, &fill_tensor); std::shared_ptr output; std::unique_ptr op(new FillOp(fill_tensor)); Status s = op->Compute(input, &output); - uint64_t out[3] = {4, 4, 4}; - - std::shared_ptr expected = - std::make_shared(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + std::vector out = {4, 4, 4}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); @@ -59,23 +56,20 @@ TEST_F(MindDataTestFillOp, TestOp) { TEST_F(MindDataTestFillOp, TestCasting) { MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - TensorShape fill_shape({}); - std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_FLOAT32)); - fill_tensor->SetItemAt({}, 2.0); + std::shared_ptr fill_tensor; + Tensor::CreateScalar(2.0, &fill_tensor); std::shared_ptr output; std::unique_ptr op(new FillOp(fill_tensor)); Status s = op->Compute(input, &output); - uint64_t out[3] = {2, 2, 2}; - - std::shared_ptr expected = - std::make_shared(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + std::vector out = {2, 2, 2}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); ASSERT_TRUE(output->shape() == expected->shape()); ASSERT_TRUE(output->type() == expected->type()); @@ -90,15 +84,15 @@ TEST_F(MindDataTestFillOp, TestCasting) { TEST_F(MindDataTestFillOp, ScalarFill) { MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); TensorShape fill_shape({2}); - uint64_t fill_labels[3] = {0, 1}; - std::shared_ptr fill_tensor = - std::make_shared(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast(fill_labels)); + std::vector fill_labels = {0, 1}; + std::shared_ptr fill_tensor; + Tensor::CreateFromVector(fill_labels, &fill_tensor); + std::shared_ptr output; std::unique_ptr op(new FillOp(fill_tensor)); Status s = op->Compute(input, &output); @@ -112,12 +106,11 @@ TEST_F(MindDataTestFillOp, ScalarFill) { TEST_F(MindDataTestFillOp, StringFill) { MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill."; std::vector strings = {"xyzzy", "plugh", "abracadabra"}; - TensorShape shape({3}); - std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr input; + Tensor::CreateFromVector(strings, &input); - TensorShape fill_shape({}); - std::string fill_string = "hello"; - std::shared_ptr fill_tensor = std::make_shared(fill_string); + std::shared_ptr fill_tensor; + Tensor::CreateScalar("hello", &fill_tensor); std::shared_ptr output; @@ -125,8 +118,8 @@ TEST_F(MindDataTestFillOp, StringFill) { Status s = op->Compute(input, &output); std::vector expected_strings = {"hello", "hello", "hello"}; - TensorShape expected_shape({3}); - std::shared_ptr expected = std::make_shared(expected_strings, expected_shape); + std::shared_ptr expected; + Tensor::CreateFromVector(expected_strings, &expected); EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); @@ -142,12 +135,11 @@ TEST_F(MindDataTestFillOp, StringFill) { TEST_F(MindDataTestFillOp, NumericToString) { MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString."; std::vector strings = {"xyzzy", "plugh", "abracadabra"}; - TensorShape shape({3}); - std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr input; + Tensor::CreateFromVector(strings, &input); - TensorShape fill_shape({}); - std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_FLOAT32)); - fill_tensor->SetItemAt({}, 2.0); + std::shared_ptr fill_tensor; + Tensor::CreateScalar(2.0, &fill_tensor); std::shared_ptr output; @@ -162,14 +154,12 @@ TEST_F(MindDataTestFillOp, NumericToString) { TEST_F(MindDataTestFillOp, StringToNumeric) { MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); - - TensorShape fill_shape({}); - std::string fill_string = "hello"; - std::shared_ptr fill_tensor = std::make_shared(fill_string); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); + + std::shared_ptr fill_tensor; + Tensor::CreateScalar("hello", &fill_tensor); std::shared_ptr output; diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 3168efa196..2cce023dcf 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -18,7 +18,7 @@ #include #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" @@ -68,8 +68,7 @@ std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int6 Status Create1DTensor(std::shared_ptr *sample_ids, int64_t num_elements, unsigned char *data = nullptr, DataType::Type data_type = DataType::DE_UINT32) { TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, TensorImpl::kFlexible, shape, DataType(data_type), data)); - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes()); // allocate memory in case user forgets! + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(shape, DataType(data_type), data, sample_ids)); return Status::OK(); } diff --git a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc index 85b3384d36..21f2483ffb 100644 --- a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc @@ -42,7 +42,8 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opFuntions) { TensorRow input, output; std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); - std::shared_ptr input_tensor = std::make_shared("今天天气太好了我们一起去外面玩吧"); + std::shared_ptr input_tensor; + Tensor::CreateScalar("今天天气太好了我们一起去外面玩吧", &input_tensor); input.push_back(input_tensor); Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -66,7 +67,8 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opAdd) { std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); op->AddWord("男默女泪"); - std::shared_ptr input_tensor = std::make_shared("男默女泪"); + std::shared_ptr input_tensor; + Tensor::CreateScalar("男默女泪", &input_tensor); input.push_back(input_tensor); Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -84,7 +86,8 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opEmpty) { std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); op->AddWord("男默女泪"); - std::shared_ptr input_tensor = std::make_shared(""); + std::shared_ptr input_tensor; + Tensor::CreateScalar("", &input_tensor); input.push_back(input_tensor); Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index a6eef4aaa2..0d6621bfa2 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -19,7 +19,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h" @@ -71,9 +71,9 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(res[i] == label); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -101,9 +101,9 @@ TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) { rc = di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); i++; di.GetNextAsMap(&tensor_map); EXPECT_EQ(label, 1); @@ -131,9 +131,9 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(label == res[i]); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -160,9 +160,9 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(0 == label); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -176,7 +176,7 @@ TEST_F(MindDataTestManifest, MindDataTestManifestEval) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; int64_t num_samples = 1; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})}); tree->Prepare(); Status rc = tree->Launch(); @@ -189,9 +189,9 @@ TEST_F(MindDataTestManifest, MindDataTestManifestEval) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(0 == label); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index 4e9cfe9ec9..0cee75b264 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -645,16 +645,14 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { map_decode_builder.SetInColNames({"image"}) .SetOutColNames({}) .SetTensorFuncs(func_list) - .SetNumWorkers(14) - .SetPerformanceMode(false); + .SetNumWorkers(14); rc = map_decode_builder.Build(&map_decode_map); EXPECT_TRUE(rc.IsOk()); map_resize_builder.SetInColNames({"image"}) .SetOutColNames({}) .SetTensorFuncs(func_list2) - .SetNumWorkers(15) - .SetPerformanceMode(false); + .SetNumWorkers(15); rc = map_resize_builder.Build(&map_resize_op); EXPECT_TRUE(rc.IsOk()); @@ -739,5 +737,3 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) { } EXPECT_TRUE(i == 88); } - - diff --git a/tests/ut/cpp/dataset/mask_test.cc b/tests/ut/cpp/dataset/mask_test.cc index 609d5bf447..c6279acdb9 100644 --- a/tests/ut/cpp/dataset/mask_test.cc +++ b/tests/ut/cpp/dataset/mask_test.cc @@ -38,9 +38,9 @@ class MindDataTestMaskOp : public UT::Common { TEST_F(MindDataTestMaskOp, Basics) { std::shared_ptr t; - Tensor::CreateTensor(&t, std::vector({1, 2, 3, 4, 5, 6})); + Tensor::CreateFromVector(std::vector({1, 2, 3, 4, 5, 6}), &t); std::shared_ptr v; - Tensor::CreateTensor(&v, std::vector({3}), TensorShape::CreateScalar()); + Tensor::CreateFromVector(std::vector({3}), TensorShape::CreateScalar(), &v); std::shared_ptr op = std::make_shared(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16)); std::shared_ptr out; ASSERT_TRUE(op->Compute(t, &out).IsOk()); diff --git a/tests/ut/cpp/dataset/mind_record_op_test.cc b/tests/ut/cpp/dataset/mind_record_op_test.cc index c9067535d6..bed97a740d 100644 --- a/tests/ut/cpp/dataset/mind_record_op_test.cc +++ b/tests/ut/cpp/dataset/mind_record_op_test.cc @@ -18,7 +18,7 @@ #include #include "minddata/dataset/core/client.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "minddata/mindrecord/include/shard_category.h" #include "minddata/mindrecord/include/shard_error.h" @@ -435,7 +435,6 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) { .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) - .SetBlockReader() .SetColumnsToLoad(column_list); rc = builder.Build(&my_mindrecord_op); ASSERT_TRUE(rc.IsOk()); diff --git a/tests/ut/cpp/dataset/mnist_op_test.cc b/tests/ut/cpp/dataset/mnist_op_test.cc index dfceeaa06a..c40086e20f 100644 --- a/tests/ut/cpp/dataset/mnist_op_test.cc +++ b/tests/ut/cpp/dataset/mnist_op_test.cc @@ -18,7 +18,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "common/common.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" diff --git a/tests/ut/cpp/dataset/one_hot_op_test.cc b/tests/ut/cpp/dataset/one_hot_op_test.cc index 2617ae4536..9dd5139dac 100644 --- a/tests/ut/cpp/dataset/one_hot_op_test.cc +++ b/tests/ut/cpp/dataset/one_hot_op_test.cc @@ -29,19 +29,17 @@ class MindDataTestOneHotOp : public UT::Common { TEST_F(MindDataTestOneHotOp, TestOp) { MS_LOG(INFO) << "Doing MindDataTestOneHotOp."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = std::make_shared(shape, DataType(DataType::DE_UINT64), - reinterpret_cast (labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); std::shared_ptr output; std::unique_ptr op(new OneHotOp(5)); Status s = op->Compute(input, &output); - uint64_t out[15] = {1, 0, 0, 0, 0, - 0, 1, 0, 0, 0, - 0, 0, 1, 0, 0}; - std::shared_ptr expected = std::make_shared(TensorShape{3, 5}, DataType(DataType::DE_UINT64), - reinterpret_cast (out)); + std::vector out = {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, TensorShape{3, 5}, &expected); + EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); ASSERT_TRUE(output->type() == expected->type()); diff --git a/tests/ut/cpp/dataset/pad_end_op_test.cc b/tests/ut/cpp/dataset/pad_end_op_test.cc index 1c838da8e8..b4bd993f38 100644 --- a/tests/ut/cpp/dataset/pad_end_op_test.cc +++ b/tests/ut/cpp/dataset/pad_end_op_test.cc @@ -35,44 +35,40 @@ TEST_F(MindDataTestPadEndOp, TestOp) { TensorShape pad_data_shape({1}); // prepare input tensor - float_t orig1[4] = {1, 1, 1, 1}; + std::vector orig1 = {1, 1, 1, 1}; TensorShape input_shape1({2, 2}); std::vector input_shape1_vector = {input_shape1}; - std::shared_ptr input1 = - std::make_shared(input_shape1, DataType(DataType::DE_FLOAT32), reinterpret_cast(orig1)); + std::shared_ptr input1; + Tensor::CreateFromVector(orig1, input_shape1, &input1); // pad_shape TensorShape pad_shape1[3] = {TensorShape({3, 3}), TensorShape({2, 4}), TensorShape({4, 2})}; // value to pad - float_t pad_data1[3][1] = {0, 3.5, 3.5}; + std::vector> pad_data1 = {{0}, {3.5}, {3.5}}; std::shared_ptr expected1[3]; // expected tensor output for testunit 1 - float_t out1[9] = {1, 1, 0, 1, 1, 0, 0, 0, 0}; - - expected1[0] = - std::make_shared(pad_shape1[0], DataType(DataType::DE_FLOAT32), reinterpret_cast(out1)); + std::vector out1 = {1, 1, 0, 1, 1, 0, 0, 0, 0}; + Tensor::CreateFromVector(out1, pad_shape1[0], &(expected1[0])); // expected tensor output for testunit 2 - float_t out2[8] = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5}; - - expected1[1] = - std::make_shared(pad_shape1[1], DataType(DataType::DE_FLOAT32), reinterpret_cast(out2)); + std::vector out2 = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5}; + Tensor::CreateFromVector(out2, pad_shape1[1], &(expected1[1])); // expected tensor output for testunit 3 - float_t out3[8] = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5}; - - expected1[2] = - std::make_shared(pad_shape1[2], DataType(DataType::DE_FLOAT32), reinterpret_cast(out3)); + std::vector out3 = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5}; + Tensor::CreateFromVector(out3, pad_shape1[2], &(expected1[2])); // run the PadEndOp for (auto i = 0; i < 3; i++) { std::shared_ptr output; std::vector output_shape = {TensorShape({})}; - std::shared_ptr pad_value1 = std::make_shared(pad_data_shape, DataType(DataType::DE_FLOAT32), - reinterpret_cast(pad_data1[i])); + + std::shared_ptr pad_value1; + Tensor::CreateFromVector(pad_data1[i], pad_data_shape, &pad_value1); + std::unique_ptr op(new PadEndOp(pad_shape1[i], pad_value1)); Status s = op->Compute(input1, &output); @@ -96,7 +92,7 @@ TEST_F(MindDataTestPadEndOp, TestOp) { TensorShape input_shape2({2}); std::vector input_shape2_vector = {input_shape2}; std::shared_ptr input2; - Tensor::CreateTensor(&input2, orig2, input_shape2); + Tensor::CreateFromVector(orig2, input_shape2, &input2); // pad_shape TensorShape pad_shape2[3] = {TensorShape({5}), TensorShape({2}), TensorShape({10})}; @@ -112,7 +108,7 @@ TEST_F(MindDataTestPadEndOp, TestOp) { for (auto i = 0; i < 3; i++) { // pad value - Tensor::CreateTensor(&pad_value2[i], pad_data2[i], pad_data_shape); + Tensor::CreateFromVector(pad_data2[i], pad_data_shape, &pad_value2[i]); std::shared_ptr output; std::vector output_shape = {TensorShape({})}; @@ -121,7 +117,7 @@ TEST_F(MindDataTestPadEndOp, TestOp) { Status s = op->Compute(input2, &output); - Tensor::CreateTensor(&expected2[i], outstring[i], pad_shape2[i]); + Tensor::CreateFromVector(outstring[i], pad_shape2[i], &expected2[i]); EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected2[i]->shape()); diff --git a/tests/ut/cpp/dataset/project_op_test.cc b/tests/ut/cpp/dataset/project_op_test.cc index 45ef11b88f..5bdb2923d4 100644 --- a/tests/ut/cpp/dataset/project_op_test.cc +++ b/tests/ut/cpp/dataset/project_op_test.cc @@ -18,7 +18,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc index fcf8ba2605..b3b46f09ec 100644 --- a/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc @@ -66,6 +66,7 @@ TEST_F(MindDataTestRandomCropWithBBoxOp, TestOp1) { } GlobalContext::config_manager()->set_seed(current_seed); } + MS_LOG(INFO) << "testRandomCropWithBBoxOp1 end."; } TEST_F(MindDataTestRandomCropWithBBoxOp, TestOp2) { @@ -87,5 +88,22 @@ TEST_F(MindDataTestRandomCropWithBBoxOp, TestOp2) { EXPECT_EQ(s, Status::OK()); EXPECT_EQ(4, output_tensor_row_[1]->shape()[1]); // check for existence of 4 columns } - MS_LOG(INFO) << "testRandomCropWithBBoxOp end."; + MS_LOG(INFO) << "testRandomCropWithBBoxOp2 end."; } + +TEST_F(MindDataTestRandomCropWithBBoxOp, TestOp3) { + MS_LOG(INFO) << "Doing testRandomCropWithBBoxOp3."; + // Crop params + unsigned int crop_height = 1280; + unsigned int crop_width = 1280; + std::unique_ptr op(new RandomCropWithBBoxOp(crop_height, crop_width, crop_height * 3 + 1, + crop_height * 3 + 1, crop_width * 3 + 1, + crop_width * 3 + 1, BorderType::kConstant, false)); + + for (auto tensor_row_ : images_and_annotations_) { + Status s = op->Compute(tensor_row_, &output_tensor_row_); + EXPECT_TRUE(s.IsError()); + ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); + } + MS_LOG(INFO) << "testRandomCropWithBBoxOp3 end."; +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/rename_op_test.cc b/tests/ut/cpp/dataset/rename_op_test.cc index ac64346c26..6a1de176da 100644 --- a/tests/ut/cpp/dataset/rename_op_test.cc +++ b/tests/ut/cpp/dataset/rename_op_test.cc @@ -19,10 +19,9 @@ #include #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/constants.h" -#include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/rename_op.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/engine/data_buffer.h" #include "gtest/gtest.h" #include "minddata/dataset/core/global_context.h" diff --git a/tests/ut/cpp/dataset/repeat_op_test.cc b/tests/ut/cpp/dataset/repeat_op_test.cc index 74d494c0dc..c74aee06ab 100644 --- a/tests/ut/cpp/dataset/repeat_op_test.cc +++ b/tests/ut/cpp/dataset/repeat_op_test.cc @@ -46,7 +46,8 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { ASSERT_TRUE(rc.IsOk()); rc = my_tree->AssociateNode(my_tfreader_op); ASSERT_TRUE(rc.IsOk()); - my_tree->AssociateNode(parent_op); + rc = my_tree->AssociateNode(parent_op); + ASSERT_TRUE(rc.IsOk()); ASSERT_NE(parent_op, nullptr); ASSERT_NE(my_tfreader_op, nullptr); parent_op->AddChild(std::move(my_tfreader_op)); diff --git a/tests/ut/cpp/dataset/schema_test.cc b/tests/ut/cpp/dataset/schema_test.cc index 95b9c75d9e..d4c47cdf51 100644 --- a/tests/ut/cpp/dataset/schema_test.cc +++ b/tests/ut/cpp/dataset/schema_test.cc @@ -18,7 +18,7 @@ #include #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/data_schema.h" diff --git a/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc b/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc new file mode 100644 index 0000000000..19f7291079 --- /dev/null +++ b/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common.h" +#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" +#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h" +#include "minddata/dataset/text/sentence_piece_vocab.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/status.h" + +using namespace mindspore::dataset; + +class MindDataTestSentencePieceVocabOp : public UT::DatasetOpTesting { + public: + void CheckEqual(const std::shared_ptr &o, const std::vector &index, const std::string &expect) { + std::string_view str; + Status s = o->GetItemAt(&str, index); + EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(str, expect); + } +}; + +TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) { + MS_LOG(INFO) << "Doing MindDataTestSentencePieceVocabOp TestSentencePieceFromDatasetFuntions."; + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/test_sentencepiece/botchan.txt"; + auto tree = std::make_shared(); + + std::shared_ptr file_op; + TextFileOp::Builder builder_file; + builder_file.SetTextFilesList({dataset_path}).SetRowsPerBuffer(1).SetNumWorkers(1).SetOpConnectorSize(2); + + Status rc = builder_file.Build(&file_op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(file_op); + ASSERT_TRUE(rc.IsOk()); + + std::shared_ptr spm = std::make_unique(); + std::shared_ptr spv_op; + BuildSentencePieceVocabOp::Builder builder_spv; + std::vector cols; + std::unordered_map m_params; + builder_spv.SetVocab(spm) + .SetVocabSize(5000) + .SetColumnNames(cols) + .SetCharacterCoverage(0.9995) + .SetModelType(SentencePieceModel::kUnigram) + .SetParams(m_params) + .SetOpConnectorSize(2); + + rc = builder_spv.Build(&spv_op); + ASSERT_TRUE(rc.IsOk()); + rc = tree->AssociateNode(spv_op); + ASSERT_TRUE(rc.IsOk()); + + rc = spv_op->AddChild(file_op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(spv_op); + ASSERT_TRUE(rc.IsOk()); + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + while (!tensor_list.empty()) { + rc = di.FetchNextTensorRow(&tensor_list); + } + ASSERT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromFileFuntions) { + MS_LOG(INFO) << "Doing MindDataTestSentencePieceVocabOp TestSentencePieceFromFileFuntions."; + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/test_sentencepiece/botchan.txt"; + std::vector path_list; + path_list.emplace_back(dataset_path); + std::unordered_map param_map; + std::shared_ptr spm = std::make_unique(); + Status rc = SentencePieceVocab::BuildFromFile(path_list, 5000, 0.9995, SentencePieceModel::kUnigram, param_map, &spm); + ASSERT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) { + MS_LOG(INFO) << "Doing MindDataTestSentencePieceVocabOp TestSentencePieceTokenizerFuntions."; + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/test_sentencepiece/botchan.txt"; + auto tree = std::make_shared(); + + std::shared_ptr file_op; + TextFileOp::Builder builder_file; + builder_file.SetTextFilesList({dataset_path}).SetRowsPerBuffer(1).SetNumWorkers(1).SetOpConnectorSize(2); + + Status rc = builder_file.Build(&file_op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(file_op); + ASSERT_TRUE(rc.IsOk()); + + std::shared_ptr spm = std::make_unique(); + std::shared_ptr spv_op; + BuildSentencePieceVocabOp::Builder builder_spv; + std::vector cols; + std::unordered_map m_params; + + builder_spv.SetVocab(spm) + .SetVocabSize(5000) + .SetColumnNames(cols) + .SetCharacterCoverage(0.9995) + .SetModelType(SentencePieceModel::kUnigram) + .SetParams(m_params) + .SetOpConnectorSize(2); + + rc = builder_spv.Build(&spv_op); + ASSERT_TRUE(rc.IsOk()); + rc = tree->AssociateNode(spv_op); + ASSERT_TRUE(rc.IsOk()); + + rc = spv_op->AddChild(file_op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(spv_op); + ASSERT_TRUE(rc.IsOk()); + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + while (!tensor_list.empty()) { + rc = di.FetchNextTensorRow(&tensor_list); + } + std::shared_ptr output_tensor; + std::unique_ptr op( + new SentencePieceTokenizerOp(spm, SPieceTokenizerLoadType::kModel, SPieceTokenizerOutType::kString)); + std::shared_ptr input_tensor; + Tensor::CreateScalar("I saw a girl with a telescope.", &input_tensor); + Status s = op->Compute(input_tensor, &output_tensor); + + std::vector expect; + expect.push_back("▁I"); + expect.push_back("▁sa"); + expect.push_back("w"); + expect.push_back("▁a"); + expect.push_back("▁girl"); + expect.push_back("▁with"); + expect.push_back("▁a"); + expect.push_back("▁te"); + expect.push_back("les"); + expect.push_back("co"); + expect.push_back("pe"); + expect.push_back("."); + ASSERT_TRUE(output_tensor->Size() == expect.size()); + for (int i = 0; i < output_tensor->Size(); i++) { + std::string_view str; + output_tensor->GetItemAt(&str, {i}); + std::string sentence{str}; + ASSERT_TRUE(sentence == expect[i]); + } +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/shuffle_op_test.cc b/tests/ut/cpp/dataset/shuffle_op_test.cc index 98b4878efb..45d2d7f608 100644 --- a/tests/ut/cpp/dataset/shuffle_op_test.cc +++ b/tests/ut/cpp/dataset/shuffle_op_test.cc @@ -15,7 +15,7 @@ */ #include "minddata/dataset/core/client.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include diff --git a/tests/ut/cpp/dataset/sliding_window_op_test.cc b/tests/ut/cpp/dataset/sliding_window_op_test.cc new file mode 100644 index 0000000000..b39a131460 --- /dev/null +++ b/tests/ut/cpp/dataset/sliding_window_op_test.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/text/kernels/sliding_window_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestSlidingWindowOp : public UT::Common { + protected: + MindDataTestSlidingWindowOp() {} +}; + +TEST_F(MindDataTestSlidingWindowOp, Compute) { + MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute."; + std::vector strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; + TensorShape shape({static_cast(strings.size())}); + std::shared_ptr input; + Tensor::CreateFromVector(strings, shape, &input); + std::shared_ptr output; + + std::unique_ptr op(new SlidingWindowOp(3, 0)); + Status s = op->Compute(input, &output); + + std::vector out = {"one", "two", "three", "two", "three", "four", "three", "four", "five", + "four", "five", "six", "five", "six", "seven", "six", "seven", "eight"}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, TensorShape({6, 3}), &expected); + + ASSERT_TRUE(output->shape() == expected->shape()); + ASSERT_TRUE(output->type() == expected->type()); + MS_LOG(DEBUG) << *output << std::endl; + MS_LOG(DEBUG) << *expected << std::endl; + ASSERT_TRUE(*output == *expected); + + MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; +} + +TEST_F(MindDataTestSlidingWindowOp, OutputShape) { + MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape."; + std::vector strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; + TensorShape shape({static_cast(strings.size())}); + std::shared_ptr input; + Tensor::CreateFromVector(strings, shape, &input); + std::vector input_shape = {input->shape()}; + std::vector output_shape = {TensorShape({})}; + + std::unique_ptr op(new SlidingWindowOp(3, 0)); + Status s = op->OutputShape(input_shape, output_shape); + + MS_LOG(DEBUG) << "input_shape" << input_shape[0]; + MS_LOG(DEBUG) << "output_shape" << output_shape[0]; + ASSERT_TRUE(output_shape[0] == TensorShape({6, 3})); + + MS_LOG(INFO) << "MindDataTestSlidingWindowOp end."; +} diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index 96e9652bbc..79464b732b 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -30,8 +30,7 @@ using namespace mindspore::dataset; Status CreateINT64Tensor(std::shared_ptr *sample_ids, int64_t num_elements, unsigned char *data = nullptr) { TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, TensorImpl::kFlexible, shape, DataType(DataType::DE_INT64), data)); - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes()); // allocate memory in case user forgets! + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(shape, DataType(DataType::DE_INT64), data, sample_ids)); return Status::OK(); } @@ -54,8 +53,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { {0, 17, 4, 10, 14, 8, 15}, {13, 9, 16, 3, 2, 19, 12}, {1, 11, 6, 18, 7, 5, 0}}; for (int i = 0; i < 6; i++) { std::shared_ptr t; - Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape({7}), - DataType(DataType::DE_INT64), (unsigned char *)(res[i])); + Tensor::CreateFromMemory(TensorShape({7}), DataType(DataType::DE_INT64), (unsigned char *)(res[i]), &t); row.push_back(t); } MockStorageOp mock(20); diff --git a/tests/ut/cpp/dataset/swap_red_blue_test.cc b/tests/ut/cpp/dataset/swap_red_blue_test.cc new file mode 100644 index 0000000000..7a760ee5b1 --- /dev/null +++ b/tests/ut/cpp/dataset/swap_red_blue_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "common/cvop_common.h" +#include "minddata/dataset/kernels/image/swap_red_blue_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestSwapRedBlueOp : public UT::CVOP::CVOpCommon { + protected: + MindDataTestSwapRedBlueOp() : CVOpCommon() {} + + std::shared_ptr output_tensor_; +}; + +TEST_F(MindDataTestSwapRedBlueOp, TestOp1) { + MS_LOG(INFO) << "Doing testSwapRedBlue."; + // SwapRedBlue params + std::unique_ptr op(new SwapRedBlueOp()); + EXPECT_TRUE(op->OneToOne()); + Status s = op->Compute(input_tensor_, &output_tensor_); + size_t actual = 0; + if (s == Status::OK()) { + actual = output_tensor_->shape()[0] * output_tensor_->shape()[1] * output_tensor_->shape()[2]; + } + EXPECT_EQ(actual, input_tensor_->shape()[0] * input_tensor_->shape()[1] * 3); + EXPECT_EQ(s, Status::OK()); +} + diff --git a/tests/ut/cpp/dataset/take_op_test.cc b/tests/ut/cpp/dataset/take_op_test.cc index a8bfe40b10..156a76e9c1 100644 --- a/tests/ut/cpp/dataset/take_op_test.cc +++ b/tests/ut/cpp/dataset/take_op_test.cc @@ -18,7 +18,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/tensor_string_test.cc b/tests/ut/cpp/dataset/tensor_string_test.cc index fe336a34c5..232fefc2ae 100644 --- a/tests/ut/cpp/dataset/tensor_string_test.cc +++ b/tests/ut/cpp/dataset/tensor_string_test.cc @@ -35,13 +35,15 @@ class MindDataTestStringTensorDE : public UT::Common { }; TEST_F(MindDataTestStringTensorDE, Basics) { - std::shared_ptr t = std::make_shared("Hi"); + std::shared_ptr t; + Tensor::CreateScalar("Hi", &t); ASSERT_TRUE(t->shape() == TensorShape({})); std::string_view s = ""; t->GetItemAt(&s, {}); ASSERT_TRUE(s == "Hi"); - std::shared_ptr t2 = std::make_shared(std::vector{"Hi", "Bye"}); + std::shared_ptr t2; + Tensor::CreateFromVector(std::vector{"Hi", "Bye"}, &t2); ASSERT_TRUE(t2->shape() == TensorShape({2})); t2->GetItemAt(&s, {0}); ASSERT_TRUE(s == "Hi"); @@ -49,7 +51,9 @@ TEST_F(MindDataTestStringTensorDE, Basics) { ASSERT_TRUE(s == "Bye"); std::vector strings{"abc", "defg", "hi", "klmno", "123", "789"}; - std::shared_ptr t3 = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t3; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t3); + ASSERT_TRUE(t3->shape() == TensorShape({2, 3})); uint32_t index = 0; for (uint32_t i = 0; i < 2; i++) { @@ -62,8 +66,10 @@ TEST_F(MindDataTestStringTensorDE, Basics) { } TEST_F(MindDataTestStringTensorDE, Basics2) { - std::shared_ptr t = - std::make_shared(std::vector{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3})); + std::shared_ptr t; + Tensor::CreateFromVector(std::vector{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3}), + &t); + ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20 + 4); std::vector offsets = {0, 4, 9, 12, 18, 22, 26}; uint32_t ctr = 0; @@ -86,7 +92,8 @@ TEST_F(MindDataTestStringTensorDE, Basics2) { TEST_F(MindDataTestStringTensorDE, Empty) { std::vector strings{"abc", "defg", "", "", "123", ""}; - std::shared_ptr t = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t); // abc_defg___123__ // 0123456789012345 ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10 + 4); @@ -112,7 +119,9 @@ TEST_F(MindDataTestStringTensorDE, Empty) { TEST_F(MindDataTestStringTensorDE, SetItem) { std::vector strings{"abc", "defg", "hi", "klmno", "123", "789"}; - std::shared_ptr t3 = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t3; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t3); + ASSERT_TRUE(t3->shape() == TensorShape({2, 3})); t3->SetItemAt({0, 1}, std::string{"xyzz"}); @@ -136,7 +145,8 @@ TEST_F(MindDataTestStringTensorDE, SetItem) { TEST_F(MindDataTestStringTensorDE, Iterator) { std::vector strings{"abc", "defg", "hi", "klmno", "123", "789"}; - std::shared_ptr t = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t); uint32_t index = 0; auto itr = t->begin(); for (; itr != t->end(); itr++) { diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index fce4652b47..758b194835 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -35,8 +35,9 @@ class MindDataTestTensorDE : public UT::Common { }; TEST_F(MindDataTestTensorDE, Basics) { - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); - ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk()); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); + ASSERT_EQ(t->shape(), TensorShape({2, 3})); ASSERT_EQ(t->type(), DataType::DE_UINT64); ASSERT_EQ(t->SizeInBytes(), 2 * 3 * 8); @@ -67,28 +68,30 @@ TEST_F(MindDataTestTensorDE, Basics) { ASSERT_EQ(t->ToString(), "Tensor (shape: <2,3>, Type: uint64)\n[[1,2,3],[4,5,6]]"); std::vector x = {1, 2, 3, 4, 5, 6}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x, TensorShape({2, 3})); + Tensor::CreateFromVector(x, TensorShape({2, 3}), &t2); ASSERT_EQ(*t == *t2, true); ASSERT_EQ(*t != *t2, false); } TEST_F(MindDataTestTensorDE, Fill) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32), &t); t->Fill(2.5); std::vector x = {2.5, 2.5, 2.5, 2.5}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x, TensorShape({2, 2})); + Tensor::CreateFromVector(x, TensorShape({2, 2}), &t2); ASSERT_EQ(*t == *t2, true); } TEST_F(MindDataTestTensorDE, Reshape) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->Fill(254); t->Reshape(TensorShape({4})); std::vector x = {254, 254, 254, 254}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x); + Tensor::CreateFromVector(x, &t2); ASSERT_EQ(*t == *t2, true); Status rc = t->Reshape(TensorShape({5})); @@ -102,7 +105,8 @@ TEST_F(MindDataTestTensorDE, Reshape) { } TEST_F(MindDataTestTensorDE, CopyTensor) { - std::shared_ptr t = std::make_shared(TensorShape({}), DataType(DataType::DE_INT16)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({}), DataType(DataType::DE_INT16), &t); t->SetItemAt({}, -66); ASSERT_EQ(t->shape(), TensorShape({})); ASSERT_EQ(t->type(), DataType::DE_INT16); @@ -125,30 +129,31 @@ TEST_F(MindDataTestTensorDE, CopyTensor) { } TEST_F(MindDataTestTensorDE, InsertTensor) { - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_FLOAT64), &t); std::vector x = {1.1, 2.1, 3.1}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x); + Tensor::CreateFromVector(x, &t2); std::vector y = {1.2, 2.2, 3.2}; std::shared_ptr t3; - Tensor::CreateTensor(&t3, y); + Tensor::CreateFromVector(y, &t3); ASSERT_TRUE(t->InsertTensor({0}, t2).OK()); ASSERT_TRUE(t->InsertTensor({1}, t3).OK()); std::vector z = {1.1, 2.1, 3.1, 1.2, 2.2, 3.2}; std::shared_ptr t4; - Tensor::CreateTensor(&t4, z, TensorShape({2, 3})); + Tensor::CreateFromVector(z, TensorShape({2, 3}), &t4); ASSERT_EQ(*t == *t4, true); std::shared_ptr t5; - Tensor::CreateTensor(&t5, 0); + Tensor::CreateScalar(0, &t5); ASSERT_TRUE(t->InsertTensor({1, 2}, t5).OK()); z[5] = 0; std::shared_ptr t6; - Tensor::CreateTensor(&t6, z, TensorShape({2, 3})); + Tensor::CreateFromVector(z, TensorShape({2, 3}), &t6); ASSERT_EQ(*t == *t6, true); ASSERT_EQ(t->InsertTensor({2}, t5).get_code(), StatusCode::kUnexpectedError); @@ -161,7 +166,8 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values TEST_F(MindDataTestTensorDE, BoolTensor) { - std::shared_ptr t = std::make_shared(TensorShape({2}), DataType(DataType::DE_BOOL)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2}), DataType(DataType::DE_BOOL), &t); t->SetItemAt({0}, true); t->SetItemAt({1}, true); std::string out = t->ToString(); @@ -169,7 +175,8 @@ TEST_F(MindDataTestTensorDE, BoolTensor) { } TEST_F(MindDataTestTensorDE, GetItemAt) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->Fill(254); uint64_t o1; t->GetItemAt(&o1, {0, 0}); @@ -183,7 +190,8 @@ TEST_F(MindDataTestTensorDE, GetItemAt) { uint8_t o4; t->GetItemAt(&o4, {1, 1}); ASSERT_EQ(o4, 254); - std::shared_ptr t2 = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_INT8)); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_INT8), &t2); t2->Fill(-10); int64_t o5; t2->GetItemAt(&o5, {0, 0}); @@ -197,7 +205,8 @@ TEST_F(MindDataTestTensorDE, GetItemAt) { int8_t o8; t2->GetItemAt(&o8, {1, 1}); ASSERT_EQ(o8, -10); - std::shared_ptr t3 = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32)); + std::shared_ptr t3; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32), &t3); t3->Fill(1.1); double o9; t3->GetItemAt(&o9, {0, 0}); @@ -208,9 +217,11 @@ TEST_F(MindDataTestTensorDE, GetItemAt) { } TEST_F(MindDataTestTensorDE, OperatorAssign) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->Fill(1); - std::shared_ptr t2 = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t2); *t2 = std::move(*t); uint8_t o; t2->GetItemAt(&o, {0, 0}); @@ -224,18 +235,20 @@ TEST_F(MindDataTestTensorDE, OperatorAssign) { } TEST_F(MindDataTestTensorDE, Strides) { - std::shared_ptr t = std::make_shared(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT8), &t); std::vector x1 = t->Strides(); std::vector x2 = {4, 2, 1}; ASSERT_EQ(x1, x2); - t = std::make_shared(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT32)); + Tensor::CreateEmpty(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT32), &t); x1 = t->Strides(); x2 = {16, 8, 4}; ASSERT_EQ(x1, x2); } void checkCvMat(TensorShape shape, DataType type) { - std::shared_ptr t = std::make_shared(shape, type); + std::shared_ptr t; + CVTensor::CreateEmpty(shape, type, &t); cv::Mat m = t->mat(); ASSERT_EQ(m.data, t->GetBuffer()); ASSERT_EQ(static_cast(m.type()) & static_cast(CV_MAT_DEPTH_MASK), type.AsCVType()); @@ -289,8 +302,10 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) { m.at(0, 1) = 20; m.at(1, 0) = 30; m.at(1, 1) = 40; - std::shared_ptr cvt = std::make_shared(m); - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr cvt; + CVTensor::CreateFromMat(m, &cvt); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->SetItemAt({0, 0}, 10); t->SetItemAt({0, 1}, 20); t->SetItemAt({1, 0}, 30); @@ -302,8 +317,10 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) { m2.at(1) = 20; m2.at(2) = 30; m2.at(3) = 40; - std::shared_ptr cvt2 = std::make_shared(m2); - std::shared_ptr t2 = std::make_shared(TensorShape({4}), DataType(DataType::DE_UINT8)); + std::shared_ptr cvt2; + CVTensor::CreateFromMat(m2, &cvt2); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({4}), DataType(DataType::DE_UINT8), &t2); t2->SetItemAt({0}, 10); t2->SetItemAt({1}, 20); t2->SetItemAt({2}, 30); @@ -313,10 +330,12 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) { } TEST_F(MindDataTestTensorDE, CVTensorAs) { - std::shared_ptr t = std::make_shared(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64), &t); t->Fill(2.2); const unsigned char *addr = t->GetBuffer(); - std::shared_ptr t2 = std::make_shared(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64)); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64), &t2); t2->Fill(4.4); std::shared_ptr ctv = CVTensor::AsCVTensor(t); ASSERT_EQ(t->GetBuffer(), nullptr); @@ -326,6 +345,10 @@ TEST_F(MindDataTestTensorDE, CVTensorAs) { ASSERT_EQ(ctv->GetBuffer(), addr); ASSERT_TRUE(*t2 == *ctv); MS_LOG(DEBUG) << *t2 << std::endl << *ctv; + cv::Mat m2 = ctv->matCopy(); + m2 = 2 * m2; + ASSERT_EQ(ctv->GetBuffer(), addr); + ASSERT_TRUE(*t2 == *ctv); } TEST_F(MindDataTestTensorDE, CVTensorMatSlice) { @@ -336,23 +359,26 @@ TEST_F(MindDataTestTensorDE, CVTensorMatSlice) { m.at(1, 0) = 40; m.at(1, 1) = 50; m.at(1, 2) = 60; - std::shared_ptr cvt = std::make_shared(m); + std::shared_ptr cvt; + CVTensor::CreateFromMat(m, &cvt); cv::Mat mat; - cvt->Mat({1}, &mat); + cvt->MatAtIndex({1}, &mat); cv::Mat m2(3, 1, CV_32S); m2.at(0) = 40; m2.at(1) = 50; m2.at(2) = 60; - std::shared_ptr cvt2 = std::make_shared(mat); - std::shared_ptr cvt3 = std::make_shared(m2); + std::shared_ptr cvt2; + CVTensor::CreateFromMat(mat, &cvt2); + std::shared_ptr cvt3; + CVTensor::CreateFromMat(m2, &cvt3); ASSERT_TRUE(*cvt2 == *cvt3); - cvt->Mat({0}, &mat); + cvt->MatAtIndex({0}, &mat); m2.at(0) = 10; m2.at(1) = 20; m2.at(2) = 30; - cvt2 = std::make_shared(mat); - cvt3 = std::make_shared(m2); + CVTensor::CreateFromMat(mat, &cvt2); + CVTensor::CreateFromMat(m2, &cvt3); ASSERT_TRUE(*cvt2 == *cvt3); } @@ -361,7 +387,7 @@ TEST_F(MindDataTestTensorDE, TensorIterator) { std::vector values2 = {2, 3, 4, 5, 6, 7}; std::shared_ptr t; - Tensor::CreateTensor(&t, values); + Tensor::CreateFromVector(values, &t); auto i = t->begin(); auto j = values.begin(); @@ -395,31 +421,31 @@ TEST_F(MindDataTestTensorDE, TensorIterator) { TEST_F(MindDataTestTensorDE, TensorSlice) { std::shared_ptr t; - Tensor::CreateTensor(&t, std::vector{0, 1, 2, 3, 4}); + Tensor::CreateFromVector(std::vector{0, 1, 2, 3, 4}, &t); std::shared_ptr t2; auto x = std::vector{0, 3, 4}; std::shared_ptr expected; - Tensor::CreateTensor(&expected, x); + Tensor::CreateFromVector(x, &expected); t->Slice(&t2, x); ASSERT_EQ(*t2, *expected); t->Slice(&t2, std::vector{0, 1, 2, 3, 4}); ASSERT_EQ(*t2, *t); } -TEST_F(MindDataTestTensorDE, TensorConcatenate) { +TEST_F(MindDataTestTensorDE, TensorPartialInsert) { std::vector values1 = {1, 2, 3, 0, 0, 0}; std::vector values2 = {4, 5, 6}; std::vector expected = {1, 2, 3, 4, 5, 6}; std::shared_ptr t1; - Tensor::CreateTensor(&t1, values1); + Tensor::CreateFromVector(values1, &t1); std::shared_ptr t2; - Tensor::CreateTensor(&t2, values2); + Tensor::CreateFromVector(values2, &t2); std::shared_ptr out; - Tensor::CreateTensor(&out, expected); - Status s = t1->Concatenate({3}, t2); + Tensor::CreateFromVector(expected, &out); + Status s = t1->InsertTensor({3}, t2, true); EXPECT_TRUE(s.IsOk()); auto i = out->begin(); @@ -429,20 +455,85 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) { } // should fail if the concatenated vector is too large - s = t1->Concatenate({5}, t2); + s = t1->InsertTensor({5}, t2, true); EXPECT_FALSE(s.IsOk()); } TEST_F(MindDataTestTensorDE, TensorEmpty) { - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); - ASSERT_TRUE(t->HasData()); -} + TensorPtr t; + Status rc = Tensor::CreateEmpty(TensorShape({0}), DataType(DataType::DE_UINT64), &t); + ASSERT_TRUE(rc.IsOk()); -TEST_F(MindDataTestTensorDE, TensorEmptyInvalidate) { - std::vector values1 = {1, 2, 3, 0, 0, 0}; - std::shared_ptr t; - Tensor::CreateTensor(&t, values1); - t->Invalidate(); - ASSERT_TRUE(t->HasData()); -} + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_UINT64); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = t->SetItemAt({0}, 7); + ASSERT_TRUE(rc.IsError()); + + rc = Tensor::CreateEmpty(TensorShape({1, 0}), DataType(DataType::DE_STRING), &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({1, 0})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + std::vector data; + rc = Tensor::CreateFromVector(data, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_UINT16); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + std::vector data2; + rc = Tensor::CreateFromVector(data2, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + rc = Tensor::CreateFromVector(data, TensorShape({0, 2}), &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0, 2})); + ASSERT_EQ(t->type(), DataType::DE_UINT16); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromVector(data2, TensorShape({0, 0, 6}), &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0, 0, 6})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromMemory(TensorShape({0}), DataType(DataType::DE_INT8), nullptr, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_INT8); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromMemory(TensorShape({0}), DataType(DataType::DE_STRING), nullptr, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + + std::vector values = {1, 2, 3, 0, 0, 0}; + std::shared_ptr t2; + Tensor::CreateFromVector(values, &t2); + ASSERT_TRUE(t2->HasData()); + t2->Invalidate(); + ASSERT_TRUE(!t2->HasData()); +} diff --git a/tests/ut/cpp/dataset/tensorshape_test.cc b/tests/ut/cpp/dataset/tensorshape_test.cc index 65ab386db0..7eec95b081 100644 --- a/tests/ut/cpp/dataset/tensorshape_test.cc +++ b/tests/ut/cpp/dataset/tensorshape_test.cc @@ -20,7 +20,7 @@ #include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/engine/data_schema.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/text_file_op_test.cc b/tests/ut/cpp/dataset/text_file_op_test.cc index bc2674a6a3..9450ac7b98 100644 --- a/tests/ut/cpp/dataset/text_file_op_test.cc +++ b/tests/ut/cpp/dataset/text_file_op_test.cc @@ -19,7 +19,7 @@ #include "minddata/dataset/core/client.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/dataset/engine/datasetops/source/text_file_op.h" diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 30fde33ff9..9f0919cd96 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -20,7 +20,7 @@ #include "minddata/dataset/core/client.h" #include "minddata/dataset/engine/data_schema.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/tokenizer_op_test.cc b/tests/ut/cpp/dataset/tokenizer_op_test.cc index cc2d7473ff..df3a5435de 100644 --- a/tests/ut/cpp/dataset/tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/tokenizer_op_test.cc @@ -46,8 +46,8 @@ class MindDataTestTokenizerOp : public UT::Common { TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp."; std::unique_ptr op(new UnicodeCharTokenizerOp(true)); - std::shared_ptr input = std::make_shared("Hello World!"); - TensorRow output; +std::shared_ptr input; + Tensor::CreateScalar("Hello World!", &input); TensorRow output; Status s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 12); @@ -66,7 +66,7 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { CheckEqual(output[0], {10}, "d"); CheckEqual(output[0], {11}, "!"); - input = std::make_shared("中国 你好!"); + Tensor::CreateScalar("中国 你好!", &input); output.clear(); s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); @@ -80,38 +80,38 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { CheckEqual(output[0], {4}, "好"); CheckEqual(output[0], {5}, "!"); - input = std::make_shared("中"); - output.clear(); + Tensor::CreateScalar("中", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); CheckEqual(output[0], {0}, "中"); - input = std::make_shared("H"); - output.clear(); + Tensor::CreateScalar("H", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); CheckEqual(output[0], {0}, "H"); - input = std::make_shared(" "); - output.clear(); + Tensor::CreateScalar(" ", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 2); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); CheckEqual(output[0], {0}, " "); CheckEqual(output[0], {1}, " "); - input = std::make_shared(""); - output.clear(); + Tensor::CreateScalar("", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor6: " << output[0]->ToString(); @@ -121,10 +121,10 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { TEST_F(MindDataTestTokenizerOp, TestWhitespaceTokenizerOp) { MS_LOG(INFO) << "Doing TestWhitespaceTokenizerOp."; std::unique_ptr op(new WhitespaceTokenizerOp(true)); - std::shared_ptr input = std::make_shared("Welcome to China."); - TensorRow output; +std::shared_ptr input; + Tensor::CreateScalar("Welcome to China.", &input); TensorRow output; Status s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 3); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor1: " << output[0]->ToString(); @@ -132,37 +132,37 @@ TEST_F(MindDataTestTokenizerOp, TestWhitespaceTokenizerOp) { CheckEqual(output[0], {1}, "to"); CheckEqual(output[0], {2}, "China."); - input = std::make_shared(" hello"); - output.clear(); + Tensor::CreateScalar(" hello", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor2: " << output[0]->ToString(); CheckEqual(output[0], {0}, "hello"); - input = std::make_shared("hello"); - output.clear(); + Tensor::CreateScalar("hello", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); CheckEqual(output[0], {0}, "hello"); - input = std::make_shared("hello "); - output.clear(); + Tensor::CreateScalar("hello ", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); CheckEqual(output[0], {0}, "hello"); - input = std::make_shared(" "); - output.clear(); + Tensor::CreateScalar(" ", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); @@ -174,8 +174,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { std::unique_ptr keep_whitespace_op(new UnicodeScriptTokenizerOp(true, true)); std::unique_ptr skip_whitespace_op(new UnicodeScriptTokenizerOp(false, true)); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); - TensorRow output; + std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); + TensorRow output; Status s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 10); @@ -204,10 +205,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { CheckEqual(output[0], {4}, "中国"); CheckEqual(output[0], {5}, "北京"); - input = std::make_shared(" Welcome to 中国. "); - output.clear(); - s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + Tensor::CreateScalar(" Welcome to 中国. ", &input); + output.clear(); + s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 4); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); @@ -230,25 +230,23 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { CheckEqual(output[0], {6}, "."); CheckEqual(output[0], {7}, " "); - input = std::make_shared("Hello"); - output.clear(); - s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + Tensor::CreateScalar("Hello", &input); +output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); CheckEqual(output[0], {0}, "Hello"); - input = std::make_shared("H"); - output.clear(); - s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + Tensor::CreateScalar("H", &input); +output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor6: " << output[0]->ToString(); CheckEqual(output[0], {0}, "H"); - input = std::make_shared(""); + Tensor::CreateScalar("", &input); output.clear(); s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); @@ -257,10 +255,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { MS_LOG(INFO) << "Out tensor7: " << output[0]->ToString(); CheckEqual(output[0], {0}, ""); - input = std::make_shared("Hello中国Hello世界"); - output.clear(); - s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output[0]->Size(), 4); + Tensor::CreateScalar("Hello中国Hello世界", &input); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 4); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor8: " << output[0]->ToString(); CheckEqual(output[0], {0}, "Hello"); @@ -268,15 +265,15 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { CheckEqual(output[0], {2}, "Hello"); CheckEqual(output[0], {3}, "世界"); - input = std::make_shared(" "); - output.clear(); + Tensor::CreateScalar(" ", &input); + output.clear(); s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor10: " << output[0]->ToString(); CheckEqual(output[0], {0}, " "); - input = std::make_shared(" "); + Tensor::CreateScalar(" ", &input); output.clear(); s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); @@ -289,7 +286,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { TEST_F(MindDataTestTokenizerOp, TestCaseFold) { MS_LOG(INFO) << "Doing TestCaseFold."; std::unique_ptr case_fold_op(new CaseFoldOp()); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); + std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); + std::shared_ptr output; Status s = case_fold_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -305,7 +304,8 @@ TEST_F(MindDataTestTokenizerOp, TestNormalize) { std::unique_ptr nfkc_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfkc)); std::unique_ptr nfd_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfd)); std::unique_ptr nfkd_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfkd)); - std::shared_ptr input = std::make_shared("ṩ"); + std::shared_ptr input; + Tensor::CreateScalar("ṩ", &input); std::shared_ptr output; Status s = nfc_normalize_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -327,7 +327,8 @@ TEST_F(MindDataTestTokenizerOp, TestNormalize) { TEST_F(MindDataTestTokenizerOp, TestRegexReplace) { MS_LOG(INFO) << "Doing TestRegexReplace."; std::unique_ptr regex_replace_op(new RegexReplaceOp("\\s+", "_", true)); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); + std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); std::shared_ptr output; Status s = regex_replace_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -340,19 +341,20 @@ TEST_F(MindDataTestTokenizerOp, TestRegexReplace) { TEST_F(MindDataTestTokenizerOp, TestRegexTokenizer) { MS_LOG(INFO) << "Doing TestRegexTokenizerOp."; std::unique_ptr regex_tokenizer_op(new RegexTokenizerOp("\\p{Cc}|\\p{Cf}|\\s+", "", true)); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); - TensorRow output; +std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); + TensorRow output; Status s = regex_tokenizer_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); } TEST_F(MindDataTestTokenizerOp, TestBasicTokenizer) { MS_LOG(INFO) << "Doing TestBasicTokenizer."; - //bool lower_case, bool keep_whitespace, + // bool lower_case, bool keep_whitespace, // NormalizeForm normalization_form, bool preserve_unused_token - std::unique_ptr basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false, - true)); - std::shared_ptr input = std::make_shared("Welcome to China. 中国\t北京"); + std::unique_ptr basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false,true)); +std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. 中国\t北京", &input); TensorRow output; Status s = basic_tokenizer->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); diff --git a/tests/ut/cpp/dataset/trucate_pair_test.cc b/tests/ut/cpp/dataset/trucate_pair_test.cc index af7e61c16a..48d30cf2f5 100644 --- a/tests/ut/cpp/dataset/trucate_pair_test.cc +++ b/tests/ut/cpp/dataset/trucate_pair_test.cc @@ -35,17 +35,17 @@ class MindDataTestTruncatePairOp : public UT::Common { TEST_F(MindDataTestTruncatePairOp, Basics) { std::shared_ptr t1; - Tensor::CreateTensor(&t1, std::vector({1, 2, 3})); + Tensor::CreateFromVector(std::vector({1, 2, 3}), &t1); std::shared_ptr t2; - Tensor::CreateTensor(&t2, std::vector({4, 5})); + Tensor::CreateFromVector(std::vector({4, 5}), &t2); TensorRow in({t1, t2}); std::shared_ptr op = std::make_shared(4); TensorRow out; ASSERT_TRUE(op->Compute(in, &out).IsOk()); std::shared_ptr out1; - Tensor::CreateTensor(&out1, std::vector({1, 2})); + Tensor::CreateFromVector(std::vector({1, 2}), &out1); std::shared_ptr out2; - Tensor::CreateTensor(&out2, std::vector({4, 5})); + Tensor::CreateFromVector(std::vector({4, 5}), &out2); ASSERT_EQ(*out1, *out[0]); ASSERT_EQ(*out2, *out[1]); } diff --git a/tests/ut/cpp/dataset/type_cast_op_test.cc b/tests/ut/cpp/dataset/type_cast_op_test.cc index a94a7fedba..371b589c57 100644 --- a/tests/ut/cpp/dataset/type_cast_op_test.cc +++ b/tests/ut/cpp/dataset/type_cast_op_test.cc @@ -43,16 +43,15 @@ class MindDataTestTypeCast : public UT::Common { template void testCast(std::vector values, const DataType &from, const DataType &to) { - std::shared_ptr t = std::make_shared(TensorShape({static_cast(values.size())}), - DataType(from), - reinterpret_cast(&values[0])); + std::shared_ptr t; + Tensor::CreateFromVector(values, &t); std::unique_ptr op(new TypeCastOp(to)); EXPECT_TRUE(op->OneToOne()); std::shared_ptr output; EXPECT_TRUE(op->Compute(t, &output)); ASSERT_TRUE(t->shape() == output->shape()); - ASSERT_TRUE(DataType(to)==output->type()); + ASSERT_TRUE(DataType(to) == output->type()); MS_LOG(DEBUG) << *output << std::endl; auto out = output->begin(); auto v = values.begin(); diff --git a/tests/ut/cpp/dataset/voc_op_test.cc b/tests/ut/cpp/dataset/voc_op_test.cc index 4bb212ffc7..294703d608 100644 --- a/tests/ut/cpp/dataset/voc_op_test.cc +++ b/tests/ut/cpp/dataset/voc_op_test.cc @@ -19,7 +19,7 @@ #include #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/voc_op.h" diff --git a/tests/ut/cpp/dataset/zip_op_test.cc b/tests/ut/cpp/dataset/zip_op_test.cc index 8d74cb0969..b55578f672 100644 --- a/tests/ut/cpp/dataset/zip_op_test.cc +++ b/tests/ut/cpp/dataset/zip_op_test.cc @@ -23,12 +23,11 @@ #include #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/constants.h" -#include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/config_manager.h" #include "common/common.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "minddata/dataset/engine/data_buffer.h" #include "gtest/gtest.h" #include "minddata/dataset/core/global_context.h" diff --git a/tests/ut/cpp/ir/clone_test.cc b/tests/ut/cpp/ir/clone_test.cc index 20da3fb8b5..1929cf599e 100644 --- a/tests/ut/cpp/ir/clone_test.cc +++ b/tests/ut/cpp/ir/clone_test.cc @@ -22,7 +22,7 @@ #include "utils/log_adapter.h" #include "ir/func_graph_cloner.h" #include "pipeline/jit/parse/parse.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "debug/draw.h" #include "./common.h" diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 3e6d1a312c..060151109d 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -22,7 +22,7 @@ #include "frontend/operator/ops.h" #include "utils/log_adapter.h" #include "debug/draw.h" -#include "debug/label.h" +#include "utils/label.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/ir/meta_tensor_test.cc b/tests/ut/cpp/ir/meta_tensor_test.cc index 537d4c460e..928e90a1a1 100644 --- a/tests/ut/cpp/ir/meta_tensor_test.cc +++ b/tests/ut/cpp/ir/meta_tensor_test.cc @@ -22,7 +22,7 @@ #include "securec/include/securec.h" #include "ir/tensor.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" using mindspore::tensor::TensorPy; @@ -225,6 +225,27 @@ TEST_F(TestTensor, EqualTest) { ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c()); } +TEST_F(TestTensor, ValueEqualTest) { + py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6); + TensorPtr t1 = TensorPy::MakeTensor(py::array(tuple), kInt32); + TensorPtr t2 = TensorPy::MakeTensor(py::array(tuple), kInt32); + ASSERT_TRUE(t1->ValueEqual(*t1)); + ASSERT_TRUE(t1->ValueEqual(*t2)); + + std::vector shape = {6}; + TensorPtr t3 = std::make_shared(kInt32->type_id(), shape); + TensorPtr t4 = std::make_shared(kInt32->type_id(), shape); + ASSERT_TRUE(t3->ValueEqual(*t3)); + ASSERT_FALSE(t3->ValueEqual(*t4)); + ASSERT_FALSE(t3->ValueEqual(*t1)); + ASSERT_FALSE(t1->ValueEqual(*t3)); + + memcpy_s(t3->data_c(), t3->data().nbytes(), t1->data_c(), t1->data().nbytes()); + ASSERT_TRUE(t1->ValueEqual(*t3)); + ASSERT_FALSE(t3->ValueEqual(*t4)); + ASSERT_FALSE(t4->ValueEqual(*t3)); +} + TEST_F(TestTensor, PyArrayTest) { py::array_t input({2, 3}); auto array = input.mutable_unchecked(); diff --git a/tests/ut/cpp/kernel/common_utils_test.cc b/tests/ut/cpp/kernel/common_utils_test.cc index 83f7c59e52..4e016cd495 100644 --- a/tests/ut/cpp/kernel/common_utils_test.cc +++ b/tests/ut/cpp/kernel/common_utils_test.cc @@ -25,7 +25,7 @@ class CommonUtilTest : public UT::Common { CommonUtilTest() = default; }; -TEST_F(CommonUtilTest, DeduplicateIndexedSlicesTest1) { +TEST_F(CommonUtilTest, BucketReduceSparseGradient1) { // The indices is a vector and the grad is a tensor with shape (6, 2) /* 0 * 0 @@ -46,20 +46,39 @@ TEST_F(CommonUtilTest, DeduplicateIndexedSlicesTest1) { for (int i = 0; i < 6 * 2; i++) { grad.push_back(i); } - std::vector unique_indices(3); - std::vector summed_grad(6); - SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 0}); - ReduceSparseGradient(SparseGradient({grad.data(), indices.data(), 6}), &unique_grad, 6, 2); + std::vector unique_indices(6); + std::vector summed_grad(12); + std::vector tmp_indices(6); + std::vector tmp_grad(12); + + SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6}); + SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6}); + SparseGradient input_grad({grad.data(), indices.data(), 6}); + + ReduceSparseGradientParam param; + param.input_grad_ = &input_grad; + param.workspace_grad_ = &workspace_grad; + param.output_grad_ = &unique_grad; + param.max_index_ = 6; + param.value_stride_ = 2; + BucketReduceSparseGradient(param); + EXPECT_EQ(unique_grad.indices_size_, 3); - EXPECT_EQ(unique_indices, std::vector({0, 1, 3})); + std::vector expect_indices({0, 1, 3}); + for (size_t i = 0; i < unique_grad.indices_size_; ++i) { + EXPECT_EQ(unique_grad.indices_[i], expect_indices[i]); + } /* 10 13 * 10 12 * 10 11 */ - EXPECT_EQ(summed_grad, std::vector({10, 13, 10, 12, 10, 11})); + std::vector expect_value({10, 13, 10, 12, 10, 11}); + for (size_t i = 0; i < unique_grad.indices_size_ * 2; ++i) { + EXPECT_EQ(unique_grad.value_[i], expect_value[i]); + } } -TEST_F(CommonUtilTest, DeduplicateIndexedSlicesTest2) { +TEST_F(CommonUtilTest, BucketReduceSparseGradient2) { // The indices is a vector and the grad is a tensor with shape (6, 2) /* 0 * 0 @@ -80,16 +99,36 @@ TEST_F(CommonUtilTest, DeduplicateIndexedSlicesTest2) { for (int i = 0; i < 6 * 2; i++) { grad.push_back(i); } - std::vector unique_indices(2); - std::vector summed_grad(4); - SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 0}); - ReduceSparseGradient(SparseGradient({grad.data(), indices.data(), 6}), &unique_grad, 6, 2); + std::vector unique_indices(6); + std::vector summed_grad(12); + std::vector tmp_indices(6); + std::vector tmp_grad(12); + SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6}); + SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6}); + SparseGradient input_grad({grad.data(), indices.data(), 6}); + + ReduceSparseGradientParam param; + param.input_grad_ = &input_grad; + param.workspace_grad_ = &workspace_grad; + param.output_grad_ = &unique_grad; + param.max_index_ = 6; + param.value_stride_ = 2; + BucketReduceSparseGradient(param); + EXPECT_EQ(unique_grad.indices_size_, 2); - EXPECT_EQ(unique_indices, std::vector({0, 1})); + + std::vector expect_indices({0, 1}); + for (size_t i = 0; i < unique_grad.indices_size_; ++i) { + EXPECT_EQ(unique_grad.indices_[i], expect_indices[i]); + } + /* 10 13 * 10 12 */ - EXPECT_EQ(summed_grad, std::vector({10, 13, 10, 12})); + std::vector expect_value({10, 13, 10, 12}); + for (size_t i = 0; i < unique_grad.indices_size_ * 2; ++i) { + EXPECT_EQ(unique_grad.value_[i], expect_value[i]); + } } } // namespace kernel } // namespace mindspore diff --git a/tests/ut/cpp/mindrecord/ut_common.h b/tests/ut/cpp/mindrecord/ut_common.h index ee943ab88e..79c76c0a04 100644 --- a/tests/ut/cpp/mindrecord/ut_common.h +++ b/tests/ut/cpp/mindrecord/ut_common.h @@ -22,7 +22,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/mindrecord/include/shard_index.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 4501ea0800..2137fb4a13 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/mindrecord/include/shard_category.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index 8b5eb2cf69..fb0e8470ce 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/mindrecord/include/shard_reader.h" @@ -94,31 +94,6 @@ TEST_F(TestShardReader, TestShardReaderSample) { dataset.Close(); } -TEST_F(TestShardReader, TestShardReaderBlock) { - MS_LOG(INFO) << FormatInfo("Test read imageNet with block way"); - std::string file_name = "./imagenet.shard01"; - auto column_list = std::vector{"label"}; - - std::vector> ops; - ops.push_back(std::make_shared(3)); - ShardReader dataset; - const bool kBlockReader = true; - dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader); - dataset.Launch(); - - while (true) { - auto x = dataset.GetBlockNext(); - if (x.empty()) break; - for (auto &j : x) { - for (auto &item : std::get<1>(j).items()) { - MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); - } - } - } - dataset.Finish(); - dataset.Close(); -} - TEST_F(TestShardReader, TestShardReaderEasy) { MS_LOG(INFO) << FormatInfo("Test read imageNet"); std::string file_name = "./imagenet.shard01"; diff --git a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc index 6b99e44d89..a4900a51f2 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc @@ -27,7 +27,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/mindrecord/include/shard_segment.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 046b4f93d5..a8abe5e98d 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -21,7 +21,7 @@ #include #include -#include "common/utils.h" +#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "minddata/mindrecord/include/shard_reader.h" diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index a2108998bc..9912e0c4e8 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -21,7 +21,7 @@ #include "frontend/operator/composite/composite.h" #include "frontend/operator/ops.h" #include "pipeline/jit/static_analysis/prim.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "debug/trace.h" namespace mindspore { diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 789b1cab25..796bad8053 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "ir/value.h" -#include "ir/primitive_py.h" +#include "utils/primitive_py.h" #include "frontend/operator/ops.h" #include "./common.h" @@ -267,11 +267,6 @@ TEST_F(TestOps, BroadCastShapeTest) { ASSERT_EQ(prim->name(), kPrimBroadcastShape->name()); } -TEST_F(TestOps, ShapeTest) { - auto prim = std::make_shared("Shape"); - ASSERT_EQ(prim->name(), kPrimShape->name()); -} - TEST_F(TestOps, ArrayMapTest) { auto prim = std::make_shared("array_map"); ASSERT_EQ(prim->name(), kPrimArrayMap->name()); @@ -454,8 +449,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) { ASSERT_TRUE(conv2d_ptr); if (nullptr != conv2d_ptr) { MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name(); - auto func = conv2d_ptr->GetComputeFunction(); - if (py::isinstance(func)) { + if(!conv2d_ptr->HasComputeFunction()){ MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented"; } diff --git a/tests/ut/cpp/optimizer/ad/ad_test.cc b/tests/ut/cpp/optimizer/ad/ad_test.cc index 3f861d3604..0cd1f1d705 100644 --- a/tests/ut/cpp/optimizer/ad/ad_test.cc +++ b/tests/ut/cpp/optimizer/ad/ad_test.cc @@ -23,7 +23,7 @@ #include "ir/value.h" #include "ir/func_graph_cloner.h" #include "utils/log_adapter.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/parse.h" #include "debug/draw.h" diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 751b301283..faf03505c7 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -24,7 +24,6 @@ #include "ir/func_graph_cloner.h" #include "ir/manager.h" #include "ir/value.h" -#include "ir/visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "pipeline/jit/resource.h" @@ -41,6 +40,9 @@ class TestOptLib : public UT::Common { void SetUp() { UT::InitPythonPath(); parse::data_converter::ClearObjectCache(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_execution_mode(kGraphMode); } FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { equiv_node.clear(); @@ -604,14 +606,27 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); } -TEST_F(TestOptLib, test_indexed_slices) { - FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices"); - FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices"); - FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values"); - FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values"); - FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape"); - FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape"); - auto patterns = std::vector({irpass.indexed_slices_eliminate_}); +TEST_F(TestOptLib, test_row_tensor) { + FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "before_get_indices"); + FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "after_get_indices"); + FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_row_tensor", "before_get_values"); + FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_row_tensor", "after_get_values"); + FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "before_get_dense_shape"); + FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "after_get_dense_shape"); + auto patterns = std::vector({irpass.row_tensor_eliminate_}); + ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); + ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); + ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); +} + +TEST_F(TestOptLib, test_sparse_tensor) { + FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices"); + FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices"); + FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values"); + FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values"); + FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape"); + FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape"); + auto patterns = std::vector({irpass.sparse_tensor_eliminate_}); ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index c329adc4a5..df1996d3e5 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -23,6 +23,7 @@ #include "ir/visitor.h" #include "ir/func_graph_cloner.h" #include "frontend/optimizer/opt.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/arithmetic_simplify.h" @@ -77,7 +78,7 @@ class TestOptOpt : public UT::Common { }; void SetUp() { - elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); + elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q); diff --git a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc index a500afc859..8840e76d85 100644 --- a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc @@ -154,13 +154,13 @@ class TestDPAlgo : public UT::Common { void TestDPAlgo::SetUp() { cost_graph = std::make_shared(); cost_graph->SetDeviceMemoryAndCostParameter(); - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 10; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); stage_map.push_back(2); @@ -1327,8 +1327,8 @@ TEST_F(TestDPAlgo, test_GetStrategy_for_DoubleStarGraph) { for (auto &op : cost_graph->GetOperators()) { StrategyPtr s_strategy = op->selected_strategy(); - std::vector strategy_0 = s_strategy->GetInputDim()[0]; - std::vector strategy_1 = s_strategy->GetInputDim()[1]; + Dimensions strategy_0 = s_strategy->GetInputDim()[0]; + Dimensions strategy_1 = s_strategy->GetInputDim()[1]; std::string string_strategy_0 = "["; for (size_t i = 0; i < strategy_0.size(); ++i) { diff --git a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc index 190a189a2d..c5c13851b1 100644 --- a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc @@ -43,13 +43,13 @@ class TestEdgeCostModel : public UT::Common { }; void TestEdgeCostModel::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 10; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); stage_map.push_back(2); diff --git a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc index 7d63f03179..f471775327 100644 --- a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc @@ -53,13 +53,13 @@ class TestCostGraph : public UT::Common { void TestCostGraph::SetUp() { cost_graph.SetDeviceMemoryAndCostParameter(); - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 10; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); stage_map.push_back(2); diff --git a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc index b9b6bb67d9..13c278bebd 100644 --- a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc @@ -33,13 +33,13 @@ class TestMatMulCost : public UT::Common { void TestMatMulCost::SetUp() { mmcost_ = MatMulCost(); - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -90,13 +90,13 @@ class TestActivationCost : public UT::Common { void TestActivationCost::SetUp() { ac_cost_ = ActivationCost(); - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -142,13 +142,13 @@ class TestPReLUCost : public UT::Common { void TestPReLUCost::SetUp() { prelu_cost_ = PReLUCost(); - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); diff --git a/tests/ut/cpp/parallel/device_manager_test.cc b/tests/ut/cpp/parallel/device_manager_test.cc index 0c048d647b..3962372bcf 100644 --- a/tests/ut/cpp/parallel/device_manager_test.cc +++ b/tests/ut/cpp/parallel/device_manager_test.cc @@ -69,8 +69,8 @@ void TestDeviceManager::TearDown() { } TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) { - std::vector dev_list; - std::vector stage_map; + RankList dev_list; + RankList stage_map; int32_t local_dev = 0; dev_list.push_back(5); @@ -85,12 +85,12 @@ TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) { ASSERT_EQ(dm_.DeviceNum(), 4); ASSERT_EQ(dm_.GetStageNum(), (int32_t)(2)); - std::vector dev_list_0 = dm_.GetDeviceListByStageId(0); - std::vector dev_list_1 = dm_.GetDeviceListByStageId(1); + RankList dev_list_0 = dm_.GetDeviceListByStageId(0); + RankList dev_list_1 = dm_.GetDeviceListByStageId(1); ASSERT_EQ(dev_list_0.size(), 2); ASSERT_EQ(dev_list_1.size(), 2); - std::vector::iterator it = dev_list_0.begin(); + RankList::iterator it = dev_list_0.begin(); ASSERT_EQ((*it), int32_t(5)); it++; ASSERT_EQ((*it), int32_t(3)); @@ -112,7 +112,7 @@ TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) { TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) { std::vector dev_list; - std::vector rlist; + RankList rlist; rlist.push_back(int32_t(2)); rlist.push_back(int32_t(1)); dev_list = dm_.CreateDeviceListByRankList(rlist); diff --git a/tests/ut/cpp/parallel/ops_info/activation_info_test.cc b/tests/ut/cpp/parallel/ops_info/activation_info_test.cc index 5f09de9e48..0dbae89d52 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_info_test.cc @@ -38,13 +38,13 @@ class TestActivationInfo : public UT::Common { }; void TestActivationInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -64,18 +64,18 @@ void TestActivationInfo::SetUp() { } TEST_F(TestActivationInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); activation->Init(strategy); - std::vector dev_matrix_shape = activation->dev_matrix_shape(); + Shape dev_matrix_shape = activation->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 16}; + Shape expect = {2, 4, 8, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestActivationInfo, InferSliceShape1) { - std::vector str = {{2, 4, 8, 16}}; + Strategys str = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, str); activation->Init(strategy); @@ -96,7 +96,7 @@ TEST_F(TestActivationInfo, InferSliceShape1) { } TEST_F(TestActivationInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 8, 16}}; + Strategys str = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, str); activation->Init(strategy); @@ -117,7 +117,7 @@ TEST_F(TestActivationInfo, GetTensorLayout1) { } TEST_F(TestActivationInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); activation->Init(strategy); @@ -128,7 +128,7 @@ TEST_F(TestActivationInfo, GetForwardOp1) { } TEST_F(TestActivationInfo, GetMirrorOPs1) { - std::vector inputs = {{1, 4, 8, 16}}; + Strategys inputs = {{1, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); activation->Init(strategy); @@ -148,7 +148,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs1) { } TEST_F(TestActivationInfo, GetMirrorOPs2) { - std::vector inputs = {{2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); activation->Init(strategy); @@ -161,7 +161,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs2) { TEST_F(TestActivationInfo, CheckStrategy1) { // Success: {{2,4,8,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = activation->Init(strategy); @@ -170,7 +170,7 @@ TEST_F(TestActivationInfo, CheckStrategy1) { TEST_F(TestActivationInfo, CheckStrategy2) { // Success: {{2,4,8,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = activation->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/activation_test.cc b/tests/ut/cpp/parallel/ops_info/activation_test.cc index 9d129b7a18..4442c1d4ff 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_test.cc @@ -40,13 +40,13 @@ class TestActivation : public UT::Common { }; void TestActivation::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -101,7 +101,7 @@ TEST_F(TestActivation, test_softmax_strategies) { ASSERT_NE(sp, nullptr); Cost cost = *(swc->cost_list[0]); - std::vector stra = sp->GetInputDim(); + Strategys stra = sp->GetInputDim(); ASSERT_GT(stra.size(), 0); Dimensions input0_stra = stra[0]; ASSERT_GT(input0_stra.size(), 2); diff --git a/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc b/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc index e49ed4e79d..e0e9424ac2 100644 --- a/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc @@ -38,13 +38,13 @@ class TestGeluInfo : public UT::Common { }; void TestGeluInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 130; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(128); stage_map.push_back(2); @@ -63,18 +63,18 @@ void TestGeluInfo::SetUp() { } TEST_F(TestGeluInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); gelu->Init(strategy); - std::vector dev_matrix_shape = gelu->dev_matrix_shape(); + Shape dev_matrix_shape = gelu->dev_matrix_shape(); - std::vector expect = {2, 4, 1, 16}; + Shape expect = {2, 4, 1, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestGeluInfo, InferSliceShape1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); gelu->Init(strategy); @@ -95,7 +95,7 @@ TEST_F(TestGeluInfo, InferSliceShape1) { } TEST_F(TestGeluInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); gelu->Init(strategy); @@ -116,7 +116,7 @@ TEST_F(TestGeluInfo, GetTensorLayout1) { } TEST_F(TestGeluInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); gelu->Init(strategy); @@ -127,7 +127,7 @@ TEST_F(TestGeluInfo, GetForwardOp1) { } TEST_F(TestGeluInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); gelu->Init(strategy); @@ -140,7 +140,7 @@ TEST_F(TestGeluInfo, GetMirrorOPs1) { TEST_F(TestGeluInfo, CheckStrategy1) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = gelu->Init(strategy); @@ -149,7 +149,7 @@ TEST_F(TestGeluInfo, CheckStrategy1) { TEST_F(TestGeluInfo, CheckStrategy2) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = gelu->Init(strategy); @@ -158,7 +158,7 @@ TEST_F(TestGeluInfo, CheckStrategy2) { TEST_F(TestGeluInfo, CheckStrategy3) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = gelu->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc b/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc index 125723868a..487099915a 100644 --- a/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc +++ b/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc @@ -34,13 +34,13 @@ class TestGenerateStrategy : public UT::Common { }; void TestGenerateStrategy::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 10; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); stage_map.push_back(2); diff --git a/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc b/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc index 029e0f2dc6..505ab29a5a 100644 --- a/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc @@ -38,13 +38,13 @@ class TestGetNextInfo : public UT::Common { }; void TestGetNextInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 8; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); int32_t local_dev = 0; // create a new g_device_manager @@ -65,16 +65,16 @@ void TestGetNextInfo::SetUp() { } TEST_F(TestGetNextInfo, InferDevMatrixShape1) { - std::vector inputs = {{}, {}}; + Strategys inputs = {{}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); get_next->Init(strategy); - std::vector dev_matrix_shape = get_next->dev_matrix_shape(); - std::vector expect = {8, 1}; + Shape dev_matrix_shape = get_next->dev_matrix_shape(); + Shape expect = {8, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestGetNextInfo, InferSliceShape1) { - std::vector str = {{}, {}}; + Strategys str = {{}, {}}; StrategyPtr strategy = NewStrategy(0, str); get_next->Init(strategy); @@ -90,7 +90,7 @@ TEST_F(TestGetNextInfo, InferSliceShape1) { } TEST_F(TestGetNextInfo, GetTensorLayout1) { - std::vector str = {{}, {}}; + Strategys str = {{}, {}}; StrategyPtr strategy = NewStrategy(0, str); get_next->Init(strategy); std::vector outputs = get_next->outputs_tensor_info(); @@ -106,14 +106,14 @@ TEST_F(TestGetNextInfo, GetTensorLayout1) { } TEST_F(TestGetNextInfo, CheckStrategy1) { - std::vector inputs = {}; + Strategys inputs = {}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = get_next->Init(strategy); ASSERT_EQ(ret, SUCCESS); } TEST_F(TestGetNextInfo, CheckStrategy2) { - std::vector inputs = {{8, 1}, {8}}; + Strategys inputs = {{8, 1}, {8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = get_next->Init(strategy); ASSERT_EQ(ret, FAILED); diff --git a/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc b/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc index 7037a85699..6cbdefd123 100644 --- a/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc @@ -38,13 +38,13 @@ class TestL2NormalizeInfo : public UT::Common { }; void TestL2NormalizeInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 34; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(32); stage_map.push_back(2); @@ -64,18 +64,18 @@ void TestL2NormalizeInfo::SetUp() { } TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) { - std::vector inputs = {{4, 1, 8}}; + Strategys inputs = {{4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); norm->Init(strategy); - std::vector dev_matrix_shape = norm->dev_matrix_shape(); + Shape dev_matrix_shape = norm->dev_matrix_shape(); - std::vector expect = {4, 1, 8}; + Shape expect = {4, 1, 8}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestL2NormalizeInfo, InferSliceShape1) { - std::vector str = {{4, 1, 8}}; + Strategys str = {{4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, str); norm->Init(strategy); @@ -96,7 +96,7 @@ TEST_F(TestL2NormalizeInfo, InferSliceShape1) { } TEST_F(TestL2NormalizeInfo, GetTensorLayout1) { - std::vector str = {{4, 1, 8}}; + Strategys str = {{4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, str); norm->Init(strategy); @@ -117,7 +117,7 @@ TEST_F(TestL2NormalizeInfo, GetTensorLayout1) { } TEST_F(TestL2NormalizeInfo, GetForwardOp1) { - std::vector inputs = {{4, 1, 8}}; + Strategys inputs = {{4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); norm->Init(strategy); @@ -128,7 +128,7 @@ TEST_F(TestL2NormalizeInfo, GetForwardOp1) { } TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) { - std::vector inputs = {{4, 1, 8}}; + Strategys inputs = {{4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); norm->Init(strategy); @@ -140,7 +140,7 @@ TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) { } TEST_F(TestL2NormalizeInfo, CheckStrategy1) { - std::vector inputs = {{4, 1, 8}, {4, 1, 8}}; + Strategys inputs = {{4, 1, 8}, {4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = norm->Init(strategy); @@ -148,7 +148,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy1) { } TEST_F(TestL2NormalizeInfo, CheckStrategy2) { - std::vector inputs = {{4, 2, 3}}; + Strategys inputs = {{4, 2, 3}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = norm->Init(strategy); @@ -156,7 +156,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy2) { } TEST_F(TestL2NormalizeInfo, CheckStrategy3) { - std::vector inputs = {{4, 2, 3, 4}}; + Strategys inputs = {{4, 2, 3, 4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = norm->Init(strategy); @@ -164,7 +164,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy3) { } TEST_F(TestL2NormalizeInfo, CheckStrategy4) { - std::vector inputs = {{4, 1, 8}}; + Strategys inputs = {{4, 1, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = norm->Init(strategy); @@ -172,7 +172,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy4) { } TEST_F(TestL2NormalizeInfo, mirror_ops) { - std::vector inputs = {{2, 1, 8}}; + Strategys inputs = {{2, 1, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); norm->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc index 8de5c07226..5803a4c325 100644 --- a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc @@ -38,13 +38,13 @@ class TestLogSoftmaxInfo : public UT::Common { }; void TestLogSoftmaxInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 130; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(128); stage_map.push_back(2); @@ -64,18 +64,18 @@ void TestLogSoftmaxInfo::SetUp() { } TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); log_softmax->Init(strategy); - std::vector dev_matrix_shape = log_softmax->dev_matrix_shape(); + Shape dev_matrix_shape = log_softmax->dev_matrix_shape(); - std::vector expect = {2, 4, 1, 16}; + Shape expect = {2, 4, 1, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestLogSoftmaxInfo, InferSliceShape1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); log_softmax->Init(strategy); @@ -96,7 +96,7 @@ TEST_F(TestLogSoftmaxInfo, InferSliceShape1) { } TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); log_softmax->Init(strategy); @@ -117,7 +117,7 @@ TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) { } TEST_F(TestLogSoftmaxInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); log_softmax->Init(strategy); @@ -128,7 +128,7 @@ TEST_F(TestLogSoftmaxInfo, GetForwardOp1) { } TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); log_softmax->Init(strategy); @@ -141,7 +141,7 @@ TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) { TEST_F(TestLogSoftmaxInfo, CheckStrategy1) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = log_softmax->Init(strategy); @@ -150,7 +150,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy1) { TEST_F(TestLogSoftmaxInfo, CheckStrategy2) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = log_softmax->Init(strategy); @@ -159,7 +159,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy2) { TEST_F(TestLogSoftmaxInfo, CheckStrategy3) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = log_softmax->Init(strategy); @@ -167,7 +167,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy3) { } TEST_F(TestLogSoftmaxInfo, GetDeviceList1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); log_softmax->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index 2d5676f211..d5fc6f2a2e 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -42,13 +42,13 @@ class TestMatmulInfo : public UT::Common { }; void TestMatmulInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -94,77 +94,77 @@ void TestMatmulInfo::SetUp() { } TEST_F(TestMatmulInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); - std::vector dev_matrix_shape = matmul1->dev_matrix_shape(); + Shape dev_matrix_shape = matmul1->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 16, 1}; + Shape expect = {2, 4, 8, 16, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestMatmulInfo, InferDevMatrixShape2) { - std::vector inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}}; + Strategys inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); - std::vector dev_matrix_shape = matmul1->dev_matrix_shape(); + Shape dev_matrix_shape = matmul1->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 8, 2}; + Shape expect = {2, 4, 8, 8, 2}; ASSERT_EQ(dev_matrix_shape, expect); } // matmul2 TEST_F(TestMatmulInfo, InferDevMatrixShape3) { - std::vector inputs = {{2, 4, 8, 16}, {1, 16}}; + Strategys inputs = {{2, 4, 8, 16}, {1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul2->Init(strategy); - std::vector dev_matrix_shape = matmul2->dev_matrix_shape(); + Shape dev_matrix_shape = matmul2->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 16, 1}; + Shape expect = {2, 4, 8, 16, 1}; ASSERT_EQ(dev_matrix_shape, expect); } // matmul2 TEST_F(TestMatmulInfo, InferDevMatrixShape4) { - std::vector inputs = {{2, 4, 8, 8}, {2, 8}}; + Strategys inputs = {{2, 4, 8, 8}, {2, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul2->Init(strategy); - std::vector dev_matrix_shape = matmul2->dev_matrix_shape(); + Shape dev_matrix_shape = matmul2->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 8, 2}; + Shape expect = {2, 4, 8, 8, 2}; ASSERT_EQ(dev_matrix_shape, expect); } // matmul3 TEST_F(TestMatmulInfo, InferDevMatrixShape5) { - std::vector inputs = {{8, 16}, {2, 4, 1, 16}}; + Strategys inputs = {{8, 16}, {2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul3->Init(strategy); - std::vector dev_matrix_shape = matmul3->dev_matrix_shape(); + Shape dev_matrix_shape = matmul3->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 16, 1}; + Shape expect = {2, 4, 8, 16, 1}; ASSERT_EQ(dev_matrix_shape, expect); } // matmul3 TEST_F(TestMatmulInfo, InferDevMatrixShape6) { - std::vector inputs = {{8, 8}, {2, 4, 2, 8}}; + Strategys inputs = {{8, 8}, {2, 4, 2, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul3->Init(strategy); - std::vector dev_matrix_shape = matmul3->dev_matrix_shape(); + Shape dev_matrix_shape = matmul3->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 8, 2}; + Shape expect = {2, 4, 8, 8, 2}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestMatmulInfo, InferTensorMap1) { - std::vector str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, str); matmul1->Init(strategy); @@ -190,7 +190,7 @@ TEST_F(TestMatmulInfo, InferTensorMap1) { // matmul2 TEST_F(TestMatmulInfo, InferTensorMap2) { - std::vector str = {{2, 4, 8, 16}, {1, 16}}; + Strategys str = {{2, 4, 8, 16}, {1, 16}}; StrategyPtr strategy = NewStrategy(0, str); matmul2->Init(strategy); @@ -216,7 +216,7 @@ TEST_F(TestMatmulInfo, InferTensorMap2) { // matmul3 TEST_F(TestMatmulInfo, InferTensorMap3) { - std::vector str = {{8, 16}, {2, 4, 1, 16}}; + Strategys str = {{8, 16}, {2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); matmul3->Init(strategy); @@ -241,7 +241,7 @@ TEST_F(TestMatmulInfo, InferTensorMap3) { } TEST_F(TestMatmulInfo, InferSliceShape1) { - std::vector str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, str); matmul1->Init(strategy); @@ -267,7 +267,7 @@ TEST_F(TestMatmulInfo, InferSliceShape1) { // matmul2 TEST_F(TestMatmulInfo, InferSliceShape2) { - std::vector str = {{2, 4, 8, 16}, {1, 16}}; + Strategys str = {{2, 4, 8, 16}, {1, 16}}; StrategyPtr strategy = NewStrategy(0, str); matmul2->Init(strategy); @@ -293,7 +293,7 @@ TEST_F(TestMatmulInfo, InferSliceShape2) { // matmul3 TEST_F(TestMatmulInfo, InferSliceShape3) { - std::vector str = {{8, 16}, {2, 4, 1, 16}}; + Strategys str = {{8, 16}, {2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); matmul3->Init(strategy); @@ -319,7 +319,7 @@ TEST_F(TestMatmulInfo, InferSliceShape3) { // matmul3 TEST_F(TestMatmulInfo, GetTensorLayout3) { - std::vector str = {{8, 16}, {2, 4, 1, 16}}; + Strategys str = {{8, 16}, {2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); matmul3->Init(strategy); @@ -344,7 +344,7 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) { } TEST_F(TestMatmulInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); @@ -370,7 +370,7 @@ TEST_F(TestMatmulInfo, GetForwardOp1) { } TEST_F(TestMatmulInfo, GetForwardOp2) { - std::vector inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); @@ -380,7 +380,7 @@ TEST_F(TestMatmulInfo, GetForwardOp2) { } TEST_F(TestMatmulInfo, GetVirtualDivOp1) { - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); @@ -399,7 +399,7 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) { } TEST_F(TestMatmulInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); @@ -419,7 +419,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) { // matmul2 TEST_F(TestMatmulInfo, GetMirrorOPs2) { - std::vector inputs = {{2, 4, 1, 16}, {8, 16}}; + Strategys inputs = {{2, 4, 1, 16}, {8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul2->Init(strategy); @@ -439,7 +439,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) { // matmul3 TEST_F(TestMatmulInfo, GetMirrorOPs3) { - std::vector inputs = {{8, 16}, {2, 4, 1, 16}}; + Strategys inputs = {{8, 16}, {2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul3->Init(strategy); @@ -457,7 +457,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) { } TEST_F(TestMatmulInfo, GetMirrorOPs4) { - std::vector inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}}; + Strategys inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); matmul1->Init(strategy); @@ -467,7 +467,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) { } TEST_F(TestMatmulInfo, InitTwice) { - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); // init twice @@ -489,7 +489,7 @@ TEST_F(TestMatmulInfo, InitTwice) { TEST_F(TestMatmulInfo, CheckStrategy1) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -498,7 +498,7 @@ TEST_F(TestMatmulInfo, CheckStrategy1) { TEST_F(TestMatmulInfo, CheckStrategy2) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{2, 4, 8, 16}, {4, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -507,7 +507,7 @@ TEST_F(TestMatmulInfo, CheckStrategy2) { TEST_F(TestMatmulInfo, CheckStrategy3) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -516,7 +516,7 @@ TEST_F(TestMatmulInfo, CheckStrategy3) { TEST_F(TestMatmulInfo, CheckStrategy4) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -525,7 +525,7 @@ TEST_F(TestMatmulInfo, CheckStrategy4) { TEST_F(TestMatmulInfo, CheckStrategy5) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -534,7 +534,7 @@ TEST_F(TestMatmulInfo, CheckStrategy5) { TEST_F(TestMatmulInfo, CheckStrategy6) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -543,7 +543,7 @@ TEST_F(TestMatmulInfo, CheckStrategy6) { TEST_F(TestMatmulInfo, CheckStrategy7) { // Success: {{2,4,8,16}, {2,4,16,1}} - std::vector inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul1->Init(strategy); @@ -552,7 +552,7 @@ TEST_F(TestMatmulInfo, CheckStrategy7) { TEST_F(TestMatmulInfo, InitFailed) { // matmul4 attr is wrong - std::vector inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = matmul4->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc index 074e4582f0..6efac9598b 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc @@ -38,13 +38,13 @@ class TestOneHotInfo : public UT::Common { }; void TestOneHotInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 10; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); stage_map.push_back(2); @@ -64,43 +64,43 @@ void TestOneHotInfo::SetUp() { } TEST_F(TestOneHotInfo, InferDevMatrixShape1) { - std::vector inputs = {{8, 1}, {}, {}}; + Strategys inputs = {{8, 1}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info->Init(strategy); ASSERT_EQ(status, SUCCESS); - std::vector dev_matrix_shape = onehot_info->dev_matrix_shape(); + Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); - std::vector expect = {8, 1}; + Shape expect = {8, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestOneHotInfo, InferDevMatrixShape2) { - std::vector inputs = {{4, 1}, {}, {}}; + Strategys inputs = {{4, 1}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info->Init(strategy); ASSERT_EQ(status, SUCCESS); - std::vector dev_matrix_shape = onehot_info->dev_matrix_shape(); + Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); - std::vector expect = {2, 4, 1}; + Shape expect = {2, 4, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestOneHotInfo, InferDevMatrixShape3) { - std::vector inputs = {{4, 2}, {}, {}}; + Strategys inputs = {{4, 2}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info->Init(strategy); ASSERT_EQ(status, FAILED); - std::vector dev_matrix_shape = onehot_info->dev_matrix_shape(); + Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); - std::vector expect = {4, 2}; + Shape expect = {4, 2}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestOneHotInfo, InferTensorMap2) { - std::vector str = {{8, 1}, {}, {}}; + Strategys str = {{8, 1}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info->Init(strategy); @@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo, InferTensorMap2) { } TEST_F(TestOneHotInfo, InferSliceShape1) { - std::vector str = {{8, 1}, {}, {}}; + Strategys str = {{8, 1}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info->Init(strategy); @@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo, InferSliceShape1) { } TEST_F(TestOneHotInfo, InferSliceShape2) { - std::vector str = {{4, 2}, {}, {}}; + Strategys str = {{4, 2}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info->Init(strategy); @@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) { } TEST_F(TestOneHotInfo, InferSliceShape3) { - std::vector str = {{2, 2}, {}, {}}; + Strategys str = {{2, 2}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info->Init(strategy); @@ -188,7 +188,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) { } TEST_F(TestOneHotInfo, GetMirrorOPs1) { - std::vector inputs = {{8, 1}, {}, {}}; + Strategys inputs = {{8, 1}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info->Init(strategy); @@ -199,7 +199,7 @@ TEST_F(TestOneHotInfo, GetMirrorOPs1) { } TEST_F(TestOneHotInfo, CheckStrategy1) { - std::vector inputs = {{16}, {}, {}}; + Strategys inputs = {{16}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = onehot_info->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc index 769d5bec45..239a7299cd 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc @@ -38,13 +38,13 @@ class TestOneHotInfo2 : public UT::Common { }; void TestOneHotInfo2::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 10; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(8); stage_map.push_back(2); @@ -64,43 +64,43 @@ void TestOneHotInfo2::SetUp() { } TEST_F(TestOneHotInfo2, InferDevMatrixShape1) { - std::vector inputs = {{1, 8}, {}, {}}; + Strategys inputs = {{1, 8}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info2->Init(strategy); ASSERT_EQ(status, SUCCESS); - std::vector dev_matrix_shape = onehot_info2->dev_matrix_shape(); + Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); - std::vector expect = {8, 1}; + Shape expect = {8, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestOneHotInfo2, InferDevMatrixShape2) { - std::vector inputs = {{1, 4}, {}, {}}; + Strategys inputs = {{1, 4}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info2->Init(strategy); ASSERT_EQ(status, SUCCESS); - std::vector dev_matrix_shape = onehot_info2->dev_matrix_shape(); + Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); - std::vector expect = {2, 4, 1}; + Shape expect = {2, 4, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestOneHotInfo2, InferDevMatrixShape3) { - std::vector inputs = {{2, 4}, {}, {}}; + Strategys inputs = {{2, 4}, {}, {}}; StrategyPtr strategy = NewStrategy(0, inputs); Status status = onehot_info2->Init(strategy); ASSERT_EQ(status, FAILED); - std::vector dev_matrix_shape = onehot_info2->dev_matrix_shape(); + Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); - std::vector expect = {4, 2}; + Shape expect = {4, 2}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestOneHotInfo2, InferTensorMap2) { - std::vector str = {{1, 8}, {}, {}}; + Strategys str = {{1, 8}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info2->Init(strategy); @@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo2, InferTensorMap2) { } TEST_F(TestOneHotInfo2, InferSliceShape1) { - std::vector str = {{1, 8}, {}, {}}; + Strategys str = {{1, 8}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info2->Init(strategy); @@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape1) { } TEST_F(TestOneHotInfo2, InferSliceShape2) { - std::vector str = {{2, 4}, {}, {}}; + Strategys str = {{2, 4}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info2->Init(strategy); @@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) { } TEST_F(TestOneHotInfo2, InferSliceShape3) { - std::vector str = {{2, 2}, {}, {}}; + Strategys str = {{2, 2}, {}, {}}; StrategyPtr strategy = NewStrategy(0, str); Status status = onehot_info2->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc index f582640db8..726f3e2307 100644 --- a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc @@ -38,13 +38,13 @@ class TestPowInfo : public UT::Common { }; void TestPowInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 66; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(64); stage_map.push_back(2); @@ -63,18 +63,18 @@ void TestPowInfo::SetUp() { } TEST_F(TestPowInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); pow->Init(strategy); - std::vector dev_matrix_shape = pow->dev_matrix_shape(); + Shape dev_matrix_shape = pow->dev_matrix_shape(); - std::vector expect = {2, 4, 8}; + Shape expect = {2, 4, 8}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestPowInfo, InferSliceShape1) { - std::vector str = {{2, 4, 8}, {2, 4, 8}}; + Strategys str = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, str); pow->Init(strategy); @@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) { } TEST_F(TestPowInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 8}, {2, 4, 8}}; + Strategys str = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, str); pow->Init(strategy); @@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) { } TEST_F(TestPowInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); pow->Init(strategy); @@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) { } TEST_F(TestPowInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); pow->Init(strategy); @@ -139,7 +139,7 @@ TEST_F(TestPowInfo, GetMirrorOPs1) { } TEST_F(TestPowInfo, CheckStrategy1) { - std::vector inputs = {{2, 2, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 2, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = pow->Init(strategy); @@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) { } TEST_F(TestPowInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = pow->Init(strategy); @@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) { } TEST_F(TestPowInfo, CheckStrategy3) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = pow->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/prelu_test.cc b/tests/ut/cpp/parallel/ops_info/prelu_test.cc index 1d4cf5eff0..b92392234e 100644 --- a/tests/ut/cpp/parallel/ops_info/prelu_test.cc +++ b/tests/ut/cpp/parallel/ops_info/prelu_test.cc @@ -39,13 +39,13 @@ class TestPReLUInfo : public UT::Common { }; void TestPReLUInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); int32_t local_dev = 0; @@ -64,18 +64,18 @@ void TestPReLUInfo::SetUp() { } TEST_F(TestPReLUInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 1, 8, 16}, {1}}; + Strategys inputs = {{2, 1, 8, 16}, {1}}; StrategyPtr strategy = NewStrategy(0, inputs); prelu->Init(strategy); - std::vector dev_matrix_shape = prelu->dev_matrix_shape(); + Shape dev_matrix_shape = prelu->dev_matrix_shape(); - std::vector expect = {4, 2, 1, 8, 16}; + Shape expect = {4, 2, 1, 8, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestPReLUInfo, InferSliceShape1) { - std::vector str = {{2, 1, 8, 16}, {1}}; + Strategys str = {{2, 1, 8, 16}, {1}}; StrategyPtr strategy = NewStrategy(0, str); prelu->Init(strategy); @@ -98,7 +98,7 @@ TEST_F(TestPReLUInfo, InferSliceShape1) { } TEST_F(TestPReLUInfo, GetTensorLayout1) { - std::vector str = {{2, 1, 8, 16}, {1}}; + Strategys str = {{2, 1, 8, 16}, {1}}; StrategyPtr strategy = NewStrategy(0, str); prelu->Init(strategy); @@ -122,7 +122,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) { } TEST_F(TestPReLUInfo, GetMirrorOPs1) { - std::vector str = {{2, 1, 2, 2}, {1}}; + Strategys str = {{2, 1, 2, 2}, {1}}; StrategyPtr strategy = NewStrategy(0, str); prelu->Init(strategy); MirrorOps mirror_ops = prelu->mirror_ops(); @@ -139,14 +139,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs1) { TEST_F(TestPReLUInfo, CheckStrategy1) { // Success: {{2,1,8,16},{1}} - std::vector inputs = {{2, 1, 8, 16}}; + Strategys inputs = {{2, 1, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = prelu->Init(strategy); ASSERT_EQ(ret, FAILED); } TEST_F(TestPReLUInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8, 16}, {4}}; + Strategys inputs = {{2, 4, 8, 16}, {4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = prelu->Init(strategy); ASSERT_EQ(ret, SUCCESS); @@ -169,18 +169,18 @@ TEST_F(TestPReLUInfo, AutoStrategy1) { } TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) { - std::vector inputs = {{128, 1}, {1}}; + Strategys inputs = {{128, 1}, {1}}; StrategyPtr strategy = NewStrategy(0, inputs); prelu_2d->Init(strategy); - std::vector dev_matrix_shape = prelu_2d->dev_matrix_shape(); + Shape dev_matrix_shape = prelu_2d->dev_matrix_shape(); - std::vector expect = {8, 128, 1}; + Shape expect = {8, 128, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestPReLUInfo, InferSliceShape_2d1) { - std::vector str = {{128, 1}, {1}}; + Strategys str = {{128, 1}, {1}}; StrategyPtr strategy = NewStrategy(0, str); prelu_2d->Init(strategy); @@ -203,7 +203,7 @@ TEST_F(TestPReLUInfo, InferSliceShape_2d1) { } TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { - std::vector str = {{128, 1}, {1}}; + Strategys str = {{128, 1}, {1}}; StrategyPtr strategy = NewStrategy(0, str); prelu_2d->Init(strategy); @@ -227,7 +227,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { } TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) { - std::vector str = {{128, 1}, {1}}; + Strategys str = {{128, 1}, {1}}; StrategyPtr strategy = NewStrategy(0, str); prelu_2d->Init(strategy); MirrorOps mirror_ops = prelu_2d->mirror_ops(); @@ -244,14 +244,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) { TEST_F(TestPReLUInfo, CheckStrategy_2d1) { // Success: {{2,1,8,16},{1}} - std::vector inputs = {{128, 1}}; + Strategys inputs = {{128, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = prelu_2d->Init(strategy); ASSERT_EQ(ret, FAILED); } TEST_F(TestPReLUInfo, CheckStrategy_2d2) { - std::vector inputs = {{128, 4}, {4}}; + Strategys inputs = {{128, 4}, {4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = prelu_2d->Init(strategy); ASSERT_EQ(ret, SUCCESS); diff --git a/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc b/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc index 64ba6af70b..69d830db0f 100644 --- a/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc @@ -39,13 +39,13 @@ class TestReduceSumInfo : public UT::Common { void TestReduceSumInfo::SetUp() { UT::InitPythonPath(); - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 34; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(32); stage_map.push_back(2); @@ -68,18 +68,18 @@ void TestReduceSumInfo::SetUp() { } TEST_F(TestReduceSumInfo, InferDevMatrixShape1) { - std::vector inputs = {{4, 8, 1}}; + Strategys inputs = {{4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reduce_sum->Init(strategy); - std::vector dev_matrix_shape = reduce_sum->dev_matrix_shape(); + Shape dev_matrix_shape = reduce_sum->dev_matrix_shape(); - std::vector expect = {4, 8, 1}; + Shape expect = {4, 8, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestReduceSumInfo, InferSliceShape1) { - std::vector str = {{4, 8, 1}}; + Strategys str = {{4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, str); reduce_sum->Init(strategy); @@ -100,7 +100,7 @@ TEST_F(TestReduceSumInfo, InferSliceShape1) { } TEST_F(TestReduceSumInfo, GetTensorLayout1) { - std::vector str = {{4, 8, 1}}; + Strategys str = {{4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, str); reduce_sum->Init(strategy); @@ -121,7 +121,7 @@ TEST_F(TestReduceSumInfo, GetTensorLayout1) { } TEST_F(TestReduceSumInfo, GetForwardOp1) { - std::vector inputs = {{4, 8, 1}}; + Strategys inputs = {{4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reduce_sum->Init(strategy); @@ -132,7 +132,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp1) { } TEST_F(TestReduceSumInfo, GetForwardOp2) { - std::vector inputs = {{4, 4, 2}}; + Strategys inputs = {{4, 4, 2}}; StrategyPtr strategy = NewStrategy(0, inputs); reduce_sum->Init(strategy); @@ -156,7 +156,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp2) { } TEST_F(TestReduceSumInfo, GetMirrorOPs1) { - std::vector inputs = {{4, 8, 1}}; + Strategys inputs = {{4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reduce_sum->Init(strategy); @@ -168,7 +168,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs1) { } TEST_F(TestReduceSumInfo, GetMirrorOPs2) { - std::vector inputs = {{4, 4, 1}}; + Strategys inputs = {{4, 4, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reduce_sum->Init(strategy); @@ -187,7 +187,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs2) { } TEST_F(TestReduceSumInfo, CheckStrategy1) { - std::vector inputs = {{2, 2, 8, 16}}; + Strategys inputs = {{2, 2, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reduce_sum->Init(strategy); @@ -195,7 +195,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy1) { } TEST_F(TestReduceSumInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reduce_sum->Init(strategy); @@ -203,7 +203,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy2) { } TEST_F(TestReduceSumInfo, CheckStrategy3) { - std::vector inputs = {{4, 4, 2}}; + Strategys inputs = {{4, 4, 2}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reduce_sum->Init(strategy); @@ -211,7 +211,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy3) { } TEST_F(TestReduceSumInfo, CheckStrategy4) { - std::vector inputs = {{4, 8, 1}}; + Strategys inputs = {{4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reduce_sum->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/reshape_test.cc b/tests/ut/cpp/parallel/ops_info/reshape_test.cc index 8cc8390e9a..71c793cf56 100644 --- a/tests/ut/cpp/parallel/ops_info/reshape_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reshape_test.cc @@ -38,13 +38,13 @@ class TestReshapeInfo : public UT::Common { }; void TestReshapeInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 34; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(32); stage_map.push_back(2); @@ -68,29 +68,29 @@ void TestReshapeInfo::SetUp() { } TEST_F(TestReshapeInfo, InferDevMatrixShape1) { - std::vector inputs = {{4, 1, 1, 1}}; + Strategys inputs = {{4, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reshape->Init(strategy); - std::vector dev_matrix_shape = reshape->dev_matrix_shape(); + Shape dev_matrix_shape = reshape->dev_matrix_shape(); - std::vector expect = {8, 4}; + Shape expect = {8, 4}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestReshapeInfo, InferDevMatrixShape2) { - std::vector inputs = {{32, 1, 1, 1}}; + Strategys inputs = {{32, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reshape->Init(strategy); - std::vector dev_matrix_shape = reshape->dev_matrix_shape(); + Shape dev_matrix_shape = reshape->dev_matrix_shape(); - std::vector expect = {32}; + Shape expect = {32}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestReshapeInfo, InferSliceShape1) { - std::vector str = {{4, 1, 1, 1}}; + Strategys str = {{4, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, str); reshape->Init(strategy); @@ -111,7 +111,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) { } TEST_F(TestReshapeInfo, InferSliceShape2) { - std::vector str = {{32, 1, 1, 1}}; + Strategys str = {{32, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, str); reshape->Init(strategy); @@ -132,7 +132,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) { } TEST_F(TestReshapeInfo, GetTensorLayout1) { - std::vector str = {{4, 1, 1, 1}}; + Strategys str = {{4, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, str); reshape->Init(strategy); @@ -153,7 +153,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) { } TEST_F(TestReshapeInfo, GetTensorLayout2) { - std::vector str = {{32, 1, 1, 1}}; + Strategys str = {{32, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, str); reshape->Init(strategy); @@ -174,7 +174,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) { } TEST_F(TestReshapeInfo, GetForwardOp1) { - std::vector inputs = {{4, 1, 1, 1}}; + Strategys inputs = {{4, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reshape->Init(strategy); @@ -185,7 +185,7 @@ TEST_F(TestReshapeInfo, GetForwardOp1) { } TEST_F(TestReshapeInfo, GetMirrorOPs1) { - std::vector inputs = {{4, 1, 1, 1}}; + Strategys inputs = {{4, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); reshape->Init(strategy); @@ -197,7 +197,7 @@ TEST_F(TestReshapeInfo, GetMirrorOPs1) { } TEST_F(TestReshapeInfo, CheckStrategy1) { - std::vector inputs = {{1, 4, 8}}; + Strategys inputs = {{1, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reshape->Init(strategy); @@ -205,7 +205,7 @@ TEST_F(TestReshapeInfo, CheckStrategy1) { } TEST_F(TestReshapeInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reshape->Init(strategy); @@ -213,7 +213,7 @@ TEST_F(TestReshapeInfo, CheckStrategy2) { } TEST_F(TestReshapeInfo, CheckStrategy3) { - std::vector inputs = {{4, 1, 1, 1}}; + Strategys inputs = {{4, 1, 1, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = reshape->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc b/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc index d370c168c9..0368ae5c1d 100644 --- a/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc @@ -38,13 +38,13 @@ class TestSoftmaxLoss : public UT::Common { }; void TestSoftmaxLoss::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 65; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(64); stage_map.push_back(1); @@ -64,18 +64,18 @@ void TestSoftmaxLoss::SetUp() { } TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; + Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); loss->Init(strategy); - std::vector dev_matrix_shape = loss->dev_matrix_shape(); + Shape dev_matrix_shape = loss->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 1}; + Shape expect = {2, 4, 8, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestSoftmaxLoss, InferSliceShape1) { - std::vector str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; + Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, str); loss->Init(strategy); @@ -104,7 +104,7 @@ TEST_F(TestSoftmaxLoss, InferSliceShape1) { } TEST_F(TestSoftmaxLoss, GetTensorLayout1) { - std::vector str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; + Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, str); loss->Init(strategy); @@ -133,7 +133,7 @@ TEST_F(TestSoftmaxLoss, GetTensorLayout1) { } TEST_F(TestSoftmaxLoss, GetForwardOp1) { - std::vector inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; + Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); loss->Init(strategy); @@ -144,7 +144,7 @@ TEST_F(TestSoftmaxLoss, GetForwardOp1) { } TEST_F(TestSoftmaxLoss, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; + Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); loss->Init(strategy); @@ -156,7 +156,7 @@ TEST_F(TestSoftmaxLoss, GetMirrorOPs1) { } TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) { - std::vector inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}}; + Strategys inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); loss->Init(strategy); @@ -176,7 +176,7 @@ TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) { TEST_F(TestSoftmaxLoss, CheckStrategy1) { // Success: {{2,4,8,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = loss->Init(strategy); @@ -185,7 +185,7 @@ TEST_F(TestSoftmaxLoss, CheckStrategy1) { TEST_F(TestSoftmaxLoss, CheckStrategy2) { // Success: {{2,4,8,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = loss->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc index 9c4205672b..7be44ef2b9 100644 --- a/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc @@ -39,13 +39,13 @@ class TestSoftmaxInfo : public UT::Common { }; void TestSoftmaxInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 130; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(128); stage_map.push_back(2); @@ -68,18 +68,18 @@ void TestSoftmaxInfo::SetUp() { } TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); softmax->Init(strategy); - std::vector dev_matrix_shape = softmax->dev_matrix_shape(); + Shape dev_matrix_shape = softmax->dev_matrix_shape(); - std::vector expect = {2, 4, 1, 16}; + Shape expect = {2, 4, 1, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestSoftmaxInfo, InferSliceShape1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); softmax->Init(strategy); @@ -100,7 +100,7 @@ TEST_F(TestSoftmaxInfo, InferSliceShape1) { } TEST_F(TestSoftmaxInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); softmax->Init(strategy); @@ -121,7 +121,7 @@ TEST_F(TestSoftmaxInfo, GetTensorLayout1) { } TEST_F(TestSoftmaxInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); softmax->Init(strategy); @@ -132,7 +132,7 @@ TEST_F(TestSoftmaxInfo, GetForwardOp1) { } TEST_F(TestSoftmaxInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); softmax->Init(strategy); @@ -145,7 +145,7 @@ TEST_F(TestSoftmaxInfo, GetMirrorOPs1) { TEST_F(TestSoftmaxInfo, CheckStrategy1) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = softmax->Init(strategy); @@ -154,7 +154,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy1) { TEST_F(TestSoftmaxInfo, CheckStrategy2) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = softmax->Init(strategy); @@ -163,7 +163,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy2) { TEST_F(TestSoftmaxInfo, CheckStrategy3) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = softmax->Init(strategy); @@ -172,7 +172,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy3) { TEST_F(TestSoftmaxInfo, InitFailed1) { // softmax2's axis is wrong - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = softmax2->Init(strategy); @@ -181,7 +181,7 @@ TEST_F(TestSoftmaxInfo, InitFailed1) { TEST_F(TestSoftmaxInfo, InitFailed2) { // dev num is wrong - std::vector inputs = {{2, 4, 1, 100}}; + Strategys inputs = {{2, 4, 1, 100}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = softmax2->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc b/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc index 2be6c5bf7f..6dadb1c3a1 100644 --- a/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc @@ -38,13 +38,13 @@ class TestTanhInfo : public UT::Common { }; void TestTanhInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 130; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(128); stage_map.push_back(2); @@ -63,18 +63,18 @@ void TestTanhInfo::SetUp() { } TEST_F(TestTanhInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); tanh->Init(strategy); - std::vector dev_matrix_shape = tanh->dev_matrix_shape(); + Shape dev_matrix_shape = tanh->dev_matrix_shape(); - std::vector expect = {2, 4, 1, 16}; + Shape expect = {2, 4, 1, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestTanhInfo, InferSliceShape1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); tanh->Init(strategy); @@ -95,7 +95,7 @@ TEST_F(TestTanhInfo, InferSliceShape1) { } TEST_F(TestTanhInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 1, 16}}; + Strategys str = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, str); tanh->Init(strategy); @@ -116,7 +116,7 @@ TEST_F(TestTanhInfo, GetTensorLayout1) { } TEST_F(TestTanhInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); tanh->Init(strategy); @@ -127,7 +127,7 @@ TEST_F(TestTanhInfo, GetForwardOp1) { } TEST_F(TestTanhInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); tanh->Init(strategy); @@ -140,7 +140,7 @@ TEST_F(TestTanhInfo, GetMirrorOPs1) { TEST_F(TestTanhInfo, CheckStrategy1) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tanh->Init(strategy); @@ -149,7 +149,7 @@ TEST_F(TestTanhInfo, CheckStrategy1) { TEST_F(TestTanhInfo, CheckStrategy2) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tanh->Init(strategy); @@ -158,7 +158,7 @@ TEST_F(TestTanhInfo, CheckStrategy2) { TEST_F(TestTanhInfo, CheckStrategy3) { // Success: {{2,4,1,16}} - std::vector inputs = {{2, 4, 1, 16}}; + Strategys inputs = {{2, 4, 1, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tanh->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc index b523652fcb..731b5caf28 100644 --- a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc @@ -38,13 +38,13 @@ class TestTensorAddInfo : public UT::Common { }; void TestTensorAddInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 34; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(32); stage_map.push_back(2); @@ -66,18 +66,18 @@ void TestTensorAddInfo::SetUp() { } TEST_F(TestTensorAddInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 4}, {2, 4, 4}}; + Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; StrategyPtr strategy = NewStrategy(0, inputs); tensor_add->Init(strategy); - std::vector dev_matrix_shape = tensor_add->dev_matrix_shape(); + Shape dev_matrix_shape = tensor_add->dev_matrix_shape(); - std::vector expect = {2, 4, 4}; + Shape expect = {2, 4, 4}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestTensorAddInfo, InferSliceShape1) { - std::vector str = {{2, 4, 4}, {2, 4, 4}}; + Strategys str = {{2, 4, 4}, {2, 4, 4}}; StrategyPtr strategy = NewStrategy(0, str); tensor_add->Init(strategy); @@ -101,7 +101,7 @@ TEST_F(TestTensorAddInfo, InferSliceShape1) { } TEST_F(TestTensorAddInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 4}, {2, 4, 4}}; + Strategys str = {{2, 4, 4}, {2, 4, 4}}; StrategyPtr strategy = NewStrategy(0, str); tensor_add->Init(strategy); @@ -125,7 +125,7 @@ TEST_F(TestTensorAddInfo, GetTensorLayout1) { } TEST_F(TestTensorAddInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 4}, {2, 4, 4}}; + Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; StrategyPtr strategy = NewStrategy(0, inputs); tensor_add->Init(strategy); @@ -136,7 +136,7 @@ TEST_F(TestTensorAddInfo, GetForwardOp1) { } TEST_F(TestTensorAddInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 4}, {2, 4, 4}}; + Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; StrategyPtr strategy = NewStrategy(0, inputs); tensor_add->Init(strategy); @@ -148,7 +148,7 @@ TEST_F(TestTensorAddInfo, GetMirrorOPs1) { } TEST_F(TestTensorAddInfo, CheckStrategy1) { - std::vector inputs = {{2, 4, 4}, {2, 6, 4}}; + Strategys inputs = {{2, 4, 4}, {2, 6, 4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tensor_add->Init(strategy); @@ -156,7 +156,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy1) { } TEST_F(TestTensorAddInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tensor_add->Init(strategy); @@ -164,7 +164,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy2) { } TEST_F(TestTensorAddInfo, CheckStrategy3) { - std::vector inputs = {{2, 4, 6}}; + Strategys inputs = {{2, 4, 6}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tensor_add->Init(strategy); @@ -172,7 +172,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy3) { } TEST_F(TestTensorAddInfo, CheckStrategy4) { - std::vector inputs = {{2, 4, 4}, {2, 4, 4}}; + Strategys inputs = {{2, 4, 4}, {2, 4, 4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = tensor_add->Init(strategy); @@ -224,7 +224,7 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { } TEST_F(TestTensorAddInfo, mirror_ops) { - std::vector inputs = {{1, 8}, {4, 1}}; + Strategys inputs = {{1, 8}, {4, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); tensor_add1->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc index 461a27d4ed..16967754c8 100644 --- a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc @@ -19,6 +19,7 @@ #include "frontend/parallel/device_manager.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { @@ -26,7 +27,6 @@ namespace parallel { class TmpIdentityInfo; using TmpIdentityInfoPtr = std::shared_ptr; TmpIdentityInfoPtr identity_ptr; -using TensorMap = std::vector; class TestTmpIdentityInfo : public UT::Common { public: @@ -38,13 +38,13 @@ class TestTmpIdentityInfo : public UT::Common { }; void TestTmpIdentityInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -65,18 +65,18 @@ void TestTmpIdentityInfo::SetUp() { } TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 8, 16}}; + Strategys inputs = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); identity_ptr->Init(strategy); - std::vector dev_matrix_shape = identity_ptr->dev_matrix_shape(); + Shape dev_matrix_shape = identity_ptr->dev_matrix_shape(); - std::vector expect = {2, 4, 8, 16}; + Shape expect = {2, 4, 8, 16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestTmpIdentityInfo, InferSliceShape1) { - std::vector str = {{2, 4, 8, 16}}; + Strategys str = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, str); identity_ptr->Init(strategy); @@ -97,7 +97,7 @@ TEST_F(TestTmpIdentityInfo, InferSliceShape1) { } TEST_F(TestTmpIdentityInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 8, 16}}; + Strategys str = {{2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, str); identity_ptr->Init(strategy); @@ -119,7 +119,7 @@ TEST_F(TestTmpIdentityInfo, GetTensorLayout1) { TEST_F(TestTmpIdentityInfo, CheckStrategy1) { // Success: {{2,4,8,16}} - std::vector inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; + Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = identity_ptr->Init(strategy); @@ -128,7 +128,7 @@ TEST_F(TestTmpIdentityInfo, CheckStrategy1) { TEST_F(TestTmpIdentityInfo, CheckStrategy2) { // Success: {{2,4,8,16}} - std::vector inputs = {{2, 4, 8}}; + Strategys inputs = {{2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = identity_ptr->Init(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/transpose_test.cc b/tests/ut/cpp/parallel/ops_info/transpose_test.cc index fe5cbb01b3..149e49e854 100644 --- a/tests/ut/cpp/parallel/ops_info/transpose_test.cc +++ b/tests/ut/cpp/parallel/ops_info/transpose_test.cc @@ -38,13 +38,13 @@ class TestTransposeInfo : public UT::Common { }; void TestTransposeInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 34; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(32); stage_map.push_back(2); @@ -68,29 +68,29 @@ void TestTransposeInfo::SetUp() { } TEST_F(TestTransposeInfo, InferDevMatrixShape1) { - std::vector inputs = {{4, 8}}; + Strategys inputs = {{4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); transpose->Init(strategy); - std::vector dev_matrix_shape = transpose->dev_matrix_shape(); + Shape dev_matrix_shape = transpose->dev_matrix_shape(); - std::vector expect = {4, 8}; + Shape expect = {4, 8}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestTransposeInfo, InferDevMatrixShape2) { - std::vector inputs = {{4, 1}}; + Strategys inputs = {{4, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); transpose->Init(strategy); - std::vector dev_matrix_shape = transpose->dev_matrix_shape(); + Shape dev_matrix_shape = transpose->dev_matrix_shape(); - std::vector expect = {8, 4, 1}; + Shape expect = {8, 4, 1}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestTransposeInfo, InferSliceShape1) { - std::vector str = {{4, 8}}; + Strategys str = {{4, 8}}; StrategyPtr strategy = NewStrategy(0, str); transpose->Init(strategy); @@ -111,7 +111,7 @@ TEST_F(TestTransposeInfo, InferSliceShape1) { } TEST_F(TestTransposeInfo, GetTensorLayout1) { - std::vector str = {{4, 8}}; + Strategys str = {{4, 8}}; StrategyPtr strategy = NewStrategy(0, str); transpose->Init(strategy); @@ -132,7 +132,7 @@ TEST_F(TestTransposeInfo, GetTensorLayout1) { } TEST_F(TestTransposeInfo, GetForwardOp1) { - std::vector inputs = {{4, 8}}; + Strategys inputs = {{4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); transpose->Init(strategy); @@ -143,7 +143,7 @@ TEST_F(TestTransposeInfo, GetForwardOp1) { } TEST_F(TestTransposeInfo, GetMirrorOPs1) { - std::vector inputs = {{4, 8}}; + Strategys inputs = {{4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); transpose->Init(strategy); @@ -155,7 +155,7 @@ TEST_F(TestTransposeInfo, GetMirrorOPs1) { } TEST_F(TestTransposeInfo, CheckStrategy1) { - std::vector inputs = {{1, 4, 8}}; + Strategys inputs = {{1, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = transpose->Init(strategy); @@ -163,7 +163,7 @@ TEST_F(TestTransposeInfo, CheckStrategy1) { } TEST_F(TestTransposeInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; + Strategys inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = transpose->Init(strategy); @@ -171,7 +171,7 @@ TEST_F(TestTransposeInfo, CheckStrategy2) { } TEST_F(TestTransposeInfo, CheckStrategy3) { - std::vector inputs = {{4, 8}}; + Strategys inputs = {{4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = transpose->Init(strategy); diff --git a/tests/ut/cpp/parallel/step_auto_parallel_test.cc b/tests/ut/cpp/parallel/step_auto_parallel_test.cc index 6cf7ec66c6..1a93981acc 100644 --- a/tests/ut/cpp/parallel/step_auto_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_auto_parallel_test.cc @@ -32,13 +32,13 @@ class TestStepAutoParallel : public UT::Common { }; void TestStepAutoParallel::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 20; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(16); stage_map.push_back(4); @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { StrategyPtr strategyPtr; std::shared_ptr matmul_info = NewOperatorInstance(prim, attrs, shape); - node->set_operator_info(matmul_info); + node->set_user_data(matmul_info); std::string name_expect = "MatMulInfo00"; std::string name_test = matmul_info->name(); ASSERT_EQ(name_expect, name_test); diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 5657db8790..80b8f6be0c 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -34,13 +34,13 @@ class TestStepParallel : public UT::Common { void TestStepParallel::SetUp() { UT::InitPythonPath(); } void Init_Device_Manager() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 20; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(16); stage_map.push_back(4); @@ -52,17 +52,26 @@ void Init_Device_Manager() { } CNodePtr Make_Node(Shape x, Shape y, Shape out, int condition = 0) { + std::vector x_shape; + std::vector y_shape; + std::vector out_shape; FuncGraphPtr func_graph = std::make_shared(); ParameterPtr param1 = func_graph->add_parameter(); ParameterPtr param2 = func_graph->add_parameter(); + (void)std::transform(x.begin(), x.end(), std::back_inserter(x_shape), + [](const int64_t &value) { return static_cast(value); }); + (void)std::transform(y.begin(), y.end(), std::back_inserter(y_shape), + [](const int64_t &value) { return static_cast(value); }); + (void)std::transform(out.begin(), out.end(), std::back_inserter(out_shape), + [](const int64_t &value) { return static_cast(value); }); param1->set_name("x"); param2->set_name("y"); BaseShapePtr shape1 = std::make_shared(x); BaseShapePtr shape2 = std::make_shared(y); BaseShapePtr shape3 = std::make_shared(out); - std::shared_ptr inputs_x = std::make_shared(kNumberTypeInt32, x); - std::shared_ptr inputs_y = std::make_shared(kNumberTypeInt32, y); - std::shared_ptr inputs_out = std::make_shared(kNumberTypeInt32, out); + std::shared_ptr inputs_x = std::make_shared(kNumberTypeInt32, x_shape); + std::shared_ptr inputs_y = std::make_shared(kNumberTypeInt32, y_shape); + std::shared_ptr inputs_out = std::make_shared(kNumberTypeInt32, out_shape); AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true); AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true); AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true); @@ -112,11 +121,11 @@ CNodePtr Make_Node(Shape x, Shape y, Shape out, int condition = 0) { } FuncGraphManagerPtr Make_Manager(int condition = 0) { - Shape inputs_x = {64, 32}; - Shape inputs_y = {32, 64}; - Shape inputs_z = {64, 128}; - Shape outputs_1 = {64, 64}; - Shape outputs_2 = {64, 128}; + std::vector inputs_x = {64, 32}; + std::vector inputs_y = {32, 64}; + std::vector inputs_z = {64, 128}; + std::vector outputs_1 = {64, 64}; + std::vector outputs_2 = {64, 128}; FuncGraphPtr func_graph = std::make_shared(); ParameterPtr param1 = func_graph->add_parameter(); ParameterPtr param2 = func_graph->add_parameter(); @@ -134,8 +143,8 @@ FuncGraphManagerPtr Make_Manager(int condition = 0) { param1->set_abstract(abstract_x); param2->set_abstract(abstract_y); param3->set_abstract(abstract_z); - std::vector v1 = {2, 2}; - std::vector v2 = {2, 4}; + Dimensions v1 = {2, 2}; + Dimensions v2 = {2, 4}; std::vector elements = {MakeValue(v1), MakeValue(v2)}; ValueTuplePtr var = std::make_shared(elements); std::vector inputs; @@ -153,8 +162,8 @@ FuncGraphManagerPtr Make_Manager(int condition = 0) { prim1->AddAttr("instance_name", MakeValue("matmul1")); prim1->AddAttr("strategy", var); inputs.clear(); - std::vector v3 = {2, 2}; - std::vector v4 = {2, 4}; + Dimensions v3 = {2, 2}; + Dimensions v4 = {2, 4}; std::vector elements2 = {MakeValue(v3), MakeValue(v4)}; ValueTuplePtr var2 = std::make_shared(elements2); inputs.push_back(NewValueNode(prim::kPrimMatMul)); @@ -186,8 +195,8 @@ FuncGraphManagerPtr Make_Manager(int condition = 0) { break; } case 3: { - std::vector vt1 = {2, 4}; - std::vector vt2 = {2, 4}; + Dimensions vt1 = {2, 4}; + Dimensions vt2 = {2, 4}; std::vector elements_t2 = {MakeValue(vt1), MakeValue(vt2)}; ValueTuplePtr var_t2 = std::make_shared(elements_t2); prim1->set_attr("strategy", var_t2); @@ -224,9 +233,9 @@ TEST_F(TestStepParallel, ExtractStrategy) { std::vector elements = {val1, val2}; ValueTuplePtr strategy_tuple = std::make_shared(elements); attrs["strategy"] = strategy_tuple; - std::vector strategy_expect = {v1, v2}; + Strategys strategy_expect = {v1, v2}; StrategyPtr strategy = ExtractStrategy(attrs); - std::vector strategy_test = strategy->GetInputDim(); + Strategys strategy_test = strategy->GetInputDim(); ASSERT_EQ(strategy_expect, strategy_test); } @@ -294,10 +303,6 @@ TEST_F(TestStepParallel, CreatOpInstance) { ASSERT_TRUE(allreduce_ptr); if (nullptr != allreduce_ptr) { MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name(); - auto func = allreduce_ptr->GetComputeFunction(); - if (py::isinstance(func)) { - MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented"; - } std::vector arglist; (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist), @@ -357,7 +362,7 @@ TEST_F(TestStepParallel, OperatorInstance) { prim->set_attr("transpose_b", transpose_b); auto attrs = prim->attrs(); // creat strategy - std::vector strategy = {{2, 2}, {2, 4}}; + Strategys strategy = {{2, 2}, {2, 4}}; StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy); // creat shape Shapes inputs_shape = std::vector{{64, 32}, {32, 64}}; @@ -518,7 +523,7 @@ TEST_F(TestStepParallel, GetTensorInLayout) { prim->set_attr("transpose_b", transpose_b); auto attrs = prim->attrs(); // creat strategy - std::vector strategy = {{2, 2}, {2, 4}}; + Strategys strategy = {{2, 2}, {2, 4}}; StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy); // creat shape Shapes inputs_shape = std::vector{{64, 32}, {32, 64}}; @@ -526,12 +531,12 @@ TEST_F(TestStepParallel, GetTensorInLayout) { std::vector shape = {inputs_shape, outputs_shape}; OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); matmul_info->Init(strategyPtr); - node->set_operator_info(matmul_info); - OperatorInfoPtr distribute_operator_pre = node->operator_info(); + node->set_user_data(matmul_info); + OperatorInfoPtr distribute_operator_pre = node->user_data(); TensorLayout tensorlayout_e; - std::vector array = {64, 64}; + Shape array = {64, 64}; TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); - std::vector tensor_shape_test = tensorlayout.tensor_shape().array(); + Shape tensor_shape_test = tensorlayout.tensor_shape().array(); ASSERT_EQ(array, tensor_shape_test); } diff --git a/tests/ut/cpp/parallel/strategy_test.cc b/tests/ut/cpp/parallel/strategy_test.cc index c13b71944e..58011b84b6 100644 --- a/tests/ut/cpp/parallel/strategy_test.cc +++ b/tests/ut/cpp/parallel/strategy_test.cc @@ -33,9 +33,9 @@ class TestStrategy : public UT::Common { TEST_F(TestStrategy, GetInputNumber) { int32_t number = 2; int32_t stage = 1; - std::vector dimension1 = {2, 4}; - std::vector dimension2 = {2, 2}; - std::vector> inputs = {dimension1, dimension2}; + Dimensions dimension1 = {2, 4}; + Dimensions dimension2 = {2, 2}; + Strategys inputs = {dimension1, dimension2}; Strategy strategy(stage, inputs); int32_t number_test = strategy.GetInputNumber(); @@ -44,9 +44,9 @@ TEST_F(TestStrategy, GetInputNumber) { TEST_F(TestStrategy, GetInputStage) { int32_t stage = 1; - std::vector dimension1 = {2, 4}; - std::vector dimension2 = {2, 2}; - std::vector> inputs = {dimension1, dimension2}; + Dimensions dimension1 = {2, 4}; + Dimensions dimension2 = {2, 2}; + Strategys inputs = {dimension1, dimension2}; Strategy strategy(stage, inputs); int32_t stage_test = strategy.GetInputStage(); @@ -55,23 +55,23 @@ TEST_F(TestStrategy, GetInputStage) { TEST_F(TestStrategy, GetInputDim) { int32_t stage = 1; - std::vector dimension1 = {2, 4}; - std::vector dimension2 = {2, 2}; - std::vector> inputs = {dimension1, dimension2}; + Dimensions dimension1 = {2, 4}; + Dimensions dimension2 = {2, 2}; + Strategys inputs = {dimension1, dimension2}; Strategy strategy(stage, inputs); - std::vector> inputs_test = strategy.GetInputDim(); + Strategys inputs_test = strategy.GetInputDim(); ASSERT_EQ(inputs, inputs_test); } TEST_F(TestStrategy, IsEqual) { int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0; - std::vector dimension1 = {8, 1}; - std::vector dimension2 = {1, 8}; - std::vector> inputs1 = {dimension1}; - std::vector> inputs2 = {dimension1}; - std::vector> inputs3 = {dimension2}; - std::vector> inputs4 = {dimension1, dimension2}; + Dimensions dimension1 = {8, 1}; + Dimensions dimension2 = {1, 8}; + Strategys inputs1 = {dimension1}; + Strategys inputs2 = {dimension1}; + Strategys inputs3 = {dimension2}; + Strategys inputs4 = {dimension1, dimension2}; StrategyPtr stra1 = std::make_shared(stage1, inputs1); StrategyPtr stra2 = std::make_shared(stage2, inputs2); diff --git a/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc b/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc index b80f199035..61df1a2461 100644 --- a/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc @@ -39,12 +39,12 @@ class TestConstructOperator : public UT::Common { }; void TestConstructOperator::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); @@ -62,7 +62,7 @@ void TestConstructOperator::SetUp() { MatMulInfoPtr matmul = std::make_shared("matmul_info", inputs_shape_1, outputs_shape_1, attr_1); - std::vector str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; + Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; StrategyPtr strategy = NewStrategy(0, str); matmul->Init(strategy); Shape tensor_shape = {512, 1024}; @@ -79,8 +79,8 @@ TEST_F(TestConstructOperator, TestReshapeOP) { TEST_F(TestConstructOperator, TestStridedSliceOP) { Args args = {1, 2, 3}; - int32_t split_count = args[0]; - int32_t split_dim = args[1]; + int64_t split_count = args[0]; + int64_t split_dim = args[1]; Shape device_arrangement = {8, 4}; Arrangement dev_mat; dev_mat.Init(device_arrangement); @@ -98,12 +98,18 @@ TEST_F(TestConstructOperator, TestStridedSliceOP) { OperatorParams params = op.second.second; ValuePtr begin_ptr = params[0].first.second; ValuePtr end_ptr = params[1].first.second; - Shape begin = GetValue>(begin_ptr); - Shape end = GetValue>(end_ptr); + std::vector begin_int = GetValue>(begin_ptr); + std::vector end_int = GetValue>(end_ptr); + Shape begin; + Shape end; + (void)std::transform(begin_int.begin(), begin_int.end(), std::back_inserter(begin), + [](const int32_t &value) { return static_cast(value); }); + (void)std::transform(end_int.begin(), end_int.end(), std::back_inserter(end), + [](const int32_t &value) { return static_cast(value); }); for (size_t i = 0; i < begin.size(); i++) { - int32_t diff = end[i] - begin[i]; - int32_t num = shape[i]; - if (SizeToInt(i) != split_dim) { + int64_t diff = end[i] - begin[i]; + int64_t num = shape[i]; + if (SizeToLong(i) != split_dim) { ASSERT_EQ(diff, shape[i]); } else { ASSERT_EQ(diff, num / split_count); diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc index 4ddc130a45..b1ecca82b4 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc @@ -20,14 +20,11 @@ #include "frontend/parallel/tensor_layout/tensor_layout.h" #include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" #include "util_layout_gen_test.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { -using DeviceArrangement = std::vector; -using TensorMap = std::vector; -using TensorShape = std::vector; - class TestRedistributionLayoutTransfer : public UT::Common { public: TestRedistributionLayoutTransfer() {} @@ -245,13 +242,13 @@ void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangeme unified_out_tensor_map, unified_tensor_shape); } -void ValidRedistributionLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow_size, - int32_t max_device_dim, int32_t max_shape_dim) { +void ValidRedistributionLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, + int64_t max_device_dim, int64_t max_shape_dim) { std::vector> layout_list; GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim, &layout_list); - for (uint32_t in = 0; in < layout_list.size(); in++) { - for (uint32_t out = 0; out < layout_list.size(); out++) { + for (size_t in = 0; in < layout_list.size(); in++) { + for (size_t out = 0; out < layout_list.size(); out++) { DeviceArrangement in_device_arrangement = std::get<0>(layout_list[in]); TensorMap in_tensor_map = std::get<1>(layout_list[in]); TensorShape in_tensor_shape = std::get<2>(layout_list[in]); @@ -273,15 +270,15 @@ void ValidRedistributionLayoutCheckAll(int32_t device_pow_size, int32_t tensor_p } TEST_F(TestRedistributionLayoutTransfer, RedistributionLayoutTransferCheckAll) { - int32_t device_pow_size_max = 4; - int32_t tensor_pow_size_max = 4; - int32_t device_pow_size_min = 1; - int32_t tensor_pow_size_min = 1; - const int32_t max_device_dim = 5; - const int32_t max_shape_dim = 5; - int32_t device_pow_size = device_pow_size_min; + int64_t device_pow_size_max = 4; + int64_t tensor_pow_size_max = 4; + int64_t device_pow_size_min = 1; + int64_t tensor_pow_size_min = 1; + const int64_t max_device_dim = 5; + const int64_t max_shape_dim = 5; + int64_t device_pow_size = device_pow_size_min; while (device_pow_size <= device_pow_size_max) { - int32_t tensor_pow_size = tensor_pow_size_min; + int64_t tensor_pow_size = tensor_pow_size_min; while (tensor_pow_size <= tensor_pow_size_max) { ValidRedistributionLayoutCheckAll(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim); tensor_pow_size++; diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc index f6caad2f9d..f002735c1f 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc @@ -28,13 +28,13 @@ class TestRedistributionOperatorInfer : public UT::Common { TestRedistributionOperatorInfer() {} void SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 1050; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(1024); stage_map.push_back(26); diff --git a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc index 11f471ea33..a2faab3706 100644 --- a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc @@ -21,14 +21,11 @@ #include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" #include "util_layout_gen_test.h" #include "utils/log_adapter.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { -using DeviceArrangement = std::vector; -using TensorMap = std::vector; -using TensorShape = std::vector; - class TestReshapeLayoutTransfer : public UT::Common { public: TestReshapeLayoutTransfer() {} @@ -260,13 +257,13 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheck11) { ValidUnifiedLayoutCheck(device_arrangement, in_tensor_map, in_tensor_shape, out_tensor_map, out_tensor_shape); } -void ValidInferUnifiedLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow_size, - int32_t max_device_dim, int32_t max_shape_dim) { +void ValidInferUnifiedLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, + int64_t max_device_dim, int64_t max_shape_dim) { std::vector> layout_list; GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim, &layout_list); - for (uint32_t in = 0; in < layout_list.size(); in++) { - for (uint32_t out = 0; out < layout_list.size(); out++) { + for (size_t in = 0; in < layout_list.size(); in++) { + for (size_t out = 0; out < layout_list.size(); out++) { DeviceArrangement in_device_arrangement = std::get<0>(layout_list[in]); TensorMap in_tensor_map = std::get<1>(layout_list[in]); TensorShape in_tensor_shape = std::get<2>(layout_list[in]); @@ -287,15 +284,15 @@ void ValidInferUnifiedLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow } TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheckAll) { - int32_t device_pow_size_max = 4; - int32_t tensor_pow_size_max = 4; - int32_t device_pow_size_min = 1; - int32_t tensor_pow_size_min = 1; - const int32_t max_device_dim = 5; - const int32_t max_shape_dim = 5; - int32_t device_pow_size = device_pow_size_min; + int64_t device_pow_size_max = 4; + int64_t tensor_pow_size_max = 4; + int64_t device_pow_size_min = 1; + int64_t tensor_pow_size_min = 1; + const int64_t max_device_dim = 5; + const int64_t max_shape_dim = 5; + int64_t device_pow_size = device_pow_size_min; while (device_pow_size <= device_pow_size_max) { - int32_t tensor_pow_size = tensor_pow_size_min; + int64_t tensor_pow_size = tensor_pow_size_min; while (tensor_pow_size <= tensor_pow_size_max) { ValidInferUnifiedLayoutCheckAll(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim); tensor_pow_size++; @@ -305,15 +302,15 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheckAll) { } TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheckAll2) { - int32_t device_pow_size_max = 1; - int32_t tensor_pow_size_max = 2; - int32_t device_pow_size_min = 1; - int32_t tensor_pow_size_min = 2; - const int32_t max_device_dim = 5; - const int32_t max_shape_dim = 5; - int32_t device_pow_size = device_pow_size_min; + int64_t device_pow_size_max = 1; + int64_t tensor_pow_size_max = 2; + int64_t device_pow_size_min = 1; + int64_t tensor_pow_size_min = 2; + const int64_t max_device_dim = 5; + const int64_t max_shape_dim = 5; + int64_t device_pow_size = device_pow_size_min; while (device_pow_size <= device_pow_size_max) { - int32_t tensor_pow_size = tensor_pow_size_min; + int64_t tensor_pow_size = tensor_pow_size_min; while (tensor_pow_size <= tensor_pow_size_max) { ValidInferUnifiedLayoutCheckAll(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim); tensor_pow_size++; diff --git a/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc b/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc index 824ab876cd..10cc712a8a 100644 --- a/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc @@ -26,7 +26,7 @@ namespace parallel { * shape_accum = [2, 2 * 8, 2 * 8 * 32] */ TEST(ShapeUtilTest, ShapeToAccumulateProduct) { - std::vector shape = {2, 8, 32}; + Shape shape = {2, 8, 32}; std::vector shape_accum; Status status = ShapeToAccumulateProduct(shape, &shape_accum); ASSERT_EQ(Status::SUCCESS, status); @@ -39,7 +39,7 @@ TEST(ShapeUtilTest, ShapeToAccumulateProduct) { * shape_accum = [2 * 8 * 32, 8 * 32, 32] */ TEST(ShapeUtilTest, ShapeToAccumulateProductReverse) { - std::vector shape = {2, 8, 32}; + Shape shape = {2, 8, 32}; std::vector shape_accum; Status status = ShapeToAccumulateProductReverse(shape, &shape_accum); ASSERT_EQ(Status::SUCCESS, status); @@ -53,10 +53,10 @@ TEST(ShapeUtilTest, ShapeToAccumulateProductReverse) { */ TEST(ShapeUtilTest, AccumulateProductToShape) { std::vector shape_accum = {2, 2 * 8, 2 * 8 * 32}; - std::vector shape; + Shape shape; Status status = AccumulateProductToShape(shape_accum, &shape); ASSERT_EQ(Status::SUCCESS, status); - std::vector shape_expect = {2, 8, 32}; + Shape shape_expect = {2, 8, 32}; ASSERT_EQ(shape_expect, shape); } @@ -66,10 +66,10 @@ TEST(ShapeUtilTest, AccumulateProductToShape) { */ TEST(ShapeUtilTest, AccumulateProductReverseToShape) { std::vector shape_accum = {2 * 8 * 32, 8 * 32, 32}; - std::vector shape; + Shape shape; Status status = AccumulateProductReverseToShape(shape_accum, &shape); ASSERT_EQ(Status::SUCCESS, status); - std::vector shape_expect = {2, 8, 32}; + Shape shape_expect = {2, 8, 32}; ASSERT_EQ(shape_expect, shape); } @@ -94,12 +94,12 @@ TEST(ShapeUtilTest, UnifyAccumulateProduct) { * out = [2, 2, 2] */ TEST(ShapeUtilTest, UnifyShape1) { - std::vector in1 = {2, 4}; - std::vector in2 = {4, 2}; - std::vector out; + Shape in1 = {2, 4}; + Shape in2 = {4, 2}; + Shape out; Status status = UnifyShape(in1, in2, &out); ASSERT_EQ(Status::SUCCESS, status); - std::vector out_expect = {2, 2, 2}; + Shape out_expect = {2, 2, 2}; ASSERT_EQ(out_expect, out); } @@ -109,12 +109,12 @@ TEST(ShapeUtilTest, UnifyShape1) { * out = [2, 4, 4] */ TEST(ShapeUtilTest, UnifyShape2) { - std::vector in1 = {8, 4}; - std::vector in2 = {2, 16}; - std::vector out; + Shape in1 = {8, 4}; + Shape in2 = {2, 16}; + Shape out; Status status = UnifyShape(in1, in2, &out); ASSERT_EQ(Status::SUCCESS, status); - std::vector out_expect = {2, 4, 4}; + Shape out_expect = {2, 4, 4}; ASSERT_EQ(out_expect, out); } @@ -184,12 +184,12 @@ TEST(ShapeUtilTest, ExpandAccumulateProduct4) { * out = [2, 8, 4, 8] */ TEST(ShapeUtilTest, ExpandShape1) { - std::vector in = {2, 8, 32}; - std::vector expand = {16, 4, 8}; - std::vector out; + Shape in = {2, 8, 32}; + Shape expand = {16, 4, 8}; + Shape out; Status status = ExpandShape(in, expand, &out); ASSERT_EQ(Status::SUCCESS, status); - std::vector out_expect = {2, 8, 4, 8}; + Shape out_expect = {2, 8, 4, 8}; ASSERT_EQ(out_expect, out); } @@ -199,12 +199,12 @@ TEST(ShapeUtilTest, ExpandShape1) { * out = [2, 8, 4, 8] */ TEST(ShapeUtilTest, ExpandShape2) { - std::vector in = {2, 8, 32}; - std::vector expand = {2, 4, 8}; - std::vector out; + Shape in = {2, 8, 32}; + Shape expand = {2, 4, 8}; + Shape out; Status status = ExpandShape(in, expand, &out); ASSERT_EQ(Status::SUCCESS, status); - std::vector out_expect = {2, 4, 2, 4, 8}; + Shape out_expect = {2, 4, 2, 4, 8}; ASSERT_EQ(out_expect, out); } diff --git a/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc b/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc index 15fb16f088..7ab896794b 100644 --- a/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc @@ -18,6 +18,7 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { @@ -31,12 +32,12 @@ class TestTensorLayout : public UT::Common { virtual void TearDown() {} }; -void ReshapeExpandDeviceArrangementTestFunction(const std::vector& in_device_arrangement_shape, - const std::vector& in_tensor_map_shape, - const std::vector& in_tensor_shape_shape, - const std::vector& out_device_arrangement_shape, - const std::vector& out_tensor_map_shape, - const std::vector& out_tensor_shape_shape) { +void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape, + const TensorMap& in_tensor_map_shape, + const TensorShape& in_tensor_shape_shape, + const DeviceArrangement& out_device_arrangement_shape, + const TensorMap& out_tensor_map_shape, + const TensorShape& out_tensor_shape_shape) { Arrangement device_arrangement; Status status = device_arrangement.Init(in_device_arrangement_shape); ASSERT_EQ(Status::SUCCESS, status); @@ -70,12 +71,12 @@ void ReshapeExpandDeviceArrangementTestFunction(const std::vector& in_d * */ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement1) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {1, 0}; - std::vector tensor_shape = {512, 1024}; - std::vector device_arrangement_new = {4, 2, 2, 2}; - std::vector tensor_map_expect = {3, 2, 1, 0}; - std::vector tensor_shape_expect = {4, 128, 2, 512}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {1, 0}; + TensorShape tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement_new = {4, 2, 2, 2}; + TensorMap tensor_map_expect = {3, 2, 1, 0}; + TensorShape tensor_shape_expect = {4, 128, 2, 512}; ReshapeExpandDeviceArrangementTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_new, tensor_map_expect, tensor_shape_expect); } @@ -91,12 +92,12 @@ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement1) { * out_tensor_shape = [2, 256, 4, 256] */ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement2) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {0, 1}; - std::vector tensor_shape = {512, 1024}; - std::vector device_arrangement_new = {4, 2, 2, 2}; - std::vector tensor_map_expect = {1, 0, 3, 2}; - std::vector tensor_shape_expect = {2, 256, 4, 256}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {0, 1}; + TensorShape tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement_new = {4, 2, 2, 2}; + TensorMap tensor_map_expect = {1, 0, 3, 2}; + TensorShape tensor_shape_expect = {2, 256, 4, 256}; ReshapeExpandDeviceArrangementTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_new, tensor_map_expect, tensor_shape_expect); } @@ -111,12 +112,12 @@ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement2) { * out_tensor_shape = [4, 128, 1024] */ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement3) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {1, -1}; - std::vector tensor_shape = {512, 1024}; - std::vector device_arrangement_new = {4, 2, 2, 2}; - std::vector tensor_map_expect = {3, 2, -1}; - std::vector tensor_shape_expect = {4, 128, 1024}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {1, -1}; + TensorShape tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement_new = {4, 2, 2, 2}; + TensorMap tensor_map_expect = {3, 2, -1}; + TensorShape tensor_shape_expect = {4, 128, 1024}; ReshapeExpandDeviceArrangementTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_new, tensor_map_expect, tensor_shape_expect); } @@ -132,33 +133,33 @@ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement3) { * out_tensor_shape = [512, 4, 256] */ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement4) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {0, 1}; - std::vector tensor_shape = {512, 1024}; - std::vector device_arrangement_new = {4, 2, 4}; - std::vector tensor_map_expect = {0, 2, 1}; - std::vector tensor_shape_expect = {512, 4, 256}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {0, 1}; + TensorShape tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement_new = {4, 2, 4}; + TensorMap tensor_map_expect = {0, 2, 1}; + TensorShape tensor_shape_expect = {512, 4, 256}; ReshapeExpandDeviceArrangementTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_new, tensor_map_expect, tensor_shape_expect); } TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement5) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {1, -1, 0}; - std::vector tensor_shape = {128, 4, 1024}; - std::vector device_arrangement_new = {8, 4}; - std::vector tensor_map_expect = {1, -1, 0}; - std::vector tensor_shape_expect = {128, 4, 1024}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {1, -1, 0}; + TensorShape tensor_shape = {128, 4, 1024}; + DeviceArrangement device_arrangement_new = {8, 4}; + TensorMap tensor_map_expect = {1, -1, 0}; + TensorShape tensor_shape_expect = {128, 4, 1024}; ReshapeExpandDeviceArrangementTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_new, tensor_map_expect, tensor_shape_expect); } -void ExpandTensorShapeTestFunction(const std::vector& in_device_arrangement_shape, - const std::vector& in_tensor_map_shape, - const std::vector& in_tensor_shape_shape, - const std::vector& out_device_arrangement_shape, - const std::vector& out_tensor_map_shape, - const std::vector& out_tensor_shape_shape) { +void ExpandTensorShapeTestFunction(const DeviceArrangement& in_device_arrangement_shape, + const TensorMap& in_tensor_map_shape, + const TensorShape& in_tensor_shape_shape, + const DeviceArrangement& out_device_arrangement_shape, + const TensorMap& out_tensor_map_shape, + const TensorShape& out_tensor_shape_shape) { Arrangement device_arrangement; Status status = device_arrangement.Init(in_device_arrangement_shape); ASSERT_EQ(Status::SUCCESS, status); @@ -193,31 +194,31 @@ void ExpandTensorShapeTestFunction(const std::vector& in_device_arrange * out_tensor_map = [2, 1, 0], */ TEST_F(TestTensorLayout, ExpandTensorShape1) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {1, 0}; - std::vector tensor_shape = {512, 1024}; - std::vector device_arrangement_expect = {4, 2, 4}; - std::vector tensor_map_expect = {2, 1, 0}; - std::vector tensor_shape_new = {4, 128, 1024}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {1, 0}; + TensorShape tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement_expect = {4, 2, 4}; + TensorMap tensor_map_expect = {2, 1, 0}; + TensorShape tensor_shape_new = {4, 128, 1024}; ExpandTensorShapeTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); } TEST_F(TestTensorLayout, ExpandTensorShape2) { - std::vector device_arrangement = {8, 4}; - std::vector tensor_map = {1, 0}; - std::vector tensor_shape = {128, 4096}; - std::vector device_arrangement_expect = {8, 4}; - std::vector tensor_map_expect = {1, 0, -1}; - std::vector tensor_shape_new = {128, 4, 1024}; + DeviceArrangement device_arrangement = {8, 4}; + TensorMap tensor_map = {1, 0}; + TensorShape tensor_shape = {128, 4096}; + DeviceArrangement device_arrangement_expect = {8, 4}; + TensorMap tensor_map_expect = {1, 0, -1}; + TensorShape tensor_shape_new = {128, 4, 1024}; ExpandTensorShapeTestFunction(device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); } TEST_F(TestTensorLayout, GetSliceShape) { - std::vector in_device_arrangement = {8, 4}; - std::vector in_tensor_map = {1, -1}; - std::vector in_tensor_shape = {512, 1024}; + DeviceArrangement in_device_arrangement = {8, 4}; + TensorMap in_tensor_map = {1, -1}; + TensorShape in_tensor_shape = {512, 1024}; Arrangement device_arrangement; device_arrangement.Init(in_device_arrangement); Map tensor_map; @@ -233,9 +234,9 @@ TEST_F(TestTensorLayout, GetSliceShape) { } TEST_F(TestTensorLayout, UpdateTensorMap) { - std::vector in_device_arrangement = {8, 4}; - std::vector in_tensor_map = {1, -1}; - std::vector in_tensor_shape = {512, 1024}; + DeviceArrangement in_device_arrangement = {8, 4}; + TensorMap in_tensor_map = {1, -1}; + TensorShape in_tensor_shape = {512, 1024}; Arrangement device_arrangement; device_arrangement.Init(in_device_arrangement); Map tensor_map; @@ -250,12 +251,12 @@ TEST_F(TestTensorLayout, UpdateTensorMap) { ASSERT_EQ(in_tensor_map, new_tensor_map); } -void RemoveElementEqualToOneInDeviceArrangementTestFunction(const std::vector& in_device_arrangement_shape, - const std::vector& in_tensor_map_shape, - const std::vector& in_tensor_shape_shape, - const std::vector& out_device_arrangement_shape, - const std::vector& out_tensor_map_shape, - const std::vector& out_tensor_shape_shape) { +void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape, + const TensorMap& in_tensor_map_shape, + const TensorShape& in_tensor_shape_shape, + const DeviceArrangement& out_device_arrangement_shape, + const TensorMap& out_tensor_map_shape, + const TensorShape& out_tensor_shape_shape) { Arrangement device_arrangement; Status status = device_arrangement.Init(in_device_arrangement_shape); ASSERT_EQ(Status::SUCCESS, status); @@ -277,45 +278,45 @@ void RemoveElementEqualToOneInDeviceArrangementTestFunction(const std::vector device_arrangement = {2, 2, 1}; - std::vector tensor_map = {2, 1}; - std::vector tensor_shape = {128, 4096}; - std::vector device_arrangement_expect = {2, 2}; - std::vector tensor_map_expect = {1, 0}; - std::vector tensor_shape_new = {128, 4096}; + DeviceArrangement device_arrangement = {2, 2, 1}; + TensorMap tensor_map = {2, 1}; + TensorShape tensor_shape = {128, 4096}; + DeviceArrangement device_arrangement_expect = {2, 2}; + TensorMap tensor_map_expect = {1, 0}; + TensorShape tensor_shape_new = {128, 4096}; RemoveElementEqualToOneInDeviceArrangementTestFunction( device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); } TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement2) { - std::vector device_arrangement = {16, 1, 1}; - std::vector tensor_map = {2, 0}; - std::vector tensor_shape = {128, 4096}; - std::vector device_arrangement_expect = {16}; - std::vector tensor_map_expect = {0, -1}; - std::vector tensor_shape_new = {128, 4096}; + DeviceArrangement device_arrangement = {16, 1, 1}; + TensorMap tensor_map = {2, 0}; + TensorShape tensor_shape = {128, 4096}; + DeviceArrangement device_arrangement_expect = {16}; + TensorMap tensor_map_expect = {0, -1}; + TensorShape tensor_shape_new = {128, 4096}; RemoveElementEqualToOneInDeviceArrangementTestFunction( device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); } TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement3) { - std::vector device_arrangement = {1, 16, 1}; - std::vector tensor_map = {2, 1}; - std::vector tensor_shape = {128, 4096}; - std::vector device_arrangement_expect = {16}; - std::vector tensor_map_expect = {-1, 0}; - std::vector tensor_shape_new = {128, 4096}; + DeviceArrangement device_arrangement = {1, 16, 1}; + TensorMap tensor_map = {2, 1}; + TensorShape tensor_shape = {128, 4096}; + DeviceArrangement device_arrangement_expect = {16}; + TensorMap tensor_map_expect = {-1, 0}; + TensorShape tensor_shape_new = {128, 4096}; RemoveElementEqualToOneInDeviceArrangementTestFunction( device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); } TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement4) { - std::vector device_arrangement = {1, 1, 1}; - std::vector tensor_map = {2, 1}; - std::vector tensor_shape = {128, 4096}; - std::vector device_arrangement_expect = {}; - std::vector tensor_map_expect = {-1, -1}; - std::vector tensor_shape_new = {128, 4096}; + DeviceArrangement device_arrangement = {1, 1, 1}; + TensorMap tensor_map = {2, 1}; + TensorShape tensor_shape = {128, 4096}; + DeviceArrangement device_arrangement_expect = {}; + TensorMap tensor_map_expect = {-1, -1}; + TensorShape tensor_shape_new = {128, 4096}; RemoveElementEqualToOneInDeviceArrangementTestFunction( device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); } diff --git a/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc b/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc index 40a4017c4b..97e2240a9b 100644 --- a/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc @@ -18,6 +18,7 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { @@ -33,7 +34,7 @@ class TestTensorRedistribution : public UT::Common { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(16); stage_map.push_back(4); @@ -49,9 +50,9 @@ class TestTensorRedistribution : public UT::Common { // Redistribution: Reshape -> SplitByAxis -> ConcatByAxis -> SplitByAxis -> Reshape TEST_F(TestTensorRedistribution, TestInferRedistribution1) { - std::vector device_arrangement = {2, 4, 2}; - std::vector tensor_map = {2, 0}; - std::vector tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement = {2, 4, 2}; + TensorMap tensor_map = {2, 0}; + TensorShape tensor_shape = {512, 1024}; Arrangement in_device_arrangement; Status status = in_device_arrangement.Init(device_arrangement); @@ -102,9 +103,9 @@ TEST_F(TestTensorRedistribution, TestInferRedistribution1) { // Redistribution: AlltoAll TEST_F(TestTensorRedistribution, TestInferRedistribution2) { - std::vector device_arrangement = {16, 1, 1}; - std::vector tensor_map = {2, 0}; - std::vector tensor_shape = {512, 1024}; + DeviceArrangement device_arrangement = {16, 1, 1}; + TensorMap tensor_map = {2, 0}; + TensorShape tensor_shape = {512, 1024}; Arrangement in_device_arrangement; Status status = in_device_arrangement.Init(device_arrangement); @@ -154,9 +155,9 @@ TEST_F(TestTensorRedistribution, TestInferRedistribution2) { // Redistribution: Reshape TEST_F(TestTensorRedistribution, TestInferRedistribution3) { - std::vector device_arrangement = {8}; - std::vector tensor_map = {0, -1, -1, -1}; - std::vector tensor_shape = {128, 64, 1, 1}; + DeviceArrangement device_arrangement = {8}; + TensorMap tensor_map = {0, -1, -1, -1}; + TensorShape tensor_shape = {128, 64, 1, 1}; Arrangement in_device_arrangement; Status status = in_device_arrangement.Init(device_arrangement); diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc index 330b571ae7..85ccd673d8 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc @@ -28,21 +28,21 @@ using std::pow; namespace mindspore { namespace parallel { -std::vector> combine(const std::vector& in, int32_t target) { - std::vector> output; - for (int32_t i = 0; i < pow(2, in.size()); i++) { - int32_t temp = 0; - int32_t count = 0; - std::vector left; - for (int32_t j = 0; j < in.size(); j++) { +std::vector combine(const Shape& in, int64_t target) { + std::vector output; + for (int64_t i = 0; i < pow(2, in.size()); i++) { + size_t temp = 0; + size_t count = 0; + Shape left; + for (size_t j = 0; j < in.size(); j++) { if ((i & (1 << j)) != 0) { left.push_back(j); count++; } } if (count == target) { - std::vector one_case; - for (int32_t j = 0; j < count; j++) { + Shape one_case; + for (size_t j = 0; j < count; j++) { temp = in.size() - 1 - left[j]; one_case.push_back(in[temp]); } @@ -54,24 +54,23 @@ std::vector> combine(const std::vector& in, int32_ return output; } -void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim, - std::vector>* out) { +void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector* out) { out->clear(); - std::vector in; - for (int32_t i = 1; i < pow_size; i++) { + Shape in; + for (int64_t i = 1; i < pow_size; i++) { in.push_back(i); } - std::vector> combine_result; + std::vector combine_result; combine_result = combine(in, dim - 1); if (combine_result.size() == 0) { - int32_t size = exp2(pow_size); - std::vector item = {size}; + int64_t size = exp2(pow_size); + Shape item = {size}; out->push_back(item); } - for (uint32_t i = 0; i < combine_result.size(); i++) { - std::vector item; - int32_t prev = 0; - for (int32_t j = combine_result[i].size() - 1; j >= 0; j--) { + for (size_t i = 0; i < combine_result.size(); i++) { + Shape item; + int64_t prev = 0; + for (int64_t j = combine_result[i].size() - 1; j >= 0; j--) { item.push_back(exp2(combine_result[i][j] - prev)); prev = combine_result[i][j]; } @@ -81,22 +80,21 @@ void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim, return; } -void GenerateValidShapeBySize(int32_t pow_size, std::vector>* out) { +void GenerateValidShapeBySize(int64_t pow_size, std::vector* out) { out->clear(); - for (int32_t dim = 1; dim <= pow_size; dim++) { - std::vector> combine_result; + for (int64_t dim = 1; dim <= pow_size; dim++) { + std::vector combine_result; GenerateValidShapeBySizeAndDim(pow_size, dim, &combine_result); - for (uint32_t i = 0; i < combine_result.size(); i++) { + for (size_t i = 0; i < combine_result.size(); i++) { out->push_back(combine_result[i]); } } return; } -std::vector GenerateTensorMap(const uint32_t& map_size, const std::vector& pos_index, - const std::vector& pos_value) { - std::vector tensor_map(map_size, -1); - for (uint32_t i = 0; i < pos_index.size() && i < pos_value.size(); i++) { +TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, const Shape& pos_value) { + TensorMap tensor_map(map_size, -1); + for (size_t i = 0; i < pos_index.size() && i < pos_value.size(); i++) { if (pos_index[i] >= map_size) { continue; } @@ -105,43 +103,43 @@ std::vector GenerateTensorMap(const uint32_t& map_size, const std::vect return tensor_map; } -void GenerateValidTensorMap(const std::vector& device_arrangement, const std::vector& tensor_shape, - std::vector>* tensor_map_list) { +void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const TensorShape& tensor_shape, + std::vector* tensor_map_list) { tensor_map_list->clear(); - int32_t device_size = device_arrangement.size(); - int32_t shape_size = tensor_shape.size(); - std::vector pos_ind_combine_in; - for (int32_t i = 0; i < shape_size; i++) { + int64_t device_size = device_arrangement.size(); + int64_t shape_size = tensor_shape.size(); + Shape pos_ind_combine_in; + for (int64_t i = 0; i < shape_size; i++) { pos_ind_combine_in.push_back(i); } - std::vector dev_ind_combine_in; - for (int32_t i = 0; i < device_size; i++) { + Shape dev_ind_combine_in; + for (int64_t i = 0; i < device_size; i++) { dev_ind_combine_in.push_back(i); } - std::vector none_map(tensor_shape.size(), -1); + TensorMap none_map(tensor_shape.size(), -1); tensor_map_list->push_back(none_map); - for (uint32_t pos_num = 1; (pos_num <= shape_size) && (pos_num <= device_size); pos_num++) { - std::vector> pos_index; + for (int64_t pos_num = 1; (pos_num <= shape_size) && (pos_num <= device_size); pos_num++) { + std::vector pos_index; pos_index = combine(pos_ind_combine_in, pos_num); - std::vector> dev_index; + std::vector dev_index; dev_index = combine(dev_ind_combine_in, pos_num); - for (int l = 0; l < dev_index.size(); l++) { - std::vector pos_value_combine_in; + for (size_t l = 0; l < dev_index.size(); l++) { + Shape pos_value_combine_in; for (int32_t i = dev_index[l].size() - 1; i >= 0; i--) { pos_value_combine_in.push_back(dev_index[l][i]); } - std::vector> pos_value; - std::vector::iterator it = pos_value_combine_in.begin(); + std::vector pos_value; + Shape::iterator it = pos_value_combine_in.begin(); do { - std::vector pos_value_item; - for (uint32_t m = 0; m < pos_num; m++) { + Shape pos_value_item; + for (size_t m = 0; m < pos_num; m++) { pos_value_item.push_back(pos_value_combine_in[m]); } pos_value.push_back(pos_value_item); } while (next_permutation(it, it + pos_num)); - for (uint32_t j = 0; j < pos_index.size(); j++) { - for (uint32_t k = 0; k < pos_value.size(); k++) { - std::vector tensor_map = GenerateTensorMap(shape_size, pos_index[j], pos_value[k]); + for (size_t j = 0; j < pos_index.size(); j++) { + for (size_t k = 0; k < pos_value.size(); k++) { + TensorMap tensor_map = GenerateTensorMap(shape_size, pos_index[j], pos_value[k]); tensor_map_list->push_back(tensor_map); } } @@ -151,19 +149,19 @@ void GenerateValidTensorMap(const std::vector& device_arrangement, cons } void GenerateValidLayoutByDeviceSizeAndTensorSize( - int32_t device_pow_size, int32_t tensor_pow_size, int32_t max_device_dim, - int32_t max_shape_dim, - std::vector, std::vector, std::vector>>* layout_list) { + int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, + int64_t max_shape_dim, + std::vector>* layout_list) { layout_list->clear(); - std::vector> device_arrangement_list; + std::vector device_arrangement_list; GenerateValidShapeBySize(device_pow_size, &device_arrangement_list); - std::vector> tensor_shape_list; + std::vector tensor_shape_list; GenerateValidShapeBySize(tensor_pow_size, &tensor_shape_list); - for (uint32_t device_idx = 0; device_idx < device_arrangement_list.size(); device_idx++) { - for (uint32_t shape_idx = 0; shape_idx < tensor_shape_list.size(); shape_idx++) { - std::vector> tensor_map_list; + for (size_t device_idx = 0; device_idx < device_arrangement_list.size(); device_idx++) { + for (size_t shape_idx = 0; shape_idx < tensor_shape_list.size(); shape_idx++) { + std::vector tensor_map_list; GenerateValidTensorMap(device_arrangement_list[device_idx], tensor_shape_list[shape_idx], &tensor_map_list); - for (uint32_t map_idx = 0; map_idx < tensor_map_list.size(); map_idx++) { + for (size_t map_idx = 0; map_idx < tensor_map_list.size(); map_idx++) { if (!CheckLayoutValid(device_arrangement_list[device_idx], tensor_map_list[map_idx], tensor_shape_list[shape_idx])) { continue; @@ -176,8 +174,8 @@ void GenerateValidLayoutByDeviceSizeAndTensorSize( return; } -bool CheckLayoutValid(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape) { +bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, + const TensorShape& tensor_shape) { bool flag = false; if ((tensor_map.size() - ComputeNoneNumber(tensor_map)) > device_arrangement.size()) { return flag; @@ -188,9 +186,9 @@ bool CheckLayoutValid(const std::vector& device_arrangement, const std: return true; } -uint32_t ComputeNoneNumber(const std::vector& tensor_map) { - uint32_t num = 0; - for (uint32_t i = 0; i < tensor_map.size(); i++) { +size_t ComputeNoneNumber(const TensorMap& tensor_map) { + size_t num = 0; + for (size_t i = 0; i < tensor_map.size(); i++) { if (tensor_map[i] == -1) { num++; } @@ -198,14 +196,14 @@ uint32_t ComputeNoneNumber(const std::vector& tensor_map) { return num; } -bool ShapeIsDividedByDevice(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape) { +bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, + const TensorShape& tensor_shape) { bool flag = false; for (uint32_t i = 0; i < tensor_map.size() && i < tensor_shape.size(); i++) { if (tensor_map[i] == -1) { continue; } - int32_t dim = device_arrangement[device_arrangement.size() - 1 - tensor_map[i]]; + int64_t dim = device_arrangement[device_arrangement.size() - 1 - tensor_map[i]]; if (tensor_shape[i] % dim != 0) { return flag; } @@ -213,8 +211,8 @@ bool ShapeIsDividedByDevice(const std::vector& device_arrangement, cons return true; } -bool IsExpended(const std::vector& in1, const std::vector& in2) { - int32_t size = 1; +bool IsExpended(const Shape& in1, const Shape& in2) { + int64_t size = 1; uint32_t ind = 0; for (uint32_t i = 0; i < in1.size(); i++) { size *= in1[i]; @@ -236,9 +234,9 @@ bool IsExpended(const std::vector& in1, const std::vector& in2 return true; } -void ComputeAccumDeviceTOAccumShapeMap(const std::vector& device_arrangement, - const std::vector& tensor_map, const std::vector& tensor_shape, - std::map* accum_device_to_accum_shape_map) { +void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement& device_arrangement, + const TensorMap& tensor_map, const TensorShape& tensor_shape, + std::map* accum_device_to_accum_shape_map) { accum_device_to_accum_shape_map->clear(); std::vector shape_accum_reverse; Status status = ShapeToAccumulateProductReverse(tensor_shape, &shape_accum_reverse); @@ -258,42 +256,42 @@ void ComputeAccumDeviceTOAccumShapeMap(const std::vector& device_arrang return; } -void IsLinearValue(int32_t small, int32_t big, int32_t small_value, int32_t big_value, int32_t middle, - int32_t middle_value) { +void IsLinearValue(int64_t small, int64_t big, int64_t small_value, int64_t big_value, int64_t middle, + int64_t middle_value) { ASSERT_NE(big, small); - int32_t value = (middle - small) * (big_value - small_value) / (big - small) + small_value; + int64_t value = (middle - small) * (big_value - small_value) / (big - small) + small_value; ASSERT_EQ(middle_value, value); } -void LayoutTransferValidLayoutChangeCheck(const std::vector& in_device_arrangement, - const std::vector& in_tensor_map, - const std::vector& in_tensor_shape, - const std::vector& out_device_arrangement, - const std::vector& out_tensor_map, - const std::vector& out_tensor_shape) { +void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, + const TensorMap& in_tensor_map, + const TensorShape& in_tensor_shape, + const DeviceArrangement& out_device_arrangement, + const TensorMap& out_tensor_map, + const TensorShape& out_tensor_shape) { bool is_expended = IsExpended(out_device_arrangement, in_device_arrangement); ASSERT_EQ(true, is_expended); is_expended = IsExpended(out_tensor_shape, in_tensor_shape); ASSERT_EQ(true, is_expended); - std::map out_accum_device_to_accum_shape_map; + std::map out_accum_device_to_accum_shape_map; ComputeAccumDeviceTOAccumShapeMap(out_device_arrangement, out_tensor_map, out_tensor_shape, &out_accum_device_to_accum_shape_map); - std::map in_accum_device_to_accum_shape_map; + std::map in_accum_device_to_accum_shape_map; ComputeAccumDeviceTOAccumShapeMap(in_device_arrangement, in_tensor_map, in_tensor_shape, &in_accum_device_to_accum_shape_map); - std::map::iterator in_iter = in_accum_device_to_accum_shape_map.begin(); + std::map::iterator in_iter = in_accum_device_to_accum_shape_map.begin(); while (in_iter != in_accum_device_to_accum_shape_map.end()) { if (in_iter->second != out_accum_device_to_accum_shape_map[in_iter->first]) { continue; } in_iter++; } - std::map::iterator out_iter = out_accum_device_to_accum_shape_map.begin(); + std::map::iterator out_iter = out_accum_device_to_accum_shape_map.begin(); while (out_iter != out_accum_device_to_accum_shape_map.end()) { if (out_accum_device_to_accum_shape_map.find(out_iter->first) == out_accum_device_to_accum_shape_map.end()) { in_iter = in_accum_device_to_accum_shape_map.begin(); - int32_t small = 1; - int32_t big = 1; + int64_t small = 1; + int64_t big = 1; while (in_iter != in_accum_device_to_accum_shape_map.end()) { if (in_iter->first < out_iter->first) { small = in_iter->second; @@ -311,18 +309,18 @@ void LayoutTransferValidLayoutChangeCheck(const std::vector& in_device_ if (big == 1) { ASSERT_EQ(true, false); } - int32_t small_value = in_accum_device_to_accum_shape_map[small]; - int32_t big_value = in_accum_device_to_accum_shape_map[big]; + int64_t small_value = in_accum_device_to_accum_shape_map[small]; + int64_t big_value = in_accum_device_to_accum_shape_map[big]; IsLinearValue(small, big, small_value, big_value, out_iter->first, out_iter->second); } out_iter++; } } -void ValidLayoutChangeCheck(const std::vector& in_device_arrangement, - const std::vector& in_tensor_map, const std::vector& in_tensor_shape, - const std::vector& out_device_arrangement, - const std::vector& out_tensor_map, const std::vector& out_tensor_shape) { +void ValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, + const TensorMap& in_tensor_map, const TensorShape& in_tensor_shape, + const DeviceArrangement& out_device_arrangement, + const TensorMap& out_tensor_map, const TensorShape& out_tensor_shape) { LayoutTransferValidLayoutChangeCheck(in_device_arrangement, in_tensor_map, in_tensor_shape, out_device_arrangement, out_tensor_map, out_tensor_shape); } diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h index c16a1fc6d4..e0b56fe0a4 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h @@ -21,51 +21,50 @@ #include #include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { -std::vector> combine(const std::vector& in, int32_t target); +std::vector combine(const Shape& in, int64_t target); -void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim, - std::vector>* out); +void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector* out); -void GenerateValidShapeBySize(int32_t pow_size, std::vector>* out); +void GenerateValidShapeBySize(int64_t pow_size, std::vector* out); -std::vector GenerateTensorMap(const uint32_t& map_size, const std::vector& pos_index, - const std::vector& pos_value); +TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, const Shape& pos_value); -void GenerateValidTensorMap(const std::vector& device_arrangement, const std::vector& tensor_shape, - std::vector>* tensor_map_list); +void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const TensorMap& tensor_shape, + std::vector* tensor_map_list); void GenerateValidLayoutByDeviceSizeAndTensorSize( - int32_t device_pow_size, int32_t tensor_pow_size, int32_t max_device_dim, - int32_t max_shape_dim, - std::vector, std::vector, std::vector>>* layout_list); + int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, + int64_t max_shape_dim, + std::vector>* layout_list); -uint32_t ComputeNoneNumber(const std::vector& tensor_map); +size_t ComputeNoneNumber(const TensorMap& tensor_map); -bool ShapeIsDividedByDevice(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape); +bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, + const TensorShape& tensor_shape); -bool CheckLayoutValid(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape); +bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, + const TensorShape& tensor_shape); -void ComputeAccumDeviceTOAccumShapeMap(const std::vector& device_arrangement, - const std::vector& tensor_map, const std::vector& tensor_shape, - std::map* accum_device_to_accum_shape_map); +void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement& device_arrangement, + const TensorMap& tensor_map, const TensorShape& tensor_shape, + std::map* accum_device_to_accum_shape_map); -void LayoutTransferValidLayoutChangeCheck(const std::vector& in_device_arrangement, - const std::vector& in_tensor_map, - const std::vector& in_tensor_shape, - const std::vector& out_device_arrangement, - const std::vector& out_tensor_map, - const std::vector& out_tensor_shape); +void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, + const TensorMap& in_tensor_map, + const TensorShape& in_tensor_shape, + const DeviceArrangement& out_device_arrangement, + const TensorMap& out_tensor_map, + const TensorShape& out_tensor_shape); -void ValidLayoutChangeCheck(const std::vector& in_device_arrangement, - const std::vector& in_tensor_map, const std::vector& in_tensor_shape, - const std::vector& out_device_arrangement, - const std::vector& out_tensor_map, const std::vector& out_tensor_shape); +void ValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, + const TensorMap& in_tensor_map, const TensorShape& in_tensor_shape, + const DeviceArrangement& out_device_arrangement, + const TensorMap& out_tensor_map, const TensorShape& out_tensor_shape); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/cpp/parallel/virtual_dataset_test.cc b/tests/ut/cpp/parallel/virtual_dataset_test.cc index 4cafdebc17..dfa18bccd3 100644 --- a/tests/ut/cpp/parallel/virtual_dataset_test.cc +++ b/tests/ut/cpp/parallel/virtual_dataset_test.cc @@ -37,13 +37,13 @@ class TestVirtualDatasetInfo : public UT::Common { }; void TestVirtualDatasetInfo::SetUp() { - std::vector dev_list; + RankList dev_list; for (int32_t i = 0; i < 130; i++) { dev_list.push_back(i); } - std::vector stage_map; + RankList stage_map; stage_map.push_back(16); stage_map.push_back(114); @@ -62,27 +62,27 @@ void TestVirtualDatasetInfo::SetUp() { } TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) { - std::vector inputs = {{16, 1}, {16, 1}, {16, 1}}; + Strategys inputs = {{16, 1}, {16, 1}, {16, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); virtual_dataset->Init(strategy); - std::vector dev_matrix_shape = virtual_dataset->dev_matrix_shape(); + Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape(); - std::vector expect = {16}; + Shape expect = {16}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape2) { - std::vector inputs = {{8, 1}, {8, 1}, {8, 1}}; + Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); virtual_dataset->Init(strategy); - std::vector dev_matrix_shape = virtual_dataset->dev_matrix_shape(); + Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape(); - std::vector expect = {8, 2}; + Shape expect = {8, 2}; ASSERT_EQ(dev_matrix_shape, expect); } TEST_F(TestVirtualDatasetInfo, InferSliceShape1) { - std::vector str = {{8, 1}, {8, 1}, {8, 1}}; + Strategys str = {{8, 1}, {8, 1}, {8, 1}}; StrategyPtr strategy = NewStrategy(0, str); virtual_dataset->Init(strategy); @@ -127,7 +127,7 @@ TEST_F(TestVirtualDatasetInfo, InferSliceShape1) { } TEST_F(TestVirtualDatasetInfo, GetTensorLayout1) { - std::vector str = {{8, 1}, {8, 1}, {8, 1}}; + Strategys str = {{8, 1}, {8, 1}, {8, 1}}; StrategyPtr strategy = NewStrategy(0, str); virtual_dataset->Init(strategy); @@ -148,7 +148,7 @@ TEST_F(TestVirtualDatasetInfo, GetTensorLayout1) { } TEST_F(TestVirtualDatasetInfo, GetForwardOp1) { - std::vector inputs = {{8, 1}, {8, 1}, {8, 1}}; + Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); virtual_dataset->Init(strategy); @@ -159,7 +159,7 @@ TEST_F(TestVirtualDatasetInfo, GetForwardOp1) { } TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) { - std::vector inputs = {{8, 1}, {8, 1}, {8, 1}}; + Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; StrategyPtr strategy = NewStrategy(0, inputs); virtual_dataset->Init(strategy); diff --git a/tests/ut/cpp/pipeline/resource_test.cc b/tests/ut/cpp/pipeline/resource_test.cc index b6be393652..f6fe8e5242 100644 --- a/tests/ut/cpp/pipeline/resource_test.cc +++ b/tests/ut/cpp/pipeline/resource_test.cc @@ -36,23 +36,23 @@ class TestResource : public UT::Common { void TearDown() {} }; -TEST_F(TestResource, test_standard_method_map) { - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt8)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt16)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt32)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt64)); +TEST_F(TestResource, test_built_in_type_map) { + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt8)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt16)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt32)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt64)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat16)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat32)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat64)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat16)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat32)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat64)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeBool)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeUInt)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTuple)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeList)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTensorType)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeBool)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeUInt)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTuple)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeList)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTensorType)); MethodMap& map = GetMethodMap(); for (auto& iter : map) { diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index 8ebea4d212..d037a9019c 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -467,24 +467,6 @@ TEST_F(TestPrim, test_env_add) { ASSERT_TRUE(*res == *exp); } -TEST_F(TestPrim, test_shape) { - PrimitivePtr shap = std::make_shared("Shape"); - FuncGraphPtr func_graph = MakeFuncGraph(shap, 1); - - auto a = UTPrimUtils::ArrayFloat64Of({2, 3}); - - AbstractBasePtrList args_spec_list = {a}; - - AbstractTuplePtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred->abstract()); - auto ret = res->BuildValue()->cast()->value(); - - std::vector element_list = {MakeValue(2), MakeValue(3)}; - ASSERT_TRUE(ret.size() == element_list.size()); - for (int i = 0; i < element_list.size(); i++) { - ASSERT_TRUE(*ret[i] == *element_list[i]); - } -} - TEST_F(TestPrim, test_relu) { PrimitivePtr relu = prim::kPrimRelu; relu->AddAttr("T", MakeValue(static_cast(kNumberTypeFloat64))); diff --git a/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc b/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc index e32a86d9be..1a320d72ed 100644 --- a/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc @@ -24,7 +24,7 @@ #include "pipeline/jit/static_analysis/program_specialize.h" #include "pipeline/static_analysis/helper.h" #include "utils/log_adapter.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "utils/misc.h" #include "debug/draw.h" diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc index 103d0f21a4..f5c3902983 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -22,6 +22,7 @@ #include "utils/utils.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "backend/optimizer/common/optimizer.h" +#include "ir/param_value.h" #define private public #define protected public #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" @@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { ~MockInsertMemcpyForHcclKernelQuery() override = default; bool IsTbeRef(const AnfNodePtr &node) override { MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr) { + if (!node->isa()) { return false; } - auto name = AnfAlgo::GetCNodeName(cnode); - return name == "ApplyMomentum"; + return AnfAlgo::GetCNodeName(node->cast()) == "ApplyMomentum"; } }; @@ -105,6 +104,10 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { AbstractBasePtrList args_spec_list{x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); EXPECT_NE(kg, nullptr); + for (auto p : kg->parameters()) { + auto param = p->cast(); + EXPECT_NE(param, nullptr); + } auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -146,10 +149,15 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { ASSERT_TRUE(g != nullptr); std::vector shp_x{1, 64, 112, 112}; auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); EXPECT_NE(kg, nullptr); + for (auto p : kg->parameters()) { + auto param = p->cast(); + EXPECT_NE(param, nullptr); + } + auto optimizer = std::make_shared(); auto pm = std::make_shared(); auto pass = std::make_shared(); @@ -161,5 +169,33 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWInsertMemcpyForHccl, test_cond5) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + + for (auto p : kg->parameters()) { + auto param = p->cast(); + EXPECT_NE(param, nullptr); + } + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->kernel_query_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + kg->SetExecOrderByDefault(); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc index 2b61a49048..c8909455d3 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc @@ -26,7 +26,7 @@ #include "backend/optimizer/ascend/format_type/insert_cast.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc index 0a5cf3dd9e..fbb3733c9b 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc @@ -21,7 +21,7 @@ #include "backend/optimizer/common/pass_manager.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #define private public #define protected public @@ -50,6 +50,8 @@ class TestHWInsertTransOp : public BackendCommon { KernelBuildInfoBuilder builder; builder.SetInputsFormat({format, format}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); + builder.SetInputsReshapeType({{},{}}); + builder.SetOutputsReshapeType({}); builder.SetOutputsFormat({format}); builder.SetOutputsDeviceType({kFloat16->type_id()}); add->set_kernel_info(std::make_shared()); @@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon { EXPECT_NE(ret->input(1)->cast()->input(1)->cast()->input(1), nullptr); auto max_pool = ret->input(1)->cast()->input(1)->cast()->input(1); KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{},{}}); builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({format, format}); @@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { ~MockInsertTransOpKernelSelectTrans4Dto5D() override = default; void SelectKernel(const CNodePtr &cnode) override { KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc new file mode 100644 index 0000000000..d4a69e70a5 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "debug/anf_ir_dump.h" +#include "common/py_func_graph_fetcher.h" +#include "backend/optimizer/ascend/format_type/remove_internal_output.h" + +#define private public +#define protected public +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; + +class TestHWRemoveInternalOutput : public BackendCommon { + public: + TestHWRemoveInternalOutput() : getPyFun_("gtest_input.pre_activate.remove_internal_output_test", true) {} + ~TestHWRemoveInternalOutput() override = default; + + AnfNodePtr GetMakeTuple(const KernelGraphPtr &kg) { + auto ret = kg->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto make_tuple = ret->input(1); + return make_tuple; + } + + KernelGraphPtr GetSingleOutputGraph(const std::string &func_name, const std::string &sub_func_name) { + FuncGraphPtr g = getPyFun_.CallAndParseRet(func_name, sub_func_name); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + auto make_tuple = GetMakeTuple(kg); + auto add = make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(add); + kg->AddInternalOutput(add, add); + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}}); + builder.SetOutputsFormat({kOpFormat_NC1HWC0}); + builder.SetOutputsDeviceType({kFloat16->type_id()}); + add->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get()); + return kg; + } + + KernelGraphPtr GetMutilpleOutputGraph(const std::string &func_name, const std::string &sub_func_name) { + FuncGraphPtr g = getPyFun_.CallAndParseRet(func_name, sub_func_name); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + auto output_make_tuple = GetMakeTuple(kg); + auto make_tuple = output_make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(make_tuple); + auto tuple_getitem1 = make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(tuple_getitem1); + auto tuple_getitem2 = make_tuple->cast()->input(2); + MS_EXCEPTION_IF_NULL(tuple_getitem2); + auto max_pool = tuple_getitem1->cast()->input(1); + MS_EXCEPTION_IF_NULL(max_pool); + kg->AddInternalOutput(tuple_getitem1, max_pool); + kg->AddInternalOutput(tuple_getitem2, max_pool); + KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}, {}}); + builder.SetInputsFormat({kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kFloat32->type_id()}); + builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); + builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); + max_pool->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), max_pool.get()); + return kg; + } + UT::PyFuncGraphFetcher getPyFun_; +}; + +class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { + public: + MockRemoveInternalOutputTransOpKernelSelect() = default; + ~MockRemoveInternalOutputTransOpKernelSelect() override = default; + void SelectKernel(const CNodePtr &cnode) override { + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({kOpFormat_NC1HWC0}); + builder.SetInputsDeviceType({kFloat16->type_id()}); + builder.SetOutputsFormat({kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); + } +}; + +TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_execution_mode(kGraphMode); + auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before"); + // insert trans op for output + auto graph_optimizer = std::make_shared(); + auto pass_manager = std::make_shared(); + auto insert_trans_op_pass = std::make_shared(); + insert_trans_op_pass->kernel_select_ = std::make_shared(); + pass_manager->AddPass(insert_trans_op_pass); + graph_optimizer->AddPassManager(pass_manager); + auto new_g = graph_optimizer->Optimize(kg); + FuncGraphPtr g_after = + getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_single_output", "after_insert_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_g)); + + auto make_tuple = GetMakeTuple(kg); + auto trans_data = make_tuple->cast()->input(1); + EXPECT_TRUE(kg->IsInternalOutput(trans_data, 0)); + + // remove trans op for internal output + auto graph_optimizer1 = std::make_shared(); + auto pass_manager1 = std::make_shared(); + auto remove_internal_output_trans_op_pass = std::make_shared(); + pass_manager1->AddPass(remove_internal_output_trans_op_pass); + graph_optimizer1->AddPassManager(pass_manager1); + auto new_g1 = graph_optimizer1->Optimize(new_g); + FuncGraphPtr g_after1 = getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_single_output", + "after_remove_internal_output_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after1, new_g1)); +} + +TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_multiple_output) { + auto kg = GetMutilpleOutputGraph("test_remove_internal_output_trans_op_for_multiple_output", "before"); + // insert trans op for output + auto graph_optimizer = std::make_shared(); + auto pass_manager = std::make_shared(); + auto insert_trans_op_pass = std::make_shared(); + insert_trans_op_pass->kernel_select_ = std::make_shared(); + pass_manager->AddPass(insert_trans_op_pass); + graph_optimizer->AddPassManager(pass_manager); + auto new_g = graph_optimizer->Optimize(kg); + FuncGraphPtr g_after = + getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_multiple_output", "after_insert_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_g)); + + auto output_make_tuple = GetMakeTuple(kg); + auto make_tuple = output_make_tuple->cast()->input(1); + auto tuple_getitem = make_tuple->cast()->input(1); + auto make_tuple1 = tuple_getitem->cast()->input(1); + auto trans_data1 = make_tuple1->cast()->input(1); + auto trans_data2 = make_tuple1->cast()->input(2); + EXPECT_TRUE(kg->IsInternalOutput(trans_data1, 0)); + EXPECT_TRUE(kg->IsInternalOutput(trans_data2, 0)); + + // remove trans op for internal output + auto graph_optimizer1 = std::make_shared(); + auto pass_manager1 = std::make_shared(); + auto remove_internal_output_trans_op_pass = std::make_shared(); + pass_manager1->AddPass(remove_internal_output_trans_op_pass); + graph_optimizer1->AddPassManager(pass_manager1); + auto new_g1 = graph_optimizer1->Optimize(new_g); + FuncGraphPtr g_after1 = getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_multiple_output", + "after_remove_internal_output_trans_op"); + EXPECT_TRUE(CheckEqualGraph(g_after1, new_g1)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/concat_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/concat_fission_test.cc new file mode 100644 index 0000000000..0198a99a59 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/concat_fission_test.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fission/concat_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWConcatFission : public BackendCommon { + public: + TestHWConcatFission() : get_py_fun_("gtest_input.pre_activate.concat_fission_test", true) {} + ~TestHWConcatFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 2; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 3; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 4; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 8; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} + +TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto concat_fission = std::make_shared(); + concat_fission->inputs_divisor_ = 9; + pm->AddPass(concat_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_9"); + EXPECT_NE(g_after, nullptr); + auto kg_after = GetKernelGraph(g_after, args_spec_list); + EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/pack_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/pack_fission_test.cc new file mode 100644 index 0000000000..d22e55c927 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/pack_fission_test.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fission/pack_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWPackFission : public BackendCommon { + public: + TestHWPackFission() : get_py_fun_("gtest_input.pre_activate.pack_fission_test", true) {} + ~TestHWPackFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWPackFission, test_pack_fission_divided_by_3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pack_fission = std::make_shared(); + pack_fission->inputs_divisor_ = 3; + pm->AddPass(pack_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_3"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWPackFission, test_pack_fission_divided_by_4) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 9; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pack_fission = std::make_shared(); + pack_fission->inputs_divisor_ = 4; + pm->AddPass(pack_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_4"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/reduce_min_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/reduce_min_fission_test.cc new file mode 100644 index 0000000000..e1cec41c96 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/reduce_min_fission_test.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "debug/anf_ir_dump.h" +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWOptReduceMinFission : public BackendCommon { + public: + TestHWOptReduceMinFission() : get_py_fun_("gtest_input.pre_activate.reduce_min_fission_test", true) {} + ~TestHWOptReduceMinFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWOptReduceMinFission, test_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reduce_min_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{32, 32, 32, 32}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + args_spec_list.push_back(x_abstract); + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto split_fission = std::make_shared(); + pm->AddPass(split_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reduce_min_fission", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc index 220e45f10a..f1a1dbe026 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc @@ -20,7 +20,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "debug/anf_ir_dump.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #define private public #define protected public @@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -58,7 +60,10 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); + } } }; @@ -74,6 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -81,6 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } } @@ -116,6 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetProcessor(kernel::Processor::AICORE); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); auto kernel_info = std::make_shared(); kernel_info->set_select_kernel_build_info(builder.Build()); transpose->set_kernel_info(kernel_info); @@ -162,6 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) { builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetProcessor(kernel::Processor::AICORE); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); auto kernel_info = std::make_shared(); kernel_info->set_select_kernel_build_info(builder.Build()); transpose->set_kernel_info(kernel_info); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index d156959c4c..5f16016d2c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -19,7 +19,7 @@ #include "runtime/device/kernel_info.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/oplib/oplib.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #define private public #define protected public #include "backend/optimizer/ascend/format_type/insert_trans_op.h" @@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } } @@ -93,6 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { EXPECT_NE(transpose, nullptr); KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); diff --git a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc index 12030433fc..e1efa3baaf 100644 --- a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc +++ b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc @@ -22,7 +22,7 @@ #include "common/common_test.h" #include "backend/optimizer/common/pattern_engine.h" #include "backend/optimizer/common/visit.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" #include "ir/anf.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc index 2a6904658e..69fd649f8c 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc @@ -146,7 +146,7 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator_split_membuf) { TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) { auto best_fit_mem_reuse = std::make_shared(); - auto size = best_fit_mem_reuse->AlignMemorySize(510); + auto size = best_fit_mem_reuse->AlignCommonMemorySize(510); ASSERT_EQ(size, 1024); } } // namespace memreuse diff --git a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc index 31ae923c0a..5bccf52077 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc @@ -26,7 +26,7 @@ #include "frontend/operator/ops.h" #include "utils/log_adapter.h" #include "backend/session/anf_runtime_algorithm.h" -#include "common/utils.h" +#include "utils/ms_utils.h" #include "pipeline/jit/resource.h" #include "backend/optimizer/mem_reuse/mem_reuse.h" @@ -225,7 +225,6 @@ TEST_F(TestMemReuseWithPy, KernelRef) { ASSERT_EQ(kernel_ref_count_ptr->size_, 512); KernelDefPtr kernel_def_ptr = std::make_shared(); ASSERT_NE(kernel_def_ptr, nullptr); - ASSERT_EQ(kernel_def_ptr->dirty, false); MembufPtr membuf_ptr = std::make_shared(); ASSERT_NE(membuf_ptr, nullptr); } diff --git a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc index 02e1865a82..35b767d4dd 100644 --- a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc @@ -26,7 +26,7 @@ #include "backend/optimizer/common/pass_manager.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc b/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc index cfcc34970b..bbe9d45d66 100644 --- a/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc +++ b/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc @@ -26,7 +26,7 @@ #include "backend/optimizer/pass/common_subexpression_elimination.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc index 25e4b3c111..59549d3c3b 100644 --- a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc +++ b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc @@ -22,7 +22,7 @@ #include "backend/optimizer/common/pass_manager.h" #include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" #include "utils/utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc index ac3272317a..2b595f1560 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc @@ -22,7 +22,7 @@ #include "backend/optimizer/common/pass_manager.h" #include "backend/optimizer/pass/convert_const_input_to_attr.h" #include "utils/utils.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc index 07bef7a042..60d99d2dad 100644 --- a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc @@ -27,10 +27,10 @@ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/pass_manager.h" #include "utils/utils.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #define private public #define protected public @@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect { ~MockEliminate5To4And4To5KernelSelect() override = default; void SelectKernel(const CNodePtr &cnode) override { KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); @@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); @@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}, {}}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); @@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index c5f25ca484..c3c848423b 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -20,7 +20,7 @@ #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/ops.h" #include "pipeline/pynative/pynative_execute.h" -#include "utils/context/ms_context.h" +#include "utils/ms_context.h" #include "utils/utils.h" namespace py = pybind11; @@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() { py::none py_none; py::args args = py::make_tuple(conv_obj, op_name, op_inputs); py::list args_input = args[PY_INPUTS]; - return GenerateOpExecInfo(args, &args_input); -} - -TEST_F(TestPynativeExecute, TestRunOpInVM) { - py::tuple result; - PynativeStatusCode status; - auto op_exec_info_ptr = ConstructOpExecInfo(); - result = pynative::RunOpInVM(op_exec_info_ptr, &status); - ASSERT_EQ(status, PYNATIVE_SUCCESS); -} - -TEST_F(TestPynativeExecute, TestRunOp) { - py::none py_none; - auto op_exec_info_ptr = ConstructOpExecInfo(); - py::tuple outputs = pynative::RunOp( - py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, op_exec_info_ptr->op_inputs)); - if (outputs.size() == 0) { - FAIL(); - } else { - SUCCEED(); - } + return GenerateOpExecInfo(args); } TEST_F(TestPynativeExecute, TestCreateContext) { diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 16c557adbe..161538f5ad 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -1130,17 +1130,17 @@ def test_adjust_allreduce_mul_add(tag): return fns[tag] -def test_indexed_slices(tag): +def test_row_tensor(tag): """ test_add_zero """ fns = FnDict() - make_indexed_slices = Primitive('MakeIndexedSlices') - indexed_slices_get_values = Primitive('IndexedSlicesGetValues') - indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') - indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') + make_row_tensor = Primitive('MakeRowTensor') + row_tensor_get_values = Primitive('RowTensorGetValues') + row_tensor_get_indices = Primitive('RowTensorGetIndices') + row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') @fns def before_get_indices(x, y, z): - return indexed_slices_get_indices(make_indexed_slices(x, y, z)) + return row_tensor_get_indices(make_row_tensor(x, y, z)) @fns def after_get_indices(x, y, z): @@ -1148,7 +1148,7 @@ def test_indexed_slices(tag): @fns def before_get_values(x, y, z): - return indexed_slices_get_values(make_indexed_slices(x, y, z)) + return row_tensor_get_values(make_row_tensor(x, y, z)) @fns def after_get_values(x, y, z): @@ -1156,7 +1156,42 @@ def test_indexed_slices(tag): @fns def before_get_dense_shape(x, y, z): - return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z)) + return row_tensor_get_dense_shape(make_row_tensor(x, y, z)) + + @fns + def after_get_dense_shape(x, y, z): + return z + + return fns[tag] + + +def test_sparse_tensor(tag): + """ test_add_zero """ + fns = FnDict() + make_sparse_tensor = Primitive('MakeSparseTensor') + sparse_tensor_get_values = Primitive('SparseTensorGetValues') + sparse_tensor_get_indices = Primitive('SparseTensorGetIndices') + sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape') + + @fns + def before_get_indices(x, y, z): + return sparse_tensor_get_indices(make_sparse_tensor(x, y, z)) + + @fns + def after_get_indices(x, y, z): + return x + + @fns + def before_get_values(x, y, z): + return sparse_tensor_get_values(make_sparse_tensor(x, y, z)) + + @fns + def after_get_values(x, y, z): + return y + + @fns + def before_get_dense_shape(x, y, z): + return sparse_tensor_get_dense_shape(make_sparse_tensor(x, y, z)) @fns def after_get_dense_shape(x, y, z): diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/concat_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/concat_fission_test.py new file mode 100644 index 0000000000..f1f01999eb --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/concat_fission_test.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P + +concat = P.Concat() + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_concat_fission(tag): + """ test_adam_apply_one_with_decay_rule """ + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, input7, input8): + return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8)) + + @fns + def after_divided_by_2(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1)) + b = concat((input2, input3)) + c = concat((input4, input5)) + d = concat((input6, input7)) + f = concat((a, b)) + g = concat((c, d)) + i = concat((f, g)) + return concat((i, input8)) + + @fns + def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1, input2)) + b = concat((input3, input4, input5)) + c = concat((input6, input7, input8)) + return concat((a, b, c)) + + @fns + def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1, input2, input3)) + b = concat((input4, input5, input6, input7)) + return concat((a, b, input8)) + + @fns + def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): + a = concat((input0, input1, input2, input3, input4, input5, input6, input7)) + return concat((a, input8)) + + @fns + def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): + return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8)) + + return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py index 61310d186f..444cf8282d 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py @@ -44,12 +44,14 @@ def test_getnext_memcpy_elimination(tag): res = get_next() res = memcpy_async_attr(res) res = cast(res) + res = add(res) return res @fns def after(): res = get_next() res = cast(res) + res = add(res) return res return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py index 7ffcfd0578..082c8144f5 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py @@ -17,6 +17,7 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P all_reduce = P.AllReduce() +broadcast = P.Broadcast(1) memcpy_async = Primitive('memcpy_async') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') @@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag): fns = FnDict() @fns - def before(a, b, c, d, e): - res1 = apply_momentun(a, b, c, d, e) - res2 = all_reduce(a) - res = control_depend(res1, res2) - res = make_tuple(res, res2) + def before(a, b): + x = relu(a) + y = all_reduce(b) + res = control_depend(x, y) return res @fns - def after(a, b, c, d, e): - res1 = apply_momentun(a, b, c, d, e) - res2 = memcpy_async(a) - res3 = all_reduce(res2) - res = control_depend(res1, res2) - res = make_tuple(res, res3) + def after(a, b): + x = relu(a) + y1 = memcpy_async(b) + y2 = all_reduce(y1) + res = control_depend(x, make_tuple(y1, y2)) + return make_tuple(res) + + return fns[tag] + + +def test_insert_memcpy_async_for_hccl_op_cond5(tag): + fns = FnDict() + + @fns + def before(a, b, c): + x = relu(a) + y = broadcast((b, c)) + res = control_depend(x, y) + return res + + @fns + def after(a, b, c): + x = relu(a) + m1 = memcpy_async(b) + m2 = memcpy_async(c) + y = broadcast(m1, m2) + res = control_depend(x, make_tuple(m1, m2, y)) return make_tuple(res) return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/pack_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/pack_fission_test.py new file mode 100644 index 0000000000..8678c6273c --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/pack_fission_test.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P +from mindspore.ops import Primitive + +pack = P.Pack() +concat = P.Concat() +make_tuple = Primitive('make_tuple') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_pack_fission(tag): + """ test_adam_apply_one_with_decay_rule """ + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, input7, input8): + return pack((input0, input1, input2, input3, input4, input5, input6, input7, input8)) + + @fns + def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): + pack1 = pack(input0, input1, input2) + pack2 = pack(input3, input4, input5) + pack3 = pack(input6, input7, input8) + return make_tuple(concat(pack1, pack2, pack3)) + + @fns + def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): + pack1 = pack(input0, input1, input2, input3) + pack2 = pack(input4, input5, input6, input7) + pack3 = pack(input8) + return make_tuple(concat(pack1, pack2, pack3)) + + return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/reduce_min_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/reduce_min_fission_test.py new file mode 100644 index 0000000000..7690023e01 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/reduce_min_fission_test.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +reduce_min = P.ReduceMin(keep_dims=False) +reduce_min1 = Primitive('ReduceMin') +reduce_min2 = Primitive('ReduceMin') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_reduce_min_fission(tag): + fns = FnDict() + + @fns + def before(x): + res = reduce_min(x, (2, 3)) + return res + + @fns + def after(x): + res = reduce_min1(x) + res = reduce_min2(res) + return make_tuple(res) + + return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py new file mode 100644 index 0000000000..0c02864816 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +tuple_getitem = Primitive('tuple_getitem') +add = P.TensorAdd() +max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) +make_tuple = Primitive('make_tuple') +trans_data = Primitive("TransData") + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_remove_internal_output_trans_op_for_single_output(tag): + fns = FnDict() + + @fns + def before(x, y): + res = add(x, y) + return res + + @fns + def after_insert_trans_op(x, y): + output = add(x, y) + res = trans_data(output) + return make_tuple(res) + + @fns + def after_remove_internal_output_trans_op(x, y): + res = add(x, y) + return make_tuple(res) + + return fns[tag] + + +def test_remove_internal_output_trans_op_for_multiple_output(tag): + fns = FnDict() + + @fns + def before(x): + max_pool_res = max_pool(x) + res = make_tuple(tuple_getitem(max_pool_res, 0), tuple_getitem(max_pool_res, 1)) + return res + + @fns + def after_insert_trans_op(x): + output = max_pool(x) + trans_data0 = trans_data(tuple_getitem(output, 0)) + trans_data1 = trans_data(tuple_getitem(output, 1)) + new_make_tuple = make_tuple(trans_data0, trans_data1) + res = make_tuple(tuple_getitem(new_make_tuple, 0), tuple_getitem(new_make_tuple, 1)) + return make_tuple(res) + + @fns + def after_remove_internal_output_trans_op(x): + output = max_pool(x) + new_make_tuple = make_tuple(tuple_getitem(output, 0), tuple_getitem(output, 1)) + res = make_tuple(tuple_getitem(new_make_tuple, 0), tuple_getitem(new_make_tuple, 1)) + return make_tuple(res) + + return fns[tag] diff --git a/tests/ut/cpp/serving/CMakeLists.txt b/tests/ut/cpp/serving/CMakeLists.txt new file mode 100644 index 0000000000..42c682e002 --- /dev/null +++ b/tests/ut/cpp/serving/CMakeLists.txt @@ -0,0 +1,90 @@ +find_package(Threads REQUIRED) + +# This branch assumes that gRPC and all its dependencies are already installed +# on this system, so they can be located by find_package(). + +# Find Protobuf installation +# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. + +#set(protobuf_MODULE_COMPATIBLE TRUE) +#find_package(Protobuf CONFIG REQUIRED) +#message(STATUS "Using protobuf ${protobuf_VERSION}") +add_library(protobuf::libprotobuf ALIAS protobuf::protobuf) +add_executable(protobuf::libprotoc ALIAS protobuf::protoc) + +set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) +if (CMAKE_CROSSCOMPILING) + find_program(_PROTOBUF_PROTOC protoc) +else () + set(_PROTOBUF_PROTOC $) +endif () + +# Find gRPC installation +# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. +if (EXISTS ${grpc_ROOT}/lib64) + set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc") +else () + set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc") +endif () +message("serving ut using grpc_DIR : " ${gPRC_DIR}) + +find_package(gRPC CONFIG REQUIRED) +message(STATUS "Using gRPC ${gRPC_VERSION}") + +set(_GRPC_GRPCPP gRPC::grpc++) +set(_REFLECTION gRPC::grpc++_reflection) + +if (CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + find_program(_GRPC_PYTHON_PLUGIN_EXECUTABLE grpc_python_plugin) +else () + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + set(_GRPC_PYTHON_PLUGIN_EXECUTABLE $) +endif () + +# Proto file +get_filename_component(hw_proto "ms_service.proto" ABSOLUTE) +get_filename_component(hw_proto_path ${hw_proto} PATH) +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h") +set(hw_py_pb2 "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2.py") +set(hw_py_pb2_grpc "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2_grpc.py") +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" "${hw_py_pb2}" "${hw_py_pb2_grpc}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --python_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_PYTHON_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") + +list(APPEND SERVING_SRC_TEST ${hw_proto_srcs} ${hw_grpc_srcs}) + +file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "../../../../serving/acl/*.cc" + "../../../../serving/core/*.cc") +list(APPEND SERVING_SRC_TEST ${ACL_SESSION_SRC_LIST}) + +# utest files +file(GLOB_RECURSE ACL_UTEST_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +list(APPEND SERVING_SRC_TEST ${ACL_UTEST_SRC_LIST}) + +include_directories(${CMAKE_SOURCE_DIR}/serving/core) +include_directories(${CMAKE_SOURCE_DIR}/serving/acl) +include_directories(${CMAKE_SOURCE_DIR}/serving) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/../) +add_library(ut_serving_obj OBJECT ${SERVING_SRC_TEST}) + + diff --git a/tests/ut/cpp/serving/acl/acl.h b/tests/ut/cpp/serving/acl/acl.h new file mode 100644 index 0000000000..0719ac8a0c --- /dev/null +++ b/tests/ut/cpp/serving/acl/acl.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACL_STUB_INC_ACL_H +#define ACL_STUB_INC_ACL_H +#include "acl_base.h" +#include "acl_mdl.h" +#include "acl_rt.h" + +aclError aclInit(const char *configPath); +aclError aclFinalize(); + +#endif // ACL_STUB_INC_ACL_H \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl/acl_base.h b/tests/ut/cpp/serving/acl/acl_base.h new file mode 100644 index 0000000000..3bab6a4a7b --- /dev/null +++ b/tests/ut/cpp/serving/acl/acl_base.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACL_STUB_INC_ACL_BASE +#define ACL_STUB_INC_ACL_BASE +#include +#include + +typedef void *aclrtStream; +typedef void *aclrtEvent; +typedef void *aclrtContext; +typedef int aclError; +typedef uint16_t aclFloat16; +typedef struct aclDataBuffer aclDataBuffer; +typedef struct aclTensorDesc aclTensorDesc; + +const int ACL_ERROR_NONE = 0; + +typedef enum { + ACL_DT_UNDEFINED = -1, + ACL_FLOAT = 0, + ACL_FLOAT16 = 1, + ACL_INT8 = 2, + ACL_INT32 = 3, + ACL_UINT8 = 4, + ACL_INT16 = 6, + ACL_UINT16 = 7, + ACL_UINT32 = 8, + ACL_INT64 = 9, + ACL_UINT64 = 10, + ACL_DOUBLE = 11, + ACL_BOOL = 12, +} aclDataType; + +typedef enum { + ACL_FORMAT_UNDEFINED = -1, + ACL_FORMAT_NCHW = 0, + ACL_FORMAT_NHWC = 1, + ACL_FORMAT_ND = 2, + ACL_FORMAT_NC1HWC0 = 3, + ACL_FORMAT_FRACTAL_Z = 4, + ACL_FORMAT_FRACTAL_NZ = 29, + +} aclFormat; + +typedef enum { + ACL_DEBUG, + ACL_INFO, + ACL_WARNING, + ACL_ERROR, +} aclLogLevel; + +aclDataBuffer *aclCreateDataBuffer(void *data, size_t size); +aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffer); +void *aclGetDataBufferAddr(const aclDataBuffer *dataBuffer); +uint32_t aclGetDataBufferSize(const aclDataBuffer *dataBuffer); +size_t aclDataTypeSize(aclDataType dataType); + +aclTensorDesc *aclCreateTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); +void aclDestroyTensorDesc(const aclTensorDesc *desc); +aclDataType aclGetTensorDescType(const aclTensorDesc *desc); +aclFormat aclGetTensorDescFormat(const aclTensorDesc *desc); +size_t aclGetTensorDescSize(const aclTensorDesc *desc); +size_t aclGetTensorDescElementCount(const aclTensorDesc *desc); +size_t aclGetTensorDescNumDims(const aclTensorDesc *desc); +int64_t aclGetTensorDescDim(const aclTensorDesc *desc, size_t index); + +void aclAppLog(aclLogLevel logLevel, const char *func, const char *file, uint32_t line, const char *fmt, ...); + +#define ACL_APP_LOG(level, fmt, ...) aclAppLog(level, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__) + +#endif \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl/acl_mdl.h b/tests/ut/cpp/serving/acl/acl_mdl.h new file mode 100644 index 0000000000..ccec700f59 --- /dev/null +++ b/tests/ut/cpp/serving/acl/acl_mdl.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACL_STUB_INC_ACL_MDL +#define ACL_STUB_INC_ACL_MDL +#include "acl_base.h" + +#define ACL_MAX_DIM_CNT 128 +#define ACL_MAX_TENSOR_NAME_LEN 128 +#define ACL_MAX_BATCH_NUM 128 +#define ACL_MAX_HW_NUM 128 +#define ACL_MAX_SHAPE_COUNT 128 + +typedef struct aclmdlDataset aclmdlDataset; +typedef struct aclmdlDesc aclmdlDesc; + +typedef struct aclmdlIODims { + char name[ACL_MAX_TENSOR_NAME_LEN]; + size_t dimCount; + int64_t dims[ACL_MAX_DIM_CNT]; +} aclmdlIODims; + +aclmdlDesc *aclmdlCreateDesc(); +aclError aclmdlDestroyDesc(aclmdlDesc *modelDesc); +aclError aclmdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId); + +size_t aclmdlGetNumInputs(aclmdlDesc *modelDesc); +size_t aclmdlGetNumOutputs(aclmdlDesc *modelDesc); +size_t aclmdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index); +size_t aclmdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index); + +aclmdlDataset *aclmdlCreateDataset(); +aclError aclmdlDestroyDataset(const aclmdlDataset *dataSet); +aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataSet, aclDataBuffer *dataBuffer); +size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *dataSet); +aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataSet, size_t index); + +aclError aclmdlLoadFromFile(const char *modelPath, uint32_t *modelId); +aclError aclmdlLoadFromMem(const void *model, size_t modelSize, uint32_t *modelId); +aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, size_t workSize, + void *weightPtr, size_t weightSize); +aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, uint32_t *modelId, void *workPtr, + size_t workSize, void *weightPtr, size_t weightSize); + +aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output); +aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, aclrtStream stream); +aclError aclmdlUnload(uint32_t modelId); + +aclError aclmdlQuerySize(const char *fileName, size_t *workSize, size_t *weightSize); +aclError aclmdlQuerySizeFromMem(const void *model, size_t modelSize, size_t *workSize, size_t *weightSize); + +aclError aclmdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims); +aclError aclmdlGetOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims); +aclError aclmdlGetCurOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims); + +aclFormat aclmdlGetInputFormat(const aclmdlDesc *modelDesc, size_t index); +aclFormat aclmdlGetOutputFormat(const aclmdlDesc *modelDesc, size_t index); + +aclDataType aclmdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index); +aclDataType aclmdlGetOutputDataType(const aclmdlDesc *modelDesc, size_t index); + +#endif \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl/acl_rt.h b/tests/ut/cpp/serving/acl/acl_rt.h new file mode 100644 index 0000000000..e36e4cf45a --- /dev/null +++ b/tests/ut/cpp/serving/acl/acl_rt.h @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACL_STUB_INC_ACL_RT_H +#define ACL_STUB_INC_ACL_RT_H +#include "acl_base.h" + +typedef enum aclrtRunMode { + ACL_DEVICE, + ACL_HOST, +} aclrtRunMode; + +typedef enum aclrtTsId { + ACL_TS_ID_AICORE, + ACL_TS_ID_AIVECTOR, + ACL_TS_ID_RESERVED, +} aclrtTsId; + +typedef enum aclrtEventStatus { + ACL_EVENT_STATUS_COMPLETE, + ACL_EVENT_STATUS_NOT_READY, + ACL_EVENT_STATUS_RESERVED, +} aclrtEventStatus; + +typedef enum aclrtCallbackBlockType { + ACL_CALLBACK_NO_BLOCK, + ACL_CALLBACK_BLOCK, +} aclrtCallbackBlockType; + +typedef enum aclrtMemcpyKind { + ACL_MEMCPY_HOST_TO_HOST, + ACL_MEMCPY_HOST_TO_DEVICE, + ACL_MEMCPY_DEVICE_TO_HOST, + ACL_MEMCPY_DEVICE_TO_DEVICE, +} aclrtMemcpyKind; + +typedef enum aclrtMemMallocPolicy { + ACL_MEM_MALLOC_HUGE_FIRST, + ACL_MEM_MALLOC_HUGE_ONLY, + ACL_MEM_MALLOC_NORMAL_ONLY, +} aclrtMemMallocPolicy; + +typedef struct rtExceptionInfo aclrtExceptionInfo; +typedef void (*aclrtCallback)(void *userData); +typedef void (*aclrtExceptionInfoCallback)(aclrtExceptionInfo *exceptionInfo); + +aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId); +aclError aclrtDestroyContext(aclrtContext context); +aclError aclrtSetCurrentContext(aclrtContext context); +aclError aclrtGetCurrentContext(aclrtContext *context); +aclError aclrtSetDevice(int32_t deviceId); +aclError aclrtResetDevice(int32_t deviceId); +aclError aclrtGetDevice(int32_t *deviceId); +aclError aclrtGetRunMode(aclrtRunMode *runMode); +aclError aclrtSynchronizeDevice(void); +aclError aclrtSetTsDevice(aclrtTsId tsId); +aclError aclrtGetDeviceCount(uint32_t *count); + +aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy); +aclError aclrtFree(void *devPtr); + +aclError aclrtMallocHost(void **hostPtr, size_t size); +aclError aclrtFreeHost(void *hostPtr); + +aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind); +aclError aclrtMemset(void *devPtr, size_t maxCount, int32_t value, size_t count); +aclError aclrtMemcpyAsync(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind, + aclrtStream stream); +aclError aclrtMemsetAsync(void *devPtr, size_t maxCount, int32_t value, size_t count, aclrtStream stream); + +aclError aclrtCreateStream(aclrtStream *stream); +aclError aclrtDestroyStream(aclrtStream stream); +aclError aclrtSynchronizeStream(aclrtStream stream); +aclError aclrtStreamWaitEvent(aclrtStream stream, aclrtEvent event); + +#endif \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl/ops/acl_dvpp.h b/tests/ut/cpp/serving/acl/ops/acl_dvpp.h new file mode 100644 index 0000000000..b66ca8ec21 --- /dev/null +++ b/tests/ut/cpp/serving/acl/ops/acl_dvpp.h @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACL_STUB_INC_ACL_DVPP_H +#define ACL_STUB_INC_ACL_DVPP_H +#include +#include "acl/acl.h" +#include "acl/acl_base.h" + +typedef struct acldvppPicDesc acldvppPicDesc; +typedef struct acldvppRoiConfig acldvppRoiConfig; +typedef struct acldvppResizeConfig acldvppResizeConfig; +typedef struct acldvppChannelDesc acldvppChannelDesc; +typedef struct acldvppStreamDesc acldvppStreamDesc; +typedef struct acldvppBatchPicDesc acldvppBatchPicDesc; + +enum acldvppPixelFormat { + PIXEL_FORMAT_YUV_400 = 0, + PIXEL_FORMAT_YUV_SEMIPLANAR_420 = 1, // YUV + PIXEL_FORMAT_YVU_SEMIPLANAR_420 = 2, // YVU + PIXEL_FORMAT_YUV_SEMIPLANAR_422 = 3, // YUV + PIXEL_FORMAT_YVU_SEMIPLANAR_422 = 4, // YVU + PIXEL_FORMAT_YUV_SEMIPLANAR_444 = 5, // YUV + PIXEL_FORMAT_YVU_SEMIPLANAR_444 = 6, // YVU + +}; + +enum acldvppStreamFormat { + H265_MAIN_LEVEL = 0, + H254_BASELINE_LEVEL = 1, + H254_MAIN_LEVEL, + H254_HIGH_LEVEL, +}; + +enum acldvppChannelMode { DVPP_CHNMODE_VPC = 1, DVPP_CHNMODE_JPEGD = 2, DVPP_CHNMODE_JPEGE = 4 }; + +aclError acldvppMalloc(void **devPtr, size_t size); +aclError acldvppFree(void *devPtr); +acldvppChannelDesc *acldvppCreateChannelDesc(); +aclError acldvppDestroyChannelDesc(acldvppChannelDesc *channelDesc); +acldvppPicDesc *acldvppCreatePicDesc(); +aclError acldvppDestroyPicDesc(acldvppPicDesc *picDesc); +aclError acldvppSetPicDescSize(acldvppPicDesc *picDesc, uint32_t size); +aclError acldvppSetPicDescFormat(acldvppPicDesc *picDesc, acldvppPixelFormat format); +aclError acldvppSetPicDescWidth(acldvppPicDesc *picDesc, uint32_t width); +aclError acldvppSetPicDescHeight(acldvppPicDesc *picDesc, uint32_t height); +aclError acldvppSetPicDescData(acldvppPicDesc *picDesc, void *dataDev); +aclError acldvppSetPicDescWidthStride(acldvppPicDesc *picDesc, uint32_t widthStride); +aclError acldvppSetPicDescHeightStride(acldvppPicDesc *picDesc, uint32_t heightStride); +aclError acldvppSetPicDescRetCode(acldvppPicDesc *picDesc, uint32_t retCode); + +uint32_t acldvppGetPicDescSize(acldvppPicDesc *picDesc); +acldvppPixelFormat acldvppGetPicDescFormat(acldvppPicDesc *picDesc); +uint32_t acldvppGetPicDescWidth(acldvppPicDesc *picDesc); +uint32_t acldvppGetPicDescHeight(acldvppPicDesc *picDesc); +void *acldvppGetPicDescData(acldvppPicDesc *picDesc); +uint32_t acldvppGetPicDescWidthStride(acldvppPicDesc *picDesc); +uint32_t acldvppGetPicDescHeightStride(acldvppPicDesc *picDesc); +uint32_t acldvppGetPicDescRetCode(acldvppPicDesc *picDesc); + +acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom); +aclError acldvppDestroyRoiConfig(acldvppRoiConfig *roiConfig); +aclError acldvppSetRoiConfigLeft(acldvppRoiConfig *roiConfig, uint32_t left); +aclError acldvppSetRoiConfigRight(acldvppRoiConfig *roiConfig, uint32_t right); +aclError acldvppSetRoiConfigTop(acldvppRoiConfig *roiConfig, uint32_t top); +aclError acldvppSetRoiConfigBottom(acldvppRoiConfig *roiConfig, uint32_t bottom); +aclError acldvppSetRoiConfig(acldvppRoiConfig *roiConfig, uint32_t left, uint32_t right, uint32_t top, uint32_t bottom); + +acldvppResizeConfig *acldvppCreateResizeConfig(); +aclError acldvppDestroyResizeConfig(acldvppResizeConfig *resizeConfig); + +aclError acldvppJpegPredictDecSize(const void *data, uint32_t dataSize, acldvppPixelFormat ouputPixelFormat, + uint32_t *decSize); + +aclError acldvppCreateChannel(acldvppChannelDesc *channelDesc); +aclError acldvppDestroyChannel(acldvppChannelDesc *channelDesc); + +aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, + acldvppResizeConfig *resizeConfig, aclrtStream stream); + +aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, + acldvppRoiConfig *cropArea, aclrtStream stream); + +aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, + acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, + acldvppRoiConfig *pasteArea, aclrtStream stream); + +aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchDesc, uint32_t *roiNums, + uint32_t size, acldvppBatchPicDesc *dstBatchDesc, acldvppRoiConfig *cropAreas[], + aclrtStream stream); + +aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size, + acldvppPicDesc *outputDesc, aclrtStream stream); + +acldvppBatchPicDesc *acldvppCreateBatchPicDesc(uint32_t batchSize); +acldvppPicDesc *acldvppGetPicDesc(acldvppBatchPicDesc *batchPicDesc, uint32_t index); +aclError acldvppDestroyBatchPicDesc(acldvppBatchPicDesc *batchPicDesc); + +#endif // ACL_STUB_INC_ACL_DVPP_H \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_session_test_add.cc b/tests/ut/cpp/serving/acl_session_test_add.cc new file mode 100644 index 0000000000..3156798037 --- /dev/null +++ b/tests/ut/cpp/serving/acl_session_test_add.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl_session_test_common.h" + +using namespace std; + +namespace mindspore { +namespace serving { + +class AclSessionAddTest : public AclSessionTest { + public: + AclSessionAddTest() = default; + void SetUp() override { + AclSessionTest::SetUp(); + aclmdlDesc model_desc; + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + model_desc.outputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + mock_model_desc_ = MockModelDesc(model_desc); + g_acl_model_desc = &mock_model_desc_; + g_acl_model = &add_mock_model_; + } + void CreateDefaultRequest(PredictRequest &request) { + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + auto input1 = request.add_data(); + CreateTensor(*input1, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + + auto input0_data = reinterpret_cast(input0->mutable_data()->data()); + auto input1_data = reinterpret_cast(input1->mutable_data()->data()); + for (int i = 0; i < 2 * 24 * 24 * 3; i++) { + input0_data[i] = i % 1024; + input1_data[i] = i % 1024 + 1; + } + } + + void CheckDefaultReply(const PredictReply &reply) { + EXPECT_TRUE(reply.result().size() == 1); + if (reply.result().size() == 1) { + CheckTensorItem(reply.result(0), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + auto &output = reply.result(0).data(); + EXPECT_EQ(output.size(), 2 * 24 * 24 * 3 * sizeof(float)); + if (output.size() == 2 * 24 * 24 * 3 * sizeof(float)) { + auto output_data = reinterpret_cast(output.data()); + for (int i = 0; i < 2 * 24 * 24 * 3; i++) { + EXPECT_EQ(output_data[i], (i % 1024) + (i % 1024 + 1)); + if (output_data[i] != (i % 1024) + (i % 1024 + 1)) { + break; + } + } + } + } + } + MockModelDesc mock_model_desc_; + AddMockAclModel add_mock_model_; +}; + +TEST_F(AclSessionAddTest, TestAclSession_OneTime_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionAddTest, TestAclSession_MutilTimes_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + for (int i = 0; i < 10; i++) { + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionAddTest, TestAclSession_DeviceRunMode_OneTime_Success) { + SetDeviceRunMode(); + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionAddTest, TestAclSession_DeviceRunMode_MutilTimes_Success) { + SetDeviceRunMode(); + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + for (int i = 0; i < 10; i++) { + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +} // namespace serving +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_session_test_common.h b/tests/ut/cpp/serving/acl_session_test_common.h new file mode 100644 index 0000000000..39a0514bca --- /dev/null +++ b/tests/ut/cpp/serving/acl_session_test_common.h @@ -0,0 +1,191 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ACL_SESSION_TEST_COMMON_H +#define MINDSPORE_ACL_SESSION_TEST_COMMON_H + +#include "common/common_test.h" +#include "serving/core/server.h" +#include "include/inference.h" +#include "include/infer_tensor.h" +#include "serving/core/serving_tensor.h" +#include "serving/acl/acl_session.h" +#include "serving/acl/model_process.h" +#include "serving/acl/dvpp_process.h" +#include "acl_stub.h" + +class MockDeviceRunMode : public AclRunMode { + public: + aclError aclrtGetRunMode(aclrtRunMode *runMode) override { + *runMode = aclrtRunMode::ACL_DEVICE; + return ACL_ERROR_NONE; + } +}; + +class AclSessionTest : public testing::Test { + public: + AclSessionTest() = default; + void SetUp() override { + g_acl_data_buffer = &g_acl_data_buffer_default; + g_acl_env = &g_acl_env_default; + g_acl_dataset = &g_acl_dataset_default; + g_acl_model = &g_acl_model_default; + g_acl_model_desc = &g_acl_model_desc_default; + g_acl_device_context_stream = &g_acl_device_context_stream_default; + g_acl_memory = &g_acl_memory_default; + g_acl_dvpp_pic_desc = &g_acl_dvpp_pic_desc_default; + g_acl_dvpp_roi_config = &g_acl_dvpp_roi_config_default; + g_acl_dvpp_resize_config = &g_acl_dvpp_resize_config_default; + g_acl_dvpp_channel_desc = &g_acl_dvpp_channel_desc_default; + g_acl_dvpp_process = &g_acl_dvpp_process_default; + g_acl_run_mode = &acl_run_mode_default; + g_acl_jpeg_lib = &acl_jpeg_lib_default; + } + void TearDown() override { + EXPECT_TRUE(g_acl_data_buffer->Check()); + EXPECT_TRUE(g_acl_env->Check()); + EXPECT_TRUE(g_acl_dataset->Check()); + EXPECT_TRUE(g_acl_model->Check()); + EXPECT_TRUE(g_acl_model_desc->Check()); + EXPECT_TRUE(g_acl_device_context_stream->Check()); + EXPECT_TRUE(g_acl_memory->Check()); + EXPECT_TRUE(g_acl_dvpp_pic_desc->Check()); + EXPECT_TRUE(g_acl_dvpp_roi_config->Check()); + EXPECT_TRUE(g_acl_dvpp_resize_config->Check()); + EXPECT_TRUE(g_acl_dvpp_channel_desc->Check()); + EXPECT_TRUE(g_acl_dvpp_process->Check()); + EXPECT_TRUE(g_acl_jpeg_lib->Check()); + } + + AclDataBuffer g_acl_data_buffer_default; + AclEnv g_acl_env_default; + AclDataSet g_acl_dataset_default; + AclModel g_acl_model_default; + AclModelDesc g_acl_model_desc_default; + AclDeviceContextStream g_acl_device_context_stream_default; + AclMemory g_acl_memory_default; + AclDvppPicDesc g_acl_dvpp_pic_desc_default; + AclDvppRoiConfig g_acl_dvpp_roi_config_default; + AclDvppResizeConfig g_acl_dvpp_resize_config_default; + AclDvppChannelDesc g_acl_dvpp_channel_desc_default; + AclDvppProcess g_acl_dvpp_process_default; + AclRunMode acl_run_mode_default; + MockDeviceRunMode acl_device_run_mode; + AclJpegLib acl_jpeg_lib_default = AclJpegLib(0, 0); + + void SetDeviceRunMode() { g_acl_run_mode = &acl_device_run_mode; } + void CreateTensor(ms_serving::Tensor &tensor, const std::vector &shape, ms_serving::DataType data_type, + std::size_t data_size = INT64_MAX) { + if (data_size == INT64_MAX) { + data_size = GetDataTypeSize(data_type); + for (auto item : shape) { + data_size *= item; + } + } + tensor.set_data(std::string(data_size, 0)); + tensor.set_tensor_type(data_type); + auto tensor_shape = tensor.mutable_tensor_shape(); + for (auto item : shape) { + tensor_shape->add_dims(item); + } + } + + size_t GetDataTypeSize(ms_serving::DataType data_type) { + const std::map type_size_map{ + {ms_serving::DataType::MS_BOOL, sizeof(bool)}, {ms_serving::DataType::MS_INT8, sizeof(int8_t)}, + {ms_serving::DataType::MS_UINT8, sizeof(uint8_t)}, {ms_serving::DataType::MS_INT16, sizeof(int16_t)}, + {ms_serving::DataType::MS_UINT16, sizeof(uint16_t)}, {ms_serving::DataType::MS_INT32, sizeof(int32_t)}, + {ms_serving::DataType::MS_UINT32, sizeof(uint32_t)}, {ms_serving::DataType::MS_INT64, sizeof(int64_t)}, + {ms_serving::DataType::MS_UINT64, sizeof(uint64_t)}, {ms_serving::DataType::MS_FLOAT16, 2}, + {ms_serving::DataType::MS_FLOAT32, sizeof(float)}, {ms_serving::DataType::MS_FLOAT64, sizeof(double)}, + }; + auto it = type_size_map.find(data_type); + if (it == type_size_map.end()) { + EXPECT_TRUE(false); + return 0; + } + return it->second; + } + + void CheckTensorItem(const ms_serving::Tensor &tensor, const std::vector &expect_shape, + ms_serving::DataType expect_data_type) { + std::vector tensor_shape; + for (auto item : tensor.tensor_shape().dims()) { + tensor_shape.push_back(item); + } + EXPECT_EQ(expect_shape, tensor_shape); + EXPECT_EQ(expect_data_type, tensor.tensor_type()); + int64_t elem_cnt = 1; + for (auto item : expect_shape) { + elem_cnt *= item; + } + auto data_size = GetDataTypeSize(expect_data_type); + EXPECT_EQ(data_size * elem_cnt, tensor.data().size()); + } +}; + +class MockModelDesc : public AclModelDesc { + public: + MockModelDesc() {} + MockModelDesc(const aclmdlDesc &mock_model_desc) : mock_model_desc_(mock_model_desc) {} + aclmdlDesc *aclmdlCreateDesc() override { + aclmdlDesc *model_desc = AclModelDesc::aclmdlCreateDesc(); + *model_desc = mock_model_desc_; + return model_desc; + } + aclmdlDesc mock_model_desc_; +}; + +class AddMockAclModel : public AclModel { + public: + aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) override { + if (AclModel::aclmdlExecute(modelId, input, output) != ACL_ERROR_NONE) { + return 1; + } + if (input->data_buffers.size() != 2) { + return 1; + } + auto &input0 = input->data_buffers[0]; + auto &input1 = input->data_buffers[1]; + std::size_t expect_count = input0->size / sizeof(float); + if (input0->size != expect_count * sizeof(float) || input1->size != expect_count * sizeof(float)) { + return 1; + } + + if (output->data_buffers.size() != 1) { + return 1; + } + auto &output0 = output->data_buffers[0]; + if (output0->size != expect_count * sizeof(float)) { + return 1; + } + + auto input0_data = reinterpret_cast(input0->data); + auto input1_data = reinterpret_cast(input1->data); + auto output0_data = reinterpret_cast(output0->data); + for (size_t i = 0; i < expect_count; i++) { + output0_data[i] = input0_data[i] + input1_data[i]; + } + return ACL_ERROR_NONE; + } + + aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, + aclrtStream stream) override { + return aclmdlExecute(modelId, input, output); + } +}; + +#endif // MINDSPORE_ACL_SESSION_TEST_COMMON_H diff --git a/tests/ut/cpp/serving/acl_session_test_dvpp.cc b/tests/ut/cpp/serving/acl_session_test_dvpp.cc new file mode 100644 index 0000000000..6a3f01e02c --- /dev/null +++ b/tests/ut/cpp/serving/acl_session_test_dvpp.cc @@ -0,0 +1,1055 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl_session_test_common.h" +#include +#include +#include + +using namespace std; +using namespace mindspore::inference; + +namespace mindspore { +namespace serving { + +class MockDvppProces : public AclDvppProcess { + public: + aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, + acldvppResizeConfig *resizeConfig, aclrtStream stream) override { + if (resize_fail_list_.empty()) { + return AclDvppProcess::acldvppVpcResizeAsync(channelDesc, inputDesc, outputDesc, resizeConfig, stream); + } + bool val = resize_fail_list_.front(); + resize_fail_list_.erase(resize_fail_list_.begin()); + if (!val) { + return 1; + } + return AclDvppProcess::acldvppVpcResizeAsync(channelDesc, inputDesc, outputDesc, resizeConfig, stream); + } + aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, + acldvppRoiConfig *cropArea, aclrtStream stream) override { + if (crop_fail_list_.empty()) { + return AclDvppProcess::acldvppVpcCropAsync(channelDesc, inputDesc, outputDesc, cropArea, stream); + } + bool val = crop_fail_list_.front(); + crop_fail_list_.erase(crop_fail_list_.begin()); + if (!val) { + return 1; + } + return AclDvppProcess::acldvppVpcCropAsync(channelDesc, inputDesc, outputDesc, cropArea, stream); + } + aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, + acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, + acldvppRoiConfig *pasteArea, aclrtStream stream) override { + if (crop_and_paste_fail_list_.empty()) { + return AclDvppProcess::acldvppVpcCropAndPasteAsync(channelDesc, inputDesc, outputDesc, cropArea, pasteArea, + stream); + } + bool val = crop_and_paste_fail_list_.front(); + crop_and_paste_fail_list_.erase(crop_and_paste_fail_list_.begin()); + if (!val) { + return 1; + } + return AclDvppProcess::acldvppVpcCropAndPasteAsync(channelDesc, inputDesc, outputDesc, cropArea, pasteArea, stream); + } + aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size, + acldvppPicDesc *outputDesc, aclrtStream stream) override { + if (decode_fail_list_.empty()) { + return AclDvppProcess::acldvppJpegDecodeAsync(channelDesc, data, size, outputDesc, stream); + } + bool val = decode_fail_list_.front(); + decode_fail_list_.erase(decode_fail_list_.begin()); + if (!val) { + return 1; + } + return AclDvppProcess::acldvppJpegDecodeAsync(channelDesc, data, size, outputDesc, stream); + } + vector decode_fail_list_; + vector resize_fail_list_; + vector crop_fail_list_; + vector crop_and_paste_fail_list_; +}; + +class AclSessionDvppTest : public AclSessionTest { + public: + AclSessionDvppTest() = default; + void SetUp() override { AclSessionTest::SetUp(); } + void InitModelDesc(uint32_t batch_size) { + batch_size_ = batch_size; + aclmdlDesc model_desc; + model_desc.inputs.push_back( // 32-> 16 align, 24->2 align + AclTensorDesc{ + .dims = {batch_size_, 32, 24, 3}, .data_type = ACL_FLOAT, .size = batch_size_ * 32 * 24 * 3 / 2}); // YUV420SP + + model_desc.outputs.push_back(AclTensorDesc{ + .dims = {batch_size_, 24, 24, 3}, .data_type = ACL_FLOAT, .size = batch_size_ * 24 * 24 * 3 * sizeof(float)}); + + model_desc.outputs.push_back(AclTensorDesc{ + .dims = {batch_size_, 24, 24, 3}, .data_type = ACL_FLOAT, .size = batch_size_ * 24 * 24 * 3 * sizeof(float)}); + + mock_model_desc_ = MockModelDesc(model_desc); + g_acl_model_desc = &mock_model_desc_; + g_acl_dvpp_process = &mock_dvpp_process_; + } + void TearDown() override { + AclSessionTest::TearDown(); + remove(dvpp_config_file_path_.c_str()); + } + void CreateDefaultRequest(PredictRequest &request, uint32_t image_size = 1) { + auto input0 = request.add_images(); + for (uint32_t i = 0; i < batch_size_; i++) { + input0->add_images(std::string(image_size, '\0')); // any length data + } + } + + void CheckDefaultReply(const PredictReply &reply) { + EXPECT_TRUE(reply.result().size() == 2); + if (reply.result().size() == 2) { + CheckTensorItem(reply.result(0), {batch_size_, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + CheckTensorItem(reply.result(1), {batch_size_, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + } + } + + void WriteDvppConfig(const std::string &dvpp_config_context) { + std::ofstream fp(dvpp_config_file_path_); + ASSERT_TRUE(fp.is_open()); + if (fp.is_open()) { + fp << dvpp_config_context; + } + } + + void SetJpegLib(uint32_t image_width, uint32_t image_height, J_COLOR_SPACE color_space = JCS_YCbCr) { + acl_jpeg_lib_default.image_width_ = image_width; + acl_jpeg_lib_default.image_height_ = image_height; + acl_jpeg_lib_default.color_space_ = color_space; + } + + void CreateDvppConfig() { + nlohmann::json dvpp_config; + auto &preprocess_list = dvpp_config["preprocess"]; + auto &preprocess = preprocess_list[0]; + preprocess["input"]["index"] = 0; + preprocess["decode_para"]["out_pixel_format"] = pixel_format_; + auto &dvpp_process = preprocess["dvpp_process"]; + + if (to_resize_flag_) { + dvpp_process["op_name"] = "resize"; + dvpp_process["out_width"] = resize_para_.output_width; + dvpp_process["out_height"] = resize_para_.output_height; + } else if (to_crop_flag_) { + auto &crop_info = crop_para_.crop_info; + auto &crop_area = crop_info.crop_area; + dvpp_process["op_name"] = "crop"; + dvpp_process["out_width"] = crop_para_.output_width; + dvpp_process["out_height"] = crop_para_.output_height; + if (crop_info.crop_type == kDvppCropTypeOffset) { + dvpp_process["crop_type"] = "offset"; + dvpp_process["crop_left"] = crop_area.left; + dvpp_process["crop_top"] = crop_area.top; + dvpp_process["crop_right"] = crop_area.right; + dvpp_process["crop_bottom"] = crop_area.bottom; + } else { + dvpp_process["crop_type"] = "centre"; + dvpp_process["crop_width"] = crop_info.crop_width; + dvpp_process["crop_height"] = crop_info.crop_height; + } + } else if (to_crop_and_paste_flag_) { + auto &crop_info = crop_paste_para_.crop_info; + auto &crop_area = crop_info.crop_area; + dvpp_process["op_name"] = "crop_and_paste"; + dvpp_process["out_width"] = crop_paste_para_.output_width; + dvpp_process["out_height"] = crop_paste_para_.output_height; + + dvpp_process["paste_left"] = crop_paste_para_.paste_area.left; + dvpp_process["paste_right"] = crop_paste_para_.paste_area.right; + dvpp_process["paste_top"] = crop_paste_para_.paste_area.top; + dvpp_process["paste_bottom"] = crop_paste_para_.paste_area.bottom; + + if (crop_info.crop_type == kDvppCropTypeOffset) { + dvpp_process["crop_type"] = "offset"; + dvpp_process["crop_left"] = crop_area.left; + dvpp_process["crop_top"] = crop_area.top; + dvpp_process["crop_right"] = crop_area.right; + dvpp_process["crop_bottom"] = crop_area.bottom; + } else { + dvpp_process["crop_type"] = "centre"; + dvpp_process["crop_width"] = crop_info.crop_width; + dvpp_process["crop_height"] = crop_info.crop_height; + } + } + stringstream output; + output << dvpp_config; + WriteDvppConfig(output.str()); + } + uint32_t batch_size_ = 1; + MockModelDesc mock_model_desc_; + MockDvppProces mock_dvpp_process_; + + const std::string model_file_path_ = "/tmp/acl_model_fake_path.om"; + const std::string dvpp_config_file_path_ = "/tmp/acl_model_fake_path_dvpp_config.json"; + inference::DvppResizePara resize_para_; + inference::DvppCropPara crop_para_; + inference::DvppCropAndPastePara crop_paste_para_; + bool to_resize_flag_ = false; + bool to_crop_flag_ = false; + bool to_crop_and_paste_flag_ = false; + std::string pixel_format_; +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize1_Success) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->resize_call_times_, 1); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize3_Success) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->resize_call_times_, 3); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 3); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCrop_BatchSize1_Success) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + crop_para_.output_width = 24; // align to 32 + crop_para_.output_height = 24; + crop_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_para_.crop_info.crop_area.left = 0; + crop_para_.crop_info.crop_area.right = 64; + crop_para_.crop_info.crop_area.top = 0; + crop_para_.crop_info.crop_area.bottom = 64; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); + EXPECT_EQ(g_acl_dvpp_process->crop_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropPaste_BatchSize1_Success) { + pixel_format_ = "YUV420SP"; + to_crop_and_paste_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + crop_paste_para_.output_width = 24; // align to 32 + crop_paste_para_.output_height = 24; + crop_paste_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_paste_para_.crop_info.crop_area.left = 0; + crop_paste_para_.crop_info.crop_area.right = 64; + crop_paste_para_.crop_info.crop_area.top = 0; + crop_paste_para_.crop_info.crop_area.bottom = 64; + crop_paste_para_.paste_area.left = 0; + crop_paste_para_.paste_area.right = 64; + crop_paste_para_.paste_area.top = 0; + crop_paste_para_.paste_area.bottom = 64; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); + EXPECT_EQ(g_acl_dvpp_process->crop_paste_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize3_MultiTime_Success) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + for (int i = 0; i < 3; i++) { + // create inputs + PredictRequest request; + CreateDefaultRequest(request, i + 1); // image size + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + } + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->resize_call_times_, 3 * 3); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 3 * 3); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize3_MultiTime_SameImageSize_Success) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + for (int i = 0; i < 3; i++) { + // create inputs + PredictRequest request; + CreateDefaultRequest(request, 1); // image size + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + } + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->resize_call_times_, 3 * 3); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 3 * 3); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_InvalidImageDim_Fail) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + SetJpegLib(31, 31); // 32*32 ~ 4096*4096 + { + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + } + SetJpegLib(4097, 4097); // 32*32 ~ 4096*4096 + { + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->resize_call_times_, 0); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 0); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_InvalidResizeWidth_Fail) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + resize_para_.output_width = 15; // align to 16 16n minimum 32 + resize_para_.output_height = 24; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_width failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->resize_call_times_, 0); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 0); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_InvalidResizeHeight_Fail) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + resize_para_.output_width = 32; // align to 32 16n, minimum 32 + resize_para_.output_height = 3; // align to 4 2n, minimum 6 + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropOffset_CropMini_Success) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + crop_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_para_.output_height = 6; // align to 6 2n, minimum 6 + crop_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_para_.crop_info.crop_area.left = 4; + crop_para_.crop_info.crop_area.right = 13; + crop_para_.crop_info.crop_area.top = 4; + crop_para_.crop_info.crop_area.bottom = 9; + + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropCentre_CropMini_Success) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + crop_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_para_.output_height = 24; // align to 24 2n, minimum 6 + crop_para_.crop_info.crop_type = kDvppCropTypeCentre; + crop_para_.crop_info.crop_width = 10; + crop_para_.crop_info.crop_height = 6; + + SetJpegLib(127, 127); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->crop_call_times_, 1); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropOffset_InvalidCropWidth_Fail) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + crop_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_para_.output_height = 6; // align to 6 2n, minimum 6 + crop_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_para_.crop_info.crop_area.left = 4; + crop_para_.crop_info.crop_area.right = 11; // minimum 10*6 + crop_para_.crop_info.crop_area.top = 4; + crop_para_.crop_info.crop_area.bottom = 9; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check crop width failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropOffset_InvalidCropHeight_Fail) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + crop_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_para_.output_height = 6; // align to 6 2n, minimum 6 + crop_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_para_.crop_info.crop_area.left = 4; + crop_para_.crop_info.crop_area.right = 13; + crop_para_.crop_info.crop_area.top = 4; + crop_para_.crop_info.crop_area.bottom = 7; // minimum 10*6 + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check crop height failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropCentre_InvalidCropHeight_Fail) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + crop_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_para_.output_height = 24; // align to 24 2n, minimum 6 + crop_para_.crop_info.crop_type = kDvppCropTypeCentre; + crop_para_.crop_info.crop_width = 10; // minimum 10*6 + crop_para_.crop_info.crop_height = 4; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check crop_height failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropPasteOffset_CropMini_Success) { + pixel_format_ = "YUV420SP"; + to_crop_and_paste_flag_ = true; + crop_paste_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_paste_para_.output_height = 24; // align to 24 2n, minimum 6 + crop_paste_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_paste_para_.crop_info.crop_area.left = 4; + crop_paste_para_.crop_info.crop_area.right = 13; + crop_paste_para_.crop_info.crop_area.top = 4; + crop_paste_para_.crop_info.crop_area.bottom = 9; + crop_paste_para_.paste_area.left = 4; + crop_paste_para_.paste_area.right = 13; + crop_paste_para_.paste_area.top = 4; + crop_paste_para_.paste_area.bottom = 9; + + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->crop_paste_call_times_, 1); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropPasteCentre_CropMini_Success) { + pixel_format_ = "YUV420SP"; + to_crop_and_paste_flag_ = true; + crop_paste_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_paste_para_.output_height = 24; // align to 24 2n, minimum 6 + crop_paste_para_.crop_info.crop_type = kDvppCropTypeCentre; + crop_paste_para_.crop_info.crop_width = 10; + crop_paste_para_.crop_info.crop_height = 6; + crop_paste_para_.paste_area.left = 4; + crop_paste_para_.paste_area.right = 13; + crop_paste_para_.paste_area.top = 4; + crop_paste_para_.paste_area.bottom = 9; + + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->crop_paste_call_times_, 1); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropPasteCentre_InvalidPasteWidth_Fail) { + pixel_format_ = "YUV420SP"; + to_crop_and_paste_flag_ = true; + crop_paste_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_paste_para_.output_height = 24; // align to 24 2n, minimum 6 + crop_paste_para_.crop_info.crop_type = kDvppCropTypeCentre; + crop_paste_para_.crop_info.crop_width = 10; + crop_paste_para_.crop_info.crop_height = 6; + crop_paste_para_.paste_area.left = 4; + crop_paste_para_.paste_area.right = 11; + crop_paste_para_.paste_area.top = 4; + crop_paste_para_.paste_area.bottom = 9; + + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropPasteCentre_InvalidPasteHeight_Fail) { + pixel_format_ = "YUV420SP"; + to_crop_and_paste_flag_ = true; + crop_paste_para_.output_width = 24; // align to 32 16n, minimum 32 + crop_paste_para_.output_height = 24; // align to 24 2n, minimum 6 + crop_paste_para_.crop_info.crop_type = kDvppCropTypeCentre; + crop_paste_para_.crop_info.crop_width = 10; + crop_paste_para_.crop_info.crop_height = 6; + crop_paste_para_.paste_area.left = 4; + crop_paste_para_.paste_area.right = 13; + crop_paste_para_.paste_area.top = 4; + crop_paste_para_.paste_area.bottom = 7; + + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + // load config, check output_height failed + EXPECT_FALSE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +// dvpp proces fail, test resource release ok +TEST_F(AclSessionDvppTest, TestAclSession_DvppDecode_BatchSize1_DvppFail) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + mock_dvpp_process_.decode_fail_list_.push_back(false); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize1_DvppFail) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + mock_dvpp_process_.resize_fail_list_.push_back(false); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize3_DvppFail0) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + mock_dvpp_process_.resize_fail_list_.push_back(false); // image 0 fail + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize3_DvppFail1) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + mock_dvpp_process_.resize_fail_list_.push_back(true); // image 0 success + mock_dvpp_process_.resize_fail_list_.push_back(false); // image 1 fail + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 2); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppResize_BatchSize3_DvppFail2) { + pixel_format_ = "YUV420SP"; + to_resize_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + resize_para_.output_width = 24; // align to 32 + resize_para_.output_height = 24; + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + mock_dvpp_process_.resize_fail_list_.push_back(true); // image 0 success + mock_dvpp_process_.resize_fail_list_.push_back(true); // image 1 success + mock_dvpp_process_.resize_fail_list_.push_back(false); // image 2 fail + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 3); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCrop_BatchSize1_DvppFail) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + crop_para_.output_width = 24; // align to 32 + crop_para_.output_height = 24; + crop_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_para_.crop_info.crop_area.left = 0; + crop_para_.crop_info.crop_area.right = 64; + crop_para_.crop_info.crop_area.top = 0; + crop_para_.crop_info.crop_area.bottom = 64; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + mock_dvpp_process_.crop_fail_list_.push_back(false); + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCropPaste_BatchSize1_DvppFail) { + pixel_format_ = "YUV420SP"; + to_crop_and_paste_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + crop_paste_para_.output_width = 24; // align to 32 + crop_paste_para_.output_height = 24; + crop_paste_para_.crop_info.crop_type = kDvppCropTypeOffset; + crop_paste_para_.crop_info.crop_area.left = 0; + crop_paste_para_.crop_info.crop_area.right = 64; + crop_paste_para_.crop_info.crop_area.top = 0; + crop_paste_para_.crop_info.crop_area.bottom = 64; + crop_paste_para_.paste_area.left = 0; + crop_paste_para_.paste_area.right = 64; + crop_paste_para_.paste_area.top = 0; + crop_paste_para_.paste_area.bottom = 64; + CreateDvppConfig(); + InitModelDesc(1); // batch_size=1 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + mock_dvpp_process_.crop_and_paste_fail_list_.push_back(false); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 1); +}; + +TEST_F(AclSessionDvppTest, TestAclSession_DvppCrop_BatchSize3_MultiTime_DvppFail) { + pixel_format_ = "YUV420SP"; + to_crop_flag_ = true; + SetJpegLib(128, 128); // 32*32 ~ 4096*4096 + crop_para_.output_width = 24; // align to 32 + crop_para_.output_height = 24; + crop_para_.crop_info.crop_type = kDvppCropTypeCentre; + crop_para_.crop_info.crop_width = 10; + crop_para_.crop_info.crop_height = 6; + + CreateDvppConfig(); + InitModelDesc(3); // batch_size=3 + + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile(model_file_path_, model_id) == SUCCESS); + + for (int i = 0; i < 3; i++) { + mock_dvpp_process_.crop_fail_list_.push_back(false); + // create inputs + PredictRequest request; + CreateDefaultRequest(request, i + 1); // image size + + PredictReply reply; + ServingRequest serving_request(request); + ServingImagesRequest serving_images(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_images, serving_request, serving_reply) == SUCCESS); + } + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); + EXPECT_EQ(g_acl_dvpp_process->decode_call_times_, 3); +}; + +} // namespace serving +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_session_test_model_load.cc b/tests/ut/cpp/serving/acl_session_test_model_load.cc new file mode 100644 index 0000000000..7fe406fd55 --- /dev/null +++ b/tests/ut/cpp/serving/acl_session_test_model_load.cc @@ -0,0 +1,342 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl_session_test_common.h" + +using namespace std; + +namespace mindspore { +namespace serving { + +class MockFailAclDeviceContextStream : public AclDeviceContextStream { + public: + aclError aclrtSetDevice(int32_t deviceId) override { + if (set_device_fail_list_.empty()) { + return AclDeviceContextStream::aclrtSetDevice(deviceId); + } + auto val = set_device_fail_list_.front(); + set_device_fail_list_.erase(set_device_fail_list_.begin()); + if (val) { + return AclDeviceContextStream::aclrtSetDevice(deviceId); + } + return 1; + } + + aclError aclrtResetDevice(int32_t deviceId) override { + auto ret = AclDeviceContextStream::aclrtResetDevice(deviceId); + if (ret != ACL_ERROR_NONE) { + return ret; + } + if (reset_device_fail_list_.empty()) { + return ret; + } + auto val = reset_device_fail_list_.front(); + reset_device_fail_list_.erase(reset_device_fail_list_.begin()); + return val ? ACL_ERROR_NONE : 1; + } + + aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId) override { + if (create_context_fail_list_.empty()) { + return AclDeviceContextStream::aclrtCreateContext(context, deviceId); + } + auto val = create_context_fail_list_.front(); + create_context_fail_list_.erase(create_context_fail_list_.begin()); + if (val) { + return AclDeviceContextStream::aclrtCreateContext(context, deviceId); + } + return 1; + } + + aclError aclrtDestroyContext(aclrtContext context) override { + auto ret = AclDeviceContextStream::aclrtDestroyContext(context); + if (ret != ACL_ERROR_NONE) { + return ret; + } + if (destroy_context_fail_list_.empty()) { + return ret; + } + auto val = destroy_context_fail_list_.front(); + destroy_context_fail_list_.erase(destroy_context_fail_list_.begin()); + return val ? ACL_ERROR_NONE : 1; + } + + aclError aclrtCreateStream(aclrtStream *stream) override { + if (create_stream_fail_list_.empty()) { + return AclDeviceContextStream::aclrtCreateStream(stream); + } + auto val = create_stream_fail_list_.front(); + create_stream_fail_list_.erase(create_stream_fail_list_.begin()); + if (val) { + return AclDeviceContextStream::aclrtCreateStream(stream); + } + return 1; + } + + aclError aclrtDestroyStream(aclrtStream stream) override { + auto ret = AclDeviceContextStream::aclrtDestroyStream(stream); + if (ret != ACL_ERROR_NONE) { + return ret; + } + if (destroy_stream_fail_list_.empty()) { + return ret; + } + auto val = destroy_stream_fail_list_.front(); + destroy_stream_fail_list_.erase(destroy_stream_fail_list_.begin()); + return val ? ACL_ERROR_NONE : 1; + } + std::vector set_device_fail_list_; + std::vector reset_device_fail_list_; + std::vector create_context_fail_list_; + std::vector destroy_context_fail_list_; + std::vector create_stream_fail_list_; + std::vector destroy_stream_fail_list_; +}; + +class MockFailAclMemory : public AclMemory { + public: + aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) override { + if (device_mem_fail_list_.empty()) { + return AclMemory::aclrtMalloc(devPtr, size, policy); + } + auto val = device_mem_fail_list_.front(); + device_mem_fail_list_.erase(device_mem_fail_list_.begin()); + if (val) { + return AclMemory::aclrtMalloc(devPtr, size, policy); + } + return 1; + } + aclError aclrtMallocHost(void **hostPtr, size_t size) override { + if (host_mem_fail_list_.empty()) { + return AclMemory::aclrtMallocHost(hostPtr, size); + } + auto val = host_mem_fail_list_.front(); + host_mem_fail_list_.erase(host_mem_fail_list_.begin()); + if (val) { + return AclMemory::aclrtMallocHost(hostPtr, size); + } + return 1; + } + aclError acldvppMalloc(void **devPtr, size_t size) override { + if (dvpp_mem_fail_list_.empty()) { + return AclMemory::acldvppMalloc(devPtr, size); + } + auto val = dvpp_mem_fail_list_.front(); + dvpp_mem_fail_list_.erase(dvpp_mem_fail_list_.begin()); + if (val) { + return AclMemory::acldvppMalloc(devPtr, size); + } + return 1; + } + + std::vector device_mem_fail_list_; + std::vector host_mem_fail_list_; + std::vector dvpp_mem_fail_list_; +}; + +class AclSessionModelLoadTest : public AclSessionTest { + public: + AclSessionModelLoadTest() = default; + void SetUp() override { + AclSessionTest::SetUp(); + aclmdlDesc model_desc; + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + model_desc.outputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + model_desc.outputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + mock_model_desc_ = MockModelDesc(model_desc); + g_acl_model_desc = &mock_model_desc_; + g_acl_device_context_stream = &fail_acl_device_context_stream_; + g_acl_memory = &fail_acl_memory_; + } + void CreateDefaultRequest(PredictRequest &request) { + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + auto input1 = request.add_data(); + CreateTensor(*input1, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + } + + void CheckDefaultReply(const PredictReply &reply) { + EXPECT_TRUE(reply.result().size() == 2); + if (reply.result().size() == 2) { + CheckTensorItem(reply.result(0), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + CheckTensorItem(reply.result(1), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + } + } + MockModelDesc mock_model_desc_; + /* Test Resource will be release on something wrong happens*/ + MockFailAclDeviceContextStream fail_acl_device_context_stream_; + MockFailAclMemory fail_acl_memory_; +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_OneTime_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_SetDeviceFail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + fail_acl_device_context_stream_.set_device_fail_list_.push_back(false); + EXPECT_FALSE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_CreateContextFail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + fail_acl_device_context_stream_.create_context_fail_list_.push_back(false); + EXPECT_FALSE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_CreateStreamFail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + fail_acl_device_context_stream_.create_stream_fail_list_.push_back(false); + EXPECT_FALSE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_ResetDeviceFail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + fail_acl_device_context_stream_.reset_device_fail_list_.push_back(false); + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + acl_session.FinalizeEnv(); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_DestroyContextFail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + fail_acl_device_context_stream_.destroy_context_fail_list_.push_back(false); + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + acl_session.FinalizeEnv(); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_DestroyStreamFail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + fail_acl_device_context_stream_.destroy_stream_fail_list_.push_back(false); + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + acl_session.FinalizeEnv(); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail0_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + fail_acl_memory_.device_mem_fail_list_.push_back(false); // input0 buffer + EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail1_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + fail_acl_memory_.device_mem_fail_list_.push_back(true); // input0 buffer + fail_acl_memory_.device_mem_fail_list_.push_back(false); // input1 buffer + EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail2_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + fail_acl_memory_.device_mem_fail_list_.push_back(true); // input0 buffer + fail_acl_memory_.device_mem_fail_list_.push_back(true); // input1 buffer + fail_acl_memory_.device_mem_fail_list_.push_back(false); // output0 buffer + EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail3_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + fail_acl_memory_.device_mem_fail_list_.push_back(true); // input0 buffer + fail_acl_memory_.device_mem_fail_list_.push_back(true); // input1 buffer + fail_acl_memory_.device_mem_fail_list_.push_back(true); // output0 buffer + fail_acl_memory_.device_mem_fail_list_.push_back(false); // output1 buffer + EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_RunOnDevice_MallocFail0_Success) { + SetDeviceRunMode(); + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + fail_acl_memory_.host_mem_fail_list_.push_back(false); // output0 buffer + EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionModelLoadTest, TestAclSession_RunOnDevice_MallocFail1_Success) { + SetDeviceRunMode(); + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + fail_acl_memory_.host_mem_fail_list_.push_back(true); // output0 buffer + fail_acl_memory_.host_mem_fail_list_.push_back(false); // output1 buffer + EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +} // namespace serving +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_session_test_one_input_output.cc b/tests/ut/cpp/serving/acl_session_test_one_input_output.cc new file mode 100644 index 0000000000..0f04756f6b --- /dev/null +++ b/tests/ut/cpp/serving/acl_session_test_one_input_output.cc @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl_session_test_common.h" + +using namespace std; + +namespace mindspore { +namespace serving { + +class AclSessionOneInputOneOutputTest : public AclSessionTest { + public: + AclSessionOneInputOneOutputTest() = default; + void SetUp() override { + AclSessionTest::SetUp(); + aclmdlDesc model_desc; + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + model_desc.outputs.push_back( + AclTensorDesc{.dims = {2, 8, 8, 3}, .data_type = ACL_FLOAT, .size = 2 * 8 * 8 * 3 * sizeof(float)}); + mock_model_desc_ = MockModelDesc(model_desc); + g_acl_model_desc = &mock_model_desc_; + } + void CreateDefaultRequest(PredictRequest &request) { + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + } + + void CreateInvalidDataSizeRequest(PredictRequest &request) { + auto input0 = request.add_data(); + // data size invalid, not match model input required + CreateTensor(*input0, {2, 24, 24, 2}, ::ms_serving::DataType::MS_FLOAT32); + } + + void CheckDefaultReply(const PredictReply &reply) { + EXPECT_TRUE(reply.result().size() == 1); + if (reply.result().size() == 1) { + CheckTensorItem(reply.result(0), {2, 8, 8, 3}, ::ms_serving::DataType::MS_FLOAT32); + } + } + + MockModelDesc mock_model_desc_; +}; + +TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_OneTime_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_MutilTimes_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + for (int i = 0; i < 10; i++) { + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_InvalidDataSize_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_InvalidDataSize_MultiTimes_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + for (int i = 0; i < 10; i++) { + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +} // namespace serving +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_session_test_two_input_output.cc b/tests/ut/cpp/serving/acl_session_test_two_input_output.cc new file mode 100644 index 0000000000..08e150cb96 --- /dev/null +++ b/tests/ut/cpp/serving/acl_session_test_two_input_output.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl_session_test_common.h" + +using namespace std; + +namespace mindspore { +namespace serving { + +class AclSessionTwoInputTwoOutputTest : public AclSessionTest { + public: + AclSessionTwoInputTwoOutputTest() = default; + void SetUp() override { + AclSessionTest::SetUp(); + aclmdlDesc model_desc; + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)}); + + model_desc.inputs.push_back( + AclTensorDesc{.dims = {2, 32}, .data_type = ACL_INT32, .size = 2 * 32 * sizeof(int32_t)}); + + model_desc.outputs.push_back( + AclTensorDesc{.dims = {2, 8, 8, 3}, .data_type = ACL_FLOAT, .size = 2 * 8 * 8 * 3 * sizeof(float)}); + + model_desc.outputs.push_back( + AclTensorDesc{.dims = {2, 1024}, .data_type = ACL_BOOL, .size = 2 * 1024 * sizeof(bool)}); + + mock_model_desc_ = MockModelDesc(model_desc); + g_acl_model_desc = &mock_model_desc_; + } + void CreateDefaultRequest(PredictRequest &request) { + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + auto input1 = request.add_data(); + CreateTensor(*input1, {2, 32}, ::ms_serving::DataType::MS_INT32); + } + + void CreateInvalidDataSizeRequest0(PredictRequest &request) { + auto input0 = request.add_data(); + // data size invalid, not match model input required + CreateTensor(*input0, {2, 24, 24, 2}, ::ms_serving::DataType::MS_FLOAT32); + + auto input1 = request.add_data(); + CreateTensor(*input1, {2, 32}, ::ms_serving::DataType::MS_INT32); + } + + void CreateInvalidDataSizeRequest1(PredictRequest &request) { + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + auto input1 = request.add_data(); + // data size invalid, not match model input required + CreateTensor(*input1, {2, 16}, ::ms_serving::DataType::MS_INT32); + } + + void CreateInvalidDataSizeRequestOneInput0(PredictRequest &request) { + // only has one input for input0 + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32); + } + + void CreateInvalidDataSizeRequestOneInput1(PredictRequest &request) { + // only has one input for input1 + auto input0 = request.add_data(); + CreateTensor(*input0, {2, 32}, ::ms_serving::DataType::MS_INT32); + } + + void CheckDefaultReply(const PredictReply &reply) { + EXPECT_TRUE(reply.result().size() == 2); + if (reply.result().size() == 2) { + CheckTensorItem(reply.result(0), {2, 8, 8, 3}, ::ms_serving::DataType::MS_FLOAT32); + CheckTensorItem(reply.result(1), {2, 1024}, ::ms_serving::DataType::MS_BOOL); + } + } + + MockModelDesc mock_model_desc_; +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_OneTime_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_MutilTimes_Success) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + for (int i = 0; i < 10; i++) { + // create inputs + PredictRequest request; + CreateDefaultRequest(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + CheckDefaultReply(reply); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_Input0_InvalidDataSize_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequest0(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_Input1_InvalidDataSize_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequest1(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_OnlyInput0_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequestOneInput0(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_OnlyInput1_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequestOneInput1(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_InvalidDataSize_MultiTimes_Fail) { + inference::AclSession acl_session; + uint32_t device_id = 1; + EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS); + uint32_t model_id = 0; + EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS); + for (int i = 0; i < 10; i++) { + // create inputs + PredictRequest request; + CreateInvalidDataSizeRequest0(request); + + PredictReply reply; + ServingRequest serving_request(request); + ServingReply serving_reply(reply); + EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS); + } + EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS); + EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS); +}; + +} // namespace serving +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_stub.cc b/tests/ut/cpp/serving/acl_stub.cc new file mode 100644 index 0000000000..f11829befd --- /dev/null +++ b/tests/ut/cpp/serving/acl_stub.cc @@ -0,0 +1,323 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl_stub.h" +#include + +AclDataBuffer *g_acl_data_buffer = nullptr; +AclEnv *g_acl_env = nullptr; +AclDataSet *g_acl_dataset = nullptr; +AclModel *g_acl_model = nullptr; +AclModelDesc *g_acl_model_desc = nullptr; +AclDeviceContextStream *g_acl_device_context_stream = nullptr; +AclMemory *g_acl_memory = nullptr; +AclDvppPicDesc *g_acl_dvpp_pic_desc = nullptr; +AclDvppRoiConfig *g_acl_dvpp_roi_config = nullptr; +AclDvppResizeConfig *g_acl_dvpp_resize_config = nullptr; +AclDvppChannelDesc *g_acl_dvpp_channel_desc = nullptr; +AclDvppProcess *g_acl_dvpp_process = nullptr; +AclRunMode *g_acl_run_mode = nullptr; +AclJpegLib *g_acl_jpeg_lib = nullptr; + +aclDataBuffer *aclCreateDataBuffer(void *data, size_t size) { + return g_acl_data_buffer->aclCreateDataBuffer(data, size); +} + +aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffer) { + return g_acl_data_buffer->aclDestroyDataBuffer(dataBuffer); +} + +void *aclGetDataBufferAddr(const aclDataBuffer *dataBuffer) { + return g_acl_data_buffer->aclGetDataBufferAddr(dataBuffer); +} + +uint32_t aclGetDataBufferSize(const aclDataBuffer *dataBuffer) { + return g_acl_data_buffer->aclGetDataBufferSize(dataBuffer); +} + +size_t aclDataTypeSize(aclDataType dataType) { + std::unordered_map dataTypeMap = { + {ACL_FLOAT16, 2}, {ACL_FLOAT, 4}, {ACL_DOUBLE, 8}, {ACL_INT8, 1}, {ACL_INT16, 2}, {ACL_INT32, 4}, + {ACL_INT64, 8}, {ACL_UINT8, 1}, {ACL_UINT16, 2}, {ACL_UINT32, 4}, {ACL_UINT64, 8}, {ACL_BOOL, 1}, + }; + auto it = dataTypeMap.find(dataType); + if (it == dataTypeMap.end()) { + return 0; + } else { + return it->second; + } +} + +void aclAppLog(aclLogLevel logLevel, const char *func, const char *file, uint32_t line, const char *fmt, ...) { + if (logLevel == ACL_ERROR) { + // std::cout << file << ":" << line << "," << func << ": " << fmt << std::endl; + } +} + +aclError aclInit(const char *configPath) { return g_acl_env->aclInit(configPath); } + +aclError aclFinalize() { return g_acl_env->aclFinalize(); } + +// dataset +aclmdlDataset *aclmdlCreateDataset() { return g_acl_dataset->aclmdlCreateDataset(); } + +aclError aclmdlDestroyDataset(const aclmdlDataset *dataSet) { return g_acl_dataset->aclmdlDestroyDataset(dataSet); } + +aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataSet, aclDataBuffer *dataBuffer) { + return g_acl_dataset->aclmdlAddDatasetBuffer(dataSet, dataBuffer); +} + +size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *dataSet) { + return g_acl_dataset->aclmdlGetDatasetNumBuffers(dataSet); +} + +aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataSet, size_t index) { + return g_acl_dataset->aclmdlGetDatasetBuffer(dataSet, index); +} + +// model +aclError aclmdlLoadFromFile(const char *modelPath, uint32_t *modelId) { + return g_acl_model->aclmdlLoadFromFile(modelPath, modelId); +} + +aclError aclmdlLoadFromMem(const void *model, size_t modelSize, uint32_t *modelId) { + return g_acl_model->aclmdlLoadFromMem(model, modelSize, modelId); +} + +aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, size_t workSize, + void *weightPtr, size_t weightSize) { + return g_acl_model->aclmdlLoadFromFileWithMem(modelPath, modelId, workPtr, workSize, weightPtr, weightSize); +} + +aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, uint32_t *modelId, void *workPtr, + size_t workSize, void *weightPtr, size_t weightSize) { + return g_acl_model->aclmdlLoadFromMemWithMem(model, modelSize, modelId, workPtr, workSize, weightPtr, weightSize); +} + +aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) { + return g_acl_model->aclmdlExecute(modelId, input, output); +} + +aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, aclrtStream stream) { + return g_acl_model->aclmdlExecuteAsync(modelId, input, output, stream); +} + +aclError aclmdlUnload(uint32_t modelId) { return g_acl_model->aclmdlUnload(modelId); } + +// model desc +aclmdlDesc *aclmdlCreateDesc() { return g_acl_model_desc->aclmdlCreateDesc(); } + +aclError aclmdlDestroyDesc(aclmdlDesc *modelDesc) { return g_acl_model_desc->aclmdlDestroyDesc(modelDesc); } + +aclError aclmdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId) { + return g_acl_model_desc->aclmdlGetDesc(modelDesc, modelId); +} + +size_t aclmdlGetNumInputs(aclmdlDesc *modelDesc) { return g_acl_model_desc->aclmdlGetNumInputs(modelDesc); } + +size_t aclmdlGetNumOutputs(aclmdlDesc *modelDesc) { return g_acl_model_desc->aclmdlGetNumOutputs(modelDesc); } + +size_t aclmdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { + return g_acl_model_desc->aclmdlGetInputSizeByIndex(modelDesc, index); +} + +size_t aclmdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { + return g_acl_model_desc->aclmdlGetOutputSizeByIndex(modelDesc, index); +} + +aclError aclmdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + return g_acl_model_desc->aclmdlGetInputDims(modelDesc, index, dims); +} + +aclError aclmdlGetOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + return g_acl_model_desc->aclmdlGetOutputDims(modelDesc, index, dims); +} + +aclError aclmdlGetCurOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + return g_acl_model_desc->aclmdlGetCurOutputDims(modelDesc, index, dims); +} + +aclFormat aclmdlGetInputFormat(const aclmdlDesc *modelDesc, size_t index) { + return g_acl_model_desc->aclmdlGetInputFormat(modelDesc, index); +} + +aclFormat aclmdlGetOutputFormat(const aclmdlDesc *modelDesc, size_t index) { + return g_acl_model_desc->aclmdlGetOutputFormat(modelDesc, index); +} + +aclDataType aclmdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index) { + return g_acl_model_desc->aclmdlGetInputDataType(modelDesc, index); +} + +aclDataType aclmdlGetOutputDataType(const aclmdlDesc *modelDesc, size_t index) { + return g_acl_model_desc->aclmdlGetOutputDataType(modelDesc, index); +} + +// device, context, stream + +aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId) { + return g_acl_device_context_stream->aclrtCreateContext(context, deviceId); +} + +aclError aclrtDestroyContext(aclrtContext context) { return g_acl_device_context_stream->aclrtDestroyContext(context); } + +aclError aclrtSetCurrentContext(aclrtContext context) { + return g_acl_device_context_stream->aclrtSetCurrentContext(context); +} + +aclError aclrtSetDevice(int32_t deviceId) { return g_acl_device_context_stream->aclrtSetDevice(deviceId); } + +aclError aclrtResetDevice(int32_t deviceId) { return g_acl_device_context_stream->aclrtResetDevice(deviceId); } + +aclError aclrtGetRunMode(aclrtRunMode *runMode) { return g_acl_run_mode->aclrtGetRunMode(runMode); } + +aclError aclrtCreateStream(aclrtStream *stream) { return g_acl_device_context_stream->aclrtCreateStream(stream); } + +aclError aclrtDestroyStream(aclrtStream stream) { return g_acl_device_context_stream->aclrtDestroyStream(stream); } + +aclError aclrtSynchronizeStream(aclrtStream stream) { + return g_acl_device_context_stream->aclrtSynchronizeStream(stream); +} + +// memory +aclError acldvppMalloc(void **devPtr, size_t size) { return g_acl_memory->acldvppMalloc(devPtr, size); } +aclError acldvppFree(void *devPtr) { return g_acl_memory->acldvppFree(devPtr); } + +aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { + return g_acl_memory->aclrtMalloc(devPtr, size, policy); +} + +aclError aclrtFree(void *devPtr) { return g_acl_memory->aclrtFree(devPtr); } + +aclError aclrtMallocHost(void **hostPtr, size_t size) { return g_acl_memory->aclrtMallocHost(hostPtr, size); } + +aclError aclrtFreeHost(void *hostPtr) { return g_acl_memory->aclrtFreeHost(hostPtr); } + +aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind) { + return g_acl_memory->aclrtMemcpy(dst, destMax, src, count, kind); +} + +acldvppPicDesc *acldvppCreatePicDesc() { return g_acl_dvpp_pic_desc->acldvppCreatePicDesc(); } +aclError acldvppDestroyPicDesc(acldvppPicDesc *picDesc) { return g_acl_dvpp_pic_desc->acldvppDestroyPicDesc(picDesc); } + +aclError acldvppSetPicDescSize(acldvppPicDesc *picDesc, uint32_t size) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescSize(picDesc, size); +} + +aclError acldvppSetPicDescFormat(acldvppPicDesc *picDesc, acldvppPixelFormat format) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescFormat(picDesc, format); +} + +aclError acldvppSetPicDescWidth(acldvppPicDesc *picDesc, uint32_t width) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescWidth(picDesc, width); +} + +aclError acldvppSetPicDescHeight(acldvppPicDesc *picDesc, uint32_t height) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescHeight(picDesc, height); +} + +aclError acldvppSetPicDescData(acldvppPicDesc *picDesc, void *dataDev) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescData(picDesc, dataDev); +} + +aclError acldvppSetPicDescWidthStride(acldvppPicDesc *picDesc, uint32_t widthStride) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescWidthStride(picDesc, widthStride); +} + +aclError acldvppSetPicDescHeightStride(acldvppPicDesc *picDesc, uint32_t heightStride) { + return g_acl_dvpp_pic_desc->acldvppSetPicDescHeightStride(picDesc, heightStride); +} + +acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom) { + return g_acl_dvpp_roi_config->acldvppCreateRoiConfig(left, right, top, bottom); +} + +aclError acldvppDestroyRoiConfig(acldvppRoiConfig *roiConfig) { + return g_acl_dvpp_roi_config->acldvppDestroyRoiConfig(roiConfig); +} + +aclError acldvppSetRoiConfig(acldvppRoiConfig *roiConfig, uint32_t left, uint32_t right, uint32_t top, + uint32_t bottom) { + return g_acl_dvpp_roi_config->acldvppSetRoiConfig(roiConfig, left, right, top, bottom); +} + +acldvppResizeConfig *acldvppCreateResizeConfig() { return g_acl_dvpp_resize_config->acldvppCreateResizeConfig(); } + +aclError acldvppDestroyResizeConfig(acldvppResizeConfig *resizeConfig) { + return g_acl_dvpp_resize_config->acldvppDestroyResizeConfig(resizeConfig); +} + +aclError acldvppCreateChannel(acldvppChannelDesc *channelDesc) { + return g_acl_dvpp_channel_desc->acldvppCreateChannel(channelDesc); +} + +aclError acldvppDestroyChannel(acldvppChannelDesc *channelDesc) { + return g_acl_dvpp_channel_desc->acldvppDestroyChannel(channelDesc); +} + +acldvppChannelDesc *acldvppCreateChannelDesc() { return g_acl_dvpp_channel_desc->acldvppCreateChannelDesc(); } + +aclError acldvppDestroyChannelDesc(acldvppChannelDesc *channelDesc) { + return g_acl_dvpp_channel_desc->acldvppDestroyChannelDesc(channelDesc); +} + +aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, + acldvppResizeConfig *resizeConfig, aclrtStream stream) { + return g_acl_dvpp_process->acldvppVpcResizeAsync(channelDesc, inputDesc, outputDesc, resizeConfig, stream); +} + +aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, + acldvppRoiConfig *cropArea, aclrtStream stream) { + return g_acl_dvpp_process->acldvppVpcCropAsync(channelDesc, inputDesc, outputDesc, cropArea, stream); +} + +aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, + acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, + acldvppRoiConfig *pasteArea, aclrtStream stream) { + return g_acl_dvpp_process->acldvppVpcCropAndPasteAsync(channelDesc, inputDesc, outputDesc, cropArea, pasteArea, + stream); +} + +aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchDesc, uint32_t *roiNums, + uint32_t size, acldvppBatchPicDesc *dstBatchDesc, acldvppRoiConfig *cropAreas[], + aclrtStream stream) { + return g_acl_dvpp_process->acldvppVpcBatchCropAsync(channelDesc, srcBatchDesc, roiNums, size, dstBatchDesc, cropAreas, + stream); +} + +aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size, + acldvppPicDesc *outputDesc, aclrtStream stream) { + return g_acl_dvpp_process->acldvppJpegDecodeAsync(channelDesc, data, size, outputDesc, stream); +} + +// jpeg lib +void jpeg_CreateDecompress(j_decompress_ptr cinfo, int version, size_t structsize) { + g_acl_jpeg_lib->jpeg_CreateDecompress(cinfo, version, structsize); +} + +void jpeg_mem_src(j_decompress_ptr cinfo, const unsigned char *inbuffer, unsigned long insize) { + g_acl_jpeg_lib->jpeg_mem_src(cinfo, inbuffer, insize); +} + +int jpeg_read_header(j_decompress_ptr cinfo, boolean require_image) { + return g_acl_jpeg_lib->jpeg_read_header(cinfo, require_image); +} + +void jpeg_destroy_decompress(j_decompress_ptr cinfo) { g_acl_jpeg_lib->jpeg_destroy_decompress(cinfo); } + +struct jpeg_error_mgr *jpeg_std_error(struct jpeg_error_mgr *err) { + return err; +} \ No newline at end of file diff --git a/tests/ut/cpp/serving/acl_stub.h b/tests/ut/cpp/serving/acl_stub.h new file mode 100644 index 0000000000..3113d27bfc --- /dev/null +++ b/tests/ut/cpp/serving/acl_stub.h @@ -0,0 +1,857 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ACL_STUB_H +#define MINDSPORE_ACL_STUB_H + +#include "acl/acl_base.h" +#include "acl/acl.h" +#include "acl/acl_mdl.h" +#include "acl/acl_rt.h" +#include "acl/ops/acl_dvpp.h" +#include +#include +#include +#include +#include +#include +#include +#include "jpeglib.h" + +struct aclDataBuffer { + void *data = nullptr; + size_t size = 0; +}; + +struct aclmdlDataset { + std::vector data_buffers; +}; + +struct aclTensorDesc {}; + +struct AclTensorDesc { + std::vector dims; + aclDataType data_type = ACL_DT_UNDEFINED; + size_t size = 0; +}; + +struct aclmdlDesc { + std::vector inputs; + std::vector outputs; +}; + +struct acldvppPicDesc { + uint32_t size = 0; + acldvppPixelFormat format = PIXEL_FORMAT_YUV_400; + uint32_t width = 0; + uint32_t height = 0; + void *dataDev = nullptr; + uint32_t widthStride = 0; + uint32_t heightStride = 0; +}; + +struct acldvppRoiConfig { + uint32_t left = 0; + uint32_t right = 0; + uint32_t top = 0; + uint32_t bottom = 0; +}; + +struct acldvppResizeConfig { + uint32_t id; +}; + +struct acldvppChannelDesc { + bool channel_valid_flag = false; +}; + +class AclModel; +extern AclModel *g_acl_model; + +template +aclError AclItemOnDestroy( + std::vector &live, std::vector &destroy, const Type *destroy_item, + std::function func_release = [](Type &list_item) {}) { + for (auto it = live.begin(); it != live.end(); it++) { + if (&(*it) == destroy_item) { + func_release(*it); + destroy.push_back(*it); + live.erase(it); + return ACL_ERROR_NONE; + } + } + return 1; +} + +template ::value, int>::type = 0> +class ResourceBase { + public: + using Type = typename std::remove_pointer::type; + ResourceBase() = default; + virtual ~ResourceBase() { Clear(); } + void Clear() { + for (auto item : resource_live_) { + delete item; + } + resource_live_.clear(); + resource_destroy_.clear(); + } + template + Type *OnCreate(Args &&... args) { + auto item = new Type(std::forward(args)...); + resource_live_.push_back(item); + return item; + } + aclError OnDestroy( + const Type *item, std::function func_release = [](Type &list_item) {}) { + auto it = std::find(resource_live_.begin(), resource_live_.end(), item); + if (it == resource_live_.end()) { + return 1; + } + func_release(**it); // Type& + resource_destroy_.push_back(*it); // Type* + resource_live_.erase(it); + delete item; + return ACL_ERROR_NONE; + } + size_t LiveSize() const { return resource_live_.size(); } + bool Check() const { return resource_live_.empty(); } + std::vector resource_live_; + std::vector resource_destroy_; +}; + +class AclDataBuffer { + public: + AclDataBuffer() {} + virtual ~AclDataBuffer() { Clear(); } + virtual void Clear() { data_buffer_.Clear(); } + bool Check() { return data_buffer_.Check(); } + + virtual aclDataBuffer *aclCreateDataBuffer(void *data, size_t size) { + aclDataBuffer data_buffer; + data_buffer.data = data; + data_buffer.size = size; + return data_buffer_.OnCreate(data_buffer); + } + + virtual aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffer) { return data_buffer_.OnDestroy(dataBuffer); } + + virtual void *aclGetDataBufferAddr(const aclDataBuffer *dataBuffer) { + if (dataBuffer == nullptr) { + return nullptr; + } + return dataBuffer->data; + } + + virtual uint32_t aclGetDataBufferSize(const aclDataBuffer *dataBuffer) { + if (dataBuffer == nullptr) { + return 0; + } + return dataBuffer->size; + } + ResourceBase data_buffer_; +}; + +class AclDataSet { + public: + AclDataSet() {} + virtual ~AclDataSet() { Clear(); } + virtual void Clear() { dataset_.Clear(); } + bool Check() { return dataset_.Check(); } + + public: + virtual aclmdlDataset *aclmdlCreateDataset() { return dataset_.OnCreate(); } + virtual aclError aclmdlDestroyDataset(const aclmdlDataset *dataSet) { return dataset_.OnDestroy(dataSet); } + virtual aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataSet, aclDataBuffer *dataBuffer) { + if (dataSet == nullptr) { + return 1; + } + dataSet->data_buffers.push_back(dataBuffer); + return ACL_ERROR_NONE; + } + virtual size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *dataSet) { + if (dataSet == nullptr) { + return 0; + } + return dataSet->data_buffers.size(); + } + virtual aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataSet, size_t index) { + if (dataSet == nullptr || index >= dataSet->data_buffers.size()) { + return nullptr; + } + return dataSet->data_buffers[index]; + } + ResourceBase dataset_; +}; + +class AclEnv { + public: + virtual aclError aclInit(const char *configPath) { + is_init = true; + return ACL_ERROR_NONE; + } + virtual aclError aclFinalize() { + is_init = false; + return ACL_ERROR_NONE; + } + bool Check() { return is_init == false; } + bool is_init = false; +}; + +class AclModel { + public: + bool Check() { return model_live_.empty(); } + virtual aclError aclmdlLoadFromFile(const char *modelPath, uint32_t *modelId) { + model_live_.push_back(cur_max_model_id_); + *modelId = cur_max_model_id_; + cur_max_model_id_++; + return ACL_ERROR_NONE; + } + + virtual aclError aclmdlLoadFromMem(const void *model, size_t modelSize, uint32_t *modelId) { + return aclmdlLoadFromFile("fake_path", modelId); + } + + virtual aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, size_t workSize, + void *weightPtr, size_t weightSize) { + return aclmdlLoadFromFile(modelPath, modelId); + } + + virtual aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, uint32_t *modelId, void *workPtr, + size_t workSize, void *weightPtr, size_t weightSize) { + return aclmdlLoadFromMem(model, modelSize, modelId); + } + + virtual aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) { + if (std::find(model_live_.begin(), model_live_.end(), modelId) == model_live_.end()) { + return 1; + } + if (input == nullptr || output == nullptr) { + return false; + } + // auto& model_desc = model_live_[modelId]; + return ACL_ERROR_NONE; + } + + virtual aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, + aclrtStream stream) { + return ACL_ERROR_NONE; + } + + virtual aclError aclmdlUnload(uint32_t modelId) { + auto it = std::find(model_live_.begin(), model_live_.end(), modelId); + if (it == model_live_.end()) { + return 1; + } + model_live_.erase(it); + model_destroy_.push_back(modelId); + return ACL_ERROR_NONE; + } + uint32_t cur_max_model_id_ = 0; + std::vector model_live_; + std::vector model_destroy_; +}; + +class AclModelDesc { + public: + AclModelDesc() {} + virtual ~AclModelDesc() { Clear(); } + virtual void Clear() { model_desc_.Clear(); } + bool Check() { return model_desc_.Check(); } + + public: + virtual aclmdlDesc *aclmdlCreateDesc() { return model_desc_.OnCreate(); } + aclError aclmdlDestroyDesc(aclmdlDesc *modelDesc) { return model_desc_.OnDestroy(modelDesc); } + + aclError aclmdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId) { + auto &model_live = g_acl_model->model_live_; + auto it = std::find(model_live.begin(), model_live.end(), modelId); + if (it == model_live.end()) { + return 1; + } + return ACL_ERROR_NONE; + } + + size_t aclmdlGetNumInputs(aclmdlDesc *modelDesc) { return modelDesc->inputs.size(); } + + size_t aclmdlGetNumOutputs(aclmdlDesc *modelDesc) { return modelDesc->outputs.size(); } + + size_t aclmdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { return modelDesc->inputs[index].size; } + + size_t aclmdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { return modelDesc->outputs[index].size; } + + aclError aclmdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + auto &input = modelDesc->inputs[index]; + dims->dimCount = input.dims.size(); + for (size_t i = 0; i < dims->dimCount; i++) { + dims->dims[i] = input.dims[i]; + } + return ACL_ERROR_NONE; + } + + aclError aclmdlGetOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + auto &input = modelDesc->outputs[index]; + dims->dimCount = input.dims.size(); + for (size_t i = 0; i < dims->dimCount; i++) { + dims->dims[i] = input.dims[i]; + } + return ACL_ERROR_NONE; + } + + aclError aclmdlGetCurOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + return aclmdlGetOutputDims(modelDesc, index, dims); + } + + aclFormat aclmdlGetInputFormat(const aclmdlDesc *modelDesc, size_t index) { return ACL_FORMAT_NCHW; } + aclFormat aclmdlGetOutputFormat(const aclmdlDesc *modelDesc, size_t index) { return ACL_FORMAT_NCHW; } + + aclDataType aclmdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index) { + return modelDesc->inputs[index].data_type; + } + + aclDataType aclmdlGetOutputDataType(const aclmdlDesc *modelDesc, size_t index) { + return modelDesc->outputs[index].data_type; + } + + ResourceBase model_desc_; +}; + +class AclRunMode { + public: + virtual aclError aclrtGetRunMode(aclrtRunMode *runMode) { + *runMode = aclrtRunMode::ACL_HOST; + return ACL_ERROR_NONE; + } +}; + +class AclDeviceContextStream { + public: + AclDeviceContextStream() {} + ~AclDeviceContextStream() { Clear(); } + virtual void Clear() { + for (auto context : context_live_) { + delete (int *)context; + } + context_live_.clear(); + context_destroy_.clear(); + device_id_live_.clear(); + device_id_destroy_.clear(); + for (auto item : stream_live_) { + delete (int *)item; + } + stream_live_.clear(); + stream_destroy_.clear(); + } + bool Check() { return context_live_.empty() && device_id_live_.empty() && stream_live_.empty(); } + virtual aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId) { + context_live_.push_back(new int()); + *context = context_live_.back(); + return ACL_ERROR_NONE; + } + virtual aclError aclrtDestroyContext(aclrtContext context) { + for (auto it = context_live_.begin(); it != context_live_.end(); ++it) { + if (*it == context) { + context_live_.erase(it); + context_destroy_.push_back(context); + delete (int *)context; + return ACL_ERROR_NONE; + } + } + return 1; + } + aclError aclrtSetCurrentContext(aclrtContext context) { return ACL_ERROR_NONE; } + aclError aclrtGetCurrentContext(aclrtContext *context) { return ACL_ERROR_NONE; } + virtual aclError aclrtSetDevice(int32_t deviceId) { + device_id_live_.push_back(deviceId); + return ACL_ERROR_NONE; + } + virtual aclError aclrtResetDevice(int32_t deviceId) { + for (auto it = device_id_live_.begin(); it != device_id_live_.end(); ++it) { + if (*it == deviceId) { + device_id_live_.erase(it); + device_id_destroy_.push_back(deviceId); + return ACL_ERROR_NONE; + } + } + return 1; + } + aclError aclrtGetDevice(int32_t *deviceId) { + *deviceId = 0; + return ACL_ERROR_NONE; + } + aclError aclrtSynchronizeDevice(void) { return ACL_ERROR_NONE; } + aclError aclrtSetTsDevice(aclrtTsId tsId) { return ACL_ERROR_NONE; } + aclError aclrtGetDeviceCount(uint32_t *count) { + *count = 1; + return ACL_ERROR_NONE; + } + virtual aclError aclrtCreateStream(aclrtStream *stream) { + stream_live_.push_back(new int()); + *stream = stream_live_.back(); + return ACL_ERROR_NONE; + } + virtual aclError aclrtDestroyStream(aclrtStream stream) { + for (auto it = stream_live_.begin(); it != context_live_.end(); ++it) { + if (*it == stream) { + stream_live_.erase(it); + stream_destroy_.push_back(stream); + delete (int *)stream; + return ACL_ERROR_NONE; + } + } + return 1; + } + aclError aclrtSynchronizeStream(aclrtStream stream) { + for (auto it = stream_live_.begin(); it != context_live_.end(); ++it) { + if (*it == stream) { + return ACL_ERROR_NONE; + } + } + return 1; + } + std::vector device_id_live_; + std::vector device_id_destroy_; + std::vector context_live_; + std::vector context_destroy_; + std::vector stream_live_; + std::vector stream_destroy_; +}; + +class AclMemory { + public: + AclMemory() {} + ~AclMemory() { Clear(); } + void Clear() { + for (auto item : device_buffer_live_) { + delete[] item; + } + for (auto item : host_buffer_live_) { + delete[] item; + } + for (auto item : dvpp_buffer_live_) { + delete[] item; + } + device_buffer_live_.clear(); + device_buffer_destroy_.clear(); + host_buffer_live_.clear(); + host_buffer_destroy_.clear(); + dvpp_buffer_live_.clear(); + dvpp_buffer_destroy_.clear(); + } + bool Check() { return device_buffer_live_.empty() && host_buffer_live_.empty() && dvpp_buffer_live_.empty(); } + virtual aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { + auto buffer = new uint8_t[size]; + *devPtr = buffer; + device_buffer_live_.push_back(buffer); + memory_len_[buffer] = size; + return ACL_ERROR_NONE; + } + aclError aclrtFree(void *devPtr) { + auto it = std::find(device_buffer_live_.begin(), device_buffer_live_.end(), devPtr); + if (it != device_buffer_live_.end()) { + delete[](*it); + device_buffer_live_.erase(it); + device_buffer_destroy_.push_back(*it); + return ACL_ERROR_NONE; + } + return 1; + } + + virtual aclError aclrtMallocHost(void **hostPtr, size_t size) { + auto buffer = new uint8_t[size]; + *hostPtr = buffer; + host_buffer_live_.push_back(buffer); + memory_len_[buffer] = size; + return ACL_ERROR_NONE; + } + + aclError aclrtFreeHost(void *hostPtr) { + auto it = std::find(host_buffer_live_.begin(), host_buffer_live_.end(), hostPtr); + if (it != host_buffer_live_.end()) { + delete[](*it); + host_buffer_live_.erase(it); + host_buffer_destroy_.push_back(*it); + return ACL_ERROR_NONE; + } + return 1; + } + + aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind) { + auto is_device_memory = [this](const void *memory, uint32_t use_size) { + for (auto it = device_buffer_live_.begin(); it != device_buffer_live_.end(); it++) { + auto size = memory_len_[*it]; + if (memory >= *it && static_cast(memory) + use_size <= (*it) + size) { + return true; + } + } + for (auto it = dvpp_buffer_live_.begin(); it != dvpp_buffer_live_.end(); it++) { + auto size = memory_len_[*it]; + if (memory >= *it && static_cast(memory) + use_size <= (*it) + size) { + return true; + } + } + return false; + }; + if (kind == ACL_MEMCPY_HOST_TO_HOST) { + if (is_device_memory(dst, destMax) || is_device_memory(src, count)) { + return 1; + } + } else if (kind == ACL_MEMCPY_HOST_TO_DEVICE) { + if (!is_device_memory(dst, destMax) || is_device_memory(src, count)) { + return 1; + } + } else if (kind == ACL_MEMCPY_DEVICE_TO_HOST) { + if (is_device_memory(dst, destMax) || !is_device_memory(src, count)) { + return 1; + } + } else if (kind == ACL_MEMCPY_DEVICE_TO_DEVICE) { + if (!is_device_memory(dst, destMax) || !is_device_memory(src, count)) { + return 1; + } + } else { + return 1; + } + memcpy(dst, src, count); + return ACL_ERROR_NONE; + } + + virtual aclError acldvppMalloc(void **devPtr, size_t size) { + auto buffer = new uint8_t[size]; + *devPtr = buffer; + dvpp_buffer_live_.push_back(buffer); + memory_len_[buffer] = size; + return ACL_ERROR_NONE; + } + aclError acldvppFree(void *devPtr) { + auto it = std::find(dvpp_buffer_live_.begin(), dvpp_buffer_live_.end(), devPtr); + if (it != dvpp_buffer_live_.end()) { + delete[](*it); + dvpp_buffer_live_.erase(it); + dvpp_buffer_destroy_.push_back(*it); + return ACL_ERROR_NONE; + } + return 1; + } + + std::vector device_buffer_live_; + std::vector device_buffer_destroy_; + std::vector host_buffer_live_; + std::vector host_buffer_destroy_; + std::vector dvpp_buffer_live_; + std::vector dvpp_buffer_destroy_; + std::map memory_len_; +}; + +class AclDvppPicDesc { + public: + bool Check() { return pic_desc_.Check(); } + acldvppPicDesc *acldvppCreatePicDesc() { return pic_desc_.OnCreate(); } + + aclError acldvppDestroyPicDesc(acldvppPicDesc *picDesc) { return pic_desc_.OnDestroy(picDesc); } + + aclError acldvppSetPicDescSize(acldvppPicDesc *picDesc, uint32_t size) { + picDesc->size = size; + return ACL_ERROR_NONE; + } + + aclError acldvppSetPicDescFormat(acldvppPicDesc *picDesc, acldvppPixelFormat format) { + picDesc->format = format; + return ACL_ERROR_NONE; + } + + aclError acldvppSetPicDescWidth(acldvppPicDesc *picDesc, uint32_t width) { + picDesc->width = width; + return ACL_ERROR_NONE; + } + + aclError acldvppSetPicDescHeight(acldvppPicDesc *picDesc, uint32_t height) { + picDesc->height = height; + return ACL_ERROR_NONE; + } + + aclError acldvppSetPicDescData(acldvppPicDesc *picDesc, void *dataDev) { + picDesc->dataDev = dataDev; + return ACL_ERROR_NONE; + } + + aclError acldvppSetPicDescWidthStride(acldvppPicDesc *picDesc, uint32_t widthStride) { + picDesc->widthStride = widthStride; + return ACL_ERROR_NONE; + } + + aclError acldvppSetPicDescHeightStride(acldvppPicDesc *picDesc, uint32_t heightStride) { + picDesc->heightStride = heightStride; + return ACL_ERROR_NONE; + } + ResourceBase pic_desc_; +}; + +class AclDvppRoiConfig { + public: + bool Check() { return roi_config_.Check(); } + acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom) { + return roi_config_.OnCreate(acldvppRoiConfig{.left = left, .right = right, .top = top, .bottom = bottom}); + } + + aclError acldvppDestroyRoiConfig(acldvppRoiConfig *roiConfig) { return roi_config_.OnDestroy(roiConfig); } + + aclError acldvppSetRoiConfig(acldvppRoiConfig *roiConfig, uint32_t left, uint32_t right, uint32_t top, + uint32_t bottom) { + roiConfig->left = left; + roiConfig->right = right; + roiConfig->top = top; + roiConfig->bottom = bottom; + return ACL_ERROR_NONE; + } + ResourceBase roi_config_; +}; + +class AclDvppResizeConfig { + public: + bool Check() { return resize_config_.Check(); } + acldvppResizeConfig *acldvppCreateResizeConfig() { return resize_config_.OnCreate(acldvppResizeConfig{}); } + + aclError acldvppDestroyResizeConfig(acldvppResizeConfig *resizeConfig) { + return resize_config_.OnDestroy(resizeConfig); + } + ResourceBase resize_config_; +}; + +class AclDvppChannelDesc { + public: + bool Check() { return channel_desc_.Check(); } + aclError acldvppCreateChannel(acldvppChannelDesc *channelDesc) { + channelDesc->channel_valid_flag = true; + return ACL_ERROR_NONE; + } + aclError acldvppDestroyChannel(acldvppChannelDesc *channelDesc) { + channelDesc->channel_valid_flag = false; + return ACL_ERROR_NONE; + } + acldvppChannelDesc *acldvppCreateChannelDesc() { return channel_desc_.OnCreate(); } + aclError acldvppDestroyChannelDesc(acldvppChannelDesc *channelDesc) { + if (channelDesc->channel_valid_flag) { + return 1; + } + return channel_desc_.OnDestroy(channelDesc); + } + ResourceBase channel_desc_; +}; + +class AclDvppProcess { + public: + bool Check() { return true; } + virtual aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, + acldvppPicDesc *outputDesc, acldvppResizeConfig *resizeConfig, + aclrtStream stream) { + resize_call_times_++; + if (channelDesc == nullptr || inputDesc == nullptr || outputDesc == nullptr || resizeConfig == nullptr || + stream == nullptr) { + return 1; + } + if (CheckPicDesc(inputDesc) != ACL_ERROR_NONE) { + return 1; + } + if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) { + return 1; + } + return ACL_ERROR_NONE; + } + + virtual aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, + acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, aclrtStream stream) { + crop_call_times_++; + if (channelDesc == nullptr || inputDesc == nullptr || outputDesc == nullptr || cropArea == nullptr || + stream == nullptr) { + return 1; + } + if (CheckPicDesc(inputDesc) != ACL_ERROR_NONE) { + return 1; + } + if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) { + return 1; + } + if (CheckCropArea(cropArea) != ACL_ERROR_NONE) { + return 1; + } + return ACL_ERROR_NONE; + } + + virtual aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, + acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, + acldvppRoiConfig *pasteArea, aclrtStream stream) { + crop_paste_call_times_++; + if (channelDesc == nullptr || inputDesc == nullptr || outputDesc == nullptr || cropArea == nullptr || + pasteArea == nullptr || stream == nullptr) { + return 1; + } + if (CheckPicDesc(inputDesc) != ACL_ERROR_NONE) { + return 1; + } + if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) { + return 1; + } + if (CheckCropArea(cropArea) != ACL_ERROR_NONE) { + return 1; + } + if (CheckCropArea(pasteArea) != ACL_ERROR_NONE) { + return 1; + } + return ACL_ERROR_NONE; + } + + aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchDesc, + uint32_t *roiNums, uint32_t size, acldvppBatchPicDesc *dstBatchDesc, + acldvppRoiConfig *cropAreas[], aclrtStream stream) { + return ACL_ERROR_NONE; + } + + virtual aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size, + acldvppPicDesc *outputDesc, aclrtStream stream) { + decode_call_times_++; + if (channelDesc == nullptr || data == nullptr || size == 0 || outputDesc == nullptr || stream == nullptr) { + return 1; + } + if (outputDesc->widthStride % 128 != 0) { + return 1; + } + if (outputDesc->heightStride % 16 != 0) { + return 1; + } + if (outputDesc->widthStride < 32 || outputDesc->widthStride > 8192) { + return 1; + } + if (outputDesc->heightStride < 32 || outputDesc->heightStride > 8192) { + return 1; + } + if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) { + return 1; + } + return ACL_ERROR_NONE; + } + aclError CheckCropArea(acldvppRoiConfig *crop_area) { + if (crop_area->left % 2 != 0 || crop_area->top % 2 != 0) { + return 1; + } + if (crop_area->right % 2 != 1 || crop_area->bottom % 2 != 1) { + return 1; + } + auto crop_width = crop_area->right - crop_area->left + 1; + if (crop_width < 10 || crop_width > 4096) { + return 1; + } + auto crop_heigth = crop_area->bottom - crop_area->top + 1; + if (crop_heigth < 6 || crop_heigth > 4096) { + return 1; + } + return ACL_ERROR_NONE; + } + aclError CheckPicDesc(acldvppPicDesc *pic_desc) { + if (pic_desc->width == 0 || pic_desc->height == 0) { + return 1; + } + if (pic_desc->widthStride % 16 != 0 || pic_desc->widthStride < pic_desc->width) { + return 1; + } + if (pic_desc->heightStride % 2 != 0 || pic_desc->heightStride < pic_desc->height) { + return 1; + } + if (pic_desc->widthStride < 32 || pic_desc->widthStride > 4096) { + return 1; + } + if (pic_desc->heightStride < 6 || pic_desc->heightStride > 4096) { + return 1; + } + if (pic_desc->dataDev == nullptr) { + return 1; + } + auto size = pic_desc->size; + auto ele_cnt = pic_desc->widthStride * pic_desc->heightStride; + switch (pic_desc->format) { + case PIXEL_FORMAT_YUV_SEMIPLANAR_420: + case PIXEL_FORMAT_YVU_SEMIPLANAR_420: + if (ele_cnt * 3 / 2 != size) { + return 1; + } + break; + case PIXEL_FORMAT_YUV_SEMIPLANAR_422: + case PIXEL_FORMAT_YVU_SEMIPLANAR_422: + if (ele_cnt * 2 != size) { + return 1; + } + break; + case PIXEL_FORMAT_YUV_SEMIPLANAR_444: + case PIXEL_FORMAT_YVU_SEMIPLANAR_444: + if (ele_cnt * 3 != size) { + return 1; + } + break; + default: + return 1; + } + return ACL_ERROR_NONE; + } + uint32_t decode_call_times_ = 0; + uint32_t resize_call_times_ = 0; + uint32_t crop_call_times_ = 0; + uint32_t crop_paste_call_times_ = 0; +}; + +class AclJpegLib { + public: + bool Check() { return jpeg_live_.empty(); } + AclJpegLib(uint32_t width, uint32_t height) : image_width_(width), image_height_(height) {} + + void jpeg_CreateDecompress(j_decompress_ptr cinfo, int version, size_t structsize) { jpeg_live_.push_back(cinfo); } + void jpeg_mem_src(j_decompress_ptr cinfo, const unsigned char *inbuffer, unsigned long insize) {} + int jpeg_read_header(j_decompress_ptr cinfo, boolean require_image) { + static JHUFF_TBL tal; + cinfo->image_width = image_width_; + cinfo->image_height = image_height_; + cinfo->jpeg_color_space = color_space_; + for (int i = 0; i < NUM_HUFF_TBLS; i++) { + cinfo->ac_huff_tbl_ptrs[i] = &tal; + cinfo->dc_huff_tbl_ptrs[i] = &tal; + } + return 0; + } + void jpeg_destroy_decompress(j_decompress_ptr cinfo) { + auto it = std::find(jpeg_live_.begin(), jpeg_live_.end(), cinfo); + if (it != jpeg_live_.end()) { + jpeg_live_.erase(it); + } + } + uint32_t image_width_; + uint32_t image_height_; + J_COLOR_SPACE color_space_ = JCS_YCbCr; + std::vector jpeg_live_; +}; + +extern AclDataBuffer *g_acl_data_buffer; +extern AclEnv *g_acl_env; +extern AclDataSet *g_acl_dataset; +extern AclModelDesc *g_acl_model_desc; +extern AclDeviceContextStream *g_acl_device_context_stream; +extern AclMemory *g_acl_memory; +extern AclDvppPicDesc *g_acl_dvpp_pic_desc; +extern AclDvppRoiConfig *g_acl_dvpp_roi_config; +extern AclDvppResizeConfig *g_acl_dvpp_resize_config; +extern AclDvppChannelDesc *g_acl_dvpp_channel_desc; +extern AclDvppProcess *g_acl_dvpp_process; +extern AclRunMode *g_acl_run_mode; +extern AclJpegLib *g_acl_jpeg_lib; + +#endif // MINDSPORE_ACL_STUB_H diff --git a/tests/ut/cpp/serving/ms_service.proto b/tests/ut/cpp/serving/ms_service.proto new file mode 120000 index 0000000000..dd846c368f --- /dev/null +++ b/tests/ut/cpp/serving/ms_service.proto @@ -0,0 +1 @@ +../../../../serving/ms_service.proto \ No newline at end of file diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index ac38e5427e..1ddc5d7b14 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -766,7 +766,7 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) { auto kernel_graph = std::make_shared(); auto parameter_node = kernel_graph->add_parameter(); MS_EXCEPTION_IF_NULL(parameter_node); - auto param_value_new = std::make_shared(); + auto param_value_new = std::make_shared(int64_t(0), kInt32); parameter_node->set_default_param(param_value_new); EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node)); EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error); diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index f24036b4aa..3cf7189a08 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -60,7 +60,7 @@ TEST_F(KernelGraphTest, NewParameter) { auto anf_graph = std::make_shared(); auto kernel_graph = std::make_shared(); // test nullptr as input - auto new_paramter = kernel_graph->NewParameter(nullptr); + auto new_paramter = kernel_graph->NewParameter(); EXPECT_NE(new_paramter, nullptr); EXPECT_TRUE(new_paramter->isa()); EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT); @@ -82,7 +82,7 @@ TEST_F(KernelGraphTest, NewParameter) { // test weight parameter node as input auto weight_parameter_node = anf_graph->add_parameter(); MS_EXCEPTION_IF_NULL(weight_parameter_node); - auto param_value_new = std::make_shared(); + auto param_value_new = std::make_shared(kNumberTypeFloat32, shape); weight_parameter_node->set_default_param(param_value_new); weight_parameter_node->set_abstract(x_abstract); auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node); diff --git a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc index 234ffdaf6b..e2502baf8b 100644 --- a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc +++ b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc @@ -38,6 +38,10 @@ bool ModelRunner::RunModel(uint32_t model_id, const ge::InputData &input_data, g return true; } +void *ModelRunner::GetModelHandle(uint32_t model_id) const { return nullptr; } + +bool ModelRunner::DistributeTask(uint32_t model_id) { return true; } + const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const { static std::vector task_id_list; return task_id_list; diff --git a/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc b/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc index 87ab543c7c..0b91bbdd35 100755 --- a/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc +++ b/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc @@ -15,7 +15,7 @@ */ #include "backend/kernel_compiler/kernel_fusion.h" #include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" -#include "common/utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace kernel { diff --git a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc index f6f2f45092..6ae883cfbd 100644 --- a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc +++ b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc @@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { return SUCCESS; } +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, + ManualShapeMap *manual_shape_map) { return SUCCESS; } } // namespace parallel } // namespace mindspore diff --git a/tests/ut/cpp/stub/tasksink/task_sink_stub.cc b/tests/ut/cpp/stub/tasksink/task_sink_stub.cc index 0b12a3862c..4f909ec2d5 100644 --- a/tests/ut/cpp/stub/tasksink/task_sink_stub.cc +++ b/tests/ut/cpp/stub/tasksink/task_sink_stub.cc @@ -15,6 +15,7 @@ */ #include "runtime/device/ascend/tasksink/task_generator.h" +#include "runtime/device/ascend/dump/data_dumper.h" namespace mindspore { namespace device { @@ -25,6 +26,11 @@ bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::ve return true; } } // namespace tasksink +void DataDumper::LoadDumpInfo() {} +void DataDumper::UnloadDumpInfo() {} +void DataDumper::OpDebugRegister() {} +void DataDumper::OpDebugUnregister() {} +DataDumper::~DataDumper() {} } // namespace ascend } // namespace device } // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/transform/graph_runner_test.cc b/tests/ut/cpp/transform/graph_runner_test.cc index b91ec959d2..fed34b1c62 100644 --- a/tests/ut/cpp/transform/graph_runner_test.cc +++ b/tests/ut/cpp/transform/graph_runner_test.cc @@ -18,7 +18,7 @@ #include #include "common/common_test.h" #include "ir/dtype.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" #include "transform/transform_base_test.h" #include "common/py_func_graph_fetcher.h" #include "pipeline/jit/static_analysis/static_analysis.h" diff --git a/tests/ut/cpp/transform/transform_base_test.cc b/tests/ut/cpp/transform/transform_base_test.cc index 50227bc53c..86c4e7b8e0 100644 --- a/tests/ut/cpp/transform/transform_base_test.cc +++ b/tests/ut/cpp/transform/transform_base_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" #include "transform/transform_base_test.h" -#include "ir/tensor_py.h" +#include "utils/tensor_py.h" using mindspore::tensor::TensorPy; diff --git a/tests/ut/cpp/utils/baseref_test.cc b/tests/ut/cpp/utils/baseref_test.cc index 4e1556d819..4a5a267d9d 100644 --- a/tests/ut/cpp/utils/baseref_test.cc +++ b/tests/ut/cpp/utils/baseref_test.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "ir/anf.h" -#include "utils/base_ref.h" +#include "base/base_ref.h" namespace mindspore { namespace utils { diff --git a/tests/ut/cpp/utils/graph_utils_test.cc b/tests/ut/cpp/utils/graph_utils_test.cc index 35fa9cdc6a..60bc3f75ec 100644 --- a/tests/ut/cpp/utils/graph_utils_test.cc +++ b/tests/ut/cpp/utils/graph_utils_test.cc @@ -22,8 +22,8 @@ #include "common/py_func_graph_fetcher.h" #include "ir/anf.h" -#include "utils/graph_utils.h" - +#include "ir/graph_utils.h" +#include "utils/convert_utils.h" #include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/parse/parse.h" diff --git a/tests/ut/cpp/vm/segment_runner_test.cc b/tests/ut/cpp/vm/segment_runner_test.cc index c83b1b3434..22d3e6857d 100644 --- a/tests/ut/cpp/vm/segment_runner_test.cc +++ b/tests/ut/cpp/vm/segment_runner_test.cc @@ -21,7 +21,7 @@ #include "utils/log_adapter.h" #include "ir/func_graph_cloner.h" #include "pipeline/jit/parse/parse.h" -#include "utils/graph_utils.h" +#include "ir/graph_utils.h" #include "pipeline/jit/resource.h" #include "debug/draw.h" #include "frontend/operator/ops.h" @@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { std::vector todos(splits.size()); auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef& seg) -> bool { return utils::isa(seg); }); + [](const BaseRef &seg) -> bool { return utils::isa(seg); }); todos.resize(std::distance(todos.begin(), it)); ASSERT_EQ(todos.size(), 1); - AnfNodePtrList anf_list; + AnfNodePtrList anf_list; for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } @@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { std::vector todos(splits.size()); auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef& seg) -> bool { return utils::isa(seg); }); + [](const BaseRef &seg) -> bool { return utils::isa(seg); }); todos.resize(std::distance(todos.begin(), it)); ASSERT_EQ(todos.size(), 1); - AnfNodePtrList anf_list; + AnfNodePtrList anf_list; for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } @@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) { std::vector todos(splits.size()); auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef& seg) -> bool { return utils::isa(seg); }); + [](const BaseRef &seg) -> bool { return utils::isa(seg); }); todos.resize(std::distance(todos.begin(), it)); ASSERT_EQ(todos.size(), 1); - AnfNodePtrList anf_list; + AnfNodePtrList anf_list; for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } @@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) { TEST_F(TestCompileSegmentRunner, test_RunOperation1) { VectorRef args({1}); - auto res = RunOperation(prim::kPrimIdentity, args); + auto res = RunOperation(std::make_shared(py::str(prim::kPrimIdentity->name()), py::none()), args); ASSERT_EQ(py::cast(BaseRefToPyData(res)), 1); } TEST_F(TestCompileSegmentRunner, test_RunOperation2) { VectorRef args({1, 2}); - auto res = RunOperation(prim::kPrimScalarGt, args); + auto res = RunOperation(std::make_shared(py::str(prim::kPrimScalarGt->name()), py::none()), args); ASSERT_EQ(py::cast(BaseRefToPyData(res)), false); } } // namespace compile diff --git a/tests/ut/data/dataset/golden/autcontrast_01_result_c.npz b/tests/ut/data/dataset/golden/autcontrast_01_result_c.npz new file mode 100644 index 0000000000..062fc9a48b Binary files /dev/null and b/tests/ut/data/dataset/golden/autcontrast_01_result_c.npz differ diff --git a/tests/ut/data/dataset/golden/autcontrast_01_result_py.npz b/tests/ut/data/dataset/golden/autcontrast_01_result_py.npz new file mode 100644 index 0000000000..252aee514b Binary files /dev/null and b/tests/ut/data/dataset/golden/autcontrast_01_result_py.npz differ diff --git a/tests/ut/data/dataset/golden/batch_01_result.npz b/tests/ut/data/dataset/golden/batch_01_result.npz index 7da040cd58..b2dd3bd71e 100644 Binary files a/tests/ut/data/dataset/golden/batch_01_result.npz and b/tests/ut/data/dataset/golden/batch_01_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_02_result.npz b/tests/ut/data/dataset/golden/batch_02_result.npz index 1d126c043e..671e5161c7 100644 Binary files a/tests/ut/data/dataset/golden/batch_02_result.npz and b/tests/ut/data/dataset/golden/batch_02_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_03_result.npz b/tests/ut/data/dataset/golden/batch_03_result.npz index 3bb486428e..3d4601cdaf 100644 Binary files a/tests/ut/data/dataset/golden/batch_03_result.npz and b/tests/ut/data/dataset/golden/batch_03_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_04_result.npz b/tests/ut/data/dataset/golden/batch_04_result.npz index 39198c5692..aed34bf1e7 100644 Binary files a/tests/ut/data/dataset/golden/batch_04_result.npz and b/tests/ut/data/dataset/golden/batch_04_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_05_result.npz b/tests/ut/data/dataset/golden/batch_05_result.npz index 24ab9b0836..865b99825c 100644 Binary files a/tests/ut/data/dataset/golden/batch_05_result.npz and b/tests/ut/data/dataset/golden/batch_05_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_06_result.npz b/tests/ut/data/dataset/golden/batch_06_result.npz index 6e8e923eb9..5b1f3e7971 100644 Binary files a/tests/ut/data/dataset/golden/batch_06_result.npz and b/tests/ut/data/dataset/golden/batch_06_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_07_result.npz b/tests/ut/data/dataset/golden/batch_07_result.npz index b25854e179..c5fca2c73a 100644 Binary files a/tests/ut/data/dataset/golden/batch_07_result.npz and b/tests/ut/data/dataset/golden/batch_07_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_08_result.npz b/tests/ut/data/dataset/golden/batch_08_result.npz index b02f0eb324..27fa114d57 100644 Binary files a/tests/ut/data/dataset/golden/batch_08_result.npz and b/tests/ut/data/dataset/golden/batch_08_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_09_result.npz b/tests/ut/data/dataset/golden/batch_09_result.npz index 6e8e923eb9..5b1f3e7971 100644 Binary files a/tests/ut/data/dataset/golden/batch_09_result.npz and b/tests/ut/data/dataset/golden/batch_09_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_10_result.npz b/tests/ut/data/dataset/golden/batch_10_result.npz index 5568b7e0cc..426e156780 100644 Binary files a/tests/ut/data/dataset/golden/batch_10_result.npz and b/tests/ut/data/dataset/golden/batch_10_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_11_result.npz b/tests/ut/data/dataset/golden/batch_11_result.npz index 2035b10c94..c7e8d6988a 100644 Binary files a/tests/ut/data/dataset/golden/batch_11_result.npz and b/tests/ut/data/dataset/golden/batch_11_result.npz differ diff --git a/tests/ut/data/dataset/golden/batch_12_result.npz b/tests/ut/data/dataset/golden/batch_12_result.npz index 24ab9b0836..865b99825c 100644 Binary files a/tests/ut/data/dataset/golden/batch_12_result.npz and b/tests/ut/data/dataset/golden/batch_12_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz index 14ddc166e2..92a3970641 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz index 07ae4e5892..602cd0d9d7 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz index a72643457b..a46d24f5dc 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz index 9a6ae1cb99..efd2abe4ed 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/equalize_01_result_c.npz b/tests/ut/data/dataset/golden/equalize_01_result_c.npz new file mode 100644 index 0000000000..2c3a37eb4d Binary files /dev/null and b/tests/ut/data/dataset/golden/equalize_01_result_c.npz differ diff --git a/tests/ut/data/dataset/golden/invert_01_result_c.npz b/tests/ut/data/dataset/golden/invert_01_result_c.npz new file mode 100644 index 0000000000..0a819192d2 Binary files /dev/null and b/tests/ut/data/dataset/golden/invert_01_result_c.npz differ diff --git a/tests/ut/data/dataset/golden/invert_01_result.npz b/tests/ut/data/dataset/golden/invert_01_result_py.npz similarity index 100% rename from tests/ut/data/dataset/golden/invert_01_result.npz rename to tests/ut/data/dataset/golden/invert_01_result_py.npz diff --git a/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz index bb33f1bece..864ff47dbb 100644 Binary files a/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz index 416223ff4d..b256b3b8f7 100644 Binary files a/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz index 75f4447ded..8d6fa8f41f 100644 Binary files a/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz and b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz index aa9778bd39..95425f55a9 100644 Binary files a/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz index e0e0eb2823..2efdcd5c3b 100644 Binary files a/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/resize_01_result.npz b/tests/ut/data/dataset/golden/resize_01_result.npz new file mode 100644 index 0000000000..b3a52243a4 Binary files /dev/null and b/tests/ut/data/dataset/golden/resize_01_result.npz differ diff --git a/tests/ut/data/dataset/golden/resize_02_result.npz b/tests/ut/data/dataset/golden/resize_02_result.npz new file mode 100644 index 0000000000..b3a52243a4 Binary files /dev/null and b/tests/ut/data/dataset/golden/resize_02_result.npz differ diff --git a/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz b/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz index ca64884937..71c6a36a99 100644 Binary files a/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz and b/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz differ diff --git a/tests/ut/data/dataset/golden/shuffle_01_result.npz b/tests/ut/data/dataset/golden/shuffle_01_result.npz index 467ea74a4c..589afc1271 100644 Binary files a/tests/ut/data/dataset/golden/shuffle_01_result.npz and b/tests/ut/data/dataset/golden/shuffle_01_result.npz differ diff --git a/tests/ut/data/dataset/golden/shuffle_02_result.npz b/tests/ut/data/dataset/golden/shuffle_02_result.npz index 27eb0a470d..03540388d3 100644 Binary files a/tests/ut/data/dataset/golden/shuffle_02_result.npz and b/tests/ut/data/dataset/golden/shuffle_02_result.npz differ diff --git a/tests/ut/data/dataset/golden/shuffle_03_result.npz b/tests/ut/data/dataset/golden/shuffle_03_result.npz index 6a6e62f3ff..297b54d9ca 100644 Binary files a/tests/ut/data/dataset/golden/shuffle_03_result.npz and b/tests/ut/data/dataset/golden/shuffle_03_result.npz differ diff --git a/tests/ut/data/dataset/golden/shuffle_04_result.npz b/tests/ut/data/dataset/golden/shuffle_04_result.npz index a3b9469f9c..704cc82389 100644 Binary files a/tests/ut/data/dataset/golden/shuffle_04_result.npz and b/tests/ut/data/dataset/golden/shuffle_04_result.npz differ diff --git a/tests/ut/data/dataset/golden/shuffle_05_result.npz b/tests/ut/data/dataset/golden/shuffle_05_result.npz index 27eb0a470d..03540388d3 100644 Binary files a/tests/ut/data/dataset/golden/shuffle_05_result.npz and b/tests/ut/data/dataset/golden/shuffle_05_result.npz differ diff --git a/tests/ut/data/dataset/golden/test_2ops_batch_repeat.npz b/tests/ut/data/dataset/golden/test_2ops_batch_repeat.npz index b4346fd796..cba3a7fa01 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_batch_repeat.npz and b/tests/ut/data/dataset/golden/test_2ops_batch_repeat.npz differ diff --git a/tests/ut/data/dataset/golden/test_2ops_batch_shuffle.npz b/tests/ut/data/dataset/golden/test_2ops_batch_shuffle.npz index e3273425d3..54ff4435e0 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_batch_shuffle.npz and b/tests/ut/data/dataset/golden/test_2ops_batch_shuffle.npz differ diff --git a/tests/ut/data/dataset/golden/test_2ops_repeat_batch.npz b/tests/ut/data/dataset/golden/test_2ops_repeat_batch.npz index 7a9a70861c..40b0489a59 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_repeat_batch.npz and b/tests/ut/data/dataset/golden/test_2ops_repeat_batch.npz differ diff --git a/tests/ut/data/dataset/golden/test_2ops_repeat_shuffle.npz b/tests/ut/data/dataset/golden/test_2ops_repeat_shuffle.npz index b0ab3fb798..2ae358d16b 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_repeat_shuffle.npz and b/tests/ut/data/dataset/golden/test_2ops_repeat_shuffle.npz differ diff --git a/tests/ut/data/dataset/golden/test_2ops_shuffle_batch.npz b/tests/ut/data/dataset/golden/test_2ops_shuffle_batch.npz index 579431fcc9..2939b29786 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_shuffle_batch.npz and b/tests/ut/data/dataset/golden/test_2ops_shuffle_batch.npz differ diff --git a/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz b/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz index 5c39c0e64b..6e2486024f 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz and b/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz differ diff --git a/tests/ut/data/dataset/testCSV/1.csv b/tests/ut/data/dataset/testCSV/1.csv new file mode 100644 index 0000000000..13fbfd70a1 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/1.csv @@ -0,0 +1,3 @@ +1,2,3,4 +5,6,7,8 +9,10,11,12 diff --git a/tests/ut/data/dataset/testCSV/2.csv b/tests/ut/data/dataset/testCSV/2.csv new file mode 100644 index 0000000000..b96a0a4ed3 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/2.csv @@ -0,0 +1,8 @@ +,"222",3,"4""" +"5",6,,"8" +9,10,"1""1",12 +,,"", +,,, + +a,b,c,"" +a,b,c,d diff --git a/tests/ut/data/dataset/testCSV/chinese.csv b/tests/ut/data/dataset/testCSV/chinese.csv new file mode 100644 index 0000000000..9445c52704 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/chinese.csv @@ -0,0 +1 @@ +大家,早上好,中午好,下午好,晚上好 diff --git a/tests/ut/data/dataset/testCSV/embedded.csv b/tests/ut/data/dataset/testCSV/embedded.csv new file mode 100644 index 0000000000..c7e10b0136 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/embedded.csv @@ -0,0 +1,2 @@ +"a,b","c""d","e +f"," g " diff --git a/tests/ut/data/dataset/testCSV/exception.csv b/tests/ut/data/dataset/testCSV/exception.csv new file mode 100644 index 0000000000..da5357efa5 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/exception.csv @@ -0,0 +1,3 @@ +1,2,3,4 +5,6,7,8 +a,"c",d,"e diff --git a/tests/ut/data/dataset/testCSV/header.csv b/tests/ut/data/dataset/testCSV/header.csv new file mode 100644 index 0000000000..bf14e15263 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/header.csv @@ -0,0 +1,2 @@ +col1,col2,col3,col4 +a,b,c,d \ No newline at end of file diff --git a/tests/ut/data/dataset/testCSV/number.csv b/tests/ut/data/dataset/testCSV/number.csv new file mode 100644 index 0000000000..2d3a7ec4c4 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/number.csv @@ -0,0 +1 @@ +3,0.3,4,55.5 diff --git a/tests/ut/data/dataset/testCSV/quoted.csv b/tests/ut/data/dataset/testCSV/quoted.csv new file mode 100644 index 0000000000..5391bb9cc7 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/quoted.csv @@ -0,0 +1 @@ +"a","b","c","d" diff --git a/tests/ut/data/dataset/testCSV/separated.csv b/tests/ut/data/dataset/testCSV/separated.csv new file mode 100644 index 0000000000..6a8e0ec28a --- /dev/null +++ b/tests/ut/data/dataset/testCSV/separated.csv @@ -0,0 +1 @@ +a|b|c|d diff --git a/tests/ut/data/dataset/testCSV/size.csv b/tests/ut/data/dataset/testCSV/size.csv new file mode 100644 index 0000000000..6ba3b2ba71 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/size.csv @@ -0,0 +1,10 @@ +1,2,3,4 +"a","b","c +","d +e" +5,6,7,8 +9,10,11,12 +a,"b +",c,"d +e" + diff --git a/tests/ut/data/dataset/testImageNetData2/dataDistributionAll.json b/tests/ut/data/dataset/testImageNetData2/dataDistributionAll.json deleted file mode 100644 index 3ebc4c989c..0000000000 --- a/tests/ut/data/dataset/testImageNetData2/dataDistributionAll.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "deviceNum":4, - "deviceId": 2, - "shardConfig":"ALL", - "shuffle":"ON", - "seed": 0, - "epoch": 2 -} diff --git a/tests/ut/data/dataset/testImageNetData2/dataDistributionRandom.json b/tests/ut/data/dataset/testImageNetData2/dataDistributionRandom.json deleted file mode 100644 index a0f468f91d..0000000000 --- a/tests/ut/data/dataset/testImageNetData2/dataDistributionRandom.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "deviceNum":4, - "deviceId": 2, - "shardConfig":"RANDOM", - "shuffle":"ON", - "seed": 0, - "epoch": 1 -} diff --git a/tests/ut/data/dataset/testImageNetData2/dataDistributionUnique.json b/tests/ut/data/dataset/testImageNetData2/dataDistributionUnique.json deleted file mode 100644 index a4eeddd9ae..0000000000 --- a/tests/ut/data/dataset/testImageNetData2/dataDistributionUnique.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "deviceNum":4, - "deviceId": 2, - "shardConfig":"UNIQUE", - "shuffle":"ON", - "seed": 0, - "epoch": 3 -} diff --git a/tests/ut/data/dataset/testPK/distribution.json b/tests/ut/data/dataset/testPK/distribution.json deleted file mode 100644 index 33f869f653..0000000000 --- a/tests/ut/data/dataset/testPK/distribution.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "deviceNum":1, - "deviceId": 0, - "shardConfig":"RANDOM", - "shuffle":"OFF", - "seed": 0 -} diff --git a/tests/ut/data/dataset/testTokenizerData/sentencepiece_tokenizer.txt b/tests/ut/data/dataset/testTokenizerData/sentencepiece_tokenizer.txt new file mode 100644 index 0000000000..ee7f25cba1 --- /dev/null +++ b/tests/ut/data/dataset/testTokenizerData/sentencepiece_tokenizer.txt @@ -0,0 +1 @@ +I saw a girl with a telescope. \ No newline at end of file diff --git a/tests/ut/data/dataset/test_sentencepiece/botchan.txt b/tests/ut/data/dataset/test_sentencepiece/botchan.txt new file mode 100644 index 0000000000..71e3c2f26f --- /dev/null +++ b/tests/ut/data/dataset/test_sentencepiece/botchan.txt @@ -0,0 +1,4288 @@ +Project Gutenberg's Botchan (Master Darling), by Kin-nosuke Natsume +This eBook is for the use of anyone anywhere at no cost and with +almost no restrictions whatsoever. You may copy it, give it away or +re-use it under the terms of the Project Gutenberg License included +with this eBook or online at www.gutenberg.org +Title: Botchan (Master Darling) +Author: Kin-nosuke Natsume +Translator: Yasotaro Morri +Posting Date: October 14, 2012 [EBook #8868] +Release Date: September, 2005 +First Posted: August 17, 2003 +Language: English +*** START OF THIS PROJECT GUTENBERG EBOOK BOTCHAN (MASTER DARLING) *** +Produced by David Starner and the Online Distributed Proofreading Team +BOTCHAN (MASTER DARLING) +By The Late Mr. Kin-nosuke Natsume +TRANSLATED By Yasotaro Morri +Revised by J. R. KENNEDY +1919 +A NOTE BY THE TRANSLATOR +No translation can expect to equal, much less to excel, the original. +The excellence of a translation can only be judged by noting how far it +has succeeded in reproducing the original tone, colors, style, the +delicacy of sentiment, the force of inert strength, the peculiar +expressions native to the language with which the original is written, +or whatever is its marked characteristic. The ablest can do no more, and +to want more than this will be demanding something impossible. Strictly +speaking, the only way one can derive full benefit or enjoyment from a +foreign work is to read the original, for any intelligence at +second-hand never gives the kind of satisfaction which is possible only +through the direct touch with the original. Even in the best translated +work is probably wanted the subtle vitality natural to the original +language, for it defies an attempt, however elaborate, to transmit all +there is in the original. Correctness of diction may be there, but +spontaneity is gone; it cannot be helped. +The task of the translator becomes doubly hazardous in case of +translating a European language into Japanese, or vice versa. Between +any of the European languages and Japanese there is no visible kinship +in word-form, significance, grammatical system, rhetorical arrangements. +It may be said that the inspiration of the two languages is totally +different. A want of similarity of customs, habits, traditions, national +sentiments and traits makes the work of translation all the more +difficult. A novel written in Japanese which had attained national +popularity might, when rendered into English, lose its captivating +vividness, alluring interest and lasting appeal to the reader. +These remarks are made not in way of excuse for any faulty dictions that +may be found in the following pages. Neither are they made out of +personal modesty nor of a desire to add undue weight to the present +work. They are made in the hope that whoever is good enough to go +through the present translation will remember, before he may venture to +make criticisms, the kind and extent of difficulties besetting him in +his attempts so as not to judge the merit of the original by this +translation. Nothing would afford the translator a greater pain than any +unfavorable comment on the original based upon this translation. If +there be any deserving merits in the following pages the credit is due +to the original. Any fault found in its interpretation or in the English +version, the whole responsibility is on the translator. +For the benefit of those who may not know the original, it must be +stated that "Botchan" by the late Mr. K. Natsume was an epoch-making +piece of work. On its first appearance, Mr. Natsume's place and name as +the foremost in the new literary school were firmly established. He had +written many other novels of more serious intent, of heavier thoughts +and of more enduring merits, but it was this "Botchan" that secured him +the lasting fame. Its quaint style, dash and vigor in its narration +appealed to the public who had become somewhat tired of the stereotyped +sort of manner with which all stories had come to be handled. +In its simplest understanding, "Botchan" may be taken as an episode in +the life of a son born in Tokyo, hot-blooded, simple-hearted, pure as +crystal and sturdy as a towering rock, honest and straight to a fault, +intolerant of the least injustice and a volunteer ever ready to champion +what he considers right and good. Children may read it as a "story of +man who tried to be honest." It is a light, amusing and, at the name +time, instructive story, with no tangle of love affairs, no scheme of +blood-curdling scenes or nothing startling or sensational in the plot or +characters. The story, however, may be regarded as a biting sarcasm on a +hypocritical society in which a gang of instructors of dark character at +a middle school in a backwoods town plays a prominent part. The hero of +the story is made a victim of their annoying intrigues, but finally +comes out triumphant by smashing the petty red tapism, knocking down the +sham pretentions and by actual use of the fist on the Head Instructor +and his henchman. +The story will be found equally entertaining as a means of studying the +peculiar traits of the native of Tokyo which are characterised by their +quick temper, dashing spirit, generosity and by their readiness to +resist even the lordly personage if convinced of their own justness, or +to kneel down even to a child if they acknowledge their own wrong. +Incidently the touching devotion of the old maid servant Kiyo to the +hero will prove a standing reproach to the inconstant, unfaithful +servants of which the number is ever increasing these days in Tokyo. The +story becomes doubly interesting by the fact that Mr. K. Natsume, when +quite young, held a position of teacher of English at a middle school +somewhere about the same part of the country described in the story, +while he himself was born and brought up in Tokyo. +It may be added that the original is written in an autobiographical +style. It is profusely interladed with spicy, catchy colloquials patent +to the people of Tokyo for the equals of which we may look to the +rattling speeches of notorious Chuck Conners of the Bowery of New York. +It should be frankly stated that much difficulty was experienced in +getting the corresponding terms in English for those catchy expressions. +Strictly speaking, some of them have no English equivalents. Care has +been exercised to select what has been thought most appropriate in the +judgment or the translator in converting those expressions into English +but some of them might provoke disapproval from those of the "cultured" +class with "refined" ears. The slangs in English in this translation +were taken from an American magazine of world-wide reputation editor of +which was not afraid to print of "damn" when necessary, by scorning the +timid, conventional way of putting it as "d--n." If the propriety of +printing such short ugly words be questioned, the translator is sorry to +say that no means now exists of directly bringing him to account for he +met untimely death on board the Lusitania when it was sunk by the German +submarine. +Thanks are due to Mr. J. R. Kennedy, General Manager, and Mr. Henry +Satoh, Editor-in-Chief, both of the Kokusai Tsushin-sha (the +International News Agency) of Tokyo and a host of personal friends of +the translator whose untiring assistance and kind suggestions have made +the present translation possible. Without their sympathetic interests, +this translation may not have seen the daylight. +Tokyo, September, 1918. +BOTCHAN (MASTER DARLING) +CHAPTER I +Because of an hereditary recklessness, I have been playing always a +losing game since my childhood. During my grammar school days, I was +once laid up for about a week by jumping from the second story of the +school building. Some may ask why I committed such a rash act. There was +no particular reason for doing such a thing except I happened to be +looking out into the yard from the second floor of the newly-built +school house, when one of my classmates, joking, shouted at me; "Say, +you big bluff, I'll bet you can't jump down from there! O, you +chicken-heart, ha, ha!" So I jumped down. The janitor of the school had +to carry me home on his back, and when my father saw me, he yelled +derisively, "What a fellow you are to go and get your bones dislocated +by jumping only from a second story!" +"I'll see I don't get dislocated next time," I answered. +One of my relatives once presented me with a pen-knife. I was showing it +to my friends, reflecting its pretty blades against the rays of the sun, +when one of them chimed in that the blades gleamed all right, but seemed +rather dull for cutting with. +"Rather dull? See if they don't cut!" I retorted. +"Cut your finger, then," he challenged. And with "Finger nothing! Here +goes!" I cut my thumb slant-wise. Fortunately the knife was small and +the bone of the thumb hard enough, so the thumb is still there, but the +scar will be there until my death. +About twenty steps to the east edge of our garden, there was a +moderate-sized vegetable yard, rising toward the south, and in the +centre of which stood a chestnut tree which was dearer to me than life. +In the season when the chestnuts were ripe, I used to slip out of the +house from the back door early in the morning to pick up the chestnuts +which had fallen during the night, and eat them at the school. On the +west side of the vegetable yard was the adjoining garden of a pawn shop +called Yamashiro-ya. This shopkeeper's son was a boy about 13 or 14 +years old named Kantaro. Kantaro was, it happens, a mollycoddle. +Nevertheless he had the temerity to come over the fence to our yard and +steal my chestnuts. +One certain evening I hid myself behind a folding-gate of the fence and +caught him in the act. Having his retreat cut off he grappled with me in +desperation. He was about two years older than I, and, though +weak-kneed, was physically the stronger. While I wallopped him, he +pushed his head against my breast and by chance it slipped inside my +sleeve. As this hindered the free action of my arm, I tried to shake him +loose, though, his head dangled the further inside, and being no longer +able to stand the stifling combat, he bit my bare arm. It was painful. I +held him fast against the fence, and by a dexterous foot twist sent him +down flat on his back. Kantaro broke the fence and as the ground +belonging to Yamashiro-ya was about six feet lower than the vegetable +yard, he fell headlong to his own territory with a thud. As he rolled +off he tore away the sleeve in which his head had been enwrapped, and my +arm recovered a sudden freedom of movement. That night when my mother +went to Yamashiro-ya to apologize, she brought back that sleeve. +Besides the above, I did many other mischiefs. With Kaneko of a +carpenter shop and Kaku of a fishmarket, I once ruined a carrot patch of +one Mosaku. The sprouts were just shooting out and the patch was covered +with straws to ensure their even healthy growth. Upon this straw-covered +patch, we three wrestled for fully half a day, and consequently +thoroughly smashed all the sprouts. Also I once filled up a well which +watered some rice fields owned by one Furukawa, and he followed me with +kicks. The well was so devised that from a large bamboo pole, sunk deep +into the ground, the water issued and irrigated the rice fields. +Ignorant of the mechanical side of this irrigating method at that time, +I stuffed the bamboo pole with stones and sticks, and satisfied that no +more water came up, I returned home and was eating supper when Furukawa, +fiery red with anger, burst into our house with howling protests. I +believe the affair was settled on our paying for the damage. +Father did not like me in the least, and mother always sided with my big +brother. This brother's face was palish white, and he had a fondness for +taking the part of an actress at the theatre. +"This fellow will never amount to much," father used to remark when +he saw me. +"He's so reckless that I worry about his future," I often heard mother +say of me. Exactly; I have never amounted to much. I am just as you see +me; no wonder my future used to cause anxiety to my mother. I am living +without becoming but a jailbird. +Two or three days previous to my mother's death, I took it into my head +to turn a somersault in the kitchen, and painfully hit my ribs against +the corner of the stove. Mother was very angry at this and told me not +to show my face again, so I went to a relative to stay with. While +there, I received the news that my mother's illness had become very +serious, and that after all efforts for her recovery, she was dead. I +came home thinking that I should have behaved better if I had known the +conditions were so serious as that. Then that big brother of mine +denounced me as wanting in filial piety, and that I had caused her +untimely death. Mortified at this, I slapped his face, and thereupon +received a sound scolding from father. +After the death of mother, I lived with father and brother. Father did +nothing, and always said "You're no good" to my face. What he meant by +"no good" I am yet to understand. A funny dad he was. My brother was to +be seen studying English hard, saying that he was going to be a +businessman. He was like a girl by nature, and so "sassy" that we two +were never on good terms, and had to fight it out about once every ten +days. When we played a chess game one day, he placed a chessman as a +"waiter,"--a cowardly tactic this,--and had hearty laugh on me by seeing +me in a fix. His manner was so trying that time that I banged a chessman +on his forehead which was injured a little bit and bled. He told all +about this to father, who said he would disinherit me. +Then I gave up myself for lost, and expected to be really disinherited. +But our maid Kiyo, who had been with us for ten years or so, interceded +on my behalf, and tearfully apologized for me, and by her appeal my +father's wrath was softened. I did not regard him, however, as one to be +afraid of in any way, but rather felt sorry for our Kiyo. I had heard +that Kiyo was of a decent, well-to-do family, but being driven to +poverty at the time of the Restoration, had to work as a servant. So she +was an old woman by this time. This old woman,--by what affinity, as +the Buddhists say, I don't know,--loved me a great deal. Strange, +indeed! She was almost blindly fond of me,--me, whom mother, became +thoroughly disgusted with three days before her death; whom father +considered a most aggravating proposition all the year round, and whom +the neighbors cordially hated as the local bully among the youngsters. I +had long reconciled myself to the fact that my nature was far from being +attractive to others, and so didn't mind if I were treated as a piece of +wood; so I thought it uncommon that Kiyo should pet me like that. +Sometimes in the kitchen, when there was nobody around, she would praise +me saying that I was straightforward and of a good disposition. What she +meant by that exactly, was not clear to me, however. If I were of so +good a nature as she said, I imagined those other than Kiyo should +accord me a better treatment. So whenever Kiyo said to me anything of +the kind, I used to answer that I did not like passing compliments. Then +she would remark; "That's the very reason I say you are of a good +disposition," and would gaze at me with absorbing tenderness. She seemed +to recreate me by her own imagination, and was proud of the fact. I felt +even chilled through my marrow at her constant attention to me. +After my mother was dead, Kiyo loved me still more. In my simple +reasoning, I wondered why she had taken such a fancy to me. Sometimes I +thought it quite futile on her part, that she had better quit that sort +of thing, which was bad for her. But she loved me just the same. Once +in, a while she would buy, out of her own pocket, some cakes or +sweetmeats for me. When the night was cold, she would secretly buy some +noodle powder, and bring all unawares hot noodle gruel to my bed; or +sometimes she would even buy a bowl of steaming noodles from the +peddler. Not only with edibles, but she was generous alike with socks, +pencils, note books, etc. And she even furnished me,--this happened some +time later,--with about three yen, I did not ask her for the money; she +offered it from her own good will by bringing it to my room, saying that +I might be in need of some cash. This, of course, embarrassed me, but as +she was so insistent I consented to borrow it. I confess I was really +glad of the money. I put it in a bag, and carried it in my pocket. While +about the house, I happened to drop the bag into a cesspool. Helpless, I +told Kiyo how I had lost the money, and at once she fetched a bamboo +stick, and said she will get it for me. After a while I heard a +splashing sound of water about our family well, and going there, saw +Kiyo washing the bag strung on the end of the stick. I opened the bag +and found the edict of the three one-yen bills turned to faint yellow +and designs fading. Kiyo dried them at an open fire and handed them over +to me, asking if they were all right. I smelled them and said; "They +stink yet." +"Give them to me; I'll get them changed." She took those three bills, +and,--I do not know how she went about it,--brought three yen in silver. +I forget now upon what I spent the three yen. "I'll pay you back soon," +I said at the time, but didn't. I could not now pay it back even if I +wished to do so with ten times the amount. +When Kiyo gave me anything she did so always when both father and +brother were out. Many things I do not like, but what I most detest is +the monopolizing of favors behind some one else's back. Bad as my +relations were with my brother, still I did not feel justified in +accepting candies or color-pencils from Kiyo without my brother's +knowledge. "Why do you give those things only to me and not to my +brother also?" I asked her once, and she answered quite unconcernedly +that my brother may be left to himself as his father bought him +everything. That was partiality; father was obstinate, but I am sure he +was not a man who would indulge in favoritism. To Kiyo, however, he +might have looked that way. There is no doubt that Kiyo was blind to the +extent of her undue indulgence with me. She was said to have come from a +well-to-do family, but the poor soul was uneducated, and it could not be +helped. All the same, you cannot tell how prejudice will drive one to +the extremes. Kiyo seemed quite sure that some day I would achieve high +position in society and become famous. Equally she was sure that my +brother, who was spending his hours studiously, was only good for his +white skin, and would stand no show in the future. Nothing can beat an +old woman for this sort of thing, I tell you. She firmly believed that +whoever she liked would become famous, while whoever she hated would +not. I did not have at that time any particular object in my life. But +the persistency with which Kiyo declared that I would be a great man +some day, made me speculate myself that after all I might become one. +How absurd it seems to me now when I recall those days. I asked her once +what kind of a man I should be, but she seemed to have formed no +concrete idea as to that; only she said that I was sure to live in a +house with grand entrance hall, and ride in a private rikisha. +And Kiyo seemed to have decided for herself to live with me when I +became independent and occupy my own house. "Please let me live with +you,"--she repeatedly asked of me. Feeling somewhat that I should +eventually be able to own a house, I answered her "Yes," as far as such +an answer went. This woman, by the way, was strongly imaginative. She +questioned me what place I liked,--Kojimachi-ku or Azabu-ku?--and +suggested that I should have a swing in our garden, that one room be +enough for European style, etc., planning everything to suit her own +fancy. I did not then care a straw for anything like a house; so neither +Japanese nor European style was much of use to me, and I told her to +that effect. Then she would praise me as uncovetous and clean of heart. +Whatever I said, she had praise for me. +I lived, after the death of mother, in this fashion for five or six +years. I had kicks from father, had rows with brother, and had candies +and praise from Kiyo. I cared for nothing more; I thought this was +enough. I imagined all other boys were leading about the same kind of +life. As Kiyo frequently told me, however, that I was to be pitied, and +was unfortunate, I imagined that that might be so. There was nothing +that particularly worried me except that father was too tight with my +pocket money, and this was rather hard on me. +In January of the 6th year after mother's death, father died of +apoplexy. In April of the same year, I graduated from a middle school, +and two months later, my brother graduated from a business college. Soon +he obtained a job in the Kyushu branch of a certain firm and had to go +there, while I had to remain in Tokyo and continue my study. He proposed +the sale of our house and the realization of our property, to which I +answered "Just as you like it." I had no intention of depending upon him +anyway. Even were he to look after me, I was sure of his starting +something which would eventually end in a smash-up as we were prone to +quarrel on the least pretext. It was because in order to receive his +protection that I should have to bow before such a fellow, that I +resolved that I would live by myself even if I had to do milk delivery. +Shortly afterwards he sent for a second-hand dealer and sold for a song +all the bric-a-bric which had been handed down from ages ago in our +family. Our house and lot were sold, through the efforts of a middleman +to a wealthy person. This transaction seemed to have netted a goodly sum +to him, but I know nothing as to the detail. +For one month previous to this, I had been rooming in a boarding house +in Kanda-ku, pending a decision as to my future course. Kiyo was greatly +grieved to see the house in which she had lived so many years change +ownership, but she was helpless in the matter. +"If you were a little older, you might have inherited this house," she +once remarked in earnest. +If I could have inherited the house through being a little older, I +ought to have been able to inherit the house right then. She knew +nothing, and believed the lack of age only prevented my coming into the +possession of the house. +Thus I parted from my brother, but the disposal of Kiyo was a difficult +proposition. My brother was, of course, unable to take her along, nor +was there any danger of her following him so far away as Kyushu, while I +was in a small room of a boarding house, and might have to clear out +anytime at that. There was no way out, so I asked her if she intended to +work somewhere else. Finally she answered me definitely that she would +go to her nephew's and wait until I started my own house and get +married. This nephew was a clerk in the Court of Justice, and being +fairly well off, had invited Kiyo before more than once to come and live +with him, but Kiyo preferred to stay with us, even as a servant, since +she had become well used to our family. But now I think she thought it +better to go over to her nephew than to start a new life as servant in a +strange house. Be that as it may, she advised me to have my own +household soon, or get married, so she would come and help me in +housekeeping. I believe she liked me more than she did her own kin. +My brother came to me, two days previous to his departure for Kyushu, +and giving me 600 yen, said that I might begin a business with it, or go +ahead with my study, or spend it in any way I liked, but that that would +be the last he could spare. It was a commendable act for my brother. +What! about only 600 yen! I could get along without it, I thought, but +as this unusually simple manner appealed to me, I accepted the offer +with thanks. Then he produced 50 yen, requesting me to give it to Kiyo +next time I saw her, which I readily complied with. Two days after, I +saw him off at the Shimbashi Station, and have not set my eyes on him +ever since. +Lying in my bed, I meditated on the best way to spend that 600 yen. A +business is fraught with too much trouble, and besides it was not my +calling. Moreover with only 600 yen no one could open a business worth +the name. Were I even able to do it, I was far from being educated, and +after all, would lose it. Better let investments alone, but study more +with the money. Dividing the 600 yen into three, and by spending 200 yen +a year, I could study for three years. If I kept at one study with +bull-dog tenacity for three years, I should be able to learn something. +Then the selection of a school was the next problem. By nature, there is +no branch of study whatever which appeals to my taste. Nix on languages +or literature! The new poetry was all Greek to me; I could not make out +one single line of twenty. Since I detested every kind of study, any +kind of study should have been the same to me. Thinking thus, I happened +to pass front of a school of physics, and seeing a sign posted for the +admittance of more students, I thought this might be a kind of +"affinity," and having asked for the prospectus, at once filed my +application for entrance. When I think of it now, it was a blunder due +to my hereditary recklessness. +For three years I studied about as diligently as ordinary fellows, but +not being of a particularly brilliant quality, my standing in the class +was easier to find by looking up from the bottom. Strange, isn't it, +that when three years were over, I graduated? I had to laugh at myself, +but there being no reason for complaint, I passed out. +Eight days after my graduation, the principal of the school asked me to +come over and see him. I wondered what he wanted, and went. A middle +school in Shikoku was in need of a teacher of mathematics for forty yen +a month, and he sounded me to see if I would take it. I had studied for +three years, but to tell the truth, I had no intention of either +teaching or going to the country. Having nothing in sight, however, +except teaching, I readily accepted the offer. This too was a blunder +due to hereditary recklessness. +I accepted the position, and so must go there. The three years of my +school life I had seen confined in a small room, but with no kick coming +or having no rough house. It was a comparatively easy going period in my +life. But now I had to pack up. Once I went to Kamakura on a picnic with +my classmates while I was in the grammar school, and that was the first +and last, so far, that I stepped outside of Tokyo since I could +remember. This time I must go darn far away, that it beats Kamakura by a +mile. The prospective town is situated on the coast, and looked the size +of a needle-point on the map. It would not be much to look at anyway. I +knew nothing about the place or the people there. It did not worry me or +cause any anxiety. I had simply to travel there and that was the +annoying part. +Once in a while, since our house was no more, I went to Kiyo's +nephew's to see her. Her nephew was unusually good-natured, and +whenever I called upon her, he treated me well if he happened to be at +home. Kiyo would boost me sky-high to her nephew right to my face. She +went so far once as to say that when I had graduated from school, I +would purchase a house somewhere in Kojimachi-ku and get a position in +a government office. She decided everything in her own way, and talked +of it aloud, and I was made an unwilling and bashful listener. I do +not know how her nephew weighed her tales of self-indulgence on me. +Kiyo was a woman of the old type, and seemed, as if it was still the +days of Feudal Lords, to regard her nephew equally under obligation to +me even as she was herself. +After settling about my new position, I called upon her three days +previous to my departure. She was sick abed in a small room, but, on +seeing me she got up and immediately inquired; +"Master Darling, when do you begin housekeeping?" +She evidently thought as soon as a fellow finishes school, money comes +to his pocket by itself. But then how absurd to call such a "great man" +"Darling." I told her simply that I should let the house proposition go +for some time, as I had to go to the country. She looked greatly +disappointed, and blankly smoothed her gray-haired sidelocks. I felt +sorry for her, and said comfortingly; "I am going away but will come +back soon. I'll return in the vacation next summer, sure." Still as she +appeared not fully satisfied, I added; +"Will bring you back a surprise. What do you like?" +She wished to eat "sasa-ame"[1] of Echigo province. I had never heard of +"sasa-ame" of Echigo. To begin with, the location is entirely different. +[Footnote 1: Sasa-ame is a kind of rice-jelly wrapped with sasa, or the +bamboo leaves, well-known as a product of Echigo province.] +"There seems to be no 'sasa-ame' in the country where I'm going," I +explained, and she rejoined; "Then, in what direction?" I answered +"westward" and she came back with "Is it on the other side of Hakone?" +This give-and-take conversation proved too much for me. +On the day of my departure, she came to my room early in the morning and +helped me to pack up. She put into my carpet-bag tooth powder, +tooth-brush and towels which she said she had bought at a dry goods +store on her way. I protested that I did not want them, but she was +insistent.[A] We rode in rikishas to the station. Coming up the +platform, she gazed at me from outside the car, and said in a low voice; +"This may be our last good-by. Take care of yourself." +Her eyes were full of tears. I did not cry, but was almost going to. +After the train had run some distance, thinking it would be all right +now, I poked my head out of the window and looked back. She was still +there. She looked very small. +CHAPTER II. +With a long, sonorous whistle the steamer which I was aboard came to a +standstill, and a boat was seen making toward us from the shore. The man +rowing the boat was stark naked, except for a piece of red cloth girt +round his loins. A barbarous place, this! though he may have been +excused for it in such hot weather as it was. The sun's rays were strong +and the water glimmered in such strange colors as to dazzle one's sight +if gazed at it for long. I had been told by a clerk of the ship that I +was to get off here. The place looked like a fishing village about the +size of Omori. Great Scott! I wouldn't stay in such a hole, I thought, +but I had to get out. So, down I jumped first into the boat, and I think +five or six others followed me. After loading about four large boxes +besides, the red-cloth rowed us ashore. When the boat struck the sand, I +was again the first to jump out, and right away I accosted a skinny +urchin standing nearby, asking him where the middle school was. The kid +answered blankly that he did not know. Confound the dull-head! Not to +know where the middle school was, living in such a tiny bit of a town. +Then a man wearing a rig with short, queer shaped sleeves approached me +and bade me follow. I walked after him and was taken to an inn called +Minato-ya. The maids of the inn, who gave me a disagreeable impression, +chorused at sight of me; "Please step inside." This discouraged me in +proceeding further, and I asked them, standing at the door-way, to show +me the middle school. On being told that the middle school was about +four miles away by rail, I became still more discouraged at putting up +there. I snatched my two valises from the man with queer-shaped [B] +sleeves who had guided me so far, and strode away. The people of the inn +looked after me with a dazed expression. +The station was easily found, and a ticket bought without any fuss. The +coach I got in was about as dignified as a match-box. The train rambled +on for about five minutes, and then I had to get off. No wonder the fare +was cheap; it cost only three sen. I then hired a rikisha and arrived at +the middle school, but school was already over and nobody was there. The +teacher on night-duty was out just for a while, said the janitor,--the +night-watch was taking life easy, sure. I thought of visiting the +principal, but being tired, ordered the rikishaman to take me to a +hotel. He did this with much alacrity and led me to a hotel called +Yamashiro-ya. I felt it rather amusing to find the name Yamashiro-ya the +same as that of Kantaro's house. +They ushered me to a dark room below the stairway. No one could stay in +such a hot place! I said I did not like such a warm room, but the maid +dumped my valises on the floor and left me, mumbling that all the other +rooms were occupied. So I took the room though it took some resolution +to stand the weltering heat. After a while the maid said the bath was +ready, and I took one: On my way back from the bathroom, I peeped about, +and found many rooms, which looked much cooler than mine, vacant. +Sunnovagun! They had lied. By'm-by, she fetched my supper. Although the +room was hot, the meal was a deal better than the kind I used to have in +my boarding house. While waiting on me, she questioned me where I was +from, and I said, "from Tokyo." Then she asked; "Isn't Tokyo a nice +place?" and I shot back, "Bet 'tis." About the time the maid had reached +the kitchen, loud laughs were heard. There was nothing doing, so I went +to bed, but could not sleep. Not only was it hot, but noisy,--about five +times noisier than my boarding house. While snoozing, I dreamed of Kiyo. +She was eating "sasa-ame" of Echigo province without taking off the +wrapper of bamboo leaves. I tried to stop her, saying bamboo leaves may +do her harm, but she replied, "O, no, these leaves are very helpful for +the health," and ate them with much relish. Astounded, I laughed "Ha, +ha, ha!"--and so awoke. The maid was opening the outside shutters. The +weather was just as clear as the previous day. +I had heard once before that when travelling, one should give "tea +money" to the hotel or inn where he stops; that unless this "tea +money" is given, the hostelry would accord him rather rough treatment. +It must have been on account of my being slow in the fork over of this +"tea money" that they had huddled me into such a narrow, dark room. +Likewise my shabby clothes and the carpet bags and satin umbrella must +have been accountable for it. Took me for a piker, eh? those hayseeds! +I would give them a knocker with "tea money." I left Tokyo with about +30 yen in my pocket, which remained from my school expenses. Taking +off the railway and steamship fare, and other incidental expenses, I +had still about 14 yen in my pocket. I could give them all I +had;--what did I care, I was going to get a salary now. All country +folk are tight-wads, and one 5-yen bill would hit them square. Now +watch and see. Having washed myself, I returned to my room and waited, +and the maid of the night before brought in my breakfast. Waiting on +me with a tray, she looked at me with a sort of sulphuric smile. Rude! +Is any parade marching on my face? I should say. Even my face is far +better than that of the maid. I intended of giving "tea money" after +breakfast, but I became disgusted, and taking out one 5-yen bill told +her to take it to the office later. The face of the maid became then +shy and awkward. After the meal, I left for the school. The maid did +not have my shoes polished. +I had had vague idea of the direction of the school as I rode to it the +previous day, so turning two or three corners, I came to the front gate. +From the gate to the entrance the walk was paved with granite. When I +had passed to the entrance in the rikisha, this walk made so +outlandishly a loud noise that I had felt coy. On my way to the school, +I met a number of the students in uniforms of cotton drill and they all +entered this gate. Some of them were taller than I and looked much +stronger. When I thought of teaching fellows of this ilk, I was +impressed with a queer sort of uneasiness. My card was taken to the +principal, to whose room I was ushered at once. With scant mustache, +dark-skinned and big-eyed, the principal was a man who looked like a +badger. He studiously assumed an air of superiority, and saying he would +like to see me do my best, handed the note of appointment, stamped big, +in a solemn manner. This note I threw away into the sea on my way back +to Tokyo. He said he would introduce me to all my fellow teachers, and I +was to show to each one of them the note of appointment. What a bother! +It would be far better to stick this note up in the teachers' room for +three days instead of going through such a monkey process. +The teachers would not be all in the room until the bugle for the first +hour was sounded. There was plenty of time. The principal took out his +watch, and saying that he would acquaint me particularly with the school +by-and-bye, he would only furnish me now with general matters, and +started a long lecture on the spirit of education. For a while I +listened to him with my mind half away somewhere else, but about half +way through his lecture, I began to realize that I should soon be in a +bad fix. I could not do, by any means, all he expected of me. He +expected that I should make myself an example to the students, should +become an object of admiration for the whole school or should exert my +moral influence, besides teaching technical knowledge in order to +become a real educator, or something ridiculously high-sounding. No man +with such admirable qualities would come so far away for only 40 yen a +month! Men are generally alike. If one gets excited, one is liable to +fight, I thought, but if things are to be kept on in the way the +principal says, I could hardly open my mouth to utter anything, nor take +a stroll around the place. If they wanted me to fill such an onerous +post, they should have told all that before. I hate to tell a lie; I +would give it up as having been cheated, and get out of this mess like a +man there and then. I had only about 9 yen left in my pocket after +tipping the hotel 5 yen. Nine yen would not take me back to Tokyo. I had +better not have tipped the hotel; what a pity! However, I would be able +to manage it somehow. I considered it better to run short in my return +expenses than to tell a lie. +"I cannot do it the way you want me to. I return this appointment." +I shoved back the note. The principal winked his badger-like eyes and +gazed at me. Then he said; +"What I have said just now is what I desire of you. I know well that you +cannot do all I want, So don't worry." +And he laughed. If he knew it so well already, what on earth did he +scare me for? +Meanwhile the bugle sounded, being followed by bustling noises in the +direction of the class rooms. All the teachers would be now ready, I was +told, and I followed the principal to the teachers' room. In a spacious +rectangular room, they sat each before a table lined along the walls. +When I entered the room, they all glanced at me as if by previous +agreement. Did they think my face was for a show? Then, as per +instructions, I introduced myself and showed the note to each one of +them. Most of them left their chairs and made a slight bow of +acknowledgment. But some of the more painfully polite took the note and +read it and respectfully returned it to me, just like the cheap +performances at a rural show! When I came to the fifteenth, who was the +teacher of physical training, I became impatient at repeating the same +old thing so often. The other side had to do it only once, but my side +had to do it fifteen times. They ought to have had some sympathy. +Among those I met in the room there was Mr. Blank who was head teacher. +Said he was a Bachelor of Arts. I suppose he was a great man since he +was a graduate from Imperial University and had such a title. He talked +in a strangely effeminate voice like a woman. But what surprised me most +was that he wore a flannel shirt. However thin it might be, flannel is +flannel and must have been pretty warm at that time of the year. What +painstaking dress is required which will be becoming to a B.A.! And it +was a red shirt; wouldn't that kill you! I heard afterwards that he +wears a red shirt all the year round. What a strange affliction! +According to his own explanation, he has his shirts made to order for +the sake of his health as the red color is beneficial to the physical +condition. Unnecessary worry, this, for that being the case, he should +have had his coat and hakama also in red. And there was one Mr. Koga, +teacher of English, whose complexion was very pale. Pale-faced people +are usually thin, but this man was pale and fat. When I was attending +grammar school, there was one Tami Asai in our class, and his father was +just as pale as this Koga. Asai was a farmer, and I asked Kiyo if one's +face would become pale if he took up farming. Kiyo said it was not so; +Asai ate always Hubbard squash of "uranari" [2] and that was the reason. +Thereafter when I saw any man pale and fat, I took it for granted that +it was the result of his having eaten too much of squash of "uranari." +This English teacher was surely subsisting upon squash. However, what +the meaning of "uranari" is, I do not know. I asked Kiyo once, but she +only laughed. Probably she did not know. Among the teachers of +mathematics, there was one named Hotta. This was a fellow of massive +body, with hair closely cropped. He looked like one of the old-time +devilish priests who made the Eizan temple famous. I showed him the note +politely, but he did not even look at it, and blurted out; +"You're the man newly appointed, eh? Come and see me sometime, +ha, ha, ha!" +[Footnote 2: Means the last crop.] +Devil take his "Ha, ha, ha!" Who would go to see a fellow so void of the +sense of common decency! I gave this priest from this time the nickname +of Porcupine. +The Confucian teacher was strict in his manner as becoming to his +profession. "Arrived yesterday? You must be tired. Start teaching +already? Working hard, indeed!"--and so on. He was an old man, quite +sociable and talkative. +The teacher of drawing was altogether like a cheap actor. He wore a +thin, flappy haori of sukiya, and, toying with a fan, he giggled; "Where +from? eh? Tokyo? Glad to hear that. You make another of our group. I'm a +Tokyo kid myself." +If such a fellow prided himself on being a Tokyo kid, I wished I had +never been born in Tokyo. I might go on writing about each one of +them, for there are many, but I stop here otherwise there will be no +end to it. +When my formal introduction was over, the principal said that I might go +for the day, but I should make arrangements as to the class hours, etc., +with the head teacher of mathematics and begin teaching from the day +after the morrow. Asked who was the head teacher of mathematics, I found +that he was no other than that Porcupine. Holy smokes! was I to serve +under him? I was disappointed. +"Say, where are you stopping? Yamashiro-ya? Well, I'll come and +talk it over." +So saying, Porcupine, chalk in hand, left the room to his class. That +was rather humiliating for a head-teacher to come over and see his +subordinate, but it was better than to call me over to him. +After leaving the school, I thought of returning straight to the hotel, +but as there was nothing to do, I decided to take in a little of the +town, and started walking about following my nose. I saw prefectural +building; it was an old structure of the last century. Also I saw the +barracks; they were less imposing than those of the Azabu Regiment, +Tokyo. I passed through the main street. The width of the street is +about one half that of Kagurazaka, and its aspect is inferior. What +about a castle-town of 250,000-koku Lord! Pity the fellows who get +swell-headed in such a place as a castle-town! +While I walked about musing like this, I found myself in front of +Yamashiro-ya. The town was much narrower than I had been led to believe. +"I think I have seen nearly all. Guess I'll return and eat." And I +entered the gate. The mistress of the hotel who was sitting at the +counter, jumped out of her place at my appearance and with "Are you +back, Sire!" scraped the floor with her forehead. When I took my shoes +off and stepped inside, the maid took me to an upstairs room that had +became vacant. It was a front room of 15 mats (about 90 square feet). I +had never before lived in so splendid a room as this. As it was quite +uncertain when I should again be able to occupy such a room in future, I +took off my European dress, and with only a single Japanese summer coat +on, sprawled in the centre of the room in the shape of the Japanese +letter "big" (arms stretched out and legs spread wide[D]). I found it +very refreshing. +After luncheon I at once wrote a letter to Kiyo. I hate most to write +letters because I am poor at sentence-making and also poor in my stock +of words. Neither did I have any place to which to address my letters. +However, Kiyo might be getting anxious. It would not do to let her worry +lest she think the steamer which I boarded had been wrecked and I was +drowned,--so I braced up and wrote a long one. The body of the letter +was as follows: + "Arrived yesterday. A dull place. Am sleeping in a room of 15 mats. + Tipped the hotel five yen as tea money. The house-wife of the hotel + scraped the floor with her forehead. Couldn't sleep last night. + Dreamed Kiyo eat sasa-ame together with the bamboo-leaf wrappers. Will + return next summer. Went to the school to-day, and nicknamed all the + fellows. 'Badger' for the principal, 'Red Shirt' for the head-teacher, + 'Hubbard Squash' for the teacher of English, 'Porcupine' the teacher + of mathematics and 'Clown' for that of drawing. Will write you many + other things soon. Good bye." +When I finished writing the letter, I felt better and sleepy. So I slept +in the centre of the room, as I had done before, in the letter "big" +shape ([D]). No dream this time, and I had a sound sleep. +"Is this the room?"--a loud voice was heard,--a voice which woke me up, +and Porcupine entered. +"How do you do? What you have to do in the school----" he began talking +shop as soon as I got up and rattled me much. On learning my duties in +the school, there seemed to be no difficulty, and I decided to accept. +If only such were what was expected of me, I would not be surprised were +I told to start not only two days hence but even from the following day. +The talk on business over, Porcupine said that he did not think it was +my intention to stay in such a hotel all the time, that he would find a +room for me in a good boarding house, and that I should move. +"They wouldn't take in another from anybody else but I can do it +right away. The sooner the better. Go and look at the room to-day, +move tomorrow and start teaching from the next day. That'll be all +nice and settled." +He seemed satisfied by arranging all by himself. Indeed, I should not be +able to occupy such a room for long. I might have to blow in all of my +salary for the hotel bill and yet be short of squaring it. It was pity +to leave the hotel so soon after I had just shone with a 5-yen tip. +However, it being decidedly convenient to move and get settled early if +I had to move at all, I asked Porcupine to get that room for me. He told +me then to come over with him and see the house at any rate, and I did. +The house was situated mid-way up a hill at the end of the town, and was +a quiet. The boss was said to be a dealer in antique curios, called +Ikagin, and his wife was about four years his senior. I learned the +English word "witch" when I was in middle school, and this woman looked +exactly like one. But as she was another man's wife, what did I care if +she was a witch. Finally I decided to live in the house from the next +day. On our way back Porcupine treated me to a cup of ice-water. When I +first met him in the school, I thought him a disgustingly overbearing +fellow, but judging by the way he had looked after me so far, he +appeared not so bad after all. Only he seemed, like me, impatient by +nature and of quick-temper. I heard afterward that he was liked most by +all the students in the school. +CHAPTER III. +My teaching began at last. When I entered the class-room and stepped +upon the platform for the first time, I felt somewhat strange. While +lecturing, I wondered if a fellow like me could keep up the profession +of public instructor. The students were noisy. Once in a while, they +would holler "Teacher!" "Teacher,"--it was "going some." I had been +calling others "teacher" every day so far, in the school of physics, but +in calling others "teacher" and being called one, there is a wide gap of +difference. It made me feel as if some one was tickling my soles. I am +not a sneakish fellow, nor a coward; only--it's a pity--I lack audacity. +If one calls me "teacher" aloud, it gives me a shock similar to that of +hearing the noon-gun in Marunouchi when I was hungry. The first hour +passed away in a dashing manner. And it passed away without encountering +any knotty questions. As I returned to the teachers' room, Porcupine +asked me how it was. I simply answered "well," and he seemed satisfied. +When I left the teachers' room, chalk in hand, for the second hour +class, I felt as if I was invading the enemy's territory. On entering +the room, I found the students for this hour were all big fellows. I am +a Tokyo kid, delicately built and small, and did not appear very +impressive even in my elevated position. If it comes to a scraping, I +can hold my own even with wrestlers, but I had no means of appearing +awe-inspiring[E], merely by the aid of my tongue, to so many as forty +such big chaps before me. Believing, however, that it would set a bad +precedent to show these country fellows any weakness, I lectured rather +loudly and in brusque tone. During the first part the students were +taken aback and listened literally with their mouths open. "That's one +on you!" I thought. Elated by my success, I kept on in this tone, when +one who looked the strongest, sitting in the middle of the front row, +stood up suddenly, and called "Teacher!" There it goes!--I thought, and +asked him what it was. +"A-ah sa-ay, you talk too quick. A-ah ca-an't you make it a leetle slow? +A-ah?" "A-ah ca-an't you?" "A-ah?" was altogether dull. +"If I talk too fast, I'll make it slow, but I'm a Tokyo fellow, and +can't talk the way you do. If you don't understand it, better wait +until you do." +So I answered him. In this way the second hour was closed better than I +had expected. Only, as I was about to leave the class, one of the +students asked me, "A-ah say, won't you please do them for me?" and +showed me some problems in geometry which I was sure I could not solve. +This proved to be somewhat a damper on me. But, helpless, I told him I +could not make them out, and telling him that I would show him how next +time, hastily got out of the room. And all of them raised "Whee--ee!" +Some of them were heard saying "He doesn't know much." Don't take a +teacher for an encyclopaedia! If I could work out such hard questions as +these easily, I would not be in such a backwoods town for forty yen a +month. I returned to the teachers' room. +"How was it this time?" asked Porcupine. I said "Umh." But not satisfied +with "Umh" only, I added that all the students in this school were +boneheads. He put up a whimsical face. +The third and the fourth hour and the first hour in the afternoon were +more or less the same. In all the classes I attended, I made some kind +of blunder. I realised that the profession of teaching not quite so easy +a calling as might have appeared. My teaching for the day was finished +but I could not get away. I had to wait alone until three o'clock. I +understood that at three o'clock the students of my classes would finish +cleaning up the rooms and report to me, whereupon I would go over the +rooms. Then I would run through the students' roll, and then be free to +go home. Outrageous, indeed, to keep on chained to the school, staring +at the empty space when he had nothing more to do, even though he was +"bought" by a salary! Other fellow teachers, however, meekly submitted +to the regulation, and believing it not well for me,--a new comer--to +fuss about it, I stood it. On my way home, I appealed to Porcupine as to +the absurdity of keeping me there till three o'clock regardless of my +having nothing to do in the school. He said "Yes" and laughed. But he +became serious and in an advisory manner told me not to make many +complaints about the school. +"Talk to me only, if you want to. There are some queer guys around." +As we parted at the next corner, I did not have time to hear more from +him. +On reaching my room, the boss of the house came to me saying, "Let me +serve you tea." I expected he was going to treat me to some good tea +since he said "Let me serve you," but he simply made himself at home +and drank my own tea. Judging by this, I thought he might be +practising "Let me serve you" during my absence. The boss said that he +was fond of antique drawings and curios and finally had decided to +start in that business. +"You look like one quite taken about art. Suppose you begin patronizing +my business just for fun as er--connoisseur of art?" +It was the least expected kind of solicitation. Two years ago, I went to +the Imperial Hotel (Tokyo) on an errand, and I was taken for a +locksmith. When I went to see the Daibutsu at Kamakura, haying wrapped +up myself from head to toe with a blanket, a rikisha man addressed me as +"Gov'ner." I have been mistaken on many occasions for as many things, +but none so far has counted on me as a probable connoisseur of art. One +should know better by my appearance. Any one who aspires to be a patron +of art is usually pictured,--you may see in any drawing,--with either a +hood on his head, or carrying a tanzaku[3] in his hand. The fellow who +calls me a connoisseur of art and pretends to mean it, may be surely as +crooked as a dog's hind legs. I told him I did not like such art-stuff, +which is usually favored by retired people. He laughed, and remarking +that that nobody liked it at first, but once in it, will find it so +fascinating that he will hardly get over it, served tea for himself and +drank it in a grotesque manner. I may say that I had asked him the night +before to buy some tea for me, but I did not like such a bitter, heavy +kind. One swallow seemed to act right on my stomach. I told him to buy a +kind not so bitter as that, and he answered "All right, Sir," and drank +another cup. The fellow seemed never to know of having enough of +anything so long as it was another man's. After he left the room, I +prepared for the morrow and went to bed. +[Footnote 3: A tanzaku is a long, narrow strip of stiff paper on which a +Japanese poem is written.] +Everyday thereafter I attended at the school and worked as per +regulations. Every day on my return, the boss came to my room with the +same old "Let me serve you tea." In about a week I understood the school +in a general way, and had my own idea as to the personality of the boss +and his wife. I heard from one of my fellow teachers that the first week +to one month after the receipt of the appointment worried them most as +to whether they had been favorably received among the students. I never +felt anything on that score. Blunders in the class room once in a while +caused me chagrin, but in about half an hour everything would clear out +of my head. I am a fellow who, by nature, can't be worrying long +about[F] anything even if I try to. I was absolutely indifferent as how +my blunders in the class room affected the students, or how much further +they affected the principal or the head-teacher. As I mentioned before, +I am not a fellow of much audacity to speak of, but I am quick to give +up anything when I see its finish. +I had resolved to go elsewhere at once if the school did not suit me. In +consequence, neither Badger nor Red Shirt wielded any influence over me. +And still less did I feel like coaxing or coddling the youngsters in the +class room. +So far it was O.K. with the school, but not so easy as that at my +boarding house. I could have stood it if it had been only the boss +coming to my room after my tea. But he would fetch many things to my +room. First time he brought in seals.[4] He displayed about ten of them +before me and persuaded me to buy them for three yen, which was very +cheap, he said. Did he take me for a third rate painter making a round +of the country? I told him I did not want them. Next time he brought in +a panel picture of flowers and birds, drawn by one Kazan or somebody. He +hung it against the wall of the alcove and asked me if it was not well +done, and I echoed it looked well done. Then he started lecturing about +Kazan, that there are two Kazans, one is Kazan something and the other +is Kazan anything, and that this picture was the work of that Kazan +something. After this nonsensical lecture, he insisted that he would +make it fifteen yen for me to buy it. I declined the offer saying that I +was shy of the money. +[Footnote 4: Artists have several seals of stone with which to stamp on +the picture they draw as a guarantee of their personal work or for +identification. The shape and kind of seals are quite a hobby among +artists, and sales or exchange are of common occurrence.] +"You can pay any time." He was insistent. I settled him by telling him +of my having no intention of purchasing it even if I had the necessary +money. Again next time, he yanked in a big writing stone slab about the +size of a ridge-tile. +"This is a tankei,"[5] he said. As he "tankeied" two or three times, I +asked for fun what was a tankei. Right away he commenced lecturing on +the subject. "There are the upper, the middle and the lower stratum in +tankei," he said. "Most of tankei slabs to-day are made from the upper +stratum," he continued, "but this one is surely from the middle +stratum. Look at this 'gan.'[6] 'Tis certainly rare to have three +'gans' like this. The ink-cake grates smoothly on it. Try it, +sir,"--and he pushed it towards me. I asked him how much, and he +answered that on account of its owner having brought it from China and +wishing to sell if as soon as possible, he would make it very cheap, +that I could have it for thirty yen. I was sure he was a fool. I seemed +to be able to get through the school somehow, but I would soon give out +if this "curio siege" kept on long. +[Footnote 5: Tankei is the name of a place in China where a certain kind +of stone suitable for writing purposes was produced.] +[Footnote 6: "Gan" may be understood as a kind of natural mark on the +stone peculiar to the stone from Tankei.] +Shortly afterwards, I began to get sick of the school. One certain +night, while I was strolling about a street named Omachi, I happened to +notice a sign of noodles below of which was annotated "Tokyo" in the +house next to the post office. I am very fond of noodles. While I was in +Tokyo, if I passed by a noodle house and smelled the seasoning spices, I +felt uncontrollable temptation to go inside at any cost. Up to this time +I had forgotten the noodle on account of mathematics and antique curios, +but since I had seen thus the sign of noodles, I could hardly pass it by +unnoticed. So availing myself of this opportunity, I went in. It was not +quite up to what I had judged by the sign. Since it claimed to follow +the Tokyo style, they should have tidied up a little bit about the room. +They did not either know Tokyo or have the means,--I did not know which, +but the room was miserably dirty. The floor-mats had all seen better +days and felt shaggy with sandy dust. The sootcovered walls defied the +blackest black. The ceiling was not only smoked by the lamp black, but +was so low as to force one involuntarily bend down his neck. Only the +price-list, on which was glaringly written "Noodles" and which was +pasted on the wall, was entirely new. I was certain that they bought an +old house and opened the business just two or three days before. At the +head of the price-list appeared "tempura" (noodles served with shrimp +fried in batter). +"Say, fetch me some tempura," I ordered in a loud voice. Then three +fellows who had been making a chewing noise together in a corner, looked +in my direction. As the room was dark I did not notice them at first. +But when we looked at each other, I found them all to be boys in our +school. They "how d'ye do'd" me and I acknowledged it. That night, +having come across the noodle after so long a time, it tasted so fine +that I ate four bowls. +The next day as I entered the class room quite unconcernedly, I saw on +the black board written in letters so large as to take up the whole +space; "Professor Tempura." The boys all glanced at my face and made +merry hee-haws at my cost. It was so absurd that I asked them if it was +in any way funny for me to eat tempura noodle. Thereupon one of them +said,--"But four bowls is too much." What did they care if I ate four +bowls or five as long as I paid it with my own money,--and speedily +finishing up my class, I returned to the teachers' room. After ten +minutes' recess, I went to the next class, and there on the black board +was newly written quite as large as before; "Four bowls of tempura +noodles, but don't laugh." +The first one did not arouse any ill-temper in me, but this time it made +me feel irritating mad. A joke carried too far becomes mischievous. It +is like the undue jealousy of some women who, like coal, look black and +suggest flames. Nobody likes it. These country simpletons, unable to +differentiate upon so delicate a boundary, would seem to be bent on +pushing everything to the limit. As they lived in such a narrow town +where one has no more to see if he goes on strolling about for one hour, +and as they were capable of doing nothing better, they were trumpeting +aloud this tempura incident in quite as serious a manner as the +Russo-Japanese war. What a bunch of miserable pups! It is because they +are raised in this fashion from their boyhood that there are many punies +who, like the dwarf maple tree in the flower pot, mature gnarled and +twisted. I have no objection to laugh myself with others over innocent +jokes. But how's this? Boys as they are, they showed a "poisonous +temper." Silently erasing off "tempura" from the board, I questioned +them if they thought such mischief interesting, that this was a cowardly +joke and if they knew the meaning of "cowardice." Some of them answered +that to get angry on being laughed at over one's own doing, was +cowardice. What made them so disgusting as this? I pitied myself for +coming from far off Tokyo to teach such a lot. +"Keep your mouth shut, and study hard," I snapped, and started the +class. In the next class again there was written: "When one eats tempura +noodles it makes him drawl nonsense." There seemed no end to it. I was +thoroughly aroused with anger, and declaring that I would not teach such +sassies, went home straight. The boys were glad of having an unexpected +holiday, so I heard. When things had come to this pass, the antique +curious seemed far more preferable to the school. +My return home and sleep over night greatly rounded off my rugged temper +over the tempura affair. I went to the school, and they were there also. +I could not tell what was what. The three days thereafter were pacific, +and on the night of the fourth day, I went to a suburb called Sumida and +ate "dango" (small balls made of glutinous rice, dressed with +sugar-paste). Sumida is a town where there are restaurants, hot-springs +bath houses and a park, and in addition, the "tenderloin." The dango +shop where I went was near the entrance to the tenderloin, and as the +dango served there was widely known for its nice taste, I dropped in on +my way back from my bath. As I did not meet any students this time, I +thought nobody knew of it, but when I entered the first hour class next +day, I found written on the black board; "Two dishes of dango--7 sen." +It is true that I ate two dishes and paid seven sen. Troublesome kids! I +declare. I expected with certainty that there would be something at the +second hour, and there it was; "The dango in the tenderloin taste fine." +Stupid wretches! +No sooner I thought, the dango incident closed than the red towel became +the topic for widespread gossip. Inquiry as to the story revealed it to +be something unusually absurd. Since, my arrival here, I had made it a +part of my routine to take in the hot springs bath every day. While +there was nothing in this town which compared favorably with Tokyo, the +hot springs were worthy of praise. So long as I was in the town, I +decided that I would have a dip every day, and went there walking, +partly for physical exercise, before my supper. And whenever I went +there I used to carry a large-size European towel dangling from my hand. +Added to somewhat reddish color the towel had acquired by its having +been soaked in the hot-springs, the red color on its border, which was +not fast enough, streaked about so that the towel now looked as if it +were dyed red. This towel hung down from my hand on both ways whether +afoot or riding in the train. For this reason, the students nicknamed me +Red Towel. Honest, it is exasperating to live in a little town. +There is some more. The bath house I patronized was a newly built +three-story house, and for the patrons of the first class the house +provided a bath-robe, in addition to an attendant, and the cost was only +eight sen. On top of that, a maid would serve tea in a regular polite +fashion. I always paid the first class. Then those gossipy spotters +started saying that for one who made only forty yen a month to take a +first class bath every day was extravagant. Why the devil should they +care? It was none of their business. +There is still some more. The bath-tub,--or the tank in this case,--was +built of granite, and measured about thirty square feet. Usually there +were thirteen or fourteen people in the tank, but sometimes there was +none. As the water came up clear to the breast, I enjoyed, for athletic +purposes, swimming in the tank. I delighted in swimming in this +30-square feet tank, taking chances of the total absence of other +people. Once, going downstairs from the third story with a light heart, +and peeping through the entrance of the tank to see if I should be able +to swim, I noticed a sign put up in which was boldly written: "No +swimming allowed in the tank." As there may not have been many who swam +in the tank, this notice was probably put up particularly for my sake. +After that I gave up swimming. But although I gave up swimming, I was +surprised, when I went to the school, to see on the board, as usual, +written: "No swimming allowed in the tank." It seemed as if all the +students united in tracking me everywhere. They made me sick. I was not +a fellow to stop doing whatever I had started upon no matter what +students might say, but I became thoroughly disgusted when I meditated +on why I had come to such a narrow, suffocating place. And, then, when I +returned home, the "antique curio siege" was still going on. +CHAPTER IV +For us teachers there was a duty of night watch in the school, and we +had to do it in turn. But Badger and Red Shirt were not in it. On +asking why these two were exempt from this duty, I was told that they +were accorded by the government treatment similar to officials of +"Sonin" rank. Oh, fudge! They were paid more, worked less, and were +then excused from this night watch. It was not fair. They made +regulations to suit their convenience and seemed to regard all this as +a matter of course. How could they be so brazen faced as this! I was +greatly dissatisfied relative to this question, but according to the +opinion of Porcupine, protests by a single person, with what insistency +they may be made, will not be heard. They ought to be heard whether +they are made by one person or by two if they are just. Porcupine +remonstrated with me by quoting "Might is right" in English. I did not +catch his point, so I asked him again, and he told me that it meant the +right of the stronger. If it was the right of the stronger I had known +it for long, and did not require Porcupine explain that to me at this +time. The right of the stronger was a question different from that of +the night watch. Who would agree that Badger and Red Shirt were the +stronger? But argument or no argument, the turn of this night watch at +last fell upon me. Being quite fastidious, I never enjoyed sound sleep +unless I slept comfortably in my own bedding. From my childhood, I +never stayed out overnight. When I did not find sleeping under the roof +of my friends inviting, night watch in the school, you may be sure, was +still worse. However repulsive, if this was a part of the forty yen a +month, there was no alternative. I had to do it. +To remain alone in the school after the faculty and students had gone +home, was something particularly awkward. The room for the night watch +was in the rear of the school building at the west end of the dormitory. +I stepped inside to see how it was, and finding it squarely facing the +setting sun, I thought I would melt. In spite of autumn having already +set in, the hot spell still lingered, quite in keeping with the +dilly-dally atmosphere of the country. I ordered the same kind of meal +as served for the students, and finished my supper. The meal was +unspeakably poor. It was a wonder they could subsist on such miserable +stuff and keep on "roughing it" in that lively fashion. Not only that, +they were always hungry for supper, finishing it at 4.30 in the +afternoon. They must be heroes in a sense. I had thus my supper, but the +sun being still high, could not go to bed yet. I felt like going to the +hot-springs. I did not know the wrong or right of night watch going out, +but it was oppressively trying to stand a life akin to heavy +imprisonment. When I called at the school the first time and inquired +about night watch, I was told by the janitor that he had just gone out +and I thought it strange. But now by taking the turn of night watch +myself, I could fathom the situation; it was right for any night watch +to go out. I told the janitor that I was going out for a minute. He +asked me "on business?" and I answered "No," but to take a bath at the +hot springs, and went out straight. It was too bad that I had left my +red towel at home, but I would borrow one over there for to-day. +I took plenty of time in dipping in the bath and as it became dark at +last, I came to the Furumachi Station on a train. It was only about four +blocks to the school; I could cover it in no time. When I started +walking schoolwards, Badger was seen coming from the opposite direction. +Badger, I presumed, was going to the hot springs by this train. He came +with brisk steps, and as we passed by, I nodded my courtesy. Then +Badger, with a studiously owlish countenance, asked: +"Am I wrong to understand that you are night watch?" +Chuck that "Am-I-wrong-to-understand"! Two hours ago, did he not say to +me "You're on first night watch to-night. Now, take care of yourself?" +What makes one use such a roundabout, twisted way of saying anything +when he becomes a principal? I was far from smiling. +"Yes, Sir," I said, "I'm night watch to-night, and as I am night watch I +will return to the school and stay there overnight, sure." With this +parting shot, I left him where we met. Coming then to the cross-streets +of Katamachi, I met Porcupine. This is a narrow place, I tell you. +Whenever one ventures out, he is sure to come across some familiar face. +"Say, aren't you night watch?" he hallooed, and I said "Yes, I am." "Tis +wrong for night watch to leave his post at his pleasure," he added, and +to this I blurted out with a bold front; "Nothing wrong at all. It is +wrong not to go out." +"Say, old man, your slap-dash is going to the limit. Wouldn't look well +for the principal or the head teacher to see you out like this." +The submissive tone of his remark was contrary to Porcupine as I had +known him so far, so I cut him short by saying: +"I have met the principal just now. Why, he approved my taking a stroll +about the town. Said it would be hard on night watch unless he took a +walk when it is hot." Then I made a bee-line for the school. +Soon it was night. I called the janitor to my room and had a chat for +about two hours. I grew tired of this, and thought I would get into bed +anyway, even if I could not sleep. I put on my night shirt, lifted the +mosquito-net, rolled off the red blanket and fell down flat on my back +with a bang. The making of this bumping noise when I go to bed is my +habit from my boyhood. "It is a bad habit," once declared a student of a +law school who lived on the ground floor, and I on the second, when I +was in the boarding house at Ogawa-machi, Kanda-ku, and who brought +complaints to my room in person. Students of law schools, weaklings as +they are, have double the ability of ordinary persons when it comes to +talking. As this student of law dwelt long on absurd accusations, I +downed him by answering that the noise made when I went to bed was not +the fault of my hip, but that of the house which was not built on a +solid base, and that if he had any fuss to make, make it to the house, +not to me. This room for night watch was not on the second floor, so +nobody cared how much I banged. I do not feel well-rested unless I go to +bed with the loudest bang I can make. +"This is bully!" and I straightened out my feet, when something jumped +and clung to them. They felt coarse, and seemed not to be fleas. I was a +bit surprised, and shook my feet inside the blanket two or three times. +Instantly the blamed thing increased,--five or six of them on my legs, +two or three on the thighs, one crushed beneath my hip and another clear +up to my belly. The shock became greater. Up I jumped, took off the +blanket, and about fifty to sixty grasshoppers flew out. I was more or +less uneasy until I found out what they were, but now I saw they were +grasshoppers, they set me on the war path. "You insignificant +grasshoppers, startling a man! See what's coming to you!" With this I +slapped them with my pillow twice or thrice, but the objects being so +small, the effect was out of proportion to the force with which the +blows were administered. I adopted a different plan. In the manner of +beating floor-mats with rolled matting at house-cleaning, I sat up in +bed and began beating them with the pillow. Many of them flew up by the +force of the pillow; some desperately clung on or shot against my nose +or head. I could not very well hit those on my head with the pillow; I +grabbed such, and dashed them on the floor. What was more provoking was +that no matter how hard I dashed them, they landed on the mosquito-net +where they made a fluffy jerk and remained, far from being dead. At +last, in about half an hour the slaughter of the grasshoppers was ended. +I fetched a broom and swept them out. The janitor came along and asked +what was the matter. +"Damn the matter! Where in thunder are the fools who keep grasshoppers +in bed! You pumpkinhead!" +The janitor answered by explaining that he did not know anything about +it. "You can't get away with Did-not-know," and I followed this +thundering by throwing away the broom. The awe-struck janitor shouldered +the broom and faded away. +At once I summoned three of the students to my room as the +"representatives," and six of them reported. Six or ten made no +difference; I rolled up the sleeves of my night-shirt and fired away. +"What do you mean by putting grasshoppers in my bed!" +"Grasshoppers? What are they?" said one in front, in a tone disgustingly +quiet. In this school, not only the principal, but the students as well, +were addicted to using twisted-round expressions. +"Don't know grasshoppers! You shall see!" To my chagrin, there was none; +I had swept them all out. I called the janitor again and told him to +fetch those grasshoppers he had taken away. The janitor said he had +thrown them into the garbage box, but that he would pick them out again. +"Yes, hurry up," I said, and he sped away. After a while he brought back +about ten grasshoppers on a white paper, remarking: +"I'm sorry, Sir. It's dark outside and I can't find out more. I'll find +some tomorrow." All fools here, down to the janitor. I showed one +grasshopper to the students. +"This is a grasshopper. What's the matter for as big idiots as you not +to know a grasshopper." Then the one with a round face sitting on the +left saucily shot back: +"A-ah say, that's a locust, a-ah----." +"Shut up. They're the same thing. In the first place, what do you +mean by answering your teacher 'A-ah say'? Ah-Say or Ah-Sing is a +Chink's name!" +For this counter-shot, he answered: +"A-ah say and Ah-Sing is different,--A-ah say." They never got rid of +"A-ah say." +"Grasshoppers or locusts, why did you put them into my bed? When I +asked you to?" +"Nobody put them in." +"If not, how could they get into the bed?" +"Locusts are fond of warm places and probably they got in there +respectfully by themselves." +"You fools! Grasshoppers getting into bed respectfully! I should smile +at them getting in there respectfully! Now, what's the reason for doing +this mischief? Speak out." +"But there is no way to explain it because we didn't do it." +Shrimps! If they were afraid of making a clean breast of their own deed, +they should not have done it at all. They looked defiant, and appeared +to insist on their innocence as long as no evidence was brought up. I +myself did some mischief while in the middle school, but when the +culprit was sought after, I was never so cowardly, not even once, to +back out. What one has done, has been done; what he has not, has not +been,--that's the black and white of it. I, for one have been game and +square, no matter how much mischief I might have done. If I wished to +dodge the punishment, I would not start it. Mischief and punishment are +bound to go together. We can enjoy mischief-making with some show of +spirit because it is accompanied by certain consequences. Where does one +expect to see the dastardly spirit which hungers for mischief-making +without punishment, in vogue? The fellows who like to borrow money but +not pay it back, are surely such as these students here after they are +graduated. What did these fellows come to this middle school for, +anyway? They enter a school, tattle round lies, play silly jokes behind +some one by sneaking and cheating and get wrongly swell-headed when they +finish the school thinking they have received an education. A common lot +of jackasses they are. +My hatred of talking with these scamps became intense, so I dismissed +them by saying: +"If you fellows have nothing to say, let it go at that. You deserve +pity for not knowing the decent from the vulgar after coming to a +middle school." +I am not very decent in my own language or manner, but am sure that my +moral standard is far more decent than that of these gangs. Those six +boys filed out leisurely. Outwardly they appeared more dignified than I +their teacher, it was the more repulsive for their calm behavior. I have +no temerity equal to theirs. Then I went to bed again, and found the +inside of the net full of merry crowds of mosquitoes. I could not bother +myself to burn one by one with a candle flame. So I took the net off the +hooks, folded it the lengthwise, and shook it crossways, up and down the +room. One of the rings of the net, flying round, accidentally hit the +back of my hand, the effect of which I did not soon forget. When I went +to bed for the third time, I cooled off a little, but could not sleep +easily. My watch showed it was half past ten. Well, as I thought it +over, I realized myself as having come to a dirty pit. If all teachers +of middle schools everywhere have to handle fellows like these in this +school, those teachers have my sympathy. It is wonderful that teachers +never run short. I believe there are many boneheads of extraordinary +patience; but me for something else. In this respect, Kiyo is worthy of +admiration. She is an old woman, with neither education nor social +position, but as a human, she does more to command our respect. Until +now, I have been a trouble to her without appreciating her goodness, but +having come alone to such a far-off country, I now appreciated, for the +first time, her kindness. If she is fond of sasa-ame of Echigo province, +and if I go to Echigo for the purpose of buying that sweetmeat to let +her eat it, she is fully worth that trouble. Kiyo has been praising me +as unselfish and straight, but she is a person of sterling qualities far +more than I whom she praises. I began to feel like meeting her. +While I was thus meditating about Kiyo, all of a sudden, on the floor +above my head, about thirty to forty people, if I guess by the number, +started stamping the floor with bang, bang, bang that well threatened to +bang down the floor. This was followed by proportionately loud whoops. +The noise surprised me, and I popped up. The moment I got up I became +aware that the students were starting a rough house to get even with me. +What wrong one has committed, he has to confess, or his offence is never +atoned for. They are just to ask for themselves what crimes they have +done. It should be proper that they repent their folly after going to +bed and to come and beg me pardon the next morning. Even if they could +not go so far as to apologize they should have kept quiet. Then what +does this racket mean? Where we keeping hogs in our dormitory? +"This crazy thing got to stop. See what you get!" +I ran out of the room in my night shirt, and flew upstairs in three and +half steps. Then, strange to say, thunderous rumbling, of which I was +sure of hearing in the act, was hushed. Not only a whisper but even +footsteps were not heard. This was funny. The lamp was already blown +out and although I could not see what was what in the dark, nevertheless +could tell by instinct whether there was somebody around or not. In the +long corridor running from the east to the west, there was not hiding +even a mouse. From other end of the corridor the moonlight flooded in +and about there it was particularly light. The scene was somewhat +uncanny. I have had the habit from my boyhood of frequently dreaming and +of flying out of bed and of muttering things which nobody understood, +affording everybody a hearty laugh. One night, when I was sixteen or +seventeen, I dreamed that I picked up a diamond, and getting up, +demanded of my brother who was sleeping close to me what he had done +with that diamond. The demand was made with such force that for about +three days all in the house chaffed me about the fatal loss of precious +stone, much to my humiliation. Maybe this noise which I heard was but a +dream, although I was sure it was real. I was wondering thus in the +middle of the corridor, when at the further end where it was moonlit, a +roar was raised, coming from about thirty or forty throats, "One, two, +three,--Whee-ee!" The roar had hardly subsided, when, as before, the +stamping of the floor commenced with furious rhythm. Ah, it was not a +dream, but a real thing! +"Quit making the noise! 'Tis midnight!" +I shouted to beat the band, and started in their direction. My passage +was dark; the moonlight yonder was only my guide. About twelve feet +past, I stumbled squarely against some hard object; ere the "Ouch!" has +passed clear up to my head, I was thrown down. I called all kinds of +gods, but could not run. My mind urged me on to hurry up, but my leg +would not obey the command. Growing impatient, I hobbled on one foot, +and found both voice and stamping already ceased and perfectly quiet. +Men can be cowards but I never expected them capable of becoming such +dastardly cowards as this. They challenged hogs. +Now the situation having developed to this pretty mess, I would not give +it up until I had dragged them out from hiding and forced them to +apologize. With this determination, I tried to open one of the doors and +examine inside, but it would not open. It was locked or held fast with a +pile of tables or something; to my persistent efforts the door stood +unyielding. Then I tried one across the corridor on the northside, but +it was also locked. While this irritating attempt at door-opening was +going on, again on the east end of the corridor the whooping roar and +rhythmic stamping of feet were heard. The fools at both ends were bent +on making a goose of me. I realized this, but then I was at a loss what +to do. I frankly confess that I have not quite as much tact as dashing +spirit. In such a case I am wholly at the mercy of swaying circumstances +without my own way of getting through it. Nevertheless, I do not expect +to play the part of underdog. If I dropped the affair then and there, it +would reflect upon my dignity. It would be mortifying to have them think +that they had one on the Tokyo-kid and that Tokyo-kid was wanting in +tenacity. To have it on record that I had been guyed by these +insignificant spawn when on night watch, and had to give in to their +impudence because I could not handle them,--this would be an indelible +disgrace on my life. Mark ye,--I am descendant of a samurai of the +"hatamato" class. The blood of the "hatamoto" samurai could be traced to +Mitsunaka Tada, who in turn could claim still a nobler ancestor. I am +different from, and nobler than, these manure-smelling louts. The only +pity is that I am rather short of tact; that I do not know what to do in +such a case. That is the trouble. But I would not throw up the sponge; +not on your life! I only do not know how because I am honest. Just +think,--if the honest does not win, what else is there in this world +that will win? If I cannot beat them to-night, I will tomorrow; if not +tomorrow, then the day after tomorrow. If not the day after tomorrow, I +will sit down right here, get my meals from my home until I beat them. +Thus resolved, I squatted in the middle of the corridor and waited for +the dawn. Myriads of mosquitoes swarmed about me, but I did not mind +them. I felt my leg where I hit it a while ago; it seemed bespattered +with something greasy. I thought it was bleeding. Let it bleed all it +cares! Meanwhile, exhausted by these unwonted affairs, I fell asleep. +When I awoke, up I jumped with a curse. The door on my right was half +opened, and two students were standing in front of me. The moment I +recovered my senses from the drowsy lull, I grabbed a leg of one of them +nearest to me, and yanked it with all my might. He fell down prone. Look +at what you're getting now! I flew at the other fellow, who was much +confused; gave him vigorous shaking twice or thrice, and he only kept +open his bewildering eyes. +"Come up to my room." Evidently they were mollycoddles, for they obeyed +my command without a murmur. The day had become already clear. +I began questioning those two in my room, but,--you cannot pound out the +leopard's spots no matter how you may try,--they seemed determined to +push it through by an insistent declaration of "not guilty," that they +would not confess. While this questioning was going on, the students +upstairs came down, one by one, and began congregating in my room. I +noticed all their eyes were swollen from want of sleep. +"Blooming nice faces you got for not sleeping only one night. And you +call yourselves men! Go, wash your face and come back to hear what I've +got to tell you." +I hurled this shot at them, but none of them went to wash his face. For +about one hour, I had been talking and back-talking with about fifty +students when suddenly Badger put in his appearance. I heard afterward +that the janitor ran to Badger for the purpose of reporting to him that +there was a trouble in the school. What a weak-knee of the janitor to +fetch the principal for so trifling an affair as this! No wonder he +cannot see better times than a janitor. +The principal listened to my explanation, and also to brief remarks from +the students. "Attend school as usual till further notice. Hurry up with +washing your face and breakfast; there isn't much time left." So the +principal let go all the students. Decidedly slow way of handling, this. +If I were the principal, I would expel them right away. It is because +the school accords them such luke-warm treatment that they get "fresh" +and start "guying" the night watch. +He said to me that it must have been trying on my nerves, and that +I might be tired, and also that I need not teach that day. To this +I replied: +"No, Sir, no worrying at all. Such things may happen every night, +but it would not disturb me in the least as long as I breathe. I +will do the teaching. If I were not able to teach on account of lack +of sleep for only one single night, I would make a rebate of my +salary to the school." +I do not know how this impressed him, but he gazed at me for a while, +and called my attention to the fact that my face was rather swollen. +Indeed, I felt it heavy. Besides, it itched all over. I was sure the +mosquitoes must have stung me there to their hearts' content. I +further added: +"My face may be swollen, but I can talk all right; so I will teach;" +thus scratching my face with some warmth. The principal smiled and +remarked, "Well, you have the strength." To tell the truth, he did not +intend remark to be a compliment, but, I think, a sneer. +CHAPTER V. +"Won't you go fishing?" asked Red Shirt He talks in a strangely womanish +voice. One would not be able to tell whether he was a man or a woman. As +a man he should talk like one. Is he not a college graduate? I can talk +man-like enough, and am a graduate from a school of physics at that. It +is a shame for a B.A. to have such a squeak. +I answered with the smallest enthusiasm, whereupon he further asked me +an impolite question if I ever did fishing. I told him not much, that I +once caught three gibels when I was a boy, at a fishing game pond at +Koume, and that I also caught a carp about eight inches long, at a +similar game at the festival of Bishamon at Kagurazaka;--the carp, just +as I was coaxing it out of the water, splashed back into it, and when I +think of the incident I feel mortified at the loss even now. Red Shirt +stuck out his chin and laughed "ho, ho." Why could he not laugh just +like an ordinary person? "Then you are not well acquainted with the +spirit of the game," he cried. "I'll show you if you like." He seemed +highly elated. +Not for me! I take it this way that generally those who are fond of +fishing or shooting have cruel hearts. Otherwise, there is no reason why +they could derive pleasure in murdering innocent creatures. Surely, fish +and birds would prefer living to getting killed. Except those who make +fishing or shooting their calling, it is nonsense for those who are well +off to say that they cannot sleep well unless they seek the lives of +fish or birds. This was the way I looked at the question, but as he was +a B. A. and would have a better command of language when it came to +talking, I kept mum, knowing he would beat me in argument. Red Shirt +mistook my silence for my surrender, and began to induce me to join him +right away, saying he would show me some fish and I should come with him +if I was not busy, because he and Mr. Yoshikawa were lonesome when +alone. Mr. Yoshikawa is the teacher of drawing whom I had nicknamed +Clown. I don't know what's in the mind of this Clown, but he was a +constant visitor at the house of Red Shirt, and wherever he went, Clown +was sure to be trailing after him. They appeared more like master and +servant than two fellow teachers. As Clown used to follow Red Shirt like +a shadow, it would be natural to see them go off together now, but when +those two alone would have been well off, why should they invite +me,--this brusque, unaesthetic fellow,--was hard to understand. +Probably, vain of his fishing ability, he desired to show his skill, but +he aimed at the wrong mark, if that was his intention, as nothing of the +kind would touch me. I would not be chagrined if he fishes out two or +three tunnies. I am a man myself and poor though I may be in the art, I +would hook something if I dropped a line. If I declined his invitation, +Red Shirt would suspect that I refused not because of my lack of +interest in the game but because of my want of skill of fishing. I +weighed the matter thus, and accepted his invitation. After the school, +I returned home and got ready, and having joined Red Shirt and Clown at +the station, we three started to the shore. There was only one boatman +to row; the boat was long and narrow, a kind we do not have in Tokyo. I +looked for fishing rods but could find none. +"How can we fish without rods? How are we going to manage it?" I asked +Clown and he told me with the air of a professional fisherman that no +rods were needed in the deep-sea fishing, but only lines. I had better +not asked him if I was to be talked down in this way. +The boatman was rowing very slowly, but his skill was something +wonderful. We had already come far out to sea, and on turning back, saw +the shore minimized, fading in far distance. The five-storied pagoda of +Tosho Temple appeared above the surrounding woods like a needle-point. +Yonder stood Aoshima (Blue Island). Nobody was living on this island +which a closer view showed to be covered with stones and pine trees. No +wonder no one could live there. Red Shirt was intently surveying about +and praising the general view as fine. Clown also termed it "an +absolutely fine view." I don't know whether it is so fine as to be +absolute, but there was no doubt as to the exhilarating air. I realized +it as the best tonic to be thus blown by the fresh sea breeze upon a +wide expanse of water. I felt hungry. +"Look at that pine; its trunk is straight and spreads its top branches +like an umbrella. Isn't it a Turnersque picture?" said Red Shirt. "Yes, +just like Turner's," responded Clown, "Isn't the way it curves just +elegant? Exactly the touch of Turner," he added with some show of pride. +I didn't know what Turner was, but as I could get along without knowing +it, I kept silent. The boat turned to the left with the island on the +right. The sea was so perfectly calm as to tempt one to think he was not +on the deep sea. The pleasant occasion was a credit to Red Shirt. As I +wished, if possible, to land on the island, I asked the boatman if our +boat could not be made to it. Upon this Red Shirt objected, saying that +we could do so but it was not advisable to go too close the shore for +fishing. I kept still for a while. Then Clown made the unlooked-for +proposal that the island be named Turner Island. "That's good; We shall +call it so hereafter," seconded Red Shirt. If I was included in that +"We," it was something I least cared for. Aoshima was good enough for +me. "By the way, how would it look," said Clown, "if we place Madonna by +Raphael upon that rock? It would make a fine picture." +"Let's quit talking about Madonna, ho, ho, ho," and Red Shirt emitted a +spooky laugh. +"That's all right. Nobody's around," remarked Clown as he glanced at me, +and turning his face to other direction significantly, smiled +devilishly. I felt sickened. +As it was none of my business whether it was a Madonna or a kodanna +(young master), they let pose there any old way, but it was vulgar to +feign assurance that one's subject is in no danger of being understood +so long as others did not know the subject. Clown claims himself as a +Yedo kid. I thought that the person called Madonna was no other than a +favorite geisha of Red Shirt. I should smile at the idea of his gazing +at his tootsy-wootsy standing beneath a pine tree. It would be better +if Clown would make an oil painting of the scene and exhibit it for +the public. +"This will be about the best place." So saying the boatman stopped +rowing the boat and dropped an anchor. +"How deep is it?" asked Red Shirt, and was told about six fathoms. +"Hard to fish sea-breams in six fathoms," said Red Shirt as he dropped a +line into the water. The old sport appeared to expect to fetch some +bream. Bravo! +"It wouldn't be hard for you. Besides it is calm," Clown fawningly +remarked, and he too dropped a line. The line had only a tiny bit of +lead that looked like a weight. It had no float. To fish without a float +seemed as nearly reasonable as to measure the heat without a +thermometer, which was something impossible for me. So I looked on. They +then told me to start, and asked me if I had any line. I told them I had +more than I could use, but that I had no float. +"To say that one is unable to fish without a float shows that he is a +novice," piped up Clown. +"See? When the line touches the bottom, you just manage it with your +finger on the edge. If a fish bites, you could tell in a minute. There +it goes," and Red Shirt hastily started taking out the line. I wondered +what he had got, but I saw no fish, only the bait was gone. Ha, good for +you, Gov'nur! +"Wasn't it too bad! I'm sure it was a big one. If you miss that way, +with your ability, we would have to keep a sharper watch to-day. But, +say, even if we miss the fish, it's far better than staring at a float, +isn't it? Just like saying he can't ride a bike without a brake." Clown +has been getting rather gay, and I was almost tempted to swat him. I'm +just as good as they are. The sea isn't leased by Red Shirt, and there +might be one obliging bonito which might get caught by my line. I +dropped my line then, and toyed it with my finger carelessly. +After a while something shook my line with successive jerks. I thought +it must be a fish. Unless it was something living, it would not give +that tremulous shaking. Good! I have it, and I commenced drawing in the +line, while Clown jibed me "What? Caught one already? Very remarkable, +indeed!" I had drawn in nearly all the line, leaving only about five +feet in the water. I peeped over and saw a fish that looked like a gold +fish with stripes was coming up swimming to right and left. It was +interesting. On taking it out of the water, it wriggled and jumped, and +covered my face with water. After some effort, I had it and tried to +detach the hook, but it would not come out easily. My hands became +greasy and the sense was anything but pleasing. I was irritated; I swung +the line and banged the fish against the bottom of the boat. It speedily +died. Red Shirt and Clown watched me with surprise. I washed my hands in +the water but they still smelled "fishy." No more for me! I don't care +what fish I might get, I don't want to grab a fish. And I presume the +fish doesn't want to be grabbed either. I hastily rolled up the line. +"Splendid for the first honor, but that's goruki," Clown again made a +"fresh" remark. +"Goruki sounds like the name of a Russian literator," said Red Shirt. +"Yes, just like a Russian literator," Clown at once seconded Red Shirt. +Gorky for a Russian literator, Maruki a photographer of Shibaku, and +komeno-naruki (rice) a life-giver, eh? This Red Shirt has a bad hobby of +marshalling before anybody the name of foreigners. Everybody has his +specialty. How could a teacher of mathematics like me tell whether it is +a Gorky or shariki (rikishaman). Red Shirt should have been a little +more considerate. And if he wants to mention such names at all, let him +mention "Autobiography of Ben Franklin," or "Pushing to the Front," or +something we all know. Red Shirt has been seen once in a while bringing +a magazine with a red cover entitled Imperial Literature to the school +and poring over it with reverence. I heard it from Porcupine that Red +Shirt gets his supply of all foreign names from that magazine. Well, I +should say! +For some time, Red Shirt and Clown fished assiduously and within about +an hour they caught about fifteen fish. The funny part of it was that +all they caught were goruki; of sea-bream there was not a sign. +"This is a day of bumper crop of Russian literature," Red Shirt said, +and Clown answered: +"When one as skilled as you gets nothing but goruki, it's natural for me +to get nothing else." +The boatman told me that this small-sized fish goruki has too many +tiny bones and tastes too poor to be fit for eating, but they could be +used for fertilising. So Red Shirt and Clown were fishing fertilisers +with vim and vigor. As for me, one goruki was enough and I laid down +myself on the bottom, and looked up at the sky. This was far more +dandy than fishing. +Then the two began whispering. I could not hear well, nor did I care to. +I was looking up at the sky and thinking about Kiyo. If I had enough of +money, I thought, and came with Kiyo to such a picturesque place, how +joyous it would be. No matter how picturesque the scene might be, it +would be flat in the company of Clown or of his kind. Kiyo is a poor +wrinkled woman, but I am not ashamed to take her to any old place. Clown +or his likes, even in a Victoria or a yacht, or in a sky-high position, +would not be worthy to come within her shadow. If I were the head +teacher, and Red Shirt I, Clown would be sure to fawn on me and jeer at +Red Shirt. They say Yedo kids are flippant. Indeed, if a fellow like +Clown was to travel the country and repeatedly declare "I am a Yedo +kid," no wonder the country folk would decide that the flippant are Yedo +kids and Yedo kids are flippant. While I was meditating like this, I +heard suppressed laughter. Between their laughs they talked something, +but I could not make out what they were talking about. "Eh? I don't +know......" "...... That's true ...... he doesn't know ...... isn't it +pity, though ......." "Can that be......." "With grasshoppers ...... +that's a fact." +I did not listen to what they were talking, but when I heard Clown say +"grasshoppers," I cocked my ear instinctively. Clown emphasized, for +what reason I do not know the word "grasshopers" so that it would be +sure to reach my ear plainly, and he blurred the rest on purpose. I did +not move, and kept on listening. "That same old Hotta," "that may be the +case...." "Tempura ...... ha, ha, ha ......" "...... incited ......" +"...... dango also? ......" +The words were thus choppy, but judging by their saying "grasshoppers," +"tempura" or "dango," I was sure they were secretly talking something +about me. If they wanted to talk, they should do it louder. If they +wanted to discuss something secret, why in thunder did they invite me? +What damnable blokes! Grasshoppers or glass-stoppers, I was not in the +wrong; I have kept quiet to save the face of Badger because the +principle asked me to leave the matter to him. Clown has been making +unnecessary criticisms; out with your old paint-brushes there! Whatever +concerns me, I will settle it myself sooner or later, and they had just +to keep off my toes. But remarks such as "the same old Hotta" or "...... +incited ......" worried me a bit. I could not make out whether they +meant that Hotta incited me to extend the circle of the trouble, or that +he incited the students to get at me. As I gazed at the blue sky, the +sunlight gradually waned and chilly winds commenced stirring. The clouds +that resembled the streaky smokes of joss sticks were slowly extending +over a clear sky, and by degrees they were absorbed, melted and changed +to a faint fog. +"Well, let's be going," said Red Shirt suddenly. "Yes, this is the time +we were going. See your Madonna to-night?" responded Clown. "Cut out +nonsense ...... might mean a serious trouble," said Red Shirt who was +reclining against the edge of the boat, now raising himself. "O, that's +all right if he hears.......," and when Clown, so saying, turned himself +my way, I glared squarely in his face. Clown turned back as if to keep +away from a dazzling light, and with "Ha, this is going some," shrugged +his shoulders and scratched his head. +The boat was now being rowed shore-ward over the calm sea. "You don't +seem much fond of fishing," asked Red Shirt. "No, I'd rather prefer +lying and looking at the sky," I answered, and threw the stub of +cigarette I had been smoking into the water; it sizzled and floated on +the waves parted by the oar. +"The students are all glad because you have come. So we want you do your +best." Red Shirt this time started something quite alien to fishing. "I +don't think they are," I said. "Yes; I don't mean it as flattery. They +are, sure. Isn't it so, Mr. Yoshikawa?" +"I should say they are. They're crazy over it," said Clown with an +unctuous smile. Strange that whatever Clown says, it makes me itching +mad. "But, if you don't look out, there is danger," warned Red Shirt. +"I am fully prepared for all dangers," I replied. In fact, I had made up +my mind either to get fired or to make all the students in the dormitory +apologize to me. +"If you talk that way, that cuts everything out. Really, as a head +teacher, I've been considering what is good for you, and wouldn't like +you to mistake it." +"The head teacher is really your friend. And I'm doing what I can for +you, though mighty little, because you and I are Yedo kids, and I would +like to have you stay with us as long as possible and we can help each +other." So said Clown and it sounded almost human. I would sooner hang +myself than to get helped by Clown. +"And the students are all glad because you had come, but there are many +circumstances," continued Red Shirt. "You may feel angry sometimes but +be patient for the present, and I will never do anything to hurt your +interests." +"You say 'many circumstances'; what are they?" +"They're rather complicated. Well, they'll be clear to you by and by. +You'll understand them naturally without my talking them over. What do +you say, Mr. Yoshikawa?" +"Yes, they're pretty complicated; hard to get them cleared up in a +jiffy. But they'll become clear by-the-bye. Will be understood naturally +without my explaining them," Clown echoed Red Shirt. +"If they're such a bother, I don't mind not hearing them. I only asked +you because you sprang the subject." +"That's right. I may seem irresponsible in not concluding the thing I +had started. Then this much I'll tell you. I mean no offense, but you +are fresh from school, and teaching is a new experience. And a school is +a place where somewhat complicated private circumstances are common and +one cannot do everything straight and simple". +"If can't get it through straight and simple, how does it go?" +"Well, there you are so straight as that. As I was saying, you're short +of experience........" +"I should be. As I wrote it down in my record-sheet, I'm 23 years and +four months." +"That's it. So you'd be done by some one in unexpected quarter." +"I'm not afraid who might do me as long as I'm honest." +"Certainly not. No need be afraid, but I do say you look sharp; your +predecessor was done." +I noticed Clown had become quiet, and turning round, saw him at the +stern talking with the boatman. Without Clown, I found our conversation +running smoothly. +"By whom was my predecessor done?" +"If I point out the name, it would reflect on the honor of that person, +so I can't mention it. Besides there is no evidence to prove it and I +may be in a bad fix if I say it. At any rate, since you're here, my +efforts will prove nothing if you fail. Keep a sharp look-out, please." +"You say look-out, but I can't be more watchful than I'm now. If I don't +do anything wrong, after all, that's all right isn't it?" +Red Shirt laughed. I did not remember having said anything provocative +of laughter. Up to this very minute, I have been firm in my conviction +that I'm right. When I come to consider the situation, it appears that a +majority of people are encouraging others to become bad. They seem to +believe that one must do wrong in order to succeed. If they happen to +see some one honest and pure, they sneer at him as "Master Darling" or +"kiddy." What's the use then of the instructors of ethics at grammar +schools or middle schools teaching children not to tell a lie or to be +honest. Better rather make a bold departure and teach at schools the +gentle art of lying or the trick of distrusting others, or show pupils +how to do others. That would be beneficial for the person thus taught +and for the public as well. When Red Shirt laughed, he laughed at my +simplicity. My word! what chances have the simple-hearted or the pure in +a society where they are made objects of contempt! Kiyo would never +laugh at such a time; she would listen with profound respect. Kiyo is +far superior to Red Shirt. +"Of course, that't all right as long as you don't do anything wrong. But +although you may not do anything wrong, they will do you just the same +unless you can see the wrong of others. There are fellows you have got +to watch,--the fellows who may appear off-hand, simple and so kind as to +get boarding house for you...... Getting rather cold. 'Tis already +autumn, isn't it. The beach looks beer-color in the fog. A fine view. +Say, Mr. Yoshikawa, what do you think of the scene along the +beach?......" This in a loud voice was addressed to Clown. +"Indeed, this is a fine view. I'd get a sketch of it if I had time. +Seems a pity to leave it there," answered Clown. +A light was seen upstairs at Minato-ya, and just as the whistle of a +train was sounded, our boat pushed its nose deep into the sand. "Well, +so you're back early," courtesied the wife of the boatman as she stepped +upon the sand. I stood on the edge of the boat; and whoop! I jumped out +to the beach. +CHAPTER VI. +I heartily despise Clown. It would be beneficial for Japan if such a +fellow were tied to a quernstone and dumped into the sea. As to Red +Shirt, his voice did not suit my fancy. I believe he suppresses his +natural tones to put on airs and assume genteel manner. He may put on +all kinds of airs, but nothing good will come of it with that type of +face. If anything falls in love with him, perhaps the Madonna will be +about the limit. As a head-teacher, however, he is more serious than +Clown. As he did not say definitely, I cannot get to the point, but it +appears that he warned me to look-out for Porcupine as he is crooked. If +that was the case, he should have declared it like a man. And if +Porcupine is so bad a teacher as that, it would be better to discharge +him. What a lack of backbone for a head teacher and a Bachelor of Arts! +As he is a fellow so cautious as to be unable to mention the name of the +other even in a whisper, he is surely a mollycoddle. All mollycoddles +are kind, and that Red Shirt may be as kind as a woman. His kindness is +one thing, and his voice quite another, and it would be wrong to +disregard his kindness on account of his voice. But then, isn't this +world a funny place! The fellow I don't like is kind to me, and the +friend whom I like is crooked,--how absurd! Probably everything here +goes in opposite directions as it is in the country, the contrary holds +in Tokyo. A dangerous place, this. By degrees, fires may get frozen and +custard pudding petrified. But it is hardly believable that Porcupine +would incite the students, although he might do most anything he wishes +as he is best liked among them. Instead of taking in so roundabout a +way, in the first place, it would have saved him a lot of trouble if he +came direct to me and got at me for a fight. If I am in his way, he had +better tell me so, and ask me to resign because I am in his way. There +is nothing that cannot be settled by talking it over. If what he says +sounds reasonable, I would resign even tomorrow. This is not the only +town where I can get bread and butter; I ought not to die homeless +wherever I go. I thought Porcupine was a better sport. +When I came here, Porcupine was the first to treat me to ice water. To +be treated by such a fellow, even if it is so trifling a thing as ice +water, affects my honor. I had only one glass then and had him pay only +one sen and a half. But one sen or half sen, I shall not die in peace if +I accept a favor from a swindler. I will pay it back tomorrow when I go +to the school. I borrowed three yen from Kiyo. That three yen is not +paid yet to-day, though it is five years since. Not that I could not +pay, but that I did not want to. Kiyo never looks to my pocket thinking +I shall pay it back by-the-bye. Not by any means. I myself do not expect +to fulfill cold obligation like a stranger by meditating on returning +it. The more I worry about paying it back, the more I may be doubting +the honest heart of Kiyo. It would be the same as traducing her pure +mind. I have not paid her back that three yen not because I regard her +lightly, but because I regard her as part of myself. Kiyo and Porcupine +cannot be compared, of course, but whether it be ice water or tea, the +fact that I accept another's favor without saying anything is an act of +good-will, taking the other on his par value, as a decent fellow. +Instead of chipping in my share, and settling each account, to receive +munificence with grateful mind is an acknowledgment which no amount of +money can purchase. I have neither title nor official position but I am +an independent fellow, and to have an independent fellow kowtow to you +in acknowledgment of the favor you extend him should be considered as +far more than a return acknowledgment with a million yen. I made +Porcupine blow one sen and a half, and gave him my gratitude which is +more costly than a million yen. He ought to have been thankful for that. +And then what an outrageous fellow to plan a cowardly action behind my +back! I will give him back that one sen and a half tomorrow, and all +will be square. Then I will land him one. When I thought thus far, I +felt sleepy and slept like a log. The next day, as I had something in my +mind, I went to the school earlier than usual and waited for Porcupine, +but he did not appear for a considerable time. "Confucius" was there, so +was Clown, and finally Red Shirt, but for Porcupine there was a piece of +chalk on his desk but the owner was not there. I had been thinking of +paying that one sen and a half as soon as I entered the room, and had +brought the coppers to the school grasped in my hand. My hands get +easily sweaty, and when I opened my hand, I found them wet. Thinking +that Porcupine might say something if wet coins were given him, I placed +them upon my desk, and cooled them by blowing in them. Then Red Shirt +came to me and said he was sorry to detain me yesterday, thought I have +been annoyed. I told him I was not annoyed at all, only I was hungry. +Thereupon Red Shirt put his elbows upon the desk, brought his +sauce-pan-like face close to my nose, and said; "Say, keep dark what I +told you yesterday in the boat. You haven't told it anybody, have you?" +He seems quite a nervous fellow as becoming one who talks in a feminish +voice. It was certain that I had not told it to anybody, but as I was in +the mood to tell it and had already one sen and a half in my hand, I +would be a little rattled if a gag was put on me. To the devil with Red +Shirt! Although he had not mentioned the name "Porcupine," he had given +me such pointers as to put me wise as to who the objective was, and now +he requested me not to blow the gaff!--it was an irresponsibility least +to be expected from a head teacher. In the ordinary run of things, he +should step into the thick of the fight between Porcupine and me, and +side with me with all his colors flying. By so doing, he might be worthy +the position of the head teacher, and vindicate the principle of wearing +red shirts. +I told the head teacher that I had not divulged the secret to anybody +but was going to fight it out with Porcupine. Red Shirt was greatly +perturbed, and stuttered out; "Say, don't do anything so rash as that. I +don't remember having stated anything plainly to you about Mr. +Hotta....... if you start a scrimmage here, I'll be greatly +embarrassed." And he asked the strangely outlandish question if I had +come to the school to start trouble? Of course not, I said, the school +would not stand for my making trouble and pay me salary for it. Red +Shirt then, perspiring, begged me to keep the secret as mere reference +and never mention it. "All right, then," I assured him, "this robs me +shy, but since you're so afraid of it, I'll keep it all to myself." "Are +you sure?" repeated Red Shirt. There was no limit to his womanishness. +If Red Shirt was typical of Bachelors of Arts, I did not see much in +them. He appeared composed after having requested me to do something +self-contradictory and wanting logic, and on top of that suspects my +sincerity. +"Don't you mistake," I said to myself, "I'm a man to the marrow, and +haven't the idea of breaking my own promises; mark that!" +Meanwhile the occupants of the desks on both my sides came to the room, +and Red Shirt hastily withdrew to his own desk. Red Shirt shows some air +even in his walk. In stepping about the room, he places down his shoes +so as to make no sound. For the first time I came to know that making no +sound in one's walk was something satisfactory to one's vanity. He was +not training himself for a burglar, I suppose. He should cut out such +nonsense before it gets worse. Then the bugle for the opening of classes +was heard. Porcupine did not appear after all. There was no other way +but to leave the coins upon the desk and attend the class. +When I returned to the room a little late after the first hour class, +all the teachers were there at their desks, and Porcupine too was +there. The moment Porcupine saw my face, he said that he was late on +my account, and I should pay him a fine. I took out that one sen and a +half, and saying it was the price of the ice water, shoved it on his +desk and told him to take it. "Don't josh me," he said, and began +laughing, but as I appeared unusually serious, he swept the coins back +to my desk, and flung back, "Quit fooling." So he really meant to +treat me, eh? +"No fooling; I mean it," I said. "I have no reason to accept your treat, +and that's why I pay you back. Why don't you take it?" +"If you're so worried about that one sen and a half, I will take it, but +why do you pay it at this time so suddenly?" +"This time or any time, I want to pay it back. I pay it back because I +don't like you treat me." +Porcupine coldly gazed at me and ejaculated "H'm." If I had not been +requested by Red Shirt, here was the chance to show up his cowardice and +make it hot for him. But since I had promised not to reveal the secret, +I could do nothing. What the deuce did he mean by "H'm" when I was red +with anger. +"I'll take the price of the ice water, but I want you leave your +boarding house." +"Take that coin; that's all there is to it. To leave or not,--that's my +pleasure." +"But that is not your pleasure. The boss of your boarding house came to +me yesterday and wanted me to tell you leave the house, and when I heard +his explanation, what he said was reasonable. And I dropped there on my +way here this morning to hear more details and make sure of everything." +What Porcupine was trying to get at was all dark to me. +"I don't care a snap what the boss was damn well pleased to tell you," I +cried. "What do you mean by deciding everything by yourself! If there is +any reason, tell me first. What's the matter with you, deciding what the +boss says is reasonable without hearing me." +"Then you shall hear," he said. "You're too tough and been regarded +a nuisance over there. Say, the wife of a boarding house is a wife, +not a maid, and you've been such a four-flusher as to make her wipe +your feet." +"When did I make her wipe my feet?" I asked. +"I don't know whether you did or did not, but anyway they're pretty sore +about you. He said he can make ten or fifteen yen easily if he sell a +roll of panel-picture." +"Damn the chap! Why did he take me for a boarder then!" +"I don't know why. They took you but they want you leave because they +got tired of you. So you'd better get out." +"Sure, I will. Who'd stay in such a house even if they beg me on their +knees. You're insolent to have induced me to go to such a false accuser +in the first place." +"Might be either I'm insolent or you're tough." Porcupine is no less +hot-tempered than I am, and spoke with equally loud voice. All the other +teachers in the room, surprised, wondering what has happened, looked in +our direction and craned their necks. I was not conscious of having done +anything to be ashamed of, so I stood up and looked around. Clown alone +was laughing amused. The moment he met my glaring stare as if to say +"You too want to fight?" he suddenly assumed a grave face and became +serious. He seemed to be a little cowed. Meanwhile the bugle was heard, +and Porcupine and I stopped the quarrel and went to the class rooms. +In the afternoon, a meeting of the teachers was going to be held to +discuss the question of punishment of those students in the dormitory +who offended me the other night. This meeting was a thing I had to +attend for the first time in my life, and I was totally ignorant about +it. Probably it was where the teachers gathered to blow about their own +opinions and the principal bring them to compromise somehow. To +compromise is a method used when no decision can be delivered as to the +right or wrong of either side. It seemed to me a waste of time to hold a +meeting over an affair in which the guilt of the other side was plain as +daylight. No matter who tried to twist it round, there was no ground for +doubting the facts. It would have been better if the principal had +decided at once on such a plain case; he is surely wanting in decision. +If all principals are like this, a principal is a synonym of a +"dilly-dally." +The meeting hall was a long, narrow room next to that of the principal, +and was used for dining room. About twenty chairs, with black leather +seat, were lined around a narrow table, and the whole scene looked like +a restaurant in Kanda. At one end of the table the principal took his +seat, and next to him Red Shirt. All the rest shifted for themselves, +but the gymnasium teacher is said always to take the seat farthest down +out of modesty. The situation was new to me, so I sat down between the +teachers of natural history and of Confucius. Across the table sat +Porcupine and Clown. Think how I might, the face of Clown was a +degrading type. That of Porcupine was far more charming, even if I was +now on bad terms with him. The panel picture which hung in the alcove of +the reception hall of Yogen temple where I went to the funeral of my +father, looked exactly like this Porcupine. A priest told me the picture +was the face of a strange creature called Idaten. To-day he was pretty +sore, and frequently stared at me with his fiery eyes rolling. "You +can't bulldoze me with that," I thought, and rolled my own in defiance +and stared back at him. My eyes are not well-shaped but their large size +is seldom beaten by others. Kiyo even once suggested that I should make +a fine actor because I had big eyes. +"All now here?" asked the principal, and the clerk named Kawamura +counted one, two, three and one was short. "Just one more," said the +clerk, and it ought to be; Hubbard Squash was not there. I don't know +what affinity there is between Hubbard Squash and me, but I can never +forget his face. When I come to the teachers' room, his face attracts me +first; while walking out in the street, his manners are recalled to my +mind. When I go to the hot springs, sometimes I meet him with a +pale-face in the bath, and if I hallooed to him, he would raise his +trembling head, making me feel sorry for him. In the school there is no +teacher so quiet as he. He seldom, if ever, laughs or talks. I knew the +word "gentleman" from books, and thought it was found only in the +dictionary, but not a thing alive. But since I met Hubbard Squash, I was +impressed for the first time that the word represented a real substance. +As he is a man so attached to me, I had noticed his absence as soon as I +entered the meeting hall. To tell the truth, I came to the hall with the +intention of sitting next to him. The principal said that the absentee +may appear shortly, and untied a package he had before him, taking out +some hectograph sheets and began reading them. Red Shirt began polishing +his amber pipe with a silk handkerchief. This was his hobby, which was +probably becoming to him. Others whispered with their neighbors. Still +others were writing nothings upon the table with the erasers at the end +of their pencils. Clown talked to Porcupine once in a while, but he was +not responsive. He only said "Umh" or "Ahm," and stared at me with +wrathful eyes. I stared back with equal ferocity. +Then the tardy Hubbard Squash apologetically entered, and politely +explained that he was unavoidably detained. "Well, then the meeting is +called to order," said Badger. On these sheets was printed, first the +question of the punishment of the offending students, second that of +superintending the students, and two or three other matters. Badger, +putting on airs as usual, as if he was an incarnation of education, +spoke to the following effect. +"Any misdeeds or faults among the teachers or the students in this +school are due to the lack of virtues in my person, and whenever +anything happens, I inwardly feel ashamed that a man like me could hold +his position. Unfortunately such an affair has taken place again, and I +have to apologize from my heart. But since it has happened, it cannot be +helped; we must settle it one way or other. The facts are as you already +know, and I ask you gentlemen to state frankly the best means by which +the affair may be settled." +When I heard the principal speak, I was impressed that indeed the +principal, or Badger, was saying something "grand." If the principal was +willing to assume all responsibilities, saying it was his fault or his +lack of virtues, it would have been better stop punishing the students +and get himself fired first. Then there will be no need of holding such +thing as a meeting. In the first place, just consider it by common +sense. I was doing my night duty right, and the students started +trouble. The wrong doer is neither the principal nor I. If Porcupine +incited them, then it would be enough to get rid of the students and +Porcupine. Where in thunder would be a peach of damfool who always +swipes other people's faults and says "these are mine?" It was a stunt +made possible only by Badger. Having made such an illogical statement, +he glanced at the teachers in a highly pleased manner. But no one opened +his mouth. The teacher of natural history was gazing at the crow which +had hopped on the roof of the nearby building. The teacher of Confucius +was folding and unfolding the hectograph sheet. Porcupine was still +staring at me. If a meeting was so nonsensical an affair as this, I +would have been better absent taking a nap at home. +I became irritated, and half raised myself, intending to make a +convincing speech, but just then Red Shirt began saying something and I +stopped. I saw him say something, having put away his pipe, and wiping +his face with a striped silk handkerchief. I'm sure he copped that +handkerchief from the Madonna; men should use white linen. He said: +"When I heard of the rough affairs in the dormitory, I was greatly +ashamed as the head teacher of my lack of discipline and influence. When +such an affair takes place there is underlying cause somewhere. Looking +at the affair itself, it may seem that the students were wrong, but in a +closer study of the facts, we may find the responsibility resting with +the School. Therefore, I'm afraid it might affect us badly in the future +if we administer too severe a punishment on the strength of what has +been shown on the surface. As they are youngsters, full of life and +vigor, they might half-consciously commit some youthful pranks, without +due regard as to their good or bad. As to the mode of punishment itself, +I have no right to suggest since it is a matter entirely in the hand of +the principal, but I should ask, considering these points, that some +leniency be shown toward the students." +Well, as Badger, so was Red Shirt. He declares the "Rough Necks" among +the students is not their fault but the fault of the teachers. A crazy +person beats other people because the beaten are wrong. Very grateful, +indeed. If the students were so full of life and vigor, shovel them out +into the campus and let them wrestle their heads off. Who would have +grasshoppers put into his bed unconsciously! If things go on like this, +they may stab some one asleep, and get freed as having done the deed +unconsciously. +Having figured it out in this wise, I thought I would state my own views +on the matter, but I wanted to give them an eloquent speech and fairly +take away their breath. I have an affection of the windpipe which clog +after two or three words when I am excited. Badger and Red Shirt are +below my standing in their personality, but they were skilled in +speech-making, and it would not do to have them see my awkwardness. I'll +make a rough note of composition first, I thought, and started mentally +making a sentence, when, to my surprise, Clown stood up suddenly. It was +unusual for Clown to state his opinion. He spoke in his flippant tone: +"Really the grasshopper incident and the whoop-la affair are peculiar +happenings which are enough to make us doubt our own future. We teachers +at this time must strive to clear the atmosphere of the school. And +what the principal and the head teacher have said just now are fit and +proper. I entirely agree with their opinions. I wish the punishment be +moderate." +In what Clown had said there were words but no meaning. It was a +juxtaposition of high-flown words making no sense. All that I understood +was the words, "I entirely agree with their opinions." +Clown's meaning was not clear to me, but as I was thoroughly angered, I +rose without completing my rough note. +"I am entirely opposed to......." I said, but the rest did not come at +once. ".......I don't like such a topsy-turvy settlement," I added and +the fellows began laughing. "The students are absolutely wrong from the +beginning. It would set a bad precedent if we don't make them apologize +....... What do we care if we kick them all out ....... darn the kids +trying to guy a new comer......." and I sat down. Then the teacher of +natural history who sat on my right whined a weak opinion, saying "The +students may be wrong, but if we punish them too severely, they may +start a reaction and would make it rather bad. I am for the moderate +side, as the head teacher suggested." The teacher of Confucius on my +left expressed his agreement with the moderate side, and so did the +teacher of history endorse the views of the head teacher. Dash those +weak-knees! Most of them belonged to the coterie of Red Shirt. It would +make a dandy school if such fellows run it. I had decided in my mind +that it must be either the students apologize to me or I resign, and if +the opinion of Red Shirt prevailed, I had determined to return home and +pack up. I had no ability of out-talking such fellows, or even if I had, +I was in no humor to keeping their company for long. Since I don't +expect to remain in the school, the devil may take care of the rest. If +I said anything, they would only laugh; so I shut my mouth tight. +Porcupine, who up to this time had been listening to the others, stood +up with some show of spirit. Ha, the fellow was going to endorse the +views of Red Shirt, eh? You and I got to fight it out anyway, I thought, +so do any way you darn please. Porcupine spoke in a thunderous voice: +"I entirely differ from the opinions of the head teacher and other +gentlemen. Because, viewed from whatever angle, this incident cannot be +other than an attempt by those fifty students in the dormitory to make +a fool of a new teacher. The head teacher seems to trace the cause of +the trouble to the personality of that teacher himself, but, begging +his pardon, I think he is mistaken. The night that new teacher was on +night duty was not long after his arrival, not more than twenty days +after he had come into contact with the students. During those short +twenty days, the students could have no reason to criticise his +knowledges or his person. If he was insulted for some cause which +deserved insult, there may be reasons in our considering the act of the +students, but if we show undue leniency toward the frivolous students +who would insult a new teacher without cause, it would affect the +dignity of this school. The spirit of education is not only in +imparting technical knowledges, but also in encouraging honest, +ennobling and samurai-like virtues, while eliminating the evil tendency +to vulgarity and roughness. If we are afraid of reaction or further +trouble, and satisfy ourselves with make-shifts, there is no telling +when we can ever get rid of this evil atmosphere[G]. We are here to +eradicate this very evil. If we mean to countenance it, we had better +not accepted our positions here. For these reasons, I believe it proper +to punish the students in the dormitory to the fullest extent and also +make them apologize to that teacher in the open." +All were quiet. Red Shirt again began polishing his pipe. I was greatly +elated. He spoke almost what I had wanted to. I'm such a simple-hearted +fellow that I forgot all about the bickerings with Porcupine, and looked +at him with a grateful face, but he appeared to take no notice of me. +After a while, Porcupine again stood up, and said. "I forgot to mention +just now, so I wish to add. The teacher on night duty that night seems +to have gone to the hot springs during his duty hours, and I think it a +blunder. It is a matter of serious misconduct to take the advantage of +being in sole charge of the school, to slip out to a hot springs. The +bad behavior of the students is one thing; this blunder is another, and +I wish the principal to call attention of the responsible person to +that matter." +A strange fellow! No sooner had he backed me up than he began talking me +down. I knew the other night watch went out during his duty hours, and +thought it was a custom, so I went as far out as to the hot springs +without considering the situation seriously. But when it was pointed out +like this, I realised that I had been wrong. Thereupon I rose again and +said; "I really went to the hot springs. It was wrong and I apologize." +Then all again laughed. Whatever I say, they laugh. What a lot of boobs! +See if you fellows can make a clean breast of your own fault like this! +You fellows laugh because you can't talk straight. +After that the principal said that since it appeared that there will be +no more opinions, he will consider the matter well and administer what +he may deem a proper punishment. I may here add the result of the +meeting. The students in the dormitory were given one week's +confinement, and in addition to that, apologized to me. If they had not +apologized, I intended to resign and go straight home, but as it was it +finally resulted in a bigger and still worse affair, of which more +later. The principal then at the meeting said something to the effect +that the manners of the students should be directed rightly by the +teachers' influence, and as the first step, no teacher should patronize, +if possible, the shops where edibles and drinks were served, excepting, +however, in case of farewell party or such social gatherings. He said he +would like no teacher to go singly to eating houses of lower kind--for +instance, noodle-house or dango shop.... And again all laughed. Clown +looked at Porcupine, said "tempura" and winked his eyes, but Porcupine +regarded him in silence. Good! +My "think box" is not of superior quality, so things said by Badger were +not clear to me, but I thought if a fellow can't hold the job of teacher +in a middle school because he patronizes a noodle-house or dango shop, +the fellow with bear-like appetite like me will never be able to hold +it. If it was the case, they ought to have specified when calling for a +teacher one who does not eat noodle and dango. To give an appointment +without reference to the matter at first, and then to proclaim that +noodle or dango should not be eaten was a blow to a fellow like me who +has no other petty hobby. Then Red Shirt again opened his mouth. +"Teachers of the middle school belong to the upper class of society and +they should not be looking after material pleasures only, for it would +eventually have effect upon their personal character. But we are human, +and it would be intolerable in a small town like this to live without +any means of affording some pleasure to ourselves, such as fishing, +reading literary products, composing new style poems, or haiku +(17-syllable poem). We should seek mental consolation of higher order." +There seemed no prospect that he would quit the hot air. If it was a +mental consolation to fish fertilisers on the sea, have goruki for +Russian literature, or to pose a favorite geisha beneath pine tree, it +would be quite as much a mental consolation to eat dempura noodle and +swallow dango. Instead of dwelling on such sham consolations, he would +find his time better spent by washing his red shirts. I became so +exasperated that I asked; "Is it also a mental consolation to meet the +Madonna?" No one laughed this time and looked at each other with queer +faces, and Red Shirt himself hung his head, apparently embarrassed. Look +at that! A good shot, eh? Only I was sorry for Hubbard Squash who, +having heard the remark, became still paler. +CHAPTER VII. +That very night I left the boarding house. While I was packing up, the +boss came to me and asked if there was anything wrong in the way I was +treated. He said he would be pleased to correct it and suit me if I was +sore at anything. This beats me, sure. How is it possible for so many +boneheads to be in this world! I could not tell whether they wanted me +to stay or get out. They're crazy. It would be disgrace for a Yedo kid +to fuss about with such a fellow; so I hired a rikishaman and speedily +left the house. +I got out of the house all right, but had no place to go. The rikishaman +asked me where I was going. I told him to follow me with his mouth shut, +then he shall see and I kept on walking. I thought of going to +Yamashiro-ya to avoid the trouble of hunting up a new boarding house, +but as I had no prospect of being able to stay there long, I would have +to renew the hunt sooner or later, so I gave up the idea. If I continued +walking this way, I thought I might strike a house with the sign of +"boarders taken" or something similar, and I would consider the first +house with the sign the one provided for me by Heaven. I kept on going +round and round through the quiet, decent part of the town when I found +myself at Kajimachi. This used to be former samurai quarters where one +had the least chance of finding any boarding house, and I was going to +retreat to a more lively part of the town when a good idea occurred to +me. Hubbard Squash whom I respected lived in this part of the town. He +is a native of the town, and has lived in the house inherited from his +great grandfather. He must be, I thought, well informed about nearly +everything in this town. If I call on him for his help, he will perhaps +find me a good boarding house. Fortunately, I called at his house once +before, and there was no trouble in finding it out. I knocked at the +door of a house, which I knew must be his, and a woman about fifty years +old with an old fashioned paper-lantern in hand, appeared at the door. I +do not despise young women, but when I see an aged woman, I feel much +more solicitous. This is probably because I am so fond of Kiyo. This +aged lady, who looked well-refined, was certainly mother of Hubbard +Squash whom she resembled. She invited me inside, but I asked her to +call him out for me. When he came I told him all the circumstances, and +asked him if he knew any who would take me for a boarder. Hubbard Squash +thought for a moment in a sympathetic mood, then said there was an old +couple called Hagino, living in the rear of the street, who had asked +him sometime ago to get some boarders for them as there are only two in +the house and they had some vacant rooms. Hubbard Squash was kind enough +to go along with me and find out if the rooms were vacant. They were. +From that night I boarded at the house of the Haginos. What surprised me +was that on the day after I left the house of Ikagin, Clown stepped in +and took the room I had been occupying. Well used to all sorts of tricks +and crooks as I might have been, this audacity fairly knocked me off my +feet. It was sickening. +I saw that I would be an easy mark for such people unless I brace up +and try to come up, or down, to their level. It would be a high time +indeed for me to be alive if it were settled that I would not get three +meals a day without living on the spoils of pick pockets. Nevertheless, +to hang myself,--healthy and vigorous as I am,--would be not only +inexcusable before my ancestors but a disgrace before the public. Now I +think it over, it would have been better for me to have started +something like a milk delivery route with that six hundred yen as +capital, instead of learning such a useless stunt as mathematics at the +School of Physics. If I had done so, Kiyo could have stayed with me, +and I could have lived without worrying about her so far a distance +away. While I was with her I did not notice it, but separated thus I +appreciated Kiyo as a good-natured old woman. One could not find a +noble natured woman like Kiyo everywhere. She was suffering from a +slight cold when I left Tokyo and I wondered how she was getting on +now? Kiyo must have been pleased when she received the letter from me +the other day. By the way, I thought it was the time I was in receipt +of answer from her. I spent two or three days with things like this in +my mind. I was anxious about the answer, and asked the old lady of the +house if any letter came from Tokyo for me, and each time she would +appear sympathetic and say no. The couple here, being formerly of +samurai class, unlike the Ikagin couple, were both refined. The old +man's recital of "utai" in a queer voice at night was somewhat telling +on my nerves, but it was much easier on me as he did not frequent my +room like Ikagin with the remark of "let me serve you tea." +The old lady once in a while would come to my room and chat on many +things. She questioned me why I had not brought my wife with me. I asked +her if I looked like one married, reminding her that I was only twenty +four yet. Saying "it is proper for one to get married at twenty four" as +a beginning, she recited that Mr. Blank married when he was twenty, that +Mr. So-and-So has already two children at twenty two, and marshalled +altogether about half a dozen examples,--quite a damper on my youthful +theory. I will then get marred at twenty four, I said, and requested her +to find me a good wife, and she asked me if I really meant it. +"Really? You bet! I can't help wanting to get married." +"I should suppose so. Everybody is just like that when young." This +remark was a knocker; I could not say anything to that. +"But I'm sure you have a Madam already. I have seen to that with my +own eyes." +"Well, they are sharp eyes. How have you seen it?" +"How? Aren't you often worried to death, asking if there's no letter +from Tokyo?" +"By Jupiter! This beats me!" +"Hit the mark, haven't I?" +"Well, you probably have." +"But the girls of these days are different from what they used to be and +you need a sharp look-out on them. So you'd better be careful." +"Do you mean that my Madam in Tokyo is behaving badly?" +"No, your Madam is all right." +"That makes me feel safe. Then about what shall I be careful?" +"Yours is all right. Though yours is all right......." +"Where is one not all right?" +"Rather many right in this town. You know the daughter of the Toyamas? +"No, I do not." +"You don't know her yet? She is the most beautiful girl about here. She +is so beautiful that the teachers in the school call her Madonna. You +haven't heard that? +"Ah, the Madonna! I thought it was the name of a geisha." +"No, Sir. Madonna is a foreign word and means a beautiful girl, +doesn't it?" +"That may be. I'm surprised." +"Probably the name was given by the teacher of drawing." +"Was it the work of Clown?" +"No, it was given by Professor Yoshikawa." +"Is that Madonna not all right?" +"That Madonna-san is a Madonna not all right." +"What a bore! We haven't any decent woman among those with nicknames +from old days. I should suppose the Madonna is not all right." +"Exactly. We have had awful women such as O-Matsu the Devil or Ohyaku +the Dakki. +"Does the Madonna belong to that ring?" +"That Madonna-san, you know, was engaged to Professor Koga,--who brought +you here,--yes, was promised to him." +"Ha, how strange! I never knew our friend Hubbard Squash was a fellow of +such gallantry. We can't judge a man by his appearance. I'll be a bit +more careful." +"The father of Professor Koga died last year,--up to that time they had +money and shares in a bank and were well off,--but since then things +have grown worse, I don't know why. Professor Koga was too good-natured, +in short, and was cheated, I presume. The wedding was delayed by one +thing or another and there appeared the head teacher who fell in love +with the Madonna head over heels and wanted to many her." +"Red Shirt? He ought be hanged. I thought that shirt was not an ordinary +kind of shirt. Well?" +"The head-teacher proposed marriage through a go-between, but the +Toyamas could not give a definite answer at once on account of their +relations with the Kogas. They replied that they would consider the +matter or something like that. Then Red Shirt-san worked up some ways +and started visiting the Toyamas and has finally won the heart of the +Miss. Red Shirt-san is bad, but so is Miss Toyama; they all talk bad of +them. She had agreed to be married to Professor Koga and changed her +mind because a Bachelor of Arts began courting her,--why, that would be +an offense to the God of To-day." +"Of course. Not only of To-day but also of tomorrow and the day after; +in fact, of time without end." +"So Hotta-san a friend of Koga-san, felt sorry for him and went to the +head teacher to remonstrate with him. But Red Shirt-san said that he had +no intention of taking away anybody who is promised to another. He may +get married if the engagement is broken, he said, but at present he was +only being acquainted with the Toyamas and he saw nothing wrong in his +visiting the Toyamas. Hotta-san couldn't do anything and returned. Since +then they say Red Shirt-san and Hotta-san are on bad terms." +"You do know many things, I should say. How did you get such details? +I'm much impressed." +"The town is so small that I can know everything." +Yes, everything seems to be known more than one cares. Judging by her +way, this woman probably knows about my tempura and dango affairs. Here +was a pot that would make peas rattle! The meaning of the Madonna, the +relations between Porcupine and Red Shirt became clear and helped me a +deal. Only what puzzled me was the uncertainty as to which of the two +was wrong. A fellow simple-hearted like me could not tell which side he +should help unless the matter was presented in black and white. +"Of Red Shirt and Porcupine, which is a better fellow?" +"What is Porcupine, Sir?" +"Porcupine means Hotta." +"Well, Hotta-san is physically strong, as strength goes, but Red +Shirt-san is a Bachelor of Arts and has more ability. And Red Shirt-san +is more gentle, as gentleness goes, but Hotta-san is more popular among +the students." +"After all, which is better?" +"After all, the one who gets a bigger salary is greater, I suppose?" +There was no use of going on further in this way, and I closed the talk. +Two or three days after this, when I returned from the school, the old +lady with a beaming smile, brought me a letter, saying, "Here you are +Sir, at last. Take your time and enjoy it." I took it up and found it +was from Kiyo. On the letter were two or three retransmission slips, and +by these I saw the letter was sent from Yamashiro-ya to the Iagins, then +to the Haginos. Besides, it stayed at Yamashiro-ya for about one week; +even letters seemed to stop in a hotel. I opened it, and it was a very +long letter. +"When I received the letter from my Master Darling, I intended to write +an answer at once. But I caught cold and was sick abed for about one +week and the answer was delayed for which I beg your pardon. I am not +well-used to writing or reading like girls in these days, and it +required some efforts to get done even so poorly written a letter as +this. I was going to ask my nephew to write it for me, but thought it +inexcusable to my Master Darling when I should take special pains for +myself. So I made a rough copy once, and then a clean copy. I finished +the clean copy, in two days, but the rough copy took me four days. It +may be difficult for you to read, but as I have written this letter with +all my might, please read it to the end." +This was the introductory part of the letter in which, about four feet +long, were written a hundred and one things. Well, it was difficult to +read. Not only was it poorly written but it was a sort of juxtaposition +of simple syllables that racked one's brain to make it clear where it +stopped or where it began. I am quick-tempered and would refuse to read +such a long, unintelligible letter for five yen, but I read this +seriously from the first to the last. It is a fact that I read it +through. My efforts were mostly spent in untangling letters and +sentences; so I started reading it over again. The room had become a +little dark, and this rendered it harder to read it; so finally I +stepped out to the porch where I sat down and went over it carefully. +The early autumn breeze wafted through the leaves of the banana trees, +bathed me with cool evening air, rustled the letter I was holding and +would have blown it clear to the hedge if I let it go. I did not mind +anything like this, but kept on reading. +"Master Darling is simple and straight like a split bamboo by +disposition," it says, "only too explosive. That's what worries me. If +you brand other people with nicknames you will only make enemies of +them; so don't use them carelessly; if you coin new ones, just tell them +only to Kiyo in your letters. The countryfolk are said to be bad, and I +wish you to be careful not have them do you. The weather must be worse +than in Tokyo, and you should take care not to catch cold. Your letter +is too short that I can't tell how things are going on with you. Next +time write me a letter at least half the length of this one. Tipping the +hotel with five yen is all right, but were you not short of money +afterward? Money is the only thing one can depend upon when in the +country and you should economize and be prepared for rainy days. I'm +sending you ten yen by postal money order. I have that fifty yen my +Master Darling gave me deposited in the Postal Savings to help you start +housekeeping when you return to Tokyo, and taking out this ten, I have +still forty yen left,--quite safe." +I should say women are very particular on many things. +When I was meditating with the letter flapping in my hand on the porch, +the old lady opened the sliding partition and brought in my supper. +"Still poring over the letter? Must be a very long one, I +imagine," she said. +"Yes, this is an important letter, so I'm reading it with the wind +blowing it about," I replied--the reply which was nonsense even for +myself,--and I sat down for supper. I looked in the dish on the tray, +and saw the same old sweet potatoes again to-night. This new boarding +house was more polite and considerate and refined than the Ikagins, but +the grub was too poor stuff and that was one drawback. It was sweet +potato yesterday, so it was the day before yesterday, and here it is +again to-night. True, I declared myself very fond of sweet potatoes, but +if I am fed with sweet potatoes with such insistency, I may soon have to +quit this dear old world. I can't be laughing at Hubbard Squash; I shall +become Sweet Potato myself before long. If it were Kiyo she would surely +serve me with my favorite sliced tunny or fried kamaboko, but nothing +doing with a tight, poor samurai. It seems best that I live with Kiyo. +If I have to stay long in the school, I believe I would call her from +Tokyo. Don't eat tempura, don't eat dango, and then get turned yellow by +feeding on sweet potatoes only, in the boarding house. That's for an +educator, and his place is really a hard one. I think even the priests +of the Zen sect are enjoying better feed. I cleaned up the sweet +potatoes, then took out two raw eggs from the drawer of my desk, broke +them on the edge of the rice bowl, to tide it over. I have to get +nourishment by eating raw eggs or something, or how can I stand the +teaching of twenty one hours a week? +I was late for my bath to-day on account of the letter from Kiyo. But I +would not like to drop off a single day since I had been there everyday. +I thought I would take a train to-day, and coming to the station with +the same old red towel dangling out of my hand, I found the train had +just left two or three minutes ago, and had to wait for some time. While +I was smoking a cigarette on a bench, my friend Hubbard Squash happened +to come in. Since I heard the story about him from the old lady my +sympathy for him had become far greater than ever. His reserve always +appeared to me pathetic. It was no longer a case of merely pathetic; +more than that. I was wishing to get his salary doubled, if possible, +and have him marry Miss Toyama and send them to Tokyo for about one +month on a pleasure trip. Seeing him, therefore, I motioned him to a +seat beside me, addressing him cheerfully: +"Hello[H], going to bath? Come and sit down here." +Hubbard Squash, appearing much awe-struck, said; "Don't mind me, +Sir," and whether out of polite reluctance or I don't know what, +remained standing. +"You have to wait for a little while before the next train starts; sit +down; you'll be tired," I persuaded him again. In fact, I was so +sympathetic for him that I wished to have him sit down by me somehow. +Then with a "Thank you, Sir," he at last sat down. A fellow like Clown, +always fresh, butts in where he is not wanted; or like Porcupine +swaggers about with a face which says "Japan would be hard up without +me," or like Red Shirt, self-satisfied in the belief of being the +wholesaler of gallantry and of cosmetics. Or like Badger who appears to +say; "If 'Education' were alive and put on a frockcoat, it would look +like me." One and all in one way or other have bravado, but I have +never seen any one like this Hubbard Squash, so quiet and resigned, +like a doll taken for a ransom. His face is rather swollen but for the +Madonna to cast off such a splendid fellow and give preference to Red +Shirt, was frivolous beyond my understanding. Put how many dozens of +Red Shirt you like together, it will not make one husband of stuff to +beat Hubbard Squash. +"Is anything wrong with you? You look quite fatigued," I asked. +"No, I have no particular ailments......." +"That's good. Poor health is the worst thing one can get." +"You appear very strong." +"Yes, I'm thin, but never got sick. That's something I don't like." +Hubbard Squash smiled at my words. Just then I heard some young girlish +laughs at the entrance, and incidentally looking that way, I saw a +"peach." A beautiful girl, tall, white-skinned, with her head done up +in "high-collared" style, was standing with a woman of about forty-five +or six, in front of the ticket window. I am not a fellow given to +describing a belle, but there was no need to repeat asserting that she +was beautiful. I felt as if I had warmed a crystal ball with perfume +and held it in my hand. The older woman was shorter, but as she +resembled the younger, they might be mother and daughter. The moment I +saw them, I forgot all about Hubbard Squash, and was intently gazing at +the young beauty. Then I was a bit startled to see Hubbard Squash +suddenly get up and start walking slowly toward them. I wondered if she +was not the Madonna. The three were courtesying in front of the ticket +window, some distance away from me, and I could not hear what they were +talking about. +The clock at the station showed the next train to start in five +minutes. Having lost my partner, I became impatient and longed for the +train to start as soon as possible, when a fellow rushed into the +station excited. It was Red Shirt. He had on some fluffy clothes, +loosely tied round with a silk-crepe girdle, and wound to it the same +old gold chain. That gold chain is stuffed. Red Shirt thinks nobody +knows it and is making a big show of it, but I have been wise. Red +Shirt stopped short, stared around, and then after bowing politely to +the three still in front of the ticket window, made a remark or two, +and hastily turned toward me. He came up to me, walking in his usual +cat's style, and hallooed. +"You too going to bath? I was afraid of missing the train and +hurried up, but we have three or four minutes yet. Wonder if that +clock is right?" +He took out his gold watch, and remarking it wrong about two minutes sat +down beside me. He never turned toward the belle, but with his chin on +the top of a cane, steadily looked straight before him. The older woman +would occasionally glance toward Red Shirt, but the younger kept her +profile away. Surely she was the Madonna. +The train now arrived with a shrill whistle and the passengers hastened +to board. Red Shirt jumped into the first class coach ahead of all. One +cannot brag much about boarding the first class coach here. It cost only +five sen for the first and three sen for the second to Sumida; even I +paid for the first and a white ticket. The country fellows, however, +being all close, seemed to regard the expenditure of the extra two sen a +serious matter and mostly boarded the second class. Following Red Shirt, +the Madonna and her mother entered the first class. Hubbard Squash +regularly rides in the second class. He stood at the door of a second +class coach and appeared somewhat hesitating, but seeing me coming, took +decisive steps and jumped into the second. I felt sorry for him--I do +not know why--and followed him into the same coach. Nothing wrong in +riding on the second with a ticket for the first, I believe. +At the hot springs, going down from the third floor to the bath room in +bathing gown, again I met Hubbard Squash. I feel my throat clogged up +and unable to speak at a formal gathering, but otherwise I am rather +talkative; so I opened conversation with him. He was so pathetic and my +compassion was aroused to such an extent that I considered it the duty +of a Yedo kid to console him to the best of my ability. But Hubbard +Squash was not responsive. Whatever I said, he would only answer "eh?" +or "umh," and even these with evident effort. Finally I gave up my +sympathetic attempt and cut off the conversation. +I did not meet Red Shirt at the bath. There are many bath rooms, and one +does not necessarily meet the fellows at the same bath room though he +might come on the same train. I thought it nothing strange. When I got +out of the bath, I found the night bright with the moon. On both sides +of the street stood willow trees which cast their shadows on the road. I +would take a little stroll, I thought. Coming up toward north, to the +end of the town, one sees a large gate to the left. Opposite the gate +stands a temple and both sides of the approach to the temple are lined +with houses with red curtains. A tenderloin inside a temple gate is an +unheard-of phenomenon. I wanted to go in and have a look at the place, +but for fear I might get another kick from Badger, I passed it by. A +flat house with narrow lattice windows and black curtain at the +entrance, near the gate, is the place where I ate dango and committed +the blunder. A round lantern with the signs of sweet meats hung outside +and its light fell on the trunk of a willow tree close by. I hungered to +have a bite of dango, but went away forbearing. +To be unable to eat dango one is so fond of eating, is tragic. But to +have one's betrothed change her love to another, would be more tragic. +When I think of Hubbard Squash, I believe that I should, not complain if +I cannot eat dango or anything else for three days. Really there is +nothing so unreliable a creature as man. As far as her face goes, she +appears the least likely to commit so stony-hearted an act as this. But +the beautiful person is cold-blooded and Koga-san who is swollen like a +pumpkin soaked in water, is a gentleman to the core,--that's where we +have to be on the look-out. Porcupine whom I had thought candid was said +to have incited the students and he whom then I regarded an agitator, +demanded of the principal a summary punishment of the students. The +disgustingly snobbish Red Shirt is unexpectedly considerate and warns me +in ways more than one, but then he won the Madonna by crooked means. He +denies, however, having schemed anything crooked about the Madonna, and +says he does not care to marry her unless her engagement with Koga is +broken. When Ikagin beat me out of his house, Clown enters and takes my +room. Viewed from any angle, man is unreliable. If I write these things +to Kiyo, it would surprise her. She would perhaps say that because it is +the west side of Hakone that the town had all the freaks and crooks +dumped in together.[7] +[Footnote 7: An old saying goes that east of the Hakone pass, there are +no apparitions or freaks.] +I do not by nature worry about little things, and had come so far +without minding anything. But hardly a month had passed since I came +here, and I have begun to regard the world quite uneasily. I have not +met with any particularly serious affairs, but I feel as if I had grown +five or six years older. Better say "good by" to this old spot soon and +return to Tokyo, I thought. While strolling thus thinking on various +matters, I had passed the stone bridge and come up to the levy of the +Nozeri river. The word river sounds too big; it is a shallow stream of +about six feet wide. If one goes on along the levy for about twelve +blocks, he reaches the Aioi village where there is a temple of Kwanon. +Looking back at the town of the hot springs, I see red lights gleaming +amid the pale moon beams. Where the sound of the drum is heard must be +the tenderloin. The stream is shallow but fast, whispering incessantly. +When I had covered about three blocks walking leisurely upon the bank, +I perceived a shadow ahead. Through the light of the moon, I found +there were two shadows. They were probably village youngsters returning +from the hot springs, though they did not sing, and were exceptionally +quiet for that. +I kept on walking, and I was faster than they. The two shadows became +larger. One appeared like a woman. When I neared them within about sixty +feet, the man, on hearing my footsteps, turned back. The moon was +shining from behind me. I could see the manner of the man then and +something queer struck me. They resumed their walk as before. And I +chased them on a full speed. The other party, unconscious, walked +slowly. I could now hear their voice distinctly. The levy was about six +feet wide, and would allow only three abreast. I easily passed them, and +turning back gazed squarely into the face of the man. The moon +generously bathed my face with its beaming light. The fellow uttered a +low "ah," and suddenly turning sideway, said to the woman "Let's go +back." They traced their way back toward the hot springs town. +Was it the intention of Red Shirt to hush the matter up by pretending +ignorance, or was it lack of nerve? I was not the only fellow who +suffered the consequence of living in a small narrow town. +CHAPTER VIII. +On my way back from the fishing to which I was invited by Red Shirt, and +since then, I began to suspect Porcupine. When the latter wanted me to +get out of Ikagin's house on sham pretexts, I regarded him a decidedly +unpleasant fellow. But as Porcupine, at the teachers' meeting, contrary +to my expectation, stood firmly for punishing the students to the +fullest extent of the school regulations, I thought it queer. When I +heard from the old lady about Porcupine volunteering himself for the +sake of Hubbard Squash to stop Red Shirt meddling with the Madonna, I +clapped my hands and hoorayed for him. Judging by these facts, I began +to wonder if the wrong-doer might be not Porcupine, but Red Shirt the +crooked one. He instilled into my head some flimsy hearsay plausibly and +in a roundabout-way. At this juncture I saw Red Shirt taking a walk with +the Madonna on the levy of the Nozeri river, and I decided that Red +Shirt may be a scoundrel. I am not sure of his being really scoundrel at +heart, but at any rate he is not a good fellow. He is a fellow with a +double face. A man deserves no confidence unless he is as straight as +the bamboo. One may fight a straight fellow, and feel satisfied. We +cannot lose sight of the fact that Red Shirt or his kind who is kind, +gentle, refined, and takes pride in his pipe had to be looked sharp, for +I could not be too careful in getting into a scrap with the fellow of +this type. I may fight, but I would not get square games like the +wrestling matches it the Wrestling Amphitheatre in Tokyo. Come to think +of it, Porcupine who turned against me and startled the whole teachers' +room over the amount of one sen and a half is far more like a man. When +he stared at me with owlish eyes at the teachers' meeting, I branded him +as a spiteful guy, but as I consider the matter now, he is better than +the feline voice of Red Shirt. To tell the truth, I tried to get +reconciled with Porcupine, and after the meeting, spoke a word or two to +him, but he shut up like a clam and kept glaring at me. So I became +sore, and let it go at that. +Porcupine has not spoken to me since. The one sen and a half which I +paid him back upon the desk, is still there, well covered with dust. I +could not touch it, nor would Porcupine take it. This one sen and a +half has become a barrier between us two. We two were cursed with this +one sen and a half. Later indeed I got sick of its sight that I hated +to see it. +While Porcupine and I were thus estranged, Red Shirt and I continued +friendly relations and associated together. On the day following my +accidental meeting with him near the Nozeri river, for instance, Red +Shirt came to my desk as soon as he came to the school, and asked me how +I liked the new boarding house. He said we would go together for fishing +Russian literature again, and talked on many things. I felt a bit +piqued, and said, "I saw you twice last night," and he answered, "Yes, +at the station. Do you go there at that time every day? Isn't it late?" +I startled him with the remark; "I met you on the levy of the Nozeri +river too, didn't I?" and he replied, "No, I didn't go in that +direction. I returned right after my bath." +What is the use of trying to keep it dark. Didn't we meet actually face +to face? He tells too many lies. If one can hold the job of a head +teacher and act in this fashion, I should be able to run the position of +Chancellor of a university. From this time on, my confidence in Red +Shirt became still less. I talk with Red Shirt whom I do not trust, and +I keep silent with Porcupine whom I respect. Funny things do happen in +this world. +One day Red Shirt asked me to come over to his house as he had something +to tell me, and much as I missed the trip to the hot springs, I started +for his house at about 4 o'clock. Red Shirt is single, but in keeping +with the dignity of a head teacher, he gave up the boarding house life +long ago, and lives in a fine house. The house rent, I understood, was +nine yen and fifty sen. The front entrance was so attractive that I +thought if one can live in such a splendid house at nine yen and a half +in the country, it would be a good game to call Kiyo from Tokyo and make +her heart glad. The younger brother of Red Shirt answered my bell. This +brother gets his lessons on algebra and mathematics from me at the +school. He stands no show in his school work, and being a "migratory +bird" is more wicked than the native boys. +I met Red Shirt. Smoking the same old unsavory amber pipe, he said +something to the following effect: +"Since you've been with us, our work has been more satisfactory than it +was under your predecessor, and the principal is very glad to have got +the right person in the right place. I wish you to work as hard as you +can, for the school is depending upon you." +"Well, is that so. I don't think I can work any harder than now......." +"What you're doing now is enough. Only don't forget what I told you the +other day." +"Meaning that one who helps me find a boarding house is dangerous?" +"If you state it so baldly, there is no meaning to it....... But that's +all right,...... I believe you understand the spirit of my advice. And +if you keep on in the way you're going to-day ...... We have not been +blind ...... we might offer you a better treatment later on if we can +manage it." +"In salary? I don't care about the salary, though the more the better." +"And fortunately there is going to be one teacher transferred,...... +however, I can't guarantee, of course, until I talk it over with the +principal ...... and we might give you something out of his salary." +"Thank you. Who is going to be transferred?" +"I think I may tell you now; 'tis going to be Announced soon. Koga +is the man." +"But isn't Koga-san a native of this town?" +"Yes, he is. But there are some circumstances ...... and it is partly by +his own preference." +"Where is he going?" +"To Nobeoka in Hiuga province. As the place is so far away, he is going +there with his salary raised a grade higher." +"Is some one coming to take his place?" +"His successor is almost decided upon." +"Well, that's fine, though I'm not very anxious to have my salary +raised." +"I'm going to talk to the principal about that anyway. And, we may have +to ask you to work more some time later ...... and the principal appears +to be of the same opinion....... I want you to go[I] ahead with that in +your mind." +"Going to increase my working hours?" +"No. The working hours may be reduced......" +"The working hours shortened and yet work more? Sounds funny." +"It does sound funny ...... I can't say definitely just yet ...... it +means that we way have to ask you to assume more responsibility." +I could not make out what he meant. To assume more responsibility might +mean my appointment to the senior instructor of mathematics, but +Porcupine is the senior instructor and there is no danger of his +resigning. Besides, he is so very popular among the students that his +transfer or discharge would be inadvisable. Red Shirt always misses the +point. And though he did not get to the point, the object of my visit +was ended. We talked a while on sundry matters, Red Shirt proposing a +farewell dinner party for Hubbard Squash, asking me if I drink liquor +and praising Hubbard Squash as an amiable gentleman, etc. Finally he +changed the topic and asked me if I take an interest in "haiku"[8] Here +is where I beat it, I thought, and, saying "No, I don't, good by," +hastily left the house. The "haiku" should be a diversion of Baseo[9] or +the boss of a barbershop. It would not do for the teacher of mathematics +to rave over the old wooden bucket and the morning glory.[10] +[Footnote 8: The 17-syllable poem] +[Footnote 9: A famous composer of the poem.] +[Footnote 10: There is a well-known 17-syllable poem describing the +scene of morning glories entwining around the wooden bucket.] +I returned home and thought it over. Here is a man whose mental process +defies a layman's understanding. He is going to court hardships in a +strange part of the country in preference of his home and the school +where he is working,--both of which should satisfy most +anybody,--because he is tired of them. That may be all right if the +strange place happens to be a lively metropolis where electric cars +run,--but of all places, why Nobeoka in Hiuga province? This town here +has a good steamship connection, yet I became sick of it and longed for +home before one month had passed. Nobeoka is situated in the heart of a +most mountainous country. According to Red Shirt, one has to make an +all-day ride in a wagonette to Miyazaki, after he had left the vessel, +and from Miyazaki another all-day ride in a rikisha to Nobeoka. Its name +alone does not commend itself as civilized. It sounds like a town +inhabited by men and monkeys in equal numbers. However sage-like Hubbard +Squash might be I thought he would not become a friend of monkeys of his +own choice. What a curious slant! +Just then the old lady brought in my supper--"Sweet potatoes again?" I +asked, and she said, "No, Sir, it is tofu to-night." They are about the +same thing. +"Say, I understand Koga-san is going to Nobeoka." +"Isn't it too bad?" +"Too bad? But it can't be helped if he goes there by his own +preference." +"Going there by his own preference? Who, Sir?" +"Who? Why, he! Isn't Professor Koga going there by his own choice?" +"That's wrong Mr. Wright, Sir." +"Ha, Mr. Wright, is it? But Red Shirt told me so just now. If that's +wrong Mr. Wright, then Red Shirt is blustering Mr. Bluff." +"What the head-teacher says is believable, but so Koga-san does not +wish to go." +"Our old lady is impartial, and that is good. Well, what's the matter?" +"The mother of Koga-san was here this morning, and told me all the +circumstances." +"Told you what circumstances?" +"Since the father of Koga-san died, they have not been quite well off as +we might have supposed, and the mother asked the principal if his salary +could not be raised a little as Koga-san has been in service for four +years. See?" +"Well?" +"The principal said that he would consider the matter, and she felt +satisfied and expected the announcement of the increase before long. She +hoped for its coming this month or next. Then the principal called +Koga-san to his office one day and said that he was sorry but the school +was short of money and could not raise his salary. But he said there is +an opening in Nobeoka which would give him five yen extra a month and he +thought that would suit his purpose, and the principal had made all +arrangements and told Koga-san he had better go......." +"That wasn't a friendly talk but a command. Wasn't it?" +"Yes, Sir, Koga-san told the principal that he liked to stay here better +at the old salary than go elsewhere on an increased salary, because he +has his own house and is living with his mother. But the matter has all +been settled, and his successor already appointed and it couldn't be +helped, said the principal." +"Hum, that's a jolly good trick, I should say. Then Koga-san has no +liking to go there? No wonder I thought it strange. We would have to go +a long way to find any blockhead to do a job in such a mountain village +and get acquainted with monkeys for five yen extra." +"What is a blockhead, Sir?" +"Well, let go at that. It was all the scheme of Red Shirt. Deucedly +underhand scheme, I declare. It was a stab from behind. And he means to +raise my salary by that; that's not right. I wouldn't take that raise. +Let's see if he can raise it." +"Is your salary going to be raised, Sir?" +"Yes, they said they would raise mine, but I'm thinking of refusing it." +"Why do you refuse?" +"Why or no why, it's going to be refused. Say, Red Shirt is a fool; he +is a coward." +"He may be a coward, but if he raises your salary, it would be best for +you to make no fuss, but accept it. One is apt to get grouchy when +young, but will always repent when he is grown up and thinks that it was +pity he hadn't been a little more patient. Take an old woman's advice +for once, and if Red Shirt-san says he will raise your salary, just take +it with thanks." +"It's none of business of you old people." +The old lady withdrew in silence. The old man is heard singing "utai" in +the off-key voice. "Utai," I think, is a stunt which purposely makes a +whole show a hard nut to crack by giving to it difficult tunes, whereas +one could better understand it by reading it. I cannot fathom what is in +the mind of the old man who groans over it every night untired. But I'm +not in a position to be fooling with "utai." Red Shirt said he would +have my salary raised, and though I did not care much about it, I +accepted it because there was no use of leaving the money lying around. +But I cannot, for the love of Mike, be so inconsiderate as to skin the +salary of a fellow teacher who is being transferred against his will. +What in thunder do they mean by sending him away so far as Nobeoka when +the fellow prefers to remain in his old position? Even +Dazai-no-Gonnosutsu did not have to go farther than about Hakata; even +Matagoro Kawai [11] stopped at Sagara. I shall not feel satisfied unless +I see Red Shirt and tell him I refuse the raise. +[Footnote 11: The persons in exile, well-known in Japanese history.] +I dressed again and went to his house. The same younger brother of Red +Shirt again answered the bell, and looked at me with eyes which plainly +said, "You here again?" I will come twice or thrice or as many times as +I want to if there is business. I might rouse them out of their beds at +midnight;--it is possible, who knows. Don't mistake me for one coming to +coax the head teacher. I was here to give back my salary. The younger +brother said that there is a visitor just now, and I told him the front +door will do; won't take more than a minute, and he went in. Looking +about my feet, I found a pair of thin, matted wooden clogs, and I heard +some one in the house saying, "Now we're banzai." I noticed that the +visitor was Clown. Nobody but Clown could make such a squeaking voice +and wear such clogs as are worn by cheap actors. +After a while Red Shirt appeared at the door with a lamp in his hand, +and said, "Come in; it's no other than Mr. Yoshikawa." +"This is good enough," I said, "it won't take long." I looked at his +face which was the color of a boiled lobster. He seemed to have been +drinking with Clown. +"You told me that you would raise my salary, but I've changed my mind, +and have come here to decline the offer." +Red Shirt, thrusting out the lamp forward, and intently staring at me, +was unable to answer at the moment. He appeared blank. Did he think it +strange that here was one fellow, only one in the world, who does not +want his salary raised, or was he taken aback that I should come back so +soon even if I wished to decline it, or was it both combined, he stood +there silent with his mouth in a queer shape. +"I accepted your offer because I understood that Mr. Koga was being +transferred by his own preference......." +"Mr. Koga is really going to be transferred by his own preference." +"No, Sir. He would like to stay here. He doesn't mind his present salary +if he can stay." +"Have you heard it from Mr. Koga himself?" +"No, not from him." +"Then, from who?" +"The old lady in my boarding house told me what she heard from the +mother of Mr. Koga." +"Then the old woman in your boarding house told you so?" +"Well, that's about the size of it." +"Excuse me, but I think you are wrong. According to what you say, it +seems as if you believe what the old woman in the boarding house tells +you, but would not believe what your head teacher tells you. Am I right +to understand it that way?" +I was stuck. A Bachelor of Arts is confoundedly good in oratorical +combat. He gets hold of unexpected point, and pushes the other backward. +My father used to tell me that I am too careless and no good, and now +indeed I look that way. I ran out of the house on the moment's impulse +when I heard the story from the old lady, and in fact I had not heard +the story from either Hubbard Squash or his mother. In consequence, when +I was challenged in this Bachelor-of-Arts fashion, it was a bit +difficult to defend myself. +I could not defend his frontal attack, but I had already declared in my +mind a lack of confidence on Red Shirt. The old lady in the boarding +house may be tight and a grabber, I do not doubt it, but she is a woman +who tells no lie. She is not double faced like Red Shirt, I was +helpless, so I answered. +"What you say might be right,--anyway, I decline the raise." +"That's still funnier. I thought your coming here now was because you +had found a certain reason for which you could not accept the raise. +Then it is hard to understand to see you still insisting on declining +the raise in spite of the reason having been eradicated by my +explanation." +"It may be hard to understand, but anyway I don't want it." +"If you don't like it so much, I wouldn't force it on you. But if you +change your mind within two or three hours with no particular reason, it +would affect your credit in future." +"I don't care if it does affect it." +"That can't be. Nothing is more important than credit for us. Supposing, +the boss of the boarding house......." +"Not the boss, but the old lady." +"Makes no difference,--suppose what the old woman in the boarding house +told you was true, the raise of your salary is not to be had by reducing +the income of Mr. Koga, is it? Mr. Koga is going to Nobeoka; his +successor is coming. He comes on a salary a little less than that of Mr. +Koga, and we propose to add the surplus money to your salary, and you +need not be shy. Mr. Koga will be promoted; the successor is to start on +less pay, and if you could be raised, I think everything be satisfactory +to all concerned. If you don't like it, that's all right, but suppose +you think it over once more at home?" +My brain is not of the best stuff, and if another fellow flourishes his +eloquence like this, I usually think, "Well, perhaps I was wrong," and +consider myself defeated, but not so to-night. From the time I came to +this town I felt prejudiced against Red Shirt. Once I had thought of him +in a different light, taking him for a fellow kind-hearted and +feminished. His kindness, however, began to look like anything but +kindness, and as a result, I have been getting sick of him. So no matter +how he might glory himself in logical grandiloquence, or how he might +attempt to out-talk me in a head-teacher-style, I don't care a snap. One +who shines in argument is not necessarily a good fellow, while the other +who is out-talked is not necessarily a bad fellow, either. Red Shirt is +very, very reasonable as far as his reasoning goes, but however graceful +he may appear, he cannot win my respect. If money, authority or +reasoning can command admiration, loansharks, police officers or college +professors should be liked best by all. I cannot be moved in the least +by the logic by so insignificant a fellow as the head teacher of a +middle school. Man works by preference, not by logic. +"What you say is right, but I have begun to dislike the raise, so I +decline. It will be the same if I think it over. Good by." And I left +the house of Red Shirt. The solitary milky way hung high in the sky. +CHAPTER IX. +When I went to the school, in the morning of the day the farewell dinner +party was to be held, Porcupine suddenly spoke to me; +"The other day I asked you to quit the Ikagins because Ikagin begged of +me to have you leave there as you were too tough, and I believed him. +But I heard afterward that Ikagin is a crook and often passes imitation +of famous drawings for originals. I think what he told me about you must +be a lie. He tried to sell pictures and curios to you, but as you shook +him off, he told some false stories on you. I did very wrong by you +because I did not know his character, and wish you would forgive me." +And he offered me a lengthy apology. +Without saying a word, I took up the one sen and a half which was lying +on the desk of Porcupine, and put it into my purse. He asked me in a +wondering tone, if I meant to take it back. I explained, "Yes. I didn't +like to have you treat me and expected to pay this back at all hazard, +but as I think about it, I would rather have you treated me after all; +so I'm going to take it back." +Porcupine laughed heartily and asked me why I had not taken it back +sooner. I told him that I wanted to more than once, in fact, but somehow +felt shy and left it there. I was sick of that one sen and a half these +days that I shunned the sight of it when I came to the school, I said. +He said "You're a deucedly unyielding sport," and I answered "You're +obstinate." Then ensued the following give-and-take between us two; +"Where were you born anyway?" +"I'm a Yedo kid." +"Ah, a Yedo kid, eh? No wonder I thought you a pretty stiff neck." +"And you?" +"I'm from Aizu." +"Ha, Aizu guy, eh? You've got reason to be obstinate. Going to the +farewell dinner to-day?" +"Sure. You?" +"Of course I am. I intend to go down to the beach to see Koga-san off +when he leaves." +"The farewell dinner should be a big blow-out. You come and see. I'm +going to get soused to the neck." +"You get loaded all you want. I quit the place right after I finish my +plates. Only fools fight booze." +"You're a fellow who picks up a fight too easy. It shows up the +characteristic of the Yedo kid well." +"I don't care. Say, before you go to the farewell dinner, come to see +me. I want to tell you something." +Porcupine came to my room as promised. I had been in full sympathy with +Hubbard Squash these days, and when it came to his farewell dinner, my +pity for him welled up so much that I wished I could go to Nobeoka for +him myself. I thought of making a parting address of burning eloquence +at the dinner to grace the occasion, but my speech which rattles off +like that of the excited spieler of New York would not become the place. +I planned to take the breath out of Red Shirt by employing Porcupine who +has a thunderous voice. Hence my invitation to him before we started for +the party. +I commenced by explaining the Madonna affair, but Porcupine, needless to +say, knew more about it than I. Telling about my meeting Red Shirt on +the Nozeri river, I called him a fool. Porcupine then said; "You call +everybody a fool. You called me a fool to-day at the school. If I'm a +fool, Red Shirt isn't," and insisted that he was not in the same group +with Red Shirt. "Then Red Shirt may be a four-flusher," I said and he +approved this new alias with enthusiasm. Porcupine is physically strong, +but when it comes to such terms, he knows less than I do. I guess all +Aizu guys are about the same. +Then, when I disclosed to him about the raise of my salary and the +advance hint on my promotion by Red Shirt, Porcupine pished, and said, +"Then he means to discharge me." "Means to discharge you? But you mean +to get discharged?" I asked. "Bet you, no. If I get fired, Red Shirt +will have to go with me," he remarked with a lordly air. I insisted on +knowing how he was going to get Red Shirt kicked out with him, and he +answered that he had not thought so far yet. Yes, Porcupine looks +strong, but seems to be possessed of no abundance of brain power. I told +him about my refusal of the raise of my salary, and the Gov'nur was much +pleased, praising me with the remark, "That's the stuff for Yedo kids." +"If Hubbard Squash does not like to go down to Nobeoka, why didn't you +do something to enable him remain here," I asked, and Porcupine said +that when he heard the story from Hubbard Squash, everything had been +settled already, but he had asked the principal twice and Red Shirt once +to have the transfer order cancelled, but to no purpose. Porcupine +bitterly condemned Hubbard Squash for being too good-natured. If Hubbard +Squash, he said, had either flatly refused or delayed the answer on the +pretext of considering it, when Red Shirt raised the question of +transfer, it would have been better for him. But he was fooled by the +oily tongue of Red Shirt, had accepted the transfer outright, and all +efforts by Porcupine who was moved by the tearful appeal of the mother, +proved unavailing. +I said; "The transfer of Koga is nothing but a trick of Red Shirt to cop +the Madonna by sending Hubbard Squash away." +"Yes," said Porcupine "That must be. Red Shirt looks gentle, but plays +nasty tricks. He is a sonovagun for when some one finds fault with him, +he has excuses prepared already. Nothing but a sound thumping will be +effective for fellows like him." +He rolled up his sleeves over his plump arms as he spoke. I asked him, +by the way, if he knew jiujitsu, because his arms looked powerful. Then +he put force in his forearm, and told me to touch it. I felt its swelled +muscle which was hard as the pumic stone in the public bathhouse. +I was deeply impressed by his massive strength, and asked him if he +could not knock five or six of Red Shirt in a bunch. "Of course," he +said, and as he extended and bent back the arm, the lumpy muscle rolled +round and round, which was very amusing. According to the statement of +Porcupine himself, this muscle, if he bends the arm back with force, +would snap a paper-string wound around it twice. I said I might do the +same thing if it were a paper-string, and he challenged me. "No, you +can't," he said. "See if you can." As it would not look well if I +failed, I did not try. +"Say, after you have drunk all you want to-night at the dinner, take a +fall out of Red Shirt and Clown, eh?" I suggested to him for fun. +Porcupine thought for a moment and said, "Not to-night, I guess." I +wanted to know why, and he pointed out that it would be bad for Koga. +"Besides, if I'm going to give it to them at all, I've to get them red +handed in their dirty scheme, or all the blame will be on me," he added +discretely. Even Porcupine seems to have wiser judgment than I. +"Then make a speech and praise Mr. Koga sky-high. My speech becomes sort +of jumpy, wanting dignity. And at any formal gathering, I get lumpy in +my throat, and can't speak. So I leave it to you," I said. +"That's a strange disease. Then you can't speak in the presence of other +people? It would be awkward, I suppose," he said, and I told him not +quite as much awkward as he might think. +About then, the time for the farewell dinner party arrived, and I went +to the hall with Porcupine. The dinner party was to be held at +Kashin-tei which is said to be the leading restaurant in the town, but I +had never been in the house before. This restaurant, I understood, was +formerly the private residence of the chief retainer of the daimyo of +the province, and its condition seemed to confirm the story. The +residence of a chief retainer transformed into a restaurant was like +making a saucepan out of warrior's armor. +When we two came there, about all of the guests were present. They +formed two or three groups in the spacious room of fifty mats. The +alcove in this room, in harmony with its magnificence, was very large. +The alcove in the fifteen-mat room which I occupied at Yamashiro-ya made +a small showing beside it. I measured it and found it was twelve feet +wide. On the right, in the alcove, there was a seto-ware flower vase, +painted with red designs, in which was a large branch of pine tree. Why +the pine twigs, I did not know, except that they are in no danger of +withering for many a month to come, and are economical. I asked the +teacher of natural history where that seto-ware flower vase is made. He +told me it was not a seto-ware but an imari. Isn't imari seto-ware? I +wondered audibly, and the natural history man laughed. I heard afterward +that we call it a seto-ware because it is made in Seto. I'm a Yedo kid, +and thought all china was seto-wares. In the center of the alcove was +hung a panel on which were written twenty eight letters, each letter as +large as my face. It was poorly written; so poorly indeed that I +enquired of the teacher of Confucius why such a poor work be hung in +apparent show of pride. He explained that it was written by Kaioku a +famous artist in the writing, but Kaioku or anyone else, I still declare +the work poorly done. +By and by, Kawamura, the clerk, requested all to be seated. I chose one +in front of a pillar so I could lean against it. Badger sat in front of +the panel of Kaioku in Japanese full dress. On his left sat Red Shirt +similarly dressed, and on his right Hubbard Squash, as the guest of +honor, in the same kind of dress. I was dressed in a European suit, and +being unable to sit down, squatted on my legs at once. The teacher of +physical culture next to me, though in the same kind of rags as mine, +sat squarely in Japanese fashion. As a teacher of his line he appeared +to have well trained himself. Then the dinner trays were served and the +bottles placed beside them. The manager of the day stood up and made a +brief opening address. He was followed by Badger and Red Shirt. These +two made farewell addresses, and dwelt at length on Hubbard Squash being +an ideal teacher and gentleman, expressing their regret, saying his +departure was a great loss not only to the school but to them in person. +They concluded that it could not be helped, however, since the transfer +was due to his own earnest desire and for his own convenience. They +appeared to be ashamed not in the least by telling such a lie at a +farewell dinner. Particularly, Red Shirt, of these three, praised Hubard +Squash in lavish terms. He went so far as to declare that to lose this +true friend was a great personal loss to him. Moreover, his tone was so +impressive in its same old gentle tone that one who listens to him for +the first time would be sure to be misled. Probably he won the Madonna +by this same trick. While Red Shirt was uttering his farewell buncomb, +Porcupine who sat on the other side across me, winked at me. As an +answer of this, I "snooked" at him. +No sooner had Red Shirt sat down than Porcupine stood up, and highly +rejoiced, I clapped hands. At this Badger and others glanced at me, and +I felt that I blushed a little. +"Our principal and other gentlemen," he said, "particularly the head +teacher, expressed their sincere regret at Mr. Koga's transfer. I am of +a different opinion, and hope to see him leave the town at the earliest +possible moment. Nobeoka is an out-of-the-way, backwoods town, and +compared with this town, it may have more material inconveniences, but +according to what I have heard, Nobeoka is said to be a town where the +customs are simple and untainted, and the teachers and students still +strong in the straightforward characteristics of old days. I am +convinced that in Nobeoka there is not a single high-collared guy who +passes round threadbare remarks, or who with smooth face, entraps +innocent people. I am sure that a man like Mr. Koga, gentle and honest, +will surely be received with an enthusiastic welcome there. I heartily +welcome this transfer for the sake of Mr. Koga. In concluding, I hope +that when he is settled down at Nobeoka, he will find a lady qualified +to become his wife, and form a sweet home at an early date and +incidentally let the inconstant, unchaste sassy old wench die ashamed +...... a'hum, a'hum!" +He coughed twice significantly and sat down. I thought of clapping my +hands again, but as it would draw attention, I refrained. When +Porcupine finished his speech, Hubbard Squash arose politely, slipped +out of his seat, went to the furthest end of the room, and having bowed +to all in a most respectful manner, acknowledged the compliments in the +following way; +"On the occasion of my going to Kyushu for my personal convenience, I am +deeply impressed and appreciate the way my friends have honored me with +this magnificent dinner....... The farewell addresses by our principal +and other gentlemen will be long held in my fondest recollection....... +I am going far away now, but I hope my name be included in the future as +in the past in the list of friends of the gentlemen here to-night." +Then again bowing, he returned to his seat. There was no telling how far +the "good-naturedness" of Hubbard Squash might go. He had respectfully +thanked the principal and the head teacher who had been fooling him. And +it was not a formal, cut-and-dried reply he made, either; by his manner, +tone and face, he appeared to have been really grateful from his heart. +Badger and Red Shirt should have blushed when they were addressed so +seriously by so good a man as Hubbard Squash, but they only listened +with long faces. +After the exchange of addresses, a sizzling sound was heard here and +there, and I too tried the soup which tasted like anything but soup. +There was kamaboko in the kuchitori dish, but instead of being snow +white as it should be, it looked grayish, and was more like a poorly +cooked chikuwa. The sliced tunny was there, but not having been sliced +fine, passed the throat like so many pieces of chopped raw tunny. Those +around me, however, ate with ravenous appetite. They have not tasted, I +guess, the real Yedo dinner. +Meanwhile the bottles began passing round, and all became more or less +"jacked up." Clown proceeded to the front of the principal and +submissively drank to his health. A beastly fellow, this! Hubbard Squash +made a round of all the guests, drinking to their health. A very onerous +job, indeed. When he came to me and proposed my health, I abandoned the +squatting posture and sat up straight. +"Too bad to see you go away so soon. When are you going? I want to see +you off at the beach," I said. +"Thank you, Sir. But never mind that. You're busy," he declined. He +might decline, but I was determined to get excused for the day and give +him a rousing send-off. +Within about an hour from this, the room became pretty lively. +"Hey, have another, hic; ain't goin', hic, have one on me?" One or two +already in a pickled state appeared on the scene. I was little tired, +and going out to the porch, was looking at the old fashioned garden by +the dim star light, when Porcupine came. +"How did you like my speech? Wasn't it grand, though!" he remarked in a +highly elated tone. I protested that while I approved 99 per cent, of +his speech, there was one per cent, that I did not. "What's that one per +cent?" he asked. +"Well, you said,...... there is not a single high-collared guy who with +smooth face entraps innocent people......." +"Yes." +"A 'high-collared guy' isn't enough." +"Then what should I say?" +"Better say,--'a high-collared guy; swindler, bastard, +super-swanker, doubleface, bluffer, totempole, spotter, who looks +like a dog as he yelps.'" +"I can't get my tongue to move so fast. You're eloquent. In the first +place, you know a great many simple words. Strange that you can't make +a speech." +"I reserve these words for use when I chew the rag. If it comes to +speech-making, they don't come out so smoothly." +"Is that so? But they simply come a-running. Repeat that again for me." +"As many times as you like. Listen,--a high-collared guy, swindler, +bastard, super-swanker ..." +While I was repeating this, two shaky fellows came out of the room +hammering the floor. +"Hey, you two gents, if won't do to run away. Won't let you off while +I'm here. Come and have a drink. Bastard? That's fine. Bastardly fine. +Now, come on." +And they pulled Porcupine and me away. These two fellows really had come +to the lavatory, but soaked as they were, in booze bubbles, they +apparently forgot to proceed to their original destination, and were +pulling us hard. All booze fighters seem to be attracted by whatever +comes directly under their eyes for the moment and forget what they had +been proposing to do. +"Say, fellows, we've got bastards. Make them drink. Get them loaded. You +gents got to stay here." +And they pushed me who never attempted to escape against the wall. +Surveying the scene, I found there was no dish in which any edibles were +left. Some one had eaten all his share, and gone on a foraging +expedition. The principal was not there,--I did not know when he left. +At that time, preceded by a coquetish voice, three or four geishas +entered the room. I was a bit surprised, but having been pushed against +the wall, I had to look on quietly. At the instant, Red Shirt who had +been leaning against a pillar with the same old amber pipe stuck into +his mouth with some pride, suddenly got up and started to leave the +room. One of the geishas who was advancing toward him smiled and +courtesied at him as she passed by him. The geisha was the youngest and +prettiest of the bunch. They were some distance away from me and I could +not see very well, but it seemed that she might have said "Good +evening." Red Shirt brushed past as if unconscious, and never showed +again. Probably he followed the principal. +The sight of the geishas set the room immediately in a buzz and it +became noisy as they all raised howls of welcome. Some started the game +of "nanko" with a force that beat the sword-drawing practice. Others +began playing morra, and the way they shook their hands, intently +absorbed in the game, was a better spectacle than a puppet show. +One in the corner was calling "Hey, serve me here," but shaking the +bottle, corrected it to "Hey, fetch me more sake." The whole room +became so infernally noisy that I could scarcely stand it. Amid this +orgy, one, like a fish out of water, sat down with his head bowed. It +was Hubbard Squash. The reason they have held this farewell dinner +party was not in order to bid him a farewell, but because they wanted +to have a jolly good time for themselves with John Barleycorn. He had +come to suffer only. Such a dinner party would have been better had it +not been started at all. +After a while, they began singing ditties in outlandish voices. One of +the geishas came in front of me, and taking up a samisen, asked me to +sing something. I told her I didn't sing, but I'd like to hear, and she +droned out: +"If one can go round and meet the one he wants, banging gongs and drums +...... bang, bang, bang, bang, bing, shouting after wandering Santaro, +there is some one I'd like to meet by banging round gongs and drums +...... bang, bang, bang, bang, b-i-n-g." +She dashed this off in two breaths, and sighed, "O, dear!" She should +have sung something easier. +Clown who had come near us meanwhile, remarked in his flippant tone: +"Hello, dear Miss Su-chan, too bad to see your beau go away so soon." +The geisha pouted, "I don't know." Clown, regardless, began imitating +"gidayu" with a dismal voice,--"What a luck, when she met her sweet +heart by a rare chance...." +The geisha slapped the lap of Clown with a "Cut that out," and Clown +gleefully laughed. This geisha is the one who made goo-goo eyes[J] at +Red Shirt. What a simpleton, to be pleased by the slap of a geisha, this +Clown. He said: +"Say, Su-chan, strike up the string. I'm going to dance the Kiino-kuni." +He seemed yet to dance. +On other side of the room, the old man of Confucius, twisting round his +toothless mouth, had finished as far as "...... dear Dembei-san" and is +asking a geisha who sat in front of him to couch him for the rest. Old +people seem to need polishing up their memorizing system. One geisha is +talking to the teacher of natural history: +"Here's the latest. I'll sing it. Just listen. 'Margaret, the +high-collared head with a white ribbon; she rides on a bike, plays a +violin, and talks in broken English,--I am glad to see you.'" Natural +history appears impressed, and says; +"That's an interesting piece. English in it too." +Porcupine called "geisha, geisha," in a loud voice, and commanded; "Bang +your samisen; I'm going to dance a sword-dance." +His manner was so rough that the geishas were startled and did not +answer. Porcupine, unconcerned, brought out a cane, and began performing +the sword-dance in the center of the room. Then Clown, having danced the +Kii-no-kuni, the Kap-pore[K] and the Durhma-san on the Shelf, almost +stark-naked, with a palm-fibre broom, began turkey-trotting about the +room, shouting "The Sino-Japanese negotiations came to a break......." +The whole was a crazy sight. +I had been feeling sorry for Hubbard Squash, who up to this time had sat +up straight in his full dress. Even were this a farewell dinner held in +his honor, I thought he was under no obligation to look patiently in a +formal dress at the naked dance. So I went to him and persuaded him with +"Say, Koga-san, let's go home." Hubbard Squash said the dinner was in +his honor, and it would be improper for him to leave the room before the +guests. He seemed to be determined to remain. +"What do you care!" I said, "If this is a farewell dinner, make it like +one. Look at those fellows; they're just like the inmates of a lunatic +asylum. Let's go." +And having forced hesitating Hubbard Squash to his feet, we were +just leaving the room, when Clown, marching past, brandishing the +broom, saw us. +"This won't do for the guest of honor to leave before us," he hollered, +"this is the Sino-Japanese negotiations. Can't let you off." He enforced +his declaration by holding the broom across our way. My temper had been +pretty well aroused for some time, and I felt impatient. +"The Sino-Japanese negotiation, eh? Then you're a Chink," and I whacked +his head with a knotty fist. +This sudden blow left Clown staring blankly speechless for a second or +two; then he stammered out: +"This is going some! Mighty pity to knock my head. What a blow on this +Yoshikawa! This makes the Sino-Japanese negotiations the sure stuff." +While Clown was mumbling these incoherent remarks, Porcupine, believing +some kind of row had been started, ceased his sword-dance and came +running toward us. On seeing us, he grabbed the neck of Clown and +pulled him back. +"The Sino-Japane......ouch!......ouch! This is outrageous," and Clown +writhed under the grip of Porcupine who twisted him sideways and threw +him down on the floor with a bang. I do not know the rest. I parted from +Hubbard Squash on the way, and it was past eleven when I returned home. +CHAPTER X. +The town is going to celebrate a Japanese victory to-day, and there is +no school. The celebration is to be held at the parade ground, and +Badger is to take out all the students and attend the ceremony. As one +of the instructors, I am to go with them. The streets are everywhere +draped with flapping national flags almost enough to dazzle the eyes. +There were as many as eight hundred students in all, and it was +arranged, under the direction of the teacher of physical culture to +divide them into sections with one teacher or two to lead them. The +arrangement itself was quite commendable, but in its actual operation +the whole thing went wrong. All students are mere kiddies who, ever too +fresh, regard it as beneath their dignity not to break all regulations. +This rendered the provision of teachers among them practically useless. +They would start marching songs without being told to, and if they +ceased the marching songs, they would raise devilish shouts without +cause. Their behavior would have done credit to the gang of tramps +parading the streets demanding work. When they neither sing nor shout, +they tee-hee and giggle. Why they cannot walk without these disorder, +passes my understanding, but all Japanese are born with their mouths +stuck out, and no kick will ever be strong enough to stop it. Their +chatter is not only of simple nature, but about the teachers when their +back is turned. What a degraded bunch! I made the students apologize to +me on the dormitory affair, and considered the incident closed. But I +was mistaken. To borrow the words of the old lady in the boarding house, +I was surely wrong Mr. Wright. The apology they offered was not prompted +by repentance in their hearts. They had kowtowed as a matter of form by +the command of the principal. Like the tradespeople who bow their heads +low but never give up cheating the public, the students apologize but +never stop their mischiefs. Society is made up, I think it probable, of +people just like those students. One may be branded foolishly honest if +he takes seriously the apologies others might offer. We should regard +all apologies a sham and forgiving also as a sham; then everything would +be all right. If one wants to make another apologize from his heart, he +has to pound him good and strong until he begs for mercy from his heart. +As I walked along between the sections, I could hear constantly the +voices mentioning "tempura" or "dango." And as there were so many of +them, I could not tell which one mentioned it. Even if I succeeded in +collaring the guilty one I was sure of his saying, "No, I didn't mean +you in saying tempura or dango. I fear you suffer from nervousness and +make wrong inferences." This dastardly spirit has been fostered from the +time of the feudal lords, and is deep-rooted. No amount of teaching or +lecturing will cure it. If I stay in a town like this for one year or +so, I may be compelled to follow their example, who knows,--clean and +honest though I have been. I do not propose to make a fool of myself by +remaining quiet when others attempt to play games on me, with all their +excuses ready-made. They are men and so am I--students or kiddies or +whatever they may be. They are bigger than I, and unless I get even with +them by punishment, I would cut a sorry figure. But in the attempt to +get even, if I resort to ordinary means, they are sure to make it a +boomerang. If I tell them, "You're wrong," they will start an eloquent +defence, because they are never short of the means of sidestepping. +Having defended themselves, and made themselves appear suffering +martyrs, they would begin attacking me. As the incident would have been +started by my attempting to get even with them, my defence would not be +a defence until I can prove their wrong. So the quarrel, which they had +started, might be mistaken, after all, as one begun by me. But the more +I keep silent the more they would become insolent, which, speaking +seriously, could not be permitted for the sake of public morale. In +consequence, I am obliged to adopt an identical policy so they cannot +catch men in playing it back on them. If the situation comes to that, it +would be the last day of the Yedo kid. Even so, if I am to be subjected +to these pin-pricking[L] tricks, I am a man and got to risk losing off +the last remnant of the honor of the Yedo kid. I became more convinced +of the advisability of returning to Tokyo quickly and living with Kiyo. +To live long in such a countrytown would be like degrading myself for a +purpose. Newspaper delivering would be preferable to being degraded so +far as that. +I walked along with a sinking heart, thinking like this, when the head +of our procession became suddenly noisy, and the whole came to a full +stop. I thought something has happened, stepped to the right out of the +ranks, and looked toward the direction of the noise. There on the corner +of Otemachi, turning to Yakushimachi, I saw a mass packed full like +canned sardines, alternately pushing back and forth. The teacher of +physical culture came down the line hoarsely shouting to all to be +quiet. I asked him what was the matter, and he said the middle school +and the normal had come to a clash at the corner. +The middle school and the normal, I understood, are as much friendly as +dogs and monkeys. It is not explained why but their temper was +hopelessly crossed, and each would try to knock the chip off the +shoulder of the other on all occasions. I presume they quarrel so much +because life gets monotonous in this backwoods town. I am fond of +fighting, and hearing of the clash, darted forward to make the most of +the fun. Those foremost in the line are jeering, "Get out of the way, +you country tax!"[12] while those in the rear are hollowing "Push them +out!" I passed through the students, and was nearing the corner, when I +heard a sharp command of "Forward!" and the line of the normal school +began marching on. The clash which had resulted from contending for the +right of way was settled, but it was settled by the middle school giving +way to the normal. From the point of school-standing the normal is said +to rank above the middle. +[Footnote 12: The normal school in the province maintains the students +mostly on the advance-expense system, supported by the country tax.] +The ceremony was quite simple. The commander of the local brigade read a +congratulatory address, and so did the governor, and the audience +shouted banzais. That was all. The entertainments were scheduled for the +afternoon, and I returned home once and started writing to Kiyo an +answer which had been in my mind for some days. Her request had been +that I should write her a letter with more detailed news; so I must get +it done with care. But as I took up the rolled letter-paper, I did not +know with what I should begin, though I have many things to write about. +Should I begin with that? That is too much trouble. Or with this? It is +not interesting. Isn't there something which will come out smoothly, I +reflected, without taxing my head too much, and which will interest +Kiyo. There seemed, however, no such item as I wanted I grated the +ink-cake, wetted the writing brush, stared at the letter-paper--stared +at the letter-paper, wetted the writing brush, grated the ink-cake--and, +having repeated the same thing several times, I gave up the letter +writing as not in my line, and covered the lid of the stationery box. To +write a letter was a bother. It would be much simpler to go back to +Tokyo and see Kiyo. Not that I am unconcerned about the anxiety of Kiyo, +but to get up a letter to please the fancy of Kiyo is a harder job than +to fast for three weeks. +I threw down the brush and letter-paper, and lying down with my bent +arms as a pillow, gazed at the garden. But the thought of the letter to +Kiyo would come back in my mind. Then I thought this way; If I am +thinking of her from my heart, even at such a distance, my sincerity +would find responsive appreciation in Kiyo. If it does find response, +there is no need of sending letters. She will regard the absence of +letters from me as a sign of my being in good health. If I write in case +of illness or when something unusual happens, that will be sufficient. +The garden is about thirty feet square, with no particular plants worthy +of name. There is one orange tree which is so tall as to be seen above +the board fence from outside. Whenever I returned from the school I used +to look at this orange tree. For to those who had not been outside of +Tokyo, oranges on the tree are rather a novel sight. Those oranges now +green will ripen by degrees and turn to yellow, when the tree would +surely be beautiful. There are some already ripened. The old lady told +me that they are juicy, sweet oranges. "They will all soon be ripe, and +then help yourself to all you want," she said. I think I will enjoy a +few every day. They will be just right in about three weeks. I do not +think I will have to leave the town in so short a time as three weeks. +While my attention was centered on the oranges, Porcupine[M] came in. +"Say, to-day being the celebration[N] of victory, I thought I would get +something good to eat with you, and bought some beef." +So saying, he took out a package covered with a bamboo-wrapper, and +threw it down in the center of the room. I had been denied the pleasure +of patronizing the noodle house or dango shop, on top of getting sick of +the sweet potatoes and tofu, and I welcomed the suggestion with "That's +fine," and began cooking it with a frying pan and some sugar borrowed +from the old lady. +Porcupine, munching the beef to the full capacity of his mouth, asked me +if I knew Red Shirt having a favorite geisha. I asked if that was not +one of the geishas who came to our dinner the other night, and he +answered, "Yes, I got the wind of the fact only recently; you're sharp." +"Red Shirt always speaks of refinement of character or of mental +consolation, but he is making a fool of himself by chasing round a +geisha. What a dandy rogue. We might let that go if he wouldn't make +fuss about others making fools of themselves. I understand through the +principal he stopped your going even to noodle houses or dango shops as +unbecoming to the dignity of the school, didn't he?" +"According to his idea, running after a geisha is a mental consolation +but tempura or dango is a material pleasure, I guess. If that's mental +consolation, why doesn't the fool do it above board? You ought to see +the jacknape skipping out of the room when the geisha came into it the +other night,--I don't like his trying to deceive us, but if one were to +point it out for him, he would deny it or say it was the Russian +literature or that the haiku is a half-brother of the new poetry, and +expect to hush it up by twaddling soft nonsense. A weak-knee like him is +not a man. I believe he lived the life of a court-maid in former life. +Perhaps his daddy might have been a kagema at Yushima in old days." +"What is a kagema?" +"I suppose something very unmanly,--sort of emasculated chaps. Say, that +part isn't cooked enough. It might give you tape worm." +"So? I think it's all right. And, say, Red Shirt is said to frequent +Kadoya at the springs town and meet his geisha there, but he keeps +it in dark." +"Kadoya? That hotel?" +"Also a restaurant. So we've got to catch him there with his geisha and +make it hot for him right to his face." +"Catch him there? Suppose we begin a kind of night watch?" +"Yes, you know there is a rooming house called Masuya in front of +Kadoya. We'll rent one room upstairs of the house, and keep peeping +through a loophole we could make in the shoji." +"Will he come when we keep peeping at him?" +"He may. We will have to do it more than one night. Must expect to keep +it up for at least two weeks." +"Say, that would make one pretty well tired, I tell you. I sat up every +night for about one week attending my father when he died, and it left +me thoroughly down and out for some time afterward." +"I don't care if I do get tired some. A crook like Red Shirt should not +go unpunished that way for the honor of Japan, and I am going to +administer a chastisement in behalf of heaven." +"Hooray! If things are decided upon that way, I am game. And we are +going to start from to-night?" +"I haven't rented a room at Masuya yet, so can't start it to-night." +"Then when?" +"Will start before long. I'll let you know, and want you help me." +"Right-O. I will help you any time. I am not much myself at scheming, +but I am IT when it comes to fighting." +While Porcupine and I were discussing the plan of subjugating Red Shirt, +the old lady appeared at the door, announcing that a student was wanting +to see Professor Hotta. The student had gone to his house, but seeing +him out, had come here as probable to find him. Porcupine went to the +front door himself, and returning to the room after a while, said: +"Say, the boy came to invite us to go and see the entertainment of the +celebration. He says there is a big bunch of dancers from Kochi to dance +something, and it would be a long time before we could see the like of +it again. Let's go." +Porcupine seemed enthusiastic over the prospect of seeing that dance, +and induced me to go with him. I have seen many kinds of dance in Tokyo. +At the annual festival of the Hachiman Shrine, moving stages come around +the district, and I have seen the Shiokukmi and almost any other +variety. I was little inclined to see that dance by the sturdy fellows +from Tosa province, but as Porcupine was so insistent, I changed my mind +and followed him out. I did not know the student who came to invite +Porcupine, but found he was the younger brother of Red Shirt. Of all +students, what a strange choice for a messenger! +The celebration ground was decorated, like the wrestling amphitheater at +Ryogoku during the season, or the annual festivity of the Hommonji +temple, with long banners planted here and there, and on the ropes that +crossed and recrossed in the mid-air were strung the colors of all +nations, as if they were borrowed from as many nations for the occasion +and the large roof presented unusually cheerful aspect. On the eastern +corner there was built a temporary stage upon which the dance of Koehi +was to be performed. For about half a block, with the stage on the +right, there was a display of flowers and plant settings arranged on +shelves sheltered with reed screens. Everybody was looking at the +display seemingly much impressed, but it failed to impress me. If +twisted grasses or bamboos afforded so much pleasure, the gallantry of a +hunchback or the husband of a wrong pair should give as much pleasure to +their eyes. +In the opposite direction, aerial bombs and fire works were steadily +going on. A balloon shot out on which was written "Long Live the +Empire!" It floated leisurely over the pine trees near the castle +tower, and fell down inside the compound of the barracks. Bang! A black +ball shot up against the serene autumn sky; burst open straight above +my head, streams of luminous green smoke ran down in an umbrella-shape, +and finally faded. Then another balloon. It was red with "Long Live the +Army and Navy" in white. The wind slowly carried it from the town +toward the Aioi village. Probably it would fall into the yard of Kwanon +temple there. +At the formal celebration this morning there were not quite so many as +here now. It was surging mass that made me wonder how so many people +lived in the place. There were not many attractive faces among the +crowd, but as far as the numerical strength went, it was a formidable +one. In the meantime that dance had begun. I took it for granted that +since they call it a dance, it would be something similar to the kind of +dance by the Fujita troupe, but I was greatly mistaken. +Thirty fellows, dressed up in a martial style, in three rows of ten +each, stood with glittering drawn swords. The sight was an eye-opener, +indeed. The space between the rows measured about two feet, and that +between the men might have been even less. One stood apart from the +group. He was similarly dressed but instead of a drawn sword, he carried +a drum hung about his chest. This fellow drawled out signals the tone of +which suggested a mighty easy-life, and then croaking a strange song, he +would strike the drum. The tune was outlandishly unfamiliar. One might +form the idea by thinking it a combination of the Mikawa Banzai and the +Fudarakuya. +The song was drowsy, and like syrup in summer is dangling and slovenly. +He struck the drum to make stops at certain intervals. The tune was kept +with regular rhythmical order, though it appeared to have neither head +nor tail. In response to this tune, the thirty drawn swords flash, with +such dexterity and speed that the sight made the spectator almost +shudder. With live men within two feet of their position, the sharp +drawn blades, each flashing them in the same manner, they looked as if +they might make a bloody mess unless they were perfectly accurate in +their movements. If it had been brandishing swords alone without moving +themselves, the chances of getting slashed or cut might have been less, +but sometimes they would turn sideways together, or clear around, or +bend their knees. Just one second's difference in the movement, either +too quick or too late, on the part of the next fellow, might have meant +sloughing off a nose or slicing off the head of the next fellow. The +drawn swords moved in perfect freedom, but the sphere of action was +limited to about two feet square, and to cap it all, each had to keep +moving with those in front and back, at right and left, in the same +direction at the same speed. This beats me! The dance of the Shiokumi or +the Sekinoto would make no show compared with this! I heard them say the +dance requires much training, and it could not be an easy matter to make +so many dancers move in a unison like this. Particularly difficult part +in the dance was that of the fellow with drum stuck to his chest. The +movement of feet, action of hands, or bending of knees of those thirty +fellows were entirely directed by the tune with which he kept them +going. To the spectators this fellow's part appeared the easiest. He +sang in a lazy tune, but it was strange that he was the fellow who takes +the heaviest responsibility. +While Porcupine and I, deeply impressed, were looking at the dance with +absorbing interest, a sudden hue and cry was raised about half a block +off. A commotion was started among those who had been quietly enjoying +the sights and all ran pell-mell in every direction. Some one was heard +saying "fight!" Then the younger brother of Red Shirt came running +forward through the crowd. +"Please, Sir," he panted, "a row again! The middles are going to get +even with the normals and have just begun fighting. Come quick, Sir!" +And he melted somewhere into the crowd. +"What troublesome brats! So they're at it again, eh? Why can't +they stop it!" +Porcupine, as he spoke, dashed forward, dodging among the running crowd. +He meant, I think, to stop the fight, because he could not be an idle +spectator once he was informed of the fact. I of course had no intention +of turning tail, and hastened on the heels of Porcupine. The fight was +in its fiercest. There were about fifty to sixty normals, and the +middles numbered by some ninety. The normals wore uniform, but the +middles had discarded their uniform and put on Japanese civilian +clothes, which made the distinction between the two hostile camps easy. +But they were so mixed up, and wrangling with such violence, that we did +not know how and where we could separate them. +Porcupine, apparently at a loss what to do, looked at the wild scene +awhile, then turned to me, saying: +"Let's jump in and separate them. It will be hell if cops get on them." +I did not answer, but rushed to the spot where the scuffle appeared +most violent. +"Stop there! Cut this out! You're ruining the name of the school! Stop +this, dash you!" +Shouting at the top of my voice, I attempted to penetrate the line which +seemed to separate the hostile sides, but this attempt did not succeed. +When about ten feet into the turmoil, I could neither advance nor +retreat. Right in my front, a comparatively large normal was grappling +with a middle about sixteen years of ago. +"Stop that!" +I grabbed the shoulder of the normal and tried to force them apart when +some one whacked my feet. On this sudden attack, I let go the normal and +fell down sideways. Some one stepped on my back with heavy shoes. With +both hands and knees upon the ground, I jumped up and the fellow on my +back rolled off to my right. I got up, and saw the big body of Porcupine +about twenty feet away, sandwiched between the students, being pushed +back and forth, shouting, "Stop the fight! Stop that!" +"Say, we can't do anything!" I hollered at him, but unable to hear, I +think, he did not answer. +A pebble-stone whiffled through the air and hit squarely on my cheek +bone; the same moment some one banged my back with a heavy stick +from behind. +"Profs mixing in!" "Knock them down!" was shouted. +"Two of them; big one and small. Throw stones at them!" Another shout. +"Drat you fresh jackanapes!" I cried as I wallopped the head of a normal +nearby. Another stone grazed my head, and passed behind me. I did not +know what had become of Porcupine, I could not find him. Well, I could +not help it but jumped into the teapot to stop the tempest. I wasn't[O] +a Hottentot to skulk away on being shot at with pebble-stones. What did +they think I was anyway! I've been through all kinds of fighting in +Tokyo, and can take in all fights one may care to give me. I slugged, +jabbed and banged the stuffing out of the fellow nearest to me. Then +some one cried, "Cops! Cops! Cheese it! Beat it!" At that moment, as if +wading through a pond of molasses, I could hardly move, but the next I +felt suddenly released and both sides scampered off simultaneously. Even +the country fellows do creditable work when it comes to retreating, more +masterly than General Kuropatkin, I might say. +I searched for Porcupine who, I found his overgown torn to shreds, was +wiping his nose. He bled considerably, and his nose having swollen was a +sight. My clothes were pretty well massed with dirt, but I had not +suffered quite as much damage as Porcupine. I felt pain in my cheek and +as Porcupine said, it bled some. +About sixteen police officers arrived at the scene but, all the students +having beat it in opposite directions, all they were able to catch were +Porcupine and me. We gave them our names and explained the whole story. +The officers requested us to follow them to the police station which we +did, and after stating to the chief of police what had happened, we +returned home. +CHAPTER XI. +The next morning on awakening I felt pains all over my body, due, I +thought, to having had no fight for a long time. This is not creditable +to my fame as regards fighting, so I thought while in bed, when the old +lady brought me a copy of the Shikoku Shimbun. I felt so weak as to need +some effort even reaching for the paper. But what should be man so +easily upset by such a trifling affair,--so I forced myself to turn in +bed, and, opening its second page, I was surprised. There was the whole +story of the fight of yesterday in print. Not that I was surprised by +the news of the fight having been published, but it said that one +teacher Hotta of the Middle School and one certain saucy Somebody, +recently from Tokyo, of the same institution, not only started this +trouble by inciting the students, but were actually present at the scene +of the trouble, directing the students and engaged themselves against +the students of the Normal School. On top of this, something of the +following effect was added. +"The Middle School in this prefecture has been an object of admiration +by all other schools for its good and ideal behavior. But since this +long-cherished honor has been sullied by these two irresponsible +persons, and this city made to suffer the consequent indignity, we have +to bring the perpetrators to full account. We trust that before we take +any step in this matter, the authorities will have those 'toughs' +properly punished, barring them forever from our educational circles." +All the types were italicized, as if they meant to administer +typographical chastisement upon us. "What the devil do I care!" I +shouted, and up I jumped out of bed. Strange to say, the pain in my +joints became tolerable. +I rolled up the newspaper and threw it into the garden. Not satisfied, I +took that paper to the cesspool and dumped it there. Newspapers tell +such reckless lies. There is nothing so adept, I believe, as the +newspaper in circulating lies. It has said what I should have said. And +what does it mean by "one saucy Somebody who is recently from Tokyo?" Is +there any one in this wide world with the name of Somebody? Don't +forget, I have a family and personal name of my own which I am proud of. +If they want to look at my family-record, they will bow before every one +of my ancestors from Mitsunaka Tada down. Having washed my face, my +cheek began suddenly smarting. I asked the old lady for a mirror, and +she asked if I had read the paper of this morning. "Yes," I said, "and +dumped it in the cesspool; go and pick it up if you want it,"--and she +withdrew with a startled look. Looking in the mirror, I saw bruises on +my cheek. Mine is a precious face to me. I get my face bruised, and am +called a saucy Somebody as if I were nobody. That is enough. +It will be a reflection on my honor to the end of my days if it is said +that I shunned the public gaze and kept out of the school on account of +the write-up in the paper. So, after the breakfast, I attended the +school ahead of all. One after the other, all coming to the school would +grin at my face. What is there to laugh about! This face is my own, +gotten up, I am sure, without the least obligation on their part. By and +by, Clown appeared. +"Ha, heroic action yesterday. Wounds of honor, eh?" +He made this sarcastic remark, I suppose, in revenge for the knock he +received on his head from me at the farewell dinner. +"Cut out nonsense; you get back there and suck your old drawing +brushes!" Then he answered "that was going some," and enquired if it +pained much? +"Pain or no pain, this is my face. That's none of your business," I +snapped back in a furious temper. Then Clown took his seat on the other +side, and still keeping his eye on me, whispered and laughed with the +teacher of history next to him. +Then came Porcupine. His nose had swollen and was purple,--it was a +tempting object for a surgeon's knife. His face showed far worse (is it +my conceit that make this comparison?) than mine. I and Porcupine are +chums with desks next to each other, and moreover, as ill-luck would +have it, the desks are placed right facing the door. Thus were two +strange faces placed together. The other fellows, when in want of +something to divert them, would gaze our way with regularity. They say +"too bad," but they are surely laughing in their minds as "ha, these +fools!" If that is not so, there is no reason for their whispering +together and grinning like that. In the class room, the boys clapped +their hands when I entered; two or three of them banzaied. I could not +tell whether it was an enthusiastic approval or open insult. While I and +Porcupine were thus being made the cynosures of the whole school, Red +Shirt came to me as usual. +"Too bad, my friend; I am very sorry indeed for you gentlemen," he said +in a semi-apologetic manner. "I've talked with the principal in regard +to the story in the paper, and have arranged to demand that the paper +retract the report, so you needn't worry on that score. You were plunged +into the trouble because my brother invited Mr. Hotta, and I don't know +how I can apologize you! I'm going to do my level best in this matter; +you gentlemen please depend on that." At the third hour recess the +principal came out of his room, and seemed more or less perturbed, +saying, "The paper made a bad mess of it, didn't it? I hope the matter +will not become serious." +As to anxiety, I have none. If they propose to relieve me, I intend +to tender my resignation before I get fired,--that's all. However, if +I resign with no fault on my part, I would be simply giving the paper +advantage. I thought it proper to make the paper take back what it +had said, and stick to my position. I was going to the newspaper +office to give them a piece of my mind on my way back but having been +told that the school had already taken steps to have the story +retracted, I did not. +Porcupine and I saw the principal and Red Shirt at a convenient hour, +giving them a faithful version of the incident. The principal and Red +Shirt agreed that the incident must have been as we said and that the +paper bore some grudge against the school and purposely published such a +story. Red Shirt made a round of personal visits on each teacher in the +room, defending and explaining our action in the affair. Particularly he +dwelt upon the fact that his brother invited Porcupine and it was his +fault. All teachers denounced the paper as infamous and agreed that we +two deserved sympathy. +On our way home, Porcupine warned me that Red Shirt smelt suspicious, +and we would be done unless we looked out. I said he had been smelling +some anyway,--it was not necessarily so just from to-day. Then he said +that it was his trick to have us invited and mixed in the fight +yesterday,--"Aren't you on to that yet?" Well, I was not. Porcupine was +quite a Grobian but he was endowed, I was impressed, with a better +brain than I. +"He made us mix into the trouble, and slipped behind and contrived to +have the paper publish the story. What a devil!" +"Even the newspaper in the band wagon of Red Shirt? That surprises me. +But would the paper listen to Red Shirt so easily?" +"Wouldn't it, though. Darn easy thing if one has friends in the +paper."[P] +"Has he any?" +"Suppose he hasn't, still that's easy. Just tell lies and say such and +such are facts, and the paper will take it up." +"A startling revelation, this. If that was really a trick of Red Shirt, +we're likely to be discharged on account of this affair." +"Quite likely we may be discharged." +"Then I'll tender my resignation tomorrow, and back to Tokyo I go. I am +sick of staying in such a wretched hole." +"Your resignation wouldn't make Red Shirt squeal." +"That's so. How can he be made to squeal?" +"A wily guy like him always plots not to leave any trace behind, and it +would be difficult to follow his track." +"What a bore! Then we have to stand in a false light, eh? Damn it! I +call all kinds of god to witness if this is just and right!" +"Let's wait for two or three days and see how it turns out. And if +we can't do anything else, we will have to catch him at the hot +springs town." +"Leaving this fight affair a separate case?" +"Yes. We'll have to his hit weak spot with our own weapon." +"That may be good. I haven't much to say in planning it out; I leave it +to you and will do anything at your bidding." +I parted from Porcupine then. If Red Shirt was really instrumental in +bringing us two into the trouble as Porcupine supposed, he certainly +deserves to be called down. Red Shirt outranks us in brainy work. And +there is no other course open but to appeal to physical force. No wonder +we never see the end of war in the world. Among individuals, it is, +after all, the question of superiority of the fist. +Next day I impatiently glanced over the paper, the arrival of which I +had been waiting with eagerness, but not a correction of the news or +even a line of retraction could be found. I pressed the matter on +Badger when I went to the school, and he said it might probably appear +tomorrow. On that "tomorrow" a line of retraction was printed in tiny +types. But the paper did not make any correction of the story. I called +the attention of Badger to the fact, and he replied that that was about +all that could be done under the circumstance. The principal, with the +face like a badger and always swaggering, is surprisingly, wanting in +influence. He has not even as much power as to bring down a country +newspaper, which had printed a false story. I was so thoroughly +indignant that I declared I would go alone to the office and see the +editor-in-chief on the subject, but Badger said no. +"If you go there and have a blowup with the editor," he continued, "it +would only mean of your being handed out worse stuff in the paper again. +Whatever is published in a paper, right or wrong, nothing can be done +with it." And he wound up with a remark that sounded like a piece of +sermon by a Buddhist bonze that "We must be contented by speedily +despatching the matter from our minds and forgetting it." +If newspapers are of that character, it would be beneficial for us all +to have them suspended,--the sooner the better. The similarity of the +unpleasant sensation of being written-up in a paper and being +bitten-down by a turtle became plain for the first time by the +explanation of Badger. +About three days afterward, Porcupine came to me excited, and said that +the time has now come, that he proposes to execute that thing we had +planned out. Then I will do so, I said, and readily agreed to join him. +But Porcupine jerked his head, saying that I had better not. I asked him +why, and he asked if I had been requested by the principal to tender my +resignation. No, I said, and asked if he had. He told me that he was +called by the principal who was very, very sorry for him but under the +circumstance requested him to decide to resign. +"That isn't fair. Badger probably had been pounding his belly-drum too +much and his stomach is upside down," I said, "you and I went to the +celebration, looked at the glittering sword dance together, and jumped +into the fight together to stop it. Wasn't it so? If he wants you to +tender your resignation, he should be impartial and should have asked me +to also. What makes everything in the country school so dull-head. This +is irritating!" +"That's wire-pulling by Red Shirt," he said. "I and Red Shirt cannot go +along together, but they think you can be left as harmless." +"I wouldn't get along with that Red Shirt either. Consider me harmless, +eh? They're getting too gay with me." +"You're so simple and straight that they think they can handle you in +any old way." +"Worse still. I wouldn't get along with him, I tell you." +"Besides, since the departure of Koga, his successor has not arrived. +Furthermore, if they fire me and you together, there will be blank spots +in the schedule hours at the school." +"Then they expect me to play their game. Darn the fellow! See if they +can make me." +On going to the school next day I made straightway for the room of the +principal and started firing; +"Why don't you ask me to put in my resignation?" I said. +"Eh?" Badger stared blankly. +"You requested Hotta to resign, but not me. Is that right?" +"That is on account of the condition of the school......" +"That condition is wrong, I dare say. If I don't have to resign, there +should be no necessity for Hotta to resign either." +"I can't offer a detailed explanation about that......as to Hotta, it +cannot be helped if he goes...... ......we see no need of your +resigning." +Indeed, he is a badger. He jabbers something, dodging the point, but +appears complacent. So I had to say: +"Then, I will tender my resignation. You might have thought that I +would remain peacefully while Mr. Hotta is forced to resign, but I +cannot do it" +"That leaves us in a bad fix. If Hotta goes away and you follow him, we +can't teach mathematics here." +"None of my business if you can't." +"Say, don't be so selfish. You ought to consider the condition of the +school. Besides, if it is said that you resigned within one month of +starting a new job, it would affect your record in the future. You +should consider that point also." +"What do I care about my record. Obligation is more important +than record." +"That's right. What you say is right, but be good enough to take our +position into consideration. If you insist on resigning, then resign, +but please stay until we get some one to take your place. At any rate, +think the matter over once more, please." +The reason was so plain as to discourage any attempt to think it over, +but as I took some pity on Badger whose face reddened or paled +alternately as he spoke, I withdrew on the condition that I would think +the matter over. I did not talk with Red Shirt. If I have to land him +one, it was better, I thought, to have it bunched together and make it +hot and strong. +I acquainted Porcupine with the details of my meeting with Badger. He +said he had expected it to be about so, and added that the matter of +resignation can be left alone without causing me any embarrassment +until the time comes. So I followed his advice. Porcupine appears +somewhat smarter than I, and I have decided to accept whatever advices +he may give. +Porcupine finally tendered his resignation, and having bidden farewell +of all the fellow teachers, went down to Minato-ya on the beach. But he +stealthily returned to the hot springs town, and having rented a front +room upstairs of Masuya, started peeping through the hole he fingered +out in the shoji. I am the only person who knows of this. If Red Shirt +comes round, it would be night anyway, and as he is liable to be seen by +students or some others during the early part in the evening, it would +surely be after nine. For the first two nights, I was on the watch till +about 11 o'clock, but no sight of Red Shirt was seen. On the third +night, I kept peeping through from nine to ten thirty, but he did not +come. Nothing made me feel more like a fool than returning to the +boarding house at midnight after a fruitless watch. In four or five +days, our old lady began worrying about me and advised me to quit night +prowling,--being married. My night prowling is different from that kind +of night prowling. Mine is that of administering a deserved +chastisement. But then, when no encouragement is in sight after one +week, it becomes tiresome. I am quick tempered, and get at it with all +zeal when my interest is aroused, and would sit up all night to work it +out, but I have never shone in endurance. However loyal a member of the +heavenly-chastisement league I may be, I cannot escape monotony. On the +sixth night I was a little tired, and on the seventh thought I would +quit. Porcupine, however, stuck to it with bull-dog tenacity. From early +in the evening up to past twelve, he would glue his eye to the shoji and +keep steadily watching under the gas globe of Kadoya. He would surprise +me, when I come into the room, with figures showing how many patrons +there were to-day, how many stop-overs and how many women, etc. Red +Shirt seems never to be coming, I said, and he would fold his arms, +audibly sighing, "Well, he ought to." If Red Shirt would not come just +for once, Porcupine would be deprived of the chance of handing out a +deserved and just punishment. +I left my boarding house about 7 o'clock on the eighth night and after +having enjoyed my bath, I bought eight raw eggs. This would counteract +the attack of sweet potatoes by the old lady. I put the eggs into my +right and left pockets, four in each, with the same old red towel hung +over my shoulder, my hands inside my coat, went to Masuya. I opened the +shoji of the room and Porcupine greeted me with his Idaten-like face +suddenly radiant, saying: +"Say, there's hope! There's hope!" Up to last night, he had been +downcast, and even I felt gloomy. But at his cheerful countenance, I too +became cheerful, and before hearing anything, I cried, "Hooray! Hooray!" +"About half past seven this evening," he said, "that geisha named Kosuzu +has gone into Kadoya." +"With Red Shirt?" +"No." +"That's no good then." +"There were two geishas......seems to me somewhat hopeful." +"How?" +"How? Why, the sly old fox is likely to send his girls ahead[Q], and +sneak round behind later." +"That may be the case. About nine now, isn't it?" +"About twelve minutes past nine," said he, pulling out a watch with +a nickel case, "and, say put out the light. It would be funny to +have two silhouettes of bonze heads on the shoji. The fox is too +ready to suspect." +I blew out the lamp which stood upon the lacquer-enameled table. The +shoji alone was dimly plain by the star light. The moon has not come up +yet. I and Porcupine put our faces close to the shoji, watching almost +breathless. A wall clock somewhere rang half past nine. +"Say, will he come to-night, do you think? If he doesn't show up, I +quit." +"I'm going to keep this up while my money lasts." +"Money? How much have you?" +"I've paid five yen and sixty sen up to to-day for eight days. I pay my +bill every night, so I can jump out anytime." +"That's well arranged. The people of this hotel must have been rather +put out, I suppose." +"That's all right with the hotel; only I can't take my mind off +the house." +"But you take some sleep in daytime." +"Yes, I take a nap, but it's nuisance because I can't go out." +"Heavenly chastisement is a hard job, I'm sure," I said. "If he gives +us the slip after giving us such trouble, it would have been a +thankless task." +"Well, I'm sure he will come to-night...--... Look, look!" His voice +changed to whisper and I was alert in a moment. A fellow with a black +hat looked up at the gas light of Kadoya and passed on into the +darkness. No, it was not Red Shirt. Disappointing, this! Meanwhile the +clock at the office below merrily tinkled off ten. It seems to be +another bum watch to-night. +The streets everywhere had become quiet. The drum playing in the +tenderloin reached our ears distinctively. The moon had risen from +behind the hills of the hot springs. It is very light outside. Then +voices were heard below. We could not poke our heads out of the window, +so were unable to see the owners of the voices, but they were evidently +coming nearer. The dragging of komageta (a kind of wooden footwear) was +heard. They approached so near we could see their shadows. +"Everything is all right now. We've got rid of the stumbling block." It +was undoubtedly the voice of Clown. +"He only glories in bullying but has no tact." This from Red Shirt. +"He is like that young tough, isn't he? Why, as to that young tough, he +is a winsome, sporty Master Darling." +"I don't want my salary raised, he says, or I want to tender +resignation,--I'm sure something is wrong with his nerves." +I was greatly inclined to open the window, jump out of the second story +and make them see more stars than they cared to, but I restrained myself +with some effort. The two laughed, and passed below the gas light, and +into Kadoya. +"Say." +"Well." +"He's here." +"Yes, he has come at last." +"I feel quite easy now." +"Damned Clown called me a sporty Master Darling." +"The stumbling[R] block means me. Hell!" +I and Porcupine had to waylay them on their return. But we knew no more +than the man in the moon when they would come out. Porcupine went down +to the hotel office, notifying them to the probability of our going out +at midnight, and requesting them to leave the door unfastened so we +could get out anytime. As I think about it now, it is wonderful how the +hotel people complied with our request. In most cases, we would have +been taken for burglars. +It was trying to wait for the coming of Red Shirt, but it was still more +trying to wait for his coming out again. We could not go to sleep, nor +could we remain with our faces stuck to the shoji all the time our minds +constantly in a state of feverish agitation. In all my life, I never +passed such fretful, mortifying hours. I suggested that we had better go +right into his room and catch him but Porcupine rejected the proposal +outright. If we get in there at this time of night, we are likely to be +prevented from preceding much further, he said, and if we ask to see +him, they will either answer that he is not there or will take us into a +different room. Supposing we do break into a room, we cannot tell of all +those many rooms, where we can find him. There is no other way but to +wait for him to come out, however tiresome it may be. So we sat up till +five in the morning. +The moment we saw them emerging from Kadoya, I and Porcupine followed +them. It was some time before the first train started and they had to +walk up to town. Beyond the limit of the hot springs town, there is a +road for about one block running through the rice fields, both sides of +which are lined with cedar trees. Farther on are thatch-roofed farm +houses here and there, and then one comes upon a dyke leading straight +to the town through the fields. We can catch them anywhere outside the +town, but thinking it would be better to get them, if possible, on the +road lined with cedar trees where we may not be seen by others, we +followed them cautiously. Once out of the town limit, we darted on a +double-quick time, and caught up with them. Wondering what was coming +after them, they turned back, and we grabbed their shoulders. We cried, +"Wait!" Clown, greatly rattled, attempted to escape, but I stepped in +front of him to cut off his retreat. +"What makes one holding the job of a head teacher stay over night at +Kadoya!" Porcupine directly fired the opening gun. +"Is there any rule that a head teacher should not stay over night at +Kadoya?" Red Shirt met the attack in a polite manner. He looked a +little pale. +"Why the one who is so strict as to forbid others from going even to +noodle house or dango shop as unbecoming to instructors, stayed over +night at a hotel with a geisha!" +Clown was inclined to run at the first opportunity; so kept I +before him. +"What's that Master Darling of a young tough!" I roared. +"I didn't mean you. Sir. No, Sir, I didn't mean you, sure." He insisted +on this brazen excuse. I happened to notice at that moment that I had +held my pockets with both hands. The eggs in both pockets jerked so when +I ran, that I had been holding them, I thrust my hand into the pocket, +took out two and dashed them on the face of Clown. The eggs crushed, and +from the tip of his nose the yellow streamed down. Clown was taken +completely surprised, and uttering a hideous cry, he fell down on the +ground and begged for mercy. I had bought those eggs to eat, but had not +carried them for the purpose of making "Irish Confetti" of them. +Thoroughly roused, in the moment of passion, I had dashed them at him +before I knew what I was doing. But seeing Clown down and finding my +hand grenade successful, I banged the rest of the eggs on him, +intermingled with "Darn you, you sonovagun!" The face of Clown was +soaked in yellow. +While I was bombarding Clown with the eggs, Porcupine was firing at +Red[S] Shirt. +"Is there any evidence that I stayed there over night with a geisha?" +"I saw your favorite old chicken go there early in the evening, and am +telling you so. You can't fool me!" +"No need for us of fooling anybody. I stayed there with Mr. Yoshikawa, +and whether any geisha had gone there early in the evening or not, +that's none of my business." +"Shut up!" Porcupine wallopped him one. Red Shirt tottered. +"This is outrageous! It is rough to resort to force before deciding the +right or wrong of it!" +"Outrageous indeed!" Another clout. "Nothing but wallopping will be +effective on you scheming guys." The remark was followed by a shower +of blows. I soaked Clown at the same time, and made him think he saw +the way to the Kingdom-Come. Finally the two crawled and crouched at +the foot of a cedar tree, and either from inability to move or to +see, because their eyes had become hazy, they did not even attempt to +break away. +"Want more? If so, here goes some more!" With that we gave him more +until he cried enough. "Want more? You?" we turned to Clown, and he +answered "Enough, of course." +"This is the punishment of heaven on you grovelling wretches. Keep +this in your head and be more careful hereafter. You can never talk +down justice." +The two said nothing. They were so thoroughly cowed that they could +not speak. +"I'm going to neither, run away nor hide. You'll find me at Minato-ya on +the beach up to five this evening. Bring police officers or any old +thing you want," said Porcupine. +"I'm not going to run away or hide either. Will wait for you at the same +place with Hotta. Take the case to the police station if you like, or do +as you damn please," I said, and we two walked our own way. +It was a little before seven when I returned to my room. I started +packing as soon as I was in the room, and the astonished old lady asked +me what I was trying to do. I'm going to Tokyo to fetch my Madam, I +said, and paid my bill. I boarded a train and came to Minato-ya on the +beach and found Porcupine asleep upstairs. I thought of writing my +resignation, but not knowing how, just scribbled off that "because of +personal affairs, I have to resign and return, to Tokyo. Yours truly," +and addressed and mailed it to the principal. +The steamer leaves the harbor at six in the evening. Porcupine and I, +tired out, slept like logs, and when we awoke it was two o'clock. We +asked the maid if the police had called on us, and she said no. Red +Shirt and Clown had not taken it to the police, eh? We laughed. +That night I and Porcupine left the town. The farther the vessel steamed +away from the shore, the more refreshed we felt. From Kobe to Tokyo we +boarded a through train and when we made Shimbashi, we breathed as if we +were once more in congenial human society. I parted from Porcupine at +the station, and have not had the chance of meeting him since. +I forgot to tell you about Kiyo. On my arrival at Tokyo, I rushed into +her house swinging my valise, before going to a hotel, with "Hello, +Kiyo, I'm back!" +"How good of you to return so soon!" she cried and hot tears streamed +down her cheeks. I was overjoyed, and declared that I would not go to +the country any more but would start housekeeping with Kiyo in Tokyo. +Some time afterward, some one helped me to a job as assistant engineer +at the tram car office. The salary was 25 yen a month, and the house +rent six. Although the house had not a magnificent front entrance, Kiyo +seemed quite satisfied, but, I am sorry to say, she was a victim of +pneumonia and died in February this year. On the day preceding her +death, she asked me to bedside, and said, "Please, Master Darling, if +Kiyo is dead, bury me in the temple yard of Master Darling. I will be +glad to wait in the grave for my Master Darling." +So Kiyo's grave is in the Yogen temple at Kobinata. +--(THE END)-- +[A: Insitent] +[B: queershaped] +[C: The original just had the Japanese character, Unicode U+5927, sans + description] +[D: aweinspiring] +[E: about about] +[F: atomosphere] +[G: Helloo] +[H: you go] +[I: goo-goo eyes] +[J: proper hyphenation unknown] +[K: pin-princking] +[L: Procupine] +[M: celabration] +[N: wans't] +[O: paper.] +[P: girl shead] +[Q: stumblieg] +[R: Rad] +End of Project Gutenberg's Botchan (Master Darling), by Kin-nosuke Natsume +*** END OF THIS PROJECT GUTENBERG EBOOK BOTCHAN (MASTER DARLING) *** +***** This file should be named 8868.txt or 8868.zip ***** +This and all associated files of various formats will be found in: + http://www.gutenberg.org/8/8/6/8868/ +Produced by David Starner and the Online Distributed Proofreading Team +Updated editions will replace the previous one--the old editions +will be renamed. +Creating the works from public domain print editions means that no +one owns a United States copyright in these works, so the Foundation +(and you!) can copy and distribute it in the United States without +permission and without paying copyright royalties. Special rules, +set forth in the General Terms of Use part of this license, apply to +copying and distributing Project Gutenberg-tm electronic works to +protect the PROJECT GUTENBERG-tm concept and trademark. Project +Gutenberg is a registered trademark, and may not be used if you +charge for the eBooks, unless you receive specific permission. If you +do not charge anything for copies of this eBook, complying with the +rules is very easy. You may use this eBook for nearly any purpose +such as creation of derivative works, reports, performances and +research. They may be modified and printed and given away--you may do +practically ANYTHING with public domain eBooks. Redistribution is +subject to the trademark license, especially commercial +redistribution. +*** START: FULL LICENSE *** +THE FULL PROJECT GUTENBERG LICENSE +PLEASE READ THIS BEFORE YOU DISTRIBUTE OR USE THIS WORK +To protect the Project Gutenberg-tm mission of promoting the free +distribution of electronic works, by using or distributing this work +(or any other work associated in any way with the phrase "Project +Gutenberg"), you agree to comply with all the terms of the Full Project +Gutenberg-tm License available with this file or online at + www.gutenberg.org/license. +Section 1. General Terms of Use and Redistributing Project Gutenberg-tm +electronic works +1.A. By reading or using any part of this Project Gutenberg-tm +electronic work, you indicate that you have read, understand, agree to +and accept all the terms of this license and intellectual property +(trademark/copyright) agreement. If you do not agree to abide by all +the terms of this agreement, you must cease using and return or destroy +all copies of Project Gutenberg-tm electronic works in your possession. +If you paid a fee for obtaining a copy of or access to a Project +Gutenberg-tm electronic work and you do not agree to be bound by the +terms of this agreement, you may obtain a refund from the person or +entity to whom you paid the fee as set forth in paragraph 1.E.8. +1.B. "Project Gutenberg" is a registered trademark. It may only be +used on or associated in any way with an electronic work by people who +agree to be bound by the terms of this agreement. There are a few +things that you can do with most Project Gutenberg-tm electronic works +even without complying with the full terms of this agreement. See +paragraph 1.C below. There are a lot of things you can do with Project +Gutenberg-tm electronic works if you follow the terms of this agreement +and help preserve free future access to Project Gutenberg-tm electronic +works. See paragraph 1.E below. +1.C. The Project Gutenberg Literary Archive Foundation ("the Foundation" +or PGLAF), owns a compilation copyright in the collection of Project +Gutenberg-tm electronic works. Nearly all the individual works in the +collection are in the public domain in the United States. If an +individual work is in the public domain in the United States and you are +located in the United States, we do not claim a right to prevent you from +copying, distributing, performing, displaying or creating derivative +works based on the work as long as all references to Project Gutenberg +are removed. Of course, we hope that you will support the Project +Gutenberg-tm mission of promoting free access to electronic works by +freely sharing Project Gutenberg-tm works in compliance with the terms of +this agreement for keeping the Project Gutenberg-tm name associated with +the work. You can easily comply with the terms of this agreement by +keeping this work in the same format with its attached full Project +Gutenberg-tm License when you share it without charge with others. +1.D. The copyright laws of the place where you are located also govern +what you can do with this work. Copyright laws in most countries are in +a constant state of change. If you are outside the United States, check +the laws of your country in addition to the terms of this agreement +before downloading, copying, displaying, performing, distributing or +creating derivative works based on this work or any other Project +Gutenberg-tm work. The Foundation makes no representations concerning +the copyright status of any work in any country outside the United +States. +1.E. Unless you have removed all references to Project Gutenberg: +1.E.1. The following sentence, with active links to, or other immediate +access to, the full Project Gutenberg-tm License must appear prominently +whenever any copy of a Project Gutenberg-tm work (any work on which the +phrase "Project Gutenberg" appears, or with which the phrase "Project +Gutenberg" is associated) is accessed, displayed, performed, viewed, +copied or distributed: +This eBook is for the use of anyone anywhere at no cost and with +almost no restrictions whatsoever. You may copy it, give it away or +re-use it under the terms of the Project Gutenberg License included +with this eBook or online at www.gutenberg.org +1.E.2. If an individual Project Gutenberg-tm electronic work is derived +from the public domain (does not contain a notice indicating that it is +posted with permission of the copyright holder), the work can be copied +and distributed to anyone in the United States without paying any fees +or charges. If you are redistributing or providing access to a work +with the phrase "Project Gutenberg" associated with or appearing on the +work, you must comply either with the requirements of paragraphs 1.E.1 +through 1.E.7 or obtain permission for the use of the work and the +Project Gutenberg-tm trademark as set forth in paragraphs 1.E.8 or +1.E.9. +1.E.3. If an individual Project Gutenberg-tm electronic work is posted +with the permission of the copyright holder, your use and distribution +must comply with both paragraphs 1.E.1 through 1.E.7 and any additional +terms imposed by the copyright holder. Additional terms will be linked +to the Project Gutenberg-tm License for all works posted with the +permission of the copyright holder found at the beginning of this work. +1.E.4. Do not unlink or detach or remove the full Project Gutenberg-tm +License terms from this work, or any files containing a part of this +work or any other work associated with Project Gutenberg-tm. +1.E.5. Do not copy, display, perform, distribute or redistribute this +electronic work, or any part of this electronic work, without +prominently displaying the sentence set forth in paragraph 1.E.1 with +active links or immediate access to the full terms of the Project +Gutenberg-tm License. +1.E.6. You may convert to and distribute this work in any binary, +compressed, marked up, nonproprietary or proprietary form, including any +word processing or hypertext form. However, if you provide access to or +distribute copies of a Project Gutenberg-tm work in a format other than +"Plain Vanilla ASCII" or other format used in the official version +posted on the official Project Gutenberg-tm web site (www.gutenberg.org), +you must, at no additional cost, fee or expense to the user, provide a +copy, a means of exporting a copy, or a means of obtaining a copy upon +request, of the work in its original "Plain Vanilla ASCII" or other +form. Any alternate format must include the full Project Gutenberg-tm +License as specified in paragraph 1.E.1. +1.E.7. Do not charge a fee for access to, viewing, displaying, +performing, copying or distributing any Project Gutenberg-tm works +unless you comply with paragraph 1.E.8 or 1.E.9. +1.E.8. You may charge a reasonable fee for copies of or providing +access to or distributing Project Gutenberg-tm electronic works provided +that +- You pay a royalty fee of 20% of the gross profits you derive from + the use of Project Gutenberg-tm works calculated using the method + you already use to calculate your applicable taxes. The fee is + owed to the owner of the Project Gutenberg-tm trademark, but he + has agreed to donate royalties under this paragraph to the + Project Gutenberg Literary Archive Foundation. Royalty payments + must be paid within 60 days following each date on which you + prepare (or are legally required to prepare) your periodic tax + returns. Royalty payments should be clearly marked as such and + sent to the Project Gutenberg Literary Archive Foundation at the + address specified in Section 4, "Information about donations to + the Project Gutenberg Literary Archive Foundation." +- You provide a full refund of any money paid by a user who notifies + you in writing (or by e-mail) within 30 days of receipt that s/he + does not agree to the terms of the full Project Gutenberg-tm + License. You must require such a user to return or + destroy all copies of the works possessed in a physical medium + and discontinue all use of and all access to other copies of + Project Gutenberg-tm works. +- You provide, in accordance with paragraph 1.F.3, a full refund of any + money paid for a work or a replacement copy, if a defect in the + electronic work is discovered and reported to you within 90 days + of receipt of the work. +- You comply with all other terms of this agreement for free + distribution of Project Gutenberg-tm works. +1.E.9. If you wish to charge a fee or distribute a Project Gutenberg-tm +electronic work or group of works on different terms than are set +forth in this agreement, you must obtain permission in writing from +both the Project Gutenberg Literary Archive Foundation and Michael +Hart, the owner of the Project Gutenberg-tm trademark. Contact the +Foundation as set forth in Section 3 below. +1.F. +1.F.1. Project Gutenberg volunteers and employees expend considerable +effort to identify, do copyright research on, transcribe and proofread +public domain works in creating the Project Gutenberg-tm +collection. Despite these efforts, Project Gutenberg-tm electronic +works, and the medium on which they may be stored, may contain +"Defects," such as, but not limited to, incomplete, inaccurate or +corrupt data, transcription errors, a copyright or other intellectual +property infringement, a defective or damaged disk or other medium, a +computer virus, or computer codes that damage or cannot be read by +your equipment. +1.F.2. LIMITED WARRANTY, DISCLAIMER OF DAMAGES - Except for the "Right +of Replacement or Refund" described in paragraph 1.F.3, the Project +Gutenberg Literary Archive Foundation, the owner of the Project +Gutenberg-tm trademark, and any other party distributing a Project +Gutenberg-tm electronic work under this agreement, disclaim all +liability to you for damages, costs and expenses, including legal +fees. YOU AGREE THAT YOU HAVE NO REMEDIES FOR NEGLIGENCE, STRICT +LIABILITY, BREACH OF WARRANTY OR BREACH OF CONTRACT EXCEPT THOSE +PROVIDED IN PARAGRAPH 1.F.3. YOU AGREE THAT THE FOUNDATION, THE +TRADEMARK OWNER, AND ANY DISTRIBUTOR UNDER THIS AGREEMENT WILL NOT BE +LIABLE TO YOU FOR ACTUAL, DIRECT, INDIRECT, CONSEQUENTIAL, PUNITIVE OR +INCIDENTAL DAMAGES EVEN IF YOU GIVE NOTICE OF THE POSSIBILITY OF SUCH +DAMAGE. +1.F.3. LIMITED RIGHT OF REPLACEMENT OR REFUND - If you discover a +defect in this electronic work within 90 days of receiving it, you can +receive a refund of the money (if any) you paid for it by sending a +written explanation to the person you received the work from. If you +received the work on a physical medium, you must return the medium with +your written explanation. The person or entity that provided you with +the defective work may elect to provide a replacement copy in lieu of a +refund. If you received the work electronically, the person or entity +providing it to you may choose to give you a second opportunity to +receive the work electronically in lieu of a refund. If the second copy +is also defective, you may demand a refund in writing without further +opportunities to fix the problem. +1.F.4. Except for the limited right of replacement or refund set forth +in paragraph 1.F.3, this work is provided to you 'AS-IS', WITH NO OTHER +WARRANTIES OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +WARRANTIES OF MERCHANTABILITY OR FITNESS FOR ANY PURPOSE. +1.F.5. Some states do not allow disclaimers of certain implied +warranties or the exclusion or limitation of certain types of damages. +If any disclaimer or limitation set forth in this agreement violates the +law of the state applicable to this agreement, the agreement shall be +interpreted to make the maximum disclaimer or limitation permitted by +the applicable state law. The invalidity or unenforceability of any +provision of this agreement shall not void the remaining provisions. +1.F.6. INDEMNITY - You agree to indemnify and hold the Foundation, the +trademark owner, any agent or employee of the Foundation, anyone +providing copies of Project Gutenberg-tm electronic works in accordance +with this agreement, and any volunteers associated with the production, +promotion and distribution of Project Gutenberg-tm electronic works, +harmless from all liability, costs and expenses, including legal fees, +that arise directly or indirectly from any of the following which you do +or cause to occur: (a) distribution of this or any Project Gutenberg-tm +work, (b) alteration, modification, or additions or deletions to any +Project Gutenberg-tm work, and (c) any Defect you cause. +Section 2. Information about the Mission of Project Gutenberg-tm +Project Gutenberg-tm is synonymous with the free distribution of +electronic works in formats readable by the widest variety of computers +including obsolete, old, middle-aged and new computers. It exists +because of the efforts of hundreds of volunteers and donations from +people in all walks of life. +Volunteers and financial support to provide volunteers with the +assistance they need are critical to reaching Project Gutenberg-tm's +goals and ensuring that the Project Gutenberg-tm collection will +remain freely available for generations to come. In 2001, the Project +Gutenberg Literary Archive Foundation was created to provide a secure +and permanent future for Project Gutenberg-tm and future generations. +To learn more about the Project Gutenberg Literary Archive Foundation +and how your efforts and donations can help, see Sections 3 and 4 +and the Foundation information page at www.gutenberg.org +Section 3. Information about the Project Gutenberg Literary Archive +Foundation +The Project Gutenberg Literary Archive Foundation is a non profit +501(c)(3) educational corporation organized under the laws of the +state of Mississippi and granted tax exempt status by the Internal +Revenue Service. The Foundation's EIN or federal tax identification +number is 64-6221541. Contributions to the Project Gutenberg +Literary Archive Foundation are tax deductible to the full extent +permitted by U.S. federal laws and your state's laws. +The Foundation's principal office is located at 4557 Melan Dr. S. +Fairbanks, AK, 99712., but its volunteers and employees are scattered +throughout numerous locations. Its business office is located at 809 +North 1500 West, Salt Lake City, UT 84116, (801) 596-1887. Email +contact links and up to date contact information can be found at the +Foundation's web site and official page at www.gutenberg.org/contact +For additional contact information: + Dr. Gregory B. Newby + Chief Executive and Director + gbnewby@pglaf.org +Section 4. Information about Donations to the Project Gutenberg +Literary Archive Foundation +Project Gutenberg-tm depends upon and cannot survive without wide +spread public support and donations to carry out its mission of +increasing the number of public domain and licensed works that can be +freely distributed in machine readable form accessible by the widest +array of equipment including outdated equipment. Many small donations +($1 to $5,000) are particularly important to maintaining tax exempt +status with the IRS. +The Foundation is committed to complying with the laws regulating +charities and charitable donations in all 50 states of the United +States. Compliance requirements are not uniform and it takes a +considerable effort, much paperwork and many fees to meet and keep up +with these requirements. We do not solicit donations in locations +where we have not received written confirmation of compliance. To +SEND DONATIONS or determine the status of compliance for any +particular state visit www.gutenberg.org/donate +While we cannot and do not solicit contributions from states where we +have not met the solicitation requirements, we know of no prohibition +against accepting unsolicited donations from donors in such states who +approach us with offers to donate. +International donations are gratefully accepted, but we cannot make +any statements concerning tax treatment of donations received from +outside the United States. U.S. laws alone swamp our small staff. +Please check the Project Gutenberg Web pages for current donation +methods and addresses. Donations are accepted in a number of other +ways including checks, online payments and credit card donations. +To donate, please visit: www.gutenberg.org/donate +Section 5. General Information About Project Gutenberg-tm electronic +works. +Professor Michael S. Hart was the originator of the Project Gutenberg-tm +concept of a library of electronic works that could be freely shared +with anyone. For forty years, he produced and distributed Project +Gutenberg-tm eBooks with only a loose network of volunteer support. +Project Gutenberg-tm eBooks are often created from several printed +editions, all of which are confirmed as Public Domain in the U.S. +unless a copyright notice is included. Thus, we do not necessarily +keep eBooks in compliance with any particular paper edition. +Most people start at our Web site which has the main PG search facility: + www.gutenberg.org +This Web site includes information about Project Gutenberg-tm, +including how to make donations to the Project Gutenberg Literary +Archive Foundation, how to help produce our new eBooks, and how to +subscribe to our email newsletter to hear about new eBooks. \ No newline at end of file diff --git a/tests/ut/data/mindrecord/testImageNetDataWhole/labels_map.txt b/tests/ut/data/mindrecord/testImageNetDataWhole/labels_map.txt index 73e9759b89..82082e1c74 100644 --- a/tests/ut/data/mindrecord/testImageNetDataWhole/labels_map.txt +++ b/tests/ut/data/mindrecord/testImageNetDataWhole/labels_map.txt @@ -1,4 +1,4 @@ -n00000005 0 data_line -n00000006 1 small_iron_box -n00000007 2 plastic_toothpicks -n00000002 3 orange +n00000005 0 +n00000006 1 +n00000007 2 +n00000002 3 diff --git a/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord b/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord new file mode 100644 index 0000000000..da4f853e2d Binary files /dev/null and b/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord differ diff --git a/tests/ut/data/profiler_data/JOB1/Framework.host.vm.graph_desc_info.0.slice_0 b/tests/ut/data/profiler_data/JOB1/Framework.host.vm.graph_desc_info.0.slice_0 new file mode 100644 index 0000000000..9b3e7b322c --- /dev/null +++ b/tests/ut/data/profiler_data/JOB1/Framework.host.vm.graph_desc_info.0.slice_0 @@ -0,0 +1,4 @@ +op_name:Default/Cast-op6 op_type:Cast input_id:0 input_format:DefaultFormat input_data_type:40 input_shape:"32,3,224,224" output_id:0 output_format:DefaultFormat output_data_type:39 output_shape:"32,3,224,224" +op_name:Default/TransData-op7 op_type:TransData input_id:0 input_format:DefaultFormat input_data_type:39 input_shape:"32,3,224,224" output_id:0 output_format:NC1HWC0 output_data_type:39 output_shape:"32,1,224,224,16" +op_name:Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 op_type:Cast input_id:0 input_format:FracZ input_data_type:40 input_shape:"49,4,16,16" output_id:0 output_format:FracZ output_data_type:39 output_shape:"49,4,16,16" +op_name:Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 op_type:Cast input_id:0 input_format:FracZ input_data_type:40 input_shape:"4,4,16,16" output_id:0 output_format:FracZ output_data_type:39 output_shape:"4,4,16,16" diff --git a/tests/ut/data/profiler_data/JOB1/Framework.host.vm.graph_desc_info.0.slice_0.done b/tests/ut/data/profiler_data/JOB1/Framework.host.vm.graph_desc_info.0.slice_0.done new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/ut/data/profiler_data/JOB1/Framework.host.vm.task_desc_info.0.slice_0 b/tests/ut/data/profiler_data/JOB1/Framework.host.vm.task_desc_info.0.slice_0 new file mode 100644 index 0000000000..e786b35e22 --- /dev/null +++ b/tests/ut/data/profiler_data/JOB1/Framework.host.vm.task_desc_info.0.slice_0 @@ -0,0 +1,4 @@ +Default/Cast-op6 32 51517 0 +Default/TransData-op7 32 51518 0 +Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 32 51519 0 +Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 4 51522 0 diff --git a/tests/ut/data/profiler_data/JOB1/pipeline_profiling_0.json b/tests/ut/data/profiler_data/JOB1/pipeline_profiling_0.json new file mode 100644 index 0000000000..2f5fdca724 --- /dev/null +++ b/tests/ut/data/profiler_data/JOB1/pipeline_profiling_0.json @@ -0,0 +1,55 @@ +{ + "sampling_interval": 10, + "op_info": [ + { + "op_id": 4, + "op_type": "TFReader", + "num_workers": 4, + "metrics": null, + "children": [3] + }, + { + "op_id": 3, + "op_type": "TFReader", + "num_workers": 4, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": null + }, + { + "op_id": 2, + "op_type": "TFReader", + "num_workers": 4, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": null + }, + { + "op_id": 1, + "op_type": "Shuffle", + "num_workers": 1, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": [2, 4] + }, + { + "op_id": 0, + "op_type": "Batch", + "num_workers": 4, + "metrics": null, + "children": [1] + } + ] +} \ No newline at end of file diff --git a/tests/ut/data/profiler_data/JOB1/training_trace.46.dev.profiler_default_tag.0.slice_0 b/tests/ut/data/profiler_data/JOB1/training_trace.46.dev.profiler_default_tag.0.slice_0 new file mode 100644 index 0000000000..2fd8ba0bc9 Binary files /dev/null and b/tests/ut/data/profiler_data/JOB1/training_trace.46.dev.profiler_default_tag.0.slice_0 differ diff --git a/tests/ut/data/profiler_data/JOB2/pipeline_profiling_0.json b/tests/ut/data/profiler_data/JOB2/pipeline_profiling_0.json new file mode 100644 index 0000000000..1357a35c83 --- /dev/null +++ b/tests/ut/data/profiler_data/JOB2/pipeline_profiling_0.json @@ -0,0 +1,48 @@ +{ + "sampling_interval": 10, + "op_info": [ + { + "op_id": 3, + "op_type": "TFReader", + "num_workers": 4, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": null + }, + { + "op_id": 2, + "op_type": "TFReader", + "num_workers": 4, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": null + }, + { + "op_id": 1, + "op_type": "Shuffle", + "num_workers": 1, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": [2, 3] + }, + { + "op_id": 0, + "op_type": "Batch", + "num_workers": 4, + "metrics": null, + "children": [1] + } + ] +} \ No newline at end of file diff --git a/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.graph_desc_info.0.slice_0 b/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.graph_desc_info.0.slice_0 new file mode 100644 index 0000000000..9b3e7b322c --- /dev/null +++ b/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.graph_desc_info.0.slice_0 @@ -0,0 +1,4 @@ +op_name:Default/Cast-op6 op_type:Cast input_id:0 input_format:DefaultFormat input_data_type:40 input_shape:"32,3,224,224" output_id:0 output_format:DefaultFormat output_data_type:39 output_shape:"32,3,224,224" +op_name:Default/TransData-op7 op_type:TransData input_id:0 input_format:DefaultFormat input_data_type:39 input_shape:"32,3,224,224" output_id:0 output_format:NC1HWC0 output_data_type:39 output_shape:"32,1,224,224,16" +op_name:Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 op_type:Cast input_id:0 input_format:FracZ input_data_type:40 input_shape:"49,4,16,16" output_id:0 output_format:FracZ output_data_type:39 output_shape:"49,4,16,16" +op_name:Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 op_type:Cast input_id:0 input_format:FracZ input_data_type:40 input_shape:"4,4,16,16" output_id:0 output_format:FracZ output_data_type:39 output_shape:"4,4,16,16" diff --git a/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.point.0.slice_0 b/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.point.0.slice_0 new file mode 100644 index 0000000000..01bcf6f3ca --- /dev/null +++ b/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.point.0.slice_0 @@ -0,0 +1,2 @@ +1 Default/Cast-op6 +2 Default/TransData-op7 diff --git a/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.task_desc_info.0.slice_0 b/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.task_desc_info.0.slice_0 new file mode 100644 index 0000000000..e49673789f --- /dev/null +++ b/tests/ut/data/profiler_data/JOB4/data/Framework.host.vm.task_desc_info.0.slice_0 @@ -0,0 +1,4 @@ +Default/Cast-op6 32 1 0 +Default/TransData-op7 32 2 0 +Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 32 3 0 +Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 4 4 0 diff --git a/tests/ut/data/profiler_data/JOB_AICPU/data/DATA_PREPROCESS.dev.AICPU.0.slice_0 b/tests/ut/data/profiler_data/JOB_AICPU/data/DATA_PREPROCESS.dev.AICPU.0.slice_0 new file mode 100644 index 0000000000..30fe89d6e9 Binary files /dev/null and b/tests/ut/data/profiler_data/JOB_AICPU/data/DATA_PREPROCESS.dev.AICPU.0.slice_0 differ diff --git a/tests/ut/data/profiler_data/JOB_AICPU/expect/output_data_preprocess_aicpu_0.txt b/tests/ut/data/profiler_data/JOB_AICPU/expect/output_data_preprocess_aicpu_0.txt new file mode 100644 index 0000000000..fe83966d02 --- /dev/null +++ b/tests/ut/data/profiler_data/JOB_AICPU/expect/output_data_preprocess_aicpu_0.txt @@ -0,0 +1,5 @@ +serial_number node_type_name total_time(ms) dispatch_time(ms) run_start run_end +1 InitData 1.567 0.1 2298200409 2298200538 +2 GetNext 0.989 0.087 2302769932 2302769980 +3 TruncatedNormal 1.566 0.105 4098200409 4098200538 +AI CPU Total Time(ms): 4.122000 diff --git a/tests/ut/data/profiler_data/container/0/data/Framework.host.vm.graph_desc_info.0.JOB2.slice_0 b/tests/ut/data/profiler_data/container/0/data/Framework.host.vm.graph_desc_info.0.JOB2.slice_0 new file mode 100644 index 0000000000..9b3e7b322c --- /dev/null +++ b/tests/ut/data/profiler_data/container/0/data/Framework.host.vm.graph_desc_info.0.JOB2.slice_0 @@ -0,0 +1,4 @@ +op_name:Default/Cast-op6 op_type:Cast input_id:0 input_format:DefaultFormat input_data_type:40 input_shape:"32,3,224,224" output_id:0 output_format:DefaultFormat output_data_type:39 output_shape:"32,3,224,224" +op_name:Default/TransData-op7 op_type:TransData input_id:0 input_format:DefaultFormat input_data_type:39 input_shape:"32,3,224,224" output_id:0 output_format:NC1HWC0 output_data_type:39 output_shape:"32,1,224,224,16" +op_name:Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 op_type:Cast input_id:0 input_format:FracZ input_data_type:40 input_shape:"49,4,16,16" output_id:0 output_format:FracZ output_data_type:39 output_shape:"49,4,16,16" +op_name:Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 op_type:Cast input_id:0 input_format:FracZ input_data_type:40 input_shape:"4,4,16,16" output_id:0 output_format:FracZ output_data_type:39 output_shape:"4,4,16,16" diff --git a/tests/ut/data/profiler_data/container/0/data/Framework.host.vm.task_desc_info.0.JOB2.slice_0 b/tests/ut/data/profiler_data/container/0/data/Framework.host.vm.task_desc_info.0.JOB2.slice_0 new file mode 100644 index 0000000000..e786b35e22 --- /dev/null +++ b/tests/ut/data/profiler_data/container/0/data/Framework.host.vm.task_desc_info.0.JOB2.slice_0 @@ -0,0 +1,4 @@ +Default/Cast-op6 32 51517 0 +Default/TransData-op7 32 51518 0 +Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5 32 51519 0 +Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28 4 51522 0 diff --git a/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv new file mode 100755 index 0000000000..ec117c5346 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv @@ -0,0 +1,11 @@ +full_op_time,execution_time +Default/AtomicAddrClean-op104,0.00133 +Default/AtomicAddrClean-op105,0.000987 +Default/AtomicAddrClean-op106,0.001129 +Default/Cast-op10,0.00466 +Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op12,0.002366 +Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Cast-op53,0.004879 +Default/TransData-op11,0.006366 +Gradients/Default/network-WithLossCell/_backbone-LeNet5/gradReshape/TransData-op44,0.006782 +Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op13,0.05651 +Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/MatMul-op9,0.370864 diff --git a/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv new file mode 100755 index 0000000000..56bf368a6c --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv @@ -0,0 +1,6 @@ +op_type,execution_time,execution_frequency,percent +AtomicAddrClean,0.007283,6,0.49 +Cast,0.053395,13,3.63 +TransData,0.121800,5,8.23 +Conv2D,0.063656,2,4.33 +MatMul,1.085982,9,73.80 diff --git a/tests/ut/data/profiler_data/profiler/framework_raw_0.csv b/tests/ut/data/profiler_data/profiler/framework_raw_0.csv new file mode 100755 index 0000000000..762bc69346 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/framework_raw_0.csv @@ -0,0 +1,5 @@ +task_id,stream_id,block_dim,full_op_name,op_name,op_type,subgraph,op_info +51517,0,32,Default/Cast-op6,Cast-op6,Cast,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,3,224,224""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,3,224,224""}}" +51518,0,32,Default/TransData-op7,TransData-op7,TransData,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,3,224,224""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,224,224,16""}}" +51519,0,32,Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5,Cast-op5,Cast,Default,"{""input_0"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""49,4,16,16""}, ""output_0"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""49,4,16,16""}}" +51522,0,4,Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28,Cast-op28,Cast,Default,"{""input_0"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""4,4,16,16""}, ""output_0"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""4,4,16,16""}}" diff --git a/tests/ut/data/profiler_data/profiler/framework_raw_1.csv b/tests/ut/data/profiler_data/profiler/framework_raw_1.csv new file mode 100755 index 0000000000..5b60ecef49 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/framework_raw_1.csv @@ -0,0 +1,11 @@ +task_id,stream_id,block_dim,full_op_name,op_name,op_type,subgraph,op_info +30290,0,1,Default/AtomicAddrClean-op104,AtomicAddrClean-op104,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": """"}}" +30295,0,1,Default/AtomicAddrClean-op105,AtomicAddrClean-op105,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""10""}}" +30300,0,1,Default/AtomicAddrClean-op106,AtomicAddrClean-op106,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84""}}" +30268,0,32,Default/Cast-op10,Cast-op10,Cast,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,1,32,32""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}}" +30271,0,9,Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op12,Cast-op12,Cast,Default,"{""input_0"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""25,1,16,16""}, ""output_0"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""25,1,16,16""}}" +30320,0,32,Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Cast-op53,Cast-op53,Cast,Gradients,"{""input_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,1,28,28,16""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,28,28,16""}}" +30269,0,32,Default/TransData-op11,TransData-op11,TransData,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}}" +30308,0,32,Gradients/Default/network-WithLossCell/_backbone-LeNet5/gradReshape/TransData-op44,TransData-op44,TransData,Gradients,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,16,5,5""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,5,5,16""}}" +30272,0,32,Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op13,Conv2D-op13,Conv2D,Default,"{""input_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32,16""}, ""input_1"": {""format"": ""FracZ"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""25,1,16,16""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,28,28,16""}}" +30286,0,1,Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/MatMul-op9,MatMul-op9,MatMul,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,120""}, ""input_1"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84,120""}, ""input_2"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,84""}}" diff --git a/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv b/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv new file mode 100644 index 0000000000..57f8ddaf68 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv @@ -0,0 +1,5 @@ +op_id,op_type,num_workers,output_queue_size,output_queue_average_size,output_queue_length,output_queue_usage_rate,sample_interval,parent_id,children_id +0,Batch,4,,,,,10,,[1] +1,Shuffle,1,"[10, 20, 30]",20.0,64,0.3125,10,0,"[2, 3]" +2,TFReader,4,"[10, 20, 30]",20.0,64,0.3125,10,1, +3,TFReader,4,"[10, 20, 30]",20.0,64,0.3125,10,1, diff --git a/tests/ut/data/profiler_data/profiler/pipeline_profiling_1.json b/tests/ut/data/profiler_data/profiler/pipeline_profiling_1.json new file mode 100644 index 0000000000..2f5fdca724 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/pipeline_profiling_1.json @@ -0,0 +1,55 @@ +{ + "sampling_interval": 10, + "op_info": [ + { + "op_id": 4, + "op_type": "TFReader", + "num_workers": 4, + "metrics": null, + "children": [3] + }, + { + "op_id": 3, + "op_type": "TFReader", + "num_workers": 4, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": null + }, + { + "op_id": 2, + "op_type": "TFReader", + "num_workers": 4, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": null + }, + { + "op_id": 1, + "op_type": "Shuffle", + "num_workers": 1, + "metrics": { + "output_queue": { + "size": [10, 20, 30], + "length": 64 + } + }, + "children": [2, 4] + }, + { + "op_id": 0, + "op_type": "Batch", + "num_workers": 4, + "metrics": null, + "children": [1] + } + ] +} \ No newline at end of file diff --git a/tests/ut/data/profiler_data/profiler/step_trace_raw_0_detail_time.csv b/tests/ut/data/profiler_data/profiler/step_trace_raw_0_detail_time.csv new file mode 100755 index 0000000000..0e923ccad8 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/step_trace_raw_0_detail_time.csv @@ -0,0 +1,22 @@ +step_num,start_point,end_point,total,fp_point,bp_point,iteration_interval,fp_and_bp,tail,stream_10_parallel_0_start_point,stream_10_parallel_0_end_point,stream_10_parallel_0,stream_10_parallel_1_start_point,stream_10_parallel_1_end_point,stream_10_parallel_1,stream_10_parallel_2_start_point,stream_10_parallel_2_end_point,stream_10_parallel_2,stream_11_parallel_0_start_point,stream_11_parallel_0_end_point,stream_11_parallel_0 +1,45000030081,45004033128,4003047,45000030081,45001733025,0,1702944,2300103,45000042679,45000060275,17596,45001048152,45001346254,298102,45002247411,45002448354,200943,45000049687,45000075987,26300 +2,45004033128,45017085658,13052530,45013070937,45014785314,9037809,1714377,2300344,45013085379,45013105429,20050,45014087119,45014385136,298017,45015297166,45015504449,207283,45013084925,45013118334,33409 +3,45017085658,45030119392,13033734,45026116231,45027818443,9030573,1702212,2300949,45026131909,45026150554,18645,45027134392,45027430418,296026,45028337093,45028537767,200674,45026129217,45026160937,31720 +4,45030119392,45043158607,13039215,45039152348,45040856975,9032956,1704627,2301632,45039169890,45039188966,19076,45040169338,45040466770,297432,45041374122,45041567754,193632,45039171681,45039193865,22184 +5,45043158607,45056198128,13039521,45052190932,45053898028,9032325,1707096,2300100,45052207675,45052222642,14967,45053204442,45053505540,301098,45054413207,45054616536,203329,45052201931,45052237599,35668 +6,45056198128,45069239564,13041436,45065233106,45066939463,9034978,1706357,2300101,45065245482,45065272534,27052,45066248423,45066546419,297996,45067455113,45067659145,204032,45065245817,45065279896,34079 +7,45069239564,45082281383,13041819,45078274997,45079980193,9035433,1705196,2301190,45078293910,45078312935,19025,45079287754,45079593841,306087,45080492957,45080691395,198438,45078292067,45078322277,30210 +8,45082281383,45095336378,13054995,45091321488,45093036084,9040105,1714596,2300294,45091338628,45091359138,20510,45092338469,45092638994,300525,45093554195,45093747470,193275,45091341356,45091369667,28311 +9,45095336378,45108372225,13035847,45104363079,45106071009,9026701,1707930,2301216,45104374524,45104400088,25564,45105378751,45105683029,304278,45106587481,45106785336,197855,45104382131,45104410852,28721 +10,45108372225,45121412413,13040188,45117401873,45119111301,9029648,1709428,2301112,45117417721,45117439668,21947,45118413083,45118718050,304967,45119629347,45119829996,200649,45117421502,45117446718,25216 +11,45121412413,45134477662,13065249,45130459598,45132175723,9047185,1716125,2301939,45130478168,45130498936,20768,45131477957,45131775220,297263,45132691645,45132893707,202062,45130470285,45130501652,31367 +12,45134477662,45147533298,13055636,45143521860,45145232553,9044198,1710693,2300745,45143533787,45143557293,23506,45144533554,45144841545,307991,45145744997,45145952255,207258,45143537383,45143563466,26083 +13,45147533298,45160588134,13054836,45156570201,45158286694,9036903,1716493,2301440,45156581069,45156609506,28437,45157581617,45157880841,299224,45158806166,45158999875,193709,45156589050,45156615664,26614 +14,45160588134,45173640064,13051930,45169625906,45171339426,9037772,1713520,2300638,45169637432,45169661754,24322,45170639482,45170940949,301467,45171853721,45172056606,202885,45169644605,45169673410,28805 +15,45173640064,45186671634,13031570,45182666696,45184371430,9026632,1704734,2300204,45182678355,45182698471,20116,45183679568,45183981082,301514,45184887156,45185083035,195879,45182680062,45182708455,28393 +16,45186671634,45199720448,13048814,45195714716,45197420410,9043082,1705694,2300038,45195728993,45195754646,25653,45196732493,45197028048,295555,45197934921,45198139237,204316,45195733069,45195764102,31033 +17,45199720448,45212762605,13042157,45208758416,45210460864,9037968,1702448,2301741,45208771010,45208790367,19357,45209773548,45210074988,301440,45210978277,45211173577,195300,45208773143,45208803280,30137 +18,45212762605,45225814601,13051996,45221801814,45223514580,9039209,1712766,2300021,45221815911,45221839644,23733,45222819211,45223114544,295333,45224031469,45224234043,202574,45221812106,45221849103,36997 +19,45225814601,45238848430,13033829,45234842015,45236548356,9027414,1706341,2300074,45234855444,45234876469,21025,45235853358,45236160825,307467,45237063061,45237260964,197903,45234857141,45234882976,25835 +20,45238848430,45251899738,13051308,45247879385,45249598280,9030955,1718895,2301458,45247896725,45247917316,20591,45248896361,45249193681,297320,45250117916,45250315651,197735,45247894228,45247926723,32495 +-,45121436513,45134482124,13045611,45130471874,45132181322,9035360,1709449,2300802,45130486422,45130508229,21808,45131486785,45131787364,300579,45132697369,45132897305,199936,45130487458,45130517315,29857 diff --git a/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv b/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv new file mode 100755 index 0000000000..9e97e499b2 --- /dev/null +++ b/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv @@ -0,0 +1,42 @@ +step_num,start_point,end_point,total,fp_point,bp_point,iteration_interval,fp_and_bp,tail,stream_10_parallel_0_start_point,stream_10_parallel_0_end_point,stream_10_parallel_0,stream_10_parallel_1_start_point,stream_10_parallel_1_end_point,stream_10_parallel_1,stream_10_parallel_2_start_point,stream_10_parallel_2_end_point,stream_10_parallel_2,stream_11_parallel_0_start_point,stream_11_parallel_0_end_point,stream_11_parallel_0 +1,45000025226,45004034753,4009527,45000025226,45001734362,0,1709136,2300391,45000044023,45000060886,16863,45001043581,45001343373,299792,45002254048,45002452830,198782,45000043807,45000065736,21929 +2,45004034753,45017091420,13056667,45013073790,45014789509,9039037,1715719,2301911,45013085205,45013104210,19005,45014086339,45014393261,306922,45015299546,45015501808,202262,45013085040,45013119810,34770 +3,45017091420,45030144372,13052952,45026123867,45027843651,9032447,1719784,2300721,45026138546,45026154524,15978,45027135742,45027437486,301744,45028363120,45028560901,197781,45026136046,45026171363,35317 +4,45030144372,45043184486,13040114,45039173149,45040883087,9028777,1709938,2301399,45039190927,45039209948,19021,45040185915,45040484897,298982,45041399754,45041594775,195021,45039192768,45039221423,28655 +5,45043184486,45056241064,13056578,45052223555,45053940709,9039069,1717154,2300355,45052241736,45052262186,20450,45053239605,45053540866,301261,45054452604,45054654505,201901,45052233932,45052265774,31842 +6,45056241064,45069291346,13050282,45065278144,45066991121,9037080,1712977,2300225,45065293660,45065316136,22476,45066289480,45066589910,300430,45067511002,45067701731,190729,45065293679,45065321296,27617 +7,45069291346,45082344927,13053581,45078335376,45080043268,9044030,1707892,2301659,45078353164,45078365382,12218,45079354748,45079648384,293636,45080557453,45080760374,202921,45078353030,45078384530,31500 +8,45082344927,45095382554,13037627,45091368697,45093080797,9023770,1712100,2301757,45091381244,45091405208,23964,45092382630,45092684285,301655,45093590961,45093796698,205737,45091381199,45091413840,32641 +9,45095382554,45108433947,13051393,45104419947,45106132133,9037393,1712186,2301814,45104432587,45104457476,24889,45105431458,45105735476,304018,45106651213,45106845305,194092,45104435207,45104466677,31470 +10,45108433947,45121486591,13052644,45117469353,45119185969,9035406,1716616,2300622,45117483627,45117504869,21242,45118483411,45118788540,305129,45119696660,45119898575,201915,45117485587,45117510985,25398 +11,45121486591,45134546571,13059980,45130528618,45132244809,9042027,1716191,2301762,45130539730,45130561122,21392,45131538695,45131846715,308020,45132759789,45132960848,201059,45130545378,45130569412,24034 +12,45134546571,45147608222,13061651,45143597023,45145307273,9050452,1710250,2300949,45143615771,45143631460,15689,45144610592,45144910736,300144,45145818642,45146024326,205684,45143613528,45143640223,26695 +13,45147608222,45160663790,13055568,45156648696,45158362923,9040474,1714227,2300867,45156663193,45156685466,22273,45157661576,45157963074,301498,45158881212,45159074431,193219,45156667038,45156694912,27874 +14,45160663790,45173707626,13043836,45169694535,45171407246,9030745,1712711,2300380,45169710667,45169727936,17269,45170705802,45171013806,308004,45171924100,45172120273,196173,45169708524,45169739038,30514 +15,45173707626,45186754860,13047234,45182750254,45184454036,9042628,1703782,2300824,45182765445,45182789799,24354,45183761335,45184065169,303834,45184973312,45185170444,197132,45182769451,45182799598,30147 +16,45186754860,45199798718,13043858,45195792271,45197497908,9037411,1705637,2300810,45195804771,45195827915,23144,45196804016,45197108243,304227,45198013357,45198209858,196501,45195806656,45195841674,35018 +17,45199798718,45212854993,13056275,45208834355,45210553378,9035637,1719023,2301615,45208850179,45208865588,15409,45209851018,45210151436,300418,45211073169,45211271792,198623,45208847052,45208876998,29946 +18,45212854993,45225893712,13038719,45221888939,45223593704,9033946,1704765,2300008,45221901732,45221924983,23251,45222908795,45223203590,294795,45224105803,45224313354,207551,45221899792,45221938802,39010 +19,45225893712,45238941242,13047530,45234926295,45236640454,9032583,1714159,2300788,45234938628,45234957237,18609,45235942710,45236239983,297273,45237159532,45237356140,196608,45234938330,45234976170,37840 +20,45238941242,45251979177,13037935,45247977674,45249678116,9036432,1700442,2301061,45247990919,45248013476,22557,45248991451,45249294742,303291,45250195733,45250395760,200027,45247988950,45248024969,36019 +21,45251979177,45265018752,13039575,45261005416,45262718472,9026239,1713056,2300280,0,0,0,0,0,0,0,0,0,0,0,0 +22,45265018752,45278062782,13044030,45274047185,45275762095,9028433,1714910,2300687,0,0,0,0,0,0,0,0,0,0,0,0 +23,45278062782,45291105708,13042926,45287094000,45288805223,9031218,1711223,2300485,0,0,0,0,0,0,0,0,0,0,0,0 +24,45291105708,45304155918,13050210,45300150844,45301854040,9045136,1703196,2301878,0,0,0,0,0,0,0,0,0,0,0,0 +25,45304155918,45317206695,13050777,45313191948,45314905714,9036030,1713766,2300981,0,0,0,0,0,0,0,0,0,0,0,0 +26,45317206695,45330265105,13058410,45326256021,45327964581,9049326,1708560,2300524,0,0,0,0,0,0,0,0,0,0,0,0 +27,45330265105,45343324012,13058907,45339305124,45341023739,9040019,1718615,2300273,0,0,0,0,0,0,0,0,0,0,0,0 +28,45343324012,45356374571,13050559,45352366211,45354073401,9042199,1707190,2301170,0,0,0,0,0,0,0,0,0,0,0,0 +29,45356374571,45369429514,13054943,45365417827,45367128283,9043256,1710456,2301231,0,0,0,0,0,0,0,0,0,0,0,0 +30,45369429514,45382479199,13049685,45378476397,45380177297,9046883,1700900,2301902,0,0,0,0,0,0,0,0,0,0,0,0 +31,45382479199,45395530376,13051177,45391510137,45393229377,9030938,1719240,2300999,0,0,0,0,0,0,0,0,0,0,0,0 +32,45395530376,45408571765,13041389,45404559082,45406270720,9028706,1711638,2301045,0,0,0,0,0,0,0,0,0,0,0,0 +33,45408571765,45421635175,13063410,45417619223,45419334221,9047458,1714998,2300954,0,0,0,0,0,0,0,0,0,0,0,0 +34,45421635175,45434672219,13037044,45430669445,45432371312,9034270,1701867,2300907,0,0,0,0,0,0,0,0,0,0,0,0 +35,45434672219,45447714036,13041817,45443704548,45445413852,9032329,1709304,2300184,0,0,0,0,0,0,0,0,0,0,0,0 +36,45447714036,45460765153,13051117,45456753675,45458463701,9039639,1710026,2301452,0,0,0,0,0,0,0,0,0,0,0,0 +37,45460765153,45473829105,13063952,45469808281,45471527400,9043128,1719119,2301705,0,0,0,0,0,0,0,0,0,0,0,0 +38,45473829105,45486884190,13055085,45482867237,45484583534,9038132,1716297,2300656,0,0,0,0,0,0,0,0,0,0,0,0 +39,45486884190,45499928571,13044381,45495917628,45497627921,9033438,1710293,2300650,0,0,0,0,0,0,0,0,0,0,0,0 +40,45499928571,45512973815,13045244,45508968990,45510673699,9040419,1704709,2300116,0,0,0,0,0,0,0,0,0,0,0,0 +-,45251983006,45265032725,13049720,45261020353,45262731761,9037347,1711408,2300964,21986676455,21986686280,9825,21987163213,21987310272,147058,21987754537,21987851587,97050,21986676441,21986691731,15290 diff --git a/tests/ut/python/automl/case.py b/tests/ut/python/automl/case.py deleted file mode 100644 index 745376277c..0000000000 --- a/tests/ut/python/automl/case.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Test case.""" -import numpy as np - -import mindspore -import mindspore.nn as nn -from mindspore import Tensor, context - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 3, 3) - self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True) - self.layers = (self.conv1, self.conv2) - - def construct(self, x, index): - x = self.layers[index](x) - y = self.conv1(x) - return x + y - - -def test_case(): - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) - net = Net() - data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32) - idx = Tensor(1, mindspore.int32) - net(data, idx) diff --git a/tests/ut/python/automl/test_case.py b/tests/ut/python/automl/test_case.py new file mode 100644 index 0000000000..39bcebca02 --- /dev/null +++ b/tests/ut/python/automl/test_case.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test case.""" +import numpy as np + +import mindspore +import mindspore.nn as nn +from mindspore import Tensor, context + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 3, 3) + self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True) + self.layers = (self.conv1, self.conv2) + + def construct(self, x, index): + x = self.layers[index](x) + return 2 + x + + +def test_case(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = Net() + data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32) + idx = Tensor(1, mindspore.int32) + net(data, idx) diff --git a/tests/ut/python/dataset/test_2ops.py b/tests/ut/python/dataset/test_2ops.py index ef60a42e27..cf781d6dfd 100644 --- a/tests/ut/python/dataset/test_2ops.py +++ b/tests/ut/python/dataset/test_2ops.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from util import save_and_check - import mindspore.dataset as ds from mindspore import log as logger +from util import save_and_check_dict DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" -COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", - "col_sint16", "col_sint32", "col_sint64"] GENERATE_GOLDEN = False @@ -33,9 +30,6 @@ def test_2ops_repeat_shuffle(): repeat_count = 2 buffer_size = 5 seed = 0 - parameters = {"params": {'repeat_count': repeat_count, - 'buffer_size': buffer_size, - 'seed': seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) @@ -44,7 +38,7 @@ def test_2ops_repeat_shuffle(): data1 = data1.shuffle(buffer_size=buffer_size) filename = "test_2ops_repeat_shuffle.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_2ops_shuffle_repeat(): @@ -56,10 +50,6 @@ def test_2ops_shuffle_repeat(): repeat_count = 2 buffer_size = 5 seed = 0 - parameters = {"params": {'repeat_count': repeat_count, - 'buffer_size': buffer_size, - 'reshuffle_each_iteration': False, - 'seed': seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) @@ -68,7 +58,7 @@ def test_2ops_shuffle_repeat(): data1 = data1.repeat(repeat_count) filename = "test_2ops_shuffle_repeat.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_2ops_repeat_batch(): @@ -79,8 +69,6 @@ def test_2ops_repeat_batch(): # define parameters repeat_count = 2 batch_size = 5 - parameters = {"params": {'repeat_count': repeat_count, - 'batch_size': batch_size}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) @@ -88,7 +76,7 @@ def test_2ops_repeat_batch(): data1 = data1.batch(batch_size, drop_remainder=True) filename = "test_2ops_repeat_batch.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_2ops_batch_repeat(): @@ -99,8 +87,6 @@ def test_2ops_batch_repeat(): # define parameters repeat_count = 2 batch_size = 5 - parameters = {"params": {'repeat_count': repeat_count, - 'batch_size': batch_size}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) @@ -108,7 +94,7 @@ def test_2ops_batch_repeat(): data1 = data1.repeat(repeat_count) filename = "test_2ops_batch_repeat.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_2ops_batch_shuffle(): @@ -120,9 +106,6 @@ def test_2ops_batch_shuffle(): buffer_size = 5 seed = 0 batch_size = 2 - parameters = {"params": {'buffer_size': buffer_size, - 'seed': seed, - 'batch_size': batch_size}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) @@ -131,7 +114,7 @@ def test_2ops_batch_shuffle(): data1 = data1.shuffle(buffer_size=buffer_size) filename = "test_2ops_batch_shuffle.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_2ops_shuffle_batch(): @@ -143,9 +126,6 @@ def test_2ops_shuffle_batch(): buffer_size = 5 seed = 0 batch_size = 2 - parameters = {"params": {'buffer_size': buffer_size, - 'seed': seed, - 'batch_size': batch_size}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) @@ -154,7 +134,7 @@ def test_2ops_shuffle_batch(): data1 = data1.batch(batch_size, drop_remainder=True) filename = "test_2ops_shuffle_batch.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_Tensor.py b/tests/ut/python/dataset/test_Tensor.py index 6235a567b4..f159f46a22 100644 --- a/tests/ut/python/dataset/test_Tensor.py +++ b/tests/ut/python/dataset/test_Tensor.py @@ -31,13 +31,13 @@ def test_basic(): arr[0] = 0 x = np.array([0, 2, 3, 4, 5]) - assert np.array_equal(x, arr) + np.testing.assert_array_equal(x, arr) assert n.type() == cde.DataType("int64") arr2 = n.as_array() arr[0] = 2 x = np.array([2, 2, 3, 4, 5]) - assert np.array_equal(x, arr2) + np.testing.assert_array_equal(x, arr2) assert n.type() == cde.DataType("int64") assert arr.__array_interface__['data'] == arr2.__array_interface__['data'] @@ -47,12 +47,12 @@ def test_strides(): n1 = cde.Tensor(x[:, 1]) arr = np.array(n1, copy=False) - assert np.array_equal(x[:, 1], arr) + np.testing.assert_array_equal(x[:, 1], arr) n2 = cde.Tensor(x.transpose()) arr = np.array(n2, copy=False) - assert np.array_equal(x.transpose(), arr) + np.testing.assert_array_equal(x.transpose(), arr) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_apply.py b/tests/ut/python/dataset/test_apply.py index 3e963780bb..e731ddcbc1 100644 --- a/tests/ut/python/dataset/test_apply.py +++ b/tests/ut/python/dataset/test_apply.py @@ -41,7 +41,7 @@ def test_apply_generator_case(): data2 = data2.batch(4) for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - assert np.array_equal(item1["data"], item2["data"]) + np.testing.assert_array_equal(item1["data"], item2["data"]) def test_apply_imagefolder_case(): @@ -64,7 +64,7 @@ def test_apply_imagefolder_case(): data2 = data2.repeat(2) for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - assert np.array_equal(item1["image"], item2["image"]) + np.testing.assert_array_equal(item1["image"], item2["image"]) def test_apply_flow_case_0(id_=0): diff --git a/tests/ut/python/dataset/test_autocontrast.py b/tests/ut/python/dataset/test_autocontrast.py index d212994e6e..6c3fc671d7 100644 --- a/tests/ut/python/dataset/test_autocontrast.py +++ b/tests/ut/python/dataset/test_autocontrast.py @@ -16,20 +16,22 @@ Testing AutoContrast op in DE """ import numpy as np - import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.py_transforms as F +import mindspore.dataset.transforms.vision.c_transforms as C from mindspore import log as logger -from util import visualize_list, diff_mse +from util import visualize_list, diff_mse, save_and_check_md5 DATA_DIR = "../data/dataset/testImageNetData/train/" +GENERATE_GOLDEN = False + -def test_auto_contrast(plot=False): +def test_auto_contrast_py(plot=False): """ Test AutoContrast """ - logger.info("Test AutoContrast") + logger.info("Test AutoContrast Python Op") # Original Images ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) @@ -56,7 +58,7 @@ def test_auto_contrast(plot=False): transforms_auto_contrast = F.ComposeOp([F.Decode(), F.Resize((224, 224)), - F.AutoContrast(), + F.AutoContrast(cutoff=10.0, ignore=[10, 20]), F.ToTensor()]) ds_auto_contrast = ds.map(input_columns="image", @@ -78,9 +80,260 @@ def test_auto_contrast(plot=False): mse[i] = diff_mse(images_auto_contrast[i], images_original[i]) logger.info("MSE= {}".format(str(np.mean(mse)))) + # Compare with expected md5 from images + filename = "autcontrast_01_result_py.npz" + save_and_check_md5(ds_auto_contrast, filename, generate_golden=GENERATE_GOLDEN) + if plot: visualize_list(images_original, images_auto_contrast) +def test_auto_contrast_c(plot=False): + """ + Test AutoContrast C Op + """ + logger.info("Test AutoContrast C Op") + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224))]) + python_op = F.AutoContrast(cutoff=10.0, ignore=[10, 20]) + c_op = C.AutoContrast(cutoff=10.0, ignore=[10, 20]) + transforms_op = F.ComposeOp([lambda img: F.ToPIL()(img.astype(np.uint8)), + python_op, + np.array])() + + ds_auto_contrast_py = ds.map(input_columns="image", + operations=transforms_op) + + ds_auto_contrast_py = ds_auto_contrast_py.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_py): + if idx == 0: + images_auto_contrast_py = image + else: + images_auto_contrast_py = np.append(images_auto_contrast_py, + image, + axis=0) + + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224))]) + + ds_auto_contrast_c = ds.map(input_columns="image", + operations=c_op) + + ds_auto_contrast_c = ds_auto_contrast_c.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_c): + if idx == 0: + images_auto_contrast_c = image + else: + images_auto_contrast_c = np.append(images_auto_contrast_c, + image, + axis=0) + + num_samples = images_auto_contrast_c.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_auto_contrast_c[i], images_auto_contrast_py[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + np.testing.assert_equal(np.mean(mse), 0.0) + + # Compare with expected md5 from images + filename = "autcontrast_01_result_c.npz" + save_and_check_md5(ds_auto_contrast_c, filename, generate_golden=GENERATE_GOLDEN) + + if plot: + visualize_list(images_auto_contrast_c, images_auto_contrast_py, visualize_mode=2) + + +def test_auto_contrast_one_channel_c(plot=False): + """ + Test AutoContrast C op with one channel + """ + logger.info("Test AutoContrast C Op With One Channel Images") + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224))]) + python_op = F.AutoContrast() + c_op = C.AutoContrast() + # not using F.ToTensor() since it converts to floats + transforms_op = F.ComposeOp([lambda img: (np.array(img)[:, :, 0]).astype(np.uint8), + F.ToPIL(), + python_op, + np.array])() + + ds_auto_contrast_py = ds.map(input_columns="image", + operations=transforms_op) + + ds_auto_contrast_py = ds_auto_contrast_py.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_py): + if idx == 0: + images_auto_contrast_py = image + else: + images_auto_contrast_py = np.append(images_auto_contrast_py, + image, + axis=0) + + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + + ds_auto_contrast_c = ds.map(input_columns="image", + operations=c_op) + + ds_auto_contrast_c = ds_auto_contrast_c.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_c): + if idx == 0: + images_auto_contrast_c = image + else: + images_auto_contrast_c = np.append(images_auto_contrast_c, + image, + axis=0) + + num_samples = images_auto_contrast_c.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_auto_contrast_c[i], images_auto_contrast_py[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + np.testing.assert_equal(np.mean(mse), 0.0) + + if plot: + visualize_list(images_auto_contrast_c, images_auto_contrast_py, visualize_mode=2) + + +def test_auto_contrast_invalid_ignore_param_c(): + """ + Test AutoContrast C Op with invalid ignore parameter + """ + logger.info("Test AutoContrast C Op with invalid ignore parameter") + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + # invalid ignore + ds = ds.map(input_columns="image", + operations=C.AutoContrast(ignore=255.5)) + except TypeError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Argument ignore with value 255.5 is not of type" in str(error) + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + # invalid ignore + ds = ds.map(input_columns="image", + operations=C.AutoContrast(ignore=(10, 100))) + except TypeError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Argument ignore with value (10,100) is not of type" in str(error) + + +def test_auto_contrast_invalid_cutoff_param_c(): + """ + Test AutoContrast C Op with invalid cutoff parameter + """ + logger.info("Test AutoContrast C Op with invalid cutoff parameter") + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + # invalid ignore + ds = ds.map(input_columns="image", + operations=C.AutoContrast(cutoff=-10.0)) + except ValueError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + # invalid ignore + ds = ds.map(input_columns="image", + operations=C.AutoContrast(cutoff=120.0)) + except ValueError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + + +def test_auto_contrast_invalid_ignore_param_py(): + """ + Test AutoContrast python Op with invalid ignore parameter + """ + logger.info("Test AutoContrast python Op with invalid ignore parameter") + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[F.ComposeOp([F.Decode(), + F.Resize((224, 224)), + F.AutoContrast(ignore=255.5), + F.ToTensor()])]) + except TypeError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Argument ignore with value 255.5 is not of type" in str(error) + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[F.ComposeOp([F.Decode(), + F.Resize((224, 224)), + F.AutoContrast(ignore=(10, 100)), + F.ToTensor()])]) + except TypeError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Argument ignore with value (10,100) is not of type" in str(error) + + +def test_auto_contrast_invalid_cutoff_param_py(): + """ + Test AutoContrast python Op with invalid cutoff parameter + """ + logger.info("Test AutoContrast python Op with invalid cutoff parameter") + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[F.ComposeOp([F.Decode(), + F.Resize((224, 224)), + F.AutoContrast(cutoff=-10.0), + F.ToTensor()])]) + except ValueError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[F.ComposeOp([F.Decode(), + F.Resize((224, 224)), + F.AutoContrast(cutoff=120.0), + F.ToTensor()])]) + except ValueError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + + if __name__ == "__main__": - test_auto_contrast(plot=True) + test_auto_contrast_py(plot=True) + test_auto_contrast_c(plot=True) + test_auto_contrast_one_channel_c(plot=True) + test_auto_contrast_invalid_ignore_param_c() + test_auto_contrast_invalid_ignore_param_py() + test_auto_contrast_invalid_cutoff_param_c() + test_auto_contrast_invalid_cutoff_param_py() diff --git a/tests/ut/python/dataset/test_batch.py b/tests/ut/python/dataset/test_batch.py index 9b9baeec33..1220d98344 100644 --- a/tests/ut/python/dataset/test_batch.py +++ b/tests/ut/python/dataset/test_batch.py @@ -14,7 +14,7 @@ # ============================================================================== import mindspore.dataset as ds from mindspore import log as logger -from util import save_and_check +from util import save_and_check_dict # Note: Number of rows in test.data dataset: 12 DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] @@ -29,8 +29,6 @@ def test_batch_01(): # define parameters batch_size = 2 drop_remainder = True - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -38,7 +36,7 @@ def test_batch_01(): assert sum([1 for _ in data1]) == 6 filename = "batch_01_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_02(): @@ -49,8 +47,6 @@ def test_batch_02(): # define parameters batch_size = 5 drop_remainder = True - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -58,7 +54,7 @@ def test_batch_02(): assert sum([1 for _ in data1]) == 2 filename = "batch_02_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_03(): @@ -69,8 +65,6 @@ def test_batch_03(): # define parameters batch_size = 3 drop_remainder = False - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -78,7 +72,7 @@ def test_batch_03(): assert sum([1 for _ in data1]) == 4 filename = "batch_03_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_04(): @@ -89,8 +83,6 @@ def test_batch_04(): # define parameters batch_size = 7 drop_remainder = False - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -98,7 +90,7 @@ def test_batch_04(): assert sum([1 for _ in data1]) == 2 filename = "batch_04_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_05(): @@ -108,7 +100,6 @@ def test_batch_05(): logger.info("test_batch_05") # define parameters batch_size = 1 - parameters = {"params": {'batch_size': batch_size}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -116,7 +107,7 @@ def test_batch_05(): assert sum([1 for _ in data1]) == 12 filename = "batch_05_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_06(): @@ -127,8 +118,6 @@ def test_batch_06(): # define parameters batch_size = 12 drop_remainder = False - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -136,7 +125,7 @@ def test_batch_06(): assert sum([1 for _ in data1]) == 1 filename = "batch_06_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_07(): @@ -148,9 +137,6 @@ def test_batch_07(): batch_size = 4 drop_remainder = False num_parallel_workers = 2 - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder, - 'num_parallel_workers': num_parallel_workers}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -159,7 +145,7 @@ def test_batch_07(): assert sum([1 for _ in data1]) == 3 filename = "batch_07_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_08(): @@ -170,8 +156,6 @@ def test_batch_08(): # define parameters batch_size = 6 num_parallel_workers = 1 - parameters = {"params": {'batch_size': batch_size, - 'num_parallel_workers': num_parallel_workers}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -179,7 +163,7 @@ def test_batch_08(): assert sum([1 for _ in data1]) == 2 filename = "batch_08_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_09(): @@ -190,8 +174,6 @@ def test_batch_09(): # define parameters batch_size = 13 drop_remainder = False - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -199,7 +181,7 @@ def test_batch_09(): assert sum([1 for _ in data1]) == 1 filename = "batch_09_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_10(): @@ -210,8 +192,6 @@ def test_batch_10(): # define parameters batch_size = 99 drop_remainder = True - parameters = {"params": {'batch_size': batch_size, - 'drop_remainder': drop_remainder}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -219,7 +199,7 @@ def test_batch_10(): assert sum([1 for _ in data1]) == 0 filename = "batch_10_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_11(): @@ -229,7 +209,6 @@ def test_batch_11(): logger.info("test_batch_11") # define parameters batch_size = 1 - parameters = {"params": {'batch_size': batch_size}} # apply dataset operations # Use schema file with 1 row @@ -239,7 +218,7 @@ def test_batch_11(): assert sum([1 for _ in data1]) == 1 filename = "batch_11_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_12(): @@ -249,7 +228,6 @@ def test_batch_12(): logger.info("test_batch_12") # define parameters batch_size = True - parameters = {"params": {'batch_size': batch_size}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -257,7 +235,7 @@ def test_batch_12(): assert sum([1 for _ in data1]) == 12 filename = "batch_12_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_batch_exception_01(): diff --git a/tests/ut/python/dataset/test_bounding_box_augment.py b/tests/ut/python/dataset/test_bounding_box_augment.py index 8924af968c..90bfae7bb8 100644 --- a/tests/ut/python/dataset/test_bounding_box_augment.py +++ b/tests/ut/python/dataset/test_bounding_box_augment.py @@ -49,9 +49,9 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False): test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "bounding_box_augment_rotation_c_result.npz" @@ -88,9 +88,9 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False): test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "bounding_box_augment_crop_c_result.npz" @@ -126,10 +126,11 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False): test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) # Add column for "bbox" + filename = "bounding_box_augment_valid_ratio_c_result.npz" save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) @@ -193,20 +194,20 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False): test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) # map to apply ops - # Add column for "annotation" - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + # Add column for "bbox" + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=lambda img, bbox: (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=lambda img, bbox: (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "bounding_box_augment_valid_edge_c_result.npz" save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) @@ -237,10 +238,10 @@ def test_bounding_box_augment_invalid_ratio_c(): # ratio range is from 0 - 1 test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) # Add column for "bbox" except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error) diff --git a/tests/ut/python/dataset/test_c_compose.py b/tests/ut/python/dataset/test_c_compose.py new file mode 100644 index 0000000000..906d787f21 --- /dev/null +++ b/tests/ut/python/dataset/test_c_compose.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops +import mindspore.dataset.transforms.py_transforms as py_ops + + +def test_compose(): + ds.config.set_seed(0) + + def test_config(arr, op_list): + try: + data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False) + data = data.map(input_columns=["col"], operations=ops.Compose(op_list)) + res = [] + for i in data.create_dict_iterator(): + res.append(i["col"].tolist()) + return res + except (TypeError, ValueError) as e: + return str(e) + + # test simple compose with only 1 op, this would generate a warning + assert test_config([[1, 0], [3, 4]], [ops.Fill(2)]) == [[2, 2], [2, 2]] + # test 1 column -> 2columns -> 1 -> 2 -> 1 + assert test_config([[1, 0]], [ops.Duplicate(), ops.Concatenate(), ops.Duplicate(), ops.Concatenate()]) == [ + [1, 0] * 4] + # test one python transform followed by a C transform. type after oneHot is float (mixed use-case) + assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[[0, 1]], [[1, 0]]] + # test exceptions. compose, randomApply randomChoice use the same validator + assert "op_list[0] is not a c_transform op" in test_config([1, 0], [1, ops.TypeCast(mstype.int32)]) + # test empty op list + assert "op_list can not be empty." in test_config([1, 0], []) + + +if __name__ == "__main__": + test_compose() diff --git a/tests/ut/python/dataset/test_c_random_apply.py b/tests/ut/python/dataset/test_c_random_apply.py new file mode 100644 index 0000000000..8b4851aab5 --- /dev/null +++ b/tests/ut/python/dataset/test_c_random_apply.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops + + +def test_random_apply(): + ds.config.set_seed(0) + + def test_config(arr, op_list, prob=0.5): + try: + data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False) + data = data.map(input_columns=["col"], operations=ops.RandomApply(op_list, prob)) + res = [] + for i in data.create_dict_iterator(): + res.append(i["col"].tolist()) + return res + except (TypeError, ValueError) as e: + return str(e) + + res1 = test_config([[0, 1]], [ops.Duplicate(), ops.Concatenate()]) + assert res1 in [[[0, 1]], [[0, 1, 0, 1]]] + # test single nested compose + assert test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate(), ops.Slice([0, 1, 2])])]) == [ + [0, 1, 2]] + # test exception + assert "is not of type (" in test_config([1, 0], ops.TypeCast(mstype.int32)) + assert "Input prob is not within the required interval" in test_config([0, 1], [ops.Slice([0, 1])], 1.1) + assert "is not of type (" in test_config([1, 0], [ops.TypeCast(mstype.int32)], None) + assert "op_list with value None is not of type (" in test_config([1, 0], None) + + +if __name__ == "__main__": + test_random_apply() diff --git a/tests/ut/python/dataset/test_c_random_choice.py b/tests/ut/python/dataset/test_c_random_choice.py new file mode 100644 index 0000000000..3faedeb26e --- /dev/null +++ b/tests/ut/python/dataset/test_c_random_choice.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops + + +def test_random_choice(): + ds.config.set_seed(0) + + def test_config(arr, op_list): + try: + data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False) + data = data.map(input_columns=["col"], operations=ops.RandomChoice(op_list)) + res = [] + for i in data.create_dict_iterator(): + res.append(i["col"].tolist()) + return res + except (TypeError, ValueError) as e: + return str(e) + + # test whether a op would be randomly chosen. In order to prevent random failure, both results need to be checked + res1 = test_config([[0, 1, 2]], [ops.PadEnd([4], 0), ops.Slice([0, 2])]) + assert res1 in [[[0, 1, 2, 0]], [[0, 2]]] + + # test nested structure + res2 = test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate()]), + ops.Compose([ops.Slice([0, 1]), ops.OneHot(2)])]) + assert res2 in [[[[1, 0], [0, 1]]], [[0, 1, 2, 0, 1, 2]]] + # test random_choice where there is only 1 op + assert test_config([[4, 3], [2, 1]], [ops.Slice([0])]) == [[4], [2]] + + +if __name__ == "__main__": + test_random_choice() diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 0e42b422aa..154a4208a0 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -24,6 +24,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" GENERATE_GOLDEN = False + def test_cache_map_basic1(): """ Test mappable leaf with cache op right over the leaf @@ -104,9 +105,36 @@ def test_cache_map_basic3(): decode_op = c_vision.Decode() ds1 = ds1.repeat(4) ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + logger.info("get data from dataset") + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info('test_cache_basic3 Ended.\n') + +def test_cache_map_basic4(): + """ + Test different rows result in core dump + """ + logger.info("Test cache basic 4") + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.repeat(4) + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) + shape = ds1.output_shapes() + logger.info(shape) num_iter = 0 for _ in ds1.create_dict_iterator(): + logger.info("get data from dataset") num_iter += 1 logger.info("Number of data in ds1: {} ".format(num_iter)) @@ -150,8 +178,15 @@ def test_cache_map_failure1(): assert num_iter == 0 logger.info('test_cache_failure1 Ended.\n') + if __name__ == '__main__': test_cache_map_basic1() + logger.info("test_cache_map_basic1 success.") test_cache_map_basic2() + logger.info("test_cache_map_basic2 success.") test_cache_map_basic3() + logger.info("test_cache_map_basic3 success.") + test_cache_map_basic4() + logger.info("test_cache_map_basic3 success.") test_cache_map_failure1() + logger.info("test_cache_map_failure1 success.") diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 39e00c0621..4a00cc5488 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -376,6 +376,44 @@ def test_cache_nomap_allowed_share3(): logger.info("test_cache_nomap_allowed_share3 Ended.\n") +def test_cache_nomap_allowed_share4(): + """ + It is allowed to share the cache between the following two trees: + + Cache Cache + | | + Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2) + | | + TFReader TFReader + """ + + logger.info("Test cache nomap allowed share 4") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) + decode_op = c_vision.Decode() + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache, num_parallel_workers=1) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache, num_parallel_workers=2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 3 + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds2: {} ".format(num_iter)) + assert num_iter == 3 + + logger.info("test_cache_nomap_allowed_share4 Ended.\n") + + def test_cache_nomap_disallowed_share1(): """ It is not allowed to share the cache between the following two trees: @@ -426,4 +464,5 @@ if __name__ == '__main__': test_cache_nomap_allowed_share1() test_cache_nomap_allowed_share2() test_cache_nomap_allowed_share3() + test_cache_nomap_allowed_share4() test_cache_nomap_disallowed_share1() diff --git a/tests/ut/python/dataset/test_center_crop.py b/tests/ut/python/dataset/test_center_crop.py index 6dfa9fc7c3..03b8079e1e 100644 --- a/tests/ut/python/dataset/test_center_crop.py +++ b/tests/ut/python/dataset/test_center_crop.py @@ -138,6 +138,17 @@ def test_crop_grayscale(height=375, width=375): assert (c_image.ndim == 3 and c_image.shape[2] == 1) +def test_center_crop_errors(): + """ + Test that CenterCropOp errors with bad input + """ + try: + test_center_crop_op(16777216, 16777216) + except RuntimeError as e: + assert "Unexpected error. CenterCropOp padding size is too big, it's more than 3 times the original size." in \ + str(e) + + if __name__ == "__main__": test_center_crop_op(600, 600, plot=True) test_center_crop_op(300, 600) diff --git a/tests/ut/python/dataset/test_concatenate_op.py b/tests/ut/python/dataset/test_concatenate_op.py index f7a432e471..d60cff06c5 100644 --- a/tests/ut/python/dataset/test_concatenate_op.py +++ b/tests/ut/python/dataset/test_concatenate_op.py @@ -130,7 +130,7 @@ def test_concatenate_op_incorrect_dim(): def gen(): yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),) - prepend_tensor = np.array([3, 5], dtype=np.float) + prepend_tensor = np.array(["ss", "ss"], dtype='S') concatenate_op = data_trans.Concatenate(0, prepend_tensor) data = ds.GeneratorDataset(gen, column_names=["col"]) diff --git a/tests/ut/python/dataset/test_datasets_cifarop.py b/tests/ut/python/dataset/test_datasets_cifarop.py index d6d3029b53..2b66f32665 100644 --- a/tests/ut/python/dataset/test_datasets_cifarop.py +++ b/tests/ut/python/dataset/test_datasets_cifarop.py @@ -87,6 +87,13 @@ def test_cifar10_basic(): """ logger.info("Test Cifar10Dataset Op") + # case 0: test loading the whole dataset + data0 = ds.Cifar10Dataset(DATA_DIR_10) + num_iter0 = 0 + for _ in data0.create_dict_iterator(): + num_iter0 += 1 + assert num_iter0 == 10000 + # case 1: test num_samples data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) num_iter1 = 0 diff --git a/tests/ut/python/dataset/test_datasets_clue.py b/tests/ut/python/dataset/test_datasets_clue.py index e1959acb42..0d8a60f5d1 100644 --- a/tests/ut/python/dataset/test_datasets_clue.py +++ b/tests/ut/python/dataset/test_datasets_clue.py @@ -356,9 +356,13 @@ def test_clue_to_device(): if __name__ == "__main__": test_clue() + test_clue_num_shards() + test_clue_num_samples() + test_textline_dataset_get_datasetsize() test_clue_afqmc() test_clue_cmnli() test_clue_csl() test_clue_iflytek() test_clue_tnews() test_clue_wsc() + test_clue_to_device() diff --git a/tests/ut/python/dataset/test_datasets_coco.py b/tests/ut/python/dataset/test_datasets_coco.py index f5bf7caa6c..fd7430ccd2 100644 --- a/tests/ut/python/dataset/test_datasets_coco.py +++ b/tests/ut/python/dataset/test_datasets_coco.py @@ -17,6 +17,7 @@ import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as vision DATA_DIR = "../data/dataset/testCOCO/train/" +DATA_DIR_2 = "../data/dataset/testCOCO/train" ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" KEYPOINT_FILE = "../data/dataset/testCOCO/annotations/key_point.json" PANOPTIC_FILE = "../data/dataset/testCOCO/annotations/panoptic.json" @@ -43,18 +44,18 @@ def test_coco_detection(): assert image_shape[3] == (642, 675, 3) assert image_shape[4] == (2268, 4032, 3) assert image_shape[5] == (2268, 4032, 3) - assert np.array_equal(np.array([[10., 10., 10., 10.], [70., 70., 70., 70.]]), bbox[0]) - assert np.array_equal(np.array([[20., 20., 20., 20.], [80., 80., 80.0, 80.]]), bbox[1]) - assert np.array_equal(np.array([[30.0, 30.0, 30.0, 30.]]), bbox[2]) - assert np.array_equal(np.array([[40., 40., 40., 40.]]), bbox[3]) - assert np.array_equal(np.array([[50., 50., 50., 50.]]), bbox[4]) - assert np.array_equal(np.array([[60., 60., 60., 60.]]), bbox[5]) - assert np.array_equal(np.array([[1], [7]]), category_id[0]) - assert np.array_equal(np.array([[2], [8]]), category_id[1]) - assert np.array_equal(np.array([[3]]), category_id[2]) - assert np.array_equal(np.array([[4]]), category_id[3]) - assert np.array_equal(np.array([[5]]), category_id[4]) - assert np.array_equal(np.array([[6]]), category_id[5]) + np.testing.assert_array_equal(np.array([[10., 10., 10., 10.], [70., 70., 70., 70.]]), bbox[0]) + np.testing.assert_array_equal(np.array([[20., 20., 20., 20.], [80., 80., 80.0, 80.]]), bbox[1]) + np.testing.assert_array_equal(np.array([[30.0, 30.0, 30.0, 30.]]), bbox[2]) + np.testing.assert_array_equal(np.array([[40., 40., 40., 40.]]), bbox[3]) + np.testing.assert_array_equal(np.array([[50., 50., 50., 50.]]), bbox[4]) + np.testing.assert_array_equal(np.array([[60., 60., 60., 60.]]), bbox[5]) + np.testing.assert_array_equal(np.array([[1], [7]]), category_id[0]) + np.testing.assert_array_equal(np.array([[2], [8]]), category_id[1]) + np.testing.assert_array_equal(np.array([[3]]), category_id[2]) + np.testing.assert_array_equal(np.array([[4]]), category_id[3]) + np.testing.assert_array_equal(np.array([[5]]), category_id[4]) + np.testing.assert_array_equal(np.array([[6]]), category_id[5]) def test_coco_stuff(): data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff", @@ -75,25 +76,26 @@ def test_coco_stuff(): assert image_shape[3] == (642, 675, 3) assert image_shape[4] == (2268, 4032, 3) assert image_shape[5] == (2268, 4032, 3) - assert np.array_equal(np.array([[10., 12., 13., 14., 15., 16., 17., 18., 19., 20.], - [70., 72., 73., 74., 75., -1., -1., -1., -1., -1.]]), - segmentation[0]) - assert np.array_equal(np.array([[0], [0]]), iscrowd[0]) - assert np.array_equal(np.array([[20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], - [10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0]]), - segmentation[1]) - assert np.array_equal(np.array([[0], [1]]), iscrowd[1]) - assert np.array_equal(np.array([[40., 42., 43., 44., 45., 46., 47., 48., 49., 40., 41., 42.]]), segmentation[2]) - assert np.array_equal(np.array([[0]]), iscrowd[2]) - assert np.array_equal(np.array([[50., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63.]]), - segmentation[3]) - assert np.array_equal(np.array([[0]]), iscrowd[3]) - assert np.array_equal(np.array([[60., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.]]), - segmentation[4]) - assert np.array_equal(np.array([[0]]), iscrowd[4]) - assert np.array_equal(np.array([[60., 62., 63., 64., 65., 66., 67.], [68., 69., 70., 71., 72., 73., 74.]]), - segmentation[5]) - assert np.array_equal(np.array([[0]]), iscrowd[5]) + np.testing.assert_array_equal(np.array([[10., 12., 13., 14., 15., 16., 17., 18., 19., 20.], + [70., 72., 73., 74., 75., -1., -1., -1., -1., -1.]]), + segmentation[0]) + np.testing.assert_array_equal(np.array([[0], [0]]), iscrowd[0]) + np.testing.assert_array_equal(np.array([[20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], + [10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0]]), + segmentation[1]) + np.testing.assert_array_equal(np.array([[0], [1]]), iscrowd[1]) + np.testing.assert_array_equal(np.array([[40., 42., 43., 44., 45., 46., 47., 48., 49., 40., 41., 42.]]), + segmentation[2]) + np.testing.assert_array_equal(np.array([[0]]), iscrowd[2]) + np.testing.assert_array_equal(np.array([[50., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63.]]), + segmentation[3]) + np.testing.assert_array_equal(np.array([[0]]), iscrowd[3]) + np.testing.assert_array_equal(np.array([[60., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.]]), + segmentation[4]) + np.testing.assert_array_equal(np.array([[0]]), iscrowd[4]) + np.testing.assert_array_equal(np.array([[60., 62., 63., 64., 65., 66., 67.], [68., 69., 70., 71., 72., 73., 74.]]), + segmentation[5]) + np.testing.assert_array_equal(np.array([[0]]), iscrowd[5]) def test_coco_keypoint(): data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint", @@ -110,16 +112,17 @@ def test_coco_keypoint(): assert num_iter == 2 assert image_shape[0] == (2268, 4032, 3) assert image_shape[1] == (561, 595, 3) - assert np.array_equal(np.array([[368., 61., 1., 369., 52., 2., 0., 0., 0., 382., 48., 2., 0., 0., 0., 368., 84., 2., - 435., 81., 2., 362., 125., 2., 446., 125., 2., 360., 153., 2., 0., 0., 0., 397., - 167., 1., 439., 166., 1., 369., 193., 2., 461., 234., 2., 361., 246., 2., 474., - 287., 2.]]), keypoints[0]) - assert np.array_equal(np.array([[14]]), num_keypoints[0]) - assert np.array_equal(np.array([[244., 139., 2., 0., 0., 0., 226., 118., 2., 0., 0., 0., 154., 159., 2., 143., 261., - 2., 135., 312., 2., 271., 423., 2., 184., 530., 2., 261., 280., 2., 347., 592., 2., - 0., 0., 0., 123., 596., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), - keypoints[1]) - assert np.array_equal(np.array([[10]]), num_keypoints[1]) + np.testing.assert_array_equal(np.array([[368., 61., 1., 369., 52., 2., 0., 0., 0., 382., 48., 2., 0., 0., 0., 368., + 84., 2., 435., 81., 2., 362., 125., 2., 446., 125., 2., 360., 153., 2., 0., + 0., 0., 397., 167., 1., 439., 166., 1., 369., 193., 2., 461., 234., 2., + 361., 246., 2., 474., 287., 2.]]), keypoints[0]) + np.testing.assert_array_equal(np.array([[14]]), num_keypoints[0]) + np.testing.assert_array_equal(np.array([[244., 139., 2., 0., 0., 0., 226., 118., 2., 0., 0., 0., 154., 159., 2., + 143., 261., 2., 135., 312., 2., 271., 423., 2., 184., 530., 2., 261., 280., + 2., 347., 592., 2., 0., 0., 0., 123., 596., 2., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.]]), + keypoints[1]) + np.testing.assert_array_equal(np.array([[10]]), num_keypoints[1]) def test_coco_panoptic(): data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", decode=True, shuffle=False) @@ -138,15 +141,15 @@ def test_coco_panoptic(): num_iter += 1 assert num_iter == 2 assert image_shape[0] == (2268, 4032, 3) - assert np.array_equal(np.array([[472, 173, 36, 48], [340, 22, 154, 301], [486, 183, 30, 35]]), bbox[0]) - assert np.array_equal(np.array([[1], [1], [2]]), category_id[0]) - assert np.array_equal(np.array([[0], [0], [0]]), iscrowd[0]) - assert np.array_equal(np.array([[705], [14062], [626]]), area[0]) + np.testing.assert_array_equal(np.array([[472, 173, 36, 48], [340, 22, 154, 301], [486, 183, 30, 35]]), bbox[0]) + np.testing.assert_array_equal(np.array([[1], [1], [2]]), category_id[0]) + np.testing.assert_array_equal(np.array([[0], [0], [0]]), iscrowd[0]) + np.testing.assert_array_equal(np.array([[705], [14062], [626]]), area[0]) assert image_shape[1] == (642, 675, 3) - assert np.array_equal(np.array([[103, 133, 229, 422], [243, 175, 93, 164]]), bbox[1]) - assert np.array_equal(np.array([[1], [3]]), category_id[1]) - assert np.array_equal(np.array([[0], [0]]), iscrowd[1]) - assert np.array_equal(np.array([[43102], [6079]]), area[1]) + np.testing.assert_array_equal(np.array([[103, 133, 229, 422], [243, 175, 93, 164]]), bbox[1]) + np.testing.assert_array_equal(np.array([[1], [3]]), category_id[1]) + np.testing.assert_array_equal(np.array([[0], [0]]), iscrowd[1]) + np.testing.assert_array_equal(np.array([[43102], [6079]]), area[1]) def test_coco_detection_classindex(): data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True) @@ -202,6 +205,17 @@ def test_coco_case_2(): num_iter += 1 assert num_iter == 24 +def test_coco_case_3(): + data1 = ds.CocoDataset(DATA_DIR_2, annotation_file=ANNOTATION_FILE, task="Detection", decode=True) + resize_op = vision.Resize((224, 224)) + + data1 = data1.map(input_columns=["image"], operations=resize_op) + data1 = data1.repeat(4) + num_iter = 0 + for _ in data1.__iter__(): + num_iter += 1 + assert num_iter == 24 + def test_coco_case_exception(): try: data1 = ds.CocoDataset("path_not_exist/", annotation_file=ANNOTATION_FILE, task="Detection") @@ -271,4 +285,5 @@ if __name__ == '__main__': test_coco_case_0() test_coco_case_1() test_coco_case_2() + test_coco_case_3() test_coco_case_exception() diff --git a/tests/ut/python/dataset/test_datasets_csv.py b/tests/ut/python/dataset/test_datasets_csv.py new file mode 100644 index 0000000000..f998e9774d --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_csv.py @@ -0,0 +1,238 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +import numpy as np +import pytest + +DATA_FILE = '../data/dataset/testCSV/1.csv' + + +def test_csv_dataset_basic(): + """ + Test CSV with repeat, skip and so on + """ + TRAIN_FILE = '../data/dataset/testCSV/1.csv' + + buffer = [] + data = ds.CSVDataset( + TRAIN_FILE, + column_defaults=["0", 0, 0.0, "0"], + column_names=['1', '2', '3', '4'], + shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(): + buffer.append(d) + assert len(buffer) == 4 + + +def test_csv_dataset_one_file(): + data = ds.CSVDataset( + DATA_FILE, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.append(d) + assert len(buffer) == 3 + + +def test_csv_dataset_all_file(): + APPEND_FILE = '../data/dataset/testCSV/2.csv' + data = ds.CSVDataset( + [DATA_FILE, APPEND_FILE], + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.append(d) + assert len(buffer) == 10 + + +def test_csv_dataset_num_samples(): + data = ds.CSVDataset( + DATA_FILE, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False, num_samples=2) + count = 0 + for _ in data.create_dict_iterator(): + count += 1 + assert count == 2 + + +def test_csv_dataset_distribution(): + TEST_FILE = '../data/dataset/testCSV/1.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False, num_shards=2, shard_id=0) + count = 0 + for _ in data.create_dict_iterator(): + count += 1 + assert count == 2 + + +def test_csv_dataset_quoted(): + TEST_FILE = '../data/dataset/testCSV/quoted.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a', 'b', 'c', 'd'] + + +def test_csv_dataset_separated(): + TEST_FILE = '../data/dataset/testCSV/separated.csv' + data = ds.CSVDataset( + TEST_FILE, + field_delim='|', + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a', 'b', 'c', 'd'] + + +def test_csv_dataset_embedded(): + TEST_FILE = '../data/dataset/testCSV/embedded.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a,b', 'c"d', 'e\nf', ' g '] + + +def test_csv_dataset_chinese(): + TEST_FILE = '../data/dataset/testCSV/chinese.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4', 'col5'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8"), + d['col5'].item().decode("utf8")]) + assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好'] + + +def test_csv_dataset_header(): + TEST_FILE = '../data/dataset/testCSV/header.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a', 'b', 'c', 'd'] + + +def test_csv_dataset_number(): + TEST_FILE = '../data/dataset/testCSV/number.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=[0.0, 0.0, 0, 0.0], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item(), + d['col2'].item(), + d['col3'].item(), + d['col4'].item()]) + assert np.allclose(buffer, [3.0, 0.3, 4, 55.5]) + + +def test_csv_dataset_size(): + TEST_FILE = '../data/dataset/testCSV/size.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=[0.0, 0.0, 0, 0.0], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + assert data.get_dataset_size() == 5 + + +def test_csv_dataset_exception(): + TEST_FILE = '../data/dataset/testCSV/exception.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + with pytest.raises(Exception) as err: + for _ in data.create_dict_iterator(): + pass + assert "Failed to parse file" in str(err.value) + + +def test_csv_dataset_type_error(): + TEST_FILE = '../data/dataset/testCSV/exception.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", 0, "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + with pytest.raises(Exception) as err: + for _ in data.create_dict_iterator(): + pass + assert "type does not match" in str(err.value) + + +if __name__ == "__main__": + test_csv_dataset_basic() + test_csv_dataset_one_file() + test_csv_dataset_all_file() + test_csv_dataset_num_samples() + test_csv_dataset_distribution() + test_csv_dataset_quoted() + test_csv_dataset_separated() + test_csv_dataset_embedded() + test_csv_dataset_chinese() + test_csv_dataset_header() + test_csv_dataset_number() + test_csv_dataset_size() + test_csv_dataset_exception() + test_csv_dataset_type_error() diff --git a/tests/ut/python/dataset/test_datasets_generator.py b/tests/ut/python/dataset/test_datasets_generator.py new file mode 100644 index 0000000000..5f3bc998f3 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_generator.py @@ -0,0 +1,665 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +from mindspore import log as logger + + +# Generate 1d int numpy array from 0 - 63 +def generator_1d(): + for i in range(64): + yield (np.array([i]),) + + +def test_generator_0(): + """ + Test 1D Generator + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + + +# Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]] +def generator_md(): + for i in range(64): + yield (np.array([[i, i + 1], [i + 2, i + 3]]),) + + +def test_generator_1(): + """ + Test MD Generator + """ + logger.info("Test MD Generator : 0 - 63, with shape [2, 2]") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_md, ["data"]) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + + +# Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD +def generator_mc(maxid=64): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + + +def test_generator_2(): + """ + Test multi column generator + """ + logger.info("Test multi column generator") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"]) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["col0"], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item["col1"], golden) + i = i + 1 + + +def test_generator_3(): + """ + Test 1D Generator + repeat(4) + """ + logger.info("Test 1D Generator : 0 - 63 + Repeat(4)") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + data1 = data1.repeat(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + if i == 64: + i = 0 + + +def test_generator_4(): + """ + Test fixed size 1D Generator + batch + """ + logger.info("Test 1D Generator : 0 - 63 + batch(4)") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + data1 = data1.batch(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 4 + + +def generator_with_type(t): + for i in range(64): + yield (np.array([i], dtype=t),) + + +def type_tester(t): + logger.info("Test with Type {}".format(t.__name__)) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"]) + + data1 = data1.batch(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + np.testing.assert_array_equal(item["data"], golden) + i = i + 4 + + +def test_generator_5(): + """ + Test 1D Generator on different data type + """ + logger.info("Test 1D Generator on all data types") + + types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64] + + for t in types: + type_tester(t) + + +def type_tester_with_type_check(t, c): + logger.info("Test with Type {}".format(t.__name__)) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c]) + + data1 = data1.batch(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + np.testing.assert_array_equal(item["data"], golden) + i = i + 4 + + +def test_generator_6(): + """ + Test 1D Generator on different data type with type check + """ + logger.info("Test 1D Generator on all data types with type check") + + np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, + np.float64] + de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, + mstype.uint64, mstype.float32, mstype.float64] + + for i, _ in enumerate(np_types): + type_tester_with_type_check(np_types[i], de_types[i]) + + +def generator_with_type_2c(t): + for i in range(64): + yield (np.array([i], dtype=t), np.array([i], dtype=t)) + + +def type_tester_with_type_check_2c(t, c): + logger.info("Test with Type {}".format(t.__name__)) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c) + + data1 = data1.batch(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + np.testing.assert_array_equal(item["data0"], golden) + i = i + 4 + + +def test_generator_7(): + """ + Test 2 column Generator on different data type with type check + """ + logger.info("Test 2 column Generator on all data types with type check") + + np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, + np.float64] + de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, + mstype.uint64, mstype.float32, mstype.float64] + + for i, _ in enumerate(np_types): + type_tester_with_type_check_2c(np_types[i], [None, de_types[i]]) + + +def test_generator_8(): + """ + Test multi column generator with few mapops + """ + logger.info("Test multi column generator with mapops to check the order too") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) + data1 = data1.map(input_columns="col0", output_columns="out0", operations=(lambda x: x * 3), + num_parallel_workers=2) + data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x * 7, x)), + num_parallel_workers=2, columns_order=["out0", "out1", "out2"]) + data1 = data1.map(input_columns="out2", output_columns="out2", operations=(lambda x: x + 1), + num_parallel_workers=2) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i * 3]) + np.testing.assert_array_equal(item["out0"], golden) + golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]]) + np.testing.assert_array_equal(item["out1"], golden) + golden = np.array([[i + 1, i + 2], [i + 3, i + 4]]) + np.testing.assert_array_equal(item["out2"], golden) + i = i + 1 + + +def test_generator_9(): + """ + Test map column order when len(input_columns) == len(output_columns). + """ + logger.info("Test map column order when len(input_columns) == len(output_columns).") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"]) + data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) + data1 = data1.map(input_columns="label", operations=(lambda x: x * 3), + num_parallel_workers=4) + data2 = data2.map(input_columns="label", operations=(lambda x: x * 3), + num_parallel_workers=4) + + # Expected column order is not changed. + # data1 = data[0] is "image" and data[1] is "label" + # data2 = data[0] is "label" and data[1] is "image" + i = 0 + for data1, data2 in zip(data1, data2): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(data1[0], golden) + golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]]) + np.testing.assert_array_equal(data1[1], golden) + + golden = np.array([i * 3]) + np.testing.assert_array_equal(data2[0], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(data2[1], golden) + i = i + 1 + + +def test_generator_10(): + """ + Test map column order when len(input_columns) != len(output_columns). + """ + logger.info("Test map column order when len(input_columns) != len(output_columns).") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) + data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)), + columns_order=['col0', 'out1', 'out2'], num_parallel_workers=2) + + # Expected column order is |col0|out1|out2| + i = 0 + for item in data1.create_tuple_iterator(): + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item[1], golden) + golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]]) + np.testing.assert_array_equal(item[2], golden) + i = i + 1 + + +def test_generator_11(): + """ + Test map column order when len(input_columns) != len(output_columns). + """ + logger.info("Test map column order when len(input_columns) != len(output_columns), " + "and columns_order drops some columns.") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) + data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)), + columns_order=['out1', 'out2'], num_parallel_workers=2) + + # Expected column order is |out1|out2| + i = 0 + for item in data1.create_tuple_iterator(): + # len should be 2 because col0 is dropped (not included in columns_order) + assert len(item) == 2 + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item[0], golden) + golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]]) + np.testing.assert_array_equal(item[1], golden) + i = i + 1 + + +def test_generator_12(): + """ + Test map column order when input_columns and output_columns are None. + """ + logger.info("Test map column order when input_columns and output_columns are None.") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) + data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2) + + # Expected column order is |col0|col1| + i = 0 + for item in data1.create_tuple_iterator(): + assert len(item) == 2 + golden = np.array([i * 5]) + np.testing.assert_array_equal(item[0], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item[1], golden) + i = i + 1 + + data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) + data1 = data1.map(operations=(lambda x: (x * 5)), columns_order=["col1", "col0"], num_parallel_workers=2) + + # Expected column order is |col0|col1| + i = 0 + for item in data1.create_tuple_iterator(): + assert len(item) == 2 + golden = np.array([i * 5]) + np.testing.assert_array_equal(item[1], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + + +def test_generator_13(): + """ + Test map column order when input_columns is None. + """ + logger.info("Test map column order when input_columns is None.") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) + data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2) + + # Expected column order is |out0|col1| + i = 0 + for item in data1.create_tuple_iterator(): + assert len(item) == 2 + golden = np.array([i * 5]) + np.testing.assert_array_equal(item[0], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item[1], golden) + i = i + 1 + + for item in data1.create_dict_iterator(): # each data is a dictionary + # len should be 2 because col0 is dropped (not included in columns_order) + assert len(item) == 2 + golden = np.array([i * 5]) + np.testing.assert_array_equal(item["out0"], golden) + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + np.testing.assert_array_equal(item["col1"], golden) + i = i + 1 + + +def test_generator_14(): + """ + Test 1D Generator MP + CPP sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_generator_15(): + """ + Test 1D Generator MP + Python sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + sampler = [x for x in range(256)] + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_generator_16(): + """ + Test multi column generator Mp + CPP sampler + """ + logger.info("Test multi column generator") + + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler()) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["col0"], golden) + golden = np.array([i + 1]) + np.testing.assert_array_equal(item["col1"], golden) + i = i + 1 + + +def test_generator_17(): + """ + Test multi column generator Mp + Python sampler + """ + logger.info("Test multi column generator") + + sampler = [x for x in range(256)] + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["col0"], golden) + golden = np.array([i + 1]) + np.testing.assert_array_equal(item["col1"], golden) + i = i + 1 + + +def test_generator_error_1(): + def generator_np(): + for i in range(64): + yield (np.array([{i}]),) + + with pytest.raises(RuntimeError) as info: + data1 = ds.GeneratorDataset(generator_np, ["data"]) + for _ in data1: + pass + assert "Invalid data type" in str(info.value) + + +def test_generator_error_2(): + def generator_np(): + for i in range(64): + yield ({i},) + + with pytest.raises(RuntimeError) as info: + data1 = ds.GeneratorDataset(generator_np, ["data"]) + for _ in data1: + pass + assert "Generator should return a tuple of numpy arrays" in str(info.value) + + +def test_generator_error_3(): + with pytest.raises(ValueError) as info: + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) + data1 = data1.map(input_columns=["label"], output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)), + num_parallel_workers=2) + + for _ in data1: + pass + assert "When (len(input_columns) != len(output_columns)), columns_order must be specified." in str(info.value) + + +def test_generator_error_4(): + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) + data1 = data1.map(input_columns=["label"], operations=(lambda x: (x, x * 5)), + num_parallel_workers=2) + + for _ in data1: + pass + assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) + + +def test_generator_sequential_sampler(): + source = [(np.array([x]),) for x in range(64)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(data["data"], golden) + i = i + 1 + + +def test_generator_random_sampler(): + source = [(np.array([x]),) for x in range(64)] + ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True) + for _ in ds1.create_dict_iterator(): # each data is a dictionary + pass + + +def test_generator_distributed_sampler(): + source = [(np.array([x]),) for x in range(64)] + for sid in range(8): + ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid) + i = sid + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(data["data"], golden) + i = i + 8 + + +def test_generator_num_samples(): + source = [(np.array([x]),) for x in range(64)] + num_samples = 32 + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples)) + ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples) + ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples) + + count = 0 + for _ in ds1.create_dict_iterator(): + count = count + 1 + assert count == num_samples + + count = 0 + for _ in ds2.create_dict_iterator(): + count = count + 1 + assert count == num_samples + + count = 0 + for _ in ds3.create_dict_iterator(): + count = count + 1 + assert count == num_samples + + +def test_generator_num_samples_underflow(): + source = [(np.array([x]),) for x in range(64)] + num_samples = 256 + ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples) + ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples) + + count = 0 + for _ in ds2.create_dict_iterator(): + count = count + 1 + assert count == 64 + + count = 0 + for _ in ds3.create_dict_iterator(): + count = count + 1 + assert count == 64 + + +def type_tester_with_type_check_2c_schema(t, c): + logger.info("Test with Type {}".format(t.__name__)) + + schema = ds.Schema() + schema.add_column("data0", c[0]) + schema.add_column("data1", c[1]) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema) + + data1 = data1.batch(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + np.testing.assert_array_equal(item["data0"], golden) + i = i + 4 + + +def test_generator_schema(): + """ + Test 2 column Generator on different data type with type check with schema input + """ + logger.info("Test 2 column Generator on all data types with type check") + + np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, + np.float64] + de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, + mstype.uint64, mstype.float32, mstype.float64] + + for i, _ in enumerate(np_types): + type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) + + +def manual_test_generator_keyboard_interrupt(): + """ + Test keyboard_interrupt + """ + logger.info("Test 1D Generator MP : 0 - 63") + + class MyDS(): + def __getitem__(self, item): + while True: + pass + + def __len__(self): + return 1024 + + ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2) + for _ in ds1.create_dict_iterator(): # each data is a dictionary + pass + + +if __name__ == "__main__": + test_generator_0() + test_generator_1() + test_generator_2() + test_generator_3() + test_generator_4() + test_generator_5() + test_generator_6() + test_generator_7() + test_generator_8() + test_generator_9() + test_generator_10() + test_generator_11() + test_generator_12() + test_generator_13() + test_generator_14() + test_generator_15() + test_generator_16() + test_generator_17() + test_generator_error_1() + test_generator_error_2() + test_generator_error_3() + test_generator_error_4() + test_generator_sequential_sampler() + test_generator_distributed_sampler() + test_generator_random_sampler() + test_generator_num_samples() + test_generator_num_samples_underflow() + test_generator_schema() diff --git a/tests/ut/python/dataset/test_datasets_mnist.py b/tests/ut/python/dataset/test_datasets_mnist.py new file mode 100644 index 0000000000..dfd6f7c6fc --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_mnist.py @@ -0,0 +1,238 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test Mnist dataset operators +""" +import os +import pytest +import numpy as np +import matplotlib.pyplot as plt +import mindspore.dataset as ds +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testMnistData" + + +def load_mnist(path): + """ + load Mnist data + """ + labels_path = os.path.join(path, 't10k-labels-idx1-ubyte') + images_path = os.path.join(path, 't10k-images-idx3-ubyte') + with open(labels_path, 'rb') as lbpath: + lbpath.read(8) + labels = np.fromfile(lbpath, dtype=np.uint8) + with open(images_path, 'rb') as imgpath: + imgpath.read(16) + images = np.fromfile(imgpath, dtype=np.uint8) + images = images.reshape(-1, 28, 28, 1) + images[images > 0] = 255 # Perform binarization to maintain consistency with our API + return images, labels + + +def visualize_dataset(images, labels): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i].squeeze(), cmap=plt.cm.gray) + plt.title(labels[i]) + plt.show() + + +def test_mnist_content_check(): + """ + Validate MnistDataset image readings + """ + logger.info("Test MnistDataset Op with content check") + data1 = ds.MnistDataset(DATA_DIR, num_samples=100, shuffle=False) + images, labels = load_mnist(DATA_DIR) + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + image_list, label_list = [], [] + for i, data in enumerate(data1.create_dict_iterator()): + image_list.append(data["image"]) + label_list.append("label {}".format(data["label"])) + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 100 + + +def test_mnist_basic(): + """ + Validate MnistDataset + """ + logger.info("Test MnistDataset Op") + + # case 1: test loading whole dataset + data1 = ds.MnistDataset(DATA_DIR) + num_iter1 = 0 + for _ in data1.create_dict_iterator(): + num_iter1 += 1 + assert num_iter1 == 10000 + + # case 2: test num_samples + data2 = ds.MnistDataset(DATA_DIR, num_samples=500) + num_iter2 = 0 + for _ in data2.create_dict_iterator(): + num_iter2 += 1 + assert num_iter2 == 500 + + # case 3: test repeat + data3 = ds.MnistDataset(DATA_DIR, num_samples=200) + data3 = data3.repeat(5) + num_iter3 = 0 + for _ in data3.create_dict_iterator(): + num_iter3 += 1 + assert num_iter3 == 1000 + + # case 4: test batch with drop_remainder=False + data4 = ds.MnistDataset(DATA_DIR, num_samples=100) + assert data4.get_dataset_size() == 100 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=7) # drop_remainder is default to be False + assert data4.get_dataset_size() == 15 + assert data4.get_batch_size() == 7 + num_iter4 = 0 + for _ in data4.create_dict_iterator(): + num_iter4 += 1 + assert num_iter4 == 15 + + # case 5: test batch with drop_remainder=True + data5 = ds.MnistDataset(DATA_DIR, num_samples=100) + assert data5.get_dataset_size() == 100 + assert data5.get_batch_size() == 1 + data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped + assert data5.get_dataset_size() == 14 + assert data5.get_batch_size() == 7 + num_iter5 = 0 + for _ in data5.create_dict_iterator(): + num_iter5 += 1 + assert num_iter5 == 14 + + +def test_mnist_pk_sampler(): + """ + Test MnistDataset with PKSampler + """ + logger.info("Test MnistDataset Op with PKSampler") + golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, + 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9] + sampler = ds.PKSampler(3) + data = ds.MnistDataset(DATA_DIR, sampler=sampler) + num_iter = 0 + label_list = [] + for item in data.create_dict_iterator(): + label_list.append(item["label"]) + num_iter += 1 + np.testing.assert_array_equal(golden, label_list) + assert num_iter == 30 + + +def test_mnist_sequential_sampler(): + """ + Test MnistDataset with SequentialSampler + """ + logger.info("Test MnistDataset Op with SequentialSampler") + num_samples = 50 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.MnistDataset(DATA_DIR, sampler=sampler) + data2 = ds.MnistDataset(DATA_DIR, shuffle=False, num_samples=num_samples) + label_list1, label_list2 = [], [] + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + label_list1.append(item1["label"]) + label_list2.append(item2["label"]) + num_iter += 1 + np.testing.assert_array_equal(label_list1, label_list2) + assert num_iter == num_samples + + +def test_mnist_exception(): + """ + Test error cases for MnistDataset + """ + logger.info("Test error cases for MnistDataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.MnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.MnistDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.MnistDataset(DATA_DIR, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.MnistDataset(DATA_DIR, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.MnistDataset(DATA_DIR, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=65) + with pytest.raises(ValueError, match=error_msg_6): + ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.MnistDataset(DATA_DIR, num_shards=2, shard_id="0") + + +def test_mnist_visualize(plot=False): + """ + Visualize MnistDataset results + """ + logger.info("Test MnistDataset visualization") + + data1 = ds.MnistDataset(DATA_DIR, num_samples=10, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in data1.create_dict_iterator(): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert image.shape == (28, 28, 1) + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 10 + if plot: + visualize_dataset(image_list, label_list) + + +if __name__ == '__main__': + test_mnist_content_check() + test_mnist_basic() + test_mnist_pk_sampler() + test_mnist_sequential_sampler() + test_mnist_exception() + test_mnist_visualize(plot=True) diff --git a/tests/ut/python/dataset/test_datasets_sharding.py b/tests/ut/python/dataset/test_datasets_sharding.py index 94c39fb34c..ce6a30077f 100644 --- a/tests/ut/python/dataset/test_datasets_sharding.py +++ b/tests/ut/python/dataset/test_datasets_sharding.py @@ -200,7 +200,7 @@ def test_cifar10_shardings(print_res=False): logger.info("labels of dataset: {}".format(res)) return res - # 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow + # 10000 rows in total. CIFAR reads everything in memory which would make each test case very slow # therefore, only 2 test cases for now. assert sharding_config(10000, 9999, 7, False, 1) == [9] assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0] diff --git a/tests/ut/python/dataset/test_datasets_tfrecord.py b/tests/ut/python/dataset/test_datasets_tfrecord.py new file mode 100644 index 0000000000..36791ac4c6 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_tfrecord.py @@ -0,0 +1,314 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test TFRecordDataset Ops +""" +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +from mindspore import log as logger +from util import save_and_check_dict + +FILES = ["../data/dataset/testTFTestAllTypes/test.data"] +DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" +SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" +DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] +SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" +GENERATE_GOLDEN = False + + +def test_tfrecord_shape(): + logger.info("test_tfrecord_shape") + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + ds1 = ds1.batch(2) + for data in ds1.create_dict_iterator(): + logger.info(data) + output_shape = ds1.output_shapes() + assert len(output_shape[-1]) == 1 + + +def test_tfrecord_read_all_dataset(): + logger.info("test_tfrecord_read_all_dataset") + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 12 + count = 0 + for _ in ds1.create_tuple_iterator(): + count += 1 + assert count == 12 + + +def test_tfrecord_num_samples(): + logger.info("test_tfrecord_num_samples") + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) + assert ds1.get_dataset_size() == 8 + count = 0 + for _ in ds1.create_dict_iterator(): + count += 1 + assert count == 8 + + +def test_tfrecord_num_samples2(): + logger.info("test_tfrecord_num_samples2") + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 7 + count = 0 + for _ in ds1.create_dict_iterator(): + count += 1 + assert count == 7 + + +def test_tfrecord_shape2(): + logger.info("test_tfrecord_shape2") + ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) + ds1 = ds1.batch(2) + output_shape = ds1.output_shapes() + assert len(output_shape[-1]) == 2 + + +def test_tfrecord_files_basic(): + logger.info("test_tfrecord_files_basic") + + data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + filename = "tfrecord_files_basic.npz" + save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) + + +def test_tfrecord_no_schema(): + logger.info("test_tfrecord_no_schema") + + data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES) + filename = "tfrecord_no_schema.npz" + save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) + + +def test_tfrecord_pad(): + logger.info("test_tfrecord_pad") + + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json" + data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES) + filename = "tfrecord_pad_bytes10.npz" + save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) + + +def test_tfrecord_read_files(): + logger.info("test_tfrecord_read_files") + pattern = DATASET_ROOT + "/test.data" + data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + assert sum([1 for _ in data]) == 12 + + pattern = DATASET_ROOT + "/test2.data" + data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + assert sum([1 for _ in data]) == 12 + + pattern = DATASET_ROOT + "/*.data" + data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES) + assert sum([1 for _ in data]) == 24 + + pattern = DATASET_ROOT + "/*.data" + data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES) + assert sum([1 for _ in data]) == 3 + + data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"], + SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES) + assert sum([1 for _ in data]) == 24 + + +def test_tfrecord_multi_files(): + logger.info("test_tfrecord_multi_files") + data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False) + data1 = data1.repeat(1) + num_iter = 0 + for _ in data1.create_dict_iterator(): + num_iter += 1 + + assert num_iter == 12 + + +def test_tfrecord_schema(): + logger.info("test_tfrecord_schema") + schema = ds.Schema() + schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) + schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) + schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2]) + schema.add_column('col_binary', de_type=mstype.uint8, shape=[1]) + schema.add_column('col_float', de_type=mstype.float32, shape=[1]) + schema.add_column('col_sint16', de_type=mstype.int64, shape=[1]) + schema.add_column('col_sint32', de_type=mstype.int64, shape=[1]) + schema.add_column('col_sint64', de_type=mstype.int64, shape=[1]) + data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES) + + data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + + for d1, d2 in zip(data1, data2): + for t1, t2 in zip(d1, d2): + np.testing.assert_array_equal(t1, t2) + + +def test_tfrecord_shuffle(): + logger.info("test_tfrecord_shuffle") + ds.config.set_seed(1) + data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) + data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + data2 = data2.shuffle(10000) + + for d1, d2 in zip(data1, data2): + for t1, t2 in zip(d1, d2): + np.testing.assert_array_equal(t1, t2) + + +def test_tfrecord_shard(): + logger.info("test_tfrecord_shard") + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", + "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] + + def get_res(shard_id, num_repeats): + data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3, + shuffle=ds.Shuffle.FILES) + data1 = data1.repeat(num_repeats) + res = list() + for item in data1.create_dict_iterator(): + res.append(item["scalars"][0]) + return res + + # get separate results from two workers. the 2 results need to satisfy 2 criteria + # 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch) + # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4) + worker1_res = get_res(0, 16) + worker2_res = get_res(1, 16) + # Confirm each worker gets 3x16=48 rows + assert len(worker1_res) == 48 + assert len(worker1_res) == len(worker2_res) + # check criteria 1 + for i, _ in enumerate(worker1_res): + assert worker1_res[i] != worker2_res[i] + # check criteria 2 + assert set(worker2_res) == set(worker1_res) + + +def test_tfrecord_shard_equal_rows(): + logger.info("test_tfrecord_shard_equal_rows") + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", + "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] + + def get_res(num_shards, shard_id, num_repeats): + ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True) + ds1 = ds1.repeat(num_repeats) + res = list() + for data in ds1.create_dict_iterator(): + res.append(data["scalars"][0]) + return res + + worker1_res = get_res(3, 0, 2) + worker2_res = get_res(3, 1, 2) + worker3_res = get_res(3, 2, 2) + # check criteria 1 + for i, _ in enumerate(worker1_res): + assert worker1_res[i] != worker2_res[i] + assert worker2_res[i] != worker3_res[i] + # Confirm each worker gets same number of rows + assert len(worker1_res) == 28 + assert len(worker1_res) == len(worker2_res) + assert len(worker2_res) == len(worker3_res) + + worker4_res = get_res(1, 0, 1) + assert len(worker4_res) == 40 + + +def test_tfrecord_no_schema_columns_list(): + logger.info("test_tfrecord_no_schema_columns_list") + data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) + row = data.create_dict_iterator().__next__() + assert row["col_sint16"] == [-32768] + + with pytest.raises(KeyError) as info: + _ = row["col_sint32"] + assert "col_sint32" in str(info.value) + + +def test_tfrecord_schema_columns_list(): + logger.info("test_tfrecord_schema_columns_list") + schema = ds.Schema() + schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) + schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) + schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2]) + schema.add_column('col_binary', de_type=mstype.uint8, shape=[1]) + schema.add_column('col_float', de_type=mstype.float32, shape=[1]) + schema.add_column('col_sint16', de_type=mstype.int64, shape=[1]) + schema.add_column('col_sint32', de_type=mstype.int64, shape=[1]) + schema.add_column('col_sint64', de_type=mstype.int64, shape=[1]) + data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"]) + row = data.create_dict_iterator().__next__() + assert row["col_sint16"] == [-32768] + + with pytest.raises(KeyError) as info: + _ = row["col_sint32"] + assert "col_sint32" in str(info.value) + + +def test_tfrecord_invalid_files(): + logger.info("test_tfrecord_invalid_files") + valid_file = "../data/dataset/testTFTestAllTypes/test.data" + invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" + files = [invalid_file, valid_file, SCHEMA_FILE] + + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + + with pytest.raises(RuntimeError) as info: + _ = data.create_dict_iterator().get_next() + assert "cannot be opened" in str(info.value) + assert "not valid tfrecord files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file in str(info.value) + assert SCHEMA_FILE in str(info.value) + + nonexistent_file = "this/file/does/not/exist" + files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file] + + with pytest.raises(ValueError) as info: + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + assert "did not match any files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file not in str(info.value) + assert SCHEMA_FILE not in str(info.value) + assert nonexistent_file in str(info.value) + + +if __name__ == '__main__': + test_tfrecord_shape() + test_tfrecord_read_all_dataset() + test_tfrecord_num_samples() + test_tfrecord_num_samples2() + test_tfrecord_shape2() + test_tfrecord_files_basic() + test_tfrecord_no_schema() + test_tfrecord_pad() + test_tfrecord_read_files() + test_tfrecord_multi_files() + test_tfrecord_schema() + test_tfrecord_shuffle() + test_tfrecord_shard() + test_tfrecord_shard_equal_rows() + test_tfrecord_no_schema_columns_list() + test_tfrecord_schema_columns_list() + test_tfrecord_invalid_files() diff --git a/tests/ut/python/dataset/test_datasets_voc.py b/tests/ut/python/dataset/test_datasets_voc.py index 37f4a8c123..1978b7005f 100644 --- a/tests/ut/python/dataset/test_datasets_voc.py +++ b/tests/ut/python/dataset/test_datasets_voc.py @@ -36,8 +36,8 @@ def test_voc_detection(): count = [0, 0, 0, 0, 0, 0] for item in data1.create_dict_iterator(): assert item["image"].shape[0] == IMAGE_SHAPE[num] - for bbox in item["annotation"]: - count[int(bbox[6])] += 1 + for label in item["label"]: + count[label[0]] += 1 num += 1 assert num == 9 assert count == [3, 2, 1, 2, 4, 3] @@ -54,9 +54,9 @@ def test_voc_class_index(): num = 0 count = [0, 0, 0, 0, 0, 0] for item in data1.create_dict_iterator(): - for bbox in item["annotation"]: - assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 5) - count[int(bbox[6])] += 1 + for label in item["label"]: + count[label[0]] += 1 + assert label[0] in (0, 1, 5) num += 1 assert num == 6 assert count == [3, 2, 0, 0, 0, 3] @@ -72,10 +72,9 @@ def test_voc_get_class_indexing(): num = 0 count = [0, 0, 0, 0, 0, 0] for item in data1.create_dict_iterator(): - for bbox in item["annotation"]: - assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 2 or int(bbox[6]) == 3 - or int(bbox[6]) == 4 or int(bbox[6]) == 5) - count[int(bbox[6])] += 1 + for label in item["label"]: + count[label[0]] += 1 + assert label[0] in (0, 1, 2, 3, 4, 5) num += 1 assert num == 9 assert count == [3, 2, 1, 2, 4, 3] diff --git a/tests/ut/python/dataset/test_deviceop_cpu.py b/tests/ut/python/dataset/test_deviceop_cpu.py index 1c701c3e40..b5f18665e0 100644 --- a/tests/ut/python/dataset/test_deviceop_cpu.py +++ b/tests/ut/python/dataset/test_deviceop_cpu.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import time + import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore import log as logger @@ -35,6 +37,8 @@ def test_case_0(): data = data.device_que() data.send() + time.sleep(0.1) + data.stop_send() def test_case_1(): @@ -58,6 +62,8 @@ def test_case_1(): data = data.device_que() data.send() + time.sleep(0.1) + data.stop_send() def test_case_2(): @@ -84,6 +90,8 @@ def test_case_2(): data = data.device_que() assert data.get_repeat_count() == 2 data.send() + time.sleep(0.1) + data.stop_send() def test_case_3(): @@ -109,13 +117,17 @@ def test_case_3(): data = data.device_que() data.send() + time.sleep(0.1) + data.stop_send() def test_case_tf_file(): data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - data = data.to_device(num_batch=10) + data = data.to_device() data.send() + time.sleep(0.1) + data.stop_send() if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_epoch_ctrl.py b/tests/ut/python/dataset/test_epoch_ctrl.py new file mode 100644 index 0000000000..3a5ddb3b8c --- /dev/null +++ b/tests/ut/python/dataset/test_epoch_ctrl.py @@ -0,0 +1,701 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Epoch Control op in DE +""" +import itertools +import cv2 +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +from mindspore import log as logger + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + + +def diff_mse(in1, in2): + """ + diff_mse + """ + mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() + return mse * 100 + +def test_cifar10(): + """ + dataset parameter + """ + logger.info("Test dataset parameter") + data_dir_10 = "../data/dataset/testCifar10Data" + num_repeat = 2 + batch_size = 32 + limit_dataset = 100 + # apply dataset operations + data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset) + data1 = data1.repeat(num_repeat) + data1 = data1.batch(batch_size, True) + num_epoch = 5 + # iter1 will always assume there is a next epoch and never shutdown. + iter1 = data1.create_tuple_iterator() + epoch_count = 0 + sample_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + # in this example, each dictionary has keys "image" and "label" + row_count += 1 + assert row_count == int(limit_dataset * num_repeat / batch_size) + logger.debug("row_count: ", row_count) + epoch_count += 1 + sample_count += row_count + assert epoch_count == num_epoch + logger.debug("total epochs: ", epoch_count) + assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch + logger.debug("total sample: ", sample_count) + + +def test_decode_op(): + """ + Test Decode op + """ + logger.info("test_decode_op") + + # Decode with rgb format set to True + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + + # Serialize and Load dataset requires using vision.Decode instead of vision.Decode(). + data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) + + # Second dataset + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + + num_epoch = 5 + # iter1 will always assume there is a next epoch and never shutdown. + iter1 = data1.create_dict_iterator() + # iter 2 will stop and shutdown pipeline after num_epoch + iter2 = data2.create_dict_iterator(num_epoch) + for _ in range(num_epoch): + i = 0 + for item1, item2 in itertools.zip_longest(iter1, iter2): + actual = item1["image"] + expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR) + expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB) + assert actual.shape == expected.shape + diff = actual - expected + mse = np.sum(np.power(diff, 2)) + assert mse == 0 + i = i + 1 + assert i == 3 + + # Users have the option to manually stop the iterator, or rely on garbage collector. + iter1.stop() + # Expect a AttributeError since iter1 has been stopped. + with pytest.raises(AttributeError) as info: + iter1.__next__() + assert "object has no attribute 'depipeline'" in str(info.value) + + with pytest.raises(RuntimeError) as info: + iter2.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + + +# Generate 1d int numpy array from 0 - 63 +def generator_1d(): + """ + generator + """ + for i in range(64): + yield (np.array([i]),) + + +def test_generator_dict_0(): + """ + test generator dict 0 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + i = 0 + # create the iterator inside the loop declaration + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + +def test_generator_dict_1(): + """ + test generator dict 1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + for _ in range(10): + i = 0 + # BAD. Do not create iterator every time inside. + # Create iterator outside the epoch for loop. + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + +def test_generator_dict_2(): + """ + test generator dict 2 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_dict_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + + # iter1 is still alive and running. + item1 = iter1.__next__() + assert item1 + # rely on garbage collector to destroy iter1 + +def test_generator_dict_3(): + """ + test generator dict 3 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_dict_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + # optional + iter1.stop() + # Expect a AttributeError since iter1 has been stopped. + with pytest.raises(AttributeError) as info: + iter1.__next__() + assert "object has no attribute 'depipeline'" in str(info.value) + + +def test_generator_dict_4(): + """ + test generator dict 4 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_dict_iterator(num_epochs=10) + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + +def test_generator_dict_4_1(): + """ + test generator dict 4_1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + # epoch ctrl op will not be injected if num_epochs is 1. + iter1 = data1.create_dict_iterator(num_epochs=1) + for _ in range(1): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + +def test_generator_dict_4_2(): + """ + test generator dict 4_2 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + # repeat will not be injected when num repeat is 1. + data1 = data1.repeat(1) + # epoch ctrl op will not be injected if num_epochs is 1. + iter1 = data1.create_dict_iterator(num_epochs=1) + for _ in range(1): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + +def test_generator_dict_5(): + """ + test generator dict 5 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_dict_iterator(num_epochs=11) + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + + # still one more epoch left in the iter1. + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + assert i == 64 + + # now iter1 has been exhausted, c++ pipeline has been shut down. + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + +# Test tuple iterator + +def test_generator_tuple_0(): + """ + test generator tuple 0 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + i = 0 + # create the iterator inside the loop declaration + for item in data1.create_tuple_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + +def test_generator_tuple_1(): + """ + test generator tuple 1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + for _ in range(10): + i = 0 + # BAD. Do not create iterator every time inside. + # Create iterator outside the epoch for loop. + for item in data1.create_tuple_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 + +def test_generator_tuple_2(): + """ + test generator tuple 2 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_tuple_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 + + # iter1 is still alive and running. + item1 = iter1.__next__() + assert item1 + # rely on garbage collector to destroy iter1 + +def test_generator_tuple_3(): + """ + test generator tuple 3 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_tuple_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 + # optional + iter1.stop() + # Expect a AttributeError since iter1 has been stopped. + with pytest.raises(AttributeError) as info: + iter1.__next__() + assert "object has no attribute 'depipeline'" in str(info.value) + + +def test_generator_tuple_4(): + """ + test generator tuple 4 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_tuple_iterator(num_epochs=10) + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 + + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + + +def test_generator_tuple_5(): + """ + test generator tuple 5 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + iter1 = data1.create_tuple_iterator(num_epochs=11) + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 + + # still one more epoch left in the iter1. + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 + + # now iter1 has been exhausted, c++ pipeline has been shut down. + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + +# Test with repeat +def test_generator_tuple_repeat_1(): + """ + test generator tuple repeat 1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(2) + iter1 = data1.create_tuple_iterator(num_epochs=11) + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 + + # still one more epoch left in the iter1. + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 + + # now iter1 has been exhausted, c++ pipeline has been shut down. + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + + +# Test with repeat +def test_generator_tuple_repeat_repeat_1(): + """ + test generator tuple repeat repeat 1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(2) + data1 = data1.repeat(3) + iter1 = data1.create_tuple_iterator(num_epochs=11) + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 * 3 + + # still one more epoch left in the iter1. + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 * 3 + + # now iter1 has been exhausted, c++ pipeline has been shut down. + with pytest.raises(RuntimeError) as info: + iter1.__next__() + err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." + assert err_msg in str(info.value) + + +def test_generator_tuple_repeat_repeat_2(): + """ + test generator tuple repeat repeat 2 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(2) + data1 = data1.repeat(3) + iter1 = data1.create_tuple_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 * 3 + # optional + iter1.stop() + # Expect a AttributeError since iter1 has been stopped. + with pytest.raises(AttributeError) as info: + iter1.__next__() + assert "object has no attribute 'depipeline'" in str(info.value) + +def test_generator_tuple_repeat_repeat_3(): + """ + test generator tuple repeat repeat 3 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(2) + data1 = data1.repeat(3) + iter1 = data1.create_tuple_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 * 3 + + for _ in range(5): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 * 3 + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_1(): + """ + test generator tuple infinite repeat repeat 1 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat() + data1 = data1.repeat(3) + iter1 = data1.create_tuple_iterator(num_epochs=11) + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_2(): + """ + test generator tuple infinite repeat repeat 2 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(3) + data1 = data1.repeat() + iter1 = data1.create_tuple_iterator(num_epochs=11) + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_3(): + """ + test generator tuple infinite repeat repeat 3 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat() + data1 = data1.repeat() + iter1 = data1.create_tuple_iterator(num_epochs=11) + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_tuple_infinite_repeat_repeat_4(): + """ + test generator tuple infinite repeat repeat 4 + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat() + data1 = data1.repeat() + iter1 = data1.create_tuple_iterator() + + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + if i == 100: + break + + # rely on garbage collector to destroy iter1 + + +def test_generator_reusedataset(): + """ + test generator reusedataset + """ + logger.info("Test 1D Generator : 0 - 63") + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(2) + iter1 = data1.create_tuple_iterator() + for _ in range(10): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 + + data1 = data1.repeat(3) + iter1 = data1.create_tuple_iterator() + for _ in range(5): + i = 0 + for item in iter1: # each data is a dictionary + golden = np.array([i % 64]) + np.testing.assert_array_equal(item[0], golden) + i = i + 1 + assert i == 64 * 2 * 3 + + data1 = data1.batch(2) + iter1 = data1.create_dict_iterator() + for _ in range(5): + i = 0 + sample = 0 + for item in iter1: # each data is a dictionary + golden = np.array([[i % 64], [(i + 1) % 64]]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 2 + sample = sample + 1 + assert sample == 64 * 3 + + # rely on garbage collector to destroy iter1 diff --git a/tests/ut/python/dataset/test_equalize.py b/tests/ut/python/dataset/test_equalize.py index 0a5f2f93d5..26102ae809 100644 --- a/tests/ut/python/dataset/test_equalize.py +++ b/tests/ut/python/dataset/test_equalize.py @@ -18,6 +18,7 @@ Testing Equalize op in DE import numpy as np import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.py_transforms as F from mindspore import log as logger from util import visualize_list, diff_mse, save_and_check_md5 @@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" GENERATE_GOLDEN = False -def test_equalize(plot=False): +def test_equalize_py(plot=False): """ - Test Equalize + Test Equalize py op """ logger.info("Test Equalize") @@ -83,9 +84,141 @@ def test_equalize(plot=False): visualize_list(images_original, images_equalize) -def test_equalize_md5(): +def test_equalize_c(plot=False): """ - Test Equalize with md5 check + Test Equalize Cpp op + """ + logger.info("Test Equalize cpp op") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = [C.Decode(), C.Resize(size=[224, 224])] + + ds_original = ds.map(input_columns="image", + operations=transforms_original) + + ds_original = ds_original.batch(512) + + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, + image, + axis=0) + + # Equalize Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_equalize = [C.Decode(), C.Resize(size=[224, 224]), + C.Equalize()] + + ds_equalize = ds.map(input_columns="image", + operations=transform_equalize) + + ds_equalize = ds_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_equalize): + if idx == 0: + images_equalize = image + else: + images_equalize = np.append(images_equalize, + image, + axis=0) + if plot: + visualize_list(images_original, images_equalize) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_equalize[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_equalize_py_c(plot=False): + """ + Test Equalize Cpp op and python op + """ + logger.info("Test Equalize cpp and python op") + + # equalize Images in cpp + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + ds_c_equalize = ds.map(input_columns="image", + operations=C.Equalize()) + + ds_c_equalize = ds_c_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_c_equalize): + if idx == 0: + images_c_equalize = image + else: + images_c_equalize = np.append(images_c_equalize, + image, + axis=0) + + # Equalize images in python + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8), + F.ToPIL(), + F.Equalize(), + np.array]) + + ds_p_equalize = ds.map(input_columns="image", + operations=transforms_p_equalize()) + + ds_p_equalize = ds_p_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_p_equalize): + if idx == 0: + images_p_equalize = image + else: + images_p_equalize = np.append(images_p_equalize, + image, + axis=0) + + num_samples = images_c_equalize.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2) + + +def test_equalize_one_channel(): + """ + Test Equalize cpp op with one channel image + """ + logger.info("Test Equalize C Op With One Channel Images") + + c_op = C.Equalize() + + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + + ds.map(input_columns="image", + operations=c_op) + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "The shape" in str(e) + + +def test_equalize_md5_py(): + """ + Test Equalize py op with md5 check """ logger.info("Test Equalize") @@ -101,6 +234,31 @@ def test_equalize_md5(): save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) +def test_equalize_md5_c(): + """ + Test Equalize cpp op with md5 check + """ + logger.info("Test Equalize cpp op with md5 check") + + # Generate dataset + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_equalize = [C.Decode(), + C.Resize(size=[224, 224]), + C.Equalize(), + F.ToTensor()] + + data = ds.map(input_columns="image", operations=transforms_equalize) + # Compare with expected md5 from images + filename = "equalize_01_result_c.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + if __name__ == "__main__": - test_equalize(plot=True) - test_equalize_md5() + test_equalize_py(plot=False) + test_equalize_c(plot=False) + test_equalize_py_c(plot=False) + test_equalize_one_channel() + test_equalize_md5_py() + test_equalize_md5_c() + \ No newline at end of file diff --git a/tests/ut/python/dataset/test_five_crop.py b/tests/ut/python/dataset/test_five_crop.py index ef2e376c0f..86f52bdcd7 100644 --- a/tests/ut/python/dataset/test_five_crop.py +++ b/tests/ut/python/dataset/test_five_crop.py @@ -87,7 +87,7 @@ def test_five_crop_error_msg(): data = data.map(input_columns=["image"], operations=transform()) with pytest.raises(RuntimeError) as info: - data.create_tuple_iterator().get_next() + data.create_tuple_iterator().__next__() error_msg = "TypeError: img should be PIL Image or Numpy array. Got " # error msg comes from ToTensor() diff --git a/tests/ut/python/dataset/test_from_dataset.py b/tests/ut/python/dataset/test_from_dataset.py index 983052ea08..7b6333ba65 100644 --- a/tests/ut/python/dataset/test_from_dataset.py +++ b/tests/ut/python/dataset/test_from_dataset.py @@ -134,7 +134,7 @@ def test_from_dataset_exceptions(): test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") test_config("text", (2, 3), 1.2345, "Argument top_k with value 1.2345 is not of type (, )") - test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (,)") + test_config(23, (2, 3), 1.2345, "Argument col[0] with value 23 is not of type (,)") test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") test_config("text", (2, 3), 0, "top_k must be greater than 0") test_config([123], (2, 3), -1, "top_k must be greater than 0") diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py deleted file mode 100644 index 926b84a7f4..0000000000 --- a/tests/ut/python/dataset/test_generator.py +++ /dev/null @@ -1,665 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy as np -import pytest - -import mindspore.common.dtype as mstype -import mindspore.dataset as ds -from mindspore import log as logger - - -# Generate 1d int numpy array from 0 - 63 -def generator_1d(): - for i in range(64): - yield (np.array([i]),) - - -def test_case_0(): - """ - Test 1D Generator - """ - logger.info("Test 1D Generator : 0 - 63") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_1d, ["data"]) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(item["data"], golden) - i = i + 1 - - -# Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]] -def generator_md(): - for i in range(64): - yield (np.array([[i, i + 1], [i + 2, i + 3]]),) - - -def test_case_1(): - """ - Test MD Generator - """ - logger.info("Test MD Generator : 0 - 63, with shape [2, 2]") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_md, ["data"]) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["data"], golden) - i = i + 1 - - -# Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD -def generator_mc(maxid=64): - for i in range(maxid): - yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) - - -def test_case_2(): - """ - Test multi column generator - """ - logger.info("Test multi column generator") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"]) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(item["col0"], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["col1"], golden) - i = i + 1 - - -def test_case_3(): - """ - Test 1D Generator + repeat(4) - """ - logger.info("Test 1D Generator : 0 - 63 + Repeat(4)") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_1d, ["data"]) - - data1 = data1.repeat(4) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(item["data"], golden) - i = i + 1 - if i == 64: - i = 0 - - -def test_case_4(): - """ - Test fixed size 1D Generator + batch - """ - logger.info("Test 1D Generator : 0 - 63 + batch(4)") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_1d, ["data"]) - - data1 = data1.batch(4) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([[i], [i + 1], [i + 2], [i + 3]]) - assert np.array_equal(item["data"], golden) - i = i + 4 - - -def generator_with_type(t): - for i in range(64): - yield (np.array([i], dtype=t),) - - -def type_tester(t): - logger.info("Test with Type {}".format(t.__name__)) - - # apply dataset operations - data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"]) - - data1 = data1.batch(4) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) - assert np.array_equal(item["data"], golden) - i = i + 4 - - -def test_case_5(): - """ - Test 1D Generator on different data type - """ - logger.info("Test 1D Generator on all data types") - - types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64] - - for t in types: - type_tester(t) - - -def type_tester_with_type_check(t, c): - logger.info("Test with Type {}".format(t.__name__)) - - # apply dataset operations - data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c]) - - data1 = data1.batch(4) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) - assert np.array_equal(item["data"], golden) - i = i + 4 - - -def test_case_6(): - """ - Test 1D Generator on different data type with type check - """ - logger.info("Test 1D Generator on all data types with type check") - - np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, - np.float64] - de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, - mstype.uint64, mstype.float32, mstype.float64] - - for i, _ in enumerate(np_types): - type_tester_with_type_check(np_types[i], de_types[i]) - - -def generator_with_type_2c(t): - for i in range(64): - yield (np.array([i], dtype=t), np.array([i], dtype=t)) - - -def type_tester_with_type_check_2c(t, c): - logger.info("Test with Type {}".format(t.__name__)) - - # apply dataset operations - data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c) - - data1 = data1.batch(4) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) - assert np.array_equal(item["data0"], golden) - i = i + 4 - - -def test_case_7(): - """ - Test 2 column Generator on different data type with type check - """ - logger.info("Test 2 column Generator on all data types with type check") - - np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, - np.float64] - de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, - mstype.uint64, mstype.float32, mstype.float64] - - for i, _ in enumerate(np_types): - type_tester_with_type_check_2c(np_types[i], [None, de_types[i]]) - - -def test_case_8(): - """ - Test multi column generator with few mapops - """ - logger.info("Test multi column generator with mapops to check the order too") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) - data1 = data1.map(input_columns="col0", output_columns="out0", operations=(lambda x: x * 3), - num_parallel_workers=2) - data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x * 7, x)), - num_parallel_workers=2, columns_order=["out0", "out1", "out2"]) - data1 = data1.map(input_columns="out2", output_columns="out2", operations=(lambda x: x + 1), - num_parallel_workers=2) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i * 3]) - assert np.array_equal(item["out0"], golden) - golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]]) - assert np.array_equal(item["out1"], golden) - golden = np.array([[i + 1, i + 2], [i + 3, i + 4]]) - assert np.array_equal(item["out2"], golden) - i = i + 1 - - -def test_case_9(): - """ - Test map column order when len(input_columns) == len(output_columns). - """ - logger.info("Test map column order when len(input_columns) == len(output_columns).") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"]) - data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) - data1 = data1.map(input_columns="label", operations=(lambda x: x * 3), - num_parallel_workers=4) - data2 = data2.map(input_columns="label", operations=(lambda x: x * 3), - num_parallel_workers=4) - - # Expected column order is not changed. - # data1 = data[0] is "image" and data[1] is "label" - # data2 = data[0] is "label" and data[1] is "image" - i = 0 - for data1, data2 in zip(data1, data2): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(data1[0], golden) - golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]]) - assert np.array_equal(data1[1], golden) - - golden = np.array([i * 3]) - assert np.array_equal(data2[0], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(data2[1], golden) - i = i + 1 - - -def test_case_10(): - """ - Test map column order when len(input_columns) != len(output_columns). - """ - logger.info("Test map column order when len(input_columns) != len(output_columns).") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) - data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)), - columns_order=['col0', 'out1', 'out2'], num_parallel_workers=2) - - # Expected column order is |col0|out1|out2| - i = 0 - for item in data1.create_tuple_iterator(): - golden = np.array([i]) - assert np.array_equal(item[0], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item[1], golden) - golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]]) - assert np.array_equal(item[2], golden) - i = i + 1 - - -def test_case_11(): - """ - Test map column order when len(input_columns) != len(output_columns). - """ - logger.info("Test map column order when len(input_columns) != len(output_columns), " - "and columns_order drops some columns.") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) - data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)), - columns_order=['out1', 'out2'], num_parallel_workers=2) - - # Expected column order is |out1|out2| - i = 0 - for item in data1.create_tuple_iterator(): - # len should be 2 because col0 is dropped (not included in columns_order) - assert len(item) == 2 - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item[0], golden) - golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]]) - assert np.array_equal(item[1], golden) - i = i + 1 - - -def test_case_12(): - """ - Test map column order when input_columns and output_columns are None. - """ - logger.info("Test map column order when input_columns and output_columns are None.") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) - data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2) - - # Expected column order is |col0|col1| - i = 0 - for item in data1.create_tuple_iterator(): - assert len(item) == 2 - golden = np.array([i * 5]) - assert np.array_equal(item[0], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item[1], golden) - i = i + 1 - - data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) - data1 = data1.map(operations=(lambda x: (x * 5)), columns_order=["col1", "col0"], num_parallel_workers=2) - - # Expected column order is |col0|col1| - i = 0 - for item in data1.create_tuple_iterator(): - assert len(item) == 2 - golden = np.array([i * 5]) - assert np.array_equal(item[1], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item[0], golden) - i = i + 1 - - -def test_case_13(): - """ - Test map column order when input_columns is None. - """ - logger.info("Test map column order when input_columns is None.") - - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"]) - data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2) - - # Expected column order is |out0|col1| - i = 0 - for item in data1.create_tuple_iterator(): - assert len(item) == 2 - golden = np.array([i * 5]) - assert np.array_equal(item[0], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item[1], golden) - i = i + 1 - - for item in data1.create_dict_iterator(): # each data is a dictionary - # len should be 2 because col0 is dropped (not included in columns_order) - assert len(item) == 2 - golden = np.array([i * 5]) - assert np.array_equal(item["out0"], golden) - golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["col1"], golden) - i = i + 1 - - -def test_case_14(): - """ - Test 1D Generator MP + CPP sampler - """ - logger.info("Test 1D Generator MP : 0 - 63") - - source = [(np.array([x]),) for x in range(256)] - ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2) - i = 0 - for data in ds1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(data["data"], golden) - i = i + 1 - if i == 256: - i = 0 - - -def test_case_15(): - """ - Test 1D Generator MP + Python sampler - """ - logger.info("Test 1D Generator MP : 0 - 63") - - sampler = [x for x in range(256)] - source = [(np.array([x]),) for x in range(256)] - ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2) - i = 0 - for data in ds1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(data["data"], golden) - i = i + 1 - if i == 256: - i = 0 - - -def test_case_16(): - """ - Test multi column generator Mp + CPP sampler - """ - logger.info("Test multi column generator") - - source = [(np.array([x]), np.array([x + 1])) for x in range(256)] - # apply dataset operations - data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler()) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(item["col0"], golden) - golden = np.array([i + 1]) - assert np.array_equal(item["col1"], golden) - i = i + 1 - - -def test_case_17(): - """ - Test multi column generator Mp + Python sampler - """ - logger.info("Test multi column generator") - - sampler = [x for x in range(256)] - source = [(np.array([x]), np.array([x + 1])) for x in range(256)] - # apply dataset operations - data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(item["col0"], golden) - golden = np.array([i + 1]) - assert np.array_equal(item["col1"], golden) - i = i + 1 - - -def test_case_error_1(): - def generator_np(): - for i in range(64): - yield (np.array([{i}]),) - - with pytest.raises(RuntimeError) as info: - data1 = ds.GeneratorDataset(generator_np, ["data"]) - for _ in data1: - pass - assert "Invalid data type" in str(info.value) - - -def test_case_error_2(): - def generator_np(): - for i in range(64): - yield ({i},) - - with pytest.raises(RuntimeError) as info: - data1 = ds.GeneratorDataset(generator_np, ["data"]) - for _ in data1: - pass - assert "Generator should return a tuple of numpy arrays" in str(info.value) - - -def test_case_error_3(): - with pytest.raises(ValueError) as info: - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) - data1 = data1.map(input_columns=["label"], output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)), - num_parallel_workers=2) - - for _ in data1: - pass - assert "When (len(input_columns) != len(output_columns)), columns_order must be specified." in str(info.value) - - -def test_case_error_4(): - with pytest.raises(RuntimeError) as info: - # apply dataset operations - data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) - data1 = data1.map(input_columns=["label"], operations=(lambda x: (x, x * 5)), - num_parallel_workers=2) - - for _ in data1: - pass - assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) - - -def test_sequential_sampler(): - source = [(np.array([x]),) for x in range(64)] - ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) - i = 0 - for data in ds1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(data["data"], golden) - i = i + 1 - - -def test_random_sampler(): - source = [(np.array([x]),) for x in range(64)] - ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True) - for _ in ds1.create_dict_iterator(): # each data is a dictionary - pass - - -def test_distributed_sampler(): - source = [(np.array([x]),) for x in range(64)] - for sid in range(8): - ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid) - i = sid - for data in ds1.create_dict_iterator(): # each data is a dictionary - golden = np.array([i]) - assert np.array_equal(data["data"], golden) - i = i + 8 - - -def test_num_samples(): - source = [(np.array([x]),) for x in range(64)] - num_samples = 32 - ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples)) - ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples) - ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples) - - count = 0 - for _ in ds1.create_dict_iterator(): - count = count + 1 - assert count == num_samples - - count = 0 - for _ in ds2.create_dict_iterator(): - count = count + 1 - assert count == num_samples - - count = 0 - for _ in ds3.create_dict_iterator(): - count = count + 1 - assert count == num_samples - - -def test_num_samples_underflow(): - source = [(np.array([x]),) for x in range(64)] - num_samples = 256 - ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples) - ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples) - - count = 0 - for _ in ds2.create_dict_iterator(): - count = count + 1 - assert count == 64 - - count = 0 - for _ in ds3.create_dict_iterator(): - count = count + 1 - assert count == 64 - - -def type_tester_with_type_check_2c_schema(t, c): - logger.info("Test with Type {}".format(t.__name__)) - - schema = ds.Schema() - schema.add_column("data0", c[0]) - schema.add_column("data1", c[1]) - - # apply dataset operations - data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema) - - data1 = data1.batch(4) - - i = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) - assert np.array_equal(item["data0"], golden) - i = i + 4 - - -def test_schema(): - """ - Test 2 column Generator on different data type with type check with schema input - """ - logger.info("Test 2 column Generator on all data types with type check") - - np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, - np.float64] - de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, - mstype.uint64, mstype.float32, mstype.float64] - - for i, _ in enumerate(np_types): - type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) - - -def manual_test_keyborad_interrupt(): - """ - Test keyborad_interrupt - """ - logger.info("Test 1D Generator MP : 0 - 63") - - class MyDS(): - def __getitem__(self, item): - while True: - pass - - def __len__(self): - return 1024 - - ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2) - for _ in ds1.create_dict_iterator(): # each data is a dictionary - pass - - -if __name__ == "__main__": - test_case_0() - test_case_1() - test_case_2() - test_case_3() - test_case_4() - test_case_5() - test_case_6() - test_case_7() - test_case_8() - test_case_9() - test_case_10() - test_case_11() - test_case_12() - test_case_13() - test_case_14() - test_case_15() - test_case_16() - test_case_17() - test_case_error_1() - test_case_error_2() - test_case_error_3() - test_case_error_4() - test_sequential_sampler() - test_distributed_sampler() - test_random_sampler() - test_num_samples() - test_num_samples_underflow() - test_schema() diff --git a/tests/ut/python/dataset/test_get_size.py b/tests/ut/python/dataset/test_get_size.py index ba4162788c..1dce312a32 100644 --- a/tests/ut/python/dataset/test_get_size.py +++ b/tests/ut/python/dataset/test_get_size.py @@ -41,18 +41,18 @@ def test_case1(): assert data.get_batch_size() == 2 assert data.get_repeat_count() == 1 data = data.repeat(10) - assert data.get_dataset_size() == 6 + assert data.get_dataset_size() == 60 assert data.get_batch_size() == 2 assert data.get_repeat_count() == 10 data = data.project(["new_column"]) - assert data.get_dataset_size() == 6 + assert data.get_dataset_size() == 60 assert data.get_batch_size() == 2 assert data.get_repeat_count() == 10 data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) data1 = data.zip(data2) - assert data1.get_dataset_size() == 6 + assert data1.get_dataset_size() == 60 def test_case2(): @@ -65,14 +65,14 @@ def test_case2(): data = data.rename("col_sint64", "new_column") assert data.get_dataset_size() == 3 data = data.repeat(10) - assert data.get_dataset_size() == 3 + assert data.get_dataset_size() == 30 data = data.project(["new_column"]) - assert data.get_dataset_size() == 3 + assert data.get_dataset_size() == 30 data2 = ds.TFRecordDataset(FILES, num_samples=6).batch(2).repeat(10) data1 = data.zip(data2) - assert data1.get_dataset_size() == 3 + assert data1.get_dataset_size() == 30 def test_case3(): @@ -94,11 +94,11 @@ def test_case4(): data2 = data2.shuffle(100) assert data2.get_dataset_size() == 6 data2 = data2.repeat(3) - assert data2.get_dataset_size() == 6 + assert data2.get_dataset_size() == 18 data3 = ds.zip((data1, data2)) - assert data3.get_dataset_size() == 6 + assert data3.get_dataset_size() == 18 def test_case5(): diff --git a/tests/ut/python/dataset/test_invert.py b/tests/ut/python/dataset/test_invert.py index f366553c6e..4f70c5a7ee 100644 --- a/tests/ut/python/dataset/test_invert.py +++ b/tests/ut/python/dataset/test_invert.py @@ -19,18 +19,20 @@ import numpy as np import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.py_transforms as F +import mindspore.dataset.transforms.vision.c_transforms as C from mindspore import log as logger -from util import visualize_list, save_and_check_md5 +from util import visualize_list, save_and_check_md5, diff_mse DATA_DIR = "../data/dataset/testImageNetData/train/" GENERATE_GOLDEN = False -def test_invert(plot=False): + +def test_invert_py(plot=False): """ - Test Invert + Test Invert python op """ - logger.info("Test Invert") + logger.info("Test Invert Python op") # Original Images ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) @@ -52,7 +54,7 @@ def test_invert(plot=False): np.transpose(image, (0, 2, 3, 1)), axis=0) - # Color Inverted Images + # Color Inverted Images ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) transforms_invert = F.ComposeOp([F.Decode(), @@ -83,11 +85,143 @@ def test_invert(plot=False): visualize_list(images_original, images_invert) -def test_invert_md5(): +def test_invert_c(plot=False): + """ + Test Invert Cpp op + """ + logger.info("Test Invert cpp op") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = [C.Decode(), C.Resize(size=[224, 224])] + + ds_original = ds.map(input_columns="image", + operations=transforms_original) + + ds_original = ds_original.batch(512) + + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, + image, + axis=0) + + # Invert Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_invert = [C.Decode(), C.Resize(size=[224, 224]), + C.Invert()] + + ds_invert = ds.map(input_columns="image", + operations=transform_invert) + + ds_invert = ds_invert.batch(512) + + for idx, (image, _) in enumerate(ds_invert): + if idx == 0: + images_invert = image + else: + images_invert = np.append(images_invert, + image, + axis=0) + if plot: + visualize_list(images_original, images_invert) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_invert[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_invert_py_c(plot=False): + """ + Test Invert Cpp op and python op + """ + logger.info("Test Invert cpp and python op") + + # Invert Images in cpp + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + ds_c_invert = ds.map(input_columns="image", + operations=C.Invert()) + + ds_c_invert = ds_c_invert.batch(512) + + for idx, (image, _) in enumerate(ds_c_invert): + if idx == 0: + images_c_invert = image + else: + images_c_invert = np.append(images_c_invert, + image, + axis=0) + + # invert images in python + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + transforms_p_invert = F.ComposeOp([lambda img: img.astype(np.uint8), + F.ToPIL(), + F.Invert(), + np.array]) + + ds_p_invert = ds.map(input_columns="image", + operations=transforms_p_invert()) + + ds_p_invert = ds_p_invert.batch(512) + + for idx, (image, _) in enumerate(ds_p_invert): + if idx == 0: + images_p_invert = image + else: + images_p_invert = np.append(images_p_invert, + image, + axis=0) + + num_samples = images_c_invert.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_p_invert[i], images_c_invert[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize_list(images_c_invert, images_p_invert, visualize_mode=2) + + +def test_invert_one_channel(): """ - Test Invert with md5 check + Test Invert cpp op with one channel image + """ + logger.info("Test Invert C Op With One Channel Images") + + c_op = C.Invert() + + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + + ds.map(input_columns="image", + operations=c_op) + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "The shape" in str(e) + + +def test_invert_md5_py(): """ - logger.info("Test Invert with md5 check") + Test Invert python op with md5 check + """ + logger.info("Test Invert python op with md5 check") # Generate dataset ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) @@ -98,10 +232,34 @@ def test_invert_md5(): data = ds.map(input_columns="image", operations=transforms_invert()) # Compare with expected md5 from images - filename = "invert_01_result.npz" + filename = "invert_01_result_py.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + +def test_invert_md5_c(): + """ + Test Invert cpp op with md5 check + """ + logger.info("Test Invert cpp op with md5 check") + + # Generate dataset + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_invert = [C.Decode(), + C.Resize(size=[224, 224]), + C.Invert(), + F.ToTensor()] + + data = ds.map(input_columns="image", operations=transforms_invert) + # Compare with expected md5 from images + filename = "invert_01_result_c.npz" save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) if __name__ == "__main__": - test_invert(plot=True) - test_invert_md5() + test_invert_py(plot=False) + test_invert_c(plot=False) + test_invert_py_c(plot=False) + test_invert_one_channel() + test_invert_md5_py() + test_invert_md5_c() diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index af5a66e89e..70da93a0cc 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -33,7 +33,7 @@ def check(project_columns): assert all([np.array_equal(d1, d2) for d1, d2 in zip(data_actual, data_expected)]) -def test_case_iterator(): +def test_iterator_create_tuple(): """ Test creating tuple iterator """ @@ -73,7 +73,7 @@ def test_iterator_weak_ref(): _cleanup() with pytest.raises(AttributeError) as info: - itr2.get_next() + itr2.__next__() assert "object has no attribute 'depipeline'" in str(info.value) del itr1 @@ -95,7 +95,9 @@ class MyDict(dict): def test_tree_copy(): - # Testing copying the tree with a pyfunc that cannot be pickled + """ + Testing copying the tree with a pyfunc that cannot be pickled + """ data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) data1 = data.map(operations=[MyDict()]) @@ -110,4 +112,6 @@ def test_tree_copy(): if __name__ == '__main__': + test_iterator_create_tuple() + test_iterator_weak_ref() test_tree_copy() diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 7d613d414f..465475c2a2 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -46,58 +46,71 @@ def add_and_remove_cv_file(): """add/remove cv file""" paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_cv_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) - cv_schema_json = {"id": {"type": "int32"}, - "file_name": {"type": "string"}, - "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "img_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - yield "yield_cv_data" - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) - @pytest.fixture def add_and_remove_nlp_file(): """add/remove nlp file""" paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(NLP_FILE_NAME, FILES_NUM) + data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)] + nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"}, + "rating": {"type": "float32"}, + "input_ids": {"type": "int64", + "shape": [-1]}, + "input_mask": {"type": "int64", + "shape": [1, -1]}, + "segment_ids": {"type": "int64", + "shape": [2, -1]} + } + writer.set_header_size(1 << 14) + writer.set_page_size(1 << 15) + writer.add_schema(nlp_schema_json, "nlp_schema") + writer.add_index(["id", "rating"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_nlp_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(NLP_FILE_NAME, FILES_NUM) - data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)] - nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"}, - "rating": {"type": "float32"}, - "input_ids": {"type": "int64", - "shape": [-1]}, - "input_mask": {"type": "int64", - "shape": [1, -1]}, - "segment_ids": {"type": "int64", - "shape": [2, -1]} - } - writer.set_header_size(1 << 14) - writer.set_page_size(1 << 15) - writer.add_schema(nlp_schema_json, "nlp_schema") - writer.add_index(["id", "rating"]) - writer.write_raw_data(data) - writer.commit() - yield "yield_nlp_data" - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) @pytest.fixture @@ -105,44 +118,51 @@ def add_and_remove_nlp_compress_file(): """add/remove nlp file""" paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(NLP_FILE_NAME, FILES_NUM) + data = [] + for row_id in range(16): + data.append({ + "label": row_id, + "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, + 255, 256, -32768, 32767, -32769, 32768, -2147483648, + 2147483647], dtype=np.int32), [-1]), + "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, + 256, -32768, 32767, -32769, 32768, + -2147483648, 2147483647, -2147483649, 2147483649, + -922337036854775808, 9223372036854775807]), [1, -1]), + "array_c": str.encode("nlp data"), + "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) + }) + nlp_schema_json = {"label": {"type": "int32"}, + "array_a": {"type": "int32", + "shape": [-1]}, + "array_b": {"type": "int64", + "shape": [1, -1]}, + "array_c": {"type": "bytes"}, + "array_d": {"type": "int64", + "shape": [2, -1]} + } + writer.set_header_size(1 << 14) + writer.set_page_size(1 << 15) + writer.add_schema(nlp_schema_json, "nlp_schema") + writer.write_raw_data(data) + writer.commit() + yield "yield_nlp_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(NLP_FILE_NAME, FILES_NUM) - data = [] - for row_id in range(16): - data.append({ - "label": row_id, - "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, - 255, 256, -32768, 32767, -32769, 32768, -2147483648, - 2147483647], dtype=np.int32), [-1]), - "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, - 256, -32768, 32767, -32769, 32768, - -2147483648, 2147483647, -2147483649, 2147483649, - -922337036854775808, 9223372036854775807]), [1, -1]), - "array_c": str.encode("nlp data"), - "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) - }) - nlp_schema_json = {"label": {"type": "int32"}, - "array_a": {"type": "int32", - "shape": [-1]}, - "array_b": {"type": "int64", - "shape": [1, -1]}, - "array_c": {"type": "bytes"}, - "array_d": {"type": "int64", - "shape": [2, -1]} - } - writer.set_header_size(1 << 14) - writer.set_page_size(1 << 15) - writer.add_schema(nlp_schema_json, "nlp_schema") - writer.write_raw_data(data) - writer.commit() - yield "yield_nlp_data" - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) def test_nlp_compress_data(add_and_remove_nlp_compress_file): @@ -199,22 +219,29 @@ def test_cv_minddataset_writer_tutorial(): """tutorial for cv dataset writer.""" paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) - cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "img_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): @@ -238,6 +265,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): assert partitions(5) == 2 assert partitions(9) == 2 +def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=1) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + return num_iter + + assert partitions(4) == 1 + assert partitions(5) == 1 + assert partitions(9) == 1 + +def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=2) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + return num_iter + + assert partitions(4) == 2 + assert partitions(5) == 2 + assert partitions(9) == 2 + +def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=3) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + return num_iter + + assert partitions(4) == 3 + assert partitions(5) == 2 + assert partitions(9) == 2 + def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file): """tutorial for cv minddataset.""" @@ -498,49 +591,6 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file): assert num_iter == 18 -def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): - """tutorial for cv minddataset.""" - columns_list = ["data", "label"] - num_readers = 4 - data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, block_reader=True) - assert data_set.get_dataset_size() == 10 - repeat_num = 2 - data_set = data_set.repeat(repeat_num) - num_iter = 0 - for item in data_set.create_dict_iterator(): - logger.info( - "-------------- block reader repeat tow {} -----------------".format(num_iter)) - logger.info( - "-------------- item[label]: {} ----------------------------".format(item["label"])) - logger.info( - "-------------- item[data]: {} -----------------------------".format(item["data"])) - num_iter += 1 - assert num_iter == 20 - - -def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file): - """tutorial for cv minddataset.""" - columns_list = ["id", "data", "label"] - num_readers = 4 - data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, - block_reader=True) - assert data_set.get_dataset_size() == 10 - repeat_num = 2 - data_set = data_set.repeat(repeat_num) - num_iter = 0 - for item in data_set.create_dict_iterator(): - logger.info( - "-------------- block reader repeat tow {} -----------------".format(num_iter)) - logger.info( - "-------------- item[id]: {} ----------------------------".format(item["id"])) - logger.info( - "-------------- item[label]: {} ----------------------------".format(item["label"])) - logger.info( - "-------------- item[data]: {} -----------------------------".format(item["data"])) - num_iter += 1 - assert num_iter == 20 - - def test_cv_minddataset_reader_file_list(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] @@ -588,106 +638,124 @@ def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file): def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file): """tutorial for cv minderdataset.""" - if os.path.exists(CV1_FILE_NAME): - os.remove(CV1_FILE_NAME) - if os.path.exists("{}.db".format(CV1_FILE_NAME)): - os.remove("{}.db".format(CV1_FILE_NAME)) - if os.path.exists(CV2_FILE_NAME): - os.remove(CV2_FILE_NAME) - if os.path.exists("{}.db".format(CV2_FILE_NAME)): - os.remove("{}.db".format(CV2_FILE_NAME)) - writer = FileWriter(CV1_FILE_NAME, 1) - data = get_data(CV_DIR_NAME) - cv_schema_json = {"id": {"type": "int32"}, - "file_name": {"type": "string"}, - "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "CV1_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - - writer = FileWriter(CV2_FILE_NAME, 1) - data = get_data(CV_DIR_NAME) - cv_schema_json = {"id": {"type": "int32"}, - "file_name": {"type": "string"}, - "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "CV2_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - columns_list = ["data", "file_name", "label"] - num_readers = 4 - data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], - columns_list, num_readers) - assert data_set.get_dataset_size() == 30 - num_iter = 0 - for item in data_set.create_dict_iterator(): - logger.info( - "-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info( - "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info( - "-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info( - "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info( - "-------------- item[label]: {} ----------------------------".format(item["label"])) - num_iter += 1 - assert num_iter == 30 - if os.path.exists(CV1_FILE_NAME): - os.remove(CV1_FILE_NAME) - if os.path.exists("{}.db".format(CV1_FILE_NAME)): - os.remove("{}.db".format(CV1_FILE_NAME)) - if os.path.exists(CV2_FILE_NAME): - os.remove(CV2_FILE_NAME) - if os.path.exists("{}.db".format(CV2_FILE_NAME)): - os.remove("{}.db".format(CV2_FILE_NAME)) - + try: + if os.path.exists(CV1_FILE_NAME): + os.remove(CV1_FILE_NAME) + if os.path.exists("{}.db".format(CV1_FILE_NAME)): + os.remove("{}.db".format(CV1_FILE_NAME)) + if os.path.exists(CV2_FILE_NAME): + os.remove(CV2_FILE_NAME) + if os.path.exists("{}.db".format(CV2_FILE_NAME)): + os.remove("{}.db".format(CV2_FILE_NAME)) + writer = FileWriter(CV1_FILE_NAME, 1) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "CV1_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + + writer = FileWriter(CV2_FILE_NAME, 1) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "CV2_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], + columns_list, num_readers) + assert data_set.get_dataset_size() == 30 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 30 + except Exception as error: + if os.path.exists(CV1_FILE_NAME): + os.remove(CV1_FILE_NAME) + if os.path.exists("{}.db".format(CV1_FILE_NAME)): + os.remove("{}.db".format(CV1_FILE_NAME)) + if os.path.exists(CV2_FILE_NAME): + os.remove(CV2_FILE_NAME) + if os.path.exists("{}.db".format(CV2_FILE_NAME)): + os.remove("{}.db".format(CV2_FILE_NAME)) + raise error + else: + if os.path.exists(CV1_FILE_NAME): + os.remove(CV1_FILE_NAME) + if os.path.exists("{}.db".format(CV1_FILE_NAME)): + os.remove("{}.db".format(CV1_FILE_NAME)) + if os.path.exists(CV2_FILE_NAME): + os.remove(CV2_FILE_NAME) + if os.path.exists("{}.db".format(CV2_FILE_NAME)): + os.remove("{}.db".format(CV2_FILE_NAME)) def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV1_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "CV1_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + + [CV1_FILE_NAME + str(x) for x in range(2, 4)], + columns_list, num_readers) + assert data_set.get_dataset_size() < 20 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter < 20 + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(CV1_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) - cv_schema_json = {"id": {"type": "int32"}, - "file_name": {"type": "string"}, - "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "CV1_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - - columns_list = ["data", "file_name", "label"] - num_readers = 4 - data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], - columns_list, num_readers) - assert data_set.get_dataset_size() < 20 - num_iter = 0 - for item in data_set.create_dict_iterator(): - logger.info( - "-------------- cv reader basic: {} ------------------------".format(num_iter)) - logger.info( - "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) - logger.info( - "-------------- item[data]: {} -----------------------------".format(item["data"])) - logger.info( - "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) - logger.info( - "-------------- item[label]: {} ----------------------------".format(item["label"])) - num_iter += 1 - assert num_iter < 20 - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) - def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): """tutorial for cv minderdataset.""" @@ -1020,770 +1088,870 @@ def inputs(vectors, maxlen=50): def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): mindrecord_file_name = "test.mindrecord" - if os.path.exists("{}".format(mindrecord_file_name)): + try: + if os.path.exists("{}".format(mindrecord_file_name)): + os.remove("{}".format(mindrecord_file_name)) + if os.path.exists("{}.db".format(mindrecord_file_name)): + os.remove("{}.db".format(mindrecord_file_name)) + data = [{"file_name": "001.jpg", "label": 4, + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8'), + "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, + {"file_name": "002.jpg", "label": 5, + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8'), + "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, + {"file_name": "003.jpg", "label": 6, + "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8'), + "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, + {"file_name": "004.jpg", "label": 7, + "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8'), + "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, + {"file_name": "005.jpg", "label": 8, + "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8'), + "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, + {"file_name": "006.jpg", "label": 9, + "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8'), + "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} + ] + + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "image3": {"type": "bytes"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]}, + "label": {"type": "int32"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list + data_value_to_list = [] + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(item["file_name"], dtype='S') + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + new_data['source_sos_ids'] = item["source_sos_ids"] + new_data['source_sos_mask'] = item["source_sos_mask"] + new_data['target_sos_ids'] = item["target_sos_ids"] + new_data['target_sos_mask'] = item["target_sos_mask"] + new_data['target_eos_ids'] = item["target_eos_ids"] + new_data['target_eos_mask'] = item["target_eos_mask"] + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 13 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["source_sos_ids", + "source_sos_mask", "target_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data[num_iter][field]).all() + else: + assert item[field] == data[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 1 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 4 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 3 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_sos_ids", + "image4", "source_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 3 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_sos_ids", "image5", + "image4", "image3", "source_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 1 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_mask", "image5", + "image2", "source_sos_mask", "label"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", + "source_sos_mask", "image2", "image4", "image3", + "source_sos_ids", "image5", "file_name"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 11 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + except Exception as error: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + raise error + else: os.remove("{}".format(mindrecord_file_name)) - if os.path.exists("{}.db".format(mindrecord_file_name)): os.remove("{}.db".format(mindrecord_file_name)) - data = [{"file_name": "001.jpg", "label": 4, - "image1": bytes("image1 bytes abc", encoding='UTF-8'), - "image2": bytes("image1 bytes def", encoding='UTF-8'), - "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "image3": bytes("image1 bytes ghi", encoding='UTF-8'), - "image4": bytes("image1 bytes jkl", encoding='UTF-8'), - "image5": bytes("image1 bytes mno", encoding='UTF-8'), - "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, - {"file_name": "002.jpg", "label": 5, - "image1": bytes("image2 bytes abc", encoding='UTF-8'), - "image2": bytes("image2 bytes def", encoding='UTF-8'), - "image3": bytes("image2 bytes ghi", encoding='UTF-8'), - "image4": bytes("image2 bytes jkl", encoding='UTF-8'), - "image5": bytes("image2 bytes mno", encoding='UTF-8'), - "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, - {"file_name": "003.jpg", "label": 6, - "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "image1": bytes("image3 bytes abc", encoding='UTF-8'), - "image2": bytes("image3 bytes def", encoding='UTF-8'), - "image3": bytes("image3 bytes ghi", encoding='UTF-8'), - "image4": bytes("image3 bytes jkl", encoding='UTF-8'), - "image5": bytes("image3 bytes mno", encoding='UTF-8'), - "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, - {"file_name": "004.jpg", "label": 7, - "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "image1": bytes("image4 bytes abc", encoding='UTF-8'), - "image2": bytes("image4 bytes def", encoding='UTF-8'), - "image3": bytes("image4 bytes ghi", encoding='UTF-8'), - "image4": bytes("image4 bytes jkl", encoding='UTF-8'), - "image5": bytes("image4 bytes mno", encoding='UTF-8'), - "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, - {"file_name": "005.jpg", "label": 8, - "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), - "image1": bytes("image5 bytes abc", encoding='UTF-8'), - "image2": bytes("image5 bytes def", encoding='UTF-8'), - "image3": bytes("image5 bytes ghi", encoding='UTF-8'), - "image4": bytes("image5 bytes jkl", encoding='UTF-8'), - "image5": bytes("image5 bytes mno", encoding='UTF-8'), - "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, - {"file_name": "006.jpg", "label": 9, - "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), - "image1": bytes("image6 bytes abc", encoding='UTF-8'), - "image2": bytes("image6 bytes def", encoding='UTF-8'), - "image3": bytes("image6 bytes ghi", encoding='UTF-8'), - "image4": bytes("image6 bytes jkl", encoding='UTF-8'), - "image5": bytes("image6 bytes mno", encoding='UTF-8'), - "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} - ] - - writer = FileWriter(mindrecord_file_name) - schema = {"file_name": {"type": "string"}, - "image1": {"type": "bytes"}, - "image2": {"type": "bytes"}, - "source_sos_ids": {"type": "int64", "shape": [-1]}, - "source_sos_mask": {"type": "int64", "shape": [-1]}, - "image3": {"type": "bytes"}, - "image4": {"type": "bytes"}, - "image5": {"type": "bytes"}, - "target_sos_ids": {"type": "int64", "shape": [-1]}, - "target_sos_mask": {"type": "int64", "shape": [-1]}, - "target_eos_ids": {"type": "int64", "shape": [-1]}, - "target_eos_mask": {"type": "int64", "shape": [-1]}, - "label": {"type": "int32"}} - writer.add_schema(schema, "data is so cool") - writer.write_raw_data(data) - writer.commit() - - # change data value to list - data_value_to_list = [] - for item in data: - new_data = {} - new_data['file_name'] = np.asarray(item["file_name"], dtype='S') - new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) - new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) - new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) - new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) - new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) - new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) - new_data['source_sos_ids'] = item["source_sos_ids"] - new_data['source_sos_mask'] = item["source_sos_mask"] - new_data['target_sos_ids'] = item["target_sos_ids"] - new_data['target_sos_mask'] = item["target_sos_mask"] - new_data['target_eos_ids'] = item["target_eos_ids"] - new_data['target_eos_mask'] = item["target_eos_mask"] - data_value_to_list.append(new_data) - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 13 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["source_sos_ids", - "source_sos_mask", "target_sos_ids"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 3 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == data[num_iter][field]).all() - else: - assert item[field] == data[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 1 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=[ - "image2", "source_sos_mask", "image3", "target_sos_ids"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 4 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 3 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_sos_ids", - "image4", "source_sos_ids"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 3 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 3 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_sos_ids", "image5", - "image4", "image3", "source_sos_ids"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 5 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 1 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_eos_mask", "image5", - "image2", "source_sos_mask", "label"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 5 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", "source_sos_mask", - "image2", "image4", "image3", "source_sos_ids", "image5", "file_name"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 11 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - os.remove("{}".format(mindrecord_file_name)) - os.remove("{}.db".format(mindrecord_file_name)) def test_write_with_multi_bytes_and_MindDataset(): mindrecord_file_name = "test.mindrecord" - data = [{"file_name": "001.jpg", "label": 43, - "image1": bytes("image1 bytes abc", encoding='UTF-8'), - "image2": bytes("image1 bytes def", encoding='UTF-8'), - "image3": bytes("image1 bytes ghi", encoding='UTF-8'), - "image4": bytes("image1 bytes jkl", encoding='UTF-8'), - "image5": bytes("image1 bytes mno", encoding='UTF-8')}, - {"file_name": "002.jpg", "label": 91, - "image1": bytes("image2 bytes abc", encoding='UTF-8'), - "image2": bytes("image2 bytes def", encoding='UTF-8'), - "image3": bytes("image2 bytes ghi", encoding='UTF-8'), - "image4": bytes("image2 bytes jkl", encoding='UTF-8'), - "image5": bytes("image2 bytes mno", encoding='UTF-8')}, - {"file_name": "003.jpg", "label": 61, - "image1": bytes("image3 bytes abc", encoding='UTF-8'), - "image2": bytes("image3 bytes def", encoding='UTF-8'), - "image3": bytes("image3 bytes ghi", encoding='UTF-8'), - "image4": bytes("image3 bytes jkl", encoding='UTF-8'), - "image5": bytes("image3 bytes mno", encoding='UTF-8')}, - {"file_name": "004.jpg", "label": 29, - "image1": bytes("image4 bytes abc", encoding='UTF-8'), - "image2": bytes("image4 bytes def", encoding='UTF-8'), - "image3": bytes("image4 bytes ghi", encoding='UTF-8'), - "image4": bytes("image4 bytes jkl", encoding='UTF-8'), - "image5": bytes("image4 bytes mno", encoding='UTF-8')}, - {"file_name": "005.jpg", "label": 78, - "image1": bytes("image5 bytes abc", encoding='UTF-8'), - "image2": bytes("image5 bytes def", encoding='UTF-8'), - "image3": bytes("image5 bytes ghi", encoding='UTF-8'), - "image4": bytes("image5 bytes jkl", encoding='UTF-8'), - "image5": bytes("image5 bytes mno", encoding='UTF-8')}, - {"file_name": "006.jpg", "label": 37, - "image1": bytes("image6 bytes abc", encoding='UTF-8'), - "image2": bytes("image6 bytes def", encoding='UTF-8'), - "image3": bytes("image6 bytes ghi", encoding='UTF-8'), - "image4": bytes("image6 bytes jkl", encoding='UTF-8'), - "image5": bytes("image6 bytes mno", encoding='UTF-8')} - ] - writer = FileWriter(mindrecord_file_name) - schema = {"file_name": {"type": "string"}, - "image1": {"type": "bytes"}, - "image2": {"type": "bytes"}, - "image3": {"type": "bytes"}, - "label": {"type": "int32"}, - "image4": {"type": "bytes"}, - "image5": {"type": "bytes"}} - writer.add_schema(schema, "data is so cool") - writer.write_raw_data(data) - writer.commit() - - # change data value to list - data_value_to_list = [] - for item in data: - new_data = {} - new_data['file_name'] = np.asarray(item["file_name"], dtype='S') - new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) - new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) - new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) - new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) - new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) - new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) - data_value_to_list.append(new_data) - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 7 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image1", "image2", "image5"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 3 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image2", "image4"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 2 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image5", "image2"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 2 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image5", "image2", "label"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 3 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["image4", "image5", - "image2", "image3", "file_name"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 5 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - os.remove("{}".format(mindrecord_file_name)) - os.remove("{}.db".format(mindrecord_file_name)) - + try: + data = [{"file_name": "001.jpg", "label": 43, + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "label": {"type": "int32"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list + data_value_to_list = [] + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(item["file_name"], dtype='S') + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 7 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image1", "image2", "image5"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image2", "image4"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image5", "image2"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image5", "image2", "label"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image4", "image5", + "image2", "image3", "file_name"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + except Exception as error: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + raise error + else: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) def test_write_with_multi_array_and_MindDataset(): mindrecord_file_name = "test.mindrecord" - data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), - "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), - "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, - {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), - "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), - "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, - {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), - "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), - "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, - {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), - "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), - "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, - {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), - "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), - "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, - {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), - "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), - "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), - "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), - "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), - "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), - "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), - "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} - ] - writer = FileWriter(mindrecord_file_name) - schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, - "source_sos_mask": {"type": "int64", "shape": [-1]}, - "source_eos_ids": {"type": "int64", "shape": [-1]}, - "source_eos_mask": {"type": "int64", "shape": [-1]}, - "target_sos_ids": {"type": "int64", "shape": [-1]}, - "target_sos_mask": {"type": "int64", "shape": [-1]}, - "target_eos_ids": {"type": "int64", "shape": [-1]}, - "target_eos_mask": {"type": "int64", "shape": [-1]}} - writer.add_schema(schema, "data is so cool") - writer.write_raw_data(data) - writer.commit() - - # change data value to list - do none - data_value_to_list = [] - for item in data: - new_data = {} - new_data['source_sos_ids'] = item["source_sos_ids"] - new_data['source_sos_mask'] = item["source_sos_mask"] - new_data['source_eos_ids'] = item["source_eos_ids"] - new_data['source_eos_mask'] = item["source_eos_mask"] - new_data['target_sos_ids'] = item["target_sos_ids"] - new_data['target_sos_mask'] = item["target_sos_mask"] - new_data['target_eos_ids'] = item["target_eos_ids"] - new_data['target_eos_mask'] = item["target_eos_mask"] - data_value_to_list.append(new_data) - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 8 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["source_eos_ids", "source_eos_mask", - "target_sos_ids", "target_sos_mask", - "target_eos_ids", "target_eos_mask"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 6 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["source_sos_ids", - "target_sos_ids", - "target_eos_mask"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 3 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_eos_mask", - "source_eos_mask", - "source_sos_mask"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 3 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_eos_ids"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 1 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - num_readers = 1 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["target_eos_mask", "target_eos_ids", - "target_sos_mask", "target_sos_ids", - "source_eos_mask", "source_eos_ids", - "source_sos_mask", "source_sos_ids"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 6 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 8 - for field in item: - if isinstance(item[field], np.ndarray): - assert (item[field] == - data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 6 - - os.remove("{}".format(mindrecord_file_name)) - os.remove("{}.db".format(mindrecord_file_name)) - -def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(): - mindrecord_file_name = "test.mindrecord" - data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), - "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, - 123414314.2141243, 87.1212122], dtype=np.float64), - "float32": 3456.12345, - "float64": 1987654321.123456785, - "int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32), - "int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64), - "int32": 3456, - "int64": 947654321123}, - {"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), - "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, - 123414314.2141243, 87.1212122], dtype=np.float64), - "float32": 3456.12445, - "float64": 1987654321.123456786, - "int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32), - "int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64), - "int32": 3466, - "int64": 957654321123}, - {"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), - "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, - 123414314.2141243, 87.1212122], dtype=np.float64), - "float32": 3456.12545, - "float64": 1987654321.123456787, - "int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32), - "int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64), - "int32": 3476, - "int64": 967654321123}, - {"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), - "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, - 123414314.2141243, 87.1212122], dtype=np.float64), - "float32": 3456.12645, - "float64": 1987654321.123456788, - "int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32), - "int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64), - "int32": 3486, - "int64": 977654321123}, - {"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), - "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, - 123414314.2141243, 87.1212122], dtype=np.float64), - "float32": 3456.12745, - "float64": 1987654321.123456789, - "int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32), - "int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64), - "int32": 3496, - "int64": 987654321123}, - ] - writer = FileWriter(mindrecord_file_name) - schema = {"float32_array": {"type": "float32", "shape": [-1]}, - "float64_array": {"type": "float64", "shape": [-1]}, - "float32": {"type": "float32"}, - "float64": {"type": "float64"}, - "int32_array": {"type": "int32", "shape": [-1]}, - "int64_array": {"type": "int64", "shape": [-1]}, - "int32": {"type": "int32"}, - "int64": {"type": "int64"}} - writer.add_schema(schema, "data is so cool") - writer.write_raw_data(data) - writer.commit() - - # change data value to list - do none - data_value_to_list = [] - for item in data: - new_data = {} - new_data['float32_array'] = item["float32_array"] - new_data['float64_array'] = item["float64_array"] - new_data['float32'] = item["float32"] - new_data['float64'] = item["float64"] - new_data['int32_array'] = item["int32_array"] - new_data['int64_array'] = item["int64_array"] - new_data['int32'] = item["int32"] - new_data['int64'] = item["int64"] - data_value_to_list.append(new_data) - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 5 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 8 - for field in item: - if isinstance(item[field], np.ndarray): - if item[field].dtype == np.float32: + try: + data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "source_eos_ids": {"type": "int64", "shape": [-1]}, + "source_eos_mask": {"type": "int64", "shape": [-1]}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list - do none + data_value_to_list = [] + for item in data: + new_data = {} + new_data['source_sos_ids'] = item["source_sos_ids"] + new_data['source_sos_mask'] = item["source_sos_mask"] + new_data['source_eos_ids'] = item["source_eos_ids"] + new_data['source_eos_mask'] = item["source_eos_mask"] + new_data['target_sos_ids'] = item["target_sos_ids"] + new_data['target_sos_mask'] = item["target_sos_mask"] + new_data['target_eos_ids'] = item["target_eos_ids"] + new_data['target_eos_mask'] = item["target_eos_mask"] + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 8 + for field in item: + if isinstance(item[field], np.ndarray): assert (item[field] == - np.array(data_value_to_list[num_iter][field], np.float32)).all() + data_value_to_list[num_iter][field]).all() else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["source_eos_ids", "source_eos_mask", + "target_sos_ids", "target_sos_mask", + "target_eos_ids", "target_eos_mask"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 6 + for field in item: + if isinstance(item[field], np.ndarray): assert (item[field] == data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 5 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["float32", "int32"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 5 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 2 - for field in item: - if isinstance(item[field], np.ndarray): - if item[field].dtype == np.float32: + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["source_sos_ids", + "target_sos_ids", + "target_eos_mask"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): assert (item[field] == - np.array(data_value_to_list[num_iter][field], np.float32)).all() + data_value_to_list[num_iter][field]).all() else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_mask", + "source_eos_mask", + "source_sos_mask"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): assert (item[field] == data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 5 - - num_readers = 2 - data_set = ds.MindDataset(dataset_file=mindrecord_file_name, - columns_list=["float64", "int64"], - num_parallel_workers=num_readers, - shuffle=False) - assert data_set.get_dataset_size() == 5 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 2 - for field in item: - if isinstance(item[field], np.ndarray): - if item[field].dtype == np.float32: + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 1 + for field in item: + if isinstance(item[field], np.ndarray): assert (item[field] == - np.array(data_value_to_list[num_iter][field], np.float32)).all() - elif item[field].dtype == np.float64: - assert math.isclose(item[field], - np.array(data_value_to_list[num_iter][field], np.float64), rel_tol=1e-14) + data_value_to_list[num_iter][field]).all() else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 1 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_mask", "target_eos_ids", + "target_sos_mask", "target_sos_ids", + "source_eos_mask", "source_eos_ids", + "source_sos_mask", "source_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 8 + for field in item: + if isinstance(item[field], np.ndarray): assert (item[field] == data_value_to_list[num_iter][field]).all() - else: - assert item[field] == data_value_to_list[num_iter][field] - num_iter += 1 - assert num_iter == 5 + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + except Exception as error: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + raise error + else: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + +def test_numpy_generic(): + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + cv_schema_json = {"label1": {"type": "int32"}, "label2": {"type": "int64"}, + "label3": {"type": "float32"}, "label4": {"type": "float64"}} + data = [] + for idx in range(10): + row = {} + row['label1'] = np.int32(idx) + row['label2'] = np.int64(idx*10) + row['label3'] = np.float32(idx+0.12345) + row['label4'] = np.float64(idx+0.12345789) + data.append(row) + writer.add_schema(cv_schema_json, "img_schema") + writer.write_raw_data(data) + writer.commit() + + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, shuffle=False) + assert data_set.get_dataset_size() == 10 + idx = 0 + for item in data_set.create_dict_iterator(): + assert item['label1'] == item['label1'] + assert item['label2'] == item['label2'] + assert item['label3'] == item['label3'] + assert item['label4'] == item['label4'] + idx += 1 + assert idx == 10 + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + +def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(): + mindrecord_file_name = "test.mindrecord" + try: + data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12345, + "float64": 1987654321.123456785, + "int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32), + "int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64), + "int32": 3456, + "int64": 947654321123}, + {"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12445, + "float64": 1987654321.123456786, + "int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32), + "int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64), + "int32": 3466, + "int64": 957654321123}, + {"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12545, + "float64": 1987654321.123456787, + "int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32), + "int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64), + "int32": 3476, + "int64": 967654321123}, + {"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12645, + "float64": 1987654321.123456788, + "int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32), + "int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64), + "int32": 3486, + "int64": 977654321123}, + {"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12745, + "float64": 1987654321.123456789, + "int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32), + "int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64), + "int32": 3496, + "int64": 987654321123}, + ] + writer = FileWriter(mindrecord_file_name) + schema = {"float32_array": {"type": "float32", "shape": [-1]}, + "float64_array": {"type": "float64", "shape": [-1]}, + "float32": {"type": "float32"}, + "float64": {"type": "float64"}, + "int32_array": {"type": "int32", "shape": [-1]}, + "int64_array": {"type": "int64", "shape": [-1]}, + "int32": {"type": "int32"}, + "int64": {"type": "int64"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list - do none + data_value_to_list = [] + for item in data: + new_data = {} + new_data['float32_array'] = item["float32_array"] + new_data['float64_array'] = item["float64_array"] + new_data['float32'] = item["float32"] + new_data['float64'] = item["float64"] + new_data['int32_array'] = item["int32_array"] + new_data['int64_array'] = item["int64_array"] + new_data['int32'] = item["int32"] + new_data['int64'] = item["int64"] + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 8 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["float32", "int32"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["float64", "int64"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + elif item[field].dtype == np.float64: + assert math.isclose(item[field], + np.array(data_value_to_list[num_iter][field], np.float64), + rel_tol=1e-14) + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + except Exception as error: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + raise error + else: + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) - os.remove("{}".format(mindrecord_file_name)) - os.remove("{}.db".format(mindrecord_file_name)) +if __name__ == '__main__': + test_nlp_compress_data(add_and_remove_nlp_compress_file) + test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file) + test_cv_minddataset_writer_tutorial() + test_cv_minddataset_partition_tutorial(add_and_remove_cv_file) + test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file) + test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file) + test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file) + test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file) + test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file) + test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file) + test_cv_minddataset_dataset_size(add_and_remove_cv_file) + test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file) + test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file) + test_cv_minddataset_issue_888(add_and_remove_cv_file) + test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file) + test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file) + test_cv_minddataset_reader_file_list(add_and_remove_cv_file) + test_cv_minddataset_reader_one_partition(add_and_remove_cv_file) + test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file) + test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file) + test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file) + test_nlp_minddataset_reader_basic_tutorial(add_and_remove_cv_file) + test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file) + test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file) + test_cv_minddataset_reader_no_columns(add_and_remove_cv_file) + test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file) + test_write_with_multi_bytes_and_array_and_read_by_MindDataset() + test_write_with_multi_bytes_and_MindDataset() + test_write_with_multi_array_and_MindDataset() + test_numpy_generic() + test_write_with_float32_float64_float32_array_float64_array_and_MindDataset() diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 0b4d0dfc8f..51621750c8 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -99,8 +99,13 @@ def test_invalid_mindrecord(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert num_iter == 0 - os.remove('dummy.mindrecord') + try: + assert num_iter == 0 + except Exception as error: + os.remove('dummy.mindrecord') + raise error + else: + os.remove('dummy.mindrecord') def test_minddataset_lack_db(): @@ -113,8 +118,13 @@ def test_minddataset_lack_db(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert num_iter == 0 - os.remove(CV_FILE_NAME) + try: + assert num_iter == 0 + except Exception as error: + os.remove(CV_FILE_NAME) + raise error + else: + os.remove(CV_FILE_NAME) def test_cv_minddataset_pk_sample_error_class_column(): @@ -189,10 +199,16 @@ def test_minddataset_invalidate_num_shards(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) + try: + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) + except Exception as error: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + raise error + else: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) - os.remove(CV_FILE_NAME) - os.remove("{}.db".format(CV_FILE_NAME)) def test_minddataset_invalidate_shard_id(): create_cv_mindrecord(1) @@ -203,9 +219,15 @@ def test_minddataset_invalidate_shard_id(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) - os.remove(CV_FILE_NAME) - os.remove("{}.db".format(CV_FILE_NAME)) + try: + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) + except Exception as error: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + raise error + else: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) def test_minddataset_shard_id_bigger_than_num_shard(): @@ -217,14 +239,65 @@ def test_minddataset_shard_id_bigger_than_num_shard(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) + try: + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) + except Exception as error: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + raise error with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) + try: + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) + except Exception as error: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + raise error + else: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) - os.remove(CV_FILE_NAME) - os.remove("{}.db".format(CV_FILE_NAME)) + +def test_cv_minddataset_partition_num_samples_equals_0(): + """tutorial for cv minddataset.""" + create_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=0) + num_iter = 0 + for _ in data_set.create_dict_iterator(): + num_iter += 1 + with pytest.raises(Exception) as error_info: + partitions(5) + try: + assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value) + except Exception as error: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + raise error + else: + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + +if __name__ == '__main__': + test_cv_lack_json() + test_cv_lack_mindrecord() + test_invalid_mindrecord() + test_minddataset_lack_db() + test_cv_minddataset_pk_sample_error_class_column() + test_cv_minddataset_pk_sample_exclusive_shuffle() + test_cv_minddataset_reader_different_schema() + test_cv_minddataset_reader_different_page_size() + test_minddataset_invalidate_num_shards() + test_minddataset_invalidate_shard_id() + test_minddataset_shard_id_bigger_than_num_shard() + test_cv_minddataset_partition_num_samples_equals_0() diff --git a/tests/ut/python/dataset/test_minddataset_multi_images_and_ndarray.py b/tests/ut/python/dataset/test_minddataset_multi_images_and_ndarray.py index c9c9388e65..5ef3a7adcb 100644 --- a/tests/ut/python/dataset/test_minddataset_multi_images_and_ndarray.py +++ b/tests/ut/python/dataset/test_minddataset_multi_images_and_ndarray.py @@ -27,54 +27,64 @@ CV_FILE_NAME = "./complex.mindrecord" def test_cv_minddataset_reader_multi_image_and_ndarray_tutorial(): - writer = FileWriter(CV_FILE_NAME, FILES_NUM) - cv_schema_json = {"id": {"type": "int32"}, - "image_0": {"type": "bytes"}, - "image_2": {"type": "bytes"}, - "image_3": {"type": "bytes"}, - "image_4": {"type": "bytes"}, - "input_mask": {"type": "int32", "shape": [-1]}, - "segments": {"type": "float32", "shape": [2, 3]}} - writer.add_schema(cv_schema_json, "two_images_schema") - with open("../data/mindrecord/testImageNetData/images/image_00010.jpg", "rb") as file_reader: - img_data = file_reader.read() - ndarray_1 = np.array([1, 2, 3, 4, 5], np.int32) - ndarray_2 = np.array(([2, 3, 1], [7, 9, 0]), np.float32) - data = [] - for i in range(5): - item = {"id": i, "image_0": img_data, "image_2": img_data, "image_3": img_data, "image_4": img_data, - "input_mask": ndarray_1, "segments": ndarray_2} - data.append(item) - writer.write_raw_data(data) - writer.commit() - assert os.path.exists(CV_FILE_NAME) - assert os.path.exists(CV_FILE_NAME + ".db") + try: + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + cv_schema_json = {"id": {"type": "int32"}, + "image_0": {"type": "bytes"}, + "image_2": {"type": "bytes"}, + "image_3": {"type": "bytes"}, + "image_4": {"type": "bytes"}, + "input_mask": {"type": "int32", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 3]}} + writer.add_schema(cv_schema_json, "two_images_schema") + with open("../data/mindrecord/testImageNetData/images/image_00010.jpg", "rb") as file_reader: + img_data = file_reader.read() + ndarray_1 = np.array([1, 2, 3, 4, 5], np.int32) + ndarray_2 = np.array(([2, 3, 1], [7, 9, 0]), np.float32) + data = [] + for i in range(5): + item = {"id": i, "image_0": img_data, "image_2": img_data, "image_3": img_data, "image_4": img_data, + "input_mask": ndarray_1, "segments": ndarray_2} + data.append(item) + writer.write_raw_data(data) + writer.commit() + assert os.path.exists(CV_FILE_NAME) + assert os.path.exists(CV_FILE_NAME + ".db") - # tutorial for minderdataset. - columns_list = ["id", "image_0", "image_2", "image_3", "image_4", "input_mask", "segments"] - num_readers = 1 - data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) - assert data_set.get_dataset_size() == 5 - num_iter = 0 - for item in data_set.create_dict_iterator(): - assert len(item) == 7 - logger.info("item: {}".format(item)) - assert item["image_0"].dtype == np.uint8 - assert (item["image_0"] == item["image_2"]).all() - assert (item["image_3"] == item["image_4"]).all() - assert (item["image_0"] == item["image_4"]).all() - assert item["image_2"].dtype == np.uint8 - assert item["image_3"].dtype == np.uint8 - assert item["image_4"].dtype == np.uint8 - assert item["id"].dtype == np.int32 - assert item["input_mask"].shape == (5,) - assert item["input_mask"].dtype == np.int32 - assert item["segments"].shape == (2, 3) - assert item["segments"].dtype == np.float32 - num_iter += 1 - assert num_iter == 5 + # tutorial for minderdataset. + columns_list = ["id", "image_0", "image_2", "image_3", "image_4", "input_mask", "segments"] + num_readers = 1 + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 7 + logger.info("item: {}".format(item)) + assert item["image_0"].dtype == np.uint8 + assert (item["image_0"] == item["image_2"]).all() + assert (item["image_3"] == item["image_4"]).all() + assert (item["image_0"] == item["image_4"]).all() + assert item["image_2"].dtype == np.uint8 + assert item["image_3"].dtype == np.uint8 + assert item["image_4"].dtype == np.uint8 + assert item["id"].dtype == np.int32 + assert item["input_mask"].shape == (5,) + assert item["input_mask"].dtype == np.int32 + assert item["segments"].shape == (2, 3) + assert item["segments"].dtype == np.float32 + num_iter += 1 + assert num_iter == 5 + except Exception as error: + if os.path.exists("{}".format(CV_FILE_NAME + ".db")): + os.remove(CV_FILE_NAME + ".db") + if os.path.exists("{}".format(CV_FILE_NAME)): + os.remove(CV_FILE_NAME) + raise error + else: + if os.path.exists("{}".format(CV_FILE_NAME + ".db")): + os.remove(CV_FILE_NAME + ".db") + if os.path.exists("{}".format(CV_FILE_NAME)): + os.remove(CV_FILE_NAME) - if os.path.exists("{}".format(CV_FILE_NAME + ".db")): - os.remove(CV_FILE_NAME + ".db") - if os.path.exists("{}".format(CV_FILE_NAME)): - os.remove(CV_FILE_NAME) +if __name__ == '__main__': + test_cv_minddataset_reader_multi_image_and_ndarray_tutorial() diff --git a/tests/ut/python/dataset/test_minddataset_padded.py b/tests/ut/python/dataset/test_minddataset_padded.py index c0724e3236..a05879ab01 100644 --- a/tests/ut/python/dataset/test_minddataset_padded.py +++ b/tests/ut/python/dataset/test_minddataset_padded.py @@ -44,24 +44,31 @@ def add_and_remove_cv_file(): """add/remove cv file""" paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None - os.remove("{}.db".format(x)) if os.path.exists( - "{}.db".format(x)) else None - writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) - cv_schema_json = {"id": {"type": "int32"}, - "file_name": {"type": "string"}, - "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "img_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - yield "yield_cv_data" - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) + try: + for x in paths: + os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None + os.remove("{}.db".format(x)) if os.path.exists( + "{}.db".format(x)) else None + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_cv_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) @pytest.fixture @@ -69,32 +76,39 @@ def add_and_remove_nlp_file(): """add/remove nlp file""" paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(NLP_FILE_NAME, FILES_NUM) + data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)] + nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"}, + "rating": {"type": "float32"}, + "input_ids": {"type": "int64", + "shape": [-1]}, + "input_mask": {"type": "int64", + "shape": [1, -1]}, + "segment_ids": {"type": "int64", + "shape": [2, -1]} + } + writer.set_header_size(1 << 14) + writer.set_page_size(1 << 15) + writer.add_schema(nlp_schema_json, "nlp_schema") + writer.add_index(["id", "rating"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_nlp_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(NLP_FILE_NAME, FILES_NUM) - data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)] - nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"}, - "rating": {"type": "float32"}, - "input_ids": {"type": "int64", - "shape": [-1]}, - "input_mask": {"type": "int64", - "shape": [1, -1]}, - "segment_ids": {"type": "int64", - "shape": [2, -1]} - } - writer.set_header_size(1 << 14) - writer.set_page_size(1 << 15) - writer.add_schema(nlp_schema_json, "nlp_schema") - writer.add_index(["id", "rating"]) - writer.write_raw_data(data) - writer.commit() - yield "yield_nlp_data" - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file): """tutorial for cv minderdataset.""" @@ -119,7 +133,7 @@ def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file): encoding='utf8') assert item['label'] == padded_sample['label'] assert (item['data'] == np.array(list(padded_sample['data']))).all() - num_iter += 1 + num_iter += 1 assert num_padded_iter == 5 assert num_iter == 15 @@ -636,3 +650,17 @@ def inputs(vectors, maxlen=50): mask = [1] * length + [0] * (maxlen - length) segment = [0] * maxlen return input_, mask, segment + +if __name__ == '__main__': + test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv_file) + test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remove_cv_file) + test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file) + test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_nlp_file) + test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_result_per_epoch(add_and_remove_nlp_file) diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 8d099f1af2..9c110c0e1f 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -34,26 +34,32 @@ def add_and_remove_cv_file(): """add/remove cv file""" paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] - for x in paths: - if os.path.exists("{}".format(x)): + try: + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME, True) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_cv_data" + except Exception as error: + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + raise error + else: + for x in paths: os.remove("{}".format(x)) - if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) - writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME, True) - cv_schema_json = {"id": {"type": "int32"}, - "file_name": {"type": "string"}, - "label": {"type": "int32"}, - "data": {"type": "bytes"}} - writer.add_schema(cv_schema_json, "img_schema") - writer.add_index(["file_name", "label"]) - writer.write_raw_data(data) - writer.commit() - yield "yield_cv_data" - for x in paths: - os.remove("{}".format(x)) - os.remove("{}.db".format(x)) - def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): """tutorial for cv minderdataset.""" @@ -626,3 +632,24 @@ def get_data(dir_name, sampler=False): except FileNotFoundError: continue return data_list + +if __name__ == '__main__': + test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file) + test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file) + test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file) + test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file) + test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file) + test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file) + test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file) + test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file) + test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file) + test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file) + test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file) + test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file) + test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file) + test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file) + test_cv_minddataset_split_basic(add_and_remove_cv_file) + test_cv_minddataset_split_exact_percent(add_and_remove_cv_file) + test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file) + test_cv_minddataset_split_deterministic(add_and_remove_cv_file) + test_cv_minddataset_split_sharding(add_and_remove_cv_file) diff --git a/tests/ut/python/dataset/test_opt_pass.py b/tests/ut/python/dataset/test_opt_pass.py index 480bfcbeab..d89ceab73e 100644 --- a/tests/ut/python/dataset/test_opt_pass.py +++ b/tests/ut/python/dataset/test_opt_pass.py @@ -67,7 +67,7 @@ def test_shuffle(): for d1, d2 in zip(data1, data2): for t1, t2 in zip(d1, d2): - assert np.array_equal(t1, t2) + np.testing.assert_array_equal(t1, t2) ds.config.set_seed(1) DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" @@ -77,7 +77,7 @@ def test_shuffle(): for d1, d2 in zip(data1, data2): for t1, t2 in zip(d1, d2): - assert np.array_equal(t1, t2) + np.testing.assert_array_equal(t1, t2) ds.config.set_seed(1) TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' @@ -87,7 +87,7 @@ def test_shuffle(): for d1, d2 in zip(data1, data2): for t1, t2 in zip(d1, d2): - assert np.array_equal(t1, t2) + np.testing.assert_array_equal(t1, t2) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_pad.py b/tests/ut/python/dataset/test_pad.py index a3038a4b91..d2c4e60dc5 100644 --- a/tests/ut/python/dataset/test_pad.py +++ b/tests/ut/python/dataset/test_pad.py @@ -148,8 +148,20 @@ def test_pad_md5(): filename2 = "pad_01_py_result.npz" save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN) +def test_pad_exception(): + try: + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + pad_op = c_vision.Pad(150) + data1 = data1.map(input_columns=["image"], operations=pad_op) + for _ in data1.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "Pad error: invalid image shape, only support 3 channels image" in str(e) + if __name__ == "__main__": test_pad_op() test_pad_grayscale() test_pad_md5() + test_pad_exception() diff --git a/tests/ut/python/dataset/test_pad_batch.py b/tests/ut/python/dataset/test_pad_batch.py index 314e6f4a53..cea3427604 100644 --- a/tests/ut/python/dataset/test_pad_batch.py +++ b/tests/ut/python/dataset/test_pad_batch.py @@ -63,8 +63,8 @@ def test_batch_padding_01(): data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], -2), "col1d": ([2], -1)}) data1 = data1.repeat(2) for data in data1.create_dict_iterator(): - assert np.array_equal([[0, -1], [1, -1]], data["col1d"]) - assert np.array_equal([[[100, -2], [200, -2]], [[101, -2], [201, -2]]], data["col2d"]) + np.testing.assert_array_equal([[0, -1], [1, -1]], data["col1d"]) + np.testing.assert_array_equal([[[100, -2], [200, -2]], [[101, -2], [201, -2]]], data["col2d"]) def test_batch_padding_02(): @@ -72,8 +72,8 @@ def test_batch_padding_02(): data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], -2)}) data1 = data1.repeat(2) for data in data1.create_dict_iterator(): - assert np.array_equal([[0], [1]], data["col1d"]) - assert np.array_equal([[[100, -2]], [[101, -2]]], data["col2d"]) + np.testing.assert_array_equal([[0], [1]], data["col1d"]) + np.testing.assert_array_equal([[[100, -2]], [[101, -2]]], data["col2d"]) def test_batch_padding_03(): @@ -83,10 +83,10 @@ def test_batch_padding_03(): res = dict() for ind, data in enumerate(data1.create_dict_iterator()): res[ind] = data["col"].copy() - assert np.array_equal(res[0], [[0, -1], [0, 1]]) - assert np.array_equal(res[1], [[0, 1, 2, -1], [0, 1, 2, 3]]) - assert np.array_equal(res[2], [[0, -1], [0, 1]]) - assert np.array_equal(res[3], [[0, 1, 2, -1], [0, 1, 2, 3]]) + np.testing.assert_array_equal(res[0], [[0, -1], [0, 1]]) + np.testing.assert_array_equal(res[1], [[0, 1, 2, -1], [0, 1, 2, 3]]) + np.testing.assert_array_equal(res[2], [[0, -1], [0, 1]]) + np.testing.assert_array_equal(res[3], [[0, 1, 2, -1], [0, 1, 2, 3]]) def test_batch_padding_04(): @@ -94,8 +94,8 @@ def test_batch_padding_04(): data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={}) # pad automatically data1 = data1.repeat(2) for data in data1.create_dict_iterator(): - assert np.array_equal(data["col1"], [[0, 0], [0, 1]]) - assert np.array_equal(data["col2"], [[100, 0], [100, 101]]) + np.testing.assert_array_equal(data["col1"], [[0, 0], [0, 1]]) + np.testing.assert_array_equal(data["col2"], [[100, 0], [100, 101]]) def test_batch_padding_05(): @@ -103,9 +103,9 @@ def test_batch_padding_05(): data1 = data1.batch(batch_size=3, drop_remainder=False, pad_info={"col2": ([2, None], -2), "col1": (None, -1)}) # pad automatically for data in data1.create_dict_iterator(): - assert np.array_equal(data["col1"], [[[0, -1, -1]], [[0, 1, -1]], [[0, 1, 2]]]) - assert np.array_equal(data["col2"], [[[100, -2, -2], [-2, -2, -2]], [[100, 101, -2], [-2, -2, -2]], - [[100, 101, 102], [-2, -2, -2]]]) + np.testing.assert_array_equal(data["col1"], [[[0, -1, -1]], [[0, 1, -1]], [[0, 1, 2]]]) + np.testing.assert_array_equal(data["col2"], [[[100, -2, -2], [-2, -2, -2]], [[100, 101, -2], [-2, -2, -2]], + [[100, 101, 102], [-2, -2, -2]]]) def batch_padding_performance_3d(): @@ -197,7 +197,7 @@ def test_pad_via_map(): res_from_batch = pad_batch_config() assert len(res_from_batch) == len(res_from_batch) for i, _ in enumerate(res_from_map): - assert np.array_equal(res_from_map[i], res_from_batch[i]) + np.testing.assert_array_equal(res_from_map[i], res_from_batch[i]) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_pair_truncate.py b/tests/ut/python/dataset/test_pair_truncate.py index 6b1138e5a9..8cc40ee126 100644 --- a/tests/ut/python/dataset/test_pair_truncate.py +++ b/tests/ut/python/dataset/test_pair_truncate.py @@ -16,7 +16,6 @@ Testing Mask op in DE """ import numpy as np -import pytest import mindspore.dataset as ds import mindspore.dataset.text as text @@ -55,9 +54,7 @@ def test_basics_str(): def test_exceptions(): - with pytest.raises(RuntimeError) as info: - compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=1, out1=[1, 2], out2=[5]) - assert "Indices are empty, generated tensor would be empty" in str(info.value) + compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=1, out1=[1], out2=[]) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 7e41f1b7fd..b512de5230 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -39,7 +39,7 @@ def test_case_0(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out"], golden) + np.testing.assert_array_equal(item["out"], golden) i = i + 4 @@ -60,9 +60,9 @@ def test_case_1(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["out0"], golden) + np.testing.assert_array_equal(item["out0"], golden) golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out1"], golden) + np.testing.assert_array_equal(item["out1"], golden) i = i + 4 @@ -84,7 +84,7 @@ def test_case_2(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out"], golden) + np.testing.assert_array_equal(item["out"], golden) i = i + 4 @@ -106,11 +106,11 @@ def test_case_3(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["out0"], golden) + np.testing.assert_array_equal(item["out0"], golden) golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out1"], golden) + np.testing.assert_array_equal(item["out1"], golden) golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) - assert np.array_equal(item["out2"], golden) + np.testing.assert_array_equal(item["out2"], golden) i = i + 4 @@ -132,11 +132,11 @@ def test_case_4(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["out0"], golden) + np.testing.assert_array_equal(item["out0"], golden) golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out1"], golden) + np.testing.assert_array_equal(item["out1"], golden) golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) - assert np.array_equal(item["out2"], golden) + np.testing.assert_array_equal(item["out2"], golden) i = i + 4 @@ -159,7 +159,7 @@ def test_case_5(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[1, 1], [1, 1]]) - assert np.array_equal(item["out"], golden) + np.testing.assert_array_equal(item["out"], golden) def test_case_6(): @@ -178,7 +178,7 @@ def test_case_6(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i * 4, (i + 1) * 4], [(i + 2) * 4, (i + 3) * 4]]) - assert np.array_equal(item["out"], golden) + np.testing.assert_array_equal(item["out"], golden) i = i + 4 @@ -198,7 +198,7 @@ def test_case_7(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out"], golden) + np.testing.assert_array_equal(item["out"], golden) i = i + 4 @@ -221,11 +221,11 @@ def test_case_8(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i, i + 1], [i + 2, i + 3]]) - assert np.array_equal(item["out0"], golden) + np.testing.assert_array_equal(item["out0"], golden) golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) - assert np.array_equal(item["out1"], golden) + np.testing.assert_array_equal(item["out1"], golden) golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) - assert np.array_equal(item["out2"], golden) + np.testing.assert_array_equal(item["out2"], golden) i = i + 4 @@ -246,7 +246,7 @@ def test_case_9(): for item in data1.create_dict_iterator(): # each data is a dictionary # In this test, the dataset is 2x2 sequential tensors golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]]) - assert np.array_equal(item["out"], golden) + np.testing.assert_array_equal(item["out"], golden) i = i + 4 diff --git a/tests/ut/python/dataset/test_random_crop_and_resize.py b/tests/ut/python/dataset/test_random_crop_and_resize.py index 486d2cd5ed..58ae31f4d2 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize.py @@ -332,11 +332,37 @@ def test_random_crop_and_resize_comp(plot=False): image_c_cropped.append(c_image) image_py_cropped.append(py_image) mse = diff_mse(c_image, py_image) - assert mse < 0.02 # rounding error + assert mse < 0.02 # rounding error if plot: visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2) +def test_random_crop_and_resize_06(): + """ + Test RandomCropAndResize with c_transforms: invalid values for scale, + expected to raise ValueError + """ + logger.info("test_random_crop_and_resize_05_c") + + # Generate dataset + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + decode_op = c_vision.Decode() + try: + random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), scale="", ratio=(1, 0.5)) + data = data.map(input_columns=["image"], operations=decode_op) + data.map(input_columns=["image"], operations=random_crop_and_resize_op) + except TypeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Argument scale with value \"\" is not of type (,)" in str(e) + + try: + random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), scale=(1, "2"), ratio=(1, 0.5)) + data = data.map(input_columns=["image"], operations=decode_op) + data.map(input_columns=["image"], operations=random_crop_and_resize_op) + except TypeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Argument scale[1] with value 2 is not of type (, )." in str(e) + if __name__ == "__main__": test_random_crop_and_resize_op_c(True) test_random_crop_and_resize_op_py(True) @@ -347,4 +373,5 @@ if __name__ == "__main__": test_random_crop_and_resize_04_py() test_random_crop_and_resize_05_c() test_random_crop_and_resize_05_py() + test_random_crop_and_resize_06() test_random_crop_and_resize_comp(True) diff --git a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py index 599acc9560..026808e9de 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py @@ -48,9 +48,9 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False): test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "random_resized_crop_with_bbox_01_c_result.npz" @@ -114,15 +114,15 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) # maps to convert data into valid edge case data - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) # Test Op added to list of Operations here - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) unaugSamp, augSamp = [], [] @@ -149,9 +149,9 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) for _ in dataVoc2.create_dict_iterator(): @@ -175,9 +175,9 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) for _ in dataVoc2.create_dict_iterator(): @@ -206,9 +206,9 @@ def test_random_resized_crop_with_bbox_op_bad_c(): if __name__ == "__main__": - test_random_resized_crop_with_bbox_op_c(plot_vis=True) - test_random_resized_crop_with_bbox_op_coco_c(plot_vis=True) - test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True) + test_random_resized_crop_with_bbox_op_c(plot_vis=False) + test_random_resized_crop_with_bbox_op_coco_c(plot_vis=False) + test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False) test_random_resized_crop_with_bbox_op_invalid_c() test_random_resized_crop_with_bbox_op_invalid2_c() test_random_resized_crop_with_bbox_op_bad_c() diff --git a/tests/ut/python/dataset/test_random_crop_with_bbox.py b/tests/ut/python/dataset/test_random_crop_with_bbox.py index b93c638f41..28a68a7c38 100644 --- a/tests/ut/python/dataset/test_random_crop_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_with_bbox.py @@ -46,10 +46,10 @@ def test_random_crop_with_bbox_op_c(plot_vis=False): test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) # Add column for "bbox" unaugSamp, augSamp = [], [] @@ -108,9 +108,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "random_crop_with_bbox_01_c_result.npz" @@ -145,9 +145,9 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) unaugSamp, augSamp = [], [] @@ -175,16 +175,18 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) # maps to convert data into valid edge case data - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[lambda img, bboxes: ( + img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) # Test Op added to list of Operations here - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[lambda img, bboxes: ( + img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) unaugSamp, augSamp = [], [] @@ -210,10 +212,10 @@ def test_random_crop_with_bbox_op_invalid_c(): test_op = c_vision.RandomCropWithBBox([512, 512, 375]) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) # Add column for "bbox" for _ in dataVoc2.create_dict_iterator(): break @@ -239,6 +241,43 @@ def test_random_crop_with_bbox_op_bad_c(): check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") +def test_random_crop_with_bbox_op_bad_padding(): + """ + Test RandomCropWithBBox Op on invalid constructor parameters, expected to raise ValueError + """ + logger.info("test_random_crop_with_bbox_op_invalid_c") + + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + + try: + test_op = c_vision.RandomCropWithBBox([512, 512], padding=-1) + + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + for _ in dataVoc2.create_dict_iterator(): + break + except ValueError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "Input padding is not within the required interval of (0 to 2147483647)." in str(err) + + try: + test_op = c_vision.RandomCropWithBBox([512, 512], padding=[16777216, 16777216, 16777216, 16777216]) + + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + for _ in dataVoc2.create_dict_iterator(): + break + except RuntimeError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "RandomCropBBoxOp padding size is too big, it\'s more than 3 times the original size." in str(err) + + if __name__ == "__main__": test_random_crop_with_bbox_op_c(plot_vis=True) test_random_crop_with_bbox_op_coco_c(plot_vis=True) @@ -247,3 +286,4 @@ if __name__ == "__main__": test_random_crop_with_bbox_op_edge_c(plot_vis=True) test_random_crop_with_bbox_op_invalid_c() test_random_crop_with_bbox_op_bad_c() + test_random_crop_with_bbox_op_bad_padding() diff --git a/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py b/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py index 4fd51a7a03..64c8de1c5e 100644 --- a/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py @@ -45,9 +45,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): test_op = c_vision.RandomHorizontalFlipWithBBox(1) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) unaugSamp, augSamp = [], [] @@ -111,9 +111,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False): test_op = c_vision.RandomHorizontalFlipWithBBox(0.6) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "random_horizontal_flip_with_bbox_01_c_result.npz" @@ -146,20 +146,20 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False): test_op = c_vision.RandomHorizontalFlipWithBBox(1) # map to apply ops - # Add column for "annotation" - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + # Add column for "bbox" + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=lambda img, bbox: (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=lambda img, bbox: (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) unaugSamp, augSamp = [], [] @@ -184,10 +184,10 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): # Note: Valid range of prob should be [0.0, 1.0] test_op = c_vision.RandomHorizontalFlipWithBBox(1.5) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) # Add column for "bbox" except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error) diff --git a/tests/ut/python/dataset/test_random_resize_with_bbox.py b/tests/ut/python/dataset/test_random_resize_with_bbox.py index 94f9d12427..439a6dc89d 100644 --- a/tests/ut/python/dataset/test_random_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_resize_with_bbox.py @@ -48,9 +48,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False): test_op = c_vision.RandomResizeWithBBox(100) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "random_resize_with_bbox_op_01_c_voc_result.npz" @@ -129,15 +129,15 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.RandomResizeWithBBox(500) # maps to convert data into valid edge case data - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: ( img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: ( img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) diff --git a/tests/ut/python/dataset/test_random_select_subpolicy.py b/tests/ut/python/dataset/test_random_select_subpolicy.py new file mode 100644 index 0000000000..4248f9d048 --- /dev/null +++ b/tests/ut/python/dataset/test_random_select_subpolicy.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops +import mindspore.dataset.transforms.vision.c_transforms as visions + + +def test_random_select_subpolicy(): + ds.config.set_seed(0) + + def test_config(arr, policy): + try: + data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False) + data = data.map(input_columns=["col"], operations=visions.RandomSelectSubpolicy(policy)) + res = [] + for i in data.create_dict_iterator(): + res.append(i["col"].tolist()) + return res + except (TypeError, ValueError) as e: + return str(e) + + # 3 possible outcomes + policy1 = [[(ops.PadEnd([4], 0), 0.5), (ops.Compose([ops.Duplicate(), ops.Concatenate()]), 1)], + [(ops.Slice([0, 1]), 0.5), (ops.Duplicate(), 1), (ops.Concatenate(), 1)]] + res1 = test_config([[1, 2, 3]], policy1) + assert res1 in [[[1, 2, 1, 2]], [[1, 2, 3, 1, 2, 3]], [[1, 2, 3, 0, 1, 2, 3, 0]]] + + # test exceptions + assert "policy can not be empty." in test_config([[1, 2, 3]], []) + assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]]) + assert "op of (op, prob) in policy[1][0] is not a c_transform op (TensorOp) nor a callable pyfunc" in test_config( + [[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]]) + assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [ + [(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]]) + + +if __name__ == "__main__": + test_random_select_subpolicy() diff --git a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py index 490dc3e419..1447c31c76 100644 --- a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py @@ -46,9 +46,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): test_op = c_vision.RandomVerticalFlipWithBBox(1) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) unaugSamp, augSamp = [], [] @@ -111,9 +111,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): test_op = c_vision.RandomVerticalFlipWithBBox(0.8) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "random_vertical_flip_with_bbox_01_c_result.npz" @@ -148,15 +148,15 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.RandomVerticalFlipWithBBox(1) # maps to convert data into valid edge case data - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) # Test Op added to list of Operations here - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) unaugSamp, augSamp = [], [] @@ -181,9 +181,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): test_op = c_vision.RandomVerticalFlipWithBBox(2) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) for _ in dataVoc2.create_dict_iterator(): diff --git a/tests/ut/python/dataset/test_repeat.py b/tests/ut/python/dataset/test_repeat.py index ca4702ff8c..a059fc3a9c 100644 --- a/tests/ut/python/dataset/test_repeat.py +++ b/tests/ut/python/dataset/test_repeat.py @@ -167,7 +167,7 @@ def test_nested_repeat5(): data = data.repeat(3) for _, d in enumerate(data): - assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) + np.testing.assert_array_equal(d[0], np.asarray([[0], [1], [2]])) assert sum([1 for _ in data]) == 6 @@ -180,7 +180,7 @@ def test_nested_repeat6(): data = data.repeat(3) for _, d in enumerate(data): - assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) + np.testing.assert_array_equal(d[0], np.asarray([[0], [1], [2]])) assert sum([1 for _ in data]) == 6 @@ -193,7 +193,7 @@ def test_nested_repeat7(): data = data.batch(3) for _, d in enumerate(data): - assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) + np.testing.assert_array_equal(d[0], np.asarray([[0], [1], [2]])) assert sum([1 for _ in data]) == 6 @@ -207,9 +207,9 @@ def test_nested_repeat8(): for i, d in enumerate(data): if i % 2 == 0: - assert np.array_equal(d[0], np.asarray([[0], [1]])) + np.testing.assert_array_equal(d[0], np.asarray([[0], [1]])) else: - assert np.array_equal(d[0], np.asarray([[2]])) + np.testing.assert_array_equal(d[0], np.asarray([[2]])) assert sum([1 for _ in data]) == 6 * 2 @@ -251,6 +251,49 @@ def test_nested_repeat11(): assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 +def test_repeat_count1(): + data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) + data1_size = data1.get_dataset_size() + logger.info("dataset size is {}".format(data1_size)) + batch_size = 2 + repeat_count = 4 + resize_height, resize_width = 32, 32 + decode_op = vision.Decode() + resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) + data1 = data1.map(input_columns=["image"], operations=decode_op) + data1 = data1.map(input_columns=["image"], operations=resize_op) + data1 = data1.repeat(repeat_count) + data1 = data1.batch(batch_size, drop_remainder=False) + dataset_size = data1.get_dataset_size() + logger.info("dataset repeat then batch's size is {}".format(dataset_size)) + num1_iter = 0 + for _ in data1.create_dict_iterator(): + num1_iter += 1 + + assert data1_size == 3 + assert dataset_size == num1_iter == 6 + +def test_repeat_count2(): + data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) + data1_size = data1.get_dataset_size() + logger.info("dataset size is {}".format(data1_size)) + batch_size = 2 + repeat_count = 4 + resize_height, resize_width = 32, 32 + decode_op = vision.Decode() + resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) + data1 = data1.map(input_columns=["image"], operations=decode_op) + data1 = data1.map(input_columns=["image"], operations=resize_op) + data1 = data1.batch(batch_size, drop_remainder=False) + data1 = data1.repeat(repeat_count) + dataset_size = data1.get_dataset_size() + logger.info("dataset batch then repeat's size is {}".format(dataset_size)) + num1_iter = 0 + for _ in data1.create_dict_iterator(): + num1_iter += 1 + + assert data1_size == 3 + assert dataset_size == num1_iter == 8 if __name__ == "__main__": test_tf_repeat_01() @@ -268,3 +311,5 @@ if __name__ == "__main__": test_nested_repeat9() test_nested_repeat10() test_nested_repeat11() + test_repeat_count1() + test_repeat_count2() diff --git a/tests/ut/python/dataset/test_resize.py b/tests/ut/python/dataset/test_resize.py new file mode 100644 index 0000000000..a187e0c53c --- /dev/null +++ b/tests/ut/python/dataset/test_resize.py @@ -0,0 +1,117 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Resize op in DE +""" +import pytest +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +from mindspore.dataset.transforms.vision.utils import Inter +from mindspore import log as logger +from util import visualize_list, save_and_check_md5, \ + config_get_set_seed, config_get_set_num_parallel_workers + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + +GENERATE_GOLDEN = False + + +def test_resize_op(plot=False): + def test_resize_op_parameters(test_name, size, plot): + """ + Test resize_op + """ + logger.info("Test resize: {0}".format(test_name)) + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + + # define map operations + decode_op = vision.Decode() + resize_op = vision.Resize(size) + + # apply map operations on images + data1 = data1.map(input_columns=["image"], operations=decode_op) + + data2 = data1.map(input_columns=["image"], operations=resize_op) + image_original = [] + image_resized = [] + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + image_1 = item1["image"] + image_2 = item2["image"] + image_original.append(image_1) + image_resized.append(image_2) + if plot: + visualize_list(image_original, image_resized) + + test_resize_op_parameters("Test single int for size", 10, plot=False) + test_resize_op_parameters("Test tuple for size", (10, 15), plot=False) + + +def test_resize_md5(plot=False): + def test_resize_md5_parameters(test_name, size, filename, seed, plot): + """ + Test Resize with md5 check + """ + logger.info("Test Resize with md5 check: {0}".format(test_name)) + original_seed = config_get_set_seed(seed) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # Generate dataset + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + decode_op = vision.Decode() + resize_op = vision.Resize(size) + data1 = data1.map(input_columns=["image"], operations=decode_op) + data2 = data1.map(input_columns=["image"], operations=resize_op) + image_original = [] + image_resized = [] + # Compare with expected md5 from images + save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) + + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + image_1 = item1["image"] + image_2 = item2["image"] + image_original.append(image_1) + image_resized.append(image_2) + if plot: + visualize_list(image_original, image_resized) + + # Restore configuration + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + test_resize_md5_parameters("Test single int for size", 5, "resize_01_result.npz", 5, plot) + test_resize_md5_parameters("Test tuple for size", (5, 7), "resize_02_result.npz", 7, plot) + + +def test_resize_op_invalid_input(): + def test_invalid_input(test_name, size, interpolation, error, error_msg): + logger.info("Test Resize with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + vision.Resize(size, interpolation) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid size parameter type as a single number", 4.5, Inter.LINEAR, TypeError, + "Size should be a single integer or a list/tuple (h, w) of length 2.") + test_invalid_input("invalid size parameter shape", (2, 3, 4), Inter.LINEAR, TypeError, + "Size should be a single integer or a list/tuple (h, w) of length 2.") + test_invalid_input("invalid size parameter type in a tuple", (2.3, 3), Inter.LINEAR, TypeError, + "incompatible constructor arguments.") + test_invalid_input("invalid Interpolation value", (2.3, 3), None, KeyError, "None") + + +if __name__ == "__main__": + test_resize_op(plot=True) + test_resize_md5(plot=True) + test_resize_op_invalid_input() diff --git a/tests/ut/python/dataset/test_resize_with_bbox.py b/tests/ut/python/dataset/test_resize_with_bbox.py index 3bb731ee97..af10ed9449 100644 --- a/tests/ut/python/dataset/test_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_resize_with_bbox.py @@ -16,9 +16,10 @@ Testing the resize with bounding boxes op in DE """ import numpy as np +import pytest + import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision - from mindspore import log as logger from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ save_and_check_md5 @@ -47,9 +48,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False): test_op = c_vision.ResizeWithBBox(100) # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[test_op]) filename = "resize_with_bbox_op_01_c_voc_result.npz" @@ -118,15 +119,15 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.ResizeWithBBox(500) # maps to convert data into valid edge case data - dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: ( img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], + dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], operations=[lambda img, bboxes: ( img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) @@ -172,6 +173,18 @@ def test_resize_with_bbox_op_bad_c(): check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") +def test_resize_with_bbox_op_params_outside_of_interpolation_dict(): + """ + Test passing in a invalid key for interpolation + """ + logger.info("test_resize_with_bbox_op_params_outside_of_interpolation_dict") + + size = (500, 500) + more_para = None + with pytest.raises(KeyError, match="None"): + c_vision.ResizeWithBBox(size, more_para) + + if __name__ == "__main__": test_resize_with_bbox_op_voc_c(plot_vis=False) test_resize_with_bbox_op_coco_c(plot_vis=False) diff --git a/tests/ut/python/dataset/test_save_op.py b/tests/ut/python/dataset/test_save_op.py new file mode 100644 index 0000000000..07f374130e --- /dev/null +++ b/tests/ut/python/dataset/test_save_op.py @@ -0,0 +1,434 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This is the test module for saveOp. +""" +import os +from string import punctuation +import mindspore.dataset as ds +from mindspore import log as logger +from mindspore.mindrecord import FileWriter +import numpy as np +import pytest + +CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" +CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" +TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord" +FILES_NUM = 1 +num_readers = 1 + + +@pytest.fixture(name="add_and_remove_cv_file") +def fixture_remove(): + """add/remove cv file""" + if os.path.exists("{}".format(CV_FILE_NAME1)): + os.remove("{}".format(CV_FILE_NAME1)) + if os.path.exists("{}.db".format(CV_FILE_NAME1)): + os.remove("{}.db".format(CV_FILE_NAME1)) + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + yield "yield_cv_data" + if os.path.exists("{}".format(CV_FILE_NAME1)): + os.remove("{}".format(CV_FILE_NAME1)) + if os.path.exists("{}.db".format(CV_FILE_NAME1)): + os.remove("{}.db".format(CV_FILE_NAME1)) + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + + +def test_case_00(add_and_remove_cv_file): # only bin data + data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')}] + schema = { + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer = FileWriter(CV_FILE_NAME1, FILES_NUM) + writer.add_schema(schema, "schema") + writer.write_raw_data(data) + writer.commit() + + d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) + d1.save(CV_FILE_NAME2, FILES_NUM) + data_value_to_list = [] + + for item in data: + new_data = {} + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + data_value_to_list.append(new_data) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + assert d2.get_dataset_size() == 5 + num_iter = 0 + for item in d2.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + +def test_case_01(add_and_remove_cv_file): # only raw data + data = [{"file_name": "001.jpg", "label": 43}, + {"file_name": "002.jpg", "label": 91}, + {"file_name": "003.jpg", "label": 61}, + {"file_name": "004.jpg", "label": 29}, + {"file_name": "005.jpg", "label": 78}, + {"file_name": "006.jpg", "label": 37}] + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"} + } + + writer = FileWriter(CV_FILE_NAME1, FILES_NUM) + writer.add_schema(schema, "schema") + writer.write_raw_data(data) + writer.commit() + + d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) + d1.save(CV_FILE_NAME2, FILES_NUM) + + data_value_to_list = [] + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(item["file_name"], dtype='S') + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + data_value_to_list.append(new_data) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + assert d2.get_dataset_size() == 6 + num_iter = 0 + for item in d2.create_dict_iterator(): + logger.info(item) + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + +def test_case_02(add_and_remove_cv_file): # muti-bytes + data = [{"file_name": "001.jpg", "label": 43, + "float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12345, + "float64": 1987654321.123456785, + "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, + "float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12445, + "float64": 1987654321.123456786, + "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, + "float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12545, + "float64": 1987654321.123456787, + "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, + "float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12645, + "float64": 1987654321.123456788, + "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, + "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12745, + "float64": 1987654321.123456789, + "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, + "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12745, + "float64": 1987654321.123456789, + "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')} + ] + schema = {"file_name": {"type": "string"}, + "float32_array": {"type": "float32", "shape": [-1]}, + "float64_array": {"type": "float64", "shape": [-1]}, + "float32": {"type": "float32"}, + "float64": {"type": "float64"}, + "source_sos_ids": {"type": "int32", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "label": {"type": "int32"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer = FileWriter(CV_FILE_NAME1, FILES_NUM) + writer.add_schema(schema, "schema") + writer.write_raw_data(data) + writer.commit() + + d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) + d1.save(CV_FILE_NAME2, FILES_NUM) + data_value_to_list = [] + + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(item["file_name"], dtype='S') + new_data['float32_array'] = item["float32_array"] + new_data['float64_array'] = item["float64_array"] + new_data['float32'] = item["float32"] + new_data['float64'] = item["float64"] + new_data['source_sos_ids'] = item["source_sos_ids"] + new_data['source_sos_mask'] = item["source_sos_mask"] + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + data_value_to_list.append(new_data) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + assert d2.get_dataset_size() == 6 + num_iter = 0 + for item in d2.create_dict_iterator(): + assert len(item) == 13 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + +def generator_1d(): + for i in range(10): + yield (np.array([i]),) + + +def test_case_03(add_and_remove_cv_file): + + # apply dataset operations + d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) + + d1.save(CV_FILE_NAME2) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + + i = 0 + for item in d2.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + np.testing.assert_array_equal(item["data"], golden) + i = i + 1 + + +def generator_with_type(t): + for i in range(64): + yield (np.array([i], dtype=t),) + + +def type_tester(t): + logger.info("Test with Type {}".format(t.__name__)) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False) + + data1 = data1.batch(4) + + data1 = data1.repeat(3) + + data1.save(CV_FILE_NAME2) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + + i = 0 + num_repeat = 0 + for item in d2.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + logger.info(item) + np.testing.assert_array_equal(item["data"], golden) + i = i + 4 + if i == 64: + i = 0 + num_repeat += 1 + assert num_repeat == 3 + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + + +def test_case_04(): + # uint8 will drop shape as mindrecord store uint8 as bytes + types = [np.int8, np.int16, np.int32, np.int64, + np.uint16, np.uint32, np.float32, np.float64] + + for t in types: + type_tester(t) + + +def test_case_05(add_and_remove_cv_file): + + d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) + + with pytest.raises(Exception, match="num_files should between 1 and 1000."): + d1.save(CV_FILE_NAME2, 0) + + +def test_case_06(add_and_remove_cv_file): + + d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) + + with pytest.raises(Exception, match="tfrecord dataset format is not supported."): + d1.save(CV_FILE_NAME2, 1, "tfrecord") + + +def cast_name(key): + """ + Cast schema names which containing special characters to valid names. + """ + special_symbols = set('{}{}'.format(punctuation, ' ')) + special_symbols.remove('_') + new_key = ['_' if x in special_symbols else x for x in key] + casted_key = ''.join(new_key) + return casted_key + + +def test_case_07(): + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False) + tf_data = [] + for x in d1.create_dict_iterator(): + tf_data.append(x) + d1.save(CV_FILE_NAME2, FILES_NUM) + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + mr_data = [] + for x in d2.create_dict_iterator(): + mr_data.append(x) + count = 0 + for x in tf_data: + for k, v in x.items(): + if isinstance(v, np.ndarray): + assert (v == mr_data[count][cast_name(k)]).all() + else: + assert v == mr_data[count][cast_name(k)] + count += 1 + assert count == 10 + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) diff --git a/tests/ut/python/dataset/test_sentencepiece_tokenizer.py b/tests/ut/python/dataset/test_sentencepiece_tokenizer.py new file mode 100644 index 0000000000..d50ed01e7b --- /dev/null +++ b/tests/ut/python/dataset/test_sentencepiece_tokenizer.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import copy +import mindspore.dataset.text as text +import mindspore.dataset as ds +from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType + +VOCAB_FILE = "../data/dataset/test_sentencepiece/botchan.txt" +DATA_FILE = "../data/dataset/testTokenizerData/sentencepiece_tokenizer.txt" + + +def test_from_vocab_to_str_UNIGRAM(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_from_vocab_to_str_BPE(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.BPE, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = ['▁I', '▁saw', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'c', 'ope', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_from_vocab_to_str_CHAR(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.CHAR, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = ['▁', 'I', '▁', 's', 'a', 'w', '▁', 'a', '▁', 'g', 'i', 'r', 'l', '▁', 'w', 'i', 't', 'h',\ + '▁', 'a', '▁', 't', 'e', 'l', 'e', 's', 'c', 'o', 'p', 'e', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_from_vocab_to_str_WORD(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.WORD, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = ['▁I', '▁saw', '▁a', '▁girl', '▁with', '▁a', '▁telescope.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_from_vocab_to_int(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.INT) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = [6, 329, 183, 8, 945, 23, 8, 3783, 4382, 4641, 1405, 4] + for i in dataset.create_dict_iterator(): + ret = i["text"] + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_from_file_to_str(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + text.SentencePieceVocab.save_model(vocab, "./", "m.model") + tokenizer = text.SentencePieceTokenizer("./m.model", out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_from_file_to_int(): + vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + text.SentencePieceVocab.save_model(vocab, "./", "m.model") + tokenizer = text.SentencePieceTokenizer("./m.model", out_type=SPieceTokenizerOutType.INT) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = [6, 329, 183, 8, 945, 23, 8, 3783, 4382, 4641, 1405, 4] + for i in dataset.create_dict_iterator(): + ret = i["text"] + for key, value in enumerate(ret): + assert value == expect[key] + + +def test_build_from_dataset(): + data = ds.TextFileDataset(VOCAB_FILE, shuffle=False) + vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def apply_func(dataset): + input_columns = ['text'] + output_columns = ['text2'] + dataset = dataset.rename(input_columns, output_columns) + return dataset + + +def zip_test(dataset): + dataset_1 = copy.deepcopy(dataset) + dataset_2 = copy.deepcopy(dataset) + dataset_1 = dataset_1.apply(apply_func) + dataset_zip = ds.zip((dataset_1, dataset_2)) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset_zip.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def concat_test(dataset): + dataset_1 = copy.deepcopy(dataset) + dataset = dataset.concat(dataset_1) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + +def test_with_zip_concat(): + data = ds.TextFileDataset(VOCAB_FILE, shuffle=False) + vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer, num_parallel_workers=2) + zip_test(dataset) + concat_test(dataset) + + +if __name__ == "__main__": + test_from_vocab_to_str_UNIGRAM() + test_from_vocab_to_str_BPE() + test_from_vocab_to_str_CHAR() + test_from_vocab_to_str_WORD() + test_from_vocab_to_int() + test_from_file_to_str() + test_from_file_to_int() + test_build_from_dataset() + test_with_zip_concat() diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 48775f15a2..d041f44bc1 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -79,12 +79,12 @@ def test_imagefolder(remove_json_files=True): # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(), data3.create_dict_iterator(), data4.create_dict_iterator()): - assert np.array_equal(item1['image'], item2['image']) - assert np.array_equal(item1['image'], item3['image']) - assert np.array_equal(item1['label'], item2['label']) - assert np.array_equal(item1['label'], item3['label']) - assert np.array_equal(item3['image'], item4['image']) - assert np.array_equal(item3['label'], item4['label']) + np.testing.assert_array_equal(item1['image'], item2['image']) + np.testing.assert_array_equal(item1['image'], item3['image']) + np.testing.assert_array_equal(item1['label'], item2['label']) + np.testing.assert_array_equal(item1['label'], item3['label']) + np.testing.assert_array_equal(item3['image'], item4['image']) + np.testing.assert_array_equal(item3['label'], item4['label']) num_samples += 1 logger.info("Number of data in data1: {}".format(num_samples)) @@ -119,10 +119,10 @@ def test_mnist_dataset(remove_json_files=True): num = 0 for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(), data3.create_dict_iterator()): - assert np.array_equal(data1['image'], data2['image']) - assert np.array_equal(data1['image'], data3['image']) - assert np.array_equal(data1['label'], data2['label']) - assert np.array_equal(data1['label'], data3['label']) + np.testing.assert_array_equal(data1['image'], data2['image']) + np.testing.assert_array_equal(data1['image'], data3['image']) + np.testing.assert_array_equal(data1['label'], data2['label']) + np.testing.assert_array_equal(data1['label'], data3['label']) num += 1 logger.info("mnist total num samples is {}".format(str(num))) @@ -160,10 +160,10 @@ def test_zip_dataset(remove_json_files=True): num_cols = len(d0) offset = 0 for t1 in d0: - assert np.array_equal(t1, d3[offset]) - assert np.array_equal(t1, d3[offset + num_cols]) - assert np.array_equal(t1, d4[offset]) - assert np.array_equal(t1, d4[offset + num_cols]) + np.testing.assert_array_equal(t1, d3[offset]) + np.testing.assert_array_equal(t1, d3[offset + num_cols]) + np.testing.assert_array_equal(t1, d4[offset]) + np.testing.assert_array_equal(t1, d4[offset + num_cols]) offset += 1 rows += 1 assert rows == 12 @@ -199,7 +199,7 @@ def test_random_crop(): for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(), data2.create_dict_iterator()): - assert np.array_equal(item1['image'], item1_1['image']) + np.testing.assert_array_equal(item1['image'], item1_1['image']) _ = item2["image"] # Restore configuration num_parallel_workers diff --git a/tests/ut/python/dataset/test_shuffle.py b/tests/ut/python/dataset/test_shuffle.py index 460c491ca1..6da7a1c885 100644 --- a/tests/ut/python/dataset/test_shuffle.py +++ b/tests/ut/python/dataset/test_shuffle.py @@ -13,10 +13,9 @@ # limitations under the License. # ============================================================================== import numpy as np -from util import save_and_check - import mindspore.dataset as ds from mindspore import log as logger +from util import save_and_check_dict # Note: Number of rows in test.data dataset: 12 DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] @@ -31,7 +30,6 @@ def test_shuffle_01(): # define parameters buffer_size = 5 seed = 1 - parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -39,7 +37,7 @@ def test_shuffle_01(): data1 = data1.shuffle(buffer_size=buffer_size) filename = "shuffle_01_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_shuffle_02(): @@ -50,7 +48,6 @@ def test_shuffle_02(): # define parameters buffer_size = 12 seed = 1 - parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -58,7 +55,7 @@ def test_shuffle_02(): data1 = data1.shuffle(buffer_size=buffer_size) filename = "shuffle_02_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_shuffle_03(): @@ -69,7 +66,6 @@ def test_shuffle_03(): # define parameters buffer_size = 2 seed = 1 - parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -77,7 +73,7 @@ def test_shuffle_03(): data1 = data1.shuffle(buffer_size) filename = "shuffle_03_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_shuffle_04(): @@ -88,7 +84,6 @@ def test_shuffle_04(): # define parameters buffer_size = 2 seed = 1 - parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, num_samples=2) @@ -96,7 +91,7 @@ def test_shuffle_04(): data1 = data1.shuffle(buffer_size=buffer_size) filename = "shuffle_04_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_shuffle_05(): @@ -107,7 +102,6 @@ def test_shuffle_05(): # define parameters buffer_size = 13 seed = 1 - parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) @@ -115,7 +109,7 @@ def test_shuffle_05(): data1 = data1.shuffle(buffer_size=buffer_size) filename = "shuffle_05_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_shuffle_06(): diff --git a/tests/ut/python/dataset/test_skip.py b/tests/ut/python/dataset/test_skip.py index 5dd7faa66a..87e4122f84 100644 --- a/tests/ut/python/dataset/test_skip.py +++ b/tests/ut/python/dataset/test_skip.py @@ -13,9 +13,12 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as vision +from mindspore import log as logger + DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" @@ -196,6 +199,29 @@ def test_skip_filter_2(): assert buf == [5, 6, 7, 8, 9, 10] +def test_skip_exception_1(): + data1 = ds.GeneratorDataset(generator_md, ["data"]) + + try: + data1 = data1.skip(count=-1) + num_iter = 0 + for _ in data1.create_dict_iterator(): + num_iter += 1 + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Skip count must be positive integer or 0." in str(e) + + +def test_skip_exception_2(): + ds1 = ds.GeneratorDataset(generator_md, ["data"]) + + with pytest.raises(ValueError) as e: + ds1 = ds1.skip(-2) + assert "Input count is not within the required interval" in str(e.value) + + + if __name__ == "__main__": test_tf_skip() test_generator_skip() @@ -208,3 +234,5 @@ if __name__ == "__main__": test_skip_take_2() test_skip_filter_1() test_skip_filter_2() + test_skip_exception_1() + test_skip_exception_2() diff --git a/tests/ut/python/dataset/test_slice_op.py b/tests/ut/python/dataset/test_slice_op.py index 6e81133a2a..72417bff71 100644 --- a/tests/ut/python/dataset/test_slice_op.py +++ b/tests/ut/python/dataset/test_slice_op.py @@ -121,21 +121,10 @@ def test_slice_exceptions(): slice_compare([1, 2, 3, 4, 5], 5) assert "Index 5 is out of bounds [0,5)" in str(info.value) - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(0)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) + slice_compare([1, 2, 3, 4, 5], slice(0)) + slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1)) + slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) + slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1)) def test_slice_all_str(): @@ -198,21 +187,10 @@ def test_slice_exceptions_str(): slice_compare([b"1", b"2", b"3", b"4", b"5"], 5) assert "Index 5 is out of bounds [0,5)" in str(info.value) - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1)) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_sliding_window.py b/tests/ut/python/dataset/test_sliding_window.py new file mode 100644 index 0000000000..4fdd7a25c0 --- /dev/null +++ b/tests/ut/python/dataset/test_sliding_window.py @@ -0,0 +1,105 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing SlidingWindow in mindspore.dataset +""" +import numpy as np +import mindspore.dataset as ds +import mindspore.dataset.text as text + +def test_sliding_window_string(): + """ test sliding_window with string type""" + inputs = [["大", "家", "早", "上", "好"]] + expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']]) + + dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) + dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) + + result = [] + for data in dataset.create_dict_iterator(): + for i in range(data['text'].shape[0]): + result.append([]) + for j in range(data['text'].shape[1]): + result[i].append(data['text'][i][j].decode('utf8')) + result = np.array(result) + np.testing.assert_array_equal(result, expect) + +def test_sliding_window_number(): + inputs = [1] + expect = np.array([[1]]) + + def gen(nums): + yield (np.array(nums),) + + dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"]) + dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1)) + + for data in dataset.create_dict_iterator(): + np.testing.assert_array_equal(data['number'], expect) + +def test_sliding_window_big_width(): + inputs = [[1, 2, 3, 4, 5]] + expect = np.array([]) + + dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False) + dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0)) + + for data in dataset.create_dict_iterator(): + np.testing.assert_array_equal(data['number'], expect) + +def test_sliding_window_exception(): + try: + _ = text.SlidingWindow(0, 0) + assert False + except ValueError: + pass + + try: + _ = text.SlidingWindow("1", 0) + assert False + except TypeError: + pass + + try: + _ = text.SlidingWindow(1, "0") + assert False + except TypeError: + pass + + try: + inputs = [[1, 2, 3, 4, 5]] + dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) + dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100)) + for _ in dataset.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "axis supports 0 or -1 only for now." in str(e) + + try: + inputs = ["aa", "bb", "cc"] + dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False) + dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0)) + for _ in dataset.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "SlidingWindosOp supports 1D Tensors only for now." in str(e) + +if __name__ == '__main__': + test_sliding_window_string() + test_sliding_window_number() + test_sliding_window_big_width() + test_sliding_window_exception() diff --git a/tests/ut/python/dataset/test_tensor_empty.py b/tests/ut/python/dataset/test_tensor_empty.py new file mode 100644 index 0000000000..f681055544 --- /dev/null +++ b/tests/ut/python/dataset/test_tensor_empty.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +import mindspore.dataset as ds + + +def test_tensor_empty(): + def gen(): + for _ in range(4): + (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1], + dtype=np.float64)) + + data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]) + + for d in data: + np.testing.assert_array_equal(np.array([], dtype=np.int64), d[0]) + np.testing.assert_array_equal(np.array([], dtype='S').reshape([0, 4]), d[1]) + np.testing.assert_array_equal(np.array([1], dtype=np.float64), d[2]) + + +def test_tensor_empty_map(): + def gen(): + for _ in range(4): + (yield np.array([], dtype=np.int64), np.array([], dtype='S'), np.array([1], dtype=np.float64)) + + data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]) + + def func(x, y, z): + x = np.array([1], dtype=np.int64) + y = np.array(["Hi"], dtype='S') + z = np.array([], dtype=np.float64) + return x, y, z + + data = data.map(input_columns=["col1", "col2", "col3"], operations=func) + + for d in data: + np.testing.assert_array_equal(np.array([1], dtype=np.int64), d[0]) + np.testing.assert_array_equal(np.array(["Hi"], dtype='S'), d[1]) + np.testing.assert_array_equal(np.array([], dtype=np.float64), d[2]) + + +def test_tensor_empty_batch(): + def gen(): + for _ in range(4): + (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1], + dtype=np.float64)) + + data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]).batch(2) + + for d in data: + np.testing.assert_array_equal(np.array([], dtype=np.int64).reshape([2, 0]), d[0]) + np.testing.assert_array_equal(np.array([], dtype='S').reshape([2, 0, 4]), d[1]) + np.testing.assert_array_equal(np.array([[1], [1]], dtype=np.float64), d[2]) + + +if __name__ == '__main__': + test_tensor_empty() + test_tensor_empty_map() + test_tensor_empty_batch() diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py deleted file mode 100644 index f57c387b35..0000000000 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test TFRecordDataset Ops -""" -import numpy as np -import pytest - -import mindspore.common.dtype as mstype -import mindspore.dataset as ds -from mindspore import log as logger -from util import save_and_check_dict - -FILES = ["../data/dataset/testTFTestAllTypes/test.data"] -DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" -SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" -DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] -SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" -GENERATE_GOLDEN = False - - -def test_tfrecord_shape(): - logger.info("test_tfrecord_shape") - schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json" - ds1 = ds.TFRecordDataset(FILES, schema_file) - ds1 = ds1.batch(2) - for data in ds1.create_dict_iterator(): - logger.info(data) - output_shape = ds1.output_shapes() - assert len(output_shape[-1]) == 1 - - -def test_tfrecord_read_all_dataset(): - logger.info("test_tfrecord_read_all_dataset") - schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" - ds1 = ds.TFRecordDataset(FILES, schema_file) - assert ds1.get_dataset_size() == 12 - count = 0 - for _ in ds1.create_tuple_iterator(): - count += 1 - assert count == 12 - - -def test_tfrecord_num_samples(): - logger.info("test_tfrecord_num_samples") - schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" - ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) - assert ds1.get_dataset_size() == 8 - count = 0 - for _ in ds1.create_dict_iterator(): - count += 1 - assert count == 8 - - -def test_tfrecord_num_samples2(): - logger.info("test_tfrecord_num_samples2") - schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" - ds1 = ds.TFRecordDataset(FILES, schema_file) - assert ds1.get_dataset_size() == 7 - count = 0 - for _ in ds1.create_dict_iterator(): - count += 1 - assert count == 7 - - -def test_tfrecord_shape2(): - logger.info("test_tfrecord_shape2") - ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) - ds1 = ds1.batch(2) - output_shape = ds1.output_shapes() - assert len(output_shape[-1]) == 2 - - -def test_tfrecord_files_basic(): - logger.info("test_tfrecord_files_basic") - - data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - filename = "tfrecord_files_basic.npz" - save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) - - -def test_tfrecord_no_schema(): - logger.info("test_tfrecord_no_schema") - - data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES) - filename = "tfrecord_no_schema.npz" - save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) - - -def test_tfrecord_pad(): - logger.info("test_tfrecord_pad") - - schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json" - data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES) - filename = "tfrecord_pad_bytes10.npz" - save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) - - -def test_tfrecord_read_files(): - logger.info("test_tfrecord_read_files") - pattern = DATASET_ROOT + "/test.data" - data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - assert sum([1 for _ in data]) == 12 - - pattern = DATASET_ROOT + "/test2.data" - data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - assert sum([1 for _ in data]) == 12 - - pattern = DATASET_ROOT + "/*.data" - data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES) - assert sum([1 for _ in data]) == 24 - - pattern = DATASET_ROOT + "/*.data" - data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES) - assert sum([1 for _ in data]) == 3 - - data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"], - SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES) - assert sum([1 for _ in data]) == 24 - - -def test_tfrecord_multi_files(): - logger.info("test_tfrecord_multi_files") - data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False) - data1 = data1.repeat(1) - num_iter = 0 - for _ in data1.create_dict_iterator(): - num_iter += 1 - - assert num_iter == 12 - - -def test_tfrecord_schema(): - logger.info("test_tfrecord_schema") - schema = ds.Schema() - schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) - schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) - schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2]) - schema.add_column('col_binary', de_type=mstype.uint8, shape=[1]) - schema.add_column('col_float', de_type=mstype.float32, shape=[1]) - schema.add_column('col_sint16', de_type=mstype.int64, shape=[1]) - schema.add_column('col_sint32', de_type=mstype.int64, shape=[1]) - schema.add_column('col_sint64', de_type=mstype.int64, shape=[1]) - data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES) - - data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - - for d1, d2 in zip(data1, data2): - for t1, t2 in zip(d1, d2): - assert np.array_equal(t1, t2) - - -def test_tfrecord_shuffle(): - logger.info("test_tfrecord_shuffle") - ds.config.set_seed(1) - data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) - data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - data2 = data2.shuffle(10000) - - for d1, d2 in zip(data1, data2): - for t1, t2 in zip(d1, d2): - assert np.array_equal(t1, t2) - - -def test_tfrecord_shard(): - logger.info("test_tfrecord_shard") - tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", - "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] - - def get_res(shard_id, num_repeats): - data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3, - shuffle=ds.Shuffle.FILES) - data1 = data1.repeat(num_repeats) - res = list() - for item in data1.create_dict_iterator(): - res.append(item["scalars"][0]) - return res - - # get separate results from two workers. the 2 results need to satisfy 2 criteria - # 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch) - # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4) - worker1_res = get_res(0, 16) - worker2_res = get_res(1, 16) - # Confirm each worker gets 3x16=48 rows - assert len(worker1_res) == 48 - assert len(worker1_res) == len(worker2_res) - # check criteria 1 - for i, _ in enumerate(worker1_res): - assert worker1_res[i] != worker2_res[i] - # check criteria 2 - assert set(worker2_res) == set(worker1_res) - - -def test_tfrecord_shard_equal_rows(): - logger.info("test_tfrecord_shard_equal_rows") - tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", - "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] - - def get_res(num_shards, shard_id, num_repeats): - ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True) - ds1 = ds1.repeat(num_repeats) - res = list() - for data in ds1.create_dict_iterator(): - res.append(data["scalars"][0]) - return res - - worker1_res = get_res(3, 0, 2) - worker2_res = get_res(3, 1, 2) - worker3_res = get_res(3, 2, 2) - # check criteria 1 - for i, _ in enumerate(worker1_res): - assert worker1_res[i] != worker2_res[i] - assert worker2_res[i] != worker3_res[i] - # Confirm each worker gets same number of rows - assert len(worker1_res) == 28 - assert len(worker1_res) == len(worker2_res) - assert len(worker2_res) == len(worker3_res) - - worker4_res = get_res(1, 0, 1) - assert len(worker4_res) == 40 - - -def test_tfrecord_no_schema_columns_list(): - logger.info("test_tfrecord_no_schema_columns_list") - data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) - row = data.create_dict_iterator().get_next() - assert row["col_sint16"] == [-32768] - - with pytest.raises(KeyError) as info: - _ = row["col_sint32"] - assert "col_sint32" in str(info.value) - - -def test_tfrecord_schema_columns_list(): - logger.info("test_tfrecord_schema_columns_list") - schema = ds.Schema() - schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) - schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) - schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2]) - schema.add_column('col_binary', de_type=mstype.uint8, shape=[1]) - schema.add_column('col_float', de_type=mstype.float32, shape=[1]) - schema.add_column('col_sint16', de_type=mstype.int64, shape=[1]) - schema.add_column('col_sint32', de_type=mstype.int64, shape=[1]) - schema.add_column('col_sint64', de_type=mstype.int64, shape=[1]) - data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"]) - row = data.create_dict_iterator().get_next() - assert row["col_sint16"] == [-32768] - - with pytest.raises(KeyError) as info: - _ = row["col_sint32"] - assert "col_sint32" in str(info.value) - - -def test_tfrecord_invalid_files(): - logger.info("test_tfrecord_invalid_files") - valid_file = "../data/dataset/testTFTestAllTypes/test.data" - invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" - files = [invalid_file, valid_file, SCHEMA_FILE] - - data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - - with pytest.raises(RuntimeError) as info: - _ = data.create_dict_iterator().get_next() - assert "cannot be opened" in str(info.value) - assert "not valid tfrecord files" in str(info.value) - assert valid_file not in str(info.value) - assert invalid_file in str(info.value) - assert SCHEMA_FILE in str(info.value) - - nonexistent_file = "this/file/does/not/exist" - files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file] - - with pytest.raises(ValueError) as info: - data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - assert "did not match any files" in str(info.value) - assert valid_file not in str(info.value) - assert invalid_file not in str(info.value) - assert SCHEMA_FILE not in str(info.value) - assert nonexistent_file in str(info.value) - - -if __name__ == '__main__': - test_tfrecord_shape() - test_tfrecord_read_all_dataset() - test_tfrecord_num_samples() - test_tfrecord_num_samples2() - test_tfrecord_shape2() - test_tfrecord_files_basic() - test_tfrecord_no_schema() - test_tfrecord_pad() - test_tfrecord_read_files() - test_tfrecord_multi_files() - test_tfrecord_schema() - test_tfrecord_shuffle() - test_tfrecord_shard() - test_tfrecord_shard_equal_rows() - test_tfrecord_no_schema_columns_list() - test_tfrecord_schema_columns_list() - test_tfrecord_invalid_files() diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py index e5b66696ea..a52da41e20 100644 --- a/tests/ut/python/dataset/test_uniform_augment.py +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -124,7 +124,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2): C.RandomColorAdjust(), C.RandomRotation(degrees=45)] - uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) transforms_all = [C.Decode(), C.Resize(size=[224, 224]), uni_aug, @@ -166,10 +166,10 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): F.Invert()] with pytest.raises(TypeError) as e: - _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) logger.info("Got an exception in DE: {}".format(str(e))) - assert "Argument tensor_op_5 with value" \ + assert "Argument tensor_ops[5] with value" \ " ,)" in str(e.value) @@ -187,7 +187,7 @@ def test_cpp_uniform_augment_exception_large_numops(num_ops=6): C.RandomRotation(degrees=45)] try: - _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + _ = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) @@ -207,7 +207,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): C.RandomRotation(degrees=45)] try: - _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + _ = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) @@ -227,7 +227,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): C.RandomRotation(degrees=45)] try: - _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + _ = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) @@ -248,7 +248,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): C.RandomCrop(size=[224, 224]), C.RandomHorizontalFlip() ] - uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) + uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) ds1 = ds1.map(input_columns="image", operations=uni_aug) # apply DatasetOps diff --git a/tests/ut/python/dataset/test_var_batch_map.py b/tests/ut/python/dataset/test_var_batch_map.py index 37db1a0898..75979457ce 100644 --- a/tests/ut/python/dataset/test_var_batch_map.py +++ b/tests/ut/python/dataset/test_var_batch_map.py @@ -36,22 +36,22 @@ def test_batch_corner_cases(): tst1, tst2, tst3, tst4 = [], [], [], [] # case 1 & 2, where batch_size is greater than the entire epoch, with drop equals to both val test_repeat_batch(gen_num=2, repeats=4, batch_size=7, drop=False, res=tst1) - assert np.array_equal(np.array([[0], [1], [0], [1], [0], [1], [0]]), tst1[0]), "\nATTENTION BATCH FAILED\n" - assert np.array_equal(np.array([[1]]), tst1[1]), "\nATTENTION TEST BATCH FAILED\n" + np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0], [1], [0]]), tst1[0], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(np.array([[1]]), tst1[1], "\nATTENTION TEST BATCH FAILED\n") assert len(tst1) == 2, "\nATTENTION TEST BATCH FAILED\n" test_repeat_batch(gen_num=2, repeats=4, batch_size=5, drop=True, res=tst2) - assert np.array_equal(np.array([[0], [1], [0], [1], [0]]), tst2[0]), "\nATTENTION BATCH FAILED\n" + np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0]]), tst2[0], "\nATTENTION BATCH FAILED\n") assert len(tst2) == 1, "\nATTENTION TEST BATCH FAILED\n" # case 3 & 4, batch before repeat with different drop test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=True, res=tst3) - assert np.array_equal(np.array([[0], [1], [2], [3]]), tst3[0]), "\nATTENTION BATCH FAILED\n" - assert np.array_equal(tst3[0], tst3[1]), "\nATTENTION BATCH FAILED\n" + np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst3[0], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst3[0], tst3[1], "\nATTENTION BATCH FAILED\n") assert len(tst3) == 2, "\nATTENTION BATCH FAILED\n" test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=False, res=tst4) - assert np.array_equal(np.array([[0], [1], [2], [3]]), tst4[0]), "\nATTENTION BATCH FAILED\n" - assert np.array_equal(tst4[0], tst4[2]), "\nATTENTION BATCH FAILED\n" - assert np.array_equal(tst4[1], np.array([[4]])), "\nATTENTION BATCH FAILED\n" - assert np.array_equal(tst4[1], tst4[3]), "\nATTENTION BATCH FAILED\n" + np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst4[0], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst4[0], tst4[2], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst4[1], np.array([[4]]), "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst4[1], tst4[3], "\nATTENTION BATCH FAILED\n") assert len(tst4) == 4, "\nATTENTION BATCH FAILED\n" diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py index 0545181360..cf3d457e31 100644 --- a/tests/ut/python/dataset/test_vocab.py +++ b/tests/ut/python/dataset/test_vocab.py @@ -133,7 +133,9 @@ def test_from_file(): assert test_config("w1 w2 w3 s1 s2 s3", 3, ["s1", "s2", "s3"], False) == [0, 1, 2, 3, 4, 5] # text exception special_words contains duplicate words assert "special_tokens contains duplicate" in test_config("w1", None, ["s1", "s1"], True) - + # test exception when vocab_size is negative + assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True) + assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True) if __name__ == '__main__': test_from_dict_exception() diff --git a/tests/ut/python/dataset/test_zip.py b/tests/ut/python/dataset/test_zip.py index a00a0823d4..ebfab86aff 100644 --- a/tests/ut/python/dataset/test_zip.py +++ b/tests/ut/python/dataset/test_zip.py @@ -252,14 +252,14 @@ def test_zip_exception_06(): if __name__ == '__main__': test_zip_01() - test_zip_02() - test_zip_03() - test_zip_04() - test_zip_05() - test_zip_06() - test_zip_exception_01() - test_zip_exception_02() - test_zip_exception_03() - test_zip_exception_04() - test_zip_exception_05() - test_zip_exception_06() + #test_zip_02() + #test_zip_03() + #test_zip_04() + #test_zip_05() + #test_zip_06() + #test_zip_exception_01() + #test_zip_exception_02() + #test_zip_exception_03() + #test_zip_exception_04() + #test_zip_exception_05() + #test_zip_exception_06() diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index 11c5735406..65ea55824c 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -24,9 +24,6 @@ import numpy as np import mindspore.dataset as ds from mindspore import log as logger -# These are the column names defined in the testTFTestAllTypes dataset -COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", - "col_sint16", "col_sint32", "col_sint64"] # These are list of plot title in different visualize modes PLOT_TITLE_DICT = { 1: ["Original image", "Transformed image"], @@ -59,7 +56,7 @@ def _compare_to_golden(golden_ref_dir, result_dict): """ test_array = np.array(list(result_dict.values())) golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] - assert np.array_equal(test_array, golden_array) + np.testing.assert_array_equal(test_array, golden_array) def _compare_to_golden_dict(golden_ref_dir, result_dict): @@ -82,39 +79,6 @@ def _save_json(filename, parameters, result_dict): fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) -def save_and_check(data, parameters, filename, generate_golden=False): - """ - Save the dataset dictionary and compare (as numpy array) with golden file. - Use create_dict_iterator to access the dataset. - Note: save_and_check() is deprecated; use save_and_check_dict(). - """ - num_iter = 0 - result_dict = {} - for column_name in COLUMNS: - result_dict[column_name] = [] - - for item in data.create_dict_iterator(): # each data is a dictionary - for data_key in list(item.keys()): - if data_key not in result_dict: - result_dict[data_key] = [] - result_dict[data_key].append(item[data_key].tolist()) - num_iter += 1 - - logger.info("Number of data in data1: {}".format(num_iter)) - - cur_dir = os.path.dirname(os.path.realpath(__file__)) - golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) - if generate_golden: - # Save as the golden result - _save_golden(cur_dir, golden_ref_dir, result_dict) - - _compare_to_golden(golden_ref_dir, result_dict) - - if SAVE_JSON: - # Save result to a json file for inspection - _save_json(filename, parameters, result_dict) - - def save_and_check_dict(data, filename, generate_golden=False): """ Save the dataset dictionary and compare (as dictionary) with golden file. @@ -203,6 +167,29 @@ def save_and_check_tuple(data, parameters, filename, generate_golden=False): _save_json(filename, parameters, result_dict) +def config_get_set_seed(seed_new): + """ + Get and return the original configuration seed value. + Set the new configuration seed value. + """ + seed_original = ds.config.get_seed() + ds.config.set_seed(seed_new) + logger.info("seed: original = {} new = {} ".format(seed_original, seed_new)) + return seed_original + + +def config_get_set_num_parallel_workers(num_parallel_workers_new): + """ + Get and return the original configuration num_parallel_workers value. + Set the new configuration num_parallel_workers value. + """ + num_parallel_workers_original = ds.config.get_num_parallel_workers() + ds.config.set_num_parallel_workers(num_parallel_workers_new) + logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, + num_parallel_workers_new)) + return num_parallel_workers_original + + def diff_mse(in1, in2): mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() return mse * 100 @@ -265,36 +252,13 @@ def visualize_image(image_original, image_de, mse=None, image_lib=None): plt.show() -def config_get_set_seed(seed_new): - """ - Get and return the original configuration seed value. - Set the new configuration seed value. - """ - seed_original = ds.config.get_seed() - ds.config.set_seed(seed_new) - logger.info("seed: original = {} new = {} ".format(seed_original, seed_new)) - return seed_original - - -def config_get_set_num_parallel_workers(num_parallel_workers_new): - """ - Get and return the original configuration num_parallel_workers value. - Set the new configuration num_parallel_workers value. - """ - num_parallel_workers_original = ds.config.get_num_parallel_workers() - ds.config.set_num_parallel_workers(num_parallel_workers_new) - logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, - num_parallel_workers_new)) - return num_parallel_workers_original - - -def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3): +def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3): """ - Take a list of un-augmented and augmented images with "annotation" bounding boxes + Take a list of un-augmented and augmented images with "bbox" bounding boxes Plot images to compare test correct BBox augment functionality :param orig: list of original images and bboxes (without aug) :param aug: list of augmented images and bboxes - :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "annotation" (VOC) + :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC) :param plot_rows: number of rows on plot (rows = samples on one plot) :return: None """ @@ -373,7 +337,7 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): :return: None """ - def add_bad_annotation(img, bboxes, invalid_bbox_type_): + def add_bad_bbox(img, bboxes, invalid_bbox_type_): """ Used to generate erroneous bounding box examples on given img. :param img: image where the bounding boxes are. @@ -402,15 +366,15 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): try: # map to use selected invalid bounding box type - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=lambda img, bboxes: add_bad_annotation(img, bboxes, invalid_bbox_type)) + data = data.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type)) # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + data = data.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) # Add column for "bbox" for _, _ in enumerate(data.create_dict_iterator()): break except RuntimeError as error: diff --git a/tests/ut/python/dtype/test_list.py b/tests/ut/python/dtype/test_list.py index ffd8c8eae3..c63763e295 100644 --- a/tests/ut/python/dtype/test_list.py +++ b/tests/ut/python/dtype/test_list.py @@ -18,8 +18,10 @@ import numpy as np import pytest import mindspore.nn as nn import mindspore.context as context +import mindspore as ms from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import composite as C from mindspore.common import dtype as mstype from tests.ut.python.ut_filter import non_graph_engine from tests.mindspore_test_framework.mindspore_test import mindspore_test @@ -48,7 +50,7 @@ def test_list_equal(): ret = net(x, y) print(ret.asnumpy()) - assert ret == x + assert np.all(ret.asnumpy() == x.asnumpy()) assert ret.dtype == mstype.int32 assert ret.shape == (6, 8, 10) @@ -70,7 +72,7 @@ def test_list_not_equal(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = [1, 2, 3] net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_list_expansion(): @@ -91,7 +93,7 @@ def test_list_expansion(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = [1, 2, 3] net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_list_append(): @@ -114,7 +116,7 @@ def test_list_append(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = [1, 2, 3] net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_class_member_list_append(): @@ -282,3 +284,26 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) def test_exec(): context.set_context(mode=context.GRAPH_MODE) return test_exec_case + + +def test_grad_make_list(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, idx, x): + return x[idx, :, :] + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, x) diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index f6b60b3d2e..e44f824ce2 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids): # pylint: disable=unused-argument def destroy_group(group): pass + + +# pylint: disable=unused-argument +def set_fusion_strategy_by_idx(): + pass + + +# pylint: disable=unused-argument +def set_fusion_strategy_by_size(): + pass diff --git a/tests/ut/python/ir/test_dtype.py b/tests/ut/python/ir/test_dtype.py index 1523a77ea3..49f834092e 100644 --- a/tests/ut/python/ir/test_dtype.py +++ b/tests/ut/python/ir/test_dtype.py @@ -134,3 +134,11 @@ def test_dtype(): with pytest.raises(NotImplementedError): x = 1.5 dtype.get_py_obj_dtype(type(type(x))) + + +def test_type_equal(): + t1 = (dtype.int32, dtype.int32) + valid_types = [dtype.float16, dtype.float32] + assert t1 not in valid_types + assert dtype.int32 not in valid_types + assert dtype.float32 in valid_types diff --git a/tests/ut/python/ir/test_indexed_slices.py b/tests/ut/python/ir/test_indexed_slices.py deleted file mode 100644 index 36dfe464cb..0000000000 --- a/tests/ut/python/ir/test_indexed_slices.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -@File : test_indexed_slices.py -@Author: -@Date : 2020-06-08 -@Desc : test mindspore indexed_slices's operation -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.ops import operations as P -from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like -from mindspore.ops.primitive import constexpr -from mindspore.ops._grad.grad_base import bprop_getters -from mindspore import Tensor, IndexedSlices, context -from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.common import dtype as mstype -from mindspore._checkparam import Validator as validator -from mindspore._checkparam import Rel -from mindspore.nn import Optimizer -from mindspore.nn import TrainOneStepCell, WithLossCell - -context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) - -reduce_sum = P.ReduceSum() -unsorted_segment_sum = P.UnsortedSegmentSum() -transpose = P.Transpose() -shape_op = P.Shape() -reshape = P.Reshape() -size_op = P.Size() -invert_permutation = P.InvertPermutation() -logical_and = P.LogicalAnd() - -@constexpr -def _generate_shape_index(out_shape, indices_shape, axis): - out_rank = len(out_shape) - ind_rank = len(indices_shape) - if axis < 0: - axis += out_rank - ind_rank + 1 - perm_part1 = tuple(range(axis, axis + ind_rank)) - index = tuple(range(out_rank)) - perm = perm_part1 + index[:axis] + index[axis + ind_rank:] - return perm - -@constexpr -def _generate_inverse_index(x_shape, axis): - x_rank = len(x_shape) - index = tuple(range(x_rank)) - if axis < 0: - axis += x_rank - perm = index[1:1 + axis] + (0,) + index[1 + axis:] - return perm - -class MySparseGatherV2(P.GatherV2): - """ - For test - """ - -@bprop_getters.register(MySparseGatherV2) -def get_bprop_sparse_gather_v2(self): - """Generate bprop for MySparseGatherV2""" - - def bprop(x, indices, axis, out, dout): - x_shp = shape_op(x) - if axis == 0: - indices_size = (size_op(indices),) - x_tail_shp = x_shp[1:] - values_shape = indices_size + x_tail_shp - values = reshape(dout, values_shape) - indices = reshape(indices, indices_size) - return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) - if F.rank(dout) == 0: - dout = P.ExpandDims()(dout, -1) - if F.rank(indices) == 0: - indices = P.ExpandDims()(indices, -1) - out_shp = shape_op(dout) - ind_shp = shape_op(indices) - # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) - perm_1 = _generate_shape_index(out_shp, ind_shp, axis) - values_transpose = transpose(dout, perm_1) - params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) - # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) - perm_2 = _generate_inverse_index(x_shp, axis) - params_grad = transpose(params_grad, perm_2) - return params_grad, zeros_like(indices), zeros_like(axis) - - return bprop - -adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") -@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool") -def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param, - m, v, gradient, decay_flag): - return gradient.values() - -@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, - m, v, gradient, decay_flag): - op_mul = P.Mul() - op_square = P.Square() - op_sqrt = P.Sqrt() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() - - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) - - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) - - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - - beta2, op_square(gradient_fp32)) - - update = next_m / (op_sqrt(next_v) + eps) - if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) - - update_with_lr = op_mul(lr, update) - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - - next_v = F.depend(next_v, F.assign(param, next_param)) - next_v = F.depend(next_v, F.assign(m, next_m)) - next_v = F.depend(next_v, F.assign(v, next_v)) - return next_v - - -def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): - """Check the type of inputs.""" - validator.check_value_type("beta1", beta1, [float], prim_name) - validator.check_value_type("beta2", beta2, [float], prim_name) - validator.check_value_type("eps", eps, [float], prim_name) - validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - - -class AdamWeightDecaySparse(Optimizer): - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(AdamWeightDecaySparse, self).__init__(learning_rate, params) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) - self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) - self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) - self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) - - self.params = self.parameters - self.moments1 = self.params.clone(prefix="adam_m", init='zeros') - self.moments2 = self.params.clone(prefix="adam_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) - self.map = C.Map() - - def construct(self, gradients): - lr = self.get_lr() - updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) - return updated_velocity - - -def test_indexed_slices_make_indexed_slices(): - class MakeIndexedSlices(nn.Cell): - def __init__(self): - super(MakeIndexedSlices, self).__init__() - self.dense_shape = (3, 4) - def construct(self, indices, values): - ret = (IndexedSlices(indices, values, self.dense_shape),) - return ret[0] - indices = Tensor([[0, 0], [1, 2]]) - values = Tensor([1, 2], dtype=ms.float32) - MakeIndexedSlices()(indices, values) - - -def test_indexed_slices_attr(): - class IndexedSlicesGetAttr(nn.Cell): - def __init__(self): - super(IndexedSlicesGetAttr, self).__init__() - self.dense_shape = (3, 4) - def construct(self, indices, values): - x = IndexedSlices(indices, values, self.dense_shape) - return x.values(), x.indices(), x.dense_shape() - indices = Tensor([[0, 0], [1, 2]]) - values = Tensor([1, 2], dtype=ms.float32) - IndexedSlicesGetAttr()(indices, values) - - -def test_indexed_slices_sparse_gatherv2_grad_all(): - grad_all = C.GradOperation('get_all', get_all=True) - class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - def construct(self, x, y): - grad = grad_all(self.network)(x, y) - return grad, grad[0], grad[1] - class SparseGatherV2(nn.Cell): - def __init__(self): - super(SparseGatherV2, self).__init__() - self.sparse_gatherv2 = MySparseGatherV2() - self.axis = 0 - def construct(self, params, indices): - return self.sparse_gatherv2(params, indices, self.axis) - params = Tensor(np.ones([3, 1, 2]).astype(np.int32)) - indices = Tensor(np.array([0, 1]).astype(np.int32)) - GradWrap(SparseGatherV2())(params, indices) - - -def test_indexed_slices_sparse_gatherv2_grad_with_pram(): - grad_by_list = C.GradOperation('get_by_list', get_by_list=True) - class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) - def construct(self, x): - weights = self.weights - grad = grad_by_list(self.network, weights)(x) - x = grad[0] - return x, x.values(), x.indices(), x.dense_shape() - class SparseGatherV2(nn.Cell): - def __init__(self): - super(SparseGatherV2, self).__init__() - self.sparse_gatherv2 = MySparseGatherV2() - self.axis = 0 - self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") - def construct(self, indices): - return self.sparse_gatherv2(self.params, indices, self.axis) - indices = Tensor(np.array([0, 1]).astype(np.int32)) - network = GradWrap(SparseGatherV2()) - network(indices) - - -def test_indexed_slices_env_get(): - class Loss(nn.Cell): - def __init__(self): - super(Loss, self).__init__() - def construct(self, base, target): - return base - class NetWithSparseGatherV2(nn.Cell): - def __init__(self): - super(NetWithSparseGatherV2, self).__init__() - self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") - self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") - self.gatherv2 = MySparseGatherV2() - self.axis = 0 - def construct(self, indices): - return self.gatherv2(self.w1, indices, self.axis) * self.w2 - - inputs = Tensor(np.array([0, 1]).astype(np.int32)) - label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) - net = NetWithSparseGatherV2() - net.set_train() - loss = Loss() - optimizer = AdamWeightDecaySparse(net.trainable_params()) - - net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepCell(net_with_loss, optimizer) - train_network(inputs, label) diff --git a/tests/ut/python/ir/test_row_tensor.py b/tests/ut/python/ir/test_row_tensor.py new file mode 100644 index 0000000000..c2097a8316 --- /dev/null +++ b/tests/ut/python/ir/test_row_tensor.py @@ -0,0 +1,453 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +@File : test_row_tensor.py +@Author: +@Date : 2020-06-08 +@Desc : test mindspore row_tensor's operation +""" +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like +from mindspore.ops.primitive import constexpr +from mindspore.ops._grad.grad_base import bprop_getters +from mindspore import Tensor, RowTensor, context +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.nn import Optimizer +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.train import Model +from ....dataset_mock import MindData + +context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + +reduce_sum = P.ReduceSum() +unsorted_segment_sum = P.UnsortedSegmentSum() +transpose = P.Transpose() +shape_op = P.Shape() +reshape = P.Reshape() +size_op = P.Size() +invert_permutation = P.InvertPermutation() +logical_and = P.LogicalAnd() + +def get_axis(x): + shape = shape_op(x) + length = F.tuple_len(shape) + perm = F.make_range(0, length) + return perm + +class MSELoss(nn.Cell): + def __init__(self): + super(MSELoss, self).__init__() + self.reduce_sum = P.ReduceSum() + self.square = P.Square() + self.reduce_mean = P.ReduceMean() + + def construct(self, data, label): + diff = data - label + return self.reduce_mean(self.square(diff), get_axis(diff)) + + +class MindDataSet(MindData): + def __init__(self, dataset_types, dataset_shapes): + super(MindDataSet, self).__init__(size=2, batch_size=32, + np_types=dataset_types, + output_shapes=dataset_shapes, + input_indexs=(0, 1)) + def __next__(self): + if self._size < self._iter_num: + raise StopIteration + self._iter_num += 1 + lst = [] + for shape_, type_ in zip(self._output_shapes, self._np_types): + lst.append(Tensor(np.ones(shape_).astype(type_))) + return tuple(lst) + + +@constexpr +def _generate_shape_index(out_shape, indices_shape, axis): + out_rank = len(out_shape) + ind_rank = len(indices_shape) + if axis < 0: + axis += out_rank - ind_rank + 1 + perm_part1 = tuple(range(axis, axis + ind_rank)) + index = tuple(range(out_rank)) + perm = perm_part1 + index[:axis] + index[axis + ind_rank:] + return perm + +@constexpr +def _generate_inverse_index(x_shape, axis): + x_rank = len(x_shape) + index = tuple(range(x_rank)) + if axis < 0: + axis += x_rank + perm = index[1:1 + axis] + (0,) + index[1 + axis:] + return perm + +class MySparseGatherV2(P.GatherV2): + """ + For test + """ + +@bprop_getters.register(MySparseGatherV2) +def get_bprop_sparse_gather_v2(self): + """Generate bprop for MySparseGatherV2""" + + def bprop(x, indices, axis, out, dout): + x_shp = shape_op(x) + if axis == 0: + indices_size = (size_op(indices),) + x_tail_shp = x_shp[1:] + values_shape = indices_size + x_tail_shp + values = reshape(dout, values_shape) + indices = reshape(indices, indices_size) + return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis) + if F.rank(dout) == 0: + dout = P.ExpandDims()(dout, -1) + if F.rank(indices) == 0: + indices = P.ExpandDims()(indices, -1) + out_shp = shape_op(dout) + ind_shp = shape_op(indices) + # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) + perm_1 = _generate_shape_index(out_shp, ind_shp, axis) + values_transpose = transpose(dout, perm_1) + params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) + perm_2 = _generate_inverse_index(x_shp, axis) + params_grad = transpose(params_grad, perm_2) + return params_grad, zeros_like(indices), zeros_like(axis) + + return bprop + +adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "RowTensor", "Bool") +def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, + m, v, gradient, decay_flag): + return gradient.values + +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, + m, v, gradient, decay_flag): + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) + + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) + + update = next_m / (op_sqrt(next_v) + eps) + if decay_flag: + update = update + op_mul(weight_decay_tensor, param_fp32) + + update_with_lr = op_mul(lr, update) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_v = F.depend(next_v, F.assign(param, next_param)) + next_v = F.depend(next_v, F.assign(m, next_m)) + next_v = F.depend(next_v, F.assign(v, next_v)) + return next_v + + +def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) + + +class AdamWeightDecaySparse(Optimizer): + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + super(AdamWeightDecaySparse, self).__init__(learning_rate, params) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) + + self.params = self.parameters + self.moments1 = self.params.clone(prefix="adam_m", init='zeros') + self.moments2 = self.params.clone(prefix="adam_v", init='zeros') + self.decay_flag = tuple(decay_filter(x) for x in self.params) + self.map = C.Map() + + def construct(self, gradients): + lr = self.get_lr() + updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, self.decay_flag) + return updated_velocity + + +def test_row_tensor_make_row_tensor(): + class MakeRowTensor(nn.Cell): + def __init__(self): + super(MakeRowTensor, self).__init__() + self.dense_shape = (3, 2) + def construct(self, indices, values): + ret = (RowTensor(indices, values, self.dense_shape),) + return ret[0] + indices = Tensor([1, 2]) + values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) + MakeRowTensor()(indices, values) + + +class RowTensorGetAttr(nn.Cell): + def __init__(self, dense_shape): + super(RowTensorGetAttr, self).__init__() + self.dense_shape = dense_shape + def construct(self, indices, values): + x = RowTensor(indices, values, self.dense_shape) + return x.values, x.indices, x.dense_shape + + +def test_row_tensor_attr(): + indices = Tensor([0]) + values = Tensor([[1, 2]], dtype=ms.float32) + RowTensorGetAttr((3, 2))(indices, values) + + +def test_row_tensor_sparse_gatherv2_grad_all(): + grad_all = C.GradOperation('get_all', get_all=True) + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + def construct(self, x, y): + grad = grad_all(self.network)(x, y) + return grad[0].indices, grad[0].values, grad[0].dense_shape + class SparseGatherV2(nn.Cell): + def __init__(self): + super(SparseGatherV2, self).__init__() + self.sparse_gatherv2 = MySparseGatherV2() + self.axis = 0 + def construct(self, params, indices): + return self.sparse_gatherv2(params, indices, self.axis) + params = Tensor(np.ones([3, 1, 2]).astype(np.int32)) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + GradWrap(SparseGatherV2())(params, indices) + + +def test_row_tensor_sparse_gatherv2_grad_with_pram(): + grad_by_list = C.GradOperation('get_by_list', get_by_list=True) + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + def construct(self, x): + weights = self.weights + grad = grad_by_list(self.network, weights)(x) + x = grad[0] + return x.values, x.indices, x.dense_shape + class SparseGatherV2(nn.Cell): + def __init__(self): + super(SparseGatherV2, self).__init__() + self.sparse_gatherv2 = MySparseGatherV2() + self.axis = 0 + self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") + def construct(self, indices): + return self.sparse_gatherv2(self.params, indices, self.axis) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + network = GradWrap(SparseGatherV2()) + network(indices) + + +def test_row_tensor_env_get(): + class Loss(nn.Cell): + def __init__(self): + super(Loss, self).__init__() + def construct(self, base, target): + return base + class NetWithSparseGatherV2(nn.Cell): + def __init__(self): + super(NetWithSparseGatherV2, self).__init__() + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") + self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") + self.gatherv2 = MySparseGatherV2() + self.axis = 0 + def construct(self, indices): + return self.gatherv2(self.w1, indices, self.axis) * self.w2 + + inputs = Tensor(np.array([0, 1]).astype(np.int32)) + label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) + net = NetWithSparseGatherV2() + net.set_train() + loss = Loss() + optimizer = AdamWeightDecaySparse(net.trainable_params()) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + train_network(inputs, label) + + +def test_row_tensor_model_train(): + class Net(nn.Cell): + def __init__(self, in_features, out_features): + super(Net, self).__init__() + self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") + self.add = P.TensorAdd() + self.cast = P.Cast() + self.flag = True + + def construct(self, inputs, label): + x = self.add(inputs, self.weight) + if self.flag: + x = self.cast(x, mstype.float32) + return x + + dataset_types = (np.float32, np.float32) + dataset_shapes = ((16, 16), (16, 16)) + dataset = MindDataSet(dataset_types, dataset_shapes) + net = Net(16, 16) + net.set_train() + + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False) + + +def test_row_tensor_values_dim_greater_than_dense_shape_dim(): + indices = Tensor(np.array([0, 1], dtype=np.int32)) + values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) + dense_shape = (3, 4) + with pytest.raises(TypeError): + RowTensorGetAttr(dense_shape)(indices, values) + + +def test_row_tensor_values_dim_less_than_dense_shape_dim(): + indices = Tensor(np.array([0, 1], dtype=np.int32)) + values = Tensor(np.random.randn(2, 4).astype(np.float32)) + dense_shape = (3, 4, 5) + with pytest.raises(TypeError): + RowTensorGetAttr(dense_shape)(indices, values) + + +def test_row_tensor_value_and_dense_shape_illegal(): + indices = Tensor(np.array([0, 1], dtype=np.int32)) + values = Tensor(np.random.randn(2, 4).astype(np.float32)) + dense_shape = (3, 5) + with pytest.raises(TypeError): + RowTensorGetAttr(dense_shape)(indices, values) + + +class RowTensorValuesDouble(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + indices = x.indices + values = x.values * 2 + dense_shape = x.dense_shape + return RowTensor(indices, values, dense_shape) + + +class RowTensorValuesAdd2(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + indices = x.indices + values = x.values + 2 + dense_shape = x.dense_shape + return RowTensor(indices, values, dense_shape) + + +class RowTensorWithControlIf(nn.Cell): + def __init__(self, dense_shape): + super().__init__() + self.op1 = RowTensorValuesDouble() + self.op2 = RowTensorValuesAdd2() + self.dense_shape = dense_shape + + def construct(self, a, b, indices, values): + x = RowTensor(indices, values, self.dense_shape) + if a > b: + x = self.op1(x) + else: + x = self.op2(x) + return x.indices, x.values + + +def test_row_tensor_with_control_flow_if(): + a = Tensor(np.array(0).astype(np.int32)) + b = Tensor(np.array(2).astype(np.int32)) + indices = Tensor(np.array([0, 2]).astype(np.int32)) + values = Tensor(np.ones([2, 2]).astype(np.float32)) + dense_shape = (5, 2) + + net = RowTensorWithControlIf(dense_shape) + net(a, b, indices, values) + + +class EmbeddingLookUpBnNet(nn.Cell): + def __init__(self, param_np, target='CPU'): + super().__init__() + self.param = Parameter(Tensor(param_np), name="w1") + self.embedding_lookup = nn.EmbeddingLookup(target=target) + self.bn = nn.BatchNorm2d(num_features=3) + self.mul = P.Mul() + self.reshape = P.Reshape() + self.relu = nn.PReLU() + + def construct(self, indices): + x = self.embedding_lookup(self.param, indices) + x = self.reshape(x, (2, 3, 2, 2)) + x = self.relu(x) + x = self.bn(x) + return x + + +def test_embedding_lookup_with_mix_precision(): + param_np = np.ones([8, 8]).astype(np.float32) + data = Tensor(np.array([0, 1, 2]).astype(np.int32)) + label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) + net = EmbeddingLookUpBnNet(param_np, target='CPU') + + criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') + optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) + optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") + train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2") + train_network.set_train() + for _ in range(2): + train_network(data, label) diff --git a/tests/ut/python/ir/test_sparse_tensor.py b/tests/ut/python/ir/test_sparse_tensor.py new file mode 100644 index 0000000000..76f53f2e13 --- /dev/null +++ b/tests/ut/python/ir/test_sparse_tensor.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +@File : test_sparse_tensor.py +@Author: +@Date : 2020-07-16 +@Desc : test mindspore sparse_tensor's operation +""" +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops import composite as C +from mindspore import Tensor, SparseTensor, context + +context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + +grad_op = C.GradOperation('get_all', get_all=True) + +class MakeSparseTensor(nn.Cell): + def __init__(self, dense_shape): + super(MakeSparseTensor, self).__init__() + self.dense_shape = dense_shape + def construct(self, indices, values): + ret = (SparseTensor(indices, values, self.dense_shape),) + return ret[0] + + +def test_sparse_tensor_make_sparse_tensor(): + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + MakeSparseTensor((3, 4))(indices, values) + + +def test_sparse_tensor_attr(): + class SparseTensorGetAttr(nn.Cell): + def __init__(self): + super(SparseTensorGetAttr, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + x = SparseTensor(indices, values, self.dense_shape) + return x.values, x.indices, x.dense_shape + + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + SparseTensorGetAttr()(indices, values) + grad_op(SparseTensorGetAttr())(indices, values) + + +def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim(): + indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32)) + values = Tensor(np.array([100, 200], dtype=np.float32)) + dense_shape = (2, 2) + with pytest.raises(TypeError): + MakeSparseTensor(dense_shape)(indices, values) + + +def test_sparse_tensor_indices_dim_less_than_dense_shape_dim(): + indices = Tensor(np.array([[0, 0], [0, 1]], dtype=np.int32)) + values = Tensor(np.array([100, 200], dtype=np.float32)) + dense_shape = (2, 2, 2) + with pytest.raises(TypeError): + MakeSparseTensor(dense_shape)(indices, values) + + +def test_sparse_tensor_to_tensor(): + class SparseToDenseCell(nn.Cell): + def __init__(self, dense_shape): + super(SparseToDenseCell, self).__init__() + self.dense_shape = dense_shape + self.sparse_to_dense = nn.SparseToDense() + def construct(self, indices, values): + sparse = SparseTensor(indices, values, self.dense_shape) + return self.sparse_to_dense(sparse) + + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + dense_shape = (3, 4) + SparseToDenseCell(dense_shape)(indices, values) + grad_op(SparseToDenseCell(dense_shape))(indices, values) diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 72100b2715..762e5b175a 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -225,7 +225,7 @@ def test_div(): @non_graph_engine def test_parameter(): x = Parameter(initializer(1, [1], ms.float32), name="beta1_power") - x.init_data() + x = x.init_data() z = x / 2 print(z) @@ -472,3 +472,7 @@ def test_tensor_operation(): assert np.all(x.asnumpy() == np.ones((3, 3))) with pytest.raises(ValueError): res = x * (2, 3) + res = 5 % x + assert np.all(x.asnumpy() == np.ones((3, 3))) + res = 5 // x + assert np.all(x.asnumpy() == np.ones((3, 3))) diff --git a/tests/ut/python/mindrecord/test_mindrecord_exception.py b/tests/ut/python/mindrecord/test_mindrecord_exception.py index 46c2371b24..e37d9692a4 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_exception.py +++ b/tests/ut/python/mindrecord/test_mindrecord_exception.py @@ -15,6 +15,8 @@ """test mindrecord exception""" import os import pytest + +import numpy as np from utils import get_data from mindspore import log as logger @@ -341,3 +343,532 @@ def test_mindpage_filename_not_exist(fixture_cv_file): _ = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] + +def test_invalid_schema(): + mindrecord_file_name = "test.mindrecord" + writer = FileWriter(mindrecord_file_name) + + # string => str + schema = {"file_name": {"type": "str"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # int32 => np.int32 + schema = {"file_name": {"type": "string"}, + "label": {"type": "np.int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # float64 => np.float64 + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "np.float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # int64 => int8 + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int8", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # int64 => uint64 + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "uint64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # bytes => byte + schema = {"file_name": {"type": "strint"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "byte"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # float32 => float3 + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float3", "shape": [2, 88]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # string with shape + schema = {"file_name": {"type": "string", "shape": [-1]}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + + # bytes with shape + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes", "shape": [100]}} + with pytest.raises(Exception, match="Schema format is error"): + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_with_invalid_data(): + mindrecord_file_name = "test.mindrecord" + + # field: file_name => filename + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"filename": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"filename": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"filename": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"filename": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"filename": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"filename": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # field: mask => masks + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "masks": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "masks": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "masks": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "masks": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "masks": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "masks": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # field: data => image + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "image": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "image": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "image": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "image": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "image": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "image": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # field: label => lable + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "lable": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "lable": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "lable": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "lable": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "lable": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "lable": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # field: score => scores + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": 43, "scores": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, "scores": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, "scores": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, "scores": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, "scores": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, "scores": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # string type with int value + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": 1, "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": 2, "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": 3, "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": 4, "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": 5, "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": 6, "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # field with int64 type, but the real data is string + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": "cat", "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": "dog", "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.6], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": "bird", "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": "mouse", "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": "tiger", "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": "lion", "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # bytes field is string + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": "image bytes abc"}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": "image bytes def"}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": "image bytes ghi"}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": "image bytes jkl"}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": "image bytes mno"}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": "image bytes pqr"} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # field is not numpy type + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": [3, 6, 9], + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": [1, 4, 7], + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": [7, 6, 3], + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": [2, 8, 0], + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": [3, 1, 2], + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": [7, 6, 7], + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # not enough field + with pytest.raises(Exception, match="Failed to write dataset"): + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # more field is ok + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") + + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8'), "test": 0}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8'), "test": 1}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8'), "test": 2}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8'), "test": 3}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8'), "test": 4}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8'), "test": 5} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + remove_one_file(mindrecord_file_name) + remove_one_file(mindrecord_file_name + ".db") diff --git a/tests/ut/python/model/test_vgg.py b/tests/ut/python/model/test_vgg.py index ed8a217e51..fe7f3a76c1 100644 --- a/tests/ut/python/model/test_vgg.py +++ b/tests/ut/python/model/test_vgg.py @@ -17,13 +17,14 @@ import numpy as np import pytest from mindspore import Tensor -from model_zoo.vgg16.src.vgg import vgg16 +from model_zoo.official.cv.vgg16.src.vgg import vgg16 +from model_zoo.official.cv.vgg16.src.config import cifar_cfg as cfg from ..ut_filter import non_graph_engine @non_graph_engine def test_vgg16(): inputs = Tensor(np.random.rand(1, 3, 112, 112).astype(np.float32)) - net = vgg16() + net = vgg16(args=cfg) with pytest.raises(ValueError): print(net.construct(inputs)) diff --git a/tests/ut/python/nn/bijector/test_exp.py b/tests/ut/python/nn/bijector/test_exp.py new file mode 100644 index 0000000000..13e3e09a34 --- /dev/null +++ b/tests/ut/python/nn/bijector/test_exp.py @@ -0,0 +1,71 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for exp""" +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +def test_init(): + b = msb.Exp() + assert isinstance(b, msb.Bijector) + b = msb.Exp(1.0) + assert isinstance(b, msb.Bijector) + +class Net(nn.Cell): + """ + Test class: forward and inverse pass of bijector. + """ + def __init__(self): + super(Net, self).__init__() + self.b1 = msb.Exp() + self.b2 = msb.Exp() + + def construct(self, x_): + forward = self.b1.forward(x_) + inverse = self.b1.inverse(forward) + return x_ - inverse + +def test1(): + """ + Test forward and inverse pass of exp bijector. + """ + net = Net() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + +class Jacobian(nn.Cell): + """ + Test class: forward and inverse pass of bijector. + """ + def __init__(self): + super(Jacobian, self).__init__() + self.b1 = msb.Exp() + self.b2 = msb.Exp() + + def construct(self, x_): + ans1 = self.b1.forward_log_jacobian(x_) + ans2 = self.b1.inverse_log_jacobian(x_) + return ans1 + ans2 + +def test2(): + """ + Test jacobians of exp bijector. + """ + net = Jacobian() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/bijector/test_power_transform.py b/tests/ut/python/nn/bijector/test_power_transform.py new file mode 100644 index 0000000000..50ea5dbd44 --- /dev/null +++ b/tests/ut/python/nn/bijector/test_power_transform.py @@ -0,0 +1,73 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for powertransform""" +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +def test_init(): + b = msb.PowerTransform() + assert isinstance(b, msb.Bijector) + b = msb.PowerTransform(1) + assert isinstance(b, msb.Bijector) + +class Net(nn.Cell): + """ + Test class: forward and inverse pass of bijector. + """ + def __init__(self): + super(Net, self).__init__() + self.b1 = msb.PowerTransform(power=0) + self.b2 = msb.PowerTransform() + + def construct(self, x_): + ans1 = self.b1.inverse(self.b1.forward(x_)) + ans2 = self.b2.inverse(self.b2.forward(x_)) + return ans1 - ans2 + +def test1(): + """ + Test forward and inverse pass of powertransform bijector. + """ + net = Net() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + +class Jacobian(nn.Cell): + """ + Test class: forward and inverse pass of bijector. + """ + def __init__(self): + super(Jacobian, self).__init__() + self.b1 = msb.PowerTransform(power=0) + self.b2 = msb.PowerTransform() + + def construct(self, x_): + ans1 = self.b1.forward_log_jacobian(x_) + ans2 = self.b2.forward_log_jacobian(x_) + ans3 = self.b1.inverse_log_jacobian(x_) + ans4 = self.b2.inverse_log_jacobian(x_) + return ans1 - ans2 + ans3 - ans4 + +def test2(): + """ + Test jacobians of powertransform bijector. + """ + net = Jacobian() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py new file mode 100644 index 0000000000..d34455dbe4 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -0,0 +1,191 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Bernoulli. +""" +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + +def test_arguments(): + """ + Args passing during initialization. + """ + b = msd.Bernoulli() + assert isinstance(b, msd.Distribution) + b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) + assert isinstance(b, msd.Distribution) + +def test_prob(): + """ + Invalid probability. + """ + with pytest.raises(ValueError): + msd.Bernoulli([-0.1], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Bernoulli([1.1], dtype=dtype.int32) + +class BernoulliProb(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliProb, self).__init__() + self.b = msd.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self, value): + prob = self.b.prob(value) + log_prob = self.b.log_prob(value) + cdf = self.b.cdf(value) + log_cdf = self.b.log_cdf(value) + sf = self.b.survival_function(value) + log_sf = self.b.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_bernoulli_prob(): + """ + Test probability functions: passing value through construct. + """ + net = BernoulliProb() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class BernoulliProb1(nn.Cell): + """ + Bernoulli distribution: initialize without probs. + """ + def __init__(self): + super(BernoulliProb1, self).__init__() + self.b = msd.Bernoulli(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.b.prob(value, probs) + log_prob = self.b.log_prob(value, probs) + cdf = self.b.cdf(value, probs) + log_cdf = self.b.log_cdf(value, probs) + sf = self.b.survival_function(value, probs) + log_sf = self.b.log_survival(value, probs) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_bernoulli_prob1(): + """ + Test probability functions: passing value/probs through construct. + """ + net = BernoulliProb1() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + +class BernoulliKl(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + """ + def __init__(self): + super(BernoulliKl, self).__init__() + self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) + self.b2 = msd.Bernoulli(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + kl1 = self.b1.kl_loss('Bernoulli', probs_b) + kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss function. + """ + ber_net = BernoulliKl() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = ber_net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class BernoulliCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Bernoulli distribution. + """ + def __init__(self): + super(BernoulliCrossEntropy, self).__init__() + self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) + self.b2 = msd.Bernoulli(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + h1 = self.b1.cross_entropy('Bernoulli', probs_b) + h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Bernoulli distributions. + """ + net = BernoulliCrossEntropy() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class BernoulliBasics(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(BernoulliBasics, self).__init__() + self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) + + def construct(self): + mean = self.b.mean() + sd = self.b.sd() + var = self.b.var() + mode = self.b.mode() + entropy = self.b.entropy() + return mean + sd + var + mode + entropy + +def test_bascis(): + """ + Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. + """ + net = BernoulliBasics() + ans = net() + assert isinstance(ans, Tensor) + +class BernoulliConstruct(nn.Cell): + """ + Bernoulli distribution: going through construct. + """ + def __init__(self): + super(BernoulliConstruct, self).__init__() + self.b = msd.Bernoulli(0.5, dtype=dtype.int32) + self.b1 = msd.Bernoulli(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.b('prob', value) + prob1 = self.b('prob', value, probs) + prob2 = self.b1('prob', value, probs) + return prob + prob1 + prob2 + +def test_bernoulli_construct(): + """ + Test probability function going through construct. + """ + net = BernoulliConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_exponential.py b/tests/ut/python/nn/distribution/test_exponential.py new file mode 100644 index 0000000000..43aa428277 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_exponential.py @@ -0,0 +1,193 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Exponential. +""" +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + + +def test_arguments(): + """ + Args passing during initialization. + """ + e = msd.Exponential() + assert isinstance(e, msd.Distribution) + e = msd.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) + assert isinstance(e, msd.Distribution) + +def test_rate(): + """ + Invalid rate. + """ + with pytest.raises(ValueError): + msd.Exponential([-0.1], dtype=dtype.float32) + with pytest.raises(ValueError): + msd.Exponential([0.0], dtype=dtype.float32) + +class ExponentialProb(nn.Cell): + """ + Exponential distribution: initialize with rate. + """ + def __init__(self): + super(ExponentialProb, self).__init__() + self.e = msd.Exponential(0.5, dtype=dtype.float32) + + def construct(self, value): + prob = self.e.prob(value) + log_prob = self.e.log_prob(value) + cdf = self.e.cdf(value) + log_cdf = self.e.log_cdf(value) + sf = self.e.survival_function(value) + log_sf = self.e.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_exponential_prob(): + """ + Test probability functions: passing value through construct. + """ + net = ExponentialProb() + value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class ExponentialProb1(nn.Cell): + """ + Exponential distribution: initialize without rate. + """ + def __init__(self): + super(ExponentialProb1, self).__init__() + self.e = msd.Exponential(dtype=dtype.float32) + + def construct(self, value, rate): + prob = self.e.prob(value, rate) + log_prob = self.e.log_prob(value, rate) + cdf = self.e.cdf(value, rate) + log_cdf = self.e.log_cdf(value, rate) + sf = self.e.survival_function(value, rate) + log_sf = self.e.log_survival(value, rate) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_exponential_prob1(): + """ + Test probability functions: passing value/rate through construct. + """ + net = ExponentialProb1() + value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) + rate = Tensor([0.5], dtype=dtype.float32) + ans = net(value, rate) + assert isinstance(ans, Tensor) + +class ExponentialKl(nn.Cell): + """ + Test class: kl_loss between Exponential distributions. + """ + def __init__(self): + super(ExponentialKl, self).__init__() + self.e1 = msd.Exponential(0.7, dtype=dtype.float32) + self.e2 = msd.Exponential(dtype=dtype.float32) + + def construct(self, rate_b, rate_a): + kl1 = self.e1.kl_loss('Exponential', rate_b) + kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss function. + """ + net = ExponentialKl() + rate_b = Tensor([0.3], dtype=dtype.float32) + rate_a = Tensor([0.7], dtype=dtype.float32) + ans = net(rate_b, rate_a) + assert isinstance(ans, Tensor) + +class ExponentialCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Exponential distribution. + """ + def __init__(self): + super(ExponentialCrossEntropy, self).__init__() + self.e1 = msd.Exponential(0.3, dtype=dtype.float32) + self.e2 = msd.Exponential(dtype=dtype.float32) + + def construct(self, rate_b, rate_a): + h1 = self.e1.cross_entropy('Exponential', rate_b) + h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Exponential distributions. + """ + net = ExponentialCrossEntropy() + rate_b = Tensor([0.3], dtype=dtype.float32) + rate_a = Tensor([0.7], dtype=dtype.float32) + ans = net(rate_b, rate_a) + assert isinstance(ans, Tensor) + +class ExponentialBasics(nn.Cell): + """ + Test class: basic mean/sd/mode/entropy function. + """ + def __init__(self): + super(ExponentialBasics, self).__init__() + self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32) + + def construct(self): + mean = self.e.mean() + sd = self.e.sd() + var = self.e.var() + mode = self.e.mode() + entropy = self.e.entropy() + return mean + sd + var + mode + entropy + +def test_bascis(): + """ + Test mean/sd/var/mode/entropy functionality of Exponential distribution. + """ + net = ExponentialBasics() + ans = net() + assert isinstance(ans, Tensor) + + +class ExpConstruct(nn.Cell): + """ + Exponential distribution: going through construct. + """ + def __init__(self): + super(ExpConstruct, self).__init__() + self.e = msd.Exponential(0.5, dtype=dtype.float32) + self.e1 = msd.Exponential(dtype=dtype.float32) + + def construct(self, value, rate): + prob = self.e('prob', value) + prob1 = self.e('prob', value, rate) + prob2 = self.e1('prob', value, rate) + return prob + prob1 + prob2 + +def test_exp_construct(): + """ + Test probability function going through construct. + """ + net = ExpConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_geometric.py b/tests/ut/python/nn/distribution/test_geometric.py new file mode 100644 index 0000000000..b705aae781 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_geometric.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Geometric. +""" +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + + +def test_arguments(): + """ + Args passing during initialization. + """ + g = msd.Geometric() + assert isinstance(g, msd.Distribution) + g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) + assert isinstance(g, msd.Distribution) + +def test_prob(): + """ + Invalid probability. + """ + with pytest.raises(ValueError): + msd.Geometric([-0.1], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Geometric([1.1], dtype=dtype.int32) + +class GeometricProb(nn.Cell): + """ + Geometric distribution: initialize with probs. + """ + def __init__(self): + super(GeometricProb, self).__init__() + self.g = msd.Geometric(0.5, dtype=dtype.int32) + + def construct(self, value): + prob = self.g.prob(value) + log_prob = self.g.log_prob(value) + cdf = self.g.cdf(value) + log_cdf = self.g.log_cdf(value) + sf = self.g.survival_function(value) + log_sf = self.g.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_geometric_prob(): + """ + Test probability functions: passing value through construct. + """ + net = GeometricProb() + value = Tensor([3, 4, 5, 6, 7], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class GeometricProb1(nn.Cell): + """ + Geometric distribution: initialize without probs. + """ + def __init__(self): + super(GeometricProb1, self).__init__() + self.g = msd.Geometric(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.g.prob(value, probs) + log_prob = self.g.log_prob(value, probs) + cdf = self.g.cdf(value, probs) + log_cdf = self.g.log_cdf(value, probs) + sf = self.g.survival_function(value, probs) + log_sf = self.g.log_survival(value, probs) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_geometric_prob1(): + """ + Test probability functions: passing value/probs through construct. + """ + net = GeometricProb1() + value = Tensor([3, 4, 5, 6, 7], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + + +class GeometricKl(nn.Cell): + """ + Test class: kl_loss between Geometric distributions. + """ + def __init__(self): + super(GeometricKl, self).__init__() + self.g1 = msd.Geometric(0.7, dtype=dtype.int32) + self.g2 = msd.Geometric(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + kl1 = self.g1.kl_loss('Geometric', probs_b) + kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss function. + """ + ber_net = GeometricKl() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = ber_net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class GeometricCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Geometric distribution. + """ + def __init__(self): + super(GeometricCrossEntropy, self).__init__() + self.g1 = msd.Geometric(0.3, dtype=dtype.int32) + self.g2 = msd.Geometric(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + h1 = self.g1.cross_entropy('Geometric', probs_b) + h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Geometric distributions. + """ + net = GeometricCrossEntropy() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class GeometricBasics(nn.Cell): + """ + Test class: basic mean/sd/mode/entropy function. + """ + def __init__(self): + super(GeometricBasics, self).__init__() + self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) + + def construct(self): + mean = self.g.mean() + sd = self.g.sd() + var = self.g.var() + mode = self.g.mode() + entropy = self.g.entropy() + return mean + sd + var + mode + entropy + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of Geometric distribution. + """ + net = GeometricBasics() + ans = net() + assert isinstance(ans, Tensor) + + +class GeoConstruct(nn.Cell): + """ + Bernoulli distribution: going through construct. + """ + def __init__(self): + super(GeoConstruct, self).__init__() + self.g = msd.Geometric(0.5, dtype=dtype.int32) + self.g1 = msd.Geometric(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.g('prob', value) + prob1 = self.g('prob', value, probs) + prob2 = self.g1('prob', value, probs) + return prob + prob1 + prob2 + +def test_geo_construct(): + """ + Test probability function going through construct. + """ + net = GeoConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_normal.py b/tests/ut/python/nn/distribution/test_normal.py new file mode 100644 index 0000000000..f569aa67a5 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_normal.py @@ -0,0 +1,199 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Normal. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + +def test_normal_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + msd.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + + +def test_arguments(): + """ + args passing during initialization. + """ + n = msd.Normal() + assert isinstance(n, msd.Distribution) + n = msd.Normal([3.0], [4.0], dtype=dtype.float32) + assert isinstance(n, msd.Distribution) + + +class NormalProb(nn.Cell): + """ + Normal distribution: initialize with mean/sd. + """ + def __init__(self): + super(NormalProb, self).__init__() + self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + prob = self.normal.prob(value) + log_prob = self.normal.log_prob(value) + cdf = self.normal.cdf(value) + log_cdf = self.normal.log_cdf(value) + sf = self.normal.survival_function(value) + log_sf = self.normal.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_normal_prob(): + """ + Test probability functions: passing value through construct. + """ + net = NormalProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + + +class NormalProb1(nn.Cell): + """ + Normal distribution: initialize without mean/sd. + """ + def __init__(self): + super(NormalProb1, self).__init__() + self.normal = msd.Normal() + + def construct(self, value, mean, sd): + prob = self.normal.prob(value, mean, sd) + log_prob = self.normal.log_prob(value, mean, sd) + cdf = self.normal.cdf(value, mean, sd) + log_cdf = self.normal.log_cdf(value, mean, sd) + sf = self.normal.survival_function(value, mean, sd) + log_sf = self.normal.log_survival(value, mean, sd) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_normal_prob1(): + """ + Test probability functions: passing mean/sd, value through construct. + """ + net = NormalProb1() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + ans = net(value, mean, sd) + assert isinstance(ans, Tensor) + +class NormalKl(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + """ + def __init__(self): + super(NormalKl, self).__init__() + self.n1 = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.n2 = msd.Normal(dtype=dtype.float32) + + def construct(self, mean_b, sd_b, mean_a, sd_a): + kl1 = self.n1.kl_loss('Normal', mean_b, sd_b) + kl2 = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss. + """ + net = NormalKl() + mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(mean_b, sd_b, mean_a, sd_a) + assert isinstance(ans, Tensor) + +class NormalCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Normal distribution. + """ + def __init__(self): + super(NormalCrossEntropy, self).__init__() + self.n1 = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.n2 = msd.Normal(dtype=dtype.float32) + + def construct(self, mean_b, sd_b, mean_a, sd_a): + h1 = self.n1.cross_entropy('Normal', mean_b, sd_b) + h2 = self.n2.cross_entropy('Normal', mean_b, sd_b, mean_a, sd_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross entropy between Normal distributions. + """ + net = NormalCrossEntropy() + mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(mean_b, sd_b, mean_a, sd_a) + assert isinstance(ans, Tensor) + +class NormalBasics(nn.Cell): + """ + Test class: basic mean/sd function. + """ + def __init__(self): + super(NormalBasics, self).__init__() + self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self): + mean = self.n.mean() + sd = self.n.sd() + mode = self.n.mode() + entropy = self.n.entropy() + return mean + sd + mode + entropy + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of Normal. + """ + net = NormalBasics() + ans = net() + assert isinstance(ans, Tensor) + + +class NormalConstruct(nn.Cell): + """ + Normal distribution: going through construct. + """ + def __init__(self): + super(NormalConstruct, self).__init__() + self.normal = msd.Normal(3.0, 4.0) + self.normal1 = msd.Normal() + + def construct(self, value, mean, sd): + prob = self.normal('prob', value) + prob1 = self.normal('prob', value, mean, sd) + prob2 = self.normal1('prob', value, mean, sd) + return prob + prob1 + prob2 + +def test_normal_construct(): + """ + Test probability function going through construct. + """ + net = NormalConstruct() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + ans = net(value, mean, sd) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_uniform.py b/tests/ut/python/nn/distribution/test_uniform.py new file mode 100644 index 0000000000..a631998e83 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_uniform.py @@ -0,0 +1,208 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Uniform. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + +def test_uniform_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + + +def test_arguments(): + """ + Args passing during initialization. + """ + u = msd.Uniform() + assert isinstance(u, msd.Distribution) + u = msd.Uniform([3.0], [4.0], dtype=dtype.float32) + assert isinstance(u, msd.Distribution) + + +def test_invalid_range(): + """ + Test range of uniform distribution. + """ + with pytest.raises(ValueError): + msd.Uniform(0.0, 0.0, dtype=dtype.float32) + with pytest.raises(ValueError): + msd.Uniform(1.0, 0.0, dtype=dtype.float32) + + +class UniformProb(nn.Cell): + """ + Uniform distribution: initialize with low/high. + """ + def __init__(self): + super(UniformProb, self).__init__() + self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + prob = self.u.prob(value) + log_prob = self.u.log_prob(value) + cdf = self.u.cdf(value) + log_cdf = self.u.log_cdf(value) + sf = self.u.survival_function(value) + log_sf = self.u.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_uniform_prob(): + """ + Test probability functions: passing value through construct. + """ + net = UniformProb() + value = Tensor([3.1, 3.2, 3.3, 3.4], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class UniformProb1(nn.Cell): + """ + Uniform distribution: initialize without low/high. + """ + def __init__(self): + super(UniformProb1, self).__init__() + self.u = msd.Uniform(dtype=dtype.float32) + + def construct(self, value, low, high): + prob = self.u.prob(value, low, high) + log_prob = self.u.log_prob(value, low, high) + cdf = self.u.cdf(value, low, high) + log_cdf = self.u.log_cdf(value, low, high) + sf = self.u.survival_function(value, low, high) + log_sf = self.u.log_survival(value, low, high) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_uniform_prob1(): + """ + Test probability functions: passing low/high, value through construct. + """ + net = UniformProb1() + value = Tensor([0.1, 0.2, 0.3, 0.9], dtype=dtype.float32) + low = Tensor([0.0], dtype=dtype.float32) + high = Tensor([1.0], dtype=dtype.float32) + ans = net(value, low, high) + assert isinstance(ans, Tensor) + +class UniformKl(nn.Cell): + """ + Test class: kl_loss of Uniform distribution. + """ + def __init__(self): + super(UniformKl, self).__init__() + self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.u2 = msd.Uniform(dtype=dtype.float32) + + def construct(self, low_b, high_b, low_a, high_a): + kl1 = self.u1.kl_loss('Uniform', low_b, high_b) + kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss. + """ + net = UniformKl() + low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) + high_b = Tensor(np.array([5.0]).astype(np.float32), dtype=dtype.float32) + low_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + high_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(low_b, high_b, low_a, high_a) + assert isinstance(ans, Tensor) + +class UniformCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Uniform distribution. + """ + def __init__(self): + super(UniformCrossEntropy, self).__init__() + self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.u2 = msd.Uniform(dtype=dtype.float32) + + def construct(self, low_b, high_b, low_a, high_a): + h1 = self.u1.cross_entropy('Uniform', low_b, high_b) + h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Unifrom distributions. + """ + net = UniformCrossEntropy() + low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) + high_b = Tensor(np.array([5.0]).astype(np.float32), dtype=dtype.float32) + low_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + high_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(low_b, high_b, low_a, high_a) + assert isinstance(ans, Tensor) + +class UniformBasics(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(UniformBasics, self).__init__() + self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) + + def construct(self): + mean = self.u.mean() + sd = self.u.sd() + var = self.u.var() + entropy = self.u.entropy() + return mean + sd + var + entropy + +def test_bascis(): + """ + Test mean/sd/var/mode/entropy functionality of Uniform. + """ + net = UniformBasics() + ans = net() + assert isinstance(ans, Tensor) + + +class UniConstruct(nn.Cell): + """ + Unifrom distribution: going through construct. + """ + def __init__(self): + super(UniConstruct, self).__init__() + self.u = msd.Uniform(-4.0, 4.0) + self.u1 = msd.Uniform() + + def construct(self, value, low, high): + prob = self.u('prob', value) + prob1 = self.u('prob', value, low, high) + prob2 = self.u1('prob', value, low, high) + return prob + prob1 + prob2 + +def test_uniform_construct(): + """ + Test probability function going through construct. + """ + net = UniConstruct() + value = Tensor([-5.0, 0.0, 1.0, 5.0], dtype=dtype.float32) + low = Tensor([-1.0], dtype=dtype.float32) + high = Tensor([1.0], dtype=dtype.float32) + ans = net(value, low, high) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index 03a73893c5..bebbc00880 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -20,8 +20,10 @@ import mindspore.nn as nn from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR +from mindspore.nn.optim import Adam, AdamWeightDecay from mindspore.ops import operations as P +import mindspore.nn.learning_rate_schedule as lr_schedules +from mindspore.nn.dynamic_lr import polynomial_decay_lr context.set_context(enable_sparse=True) @@ -112,6 +114,62 @@ def test_sparse_adam_compile(): _executor.compile(train_network, indices, label) +def test_adam_group1(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + poly_decay_lr = polynomial_decay_lr(0.01, 0.0001, total_step=10, step_per_epoch=1, decay_epoch=3, power=1.0) + + group_params = [{'params': [all_params[0]], 'lr': poly_decay_lr, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.Adam(group_params, learning_rate=0.1) + + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_adam_group2(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0) + group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.Adam(group_params, learning_rate=schedule_lr) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_adamweightdecay_group(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0) + group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.AdamWeightDecay(group_params, learning_rate=schedule_lr) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + def test_AdamWeightDecay_beta1(): net = Net() print("**********", net.get_parameters()) @@ -131,20 +189,6 @@ def test_AdamWeightDecay_e(): AdamWeightDecay(net.get_parameters(), eps=-0.1, learning_rate=0.1) -def test_AdamWeightDecayDynamicLR(): - """ test_AdamWeightDecayDynamicLR """ - inputs = Tensor(np.ones([1, 64]).astype(np.float32)) - label = Tensor(np.zeros([1, 10]).astype(np.float32)) - net = Net() - net.set_train() - loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) - - net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepCell(net_with_loss, optimizer) - _executor.compile(train_network, inputs, label) - - def test_adam_mindspore_with_empty_params(): net = nn.Flatten() with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): diff --git a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py deleted file mode 100644 index 23aad24c47..0000000000 --- a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test adam """ -import numpy as np - -import mindspore.nn as nn -from mindspore import Tensor, Parameter, context -from mindspore.common.api import _executor -from mindspore.common import dtype as mstype -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Optimizer -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore._checkparam import Validator as validator -from mindspore._checkparam import Rel - -context.set_context(enable_sparse=True) - -adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") -@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): - op_mul = P.Mul() - op_square = P.Square() - op_sqrt = P.Sqrt() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() - - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) - - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) - - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - - beta2, op_square(gradient_fp32)) - - update = next_m / (op_sqrt(next_v) + eps) - if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) - - update_with_lr = op_mul(lr, update) - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - - next_v = F.depend(next_v, F.assign(param, next_param)) - next_v = F.depend(next_v, F.assign(m, next_m)) - next_v = F.depend(next_v, F.assign(v, next_v)) - return next_v - - -@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tuple", "Bool") -def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): - return gradient[2][2] - -def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): - """Check the type of inputs.""" - validator.check_value_type("beta1", beta1, [float], prim_name) - validator.check_value_type("beta2", beta2, [float], prim_name) - validator.check_value_type("eps", eps, [float], prim_name) - validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - - -class AdamWeightDecaySparse(Optimizer): - """ - Implements Adam algorithm weight decay fix. - - Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. Default: 1e-3. - beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. - Should be in range (0.0, 1.0). - beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. - Should be in range (0.0, 1.0). - eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. - Should be greater than 0. - weight_decay (float): Weight decay (L2 penalty). Default: 0.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. - - Inputs: - - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`, - and might be in sparse format. - - Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. - - Examples: - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - """ - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(AdamWeightDecaySparse, self).__init__(learning_rate, params) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) - self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) - self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) - self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) - - self.params = self.parameters - self.moments1 = self.params.clone(prefix="adam_m", init='zeros') - self.moments2 = self.params.clone(prefix="adam_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) - - self.map = C.Map() - - def construct(self, gradients): - lr = self.get_lr() - updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) - - return updated_velocity - - -def test_AdamWeightDecaySparse(): - """ test_AdamWeightDecaySparse """ - context.set_context(mode=context.GRAPH_MODE) - class Loss(nn.Cell): - def __init__(self): - super(Loss, self).__init__() - def construct(self, base, target): - return base - class NetWithSparseGatherV2(nn.Cell): - def __init__(self): - super(NetWithSparseGatherV2, self).__init__() - self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") - self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") - self.gatherv2 = P.SparseGatherV2() - self.axis = 0 - def construct(self, indices): - return self.gatherv2(self.w1, indices, self.axis) * self.w2 - - inputs = Tensor(np.array([0, 1]).astype(np.int32)) - label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) - net = NetWithSparseGatherV2() - net.set_train() - loss = Loss() - optimizer = AdamWeightDecaySparse(net.trainable_params()) - - net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepCell(net_with_loss, optimizer) - _executor.compile(train_network, inputs, label) diff --git a/tests/ut/python/nn/optim/test_lamb.py b/tests/ut/python/nn/optim/test_lamb.py index 4d229f0837..b2963fc950 100644 --- a/tests/ut/python/nn/optim/test_lamb.py +++ b/tests/ut/python/nn/optim/test_lamb.py @@ -14,7 +14,6 @@ # ============================================================================ """ test lamb """ import numpy as np -import pytest import mindspore.nn as nn from mindspore import Tensor, Parameter @@ -22,6 +21,27 @@ from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Lamb from mindspore.ops import operations as P +import mindspore.common.dtype as mstype +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR + + +class LambLearningRate(LearningRateSchedule): + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(LambLearningRate, self).__init__() + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr class Net(nn.Cell): @@ -51,27 +71,49 @@ class NetWithoutWeight(nn.Cell): return x -def test_lamb_compile(): +def test_lamb_compile_dynamic_lr(): """ test_Lamb_compile """ inputs = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = Lamb(net.trainable_params(), decay_steps=10) + warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0) + optimizer = Lamb(net.trainable_params(), warmup_decay_lr) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) -def test_lamb_error(): +def test_lamb_compile(): + """ test_Lamb_compile """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() - with pytest.raises(TypeError): - Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0) + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() - with pytest.raises(TypeError): - Lamb(net.get_parameters(), decay_steps=1.0) + optimizer = Lamb(net.trainable_params(), 0.02, 0.9) - with pytest.raises(ValueError): - Lamb(net.get_parameters(), decay_steps=0) + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_lamb_group(): + """ test_Lamb_group_compile """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0) + all_params = net.trainable_params() + group_params = [{'params': [all_params[0]], 'lr': warmup_decay_lr, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = Lamb(group_params, 0.02) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) diff --git a/tests/ut/python/nn/optim/test_optimizer.py b/tests/ut/python/nn/optim/test_optimizer.py index 70b79e97d7..32d9c5b4fe 100644 --- a/tests/ut/python/nn/optim/test_optimizer.py +++ b/tests/ut/python/nn/optim/test_optimizer.py @@ -18,7 +18,7 @@ import pytest from mindspore import Tensor from mindspore.common.parameter import Parameter -from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR +from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay class IterableObjc: @@ -81,10 +81,6 @@ class TestNullParam(): with pytest.raises(ValueError): AdamWeightDecay(None) - def test_AdamWeightDecayDynamicLR_init(self): - with pytest.raises(ValueError): - AdamWeightDecayDynamicLR(None, 10) - def test_Sgd_init(self): with pytest.raises(ValueError): SGD(None) @@ -101,10 +97,6 @@ class TestUnsupportParam(): with pytest.raises(TypeError): AdamWeightDecay(9) - def test_AdamWeightDecayDynamicLR_init(self): - with pytest.raises(TypeError): - AdamWeightDecayDynamicLR(0.5, 10) - def test_Sgd_init(self): with pytest.raises(TypeError): paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x") diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index 3077896fed..d88c55fd70 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -37,6 +37,7 @@ class Net(nn.Cell): x = self.biasAdd(self.matmul(x, self.weight), self.bias) return x + class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): diff --git a/tests/ut/python/nn/test_distribution.py b/tests/ut/python/nn/test_distribution.py deleted file mode 100644 index 845c64a110..0000000000 --- a/tests/ut/python/nn/test_distribution.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Test nn.Distribution. - -Including Normal Distribution and Bernoulli Distribution. -""" -import pytest -import numpy as np - -import mindspore.nn as nn -from mindspore import dtype -from mindspore import Tensor - -def test_normal_shape_errpr(): - """ - Invalid shapes. - """ - with pytest.raises(ValueError): - nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) - -def test_no_arguments(): - """ - No args passed in during initialization. - """ - n = nn.Normal() - assert isinstance(n, nn.Distribution) - b = nn.Bernoulli() - assert isinstance(b, nn.Distribution) - -def test_with_arguments(): - """ - Args passed in during initialization. - """ - n = nn.Normal([3.0], [4.0], dtype=dtype.float32) - assert isinstance(n, nn.Distribution) - b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) - assert isinstance(b, nn.Distribution) - -class NormalProb(nn.Cell): - """ - Normal distribution: initialize with mean/sd. - """ - def __init__(self): - super(NormalProb, self).__init__() - self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) - - def construct(self, value): - x = self.normal('prob', value) - y = self.normal('log_prob', value) - return x, y - -def test_normal_prob(): - """ - Test pdf/log_pdf: passing value through construct. - """ - net = NormalProb() - value = Tensor([0.5, 1.0], dtype=dtype.float32) - pdf, log_pdf = net(value) - assert isinstance(pdf, Tensor) - assert isinstance(log_pdf, Tensor) - -class NormalProb1(nn.Cell): - """ - Normal distribution: initialize without mean/sd. - """ - def __init__(self): - super(NormalProb1, self).__init__() - self.normal = nn.Normal() - - def construct(self, value, mean, sd): - x = self.normal('prob', value, mean, sd) - y = self.normal('log_prob', value, mean, sd) - return x, y - -def test_normal_prob1(): - """ - Test pdf/logpdf: passing mean/sd, value through construct. - """ - net = NormalProb1() - value = Tensor([0.5, 1.0], dtype=dtype.float32) - mean = Tensor([0.0], dtype=dtype.float32) - sd = Tensor([1.0], dtype=dtype.float32) - pdf, log_pdf = net(value, mean, sd) - assert isinstance(pdf, Tensor) - assert isinstance(log_pdf, Tensor) - -class NormalProb2(nn.Cell): - """ - Normal distribution: initialize with mean/sd. - """ - def __init__(self): - super(NormalProb2, self).__init__() - self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) - - def construct(self, value, mean, sd): - x = self.normal('prob', value, mean, sd) - y = self.normal('log_prob', value, mean, sd) - return x, y - -def test_normal_prob2(): - """ - Test pdf/log_pdf: passing mean/sd through construct. - Overwrite original mean/sd. - """ - net = NormalProb2() - value = Tensor([0.5, 1.0], dtype=dtype.float32) - mean = Tensor([0.0], dtype=dtype.float32) - sd = Tensor([1.0], dtype=dtype.float32) - pdf, log_pdf = net(value, mean, sd) - assert isinstance(pdf, Tensor) - assert isinstance(log_pdf, Tensor) - -class BernoulliProb(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliProb, self).__init__() - self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) - - def construct(self, value): - return self.bernoulli('prob', value) - -class BernoulliLogProb(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliLogProb, self).__init__() - self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) - - def construct(self, value): - return self.bernoulli('log_prob', value) - - -def test_bernoulli_prob(): - """ - Test pmf/log_pmf: passing value through construct. - """ - net = BernoulliProb() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - pmf = net(value) - assert isinstance(pmf, Tensor) - -def test_bernoulli_log_prob(): - """ - Test pmf/log_pmf: passing value through construct. - """ - net = BernoulliLogProb() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - log_pmf = net(value) - assert isinstance(log_pmf, Tensor) - -class BernoulliProb1(nn.Cell): - """ - Bernoulli distribution: initialize without probs. - """ - def __init__(self): - super(BernoulliProb1, self).__init__() - self.bernoulli = nn.Bernoulli() - - def construct(self, value, probs): - return self.bernoulli('prob', value, probs) - -class BernoulliLogProb1(nn.Cell): - """ - Bernoulli distribution: initialize without probs. - """ - def __init__(self): - super(BernoulliLogProb1, self).__init__() - self.bernoulli = nn.Bernoulli() - - def construct(self, value, probs): - return self.bernoulli('log_prob', value, probs) - - -def test_bernoulli_prob1(): - """ - Test pmf/log_pmf: passing probs through construct. - """ - net = BernoulliProb1() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - pmf = net(value, probs) - assert isinstance(pmf, Tensor) - -def test_bernoulli_log_prob1(): - """ - Test pmf/log_pmf: passing probs through construct. - """ - net = BernoulliLogProb1() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - log_pmf = net(value, probs) - assert isinstance(log_pmf, Tensor) - -class BernoulliProb2(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliProb2, self).__init__() - self.bernoulli = nn.Bernoulli(0.5) - - def construct(self, value, probs): - return self.bernoulli('prob', value, probs) - -class BernoulliLogProb2(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliLogProb2, self).__init__() - self.bernoulli = nn.Bernoulli(0.5) - - def construct(self, value, probs): - return self.bernoulli('log_prob', value, probs) - - -def test_bernoulli_prob2(): - """ - Test pmf/log_pmf: passing probs/value through construct. - Overwrite original probs. - """ - net = BernoulliProb2() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - pmf = net(value, probs) - assert isinstance(pmf, Tensor) - -def test_bernoulli_log_prob2(): - """ - Test pmf/log_pmf: passing probs/value through construct. - Overwrite original probs. - """ - net = BernoulliLogProb2() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - log_pmf = net(value, probs) - assert isinstance(log_pmf, Tensor) - - -class NormalKl(nn.Cell): - """ - Test class: kl_loss of Normal distribution. - """ - def __init__(self): - super(NormalKl, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) - - def construct(self, x_, y_): - return self.n('kl_loss', 'Normal', x_, y_) - -class BernoulliKl(nn.Cell): - """ - Test class: kl_loss between Bernoulli distributions. - """ - def __init__(self): - super(BernoulliKl, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) - - def construct(self, x_): - return self.b('kl_loss', 'Bernoulli', x_) - -def test_kl(): - """ - Test kl_loss function. - """ - nor_net = NormalKl() - mean_b = np.array([1.0]).astype(np.float32) - sd_b = np.array([1.0]).astype(np.float32) - mean = Tensor(mean_b, dtype=dtype.float32) - sd = Tensor(sd_b, dtype=dtype.float32) - loss = nor_net(mean, sd) - assert isinstance(loss, Tensor) - - ber_net = BernoulliKl() - probs_b = Tensor([0.3], dtype=dtype.float32) - loss = ber_net(probs_b) - assert isinstance(loss, Tensor) - - -class NormalKlNoArgs(nn.Cell): - """ - Test class: kl_loss of Normal distribution. - No args during initialization. - """ - def __init__(self): - super(NormalKlNoArgs, self).__init__() - self.n = nn.Normal(dtype=dtype.float32) - - def construct(self, x_, y_, w_, v_): - return self.n('kl_loss', 'Normal', x_, y_, w_, v_) - -class BernoulliKlNoArgs(nn.Cell): - """ - Test class: kl_loss between Bernoulli distributions. - No args during initialization. - """ - def __init__(self): - super(BernoulliKlNoArgs, self).__init__() - self.b = nn.Bernoulli(dtype=dtype.int32) - - def construct(self, x_, y_): - return self.b('kl_loss', 'Bernoulli', x_, y_) - -def test_kl_no_args(): - """ - Test kl_loss function. - """ - nor_net = NormalKlNoArgs() - mean_b = np.array([1.0]).astype(np.float32) - sd_b = np.array([1.0]).astype(np.float32) - mean_a = np.array([2.0]).astype(np.float32) - sd_a = np.array([3.0]).astype(np.float32) - mean_b = Tensor(mean_b, dtype=dtype.float32) - sd_b = Tensor(sd_b, dtype=dtype.float32) - mean_a = Tensor(mean_a, dtype=dtype.float32) - sd_a = Tensor(sd_a, dtype=dtype.float32) - loss = nor_net(mean_b, sd_b, mean_a, sd_a) - assert isinstance(loss, Tensor) - - ber_net = BernoulliKlNoArgs() - probs_b = Tensor([0.3], dtype=dtype.float32) - probs_a = Tensor([0.7], dtype=dtype.float32) - loss = ber_net(probs_b, probs_a) - assert isinstance(loss, Tensor) - - - -class NormalBernoulli(nn.Cell): - """ - Test class: basic mean/sd function. - """ - def __init__(self): - super(NormalBernoulli, self).__init__() - self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) - self.b = nn.Bernoulli(0.5, dtype=dtype.int32) - - def construct(self): - normal_mean = self.n('mean') - normal_sd = self.n('sd') - bernoulli_mean = self.b('mean') - bernoulli_sd = self.b('sd') - return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd - -def test_bascis(): - """ - Test mean/sd functionality of Normal and Bernoulli. - """ - net = NormalBernoulli() - normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() - assert isinstance(normal_mean, Tensor) - assert isinstance(normal_sd, Tensor) - assert isinstance(bernoulli_mean, Tensor) - assert isinstance(bernoulli_sd, Tensor) diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index c53f28d5f7..44a803a219 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -28,7 +28,7 @@ decay_epoch = 2 min_lr = 0.01 max_lr = 0.1 power = 0.5 - +warmup_epoch = 2 class TestInputs: def test_milestone1(self): @@ -234,3 +234,8 @@ def test_polynomial_decay(): lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, True) assert len(lr2) == total_step + + +def test_warmup(): + lr1 = dr.warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch) + assert len(lr1) == total_step diff --git a/tests/ut/python/nn/test_learning_rate_schedule.py b/tests/ut/python/nn/test_learning_rate_schedule.py new file mode 100644 index 0000000000..74f261a02e --- /dev/null +++ b/tests/ut/python/nn/test_learning_rate_schedule.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Test Dynamic Learning Rate """ +import pytest + +from mindspore import Tensor, Parameter +from mindspore.nn import learning_rate_schedule as lr_schedules +from mindspore.common.api import _executor +import mindspore.common.dtype as mstype + + +learning_rate = 0.1 +end_learning_rate = 0.01 +decay_rate = 0.9 +decay_steps = 4 +warmup_steps = 2 +min_lr = 0.01 +max_lr = 0.1 +power = 0.5 +global_step = Parameter(Tensor(2, mstype.int32), 'global_step') + + +class TestInit: + def test_learning_rate_type(self): + lr = True + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps) + + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power) + + def test_learning_rate_value(self): + lr = -1.0 + with pytest.raises(ValueError): + lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps) + + with pytest.raises(ValueError): + lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power) + + def test_end_learning_rate_type(self): + lr = True + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power) + + def test_end_learning_rate_value(self): + lr = -1.0 + with pytest.raises(ValueError): + lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power) + + def test_decay_rate_type(self): + rate = 'a' + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps) + + def test_decay_rate_value(self): + rate = -1.0 + with pytest.raises(ValueError): + lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps) + + def test_decay_steps_type(self): + decay_steps_e = 'm' + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e) + + with pytest.raises(TypeError): + lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e) + + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power) + + def test_decay_steps_value(self): + decay_steps_e = -2 + with pytest.raises(ValueError): + lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e) + + with pytest.raises(ValueError): + lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e) + + with pytest.raises(ValueError): + lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power) + + def test_is_stair(self): + is_stair = 1 + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, is_stair) + + def test_min_lr_type(self): + min_lr1 = True + with pytest.raises(TypeError): + lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps) + + def test_min_lr_value(self): + min_lr1 = -1.0 + with pytest.raises(ValueError): + lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps) + + def test_max_lr_type(self): + max_lr1 = 'a' + with pytest.raises(TypeError): + lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps) + + def test_max_lr_value(self): + max_lr1 = -1.0 + with pytest.raises(ValueError): + lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps) + + def test_power(self): + power1 = True + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power1) + + +def test_exponential_decay(): + lr_schedule = lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, True) + _executor.compile(lr_schedule, global_step) + + +def test_enatural_exp_decay(): + lr_schedule = lr_schedules.NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True) + _executor.compile(lr_schedule, global_step) + + +def test_inverse_decay(): + lr_schedule = lr_schedules.InverseDecayLR(learning_rate, decay_rate, decay_steps, True) + _executor.compile(lr_schedule, global_step) + + +def test_cosine_decay(): + lr_schedule = lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps) + _executor.compile(lr_schedule, global_step) + + +def test_polynomial_decay(): + lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + _executor.compile(lr_schedule, global_step) + + +def test_polynomial_decay2(): + lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power, True) + _executor.compile(lr_schedule, global_step) + + +def test_warmup(): + lr_schedule = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + _executor.compile(lr_schedule, global_step) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 3c66e0a6df..f4ab8734f8 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from mindspore import Tensor, Parameter, ParameterTuple +from mindspore import context, Tensor, Parameter, ParameterTuple from mindspore._checkparam import _check_str_by_regular from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer @@ -43,11 +43,11 @@ def test_parameter_tuple_illegal(): ParameterTuple(ptuple) with pytest.raises(TypeError): ParameterTuple(p1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): ParameterTuple(plist2) - with pytest.raises(ValueError): + with pytest.raises(TypeError): ParameterTuple(ptuple_str) - with pytest.raises(ValueError): + with pytest.raises(TypeError): ParameterTuple(pstr) with pytest.raises(TypeError): ParameterTuple(pnum) @@ -134,3 +134,40 @@ def test_check_str_by_regular(): _check_str_by_regular(str5) with pytest.raises(ValueError): _check_str_by_regular(str6) + +def test_parameter_lazy_init(): + # support lazy init in SEMI_AUTO_PARALLEL mode + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) + # Call init_data() without set default_input. + para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') + assert not isinstance(para.default_input, Tensor) + para = para.init_data() + assert isinstance(para.default_input, Tensor) + assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) + + # Call init_data() after default_input is set. + para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') + assert not isinstance(para.default_input, Tensor) + # expect type error when not init + with pytest.raises(TypeError): + para.default_input = Tensor(np.zeros((1, 2, 3))) + # init then assign + para = para.init_data() + # check the type + with pytest.raises(ValueError): + para.default_input = Tensor(np.zeros((1, 2, 3))) + # check the shape + with pytest.raises(ValueError): + para.default_input = Tensor(np.zeros((1, 2))) + # expect change ok + para.default_input = Tensor(np.zeros((1, 2, 3)).astype(np.float32)) + assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) + para.default_input = initializer('ones', [1, 2, 3], mstype.float32) + assert isinstance(para.default_input, Tensor) + # same object and has inited + assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) + # expect no effect. + para.init_data() + assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) + context.reset_auto_parallel_context() diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 53b42b8f66..369fe5f9b1 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -20,6 +20,7 @@ import mindspore as ms from mindspore import Tensor from mindspore import context from mindspore import nn +from mindspore.common import dtype as mstype from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -115,8 +116,7 @@ def test_if_none(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = None net = Net(z) - assert net(x, y) == y - + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_if_str_is_not_none_right(): class Net(nn.Cell): @@ -136,7 +136,7 @@ def test_if_str_is_not_none_right(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = "ok" net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_if_str_is_not_none_left(): @@ -157,7 +157,7 @@ def test_if_str_is_not_none_left(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = "ok" net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_if_none_equal_none(): @@ -178,7 +178,7 @@ def test_if_none_equal_none(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = None net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_if_str_is_null(): @@ -199,7 +199,7 @@ def test_if_str_is_null(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = "" net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_if_str_is_true(): @@ -220,7 +220,7 @@ def test_if_str_is_true(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = "ok" net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_if_str_equal(): @@ -241,7 +241,7 @@ def test_if_str_equal(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = "ok" net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_if_tuple_is_null(): @@ -262,7 +262,7 @@ def test_if_tuple_is_null(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = () net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_if_tuple_is_not_null(): @@ -283,7 +283,7 @@ def test_if_tuple_is_not_null(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = (1, 2, 3) net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_if_dict_is_null(): @@ -304,7 +304,7 @@ def test_if_dict_is_null(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = {} net = Net(z) - assert net(x, y) == y + assert np.all(net(x, y).asnumpy() == y.asnumpy()) def test_if_dict_is_not_null(): @@ -325,7 +325,7 @@ def test_if_dict_is_not_null(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = {"one": 1, "two": 2} net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_if_else_assign(): @@ -355,7 +355,7 @@ def test_if_else_assign(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = [1, 2] net = Net(z) - assert net(x, y) == x + assert np.all(net(x, y).asnumpy() == x.asnumpy()) def test_if_compile_true(): @@ -398,7 +398,7 @@ def test_switch_layer(): ret = F.switch_layer(index, self.layers)(x) * self.z3 return ret - index = Tensor(0) + index = Tensor(0, dtype=mstype.int32) net = SwitchLayerCell() net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, @@ -436,7 +436,7 @@ def test_index_to_switch_layer(): ret = self.layers[index](x) * self.z3 return ret - index = Tensor(0) + index = Tensor(0, dtype=mstype.int32) net = SwitchLayerCell() net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, @@ -639,3 +639,33 @@ def test_large_for_loop_with_continue_break(): t = Tensor(np.ones([2, 3], dtype=np.float32)) net = Net() net(t) + + +def test_mixed_precision_cast(): + x = Tensor(np.ones([2, 3], dtype=np.float32)) + z = F.mixed_precision_cast(mstype.float16, x) + assert z.dtype == mstype.float16 + + +def test_while_concat(): + class Net(nn.Cell): + def __init__(self, data): + super(Net, self).__init__() + self.start = Tensor(0, dtype=mstype.int32) + self.end = Tensor(2, dtype=mstype.int32) + self.out = Tensor(np.zeros([2, 3], dtype=np.float32)) + self.concat = P.Concat() + + def construct(self, inputs): + idx = self.start + end = self.end + out = self.out + while idx < end: + xi = inputs[idx, :, :] + out = self.concat((out, xi)) + idx = idx + 1 + return out + + x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) + net = Net(x) + net(x) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 7f53d40469..8b2e7ab432 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -237,6 +237,44 @@ class ScatterAdd(nn.Cell): return out +class ScatterNonAliasingAdd(nn.Cell): + """ScatterNonAliasingAdd net definition""" + + def __init__(self, ref_shape, dtype=np.float32): + super(ScatterNonAliasingAdd, self).__init__() + self.scatter_no_aliasing_add = P.ScatterNonAliasingAdd() + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_no_aliasing_add(self.ref, indices, updates) + return out + + +class ScatterNdSub(nn.Cell): + """ScatterNdSub net definition""" + + def __init__(self, ref_shape, dtype=np.float32): + super(ScatterNdSub, self).__init__() + self.scatter_nd_sub = P.ScatterNdSub() + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_nd_sub(self.ref, indices, updates) + return out + +class ScatterNdAdd(nn.Cell): + """ScatterNdAdd net definition""" + + def __init__(self, ref_shape, dtype=np.float32): + super(ScatterNdAdd, self).__init__() + self.scatter_nd_add = P.ScatterNdAdd() + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_nd_add(self.ref, indices, updates) + return out + + class ScatterSub(nn.Cell): """ScatterSub net definition""" @@ -810,6 +848,10 @@ test_case_math_ops = [ 'block': P.Asinh(), 'desc_inputs': [[3, 4, 5]], 'desc_bprop': [[3, 4, 5]]}), + ('Tan', { + 'block': P.Tan(), + 'desc_inputs': [[2, 3]], + 'desc_bprop': [[2, 3]]}), ('Reciprocal', { 'block': P.Reciprocal(), 'desc_inputs': [[2, 3, 3, 5]], @@ -912,6 +954,14 @@ test_case_math_ops = [ 'block': P.FloorMod(), 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], 'desc_bprop': [[2, 3, 4, 5]]}), + ('TruncateDiv', { + 'block': P.TruncateDiv(), + 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), + ('TruncateMod', { + 'block': P.TruncateMod(), + 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), ('identity', { 'block': ops.functional.identity, 'desc_inputs': [[2, 2]], @@ -948,6 +998,18 @@ test_case_math_ops = [ 'desc_const': [(0, 3, 1, 2)], 'desc_inputs': [], 'skip': ['backward']}), + ('Xdivy', { + 'block': P.Xdivy(), + 'desc_inputs': [[4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), + ('Xlogy', { + 'block': P.Xlogy(), + 'desc_inputs': [[4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), + ('SquaredDifference', { + 'block': P.SquaredDifference(), + 'desc_inputs': [[4, 5], [2, 3, 4, 5]], + 'desc_bprop': [[2, 3, 4, 5]]}), ('Square', { 'block': P.Square(), 'desc_inputs': [[4]], @@ -1111,7 +1173,8 @@ test_case_math_ops = [ 'block': P.SquareSumAll(), 'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)), Tensor(np.array([1, 1, 3, 7]).astype(np.float32))], - 'skip': ['backward']}), + 'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)), + Tensor(np.array(0.1).astype(np.float32))]}), ('Cos', { 'block': P.Cos(), 'desc_inputs': [[2, 3]], @@ -1121,6 +1184,11 @@ test_case_math_ops = [ 'desc_const': [1], 'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))], 'desc_bprop': []}), + ('ReduceAny', { + 'block': P.ReduceAny(), + 'desc_const': [1], + 'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))], + 'desc_bprop': []}), ('BesselI0e', { 'block': P.BesselI0e(), 'desc_inputs': [[2, 3]], @@ -1246,13 +1314,6 @@ test_case_nn_ops = [ 'block': P.AvgPool(ksize=(2, 2), strides=(2, 2), padding="VALID"), 'desc_inputs': [[100, 3, 28, 28]], 'desc_bprop': [[100, 3, 14, 14]]}), - ('AvgPoolGrad', { - 'block': G.AvgPoolGrad(ksize=(2, 2), strides=(2, 2), padding="VALID"), - 'desc_const': [(3, 4, 6, 6)], - 'const_first': True, - 'desc_inputs': [[3, 4, 6, 6]], - 'desc_bprop': [[3, 4, 6, 6]], - 'skip': ['backward']}), ('MaxPoolWithArgmax', { 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), 'desc_inputs': [[128, 32, 32, 64]], @@ -1372,14 +1433,12 @@ test_case_nn_ops = [ 'block': P.UnsortedSegmentSum(), 'desc_const': [1280], 'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))], - 'desc_bprop': [[8192, 1024]], - 'skip': ['backward']}), + 'desc_bprop': [[1280, 1024]]}), ('UnsortedSegmentSum_1', { 'block': P.UnsortedSegmentSum(), 'desc_const': [4], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], - 'desc_bprop': [[4, 1, 3]], - 'skip': ['backward']}), + 'desc_bprop': [[4, 1, 3]]}), ('UnsortedSegmentMin', { 'block': P.UnsortedSegmentMin(), 'desc_const': [4], @@ -1719,13 +1778,11 @@ test_case_array_ops = [ ('AddN', { 'block': NetForTupleInput(P.AddN()), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], - 'desc_bprop': [[2, 3, 3, 5]], - 'skip': ['backward']}), + 'desc_bprop': [[2, 3, 3, 5]]}), ('AccumulateNV2', { 'block': NetForTupleInput(P.AccumulateNV2()), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], - 'desc_bprop': [[2, 3, 3, 5]], - 'skip': ['backward']}), + 'desc_bprop': [[2, 3, 3, 5]]}), ('Shape', { 'block': P.Shape(), 'desc_inputs': [[3, 3, 2, 2]], @@ -1786,6 +1843,14 @@ test_case_array_ops = [ 'desc_const': [(2, 1, 1, 2)], 'desc_inputs': [[2, 2, 2]], 'desc_bprop': [[2, 2, 2, 4]]}), + ('ReverseV2', { + 'block': P.ReverseV2(axis=[1]), + 'desc_inputs': [(Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)))], + 'desc_bprop': [(Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)))]}), + ('Rint', { + 'block': P.Rint(), + 'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))], + 'skip': ['backward']}), ('ConcatV2_0', { 'block': P.Concat(), 'desc_inputs': [ @@ -2049,6 +2114,21 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), Tensor(np.array([2.0, 3.0, 4.0], np.float32))), 'skip': ['backward']}), + ('ScatterNonAliasingAdd_1d', { + 'block': ScatterNonAliasingAdd((8,)), + 'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))), + 'skip': ['backward']}), + ('ScatterNdAdd', { + 'block': ScatterNdAdd((8,)), + 'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))), + 'skip': ['backward']}), + ('ScatterNdSub', { + 'block': ScatterNdAdd((8,)), + 'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))), + 'skip': ['backward']}), ('ScatterAdd', { 'block': ScatterAdd((6,)), 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), @@ -2228,36 +2308,36 @@ test_case_other_ops = [ ] test_case_quant_ops = [ - ('AscendQuant_1', { - 'block': inner.AscendQuant(0.5, 0.0, False, "Round"), + ('Quant_1', { + 'block': inner.Quant(0.5, 0.0, False, "Round"), 'desc_inputs': [Tensor(np.random.rand(1, 2, 4, 4), mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_2', { - 'block': inner.AscendQuant(80.0, 10.0, True, "Round"), + ('Quant_2', { + 'block': inner.Quant(80.0, 10.0, True, "Round"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_3', { - 'block': inner.AscendQuant(80.0, 0.0, False, "Floor"), + ('Quant_3', { + 'block': inner.Quant(80.0, 0.0, False, "Floor"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_4', { - 'block': inner.AscendQuant(80.0, 0.0, False, "Ceil"), + ('Quant_4', { + 'block': inner.Quant(80.0, 0.0, False, "Ceil"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_5', { - 'block': inner.AscendQuant(80.0, 0.0, False, "Trunc"), + ('Quant_5', { + 'block': inner.Quant(80.0, 0.0, False, "Trunc"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_6', { - 'block': inner.AscendQuant(-80.0, 10.0, False, "Round"), + ('Quant_6', { + 'block': inner.Quant(-80.0, 10.0, False, "Round"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_7', { - 'block': inner.AscendQuant(80.0, -10.0, False, "Round"), + ('Quant_7', { + 'block': inner.Quant(80.0, -10.0, False, "Round"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float32)], 'skip': ['backward']}), - ('AscendQuant_8', { - 'block': inner.AscendQuant(80.0, 10.0, False, "Round"), + ('Quant_8', { + 'block': inner.Quant(80.0, 10.0, False, "Round"), 'desc_inputs': [Tensor([100.0, 200.0], mstype.float16)], 'skip': ['backward']}), ] diff --git a/tests/ut/python/ops/test_ops_attr_infer.py b/tests/ut/python/ops/test_ops_attr_infer.py index 6f18710558..0408937368 100644 --- a/tests/ut/python/ops/test_ops_attr_infer.py +++ b/tests/ut/python/ops/test_ops_attr_infer.py @@ -15,6 +15,7 @@ """ test nn ops """ import numpy as np from numpy.random import normal +import pytest import mindspore.nn as nn import mindspore.context as context @@ -311,6 +312,7 @@ def test_op_with_arg_as_input(): # The partial application used as argument is not supported yet # because of the limit of inference specialize system +@pytest.mark.skip("poly in infer") def test_partial_as_arg(): class PartialArgNet(nn.Cell): def __init__(self): diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 66590945da..ec8b9957a2 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -971,7 +971,7 @@ raise_error_set = [ Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], }), ('TensorGetItemByMixedTensorsTypeError', { - 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}), + 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}), 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)], diff --git a/tests/ut/python/ops/test_tuple_slice.py b/tests/ut/python/ops/test_tuple_slice.py index 1475c177f4..bfa573b8df 100644 --- a/tests/ut/python/ops/test_tuple_slice.py +++ b/tests/ut/python/ops/test_tuple_slice.py @@ -114,13 +114,13 @@ test_cases = [ test_cases_for_verify_exception = [ ('SliceStartCross', { - 'block': (NetWork_3(), {'exception': RuntimeError}), + 'block': (NetWork_3(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32))], }), ('SliceStepZero', { - 'block': (NetWork_3(), {'exception': RuntimeError}), + 'block': (NetWork_3(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32))], diff --git a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py index ea60f1f09b..cccaf8e3b8 100644 --- a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py +++ b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py @@ -152,7 +152,7 @@ def test_compile_fp16_overflow(): net = NetFP16(16, 16) loss = MSELoss() - optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) + optimizer = Lamb(net.trainable_params(), learning_rate=0.01) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network.set_train() @@ -276,7 +276,7 @@ def test_compile_fp16_lr_overflow_dynamic_graph(): print("the result is ", output) -def test_adam_compile(): +def adam_compile(loss_scale=1.0): inputs = Tensor(np.ones([15, 1]).astype(np.float32)) label = Tensor(np.zeros([15, 1]).astype(np.float32)) scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) @@ -284,10 +284,17 @@ def test_adam_compile(): loss = MSELoss() optimizer = Adam(net.trainable_params(), learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, - use_nesterov=False, weight_decay=0.0, loss_scale=1.0) + use_nesterov=False, weight_decay=0.0, loss_scale=loss_scale) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network.set_train() output = train_network(inputs, label, scaling_sens) print("the result is ", output) + +def test_adam_compile(): + adam_compile() + +def test_adam_loss_scale_compile(): + """ test setting loss_scale to 1e-40 """ + adam_compile(loss_scale=1e-40) diff --git a/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py index 2f93eb6186..fccccafb0f 100644 --- a/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py @@ -71,7 +71,7 @@ def test_group_lr(): assert opt.dynamic_lr is False assert opt.is_group_params_ordered is True for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): - if param in conv_params: + if 'conv' in param.name: assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy()) else: assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) @@ -103,10 +103,12 @@ def test_group_dynamic_1(): assert opt.dynamic_lr is True assert opt.is_group_params_ordered is True for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): - if param in conv_params: - assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) + if 'conv' in param.name: + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) else: - assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) assert param.name == order_param.name @@ -133,10 +135,12 @@ def test_group_dynamic_2(): assert opt.is_group is True assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): - if param in conv_params: - assert np.all(lr.data.asnumpy() == Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy()) + if 'conv' in param.name: + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy()) else: - assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy()) + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -157,7 +161,7 @@ def test_group_dynamic_no_same_size(): def test_group_not_float_lr(): net = LeNet5() - conv_lr = 1 + conv_lr = np.array(1) default_lr = 0.3 conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) @@ -169,7 +173,7 @@ def test_group_not_float_lr(): def test_group_not_float_weight_decay(): net = LeNet5() - conv_weight_decay = 1 + conv_weight_decay = np.array(1) conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, @@ -199,7 +203,7 @@ def test_weight_decay(): assert opt.is_group_params_ordered is True for weight_decay, decay_flags, param, order_param in zip( opt.weight_decay, opt.decay_flags, opt.parameters, net.trainable_params()): - if param in conv_params: + if 'conv' in param.name: assert weight_decay == conv_weight_decay assert decay_flags is True else: @@ -238,11 +242,15 @@ def test_get_lr_parameter_with_group(): assert opt.is_group_lr is True for param in opt.parameters: lr = opt.get_lr_parameter(param) - assert lr.name == 'lr_' + param.name + if 'conv' in param.name: + cur_name = 'learning_rate_group_' + '0' + else: + cur_name = 'learning_rate_group_' + '1' + assert lr.name == cur_name lr_list = opt.get_lr_parameter(conv_params) for lr, param in zip(lr_list, conv_params): - assert lr.name == 'lr_' + param.name + assert lr.name == 'learning_rate_group_' + '0' def test_get_lr_parameter_with_order_group(): @@ -256,7 +264,11 @@ def test_get_lr_parameter_with_order_group(): assert opt.is_group_lr is True for param in opt.parameters: lr = opt.get_lr_parameter(param) - assert lr.name == 'lr_' + param.name + if 'conv' in param.name: + cur_name = 'learning_rate_group_' + '0' + else: + cur_name = 'learning_rate' + assert lr.name == cur_name def test_get_lr_parameter_with_no_group(): @@ -271,7 +283,7 @@ def test_get_lr_parameter_with_no_group(): assert opt.is_group_lr is False for param in opt.parameters: lr = opt.get_lr_parameter(param) - assert lr.name == opt.learning_rate.name + assert lr.name == 'learning_rate' params_error = [1, 2, 3] with pytest.raises(TypeError): @@ -291,11 +303,11 @@ def test_order_params_1(): assert opt.is_group_params_ordered is True for weight_decay, decay_flags, lr, param, order_param in zip( opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, bias_params+conv_params): - if param in conv_params: + if 'conv' in param.name: assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy()) assert weight_decay == 0.01 assert decay_flags is True - elif param in bias_params: + elif 'bias' in param.name: assert np.all(lr.data.asnumpy() == Tensor(0.01, mstype.float32).asnumpy()) assert weight_decay == 0.0 assert decay_flags is False @@ -305,7 +317,11 @@ def test_order_params_1(): assert decay_flags is False assert param.name == order_param.name - assert lr.name == 'lr_' + param.name + if 'conv' in param.name: + assert lr.name == 'learning_rate' + elif 'bias' in param.name: + assert lr.name == 'learning_rate_group_' + '1' + def test_order_params_2(): @@ -323,13 +339,14 @@ def test_order_params_2(): assert opt.is_group is True assert opt.is_group_lr is True assert opt.is_group_params_ordered is True + all_lr = opt.get_lr_parameter(fc1_params+conv_params) for weight_decay, decay_flags, lr, param, order_param in zip( - opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, fc1_params+conv_params): - if param in conv_params: + opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params+conv_params): + if 'conv' in param.name: assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy()) assert weight_decay == conv_weight_decay assert decay_flags is True - elif param in fc1_params: + elif 'fc1' in param.name: assert np.all(lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy()) assert weight_decay == default_wd assert decay_flags is False @@ -339,8 +356,10 @@ def test_order_params_2(): assert decay_flags is False assert param.name == order_param.name - assert lr.name == 'lr_' + param.name - + if 'conv' in param.name: + assert lr.name == 'learning_rate' + elif 'fc1' in param.name: + assert lr.name == 'learning_rate_group_' + '0' def test_get_order_params_with_not_same(): net = LeNet5() diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py index c3ce3d6c4e..8728120ff1 100644 --- a/tests/ut/python/optimizer/test_python_pass.py +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -22,10 +22,11 @@ from mindspore.ops import operations as P from mindspore.common.python_pass_register import registe_pass, PyPassManager from mindspore.common.api import _generate_pip_args from mindspore._c_expression import generate_key, Executor_ +from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor context.set_context(mode=context.GRAPH_MODE) -def get_func_graph(obj, *args, phase="predict"): +def get_func_graph(obj, *args, phase="validate"): args_names, args_list = _generate_pip_args(obj, *args) dic = dict(zip(args_names, args_list)) key = generate_key(phase, dic) @@ -47,14 +48,11 @@ def test_softmax_relu(): @registe_pass(run_only_once=True) def softmax_relu_pass(): - softmax = P.Softmax() - relu = P.ReLU() - def pattern(x): - x = softmax(x) - return x - def target(x): - x = relu(x) - return x + x = AnyPattern() + softmax_pattern = IsPrimTypeOf(P.Softmax()) + pattern = CallWith(softmax_pattern, inputs=[x]) + relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) + target = CallWith(relu_pattern, inputs=[x]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) @@ -62,3 +60,128 @@ def test_softmax_relu(): ppm.unregiste(softmax_relu_pass) assert "ReLU" in transformed_repr assert "Softmax" not in transformed_repr + +def test_isin_pattern(): + """ + Test IsIn pattern which expresses the IsIn/OneOf semantics. + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_relu_pass(): + x = AnyPattern() + softmax_pattern = IsPrimTypeOf(P.Softmax()) + call_softmax = CallWith(softmax_pattern, inputs=[x]) + relu_pattern = IsPrimTypeOf(P.ReLU()) + call_relu = CallWith(relu_pattern, inputs=[x]) + + pattern = IsIn([call_softmax, call_relu]) + relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False) + target = CallWith(relu6_pattern, inputs=[x]) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) + ppm = PyPassManager() + ppm.unregiste(softmax_relu_pass) + assert "ReLU6" in transformed_repr + assert "Softmax" not in transformed_repr + +def test_isnot_pattern_0(): + """ + Test IsNot pattern which expresses the IsNot semantics. + Case: IsNot pass failed to match + """ + class ConvBN(nn.Cell): + def __init__(self): + super(ConvBN, self).__init__() + self.conv = P.Conv2D(32, 3) + self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32) + self.scale = Tensor(np.ones([32]), mindspore.float32) + self.bias = Tensor(np.ones([32]), mindspore.float32) + self.mean = Tensor(np.ones([32]), mindspore.float32) + self.variance = Tensor(np.ones([32]), mindspore.float32) + self.bn = P.BatchNorm() + def construct(self, x): + x = self.conv(x, self.conv_weight) + x = self.bn(x, self.scale, self.bias, self.mean, self.variance) + return x + inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) + conv_bn_model = ConvBN() + + @registe_pass(run_only_once=True) + def single_bn_pass(): + """ + Sub a BN which does NOT take Conv as inputs to ReLU6. + """ + conv2d_prim = IsPrimTypeOf("Conv2D") + conv2d = CallWith(conv2d_prim) + pattern_0 = IsNot(conv2d) + pattern = CallWith(P.BatchNorm(), inputs=[pattern_0]) + target = CallWith(P.ReLU6(), inputs=[pattern_0]) + return pattern, target + + @registe_pass(run_only_once=True) + def bn_pass(): + """ + Sub a BN to Softmax. + """ + bn = P.BatchNorm() + pattern = CallWith(bn) + softmax = P.Softmax() + target = CallWith(softmax, should_replace=False) + return pattern, target + + transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) + ppm = PyPassManager() + ppm.unregiste(single_bn_pass) + ppm.unregiste(bn_pass) + assert "ReLU6" not in transformed_repr + assert "Softmax" in transformed_repr + +def test_isnot_pattern_1(): + """ + Test IsNot pattern which expresses the IsNot semantics. + Case: IsNot pattern matches with the graph + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def single_bn_pass(): + """ + Sub a BN which does NOT take MatMul as inputs to ReLU6. + """ + matmul = IsPrimTypeOf("MatMul") + pattern_0 = IsNot(matmul) + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[pattern_0]) + relu6 = P.ReLU6() + target = CallWith(relu6, inputs=[pattern_0], should_replace=False) + return pattern, target + + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) + ppm = PyPassManager() + ppm.unregiste(single_bn_pass) + assert "ReLU6" in transformed_repr + assert "Softmax" not in transformed_repr + +def test_newtensor_pattern(): + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_addn_pass(): + x = AnyPattern() + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[x]) + + weight_tensor = Tensor(np.zeros([42]), mindspore.float16) + new_weight = NewTensor(weight_tensor) + addn_ops = P.AddN() + target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) + ppm = PyPassManager() + ppm.unregiste(softmax_addn_pass) + assert "AddN" in transformed_repr + assert "Softmax" not in transformed_repr diff --git a/tests/ut/python/parallel/test_auto_parallel_double_sources.py b/tests/ut/python/parallel/test_auto_parallel_double_sources.py new file mode 100644 index 0000000000..188c962f26 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_double_sources.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y, z, w, a): + predict = self.network(x, y, z, w, a) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y, z, w, a): + return C.grad_all(self.network)(x, y, z, w, a) + + # model_parallel test + + +def test_double_source_graph(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + self.matmul3 = P.MatMul() + self.matmul4 = P.MatMul() + self.matmul5 = P.MatMul() + + def construct(self, x, y, z, w, a): + m1_result = self.matmul1(x, y) + m2_result = self.matmul2(z, w) + m3_result = self.matmul3(m2_result, m1_result) + m4_result = self.matmul4(m2_result, m1_result) + out = self.matmul5(m3_result, m4_result) + + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([32, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 32]), dtype=ms.float32) + z = Tensor(np.ones([32, 32]), dtype=ms.float32) + w = Tensor(np.ones([32, 32]), dtype=ms.float32) + a = Tensor(np.ones([32, 32]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x, y, z, w, a) + + +def test_double_source_complex_graph(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + self.matmul3 = P.MatMul() + self.matmul4 = P.MatMul() + self.matmul5 = P.MatMul() + self.matmul6 = P.MatMul() + + def construct(self, x, y, z, w, a): + m1_result = self.matmul1(x, y) + m6_result = self.matmul6(m1_result, a) + m2_result = self.matmul2(z, w) + m3_result = self.matmul3(m2_result, m6_result) + m4_result = self.matmul4(m2_result, m1_result) + out = self.matmul5(m3_result, m4_result) + + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([32, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 32]), dtype=ms.float32) + z = Tensor(np.ones([32, 32]), dtype=ms.float32) + w = Tensor(np.ones([32, 32]), dtype=ms.float32) + a = Tensor(np.ones([32, 32]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x, y, z, w, a) diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index fdba571e70..5a117bbb1a 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -274,6 +274,9 @@ class DatasetLenet(): def get_repeat_count(self): return 1 + def create_tuple_iterator(self): + return self + def test_train_32k_8p(batch_size=32, num_classes=32768): dev_num = 8 @@ -675,3 +678,56 @@ def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #327 assert v == [[1, 1], [dev_num, 1]] elif re.search('ReduceSum-op', k) is not None: assert v == [[1, dev_num]] + + +def test_train_8k_8p_gpu(batch_size=32, num_classes=8192): + dev_num = 8 + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) + set_algo_parameters(elementwise_op_strategy_follow=True) + resset_op_id() + np.random.seed(6) + input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) + label_np = np.zeros([batch_size]).astype(np.int32) + for i in range(0, batch_size): + label_np[i] = i % num_classes + dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) + net = resnet50(num_classes) + loss = SoftmaxCrossEntropyExpand(sparse=True) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) + model = Model(net, loss_fn=loss, optimizer=opt) + model.train(5, dataset, dataset_sink_mode=False) + strategies = _executor._get_strategy(model._train_network) + for (k, v) in strategies.items(): + if re.search('Conv2D-op', k) is not None: + assert v[0][0] == dev_num + elif re.search('MatMul-op', k) is not None: + assert v == [[1, 1], [dev_num, 1]] + elif re.search('ReduceSum-op', k) is not None: + assert v == [[1, dev_num]] + +def test_train_4k_8p_gpu(batch_size=32, num_classes=4096): + dev_num = 8 + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) + set_algo_parameters(elementwise_op_strategy_follow=True) + resset_op_id() + np.random.seed(6) + input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) + label_np = np.zeros([batch_size]).astype(np.int32) + for i in range(0, batch_size): + label_np[i] = i % num_classes + dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) + net = resnet50(num_classes) + loss = SoftmaxCrossEntropyExpand(sparse=True) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) + model = Model(net, loss_fn=loss, optimizer=opt) + model.train(5, dataset, dataset_sink_mode=False) + strategies = _executor._get_strategy(model._train_network) + for (k, v) in strategies.items(): + if re.search('Conv2D-op', k) is not None: + assert v[0][0] == dev_num + elif re.search('MatMul-op', k) is not None: + assert v == [[dev_num, 1], [1, 1]] + elif re.search('ReduceSum-op', k) is not None: + assert v == [[dev_num, 1]] diff --git a/tests/ut/python/parallel/test_auto_parallel_tuple_depend.py b/tests/ut/python/parallel/test_auto_parallel_tuple_depend.py index ab6fdcb5ce..a80ccb550a 100644 --- a/tests/ut/python/parallel/test_auto_parallel_tuple_depend.py +++ b/tests/ut/python/parallel/test_auto_parallel_tuple_depend.py @@ -46,7 +46,7 @@ class GradWrap(nn.Cell): def bn_with_initialize(out_channels): - bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5) + bn = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) return bn diff --git a/tests/ut/python/parallel/test_auto_parallel_two_bn.py b/tests/ut/python/parallel/test_auto_parallel_two_bn.py index 3c73290b1e..029d85ab3c 100644 --- a/tests/ut/python/parallel/test_auto_parallel_two_bn.py +++ b/tests/ut/python/parallel/test_auto_parallel_two_bn.py @@ -40,7 +40,7 @@ class NetWithLoss(nn.Cell): class Blockcell(nn.Cell): def __init__(self): super(Blockcell, self).__init__() - self.bn = nn.BatchNorm2d(64, momentum=0.9) + self.bn = nn.BatchNorm1d(64, momentum=0.9) def construct(self, x): out = self.bn(x) diff --git a/tests/ut/python/parallel/test_bias_add.py b/tests/ut/python/parallel/test_bias_add.py index 321810b1ae..573efde125 100644 --- a/tests/ut/python/parallel/test_bias_add.py +++ b/tests/ut/python/parallel/test_bias_add.py @@ -61,6 +61,9 @@ class DatasetLenet(): def get_repeat_count(self): return 1 + def create_tuple_iterator(self): + return self + class Net(nn.Cell): def __init__(self): diff --git a/tests/ut/python/parallel/test_dataset_util.py b/tests/ut/python/parallel/test_dataset_util.py index f3c861dd68..28d70fd287 100644 --- a/tests/ut/python/parallel/test_dataset_util.py +++ b/tests/ut/python/parallel/test_dataset_util.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + import mindspore as ms from mindspore import Tensor from mindspore.train._utils import _to_full_shapes, _to_full_tensor @@ -33,7 +35,7 @@ def test_to_full_tensor_1(): expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]]) expect_tensor = Tensor(expect, dtype=ms.float32) - assert full_tensor[0] == expect_tensor + assert np.all(full_tensor[0].asnumpy() == expect_tensor.asnumpy()) def test_to_full_tensor_2(): @@ -50,7 +52,8 @@ def test_to_full_tensor_2(): expect_tensor1 = Tensor(expect1, dtype=ms.int32) expect_tensors = (expect_tensor0, expect_tensor1) - assert full_tensor == expect_tensors + assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy()) + assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy()) def test_to_full_tensor_sens_2(): @@ -68,4 +71,6 @@ def test_to_full_tensor_sens_2(): expect_tensor_sens = Tensor(0.1, dtype=ms.float32) expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens) - assert full_tensor == expect_tensors + assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy()) + assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy()) + assert np.all(full_tensor[2].asnumpy() == expect_tensors[2].asnumpy()) diff --git a/tests/ut/python/parallel/test_gather_v2_primitive.py b/tests/ut/python/parallel/test_gather_v2_primitive.py index 8aa093a24e..e6f269e2db 100644 --- a/tests/ut/python/parallel/test_gather_v2_primitive.py +++ b/tests/ut/python/parallel/test_gather_v2_primitive.py @@ -58,6 +58,9 @@ class Dataset(): def get_repeat_count(self): return 1 + def create_tuple_iterator(self): + return self + class GatherV2(_Loss): def __init__(self, index_dim, strategy, index_size=16): diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index a34ee94840..23649b5f0c 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -49,8 +49,8 @@ def test_get_parameter_layout(): net.set_auto_parallel() exe = me._executor exe.compile(net, x, phase='train', auto_parallel_mode=True) - x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] - weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] + x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] + weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] expect_dict = {'x': x_layout, 'w1': weight_layout} # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut assert net.parameter_layout_dict == expect_dict diff --git a/tests/ut/python/parallel/test_gpu_dropout.py b/tests/ut/python/parallel/test_gpu_dropout.py new file mode 100644 index 0000000000..f5ad20b4ee --- /dev/null +++ b/tests/ut/python/parallel/test_gpu_dropout.py @@ -0,0 +1,99 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) + + +class Net(nn.Cell): + def __init__(self, strategy1=None, strategy2=None): + super().__init__() + self.dropout = P.Dropout(keep_prob=0.6).set_strategy(strategy1) + self.matmul = P.MatMul().set_strategy(strategy2) + + def construct(self, x, y): + out = self.matmul(x, y) + out, _ = self.dropout(out) + return out + + +def test_dropout_semi_auto(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + net = GradWrap(NetWithLoss(Net())) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 128]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_dropout_semi_auto2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1),) + strategy2 = ((4, 2), (2, 1)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 128]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_dropout_semi_auto3(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4),) + strategy2 = ((4, 2), (2, 1)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 128]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_dropout_auto(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net())) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 128]), dtype=ms.float32) + _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index 6663e34871..ca5fe0ac3e 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -20,11 +20,10 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb +from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb from mindspore.ops import operations as P -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore import context - +from mindspore.parallel._auto_parallel_context import auto_parallel_context class Net(nn.Cell): """Net definition""" @@ -52,63 +51,64 @@ class Net(nn.Cell): return s -def test_AdamWeightDecayDynamicLR(): - """ test_AdamWeightDecayDynamicLR """ - auto_parallel_context().set_enable_parallel_optimizer(True) - context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) +def test_AdamWeightDecay(): + """ test_AdamWeightDecay """ + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) + optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) + context.reset_auto_parallel_context() -def test_AdamWeightDecay(): - """ test_AdamWeightDecayDynamicLR """ - auto_parallel_context().set_enable_parallel_optimizer(True) - context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) +def test_lamb_compile(): + """ test_Lamb_compile """ + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1) + optimizer = Lamb(net.trainable_params(), learning_rate=0.1) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) + context.reset_auto_parallel_context() -def test_lamb_compile(): - """ test_Lamb_compile """ - auto_parallel_context().set_enable_parallel_optimizer(True) - context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2) +def test_lamb_split_fusion(): + """ test_Lamb_split_fusion """ + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = Lamb(net.trainable_params(), decay_steps=10) + optimizer = Lamb(net.trainable_params(), learning_rate=0.1) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) - + context.reset_auto_parallel_context() def test_edge_case(): """ test_edge_case """ - auto_parallel_context().set_enable_parallel_optimizer(True) + context.set_auto_parallel_context(enable_parallel_optimizer=True) net = Net() with pytest.raises(RuntimeError): context.set_auto_parallel_context(parallel_mode="stand_alone") - Lamb(net.trainable_params(), decay_steps=10) + Lamb(net.trainable_params(), learning_rate=0.1) with pytest.raises(RuntimeError): Adam(net.trainable_params(), learning_rate=0.1) with pytest.raises(RuntimeError): context.set_auto_parallel_context(device_num=16) - Lamb(net.trainable_params(), decay_steps=10) + Lamb(net.trainable_params(), learning_rate=0.1) + context.reset_auto_parallel_context() diff --git a/tests/ut/python/parallel/test_reshape_skip_redistribution.py b/tests/ut/python/parallel/test_reshape_skip_redistribution.py new file mode 100644 index 0000000000..cbaf20d113 --- /dev/null +++ b/tests/ut/python/parallel/test_reshape_skip_redistribution.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, matmul_weight, strategy1=None): + super().__init__() + self.gatherv2 = P.GatherV2().set_strategy(strategy1) + self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True) + self.matmul = P.MatMul(transpose_b=False) + self.index = Tensor(np.ones([64, 64]), dtype=ms.int32) + self.matmul_weight = Parameter(matmul_weight, "w1") + self.axis = 0 + + def construct(self, x, b): + out = self.gatherv2(x, self.index, self.axis) + out = self.reshape(out, (64, -1)) + out = self.matmul(out, self.matmul_weight) + return out + + +_w1 = Tensor(np.ones([4096, 32]), dtype=ms.float32) +_x = Tensor(np.ones([64, 64]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + +def compile_net(net): + context.set_context(save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_reshape_skip_redistribution(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8), (1, 1)) + net = Net(_w1, strategy1) + compile_net(net) diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index c476b0cebc..19187cb262 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -81,8 +81,8 @@ def test_set_auto_parallel_context(): with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=1025) - auto_parallel_context().set_enable_parallel_optimizer(True) - assert auto_parallel_context().get_enable_parallel_optimizer() is True + context.set_auto_parallel_context(enable_parallel_optimizer=True) + assert context.get_auto_parallel_context("enable_parallel_optimizer") assert not auto_parallel_context().get_all_reduce_fusion_split_indices() diff --git a/tests/ut/python/parallel/test_sparse_feature_bprop.py b/tests/ut/python/parallel/test_sparse_feature_bprop.py index cd58261dbd..4070442020 100644 --- a/tests/ut/python/parallel/test_sparse_feature_bprop.py +++ b/tests/ut/python/parallel/test_sparse_feature_bprop.py @@ -18,16 +18,13 @@ import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import context -from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore.ops import composite as C -from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator -from mindspore.ops._grad.grad_base import bprop_getters -from mindspore._checkparam import Validator as validator -from mindspore._checkparam import Rel -from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer +from mindspore.ops import composite as C, operations as P +from mindspore.ops.operations.comm_ops import AllReduce from mindspore.common.api import _executor -from mindspore.communication.management import HCCL_WORLD_COMM_GROUP +from mindspore.nn import TrainOneStepCell, Adam + class GradWrap(nn.Cell): def __init__(self, network): @@ -37,40 +34,9 @@ class GradWrap(nn.Cell): def construct(self, x): return C.grad_all(self.network)(x) -class VirtualGatherV2(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - """init index_select""" - super(VirtualGatherV2, self).__init__('VirtualGatherV2') - self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - - def __infer__(self, params, indices, axis): - validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) - validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) - axis_v = axis['value'] - params_shp = params['shape'] - rank = len(params_shp) - validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) - if axis_v < 0: - axis_v += rank - out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None} - return out - -@bprop_getters.register(VirtualGatherV2) -def get_bprop_gather_v2(self): - """Generate bprop for GatherV2""" - - def bprop(x, indices, axis, out, dout): - return (indices, dout, x), axis, out - - return bprop - def test_bprop_with_sparse_feature_allreduce(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") + context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, axis=0, shape=None): @@ -78,7 +44,7 @@ def test_bprop_with_sparse_feature_allreduce(): if shape is None: shape = [8, 8] self.all_reduce = AllReduce() - self.gatherv2 = VirtualGatherV2() + self.gatherv2 = P.SparseGatherV2() self.index = Tensor(np.ones(shape), dtype=ms.int32) self.axis = axis @@ -93,26 +59,66 @@ def test_bprop_with_sparse_feature_allreduce(): _executor.compile(net, x) + def test_bprop_with_sparse_feature_mirror(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + context.set_context(enable_sparse=True) + + class Net(nn.Cell): + def __init__(self, shape=None): + super(Net, self).__init__() + if shape is None: + shape = [8, 8] + weight = Tensor(np.ones([64, 64]), dtype=ms.float32) + self.weight = Parameter(weight, "w") + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.embeddinglookup = nn.EmbeddingLookup() + self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1))) + + def construct(self, x, b): + out = self.embeddinglookup(self.weight, self.index) + + return out + + _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) + _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) + + def compile_net(net): + optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + + net = Net() + compile_net(net) + + +def test_bprop_with_sparse_feature_dataparallel(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel") + context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, axis=0, shape=None): super(Net, self).__init__() if shape is None: shape = [8, 8] - self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) - self.gatherv2 = VirtualGatherV2() + weight = Tensor(np.ones([64, 64]), dtype=ms.float32) + self.weight = Parameter(weight, "w") self.index = Tensor(np.ones(shape), dtype=ms.int32) self.axis = axis + self.gatherv2 = P.SparseGatherV2() - def construct(self, x): - out = self.mirror(x) - out = self.gatherv2(out, self.index, self.axis) + def construct(self, x, b): + out = self.gatherv2(self.weight, self.index, self.axis) return out - net = GradWrap(Net()) - x = Tensor(np.ones([64, 64]), dtype=ms.float32) + _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) + _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) - _executor.compile(net, x) + def compile_net(net): + optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + + net = Net() + compile_net(net) diff --git a/tests/ut/python/parallel/test_stridedslice.py b/tests/ut/python/parallel/test_stridedslice.py new file mode 100644 index 0000000000..9ee190b14a --- /dev/null +++ b/tests/ut/python/parallel/test_stridedslice.py @@ -0,0 +1,164 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.strided_slice = P.StridedSlice(begin_mask=mask).set_strategy(strategy2) + if is_parameter: + self.weight = Parameter(weight, "w1") + else: + self.weight = weight + self.mul2 = P.Mul() + self.weight2 = Parameter(w2, "w2") + self.begin = begin + self.end = end + self.strides = strides + + def construct(self, x, b): + out = self.strided_slice(self.weight, self.begin, self.end, self.strides) + out = self.mul(x, out) + out = self.mul2(out, self.weight2) + return out + + +class Net2(Cell): + def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.strided_slice = P.StridedSlice().set_strategy(strategy2) + self.weight2 = Parameter(weight2, "w2") + self.begin = begin + self.end = end + self.strides = strides + + def construct(self, x, b): + out = self.mul(x, self.weight2) + out = self.strided_slice(out, self.begin, self.end, self.strides) + return out + + +_x = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) +_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32) +_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_stridedslice_no_fully_fetch_split_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_stridedslice_strides_no_1_split_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), strategy1, strategy2, is_parameter=True) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_stridedslice_mask_no_0_split_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, mask=1) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_stridedslice_begin_size_smaller(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 4, 2),) + net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_stridedslice_parameter(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 4, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_stridedslice_tensor(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 4, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False) + compile_net(net) + + +def test_stridedslice_parameter_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 1), (1, 4, 2)) + strategy2 = ((1, 2, 2),) + net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_stridedslice_output(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = ((1, 8, 1),) + net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) + compile_net(net) + + +def test_stridedslice_output_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = ((1, 4, 1),) + net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) + compile_net(net) + + +def test_stridedslice_no_strategy(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = None + net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2) + compile_net(net) + + +def test_stridedslice_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1)) + compile_net(net) diff --git a/tests/ut/python/parallel/test_tile.py b/tests/ut/python/parallel/test_tile.py new file mode 100644 index 0000000000..22832460ba --- /dev/null +++ b/tests/ut/python/parallel/test_tile.py @@ -0,0 +1,128 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.tile = P.Tile().set_strategy(strategy2) + if is_parameter: + self.weight = Parameter(weight, "w1") + else: + self.weight = weight + self.mul2 = P.Mul() + self.weight2 = Parameter(weight2, "w2") + + def construct(self, x, b): + out = self.tile(self.weight, (8, 4, 2)) + out = self.mul(x, out) + out = self.mul2(out, self.weight2) + return out + + +class Net2(Cell): + def __init__(self, weight2, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.tile = P.Tile().set_strategy(strategy2) + self.weight2 = Parameter(weight2, "w2") + + def construct(self, x, b): + out = self.mul(x, self.weight2) + out = self.tile(out, (8, 8, 4, 2)) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32) +_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_tile_parameter(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 2),) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_tile_parameter_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 1),) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) + compile_net(net) + + +def test_tile_tensor(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 2),) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) + compile_net(net) + + +def test_tile_tensor_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((2, 2, 1),) + net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) + compile_net(net) + + +def test_tile_output(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 2, 2),) + net = Net2(_w2, strategy1, strategy2) + compile_net(net) + +def test_tile_output_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = ((1, 2, 1, 2),) + net = Net2(_w2, strategy1, strategy2) + compile_net(net) + + +def test_tile_no_strategy(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2), (2, 2, 2)) + strategy2 = None + net = Net2(_w2, strategy1, strategy2) + compile_net(net) + +def test_tile_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net2(_w2) + compile_net(net) diff --git a/tests/ut/python/parameter_feature/test_parameter.py b/tests/ut/python/parameter_feature/test_parameter.py index 289fd35e81..3088cd7172 100644 --- a/tests/ut/python/parameter_feature/test_parameter.py +++ b/tests/ut/python/parameter_feature/test_parameter.py @@ -47,7 +47,7 @@ def test_parser_three_default_mixed_args_subnet(): tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32)) tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32)) net = NetOut() - assert net(tensor1, tensor2) == tensor1 + assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy()) # pylint: disable=keyword-arg-before-vararg diff --git a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py index c292e3662d..1e08b5c869 100644 --- a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py +++ b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py @@ -53,4 +53,7 @@ def test_hypermap_specialize_param(): expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) ret = hypermap_specialize_param() - assert ret == (expected_ret, list(expected_ret)) + assert ret[0][0].asnumpy() == expected_ret[0].asnumpy() + assert np.all(ret[0][1].asnumpy() == expected_ret[1].asnumpy()) + assert ret[1][0].asnumpy() == list(expected_ret[0].asnumpy()) + assert np.all(ret[1][1].asnumpy() == list(expected_ret[1].asnumpy())) diff --git a/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py b/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py new file mode 100644 index 0000000000..9781164038 --- /dev/null +++ b/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test interface 'all' and 'any' of tensor """ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + + +def test_all_and_any_of_tensor_in_graph(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + all_ = x.all() + any_ = x.any() + all_0 = x.all(0, True) + any_0 = x.any(0, True) + return all_, any_, all_0, any_0 + + net = Net() + x = Tensor(np.array([[True, False, False], [True, False, False]])) + context.set_context(mode=context.GRAPH_MODE) + net(x) + + +def test_all_and_any_of_tensor_in_pynative(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + all_ = x.all() + any_ = x.any() + all_0 = x.all(0, True) + any_0 = x.any(0, True) + return all_, any_, all_0, any_0 + + net = Net() + x = Tensor(np.array([[True, False, True], [True, False, False]])) + context.set_context(mode=context.PYNATIVE_MODE) + ret = net(x) + assert ret[0].asnumpy() == np.array(False) + assert ret[1].asnumpy() == np.array(True) + assert ret[2].asnumpy().shape == np.array([[True, False, False]]).shape + assert (ret[2].asnumpy() == np.array([[True, False, False]])).all() + assert ret[3].shape == Tensor(np.array([[True, False, True]])).shape + assert (ret[3] == Tensor(np.array([[True, False, True]]))).all() diff --git a/tests/ut/python/pipeline/infer/test_net_infer.py b/tests/ut/python/pipeline/infer/test_net_infer.py index 9c19f213f5..51bfcf87cd 100644 --- a/tests/ut/python/pipeline/infer/test_net_infer.py +++ b/tests/ut/python/pipeline/infer/test_net_infer.py @@ -66,5 +66,63 @@ def test_assign_in_while(): input_shape = (1024, 512) z = Tensor(np.random.randn(*input_shape).astype(np.float32)) net = Net(input_shape) - ret = net(x, y, z) - assert ret == z + net(x, y, z) + + +def test_dup_context(): + ''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and + Evaluator. + ''' + context.set_context(mode=context.GRAPH_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + def identity(f): + return f + + def func_with_fv(): + return x + + def net1(): + local_func = identity(func_with_fv) + out = local_func() + 20.0 + return out + + def net2(): + local_func = identity(func_with_fv) + out = local_func() + 15.0 + return out + + return net1() + net2() + + Net()(5.0) + + +def test_maybe_poly_func(): + ''' different func_with_fv in net1 and net2 may produce poly node. ''' + context.set_context(mode=context.GRAPH_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y, z): + def identity(f, inp): + return f(inp) + + def func_with_fv(yy): + return (x, yy) + + def make_call(): + out1 = identity(func_with_fv, y) + out2 = identity(func_with_fv, z) + return (out1, out2) + + return make_call() + + y_input = Tensor(np.array([1, 2]).astype(np.int32)) + z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) + Net()(1, y_input, z_input) diff --git a/tests/ut/python/pipeline/parse/test_cell_bprop.py b/tests/ut/python/pipeline/parse/test_cell_bprop.py deleted file mode 100644 index e896ddc9ac..0000000000 --- a/tests/ut/python/pipeline/parse/test_cell_bprop.py +++ /dev/null @@ -1,405 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test_cell_bprop """ -import numpy as np -import pytest - -import mindspore as ms -import mindspore.common.dtype as mstype -import mindspore.nn as nn -from mindspore import Parameter -from mindspore import context -from mindspore.common.initializer import initializer -from mindspore.common.tensor import Tensor -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from .....mindspore_test_framework.utils.bprop_util import bprop - - -def setup_module(module): - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - -def teardown_module(module): - context.set_context(device_target="Ascend") - -class MulAdd(nn.Cell): - def __init__(self): - super(MulAdd, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result - return 2 * dout, 2 * y - - -def test_grad_mul_add(): - mul_add = MulAdd() - x = Tensor(1, dtype=ms.int32) - y = Tensor(2, dtype=ms.int32) - assert C.grad_all(mul_add)(x, y) == (2, 4) - - -class InlineMulADD(nn.Cell): - def __init__(self): - super(InlineMulADD, self).__init__() - self.mul_add = MulAdd() - self.param = 2 - - def construct(self, x, y): - return self.mul_add(x, y) + x + self.param * y - - -def test_grad_inline_mul_add(): - inline_mul_add = InlineMulADD() - x = Tensor(1, dtype=ms.int32) - y = Tensor(2, dtype=ms.int32) - assert C.grad_all(inline_mul_add)(x, y) == (3, 6) - - -class WithParameter(nn.Cell): - def __init__(self): - super(WithParameter, self).__init__() - self.param1 = Parameter(1, 'param1') - self.param2 = Parameter(2, 'param2') - - def construct(self, x, y): - return self.param1 * self.param2 * x + y - - def bprop(self, x, y, out, dout): - # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result - return self.param1 * self.param2 * dout, 2 * y - - -def test_with_param(): - with_param = WithParameter() - with pytest.raises(RuntimeError): - C.grad_all(with_param)(1, 2) - - -class WithNoBprop(nn.Cell): - def __init__(self): - super(WithNoBprop, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - -def test_with_no_bprop(): - with_no_bprop = WithNoBprop() - x = Tensor(1, dtype=ms.int32) - y = Tensor(2, dtype=ms.int32) - assert C.grad_all(with_no_bprop)(x, y) == (2, 1) - - -def test_grad_in_bprop_1(): - class GradInBprop_1(nn.Cell): - def __init__(self): - super(GradInBprop_1, self).__init__() - self.relu = P.ReLU() - - def construct(self, x, y): - return self.relu(x) - - class GradInBprop_2(nn.Cell): - def __init__(self): - super(GradInBprop_2, self).__init__() - self.f = GradInBprop_1() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return out[1][0], grads[1] - - class GradInBprop_3(nn.Cell): - def __init__(self): - super(GradInBprop_3, self).__init__() - self.f = GradInBprop_2() - - def construct(self, x, y): - return self.f(x, y) - - grad_in_bprop = GradInBprop_3() - grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), - Tensor(np.ones([2, 2]).astype(np.float32))) - assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() - - -def test_grad_in_bprop_2(): - class GradInBprop_1(nn.Cell): - def __init__(self): - super(GradInBprop_1, self).__init__() - self.relu = P.ReLU() - - def construct(self, x, y): - return self.relu(x) - - def bprop(self, x, y, out, dout): - return x * y, y + x - - class GradInBprop_2(nn.Cell): - def __init__(self): - super(GradInBprop_2, self).__init__() - self.f = GradInBprop_1() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return out[1][0], grads[1] - - class GradInBprop_3(nn.Cell): - def __init__(self): - super(GradInBprop_3, self).__init__() - self.f = GradInBprop_2() - - def construct(self, x, y): - return self.f(x, y) - - grad_in_bprop = GradInBprop_3() - grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), - Tensor(np.ones([2, 2]).astype(np.float32))) - assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() - - -def test_grad_in_bprop_3(): - class GradInBprop_1(nn.Cell): - def __init__(self): - super(GradInBprop_1, self).__init__() - self.relu = P.ReLU() - - def construct(self, x, y): - return self.relu(x) - - class GradInBprop_2(nn.Cell): - def __init__(self): - super(GradInBprop_2, self).__init__() - self.f = GradInBprop_1() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return out[1][0], grads[1] - - class GradInBprop_3(nn.Cell): - def __init__(self): - super(GradInBprop_3, self).__init__() - self.f = GradInBprop_2() - - def construct(self, x, y): - return self.f(x, y) - - def bprop(self, x, y, out, dout): - return x + y + y + out[0], x + x + y + y + dout[0] - - grad_in_bprop = GradInBprop_3() - grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), - Tensor(np.ones([2, 2]).astype(np.float32))) - assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() - - -class OneInputBprop(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.ReLU() - - def construct(self, x): - return self.op(x) - - def bprop(self, x, out, dout): - return (5 * x,) - - -def test_grad_one_input_bprop(): - net = OneInputBprop() - input1 = Tensor(np.ones([2, 2]).astype(np.float32)) - grad = C.grad_all(net)(input1) - assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() - - -class TwoInput(nn.Cell): - def __init__(self): - super().__init__() - - def construct(self, x, y): - return x * y - - -class InlineBpropTwoInput(nn.Cell): - def __init__(self): - super().__init__() - self.f = TwoInput() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return grads[0] * 2, grads[1] * 2 - - -def test_grad_inline_bprop_two_input(): - net = InlineBpropTwoInput() - input1 = Tensor(np.ones([2, 2]).astype(np.float32)) - input2 = Tensor(np.ones([2, 2]).astype(np.float32)) - grads = C.grad_all(net)(input1, input2) - assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - assert len(grads) == 2 - - -class TwoInputBprop(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - - def construct(self, x, y): - return self.op(x, y) - - def bprop(self, x, y, out, dout): - return 5 * x, 8 * y - - -class TwoInputWithParameter(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") - - def construct(self, x, y): - x = self.inputdata + x - return self.op(x, y) - - -class TwoInputWithOnlyInitParameterBprop(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") - - def construct(self, x, y): - return self.op(x, y) - - def bprop(self, x, y, out, dout): - return 5 * x, 8 * y - - -class InlineMutilTwoInputParameterCell(nn.Cell): - def __init__(self): - super().__init__() - self.f1 = TwoInputBprop() - self.f2 = TwoInput() - self.f3 = TwoInputWithParameter() - self.f4 = TwoInputWithOnlyInitParameterBprop() - - def construct(self, x, y): - output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) - return output - - -def test_grad_inline_bprop_multi_input(): - net = InlineMutilTwoInputParameterCell() - input1 = Tensor(np.ones([2, 2]).astype(np.float32)) - input2 = Tensor(np.ones([2, 2]).astype(np.float32)) - net.init_parameters_data() - grads = C.grad_all(net)(input1, input2) - assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() - assert len(grads) == 2 - - -class MulAddWithParam(nn.Cell): - def __init__(self): - super(MulAddWithParam, self).__init__() - self.mul_add = MulAdd() - self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param') - - def construct(self, x): - return self.mul_add(self.param, x) - - -def test_refkey_bprop(): - net = MulAddWithParam() - input_data = Tensor(np.array([2, 2], np.float32)) - grads = bprop(net, input_data, - grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))), - wrt=['params', 'inputs'], - params=net.trainable_params()) - assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() - assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - - -class MulAddWithWrongOutputNum(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputNum, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return (2 * dout,) - - -def test_grad_mul_add_with_wrong_output_num(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputNum() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, 2) - - -class MulAddWithWrongOutputType(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputType, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return 2 * dout, 2 - - -def test_grad_mul_add_with_wrong_output_type(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputType() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) - - -class MulAddWithWrongOutputShape(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputShape, self).__init__() - self.ones = Tensor(np.ones([2,])) - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return 2, self.ones - - -def test_grad_mul_add_with_wrong_output_shape(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputShape() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) diff --git a/tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py b/tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py new file mode 100644 index 0000000000..f46242d7f6 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test dtype and shape as attr""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore import dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE) + + +def test_dtype_and_shape_as_attr(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + shape = x.shape + dtype = x.dtype + return shape, dtype + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + ret = net(x) + assert ret == ((1, 2, 3), mstype.int32) + + +def test_dtype_and_shape_as_attr_to_new_tensor(): + class Net(nn.Cell): + def __init__(self, value): + super(Net, self).__init__() + self.fill = P.Fill() + self.value = value + + def construct(self, x): + dtype = x.dtype + shape = x.shape + y = self.fill(dtype, shape, self.value) + return y + + net = Net(2.2) + x = Tensor(np.ones([1, 2, 3], np.float32)) + ret = net(x) + assert (ret.asnumpy() == (np.zeros([1, 2, 3], np.float32) + 2.2)).all() + + +def test_type_not_have_the_attr(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + shape = x.shapes + return shape + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + with pytest.raises(RuntimeError) as ex: + net(x) + assert "The object of type: Tensor[Int32] has no method or attr: shapes" in str(ex.value) + + +def test_type_not_have_the_method(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + shape = x.dtypes() + return shape + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + with pytest.raises(RuntimeError) as ex: + net(x) + assert "The object of type: Tensor[Int32] has no method or attr: dtypes" in str(ex.value) diff --git a/tests/ut/python/pipeline/parse/test_parse.py b/tests/ut/python/pipeline/parse/test_parse.py index b295adcbec..5f5e3659b0 100644 --- a/tests/ut/python/pipeline/parse/test_parse.py +++ b/tests/ut/python/pipeline/parse/test_parse.py @@ -175,7 +175,7 @@ def test_bprop_with_wrong_output_num(): def construct(self, x, y): return BpropWithWrongOutputNum()(x, y) - with pytest.raises(TypeError): + with pytest.raises(ValueError): C.grad_all(BpropWithWrongOutputNumCell())(1, 2) def test_bprop_with_wrong_output_type(): @@ -247,7 +247,7 @@ def test_bprop_with_wrong_output_shape(): def construct(self, x): return BpropWithWrongOutputShape()(x) - with pytest.raises(TypeError): + with pytest.raises(ValueError): net = BpropWithWrongOutputShapeCell() net.set_grad() C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) diff --git a/tests/ut/python/pipeline/parse/test_super.py b/tests/ut/python/pipeline/parse/test_super.py new file mode 100644 index 0000000000..f8734584ad --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_super.py @@ -0,0 +1,144 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test super""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +class FatherNet(nn.Cell): + def __init__(self, x): + super(FatherNet, self).__init__(x) + self.x = x + + def construct(self, x, y): + return self.x * x + + def test_father(self, x): + return self.x + x + + +class MatherNet(nn.Cell): + def __init__(self, y): + super(MatherNet, self).__init__() + self.y = y + + def construct(self, x, y): + return self.y * y + + def test_mather(self, y): + return self.y + y + + +class SingleSubNet(FatherNet): + def __init__(self, x, z): + super(SingleSubNet, self).__init__(x) + self.z = z + + def construct(self, x, y): + ret_father_construct = super().construct(x, y) + ret_father_test = super(SingleSubNet, self).test_father(x) + ret_father_x = super(SingleSubNet, self).x + ret_sub_z = self.z + + return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z + + +class MulSubNet(FatherNet, MatherNet): + def __init__(self, x, y, z): + super(MulSubNet, self).__init__(x) + super(FatherNet, self).__init__(y) + self.z = z + + def construct(self, x, y): + ret_father_construct = super().construct(x, y) + ret_father_test = super(MulSubNet, self).test_father(x) + ret_father_x = super(MulSubNet, self).x + ret_mather_construct = super(FatherNet, self).construct(x, y) + ret_mather_test = super(FatherNet, self).test_mather(y) + ret_mather_y = super(FatherNet, self).y + ret_sub_z = self.z + + return ret_father_construct, ret_father_test, ret_father_x, \ + ret_mather_construct, ret_mather_test, ret_mather_y, ret_sub_z + + +class Net(nn.Cell): + def __init__(self, x): + super(Net, self).__init__() + self.x = x + + def construct(self, x, y): + ret = super(Net, self).construct(x, y) + return ret + + +def test_single_super(): + single_net = SingleSubNet(2, 3) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + single_net(x, y) + + +def test_mul_super(): + mul_net = MulSubNet(2, 3, 4) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + mul_net(x, y) + + +def test_super_cell(): + net = Net(2) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + with pytest.raises(RuntimeError) as er: + net(x, y) + assert "Unsupported syntax 'Raise'" in str(er.value) + + +def test_single_super_in(): + class FatherNetIn(nn.Cell): + def __init__(self, x): + super(FatherNetIn, self).__init__(x) + self.x = x + + def construct(self, x, y): + return self.x * x + + def test_father(self, x): + return self.x + x + + class SingleSubNetIN(FatherNetIn): + def __init__(self, x, z): + super(SingleSubNetIN, self).__init__(x) + self.z = z + + def construct(self, x, y): + ret_father_construct = super().construct(x, y) + ret_father_test = super(SingleSubNetIN, self).test_father(x) + ret_father_x = super(SingleSubNetIN, self).x + ret_sub_z = self.z + + return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z + + single_net_in = SingleSubNetIN(2, 3) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + single_net_in(x, y) diff --git a/tests/ut/python/predict/test_predict_save_model.py b/tests/ut/python/predict/test_predict_save_model.py deleted file mode 100644 index f57875d073..0000000000 --- a/tests/ut/python/predict/test_predict_save_model.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Function: - test network -Usage: - python test_predict_save_model.py --path ./ -""" - -import argparse -import os -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations as P -from mindspore.common.tensor import Tensor -from mindspore.train.serialization import export, load_checkpoint, load_param_into_net - - -class LeNet(nn.Cell): - def __init__(self): - super(LeNet, self).__init__() - self.relu = P.ReLU() - self.batch_size = 32 - - self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() - self.fc1 = nn.Dense(400, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, 10) - - def construct(self, input_x): - output = self.conv1(input_x) - output = self.relu(output) - output = self.pool(output) - output = self.conv2(output) - output = self.relu(output) - output = self.pool(output) - output = self.reshape(output, (self.batch_size, -1)) - output = self.fc1(output) - output = self.relu(output) - output = self.fc2(output) - output = self.relu(output) - output = self.fc3(output) - return output - - -parser = argparse.ArgumentParser(description='MindSpore Model Save') -parser.add_argument('--path', default='./lenet_model.ms', type=str, help='model save path') - -if __name__ == '__main__': - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - print("test lenet predict start") - seed = 0 - np.random.seed(seed) - batch = 32 - channel = 1 - input_h = 32 - input_w = 32 - origin_data = np.random.uniform(low=0, high=255, size=(batch, channel, input_h, input_w)).astype(np.float32) - origin_data.tofile("lenet_input_data.bin") - - input_data = Tensor(origin_data) - print(input_data.asnumpy()) - net = LeNet() - ckpt_file_path = "./tests/ut/python/predict/checkpoint_lenet.ckpt" - predict_args = parser.parse_args() - model_path_name = predict_args.path - - is_ckpt_exist = os.path.exists(ckpt_file_path) - if is_ckpt_exist: - param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path) - load_param_into_net(net, param_dict) - export(net, input_data, file_name=model_path_name, file_format='LITE') - print("test lenet predict success.") - else: - print("checkpoint file is not exist.") diff --git a/tests/ut/python/profiler/__init__.py b/tests/ut/python/profiler/__init__.py new file mode 100644 index 0000000000..586589132c --- /dev/null +++ b/tests/ut/python/profiler/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit test for profiler.""" +import os + +RAW_DATA_BASE = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../data/profiler_data')) +RAW_DATA = os.path.realpath(os.path.join(RAW_DATA_BASE, 'JOB1')) +RAW_DATA_JOB2 = os.path.realpath(os.path.join(RAW_DATA_BASE, 'JOB2')) +PROFILER_DIR = os.path.realpath(os.path.join(RAW_DATA_BASE, 'profiler')) diff --git a/tests/ut/python/profiler/parser/__init__.py b/tests/ut/python/profiler/parser/__init__.py new file mode 100644 index 0000000000..e30774307c --- /dev/null +++ b/tests/ut/python/profiler/parser/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/tests/ut/python/profiler/parser/test_aicpu_parser.py b/tests/ut/python/profiler/parser/test_aicpu_parser.py new file mode 100644 index 0000000000..47e3bc13aa --- /dev/null +++ b/tests/ut/python/profiler/parser/test_aicpu_parser.py @@ -0,0 +1,74 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test the aicpu parser.""" +import os +import tempfile +import shutil + +from unittest import TestCase + +from mindspore.profiler.parser.aicpu_data_parser import DataPreProcessParser + + +def get_result(file_path): + """ + Get result from the aicpu file. + + Args: + file_path (str): The aicpu file path. + + Returns: + list[list], the parsed aicpu information. + """ + result = [] + try: + file = open(file_path, 'r') + result.append(file.read()) + return result + finally: + if file: + file.close() + + +class TestAicpuParser(TestCase): + """Test the class of Aicpu Parser.""" + + def setUp(self) -> None: + """Initialization before test case execution.""" + self.profiling_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), + '../../../data/profiler_data/' + 'JOB_AICPU/data')) + self.expect_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), + '../../../data/profiler_data/' + 'JOB_AICPU/expect')) + self.output_path = tempfile.mkdtemp(prefix='output_data_preprocess_aicpu_') + self.output_file = os.path.join(self.output_path, 'output_data_preprocess_aicpu_0.txt') + self.expect_file = os.path.join(self.expect_dir, 'output_data_preprocess_aicpu_0.txt') + + def test_aicpu_parser(self): + """Test the class of Aicpu Parser.""" + data = DataPreProcessParser(self.profiling_dir, self.output_file) + data.execute() + expect_result = get_result(self.expect_file) + result = get_result(self.output_file) + shutil.rmtree(self.output_path) + assert expect_result == result + + def test_aicpu_parser_file_not_exist(self): + """Test the class of Aicpu Parser.""" + profiling_dir = os.path.realpath(os.path.join(self.profiling_dir, 'data')) + data = DataPreProcessParser(profiling_dir, self.output_file) + data.execute() + shutil.rmtree(self.output_path) diff --git a/tests/ut/python/profiler/parser/test_framework_parser.py b/tests/ut/python/profiler/parser/test_framework_parser.py new file mode 100644 index 0000000000..d37bd19dd9 --- /dev/null +++ b/tests/ut/python/profiler/parser/test_framework_parser.py @@ -0,0 +1,128 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test the framework parser module.""" +import csv +import os +import shutil +import tempfile +from unittest import mock + +import pytest + +from mindspore.profiler.common.exceptions.exceptions import \ + ProfilerFileNotFoundException +from mindspore.profiler.parser.framework_parser import FrameworkParser +from tests.ut.python.profiler import PROFILER_DIR, RAW_DATA_BASE + + +def get_framework_result(file_path): + """ + Get framework result from the framework file. + + Args: + file_path (str): The framework file path. + + Returns: + list[list], the parsed framework information. + """ + result = [] + with open(file_path, 'r') as file: + csv_reader = csv.reader(file) + for row in csv_reader: + result.append(row) + return result + + +class TestFrameworkParser: + """Test the class of `FrameworkParser`.""" + def setup_method(self): + """Initialization before test case execution.""" + with mock.patch.object(FrameworkParser, '_raw_data_dir', RAW_DATA_BASE): + self._output_path_1 = tempfile.mkdtemp(prefix='test_framework_parser_') + self._parser_1 = FrameworkParser('JOB1', '0', self._output_path_1) + self._output_path_2 = tempfile.mkdtemp(prefix='test_framework_parser_') + self._parser_2 = FrameworkParser('JOB2', '0', self._output_path_2) + self._output_path_4 = tempfile.mkdtemp(prefix='test_framework_parser_') + self._parser_4 = FrameworkParser('JOB4', '0', self._output_path_4) + + def teardown_method(self) -> None: + """Clear up after test case execution.""" + shutil.rmtree(self._output_path_1) + shutil.rmtree(self._output_path_2) + shutil.rmtree(self._output_path_4) + + def test_save_path(self): + """Test the querying save path function.""" + expect_result = os.path.join(self._output_path_1, 'framework_raw_0.csv') + assert expect_result == self._parser_1.save_path + + expect_result = os.path.join(self._output_path_2, 'framework_raw_0.csv') + assert expect_result == self._parser_2.save_path + + def test_point_info(self): + """Test the querying point info function.""" + expect_result = { + 1: 'Default/Cast-op6', + 2: 'Default/TransData-op7' + } + assert expect_result == self._parser_4.point_info + + def test_to_task_id_full_op_name_dict(self): + """Test the querying task id and full operator name dict function.""" + expect_result = { + '51517': 'Default/Cast-op6', + '51518': 'Default/TransData-op7', + '51519': 'Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5', + '51522': 'Default/network-WithLossCell/_backbone-ResNet/' + 'layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28' + } + assert expect_result == self._parser_1.to_task_id_full_op_name_dict() + assert expect_result == self._parser_2.to_task_id_full_op_name_dict() + + expect_result = { + '0_1': 'Default/Cast-op6', + '0_2': 'Default/TransData-op7', + '0_3': 'Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5', + '0_4': 'Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/' + '0-ResidualBlock/conv1-Conv2d/Cast-op28' + } + assert expect_result == self._parser_4.to_task_id_full_op_name_dict() + + def test_parse(self): + """Test the parse function.""" + expect_framework_file = os.path.join(PROFILER_DIR, 'framework_raw_0.csv') + expect_framework_file = os.path.realpath(expect_framework_file) + expect_result = get_framework_result(expect_framework_file) + + self._parser_1.parse() + framework_file = os.path.join(self._output_path_1, 'framework_raw_0.csv') + result = get_framework_result(framework_file) + assert expect_result == result + + self._parser_2.parse() + framework_file = os.path.join(self._output_path_2, 'framework_raw_0.csv') + result = get_framework_result(framework_file) + assert expect_result == result + + @mock.patch('os.listdir') + @mock.patch('os.path.isdir') + def test_create_framework_parser_fail_1(self, *args): + """Test the function of fail to create framework parser.""" + args[0].return_value = True + args[1].return_value = [] + with pytest.raises(ProfilerFileNotFoundException) as exc_info: + FrameworkParser('JOB1', '0') + assert exc_info.value.error_code == '50546084' + assert exc_info.value.message == 'The file not found.' diff --git a/tests/ut/python/profiler/parser/test_minddata_pipeline_parser.py b/tests/ut/python/profiler/parser/test_minddata_pipeline_parser.py new file mode 100644 index 0000000000..65ca7d7717 --- /dev/null +++ b/tests/ut/python/profiler/parser/test_minddata_pipeline_parser.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test the minddata pipeline parser module.""" +import csv +import os +import shutil +import tempfile + +from mindspore.profiler.parser.minddata_pipeline_parser import \ + MinddataPipelineParser +from tests.ut.python.profiler import PROFILER_DIR, RAW_DATA, RAW_DATA_JOB2 + + +def get_minddata_pipeline_result(file_path): + """ + Get minddata pipeline result from the minddata pipeline file. + + Args: + file_path (str): The minddata pipeline file path. + + Returns: + list[list], the parsed minddata pipeline information. + """ + result = [] + with open(file_path, 'r') as file: + csv_reader = csv.reader(file) + for row in csv_reader: + result.append(row) + return result + + +class TestMinddataPipelineParser: + """Test the class of `MinddataPipelineParser`.""" + def setup_method(self): + """Initialization before test case execution.""" + self._output_path_1 = tempfile.mkdtemp( + prefix='test_minddata_pipeline_parser_' + ) + self._parser_1 = MinddataPipelineParser( + RAW_DATA, '0', self._output_path_1 + ) + + self._output_path_2 = tempfile.mkdtemp( + prefix='test_minddata_pipeline_parser_' + ) + self._parser_2 = MinddataPipelineParser( + RAW_DATA_JOB2, '0', self._output_path_2 + ) + + def teardown_method(self) -> None: + """Clear up after test case execution.""" + shutil.rmtree(self._output_path_1) + shutil.rmtree(self._output_path_2) + + def test_save_path(self): + """Test the querying save path function.""" + expect_result = os.path.join( + self._output_path_1, 'minddata_pipeline_raw_0.csv' + ) + assert expect_result == self._parser_1.save_path + + def test_parse(self): + """Test the parse function.""" + expect_pipeline_file = os.path.join( + PROFILER_DIR, 'minddata_pipeline_raw_0.csv' + ) + expect_result = get_minddata_pipeline_result(expect_pipeline_file) + + self._parser_1.parse() + pipeline_file = os.path.join( + self._output_path_1, 'minddata_pipeline_raw_0.csv' + ) + result = get_minddata_pipeline_result(pipeline_file) + assert expect_result == result + + self._parser_2.parse() + pipeline_file = os.path.join( + self._output_path_2, 'minddata_pipeline_raw_0.csv' + ) + result = get_minddata_pipeline_result(pipeline_file) + assert expect_result == result diff --git a/tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py b/tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py index 0ea8ff7250..e530aa201a 100644 --- a/tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py +++ b/tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py @@ -39,5 +39,5 @@ def test_tensor_orign_ops(): assert np.all(z.asnumpy() - (x.asnumpy() + y.asnumpy()) < 0.0001) z = x * y assert np.all(z.asnumpy() - (x.asnumpy() * y.asnumpy()) < 0.0001) - assert x == y + assert np.all(x.asnumpy() == y.asnumpy()) assert x != 'zero' diff --git a/tests/ut/python/pynative_mode/nn/test_batchnorm.py b/tests/ut/python/pynative_mode/nn/test_batchnorm.py index 08f84f2fe3..a0fda8bd5e 100644 --- a/tests/ut/python/pynative_mode/nn/test_batchnorm.py +++ b/tests/ut/python/pynative_mode/nn/test_batchnorm.py @@ -67,10 +67,10 @@ def test_bn2d(): def test_bn1d(): """ut of nn.BatchNorm1d""" bn = nn.BatchNorm1d(3) - input_data = Tensor(np.random.randint(0, 1, [1, 3, 100, 100]).astype(np.float32)) + input_data = Tensor(np.random.randint(0, 1, [1, 3]).astype(np.float32)) output = bn(input_data) output_np = output.asnumpy() - assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) + assert isinstance(output_np[0][0], (np.float32, np.float64)) def test_bn2d_train(): diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index f028e91beb..475b48f524 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -17,7 +17,7 @@ import numpy as np import mindspore as ms import mindspore.ops.operations as P -from mindspore import Tensor +from mindspore import Tensor, context from mindspore.common.api import ms_function from mindspore.common.dtype import get_py_obj_dtype from mindspore.ops import composite as C @@ -25,6 +25,9 @@ from mindspore.ops import functional as F from mindspore.ops.composite import grad_all_with_sens from ...ut_filter import non_graph_engine +# pylint: disable=unused-argument +def setup_module(module): + context.set_context(mode=context.PYNATIVE_MODE) def mul(x, y): return x * y @@ -41,7 +44,7 @@ def test_grad(): @non_graph_engine -def test_expand_dims_grad(): +def Xtest_expand_dims_grad(): """ test_expand_dims_grad """ input_tensor = Tensor(np.array([[2, 2], [2, 2]])) expand_dims = P.ExpandDims() diff --git a/tests/ut/python/pynative_mode/ops/test_multitype.py b/tests/ut/python/pynative_mode/ops/test_multitype.py index 4ceff1d4ac..24d7edcc0b 100644 --- a/tests/ut/python/pynative_mode/ops/test_multitype.py +++ b/tests/ut/python/pynative_mode/ops/test_multitype.py @@ -57,7 +57,7 @@ def test_multitype_tuple(): params1 = Parameter(tensor1, name="params1") tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) output = op_add((params1, tensor2)) - assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32')) + assert np.all(output.asnumpy() == np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32')) def test_multitype_scalar(): diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 3b99d0dc5f..71c72491bb 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -380,7 +380,7 @@ def test_while_net(): x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32)) z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32)) res = t1_while(x, y, z) - assert res == Tensor(np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0) + assert np.all(res.asnumpy() == np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0) @ms_function @@ -403,7 +403,7 @@ def test_if_while(): x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32)) z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32)) res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z) - assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0) + assert np.all(res.asnumpy() == np.ones([64, 10]).astype(np.float32) * 4.0) def _while(x): @@ -550,7 +550,7 @@ def test_zeros(): """ test_zeros """ x = Tensor(np.ones([2, 3]).astype(np.int32)) res = zero_like_tensor(x) - assert res == Tensor(np.zeros([2, 3]).astype(np.int32)) + assert np.all(res.asnumpy() == np.zeros([2, 3]).astype(np.int32)) @ms_function @@ -811,7 +811,7 @@ def test_while_sp(): z = Tensor(np.ones([1, 3]).astype(np.float32)) x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0) res = while_sp(x, y, z) - assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0) + assert np.all(res.asnumpy() == np.ones([1, 3]).astype(np.float32) * 1024.0) def grad_refactor_simple_1(x, y): @@ -1030,7 +1030,7 @@ def test_grad_if_defer_inline(): network.add_flags(defer_inline=False) inp = Tensor(np.ones([128, 96]).astype(np.float32)) grads = C.grad_all(network)(inp) - assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) + assert np.all(grads[0].asnumpy() == np.full([128, 96], 0.6, dtype=np.float32)) def test_dict_const(): diff --git a/tests/ut/python/pynative_mode/test_high_order_grad.py b/tests/ut/python/pynative_mode/test_high_order_grad.py index 97fe7c3b68..71a7dda94d 100644 --- a/tests/ut/python/pynative_mode/test_high_order_grad.py +++ b/tests/ut/python/pynative_mode/test_high_order_grad.py @@ -19,7 +19,7 @@ from mindspore.ops.composite import grad, grad_all, grad_all_with_sens def setup_module(module): - context.set_context(mode=context.PYNATIVE_MODE) + context.set_context(mode=context.PYNATIVE_MODE, check_bprop=False) def single(x): diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index f34a81ab5c..8c5fc9d42e 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -152,7 +152,7 @@ def test_hook(): assert cell_hook_done assert var_hook_done assert cell_bprop_done - print(loss_output.asnumpy().shape) + print(loss_output.asnumpy()) bprop_debug = False diff --git a/tests/ut/python/pynative_mode/test_implicit_conversion.py b/tests/ut/python/pynative_mode/test_implicit_conversion.py index ecaffd87f2..3a19732462 100644 --- a/tests/ut/python/pynative_mode/test_implicit_conversion.py +++ b/tests/ut/python/pynative_mode/test_implicit_conversion.py @@ -14,6 +14,7 @@ # ============================================================================ """ test implicit conversion """ import numpy as np +import pytest from mindspore import Tensor, nn from mindspore.ops import composite as C @@ -90,6 +91,30 @@ def test_float_tensor_and_bool_tensors_add(): assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() +def test_float_tensor_and_str_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = "ok" + with pytest.raises(TypeError) as er: + ret = x + y + assert "For 'TensorAdd', the 1th input is a not support type: str" in str(er.value) + + +def test_float_tensor_and_tuple_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = (1, 2, 3) + with pytest.raises(TypeError) as er: + ret = x + y + assert "For 'TensorAdd', the 1th input is a not support type: tuple" in str(er.value) + + +def test_float_tensor_and_list_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = [1, 2, 3] + with pytest.raises(TypeError) as er: + ret = x + y + assert "For 'TensorAdd', the 1th input is a not support type: list" in str(er.value) + + def test_float_tensor_and_bool_tensors_add_grad(): class Net(nn.Cell): def __init__(self): @@ -104,7 +129,6 @@ def test_float_tensor_and_bool_tensors_add_grad(): self.net = net def construct(self, x, y, sens): - return C.grad_all_with_sens(self.net)(x, y, sens) x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) @@ -133,7 +157,6 @@ def test_float_tensor_and_int_tensors_sub_grad(): self.net = net def construct(self, x, y, sens): - return C.grad_all_with_sens(self.net)(x, y, sens) x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) @@ -163,7 +186,6 @@ def test_float16_tensor_and_float32_tensors_sub_grad(): self.net = net def construct(self, x, y, sens): - return C.grad_all_with_sens(self.net)(x, y, sens) x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.int32)) diff --git a/tests/ut/python/pynative_mode/test_multigraph_sink.py b/tests/ut/python/pynative_mode/test_multigraph_sink.py index e8ebe03797..c4ef44ef5a 100644 --- a/tests/ut/python/pynative_mode/test_multigraph_sink.py +++ b/tests/ut/python/pynative_mode/test_multigraph_sink.py @@ -131,3 +131,26 @@ def test_while_in_while(): output = while_in_while(c1, c2, c3) expect = Tensor([1274], mstype.int32) assert output == expect + + +@ms_function +def while_by_while_in_while(x, y, z): + out = c4 + while x < c2: + y = c4 + c4 + while y < c2: + y = y + 1 + out = out + y + z = c4 + c4 + while z < c2: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + return out + + +def test_while_by_while_in_while(): + output = while_by_while_in_while(c1, c2, c3) + expect = Tensor([350], mstype.int32) + assert output == expect diff --git a/tests/ut/python/pynative_mode/test_parse_method.py b/tests/ut/python/pynative_mode/test_parse_method.py index f189b825e9..0a8c1767db 100644 --- a/tests/ut/python/pynative_mode/test_parse_method.py +++ b/tests/ut/python/pynative_mode/test_parse_method.py @@ -304,6 +304,29 @@ def test_access(): """ test_access """ invoke_dataclass(1, 2) +@dataclass +class Access2: + a: int + b: int + + def max(self): + if self.a > self.b: + return self.c + return self.b + + +@ms_function +def invoke_dataclass2(x, y): + """ invoke_dataclass """ + acs = Access2(x, y) + return acs.max() + + +def test_access_attr_error(): + """ test_access """ + with pytest.raises(AttributeError): + invoke_dataclass2(1, 2) + def myfunc(x): """ myfunc """ diff --git a/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py b/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py index 5cc2ce35cc..35e0687b9d 100644 --- a/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py +++ b/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py @@ -20,7 +20,6 @@ from numpy.random import normal from mindspore import Tensor from mindspore import context from mindspore.common.api import ms_function -from mindspore.ops.composite import core def setup_module(module): @@ -34,7 +33,6 @@ def test_remove_phi_and_fv(): """ test_remove_phi_and_fv """ @ms_function - @core(loop_can_unroll=True) def loop(x, input_data): def fv_func(y): return x * y @@ -60,7 +58,6 @@ def test_remove_multiple_phi(): """ test_remove_multiple_phi """ @ms_function - @core(loop_can_unroll=True) def loop(x): def mul(a, b): return a * b @@ -83,7 +80,6 @@ def test_remove_multiple_phi_recursive(): """ test_remove_multiple_phi_recursive """ @ms_function - @core(loop_can_unroll=True) def loop(x): def mul(a, b): return a * b diff --git a/tests/ut/python/pynative_mode/test_sparse_pynative.py b/tests/ut/python/pynative_mode/test_sparse_pynative.py new file mode 100644 index 0000000000..4d9db16cb7 --- /dev/null +++ b/tests/ut/python/pynative_mode/test_sparse_pynative.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +@File : test_sparse_pynative.py +@Author: +@Date : 2020-08-04 +@Desc : test mindspore sparse pynative +""" +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor, RowTensor, SparseTensor +from mindspore.ops import composite as C + +context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) + + +grad_all = C.GradOperation('get_all', get_all=True) +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + def construct(self, *args): + grad = grad_all(self.network)(*args) + return grad + + +def test_row_tensor_attr(): + class RowTensorGetAttr(nn.Cell): + def __init__(self, dense_shape): + super(RowTensorGetAttr, self).__init__() + self.dense_shape = dense_shape + def construct(self, indices, values): + x = RowTensor(indices, values, self.dense_shape) + return x.values, x.indices, x.dense_shape + indices = Tensor([0]) + values = Tensor([[1, 2]], dtype=ms.float32) + RowTensorGetAttr((3, 2))(indices, values) + GradWrap(RowTensorGetAttr((3, 2)))(indices, values) + + +def test_sparse_tensor_attr(): + class SparseTensorGetAttr(nn.Cell): + def __init__(self): + super(SparseTensorGetAttr, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + x = SparseTensor(indices, values, self.dense_shape) + return x.values, x.indices, x.dense_shape + + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + SparseTensorGetAttr()(indices, values) + GradWrap(SparseTensorGetAttr())(indices, values) diff --git a/tests/ut/python/pynative_mode/test_stop_gradient.py b/tests/ut/python/pynative_mode/test_stop_gradient.py index 09e4f25c54..36ada9f5c2 100644 --- a/tests/ut/python/pynative_mode/test_stop_gradient.py +++ b/tests/ut/python/pynative_mode/test_stop_gradient.py @@ -256,7 +256,7 @@ def test_stop_gradient_4(): def stop_test(x): return stop_gradient(x) - assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) def test_stop_gradient_5(): diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 39e887170c..10e7250fcf 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -20,7 +20,7 @@ import mindspore.context as context from mindspore import Tensor from mindspore import nn from mindspore.train.quant import quant as qat -from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 +from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -69,17 +69,24 @@ def test_qat_lenet(): net = qat.convert_quant_network( net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here - for param in net.get_parameters(): - param.init_data() + net.init_parameters_data() qat.export(net, img, file_name="quant.pb") @pytest.mark.skip(reason="no `te.lang.cce` in ut env") -def test_qat_mobile(): +def test_qat_mobile_per_channel_tf(): network = mobilenetV2(num_classes=1000) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here - for param in network.get_parameters(): - param.init_data() + network.init_parameters_data() + qat.export(network, img, file_name="quant.pb") + +@pytest.mark.skip(reason="no `te.lang.cce` in ut env") +def test_qat_mobile_per_channel_ff(): + network = mobilenetV2(num_classes=1000) + img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) + network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False]) + # should load the checkpoint. mock here + network.init_parameters_data() qat.export(network, img, file_name="quant.pb") diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index 31552e44bd..48d79a80dc 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -294,10 +294,7 @@ class TestSummaryCollector: summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) assert summary_collector._is_parse_loss_success - assert summary_collector._get_loss(cb_params) == expected_loss - if expected_loss is None: - assert not summary_collector._is_parse_loss_success def test_get_optimizer_from_cb_params_success(self): """Test get optimizer success from cb params.""" @@ -381,7 +378,6 @@ class TestSummaryCollector: result = get_value() assert PluginEnum.HISTOGRAM.value == result[0][0] assert expected_names == [data[1] for data in result] - assert expected_values == [data[2] for data in result] @pytest.mark.parametrize("specified_data, action, expected_result", [ (None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA), diff --git a/tests/ut/python/train/test_dataset_helper.py b/tests/ut/python/train/test_dataset_helper.py new file mode 100644 index 0000000000..6540adfe12 --- /dev/null +++ b/tests/ut/python/train/test_dataset_helper.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test dataset helper.""" + +import pytest +import numpy as np +import mindspore.context as context +from mindspore.communication.management import init +from mindspore.train.dataset_helper import DatasetHelper +from ....dataset_mock import MindData + + +def get_dataset(batch_size=1): + dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) + dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), + (batch_size, 20), (batch_size, 20), (batch_size, 20)) + + dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types, + output_shapes=dataset_shapes, input_indexs=(0, 1)) + return dataset + + +def test_dataset_helper_dataset_sink_mode_str(): + dataset = get_dataset(32) + with pytest.raises(TypeError): + DatasetHelper(dataset, dataset_sink_mode="True") + + +def test_dataset_helper_dataset_sink_mode_int(): + dataset = get_dataset(32) + with pytest.raises(TypeError): + DatasetHelper(dataset, dataset_sink_mode=1) + + +def test_dataset_helper_sink_size_bool(): + dataset = get_dataset(32) + with pytest.raises(TypeError): + DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True) + + +def test_dataset_helper_sink_size_float(): + dataset = get_dataset(32) + with pytest.raises(TypeError): + DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0) + + +def test_dataset_helper_sink_size_negative(): + dataset = get_dataset(32) + with pytest.raises(ValueError): + DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2) + + +def test_dataset_iter_normal(): + dataset = get_dataset(32) + dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False) + count = 0 + for _ in range(2): + for _ in dataset_helper: + count += 1 + dataset.reset() + assert count == 6 + + +@pytest.mark.skipif('not context.get_context("enable_ge")') +def test_dataset_iter_ge(): + init() + dataset = get_dataset(32) + dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) + count = 0 + for _ in range(2): + for _ in dataset_helper: + count += 1 + assert count == 2 + + +@pytest.mark.skipif('context.get_context("enable_ge")') +def test_dataset_iter_ms_loop_sink(): + init() + context.set_context(enable_loop_sink=True) + dataset = get_dataset(32) + dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) + count = 0 + for _ in range(2): + for inputs in dataset_helper: + count += 1 + assert inputs == tuple() + assert count == 2 + + +@pytest.mark.skipif('context.get_context("enable_ge")') +def test_dataset_iter_ms(): + init() + context.set_context(enable_loop_sink=False) + dataset = get_dataset(32) + DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) diff --git a/tests/vm_impl/array_ops_vm_impl.py b/tests/vm_impl/array_ops_vm_impl.py index 91493a4c07..921d5c5182 100644 --- a/tests/vm_impl/array_ops_vm_impl.py +++ b/tests/vm_impl/array_ops_vm_impl.py @@ -194,7 +194,19 @@ def vm_impl_all(self): def vm_impl(x, axis): x = x.asnumpy() - out = vm.all(x, axis) + out = vm.all(x, axis, self.keep_dims) + return Tensor(out) + + return vm_impl + + +@vm_impl_getters.register(P.ReduceAny) +def vm_impl_any(self): + """Generate vm_impl function for Any""" + + def vm_impl(x, axis): + x = x.asnumpy() + out = vm.any(x, axis, self.keep_dims) return Tensor(out) return vm_impl diff --git a/tests/vm_impl/vm_interface.py b/tests/vm_impl/vm_interface.py index 2400b78902..fd4ab30d96 100644 --- a/tests/vm_impl/vm_interface.py +++ b/tests/vm_impl/vm_interface.py @@ -67,3 +67,5 @@ setattr(vm, "tanh", tanh) setattr(vm, "sigmoid", sigmoid) setattr(vm, 'maximum', maximum) setattr(vm, 'minimum', minimum) +setattr(vm, 'all', all_) +setattr(vm, 'any', any_) diff --git a/tests/vm_impl/vm_me.py b/tests/vm_impl/vm_me.py index 7216ec613b..58558ffa0f 100644 --- a/tests/vm_impl/vm_me.py +++ b/tests/vm_impl/vm_me.py @@ -554,9 +554,7 @@ def softmax_cross_entropy_with_logits(logits, labels): sample_num = labels.shape[0] prob = softmax(logits) log_likelihood = -np.log(prob[range(sample_num)]) * labels - # loss = np.sum(log_likelihood) - loss = log_likelihood - + loss = np.sum(log_likelihood) dx = prob.copy() dx[range(sample_num)] -= labels return loss, dx @@ -842,3 +840,35 @@ def minimum(x, y): numpy.ndarray, has the same type as x. """ return np.minimum(x, y) + + +def all_(x, axis=(), keep_dims=False): + """ + Check all array elements along a given axis evaluate to True. + + Args: + x (numpy.ndarray): An array to be reduced. + axis (Union[None, int, tuple(int)): Dimensions of reduction. + keep_dims (bool): Whether to keep the reduced dimensions. + + Returns: + numpy.ndarray, has the same type as x. + """ + axis = None if axis == () else axis + return np.all(x, axis, keepdims=keep_dims) + + +def any_(x, axis=(), keep_dims=False): + """ + Check any array element along a given axis evaluate to True. + + Args: + x (numpy.ndarray): An array to be reduced. + axis (Union[None, int, tuple(int)): Dimensions of reduction. + keep_dims (bool): Whether to keep the reduced dimensions. + + Returns: + numpy.ndarray, has the same type as x. + """ + axis = None if axis == () else axis + return np.any(x, axis, keepdims=keep_dims) diff --git a/third_party/OpenCL-CLHPP b/third_party/OpenCL-CLHPP new file mode 160000 index 0000000000..524f5ca96c --- /dev/null +++ b/third_party/OpenCL-CLHPP @@ -0,0 +1 @@ +Subproject commit 524f5ca96c3b9775f9d1debbdbcc2666bcce5c07 diff --git a/third_party/OpenCL-Headers b/third_party/OpenCL-Headers new file mode 160000 index 0000000000..879576679f --- /dev/null +++ b/third_party/OpenCL-Headers @@ -0,0 +1 @@ +Subproject commit 879576679f7cf319d4f27ed6639a85207dd91e07 diff --git a/third_party/apply_patches.sh b/third_party/apply_patches.sh deleted file mode 100755 index fbd06b68b6..0000000000 --- a/third_party/apply_patches.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -PWD_PATH=`pwd` -THIRD_PARTY_PATH=$(cd "$(dirname $0)"; pwd) -if [ $# -lt 1 ]; then - echo "Usage: sh apply_patches.sh [build_dir]" - echo " build_dir is the directory where you type \"cmake\"" - echo " Open source software incubator-tvm will be copied to build_dir" - echo " where patches will be applied on." - exit 1 -fi -BUILD_PATH=$1 - -if [ -d ${BUILD_PATH}/incubator-tvm ]; then - rm -rf ${BUILD_PATH}/incubator-tvm -fi -DLPACK_PATH=$2 -DMLC_PATH=$3 -RANG_PATH=$4 -TVM_PATH=$5 -mkdir ${BUILD_PATH}/incubator-tvm -cp -rf ${TVM_PATH}/* ${BUILD_PATH}/incubator-tvm/ -cp -rf ${DLPACK_PATH}/* ${BUILD_PATH}/incubator-tvm/3rdparty/dlpack/ -cp -rf ${DMLC_PATH}/* ${BUILD_PATH}/incubator-tvm/3rdparty/dmlc-core/ -cp -rf ${RANG_PATH}/* ${BUILD_PATH}/incubator-tvm/3rdparty/rang/ - -check_dir_not_empty() -{ - if [ ! $# -eq 1 ]; then - echo "Usage: check_dir_not_empty dir_path" - exit 1 - fi - - if [ ! -d $1 ]; then - echo "Directory $1 does not exist." - exit 1 - fi - - fileCounts=`ls $1 | wc -l` - if [ ${fileCounts} -eq 0 ]; then - echo "Directory $1 is empty." - exit 1 - fi -} - -apply_patch() -{ - if [ ! $# -eq 1 ]; then - echo "Usage: apply_patch patch_name" - exit 1 - fi - - if [ ! -f $1 ]; then - echo "Patch $1 does not exist." - exit 1 - fi - - patch -p1 < $1 - if [ $? -eq 0 ]; then - echo "Patch $1 applied successfully." - else - echo "Patch $1 not applied." - fi -} - -# apply patches on tvm -TVM_PATH=${BUILD_PATH}/incubator-tvm -TVM_PATCH_PATH=${THIRD_PARTY_PATH}/patch/incubator-tvm -check_dir_not_empty "${TVM_PATH}" -check_dir_not_empty "${TVM_PATCH_PATH}" -cd ${TVM_PATH} -apply_patch "${TVM_PATCH_PATH}/cmake.patch" -apply_patch "${TVM_PATCH_PATH}/find_library.patch" -apply_patch "${TVM_PATCH_PATH}/include.patch" -apply_patch "${TVM_PATCH_PATH}/src_pass.patch" - -cd ${PWD_PATH} diff --git a/third_party/eigen b/third_party/eigen new file mode 160000 index 0000000000..daf9bbeca2 --- /dev/null +++ b/third_party/eigen @@ -0,0 +1 @@ +Subproject commit daf9bbeca26e98da2eed0058835cbb04e0a30ad8 diff --git a/third_party/libjpeg-turbo b/third_party/libjpeg-turbo new file mode 160000 index 0000000000..b443c541b9 --- /dev/null +++ b/third_party/libjpeg-turbo @@ -0,0 +1 @@ +Subproject commit b443c541b9a6fdcac214f9f003de0aa13e480ac1 diff --git a/third_party/opencv b/third_party/opencv new file mode 160000 index 0000000000..bda89a6469 --- /dev/null +++ b/third_party/opencv @@ -0,0 +1 @@ +Subproject commit bda89a6469aa79ecd8713967916bd754bff1d931 diff --git a/third_party/patch/incubator-tvm/CMakeLists.txt b/third_party/patch/incubator-tvm/CMakeLists.txt deleted file mode 100644 index d8964579cd..0000000000 --- a/third_party/patch/incubator-tvm/CMakeLists.txt +++ /dev/null @@ -1,100 +0,0 @@ -cmake_minimum_required(VERSION 3.2) -project(tvm C CXX) -set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -# Utility functions -include(${TVM_DIR}/cmake/util/Util.cmake) -include(${TVM_DIR}/cmake/util/FindCUDA.cmake) - -# include directories -include_directories(AFTER "${TVM_DIR}/include") -include_directories(AFTER "${TVM_DIR}/src") -include_directories(AFTER "${TVM_DIR}") -include_directories(AFTER "${TVM_DIR}/src/schedule") - -include_directories(AFTER "${TVM_DIR}/3rdparty/dmlc-core/include") -include_directories(AFTER "${TVM_DIR}/3rdparty/dlpack/include") -include_directories(AFTER "${TVM_DIR}/3rdparty/compiler-rt") -include_directories(AFTER "${TVM_DIR}/3rdparty/rang/include") - -# lib contain dlopen and dlclose -set(TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) - -# add source group -file(GLOB_RECURSE GROUP_SOURCE "${TVM_DIR}/src/*.cc" "src/*.cc") -file(GLOB_RECURSE GROUP_INCLUDE "${TVM_DIR}/src/*.h" - "${TVM_DIR}/include/*.h" "src/*.h" "include/*.h") -assign_source_group("Source" ${GROUP_SOURCE}) -assign_source_group("Include" ${GROUP_INCLUDE}) - -file(GLOB COMPILER_SRCS - "pre_activate/gpu/*.cc" - ${TVM_DIR}/src/api/*.cc - ${TVM_DIR}/src/arithmetic/*.cc - ${TVM_DIR}/src/autotvm/*.cc - ${TVM_DIR}/src/codegen/*.cc - ${TVM_DIR}/src/lang/*.cc - ${TVM_DIR}/src/pass/*.cc - ${TVM_DIR}/src/op/*.cc - ${TVM_DIR}/src/node/*.cc - ${TVM_DIR}/src/schedule/*.cc - ${TVM_DIR}/src/runtime/*.cc - ${TVM_DIR}/src/runtime/vm/*.cc - ${TVM_DIR}/src/runtime/vm/profiler/*.cc - ${TVM_DIR}/src/codegen/stackvm/*.cc) - -file(GLOB_RECURSE RELAY_SRCS ${TVM_DIR}/src/relay/*.cc) -list(APPEND COMPILER_SRCS ${RELAY_SRCS}) - -file(GLOB DATATYPE_SRCS ${TVM_DIR}/src/codegen/datatype/*.cc) -list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) - -file(GLOB COMPILER_VERILOG_SRCS ${TVM_DIR}/src/codegen/verilog/*.cc) -list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) - -file(GLOB TOPI_SRCS ${TVM_DIR}/topi/src/*.cc) - -file(GLOB RUNTIME_SRCS - ${TVM_DIR}/src/runtime/*.cc - ${TVM_DIR}/src/runtime/vm/*.cc - ${TVM_DIR}/src/runtime/stub/*.cc - ${TVM_DIR}/src/runtime/stackvm/*.cc) - - -file(GLOB COMPILER_OFF_SRCS - ${TVM_DIR}/src/codegen/opt/build_*_off.cc) - -list(REMOVE_ITEM COMPILER_OFF_SRCS - ${TVM_DIR}/src/codegen/opt/build_cuda_off.cc) -set(USE_CUDA "ON") -list(APPEND COMPILER_SRCS ${COMPILER_OFF_SRCS}) -# Module rules -include(${TVM_DIR}/cmake/modules/CUDA.cmake) - -set(CMAKE_C_FLAGS_AKG -pipe -Wall -fPIC -fstack-protector-all) -set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) - -set(CMAKE_CXX_FLAGS_AKG -std=c++11 -pipe -Wall -fPIC -fstack-protector-all) -set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) - -if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("-- Build in Debug mode") - set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O0 -g -rdynamic) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O0 -g -rdynamic) -else() - message("-- Build in Release mode") - set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O2 -Werror) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O2 -Werror) -endif() -if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION - VERSION_GREATER 7.0) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -faligned-new) -endif() - -add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS} ${TOPI_SRCS}) - -target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) -target_compile_options(tvm PRIVATE - $<$:${CMAKE_C_FLAGS_AKG}> - $<$:${CMAKE_CXX_FLAGS_AKG}>) -target_include_directories(tvm PRIVATE "${TVM_DIR}/topi/include") -install(TARGETS tvm) \ No newline at end of file diff --git a/third_party/patch/incubator-tvm/cmake.patch b/third_party/patch/incubator-tvm/cmake.patch deleted file mode 100644 index 820c7e24fd..0000000000 --- a/third_party/patch/incubator-tvm/cmake.patch +++ /dev/null @@ -1,201 +0,0 @@ -diff -Npur tvm/cmake/modules/ANTLR.cmake tvm_new/cmake/modules/ANTLR.cmake ---- tvm/cmake/modules/ANTLR.cmake 2019-12-14 15:11:37.562418441 +0800 -+++ tvm_new/cmake/modules/ANTLR.cmake 2019-12-14 11:28:49.161977599 +0800 -@@ -14,12 +14,15 @@ - # KIND, either express or implied. See the License for the - # specific language governing permissions and limitations - # under the License. -+ -+# 2019.12.30 - Modify current directory of tvm. -+ - if(USE_ANTLR) - find_antlr(${USE_ANTLR}) - if(ANTLR4) - - set(RELAY_PARSER_DIR -- ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) -+ ${TVM_DIR}/python/tvm/relay/grammar) - - set(RELAY_PARSER - ${RELAY_PARSER_DIR}/py3/RelayVisitor.py -diff -Npur tvm/cmake/modules/CUDA.cmake tvm_new/cmake/modules/CUDA.cmake ---- tvm/cmake/modules/CUDA.cmake 2019-12-14 15:11:37.562418441 +0800 -+++ tvm_new/cmake/modules/CUDA.cmake 2019-12-14 11:28:49.161977599 +0800 -@@ -15,6 +15,8 @@ - # specific language governing permissions and limitations - # under the License. - -+# 2019.12.30 - Modify current directory of tvm. -+ - # CUDA Module - find_cuda(${USE_CUDA}) - -@@ -29,9 +31,9 @@ if(USE_CUDA) - message(FATAL_ERROR "Cannot find CUDA, USE_CUDA=" ${USE_CUDA}) - endif() - message(STATUS "Build with CUDA support") -- file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) -+ file(GLOB RUNTIME_CUDA_SRCS ${TVM_DIR}/src/runtime/cuda/*.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS}) -- list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_on.cc) -+ list(APPEND COMPILER_SRCS ${TVM_DIR}/src/codegen/opt/build_cuda_on.cc) - - list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY}) -@@ -40,18 +42,18 @@ if(USE_CUDA) - - if(USE_CUDNN) - message(STATUS "Build with cuDNN support") -- file(GLOB CONTRIB_CUDNN_SRCS src/runtime/contrib/cudnn/*.cc) -+ file(GLOB CONTRIB_CUDNN_SRCS ${TVM_DIR}/src/runtime/contrib/cudnn/*.cc) - list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_SRCS}) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY}) - endif(USE_CUDNN) - - if(USE_CUBLAS) - message(STATUS "Build with cuBLAS support") -- file(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc) -+ file(GLOB CONTRIB_CUBLAS_SRCS ${TVM_DIR}/src/runtime/contrib/cublas/*.cc) - list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS}) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY}) - endif(USE_CUBLAS) - - else(USE_CUDA) -- list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_off.cc) -+ list(APPEND COMPILER_SRCS ${TVM_DIR}/src/codegen/opt/build_cuda_off.cc) - endif(USE_CUDA) -diff -Npur tvm/cmake/modules/LLVM.cmake tvm_new/cmake/modules/LLVM.cmake ---- tvm/cmake/modules/LLVM.cmake 2019-12-14 15:11:37.562418441 +0800 -+++ tvm_new/cmake/modules/LLVM.cmake 2019-12-14 11:28:49.161977599 +0800 -@@ -15,6 +15,8 @@ - # specific language governing permissions and limitations - # under the License. - -+# 2019.12.30 - Modify current directory of tvm. -+ - # LLVM rules - add_definitions(-DDMLC_USE_FOPEN64=0) - -@@ -26,7 +28,7 @@ if(NOT USE_LLVM STREQUAL "OFF") - message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) - # Set flags that are only needed for LLVM target - add_definitions(-DTVM_LLVM_VERSION=${TVM_LLVM_VERSION}) -- file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc) -+ file(GLOB COMPILER_LLVM_SRCS ${TVM_DIR}/src/codegen/llvm/*.cc) - list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS}) - list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS}) - if(NOT MSVC) -diff -Npur tvm/cmake/modules/Micro.cmake tvm_new/cmake/modules/Micro.cmake ---- tvm/cmake/modules/Micro.cmake 2019-12-14 15:11:37.562418441 +0800 -+++ tvm_new/cmake/modules/Micro.cmake 2019-12-14 11:28:49.161977599 +0800 -@@ -15,8 +15,10 @@ - # specific language governing permissions and limitations - # under the License. - -+# 2019.12.30 - Modify current directory of tvm. -+ - if(USE_MICRO) - message(STATUS "Build with Micro support") -- file(GLOB RUNTIME_MICRO_SRCS src/runtime/micro/*.cc) -+ file(GLOB RUNTIME_MICRO_SRCS ${TVM_DIR}/src/runtime/micro/*.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_MICRO_SRCS}) - endif(USE_MICRO) -diff -Npur tvm/cmake/modules/VTA.cmake tvm_new/cmake/modules/VTA.cmake ---- tvm/cmake/modules/VTA.cmake 2019-12-14 15:11:37.562418441 +0800 -+++ tvm_new/cmake/modules/VTA.cmake 2019-12-14 14:42:32.358381133 +0800 -@@ -15,17 +15,19 @@ - # specific language governing permissions and limitations - # under the License. - -+# 2019.12.30 - Modify current directory of tvm. -+ - # CMake Build rules for VTA - find_program(PYTHON NAMES python python3 python3.6) - - if(MSVC) - message(STATUS "VTA build is skipped in Windows..") - elseif(PYTHON) -- set(VTA_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/vta/config/vta_config.py) -+ set(VTA_CONFIG ${PYTHON} ${TVM_DIR}/vta/config/vta_config.py) - - if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/vta_config.json) - message(STATUS "Use VTA config " ${CMAKE_CURRENT_BINARY_DIR}/vta_config.json) -- set(VTA_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/vta/config/vta_config.py -+ set(VTA_CONFIG ${PYTHON} ${TVM_DIR}/vta/config/vta_config.py - --use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json) - endif() - -@@ -40,18 +42,18 @@ elseif(PYTHON) - # Fast simulator driver build - if(USE_VTA_FSIM) - # Add fsim driver sources -- file(GLOB FSIM_RUNTIME_SRCS vta/src/*.cc) -- list(APPEND FSIM_RUNTIME_SRCS vta/src/sim/sim_driver.cc) -- list(APPEND FSIM_RUNTIME_SRCS vta/src/vmem/virtual_memory.cc vta/src/vmem/virtual_memory.h) -- list(APPEND FSIM_RUNTIME_SRCS vta/src/sim/sim_tlpp.cc) -+ file(GLOB FSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/*.cc) -+ list(APPEND FSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/sim/sim_driver.cc) -+ list(APPEND FSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/vmem/virtual_memory.cc ${TVM_DIR}/vta/src/vmem/virtual_memory.h) -+ list(APPEND FSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/sim/sim_tlpp.cc) - # Target lib: vta_fsim - add_library(vta_fsim SHARED ${FSIM_RUNTIME_SRCS}) -- target_include_directories(vta_fsim PUBLIC vta/include) -+ target_include_directories(vta_fsim PUBLIC ${TVM_DIR}/vta/include) - foreach(__def ${VTA_DEFINITIONS}) - string(SUBSTRING ${__def} 3 -1 __strip_def) - target_compile_definitions(vta_fsim PUBLIC ${__strip_def}) - endforeach() -- include_directories("vta/include") -+ include_directories("${TVM_DIR}/vta/include") - if(APPLE) - set_target_properties(vta_fsim PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") - endif(APPLE) -@@ -61,18 +63,18 @@ elseif(PYTHON) - # Cycle accurate simulator driver build - if(USE_VTA_TSIM) - # Add tsim driver sources -- file(GLOB TSIM_RUNTIME_SRCS vta/src/*.cc) -- list(APPEND TSIM_RUNTIME_SRCS vta/src/tsim/tsim_driver.cc) -- list(APPEND TSIM_RUNTIME_SRCS vta/src/dpi/module.cc) -- list(APPEND TSIM_RUNTIME_SRCS vta/src/vmem/virtual_memory.cc vta/src/vmem/virtual_memory.h) -+ file(GLOB TSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/*.cc) -+ list(APPEND TSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/tsim/tsim_driver.cc) -+ list(APPEND TSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/dpi/module.cc) -+ list(APPEND TSIM_RUNTIME_SRCS ${TVM_DIR}/vta/src/vmem/virtual_memory.cc ${TVM_DIR}/vta/src/vmem/virtual_memory.h) - # Target lib: vta_tsim - add_library(vta_tsim SHARED ${TSIM_RUNTIME_SRCS}) -- target_include_directories(vta_tsim PUBLIC vta/include) -+ target_include_directories(vta_tsim PUBLIC ${TVM_DIR}/vta/include) - foreach(__def ${VTA_DEFINITIONS}) - string(SUBSTRING ${__def} 3 -1 __strip_def) - target_compile_definitions(vta_tsim PUBLIC ${__strip_def}) - endforeach() -- include_directories("vta/include") -+ include_directories("${TVM_DIR}/vta/include") - if(APPLE) - set_target_properties(vta_tsim PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") - endif(APPLE) -@@ -80,19 +82,19 @@ elseif(PYTHON) - - # VTA FPGA driver sources - if(USE_VTA_FPGA) -- file(GLOB FPGA_RUNTIME_SRCS vta/src/*.cc) -+ file(GLOB FPGA_RUNTIME_SRCS ${TVM_DIR}/vta/src/*.cc) - # Rules for Zynq-class FPGAs with pynq OS support (see pynq.io) - if(${VTA_TARGET} STREQUAL "pynq" OR - ${VTA_TARGET} STREQUAL "ultra96") -- list(APPEND FPGA_RUNTIME_SRCS vta/src/pynq/pynq_driver.cc) -+ list(APPEND FPGA_RUNTIME_SRCS ${TVM_DIR}/vta/src/pynq/pynq_driver.cc) - # Rules for Pynq v2.4 - find_library(__cma_lib NAMES cma PATH /usr/lib) - elseif(${VTA_TARGET} STREQUAL "de10nano") # DE10-Nano rules -- file(GLOB FPGA_RUNTIME_SRCS vta/src/de10nano/*.cc vta/src/*.cc) -+ file(GLOB FPGA_RUNTIME_SRCS ${TVM_DIR}/vta/src/de10nano/*.cc ${TVM_DIR}/vta/src/*.cc) - endif() - # Target lib: vta - add_library(vta SHARED ${FPGA_RUNTIME_SRCS}) -- target_include_directories(vta PUBLIC vta/include) -+ target_include_directories(vta PUBLIC ${TVM_DIR}/vta/include) - foreach(__def ${VTA_DEFINITIONS}) - string(SUBSTRING ${__def} 3 -1 __strip_def) - target_compile_definitions(vta PUBLIC ${__strip_def}) diff --git a/third_party/patch/incubator-tvm/find_library.patch b/third_party/patch/incubator-tvm/find_library.patch deleted file mode 100644 index f7b2f9af0a..0000000000 --- a/third_party/patch/incubator-tvm/find_library.patch +++ /dev/null @@ -1,71 +0,0 @@ ---- tvm/python/tvm/_ffi/base.py 2020-03-12 16:17:39.089828527 +0800 -+++ tvm_new/python/tvm/_ffi/base.py 2020-03-12 16:17:16.829829558 +0800 -@@ -16,6 +16,9 @@ - # under the License. - # coding: utf-8 - # pylint: disable=invalid-name -+ -+# 2019.12.30 - Modify _load_lib function. -+ - """Base library for TVM FFI.""" - from __future__ import absolute_import - -@@ -47,8 +50,18 @@ else: - - - def _load_lib(): -- """Load libary by searching possible path.""" -- lib_path = libinfo.find_lib_path() -+ """Load library by searching possible path.""" -+ pwd = os.path.dirname(os.path.realpath(__file__)) -+ path = os.path.realpath(pwd+"/../../../mindspore/lib") -+ lib_path = [] -+ files = os.listdir(path) -+ for f in files: -+ if f.startswith("libtvm.") and f.endswith(".so"): -+ lib_path.append(path+"/"+f) -+ break -+ if not lib_path: -+ raise RuntimeError("mindspore library cannot find.") -+ - lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) - # DMatrix functions - lib.TVMGetLastError.restype = ctypes.c_char_p -diff -Npur tvm/topi/python/topi/cpp/impl.py tvm_new/topi/python/topi/cpp/impl.py ---- tvm/topi/python/topi/cpp/impl.py 2020-03-12 16:17:39.129828525 +0800 -+++ tvm_new/topi/python/topi/cpp/impl.py 2020-03-12 16:17:16.873829556 +0800 -@@ -14,6 +14,9 @@ - # KIND, either express or implied. See the License for the - # specific language governing permissions and limitations - # under the License. -+ -+# 2019.12.30 - Modify _load_lib function. -+ - """Load Lib for C++ TOPI ops and schedules""" - import sys - import os -@@ -30,12 +33,18 @@ def _get_lib_names(): - return ['libtvm_topi.so', 'tvm_topi.so'] - - def _load_lib(): -- """Load libary by searching possible path.""" -- curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) -- lib_search = curr_path -- lib_path = libinfo.find_lib_path(_get_lib_names(), lib_search, optional=True) -- if lib_path is None: -- return None, None -+ """Load library by searching possible path.""" -+ pwd = os.path.dirname(os.path.realpath(__file__)) -+ path = os.path.realpath(pwd+"/../../../mindspore/lib") -+ lib_path = [] -+ files = os.listdir(path) -+ for f in files: -+ if f.startswith("libtvm.") and f.endswith(".so"): -+ lib_path.append(path+"/"+f) -+ break -+ if not lib_path: -+ raise RuntimeError("mindspore library cannot find.") -+ - lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) - return lib, os.path.basename(lib_path[0]) - diff --git a/third_party/patch/incubator-tvm/include.patch b/third_party/patch/incubator-tvm/include.patch deleted file mode 100644 index 270c7a0f39..0000000000 --- a/third_party/patch/incubator-tvm/include.patch +++ /dev/null @@ -1,55 +0,0 @@ -diff -Npur tvm/include/tvm/expr_operator.h tvm_new/include/tvm/expr_operator.h ---- tvm/include/tvm/expr_operator.h 2019-12-28 10:11:27.369814744 +0800 -+++ tvm_new/include/tvm/expr_operator.h 2019-12-28 10:11:27.209812391 +0800 -@@ -25,6 +25,11 @@ - * when the type is int32 or int64 for simplifying the index expressions. - */ - // Acknowledgement: Most operator APIs originate from Halide. -+ -+/* -+ * 2019.12.30 - Add new operator for expr. -+ */ -+ - #ifndef TVM_EXPR_OPERATOR_H_ - #define TVM_EXPR_OPERATOR_H_ - -@@ -217,6 +222,16 @@ TVM_DLL Expr operator*(Expr a, Expr b); - */ - TVM_DLL Expr operator/(Expr a, Expr b); - /*! -+ * \brief mod operator -+ * -+ * \param a left operand -+ * \param b right operand -+ * \return The result expression. -+ * \note this function does eager constant folding for -+ * index types(int32, int64) when possible. -+ */ -+TVM_DLL Expr operator%(Expr a, Expr b); -+/*! - * \brief left shift operator - * - * \param a left operand -diff -Npur tvm/include/tvm/lowered_func.h tvm_new/include/tvm/lowered_func.h ---- tvm/include/tvm/lowered_func.h 2019-12-28 10:11:27.369814744 +0800 -+++ tvm_new/include/tvm/lowered_func.h 2019-12-28 10:11:27.209812391 +0800 -@@ -22,6 +22,11 @@ - * \brief Information about a lowered TVM function. - * This data structure is final step toward codegen. - */ -+ -+/* -+ * 2019.12.30 - Add new var array for args_real. -+ */ -+ - #ifndef TVM_LOWERED_FUNC_H_ - #define TVM_LOWERED_FUNC_H_ - -@@ -74,6 +79,7 @@ class LoweredFuncNode : public ir::Funct - * This function can only take pod type(int, float) and void* as arguments. - */ - Array args; -+ Array args_real; - /*! - * \brief The IterVar axis of threads - * Each axis need host function to specify a size. diff --git a/third_party/patch/incubator-tvm/src_pass.patch b/third_party/patch/incubator-tvm/src_pass.patch deleted file mode 100644 index 5450ca142e..0000000000 --- a/third_party/patch/incubator-tvm/src_pass.patch +++ /dev/null @@ -1,120 +0,0 @@ -diff -Npur tvm/src/pass/make_api.cc tvm_new/src/pass/make_api.cc ---- tvm/src/pass/make_api.cc 2019-12-14 15:11:37.626419432 +0800 -+++ tvm_new/src/pass/make_api.cc 2019-12-14 14:58:46.562493287 +0800 -@@ -20,6 +20,11 @@ - /*! - * \file make_api.cc Build API function. - */ -+ -+/* -+ * 2019.12.30 - Define new function to push buffer node from api_args to args_real. -+ */ -+ - #include - #include - #include -@@ -40,6 +45,17 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr - return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0)); - } - -+Array Param ( Array api_args,Array args_real) { -+ int num_args = static_cast(api_args.size()); -+ for (int i = 0; i < num_args; i++) { -+ const BufferNode *v = api_args[i].as(); -+ if(v) { -+ args_real.push_back(v->data); -+ } -+ } -+ return args_real; -+} -+ - LoweredFunc MakeAPI(Stmt body, - std::string name, - Array api_args, -@@ -47,6 +63,8 @@ LoweredFunc MakeAPI(Stmt body, - bool is_restricted) { - const Stmt nop = Evaluate::make(0); - int num_args = static_cast(api_args.size()); -+ Array args_real; -+ args_real = Param (api_args, args_real); - CHECK_LE(num_unpacked_args, num_args); - int num_packed_args = num_args - num_unpacked_args; - // Data field definitions -@@ -170,6 +188,7 @@ LoweredFunc MakeAPI(Stmt body, - NodePtr n = make_node(); - n->name = name; - n->args = args; -+ n->args_real = args_real; - n->handle_data_type = binder.def_handle_dtype(); - n->is_packed_func = num_unpacked_args == 0; - n->is_restricted = is_restricted; -diff -Npur tvm/src/pass/split_host_device.cc tvm_new/src/pass/split_host_device.cc ---- tvm/src/pass/split_host_device.cc 2019-12-14 15:11:37.626419432 +0800 -+++ tvm_new/src/pass/split_host_device.cc 2019-12-14 11:28:49.293979656 +0800 -@@ -21,6 +21,11 @@ - * \file split_host_device.cc - * \brief Split device function from host. - */ -+ -+/* -+ * 2019.12.30 - Add new implements for host device splitter. -+ */ -+ - #include - #include - #include -@@ -38,6 +43,7 @@ class IRUseDefAnalysis : public IRMutato - Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { - if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast(op->node); -+ iv = IterVarNode::make(Range(0, op->value), iv->var, iv->iter_type, iv->thread_tag); - CHECK_NE(iv->thread_tag.length(), 0U); - // thread_extent can appear multiple times - // use the first appearance as def. -@@ -186,6 +192,7 @@ class HostDeviceSplitter : public IRMuta - name_ = f->name; - NodePtr n = - make_node(*f.operator->()); -+ args_real = n->args_real; - n->body = this->Mutate(f->body); - n->func_type = kHostFunc; - Array ret{LoweredFunc(n)}; -@@ -196,6 +203,7 @@ class HostDeviceSplitter : public IRMuta - } - - private: -+ Array args_real; - Stmt SplitDeviceFunc(Stmt body) { - std::ostringstream os; - os << name_ << "_kernel" << device_funcs_.size(); -@@ -223,6 +231,30 @@ class HostDeviceSplitter : public IRMuta - n->args.push_back(v); - } - } -+std::shared_ptr na = std::make_shared(); -+ for (unsigned i = 0; i < (unsigned)args_real.size(); i++) { -+ bool match = false; -+ for (unsigned j = 0; j < (unsigned)n->args.size(); j++) { -+ if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) { -+ na->args.push_back(n->args[j]); -+ match = true; -+ break; -+ } else { -+ continue; -+ } -+ } -+ -+ if (!match) { -+ na->args.push_back(args_real[i]); -+ // mark handle data type. -+ for (auto kv : handle_data_type_) { -+ if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) { -+ n->handle_data_type.Set(args_real[i], kv.second); -+ } -+ } -+ } -+ } -+ n->args = na->args; - LoweredFunc f_device(n); - Array call_args; - call_args.push_back(StringImm::make(f_device->name)); diff --git a/third_party/patch/jpeg_turbo/jpeg_turbo.patch001 b/third_party/patch/jpeg_turbo/jpeg_turbo.patch001 new file mode 100644 index 0000000000..7f13ca5ac5 --- /dev/null +++ b/third_party/patch/jpeg_turbo/jpeg_turbo.patch001 @@ -0,0 +1,39 @@ +diff -Npur libjpeg-turbo-2.0.4/ChangeLog.md libjpeg-turbo-2.0.4-new/ChangeLog.md +--- libjpeg-turbo-2.0.4/ChangeLog.md 2019-12-31 15:10:30.000000000 +0800 ++++ libjpeg-turbo-2.0.4-new/ChangeLog.md 2020-07-29 19:12:06.259357156 +0800 +@@ -562,10 +562,10 @@ application was linked against. + + 3. Fixed a couple of issues in the PPM reader that would cause buffer overruns + in cjpeg if one of the values in a binary PPM/PGM input file exceeded the +-maximum value defined in the file's header. libjpeg-turbo 1.4.2 already +-included a similar fix for ASCII PPM/PGM files. Note that these issues were +-not security bugs, since they were confined to the cjpeg program and did not +-affect any of the libjpeg-turbo libraries. ++maximum value defined in the file's header and that maximum value was greater ++than 255. libjpeg-turbo 1.4.2 already included a similar fix for ASCII PPM/PGM ++files. Note that these issues were not security bugs, since they were confined ++to the cjpeg program and did not affect any of the libjpeg-turbo libraries. + + 4. Fixed an issue whereby attempting to decompress a JPEG file with a corrupt + header using the `tjDecompressToYUV2()` function would cause the function to +diff -Npur libjpeg-turbo-2.0.4/rdppm.c libjpeg-turbo-2.0.4-new/rdppm.c +--- libjpeg-turbo-2.0.4/rdppm.c 2019-12-31 15:10:30.000000000 +0800 ++++ libjpeg-turbo-2.0.4-new/rdppm.c 2020-07-29 17:55:33.129123386 +0800 +@@ -5,7 +5,7 @@ + * Copyright (C) 1991-1997, Thomas G. Lane. + * Modified 2009 by Bill Allombert, Guido Vollbeding. + * libjpeg-turbo Modifications: +- * Copyright (C) 2015-2017, D. R. Commander. ++ * Copyright (C) 2015-2017, 2020, D. R. Commander. + * For conditions of distribution and use, see the accompanying README.ijg + * file. + * +@@ -720,7 +720,7 @@ start_input_ppm(j_compress_ptr cinfo, cj + /* On 16-bit-int machines we have to be careful of maxval = 65535 */ + source->rescale = (JSAMPLE *) + (*cinfo->mem->alloc_small) ((j_common_ptr)cinfo, JPOOL_IMAGE, +- (size_t)(((long)maxval + 1L) * ++ (size_t)(((long)MAX(maxval, 255) + 1L) * + sizeof(JSAMPLE))); + half_maxval = maxval / 2; + for (val = 0; val <= (long)maxval; val++) { diff --git a/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch b/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch deleted file mode 100644 index 5977c943ef..0000000000 --- a/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch +++ /dev/null @@ -1,203 +0,0 @@ -diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h -index d668984..3676a61 100644 ---- a/include/tvm/runtime/registry.h -+++ b/include/tvm/runtime/registry.h -@@ -319,6 +319,19 @@ class Registry { - #define TVM_REGISTER_EXT_TYPE(T) \ - TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::runtime::ExtTypeVTable::Register_() -+/* -+ * Macro transfer TVM runtime API to custom runtime API -+ */ -+#define TVM_RT_FUNC_TRANS(OrigFuncStr) ({ \ -+ const runtime::PackedFunc* trans_func = runtime::Registry::Get("codegen.GetTransRTFunc");\ -+ const char* dst_func_str = nullptr; \ -+ if( trans_func != nullptr){ \ -+ dst_func_str = ((*trans_func)(OrigFuncStr)).ptr(); \ -+ }else{ \ -+ dst_func_str = OrigFuncStr; \ -+ } \ -+ dst_func_str; \ -+}) - - } // namespace runtime - } // namespace tvm -diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc -index 0ba0c58..2850ad4 100644 ---- a/src/codegen/llvm/codegen_cpu.cc -+++ b/src/codegen/llvm/codegen_cpu.cc -@@ -99,26 +99,26 @@ void CodeGenCPU::Init(const std::string& module_name, - // We will need this in environment for backward registration. - f_tvm_register_system_symbol_ = llvm::Function::Create( - llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false), -- llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get()); -+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendRegisterSystemLibSymbol"), module_.get()); - } else { - f_tvm_register_system_symbol_ = nullptr; - } - if (dynamic_lookup || system_lib) { - f_tvm_func_call_ = llvm::Function::Create( - ftype_tvm_func_call_, -- llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); -+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMFuncCall"), module_.get()); - f_tvm_get_func_from_env_ = llvm::Function::Create( - ftype_tvm_get_func_from_env_, - llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get()); - f_tvm_api_set_last_error_ = llvm::Function::Create( - ftype_tvm_api_set_last_error_, -- llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); -+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMAPISetLastError"), module_.get()); - f_tvm_parallel_launch_ = llvm::Function::Create( - ftype_tvm_parallel_launch_, -- llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get()); -+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelLaunch"), module_.get()); - f_tvm_parallel_barrier_ = llvm::Function::Create( - ftype_tvm_parallel_barrier_, -- llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); -+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelBarrier"), module_.get()); - } - this->InitGlobalContext(dynamic_lookup); - } -@@ -461,11 +461,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) { - } - std::swap(function_, fcompute); - std::swap(new_vmap, var_map_); -+ std::stack br_ret_flg; -+ std::swap(br_ret_flg, br_ret_flg_); - BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_); - builder_->SetInsertPoint(compute_entry); - this->VisitStmt(op->body); - builder_->CreateRet(ConstInt32(0)); - // swap the var map back, now we are back on track. -+ std::swap(br_ret_flg, br_ret_flg_); - std::swap(new_vmap, var_map_); - std::swap(function_, fcompute); - builder_->SetInsertPoint(compute_call_end); -@@ -542,9 +545,12 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { - std::swap(function_, f); - std::swap(parallel_env_, par_env); - std::swap(var_map_, new_vmap); -+ std::stack br_ret_flg; -+ std::swap(br_ret_flg, br_ret_flg_); - this->VisitStmt(body); - builder_->CreateRet(ConstInt32(0)); - // swap the var map back, now we are back on track. -+ std::swap(br_ret_flg, br_ret_flg_); - std::swap(var_map_, new_vmap); - std::swap(parallel_env_, par_env); - std::swap(function_, f); -@@ -794,7 +800,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { - } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) { - return CreateStaticHandle(); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { -- builder_->CreateRet(ConstInt32(-1)); -+ llvm::Value* pRetCode = (op->args.size() == 0) ? ConstInt32(-1) : MakeValue(op->args[0]); -+ builder_->CreateRet(pRetCode); -+ CodeGenLLVM::SetRetTrFlg(true); - return ConstInt32(-1); - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { - CHECK_EQ(op->args.size(), 3U); -diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc -index 2cff88b..e26812d 100644 ---- a/src/codegen/llvm/codegen_llvm.cc -+++ b/src/codegen/llvm/codegen_llvm.cc -@@ -1110,23 +1110,37 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) { - *ctx_, "if_then", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); -+ // define ret terminitor exist flg for this Stmt -+ bool cur_br_ret_flg = false; -+ br_ret_flg_.push(&cur_br_ret_flg); - if (op->else_case.defined()) { - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); - builder_->CreateCondBr(cond, then_block, else_block); - builder_->SetInsertPoint(then_block); -+ cur_br_ret_flg = false; - this->VisitStmt(op->then_case); - builder_->CreateBr(end_block); -+ if ( !cur_br_ret_flg ){ -+ builder_->CreateBr(end_block); -+ } - builder_->SetInsertPoint(else_block); -+ cur_br_ret_flg = false; - this->VisitStmt(op->else_case); -- builder_->CreateBr(end_block); -+ if ( !cur_br_ret_flg ){ -+ builder_->CreateBr(end_block); -+ } - } else { - builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_); - builder_->SetInsertPoint(then_block); -+ cur_br_ret_flg = false; - this->VisitStmt(op->then_case); -- builder_->CreateBr(end_block); -+ if ( !cur_br_ret_flg ){ -+ builder_->CreateBr(end_block); -+ } - } - builder_->SetInsertPoint(end_block); -+ br_ret_flg_.pop(); - } - - -diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h -index b7d091b..6fba863 100644 ---- a/src/codegen/llvm/codegen_llvm.h -+++ b/src/codegen/llvm/codegen_llvm.h -@@ -143,6 +143,11 @@ class CodeGenLLVM : - void VisitStmt_(const Block* op) override; - void VisitStmt_(const Evaluate* op) override; - void VisitStmt_(const ProducerConsumer* op) override; -+ //Set IfThelElse branch exist Return flg -+ void SetRetTrFlg(bool RetFlg){ -+ if( !br_ret_flg_.empty() ) -+ *(br_ret_flg_.top()) = RetFlg; -+ } - - protected: - /*! \brief The storage information */ -@@ -304,6 +309,12 @@ class CodeGenLLVM : - * initializes file and compilation_unit_ to TVM defaults. - */ - static std::unique_ptr CreateDebugInfo(llvm::Module* module); -+ -+ /* -+ * IfThenElse stmt branch return flg store stack -+ * if a branch already return, can't add br terminator again -+ */ -+ std::stack br_ret_flg_; - }; - } // namespace codegen - } // namespace tvm -diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc -index e73956c..3a7b46c 100644 ---- a/src/pass/lower_tvm_builtin.cc -+++ b/src/pass/lower_tvm_builtin.cc -@@ -104,7 +104,7 @@ class BuiltinLower : public IRMutator { - CHECK(device_type_.defined()) << "Unknown device type in current IR"; - CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = Evaluate::make(Call::make(Int(32), -- intrinsic::tvm_throw_last_error, {}, -+ intrinsic::tvm_throw_last_error, {(Int(32), 1001)}, - Call::Intrinsic)); - - Stmt body = Block::make( -@@ -117,7 +117,7 @@ class BuiltinLower : public IRMutator { - Stmt alloca = LetStmt::make( - op->buffer_var, - Call::make(op->buffer_var.type(), -- "TVMBackendAllocWorkspace", -+ TVM_RT_FUNC_TRANS("TVMBackendAllocWorkspace"), - {cast(Int(32), device_type_), - cast(Int(32), device_id_), - cast(UInt(64), total_bytes), -@@ -127,7 +127,7 @@ class BuiltinLower : public IRMutator { - body); - - Expr free_op = Call::make(Int(32), -- "TVMBackendFreeWorkspace", -+ TVM_RT_FUNC_TRANS("TVMBackendFreeWorkspace"), - {cast(Int(32), device_type_), - cast(Int(32), device_id_), - op->buffer_var}, diff --git a/third_party/patch/sentencepiece/sentencepiece.patch001 b/third_party/patch/sentencepiece/sentencepiece.patch001 new file mode 100644 index 0000000000..f86ea6de8d --- /dev/null +++ b/third_party/patch/sentencepiece/sentencepiece.patch001 @@ -0,0 +1,91 @@ +diff -Npur sentencepiece-0.1.92/src/CMakeLists.txt sentencepiece-0.1.92_bak/src/CMakeLists.txt +--- sentencepiece-0.1.92/src/CMakeLists.txt 2020-06-08 16:25:01.000000000 +0800 ++++ sentencepiece-0.1.92_bak/src/CMakeLists.txt 2020-07-02 17:42:33.306933546 +0800 +@@ -11,6 +11,46 @@ + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.! ++add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) ++ ++ ++function(protobuf_generate c_var h_var) ++ if(NOT ARGN) ++ message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files") ++ return() ++ endif() ++ ++ set(${c_var}) ++ set(${h_var}) ++ ++ find_program(PROTOC_EXE NAMES "protoc" PATHS ${PROTOBUF_INC}/../bin NO_DEFAULT_PATH) ++ ++ foreach(file ${ARGN}) ++ get_filename_component(abs_file ${file} ABSOLUTE) ++ get_filename_component(file_name ${file} NAME_WE) ++ get_filename_component(file_dir ${abs_file} PATH) ++ file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir}) ++ ++ list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${file_name}.pb.cc") ++ list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${file_name}.pb.h") ++ ++ add_custom_command( ++ OUTPUT "${CMAKE_BINARY_DIR}/${file_name}.pb.cc" ++ "${CMAKE_BINARY_DIR}/${file_name}.pb.h" ++ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} ++ COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}" ++ COMMAND ${PROTOC_EXE} -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR} ${abs_file} ++ DEPENDS ${PROTOC_EXE} ${abs_file} ++ COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM) ++ endforeach() ++ ++ set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) ++ set(${c_var} ${${c_var}} PARENT_SCOPE) ++ set(${h_var} ${${h_var}} PARENT_SCOPE) ++ ++endfunction() ++ ++ + + if (SPM_USE_BUILTIN_PROTOBUF) + set(SPM_PROTO_HDRS builtin_pb/sentencepiece.pb.h) +@@ -52,12 +92,9 @@ if (SPM_USE_BUILTIN_PROTOBUF) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite) + include_directories(builtin_pb) + else() +- find_package(Protobuf REQUIRED) +- include_directories(${Protobuf_INCLUDE_DIRS}) +- protobuf_generate_cpp(SPM_PROTO_SRCS SPM_PROTO_HDRS sentencepiece.proto) +- protobuf_generate_cpp(SPM_MODEL_PROTO_SRCS SPM_MODEL_PROTO_HDRS sentencepiece_model.proto) +- set(PROTOBUF_LITE_SRCS "") +- include_directories(${PROTOBUF_INCLUDE_DIR}) ++ include_directories(${PROTOBUF_INC}) ++ protobuf_generate(SPM_PROTO_SRCS SPM_PROTO_HDRS sentencepiece.proto) ++ protobuf_generate(SPM_MODEL_PROTO_SRCS SPM_MODEL_PROTO_HDRS sentencepiece_model.proto) + endif() + + include_directories(${CMAKE_CURRENT_BINARY_DIR}) +@@ -191,11 +228,13 @@ endif() + add_library(sentencepiece-static STATIC ${SPM_SRCS}) + add_library(sentencepiece_train-static STATIC ${SPM_TRAIN_SRCS}) + +-target_link_libraries(sentencepiece-static INTERFACE ${SPM_LIBS}) ++find_library(PROTO_LIB NAMES "libprotobuf.a" PATHS ${PROTOBUF_INC}/../lib NO_DEFAULT_PATH) ++ ++target_link_libraries(sentencepiece-static INTERFACE ${PROTO_LIB} ${SPM_LIBS}) + target_link_libraries(sentencepiece_train-static INTERFACE sentencepiece-static ${SPM_LIBS}) + + if (SPM_ENABLE_SHARED) +- target_link_libraries(sentencepiece ${SPM_LIBS}) ++ target_link_libraries(sentencepiece ${SPM_LIBS} ${PROTO_LIB}) + target_link_libraries(sentencepiece_train ${SPM_LIBS} sentencepiece) + set(SPM_INSTALLTARGETS sentencepiece sentencepiece_train sentencepiece-static sentencepiece_train-static) + set_target_properties(sentencepiece sentencepiece_train PROPERTIES SOVERSION 0 VERSION 0.0.0) +@@ -265,7 +304,7 @@ install(TARGETS ${SPM_INSTALLTARGETS} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +-install(FILES sentencepiece_trainer.h sentencepiece_processor.h ++install(FILES sentencepiece_trainer.h sentencepiece_processor.h "${CMAKE_BINARY_DIR}/sentencepiece_model.pb.h" + DESTINATION ${CMAKE_INSTALL_INCDIR}) + + file(TO_NATIVE_PATH "${PROJECT_SOURCE_DIR}/data" data_dir) \ No newline at end of file